3fec375c7d6713e07f52a51239f166ae78785f39
[ghc-base.git] / Control / Monad / State.hs
1 -----------------------------------------------------------------------------
2 -- |
3 -- Module      :  Control.Monad.State
4 -- Copyright   :  (c) Andy Gill 2001,
5 --                (c) Oregon Graduate Institute of Science and Technology, 2001
6 -- License     :  BSD-style (see the file libraries/base/LICENSE)
7 -- 
8 -- Maintainer  :  libraries@haskell.org
9 -- Stability   :  experimental
10 -- Portability :  non-portable (multi-param classes, functional dependencies)
11 --
12 -- State monads.
13 --
14 --        This module is inspired by the paper
15 --        /Functional Programming with Overloading and
16 --            Higher-Order Polymorphism/, 
17 --          Mark P Jones (<http://www.cse.ogi.edu/~mpj/>)
18 --                Advanced School of Functional Programming, 1995.
19 --
20 -- See below for examples.
21
22 -----------------------------------------------------------------------------
23
24 module Control.Monad.State (
25         -- * MonadState class
26         MonadState(..),
27         modify,
28         gets,
29         -- * The State Monad
30         State(..),
31         evalState,
32         execState,
33         mapState,
34         withState,
35         -- * The StateT Monad
36         StateT(..),
37         evalStateT,
38         execStateT,
39         mapStateT,
40         withStateT,
41         module Control.Monad,
42         module Control.Monad.Fix,
43         module Control.Monad.Trans,
44         -- * Examples
45         -- $examples
46   ) where
47
48 import Prelude
49
50 import Control.Monad
51 import Control.Monad.Fix
52 import Control.Monad.Trans
53 import Control.Monad.Reader
54 import Control.Monad.Writer
55
56 -- ---------------------------------------------------------------------------
57 -- | /get/ returns the state from the internals of the monad.
58 --
59 -- /put/ replaces the state inside the monad.
60
61 class (Monad m) => MonadState s m | m -> s where
62         get :: m s
63         put :: s -> m ()
64
65 -- | Monadic state transformer.
66 --
67 --      Maps an old state to a new state inside a state monad.
68 --      The old state is thrown away.
69 --
70 -- >      Main> :t modify ((+1) :: Int -> Int)
71 -- >      modify (...) :: (MonadState Int a) => a ()
72 --
73 --      This says that @modify (+1)@ acts over any
74 --      Monad that is a member of the @MonadState@ class,
75 --      with an @Int@ state.
76
77 modify :: (MonadState s m) => (s -> s) -> m ()
78 modify f = do
79         s <- get
80         put (f s)
81
82 -- | Gets specific component of the state, using a projection function
83 -- supplied.
84
85 gets :: (MonadState s m) => (s -> a) -> m a
86 gets f = do
87         s <- get
88         return (f s)
89
90 -- ---------------------------------------------------------------------------
91 -- | A parameterizable state monad where /s/ is the type of the state
92 -- to carry and /a/ is the type of the /return value/.
93
94 newtype State s a = State { runState :: s -> (a, s) }
95
96 -- The State Monad structure is parameterized over just the state.
97
98 instance Functor (State s) where
99         fmap f m = State $ \s -> let
100                 (a, s') = runState m s
101                 in (f a, s')
102
103 instance Monad (State s) where
104         return a = State $ \s -> (a, s)
105         m >>= k  = State $ \s -> let
106                 (a, s') = runState m s
107                 in runState (k a) s'
108
109 instance MonadFix (State s) where
110         mfix f = State $ \s -> let (a, s') = runState (f a) s in (a, s')
111
112 instance MonadState s (State s) where
113         get   = State $ \s -> (s, s)
114         put s = State $ \_ -> ((), s)
115
116 -- |Evaluate this state monad with the given initial state,throwing
117 -- away the final state.  Very much like @fst@ composed with
118 -- @runstate@.
119
120 evalState :: State s a -- ^The state to evaluate
121           -> s         -- ^An initial value
122           -> a         -- ^The return value of the state application
123 evalState m s = fst (runState m s)
124
125 -- |Execute this state and return the new state, throwing away the
126 -- return value.  Very much like @snd@ composed with
127 -- @runstate@.
128
129 execState :: State s a -- ^The state to evaluate
130           -> s         -- ^An initial value
131           -> s         -- ^The new state
132 execState m s = snd (runState m s)
133
134 -- |Map a stateful computation from one (return value, state) pair to
135 -- another.  For instance, to convert numberTree from a function that
136 -- returns a tree to a function that returns the sum of the numbered
137 -- tree (see the Examples section for numberTree and sumTree) you may
138 -- write:
139 --
140 -- > sumNumberedTree :: (Eq a) => Tree a -> State (Table a) Int
141 -- > sumNumberedTree = mapState (\ (t, tab) -> (sumTree t, tab))  . numberTree
142
143 mapState :: ((a, s) -> (b, s)) -> State s a -> State s b
144 mapState f m = State $ f . runState m
145
146 -- |Apply this function to this state and return the resulting state.
147 withState :: (s -> s) -> State s a -> State s a
148 withState f m = State $ runState m . f
149
150 -- ---------------------------------------------------------------------------
151 -- | A parameterizable state monad for encapsulating an inner
152 -- monad.
153 --
154 -- The StateT Monad structure is parameterized over two things:
155 --
156 --   * s - The state.
157 --
158 --   * m - The inner monad.
159 --
160 -- Here are some examples of use:
161 --
162 -- (Parser from ParseLib with Hugs)
163 --
164 -- >  type Parser a = StateT String [] a
165 -- >     ==> StateT (String -> [(a,String)])
166 --
167 -- For example, item can be written as:
168 --
169 -- >   item = do (x:xs) <- get
170 -- >          put xs
171 -- >          return x
172 -- >
173 -- >   type BoringState s a = StateT s Indentity a
174 -- >        ==> StateT (s -> Identity (a,s))
175 -- >
176 -- >   type StateWithIO s a = StateT s IO a
177 -- >        ==> StateT (s -> IO (a,s))
178 -- >
179 -- >   type StateWithErr s a = StateT s Maybe a
180 -- >        ==> StateT (s -> Maybe (a,s))
181
182 newtype StateT s m a = StateT { runStateT :: s -> m (a,s) }
183
184 instance (Monad m) => Functor (StateT s m) where
185         fmap f m = StateT $ \s -> do
186                 (x, s') <- runStateT m s
187                 return (f x, s')
188
189 instance (Monad m) => Monad (StateT s m) where
190         return a = StateT $ \s -> return (a, s)
191         m >>= k  = StateT $ \s -> do
192                 (a, s') <- runStateT m s
193                 runStateT (k a) s'
194         fail str = StateT $ \_ -> fail str
195
196 instance (MonadPlus m) => MonadPlus (StateT s m) where
197         mzero       = StateT $ \_ -> mzero
198         m `mplus` n = StateT $ \s -> runStateT m s `mplus` runStateT n s
199
200 instance (MonadFix m) => MonadFix (StateT s m) where
201         mfix f = StateT $ \s -> mfix $ \ ~(a, _) -> runStateT (f a) s
202
203 instance (Monad m) => MonadState s (StateT s m) where
204         get   = StateT $ \s -> return (s, s)
205         put s = StateT $ \_ -> return ((), s)
206
207 instance MonadTrans (StateT s) where
208         lift m = StateT $ \s -> do
209                 a <- m
210                 return (a, s)
211
212 instance (MonadIO m) => MonadIO (StateT s m) where
213         liftIO = lift . liftIO
214
215 instance (MonadReader r m) => MonadReader r (StateT s m) where
216         ask       = lift ask
217         local f m = StateT $ \s -> local f (runStateT m s)
218
219 instance (MonadWriter w m) => MonadWriter w (StateT s m) where
220         tell     = lift . tell
221         listen m = StateT $ \s -> do
222                 ((a, s'), w) <- listen (runStateT m s)
223                 return ((a, w), s')
224         pass   m = StateT $ \s -> pass $ do
225                 ((a, f), s') <- runStateT m s
226                 return ((a, s'), f)
227
228 -- |Similar to 'evalState'
229 evalStateT :: (Monad m) => StateT s m a -> s -> m a
230 evalStateT m s = do
231         (a, _) <- runStateT m s
232         return a
233
234 -- |Similar to 'execState'
235 execStateT :: (Monad m) => StateT s m a -> s -> m s
236 execStateT m s = do
237         (_, s') <- runStateT m s
238         return s'
239
240 -- |Similar to 'mapState'
241 mapStateT :: (m (a, s) -> n (b, s)) -> StateT s m a -> StateT s n b
242 mapStateT f m = StateT $ f . runStateT m
243
244 -- |Similar to 'withState'
245 withStateT :: (s -> s) -> StateT s m a -> StateT s m a
246 withStateT f m = StateT $ runStateT m . f
247
248 -- ---------------------------------------------------------------------------
249 -- MonadState instances for other monad transformers
250
251 instance (MonadState s m) => MonadState s (ReaderT r m) where
252         get = lift get
253         put = lift . put
254
255 instance (Monoid w, MonadState s m) => MonadState s (WriterT w m) where
256         get = lift get
257         put = lift . put
258
259 -- ---------------------------------------------------------------------------
260 -- $examples
261 -- A function to increment a counter.  Taken from the paper
262 -- /Generalising Monads to Arrows/, John
263 -- Hughes (<http://www.math.chalmers.se/~rjmh/>), November 1998:
264 --
265 -- > tick :: State Int Int
266 -- > tick = do n <- get
267 -- >           put (n+1)
268 -- >           return n
269 --
270 -- Add one to the given number using the state monad:
271 --
272 -- > plusOne :: Int -> Int
273 -- > plusOne n = execState tick n
274 --
275 -- A contrived addition example. Works only with positive numbers:
276 --
277 -- > plus :: Int -> Int -> Int
278 -- > plus n x = execState (sequence $ replicate n tick) x
279 --
280 -- An example from /The Craft of Functional Programming/, Simon
281 -- Thompson (<http://www.cs.kent.ac.uk/people/staff/sjt/>),
282 -- Addison-Wesley 1999: \"Given an arbitrary tree, transform it to a
283 -- tree of integers in which the original elements are replaced by
284 -- natural numbers, starting from 0.  The same element has to be
285 -- replaced by the same number at every occurrence, and when we meet
286 -- an as-yet-unvisited element we have to find a 'new' number to match
287 -- it with:\"
288 --
289 -- > data Tree a = Nil | Node a (Tree a) (Tree a) deriving (Show, Eq)
290 -- > type Table a = [a]
291 --
292 -- > numberTree :: Eq a => Tree a -> State (Table a) (Tree Int)
293 -- > numberTree Nil = return Nil
294 -- > numberTree (Node x t1 t2) 
295 -- >        =  do num <- numberNode x
296 -- >              nt1 <- numberTree t1
297 -- >              nt2 <- numberTree t2
298 -- >              return (Node num nt1 nt2)
299 -- >     where 
300 -- >     numberNode :: Eq a => a -> State (Table a) Int
301 -- >     numberNode x
302 -- >        = do table <- get
303 -- >             (newTable, newPos) <- return (nNode x table)
304 -- >             put newTable
305 -- >             return newPos
306 -- >     nNode::  (Eq a) => a -> Table a -> (Table a, Int)
307 -- >     nNode x table
308 -- >        = case (findIndexInList (== x) table) of
309 -- >          Nothing -> (table ++ [x], length table)
310 -- >          Just i  -> (table, i)
311 -- >     findIndexInList :: (a -> Bool) -> [a] -> Maybe Int
312 -- >     findIndexInList = findIndexInListHelp 0
313 -- >     findIndexInListHelp _ _ [] = Nothing
314 -- >     findIndexInListHelp count f (h:t)
315 -- >        = if (f h)
316 -- >          then Just count
317 -- >          else findIndexInListHelp (count+1) f t
318 --
319 -- numTree applies numberTree with an initial state:
320 --
321 -- > numTree :: (Eq a) => Tree a -> Tree Int
322 -- > numTree t = evalState (numberTree t) []
323 --
324 -- > testTree = Node "Zero" (Node "One" (Node "Two" Nil Nil) (Node "One" (Node "Zero" Nil Nil) Nil)) Nil
325 -- > numTree testTree => Node 0 (Node 1 (Node 2 Nil Nil) (Node 1 (Node 0 Nil Nil) Nil)) Nil
326 --
327 -- sumTree is a little helper function that does not use the State monad:
328 --
329 -- > sumTree :: (Num a) => Tree a -> a
330 -- > sumTree Nil = 0
331 -- > sumTree (Node e t1 t2) = e + (sumTree t1) + (sumTree t2)