udp support
[nestedvm.git] / src / org / ibex / nestedvm / UnixRuntime.java
index 8fb08f9..0b50ee3 100644 (file)
@@ -158,6 +158,9 @@ public abstract class UnixRuntime extends Runtime implements Cloneable {
             case SYS_accept: return sys_accept(a,b,c);
             case SYS_shutdown: return sys_shutdown(a,b);
             case SYS_sysctl: return sys_sysctl(a,b,c,d,e,f);
             case SYS_accept: return sys_accept(a,b,c);
             case SYS_shutdown: return sys_shutdown(a,b);
             case SYS_sysctl: return sys_sysctl(a,b,c,d,e,f);
+            case SYS_sendto: return sys_sendto(a,b,c,d,e,f);
+            case SYS_recvfrom: return sys_recvfrom(a,b,c,d,e,f);
+            case SYS_select: return sys_select(a,b,c,d,e);
 
             default: return super._syscall(syscall,a,b,c,d,e,f);
         }
 
             default: return super._syscall(syscall,a,b,c,d,e,f);
         }
@@ -558,8 +561,6 @@ public abstract class UnixRuntime extends Runtime implements Cloneable {
         return n;
     }
     
         return n;
     }
     
-    // FEATURE: UDP is totally broken
-    
     static class SocketFD extends FD {
         public static final int TYPE_STREAM = 0;
         public static final int TYPE_DGRAM = 1;
     static class SocketFD extends FD {
         public static final int TYPE_STREAM = 0;
         public static final int TYPE_DGRAM = 1;
@@ -569,19 +570,31 @@ public abstract class UnixRuntime extends Runtime implements Cloneable {
         
         int flags;
         int options;
         
         int flags;
         int options;
-        Object o;
+        
+        Socket s;
+        ServerSocket ss;
+        DatagramSocket ds;
+        
         InetAddress bindAddr;
         int bindPort = -1;
         InetAddress bindAddr;
         int bindPort = -1;
+        InetAddress connectAddr;
+        int connectPort = -1;
+        
         DatagramPacket dp;
         InputStream is;
         OutputStream os; 
         
         DatagramPacket dp;
         InputStream is;
         OutputStream os; 
         
-        public SocketFD(int type) { flags = type; }
+        private static final byte[] EMPTY = new byte[0];
+        public SocketFD(int type) {
+               flags = type;
+               if(type == TYPE_DGRAM)
+                       dp = new DatagramPacket(EMPTY,0);
+        }
         
         public void setOptions() {
             try {
         
         public void setOptions() {
             try {
-                if(o != null && type() == TYPE_STREAM && !listen()) {
-                    Platform.socketSetKeepAlive((Socket)o,(options & SO_KEEPALIVE) != 0);
+                if(s != null && type() == TYPE_STREAM && !listen()) {
+                    Platform.socketSetKeepAlive(s,(options & SO_KEEPALIVE) != 0);
                 }
             } catch(SocketException e) {
                 if(STDERR_DIAG) e.printStackTrace();
                 }
             } catch(SocketException e) {
                 if(STDERR_DIAG) e.printStackTrace();
@@ -589,65 +602,84 @@ public abstract class UnixRuntime extends Runtime implements Cloneable {
         }
         
         public void _close() {
         }
         
         public void _close() {
-            if(o != null) {
-                try {
-                    if(type() == TYPE_STREAM) {
-                        if(listen()) ((ServerSocket)o).close();
-                        else ((Socket)o).close();
-                    } else {
-                        ((DatagramSocket)o).close();
-                    }
-                } catch(IOException e) {
-                    /* ignore */
-                }
+            try {
+               if(s != null) s.close();
+               if(ss != null) ss.close();
+               if(ds != null) ds.close();
+            } catch(IOException e) {
+                /* ignore */
             }
         }
         
         public int read(byte[] a, int off, int length) throws ErrnoException {
             }
         }
         
         public int read(byte[] a, int off, int length) throws ErrnoException {
-            if(type() == TYPE_STREAM) {
-                if(is == null) throw new ErrnoException(EPIPE);
-                try {
-                    int n = is.read(a,off,length);
-                    return n < 0 ? 0 : n;
-                } catch(IOException e) {
-                    throw new ErrnoException(EIO);
-                }
-            } else {
-                if(off != 0) throw new IllegalArgumentException("off must be 0");
-                DatagramSocket ds = (DatagramSocket) o;
-                dp.setData(a);
-                dp.setLength(length);
-                try {
-                    ds.receive(dp);
-                } catch(IOException e) {
-                    throw new ErrnoException(EIO);
-                }
-                return dp.getLength();
+            if(type() == TYPE_DGRAM) return recvfrom(a,off,length,null,null);
+            if(is == null) throw new ErrnoException(EPIPE);
+            try {
+                int n = is.read(a,off,length);
+                return n < 0 ? 0 : n;
+            } catch(IOException e) {
+                throw new ErrnoException(EIO);
             }
         }    
         
             }
         }    
         
+        public int recvfrom(byte[] a, int off, int length, InetAddress[] sockAddr, int[] port) throws ErrnoException {
+               if(type() == TYPE_STREAM) return read(a,off,length);
+               
+               if(off != 0) throw new IllegalArgumentException("off must be 0");
+               dp.setData(a);
+               dp.setLength(length);
+               try {
+                       if(ds == null) ds = new DatagramSocket();
+                       ds.receive(dp);
+               } catch(IOException e) {
+                       if(STDERR_DIAG) e.printStackTrace();
+                       throw new ErrnoException(EIO);
+               }
+               if(sockAddr != null) {
+                       sockAddr[0] = dp.getAddress();
+                       port[0] = dp.getPort();
+               }
+               return dp.getLength();
+        }
+        
         public int write(byte[] a, int off, int length) throws ErrnoException {
         public int write(byte[] a, int off, int length) throws ErrnoException {
-            if(type() == TYPE_STREAM) {
-                if(os == null) throw new ErrnoException(EPIPE);
-                try {
-                    os.write(a,off,length);
-                    return length;
-                } catch(IOException e) {
-                    throw new ErrnoException(EIO);
-                }
-            } else {
-                if(off != 0) throw new IllegalArgumentException("off must be 0");
-                DatagramSocket ds = (DatagramSocket) o;
-                dp.setData(a);
-                dp.setLength(length);
-                try {
-                    ds.send(dp);
-                } catch(IOException e) {
-                    throw new ErrnoException(EIO);
-                }
-                return dp.getLength();
+            if(type() == TYPE_DGRAM) return  sendto(a,off,length,null,-1);
+
+            if(os == null) throw new ErrnoException(EPIPE);
+            try {
+                os.write(a,off,length);
+                return length;
+            } catch(IOException e) {
+                throw new ErrnoException(EIO);
             }
         }
             }
         }
+        
+        public int sendto(byte[] a, int off, int length, InetAddress destAddr, int destPort) throws ErrnoException {
+               if(off != 0) throw new IllegalArgumentException("off must be 0");
+               if(type() == TYPE_STREAM) return write(a,off,length);
+               
+               if(destAddr == null) {
+                       destAddr = connectAddr;
+                       destPort = connectPort;
+                       
+                       if(destAddr == null) throw new ErrnoException(ENOTCONN);
+               }
+               
+               dp.setAddress(destAddr);
+               dp.setPort(destPort);
+               dp.setData(a);
+               dp.setLength(length);
+               
+               try {
+                       if(ds == null) ds = new DatagramSocket();
+                       ds.send(dp);
+               } catch(IOException e) {
+                       if(STDERR_DIAG) e.printStackTrace();
+                       if("Network is unreachable".equals(e.getMessage())) throw new ErrnoException(EHOSTUNREACH);
+                       throw new ErrnoException(EIO);
+               }
+               return dp.getLength();
+        }
 
         public int flags() { return O_RDWR; }
         public FStat _fstat() { return new SocketFStat(); }
 
         public int flags() { return O_RDWR; }
         public FStat _fstat() { return new SocketFStat(); }
@@ -669,7 +701,7 @@ public abstract class UnixRuntime extends Runtime implements Cloneable {
     private int sys_connect(int fdn, int addr, int namelen) throws ErrnoException, FaultException {
         SocketFD fd = getSocketFD(fdn);
         
     private int sys_connect(int fdn, int addr, int namelen) throws ErrnoException, FaultException {
         SocketFD fd = getSocketFD(fdn);
         
-        if(fd.type() == SocketFD.TYPE_STREAM && fd.o != null) return -EISCONN;
+        if(fd.type() == SocketFD.TYPE_STREAM && (fd.s != null || fd.ss != null)) return -EISCONN;
         int word1 = memRead(addr);
         if( ((word1 >>> 16)&0xff) != AF_INET) return -EAFNOSUPPORT;
         int port = word1 & 0xffff;
         int word1 = memRead(addr);
         if( ((word1 >>> 16)&0xff) != AF_INET) return -EAFNOSUPPORT;
         int port = word1 & 0xffff;
@@ -683,22 +715,21 @@ public abstract class UnixRuntime extends Runtime implements Cloneable {
             return -EADDRNOTAVAIL;
         }
         
             return -EADDRNOTAVAIL;
         }
         
+        fd.connectAddr = inetAddr;
+        fd.connectPort = port;
+        
         try {
             switch(fd.type()) {
                 case SocketFD.TYPE_STREAM: {
                     Socket s = new Socket(inetAddr,port);
         try {
             switch(fd.type()) {
                 case SocketFD.TYPE_STREAM: {
                     Socket s = new Socket(inetAddr,port);
-                    fd.o = s;
+                    fd.s = s;
                     fd.setOptions();
                     fd.is = s.getInputStream();
                     fd.os = s.getOutputStream();
                     break;
                 }
                     fd.setOptions();
                     fd.is = s.getInputStream();
                     fd.os = s.getOutputStream();
                     break;
                 }
-                case SocketFD.TYPE_DGRAM: {
-                    if(fd.dp == null) fd.dp = new DatagramPacket(null,0);
-                    fd.dp.setAddress(inetAddr);
-                    fd.dp.setPort(port);
+                case SocketFD.TYPE_DGRAM:
                     break;
                     break;
-                }
                 default:
                     throw new Error("should never happen");
             }
                 default:
                     throw new Error("should never happen");
             }
@@ -778,7 +809,7 @@ public abstract class UnixRuntime extends Runtime implements Cloneable {
     private int sys_bind(int fdn, int addr, int namelen) throws FaultException, ErrnoException {
         SocketFD fd = getSocketFD(fdn);
         
     private int sys_bind(int fdn, int addr, int namelen) throws FaultException, ErrnoException {
         SocketFD fd = getSocketFD(fdn);
         
-        if(fd.type() == SocketFD.TYPE_STREAM && fd.o != null) return -EISCONN;
+        if(fd.type() == SocketFD.TYPE_STREAM && (fd.s != null || fd.ss != null)) return -EISCONN;
         int word1 = memRead(addr);
         if( ((word1 >>> 16)&0xff) != AF_INET) return -EAFNOSUPPORT;
         int port = word1 & 0xffff;
         int word1 = memRead(addr);
         if( ((word1 >>> 16)&0xff) != AF_INET) return -EAFNOSUPPORT;
         int port = word1 & 0xffff;
@@ -801,10 +832,9 @@ public abstract class UnixRuntime extends Runtime implements Cloneable {
                 return 0;
             }
             case SocketFD.TYPE_DGRAM: {
                 return 0;
             }
             case SocketFD.TYPE_DGRAM: {
-                DatagramSocket s = (DatagramSocket) fd.o;
-                if(s != null) s.close();
+                if(fd.ds != null) fd.ds.close();
                 try {
                 try {
-                    fd.o = inetAddr != null ? new DatagramSocket(port,inetAddr) : new DatagramSocket(port);
+                    fd.ds = inetAddr != null ? new DatagramSocket(port,inetAddr) : new DatagramSocket(port);
                 } catch(IOException e) {
                     return -EADDRINUSE;
                 }
                 } catch(IOException e) {
                     return -EADDRINUSE;
                 }
@@ -818,11 +848,11 @@ public abstract class UnixRuntime extends Runtime implements Cloneable {
     private int sys_listen(int fdn, int backlog) throws ErrnoException {
         SocketFD fd = getSocketFD(fdn);
         if(fd.type() != SocketFD.TYPE_STREAM) return -EOPNOTSUPP;
     private int sys_listen(int fdn, int backlog) throws ErrnoException {
         SocketFD fd = getSocketFD(fdn);
         if(fd.type() != SocketFD.TYPE_STREAM) return -EOPNOTSUPP;
-        if(fd.o != null) return -EISCONN;
+        if(fd.ss != null || fd.s != null) return -EISCONN;
         if(fd.bindPort < 0) return -EOPNOTSUPP;
         
         try {
         if(fd.bindPort < 0) return -EOPNOTSUPP;
         
         try {
-            fd.o = new ServerSocket(fd.bindPort,backlog,fd.bindAddr);
+            fd.ss = new ServerSocket(fd.bindPort,backlog,fd.bindAddr);
             fd.flags |= SocketFD.LISTEN;
             return 0;
         } catch(IOException e) {
             fd.flags |= SocketFD.LISTEN;
             return 0;
         } catch(IOException e) {
@@ -838,7 +868,7 @@ public abstract class UnixRuntime extends Runtime implements Cloneable {
 
         int size = memRead(lenaddr);
         
 
         int size = memRead(lenaddr);
         
-        ServerSocket s = (ServerSocket) fd.o;
+        ServerSocket s = fd.ss;
         Socket client;
         try {
             client = s.accept();
         Socket client;
         try {
             client = s.accept();
@@ -854,7 +884,7 @@ public abstract class UnixRuntime extends Runtime implements Cloneable {
         }
         
         SocketFD clientFD = new SocketFD(SocketFD.TYPE_STREAM);
         }
         
         SocketFD clientFD = new SocketFD(SocketFD.TYPE_STREAM);
-        clientFD.o = client;
+        clientFD.s = client;
         try {
             clientFD.is = client.getInputStream();
             clientFD.os = client.getOutputStream();
         try {
             clientFD.is = client.getInputStream();
             clientFD.os = client.getOutputStream();
@@ -869,9 +899,9 @@ public abstract class UnixRuntime extends Runtime implements Cloneable {
     private int sys_shutdown(int fdn, int how) throws ErrnoException {
         SocketFD fd = getSocketFD(fdn);
         if(fd.type() != SocketFD.TYPE_STREAM || fd.listen()) return -EOPNOTSUPP;
     private int sys_shutdown(int fdn, int how) throws ErrnoException {
         SocketFD fd = getSocketFD(fdn);
         if(fd.type() != SocketFD.TYPE_STREAM || fd.listen()) return -EOPNOTSUPP;
-        if(fd.o == null) return -ENOTCONN;
+        if(fd.s == null) return -ENOTCONN;
         
         
-        Socket s = (Socket) fd.o;
+        Socket s = fd.s;
         
         try {
             if(how == SHUT_RD || how == SHUT_RDWR) Platform.socketHalfClose(s,false);
         
         try {
             if(how == SHUT_RD || how == SHUT_RDWR) Platform.socketHalfClose(s,false);
@@ -883,6 +913,58 @@ public abstract class UnixRuntime extends Runtime implements Cloneable {
         return 0;
     }
     
         return 0;
     }
     
+    private int sys_sendto(int fdn, int addr, int count, int flags, int destAddr, int socklen) throws ErrnoException,ReadFaultException {
+       SocketFD fd = getSocketFD(fdn);
+       if(flags != 0) throw new ErrnoException(EINVAL);
+       
+       int word1 = memRead(destAddr);
+       if( ((word1 >>> 16)&0xff) != AF_INET) return -EAFNOSUPPORT;
+       int port = word1 & 0xffff;
+       InetAddress inetAddr;
+               byte[] ip = new byte[4];
+               copyin(destAddr+4,ip,4);
+               try {
+                       inetAddr = Platform.inetAddressFromBytes(ip);
+               } catch(UnknownHostException e) {
+                       return -EADDRNOTAVAIL;
+               }
+       
+       count = Math.min(count,MAX_CHUNK);
+       byte[] buf = byteBuf(count);
+       copyin(addr,buf,count);
+       try {
+               return fd.sendto(buf,0,count,inetAddr,port);
+       } catch(ErrnoException e) {
+               if(e.errno == EPIPE) exit(128+13,true);
+               throw e;
+       }
+    }
+    
+    private int sys_recvfrom(int fdn, int addr, int count, int flags, int sourceAddr, int socklenAddr) throws ErrnoException, FaultException {
+       SocketFD fd = getSocketFD(fdn);
+       if(flags != 0) throw new ErrnoException(EINVAL);
+       
+       InetAddress[] inetAddr = sourceAddr == 0 ? null : new InetAddress[1];
+       int[] port = sourceAddr == 0 ? null : new int[1];
+       
+       count = Math.min(count,MAX_CHUNK);
+       byte[] buf = byteBuf(count);
+       int n = fd.recvfrom(buf,0,count,inetAddr,port);
+       copyout(buf,addr,n);
+       
+       if(sourceAddr != 0) {
+               memWrite(sourceAddr,(AF_INET << 16) | port[0]);
+               byte[] ip = inetAddr[0].getAddress();
+               copyout(ip,sourceAddr+4,4);
+       }
+       
+       return n;
+    }
+    
+    private int sys_select(int n, int readFDs, int writeFDs, int exceptFDs, int timevalAddr) throws ReadFaultException, ErrnoException {
+       return -ENOSYS;
+    }
+    
     private static String hostName() {
         try {
             return InetAddress.getLocalHost().getHostName();
     private static String hostName() {
         try {
             return InetAddress.getLocalHost().getHostName();