remove empty dir
[ghc-hetmet.git] / rts / gmp / mpz / powm.c
1 /* mpz_powm(res,base,exp,mod) -- Set RES to (base**exp) mod MOD.
2
3 Copyright (C) 1991, 1993, 1994, 1996, 1997, 2000 Free Software Foundation, Inc.
4 Contributed by Paul Zimmermann.
5
6 This file is part of the GNU MP Library.
7
8 The GNU MP Library is free software; you can redistribute it and/or modify
9 it under the terms of the GNU Lesser General Public License as published by
10 the Free Software Foundation; either version 2.1 of the License, or (at your
11 option) any later version.
12
13 The GNU MP Library is distributed in the hope that it will be useful, but
14 WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
15 or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU Lesser General Public
16 License for more details.
17
18 You should have received a copy of the GNU Lesser General Public License
19 along with the GNU MP Library; see the file COPYING.LIB.  If not, write to
20 the Free Software Foundation, Inc., 59 Temple Place - Suite 330, Boston,
21 MA 02111-1307, USA. */
22
23 #include "gmp.h"
24 #include "gmp-impl.h"
25 #include "longlong.h"
26 #ifdef BERKELEY_MP
27 #include "mp.h"
28 #endif
29
30
31 /* set c <- (a*b)/R^n mod m c has to have at least (2n) allocated limbs */
32 static void
33 #if __STDC__
34 mpz_redc (mpz_ptr c, mpz_srcptr a, mpz_srcptr b, mpz_srcptr m, mp_limb_t Nprim)
35 #else
36 mpz_redc (c, a, b, m, Nprim)
37      mpz_ptr c;
38      mpz_srcptr a;
39      mpz_srcptr b;
40      mpz_srcptr m;
41      mp_limb_t Nprim;
42 #endif
43 {
44   mp_ptr cp, mp = PTR (m);
45   mp_limb_t cy, cout = 0;
46   mp_limb_t q;
47   size_t j, n = ABSIZ (m);
48
49   ASSERT (ALLOC (c) >= 2 * n);
50
51   mpz_mul (c, a, b);
52   cp = PTR (c);
53   j = ABSIZ (c);
54   MPN_ZERO (cp + j, 2 * n - j);
55   for (j = 0; j < n; j++)
56     {
57       q = cp[0] * Nprim;
58       cy = mpn_addmul_1 (cp, mp, n, q);
59       cout += mpn_add_1 (cp + n, cp + n, n - j, cy);
60       cp++;
61     }
62   cp -= n;
63   if (cout)
64     {
65       cy = cout - mpn_sub_n (cp, cp + n, mp, n);
66       while (cy)
67         cy -= mpn_sub_n (cp, cp, mp, n);
68     }
69   else
70     MPN_COPY (cp, cp + n, n);
71   MPN_NORMALIZE (cp, n);
72   SIZ (c) = SIZ (c) < 0 ? -n : n;
73 }
74
75 /* average number of calls to redc for an exponent of n bits
76    with the sliding window algorithm of base 2^k: the optimal is
77    obtained for the value of k which minimizes 2^(k-1)+n/(k+1):
78
79    n\k    4     5     6     7     8
80    128    156*  159   171   200   261
81    256    309   307*  316   343   403
82    512    617   607*  610   632   688
83    1024   1231  1204  1195* 1207  1256
84    2048   2461  2399  2366  2360* 2396
85    4096   4918  4787  4707  4665* 4670
86 */
87 \f
88 #ifndef BERKELEY_MP
89 void
90 #if __STDC__
91 mpz_powm (mpz_ptr res, mpz_srcptr base, mpz_srcptr e, mpz_srcptr mod)
92 #else
93 mpz_powm (res, base, e, mod)
94      mpz_ptr res;
95      mpz_srcptr base;
96      mpz_srcptr e;
97      mpz_srcptr mod;
98 #endif
99 #else /* BERKELEY_MP */
100 void
101 #if __STDC__
102 pow (mpz_srcptr base, mpz_srcptr e, mpz_srcptr mod, mpz_ptr res)
103 #else
104 pow (base, e, mod, res)
105      mpz_srcptr base;
106      mpz_srcptr e;
107      mpz_srcptr mod;
108      mpz_ptr res;
109 #endif
110 #endif /* BERKELEY_MP */
111 {
112   mp_limb_t invm, *ep, c, mask;
113   mpz_t xx, *g;
114   mp_size_t n, i, K, j, l, k;
115   int sh;
116   int use_redc;
117
118 #ifdef POWM_DEBUG
119   mpz_t exp;
120   mpz_init (exp);
121 #endif
122
123   n = ABSIZ (mod);
124
125   if (n == 0)
126     DIVIDE_BY_ZERO;
127
128   if (SIZ (e) == 0)
129     {
130       /* Exponent is zero, result is 1 mod MOD, i.e., 1 or 0
131          depending on if MOD equals 1.  */
132       SIZ(res) = (ABSIZ (mod) == 1 && (PTR(mod))[0] == 1) ? 0 : 1;
133       PTR(res)[0] = 1;
134       return;
135     }
136
137   /* Use REDC instead of usual reduction for sizes < POWM_THRESHOLD.
138      In REDC each modular multiplication costs about 2*n^2 limbs operations,
139      whereas using usual reduction it costs 3*K(n), where K(n) is the cost of a
140      multiplication using Karatsuba, and a division is assumed to cost 2*K(n),
141      for example using Burnikel-Ziegler's algorithm. This gives a theoretical
142      threshold of a*KARATSUBA_SQR_THRESHOLD, with a=(3/2)^(1/(2-ln(3)/ln(2))) ~
143      2.66.  */
144   /* For now, also disable REDC when MOD is even, as the inverse can't
145      handle that.  */
146
147 #ifndef POWM_THRESHOLD
148 #define POWM_THRESHOLD  ((8 * KARATSUBA_SQR_THRESHOLD) / 3)
149 #endif
150
151   use_redc = (n < POWM_THRESHOLD && PTR(mod)[0] % 2 != 0);
152   if (use_redc)
153     {
154       /* invm = -1/m mod 2^BITS_PER_MP_LIMB, must have m odd */
155       modlimb_invert (invm, PTR(mod)[0]);
156       invm = -invm;
157     }
158
159   /* determines optimal value of k */
160   l = ABSIZ (e) * BITS_PER_MP_LIMB; /* number of bits of exponent */
161   k = 1;
162   K = 2;
163   while (2 * l > K * (2 + k * (3 + k)))
164     {
165       k++;
166       K *= 2;
167     }
168
169   g = (mpz_t *) (*_mp_allocate_func) (K / 2 * sizeof (mpz_t));
170   /* compute x*R^n where R=2^BITS_PER_MP_LIMB */
171   mpz_init (g[0]);
172   if (use_redc)
173     {
174       mpz_mul_2exp (g[0], base, n * BITS_PER_MP_LIMB);
175       mpz_mod (g[0], g[0], mod);
176     }
177   else
178     mpz_mod (g[0], base, mod);
179
180   /* compute xx^g for odd g < 2^k */
181   mpz_init (xx);
182   if (use_redc)
183     {
184       _mpz_realloc (xx, 2 * n);
185       mpz_redc (xx, g[0], g[0], mod, invm); /* xx = x^2*R^n */
186     }
187   else
188     {
189       mpz_mul (xx, g[0], g[0]);
190       mpz_mod (xx, xx, mod);
191     }
192   for (i = 1; i < K / 2; i++)
193     {
194       mpz_init (g[i]);
195       if (use_redc)
196         {
197           _mpz_realloc (g[i], 2 * n);
198           mpz_redc (g[i], g[i - 1], xx, mod, invm); /* g[i] = x^(2i+1)*R^n */
199         }
200       else
201         {
202           mpz_mul (g[i], g[i - 1], xx);
203           mpz_mod (g[i], g[i], mod);
204         }
205     }
206
207   /* now starts the real stuff */
208   mask = (mp_limb_t) ((1<<k) - 1);
209   ep = PTR (e);
210   i = ABSIZ (e) - 1;                    /* current index */
211   c = ep[i];                            /* current limb */
212   count_leading_zeros (sh, c);
213   sh = BITS_PER_MP_LIMB - sh;           /* significant bits in ep[i] */
214   sh -= k;                              /* index of lower bit of ep[i] to take into account */
215   if (sh < 0)
216     {                                   /* k-sh extra bits are needed */
217       if (i > 0)
218         {
219           i--;
220           c = (c << (-sh)) | (ep[i] >> (BITS_PER_MP_LIMB + sh));
221           sh += BITS_PER_MP_LIMB;
222         }
223     }
224   else
225     c = c >> sh;
226 #ifdef POWM_DEBUG
227   printf ("-1/m mod 2^%u = %lu\n", BITS_PER_MP_LIMB, invm);
228   mpz_set_ui (exp, c);
229 #endif
230   j=0;
231   while (c % 2 == 0)
232     {
233       j++;
234       c = (c >> 1);
235     }
236   mpz_set (xx, g[c >> 1]);
237   while (j--)
238     {
239       if (use_redc)
240         mpz_redc (xx, xx, xx, mod, invm);
241       else
242         {
243           mpz_mul (xx, xx, xx);
244           mpz_mod (xx, xx, mod);
245         }
246     }
247
248 #ifdef POWM_DEBUG
249   printf ("x^"); mpz_out_str (0, 10, exp);
250   printf ("*2^%u mod m = ", n * BITS_PER_MP_LIMB); mpz_out_str (0, 10, xx);
251   putchar ('\n');
252 #endif
253
254   while (i > 0 || sh > 0)
255     {
256       c = ep[i];
257       sh -= k;
258       l = k;                            /* number of bits treated */
259       if (sh < 0)
260         {
261           if (i > 0)
262             {
263               i--;
264               c = (c << (-sh)) | (ep[i] >> (BITS_PER_MP_LIMB + sh));
265               sh += BITS_PER_MP_LIMB;
266             }
267           else
268             {
269               l += sh;                  /* may be less bits than k here */
270               c = c & ((1<<l) - 1);
271             }
272         }
273       else
274         c = c >> sh;
275       c = c & mask;
276
277       /* this while loop implements the sliding window improvement */
278       while ((c & (1 << (k - 1))) == 0 && (i > 0 || sh > 0))
279         {
280           if (use_redc) mpz_redc (xx, xx, xx, mod, invm);
281           else
282             {
283               mpz_mul (xx, xx, xx);
284               mpz_mod (xx, xx, mod);
285             }
286           if (sh)
287             {
288               sh--;
289               c = (c<<1) + ((ep[i]>>sh) & 1);
290             }
291           else
292             {
293               i--;
294               sh = BITS_PER_MP_LIMB - 1;
295               c = (c<<1) + (ep[i]>>sh);
296             }
297         }
298
299 #ifdef POWM_DEBUG
300       printf ("l=%u c=%lu\n", l, c);
301       mpz_mul_2exp (exp, exp, k);
302       mpz_add_ui (exp, exp, c);
303 #endif
304
305       /* now replace xx by xx^(2^k)*x^c */
306       if (c != 0)
307         {
308           j = 0;
309           while (c % 2 == 0)
310             {
311               j++;
312               c = c >> 1;
313             }
314           /* c0 = c * 2^j, i.e. xx^(2^k)*x^c = (A^(2^(k - j))*c)^(2^j) */
315           l -= j;
316           while (l--)
317             if (use_redc) mpz_redc (xx, xx, xx, mod, invm);
318             else
319               {
320                 mpz_mul (xx, xx, xx);
321                 mpz_mod (xx, xx, mod);
322               }
323           if (use_redc)
324             mpz_redc (xx, xx, g[c >> 1], mod, invm);
325           else
326             {
327               mpz_mul (xx, xx, g[c >> 1]);
328               mpz_mod (xx, xx, mod);
329             }
330         }
331       else
332         j = l;                          /* case c=0 */
333       while (j--)
334         {
335           if (use_redc)
336             mpz_redc (xx, xx, xx, mod, invm);
337           else
338             {
339               mpz_mul (xx, xx, xx);
340               mpz_mod (xx, xx, mod);
341             }
342         }
343 #ifdef POWM_DEBUG
344       printf ("x^"); mpz_out_str (0, 10, exp);
345       printf ("*2^%u mod m = ", n * BITS_PER_MP_LIMB); mpz_out_str (0, 10, xx);
346       putchar ('\n');
347 #endif
348     }
349
350   /* now convert back xx to xx/R^n */
351   if (use_redc)
352     {
353       mpz_set_ui (g[0], 1);
354       mpz_redc (xx, xx, g[0], mod, invm);
355       if (mpz_cmp (xx, mod) >= 0)
356         mpz_sub (xx, xx, mod);
357     }
358   mpz_set (res, xx);
359
360   mpz_clear (xx);
361   for (i = 0; i < K / 2; i++)
362     mpz_clear (g[i]);
363   (*_mp_free_func) (g, K / 2 * sizeof (mpz_t));
364 }