#include <tommath.h>
#ifdef BN_MP_TOOM_SQR_C
/* LibTomMath, multiple-precision integer library -- Tom St Denis
 *
 * LibTomMath is a library that provides multiple-precision
 * integer arithmetic as well as number theoretic functionality.
 *
 * The library was designed directly after the MPI library by
 * Michael Fromberger but has been written from scratch with
 * additional optimizations in place.
 *
 * Tom St Denis, tomstdenis@gmail.com, http://libtom.org
 * 
 * 2008 - Complete rewrite by Marco Bodrato, bodrato@mail.dm.unipi.it
 *
 * Copyright 2008, Marco Bodrato, http://bodrato.it/
 *
 * This file is a replacement for the original implementation.
 * This is NOT public domain, it is released under the GNU GPL.
 *
 * The Toom-Cook implementation for LibTomMath is free software; you
 * can redistribute it and/or modify it under the terms of the GNU
 * General Public License as published by the Free Software
 * Foundation; either version 3 of the License, or (at your option)
 * any later version.
 * 
 * The Toom-Cook implementation is distributed in the hope that it
 * will be useful, but WITHOUT ANY WARRANTY; without even the implied
 * warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
 * See the GNU Lesser General Public License for more details.
 * 
 * You should have received a copy of the GNU General Public License
 * along with the GNU MP Library; see the file COPYING.  If not, write
 * to the Free Software Foundation, Inc., 51 Franklin Street, Fifth
 * Floor, Boston, MA 02110-1301, USA.
 */

/* Squaring using the Toom-Cook 3-way algorithm.
 * 
 * The same algorithm used for Toom multiplication, with only one
 * evaluation. Sequence described on http://bodrato.it/toom-cook/ .
 * Some shortcut and some variable renaming/reusing were added to
 * avoid unnecessary allocations.  w0=a0; w2=a1; c=a2=w4.
 */

int
mp_toom_sqr(mp_int *a, mp_int *b)
{
  mp_int w0, w1, w2, w3, *c;
  int res, B;
  c = b;

  /* B */
  B = (a->used + 2) / 3;

  /* default the return code to an error */
  res = MP_MEM;

  /* init size all the temps */
  if (mp_init_size (&w0, B<<1) != MP_OKAY)
    goto END;
  if (mp_init_size (&w1, (B<<1)+1) != MP_OKAY)
    goto W0;
  if (mp_init_size (&w2, (B<<1)+1) != MP_OKAY)
    goto W1;
  if (mp_init_size (&w3, (B<<1)+1) != MP_OKAY)
    goto W2;

  /* init result */
  if (c->alloc < (a->used << 1) - 1 ) {
    if ((res = mp_grow (c, a->used << 1)) != MP_OKAY) {
      return res;
    }
  }
  c->sign = MP_ZPOS;
  /* now shift the digits */
  w0.used = w2.used = B;
  c->used = a->used - B*2;

  {
    register int x;
    register mp_digit *tmpa, *tmpx;

    /* we copy the digits directly instead of using higher level functions
     * since we also need to shift the digits
     */
    tmpa = a->dp;

    tmpx = w0.dp;
    for (x = B + 1; --x;) {
      *tmpx++ = *tmpa++;
    }
    tmpx = w2.dp;
    for (x = B + 1; --x;) {
      *tmpx++ = *tmpa++;
    }

    tmpx = c->dp;
    for (x = c->used + 1; --x;) {
      *tmpx++ = *tmpa++;
    }
  }

  /* only need to clamp the lower words since by definition the 
   * upper words a2 must have a known number of digits
   */
  mp_clamp (&w0);
  mp_clamp (&w2);

  /* w3 = a0+a2 */
  if ((res = s_mp_add(&w0, c, &w3)) != MP_OKAY)
    goto ERR;

  /* w1 = w3-a1 = a2-a1+a0 */
  if ((res = mp_sub(&w3, &w2, &w1)) != MP_OKAY)
    goto ERR;

  /* w3 = w3+a1 = a2+a1+a0 */
  if ((res = s_mp_add(&w3, &w2, &w3)) != MP_OKAY)
    goto ERR;

  /* w2 = w3**2 */
  if ((res = mp_sqr(&w3, &w2)) != MP_OKAY)
    goto ERR;

  /* w3 = (w3+a2)*2 -a0 = 4a2+2a1+a0 */
  if ((res = s_mp_add(&w3, c, &w3)) != MP_OKAY)
    goto ERR;
  if ((res = mp_mul_2(&w3, &w3)) != MP_OKAY)
    goto ERR;
  if ((res = s_mp_sub(&w3, &w0, &w3)) != MP_OKAY)
    goto ERR;

  /* w1 = w1**2 */
  if ((res = mp_sqr(&w1, &w1)) != MP_OKAY)
    goto ERR;

  /* w3 = w3**2 */
  if ((res = mp_sqr(&w3, &w3)) != MP_OKAY)
    goto ERR;

  /* w0 = a0**2 */
  if ((res = mp_sqr(&w0, &w0)) != MP_OKAY)
    goto ERR;

  /* w4 = a2**2 */
  if ((res = mp_sqr(c, c)) != MP_OKAY)
    goto ERR;

  /* now solve the matrix 
    
       0  0  0  0  1
       1 -1  1 -1  1
       1  1  1  1  1
       16 8  4  2  1
       1  0  0  0  0
       
       using 8 subtractions, 3 shifts, 
             1 division by 3. (one shift is replaced by a repeated subtraction)
  */

  /* w3 = (w3-w1)/3  -> [5 3 1 1 0] */
  if ((res = mp_sub(&w3, &w1, &w3)) != MP_OKAY)
    goto ERR;
  if ((res = mp_div_3(&w3, &w3, NULL)) != MP_OKAY)
    goto ERR;

  /* w1 = (w2-w1)/2  -> [0 1 0 1 0] */
  if ((res = mp_sub(&w2, &w1, &w1)) != MP_OKAY)
    goto ERR;
  if ((res = mp_div_2(&w1, &w1)) != MP_OKAY)
    goto ERR;

  /* w2 = w2-w0  -> [1 1 1 1 0] */
  if ((res = s_mp_sub(&w2, &w0, &w2)) != MP_OKAY) /* all positive */
    goto ERR;

  /* w3 = (w3-w2)/2 - w4*2 */
  if ((res = s_mp_sub(&w3, &w2, &w3)) != MP_OKAY) /* all positive */
    goto ERR;
  if ((res = mp_div_2(&w3, &w3)) != MP_OKAY)
    goto ERR;
/*   if ((res = mp_mul_2(c, &a2)) != MP_OKAY) */
/*     goto ERR; */
/*   if ((res = s_mp_sub(&w3, &a2, &w3)) != MP_OKAY) */
/*     goto ERR; */
  if ((res = s_mp_sub(&w3, c, &w3)) != MP_OKAY) /* all positive */
    goto ERR;
  if ((res = s_mp_sub(&w3, c, &w3)) != MP_OKAY) /* all positive */
    goto ERR;

  /* w2 = w2-w1-w4 */
  if ((res = s_mp_sub(&w2, c, &w2)) != MP_OKAY) /* all positive */
    goto ERR;
  if ((res = s_mp_sub(&w2, &w1, &w2)) != MP_OKAY) /* all positive */
    goto ERR;

  /* w1 = w1-w3 */
  if ((res = s_mp_sub(&w1, &w3, &w1)) != MP_OKAY) /* all positive */
    goto ERR;

  /* at this point shift W[n] by B*n */
  if ((res = mp_lshd(c, B)) != MP_OKAY)
    goto ERR;
  if ((res = s_mp_add(&w3, c, c)) != MP_OKAY)
    goto ERR;
  if ((res = mp_lshd(c, B)) != MP_OKAY)
    goto ERR;
  if ((res = s_mp_add(&w2, c, c)) != MP_OKAY)
    goto ERR;
  if ((res = mp_lshd(c, B)) != MP_OKAY)
    goto ERR;
  if ((res = s_mp_add(&w1, c, c)) != MP_OKAY)
    goto ERR;
  if ((res = mp_lshd(c, B)) != MP_OKAY)
    goto ERR;
  if ((res = s_mp_add(&w0, c, c)) != MP_OKAY)
    goto ERR;

ERR:mp_clear(&w3);
 W2:mp_clear(&w2);
 W1:mp_clear(&w1);
 W0:mp_clear(&w0);
END:
  return res;
}

#endif

/* $Source: /cvs/libtom/libtommath/bn_mp_toom_sqr.c,v $ */
/* $Revision: 1.5 $ */
/* $Date: 2008/12/15 13:14:15 $ */

