5e406d175f1b5edc5da77bff913d0374aa6edf2c
[ghc-hetmet.git] / ghc / compiler / simplStg / LambdaLift.lhs
1 %
2 % (c) The AQUA Project, Glasgow University, 1994-1995
3 %
4 \section[LambdaLift]{A STG-code lambda lifter}
5
6 \begin{code}
7 #include "HsVersions.h"
8
9 module LambdaLift ( liftProgram ) where
10
11 import StgSyn
12
13 import AbsUniType       ( mkForallTy, splitForalls, glueTyArgs,
14                           UniType, RhoType(..), TauType(..)
15                         )
16 import Bag
17 import Id               ( mkSysLocal, getIdUniType, addIdArity, Id )
18 import IdEnv
19 import Maybes
20 import SplitUniq
21 import SrcLoc           ( mkUnknownSrcLoc, SrcLoc )
22 import UniqSet
23 import Util
24 \end{code}
25
26 This is the lambda lifter.  It turns lambda abstractions into
27 supercombinators on a selective basis:
28
29 * Let-no-escaped bindings are never lifted. That's one major reason
30   why the lambda lifter is done in STG.
31
32 * Non-recursive bindings whose RHS is a lambda abstractions are lifted,
33   provided all the occurrences of the bound variable is in a function
34   postition.  In this example, f will be lifted:
35         
36         let     
37           f = \x -> e
38         in
39         ..(f a1)...(f a2)...
40   thus
41
42     $f p q r x = e      -- Supercombinator
43
44         ..($f p q r a1)...($f p q r a2)...
45
46   NOTE that the original binding is eliminated.
47
48   But in this case, f won't be lifted:
49
50         let     
51           f = \x -> e
52         in
53         ..(g f)...(f a2)...
54
55   Why? Because we have to heap-allocate a closure for f thus:
56
57     $f p q r x = e      -- Supercombinator
58
59         let
60           f = $f p q r
61         in 
62         ..(g f)...($f p q r a2)..
63
64   so it might as well be the original lambda abstraction.
65
66   We also do not lift if the function has an occurrence with no arguments, e.g.
67   
68         let
69           f = \x -> e
70         in f
71         
72   as this form is more efficient than if we create a partial application
73
74   $f p q r x = e      -- Supercombinator
75
76         f p q r
77
78 * Recursive bindings *all* of whose RHSs are lambda abstractions are
79   lifted iff
80         - all the occurrences of all the binders are in a function position
81         - there aren't ``too many'' free variables.
82
83   Same reasoning as before for the function-position stuff.  The ``too many
84   free variable'' part comes from considering the (potentially many) 
85   recursive calls, which may now have lots of free vars.
86
87 Recent Observations:
88 * 2 might be already ``too many'' variables to abstract.
89   The problem is that the increase in the number of free variables
90   of closures refering to the lifted function (which is always # of
91   abstracted args - 1) may increase heap allocation a lot.
92   Expeiments are being done to check this...
93 * We do not lambda lift if the function has at least one occurrence
94   without any arguments. This caused lots of problems. Ex:
95   h = \ x -> ... let y = ...
96                  in let let f = \x -> ...y...
97                     in f
98   ==> 
99   f = \y x -> ...y...
100   h = \ x -> ... let y = ...
101                  in f y
102   
103   now f y is a partial application, so it will be updated, and this
104   is Bad.
105
106
107 --- NOT RELEVANT FOR STG ----
108 * All ``lone'' lambda abstractions are lifted.  Notably this means lambda 
109   abstractions:
110         - in a case alternative: case e of True -> (\x->b)
111         - in the body of a let:  let x=e in (\y->b)
112 -----------------------------
113
114 %************************************************************************
115 %*                                                                      *
116 \subsection[Lift-expressions]{The main function: liftExpr}
117 %*                                                                      *
118 %************************************************************************
119
120 \begin{code}
121 liftProgram :: SplitUniqSupply -> [PlainStgBinding] -> [PlainStgBinding]
122 liftProgram us prog = concat (runLM Nothing us (mapLM liftTopBind prog))
123
124
125 liftTopBind :: PlainStgBinding -> LiftM [PlainStgBinding]
126 liftTopBind (StgNonRec id rhs)
127   = dontLiftRhs rhs             `thenLM` \ (rhs', rhs_info) ->
128     returnLM (getScBinds rhs_info ++ [StgNonRec id rhs'])
129
130 liftTopBind (StgRec pairs)
131   = mapAndUnzipLM dontLiftRhs rhss      `thenLM` \ (rhss', rhs_infos) ->
132     returnLM ([co_rec_ify (StgRec (ids `zip` rhss') :
133                            getScBinds (unionLiftInfos rhs_infos))
134              ])
135   where
136    (ids, rhss) = unzip pairs
137 \end{code}
138
139
140 \begin{code}
141 liftExpr :: PlainStgExpr
142          -> LiftM (PlainStgExpr, LiftInfo)
143
144
145 liftExpr expr@(StgConApp con args lvs) = returnLM (expr, emptyLiftInfo)
146 liftExpr expr@(StgPrimApp op args lvs) = returnLM (expr, emptyLiftInfo)
147
148 liftExpr expr@(StgApp (StgLitAtom lit) args lvs) = returnLM (expr, emptyLiftInfo)
149 liftExpr expr@(StgApp (StgVarAtom v)  args lvs)
150   = lookup v            `thenLM` \ ~(sc, sc_args) ->    -- NB the ~.  We don't want to
151                                                         -- poke these bindings too early!
152     returnLM (StgApp (StgVarAtom sc) (map StgVarAtom sc_args ++ args) lvs,
153               emptyLiftInfo)
154         -- The lvs field is probably wrong, but we reconstruct it 
155         -- anyway following lambda lifting
156
157 liftExpr (StgCase scrut lv1 lv2 uniq alts)
158   = liftExpr scrut      `thenLM` \ (scrut', scrut_info) ->
159     lift_alts alts      `thenLM` \ (alts', alts_info) ->
160     returnLM (StgCase scrut' lv1 lv2 uniq alts', scrut_info `unionLiftInfo` alts_info)
161   where
162     lift_alts (StgAlgAlts ty alg_alts deflt)
163         = mapAndUnzipLM lift_alg_alt alg_alts   `thenLM` \ (alg_alts', alt_infos) ->
164           lift_deflt deflt                      `thenLM` \ (deflt', deflt_info) ->
165           returnLM (StgAlgAlts ty alg_alts' deflt', foldr unionLiftInfo deflt_info alt_infos)
166
167     lift_alts (StgPrimAlts ty prim_alts deflt)
168         = mapAndUnzipLM lift_prim_alt prim_alts `thenLM` \ (prim_alts', alt_infos) ->
169           lift_deflt deflt                      `thenLM` \ (deflt', deflt_info) ->
170           returnLM (StgPrimAlts ty prim_alts' deflt', foldr unionLiftInfo deflt_info alt_infos)
171
172     lift_alg_alt (con, args, use_mask, rhs)
173         = liftExpr rhs          `thenLM` \ (rhs', rhs_info) ->
174           returnLM ((con, args, use_mask, rhs'), rhs_info)
175
176     lift_prim_alt (lit, rhs)
177         = liftExpr rhs  `thenLM` \ (rhs', rhs_info) ->
178           returnLM ((lit, rhs'), rhs_info)
179
180     lift_deflt StgNoDefault = returnLM (StgNoDefault, emptyLiftInfo)
181     lift_deflt (StgBindDefault var used rhs)
182         = liftExpr rhs  `thenLM` \ (rhs', rhs_info) ->
183           returnLM (StgBindDefault var used rhs', rhs_info)
184 \end{code}
185
186 Now the interesting cases.  Let no escape isn't lifted.  We turn it
187 back into a let, to play safe, because we have to redo that pass after
188 lambda anyway.
189
190 \begin{code}
191 liftExpr (StgLetNoEscape _ _ (StgNonRec binder rhs) body)
192   = dontLiftRhs rhs     `thenLM` \ (rhs', rhs_info) ->
193     liftExpr body       `thenLM` \ (body', body_info) ->
194     returnLM (StgLet (StgNonRec binder rhs') body', 
195               rhs_info `unionLiftInfo` body_info)
196
197 liftExpr (StgLetNoEscape _ _ (StgRec pairs) body)
198   = liftExpr body                       `thenLM` \ (body', body_info) ->
199     mapAndUnzipLM dontLiftRhs rhss      `thenLM` \ (rhss', rhs_infos) ->
200     returnLM (StgLet (StgRec (binders `zipEqual` rhss')) body',
201               foldr unionLiftInfo body_info rhs_infos)
202   where
203    (binders,rhss) = unzip pairs
204 \end{code}
205
206 \begin{code}
207 liftExpr (StgLet (StgNonRec binder rhs) body)
208   | not (isLiftable rhs)
209   = dontLiftRhs rhs     `thenLM` \ (rhs', rhs_info) ->
210     liftExpr body       `thenLM` \ (body', body_info) ->
211     returnLM (StgLet (StgNonRec binder rhs') body', 
212               rhs_info `unionLiftInfo` body_info)
213
214   | otherwise   -- It's a lambda
215   =     -- Do the body of the let
216     fixLM (\ ~(sc_inline, _, _) ->
217       addScInlines [binder] [sc_inline] (
218         liftExpr body   
219       )                 `thenLM` \ (body', body_info) ->
220
221         -- Deal with the RHS
222       dontLiftRhs rhs           `thenLM` \ (rhs', rhs_info) -> 
223
224         -- All occurrences in function position, so lambda lift
225       getFinalFreeVars (rhsFreeVars rhs)    `thenLM` \ final_free_vars ->
226
227       mkScPieces final_free_vars (binder,rhs')  `thenLM` \ (sc_inline, sc_bind) -> 
228
229       returnLM (sc_inline, 
230                 body', 
231                 nonRecScBind rhs_info sc_bind `unionLiftInfo` body_info)
232
233     )                   `thenLM` \ (_, expr', final_info) ->
234
235     returnLM (expr', final_info)
236
237 liftExpr (StgLet (StgRec pairs) body)
238 --[Andre-testing]  
239   | not (all isLiftableRec rhss)
240   = liftExpr body                       `thenLM` \ (body', body_info) ->
241     mapAndUnzipLM dontLiftRhs rhss      `thenLM` \ (rhss', rhs_infos) ->
242     returnLM (StgLet (StgRec (binders `zipEqual` rhss')) body',
243               foldr unionLiftInfo body_info rhs_infos)
244
245   | otherwise   -- All rhss are liftable
246   = -- Do the body of the let
247     fixLM (\ ~(sc_inlines, _, _) ->
248       addScInlines binders sc_inlines   (
249
250       liftExpr body                     `thenLM` \ (body', body_info) ->
251       mapAndUnzipLM dontLiftRhs rhss    `thenLM` \ (rhss', rhs_infos) ->
252       let
253         -- Find the free vars of all the rhss, 
254         -- excluding the binders themselves.
255         rhs_free_vars = unionManyUniqSets (map rhsFreeVars rhss)
256                         `minusUniqSet`
257                         mkUniqSet binders
258
259         rhs_info      = unionLiftInfos rhs_infos
260       in
261       getFinalFreeVars rhs_free_vars    `thenLM` \ final_free_vars ->
262
263       mapAndUnzipLM (mkScPieces final_free_vars) (binders `zip` rhss')
264                                         `thenLM` \ (sc_inlines, sc_pairs) ->
265       returnLM (sc_inlines, 
266                 body', 
267                 recScBind rhs_info sc_pairs `unionLiftInfo` body_info)
268
269     ))                  `thenLM` \ (_, expr', final_info) ->
270
271     returnLM (expr', final_info)
272   where
273     (binders,rhss)    = unzip pairs
274 \end{code}
275
276 \begin{code}
277 liftExpr (StgSCC ty cc expr)
278   = liftExpr expr `thenLM` \ (expr2, expr_info) ->
279     returnLM (StgSCC ty cc expr2, expr_info)
280 \end{code}
281
282 A binding is liftable if it's a *function* (args not null) and never
283 occurs in an argument position.
284
285 \begin{code}
286 isLiftable :: PlainStgRhs -> Bool
287
288 isLiftable (StgRhsClosure _ (StgBinderInfo arg_occ _ _ _ unapplied_occ) fvs _ args _) 
289
290   -- Experimental evidence suggests we should lift only if we will be
291   -- abstracting up to 4 fvs.
292
293   = if not (null args   ||      -- Not a function
294          unapplied_occ  ||      -- Has an occ with no args at all
295          arg_occ        ||      -- Occurs in arg position
296          length fvs > 4         -- Too many free variables
297         )
298     then {-trace ("LL: " ++ show (length fvs))-} True
299     else False
300 isLiftable other_rhs = False
301
302 isLiftableRec :: PlainStgRhs -> Bool
303
304 -- this is just the same as for non-rec, except we only lift to
305 -- abstract up to 1 argument this avoids undoing Static Argument
306 -- Transformation work
307
308 {- Andre's longer comment about isLiftableRec: 1996/01:
309
310 A rec binding is "liftable" (according to our heuristics) if: 
311 * It is a function, 
312 * all occurrences have arguments, 
313 * does not occur in an argument position and
314 * has up to *2* free variables (including the rec binding variable
315   itself!)
316
317 The point is: my experiments show that SAT is more important than LL.
318 Therefore if we still want to do LL, for *recursive* functions, we do
319 not want LL to undo what SAT did.  We do this by avoiding LL recursive
320 functions that have more than 2 fvs, since if this recursive function
321 was created by SAT (we don't know!), it would have at least 3 fvs: one
322 for the rec binding itself and 2 more for the static arguments (note:
323 this matches with the choice of performing SAT to have at least 2
324 static arguments, if we change things there we should change things
325 here).
326 -}
327
328 isLiftableRec (StgRhsClosure _ (StgBinderInfo arg_occ _ _ _ unapplied_occ) fvs _ args _) 
329   = if not (null args   ||      -- Not a function
330          unapplied_occ  ||      -- Has an occ with no args at all
331          arg_occ        ||      -- Occurs in arg position
332          length fvs > 2         -- Too many free variables
333         )
334     then {-trace ("LLRec: " ++ show (length fvs))-} True
335     else False
336 isLiftableRec other_rhs = False
337
338 rhsFreeVars :: PlainStgRhs -> IdSet
339 rhsFreeVars (StgRhsClosure _ _ fvs _ _ _) = mkUniqSet fvs
340 rhsFreeVars other                         = panic "rhsFreeVars"
341 \end{code}
342
343 dontLiftRhs is like liftExpr, except that it does not lift a top-level
344 lambda abstraction.  It is used for the right-hand sides of
345 definitions where we've decided *not* to lift: for example, top-level
346 ones or mutually-recursive ones where not all are lambdas.
347
348 \begin{code}
349 dontLiftRhs :: PlainStgRhs -> LiftM (PlainStgRhs, LiftInfo)
350
351 dontLiftRhs rhs@(StgRhsCon cc v args) = returnLM (rhs, emptyLiftInfo)
352
353 dontLiftRhs (StgRhsClosure cc bi fvs upd args body) 
354   = liftExpr body       `thenLM` \ (body', body_info) ->
355     returnLM (StgRhsClosure cc bi fvs upd args body', body_info)
356 \end{code}
357
358 \begin{code}
359 mkScPieces :: IdSet             -- Extra args for the supercombinator
360            -> (Id, PlainStgRhs) -- The processed RHS and original Id
361            -> LiftM ((Id,[Id]),         -- Replace abstraction with this;
362                                                 -- the set is its free vars
363                      (Id,PlainStgRhs))  -- Binding for supercombinator
364
365 mkScPieces extra_arg_set (id, StgRhsClosure cc bi _ upd args body)
366   = ASSERT( n_args > 0 )
367         -- Construct the rhs of the supercombinator, and its Id
368     -- this trace blackholes sometimes, don't use it
369     -- trace ("LL " ++ show (length (uniqSetToList extra_arg_set))) (
370     newSupercombinator sc_ty arity  `thenLM` \ sc_id ->
371
372     returnLM ((sc_id, extra_args), (sc_id, sc_rhs))
373     --)
374   where
375     n_args     = length args
376     extra_args = uniqSetToList extra_arg_set
377     arity      = n_args + length extra_args
378
379         -- Construct the supercombinator type
380     type_of_original_id = getIdUniType id
381     extra_arg_tys       = map getIdUniType extra_args
382     (tyvars, rest)      = splitForalls type_of_original_id
383     sc_ty               = mkForallTy tyvars (glueTyArgs extra_arg_tys rest)
384
385     sc_rhs = StgRhsClosure cc bi [] upd (extra_args ++ args) body
386 \end{code}
387
388
389 %************************************************************************
390 %*                                                                      *
391 \subsection[Lift-monad]{The LiftM monad}
392 %*                                                                      *
393 %************************************************************************
394
395 The monad is used only to distribute global stuff, and the unique supply.
396
397 \begin{code}
398 type LiftM a =  LiftFlags
399              -> SplitUniqSupply
400              -> (IdEnv                          -- Domain = candidates for lifting
401                        (Id,                     -- The supercombinator
402                         [Id])                   -- Args to apply it to
403                  )
404              -> a
405
406
407 type LiftFlags = Maybe Int      -- No of fvs reqd to float recursive
408                                 -- binding; Nothing == infinity
409
410
411 runLM :: LiftFlags -> SplitUniqSupply -> LiftM a -> a
412 runLM flags us m = m flags us nullIdEnv
413
414 thenLM :: LiftM a -> (a -> LiftM b) -> LiftM b
415 thenLM m k ci us idenv
416   = k (m ci us1 idenv) ci us2 idenv
417   where
418     (us1, us2) = splitUniqSupply us
419
420 returnLM :: a -> LiftM a
421 returnLM a ci us idenv = a
422
423 fixLM :: (a -> LiftM a) -> LiftM a
424 fixLM k ci us idenv = r
425                        where
426                          r = k r ci us idenv
427
428 mapLM :: (a -> LiftM b) -> [a] -> LiftM [b]
429 mapLM f [] = returnLM []
430 mapLM f (a:as) = f a            `thenLM` \ r ->
431                  mapLM f as     `thenLM` \ rs ->
432                  returnLM (r:rs)
433
434 mapAndUnzipLM :: (a -> LiftM (b,c)) -> [a] -> LiftM ([b],[c])
435 mapAndUnzipLM f []     = returnLM ([],[])
436 mapAndUnzipLM f (a:as) = f a                    `thenLM` \ (b,c) ->
437                          mapAndUnzipLM f as     `thenLM` \ (bs,cs) ->
438                          returnLM (b:bs, c:cs)
439 \end{code}
440
441 \begin{code}
442 newSupercombinator :: UniType 
443                    -> Int               -- Arity
444                    -> LiftM Id
445
446 newSupercombinator ty arity ci us idenv
447   = (mkSysLocal SLIT("sc") uniq ty mkUnknownSrcLoc)     -- ToDo: improve location
448     `addIdArity` arity
449         -- ToDo: rm the addIdArity?  Just let subsequent stg-saturation pass do it?
450   where
451     uniq = getSUnique us
452     
453 lookup :: Id -> LiftM (Id,[Id])
454 lookup v ci us idenv 
455   = case lookupIdEnv idenv v of
456         Just result -> result
457         Nothing     -> (v, [])
458
459 addScInlines :: [Id] -> [(Id,[Id])] -> LiftM a -> LiftM a
460 addScInlines ids values m ci us idenv
461   = m ci us idenv'
462   where
463     idenv' = growIdEnvList idenv (ids `zip_lazy` values)
464
465     -- zip_lazy zips two things together but matches lazily on the
466     -- second argument.  This is important, because the ids are know here,
467     -- but the things they are bound to are decided only later
468     zip_lazy [] _           = []
469     zip_lazy (x:xs) ~(y:ys) = (x,y) : zip_lazy xs ys
470
471
472 -- The free vars reported by the free-var analyser will include
473 -- some ids, f, which are to be replaced by ($f a b c), where $f
474 -- is the supercombinator.  Hence instead of f being a free var,
475 -- {a,b,c} are.
476 --
477 -- Example
478 --      let
479 --         f a = ...y1..y2.....
480 --      in
481 --      let
482 --         g b = ...f...z...
483 --      in
484 --      ...
485 --
486 --  Here the free vars of g are {f,z}; but f will be lambda-lifted
487 --  with free vars {y1,y2}, so the "real~ free vars of g are {y1,y2,z}.
488
489 getFinalFreeVars :: IdSet -> LiftM IdSet
490
491 getFinalFreeVars free_vars ci us idenv 
492   = unionManyUniqSets (map munge_it (uniqSetToList free_vars))
493   where
494     munge_it :: Id -> IdSet     -- Takes a free var and maps it to the "real"
495                                 -- free var
496     munge_it id = case lookupIdEnv idenv id of
497                         Just (_, args) -> mkUniqSet args
498                         Nothing        -> singletonUniqSet id
499   
500 \end{code}
501
502
503 %************************************************************************
504 %*                                                                      *
505 \subsection[Lift-info]{The LiftInfo type}
506 %*                                                                      *
507 %************************************************************************
508
509 \begin{code}
510 type LiftInfo = Bag PlainStgBinding     -- Float to top
511
512 emptyLiftInfo = emptyBag
513                         
514 unionLiftInfo :: LiftInfo -> LiftInfo -> LiftInfo
515 unionLiftInfo binds1 binds2 = binds1 `unionBags` binds2
516
517 unionLiftInfos :: [LiftInfo] -> LiftInfo
518 unionLiftInfos infos = foldr unionLiftInfo emptyLiftInfo infos
519
520 mkScInfo :: PlainStgBinding -> LiftInfo
521 mkScInfo bind = unitBag bind
522
523 nonRecScBind :: LiftInfo                -- From body of supercombinator
524              -> (Id, PlainStgRhs)       -- Supercombinator and its rhs
525              -> LiftInfo
526 nonRecScBind binds (sc_id,sc_rhs) = binds `snocBag` (StgNonRec sc_id sc_rhs)
527
528
529 -- In the recursive case, all the SCs from the RHSs of the recursive group
530 -- are dealing with might potentially mention the new, recursive SCs.
531 -- So we flatten the whole lot into a single recursive group.
532
533 recScBind :: LiftInfo                   -- From body of supercombinator
534            -> [(Id,PlainStgRhs)]        -- Supercombinator rhs
535            -> LiftInfo
536
537 recScBind binds pairs = unitBag (co_rec_ify (StgRec pairs : bagToList binds))
538
539 co_rec_ify :: [PlainStgBinding] -> PlainStgBinding
540 co_rec_ify binds = StgRec (concat (map f binds))
541   where
542     f (StgNonRec id rhs) = [(id,rhs)]
543     f (StgRec pairs)     = pairs
544
545
546 getScBinds :: LiftInfo -> [PlainStgBinding]
547 getScBinds binds = bagToList binds
548
549 looksLikeSATRhs [(f,StgRhsClosure _ _ _ _ ls _)] (StgApp (StgVarAtom f') args _)
550   = (f == f') && (length args == length ls)
551 looksLikeSATRhs _ _ = False
552 \end{code}