[project @ 2003-06-03 22:26:44 by diatchki]
[ghc-base.git] / Control / Monad / X / StateT.hs
1 -----------------------------------------------------------------------------
2 -- |
3 -- Module      :  Control.Monad.State
4 -- Copyright   :  (c) Andy Gill 2001,
5 --                (c) Oregon Graduate Institute of Science and Technology, 2001
6 -- License     :  BSD-style (see the file libraries/base/LICENSE)
7 -- 
8 -- Maintainer  :  libraries@haskell.org
9 -- Stability   :  experimental
10 -- Portability :  non-portable (multi-param classes, functional dependencies)
11 --
12 -- State monads.
13 --
14 --        This module is inspired by the paper
15 --        /Functional Programming with Overloading and
16 --            Higher-Order Polymorphism/, 
17 --          Mark P Jones (<http://www.cse.ogi.edu/~mpj/>)
18 --                Advanced School of Functional Programming, 1995.
19 --
20 -- See below for examples.
21
22 -----------------------------------------------------------------------------
23
24 module Control.Monad.X.StateT (
25         StateT,
26         runState,
27         runStateS,
28         runStateT,
29         evalStateT,
30         execStateT,
31         mapStateT,
32         withStateT,
33         module T
34   ) where
35
36 import Prelude (Functor(..),Monad(..),(.),fst)
37
38 import Control.Monad
39 import Control.Monad.X.Trans as T
40 import Control.Monad.X.Utils
41 import Control.Monad.X.Types(StateT(..))
42
43 instance MonadTrans (StateT s) where
44   lift m    = S (\s -> liftM (\a -> (a,s)) m)
45
46 instance HasBaseMonad m n => HasBaseMonad (StateT s m) n where
47   inBase    = inBase'
48
49 instance (Monad m) => Functor (StateT s m) where
50   fmap      = liftM
51
52 instance (Monad m) => Monad (StateT s m) where
53   return    = return'
54   m >>= k   = S (\s -> do (a, s') <- m $$ s
55                           k a $$ s')
56   fail      = fail'
57
58
59 runState      :: Monad m => s -> StateT s m a -> m a
60 runState s m  = liftM fst (runStateS s m)
61
62 runStateS     :: s -> StateT s m a -> m (a,s)
63 runStateS s m = m $$ s
64
65
66 runStateT   :: StateT s m a -> s -> m (a,s)
67 runStateT   = ($$)
68
69 evalStateT :: (Monad m) => StateT s m a -> s -> m a
70 evalStateT m s = do
71         (a, _) <- m $$ s
72         return a
73
74 execStateT :: (Monad m) => StateT s m a -> s -> m s
75 execStateT m s = do
76         (_, s') <- m $$ s
77         return s'
78
79 mapStateT :: (m (a, s) -> n (b, s)) -> StateT s m a -> StateT s n b
80 mapStateT f m = S (f . (m $$))
81
82 withStateT :: (s -> s) -> StateT s m a -> StateT s m a
83 withStateT f m = S ((m $$) . f)
84
85 ($$)          = unS
86
87
88 instance (MonadReader r m) => MonadReader r (StateT s m) where
89   ask         = ask'
90   local       = local' mapStateT
91
92 instance (MonadWriter w m) => MonadWriter w (StateT s m) where
93   tell        = tell'
94   listen      = listen2' S unS (\w (a,s) -> ((a,w),s)) 
95
96 instance (Monad m) => MonadState s (StateT s m) where
97   get         = S (\s -> return (s, s))
98   put s       = S (\_ -> return ((), s))
99
100 instance (MonadError e m) => MonadError e (StateT s m) where
101   throwError  = throwError'
102   catchError  = catchError2' S ($$)
103
104 instance (MonadPlus m) => MonadPlus (StateT s m) where
105   mzero       = mzero'
106   mplus       = mplus2' S ($$)
107
108 -- 'findAll' does not affect the state
109 -- if interested in the state as well as the result, use 
110 -- `get` before `findAll`.
111 -- e.g. findAllSt m = findAll (do x <- m; y <- get; reutrn (x,y))
112 instance MonadNondet m => MonadNondet (StateT s m) where
113   findAll m   = S (\s -> liftM (\xs -> (fmap fst xs,s)) (findAll (m $$ s)))
114   commit      = mapStateT commit
115
116 instance MonadResume m => MonadResume (StateT s m) where
117   delay       = mapStateT delay
118   force       = mapStateT force
119
120 -- jumping undoes changes to the state state
121 instance MonadCont m => MonadCont (StateT s m) where
122   callCC      = callCC2' S unS (\a s -> (a,s))
123
124