[project @ 1998-11-26 09:17:22 by sof]
[ghc-hetmet.git] / ghc / runtime / gmp / mpn_mul.c
1 /* mpn_mul -- Multiply two natural numbers.
2
3 Copyright (C) 1991, 1992 Free Software Foundation, Inc.
4
5 This file is part of the GNU MP Library.
6
7 The GNU MP Library is free software; you can redistribute it and/or modify
8 it under the terms of the GNU General Public License as published by
9 the Free Software Foundation; either version 2, or (at your option)
10 any later version.
11
12 The GNU MP Library is distributed in the hope that it will be useful,
13 but WITHOUT ANY WARRANTY; without even the implied warranty of
14 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 GNU General Public License for more details.
16
17 You should have received a copy of the GNU General Public License
18 along with the GNU MP Library; see the file COPYING.  If not, write to
19 the Free Software Foundation, 675 Mass Ave, Cambridge, MA 02139, USA.  */
20
21 #include "gmp.h"
22 #include "gmp-impl.h"
23 #include "longlong.h"
24
25 #ifdef GMP_DEBUG /* partain: was DEBUG */
26 #define MPN_MUL_VERIFY(res_ptr,res_size,op1_ptr,op1_size,op2_ptr,op2_size) \
27   mpn_mul_verify (res_ptr, res_size, op1_ptr, op1_size, op2_ptr, op2_size)
28
29 #include <stdio.h>
30 static void
31 mpn_mul_verify (res_ptr, res_size, op1_ptr, op1_size, op2_ptr, op2_size)
32      mp_ptr res_ptr, op1_ptr, op2_ptr;
33      mp_size res_size, op1_size, op2_size;
34 {
35   mp_ptr tmp_ptr;
36   mp_size tmp_size;
37   tmp_ptr = alloca ((op1_size + op2_size) * BYTES_PER_MP_LIMB);
38   if (op1_size >= op2_size)
39     tmp_size = mpn_mul_classic (tmp_ptr,
40                                  op1_ptr, op1_size, op2_ptr, op2_size);
41   else
42     tmp_size = mpn_mul_classic (tmp_ptr,
43                                  op2_ptr, op2_size, op1_ptr, op1_size);
44   if (tmp_size != res_size
45       || mpn_cmp (tmp_ptr, res_ptr, tmp_size) != 0)
46     {
47       fprintf (stderr, "GNU MP internal error: Wrong result in mpn_mul.\n");
48       fprintf (stderr, "op1{%d} = ", op1_size); mpn_dump (op1_ptr, op1_size);
49       fprintf (stderr, "op2{%d} = ", op2_size); mpn_dump (op2_ptr, op2_size);
50       abort ();
51     }
52 }
53 #else
54 #define MPN_MUL_VERIFY(a,b,c,d,e,f)
55 #endif
56
57 /* Multiply the natural numbers u (pointed to by UP, with USIZE limbs)
58    and v (pointed to by VP, with VSIZE limbs), and store the result at
59    PRODP.  USIZE + VSIZE limbs are always stored, but if the input
60    operands are normalized, the return value will reflect the true
61    result size (which is either USIZE + VSIZE, or USIZE + VSIZE -1).
62
63    NOTE: The space pointed to by PRODP is overwritten before finished
64    with U and V, so overlap is an error.
65
66    Argument constraints:
67    1. USIZE >= VSIZE.
68    2. PRODP != UP and PRODP != VP, i.e. the destination
69       must be distinct from the multiplier and the multiplicand.  */
70
71 /* If KARATSUBA_THRESHOLD is not already defined, define it to a
72    value which is good on most machines.  */
73 #ifndef KARATSUBA_THRESHOLD
74 #define KARATSUBA_THRESHOLD 8
75 #endif
76
77 /* The code can't handle KARATSUBA_THRESHOLD smaller than 4.  */
78 #if KARATSUBA_THRESHOLD < 4
79 #undef KARATSUBA_THRESHOLD
80 #define KARATSUBA_THRESHOLD 4
81 #endif
82
83 mp_size
84 #ifdef __STDC__
85 mpn_mul (mp_ptr prodp,
86           mp_srcptr up, mp_size usize,
87           mp_srcptr vp, mp_size vsize)
88 #else
89 mpn_mul (prodp, up, usize, vp, vsize)
90      mp_ptr prodp;
91      mp_srcptr up;
92      mp_size usize;
93      mp_srcptr vp;
94      mp_size vsize;
95 #endif
96 {
97   mp_size n;
98   mp_size prod_size;
99   mp_limb cy;
100
101   if (vsize < KARATSUBA_THRESHOLD)
102     {
103       /* Handle simple cases with traditional multiplication.
104
105          This is the most critical code of the entire function.  All
106          multiplies rely on this, both small and huge.  Small ones arrive
107          here immediately.  Huge ones arrive here as this is the base case
108          for the recursive algorithm below.  */
109       mp_size i, j;
110       mp_limb prod_low, prod_high;
111       mp_limb cy_limb;
112       mp_limb v_limb;
113
114       if (vsize == 0)
115         return 0;
116
117       /* Offset UP and PRODP so that the inner loop can be faster.  */
118       up += usize;
119       prodp += usize;
120
121       /* Multiply by the first limb in V separately, as the result can
122          be stored (not added) to PROD.  We also avoid a loop for zeroing.  */
123       v_limb = vp[0];
124       if (v_limb <= 1)
125         {
126           if (v_limb == 1)
127             MPN_COPY (prodp - usize, up - usize, usize);
128           else
129             MPN_ZERO (prodp - usize, usize);
130           cy_limb = 0;
131         }
132       else
133         {
134           cy_limb = 0;
135           j = -usize;
136           do
137             {
138               umul_ppmm (prod_high, prod_low, up[j], v_limb);
139               add_ssaaaa (cy_limb, prodp[j], prod_high, prod_low, 0, cy_limb);
140               j++;
141             }
142           while (j < 0);
143         }
144
145       prodp[0] = cy_limb;
146       prodp++;
147
148       /* For each iteration in the outer loop, multiply one limb from
149          U with one limb from V, and add it to PROD.  */
150       for (i = 1; i < vsize; i++)
151         {
152           v_limb = vp[i];
153           if (v_limb <= 1)
154             {
155               cy_limb = 0;
156               if (v_limb == 1)
157                 cy_limb = mpn_add (prodp - usize,
158                                     prodp - usize, usize, up - usize, usize);
159             }
160           else
161             {
162               cy_limb = 0;
163               j = -usize;
164
165               do
166                 {
167                   umul_ppmm (prod_high, prod_low, up[j], v_limb);
168                   add_ssaaaa (cy_limb, prod_low,
169                               prod_high, prod_low, 0, cy_limb);
170                   add_ssaaaa (cy_limb, prodp[j],
171                               cy_limb, prod_low, 0, prodp[j]);
172                   j++;
173                 }
174               while (j < 0);
175             }
176
177           prodp[0] = cy_limb;
178           prodp++;
179         }
180
181       return usize + vsize - (cy_limb == 0);
182     }
183
184   n = (usize + 1) / 2;
185
186   /* Is USIZE larger than 1.5 times VSIZE?  Avoid Karatsuba's algorithm.  */
187   if (2 * usize > 3 * vsize)
188     {
189       /* If U has at least twice as many limbs as V.  Split U in two
190          pieces, U1 and U0, such that U = U0 + U1*(2**BITS_PER_MP_LIMB)**N,
191          and recursively multiply the two pieces separately with V.  */
192
193       mp_size u0_size;
194       mp_ptr tmp;
195       mp_size tmp_size;
196
197       /* V1 (the high part of V) is zero.  */
198
199       /* Calculate the length of U0.  It is normally equal to n, but
200          of course not for sure.  */
201       for (u0_size = n; u0_size > 0 && up[u0_size - 1] == 0; u0_size--)
202         ;
203
204       /* Perform (U0 * V).  */
205       if (u0_size >= vsize)
206         prod_size = mpn_mul (prodp, up, u0_size, vp, vsize);
207       else
208         prod_size = mpn_mul (prodp, vp, vsize, up, u0_size);
209       MPN_MUL_VERIFY (prodp, prod_size, up, u0_size, vp, vsize);
210
211       /* We have to zero-extend the lower partial product to n limbs,
212          since the mpn_add some lines below expect the first n limbs
213          to be well defined.  (This is normally a no-op.  It may
214          do something when U1 has many leading 0 limbs.) */
215       while (prod_size < n)
216         prodp[prod_size++] = 0;
217
218       tmp = (mp_ptr) alloca ((usize + vsize - n) * BYTES_PER_MP_LIMB);
219
220       /* Perform (U1 * V).  Make sure the first source argument to mpn_mul
221          is not less than the second source argument.  */
222       if (vsize <= usize - n)
223         tmp_size = mpn_mul (tmp, up + n, usize - n, vp, vsize);
224       else
225         tmp_size = mpn_mul (tmp, vp, vsize, up + n, usize - n);
226       MPN_MUL_VERIFY (tmp, tmp_size, up + n, usize - n, vp, vsize);
227
228       /* In this addition hides a potentially large copying of TMP.  */
229       if (prod_size - n >= tmp_size)
230         cy = mpn_add (prodp + n, prodp + n, prod_size - n, tmp, tmp_size);
231       else
232         cy = mpn_add (prodp + n, tmp, tmp_size, prodp + n, prod_size - n);
233       if (cy)
234         abort (); /* prodp[prod_size] = cy; */
235
236       alloca (0);
237       return tmp_size + n;
238     }
239   else
240     {
241       /* Karatsuba's divide-and-conquer algorithm.
242
243          Split U in two pieces, U1 and U0, such that
244          U = U0 + U1*(B**n),
245          and V in V1 and V0, such that
246          V = V0 + V1*(B**n).
247
248          UV is then computed recursively using the identity
249
250                 2n   n        n                   n
251          UV = (B  + B )U V + B (U -U )(V -V ) + (B + 1)U V
252                         1 1      1  0   0  1            0 0
253
254          Where B = 2**BITS_PER_MP_LIMB.
255        */
256
257       /* It's possible to decrease the temporary allocation by using the
258          prodp area for temporary storage of the middle term, and doing
259          that recursive multiplication first.  (Do this later.)  */
260
261       mp_size u0_size;
262       mp_size v0_size;
263       mp_size u0v0_size;
264       mp_size u1v1_size;
265       mp_ptr temp;
266       mp_size temp_size;
267       mp_size utem_size;
268       mp_size vtem_size;
269       mp_ptr ptem;
270       mp_size ptem_size;
271       int negflg;
272       mp_ptr pp;
273
274       pp = (mp_ptr) alloca (4 * n * BYTES_PER_MP_LIMB);
275
276       /* Calculate the lengths of U0 and V0.  They are normally equal
277          to n, but of course not for sure.  */
278       for (u0_size = n; u0_size > 0 && up[u0_size - 1] == 0; u0_size--)
279         ;
280       for (v0_size = n; v0_size > 0 && vp[v0_size - 1] == 0; v0_size--)
281         ;
282
283       /*** 1. PROD]2n..0] := U0 x V0
284             (Recursive call to mpn_mul may NOT overwrite input operands.)
285              ________________  ________________
286             |________________||____U0 x V0_____|  */
287
288       if (u0_size >= v0_size)
289         u0v0_size = mpn_mul (pp, up, u0_size, vp, v0_size);
290       else
291         u0v0_size = mpn_mul (pp, vp, v0_size, up, u0_size);
292       MPN_MUL_VERIFY (pp, u0v0_size, up, u0_size, vp, v0_size);
293
294       /* Zero-extend to 2n limbs. */
295       while (u0v0_size < 2 * n)
296         pp[u0v0_size++] = 0;
297
298
299       /*** 2. PROD]4n..2n] := U1 x V1
300             (Recursive call to mpn_mul may NOT overwrite input operands.)
301              ________________  ________________
302             |_____U1 x V1____||____U0 x V0_____|  */
303
304       u1v1_size = mpn_mul (pp + 2*n,
305                              up + n, usize - n,
306                              vp + n, vsize - n);
307       MPN_MUL_VERIFY (pp + 2*n, u1v1_size,
308                       up + n, usize - n, vp + n, vsize - n);
309       prod_size = 2 * n + u1v1_size;
310
311
312       /*** 3. PTEM]2n..0] := (U1-U0) x (V0-V1)
313             (Recursive call to mpn_mul may overwrite input operands.)
314              ________________
315             |_(U1-U0)(V0-V1)_|  */
316
317       temp = (mp_ptr) alloca ((2 * n + 1) * BYTES_PER_MP_LIMB);
318       if (usize - n > u0_size
319           || (usize - n == u0_size
320               && mpn_cmp (up + n, up, u0_size) >= 0))
321         {
322           utem_size = usize - n
323             + mpn_sub (temp, up + n, usize - n, up, u0_size);
324           negflg = 0;
325         }
326       else
327         {
328           utem_size = u0_size
329             + mpn_sub (temp, up, u0_size, up + n, usize - n);
330           negflg = 1;
331         }
332       if (vsize - n > v0_size
333           || (vsize - n == v0_size
334               && mpn_cmp (vp + n, vp, v0_size) >= 0))
335         {
336           vtem_size = vsize - n
337             + mpn_sub (temp + n, vp + n, vsize - n, vp, v0_size);
338           negflg ^= 1;
339         }
340       else
341         {
342           vtem_size = v0_size
343             + mpn_sub (temp + n, vp, v0_size, vp + n, vsize - n);
344           /* No change of NEGFLG.  */
345         }
346       ptem = (mp_ptr) alloca (2 * n * BYTES_PER_MP_LIMB);
347       if (utem_size >= vtem_size)
348         ptem_size = mpn_mul (ptem, temp, utem_size, temp + n, vtem_size);
349       else
350         ptem_size = mpn_mul (ptem, temp + n, vtem_size, temp, utem_size);
351       MPN_MUL_VERIFY (ptem, ptem_size, temp, utem_size, temp + n, vtem_size);
352
353       /*** 4. TEMP]2n..0] := PROD]2n..0] + PROD]4n..2n]
354               ________________
355              |_____U1 x V1____|
356               ________________
357              |_____U0_x_V0____|  */
358
359       cy = mpn_add (temp, pp, 2*n, pp + 2*n, u1v1_size);
360       if (cy != 0)
361         {
362           temp[2*n] = cy;
363           temp_size = 2*n + 1;
364         }
365       else
366         {
367           /* Normalize temp.  pp[2*n-1] might have been zero in the
368              mpn_add call above, and thus temp might be unnormalized.  */
369           for (temp_size = 2*n; temp_size > 0 && temp[temp_size - 1] == 0;
370                temp_size--)
371             ;
372         }
373
374       if (prod_size - n >= temp_size)
375         cy = mpn_add (pp + n, pp + n, prod_size - n, temp, temp_size);
376       else
377         {
378           /* This is a weird special case that should not happen (often)!  */
379           cy = mpn_add (pp + n, temp, temp_size, pp + n, prod_size - n);
380           prod_size = temp_size + n;
381         }
382       if (cy != 0)
383         {
384           pp[prod_size] = cy;
385           prod_size++;
386         }
387 #ifdef GMP_DEBUG  /* partain: was DEBUG */
388       if (prod_size > 4 * n)
389         abort();
390 #endif
391       if (negflg)
392         prod_size = prod_size
393           + mpn_sub (pp + n, pp + n, prod_size - n, ptem, ptem_size);
394       else
395         {
396           if (prod_size - n < ptem_size)
397             abort();
398           cy = mpn_add (pp + n, pp + n, prod_size - n, ptem, ptem_size);
399           if (cy != 0)
400             {
401               pp[prod_size] = cy;
402               prod_size++;
403 #ifdef GMP_DEBUG /* partain: was DEBUG */
404               if (prod_size > 4 * n)
405                 abort();
406 #endif
407             }
408         }
409
410       MPN_COPY (prodp, pp, prod_size);
411       alloca (0);
412       return prod_size;
413     }
414 }