2003/02/12 06:21:04
[org.ibex.core.git] / src / org / bouncycastle / crypto / engines / RSAEngine.java
1 package org.bouncycastle.crypto.engines;
2
3 import java.math.BigInteger;
4
5 import org.bouncycastle.crypto.CipherParameters;
6 import org.bouncycastle.crypto.DataLengthException;
7 import org.bouncycastle.crypto.AsymmetricBlockCipher;
8 import org.bouncycastle.crypto.params.AsymmetricKeyParameter;
9 import org.bouncycastle.crypto.params.RSAKeyParameters;
10 import org.bouncycastle.crypto.params.RSAPrivateCrtKeyParameters;
11
12 /**
13  * this does your basic RSA algorithm.
14  */
15 public class RSAEngine
16     implements AsymmetricBlockCipher
17 {
18     private RSAKeyParameters        key;
19     private boolean                 forEncryption;
20
21     /**
22      * initialise the RSA engine.
23      *
24      * @param forEncryption true if we are encrypting, false otherwise.
25      * @param param the necessary RSA key parameters.
26      */
27     public void init(
28         boolean             forEncryption,
29         CipherParameters    param)
30     {
31         this.key = (RSAKeyParameters)param;
32         this.forEncryption = forEncryption;
33     }
34
35     /**
36      * Return the maximum size for an input block to this engine.
37      * For RSA this is always one byte less than the key size on
38      * encryption, and the same length as the key size on decryption.
39      *
40      * @return maximum size for an input block.
41      */
42     public int getInputBlockSize()
43     {
44         int     bitSize = key.getModulus().bitLength();
45
46         if (forEncryption)
47         {
48             return (bitSize + 7) / 8 - 1;
49         }
50         else
51         {
52             return (bitSize + 7) / 8;
53         }
54     }
55
56     /**
57      * Return the maximum size for an output block to this engine.
58      * For RSA this is always one byte less than the key size on
59      * decryption, and the same length as the key size on encryption.
60      *
61      * @return maximum size for an output block.
62      */
63     public int getOutputBlockSize()
64     {
65         int     bitSize = key.getModulus().bitLength();
66
67         if (forEncryption)
68         {
69             return (bitSize + 7) / 8;
70         }
71         else
72         {
73             return (bitSize + 7) / 8 - 1;
74         }
75     }
76
77     /**
78      * Process a single block using the basic RSA algorithm.
79      *
80      * @param in the input array.
81      * @param inOff the offset into the input buffer where the data starts.
82      * @param inLen the length of the data to be processed.
83      * @return the result of the RSA process.
84      * @exception DataLengthException the input block is too large.
85      */
86     public byte[] processBlock(
87         byte[]  in,
88         int     inOff,
89         int     inLen)
90     {
91         if (inLen > (getInputBlockSize() + 1))
92         {
93             throw new DataLengthException("input too large for RSA cipher.\n");
94         }
95         else if (inLen == (getInputBlockSize() + 1) && (in[inOff] & 0x80) != 0)
96         {
97             throw new DataLengthException("input too large for RSA cipher.\n");
98         }
99
100         byte[]  block;
101
102         if (inOff != 0 || inLen != in.length)
103         {
104             block = new byte[inLen];
105
106             System.arraycopy(in, inOff, block, 0, inLen);
107         }
108         else
109         {
110             block = in;
111         }
112
113         BigInteger  input = new BigInteger(1, block);
114         byte[]      output;
115
116         if (key instanceof RSAPrivateCrtKeyParameters)
117         {
118             //
119             // we have the extra factors, use the Chinese Remainder Theorem - the author
120             // wishes to express his thanks to Dirk Bonekaemper at rtsffm.com for 
121             // advice regarding the expression of this.
122             //
123             RSAPrivateCrtKeyParameters crtKey = (RSAPrivateCrtKeyParameters)key;
124
125             BigInteger d = crtKey.getExponent();
126             BigInteger p = crtKey.getP();
127             BigInteger q = crtKey.getQ();
128             BigInteger dP = crtKey.getDP();
129             BigInteger dQ = crtKey.getDQ();
130             BigInteger qInv = crtKey.getQInv();
131     
132             BigInteger mP, mQ, h, m;
133     
134             // mP = ((input mod p) ^ dP)) mod p
135             mP = (input.remainder(p)).modPow(dP, p);
136     
137             // mQ = ((input mod q) ^ dQ)) mod q
138             mQ = (input.remainder(q)).modPow(dQ, q);
139     
140             // h = qInv * (mP - mQ) mod p
141             h = mP.subtract(mQ);
142             h = h.multiply(qInv);
143             h = h.mod(p);               // mod (in Java) returns the positive residual
144     
145             // m = h * q + mQ
146             m = h.multiply(q);
147             m = m.add(mQ);
148     
149             output = m.toByteArray();
150         }
151         else
152         {
153             output = input.modPow(
154                         key.getExponent(), key.getModulus()).toByteArray();
155         }
156
157         if (forEncryption)
158         {
159             if (output[0] == 0 && output.length > getOutputBlockSize())        // have ended up with an extra zero byte, copy down.
160             {
161                 byte[]  tmp = new byte[output.length - 1];
162
163                 System.arraycopy(output, 1, tmp, 0, tmp.length);
164
165                 return tmp;
166             }
167
168             if (output.length < getOutputBlockSize())     // have ended up with less bytes than normal, lengthen
169             {
170                 byte[]  tmp = new byte[getOutputBlockSize()];
171
172                 System.arraycopy(output, 0, tmp, tmp.length - output.length, output.length);
173
174                 return tmp;
175             }
176         }
177         else
178         {
179             if (output[0] == 0)        // have ended up with an extra zero byte, copy down.
180             {
181                 byte[]  tmp = new byte[output.length - 1];
182
183                 System.arraycopy(output, 1, tmp, 0, tmp.length);
184
185                 return tmp;
186             }
187         }
188         return output;
189     }
190 }