make verify certs public
[org.ibex.crypto.git] / src / org / ibex / net / SSL.java
index 7078e81..f53a1ae 100644 (file)
 
 package org.ibex.net;
 
-import org.ibex.der.DER.Exception;
-import org.ibex.der.DER.InputStream;
-import org.ibex.x509.X509Certificate;
-import org.ibex.x509.RSAPublicKey;
-import org.ibex.x509.X509Name;
-import org.ibex.crypto.HMAC;
-import org.ibex.crypto.PKCS1;
-import org.ibex.crypto.RC4;
-import org.ibex.crypto.RSA;
-import org.ibex.crypto.Digest;
-import org.ibex.crypto.MD5;
-import org.ibex.crypto.SHA1;
-
+import org.ibex.crypto.*;
 import java.security.SecureRandom;
 
 import java.net.Socket;
@@ -40,12 +28,23 @@ import java.util.Vector;
 // FEATURE: Server socket
 
 public class SSL extends Socket {
+    public static final byte TLS_RSA_WITH_AES_256_CBC_SHA = 0x35;
+    public static final byte TLS_RSA_WITH_AES_128_CBC_SHA = 0x2f;
+    public static final byte SSL_RSA_WITH_RC4_128_SHA = 0x05;
+    public static final byte SSL_RSA_WITH_RC4_128_MD5 = 0x04;
+    
+    private static final byte[] DEFAULT_CIPHER_PREFS = new byte[]{
+        TLS_RSA_WITH_AES_256_CBC_SHA,TLS_RSA_WITH_AES_128_CBC_SHA,
+        SSL_RSA_WITH_RC4_128_SHA,SSL_RSA_WITH_RC4_128_MD5
+    };
+    
     private String hostname;
     
     private int negotiated;
     
     private boolean tls = true;
     private boolean sha;
+    private int aes;
     
     private final DataInputStream rawIS;
     private final DataOutputStream rawOS;
@@ -59,8 +58,8 @@ public class SSL extends Socket {
     private Digest serverWriteMACDigest;        
     private byte[] masterSecret;
     
-    private RC4 writeRC4;
-    private RC4 readRC4;
+    private Cipher writeCipher;
+    private Cipher readCipher;
     
     private long serverSequenceNumber;
     private long clientSequenceNumber;
@@ -87,6 +86,11 @@ public class SSL extends Socket {
     private byte[] readRecordBuf = new byte[16384+20];   // 20 == sizeof(sha1 hash)
     private byte[] readRecordScratch = new byte[16384+20];
     
+    // These are only uses for CBCs
+    private int blockSize; // block size (0 for stream ciphers)
+    private byte[] padBuf;
+    private byte[] prevBlock;
+    
     private ByteArrayOutputStream handshakesBuffer;
     
     // End Buffers
@@ -126,14 +130,16 @@ public class SSL extends Socket {
 
     public synchronized void setTLS(boolean b) { if(negotiated!=0) throw new IllegalStateException("already negotiated"); tls = b; }
     
-    public void negotiate() throws IOException { negotiate(null); }
-    public synchronized void negotiate(State state) throws IOException {
+    public void negotiate() throws IOException { negotiate(null,null); }
+    public void negotiate(State state) throws IOException { negotiate(state,null); }
+    public void negotiate(byte[] cipherPrefs) throws IOException { negotiate(null,cipherPrefs); }
+    private void negotiate(State state, byte[] cipherPrefs) throws IOException {
         if(negotiated != 0) throw new IllegalStateException("already negotiated");
         
         handshakesBuffer = new ByteArrayOutputStream();
         
         try {
-            sendClientHello(state != null ? state.sessionID : null);
+            sendClientHello(state != null ? state.sessionID : null, cipherPrefs);
             flush();
             debug("sent ClientHello (" + (tls?"TLSv1.0":"SSLv3.0")+")");
             
@@ -183,7 +189,7 @@ public class SSL extends Socket {
     }
     
     private void negotiateNew() throws IOException {
-        X509Certificate[] certs = receiveServerCertificates();
+        X509.Certificate[] certs = receiveServerCertificates();
         debug("got Certificate");
         
         boolean gotCertificateRequest = false;
@@ -245,12 +251,14 @@ public class SSL extends Socket {
     public boolean isActive() { return !closed; }
     public boolean isNegotiated() { return (negotiated&3) == 3; }
     
-    private void sendClientHello(byte[] sessionID) throws IOException {
+    private void sendClientHello(byte[] sessionID, byte[] cipherPrefs) throws IOException {
         if(sessionID != null && sessionID.length > 256) throw new IllegalArgumentException("sessionID");
-        // 2 = version, 32 = randomvalue, 1 = sessionID size, 2 = cipher list size, 4 = the two ciphers,
+        if(cipherPrefs == null) cipherPrefs = DEFAULT_CIPHER_PREFS;
+        else if(cipherPrefs.length > 4) throw new IllegalArgumentException("too many cipherPrefs");
+        // 2 = version, 32 = randomvalue, 1 = sessionID size, 2 = cipher list size, 8 = the four ciphers,
         // 2 = compression length/no compression
         int p = 0;
-        byte[] buf = new byte[2+32+1+(sessionID == null ? 0 : sessionID.length)+2+2+4];
+        byte[] buf = new byte[2+32+1+(sessionID == null ? 0 : sessionID.length)+2+(cipherPrefs.length * 2)+2];
         buf[p++] = 0x03; // major version
         buf[p++] = tls ? (byte)0x01 : (byte)0x00;
         
@@ -267,16 +275,18 @@ public class SSL extends Socket {
         buf[p++] = sessionID != null ? (byte)sessionID.length : 0;
         if(sessionID != null && sessionID.length != 0) System.arraycopy(sessionID,0,buf,p,sessionID.length);
         p += sessionID != null ? sessionID.length : 0;
-        buf[p++] = 0x00; // 4 bytes of ciphers
-        buf[p++] = 0x04;
-        buf[p++] = 0x00; // SSL_RSA_WITH_RC4_128_SHA
-        buf[p++] = 0x05;
-        buf[p++] = 0x00; // SSL_RSA_WITH_RC4_128_MD5
-        buf[p++] = 0x04; 
-        
-        buf[p++] = 0x01;
-        buf[p++] = 0x00;
-                
+        
+        buf[p++] = 0x00; // 8 bytes of ciphers
+        buf[p++] = (byte)(cipherPrefs.length * 2);
+        
+        for(int i=0;i<cipherPrefs.length;i++) {
+            buf[p++] = 0x00;
+            buf[p++] = cipherPrefs[i];
+        }
+        
+        buf[p++] = 0x01; // compression length
+        buf[p++] = 0x00; // no compression
+        
         sendHandshake((byte)1,buf);
         flush();
     }
@@ -300,16 +310,19 @@ public class SSL extends Socket {
         p += sessionID.length;
         int cipher = ((buf[p]&0xff)<<8) | (buf[p+1]&0xff);
         p += 2;
-        switch(cipher) {
-            case 0x0004: sha = false; debug("Using SSL_RSA_WITH_RC4_128_MD5"); break;
-            case 0x0005: sha = true;  debug("Using SSL_RSA_WITH_RC4_128_SHA"); break;
+        if((cipher>>>8)!=0) throw new Exn("Unsupported cipher " + cipher);
+        switch(cipher&0xff) {
+            case SSL_RSA_WITH_RC4_128_MD5:     sha = false; aes=0;    debug("Using SSL_RSA_WITH_RC4_128_MD5"); break;
+            case SSL_RSA_WITH_RC4_128_SHA:     sha = true;  aes=0;    debug("Using SSL_RSA_WITH_RC4_128_SHA"); break;
+            case TLS_RSA_WITH_AES_128_CBC_SHA: sha = true; aes = 128; debug("Using TLS_RSA_WITH_AES_128_CBC_SHA"); break;
+            case TLS_RSA_WITH_AES_256_CBC_SHA: sha = true; aes = 256; debug("Using TLS_RSA_WITH_AES_256_CBC_SHA"); break;
             default: throw new Exn("Unsupported cipher " + cipher);
         }
         mac = new byte[sha ? 20 : 16];
         if(buf[p++] != 0x0) throw new Exn("unsupported compression " + buf[p-1]);
     }
     
-    private X509Certificate[] receiveServerCertificates() throws IOException {
+    private X509.Certificate[] receiveServerCertificates() throws IOException {
         byte[] buf = readHandshake();
         if(buf[0] != 11) throw new Exn("expected a Certificate message");
         if((((buf[4]&0xff)<<16)|((buf[5]&0xff)<<8)|((buf[6]&0xff)<<0)) != buf.length-7) throw new Exn("size mismatch in Certificate message");
@@ -318,21 +331,21 @@ public class SSL extends Socket {
         
         for(int i=p;i<buf.length-3;i+=((buf[p+0]&0xff)<<16)|((buf[p+1]&0xff)<<8)|((buf[p+2]&0xff)<<0)) count++;
         if(count == 0) throw new Exn("server didn't provide any certificates");
-        X509Certificate[] certs = new X509Certificate[count];
+        X509.Certificate[] certs = new X509.Certificate[count];
         count = 0;
         while(p < buf.length) {
             int len = ((buf[p+0]&0xff)<<16)|((buf[p+1]&0xff)<<8)|((buf[p+2]&0xff)<<0);
             p += 3;
             if(p + len > buf.length) throw new Exn("Certificate message cut short");
-            certs[count++] = new X509Certificate(new ByteArrayInputStream(buf,p,len));
+            certs[count++] = new X509.Certificate(new ByteArrayInputStream(buf,p,len));
             p += len;
         }
         return certs;
     }
     
-    private void sendClientKeyExchange(X509Certificate serverCert) throws IOException {
+    private void sendClientKeyExchange(X509.Certificate serverCert) throws IOException {
         byte[] encryptedPreMasterSecret;
-        RSAPublicKey pks = serverCert.getRSAPublicKey();
+        RSA.PublicKey pks = serverCert.getRSAPublicKey();
         PKCS1 pkcs1 = new PKCS1(new RSA(pks.modulus,pks.exponent,false),random);
         encryptedPreMasterSecret = pkcs1.encode(preMasterSecret);
         byte[] buf;
@@ -375,10 +388,11 @@ public class SSL extends Socket {
     
     public void initCrypto() {
         byte[] keyMaterial;
+        byte[] ivBlock;
         
         if(tls) {
             keyMaterial = tlsPRF(
-                    (mac.length + 16 + 0)*2, // MAC len + key len + iv len
+                    (mac.length + (aes==256 ? 32 : 16) + (aes==0 ? 0 : 16))*2, // MAC len + key len + iv len
                     masterSecret,
                     getBytes("key expansion"),
                     concat(serverRandom,clientRandom)
@@ -392,28 +406,38 @@ public class SSL extends Socket {
                         md5(new byte[][] { masterSecret,
                                 sha1(new byte[][] { crap, masterSecret, serverRandom, clientRandom }) }) });
             }            
+            if(aes != 0) throw new Error("should never happen");
         }
 
         byte[] clientWriteMACSecret = new byte[mac.length];
         byte[] serverWriteMACSecret = new byte[mac.length];
-        byte[] clientWriteKey = new byte[16];
-        byte[] serverWriteKey = new byte[16];
+        byte[] clientWriteKey = new byte[aes==256 ? 32 : 16];
+        byte[] serverWriteKey = new byte[aes==256 ? 32 : 16];
+        byte[] clientWriteIV = aes == 0 ? null : new byte[16];
+        byte[] serverWriteIV = aes == 0 ? null : new byte[16];
         
         int p = 0;
         System.arraycopy(keyMaterial, p, clientWriteMACSecret, 0, mac.length); p += mac.length;
         System.arraycopy(keyMaterial, p, serverWriteMACSecret, 0, mac.length); p += mac.length;
-        System.arraycopy(keyMaterial, p, clientWriteKey, 0, 16); p += 16; 
-        System.arraycopy(keyMaterial, p, serverWriteKey, 0, 16); p += 16;
+        System.arraycopy(keyMaterial, p, clientWriteKey, 0, clientWriteKey.length); p += clientWriteKey.length; 
+        System.arraycopy(keyMaterial, p, serverWriteKey, 0, serverWriteKey.length); p += serverWriteKey.length;
+        if(clientWriteIV != null) System.arraycopy(keyMaterial,p,clientWriteIV,0,16); p += 16;
+        if(serverWriteIV != null) System.arraycopy(keyMaterial,p,serverWriteIV,0,16); p += 16;
         
         Digest inner;
         
-        writeRC4 = new RC4(clientWriteKey);
+        writeCipher = aes==0 ? (Cipher)new RC4(clientWriteKey) : (Cipher)new CBC(new AES(clientWriteKey,false),16,clientWriteIV,false);
         inner = sha ? (Digest)new SHA1() : (Digest)new MD5();
         clientWriteMACDigest = tls ? (Digest) new HMAC(inner,clientWriteMACSecret) : (Digest)new SSLv3HMAC(inner,clientWriteMACSecret);
         
-        readRC4 = new RC4(serverWriteKey);
+        readCipher = aes==0 ? (Cipher)new RC4(serverWriteKey) : (Cipher)new CBC(new AES(serverWriteKey,true),16,serverWriteIV,true);
         inner = sha ? (Digest)new SHA1() : (Digest)new MD5();
         serverWriteMACDigest = tls ? (Digest)new HMAC(inner,serverWriteMACSecret) : (Digest)new SSLv3HMAC(inner,serverWriteMACSecret);
+        
+        if(aes != 0) {
+            blockSize = 16; // aes block size == 16
+            padBuf = new byte[mac.length+blockSize]; // worse case, 15 bytes of data, mac, and padding byte
+        }
     }
     
     private void sendFinished() throws IOException {
@@ -492,10 +516,27 @@ public class SSL extends Socket {
             if((negotiated & 1) != 0) {
                 computeMAC(proto,payload,off,len,clientWriteMACDigest,clientSequenceNumber);
                 // FEATURE: Encode in place
-                writeRC4.process(payload,off,sendRecordBuf,0,len);
-                writeRC4.process(mac,0,sendRecordBuf,len,macLength);
-                rawOS.writeShort(len + macLength);
-                rawOS.write(sendRecordBuf,0, len +macLength);
+                if(blockSize != 0) {
+                    int firstShot = len & ~(blockSize-1);
+                    if(firstShot != 0) writeCipher.process(payload,off,sendRecordBuf,0,firstShot);
+                    int _off = off + firstShot; // offset in sendRecordBuf
+                    int _len = len - firstShot; // length of data in padBuf
+                    if(_len != 0) System.arraycopy(payload,_off,padBuf,0,_len);
+                    System.arraycopy(mac,0,padBuf,_len,macLength);
+                    _len += macLength;
+                    int extra = blockSize - (_len&~blockSize);
+                    if(extra == 0) extra = 16;
+                    for(int i=0;i<extra;i++) padBuf[_len+i] = (byte)(extra-1);
+                    _len += extra;
+                    writeCipher.process(padBuf,0,sendRecordBuf,_off,_len);
+                    rawOS.writeShort(firstShot+_len);
+                    rawOS.write(sendRecordBuf,0,firstShot+_len);
+                } else {
+                    writeCipher.process(payload,off,sendRecordBuf,0,len);
+                    writeCipher.process(mac,0,sendRecordBuf,len,macLength);
+                    rawOS.writeShort(len + macLength);
+                    rawOS.write(sendRecordBuf,0, len +macLength);
+                }
                 clientSequenceNumber++;
             } else {
                 rawOS.writeShort(len);
@@ -505,7 +546,7 @@ public class SSL extends Socket {
             off += len;
         }
     }
-    
+
     private byte[] readHandshake() throws IOException {
         if(handshakeDataLength == 0) {
             handshakeDataStart = 0;
@@ -554,9 +595,19 @@ public class SSL extends Socket {
             }
             
             if((negotiated & 2) != 0) {
-                if(len < macLength) throw new Exn("packet size < macLength");
                 // FEATURE: Decode in place
-                readRC4.process(readRecordScratch,0,readRecordBuf,0,len);
+                if(blockSize != 0 && (len % blockSize) != 0) throw new Exn("input not a multiple of the cipher's block size");
+                readCipher.process(readRecordScratch,0,readRecordBuf,0,len);
+                
+                if(blockSize != 0) {
+                    if(!tls) throw new Error("should never happen");
+                    int padding = readRecordBuf[len-1]&0xff;
+                    if(padding >= blockSize) throw new Exn("invalid padding length: " + padding);
+                    for(int i=0;i<padding;i++) if(readRecordBuf[len-padding-1] != padding) throw new Exn("bad padding");
+                    len -= (padding+1);
+                }
+                if(len < macLength) throw new Exn("packet size < macLength");
+                
                 computeMAC(proto,readRecordBuf,0,len-macLength,serverWriteMACDigest,serverSequenceNumber);
                 for(int i=0;i<macLength;i++)
                     if(mac[i] != readRecordBuf[len-macLength+i])
@@ -684,7 +735,7 @@ public class SSL extends Socket {
      A(i) = HMAC_hash(secret, A(i-1))
      */           
     private byte[] tlsPRF(int size,byte[] secret, byte[] label, byte[] seed) {
-        if(size > 112) throw new IllegalArgumentException("size > 112");
+        if(size > 140) throw new IllegalArgumentException("size > 140: " + size);
         seed = concat(label,seed);
         
         int half_length = (secret.length + 1) / 2;
@@ -696,8 +747,8 @@ public class SSL extends Socket {
         Digest hmac_md5 = new HMAC(new MD5(),s1);
         Digest hmac_sha = new HMAC(new SHA1(),s2);
         
-        byte[] md5out = new byte[112];
-        byte[] shaout = new byte[120];
+        byte[] md5out = new byte[144];
+        byte[] shaout = new byte[140];
         byte[] digest = new byte[20];
         int n;
         
@@ -733,7 +784,37 @@ public class SSL extends Socket {
         return ret;
     }
 
-    public static class SSLv3HMAC implements Digest {
+    public static class CBC implements Cipher {
+        private final Cipher c;
+        private final byte[] iv;
+        private final int blockSize;
+        private final boolean reverse;
+        public CBC(Cipher c, int blockSize, byte[] iv, boolean reverse) {
+            this.c = c;
+            this.blockSize = blockSize;
+            this.iv = iv;
+            this.reverse = reverse;
+            if(iv.length != blockSize) throw new IllegalArgumentException("iv.length != blockSize");
+        }
+        
+        public void process(byte[] in, int inp, byte[] out, int outp, int len) {
+            if((len % blockSize) != 0) throw new IllegalArgumentException("buffer must be a multiple of block size");
+            while(len != 0) {
+                if(!reverse) {
+                    for(int i=0;i<blockSize;i++) iv[i] ^= in[inp+i]; // mangle the cleartext
+                    c.process(iv,0,out,outp,blockSize); // process it
+                    System.arraycopy(out,outp,iv,0,blockSize); // copy the block to iv
+                } else {
+                    c.process(in,inp,out,outp,blockSize); // process the ciphertext
+                    for(int i=0;i<blockSize;i++) out[outp+i] ^= iv[i]; // mangle the cleartext
+                    System.arraycopy(in,inp,iv,0,blockSize); // copy the ciphertext to iv
+                }
+                inp+=16; outp+=16; len-=16;
+            }
+        }
+    }
+
+    public static class SSLv3HMAC extends Digest {
         private final Digest h;
         private final byte[] digest;
         private final byte[] key;
@@ -766,6 +847,9 @@ public class SSL extends Socket {
             h.doFinal(out,off);
             reset();
         }
+        protected void processWord(byte[] in, int inOff) {}
+        protected void processLength(long bitLength) {}
+        protected void processBlock() {}
     }
     
     //
@@ -904,10 +988,10 @@ public class SSL extends Socket {
     }
     
     public static boolean debugOn = false;
-    private static void debug(Object o) { if(debugOn) System.err.println("[BriSSL-Debug] " + o.toString()); }
-    private static void log(Object o) { System.err.println("[BriSSL] " + o.toString()); }
+    private static void debug(Object o) { if(debugOn) System.err.println("[IbexSSL-Debug] " + o.toString()); }
+    private static void log(Object o) { System.err.println("[IbexSSL] " + o.toString()); }
             
-    private static void verifyCerts(X509Certificate[] certs) throws DER.Exception, Exn {
+    public static void verifyCerts(X509.Certificate[] certs) throws DER.Exception, Exn {
         try {
             verifyCerts_(certs);
         } catch(RuntimeException e) {
@@ -916,20 +1000,17 @@ public class SSL extends Socket {
         }
     }
     
-    private static void verifyCerts_(X509Certificate[] certs) throws DER.Exception, Exn {
-        boolean ignoreLast = false;
+    private static void verifyCerts_(X509.Certificate[] certs) throws DER.Exception, Exn {
+        int last = certs.length-1;
         for(int i=0;i<certs.length;i++) {
             debug("Cert " + i + ": " + certs[i].subject + " ok");
             if(!certs[i].isValid())
                 throw new Exn("Certificate " + i + " in certificate chain is not valid (" + certs[i].startDate + " - " + certs[i].endDate + ")");
             if(i != 0) {
-                X509Certificate.BC bc = certs[i].basicContraints;
+                X509.Certificate.BC bc = certs[i].basicContraints;
                 if(bc == null) {
-                    if(i == certs.length - 1) {
-                        ignoreLast = true;
-                        break;
-                    }
-                    throw new Exn("CA-cert lacks Basic Constraints");
+                    last = i;
+                    break;
                 } else {
                     if(!bc.isCA) throw new Exn("non-CA certificate used for signing");
                     if(bc.pathLenConstraint != null && bc.pathLenConstraint.longValue() < i-1) throw new Exn("CA cert can't be used this deep");
@@ -943,18 +1024,18 @@ public class SSL extends Socket {
             }
         }
         
-        X509Certificate cert = certs[ignoreLast ? certs.length - 2 : certs.length-1];
+        X509.Certificate cert = certs[last];
         
-        RSAPublicKey pks = (RSAPublicKey) caKeys.get(cert.issuer);
+        RSA.PublicKey pks = (RSA.PublicKey) caKeys.get(cert.issuer);
         if(pks == null) throw new Exn("Certificate is signed by an unknown CA (" + cert.issuer + ")");
         if(!cert.isSignedWith(pks)) throw new Exn("Certificate is not signed by its CA");
         log("" + cert.subject + " is signed by " + cert.issuer);
     }
     
     public static void addCACert(byte[] b) throws IOException { addCACert(new ByteArrayInputStream(b)); }
-    public static void addCACert(InputStream is) throws IOException { addCACert(new X509Certificate(is)); }
-    public static void addCACert(X509Certificate cert) throws DER.Exception { addCAKey(cert.subject,cert.getRSAPublicKey()); }
-    public static void addCAKey(X509Name subject, RSAPublicKey pks)  {
+    public static void addCACert(InputStream is) throws IOException { addCACert(new X509.Certificate(is)); }
+    public static void addCACert(X509.Certificate cert) throws DER.Exception { addCAKey(cert.subject,cert.getRSAPublicKey()); }
+    public static void addCAKey(X509.Name subject, RSA.PublicKey pks)  {
         synchronized(caKeys) {
             if(caKeys.get(subject) != null)
                 throw new IllegalArgumentException(subject.toString() + " already exists!");
@@ -986,8 +1067,8 @@ public class SSL extends Socket {
                 Vector seq = (Vector) new DER.InputStream(is).readObject();
                 for(Enumeration e = seq.elements(); e.hasMoreElements();) {
                     Vector seq2 = (Vector) e.nextElement();
-                    X509Name subject = new X509Name(seq2.elementAt(0));
-                    RSAPublicKey pks = new RSAPublicKey(seq2.elementAt(1));
+                    X509.Name subject = new X509.Name(seq2.elementAt(0));
+                    RSA.PublicKey pks = new RSA.PublicKey(seq2.elementAt(1));
                     addCAKey(subject,pks);
                 }
                 return seq.size();
@@ -1011,7 +1092,7 @@ public class SSL extends Socket {
     }
     
     public interface VerifyCallback {
-        public boolean checkCerts(X509Certificate[] certs, String hostname, Exn exn);
+        public boolean checkCerts(X509.Certificate[] certs, String hostname, Exn exn);
     }
     
     // Helper methods