[project @ 1999-01-23 17:34:37 by sof]
[ghc-hetmet.git] / ghc / compiler / simplStg / LambdaLift.lhs
1 %
2 % (c) The AQUA Project, Glasgow University, 1994-1998
3 %
4 \section[LambdaLift]{A STG-code lambda lifter}
5
6 \begin{code}
7 module LambdaLift ( liftProgram ) where
8
9 #include "HsVersions.h"
10
11 import StgSyn
12
13 import Bag              ( Bag, emptyBag, unionBags, unitBag, snocBag, bagToList )
14 import Id               ( mkUserId, idType, setIdArity, Id )
15 import VarSet
16 import VarEnv
17 import IdInfo           ( exactArity )
18 import Name             ( Module, mkTopName )
19 import Type             ( splitForAllTys, mkForAllTys, mkFunTys, Type )
20 import UniqSupply       ( uniqFromSupply, splitUniqSupply, UniqSupply )
21 import Util             ( zipEqual )
22 import Panic            ( panic, assertPanic )
23 \end{code}
24
25 This is the lambda lifter.  It turns lambda abstractions into
26 supercombinators on a selective basis:
27
28 * Let-no-escaped bindings are never lifted. That's one major reason
29   why the lambda lifter is done in STG.
30
31 * Non-recursive bindings whose RHS is a lambda abstractions are lifted,
32   provided all the occurrences of the bound variable is in a function
33   postition.  In this example, f will be lifted:
34
35         let
36           f = \x -> e
37         in
38         ..(f a1)...(f a2)...
39   thus
40
41     $f p q r x = e      -- Supercombinator
42
43         ..($f p q r a1)...($f p q r a2)...
44
45   NOTE that the original binding is eliminated.
46
47   But in this case, f won't be lifted:
48
49         let
50           f = \x -> e
51         in
52         ..(g f)...(f a2)...
53
54   Why? Because we have to heap-allocate a closure for f thus:
55
56     $f p q r x = e      -- Supercombinator
57
58         let
59           f = $f p q r
60         in
61         ..(g f)...($f p q r a2)..
62
63   so it might as well be the original lambda abstraction.
64
65   We also do not lift if the function has an occurrence with no arguments, e.g.
66
67         let
68           f = \x -> e
69         in f
70
71   as this form is more efficient than if we create a partial application
72
73   $f p q r x = e      -- Supercombinator
74
75         f p q r
76
77 * Recursive bindings *all* of whose RHSs are lambda abstractions are
78   lifted iff
79         - all the occurrences of all the binders are in a function position
80         - there aren't ``too many'' free variables.
81
82   Same reasoning as before for the function-position stuff.  The ``too many
83   free variable'' part comes from considering the (potentially many)
84   recursive calls, which may now have lots of free vars.
85
86 Recent Observations:
87
88 * 2 might be already ``too many'' variables to abstract.
89   The problem is that the increase in the number of free variables
90   of closures refering to the lifted function (which is always # of
91   abstracted args - 1) may increase heap allocation a lot.
92   Expeiments are being done to check this...
93
94 * We do not lambda lift if the function has at least one occurrence
95   without any arguments. This caused lots of problems. Ex:
96   h = \ x -> ... let y = ...
97                  in let let f = \x -> ...y...
98                     in f
99   ==>
100   f = \y x -> ...y...
101   h = \ x -> ... let y = ...
102                  in f y
103
104   now f y is a partial application, so it will be updated, and this
105   is Bad.
106
107
108 --- NOT RELEVANT FOR STG ----
109 * All ``lone'' lambda abstractions are lifted.  Notably this means lambda
110   abstractions:
111         - in a case alternative: case e of True -> (\x->b)
112         - in the body of a let:  let x=e in (\y->b)
113 -----------------------------
114
115 %************************************************************************
116 %*                                                                      *
117 \subsection[Lift-expressions]{The main function: liftExpr}
118 %*                                                                      *
119 %************************************************************************
120
121 \begin{code}
122 liftProgram :: Module -> UniqSupply -> [StgBinding] -> [StgBinding]
123 liftProgram mod us prog = concat (runLM mod Nothing us (mapLM liftTopBind prog))
124
125
126 liftTopBind :: StgBinding -> LiftM [StgBinding]
127 liftTopBind (StgNonRec id rhs)
128   = dontLiftRhs rhs             `thenLM` \ (rhs', rhs_info) ->
129     returnLM (getScBinds rhs_info ++ [StgNonRec id rhs'])
130
131 liftTopBind (StgRec pairs)
132   = mapAndUnzipLM dontLiftRhs rhss      `thenLM` \ (rhss', rhs_infos) ->
133     returnLM ([co_rec_ify (StgRec (ids `zip` rhss') :
134                            getScBinds (unionLiftInfos rhs_infos))
135              ])
136   where
137    (ids, rhss) = unzip pairs
138 \end{code}
139
140
141 \begin{code}
142 liftExpr :: StgExpr
143          -> LiftM (StgExpr, LiftInfo)
144
145
146 liftExpr expr@(StgCon con args _) = returnLM (expr, emptyLiftInfo)
147
148 liftExpr expr@(StgApp v args)
149   = lookUp v            `thenLM` \ ~(sc, sc_args) ->    -- NB the ~.  We don't want to
150                                                         -- poke these bindings too early!
151     returnLM (StgApp sc (map StgVarArg sc_args ++ args),
152               emptyLiftInfo)
153         -- The lvs field is probably wrong, but we reconstruct it
154         -- anyway following lambda lifting
155
156 liftExpr (StgCase scrut lv1 lv2 bndr srt alts)
157   = liftExpr scrut      `thenLM` \ (scrut', scrut_info) ->
158     lift_alts alts      `thenLM` \ (alts', alts_info) ->
159     returnLM (StgCase scrut' lv1 lv2 bndr srt alts', scrut_info `unionLiftInfo` alts_info)
160   where
161     lift_alts (StgAlgAlts ty alg_alts deflt)
162         = mapAndUnzipLM lift_alg_alt alg_alts   `thenLM` \ (alg_alts', alt_infos) ->
163           lift_deflt deflt                      `thenLM` \ (deflt', deflt_info) ->
164           returnLM (StgAlgAlts ty alg_alts' deflt', foldr unionLiftInfo deflt_info alt_infos)
165
166     lift_alts (StgPrimAlts ty prim_alts deflt)
167         = mapAndUnzipLM lift_prim_alt prim_alts `thenLM` \ (prim_alts', alt_infos) ->
168           lift_deflt deflt                      `thenLM` \ (deflt', deflt_info) ->
169           returnLM (StgPrimAlts ty prim_alts' deflt', foldr unionLiftInfo deflt_info alt_infos)
170
171     lift_alg_alt (con, args, use_mask, rhs)
172         = liftExpr rhs          `thenLM` \ (rhs', rhs_info) ->
173           returnLM ((con, args, use_mask, rhs'), rhs_info)
174
175     lift_prim_alt (lit, rhs)
176         = liftExpr rhs  `thenLM` \ (rhs', rhs_info) ->
177           returnLM ((lit, rhs'), rhs_info)
178
179     lift_deflt StgNoDefault = returnLM (StgNoDefault, emptyLiftInfo)
180     lift_deflt (StgBindDefault rhs)
181         = liftExpr rhs  `thenLM` \ (rhs', rhs_info) ->
182           returnLM (StgBindDefault rhs', rhs_info)
183 \end{code}
184
185 Now the interesting cases.  Let no escape isn't lifted.  We turn it
186 back into a let, to play safe, because we have to redo that pass after
187 lambda anyway.
188
189 \begin{code}
190 liftExpr (StgLetNoEscape _ _ (StgNonRec binder rhs) body)
191   = dontLiftRhs rhs     `thenLM` \ (rhs', rhs_info) ->
192     liftExpr body       `thenLM` \ (body', body_info) ->
193     returnLM (StgLet (StgNonRec binder rhs') body',
194               rhs_info `unionLiftInfo` body_info)
195
196 liftExpr (StgLetNoEscape _ _ (StgRec pairs) body)
197   = liftExpr body                       `thenLM` \ (body', body_info) ->
198     mapAndUnzipLM dontLiftRhs rhss      `thenLM` \ (rhss', rhs_infos) ->
199     returnLM (StgLet (StgRec (zipEqual "liftExpr" binders rhss')) body',
200               foldr unionLiftInfo body_info rhs_infos)
201   where
202    (binders,rhss) = unzip pairs
203 \end{code}
204
205 \begin{code}
206 liftExpr (StgLet (StgNonRec binder rhs) body)
207   | not (isLiftable rhs)
208   = dontLiftRhs rhs     `thenLM` \ (rhs', rhs_info) ->
209     liftExpr body       `thenLM` \ (body', body_info) ->
210     returnLM (StgLet (StgNonRec binder rhs') body',
211               rhs_info `unionLiftInfo` body_info)
212
213   | otherwise   -- It's a lambda
214   =     -- Do the body of the let
215     fixLM (\ ~(sc_inline, _, _) ->
216       addScInlines [binder] [sc_inline] (
217         liftExpr body
218       )                 `thenLM` \ (body', body_info) ->
219
220         -- Deal with the RHS
221       dontLiftRhs rhs           `thenLM` \ (rhs', rhs_info) ->
222
223         -- All occurrences in function position, so lambda lift
224       getFinalFreeVars (rhsFreeVars rhs)    `thenLM` \ final_free_vars ->
225
226       mkScPieces final_free_vars (binder,rhs')  `thenLM` \ (sc_inline, sc_bind) ->
227
228       returnLM (sc_inline,
229                 body',
230                 nonRecScBind rhs_info sc_bind `unionLiftInfo` body_info)
231
232     )                   `thenLM` \ (_, expr', final_info) ->
233
234     returnLM (expr', final_info)
235
236 liftExpr (StgLet (StgRec pairs) body)
237 --[Andre-testing]
238   | not (all isLiftableRec rhss)
239   = liftExpr body                       `thenLM` \ (body', body_info) ->
240     mapAndUnzipLM dontLiftRhs rhss      `thenLM` \ (rhss', rhs_infos) ->
241     returnLM (StgLet (StgRec (zipEqual "liftExpr2" binders rhss')) body',
242               foldr unionLiftInfo body_info rhs_infos)
243
244   | otherwise   -- All rhss are liftable
245   = -- Do the body of the let
246     fixLM (\ ~(sc_inlines, _, _) ->
247       addScInlines binders sc_inlines   (
248
249       liftExpr body                     `thenLM` \ (body', body_info) ->
250       mapAndUnzipLM dontLiftRhs rhss    `thenLM` \ (rhss', rhs_infos) ->
251       let
252         -- Find the free vars of all the rhss,
253         -- excluding the binders themselves.
254         rhs_free_vars = unionVarSets (map rhsFreeVars rhss)
255                         `minusVarSet`
256                         mkVarSet binders
257
258         rhs_info      = unionLiftInfos rhs_infos
259       in
260       getFinalFreeVars rhs_free_vars    `thenLM` \ final_free_vars ->
261
262       mapAndUnzipLM (mkScPieces final_free_vars) (binders `zip` rhss')
263                                         `thenLM` \ (sc_inlines, sc_pairs) ->
264       returnLM (sc_inlines,
265                 body',
266                 recScBind rhs_info sc_pairs `unionLiftInfo` body_info)
267
268     ))                  `thenLM` \ (_, expr', final_info) ->
269
270     returnLM (expr', final_info)
271   where
272     (binders,rhss)    = unzip pairs
273 \end{code}
274
275 \begin{code}
276 liftExpr (StgSCC cc expr)
277   = liftExpr expr `thenLM` \ (expr2, expr_info) ->
278     returnLM (StgSCC cc expr2, expr_info)
279 \end{code}
280
281 A binding is liftable if it's a *function* (args not null) and never
282 occurs in an argument position.
283
284 \begin{code}
285 isLiftable :: StgRhs -> Bool
286
287 isLiftable (StgRhsClosure _ (StgBinderInfo arg_occ _ _ _ unapplied_occ) _ fvs _ args _)
288
289   -- Experimental evidence suggests we should lift only if we will be
290   -- abstracting up to 4 fvs.
291
292   = if not (null args   ||      -- Not a function
293          unapplied_occ  ||      -- Has an occ with no args at all
294          arg_occ        ||      -- Occurs in arg position
295          length fvs > 4         -- Too many free variables
296         )
297     then {-trace ("LL: " ++ show (length fvs))-} True
298     else False
299 isLiftable other_rhs = False
300
301 isLiftableRec :: StgRhs -> Bool
302
303 -- this is just the same as for non-rec, except we only lift to
304 -- abstract up to 1 argument this avoids undoing Static Argument
305 -- Transformation work
306
307 {- Andre's longer comment about isLiftableRec: 1996/01:
308
309 A rec binding is "liftable" (according to our heuristics) if:
310 * It is a function,
311 * all occurrences have arguments,
312 * does not occur in an argument position and
313 * has up to *2* free variables (including the rec binding variable
314   itself!)
315
316 The point is: my experiments show that SAT is more important than LL.
317 Therefore if we still want to do LL, for *recursive* functions, we do
318 not want LL to undo what SAT did.  We do this by avoiding LL recursive
319 functions that have more than 2 fvs, since if this recursive function
320 was created by SAT (we don't know!), it would have at least 3 fvs: one
321 for the rec binding itself and 2 more for the static arguments (note:
322 this matches with the choice of performing SAT to have at least 2
323 static arguments, if we change things there we should change things
324 here).
325 -}
326
327 isLiftableRec (StgRhsClosure _ (StgBinderInfo arg_occ _ _ _ unapplied_occ) _ fvs _ args _)
328   = if not (null args   ||      -- Not a function
329          unapplied_occ  ||      -- Has an occ with no args at all
330          arg_occ        ||      -- Occurs in arg position
331          length fvs > 2         -- Too many free variables
332         )
333     then {-trace ("LLRec: " ++ show (length fvs))-} True
334     else False
335 isLiftableRec other_rhs = False
336
337 rhsFreeVars :: StgRhs -> IdSet
338 rhsFreeVars (StgRhsClosure _ _ _ fvs _ _ _) = mkVarSet fvs
339 rhsFreeVars other                         = panic "rhsFreeVars"
340 \end{code}
341
342 dontLiftRhs is like liftExpr, except that it does not lift a top-level
343 lambda abstraction.  It is used for the right-hand sides of
344 definitions where we've decided *not* to lift: for example, top-level
345 ones or mutually-recursive ones where not all are lambdas.
346
347 \begin{code}
348 dontLiftRhs :: StgRhs -> LiftM (StgRhs, LiftInfo)
349
350 dontLiftRhs rhs@(StgRhsCon cc v args) = returnLM (rhs, emptyLiftInfo)
351
352 dontLiftRhs (StgRhsClosure cc bi srt fvs upd args body)
353   = liftExpr body       `thenLM` \ (body', body_info) ->
354     returnLM (StgRhsClosure cc bi srt fvs upd args body', body_info)
355 \end{code}
356
357 \begin{code}
358 mkScPieces :: IdSet             -- Extra args for the supercombinator
359            -> (Id, StgRhs)      -- The processed RHS and original Id
360            -> LiftM ((Id,[Id]),         -- Replace abstraction with this;
361                                                 -- the set is its free vars
362                      (Id,StgRhs))       -- Binding for supercombinator
363
364 mkScPieces extra_arg_set (id, StgRhsClosure cc bi srt _ upd args body)
365   = ASSERT( n_args > 0 )
366         -- Construct the rhs of the supercombinator, and its Id
367     newSupercombinator sc_ty arity  `thenLM` \ sc_id ->
368     returnLM ((sc_id, extra_args), (sc_id, sc_rhs))
369   where
370     n_args     = length args
371     extra_args = varSetElems extra_arg_set
372     arity      = n_args + length extra_args
373
374         -- Construct the supercombinator type
375     type_of_original_id = idType id
376     extra_arg_tys       = map idType extra_args
377     (tyvars, rest)      = splitForAllTys type_of_original_id
378     sc_ty               = mkForAllTys tyvars (mkFunTys extra_arg_tys rest)
379
380     sc_rhs = StgRhsClosure cc bi srt [] upd (extra_args ++ args) body
381 \end{code}
382
383
384 %************************************************************************
385 %*                                                                      *
386 \subsection[Lift-monad]{The LiftM monad}
387 %*                                                                      *
388 %************************************************************************
389
390 The monad is used only to distribute global stuff, and the unique supply.
391
392 \begin{code}
393 type LiftM a =  Module 
394              -> LiftFlags
395              -> UniqSupply
396              -> (IdEnv                          -- Domain = candidates for lifting
397                        (Id,                     -- The supercombinator
398                         [Id])                   -- Args to apply it to
399                  )
400              -> a
401
402
403 type LiftFlags = Maybe Int      -- No of fvs reqd to float recursive
404                                 -- binding; Nothing == infinity
405
406
407 runLM :: Module -> LiftFlags -> UniqSupply -> LiftM a -> a
408 runLM mod flags us m = m mod flags us emptyVarEnv
409
410 thenLM :: LiftM a -> (a -> LiftM b) -> LiftM b
411 thenLM m k mod ci us idenv
412   = k (m mod ci us1 idenv) mod ci us2 idenv
413   where
414     (us1, us2) = splitUniqSupply us
415
416 returnLM :: a -> LiftM a
417 returnLM a mod ci us idenv = a
418
419 fixLM :: (a -> LiftM a) -> LiftM a
420 fixLM k mod ci us idenv = r
421                        where
422                          r = k r mod ci us idenv
423
424 mapLM :: (a -> LiftM b) -> [a] -> LiftM [b]
425 mapLM f [] = returnLM []
426 mapLM f (a:as) = f a            `thenLM` \ r ->
427                  mapLM f as     `thenLM` \ rs ->
428                  returnLM (r:rs)
429
430 mapAndUnzipLM :: (a -> LiftM (b,c)) -> [a] -> LiftM ([b],[c])
431 mapAndUnzipLM f []     = returnLM ([],[])
432 mapAndUnzipLM f (a:as) = f a                    `thenLM` \ (b,c) ->
433                          mapAndUnzipLM f as     `thenLM` \ (bs,cs) ->
434                          returnLM (b:bs, c:cs)
435 \end{code}
436
437 \begin{code}
438 newSupercombinator :: Type
439                    -> Int               -- Arity
440                    -> LiftM Id
441
442 newSupercombinator ty arity mod ci us idenv
443   = mkUserId (mkTopName uniq mod SLIT("_ll")) ty
444     `setIdArity` exactArity arity
445         -- ToDo: rm the setIdArity?  Just let subsequent stg-saturation pass do it?
446   where
447     uniq = uniqFromSupply us
448
449 lookUp :: Id -> LiftM (Id,[Id])
450 lookUp v mod ci us idenv
451   = case (lookupVarEnv idenv v) of
452       Just result -> result
453       Nothing     -> (v, [])
454
455 addScInlines :: [Id] -> [(Id,[Id])] -> LiftM a -> LiftM a
456 addScInlines ids values m mod ci us idenv
457   = m mod ci us idenv'
458   where
459     idenv' = extendVarEnvList idenv (ids `zip_lazy` values)
460
461     -- zip_lazy zips two things together but matches lazily on the
462     -- second argument.  This is important, because the ids are know here,
463     -- but the things they are bound to are decided only later
464     zip_lazy [] _           = []
465     zip_lazy (x:xs) ~(y:ys) = (x,y) : zip_lazy xs ys
466
467
468 -- The free vars reported by the free-var analyser will include
469 -- some ids, f, which are to be replaced by ($f a b c), where $f
470 -- is the supercombinator.  Hence instead of f being a free var,
471 -- {a,b,c} are.
472 --
473 -- Example
474 --      let
475 --         f a = ...y1..y2.....
476 --      in
477 --      let
478 --         g b = ...f...z...
479 --      in
480 --      ...
481 --
482 --  Here the free vars of g are {f,z}; but f will be lambda-lifted
483 --  with free vars {y1,y2}, so the "real~ free vars of g are {y1,y2,z}.
484
485 getFinalFreeVars :: IdSet -> LiftM IdSet
486
487 getFinalFreeVars free_vars mod ci us idenv
488   = unionVarSets (map munge_it (varSetElems free_vars))
489   where
490     munge_it :: Id -> IdSet     -- Takes a free var and maps it to the "real"
491                                 -- free var
492     munge_it id = case (lookupVarEnv idenv id) of
493                     Just (_, args) -> mkVarSet args
494                     Nothing        -> unitVarSet id
495 \end{code}
496
497
498 %************************************************************************
499 %*                                                                      *
500 \subsection[Lift-info]{The LiftInfo type}
501 %*                                                                      *
502 %************************************************************************
503
504 \begin{code}
505 type LiftInfo = Bag StgBinding  -- Float to top
506
507 emptyLiftInfo = emptyBag
508
509 unionLiftInfo :: LiftInfo -> LiftInfo -> LiftInfo
510 unionLiftInfo binds1 binds2 = binds1 `unionBags` binds2
511
512 unionLiftInfos :: [LiftInfo] -> LiftInfo
513 unionLiftInfos infos = foldr unionLiftInfo emptyLiftInfo infos
514
515 mkScInfo :: StgBinding -> LiftInfo
516 mkScInfo bind = unitBag bind
517
518 nonRecScBind :: LiftInfo                -- From body of supercombinator
519              -> (Id, StgRhs)    -- Supercombinator and its rhs
520              -> LiftInfo
521 nonRecScBind binds (sc_id,sc_rhs) = binds `snocBag` (StgNonRec sc_id sc_rhs)
522
523
524 -- In the recursive case, all the SCs from the RHSs of the recursive group
525 -- are dealing with might potentially mention the new, recursive SCs.
526 -- So we flatten the whole lot into a single recursive group.
527
528 recScBind :: LiftInfo                   -- From body of supercombinator
529            -> [(Id,StgRhs)]     -- Supercombinator rhs
530            -> LiftInfo
531
532 recScBind binds pairs = unitBag (co_rec_ify (StgRec pairs : bagToList binds))
533
534 co_rec_ify :: [StgBinding] -> StgBinding
535 co_rec_ify binds = StgRec (concat (map f binds))
536   where
537     f (StgNonRec id rhs) = [(id,rhs)]
538     f (StgRec pairs)     = pairs
539
540
541 getScBinds :: LiftInfo -> [StgBinding]
542 getScBinds binds = bagToList binds
543
544 looksLikeSATRhs [(f,StgRhsClosure _ _ _ _ _ ls _)] (StgApp f' args)
545   = (f == f') && (length args == length ls)
546 looksLikeSATRhs _ _ = False
547 \end{code}