Revive the static argument transformation
[ghc-hetmet.git] / compiler / simplCore / SAT.lhs
1 %
2 % (c) The GRASP/AQUA Project, Glasgow University, 1992-1998
3 %
4
5 %************************************************************************
6
7                 Static Argument Transformation pass
8
9 %************************************************************************
10
11 May be seen as removing invariants from loops:
12 Arguments of recursive functions that do not change in recursive
13 calls are removed from the recursion, which is done locally
14 and only passes the arguments which effectively change.
15
16 Example:
17 map = /\ ab -> \f -> \xs -> case xs of
18                  []       -> []
19                  (a:b) -> f a : map f b
20
21 as map is recursively called with the same argument f (unmodified)
22 we transform it to
23
24 map = /\ ab -> \f -> \xs -> let map' ys = case ys of
25                        []     -> []
26                        (a:b) -> f a : map' b
27                 in map' xs
28
29 Notice that for a compiler that uses lambda lifting this is
30 useless as map' will be transformed back to what map was.
31
32 We could possibly do the same for big lambdas, but we don't as
33 they will eventually be removed in later stages of the compiler,
34 therefore there is no penalty in keeping them.
35
36 We only apply the SAT when the number of static args is > 2. This
37 produces few bad cases.  See
38         should_transform 
39 in saTransform.
40
41 Here are the headline nofib results:
42                   Size    Allocs   Runtime
43 Min             +0.0%    -13.7%    -21.4%
44 Max             +0.1%     +0.0%     +5.4%
45 Geometric Mean  +0.0%     -0.2%     -6.9%
46
47 The previous patch, to fix polymorphic floatout demand signatures, is
48 essential to make this work well!
49
50
51 \begin{code}
52
53 module SAT ( doStaticArgs ) where
54
55 import DynFlags
56 import Var
57 import VarEnv
58 import CoreSyn
59 import CoreLint
60 import Type
61 import TcType
62 import Id
63 import UniqSupply
64 import Unique
65 import Util
66
67 import Data.List
68 import Panic
69 import FastString
70
71 #include "HsVersions.h"
72 \end{code}
73
74 \begin{code}
75 doStaticArgs :: DynFlags -> UniqSupply -> [CoreBind] -> IO [CoreBind]
76 doStaticArgs dflags us binds = do
77     showPass dflags "Static argument"
78     let binds' = snd $ mapAccumL sat_bind_threaded_us us binds
79     endPass dflags "Static argument" Opt_D_verbose_core2core binds'
80   where
81     sat_bind_threaded_us us bind = 
82         let (us1, us2) = splitUniqSupply us 
83         in (us1, runSAT (satBind bind) us2)
84 \end{code}
85 \begin{code}
86 -- We don't bother to SAT recursive groups since it can lead
87 -- to massive code expansion: see Andre Santos' thesis for details.
88 -- This means we only apply the actual SAT to Rec groups of one element,
89 -- but we want to recurse into the others anyway to discover other binds
90 satBind :: CoreBind -> SatM CoreBind
91 satBind (NonRec binder expr) = do
92     expr' <- satExpr expr
93     return (NonRec binder expr')
94 satBind (Rec [(binder, rhs)]) = do
95     insSAEnvFromBinding binder rhs
96     rhs' <- satExpr rhs
97     saTransform binder rhs'
98 satBind (Rec pairs) = do
99     let (binders, rhss) = unzip pairs
100     rhss' <- mapM satExpr rhss
101     return (Rec (zipEqual "satBind" binders rhss'))
102 \end{code}
103 \begin{code}
104 emptySATInfo :: Id -> Maybe (Id, SATInfo)
105 emptySATInfo v = Just (v, ([], []))
106
107 satExpr :: CoreExpr -> SatM CoreExpr
108 satExpr var@(Var v) = do
109     updSAEnv (emptySATInfo v)
110     return var
111
112 satExpr lit@(Lit _) = do
113     return lit
114
115 satExpr (Lam binders body) = do
116     body' <- satExpr body
117     return (Lam binders body')
118
119 satExpr app@(App _ _) = do
120     getAppArgs app
121
122 satExpr (Case expr bndr ty alts) = do
123     expr' <- satExpr expr
124     alts' <- mapM satAlt alts
125     return (Case expr' bndr ty alts')
126   where
127     satAlt (con, bndrs, expr) = do
128         expr' <- satExpr expr
129         return (con, bndrs, expr')
130
131 satExpr (Let bind body) = do
132     body' <- satExpr body
133     bind' <- satBind bind
134     return (Let bind' body')
135
136 satExpr (Note note expr) = do
137     expr' <- satExpr expr
138     return (Note note expr')
139
140 satExpr ty@(Type _) = do
141     return ty
142
143 satExpr (Cast expr coercion) = do
144     expr' <- satExpr expr
145     return (Cast expr' coercion)
146 \end{code}
147
148 \begin{code}
149 getAppArgs :: CoreExpr -> SatM CoreExpr
150 getAppArgs app = do
151     (app', result) <- get app
152     updSAEnv result
153     return app'
154   where
155     get :: CoreExpr -> SatM (CoreExpr, Maybe (Id, SATInfo))
156     get (App e (Type ty)) = do
157         (e', result) <- get e
158         return
159             (App e' (Type ty),
160             case result of
161                 Nothing            -> Nothing
162                 Just (v, (tv, lv)) -> Just (v, (tv ++ [Static ty], lv)))
163
164     get (App e a) = do
165         (e', result) <- get e
166         a' <- satExpr a
167         
168         let si = case a' of
169                     Var v -> Static v
170                     _     -> NotStatic
171         return
172             (App e' a',
173             case result of
174                 Just (v, (tv, lv))  -> Just (v, (tv, lv ++ [si]))
175                 Nothing             -> Nothing)
176
177     get var@(Var v) = do
178         return (var, emptySATInfo v)
179
180     get e = do
181         e' <- satExpr e
182         return (e', Nothing)
183 \end{code}
184
185 %************************************************************************
186
187         Environment
188
189 %************************************************************************
190
191 \begin{code}
192 data SATEnv = SatEnv { idSATInfo :: IdEnv SATInfo }
193
194 emptyEnv :: SATEnv
195 emptyEnv = SatEnv { idSATInfo = emptyVarEnv }
196
197 type SATInfo = ([Staticness Type], [Staticness Id])
198
199 data Staticness a = Static a | NotStatic
200
201 delOneFromSAEnv :: Id -> SatM ()
202 delOneFromSAEnv v = modifyEnv $ \env -> env { idSATInfo = delVarEnv (idSATInfo env) v }
203
204 updSAEnv :: Maybe (Id, SATInfo) -> SatM ()
205 updSAEnv Nothing = do
206     return ()
207 updSAEnv (Just (b, (tyargs, args))) = do
208     r <- getSATInfo b
209     case r of
210       Nothing               -> return ()
211       Just (tyargs', args') -> do
212           delOneFromSAEnv b
213           insSAEnv b (checkArgs (eqWith coreEqType) tyargs tyargs',
214                       checkArgs (eqWith (==)) args args')
215   where eqWith _  NotStatic  NotStatic  = True
216         eqWith eq (Static x) (Static y) = x `eq` y
217         eqWith _  _          _          = False
218
219 checkArgs :: (Staticness a -> Staticness a -> Bool) -> [Staticness a] -> [Staticness a] -> [Staticness a]
220 checkArgs _  as [] = notStatics (length as)
221 checkArgs _  [] as = notStatics (length as)
222 checkArgs eq (a:as) (a':as') | a `eq` a' = a:checkArgs eq as as'
223 checkArgs eq (_:as) (_:as') = NotStatic:checkArgs eq as as'
224
225 notStatics :: Int -> [Staticness a]
226 notStatics n = nOfThem n NotStatic
227
228 insSAEnv :: Id -> SATInfo -> SatM ()
229 insSAEnv b info = modifyEnv $ \env -> env { idSATInfo = extendVarEnv (idSATInfo env) b info }
230
231 insSAEnvFromBinding :: Id -> CoreExpr -> SatM ()
232 insSAEnvFromBinding bndr e = insSAEnv bndr (getArgLists e)
233 \end{code}
234
235 %************************************************************************
236
237         Static Argument Transformation Monad
238
239 %************************************************************************
240
241 Two items of state to thread around: a UniqueSupply and a SATEnv.
242
243 \begin{code}
244 newtype SatM result
245   = SatM (UniqSupply -> SATEnv -> (result, SATEnv))
246
247 instance Monad SatM where
248     (>>=) = thenSAT
249     (>>) = thenSAT_
250     return = returnSAT
251
252 runSAT :: SatM a -> UniqSupply -> a
253 runSAT (SatM f) us = fst $ f us emptyEnv
254
255 thenSAT :: SatM a -> (a -> SatM b) -> SatM b
256 thenSAT (SatM m) k
257   = SatM $ \us env -> 
258     case splitUniqSupply us    of { (s1, s2) ->
259     case m s1 env              of { (m_result, menv) ->
260     case k m_result            of { (SatM k') ->
261     k' s2 menv }}}
262
263 thenSAT_ :: SatM a -> SatM b -> SatM b
264 thenSAT_ (SatM m) (SatM k)
265   = SatM $ \us env ->
266     case splitUniqSupply us    of { (s1, s2) ->
267     case m s1 env               of { (_, menv) ->
268     k s2 menv }}
269
270 returnSAT :: a -> SatM a
271 returnSAT v = withEnv $ \env -> (v, env)
272
273 modifyEnv :: (SATEnv -> SATEnv) -> SatM ()
274 modifyEnv f = SatM $ \_ env -> ((), f env)
275
276 withEnv :: (SATEnv -> (b, SATEnv)) -> SatM b
277 withEnv f = SatM $ \_ env -> f env
278
279 projectFromEnv :: (SATEnv -> a) -> SatM a
280 projectFromEnv f = withEnv (\env -> (f env, env))
281 \end{code}
282
283 %************************************************************************
284
285                 Utility Functions
286
287 %************************************************************************
288
289 \begin{code}
290 getSATInfo :: Id -> SatM (Maybe SATInfo)
291 getSATInfo var = projectFromEnv $ \env -> lookupVarEnv (idSATInfo env) var
292
293 newSATName :: Id -> Type -> SatM Id
294 newSATName _ ty
295   = SatM $ \us env -> (mkSysLocal FSLIT("$sat") (uniqFromSupply us) ty, env)
296
297 getArgLists :: CoreExpr -> ([Staticness Type], [Staticness Id])
298 getArgLists expr
299   = let
300     (tvs, lambda_bounds, _) = collectTyAndValBinders expr
301     in
302     ([ Static (mkTyVarTy tv) | tv <- tvs ],
303      [ Static v              | v <- lambda_bounds ])
304
305 \end{code}
306
307 We implement saTransform using shadowing of binders, that is
308 we transform
309 map = \f as -> case as of
310          [] -> []
311          (a':as') -> let x = f a'
312                  y = map f as'
313                  in x:y
314 to
315 map = \f as -> let map = \f as -> map' as
316            in let rec map' = \as -> case as of
317                       [] -> []
318                       (a':as') -> let x = f a'
319                               y = map f as'
320                               in x:y
321           in map' as
322
323 the inner map should get inlined and eliminated.
324
325 \begin{code}
326 saTransform :: Id -> CoreExpr -> SatM CoreBind
327 saTransform binder rhs = do
328     r <- getSATInfo binder
329     case r of
330       Just (tyargs, args) | should_transform args
331         -> do
332             -- In order to get strictness information on this new binder
333             -- we need to make sure this stage happens >before< the analysis
334             binder' <- newSATName binder (mkSATLamTy tyargs args)
335             new_rhs <- mkNewRhs binder binder' args rhs
336             return (NonRec binder new_rhs)
337       _ -> return (Rec [(binder, rhs)])
338   where
339     should_transform args
340       = staticArgsLength > 1            -- THIS IS THE DECISION POINT
341       where staticArgsLength = length (filter isStatic args)
342     
343     mkNewRhs binder binder' args rhs = let
344         non_static_args :: [Id]
345         non_static_args = get_nsa args rhs_val_binders
346           where
347             get_nsa :: [Staticness a] -> [a] -> [a]
348             get_nsa [] _ = []
349             get_nsa _ [] = []
350             get_nsa (NotStatic:args) (v:as) = v:get_nsa args as
351             get_nsa (_:args)         (_:as) =   get_nsa args as
352
353         -- To do the transformation, the game plan is to:
354         -- 1. Create a small nonrecursive RHS that takes the
355         --    original arguments to the function but discards
356         --    the ones that are static and makes a call to the
357         --    SATed version with the remainder. We intend that
358         --    this will be inlined later, removing the overhead
359         -- 2. Bind this nonrecursive RHS over the original body
360         --    WITH THE SAME UNIQUE as the original body so that
361         --    any recursive calls to the original now go via
362         --    the small wrapper
363         -- 3. Rebind the original function to a new one which contains
364         --    our SATed function and just makes a call to it:
365         --    we call the thing making this call the local body
366
367         local_body = mkApps (Var binder') [Var a | a <- non_static_args]
368
369         nonrec_rhs = mkOrigLam local_body
370
371         -- HACK! The following is a fake SysLocal binder with
372         --  *the same* unique as binder.
373         -- the reason for this is the following:
374         -- this binder *will* get inlined but if it happen to be
375         -- a top level binder it is never removed as dead code,
376         -- therefore we have to remove that information (of it being
377         -- top-level or exported somehow.)
378         -- A better fix is to use binder directly but with the TopLevel
379         -- tag (or Exported tag) modified.
380         fake_binder = mkSysLocal FSLIT("sat")
381                 (getUnique binder)
382                 (idType binder)
383         rec_body = mkLams non_static_args
384                    (Let (NonRec fake_binder nonrec_rhs) {-in-} rhs_body)
385         in return (mkOrigLam (Let (Rec [(binder', rec_body)]) {-in-} local_body))
386       where
387         (rhs_binders, rhs_body) = collectBinders rhs
388         rhs_val_binders = filter isId rhs_binders
389         
390         mkOrigLam = mkLams rhs_binders
391
392     mkSATLamTy tyargs args
393       = substTy (mk_inst_tyenv tyargs tv_tmpl)
394                 (mkSigmaTy tv_tmpl' theta_tys' tau_ty')
395       where
396           -- get type info for the local function:
397           (tv_tmpl, theta_tys, tau_ty) = (tcSplitSigmaTy . idType) binder
398           (reg_arg_tys, res_type)      = splitFunTys tau_ty
399
400           -- now, we drop the ones that are
401           -- static, that is, the ones we will not pass to the local function
402           tv_tmpl'     = dropStatics tyargs tv_tmpl
403
404           -- Extract the args that correspond to the theta tys (e.g. dictionaries) and argument tys (normal values)
405           (args1, args2) = splitAtList theta_tys args
406           theta_tys'     = dropStatics args1 theta_tys
407           reg_arg_tys'   = dropStatics args2 reg_arg_tys
408
409           -- Piece the function type back together from our static-filtered components
410           tau_ty'        = mkFunTys reg_arg_tys' res_type
411
412           mk_inst_tyenv :: [Staticness Type] -> [TyVar] -> TvSubst
413           mk_inst_tyenv []              _      = emptyTvSubst
414           mk_inst_tyenv (Static s:args) (t:ts) = extendTvSubst (mk_inst_tyenv args ts) t s
415           mk_inst_tyenv (_:args)        (_:ts) = mk_inst_tyenv args ts
416           mk_inst_tyenv _               _      = panic "mk_inst_tyenv"
417
418 dropStatics :: [Staticness a] -> [b] -> [b]
419 dropStatics [] t = t
420 dropStatics (Static _:args) (_:ts) = dropStatics args ts
421 dropStatics (_:args)        (t:ts) = t:dropStatics args ts
422 dropStatics _               _      = panic "dropStatics"
423
424 isStatic :: Staticness a -> Bool
425 isStatic NotStatic = False
426 isStatic _         = True
427 \end{code}