2ef83ae73e7d743de6ed9d6d59119bc4d4d9a475
[ghc-base.git] / Control / Concurrent.hs
1 -----------------------------------------------------------------------------
2 -- |
3 -- Module      :  Control.Concurrent
4 -- Copyright   :  (c) The University of Glasgow 2001
5 -- License     :  BSD-style (see the file libraries/core/LICENSE)
6 -- 
7 -- Maintainer  :  libraries@haskell.org
8 -- Stability   :  experimental
9 -- Portability :  non-portable
10 --
11 -- A common interface to a collection of useful concurrency
12 -- abstractions.
13 --
14 -----------------------------------------------------------------------------
15
16 module Control.Concurrent (
17         module Control.Concurrent.Chan,
18         module Control.Concurrent.CVar,
19         module Control.Concurrent.MVar,
20         module Control.Concurrent.QSem,
21         module Control.Concurrent.QSemN,
22         module Control.Concurrent.SampleVar,
23
24         forkIO,                 -- :: IO () -> IO ()
25         yield,                  -- :: IO ()
26
27 #ifdef __GLASGOW_HASKELL__
28         ThreadId,
29
30         -- Forking and suchlike
31         myThreadId,             -- :: IO ThreadId
32         killThread,             -- :: ThreadId -> IO ()
33         throwTo,                -- :: ThreadId -> Exception -> IO ()
34
35         threadDelay,            -- :: Int -> IO ()
36         threadWaitRead,         -- :: Int -> IO ()
37         threadWaitWrite,        -- :: Int -> IO ()
38 #endif
39
40          -- merging of streams
41         mergeIO,                -- :: [a]   -> [a] -> IO [a]
42         nmergeIO                -- :: [[a]] -> IO [a]
43     ) where
44
45 import Prelude
46
47 import Control.Exception as Exception
48
49 #ifdef __GLASGOW_HASKELL__
50 import GHC.Conc
51 import GHC.TopHandler   ( reportStackOverflow, reportError )
52 import GHC.IOBase       ( IO(..) )
53 import GHC.IOBase       ( unsafeInterleaveIO )
54 import GHC.Base
55 #endif
56
57 #ifdef __HUGS__
58 import IOExts ( unsafeInterleaveIO )
59 import ConcBase
60 #endif
61
62 import Control.Concurrent.MVar
63 import Control.Concurrent.CVar
64 import Control.Concurrent.Chan
65 import Control.Concurrent.QSem
66 import Control.Concurrent.QSemN
67 import Control.Concurrent.SampleVar
68
69 -- Thread Ids, specifically the instances of Eq and Ord for these things.
70 -- The ThreadId type itself is defined in std/PrelConc.lhs.
71
72 -- Rather than define a new primitve, we use a little helper function
73 -- cmp_thread in the RTS.
74
75 #ifdef __GLASGOW_HASKELL__
76 foreign import ccall unsafe "cmp_thread" cmp_thread :: Addr# -> Addr# -> Int
77 -- Returns -1, 0, 1
78
79 cmpThread :: ThreadId -> ThreadId -> Ordering
80 cmpThread (ThreadId t1) (ThreadId t2) = 
81    case cmp_thread (unsafeCoerce# t1) (unsafeCoerce# t2) of
82       -1 -> LT
83       0  -> EQ
84       _  -> GT -- must be 1
85
86 instance Eq ThreadId where
87    t1 == t2 = 
88       case t1 `cmpThread` t2 of
89          EQ -> True
90          _  -> False
91
92 instance Ord ThreadId where
93    compare = cmpThread
94
95 foreign import ccall unsafe "rts_getThreadId" getThreadId :: Addr# -> Int
96
97 instance Show ThreadId where
98    showsPrec d (ThreadId t) = 
99         showString "ThreadId " . 
100         showsPrec d (getThreadId (unsafeCoerce# t))
101
102 forkIO :: IO () -> IO ThreadId
103 forkIO action = IO $ \ s -> 
104    case (fork# action_plus s) of (# s1, id #) -> (# s1, ThreadId id #)
105  where
106   action_plus = Exception.catch action childHandler
107
108 childHandler :: Exception -> IO ()
109 childHandler err = Exception.catch (real_handler err) childHandler
110
111 real_handler :: Exception -> IO ()
112 real_handler ex =
113   case ex of
114         -- ignore thread GC and killThread exceptions:
115         BlockedOnDeadMVar            -> return ()
116         AsyncException ThreadKilled  -> return ()
117
118         -- report all others:
119         AsyncException StackOverflow -> reportStackOverflow False
120         ErrorCall s -> reportError False s
121         other       -> reportError False (showsPrec 0 other "\n")
122
123 #endif /* __GLASGOW_HASKELL__ */
124
125
126 max_buff_size :: Int
127 max_buff_size = 1
128
129 mergeIO :: [a] -> [a] -> IO [a]
130 nmergeIO :: [[a]] -> IO [a]
131
132 mergeIO ls rs
133  = newEmptyMVar                >>= \ tail_node ->
134    newMVar tail_node           >>= \ tail_list ->
135    newQSem max_buff_size       >>= \ e ->
136    newMVar 2                   >>= \ branches_running ->
137    let
138     buff = (tail_list,e)
139    in
140     forkIO (suckIO branches_running buff ls) >>
141     forkIO (suckIO branches_running buff rs) >>
142     takeMVar tail_node  >>= \ val ->
143     signalQSem e        >>
144     return val
145
146 type Buffer a 
147  = (MVar (MVar [a]), QSem)
148
149 suckIO :: MVar Int -> Buffer a -> [a] -> IO ()
150
151 suckIO branches_running buff@(tail_list,e) vs
152  = case vs of
153         [] -> takeMVar branches_running >>= \ val ->
154               if val == 1 then
155                  takeMVar tail_list     >>= \ node ->
156                  putMVar node []        >>
157                  putMVar tail_list node
158               else      
159                  putMVar branches_running (val-1)
160         (x:xs) ->
161                 waitQSem e                       >>
162                 takeMVar tail_list               >>= \ node ->
163                 newEmptyMVar                     >>= \ next_node ->
164                 unsafeInterleaveIO (
165                         takeMVar next_node  >>= \ y ->
166                         signalQSem e        >>
167                         return y)                >>= \ next_node_val ->
168                 putMVar node (x:next_node_val)   >>
169                 putMVar tail_list next_node      >>
170                 suckIO branches_running buff xs
171
172 nmergeIO lss
173  = let
174     len = length lss
175    in
176     newEmptyMVar          >>= \ tail_node ->
177     newMVar tail_node     >>= \ tail_list ->
178     newQSem max_buff_size >>= \ e ->
179     newMVar len           >>= \ branches_running ->
180     let
181      buff = (tail_list,e)
182     in
183     mapIO (\ x -> forkIO (suckIO branches_running buff x)) lss >>
184     takeMVar tail_node  >>= \ val ->
185     signalQSem e        >>
186     return val
187   where
188     mapIO f xs = sequence (map f xs)