initial checkin
[org.ibex.nanogoat.git] / src / org / bouncycastle / crypto / engines / RSAEngine.java
1 package org.bouncycastle.crypto.engines;
2
3 import java.math.BigInteger;
4
5 import org.bouncycastle.crypto.CipherParameters;
6 import org.bouncycastle.crypto.DataLengthException;
7 import org.bouncycastle.crypto.AsymmetricBlockCipher;
8 import org.bouncycastle.crypto.params.RSAKeyParameters;
9 import org.bouncycastle.crypto.params.RSAPrivateCrtKeyParameters;
10
11 /**
12  * this does your basic RSA algorithm.
13  */
14 public class RSAEngine
15     implements AsymmetricBlockCipher
16 {
17     private RSAKeyParameters        key;
18     private boolean                 forEncryption;
19
20     /**
21      * initialise the RSA engine.
22      *
23      * @param forEncryption true if we are encrypting, false otherwise.
24      * @param param the necessary RSA key parameters.
25      */
26     public void init(
27         boolean             forEncryption,
28         CipherParameters    param)
29     {
30         this.key = (RSAKeyParameters)param;
31         this.forEncryption = forEncryption;
32     }
33
34     /**
35      * Return the maximum size for an input block to this engine.
36      * For RSA this is always one byte less than the key size on
37      * encryption, and the same length as the key size on decryption.
38      *
39      * @return maximum size for an input block.
40      */
41     public int getInputBlockSize()
42     {
43         int     bitSize = key.getModulus().bitLength();
44
45         if (forEncryption)
46         {
47             return (bitSize + 7) / 8 - 1;
48         }
49         else
50         {
51             return (bitSize + 7) / 8;
52         }
53     }
54
55     /**
56      * Return the maximum size for an output block to this engine.
57      * For RSA this is always one byte less than the key size on
58      * decryption, and the same length as the key size on encryption.
59      *
60      * @return maximum size for an output block.
61      */
62     public int getOutputBlockSize()
63     {
64         int     bitSize = key.getModulus().bitLength();
65
66         if (forEncryption)
67         {
68             return (bitSize + 7) / 8;
69         }
70         else
71         {
72             return (bitSize + 7) / 8 - 1;
73         }
74     }
75
76     /**
77      * Process a single block using the basic RSA algorithm.
78      *
79      * @param in the input array.
80      * @param inOff the offset into the input buffer where the data starts.
81      * @param inLen the length of the data to be processed.
82      * @return the result of the RSA process.
83      * @exception DataLengthException the input block is too large.
84      */
85     public byte[] processBlock(
86         byte[]  in,
87         int     inOff,
88         int     inLen)
89     {
90         if (inLen > (getInputBlockSize() + 1))
91         {
92             throw new DataLengthException("input too large for RSA cipher.\n");
93         }
94         else if (inLen == (getInputBlockSize() + 1) && (in[inOff] & 0x80) != 0)
95         {
96             throw new DataLengthException("input too large for RSA cipher.\n");
97         }
98
99         byte[]  block;
100
101         if (inOff != 0 || inLen != in.length)
102         {
103             block = new byte[inLen];
104
105             System.arraycopy(in, inOff, block, 0, inLen);
106         }
107         else
108         {
109             block = in;
110         }
111
112         BigInteger  input = new BigInteger(1, block);
113         byte[]      output;
114
115         if (key instanceof RSAPrivateCrtKeyParameters)
116         {
117             //
118             // we have the extra factors, use the Chinese Remainder Theorem - the author
119             // wishes to express his thanks to Dirk Bonekaemper at rtsffm.com for 
120             // advice regarding the expression of this.
121             //
122             RSAPrivateCrtKeyParameters crtKey = (RSAPrivateCrtKeyParameters)key;
123
124             BigInteger d = crtKey.getExponent();
125             BigInteger p = crtKey.getP();
126             BigInteger q = crtKey.getQ();
127             BigInteger dP = crtKey.getDP();
128             BigInteger dQ = crtKey.getDQ();
129             BigInteger qInv = crtKey.getQInv();
130     
131             BigInteger mP, mQ, h, m;
132     
133             // mP = ((input mod p) ^ dP)) mod p
134             mP = (input.remainder(p)).modPow(dP, p);
135     
136             // mQ = ((input mod q) ^ dQ)) mod q
137             mQ = (input.remainder(q)).modPow(dQ, q);
138     
139             // h = qInv * (mP - mQ) mod p
140             h = mP.subtract(mQ);
141             h = h.multiply(qInv);
142             h = h.mod(p);               // mod (in Java) returns the positive residual
143     
144             // m = h * q + mQ
145             m = h.multiply(q);
146             m = m.add(mQ);
147     
148             output = m.toByteArray();
149         }
150         else
151         {
152             output = input.modPow(
153                         key.getExponent(), key.getModulus()).toByteArray();
154         }
155
156         if (forEncryption)
157         {
158             if (output[0] == 0 && output.length > getOutputBlockSize())        // have ended up with an extra zero byte, copy down.
159             {
160                 byte[]  tmp = new byte[output.length - 1];
161
162                 System.arraycopy(output, 1, tmp, 0, tmp.length);
163
164                 return tmp;
165             }
166
167             if (output.length < getOutputBlockSize())     // have ended up with less bytes than normal, lengthen
168             {
169                 byte[]  tmp = new byte[getOutputBlockSize()];
170
171                 System.arraycopy(output, 0, tmp, tmp.length - output.length, output.length);
172
173                 return tmp;
174             }
175         }
176         else
177         {
178             if (output[0] == 0)        // have ended up with an extra zero byte, copy down.
179             {
180                 byte[]  tmp = new byte[output.length - 1];
181
182                 System.arraycopy(output, 1, tmp, 0, tmp.length);
183
184                 return tmp;
185             }
186         }
187         return output;
188     }
189 }