[project @ 1997-03-14 07:52:06 by simonpj]
[ghc-hetmet.git] / ghc / compiler / simplStg / LambdaLift.lhs
1 %
2 % (c) The AQUA Project, Glasgow University, 1994-1996
3 %
4 \section[LambdaLift]{A STG-code lambda lifter}
5
6 \begin{code}
7 #include "HsVersions.h"
8
9 module LambdaLift ( liftProgram ) where
10
11 IMP_Ubiq(){-uitous-}
12
13 import StgSyn
14
15 import Bag              ( emptyBag, unionBags, unitBag, snocBag, bagToList )
16 import Id               ( idType, mkSysLocal, addIdArity, 
17                           mkIdSet, unitIdSet, minusIdSet, setIdVisibility,
18                           unionManyIdSets, idSetToList, SYN_IE(IdSet),
19                           nullIdEnv, growIdEnvList, lookupIdEnv, SYN_IE(IdEnv)
20                         )
21 import IdInfo           ( ArityInfo, exactArity )
22 import SrcLoc           ( noSrcLoc )
23 import Type             ( splitForAllTy, mkForAllTys, mkFunTys )
24 import UniqSupply       ( getUnique, splitUniqSupply )
25 import Util             ( zipEqual, panic, assertPanic )
26 \end{code}
27
28 This is the lambda lifter.  It turns lambda abstractions into
29 supercombinators on a selective basis:
30
31 * Let-no-escaped bindings are never lifted. That's one major reason
32   why the lambda lifter is done in STG.
33
34 * Non-recursive bindings whose RHS is a lambda abstractions are lifted,
35   provided all the occurrences of the bound variable is in a function
36   postition.  In this example, f will be lifted:
37
38         let
39           f = \x -> e
40         in
41         ..(f a1)...(f a2)...
42   thus
43
44     $f p q r x = e      -- Supercombinator
45
46         ..($f p q r a1)...($f p q r a2)...
47
48   NOTE that the original binding is eliminated.
49
50   But in this case, f won't be lifted:
51
52         let
53           f = \x -> e
54         in
55         ..(g f)...(f a2)...
56
57   Why? Because we have to heap-allocate a closure for f thus:
58
59     $f p q r x = e      -- Supercombinator
60
61         let
62           f = $f p q r
63         in
64         ..(g f)...($f p q r a2)..
65
66   so it might as well be the original lambda abstraction.
67
68   We also do not lift if the function has an occurrence with no arguments, e.g.
69
70         let
71           f = \x -> e
72         in f
73
74   as this form is more efficient than if we create a partial application
75
76   $f p q r x = e      -- Supercombinator
77
78         f p q r
79
80 * Recursive bindings *all* of whose RHSs are lambda abstractions are
81   lifted iff
82         - all the occurrences of all the binders are in a function position
83         - there aren't ``too many'' free variables.
84
85   Same reasoning as before for the function-position stuff.  The ``too many
86   free variable'' part comes from considering the (potentially many)
87   recursive calls, which may now have lots of free vars.
88
89 Recent Observations:
90
91 * 2 might be already ``too many'' variables to abstract.
92   The problem is that the increase in the number of free variables
93   of closures refering to the lifted function (which is always # of
94   abstracted args - 1) may increase heap allocation a lot.
95   Expeiments are being done to check this...
96
97 * We do not lambda lift if the function has at least one occurrence
98   without any arguments. This caused lots of problems. Ex:
99   h = \ x -> ... let y = ...
100                  in let let f = \x -> ...y...
101                     in f
102   ==>
103   f = \y x -> ...y...
104   h = \ x -> ... let y = ...
105                  in f y
106
107   now f y is a partial application, so it will be updated, and this
108   is Bad.
109
110
111 --- NOT RELEVANT FOR STG ----
112 * All ``lone'' lambda abstractions are lifted.  Notably this means lambda
113   abstractions:
114         - in a case alternative: case e of True -> (\x->b)
115         - in the body of a let:  let x=e in (\y->b)
116 -----------------------------
117
118 %************************************************************************
119 %*                                                                      *
120 \subsection[Lift-expressions]{The main function: liftExpr}
121 %*                                                                      *
122 %************************************************************************
123
124 \begin{code}
125 liftProgram :: Module -> UniqSupply -> [StgBinding] -> [StgBinding]
126 liftProgram mod us prog = concat (runLM mod Nothing us (mapLM liftTopBind prog))
127
128
129 liftTopBind :: StgBinding -> LiftM [StgBinding]
130 liftTopBind (StgNonRec id rhs)
131   = dontLiftRhs rhs             `thenLM` \ (rhs', rhs_info) ->
132     returnLM (getScBinds rhs_info ++ [StgNonRec id rhs'])
133
134 liftTopBind (StgRec pairs)
135   = mapAndUnzipLM dontLiftRhs rhss      `thenLM` \ (rhss', rhs_infos) ->
136     returnLM ([co_rec_ify (StgRec (ids `zip` rhss') :
137                            getScBinds (unionLiftInfos rhs_infos))
138              ])
139   where
140    (ids, rhss) = unzip pairs
141 \end{code}
142
143
144 \begin{code}
145 liftExpr :: StgExpr
146          -> LiftM (StgExpr, LiftInfo)
147
148
149 liftExpr expr@(StgCon con args lvs) = returnLM (expr, emptyLiftInfo)
150 liftExpr expr@(StgPrim op args lvs) = returnLM (expr, emptyLiftInfo)
151
152 liftExpr expr@(StgApp (StgLitArg lit) args lvs) = returnLM (expr, emptyLiftInfo)
153 liftExpr expr@(StgApp (StgConArg con) args lvs) = returnLM (expr, emptyLiftInfo)
154 liftExpr expr@(StgApp (StgVarArg v)  args lvs)
155   = lookUp v            `thenLM` \ ~(sc, sc_args) ->    -- NB the ~.  We don't want to
156                                                         -- poke these bindings too early!
157     returnLM (StgApp (StgVarArg sc) (map StgVarArg sc_args ++ args) lvs,
158               emptyLiftInfo)
159         -- The lvs field is probably wrong, but we reconstruct it
160         -- anyway following lambda lifting
161
162 liftExpr (StgCase scrut lv1 lv2 uniq alts)
163   = liftExpr scrut      `thenLM` \ (scrut', scrut_info) ->
164     lift_alts alts      `thenLM` \ (alts', alts_info) ->
165     returnLM (StgCase scrut' lv1 lv2 uniq alts', scrut_info `unionLiftInfo` alts_info)
166   where
167     lift_alts (StgAlgAlts ty alg_alts deflt)
168         = mapAndUnzipLM lift_alg_alt alg_alts   `thenLM` \ (alg_alts', alt_infos) ->
169           lift_deflt deflt                      `thenLM` \ (deflt', deflt_info) ->
170           returnLM (StgAlgAlts ty alg_alts' deflt', foldr unionLiftInfo deflt_info alt_infos)
171
172     lift_alts (StgPrimAlts ty prim_alts deflt)
173         = mapAndUnzipLM lift_prim_alt prim_alts `thenLM` \ (prim_alts', alt_infos) ->
174           lift_deflt deflt                      `thenLM` \ (deflt', deflt_info) ->
175           returnLM (StgPrimAlts ty prim_alts' deflt', foldr unionLiftInfo deflt_info alt_infos)
176
177     lift_alg_alt (con, args, use_mask, rhs)
178         = liftExpr rhs          `thenLM` \ (rhs', rhs_info) ->
179           returnLM ((con, args, use_mask, rhs'), rhs_info)
180
181     lift_prim_alt (lit, rhs)
182         = liftExpr rhs  `thenLM` \ (rhs', rhs_info) ->
183           returnLM ((lit, rhs'), rhs_info)
184
185     lift_deflt StgNoDefault = returnLM (StgNoDefault, emptyLiftInfo)
186     lift_deflt (StgBindDefault var used rhs)
187         = liftExpr rhs  `thenLM` \ (rhs', rhs_info) ->
188           returnLM (StgBindDefault var used rhs', rhs_info)
189 \end{code}
190
191 Now the interesting cases.  Let no escape isn't lifted.  We turn it
192 back into a let, to play safe, because we have to redo that pass after
193 lambda anyway.
194
195 \begin{code}
196 liftExpr (StgLetNoEscape _ _ (StgNonRec binder rhs) body)
197   = dontLiftRhs rhs     `thenLM` \ (rhs', rhs_info) ->
198     liftExpr body       `thenLM` \ (body', body_info) ->
199     returnLM (StgLet (StgNonRec binder rhs') body',
200               rhs_info `unionLiftInfo` body_info)
201
202 liftExpr (StgLetNoEscape _ _ (StgRec pairs) body)
203   = liftExpr body                       `thenLM` \ (body', body_info) ->
204     mapAndUnzipLM dontLiftRhs rhss      `thenLM` \ (rhss', rhs_infos) ->
205     returnLM (StgLet (StgRec (zipEqual "liftExpr" binders rhss')) body',
206               foldr unionLiftInfo body_info rhs_infos)
207   where
208    (binders,rhss) = unzip pairs
209 \end{code}
210
211 \begin{code}
212 liftExpr (StgLet (StgNonRec binder rhs) body)
213   | not (isLiftable rhs)
214   = dontLiftRhs rhs     `thenLM` \ (rhs', rhs_info) ->
215     liftExpr body       `thenLM` \ (body', body_info) ->
216     returnLM (StgLet (StgNonRec binder rhs') body',
217               rhs_info `unionLiftInfo` body_info)
218
219   | otherwise   -- It's a lambda
220   =     -- Do the body of the let
221     fixLM (\ ~(sc_inline, _, _) ->
222       addScInlines [binder] [sc_inline] (
223         liftExpr body
224       )                 `thenLM` \ (body', body_info) ->
225
226         -- Deal with the RHS
227       dontLiftRhs rhs           `thenLM` \ (rhs', rhs_info) ->
228
229         -- All occurrences in function position, so lambda lift
230       getFinalFreeVars (rhsFreeVars rhs)    `thenLM` \ final_free_vars ->
231
232       mkScPieces final_free_vars (binder,rhs')  `thenLM` \ (sc_inline, sc_bind) ->
233
234       returnLM (sc_inline,
235                 body',
236                 nonRecScBind rhs_info sc_bind `unionLiftInfo` body_info)
237
238     )                   `thenLM` \ (_, expr', final_info) ->
239
240     returnLM (expr', final_info)
241
242 liftExpr (StgLet (StgRec pairs) body)
243 --[Andre-testing]
244   | not (all isLiftableRec rhss)
245   = liftExpr body                       `thenLM` \ (body', body_info) ->
246     mapAndUnzipLM dontLiftRhs rhss      `thenLM` \ (rhss', rhs_infos) ->
247     returnLM (StgLet (StgRec (zipEqual "liftExpr2" binders rhss')) body',
248               foldr unionLiftInfo body_info rhs_infos)
249
250   | otherwise   -- All rhss are liftable
251   = -- Do the body of the let
252     fixLM (\ ~(sc_inlines, _, _) ->
253       addScInlines binders sc_inlines   (
254
255       liftExpr body                     `thenLM` \ (body', body_info) ->
256       mapAndUnzipLM dontLiftRhs rhss    `thenLM` \ (rhss', rhs_infos) ->
257       let
258         -- Find the free vars of all the rhss,
259         -- excluding the binders themselves.
260         rhs_free_vars = unionManyIdSets (map rhsFreeVars rhss)
261                         `minusIdSet`
262                         mkIdSet binders
263
264         rhs_info      = unionLiftInfos rhs_infos
265       in
266       getFinalFreeVars rhs_free_vars    `thenLM` \ final_free_vars ->
267
268       mapAndUnzipLM (mkScPieces final_free_vars) (binders `zip` rhss')
269                                         `thenLM` \ (sc_inlines, sc_pairs) ->
270       returnLM (sc_inlines,
271                 body',
272                 recScBind rhs_info sc_pairs `unionLiftInfo` body_info)
273
274     ))                  `thenLM` \ (_, expr', final_info) ->
275
276     returnLM (expr', final_info)
277   where
278     (binders,rhss)    = unzip pairs
279 \end{code}
280
281 \begin{code}
282 liftExpr (StgSCC ty cc expr)
283   = liftExpr expr `thenLM` \ (expr2, expr_info) ->
284     returnLM (StgSCC ty cc expr2, expr_info)
285 \end{code}
286
287 A binding is liftable if it's a *function* (args not null) and never
288 occurs in an argument position.
289
290 \begin{code}
291 isLiftable :: StgRhs -> Bool
292
293 isLiftable (StgRhsClosure _ (StgBinderInfo arg_occ _ _ _ unapplied_occ) fvs _ args _)
294
295   -- Experimental evidence suggests we should lift only if we will be
296   -- abstracting up to 4 fvs.
297
298   = if not (null args   ||      -- Not a function
299          unapplied_occ  ||      -- Has an occ with no args at all
300          arg_occ        ||      -- Occurs in arg position
301          length fvs > 4         -- Too many free variables
302         )
303     then {-trace ("LL: " ++ show (length fvs))-} True
304     else False
305 isLiftable other_rhs = False
306
307 isLiftableRec :: StgRhs -> Bool
308
309 -- this is just the same as for non-rec, except we only lift to
310 -- abstract up to 1 argument this avoids undoing Static Argument
311 -- Transformation work
312
313 {- Andre's longer comment about isLiftableRec: 1996/01:
314
315 A rec binding is "liftable" (according to our heuristics) if:
316 * It is a function,
317 * all occurrences have arguments,
318 * does not occur in an argument position and
319 * has up to *2* free variables (including the rec binding variable
320   itself!)
321
322 The point is: my experiments show that SAT is more important than LL.
323 Therefore if we still want to do LL, for *recursive* functions, we do
324 not want LL to undo what SAT did.  We do this by avoiding LL recursive
325 functions that have more than 2 fvs, since if this recursive function
326 was created by SAT (we don't know!), it would have at least 3 fvs: one
327 for the rec binding itself and 2 more for the static arguments (note:
328 this matches with the choice of performing SAT to have at least 2
329 static arguments, if we change things there we should change things
330 here).
331 -}
332
333 isLiftableRec (StgRhsClosure _ (StgBinderInfo arg_occ _ _ _ unapplied_occ) fvs _ args _)
334   = if not (null args   ||      -- Not a function
335          unapplied_occ  ||      -- Has an occ with no args at all
336          arg_occ        ||      -- Occurs in arg position
337          length fvs > 2         -- Too many free variables
338         )
339     then {-trace ("LLRec: " ++ show (length fvs))-} True
340     else False
341 isLiftableRec other_rhs = False
342
343 rhsFreeVars :: StgRhs -> IdSet
344 rhsFreeVars (StgRhsClosure _ _ fvs _ _ _) = mkIdSet fvs
345 rhsFreeVars other                         = panic "rhsFreeVars"
346 \end{code}
347
348 dontLiftRhs is like liftExpr, except that it does not lift a top-level
349 lambda abstraction.  It is used for the right-hand sides of
350 definitions where we've decided *not* to lift: for example, top-level
351 ones or mutually-recursive ones where not all are lambdas.
352
353 \begin{code}
354 dontLiftRhs :: StgRhs -> LiftM (StgRhs, LiftInfo)
355
356 dontLiftRhs rhs@(StgRhsCon cc v args) = returnLM (rhs, emptyLiftInfo)
357
358 dontLiftRhs (StgRhsClosure cc bi fvs upd args body)
359   = liftExpr body       `thenLM` \ (body', body_info) ->
360     returnLM (StgRhsClosure cc bi fvs upd args body', body_info)
361 \end{code}
362
363 \begin{code}
364 mkScPieces :: IdSet             -- Extra args for the supercombinator
365            -> (Id, StgRhs)      -- The processed RHS and original Id
366            -> LiftM ((Id,[Id]),         -- Replace abstraction with this;
367                                                 -- the set is its free vars
368                      (Id,StgRhs))       -- Binding for supercombinator
369
370 mkScPieces extra_arg_set (id, StgRhsClosure cc bi _ upd args body)
371   = ASSERT( n_args > 0 )
372         -- Construct the rhs of the supercombinator, and its Id
373     newSupercombinator sc_ty arity  `thenLM` \ sc_id ->
374     returnLM ((sc_id, extra_args), (sc_id, sc_rhs))
375   where
376     n_args     = length args
377     extra_args = idSetToList extra_arg_set
378     arity      = n_args + length extra_args
379
380         -- Construct the supercombinator type
381     type_of_original_id = idType id
382     extra_arg_tys       = map idType extra_args
383     (tyvars, rest)      = splitForAllTy type_of_original_id
384     sc_ty               = mkForAllTys tyvars (mkFunTys extra_arg_tys rest)
385
386     sc_rhs = StgRhsClosure cc bi [] upd (extra_args ++ args) body
387 \end{code}
388
389
390 %************************************************************************
391 %*                                                                      *
392 \subsection[Lift-monad]{The LiftM monad}
393 %*                                                                      *
394 %************************************************************************
395
396 The monad is used only to distribute global stuff, and the unique supply.
397
398 \begin{code}
399 type LiftM a =  Module 
400              -> LiftFlags
401              -> UniqSupply
402              -> (IdEnv                          -- Domain = candidates for lifting
403                        (Id,                     -- The supercombinator
404                         [Id])                   -- Args to apply it to
405                  )
406              -> a
407
408
409 type LiftFlags = Maybe Int      -- No of fvs reqd to float recursive
410                                 -- binding; Nothing == infinity
411
412
413 runLM :: Module -> LiftFlags -> UniqSupply -> LiftM a -> a
414 runLM mod flags us m = m mod flags us nullIdEnv
415
416 thenLM :: LiftM a -> (a -> LiftM b) -> LiftM b
417 thenLM m k mod ci us idenv
418   = k (m mod ci us1 idenv) mod ci us2 idenv
419   where
420     (us1, us2) = splitUniqSupply us
421
422 returnLM :: a -> LiftM a
423 returnLM a mod ci us idenv = a
424
425 fixLM :: (a -> LiftM a) -> LiftM a
426 fixLM k mod ci us idenv = r
427                        where
428                          r = k r mod ci us idenv
429
430 mapLM :: (a -> LiftM b) -> [a] -> LiftM [b]
431 mapLM f [] = returnLM []
432 mapLM f (a:as) = f a            `thenLM` \ r ->
433                  mapLM f as     `thenLM` \ rs ->
434                  returnLM (r:rs)
435
436 mapAndUnzipLM :: (a -> LiftM (b,c)) -> [a] -> LiftM ([b],[c])
437 mapAndUnzipLM f []     = returnLM ([],[])
438 mapAndUnzipLM f (a:as) = f a                    `thenLM` \ (b,c) ->
439                          mapAndUnzipLM f as     `thenLM` \ (bs,cs) ->
440                          returnLM (b:bs, c:cs)
441 \end{code}
442
443 \begin{code}
444 newSupercombinator :: Type
445                    -> Int               -- Arity
446                    -> LiftM Id
447
448 newSupercombinator ty arity mod ci us idenv
449   = setIdVisibility mod (mkSysLocal SLIT("sc") uniq ty noSrcLoc)
450     `addIdArity` exactArity arity
451         -- ToDo: rm the addIdArity?  Just let subsequent stg-saturation pass do it?
452   where
453     uniq = getUnique us
454
455 lookUp :: Id -> LiftM (Id,[Id])
456 lookUp v mod ci us idenv
457   = case (lookupIdEnv idenv v) of
458       Just result -> result
459       Nothing     -> (v, [])
460
461 addScInlines :: [Id] -> [(Id,[Id])] -> LiftM a -> LiftM a
462 addScInlines ids values m mod ci us idenv
463   = m mod ci us idenv'
464   where
465     idenv' = growIdEnvList idenv (ids `zip_lazy` values)
466
467     -- zip_lazy zips two things together but matches lazily on the
468     -- second argument.  This is important, because the ids are know here,
469     -- but the things they are bound to are decided only later
470     zip_lazy [] _           = []
471     zip_lazy (x:xs) ~(y:ys) = (x,y) : zip_lazy xs ys
472
473
474 -- The free vars reported by the free-var analyser will include
475 -- some ids, f, which are to be replaced by ($f a b c), where $f
476 -- is the supercombinator.  Hence instead of f being a free var,
477 -- {a,b,c} are.
478 --
479 -- Example
480 --      let
481 --         f a = ...y1..y2.....
482 --      in
483 --      let
484 --         g b = ...f...z...
485 --      in
486 --      ...
487 --
488 --  Here the free vars of g are {f,z}; but f will be lambda-lifted
489 --  with free vars {y1,y2}, so the "real~ free vars of g are {y1,y2,z}.
490
491 getFinalFreeVars :: IdSet -> LiftM IdSet
492
493 getFinalFreeVars free_vars mod ci us idenv
494   = unionManyIdSets (map munge_it (idSetToList free_vars))
495   where
496     munge_it :: Id -> IdSet     -- Takes a free var and maps it to the "real"
497                                 -- free var
498     munge_it id = case (lookupIdEnv idenv id) of
499                     Just (_, args) -> mkIdSet args
500                     Nothing        -> unitIdSet id
501 \end{code}
502
503
504 %************************************************************************
505 %*                                                                      *
506 \subsection[Lift-info]{The LiftInfo type}
507 %*                                                                      *
508 %************************************************************************
509
510 \begin{code}
511 type LiftInfo = Bag StgBinding  -- Float to top
512
513 emptyLiftInfo = emptyBag
514
515 unionLiftInfo :: LiftInfo -> LiftInfo -> LiftInfo
516 unionLiftInfo binds1 binds2 = binds1 `unionBags` binds2
517
518 unionLiftInfos :: [LiftInfo] -> LiftInfo
519 unionLiftInfos infos = foldr unionLiftInfo emptyLiftInfo infos
520
521 mkScInfo :: StgBinding -> LiftInfo
522 mkScInfo bind = unitBag bind
523
524 nonRecScBind :: LiftInfo                -- From body of supercombinator
525              -> (Id, StgRhs)    -- Supercombinator and its rhs
526              -> LiftInfo
527 nonRecScBind binds (sc_id,sc_rhs) = binds `snocBag` (StgNonRec sc_id sc_rhs)
528
529
530 -- In the recursive case, all the SCs from the RHSs of the recursive group
531 -- are dealing with might potentially mention the new, recursive SCs.
532 -- So we flatten the whole lot into a single recursive group.
533
534 recScBind :: LiftInfo                   -- From body of supercombinator
535            -> [(Id,StgRhs)]     -- Supercombinator rhs
536            -> LiftInfo
537
538 recScBind binds pairs = unitBag (co_rec_ify (StgRec pairs : bagToList binds))
539
540 co_rec_ify :: [StgBinding] -> StgBinding
541 co_rec_ify binds = StgRec (concat (map f binds))
542   where
543     f (StgNonRec id rhs) = [(id,rhs)]
544     f (StgRec pairs)     = pairs
545
546
547 getScBinds :: LiftInfo -> [StgBinding]
548 getScBinds binds = bagToList binds
549
550 looksLikeSATRhs [(f,StgRhsClosure _ _ _ _ ls _)] (StgApp (StgVarArg f') args _)
551   = (f == f') && (length args == length ls)
552 looksLikeSATRhs _ _ = False
553 \end{code}