38967fe78199ba80e5983d43a91bd8d9b99450ea
[ghc-hetmet.git] / ghc / compiler / simplStg / LambdaLift.lhs
1 %
2 % (c) The AQUA Project, Glasgow University, 1994-1996
3 %
4 \section[LambdaLift]{A STG-code lambda lifter}
5
6 \begin{code}
7 #include "HsVersions.h"
8
9 module LambdaLift ( liftProgram ) where
10
11 IMP_Ubiq(){-uitous-}
12
13 import StgSyn
14
15 import Bag              ( Bag, emptyBag, unionBags, unitBag, snocBag, bagToList )
16 import Id               ( idType, mkSysLocal, addIdArity, 
17                           mkIdSet, unitIdSet, minusIdSet, setIdVisibility,
18                           unionManyIdSets, idSetToList, SYN_IE(IdSet),
19                           nullIdEnv, growIdEnvList, lookupIdEnv, SYN_IE(IdEnv),
20                           SYN_IE(Id)
21                         )
22 import IdInfo           ( ArityInfo, exactArity )
23 import Name             ( SYN_IE(Module) )
24 import SrcLoc           ( noSrcLoc )
25 import Type             ( splitForAllTy, mkForAllTys, mkFunTys, SYN_IE(Type) )
26 import UniqSupply       ( getUnique, splitUniqSupply, UniqSupply )
27 import Util             ( zipEqual, panic, assertPanic )
28 \end{code}
29
30 This is the lambda lifter.  It turns lambda abstractions into
31 supercombinators on a selective basis:
32
33 * Let-no-escaped bindings are never lifted. That's one major reason
34   why the lambda lifter is done in STG.
35
36 * Non-recursive bindings whose RHS is a lambda abstractions are lifted,
37   provided all the occurrences of the bound variable is in a function
38   postition.  In this example, f will be lifted:
39
40         let
41           f = \x -> e
42         in
43         ..(f a1)...(f a2)...
44   thus
45
46     $f p q r x = e      -- Supercombinator
47
48         ..($f p q r a1)...($f p q r a2)...
49
50   NOTE that the original binding is eliminated.
51
52   But in this case, f won't be lifted:
53
54         let
55           f = \x -> e
56         in
57         ..(g f)...(f a2)...
58
59   Why? Because we have to heap-allocate a closure for f thus:
60
61     $f p q r x = e      -- Supercombinator
62
63         let
64           f = $f p q r
65         in
66         ..(g f)...($f p q r a2)..
67
68   so it might as well be the original lambda abstraction.
69
70   We also do not lift if the function has an occurrence with no arguments, e.g.
71
72         let
73           f = \x -> e
74         in f
75
76   as this form is more efficient than if we create a partial application
77
78   $f p q r x = e      -- Supercombinator
79
80         f p q r
81
82 * Recursive bindings *all* of whose RHSs are lambda abstractions are
83   lifted iff
84         - all the occurrences of all the binders are in a function position
85         - there aren't ``too many'' free variables.
86
87   Same reasoning as before for the function-position stuff.  The ``too many
88   free variable'' part comes from considering the (potentially many)
89   recursive calls, which may now have lots of free vars.
90
91 Recent Observations:
92
93 * 2 might be already ``too many'' variables to abstract.
94   The problem is that the increase in the number of free variables
95   of closures refering to the lifted function (which is always # of
96   abstracted args - 1) may increase heap allocation a lot.
97   Expeiments are being done to check this...
98
99 * We do not lambda lift if the function has at least one occurrence
100   without any arguments. This caused lots of problems. Ex:
101   h = \ x -> ... let y = ...
102                  in let let f = \x -> ...y...
103                     in f
104   ==>
105   f = \y x -> ...y...
106   h = \ x -> ... let y = ...
107                  in f y
108
109   now f y is a partial application, so it will be updated, and this
110   is Bad.
111
112
113 --- NOT RELEVANT FOR STG ----
114 * All ``lone'' lambda abstractions are lifted.  Notably this means lambda
115   abstractions:
116         - in a case alternative: case e of True -> (\x->b)
117         - in the body of a let:  let x=e in (\y->b)
118 -----------------------------
119
120 %************************************************************************
121 %*                                                                      *
122 \subsection[Lift-expressions]{The main function: liftExpr}
123 %*                                                                      *
124 %************************************************************************
125
126 \begin{code}
127 liftProgram :: Module -> UniqSupply -> [StgBinding] -> [StgBinding]
128 liftProgram mod us prog = concat (runLM mod Nothing us (mapLM liftTopBind prog))
129
130
131 liftTopBind :: StgBinding -> LiftM [StgBinding]
132 liftTopBind (StgNonRec id rhs)
133   = dontLiftRhs rhs             `thenLM` \ (rhs', rhs_info) ->
134     returnLM (getScBinds rhs_info ++ [StgNonRec id rhs'])
135
136 liftTopBind (StgRec pairs)
137   = mapAndUnzipLM dontLiftRhs rhss      `thenLM` \ (rhss', rhs_infos) ->
138     returnLM ([co_rec_ify (StgRec (ids `zip` rhss') :
139                            getScBinds (unionLiftInfos rhs_infos))
140              ])
141   where
142    (ids, rhss) = unzip pairs
143 \end{code}
144
145
146 \begin{code}
147 liftExpr :: StgExpr
148          -> LiftM (StgExpr, LiftInfo)
149
150
151 liftExpr expr@(StgCon con args lvs) = returnLM (expr, emptyLiftInfo)
152 liftExpr expr@(StgPrim op args lvs) = returnLM (expr, emptyLiftInfo)
153
154 liftExpr expr@(StgApp (StgLitArg lit) args lvs) = returnLM (expr, emptyLiftInfo)
155 liftExpr expr@(StgApp (StgConArg con) args lvs) = returnLM (expr, emptyLiftInfo)
156 liftExpr expr@(StgApp (StgVarArg v)  args lvs)
157   = lookUp v            `thenLM` \ ~(sc, sc_args) ->    -- NB the ~.  We don't want to
158                                                         -- poke these bindings too early!
159     returnLM (StgApp (StgVarArg sc) (map StgVarArg sc_args ++ args) lvs,
160               emptyLiftInfo)
161         -- The lvs field is probably wrong, but we reconstruct it
162         -- anyway following lambda lifting
163
164 liftExpr (StgCase scrut lv1 lv2 uniq alts)
165   = liftExpr scrut      `thenLM` \ (scrut', scrut_info) ->
166     lift_alts alts      `thenLM` \ (alts', alts_info) ->
167     returnLM (StgCase scrut' lv1 lv2 uniq alts', scrut_info `unionLiftInfo` alts_info)
168   where
169     lift_alts (StgAlgAlts ty alg_alts deflt)
170         = mapAndUnzipLM lift_alg_alt alg_alts   `thenLM` \ (alg_alts', alt_infos) ->
171           lift_deflt deflt                      `thenLM` \ (deflt', deflt_info) ->
172           returnLM (StgAlgAlts ty alg_alts' deflt', foldr unionLiftInfo deflt_info alt_infos)
173
174     lift_alts (StgPrimAlts ty prim_alts deflt)
175         = mapAndUnzipLM lift_prim_alt prim_alts `thenLM` \ (prim_alts', alt_infos) ->
176           lift_deflt deflt                      `thenLM` \ (deflt', deflt_info) ->
177           returnLM (StgPrimAlts ty prim_alts' deflt', foldr unionLiftInfo deflt_info alt_infos)
178
179     lift_alg_alt (con, args, use_mask, rhs)
180         = liftExpr rhs          `thenLM` \ (rhs', rhs_info) ->
181           returnLM ((con, args, use_mask, rhs'), rhs_info)
182
183     lift_prim_alt (lit, rhs)
184         = liftExpr rhs  `thenLM` \ (rhs', rhs_info) ->
185           returnLM ((lit, rhs'), rhs_info)
186
187     lift_deflt StgNoDefault = returnLM (StgNoDefault, emptyLiftInfo)
188     lift_deflt (StgBindDefault var used rhs)
189         = liftExpr rhs  `thenLM` \ (rhs', rhs_info) ->
190           returnLM (StgBindDefault var used rhs', rhs_info)
191 \end{code}
192
193 Now the interesting cases.  Let no escape isn't lifted.  We turn it
194 back into a let, to play safe, because we have to redo that pass after
195 lambda anyway.
196
197 \begin{code}
198 liftExpr (StgLetNoEscape _ _ (StgNonRec binder rhs) body)
199   = dontLiftRhs rhs     `thenLM` \ (rhs', rhs_info) ->
200     liftExpr body       `thenLM` \ (body', body_info) ->
201     returnLM (StgLet (StgNonRec binder rhs') body',
202               rhs_info `unionLiftInfo` body_info)
203
204 liftExpr (StgLetNoEscape _ _ (StgRec pairs) body)
205   = liftExpr body                       `thenLM` \ (body', body_info) ->
206     mapAndUnzipLM dontLiftRhs rhss      `thenLM` \ (rhss', rhs_infos) ->
207     returnLM (StgLet (StgRec (zipEqual "liftExpr" binders rhss')) body',
208               foldr unionLiftInfo body_info rhs_infos)
209   where
210    (binders,rhss) = unzip pairs
211 \end{code}
212
213 \begin{code}
214 liftExpr (StgLet (StgNonRec binder rhs) body)
215   | not (isLiftable rhs)
216   = dontLiftRhs rhs     `thenLM` \ (rhs', rhs_info) ->
217     liftExpr body       `thenLM` \ (body', body_info) ->
218     returnLM (StgLet (StgNonRec binder rhs') body',
219               rhs_info `unionLiftInfo` body_info)
220
221   | otherwise   -- It's a lambda
222   =     -- Do the body of the let
223     fixLM (\ ~(sc_inline, _, _) ->
224       addScInlines [binder] [sc_inline] (
225         liftExpr body
226       )                 `thenLM` \ (body', body_info) ->
227
228         -- Deal with the RHS
229       dontLiftRhs rhs           `thenLM` \ (rhs', rhs_info) ->
230
231         -- All occurrences in function position, so lambda lift
232       getFinalFreeVars (rhsFreeVars rhs)    `thenLM` \ final_free_vars ->
233
234       mkScPieces final_free_vars (binder,rhs')  `thenLM` \ (sc_inline, sc_bind) ->
235
236       returnLM (sc_inline,
237                 body',
238                 nonRecScBind rhs_info sc_bind `unionLiftInfo` body_info)
239
240     )                   `thenLM` \ (_, expr', final_info) ->
241
242     returnLM (expr', final_info)
243
244 liftExpr (StgLet (StgRec pairs) body)
245 --[Andre-testing]
246   | not (all isLiftableRec rhss)
247   = liftExpr body                       `thenLM` \ (body', body_info) ->
248     mapAndUnzipLM dontLiftRhs rhss      `thenLM` \ (rhss', rhs_infos) ->
249     returnLM (StgLet (StgRec (zipEqual "liftExpr2" binders rhss')) body',
250               foldr unionLiftInfo body_info rhs_infos)
251
252   | otherwise   -- All rhss are liftable
253   = -- Do the body of the let
254     fixLM (\ ~(sc_inlines, _, _) ->
255       addScInlines binders sc_inlines   (
256
257       liftExpr body                     `thenLM` \ (body', body_info) ->
258       mapAndUnzipLM dontLiftRhs rhss    `thenLM` \ (rhss', rhs_infos) ->
259       let
260         -- Find the free vars of all the rhss,
261         -- excluding the binders themselves.
262         rhs_free_vars = unionManyIdSets (map rhsFreeVars rhss)
263                         `minusIdSet`
264                         mkIdSet binders
265
266         rhs_info      = unionLiftInfos rhs_infos
267       in
268       getFinalFreeVars rhs_free_vars    `thenLM` \ final_free_vars ->
269
270       mapAndUnzipLM (mkScPieces final_free_vars) (binders `zip` rhss')
271                                         `thenLM` \ (sc_inlines, sc_pairs) ->
272       returnLM (sc_inlines,
273                 body',
274                 recScBind rhs_info sc_pairs `unionLiftInfo` body_info)
275
276     ))                  `thenLM` \ (_, expr', final_info) ->
277
278     returnLM (expr', final_info)
279   where
280     (binders,rhss)    = unzip pairs
281 \end{code}
282
283 \begin{code}
284 liftExpr (StgSCC ty cc expr)
285   = liftExpr expr `thenLM` \ (expr2, expr_info) ->
286     returnLM (StgSCC ty cc expr2, expr_info)
287 \end{code}
288
289 A binding is liftable if it's a *function* (args not null) and never
290 occurs in an argument position.
291
292 \begin{code}
293 isLiftable :: StgRhs -> Bool
294
295 isLiftable (StgRhsClosure _ (StgBinderInfo arg_occ _ _ _ unapplied_occ) fvs _ args _)
296
297   -- Experimental evidence suggests we should lift only if we will be
298   -- abstracting up to 4 fvs.
299
300   = if not (null args   ||      -- Not a function
301          unapplied_occ  ||      -- Has an occ with no args at all
302          arg_occ        ||      -- Occurs in arg position
303          length fvs > 4         -- Too many free variables
304         )
305     then {-trace ("LL: " ++ show (length fvs))-} True
306     else False
307 isLiftable other_rhs = False
308
309 isLiftableRec :: StgRhs -> Bool
310
311 -- this is just the same as for non-rec, except we only lift to
312 -- abstract up to 1 argument this avoids undoing Static Argument
313 -- Transformation work
314
315 {- Andre's longer comment about isLiftableRec: 1996/01:
316
317 A rec binding is "liftable" (according to our heuristics) if:
318 * It is a function,
319 * all occurrences have arguments,
320 * does not occur in an argument position and
321 * has up to *2* free variables (including the rec binding variable
322   itself!)
323
324 The point is: my experiments show that SAT is more important than LL.
325 Therefore if we still want to do LL, for *recursive* functions, we do
326 not want LL to undo what SAT did.  We do this by avoiding LL recursive
327 functions that have more than 2 fvs, since if this recursive function
328 was created by SAT (we don't know!), it would have at least 3 fvs: one
329 for the rec binding itself and 2 more for the static arguments (note:
330 this matches with the choice of performing SAT to have at least 2
331 static arguments, if we change things there we should change things
332 here).
333 -}
334
335 isLiftableRec (StgRhsClosure _ (StgBinderInfo arg_occ _ _ _ unapplied_occ) fvs _ args _)
336   = if not (null args   ||      -- Not a function
337          unapplied_occ  ||      -- Has an occ with no args at all
338          arg_occ        ||      -- Occurs in arg position
339          length fvs > 2         -- Too many free variables
340         )
341     then {-trace ("LLRec: " ++ show (length fvs))-} True
342     else False
343 isLiftableRec other_rhs = False
344
345 rhsFreeVars :: StgRhs -> IdSet
346 rhsFreeVars (StgRhsClosure _ _ fvs _ _ _) = mkIdSet fvs
347 rhsFreeVars other                         = panic "rhsFreeVars"
348 \end{code}
349
350 dontLiftRhs is like liftExpr, except that it does not lift a top-level
351 lambda abstraction.  It is used for the right-hand sides of
352 definitions where we've decided *not* to lift: for example, top-level
353 ones or mutually-recursive ones where not all are lambdas.
354
355 \begin{code}
356 dontLiftRhs :: StgRhs -> LiftM (StgRhs, LiftInfo)
357
358 dontLiftRhs rhs@(StgRhsCon cc v args) = returnLM (rhs, emptyLiftInfo)
359
360 dontLiftRhs (StgRhsClosure cc bi fvs upd args body)
361   = liftExpr body       `thenLM` \ (body', body_info) ->
362     returnLM (StgRhsClosure cc bi fvs upd args body', body_info)
363 \end{code}
364
365 \begin{code}
366 mkScPieces :: IdSet             -- Extra args for the supercombinator
367            -> (Id, StgRhs)      -- The processed RHS and original Id
368            -> LiftM ((Id,[Id]),         -- Replace abstraction with this;
369                                                 -- the set is its free vars
370                      (Id,StgRhs))       -- Binding for supercombinator
371
372 mkScPieces extra_arg_set (id, StgRhsClosure cc bi _ upd args body)
373   = ASSERT( n_args > 0 )
374         -- Construct the rhs of the supercombinator, and its Id
375     newSupercombinator sc_ty arity  `thenLM` \ sc_id ->
376     returnLM ((sc_id, extra_args), (sc_id, sc_rhs))
377   where
378     n_args     = length args
379     extra_args = idSetToList extra_arg_set
380     arity      = n_args + length extra_args
381
382         -- Construct the supercombinator type
383     type_of_original_id = idType id
384     extra_arg_tys       = map idType extra_args
385     (tyvars, rest)      = splitForAllTy type_of_original_id
386     sc_ty               = mkForAllTys tyvars (mkFunTys extra_arg_tys rest)
387
388     sc_rhs = StgRhsClosure cc bi [] upd (extra_args ++ args) body
389 \end{code}
390
391
392 %************************************************************************
393 %*                                                                      *
394 \subsection[Lift-monad]{The LiftM monad}
395 %*                                                                      *
396 %************************************************************************
397
398 The monad is used only to distribute global stuff, and the unique supply.
399
400 \begin{code}
401 type LiftM a =  Module 
402              -> LiftFlags
403              -> UniqSupply
404              -> (IdEnv                          -- Domain = candidates for lifting
405                        (Id,                     -- The supercombinator
406                         [Id])                   -- Args to apply it to
407                  )
408              -> a
409
410
411 type LiftFlags = Maybe Int      -- No of fvs reqd to float recursive
412                                 -- binding; Nothing == infinity
413
414
415 runLM :: Module -> LiftFlags -> UniqSupply -> LiftM a -> a
416 runLM mod flags us m = m mod flags us nullIdEnv
417
418 thenLM :: LiftM a -> (a -> LiftM b) -> LiftM b
419 thenLM m k mod ci us idenv
420   = k (m mod ci us1 idenv) mod ci us2 idenv
421   where
422     (us1, us2) = splitUniqSupply us
423
424 returnLM :: a -> LiftM a
425 returnLM a mod ci us idenv = a
426
427 fixLM :: (a -> LiftM a) -> LiftM a
428 fixLM k mod ci us idenv = r
429                        where
430                          r = k r mod ci us idenv
431
432 mapLM :: (a -> LiftM b) -> [a] -> LiftM [b]
433 mapLM f [] = returnLM []
434 mapLM f (a:as) = f a            `thenLM` \ r ->
435                  mapLM f as     `thenLM` \ rs ->
436                  returnLM (r:rs)
437
438 mapAndUnzipLM :: (a -> LiftM (b,c)) -> [a] -> LiftM ([b],[c])
439 mapAndUnzipLM f []     = returnLM ([],[])
440 mapAndUnzipLM f (a:as) = f a                    `thenLM` \ (b,c) ->
441                          mapAndUnzipLM f as     `thenLM` \ (bs,cs) ->
442                          returnLM (b:bs, c:cs)
443 \end{code}
444
445 \begin{code}
446 newSupercombinator :: Type
447                    -> Int               -- Arity
448                    -> LiftM Id
449
450 newSupercombinator ty arity mod ci us idenv
451   = setIdVisibility (Just mod) uniq (mkSysLocal SLIT("sc") uniq ty noSrcLoc)
452     `addIdArity` exactArity arity
453         -- ToDo: rm the addIdArity?  Just let subsequent stg-saturation pass do it?
454   where
455     uniq = getUnique us
456
457 lookUp :: Id -> LiftM (Id,[Id])
458 lookUp v mod ci us idenv
459   = case (lookupIdEnv idenv v) of
460       Just result -> result
461       Nothing     -> (v, [])
462
463 addScInlines :: [Id] -> [(Id,[Id])] -> LiftM a -> LiftM a
464 addScInlines ids values m mod ci us idenv
465   = m mod ci us idenv'
466   where
467     idenv' = growIdEnvList idenv (ids `zip_lazy` values)
468
469     -- zip_lazy zips two things together but matches lazily on the
470     -- second argument.  This is important, because the ids are know here,
471     -- but the things they are bound to are decided only later
472     zip_lazy [] _           = []
473     zip_lazy (x:xs) ~(y:ys) = (x,y) : zip_lazy xs ys
474
475
476 -- The free vars reported by the free-var analyser will include
477 -- some ids, f, which are to be replaced by ($f a b c), where $f
478 -- is the supercombinator.  Hence instead of f being a free var,
479 -- {a,b,c} are.
480 --
481 -- Example
482 --      let
483 --         f a = ...y1..y2.....
484 --      in
485 --      let
486 --         g b = ...f...z...
487 --      in
488 --      ...
489 --
490 --  Here the free vars of g are {f,z}; but f will be lambda-lifted
491 --  with free vars {y1,y2}, so the "real~ free vars of g are {y1,y2,z}.
492
493 getFinalFreeVars :: IdSet -> LiftM IdSet
494
495 getFinalFreeVars free_vars mod ci us idenv
496   = unionManyIdSets (map munge_it (idSetToList free_vars))
497   where
498     munge_it :: Id -> IdSet     -- Takes a free var and maps it to the "real"
499                                 -- free var
500     munge_it id = case (lookupIdEnv idenv id) of
501                     Just (_, args) -> mkIdSet args
502                     Nothing        -> unitIdSet id
503 \end{code}
504
505
506 %************************************************************************
507 %*                                                                      *
508 \subsection[Lift-info]{The LiftInfo type}
509 %*                                                                      *
510 %************************************************************************
511
512 \begin{code}
513 type LiftInfo = Bag StgBinding  -- Float to top
514
515 emptyLiftInfo = emptyBag
516
517 unionLiftInfo :: LiftInfo -> LiftInfo -> LiftInfo
518 unionLiftInfo binds1 binds2 = binds1 `unionBags` binds2
519
520 unionLiftInfos :: [LiftInfo] -> LiftInfo
521 unionLiftInfos infos = foldr unionLiftInfo emptyLiftInfo infos
522
523 mkScInfo :: StgBinding -> LiftInfo
524 mkScInfo bind = unitBag bind
525
526 nonRecScBind :: LiftInfo                -- From body of supercombinator
527              -> (Id, StgRhs)    -- Supercombinator and its rhs
528              -> LiftInfo
529 nonRecScBind binds (sc_id,sc_rhs) = binds `snocBag` (StgNonRec sc_id sc_rhs)
530
531
532 -- In the recursive case, all the SCs from the RHSs of the recursive group
533 -- are dealing with might potentially mention the new, recursive SCs.
534 -- So we flatten the whole lot into a single recursive group.
535
536 recScBind :: LiftInfo                   -- From body of supercombinator
537            -> [(Id,StgRhs)]     -- Supercombinator rhs
538            -> LiftInfo
539
540 recScBind binds pairs = unitBag (co_rec_ify (StgRec pairs : bagToList binds))
541
542 co_rec_ify :: [StgBinding] -> StgBinding
543 co_rec_ify binds = StgRec (concat (map f binds))
544   where
545     f (StgNonRec id rhs) = [(id,rhs)]
546     f (StgRec pairs)     = pairs
547
548
549 getScBinds :: LiftInfo -> [StgBinding]
550 getScBinds binds = bagToList binds
551
552 looksLikeSATRhs [(f,StgRhsClosure _ _ _ _ ls _)] (StgApp (StgVarArg f') args _)
553   = (f == f') && (length args == length ls)
554 looksLikeSATRhs _ _ = False
555 \end{code}