Fix #4514 - IO manager deadlock
[ghc-base.git] / Control / Concurrent.hs
index 2d9cf57..b49f7db 100644 (file)
@@ -28,6 +28,7 @@ module Control.Concurrent (
 
         forkIO,
 #ifdef __GLASGOW_HASKELL__
+        forkIOUnmasked,
         killThread,
         throwTo,
 #endif
@@ -46,6 +47,7 @@ module Control.Concurrent (
         threadDelay,            -- :: Int -> IO ()
         threadWaitRead,         -- :: Int -> IO ()
         threadWaitWrite,        -- :: Int -> IO ()
+        closeFd,                -- :: (Int -> IO ()) -> Int -> IO ()
 #endif
 
         -- * Communication abstractions
@@ -98,9 +100,9 @@ import Control.Exception.Base as Exception
 #ifdef __GLASGOW_HASKELL__
 import GHC.Exception
 import GHC.Conc         ( ThreadId(..), myThreadId, killThread, yield,
-                          threadDelay, forkIO, childHandler )
+                          threadDelay, forkIO, forkIOUnmasked, childHandler )
 import qualified GHC.Conc
-import GHC.IO           ( IO(..), unsafeInterleaveIO )
+import GHC.IO           ( IO(..), unsafeInterleaveIO, unsafeUnmask )
 import GHC.IORef        ( newIORef, readIORef, writeIORef )
 import GHC.Base
 
@@ -357,13 +359,15 @@ failNonThreaded = fail $ "RTS doesn't support multiple OS threads "
 forkOS action0
     | rtsSupportsBoundThreads = do
         mv <- newEmptyMVar
-        b <- Exception.blocked
+        b <- Exception.getMaskingState
         let
-            -- async exceptions are blocked in the child if they are blocked
+            -- async exceptions are masked in the child if they are masked
             -- in the parent, as for forkIO (see #1048). forkOS_createThread
-            -- creates a thread with exceptions blocked by default.
-            action1 | b = action0
-                    | otherwise = unblock action0
+            -- creates a thread with exceptions masked by default.
+            action1 = case b of
+                        Unmasked -> unsafeUnmask action0
+                        MaskedInterruptible -> action0
+                        MaskedUninterruptible -> uninterruptibleMask_ action0
 
             action_plus = Exception.catch action1 childHandler
 
@@ -403,13 +407,10 @@ runInBoundThread action
             else do
                 ref <- newIORef undefined
                 let action_plus = Exception.try action >>= writeIORef ref
-                resultOrException <-
-                    bracket (newStablePtr action_plus)
-                            freeStablePtr
-                            (\cEntry -> forkOS_entry_reimported cEntry >> readIORef ref)
-                case resultOrException of
-                    Left exception -> Exception.throw (exception :: SomeException)
-                    Right result -> return result
+                bracket (newStablePtr action_plus)
+                        freeStablePtr
+                        (\cEntry -> forkOS_entry_reimported cEntry >> readIORef ref) >>=
+                  unsafeResult
     | otherwise = failNonThreaded
 
 {- | 
@@ -422,20 +423,27 @@ performance loss due to the use of bound threads. A program that
 doesn't need it's main thread to be bound and makes /heavy/ use of concurrency
 (e.g. a web server), might want to wrap it's @main@ action in
 @runInUnboundThread@.
+
+Note that exceptions which are thrown to the current thread are thrown in turn
+to the thread that is executing the given computation. This ensures there's
+always a way of killing the forked thread.
 -}
 runInUnboundThread :: IO a -> IO a
 
 runInUnboundThread action = do
-    bound <- isCurrentThreadBound
-    if bound
-        then do
-            mv <- newEmptyMVar
-            forkIO (Exception.try action >>= putMVar mv)
-            takeMVar mv >>= \ei -> case ei of
-                Left exception -> Exception.throw (exception :: SomeException)
-                Right result -> return result
-        else action
-
+  bound <- isCurrentThreadBound
+  if bound
+    then do
+      mv <- newEmptyMVar
+      mask $ \restore -> do
+        tid <- forkIO $ Exception.try (restore action) >>= putMVar mv
+        let wait = takeMVar mv `Exception.catch` \(e :: SomeException) ->
+                     Exception.throwTo tid e >> wait
+        wait >>= unsafeResult
+    else action
+
+unsafeResult :: Either SomeException a -> IO a
+unsafeResult = either Exception.throwIO return
 #endif /* __GLASGOW_HASKELL__ */
 
 #ifdef __GLASGOW_HASKELL__
@@ -444,6 +452,9 @@ runInUnboundThread action = do
 
 -- | Block the current thread until data is available to read on the
 -- given file descriptor (GHC only).
+--
+-- This will throw an 'IOError' if the file descriptor was closed
+-- while this thread was blocked.
 threadWaitRead :: Fd -> IO ()
 threadWaitRead fd
 #ifdef mingw32_HOST_OS
@@ -453,7 +464,8 @@ threadWaitRead fd
   -- and this only works with -threaded.
   | threaded  = withThread (waitFd fd 0)
   | otherwise = case fd of
-                  0 -> do hWaitForInput stdin (-1); return ()
+                  0 -> do _ <- hWaitForInput stdin (-1)
+                          return ()
                         -- hWaitForInput does work properly, but we can only
                         -- do this for stdin since we know its FD.
                   _ -> error "threadWaitRead requires -threaded on Windows, or use System.IO.hWaitForInput"
@@ -463,6 +475,9 @@ threadWaitRead fd
 
 -- | Block the current thread until data can be written to the
 -- given file descriptor (GHC only).
+--
+-- This will throw an 'IOError' if the file descriptor was closed
+-- while this thread was blocked.
 threadWaitWrite :: Fd -> IO ()
 threadWaitWrite fd
 #ifdef mingw32_HOST_OS
@@ -472,13 +487,31 @@ threadWaitWrite fd
   = GHC.Conc.threadWaitWrite fd
 #endif
 
+-- | Close a file descriptor in a concurrency-safe way (GHC only).  If
+-- you are using 'threadWaitRead' or 'threadWaitWrite' to perform
+-- blocking I\/O, you /must/ use this function to close file
+-- descriptors, or blocked threads may not be woken.
+--
+-- Any threads that are blocked on the file descriptor via
+-- 'threadWaitRead' or 'threadWaitWrite' will be unblocked by having
+-- IO exceptions thrown.
+closeFd :: (Fd -> IO ())        -- ^ Low-level action that performs the real close.
+        -> Fd                   -- ^ File descriptor to close.
+        -> IO ()
+closeFd close fd
+#ifdef mingw32_HOST_OS
+  = close fd
+#else
+  = GHC.Conc.closeFd close fd
+#endif
+
 #ifdef mingw32_HOST_OS
 foreign import ccall unsafe "rtsSupportsBoundThreads" threaded :: Bool
 
 withThread :: IO a -> IO a
 withThread io = do
   m <- newEmptyMVar
-  forkIO $ try io >>= putMVar m
+  _ <- mask_ $ forkIO $ try io >>= putMVar m
   x <- takeMVar m
   case x of
     Right a -> return a
@@ -486,9 +519,8 @@ withThread io = do
 
 waitFd :: Fd -> CInt -> IO ()
 waitFd fd write = do
-   throwErrnoIfMinus1 "fdReady" $
-        fdReady (fromIntegral fd) write (fromIntegral iNFINITE) 0
-   return ()
+   throwErrnoIfMinus1_ "fdReady" $
+        fdReady (fromIntegral fd) write iNFINITE 0
 
 iNFINITE :: CInt
 iNFINITE = 0xFFFFFFFF -- urgh