[project @ 1999-10-29 13:59:52 by sof]
[ghc-hetmet.git] / ghc / lib / misc / SocketPrim.lhs
index 1d2f9c8..35420b8 100644 (file)
@@ -77,20 +77,24 @@ module SocketPrim (
     packSocketType,
     packSockAddr, unpackSockAddr
 
+    , withSocketsDo  -- :: IO a -> IO a
+
 ) where
  
 import GlaExts
 import ST
 import Ix
-import Weak        ( addForeignFinaliser )
+import Weak        ( addForeignFinalizer )
 import PrelIOBase  -- IOError, Handle representation
 import PrelHandle
+import PrelConc            ( threadWaitRead, threadWaitWrite )
 import Foreign
+import Addr        ( nullAddr )
 
 import IO
 import IOExts      ( IORef, newIORef, readIORef, writeIORef )
 import CString      ( unpackNBytesBAIO,
-                     unpackCString, unpackCStringIO,
+                     unpackCStringIO,
                      unpackCStringLenIO,
                      allocChars
                    )
@@ -188,7 +192,7 @@ instance Num PortNumber where
    signum n  = mkPortNumber (signum (ntohs n))
 
 data SockAddr          -- C Names                              
-#ifndef cygwin32_TARGET_OS
+#if !defined(cygwin32_TARGET_OS) && !defined(mingw32_TARGET_OS)
   = SockAddrUnix        -- struct sockaddr_un
         String          -- sun_path
   | SockAddrInet       -- struct sockaddr_in
@@ -266,7 +270,7 @@ bindSocket :: Socket        -- Unconnected Socket
           -> IO ()
 
 bindSocket (MkSocket s _family _stype _protocol socketStatus) addr = do
-#ifndef cygwin32_TARGET_OS
+#if !defined(cygwin32_TARGET_OS) && !defined(mingw32_TARGET_OS)
  let isDomainSocket = if _family == AF_UNIX then 1 else (0::Int)
 #else
  let isDomainSocket = 0
@@ -278,7 +282,7 @@ bindSocket (MkSocket s _family _stype _protocol socketStatus) addr = do
         show currentStatus))
   else do
    addr' <- packSockAddr addr
-   let (_,sz) = boundsOfByteArray addr'
+   let (_,sz) = boundsOfMutableByteArray addr'
    status <- _ccall_ bindSocket s addr' sz (isDomainSocket::Int)
    case (status::Int) of
      -1 -> constructErrorAndFail "bindSocket"
@@ -301,7 +305,7 @@ connect :: Socket   -- Unconnected Socket
        -> IO ()
 
 connect (MkSocket s _family _stype _protocol socketStatus) addr = do
-#ifndef cygwin32_TARGET_OS
+#if !defined(mingw32_TARGET_OS) && !defined(cygwin32_TARGET_OS)
  let isDomainSocket = if _family == AF_UNIX then 1 else (0::Int)
 #else
  let isDomainSocket = 0
@@ -313,10 +317,12 @@ connect (MkSocket s _family _stype _protocol socketStatus) addr = do
          show currentStatus))
   else do
    addr' <- packSockAddr addr
-   let (_,sz) = boundsOfByteArray addr'
+   let (_,sz) = boundsOfMutableByteArray addr'
    status <- _ccall_ connectSocket s addr' sz (isDomainSocket::Int)
    case (status::Int) of
      -1 -> constructErrorAndFail "connect"
+     -6 -> do threadWaitWrite s >> writeIORef socketStatus Connected
+          -- ToDo: check for error with getsockopt
      _  -> writeIORef socketStatus Connected
 \end{code}
        
@@ -368,16 +374,28 @@ accept sock@(MkSocket s family stype protocol status) = do
         show currentStatus))
    else do
      (ptr, sz) <- allocSockAddr family
-     int_star <- stToIO (newIntArray (0,1))
+     int_star <- stToIO (newIntArray ((0::Int),1))
      stToIO (writeIntArray int_star 0 sz)
+     new_sock <- accept_socket s ptr int_star
+     a_sz <- stToIO (readIntArray int_star 0)
+     addr <- unpackSockAddr ptr a_sz
+     new_status <- newIORef Connected
+     return ((MkSocket new_sock family stype protocol new_status), addr)
+
+accept_socket :: Int 
+       -> MutableByteArray RealWorld Int
+       -> MutableByteArray RealWorld Int
+       -> IO Int
+
+accept_socket s ptr int_star = do
      new_sock <- _ccall_ acceptSocket s ptr int_star
      case (new_sock::Int) of
          -1 -> constructErrorAndFail "accept"
-         _  -> do
-               a_sz <- stToIO (readIntArray int_star 0)
-               addr <- unpackSockAddr ptr a_sz
-               new_status <- newIORef Connected
-               return ((MkSocket new_sock family stype protocol new_status), addr)
+
+               -- wait if there are no pending connections
+         -5 -> threadWaitRead s >> accept_socket s ptr int_star
+
+         _  -> return new_sock
 \end{code}
 
 %************************************************************************
@@ -425,7 +443,7 @@ sendTo (MkSocket s _family _stype _protocol status) xs addr = do
           show currentStatus))
    else do
     addr' <- packSockAddr addr
-    let (_,sz) = boundsOfByteArray addr'
+    let (_,sz) = boundsOfMutableByteArray addr'
     nbytes <- _ccall_ sendTo__ s xs (length xs) addr' sz
     case (nbytes::Int) of
       -1 -> constructErrorAndFail "sendTo"
@@ -511,7 +529,7 @@ getPeerName   :: Socket -> IO SockAddr
 
 getPeerName (MkSocket s family _ _ _) = do
  (ptr, a_sz) <- allocSockAddr family
- int_star <- stToIO (newIntArray (0,1))
+ int_star <- stToIO (newIntArray ((0::Int),1))
  stToIO (writeIntArray int_star 0 a_sz)
  status <- _ccall_ getPeerName s ptr int_star
  case (status::Int) of
@@ -524,7 +542,7 @@ getSocketName :: Socket -> IO SockAddr
 
 getSocketName (MkSocket s family _ _ _) = do
  (ptr, a_sz) <- allocSockAddr family
- int_star <- stToIO (newIntArray (0,1))
+ int_star <- stToIO (newIntArray ((0::Int),1))
  stToIO (writeIntArray int_star 0 a_sz)
  rc <- _ccall_ getSockName s ptr int_star
  case (rc::Int) of
@@ -555,7 +573,9 @@ data SocketOption
     | RecvBuffer    {- SO_RCVBUF    -}
     | KeepAlive     {- SO_KEEPALIVE -}
     | OOBInline     {- SO_OOBINLINE -}
+#if !defined(cygwin32_TARGET_OS) && !defined(mingw32_TARGET_OS)
     | MaxSegment    {- TCP_MAXSEG   -}
+#endif
     | NoDelay       {- TCP_NODELAY  -}
 --    | Linger        {- SO_LINGER    -}
 #if 0
@@ -567,6 +587,15 @@ data SocketOption
     | UseLoopBack   {- SO_USELOOPBACK -}  -- not used, I believe.
 #endif
 
+socketOptLevel :: SocketOption -> Int
+socketOptLevel so = 
+  case so of
+#if !defined(cygwin32_TARGET_OS) && !defined(mingw32_TARGET_OS)
+    MaxSegment   -> ``IPPROTO_TCP''
+#endif
+    NoDelay      -> ``IPPROTO_TCP''
+    _            -> ``SOL_SOCKET''
+
 packSocketOption :: SocketOption -> Int
 packSocketOption so =
   case so of
@@ -580,7 +609,9 @@ packSocketOption so =
     RecvBuffer    -> ``SO_RCVBUF''
     KeepAlive     -> ``SO_KEEPALIVE''
     OOBInline     -> ``SO_OOBINLINE''
+#if !defined(cygwin32_TARGET_OS) && !defined(mingw32_TARGET_OS)
     MaxSegment    -> ``TCP_MAXSEG''
+#endif
     NoDelay       -> ``TCP_NODELAY''
 #if 0
     ReusePort     -> ``SO_REUSEPORT''  -- BSD only?
@@ -596,7 +627,10 @@ setSocketOption :: Socket
                -> Int           -- Option Value
                -> IO ()
 setSocketOption (MkSocket s _ _ _ _) so v = do
-   rc <- _ccall_ setSocketOption__ s (packSocketOption so) v
+   rc <- _ccall_ setSocketOption__ s 
+               (packSocketOption so) 
+               (socketOptLevel so) 
+               v 
    if rc /= (0::Int)
     then constructErrorAndFail "setSocketOption"
     else return ()
@@ -605,7 +639,9 @@ getSocketOption :: Socket
                -> SocketOption  -- Option Name
                -> IO Int         -- Option Value
 getSocketOption (MkSocket s _ _ _ _) so = do
-   rc <- _ccall_ getSocketOption__ s (packSocketOption so)
+   rc <- _ccall_ getSocketOption__ s 
+               (packSocketOption so)
+               (socketOptLevel so)
    if rc == -1 -- let's just hope that value isn't taken..
     then constructErrorAndFail "getSocketOption"
     else return rc
@@ -708,7 +744,7 @@ unpackFamily family = (range (AF_UNSPEC, AF_IPX))!!family
 
 #endif
 
-#if cygwin32_TARGET_OS
+#if defined(cygwin32_TARGET_OS) || defined(mingw32_TARGET_OS)
  
 data Family = 
          AF_UNSPEC     -- unspecified
@@ -951,14 +987,13 @@ packSocketType stype = 1 + (index (Stream, SeqPacket) stype)
 
 -- This is for a box running cygwin32 toolchain.
 
-#if defined(cygwin32_TARGET_OS)
+#if defined(mingw32_TARGET_OS) || defined(cygwin32_TARGET_OS)
 data SocketType = 
          Stream 
        | Datagram
        | Raw 
        | RDM       -- reliably delivered msg
        | SeqPacket
-       | Packet
        deriving (Eq, Ord, Ix, Show)
        
 packSocketType stype =
@@ -968,7 +1003,6 @@ packSocketType stype =
    Raw       -> ``SOCK_RAW''
    RDM       -> ``SOCK_RDM'' 
    SeqPacket -> ``SOCK_SEQPACKET''
-   Packet    -> ``SOCK_PACKET''
 
 #endif
 
@@ -1081,7 +1115,7 @@ sIsWritable = sIsReadable -- sort of.
 -------------------------------------------------------------------------------
 
 sIsAcceptable :: Socket -> IO Bool
-#ifndef cygwin32_TARGET_OS
+#if !defined(cygwin32_TARGET_OS) && !defined(mingw32_TARGET_OS)
 sIsAcceptable (MkSocket _ AF_UNIX Stream _ status) = do
     value <- readIORef status
     return (value == Connected || value == Bound || value == Listening)
@@ -1127,16 +1161,16 @@ Marshaling and allocation helper functions:
 
 allocSockAddr :: Family -> IO (MutableByteArray RealWorld Int, Int)
 
-#ifndef cygwin32_TARGET_OS
+#if !defined(cygwin32_TARGET_OS) && !defined(mingw32_TARGET_OS)
 allocSockAddr AF_UNIX = do
     ptr <- allocChars ``sizeof(struct sockaddr_un)''
-    let (_,sz) = boundsOfByteArray ptr
+    let (_,sz) = boundsOfMutableByteArray ptr
     return (ptr, sz)
 #endif
 
 allocSockAddr AF_INET = do
     ptr <- allocChars ``sizeof(struct sockaddr_in)''
-    let (_,sz) = boundsOfByteArray ptr
+    let (_,sz) = boundsOfMutableByteArray ptr
     return (ptr, sz)
 
 -------------------------------------------------------------------------------
@@ -1145,14 +1179,14 @@ unpackSockAddr :: MutableByteArray RealWorld Int -> Int -> IO SockAddr
 unpackSockAddr arr len = do
     fam <- _casm_ ``%r = ((struct sockaddr*)%0)->sa_family;'' arr
     case unpackFamily fam of
-#ifndef cygwin32_TARGET_OS
+#if !defined(cygwin32_TARGET_OS) && !defined(mingw32_TARGET_OS)
        AF_UNIX -> unpackSockAddrUnix arr (len - ``sizeof(short)'')
 #endif
        AF_INET -> unpackSockAddrInet arr
 
 -------------------------------------------------------------------------------
 
-#ifndef cygwin32_TARGET_OS
+#if !defined(cygwin32_TARGET_OS) && !defined(mingw32_TARGET_OS)
 
 {-
   sun_path is *not* NULL terminated, hence we *do* need to know the
@@ -1178,7 +1212,7 @@ unpackSockAddrInet ptr = do
 
 
 packSockAddr :: SockAddr -> IO (MutableByteArray RealWorld Int)
-#ifndef cygwin32_TARGET_OS
+#if !defined(cygwin32_TARGET_OS) && !defined(mingw32_TARGET_OS)
 packSockAddr (SockAddrUnix path) = do
     (ptr,_) <- allocSockAddr AF_UNIX
     _casm_ ``(((struct sockaddr_un *)%0)->sun_family) = AF_UNIX;''    ptr
@@ -1204,14 +1238,23 @@ it subsequently.
 socketToHandle :: Socket -> IOMode -> IO Handle
 
 socketToHandle (MkSocket fd _ _ _ _) m = do
-    fileobj <- _ccall_ openFd fd (file_mode::Int) (flush_on_close::Int)
-    fo <- makeForeignObj fileobj
-    addForeignFinaliser fo (freeFileObject fo)
-    mkBuffer__ fo 0  -- not buffered
-    hndl <- newHandle (Handle__ fo htype NoBuffering socket_str)
-    return hndl
+    fileobj <- _ccall_ openFd fd (file_mode::Int) (file_flags::Int)
+    if fileobj == nullAddr then
+       ioError (userError "socketHandle: Failed to open file desc")
+     else do
+       fo <- mkForeignObj fileobj
+       addForeignFinalizer fo (freeFileObject fo)
+       mkBuffer__ fo 0  -- not buffered
+       hndl <- newHandle (Handle__ fo htype NoBuffering socket_str)
+       return hndl
  where
   socket_str = "<socket: "++show fd
+#if defined(mingw32_TARGET_OS)
+  file_flags = flush_on_close + 1024{-I'm a socket fd, me!-}
+#else
+  file_flags = flush_on_close
+#endif
+
   (flush_on_close, file_mode) =
    case m of 
            AppendMode    -> (1, 0)
@@ -1231,3 +1274,28 @@ socketToHandle (MkSocket s family stype protocol status) m =
 #endif
 \end{code}
 
+If you're using WinSock, the programmer has to call a startup
+routine before starting to use the goods. So, if you want to
+stay portable across all ghc-supported platforms, you have to
+use @withSocketsDo@...:
+
+\begin{code}
+withSocketsDo :: IO a -> IO a
+#if !defined(HAVE_WINSOCK_H) || defined(cygwin32_TARGET_OS)
+withSocketsDo x = x
+#else
+withSocketsDo act = do
+   x <- initWinSock
+   if ( x /= 0 ) then
+     ioError (userError "Failed to initialise WinSock")
+    else do
+      v <- act
+      shutdownWinSock
+      return v
+
+foreign import "initWinSock" initWinSock :: IO Int
+foreign import "shutdownWinSock" shutdownWinSock :: IO ()
+
+#endif
+
+\end{code}