2002/03/21 01:19:32
[org.ibex.core.git] / src / org / bouncycastle / crypto / engines / RSAEngine.java
diff --git a/src/org/bouncycastle/crypto/engines/RSAEngine.java b/src/org/bouncycastle/crypto/engines/RSAEngine.java
new file mode 100644 (file)
index 0000000..2cbac54
--- /dev/null
@@ -0,0 +1,195 @@
+package org.bouncycastle.crypto.engines;
+
+import java.math.BigInteger;
+
+import org.bouncycastle.crypto.CipherParameters;
+import org.bouncycastle.crypto.DataLengthException;
+import org.bouncycastle.crypto.AsymmetricBlockCipher;
+import org.bouncycastle.crypto.params.AsymmetricKeyParameter;
+import org.bouncycastle.crypto.params.RSAKeyParameters;
+import org.bouncycastle.crypto.params.RSAPrivateCrtKeyParameters;
+
+/**
+ * this does your basic RSA algorithm.
+ */
+public class RSAEngine
+    implements AsymmetricBlockCipher
+{
+    private RSAKeyParameters        key;
+    private boolean                 forEncryption;
+
+    /**
+     * initialise the RSA engine.
+     *
+     * @param forEncryption treu if we are encrypting, false otherwise.
+     * @param param the necessary RSA key parameters.
+     */
+    public void init(
+        boolean             forEncryption,
+        CipherParameters    param)
+    {
+        this.key = (RSAKeyParameters)param;
+        this.forEncryption = forEncryption;
+    }
+
+    /**
+     * Return the maximum size for an input block to this engine.
+     * For RSA this is always one byte less than the key size on
+     * encryption, and the same length as the key size on decryption.
+     *
+     * @return maximum size for an input block.
+     */
+    public int getInputBlockSize()
+    {
+        int     bitSize = key.getModulus().bitLength();
+
+        if (forEncryption)
+        {
+            if ((bitSize % 8) == 0)
+            {
+                return bitSize / 8 - 1;
+            }
+
+            return bitSize / 8;
+        }
+        else
+        {
+            return (bitSize + 7) / 8;
+        }
+    }
+
+    /**
+     * Return the maximum size for an output block to this engine.
+     * For RSA this is always one byte less than the key size on
+     * decryption, and the same length as the key size on encryption.
+     *
+     * @return maximum size for an input block.
+     */
+    public int getOutputBlockSize()
+    {
+        int     bitSize = key.getModulus().bitLength();
+
+        if (forEncryption)
+        {
+            return ((bitSize - 1) + 7) / 8;
+        }
+        else
+        {
+            return (bitSize - 7) / 8;
+        }
+    }
+
+    /**
+     * Process a single block using the basic RSA algorithm.
+     *
+     * @param in the input array.
+     * @param inOff the offset into the input buffer where the data starts.
+     * @param inLen the length of the data to be processed.
+     * @return the result of the RSA process.
+     * @exception DataLengthException the input block is too large.
+     */
+    public byte[] processBlock(
+        byte[]  in,
+        int     inOff,
+        int     inLen)
+    {
+        if (inLen > (getInputBlockSize() + 1))
+        {
+            throw new DataLengthException("input too large for RSA cipher.\n");
+        }
+        else if (inLen == (getInputBlockSize() + 1) && (in[inOff] & 0x80) != 0)
+        {
+            throw new DataLengthException("input too large for RSA cipher.\n");
+        }
+
+        byte[]  block;
+
+        if (inOff != 0 || inLen != in.length)
+        {
+            block = new byte[inLen];
+
+            System.arraycopy(in, inOff, block, 0, inLen);
+        }
+        else
+        {
+            block = in;
+        }
+
+        BigInteger  input = new BigInteger(1, block);
+        byte[]      output;
+
+        if (key instanceof RSAPrivateCrtKeyParameters)
+        {
+            //
+            // we have the extra factors, use the Chinese Remainder Theorem - the author
+            // wishes to express his thanks to Dirk Bonekaemper at rtsffm.com for 
+            // advice regarding the expression of this.
+            //
+            RSAPrivateCrtKeyParameters crtKey = (RSAPrivateCrtKeyParameters)key;
+
+            BigInteger d = crtKey.getExponent();
+            BigInteger p = crtKey.getP();
+            BigInteger q = crtKey.getQ();
+            BigInteger dP = crtKey.getDP();
+            BigInteger dQ = crtKey.getDQ();
+            BigInteger qInv = crtKey.getQInv();
+    
+            BigInteger mP, mQ, h, m;
+    
+            // mP = ((input mod p) ^ dP)) mod p
+            mP = (input.remainder(p)).modPow(dP, p);
+    
+            // mQ = ((input mod q) ^ dQ)) mod q
+            mQ = (input.remainder(q)).modPow(dQ, q);
+    
+            // h = qInv * (mP - mQ) mod p
+            h = mP.subtract(mQ);
+            h = h.multiply(qInv);
+            h = h.mod(p);               // mod (in Java) returns the positive residual
+    
+            // m = h * q + mQ
+            m = h.multiply(q);
+            m = m.add(mQ);
+    
+            output = m.toByteArray();
+        }
+        else
+        {
+            output = input.modPow(
+                        key.getExponent(), key.getModulus()).toByteArray();
+        }
+
+        if (forEncryption)
+        {
+            if (output[0] == 0 && output.length > getOutputBlockSize())        // have ended up with an extra zero byte, copy down.
+            {
+                byte[]  tmp = new byte[output.length - 1];
+
+                System.arraycopy(output, 1, tmp, 0, tmp.length);
+
+                return tmp;
+            }
+
+            if (output.length < getOutputBlockSize())     // have ended up with less bytes than normal, lengthen
+            {
+                byte[]  tmp = new byte[getOutputBlockSize()];
+
+                System.arraycopy(output, 0, tmp, tmp.length - output.length, output.length);
+
+                return tmp;
+            }
+        }
+        else
+        {
+            if (output[0] == 0)        // have ended up with an extra zero byte, copy down.
+            {
+                byte[]  tmp = new byte[output.length - 1];
+
+                System.arraycopy(output, 1, tmp, 0, tmp.length);
+
+                return tmp;
+            }
+        }
+        return output;
+    }
+}