[project @ 1996-01-08 20:28:12 by partain]
[ghc-hetmet.git] / ghc / compiler / simplCore / SATMonad.lhs
1 %
2 % (c) The GRASP/AQUA Project, Glasgow University, 1992-1995
3 %
4 %************************************************************************
5 %*                                                                      *
6 \section[SATMonad]{The Static Argument Transformation pass Monad}
7 %*                                                                      *
8 %************************************************************************
9
10 \begin{code}
11 #include "HsVersions.h"
12
13 module SATMonad (
14         SATInfo(..), updSAEnv,
15         SatM(..), initSAT, emptyEnvSAT,
16         returnSAT, thenSAT, thenSAT_, mapSAT, getSATInfo, newSATName,
17         getArgLists, Arg(..), insSAEnv, saTransform,
18
19         SATEnv(..), isStatic, dropStatics,
20
21         Id, UniType, SplitUniqSupply, PlainCoreExpr(..)
22     ) where
23
24 import AbsUniType       ( mkTyVarTy, mkSigmaTy, TyVarTemplate,
25                           extractTyVarsFromTy, splitType, splitTyArgs,
26                           glueTyArgs, instantiateTy, TauType(..),
27                           Class, ThetaType(..), SigmaType(..),
28                           InstTyEnv(..)
29                         )
30 import IdEnv
31 import Id               ( mkSysLocal, getIdUniType )
32 import Maybes           ( Maybe(..) )
33 import PlainCore
34 import SrcLoc           ( SrcLoc, mkUnknownSrcLoc )
35 import SplitUniq
36 import Unique
37 import Util
38
39 infixr 9 `thenSAT`, `thenSAT_`
40 \end{code}
41
42 %************************************************************************
43 %*                                                                      *
44 \subsection{Static Argument Transformation Environment}
45 %*                                                                      *
46 %************************************************************************
47
48 \begin{code}
49 type SATEnv = IdEnv SATInfo
50
51 type SATInfo = ([Arg UniType],[Arg Id])
52
53 data Arg a = Static a | NotStatic
54     deriving Eq
55
56 delOneFromSAEnv v us env
57   = ((), delOneFromIdEnv env v)
58
59 updSAEnv :: Maybe (Id,SATInfo) -> SatM ()
60 updSAEnv Nothing
61   = returnSAT ()
62 updSAEnv (Just (b,(tyargs,args)))
63   = getSATInfo b      `thenSAT` (\ r ->
64     case r of
65       Nothing              -> returnSAT ()
66       Just (tyargs',args') -> delOneFromSAEnv b `thenSAT_`
67                               insSAEnv b (checkArgs tyargs tyargs',
68                                           checkArgs args args')
69     )
70
71 checkArgs as [] = notStatics (length as)
72 checkArgs [] as = notStatics (length as)
73 checkArgs (a:as) (a':as') | a == a' = a:checkArgs as as'
74 checkArgs (_:as) (_:as') = NotStatic:checkArgs as as'
75
76 notStatics :: Int -> [Arg a]
77 notStatics n = nOfThem n NotStatic
78
79 insSAEnv :: Id -> SATInfo -> SatM ()
80 insSAEnv b info us env
81   = ((), addOneToIdEnv env b info)
82 \end{code}
83
84 %************************************************************************
85 %*                                                                      *
86 \subsection{Static Argument Transformation Monad}
87 %*                                                                      *
88 %************************************************************************
89
90 Two items of state to thread around: a UniqueSupply and a SATEnv.
91
92 \begin{code}
93 type SatM result
94   =  SplitUniqSupply -> SATEnv -> (result, SATEnv)
95
96 initSAT :: SatM a -> SplitUniqSupply -> a
97
98 initSAT f us = fst (f us nullIdEnv)
99
100 thenSAT m k us env
101   = case splitUniqSupply us     of { (s1, s2) ->
102     case m s1 env               of { (m_result, menv) ->
103     k m_result s2 menv }}
104
105 thenSAT_ m k us env
106   = case splitUniqSupply us     of { (s1, s2) ->
107     case m s1 env               of { (_, menv) ->
108     k s2 menv }}
109
110 emptyEnvSAT :: SatM ()
111 emptyEnvSAT us _ = ((), nullIdEnv)
112
113 returnSAT v us env = (v, env)
114
115 mapSAT f []     = returnSAT []
116 mapSAT f (x:xs)
117   = f x         `thenSAT` \ x'  ->
118     mapSAT f xs `thenSAT` \ xs' ->
119     returnSAT (x':xs')
120 \end{code}
121
122 %************************************************************************
123 %*                                                                      *
124 \subsection{Utility Functions}
125 %*                                                                      *
126 %************************************************************************
127
128 \begin{code}
129 getSATInfo :: Id -> SatM (Maybe SATInfo)
130 getSATInfo var us env
131   = (lookupIdEnv env var, env)
132
133 newSATName :: Id -> UniType -> SatM Id
134 newSATName id ty us env
135   = case (getSUnique us) of { unique ->
136     (mkSysLocal new_str unique ty mkUnknownSrcLoc, env) }
137   where
138     new_str = getOccurrenceName id _APPEND_ SLIT("_sat")
139
140 getArgLists :: PlainCoreExpr -> ([Arg UniType],[Arg Id])
141 getArgLists expr
142   = let
143         (tvs, lambda_bounds, body) = digForLambdas expr
144     in
145     ([ Static (mkTyVarTy tv) | tv <- tvs ],
146      [ Static v              | v <- lambda_bounds ])
147
148 dropArgs :: PlainCoreExpr -> PlainCoreExpr
149 dropArgs (CoLam v e)    = dropArgs e
150 dropArgs (CoTyLam ty e) = dropArgs e
151 dropArgs e              = e
152
153 \end{code}
154
155 We implement saTransform using shadowing of binders, that is
156 we transform
157 map = \f as -> case as of
158                  [] -> []
159                  (a':as') -> let x = f a'
160                                  y = map f as'
161                              in x:y
162 to
163 map = \f as -> let map = \f as -> map' as
164                in let rec map' = \as -> case as of
165                                           [] -> []
166                                           (a':as') -> let x = f a'
167                                                           y = map f as'
168                                                       in x:y
169                   in map' as
170
171 the inner map should get inlined and eliminated.
172 \begin{code}
173 saTransform :: Id -> PlainCoreExpr -> SatM PlainCoreBinding
174 saTransform binder rhs
175   = getSATInfo binder `thenSAT` \ r ->
176     case r of
177       -- [Andre] test: do it only if we have more than one static argument.
178       --Just (tyargs,args) | any isStatic args 
179       Just (tyargs,args) | length (filter isStatic args) > 1
180         -> newSATName binder (new_ty tyargs args)  `thenSAT` \ binder' ->
181            mkNewRhs binder binder' tyargs args rhs `thenSAT` \ new_rhs ->
182            trace ("SAT "++ show (length (filter isStatic args))) (
183            returnSAT (CoNonRec binder new_rhs)
184            )
185       _ -> returnSAT (CoRec [(binder, rhs)])
186   where
187     mkNewRhs binder binder' tyargs args rhs
188       = let
189             non_static_args :: [Id]
190             non_static_args
191                = get_nsa args (snd (getArgLists rhs))
192                where
193                  get_nsa :: [Arg a] -> [Arg a] -> [a]
194                  get_nsa [] _ = []
195                  get_nsa _ [] = []
196                  get_nsa (NotStatic:args) (Static v:as) = v:get_nsa args as
197                  get_nsa (_:args)         (_:as)        =   get_nsa args as
198
199             local_body = foldl CoApp (CoVar binder')
200                                 [CoVarAtom a | a <- non_static_args]
201
202             nonrec_rhs = origLams local_body
203
204             -- HACK! The following is a fake SysLocal binder with 
205             -- *the same* unique as binder.
206             -- the reason for this is the following:
207             -- this binder *will* get inlined but if it happen to be
208             -- a top level binder it is never removed as dead code,
209             -- therefore we have to remove that information (of it being
210             -- top-level or exported somehow.
211             -- A better fix is to use binder directly but with the TopLevel
212             -- tag (or Exported tag) modified.
213             fake_binder = mkSysLocal 
214                             (getOccurrenceName binder _APPEND_ SLIT("_fsat")) 
215                             (getTheUnique binder)
216                             (getIdUniType binder) 
217                             mkUnknownSrcLoc
218             rec_body = mkCoLam non_static_args 
219                                ( CoLet (CoNonRec fake_binder nonrec_rhs)
220                                  {-in-} (dropArgs rhs))
221         in
222         returnSAT (
223             origLams (CoLet (CoRec [(binder',rec_body)]) {-in-} local_body)
224         )
225       where
226         origLams = origLams' rhs
227                  where 
228                    origLams' (CoLam v e)     e' = mkCoLam v  (origLams' e e')
229                    origLams' (CoTyLam ty e)  e' = CoTyLam ty (origLams' e e')
230                    origLams' _               e' = e'
231
232     new_ty tyargs args
233       = instantiateTy (mk_inst_tyenv tyargs tv_tmpl) 
234                       (mkSigmaTy tv_tmpl' dict_tys' tau_ty')
235       where
236         -- get type info for the local function:
237         (tv_tmpl, dict_tys, tau_ty) = (splitType . getIdUniType) binder
238         (reg_arg_tys, res_type)     = splitTyArgs tau_ty
239
240         -- now, we drop the ones that are
241         -- static, that is, the ones we will not pass to the local function
242         l            = length dict_tys
243         tv_tmpl'     = dropStatics tyargs tv_tmpl
244         dict_tys'    = dropStatics (take l args) dict_tys
245         reg_arg_tys' = dropStatics (drop l args) reg_arg_tys
246         tau_ty'      = glueTyArgs reg_arg_tys' res_type
247
248         mk_inst_tyenv []                    _ = []
249         mk_inst_tyenv (Static s:args) (t:ts)  = (t,s) : mk_inst_tyenv args ts
250         mk_inst_tyenv (_:args)      (_:ts)    = mk_inst_tyenv args ts
251
252 dropStatics [] t = t
253 dropStatics (Static _:args) (t:ts) = dropStatics args ts
254 dropStatics (_:args)        (t:ts) = t:dropStatics args ts
255
256 isStatic :: Arg a -> Bool
257 isStatic NotStatic = False
258 isStatic _         = True
259 \end{code}