aes/cbc support in ssl.java
[org.ibex.crypto.git] / src / org / ibex / net / SSL.java
1 /*
2  * org.ibex.net.SSL - By Brian Alliet
3  * Copyright (C) 2004 Brian Alliet
4  * 
5  * Based on TinySSL by Adam Megacz
6  * Copyright (C) 2003 Adam Megacz <adam@xwt.org> all rights reserved.
7  * 
8  * You may modify, copy, and redistribute this code under the terms of
9  * the GNU Lesser General Public License version 2.1, with the exception
10  * of the portion of clause 6a after the semicolon (aka the "obnoxious
11  * relink clause")
12  */
13
14 package org.ibex.net;
15
16 import org.ibex.crypto.*;
17 import java.security.SecureRandom;
18
19 import java.net.Socket;
20 import java.net.SocketException;
21
22 import java.io.*;
23 import java.util.Enumeration;
24 import java.util.Hashtable;
25 import java.util.Random;
26 import java.util.Vector;
27
28 // FEATURE: Server socket
29
30 public class SSL extends Socket {
31     public static final byte TLS_RSA_WITH_AES_256_CBC_SHA = 0x35;
32     public static final byte TLS_RSA_WITH_AES_128_CBC_SHA = 0x2f;
33     public static final byte SSL_RSA_WITH_RC4_128_SHA = 0x05;
34     public static final byte SSL_RSA_WITH_RC4_128_MD5 = 0x04;
35     
36     private static final byte[] DEFAULT_CIPHER_PREFS = new byte[]{
37         TLS_RSA_WITH_AES_256_CBC_SHA,TLS_RSA_WITH_AES_128_CBC_SHA,
38         SSL_RSA_WITH_RC4_128_SHA,SSL_RSA_WITH_RC4_128_MD5
39     };
40     
41     private String hostname;
42     
43     private int negotiated;
44     
45     private boolean tls = true;
46     private boolean sha;
47     private int aes;
48     
49     private final DataInputStream rawIS;
50     private final DataOutputStream rawOS;
51     
52     private final InputStream sslIS;
53     private final OutputStream sslOS;
54     
55     private byte[] sessionID;
56
57     private Digest clientWriteMACDigest;        
58     private Digest serverWriteMACDigest;        
59     private byte[] masterSecret;
60     
61     private Cipher writeCipher;
62     private Cipher readCipher;
63     
64     private long serverSequenceNumber;
65     private long clientSequenceNumber;
66     
67     private int warnings;
68     private boolean closed;
69     
70     // These are only used during negotiation
71     private byte[] serverRandom;
72     private byte[] clientRandom;
73     private byte[] preMasterSecret;
74     
75     // Buffers
76     private byte[] mac;
77     
78     private byte[] pending = new byte[16384];
79     private int pendingStart;
80     private int pendingLength;
81
82     private byte[] sendRecordBuf = new byte[16384];
83     
84     private int handshakeDataStart;
85     private int handshakeDataLength;
86     private byte[] readRecordBuf = new byte[16384+20];   // 20 == sizeof(sha1 hash)
87     private byte[] readRecordScratch = new byte[16384+20];
88     
89     // These are only uses for CBCs
90     private int blockSize; // block size (0 for stream ciphers)
91     private byte[] padBuf;
92     private byte[] prevBlock;
93     
94     private ByteArrayOutputStream handshakesBuffer;
95     
96     // End Buffers
97     
98     // Static variables
99     private final static byte[] pad1 = new byte[48];
100     private final static byte[] pad2 = new byte[48];
101     private final static byte[] pad1_sha = new byte[40];
102     private final static byte[] pad2_sha = new byte[40];
103     
104     static {
105         for(int i=0; i<pad1.length; i++) pad1[i] = (byte)0x36;
106         for(int i=0; i<pad2.length; i++) pad2[i] = (byte)0x5C;
107         for(int i=0; i<pad1_sha.length; i++) pad1_sha[i] = (byte)0x36;
108         for(int i=0; i<pad2_sha.length; i++) pad2_sha[i] = (byte)0x5C;
109     }
110     
111     private final static Hashtable caKeys = new Hashtable();
112     private static VerifyCallback verifyCallback;
113     
114     //
115     // Constructors
116     //
117     public SSL(String host) throws IOException { this(host,443); }
118     public SSL(String host, int port) throws IOException { this(host,port,true); }
119     public SSL(String host, int port, boolean negotiate) throws IOException { this(host,port,negotiate,null); }
120     public SSL(String host, int port, State state) throws IOException { this(host,port,true,state); }
121     public SSL(String host, int port, boolean negotiate, State state) throws IOException {
122         super(host,port);
123         hostname = host;
124         rawIS = new DataInputStream(new BufferedInputStream(super.getInputStream()));
125         rawOS = new DataOutputStream(new BufferedOutputStream(super.getOutputStream()));
126         sslIS = new SSLInputStream();
127         sslOS = new SSLOutputStream();
128         if(negotiate) negotiate(state);
129     }
130
131     public synchronized void setTLS(boolean b) { if(negotiated!=0) throw new IllegalStateException("already negotiated"); tls = b; }
132     
133     public void negotiate() throws IOException { negotiate(null,null); }
134     public void negotiate(State state) throws IOException { negotiate(state,null); }
135     public void negotiate(byte[] cipherPrefs) throws IOException { negotiate(null,cipherPrefs); }
136     private void negotiate(State state, byte[] cipherPrefs) throws IOException {
137         if(negotiated != 0) throw new IllegalStateException("already negotiated");
138         
139         handshakesBuffer = new ByteArrayOutputStream();
140         
141         try {
142             sendClientHello(state != null ? state.sessionID : null, cipherPrefs);
143             flush();
144             debug("sent ClientHello (" + (tls?"TLSv1.0":"SSLv3.0")+")");
145             
146             receiveServerHello();
147             debug("got ServerHello (" + (tls?"TLSv1.0":"SSLv3.0")+")");
148             
149             boolean resume = 
150                 state != null && sessionID.length == state.sessionID.length && 
151                 eq(state.sessionID,0,sessionID,0,sessionID.length);
152             
153             if(resume) 
154                 negotiateResume(state);
155             else
156                 negotiateNew();
157             
158             // we're done with these now
159             clientRandom = serverRandom = preMasterSecret = null;
160             handshakesBuffer = null;
161             
162             log("Negotiation with " + hostname + " complete (" + (tls?"TLSv1.0":"SSLv3.0")+")");
163         } finally {
164             if((negotiated & 3) != 3) {
165                 negotiated = 0;
166                 try { super.close(); } catch(IOException e) { /* ignore */ }
167                 closed = true;
168             }
169         }
170     }
171     
172     private void negotiateResume(State state) throws IOException {
173         masterSecret = state.masterSecret;
174         
175         initCrypto();
176         log("initializec crypto");
177         
178         receiveChangeCipherSpec();
179         debug("Received ChangeCipherSpec");
180         negotiated |= 2;
181         receieveFinished();
182         debug("Received Finished");
183         
184         sendChangeCipherSpec();
185         debug("Sent ChangeCipherSpec");
186         negotiated |= 1;
187         sendFinished();
188         debug("Sent Finished");
189     }
190     
191     private void negotiateNew() throws IOException {
192         X509.Certificate[] certs = receiveServerCertificates();
193         debug("got Certificate");
194         
195         boolean gotCertificateRequest = false;
196         OUTER: for(;;) {
197             byte[] buf = readHandshake();
198             switch(buf[0]) {
199             case 14: // ServerHelloDone
200                 if(buf.length != 4) throw new Exn("ServerHelloDone contained trailing garbage");
201                 debug("got ServerHelloDone");
202                 break OUTER;
203             case 13: // CertificateRequest
204                 debug("Got a CertificateRequest message but we don't suport client certificates");
205                 gotCertificateRequest = true;
206                 break;
207             default:
208                 throw new Exn("unknown handshake type " + buf[0]);
209             }
210         }
211         
212         if(gotCertificateRequest)
213             sendHandshake((byte)11,new byte[3]); // send empty cert list
214         
215         try {
216             if(!hostname.equalsIgnoreCase(certs[0].getCN()))
217                 throw new Exn("Certificate is for " + certs[0].getCN() + " not " + hostname);
218             verifyCerts(certs);
219         } catch(Exn e) {
220             if(verifyCallback == null) throw e;
221             synchronized(SSL.class) {
222                 if(!verifyCallback.checkCerts(certs,hostname,e)) throw e;
223             }
224         }
225         
226         computeMasterSecret();
227         
228         sendClientKeyExchange(certs[0]);
229         debug("sent ClientKeyExchange");
230         
231         initCrypto();
232         
233         sendChangeCipherSpec();
234         debug("sent ChangeCipherSpec");
235         negotiated |= 1;
236         sendFinished();
237         debug("sent Finished");
238         flush();
239         
240         receiveChangeCipherSpec();
241         debug("got ChangeCipherSpec");
242         negotiated |= 2;
243         receieveFinished();
244         debug("got Finished");
245     }
246     
247     public State getSessionState() {
248         if((negotiated&3)!=3 || !closed || warnings != 0) return null;
249         return new State(sessionID,masterSecret);
250     }
251     public boolean isActive() { return !closed; }
252     public boolean isNegotiated() { return (negotiated&3) == 3; }
253     
254     private void sendClientHello(byte[] sessionID, byte[] cipherPrefs) throws IOException {
255         if(sessionID != null && sessionID.length > 256) throw new IllegalArgumentException("sessionID");
256         if(cipherPrefs == null) cipherPrefs = DEFAULT_CIPHER_PREFS;
257         else if(cipherPrefs.length > 4) throw new IllegalArgumentException("too many cipherPrefs");
258         // 2 = version, 32 = randomvalue, 1 = sessionID size, 2 = cipher list size, 8 = the four ciphers,
259         // 2 = compression length/no compression
260         int p = 0;
261         byte[] buf = new byte[2+32+1+(sessionID == null ? 0 : sessionID.length)+2+(cipherPrefs.length * 2)+2];
262         buf[p++] = 0x03; // major version
263         buf[p++] = tls ? (byte)0x01 : (byte)0x00;
264         
265         clientRandom = new byte[32];
266         int now = (int)(System.currentTimeMillis() / 1000L);
267         new Random().nextBytes(clientRandom);
268         clientRandom[0] = (byte)(now>>>24);
269         clientRandom[1] = (byte)(now>>>16);
270         clientRandom[2] = (byte)(now>>>8);
271         clientRandom[3] = (byte)(now>>>0);
272         System.arraycopy(clientRandom,0,buf,p,32);
273         p += 32;
274         
275         buf[p++] = sessionID != null ? (byte)sessionID.length : 0;
276         if(sessionID != null && sessionID.length != 0) System.arraycopy(sessionID,0,buf,p,sessionID.length);
277         p += sessionID != null ? sessionID.length : 0;
278         
279         buf[p++] = 0x00; // 8 bytes of ciphers
280         buf[p++] = (byte)(cipherPrefs.length * 2);
281         
282         for(int i=0;i<cipherPrefs.length;i++) {
283             buf[p++] = 0x00;
284             buf[p++] = cipherPrefs[i];
285         }
286         
287         buf[p++] = 0x01; // compression length
288         buf[p++] = 0x00; // no compression
289         
290         sendHandshake((byte)1,buf);
291         flush();
292     }
293     
294     private void receiveServerHello() throws IOException {
295         // ServerHello
296         byte[] buf = readHandshake();
297         if(buf[0] != 2) throw new Exn("expected a ServerHello message");
298         
299         if(buf.length < 6 + 32 + 1) throw new Exn("ServerHello too small");
300         if(buf.length < 6 + 32 + 1 + buf[6+32] + 3) throw new Exn("ServerHello too small " + buf.length+" "+buf[6+32]); 
301         
302         if(buf[4] != 0x03 || !(buf[5]==0x00 || buf[5]==0x01)) throw new Exn("server wants to use version " + buf[4] + "." + buf[5]);
303         tls = buf[5] == 0x01;
304         int p = 6;
305         serverRandom = new byte[32];
306         System.arraycopy(buf,p,serverRandom,0,32);
307         p += 32;
308         sessionID = new byte[buf[p++]&0xff];
309         if(sessionID.length != 0) System.arraycopy(buf,p,sessionID,0,sessionID.length);
310         p += sessionID.length;
311         int cipher = ((buf[p]&0xff)<<8) | (buf[p+1]&0xff);
312         p += 2;
313         if((cipher>>>8)!=0) throw new Exn("Unsupported cipher " + cipher);
314         switch(cipher&0xff) {
315             case SSL_RSA_WITH_RC4_128_MD5:     sha = false; aes=0;    debug("Using SSL_RSA_WITH_RC4_128_MD5"); break;
316             case SSL_RSA_WITH_RC4_128_SHA:     sha = true;  aes=0;    debug("Using SSL_RSA_WITH_RC4_128_SHA"); break;
317             case TLS_RSA_WITH_AES_128_CBC_SHA: sha = true; aes = 128; debug("Using TLS_RSA_WITH_AES_128_CBC_SHA"); break;
318             case TLS_RSA_WITH_AES_256_CBC_SHA: sha = true; aes = 256; debug("Using TLS_RSA_WITH_AES_256_CBC_SHA"); break;
319             default: throw new Exn("Unsupported cipher " + cipher);
320         }
321         mac = new byte[sha ? 20 : 16];
322         if(buf[p++] != 0x0) throw new Exn("unsupported compression " + buf[p-1]);
323     }
324     
325     private X509.Certificate[] receiveServerCertificates() throws IOException {
326         byte[] buf = readHandshake();
327         if(buf[0] != 11) throw new Exn("expected a Certificate message");
328         if((((buf[4]&0xff)<<16)|((buf[5]&0xff)<<8)|((buf[6]&0xff)<<0)) != buf.length-7) throw new Exn("size mismatch in Certificate message");
329         int p = 7;
330         int count = 0;
331         
332         for(int i=p;i<buf.length-3;i+=((buf[p+0]&0xff)<<16)|((buf[p+1]&0xff)<<8)|((buf[p+2]&0xff)<<0)) count++;
333         if(count == 0) throw new Exn("server didn't provide any certificates");
334         X509.Certificate[] certs = new X509.Certificate[count];
335         count = 0;
336         while(p < buf.length) {
337             int len = ((buf[p+0]&0xff)<<16)|((buf[p+1]&0xff)<<8)|((buf[p+2]&0xff)<<0);
338             p += 3;
339             if(p + len > buf.length) throw new Exn("Certificate message cut short");
340             certs[count++] = new X509.Certificate(new ByteArrayInputStream(buf,p,len));
341             p += len;
342         }
343         return certs;
344     }
345     
346     private void sendClientKeyExchange(X509.Certificate serverCert) throws IOException {
347         byte[] encryptedPreMasterSecret;
348         RSA.PublicKey pks = serverCert.getRSAPublicKey();
349         PKCS1 pkcs1 = new PKCS1(new RSA(pks.modulus,pks.exponent,false),random);
350         encryptedPreMasterSecret = pkcs1.encode(preMasterSecret);
351         byte[] buf;
352         if(tls) {
353             buf = new byte[encryptedPreMasterSecret.length+2];
354             buf[0] = (byte) (encryptedPreMasterSecret.length>>>8);
355             buf[1] = (byte) (encryptedPreMasterSecret.length>>>0);
356             System.arraycopy(encryptedPreMasterSecret,0,buf,2,encryptedPreMasterSecret.length);
357         } else {
358             // ugh... netscape didn't send the length bytes and now every SSLv3 implementation
359             // must implement this bug
360             buf = encryptedPreMasterSecret;
361         }
362         sendHandshake((byte)16,buf);
363     }
364     
365     private void sendChangeCipherSpec() throws IOException {
366         sendRecord((byte)20,new byte[] { 0x01 });
367     }
368     
369     private void computeMasterSecret() {
370         preMasterSecret = new byte[48];
371         preMasterSecret[0] = 0x03; // version_high
372         preMasterSecret[1] = tls ? (byte) 0x01 : (byte) 0x00; // version_low
373         randomBytes(preMasterSecret,2,46);
374         
375         if(tls) {
376             masterSecret = tlsPRF(48,preMasterSecret,getBytes("master secret"),concat(clientRandom,serverRandom));
377         } else {
378             masterSecret = concat(new byte[][] {
379                     md5(new byte[][] { preMasterSecret,
380                             sha1(new byte[][] { new byte[] { 0x41 }, preMasterSecret, clientRandom, serverRandom })}),
381                             md5(new byte[][] { preMasterSecret,
382                                     sha1(new byte[][] { new byte[] { 0x42, 0x42 }, preMasterSecret, clientRandom, serverRandom })}),
383                                     md5(new byte[][] { preMasterSecret,
384                                             sha1(new byte[][] { new byte[] { 0x43, 0x43, 0x43 }, preMasterSecret, clientRandom, serverRandom })})
385             } );    
386         }
387     }
388     
389     public void initCrypto() {
390         byte[] keyMaterial;
391         byte[] ivBlock;
392         
393         if(tls) {
394             keyMaterial = tlsPRF(
395                     (mac.length + (aes==256 ? 32 : 16) + (aes==0 ? 0 : 16))*2, // MAC len + key len + iv len
396                     masterSecret,
397                     getBytes("key expansion"),
398                     concat(serverRandom,clientRandom)
399             );
400         } else {
401             keyMaterial = new byte[] { };
402             for(int i=0; keyMaterial.length < 72; i++) {
403                 byte[] crap = new byte[i + 1];
404                 for(int j=0; j<crap.length; j++) crap[j] = (byte)(((byte)0x41) + ((byte)i));
405                 keyMaterial = concat(new byte[][] { keyMaterial,
406                         md5(new byte[][] { masterSecret,
407                                 sha1(new byte[][] { crap, masterSecret, serverRandom, clientRandom }) }) });
408             }            
409             if(aes != 0) throw new Error("should never happen");
410         }
411
412         byte[] clientWriteMACSecret = new byte[mac.length];
413         byte[] serverWriteMACSecret = new byte[mac.length];
414         byte[] clientWriteKey = new byte[aes==256 ? 32 : 16];
415         byte[] serverWriteKey = new byte[aes==256 ? 32 : 16];
416         byte[] clientWriteIV = aes == 0 ? null : new byte[16];
417         byte[] serverWriteIV = aes == 0 ? null : new byte[16];
418         
419         int p = 0;
420         System.arraycopy(keyMaterial, p, clientWriteMACSecret, 0, mac.length); p += mac.length;
421         System.arraycopy(keyMaterial, p, serverWriteMACSecret, 0, mac.length); p += mac.length;
422         System.arraycopy(keyMaterial, p, clientWriteKey, 0, clientWriteKey.length); p += clientWriteKey.length; 
423         System.arraycopy(keyMaterial, p, serverWriteKey, 0, serverWriteKey.length); p += serverWriteKey.length;
424         if(clientWriteIV != null) System.arraycopy(keyMaterial,p,clientWriteIV,0,16); p += 16;
425         if(serverWriteIV != null) System.arraycopy(keyMaterial,p,serverWriteIV,0,16); p += 16;
426         
427         Digest inner;
428         
429         writeCipher = aes==0 ? (Cipher)new RC4(clientWriteKey) : (Cipher)new CBC(new AES(clientWriteKey,false),16,clientWriteIV,false);
430         inner = sha ? (Digest)new SHA1() : (Digest)new MD5();
431         clientWriteMACDigest = tls ? (Digest) new HMAC(inner,clientWriteMACSecret) : (Digest)new SSLv3HMAC(inner,clientWriteMACSecret);
432         
433         readCipher = aes==0 ? (Cipher)new RC4(serverWriteKey) : (Cipher)new CBC(new AES(serverWriteKey,true),16,serverWriteIV,true);
434         inner = sha ? (Digest)new SHA1() : (Digest)new MD5();
435         serverWriteMACDigest = tls ? (Digest)new HMAC(inner,serverWriteMACSecret) : (Digest)new SSLv3HMAC(inner,serverWriteMACSecret);
436         
437         if(aes != 0) {
438             blockSize = 16; // aes block size == 16
439             padBuf = new byte[mac.length+blockSize]; // worse case, 15 bytes of data, mac, and padding byte
440         }
441     }
442     
443     private void sendFinished() throws IOException {
444         byte[] handshakes = handshakesBuffer.toByteArray();
445         if(tls) {
446             sendHandshake((byte)20, tlsPRF(
447                     12,
448                     masterSecret,
449                     getBytes("client finished"),
450                     concat(md5(handshakes),sha1(handshakes))));
451             
452         } else {
453             sendHandshake((byte)20, concat(new byte[][] { 
454                     md5(new byte[][] { masterSecret, pad2, 
455                                        md5(new byte[][] { handshakes, new byte[] { (byte)0x43, (byte)0x4C, (byte)0x4E, (byte)0x54 },
456                                                           masterSecret, pad1 }) }),
457                     sha1(new byte[][] { masterSecret, pad2_sha,
458                                        sha1(new byte[][] { handshakes, new byte[] { (byte)0x43, (byte)0x4C, (byte)0x4E, (byte)0x54 },
459                                                           masterSecret, pad1_sha } ) })
460                 }));
461         }
462     }
463         
464     private void receiveChangeCipherSpec() throws IOException {    
465         int size = readRecord((byte)20);
466         if(size == -1) throw new Exn("got eof when expecting a ChangeCipherSpec message");
467         if(size != 1 || readRecordBuf[0] != 0x01) throw new Exn("Invalid ChangeCipherSpec message");
468     }
469     
470     private void receieveFinished() throws IOException {
471         byte[] handshakes = handshakesBuffer.toByteArray();
472         byte[] buf = readHandshake();
473         if(buf[0] != 20) throw new Exn("expected a Finished message");
474         byte[] expected;
475         
476         if(tls) {
477             if(buf.length != 4 + 12) throw new Exn("Finished message too short");
478             expected = tlsPRF(
479                     12,masterSecret,
480                     getBytes("server finished"),
481                     concat(md5(handshakes),sha1(handshakes)));
482         } else {
483             if(buf.length != 4 + 16 +20) throw new Exn("Finished message too short");
484             expected = concat(new byte[][] {
485                     md5(new byte[][] { masterSecret, pad2,
486                             md5(new byte[][] { handshakes, new byte[] { (byte)0x53, (byte)0x52, (byte)0x56, (byte)0x52 },
487                                     masterSecret, pad1 }) }),
488                                     sha1(new byte[][] { masterSecret, pad2_sha,
489                                             sha1(new byte[][] { handshakes, new byte[] { (byte)0x53, (byte)0x52, (byte)0x56, (byte)0x52 },
490                                                     masterSecret, pad1_sha } ) } ) } );
491         }
492         if(!eq(expected,0,buf,4,expected.length)) throw new Exn("server finished message mismatch");
493     }
494     
495     private void flush() throws IOException { rawOS.flush(); }
496
497     private void sendHandshake(byte type, byte[] payload) throws IOException {
498         if(payload.length > (1<<24)) throw new IllegalArgumentException("payload.length");
499         byte[] buf = new byte[4+payload.length];
500         buf[0] = type;
501         buf[1] = (byte)(payload.length>>>16);
502         buf[2] = (byte)(payload.length>>>8);
503         buf[3] = (byte)(payload.length>>>0);
504         System.arraycopy(payload,0,buf,4,payload.length);
505         handshakesBuffer.write(buf);
506         sendRecord((byte)22,buf);
507     }
508     
509     private void sendRecord(byte proto, byte[] buf) throws IOException { sendRecord(proto,buf,0,buf.length); }
510     private void sendRecord(byte proto, byte[] payload, int off, int totalLen) throws IOException {
511         int macLength = (negotiated & 1) != 0 ? mac.length : 0;
512         while(totalLen > 0) {
513             int len = min(totalLen,16384-macLength);
514             rawOS.writeByte(proto);
515             rawOS.writeShort(tls ? 0x0301 : 0x0300);
516             if((negotiated & 1) != 0) {
517                 computeMAC(proto,payload,off,len,clientWriteMACDigest,clientSequenceNumber);
518                 // FEATURE: Encode in place
519                 if(blockSize != 0) {
520                     int firstShot = len & ~(blockSize-1);
521                     if(firstShot != 0) writeCipher.process(payload,off,sendRecordBuf,0,firstShot);
522                     int _off = off + firstShot; // offset in sendRecordBuf
523                     int _len = len - firstShot; // length of data in padBuf
524                     if(_len != 0) System.arraycopy(payload,_off,padBuf,0,_len);
525                     System.arraycopy(mac,0,padBuf,_len,macLength);
526                     _len += macLength;
527                     int extra = blockSize - (_len&~blockSize);
528                     if(extra == 0) extra = 16;
529                     for(int i=0;i<extra;i++) padBuf[_len+i] = (byte)(extra-1);
530                     _len += extra;
531                     writeCipher.process(padBuf,0,sendRecordBuf,_off,_len);
532                     rawOS.writeShort(firstShot+_len);
533                     rawOS.write(sendRecordBuf,0,firstShot+_len);
534                 } else {
535                     writeCipher.process(payload,off,sendRecordBuf,0,len);
536                     writeCipher.process(mac,0,sendRecordBuf,len,macLength);
537                     rawOS.writeShort(len + macLength);
538                     rawOS.write(sendRecordBuf,0, len +macLength);
539                 }
540                 clientSequenceNumber++;
541             } else {
542                 rawOS.writeShort(len);
543                 rawOS.write(payload,off,len);
544             }
545             totalLen -= len;
546             off += len;
547         }
548     }
549
550     public static char[] hexDigit = new char[] { '0','1','2','3','4','5','6','7','8','9','a','b','c','d','e','f' };
551     public static String hex(byte[] b, int len) {
552         StringBuffer sb = new StringBuffer(len);
553         for(int i=0;i<len;i++) {
554             if((i%8)==0) sb.append("\n");
555             else if(i!=0) sb.append(":");
556             sb.append(hexDigit[(b[i]&0xf0)>>4]).append(hexDigit[b[i]&0xf]);
557         }
558         return sb.toString();
559     }
560         
561     private byte[] readHandshake() throws IOException {
562         if(handshakeDataLength == 0) {
563             handshakeDataStart = 0;
564             handshakeDataLength = readRecord((byte)22);
565             if(handshakeDataLength == -1) throw new Exn("got eof when expecting a handshake packet");
566         }
567         byte[] buf = readRecordBuf;
568         int len = ((buf[handshakeDataStart+1]&0xff)<<16)|((buf[handshakeDataStart+2]&0xff)<<8)|((buf[handshakeDataStart+3]&0xff)<<0);
569         // Handshake messages can theoretically span multiple records, but in practice this does not occur
570         if(len > handshakeDataLength) {
571             sendAlert(true,10); // 10 == unexpected message
572             throw new Exn("handshake message size too large " + len + " vs " + (handshakeDataLength-handshakeDataStart));
573         }
574         byte[] ret = new byte[4+len];
575         System.arraycopy(buf,handshakeDataStart,ret,0,ret.length);
576         handshakeDataLength -= ret.length;
577         handshakeDataStart += ret.length;
578         handshakesBuffer.write(ret);
579         return ret;
580     }
581     
582     private int readRecord(byte reqProto) throws IOException {
583         int macLength = (negotiated & 2) != 0 ? mac.length : 0;
584         for(;;) {
585             byte proto;
586             int version, len;
587             
588             try {
589                 proto = rawIS.readByte();
590             } catch(EOFException e) {
591                 // this may or may not be an error. it is up to the application protocol
592                 closed = true;
593                 super.close();
594                 throw new PrematureCloseExn();
595             }
596             try {
597                 version = rawIS.readShort();
598                 if(version != 0x0300 && version != 0x0301) throw new Exn("invalid version ");
599                 len = rawIS.readShort();
600                 if(len <= 0 || len > 16384+((negotiated&2)!=0 ? macLength : 0)) throw new Exn("invalid length " + len);
601                 rawIS.readFully((negotiated&2)!=0 ? readRecordScratch : readRecordBuf,0,len);
602             } catch(EOFException e) {
603                 // an EOF here is always an error (we don't pass the EOF back on to the app
604                 // because it isn't a "legitimate" eof)
605                 throw new Exn("Hit EOF too early");
606             }
607             
608             if((negotiated & 2) != 0) {
609                 // FEATURE: Decode in place
610                 if(blockSize != 0 && (len % blockSize) != 0) throw new Exn("input not a multiple of the cipher's block size");
611                 readCipher.process(readRecordScratch,0,readRecordBuf,0,len);
612                 
613                 if(blockSize != 0) {
614                     if(!tls) throw new Error("should never happen");
615                     int padding = readRecordBuf[len-1]&0xff;
616                     if(padding >= blockSize) throw new Exn("invalid padding length: " + padding);
617                     for(int i=0;i<padding;i++) if(readRecordBuf[len-padding-1] != padding) throw new Exn("bad padding");
618                     len -= (padding+1);
619                 }
620                 if(len < macLength) throw new Exn("packet size < macLength");
621                 
622                 computeMAC(proto,readRecordBuf,0,len-macLength,serverWriteMACDigest,serverSequenceNumber);
623                 for(int i=0;i<macLength;i++)
624                     if(mac[i] != readRecordBuf[len-macLength+i])
625                         throw new Exn("mac mismatch");
626                 len -= macLength;
627                 serverSequenceNumber++;
628             }
629             
630             if(proto == reqProto) return len;
631             
632             switch(proto) {
633                 case 21: { // ALERT
634                     if(len != 2) throw new Exn("invalid lengh for alert");
635                     int level = readRecordBuf[0];
636                     int desc = readRecordBuf[1];
637                     if(level == 1) {
638                         if(desc == 0) { // CloseNotify
639                             debug("Server requested connection closure");
640                             try {
641                                 sendCloseNotify();
642                             } catch(SocketException e) { /* incomplete close, thats ok */ }
643                             closed = true;
644                             super.close();
645                             return -1;
646                         } else {
647                             warnings++;
648                             log("SSL ALERT WARNING: desc: " + desc);
649                         }
650                     } else if(level == 2) {
651                         throw new Exn("SSL ALERT FATAL: desc: " +desc);
652                     } else {
653                         throw new Exn("invalid alert level");
654                     }
655                     break;
656                 }
657                 case 22: { // Handshake
658                     int type = readRecordBuf[0];
659                     int hslen = ((readRecordBuf[1]&0xff)<<16)|((readRecordBuf[2]&0xff)<<8)|((readRecordBuf[3]&0xff)<<0);
660                     if(hslen > len - 4) throw new Exn("Multiple sequential handshake messages received after negotiation");
661                     if(type == 0) { // HellloRequest
662                         if(tls) sendAlert(false,100); // politely refuse, 100 == NoRegnegotiation
663                     } else {
664                         throw new Exn("Unexpected Handshake type: " + type);
665                     }
666                 }
667                 default: throw new Exn("Unexpected protocol: " + proto);
668             }
669         }
670     }
671     
672     private static void longToBytes(long l, byte[] buf, int off) {
673         for(int i=0;i<8;i++) buf[off+i] = (byte)(l>>>(8*(7-i)));
674     }
675     private void computeMAC(byte proto, byte[] payload, int off, int len, Digest digest, long sequenceNumber) {
676         if(tls) {
677             longToBytes(sequenceNumber,mac,0);
678             mac[8] = proto;
679             mac[9] = 0x03; // version
680             mac[10] = 0x01;
681             mac[11] = (byte)(len>>>8);
682             mac[12] = (byte)(len>>>0);
683             
684             digest.update(mac,0,13);
685             digest.update(payload,off,len);
686             digest.doFinal(mac,0);
687         } else {
688             longToBytes(sequenceNumber, mac, 0);
689             mac[8] = proto;
690             mac[9] = (byte)(len>>>8);
691             mac[10] = (byte)(len>>>0);
692             
693             digest.update(mac, 0, 11);
694             digest.update(payload, off, len);
695             digest.doFinal(mac, 0);
696         }
697     }
698     
699     private void sendCloseNotify() throws IOException { sendRecord((byte)21, new byte[] { 0x01, 0x00 }); }
700     private void sendAlert(boolean fatal, int message) throws IOException {
701         byte[] buf = new byte[] { fatal ? (byte)2 :(byte)1, (byte)message };
702         sendRecord((byte)21,buf);
703         flush();
704     }
705     
706     //
707     // Hash functions
708     //
709     
710     // Shared digest objects
711     private MD5 masterMD5 = new MD5();
712     private SHA1 masterSHA1 = new SHA1();
713     
714     private byte[] md5(byte[] in) { return md5( new byte[][] { in }); }
715     private byte[] md5(byte[][] inputs) {
716         masterMD5.reset();
717         for(int i=0; i<inputs.length; i++) masterMD5.update(inputs[i], 0, inputs[i].length);
718         byte[] ret = new byte[masterMD5.getDigestSize()];
719         masterMD5.doFinal(ret, 0);
720         return ret;
721     }
722     
723     private byte[] sha1(byte[] in)  { return sha1(new byte[][] { in }); }
724     private byte[] sha1(byte[][] inputs) {
725         masterSHA1.reset();
726         for(int i=0; i<inputs.length; i++) masterSHA1.update(inputs[i], 0, inputs[i].length);
727         byte[] ret = new byte[masterSHA1.getDigestSize()];
728         masterSHA1.doFinal(ret, 0);
729         return ret;
730     }
731     
732     /*  RFC-2246
733      PRF(secret, label, seed) = P_MD5(S1, label + seed) XOR P_SHA-1(S2, label + seed);
734      L_S = length in bytes of secret;
735      L_S1 = L_S2 = ceil(L_S / 2);
736      
737      The secret is partitioned into two halves (with the possibility of
738      one shared byte) as described above, S1 taking the first L_S1 bytes
739      and S2 the last L_S2 bytes.
740      
741      P_hash(secret, seed) = HMAC_hash(secret, A(1) + seed) +
742      HMAC_hash(secret, A(2) + seed) +
743      HMAC_hash(secret, A(3) + seed) + ...
744      
745      A(0) = seed
746      A(i) = HMAC_hash(secret, A(i-1))
747      */           
748     private byte[] tlsPRF(int size,byte[] secret, byte[] label, byte[] seed) {
749         if(size > 140) throw new IllegalArgumentException("size > 140: " + size);
750         seed = concat(label,seed);
751         
752         int half_length = (secret.length + 1) / 2;
753         byte[] s1 = new byte[half_length];
754         System.arraycopy(secret,0,s1,0,half_length);
755         byte[] s2 = new byte[half_length];
756         System.arraycopy(secret,secret.length - half_length, s2, 0, half_length);
757
758         Digest hmac_md5 = new HMAC(new MD5(),s1);
759         Digest hmac_sha = new HMAC(new SHA1(),s2);
760         
761         byte[] md5out = new byte[144];
762         byte[] shaout = new byte[140];
763         byte[] digest = new byte[20];
764         int n;
765         
766         n = 0;
767         hmac_md5.update(seed,0,seed.length);
768         hmac_md5.doFinal(digest,0);
769         
770         // digest == md5_a_1
771         while(n < size) {
772             hmac_md5.update(digest,0,16);
773             hmac_md5.update(seed,0,seed.length);
774             hmac_md5.doFinal(md5out,n);
775             hmac_md5.update(digest,0,16);
776             hmac_md5.doFinal(digest,0);
777             n += 16;
778         }
779         
780         n = 0;
781         hmac_sha.update(seed,0,seed.length);
782         hmac_sha.doFinal(digest,0);
783         
784         while(n < size) {
785             hmac_sha.update(digest,0,20);
786             hmac_sha.update(seed,0,seed.length);
787             hmac_sha.doFinal(shaout,n);
788             hmac_sha.update(digest,0,20);
789             hmac_sha.doFinal(digest,0);
790             n += 20;
791          }
792             
793         byte[] ret = new byte[size];
794         for(int i=0;i<size;i++) ret[i] = (byte)(md5out[i] ^ shaout[i]);
795         return ret;
796     }
797
798     public static class CBC implements Cipher {
799         private final Cipher c;
800         private final byte[] iv;
801         private final int blockSize;
802         private final boolean reverse;
803         public CBC(Cipher c, int blockSize, byte[] iv, boolean reverse) {
804             this.c = c;
805             this.blockSize = blockSize;
806             this.iv = iv;
807             this.reverse = reverse;
808             if(iv.length != blockSize) throw new IllegalArgumentException("iv.length != blockSize");
809         }
810         
811         public void process(byte[] in, int inp, byte[] out, int outp, int len) {
812             if((len % blockSize) != 0) throw new IllegalArgumentException("buffer must be a multiple of block size");
813             while(len != 0) {
814                 if(!reverse) {
815                     for(int i=0;i<blockSize;i++) iv[i] ^= in[inp+i]; // mangle the cleartext
816                     c.process(iv,0,out,outp,blockSize); // process it
817                     System.arraycopy(out,outp,iv,0,blockSize); // copy the block to iv
818                 } else {
819                     c.process(in,inp,out,outp,blockSize); // process the ciphertext
820                     for(int i=0;i<blockSize;i++) out[outp+i] ^= iv[i]; // mangle the cleartext
821                     System.arraycopy(in,inp,iv,0,blockSize); // copy the ciphertext to iv
822                 }
823                 inp+=16; outp+=16; len-=16;
824             }
825         }
826     }
827
828     public static class SSLv3HMAC extends Digest {
829         private final Digest h;
830         private final byte[] digest;
831         private final byte[] key;
832         private final int padSize;
833         
834         public int getDigestSize() { return h.getDigestSize(); }
835         
836         public SSLv3HMAC(Digest h, byte[] key) {
837             this.h = h;
838             this.key = key;
839             switch(h.getDigestSize()) {
840                 case 16: padSize = 48; break;
841                 case 20: padSize = 40; break;
842                 default: throw new IllegalArgumentException("unsupported digest size");
843             }
844             digest = new byte[h.getDigestSize()];
845             reset();
846         }
847         public void reset() {
848             h.reset();
849             h.update(key,0,key.length);
850             h.update(pad1,0,padSize);
851         }
852         public void update(byte[] b, int off, int len) { h.update(b,off,len); }
853         public void doFinal(byte[] out, int off){
854             h.doFinal(digest,0);
855             h.update(key,0,key.length);
856             h.update(pad2,0,padSize);
857             h.update(digest,0,digest.length);
858             h.doFinal(out,off);
859             reset();
860         }
861         protected void processWord(byte[] in, int inOff) {}
862         protected void processLength(long bitLength) {}
863         protected void processBlock() {}
864     }
865     
866     //
867     // Static Methods
868     //
869     
870     private static SecureRandom random = new SecureRandom();
871     public static synchronized void randomBytes(byte[] buf, int off, int len) {
872         byte[] bytes =  new byte[len];
873         random.nextBytes(bytes);
874         System.arraycopy(bytes,0,buf,off,len);
875     }
876     
877     public static byte[] concat(byte[] a, byte[] b) { return concat(new byte[][] { a, b }); }
878     public static byte[] concat(byte[] a, byte[] b, byte[] c) { return concat(new byte[][] { a, b, c }); }
879     public static byte[] concat(byte[][] inputs) {
880         int total = 0;
881         for(int i=0; i<inputs.length; i++) total += inputs[i].length;
882         byte[] ret = new byte[total];
883         for(int i=0,pos=0; i<inputs.length;pos+=inputs[i].length,i++)
884             System.arraycopy(inputs[i], 0, ret, pos, inputs[i].length);
885         return ret;
886     }
887     
888     public static byte[] getBytes(String s) {
889         try {
890             return s.getBytes("US-ASCII");
891         } catch (UnsupportedEncodingException e) {
892             return null; // will never happen
893         }
894     }
895     
896     public static boolean eq(byte[] a, int aoff, byte[] b, int boff, int len){
897         for(int i=0;i<len;i++) if(a[aoff+i] != b[boff+i]) return false;
898         return true;
899     }
900     
901     //
902     // InputStream/OutputStream/Socket interfaces
903     //
904     public OutputStream getOutputStream() { return sslOS; }
905     public InputStream getInputStream() { return sslIS; }
906     public synchronized void close() throws IOException {
907         if(!closed) {
908             if(negotiated != 0) {
909                 sendCloseNotify();
910                 flush();
911                 // don't bother sending a close_notify back to the server 
912                 // this is an incomplete close which is allowed by the spec
913             }
914             super.close();
915             closed = true;
916         }
917     }
918     
919     private int read(byte[] buf, int off, int len) throws IOException {
920         if(pendingLength == 0) {
921             if(closed) return -1;
922             int readLen = readRecord((byte)23);
923             if(readLen == -1) return -1; // EOF
924             len = min(len,readLen);
925             System.arraycopy(readRecordBuf,0,buf,off,len);
926             if(readLen > len) System.arraycopy(readRecordBuf,len,pending,0,readLen-len);
927             pendingStart = 0;
928             pendingLength = readLen - len;
929             return len;
930         } else {
931             len = min(len,pendingLength);
932             System.arraycopy(pending,pendingStart,buf,off,len);
933             pendingLength -= len;
934             pendingStart += len;
935             return len;
936         }
937     }
938     
939     private void write(byte[] buf, int off, int len) throws IOException {
940         if(closed) throw new SocketException("Socket closed");
941         sendRecord((byte)23,buf,off,len);
942         flush();
943     }
944     
945     private class SSLInputStream extends InputStream {
946         public int available() throws IOException {
947             synchronized(SSL.this) {
948                 return negotiated != 0 ? pendingLength : rawIS.available();
949             }
950         }
951         public int read() throws IOException {
952             synchronized(SSL.this) {
953                 if(negotiated==0) return rawIS.read();
954                 if(pendingLength > 0) {
955                     pendingLength--;
956                     return pending[pendingStart++];
957                 } else {
958                     byte[] buf = new byte[1];
959                     int n = read(buf);
960                     return n == -1 ? -1 : buf[0]&0xff;
961                 }
962             }
963         }
964         public int read(byte[] buf, int off, int len) throws IOException {
965             synchronized(SSL.this) {
966                 return negotiated!=0 ? SSL.this.read(buf,off,len) : rawIS.read(buf,off,len);
967             }
968         }
969         public long skip(long n) throws IOException {
970             synchronized(SSL.this) {
971                 if(negotiated==0) return rawIS.skip(n);
972                 if(pendingLength > 0) {
973                     n = min((int)n,pendingLength);
974                     pendingLength -= n;
975                     pendingStart += n;
976                     return n;
977                 }
978                 return super.skip(n);
979             }
980         }
981     }
982     
983     private class SSLOutputStream extends OutputStream {
984         public void flush() throws IOException { rawOS.flush(); }
985         public void write(int b) throws IOException { write(new byte[] { (byte)b }); }
986         public void write(byte[] buf, int off, int len) throws IOException {
987             synchronized(SSL.this) {
988                 if(negotiated!=0)
989                     SSL.this.write(buf,off,len);
990                 else
991                     rawOS.write(buf,off,len);
992             }
993         }
994     }
995     
996     public static class Exn extends IOException { public Exn(String s) { super(s); } }
997     public static class PrematureCloseExn extends Exn {
998         public PrematureCloseExn() { super("Connection was closed by the remote WITHOUT a close_noify"); }
999     }
1000     
1001     public static boolean debugOn = false;
1002     private static void debug(Object o) { if(debugOn) System.err.println("[BriSSL-Debug] " + o.toString()); }
1003     private static void log(Object o) { System.err.println("[BriSSL] " + o.toString()); }
1004             
1005     private static void verifyCerts(X509.Certificate[] certs) throws DER.Exception, Exn {
1006         try {
1007             verifyCerts_(certs);
1008         } catch(RuntimeException e) {
1009             e.printStackTrace();
1010             throw new Exn("Error while verifying certificates: " + e);
1011         }
1012     }
1013     
1014     private static void verifyCerts_(X509.Certificate[] certs) throws DER.Exception, Exn {
1015         int last = certs.length-1;
1016         for(int i=0;i<certs.length;i++) {
1017             debug("Cert " + i + ": " + certs[i].subject + " ok");
1018             if(!certs[i].isValid())
1019                 throw new Exn("Certificate " + i + " in certificate chain is not valid (" + certs[i].startDate + " - " + certs[i].endDate + ")");
1020             if(i != 0) {
1021                 X509.Certificate.BC bc = certs[i].basicContraints;
1022                 if(bc == null) {
1023                     last = i;
1024                     break;
1025                 } else {
1026                     if(!bc.isCA) throw new Exn("non-CA certificate used for signing");
1027                     if(bc.pathLenConstraint != null && bc.pathLenConstraint.longValue() < i-1) throw new Exn("CA cert can't be used this deep");
1028                 }
1029             }
1030             if(i != certs.length - 1) {
1031                 if(!certs[i].issuer.equals(certs[i+1].subject))
1032                     throw new Exn("Issuer for certificate " + i + " does not match next in chain");
1033                 if(!certs[i].isSignedBy(certs[i+1]))
1034                     throw new Exn("Certificate " + i + " in chain is not signed by the next certificate");
1035             }
1036         }
1037         
1038         X509.Certificate cert = certs[last];
1039         
1040         RSA.PublicKey pks = (RSA.PublicKey) caKeys.get(cert.issuer);
1041         if(pks == null) throw new Exn("Certificate is signed by an unknown CA (" + cert.issuer + ")");
1042         if(!cert.isSignedWith(pks)) throw new Exn("Certificate is not signed by its CA");
1043         log("" + cert.subject + " is signed by " + cert.issuer);
1044     }
1045     
1046     public static void addCACert(byte[] b) throws IOException { addCACert(new ByteArrayInputStream(b)); }
1047     public static void addCACert(InputStream is) throws IOException { addCACert(new X509.Certificate(is)); }
1048     public static void addCACert(X509.Certificate cert) throws DER.Exception { addCAKey(cert.subject,cert.getRSAPublicKey()); }
1049     public static void addCAKey(X509.Name subject, RSA.PublicKey pks)  {
1050         synchronized(caKeys) {
1051             if(caKeys.get(subject) != null)
1052                 throw new IllegalArgumentException(subject.toString() + " already exists!");
1053             caKeys.put(subject,pks);
1054         }
1055     }
1056     
1057     static {
1058         try {
1059             // This will force a <clinit> which'll load the certs
1060             Class.forName("org.ibex.net.ssl.RootCerts");
1061             log("Loaded root keys from org.ibex.net.ssl.RootCerts");
1062         } catch(ClassNotFoundException e) {
1063             InputStream is = SSL.class.getClassLoader().getResourceAsStream("org.ibex/net/ssl/rootcerts.dat");
1064             if(is != null) {
1065                 try {
1066                     addCompactCAKeys(is);
1067                     log("Loaded root certs from rootcerts.dat");
1068                 } catch(IOException e2) {
1069                     log("Error loading certs from rootcerts.dat: " + e2.getMessage()); 
1070                 }
1071             }
1072         }
1073     }
1074         
1075     public static int addCompactCAKeys(InputStream is) throws IOException {
1076         synchronized(caKeys) {
1077             try {
1078                 Vector seq = (Vector) new DER.InputStream(is).readObject();
1079                 for(Enumeration e = seq.elements(); e.hasMoreElements();) {
1080                     Vector seq2 = (Vector) e.nextElement();
1081                     X509.Name subject = new X509.Name(seq2.elementAt(0));
1082                     RSA.PublicKey pks = new RSA.PublicKey(seq2.elementAt(1));
1083                     addCAKey(subject,pks);
1084                 }
1085                 return seq.size();
1086             } catch(RuntimeException e) {
1087                 e.printStackTrace();
1088                 throw new IOException("error while reading stream: " + e);
1089             }
1090         }
1091     }
1092     
1093     public static synchronized void setVerifyCallback(VerifyCallback cb) { verifyCallback = cb; }
1094     
1095     // State Info
1096     public static class State {
1097         byte[] sessionID;
1098         byte[] masterSecret;
1099         State(byte[] sessionID, byte[] masterSecret) {
1100             this.sessionID = sessionID;
1101             this.masterSecret = masterSecret;
1102         }
1103     }
1104     
1105     public interface VerifyCallback {
1106         public boolean checkCerts(X509.Certificate[] certs, String hostname, Exn exn);
1107     }
1108     
1109     // Helper methods
1110     private static final int min(int a, int b) { return a < b ? a : b; }
1111 }