[project @ 2006-01-10 10:23:16 by simonmar]
[haskell-directory.git] / GHC / Conc.lhs
index 4286566..896df03 100644 (file)
 -----------------------------------------------------------------------------
 
 -- No: #hide, because bits of this module are exposed by the stm package.
+-- However, we don't want this module to be the home location for the
+-- bits it exports, we'd rather have Control.Concurrent and the other
+-- higher level modules be the home.  Hence:
+
 -- #not-home
 module GHC.Conc
        ( ThreadId(..)
 
        -- Forking and suchlike
+       , forkIO        -- :: IO a -> IO ThreadId
+       , childHandler  -- :: Exception -> IO ()
        , myThreadId    -- :: IO ThreadId
        , killThread    -- :: ThreadId -> IO ()
        , throwTo       -- :: ThreadId -> Exception -> IO ()
@@ -30,6 +36,7 @@ module GHC.Conc
 
        -- Waiting
        , threadDelay           -- :: Int -> IO ()
+       , registerDelay         -- :: Int -> IO (TVar Bool)
        , threadWaitRead        -- :: Int -> IO ()
        , threadWaitWrite       -- :: Int -> IO ()
 
@@ -52,6 +59,7 @@ module GHC.Conc
         , catchSTM      -- :: STM a -> (Exception -> STM a) -> STM a
        , TVar          -- abstract
        , newTVar       -- :: a -> STM (TVar a)
+       , newTVarIO     -- :: a -> STM (TVar a)
        , readTVar      -- :: TVar a -> STM a
        , writeTVar     -- :: a -> TVar a -> STM ()
        , unsafeIOToSTM -- :: IO a -> STM a
@@ -64,6 +72,10 @@ module GHC.Conc
        , asyncReadBA   -- :: Int -> Int -> Int -> Int -> MutableByteArray# RealWorld -> IO (Int, Int)
        , asyncWriteBA  -- :: Int -> Int -> Int -> Int -> MutableByteArray# RealWorld -> IO (Int, Int)
 #endif
+
+#ifndef mingw32_HOST_OS
+       , ensureIOManagerIsRunning
+#endif
         ) where
 
 import System.Posix.Types
@@ -71,6 +83,10 @@ import System.Posix.Internals
 import Foreign
 import Foreign.C
 
+#ifndef __HADDOCK__
+import {-# SOURCE #-} GHC.TopHandler ( reportError, reportStackOverflow )
+#endif
+
 import Data.Maybe
 
 import GHC.Base
@@ -78,7 +94,7 @@ import GHC.IOBase
 import GHC.Num         ( Num(..) )
 import GHC.Real                ( fromIntegral, quot )
 import GHC.Base                ( Int(..) )
-import GHC.Exception    ( Exception(..), AsyncException(..) )
+import GHC.Exception    ( catchException, Exception(..), AsyncException(..) )
 import GHC.Pack                ( packCString# )
 import GHC.Ptr          ( Ptr(..), plusPtr, FunPtr(..) )
 import GHC.STRef
@@ -116,7 +132,34 @@ This misfeature will hopefully be corrected at a later date.
 it defines 'ThreadId' as a synonym for ().
 -}
 
---forkIO has now been hoisted out into the Concurrent library.
+{- |
+This sparks off a new thread to run the 'IO' computation passed as the
+first argument, and returns the 'ThreadId' of the newly created
+thread.
+
+The new thread will be a lightweight thread; if you want to use a foreign
+library that uses thread-local storage, use 'forkOS' instead.
+-}
+forkIO :: IO () -> IO ThreadId
+forkIO action = IO $ \ s -> 
+   case (fork# action_plus s) of (# s1, id #) -> (# s1, ThreadId id #)
+ where
+  action_plus = catchException action childHandler
+
+childHandler :: Exception -> IO ()
+childHandler err = catchException (real_handler err) childHandler
+
+real_handler :: Exception -> IO ()
+real_handler ex =
+  case ex of
+       -- ignore thread GC and killThread exceptions:
+       BlockedOnDeadMVar            -> return ()
+       BlockedIndefinitely          -> return ()
+       AsyncException ThreadKilled  -> return ()
+
+       -- report all others:
+       AsyncException StackOverflow -> reportStackOverflow
+       other       -> reportError other
 
 {- | 'killThread' terminates the given thread (GHC only).
 Any work already done by the thread isn\'t
@@ -138,7 +181,13 @@ target thread.  The calling thread can thus be certain that the target
 thread has received the exception.  This is a useful property to know
 when dealing with race conditions: eg. if there are two threads that
 can kill each other, it is guaranteed that only one of the threads
-will get to kill the other. -}
+will get to kill the other.
+
+If the target thread is currently making a foreign call, then the
+exception will not be raised (and hence 'throwTo' will not return)
+until the call has completed.  This is the case regardless of whether
+the call is inside a 'block' or not.
+ -}
 throwTo :: ThreadId -> Exception -> IO ()
 throwTo (ThreadId id) ex = IO $ \ s ->
    case (killThread# id ex s) of s1 -> (# s1, () #)
@@ -273,6 +322,15 @@ newTVar val = STM $ \s1# ->
     case newTVar# val s1# of
         (# s2#, tvar# #) -> (# s2#, TVar tvar# #)
 
+-- |@IO@ version of 'newTVar'.  This is useful for creating top-level
+-- 'TVar's using 'System.IO.Unsafe.unsafePerformIO', because using
+-- 'atomically' inside 'System.IO.Unsafe.unsafePerformIO' isn't
+-- possible.
+newTVarIO :: a -> IO (TVar a)
+newTVarIO val = IO $ \s1# ->
+    case newTVar# val s1# of
+        (# s2#, tvar# #) -> (# s2#, TVar tvar# #)
+
 -- |Return the current value stored in a TVar
 readTVar :: TVar a -> STM a
 readTVar (TVar tvar#) = STM $ \s# -> readTVar# tvar# s#
@@ -319,16 +377,34 @@ newMVar value =
 -- empty, 'takeMVar' will wait until it is full.  After a 'takeMVar', 
 -- the 'MVar' is left empty.
 -- 
--- If several threads are competing to take the same 'MVar', one is chosen
--- to continue at random when the 'MVar' becomes full.
+-- There are two further important properties of 'takeMVar':
+--
+--   * 'takeMVar' is single-wakeup.  That is, if there are multiple
+--     threads blocked in 'takeMVar', and the 'MVar' becomes full,
+--     only one thread will be woken up.  The runtime guarantees that
+--     the woken thread completes its 'takeMVar' operation.
+--
+--   * When multiple threads are blocked on an 'MVar', they are
+--     woken up in FIFO order.  This is useful for providing
+--     fairness properties of abstractions built using 'MVar's.
+--
 takeMVar :: MVar a -> IO a
 takeMVar (MVar mvar#) = IO $ \ s# -> takeMVar# mvar# s#
 
 -- |Put a value into an 'MVar'.  If the 'MVar' is currently full,
 -- 'putMVar' will wait until it becomes empty.
 --
--- If several threads are competing to fill the same 'MVar', one is
--- chosen to continue at random when the 'MVar' becomes empty.
+-- There are two further important properties of 'putMVar':
+--
+--   * 'putMVar' is single-wakeup.  That is, if there are multiple
+--     threads blocked in 'putMVar', and the 'MVar' becomes empty,
+--     only one thread will be woken up.  The runtime guarantees that
+--     the woken thread completes its 'putMVar' operation.
+--
+--   * When multiple threads are blocked on an 'MVar', they are
+--     woken up in FIFO order.  This is useful for providing
+--     fairness properties of abstractions built using 'MVar's.
+--
 putMVar  :: MVar a -> a -> IO ()
 putMVar (MVar mvar#) x = IO $ \ s# ->
     case putMVar# mvar# x s# of
@@ -466,9 +542,17 @@ threadDelay time
        case delay# time# s of { s -> (# s, () #)
        }}
 
+registerDelay usecs 
+#ifndef mingw32_HOST_OS
+  | threaded = waitForDelayEventSTM usecs
+  | otherwise = error "registerDelay: requires -threaded"
+#else
+  = error "registerDelay: not currently supported on Windows"
+#endif
+
 -- On Windows, we just make a safe call to 'Sleep' to implement threadDelay.
 #ifdef mingw32_HOST_OS
-foreign import ccall safe "Sleep" c_Sleep :: CInt -> IO ()
+foreign import stdcall safe "Sleep" c_Sleep :: CInt -> IO ()
 #endif
 
 foreign import ccall unsafe "rtsSupportsBoundThreads" threaded :: Bool
@@ -512,7 +596,8 @@ data IOReq
   | Write  {-# UNPACK #-} !Fd {-# UNPACK #-} !(MVar ())
 
 data DelayReq
-  = Delay  {-# UNPACK #-} !Int {-# UNPACK #-} !(MVar ())
+  = Delay    {-# UNPACK #-} !Int {-# UNPACK #-} !(MVar ())
+  | DelaySTM {-# UNPACK #-} !Int {-# UNPACK #-} !(TVar Bool)
 
 pendingEvents :: IORef [IOReq]
 pendingDelays :: IORef [DelayReq]
@@ -520,31 +605,33 @@ pendingDelays :: IORef [DelayReq]
 {-# NOINLINE pendingEvents #-}
 {-# NOINLINE pendingDelays #-}
 (pendingEvents,pendingDelays) = unsafePerformIO $ do
-  startIOServiceThread
+  startIOManagerThread
   reqs <- newIORef []
   dels <- newIORef []
   return (reqs, dels)
        -- the first time we schedule an IO request, the service thread
        -- will be created (cool, huh?)
 
-startIOServiceThread :: IO ()
-startIOServiceThread = do
+ensureIOManagerIsRunning :: IO ()
+ensureIOManagerIsRunning 
+  | threaded  = seq pendingEvents $ return ()
+  | otherwise = return ()
+
+startIOManagerThread :: IO ()
+startIOManagerThread = do
         allocaArray 2 $ \fds -> do
-       throwErrnoIfMinus1 "startIOServiceThread" (c_pipe fds)
+       throwErrnoIfMinus1 "startIOManagerThread" (c_pipe fds)
        rd_end <- peekElemOff fds 0
        wr_end <- peekElemOff fds 1
        writeIORef stick (fromIntegral wr_end)
-       quickForkIO $ do
+       c_setIOManagerPipe wr_end
+       forkIO $ do
            allocaBytes sizeofFdSet   $ \readfds -> do
            allocaBytes sizeofFdSet   $ \writefds -> do 
            allocaBytes sizeofTimeVal $ \timeval -> do
            service_loop (fromIntegral rd_end) readfds writefds timeval [] []
        return ()
 
--- XXX: move real forkIO here from Control.Concurrent?
-quickForkIO action = IO $ \s ->
-   case (fork# action s) of (# s1, id #) -> (# s1, ThreadId id #)
-
 service_loop
    :: Fd               -- listen to this for wakeup calls
    -> Ptr CFdSet
@@ -569,29 +656,42 @@ service_loop wakeup readfds writefds ptimeval old_reqs old_delays = do
   fdSet wakeup readfds
   maxfd <- buildFdSets 0 readfds writefds reqs
 
-  -- check the current time and wake up any thread in threadDelay whose
-  -- timeout has expired.  Also find the timeout value for the select() call.
-  now <- getTicksOfDay
-  (delays', timeout) <- getDelay now ptimeval delays
-
   -- perform the select()
-  let do_select = do
+  let do_select delays = do
+         -- check the current time and wake up any thread in
+         -- threadDelay whose timeout has expired.  Also find the
+         -- timeout value for the select() call.
+         now <- getTicksOfDay
+         (delays', timeout) <- getDelay now ptimeval delays
+
          res <- c_select ((max wakeup maxfd)+1) readfds writefds 
                        nullPtr timeout
          if (res == -1)
             then do
                err <- getErrno
                if err == eINTR
-                       then do_select
-                       else return res
+                       then do_select delays'
+                       else return (res,delays')
             else
-               return res
-  res <- do_select
+               return (res,delays')
+
+  (res,delays') <- do_select delays
   -- ToDo: check result
 
-  b <- takeMVar prodding
-  if b then alloca $ \p -> do c_read (fromIntegral wakeup) p 1; return ()
-       else return ()
+  b <- fdIsSet wakeup readfds
+  if b == 0 
+    then return ()
+    else alloca $ \p -> do 
+           c_read (fromIntegral wakeup) p 1; return ()
+           s <- peek p         
+           if (s == 0xff) 
+             then return ()
+             else do handler_tbl <- peek handlers
+                     sp <- peekElemOff handler_tbl (fromIntegral s)
+                     forkIO (do io <- deRefStablePtr sp; io)
+                     return ()
+
+  takeMVar prodding
   putMVar prodding False
 
   reqs' <- completeRequests reqs readfds writefds []
@@ -610,20 +710,29 @@ prodServiceThread = do
   b <- takeMVar prodding
   if (not b) 
     then do fd <- readIORef stick
-           with 42 $ \pbuf -> do c_write (fromIntegral fd) pbuf 1; return ()
+           with 0xff $ \pbuf -> do c_write (fromIntegral fd) pbuf 1; return ()
     else return ()
   putMVar prodding True
 
+foreign import ccall "&signal_handlers" handlers :: Ptr (Ptr (StablePtr (IO ())))
+
+foreign import ccall "setIOManagerPipe"
+  c_setIOManagerPipe :: CInt -> IO ()
+
 -- -----------------------------------------------------------------------------
 -- IO requests
 
 buildFdSets maxfd readfds writefds [] = return maxfd
-buildFdSets maxfd readfds writefds (Read fd m : reqs) = do
-  fdSet fd readfds
-  buildFdSets (max maxfd fd) readfds writefds reqs
-buildFdSets maxfd readfds writefds (Write fd m : reqs) = do
-  fdSet fd writefds
-  buildFdSets (max maxfd fd) readfds writefds reqs
+buildFdSets maxfd readfds writefds (Read fd m : reqs)
+  | fd >= fD_SETSIZE =  error "buildFdSets: file descriptor out of range"
+  | otherwise        =  do
+       fdSet fd readfds
+        buildFdSets (max maxfd fd) readfds writefds reqs
+buildFdSets maxfd readfds writefds (Write fd m : reqs)
+  | fd >= fD_SETSIZE =  error "buildFdSets: file descriptor out of range"
+  | otherwise        =  do
+       fdSet fd writefds
+       buildFdSets (max maxfd fd) readfds writefds reqs
 
 completeRequests [] _ _ reqs' = return reqs'
 completeRequests (Read fd m : reqs) readfds writefds reqs' = do
@@ -667,24 +776,41 @@ waitForDelayEvent usecs = do
   prodServiceThread
   takeMVar m
 
+-- Delays for use in STM
+waitForDelayEventSTM :: Int -> IO (TVar Bool)
+waitForDelayEventSTM usecs = do
+   t <- atomically $ newTVar False
+   now <- getTicksOfDay
+   let target = now + usecs `quot` tick_usecs
+   atomicModifyIORef pendingDelays (\xs -> (DelaySTM target t : xs, ()))
+   prodServiceThread
+   return t  
+    
 -- Walk the queue of pending delays, waking up any that have passed
 -- and return the smallest delay to wait for.  The queue of pending
 -- delays is kept ordered.
 getDelay :: Ticks -> Ptr CTimeVal -> [DelayReq] -> IO ([DelayReq], Ptr CTimeVal)
 getDelay now ptimeval [] = return ([],nullPtr)
-getDelay now ptimeval all@(Delay time m : rest)
-  | now >= time = do
+getDelay now ptimeval all@(d : rest) 
+  = case d of
+     Delay time m | now >= time -> do
        putMVar m ()
        getDelay now ptimeval rest
-  | otherwise = do
-       setTimevalTicks ptimeval (time - now)
+     DelaySTM time t | now >= time -> do
+       atomically $ writeTVar t True
+       getDelay now ptimeval rest
+     _otherwise -> do
+       setTimevalTicks ptimeval (delayTime d - now)
        return (all,ptimeval)
 
 insertDelay :: DelayReq -> [DelayReq] -> [DelayReq]
-insertDelay d@(Delay time m) [] = [d]
-insertDelay d1@(Delay time m) ds@(d2@(Delay time' m') : rest)
-  | time <= time' = d1 : ds
-  | otherwise     = d2 : insertDelay d1 rest
+insertDelay d [] = [d]
+insertDelay d1 ds@(d2 : rest)
+  | delayTime d1 <= delayTime d2 = d1 : ds
+  | otherwise                    = d2 : insertDelay d1 rest
+
+delayTime (Delay t _) = t
+delayTime (DelaySTM t _) = t
 
 type Ticks = Int
 tick_freq  = 50 :: Ticks  -- accuracy of threadDelay (ticks per sec)
@@ -712,6 +838,9 @@ foreign import ccall safe "select"
   c_select :: Fd -> Ptr CFdSet -> Ptr CFdSet -> Ptr CFdSet -> Ptr CTimeVal
            -> IO CInt
 
+foreign import ccall unsafe "hsFD_SETSIZE"
+  fD_SETSIZE :: Fd
+
 foreign import ccall unsafe "hsFD_CLR"
   fdClr :: Fd -> Ptr CFdSet -> IO ()