aee36a240e8d3ac38a18c093d8d7399211367830
[ghc-base.git] / Control / Monad.hs
1 {-# OPTIONS -fno-implicit-prelude #-}
2 -----------------------------------------------------------------------------
3 -- |
4 -- Module      :  Control.Monad
5 -- Copyright   :  (c) The University of Glasgow 2001
6 -- License     :  BSD-style (see the file libraries/core/LICENSE)
7 -- 
8 -- Maintainer  :  libraries@haskell.org
9 -- Stability   :  provisional
10 -- Portability :  portable
11 --
12 -----------------------------------------------------------------------------
13
14 module Control.Monad
15     ( MonadPlus (   -- class context: Monad
16           mzero     -- :: (MonadPlus m) => m a
17         , mplus     -- :: (MonadPlus m) => m a -> m a -> m a
18         )
19     , join          -- :: (Monad m) => m (m a) -> m a
20     , guard         -- :: (MonadPlus m) => Bool -> m ()
21     , when          -- :: (Monad m) => Bool -> m () -> m ()
22     , unless        -- :: (Monad m) => Bool -> m () -> m ()
23     , ap            -- :: (Monad m) => m (a -> b) -> m a -> m b
24     , msum          -- :: (MonadPlus m) => [m a] -> m a
25     , filterM       -- :: (Monad m) => (a -> m Bool) -> [a] -> m [a]
26     , mapAndUnzipM  -- :: (Monad m) => (a -> m (b,c)) -> [a] -> m ([b], [c])
27     , zipWithM      -- :: (Monad m) => (a -> b -> m c) -> [a] -> [b] -> m [c]
28     , zipWithM_     -- :: (Monad m) => (a -> b -> m c) -> [a] -> [b] -> m ()
29     , foldM         -- :: (Monad m) => (a -> b -> m a) -> a -> [b] -> m a 
30     
31     , liftM         -- :: (Monad m) => (a -> b) -> (m a -> m b)
32     , liftM2        -- :: (Monad m) => (a -> b -> c) -> (m a -> m b -> m c)
33     , liftM3        -- :: ...
34     , liftM4        -- :: ...
35     , liftM5        -- :: ...
36
37     , Monad((>>=), (>>), return, fail)
38     , Functor(fmap)
39
40     , mapM          -- :: (Monad m) => (a -> m b) -> [a] -> m [b]
41     , mapM_         -- :: (Monad m) => (a -> m b) -> [a] -> m ()
42     , sequence      -- :: (Monad m) => [m a] -> m [a]
43     , sequence_     -- :: (Monad m) => [m a] -> m ()
44     , (=<<)         -- :: (Monad m) => (a -> m b) -> m a -> m b
45     ) where
46
47 import Data.Maybe
48
49 #ifdef __GLASGOW_HASKELL__
50 import GHC.List
51 import GHC.Base
52 #endif
53
54 infixr 1 =<<
55
56 -- -----------------------------------------------------------------------------
57 -- Prelude monad functions
58
59 {-# SPECIALISE (=<<) :: (a -> [b]) -> [a] -> [b] #-}
60 (=<<)           :: Monad m => (a -> m b) -> m a -> m b
61 f =<< x         = x >>= f
62
63 sequence       :: Monad m => [m a] -> m [a] 
64 {-# INLINE sequence #-}
65 sequence ms = foldr k (return []) ms
66             where
67               k m m' = do { x <- m; xs <- m'; return (x:xs) }
68
69 sequence_        :: Monad m => [m a] -> m () 
70 {-# INLINE sequence_ #-}
71 sequence_ ms     =  foldr (>>) (return ()) ms
72
73 mapM            :: Monad m => (a -> m b) -> [a] -> m [b]
74 {-# INLINE mapM #-}
75 mapM f as       =  sequence (map f as)
76
77 mapM_           :: Monad m => (a -> m b) -> [a] -> m ()
78 {-# INLINE mapM_ #-}
79 mapM_ f as      =  sequence_ (map f as)
80
81 -- -----------------------------------------------------------------------------
82 -- Monadic classes: MonadPlus
83
84 class Monad m => MonadPlus m where
85    mzero :: m a
86    mplus :: m a -> m a -> m a
87
88 instance MonadPlus [] where
89    mzero = []
90    mplus = (++)
91
92 instance MonadPlus Maybe where
93    mzero = Nothing
94
95    Nothing `mplus` ys  = ys
96    xs      `mplus` _ys = xs
97
98 -- -----------------------------------------------------------------------------
99 -- Functions mandated by the Prelude
100
101 guard           :: (MonadPlus m) => Bool -> m ()
102 guard True      =  return ()
103 guard False     =  mzero
104
105 -- This subsumes the list-based filter function.
106
107 filterM          :: (Monad m) => (a -> m Bool) -> [a] -> m [a]
108 filterM _ []     =  return []
109 filterM p (x:xs) =  do
110    flg <- p x
111    ys  <- filterM p xs
112    return (if flg then x:ys else ys)
113
114 -- This subsumes the list-based concat function.
115
116 msum        :: MonadPlus m => [m a] -> m a
117 {-# INLINE msum #-}
118 msum        =  foldr mplus mzero
119
120 -- -----------------------------------------------------------------------------
121 -- Other monad functions
122
123 join              :: (Monad m) => m (m a) -> m a
124 join x            =  x >>= id
125
126 mapAndUnzipM      :: (Monad m) => (a -> m (b,c)) -> [a] -> m ([b], [c])
127 mapAndUnzipM f xs =  sequence (map f xs) >>= return . unzip
128
129 zipWithM          :: (Monad m) => (a -> b -> m c) -> [a] -> [b] -> m [c]
130 zipWithM f xs ys  =  sequence (zipWith f xs ys)
131
132 zipWithM_         :: (Monad m) => (a -> b -> m c) -> [a] -> [b] -> m ()
133 zipWithM_ f xs ys =  sequence_ (zipWith f xs ys)
134
135 foldM             :: (Monad m) => (a -> b -> m a) -> a -> [b] -> m a
136 foldM _ a []      =  return a
137 foldM f a (x:xs)  =  f a x >>= \fax -> foldM f fax xs
138
139 unless            :: (Monad m) => Bool -> m () -> m ()
140 unless p s        =  if p then return () else s
141
142 when              :: (Monad m) => Bool -> m () -> m ()
143 when p s          =  if p then s else return ()
144
145 ap                :: (Monad m) => m (a -> b) -> m a -> m b
146 ap                =  liftM2 id
147
148 liftM   :: (Monad m) => (a1 -> r) -> m a1 -> m r
149 liftM2  :: (Monad m) => (a1 -> a2 -> r) -> m a1 -> m a2 -> m r
150 liftM3  :: (Monad m) => (a1 -> a2 -> a3 -> r) -> m a1 -> m a2 -> m a3 -> m r
151 liftM4  :: (Monad m) => (a1 -> a2 -> a3 -> a4 -> r) -> m a1 -> m a2 -> m a3 -> m a4 -> m r
152 liftM5  :: (Monad m) => (a1 -> a2 -> a3 -> a4 -> a5 -> r) -> m a1 -> m a2 -> m a3 -> m a4 -> m a5 -> m r
153
154 liftM f m1              = do { x1 <- m1; return (f x1) }
155 liftM2 f m1 m2          = do { x1 <- m1; x2 <- m2; return (f x1 x2) }
156 liftM3 f m1 m2 m3       = do { x1 <- m1; x2 <- m2; x3 <- m3; return (f x1 x2 x3) }
157 liftM4 f m1 m2 m3 m4    = do { x1 <- m1; x2 <- m2; x3 <- m3; x4 <- m4; return (f x1 x2 x3 x4) }
158 liftM5 f m1 m2 m3 m4 m5 = do { x1 <- m1; x2 <- m2; x3 <- m3; x4 <- m4; x5 <- m5; return (f x1 x2 x3 x4 x5) }