f342664baeaa86e787c19b88f3062b636ca95c08
[ghc-hetmet.git] / ghc / compiler / simplStg / LambdaLift.lhs
1 %
2 % (c) The AQUA Project, Glasgow University, 1994-1996
3 %
4 \section[LambdaLift]{A STG-code lambda lifter}
5
6 \begin{code}
7 module LambdaLift ( liftProgram ) where
8
9 #include "HsVersions.h"
10
11 import StgSyn
12
13 import Bag              ( Bag, emptyBag, unionBags, unitBag, snocBag, bagToList )
14 import MkId             ( mkSysLocal )
15 import Id               ( idType, addIdArity, 
16                           mkIdSet, unitIdSet, minusIdSet, setIdVisibility,
17                           unionManyIdSets, idSetToList, IdSet,
18                           nullIdEnv, growIdEnvList, lookupIdEnv, IdEnv,
19                           Id
20                         )
21 import IdInfo           ( ArityInfo, exactArity )
22 import Name             ( Module )
23 import SrcLoc           ( noSrcLoc )
24 import Type             ( splitForAllTys, mkForAllTys, mkFunTys, Type )
25 import UniqSupply       ( getUnique, splitUniqSupply, UniqSupply )
26 import Util             ( zipEqual, panic, assertPanic )
27 \end{code}
28
29 This is the lambda lifter.  It turns lambda abstractions into
30 supercombinators on a selective basis:
31
32 * Let-no-escaped bindings are never lifted. That's one major reason
33   why the lambda lifter is done in STG.
34
35 * Non-recursive bindings whose RHS is a lambda abstractions are lifted,
36   provided all the occurrences of the bound variable is in a function
37   postition.  In this example, f will be lifted:
38
39         let
40           f = \x -> e
41         in
42         ..(f a1)...(f a2)...
43   thus
44
45     $f p q r x = e      -- Supercombinator
46
47         ..($f p q r a1)...($f p q r a2)...
48
49   NOTE that the original binding is eliminated.
50
51   But in this case, f won't be lifted:
52
53         let
54           f = \x -> e
55         in
56         ..(g f)...(f a2)...
57
58   Why? Because we have to heap-allocate a closure for f thus:
59
60     $f p q r x = e      -- Supercombinator
61
62         let
63           f = $f p q r
64         in
65         ..(g f)...($f p q r a2)..
66
67   so it might as well be the original lambda abstraction.
68
69   We also do not lift if the function has an occurrence with no arguments, e.g.
70
71         let
72           f = \x -> e
73         in f
74
75   as this form is more efficient than if we create a partial application
76
77   $f p q r x = e      -- Supercombinator
78
79         f p q r
80
81 * Recursive bindings *all* of whose RHSs are lambda abstractions are
82   lifted iff
83         - all the occurrences of all the binders are in a function position
84         - there aren't ``too many'' free variables.
85
86   Same reasoning as before for the function-position stuff.  The ``too many
87   free variable'' part comes from considering the (potentially many)
88   recursive calls, which may now have lots of free vars.
89
90 Recent Observations:
91
92 * 2 might be already ``too many'' variables to abstract.
93   The problem is that the increase in the number of free variables
94   of closures refering to the lifted function (which is always # of
95   abstracted args - 1) may increase heap allocation a lot.
96   Expeiments are being done to check this...
97
98 * We do not lambda lift if the function has at least one occurrence
99   without any arguments. This caused lots of problems. Ex:
100   h = \ x -> ... let y = ...
101                  in let let f = \x -> ...y...
102                     in f
103   ==>
104   f = \y x -> ...y...
105   h = \ x -> ... let y = ...
106                  in f y
107
108   now f y is a partial application, so it will be updated, and this
109   is Bad.
110
111
112 --- NOT RELEVANT FOR STG ----
113 * All ``lone'' lambda abstractions are lifted.  Notably this means lambda
114   abstractions:
115         - in a case alternative: case e of True -> (\x->b)
116         - in the body of a let:  let x=e in (\y->b)
117 -----------------------------
118
119 %************************************************************************
120 %*                                                                      *
121 \subsection[Lift-expressions]{The main function: liftExpr}
122 %*                                                                      *
123 %************************************************************************
124
125 \begin{code}
126 liftProgram :: Module -> UniqSupply -> [StgBinding] -> [StgBinding]
127 liftProgram mod us prog = concat (runLM mod Nothing us (mapLM liftTopBind prog))
128
129
130 liftTopBind :: StgBinding -> LiftM [StgBinding]
131 liftTopBind (StgNonRec id rhs)
132   = dontLiftRhs rhs             `thenLM` \ (rhs', rhs_info) ->
133     returnLM (getScBinds rhs_info ++ [StgNonRec id rhs'])
134
135 liftTopBind (StgRec pairs)
136   = mapAndUnzipLM dontLiftRhs rhss      `thenLM` \ (rhss', rhs_infos) ->
137     returnLM ([co_rec_ify (StgRec (ids `zip` rhss') :
138                            getScBinds (unionLiftInfos rhs_infos))
139              ])
140   where
141    (ids, rhss) = unzip pairs
142 \end{code}
143
144
145 \begin{code}
146 liftExpr :: StgExpr
147          -> LiftM (StgExpr, LiftInfo)
148
149
150 liftExpr expr@(StgCon con args lvs) = returnLM (expr, emptyLiftInfo)
151 liftExpr expr@(StgPrim op args lvs) = returnLM (expr, emptyLiftInfo)
152
153 liftExpr expr@(StgApp (StgLitArg lit) args lvs) = returnLM (expr, emptyLiftInfo)
154 liftExpr expr@(StgApp (StgConArg con) args lvs) = returnLM (expr, emptyLiftInfo)
155 liftExpr expr@(StgApp (StgVarArg v)  args lvs)
156   = lookUp v            `thenLM` \ ~(sc, sc_args) ->    -- NB the ~.  We don't want to
157                                                         -- poke these bindings too early!
158     returnLM (StgApp (StgVarArg sc) (map StgVarArg sc_args ++ args) lvs,
159               emptyLiftInfo)
160         -- The lvs field is probably wrong, but we reconstruct it
161         -- anyway following lambda lifting
162
163 liftExpr (StgCase scrut lv1 lv2 uniq alts)
164   = liftExpr scrut      `thenLM` \ (scrut', scrut_info) ->
165     lift_alts alts      `thenLM` \ (alts', alts_info) ->
166     returnLM (StgCase scrut' lv1 lv2 uniq alts', scrut_info `unionLiftInfo` alts_info)
167   where
168     lift_alts (StgAlgAlts ty alg_alts deflt)
169         = mapAndUnzipLM lift_alg_alt alg_alts   `thenLM` \ (alg_alts', alt_infos) ->
170           lift_deflt deflt                      `thenLM` \ (deflt', deflt_info) ->
171           returnLM (StgAlgAlts ty alg_alts' deflt', foldr unionLiftInfo deflt_info alt_infos)
172
173     lift_alts (StgPrimAlts ty prim_alts deflt)
174         = mapAndUnzipLM lift_prim_alt prim_alts `thenLM` \ (prim_alts', alt_infos) ->
175           lift_deflt deflt                      `thenLM` \ (deflt', deflt_info) ->
176           returnLM (StgPrimAlts ty prim_alts' deflt', foldr unionLiftInfo deflt_info alt_infos)
177
178     lift_alg_alt (con, args, use_mask, rhs)
179         = liftExpr rhs          `thenLM` \ (rhs', rhs_info) ->
180           returnLM ((con, args, use_mask, rhs'), rhs_info)
181
182     lift_prim_alt (lit, rhs)
183         = liftExpr rhs  `thenLM` \ (rhs', rhs_info) ->
184           returnLM ((lit, rhs'), rhs_info)
185
186     lift_deflt StgNoDefault = returnLM (StgNoDefault, emptyLiftInfo)
187     lift_deflt (StgBindDefault var used rhs)
188         = liftExpr rhs  `thenLM` \ (rhs', rhs_info) ->
189           returnLM (StgBindDefault var used rhs', rhs_info)
190 \end{code}
191
192 Now the interesting cases.  Let no escape isn't lifted.  We turn it
193 back into a let, to play safe, because we have to redo that pass after
194 lambda anyway.
195
196 \begin{code}
197 liftExpr (StgLetNoEscape _ _ (StgNonRec binder rhs) body)
198   = dontLiftRhs rhs     `thenLM` \ (rhs', rhs_info) ->
199     liftExpr body       `thenLM` \ (body', body_info) ->
200     returnLM (StgLet (StgNonRec binder rhs') body',
201               rhs_info `unionLiftInfo` body_info)
202
203 liftExpr (StgLetNoEscape _ _ (StgRec pairs) body)
204   = liftExpr body                       `thenLM` \ (body', body_info) ->
205     mapAndUnzipLM dontLiftRhs rhss      `thenLM` \ (rhss', rhs_infos) ->
206     returnLM (StgLet (StgRec (zipEqual "liftExpr" binders rhss')) body',
207               foldr unionLiftInfo body_info rhs_infos)
208   where
209    (binders,rhss) = unzip pairs
210 \end{code}
211
212 \begin{code}
213 liftExpr (StgLet (StgNonRec binder rhs) body)
214   | not (isLiftable rhs)
215   = dontLiftRhs rhs     `thenLM` \ (rhs', rhs_info) ->
216     liftExpr body       `thenLM` \ (body', body_info) ->
217     returnLM (StgLet (StgNonRec binder rhs') body',
218               rhs_info `unionLiftInfo` body_info)
219
220   | otherwise   -- It's a lambda
221   =     -- Do the body of the let
222     fixLM (\ ~(sc_inline, _, _) ->
223       addScInlines [binder] [sc_inline] (
224         liftExpr body
225       )                 `thenLM` \ (body', body_info) ->
226
227         -- Deal with the RHS
228       dontLiftRhs rhs           `thenLM` \ (rhs', rhs_info) ->
229
230         -- All occurrences in function position, so lambda lift
231       getFinalFreeVars (rhsFreeVars rhs)    `thenLM` \ final_free_vars ->
232
233       mkScPieces final_free_vars (binder,rhs')  `thenLM` \ (sc_inline, sc_bind) ->
234
235       returnLM (sc_inline,
236                 body',
237                 nonRecScBind rhs_info sc_bind `unionLiftInfo` body_info)
238
239     )                   `thenLM` \ (_, expr', final_info) ->
240
241     returnLM (expr', final_info)
242
243 liftExpr (StgLet (StgRec pairs) body)
244 --[Andre-testing]
245   | not (all isLiftableRec rhss)
246   = liftExpr body                       `thenLM` \ (body', body_info) ->
247     mapAndUnzipLM dontLiftRhs rhss      `thenLM` \ (rhss', rhs_infos) ->
248     returnLM (StgLet (StgRec (zipEqual "liftExpr2" binders rhss')) body',
249               foldr unionLiftInfo body_info rhs_infos)
250
251   | otherwise   -- All rhss are liftable
252   = -- Do the body of the let
253     fixLM (\ ~(sc_inlines, _, _) ->
254       addScInlines binders sc_inlines   (
255
256       liftExpr body                     `thenLM` \ (body', body_info) ->
257       mapAndUnzipLM dontLiftRhs rhss    `thenLM` \ (rhss', rhs_infos) ->
258       let
259         -- Find the free vars of all the rhss,
260         -- excluding the binders themselves.
261         rhs_free_vars = unionManyIdSets (map rhsFreeVars rhss)
262                         `minusIdSet`
263                         mkIdSet binders
264
265         rhs_info      = unionLiftInfos rhs_infos
266       in
267       getFinalFreeVars rhs_free_vars    `thenLM` \ final_free_vars ->
268
269       mapAndUnzipLM (mkScPieces final_free_vars) (binders `zip` rhss')
270                                         `thenLM` \ (sc_inlines, sc_pairs) ->
271       returnLM (sc_inlines,
272                 body',
273                 recScBind rhs_info sc_pairs `unionLiftInfo` body_info)
274
275     ))                  `thenLM` \ (_, expr', final_info) ->
276
277     returnLM (expr', final_info)
278   where
279     (binders,rhss)    = unzip pairs
280 \end{code}
281
282 \begin{code}
283 liftExpr (StgSCC ty cc expr)
284   = liftExpr expr `thenLM` \ (expr2, expr_info) ->
285     returnLM (StgSCC ty cc expr2, expr_info)
286 \end{code}
287
288 A binding is liftable if it's a *function* (args not null) and never
289 occurs in an argument position.
290
291 \begin{code}
292 isLiftable :: StgRhs -> Bool
293
294 isLiftable (StgRhsClosure _ (StgBinderInfo arg_occ _ _ _ unapplied_occ) fvs _ args _)
295
296   -- Experimental evidence suggests we should lift only if we will be
297   -- abstracting up to 4 fvs.
298
299   = if not (null args   ||      -- Not a function
300          unapplied_occ  ||      -- Has an occ with no args at all
301          arg_occ        ||      -- Occurs in arg position
302          length fvs > 4         -- Too many free variables
303         )
304     then {-trace ("LL: " ++ show (length fvs))-} True
305     else False
306 isLiftable other_rhs = False
307
308 isLiftableRec :: StgRhs -> Bool
309
310 -- this is just the same as for non-rec, except we only lift to
311 -- abstract up to 1 argument this avoids undoing Static Argument
312 -- Transformation work
313
314 {- Andre's longer comment about isLiftableRec: 1996/01:
315
316 A rec binding is "liftable" (according to our heuristics) if:
317 * It is a function,
318 * all occurrences have arguments,
319 * does not occur in an argument position and
320 * has up to *2* free variables (including the rec binding variable
321   itself!)
322
323 The point is: my experiments show that SAT is more important than LL.
324 Therefore if we still want to do LL, for *recursive* functions, we do
325 not want LL to undo what SAT did.  We do this by avoiding LL recursive
326 functions that have more than 2 fvs, since if this recursive function
327 was created by SAT (we don't know!), it would have at least 3 fvs: one
328 for the rec binding itself and 2 more for the static arguments (note:
329 this matches with the choice of performing SAT to have at least 2
330 static arguments, if we change things there we should change things
331 here).
332 -}
333
334 isLiftableRec (StgRhsClosure _ (StgBinderInfo arg_occ _ _ _ unapplied_occ) fvs _ args _)
335   = if not (null args   ||      -- Not a function
336          unapplied_occ  ||      -- Has an occ with no args at all
337          arg_occ        ||      -- Occurs in arg position
338          length fvs > 2         -- Too many free variables
339         )
340     then {-trace ("LLRec: " ++ show (length fvs))-} True
341     else False
342 isLiftableRec other_rhs = False
343
344 rhsFreeVars :: StgRhs -> IdSet
345 rhsFreeVars (StgRhsClosure _ _ fvs _ _ _) = mkIdSet fvs
346 rhsFreeVars other                         = panic "rhsFreeVars"
347 \end{code}
348
349 dontLiftRhs is like liftExpr, except that it does not lift a top-level
350 lambda abstraction.  It is used for the right-hand sides of
351 definitions where we've decided *not* to lift: for example, top-level
352 ones or mutually-recursive ones where not all are lambdas.
353
354 \begin{code}
355 dontLiftRhs :: StgRhs -> LiftM (StgRhs, LiftInfo)
356
357 dontLiftRhs rhs@(StgRhsCon cc v args) = returnLM (rhs, emptyLiftInfo)
358
359 dontLiftRhs (StgRhsClosure cc bi fvs upd args body)
360   = liftExpr body       `thenLM` \ (body', body_info) ->
361     returnLM (StgRhsClosure cc bi fvs upd args body', body_info)
362 \end{code}
363
364 \begin{code}
365 mkScPieces :: IdSet             -- Extra args for the supercombinator
366            -> (Id, StgRhs)      -- The processed RHS and original Id
367            -> LiftM ((Id,[Id]),         -- Replace abstraction with this;
368                                                 -- the set is its free vars
369                      (Id,StgRhs))       -- Binding for supercombinator
370
371 mkScPieces extra_arg_set (id, StgRhsClosure cc bi _ upd args body)
372   = ASSERT( n_args > 0 )
373         -- Construct the rhs of the supercombinator, and its Id
374     newSupercombinator sc_ty arity  `thenLM` \ sc_id ->
375     returnLM ((sc_id, extra_args), (sc_id, sc_rhs))
376   where
377     n_args     = length args
378     extra_args = idSetToList extra_arg_set
379     arity      = n_args + length extra_args
380
381         -- Construct the supercombinator type
382     type_of_original_id = idType id
383     extra_arg_tys       = map idType extra_args
384     (tyvars, rest)      = splitForAllTys type_of_original_id
385     sc_ty               = mkForAllTys tyvars (mkFunTys extra_arg_tys rest)
386
387     sc_rhs = StgRhsClosure cc bi [] upd (extra_args ++ args) body
388 \end{code}
389
390
391 %************************************************************************
392 %*                                                                      *
393 \subsection[Lift-monad]{The LiftM monad}
394 %*                                                                      *
395 %************************************************************************
396
397 The monad is used only to distribute global stuff, and the unique supply.
398
399 \begin{code}
400 type LiftM a =  Module 
401              -> LiftFlags
402              -> UniqSupply
403              -> (IdEnv                          -- Domain = candidates for lifting
404                        (Id,                     -- The supercombinator
405                         [Id])                   -- Args to apply it to
406                  )
407              -> a
408
409
410 type LiftFlags = Maybe Int      -- No of fvs reqd to float recursive
411                                 -- binding; Nothing == infinity
412
413
414 runLM :: Module -> LiftFlags -> UniqSupply -> LiftM a -> a
415 runLM mod flags us m = m mod flags us nullIdEnv
416
417 thenLM :: LiftM a -> (a -> LiftM b) -> LiftM b
418 thenLM m k mod ci us idenv
419   = k (m mod ci us1 idenv) mod ci us2 idenv
420   where
421     (us1, us2) = splitUniqSupply us
422
423 returnLM :: a -> LiftM a
424 returnLM a mod ci us idenv = a
425
426 fixLM :: (a -> LiftM a) -> LiftM a
427 fixLM k mod ci us idenv = r
428                        where
429                          r = k r mod ci us idenv
430
431 mapLM :: (a -> LiftM b) -> [a] -> LiftM [b]
432 mapLM f [] = returnLM []
433 mapLM f (a:as) = f a            `thenLM` \ r ->
434                  mapLM f as     `thenLM` \ rs ->
435                  returnLM (r:rs)
436
437 mapAndUnzipLM :: (a -> LiftM (b,c)) -> [a] -> LiftM ([b],[c])
438 mapAndUnzipLM f []     = returnLM ([],[])
439 mapAndUnzipLM f (a:as) = f a                    `thenLM` \ (b,c) ->
440                          mapAndUnzipLM f as     `thenLM` \ (bs,cs) ->
441                          returnLM (b:bs, c:cs)
442 \end{code}
443
444 \begin{code}
445 newSupercombinator :: Type
446                    -> Int               -- Arity
447                    -> LiftM Id
448
449 newSupercombinator ty arity mod ci us idenv
450   = setIdVisibility (Just mod) uniq (mkSysLocal SLIT("sc") uniq ty noSrcLoc)
451     `addIdArity` exactArity arity
452         -- ToDo: rm the addIdArity?  Just let subsequent stg-saturation pass do it?
453   where
454     uniq = getUnique us
455
456 lookUp :: Id -> LiftM (Id,[Id])
457 lookUp v mod ci us idenv
458   = case (lookupIdEnv idenv v) of
459       Just result -> result
460       Nothing     -> (v, [])
461
462 addScInlines :: [Id] -> [(Id,[Id])] -> LiftM a -> LiftM a
463 addScInlines ids values m mod ci us idenv
464   = m mod ci us idenv'
465   where
466     idenv' = growIdEnvList idenv (ids `zip_lazy` values)
467
468     -- zip_lazy zips two things together but matches lazily on the
469     -- second argument.  This is important, because the ids are know here,
470     -- but the things they are bound to are decided only later
471     zip_lazy [] _           = []
472     zip_lazy (x:xs) ~(y:ys) = (x,y) : zip_lazy xs ys
473
474
475 -- The free vars reported by the free-var analyser will include
476 -- some ids, f, which are to be replaced by ($f a b c), where $f
477 -- is the supercombinator.  Hence instead of f being a free var,
478 -- {a,b,c} are.
479 --
480 -- Example
481 --      let
482 --         f a = ...y1..y2.....
483 --      in
484 --      let
485 --         g b = ...f...z...
486 --      in
487 --      ...
488 --
489 --  Here the free vars of g are {f,z}; but f will be lambda-lifted
490 --  with free vars {y1,y2}, so the "real~ free vars of g are {y1,y2,z}.
491
492 getFinalFreeVars :: IdSet -> LiftM IdSet
493
494 getFinalFreeVars free_vars mod ci us idenv
495   = unionManyIdSets (map munge_it (idSetToList free_vars))
496   where
497     munge_it :: Id -> IdSet     -- Takes a free var and maps it to the "real"
498                                 -- free var
499     munge_it id = case (lookupIdEnv idenv id) of
500                     Just (_, args) -> mkIdSet args
501                     Nothing        -> unitIdSet id
502 \end{code}
503
504
505 %************************************************************************
506 %*                                                                      *
507 \subsection[Lift-info]{The LiftInfo type}
508 %*                                                                      *
509 %************************************************************************
510
511 \begin{code}
512 type LiftInfo = Bag StgBinding  -- Float to top
513
514 emptyLiftInfo = emptyBag
515
516 unionLiftInfo :: LiftInfo -> LiftInfo -> LiftInfo
517 unionLiftInfo binds1 binds2 = binds1 `unionBags` binds2
518
519 unionLiftInfos :: [LiftInfo] -> LiftInfo
520 unionLiftInfos infos = foldr unionLiftInfo emptyLiftInfo infos
521
522 mkScInfo :: StgBinding -> LiftInfo
523 mkScInfo bind = unitBag bind
524
525 nonRecScBind :: LiftInfo                -- From body of supercombinator
526              -> (Id, StgRhs)    -- Supercombinator and its rhs
527              -> LiftInfo
528 nonRecScBind binds (sc_id,sc_rhs) = binds `snocBag` (StgNonRec sc_id sc_rhs)
529
530
531 -- In the recursive case, all the SCs from the RHSs of the recursive group
532 -- are dealing with might potentially mention the new, recursive SCs.
533 -- So we flatten the whole lot into a single recursive group.
534
535 recScBind :: LiftInfo                   -- From body of supercombinator
536            -> [(Id,StgRhs)]     -- Supercombinator rhs
537            -> LiftInfo
538
539 recScBind binds pairs = unitBag (co_rec_ify (StgRec pairs : bagToList binds))
540
541 co_rec_ify :: [StgBinding] -> StgBinding
542 co_rec_ify binds = StgRec (concat (map f binds))
543   where
544     f (StgNonRec id rhs) = [(id,rhs)]
545     f (StgRec pairs)     = pairs
546
547
548 getScBinds :: LiftInfo -> [StgBinding]
549 getScBinds binds = bagToList binds
550
551 looksLikeSATRhs [(f,StgRhsClosure _ _ _ _ ls _)] (StgApp (StgVarArg f') args _)
552   = (f == f') && (length args == length ls)
553 looksLikeSATRhs _ _ = False
554 \end{code}