Move all the CoreToDo stuff into CoreMonad
[ghc-hetmet.git] / compiler / simplCore / FloatOut.lhs
1 %
2 % (c) The GRASP/AQUA Project, Glasgow University, 1992-1998
3 %
4 \section[FloatOut]{Float bindings outwards (towards the top level)}
5
6 ``Long-distance'' floating of bindings towards the top level.
7
8 \begin{code}
9 module FloatOut ( floatOutwards ) where
10
11 import CoreSyn
12 import CoreUtils
13 import CoreArity        ( etaExpand )
14 import CoreMonad        ( FloatOutSwitches(..) )
15
16 import DynFlags         ( DynFlags, DynFlag(..) )
17 import ErrUtils         ( dumpIfSet_dyn )
18 import CostCentre       ( dupifyCC, CostCentre )
19 import Id               ( Id, idType, idArity, isBottomingId )
20 import Type             ( isUnLiftedType )
21 import SetLevels        ( Level(..), LevelledExpr, LevelledBind,
22                           setLevels, isTopLvl, tOP_LEVEL )
23 import UniqSupply       ( UniqSupply )
24 import Bag
25 import Util
26 import Maybes
27 import UniqFM
28 import Outputable
29 import FastString
30 \end{code}
31
32         -----------------
33         Overall game plan
34         -----------------
35
36 The Big Main Idea is:
37
38         To float out sub-expressions that can thereby get outside
39         a non-one-shot value lambda, and hence may be shared.
40
41
42 To achieve this we may need to do two thing:
43
44    a) Let-bind the sub-expression:
45
46         f (g x)  ==>  let lvl = f (g x) in lvl
47
48       Now we can float the binding for 'lvl'.  
49
50    b) More than that, we may need to abstract wrt a type variable
51
52         \x -> ... /\a -> let v = ...a... in ....
53
54       Here the binding for v mentions 'a' but not 'x'.  So we
55       abstract wrt 'a', to give this binding for 'v':
56
57             vp = /\a -> ...a...
58             v  = vp a
59
60       Now the binding for vp can float out unimpeded.
61       I can't remember why this case seemed important enough to
62       deal with, but I certainly found cases where important floats
63       didn't happen if we did not abstract wrt tyvars.
64
65 With this in mind we can also achieve another goal: lambda lifting.
66 We can make an arbitrary (function) binding float to top level by
67 abstracting wrt *all* local variables, not just type variables, leaving
68 a binding that can be floated right to top level.  Whether or not this
69 happens is controlled by a flag.
70
71
72 Random comments
73 ~~~~~~~~~~~~~~~
74
75 At the moment we never float a binding out to between two adjacent
76 lambdas.  For example:
77
78 @
79         \x y -> let t = x+x in ...
80 ===>
81         \x -> let t = x+x in \y -> ...
82 @
83 Reason: this is less efficient in the case where the original lambda
84 is never partially applied.
85
86 But there's a case I've seen where this might not be true.  Consider:
87 @
88 elEm2 x ys
89   = elem' x ys
90   where
91     elem' _ []  = False
92     elem' x (y:ys)      = x==y || elem' x ys
93 @
94 It turns out that this generates a subexpression of the form
95 @
96         \deq x ys -> let eq = eqFromEqDict deq in ...
97 @
98 vwhich might usefully be separated to
99 @
100         \deq -> let eq = eqFromEqDict deq in \xy -> ...
101 @
102 Well, maybe.  We don't do this at the moment.
103
104
105 %************************************************************************
106 %*                                                                      *
107 \subsection[floatOutwards]{@floatOutwards@: let-floating interface function}
108 %*                                                                      *
109 %************************************************************************
110
111 \begin{code}
112 floatOutwards :: FloatOutSwitches
113               -> DynFlags
114               -> UniqSupply 
115               -> [CoreBind] -> IO [CoreBind]
116
117 floatOutwards float_sws dflags us pgm
118   = do {
119         let { annotated_w_levels = setLevels float_sws pgm us ;
120               (fss, binds_s')    = unzip (map floatTopBind annotated_w_levels)
121             } ;
122
123         dumpIfSet_dyn dflags Opt_D_verbose_core2core "Levels added:"
124                   (vcat (map ppr annotated_w_levels));
125
126         let { (tlets, ntlets, lams) = get_stats (sum_stats fss) };
127
128         dumpIfSet_dyn dflags Opt_D_dump_simpl_stats "FloatOut stats:"
129                 (hcat [ int tlets,  ptext (sLit " Lets floated to top level; "),
130                         int ntlets, ptext (sLit " Lets floated elsewhere; from "),
131                         int lams,   ptext (sLit " Lambda groups")]);
132
133         return (concat binds_s')
134     }
135
136 floatTopBind :: LevelledBind -> (FloatStats, [CoreBind])
137 floatTopBind bind
138   = case (floatBind bind) of { (fs, floats) ->
139     (fs, bagToList (flattenFloats floats))
140     }
141 \end{code}
142
143 %************************************************************************
144 %*                                                                      *
145 \subsection[FloatOut-Bind]{Floating in a binding (the business end)}
146 %*                                                                      *
147 %************************************************************************
148
149 \begin{code}
150 floatBind :: LevelledBind -> (FloatStats, FloatBinds)
151
152 floatBind (NonRec (TB var level) rhs)
153   = case (floatRhs level rhs) of { (fs, rhs_floats, rhs') ->
154
155         -- A tiresome hack: 
156         -- see Note [Bottoming floats: eta expansion] in SetLevels
157     let rhs'' | isBottomingId var = etaExpand (idArity var) rhs'
158               | otherwise         = rhs'
159
160     in (fs, rhs_floats `plusFloats` unitFloat level (NonRec var rhs'')) }
161
162 floatBind bind@(Rec pairs)
163   = case (unzip3 (map do_pair pairs)) of { (fss, rhss_floats, new_pairs) ->
164     let rhs_floats = foldr1 plusFloats rhss_floats in
165
166     if not (isTopLvl bind_dest_lvl) then
167         -- Find which bindings float out at least one lambda beyond this one
168         -- These ones can't mention the binders, because they couldn't 
169         -- be escaping a major level if so.
170         -- The ones that are not going further can join the letrec;
171         -- they may not be mutually recursive but the occurrence analyser will
172         -- find that out.
173         case (partitionByMajorLevel bind_dest_lvl rhs_floats) of { (floats', heres) ->
174         (sum_stats fss, 
175          floats' `plusFloats` unitFloat bind_dest_lvl 
176                                 (Rec (floatsToBindPairs heres new_pairs))) }
177     else
178         -- In a recursive binding, *destined for* the top level
179         -- (only), the rhs floats may contain references to the 
180         -- bound things.  For example
181         --      f = ...(let v = ...f... in b) ...
182         --  might get floated to
183         --      v = ...f...
184         --      f = ... b ...
185         -- and hence we must (pessimistically) make all the floats recursive
186         -- with the top binding.  Later dependency analysis will unravel it.
187         --
188         -- This can only happen for bindings destined for the top level,
189         -- because only then will partitionByMajorLevel allow through a binding
190         -- that only differs in its minor level
191         (sum_stats fss, unitFloat tOP_LEVEL
192                            (Rec (floatsToBindPairs (flattenFloats rhs_floats) new_pairs)))
193     }
194   where
195     bind_dest_lvl = getBindLevel bind
196
197     do_pair (TB name level, rhs)
198       = case (floatRhs level rhs) of { (fs, rhs_floats, rhs') ->
199         (fs, rhs_floats, (name, rhs'))
200         }
201 \end{code}
202
203 %************************************************************************
204
205 \subsection[FloatOut-Expr]{Floating in expressions}
206 %*                                                                      *
207 %************************************************************************
208
209 \begin{code}
210 floatExpr, floatRhs, floatCaseAlt
211          :: Level
212          -> LevelledExpr
213          -> (FloatStats, FloatBinds, CoreExpr)
214
215 floatCaseAlt lvl arg    -- Used rec rhss, and case-alternative rhss
216   = case (floatExpr lvl arg) of { (fsa, floats, arg') ->
217     case (partitionByMajorLevel lvl floats) of { (floats', heres) ->
218         -- Dump bindings that aren't going to escape from a lambda;
219         -- in particular, we must dump the ones that are bound by 
220         -- the rec or case alternative
221     (fsa, floats', install heres arg') }}
222
223 floatRhs lvl arg        -- Used for nested non-rec rhss, and fn args
224                         -- See Note [Floating out of RHS]
225   = case (floatExpr lvl arg) of { (fsa, floats, arg') ->
226     if exprIsCheap arg' then    
227         (fsa, floats, arg')
228     else
229     case (partitionByMajorLevel lvl floats) of { (floats', heres) ->
230     (fsa, floats', install heres arg') }}
231
232 -- Note [Floating out of RHSs]
233 -- ~~~~~~~~~~~~~~~~~~~~~~~~~~~
234 -- Dump bindings that aren't going to escape from a lambda
235 -- This isn't a scoping issue (the binder isn't in scope in the RHS 
236 --      of a non-rec binding)
237 -- Rather, it is to avoid floating the x binding out of
238 --      f (let x = e in b)
239 -- unnecessarily.  But we first test for values or trival rhss,
240 -- because (in particular) we don't want to insert new bindings between
241 -- the "=" and the "\".  E.g.
242 --      f = \x -> let <bind> in <body>
243 -- We do not want
244 --      f = let <bind> in \x -> <body>
245 -- (a) The simplifier will immediately float it further out, so we may
246 --      as well do so right now; in general, keeping rhss as manifest 
247 --      values is good
248 -- (b) If a float-in pass follows immediately, it might add yet more
249 --      bindings just after the '='.  And some of them might (correctly)
250 --      be strict even though the 'let f' is lazy, because f, being a value,
251 --      gets its demand-info zapped by the simplifier.
252 --
253 -- We use exprIsCheap because that is also what's used by the simplifier
254 -- to decide whether to float a let out of a let
255
256 floatExpr _ (Var v)   = (zeroStats, emptyFloats, Var v)
257 floatExpr _ (Type ty) = (zeroStats, emptyFloats, Type ty)
258 floatExpr _ (Lit lit) = (zeroStats, emptyFloats, Lit lit)
259           
260 floatExpr lvl (App e a)
261   = case (floatExpr      lvl e) of { (fse, floats_e, e') ->
262     case (floatRhs lvl a)       of { (fsa, floats_a, a') ->
263     (fse `add_stats` fsa, floats_e `plusFloats` floats_a, App e' a') }}
264
265 floatExpr _ lam@(Lam _ _)
266   = let
267         (bndrs_w_lvls, body) = collectBinders lam
268         bndrs                = [b | TB b _ <- bndrs_w_lvls]
269         lvls                 = [l | TB _ l <- bndrs_w_lvls]
270
271         -- For the all-tyvar case we are prepared to pull 
272         -- the lets out, to implement the float-out-of-big-lambda
273         -- transform; but otherwise we only float bindings that are
274         -- going to escape a value lambda.
275         -- In particular, for one-shot lambdas we don't float things
276         -- out; we get no saving by so doing.
277         partition_fn | all isTyVar bndrs = partitionByLevel
278                      | otherwise         = partitionByMajorLevel
279     in
280     case (floatExpr (last lvls) body) of { (fs, floats, body') ->
281
282         -- Dump any bindings which absolutely cannot go any further
283     case (partition_fn (head lvls) floats)      of { (floats', heres) ->
284
285     (add_to_stats fs floats', floats', mkLams bndrs (install heres body'))
286     }}
287
288 floatExpr lvl (Note note@(SCC cc) expr)
289   = case (floatExpr lvl expr)    of { (fs, floating_defns, expr') ->
290     let
291         -- Annotate bindings floated outwards past an scc expression
292         -- with the cc.  We mark that cc as "duplicated", though.
293
294         annotated_defns = wrapCostCentre (dupifyCC cc) floating_defns
295     in
296     (fs, annotated_defns, Note note expr') }
297
298 floatExpr lvl (Note note expr)  -- Other than SCCs
299   = case (floatExpr lvl expr)    of { (fs, floating_defns, expr') ->
300     (fs, floating_defns, Note note expr') }
301
302 floatExpr lvl (Cast expr co)
303   = case (floatExpr lvl expr)   of { (fs, floating_defns, expr') ->
304     (fs, floating_defns, Cast expr' co) }
305
306 floatExpr lvl (Let (NonRec (TB bndr bndr_lvl) rhs) body)
307   | isUnLiftedType (idType bndr)  -- Treat unlifted lets just like a case
308                                   -- I.e. floatExpr for rhs, floatCaseAlt for body
309   = case floatExpr lvl rhs          of { (_, rhs_floats, rhs') ->
310     case floatCaseAlt bndr_lvl body of { (fs, body_floats, body') ->
311     (fs, rhs_floats `plusFloats` body_floats, Let (NonRec bndr rhs') body') }}
312
313 floatExpr lvl (Let bind body)
314   = case (floatBind bind)     of { (fsb, bind_floats) ->
315     case (floatExpr lvl body) of { (fse, body_floats, body') ->
316     (add_stats fsb fse,
317      bind_floats `plusFloats` body_floats,
318      body')  }}
319
320 floatExpr lvl (Case scrut (TB case_bndr case_lvl) ty alts)
321   = case floatExpr lvl scrut    of { (fse, fde, scrut') ->
322     case floatList float_alt alts       of { (fsa, fda, alts')  ->
323     (add_stats fse fsa, fda `plusFloats` fde, Case scrut' case_bndr ty alts')
324     }}
325   where
326         -- Use floatCaseAlt for the alternatives, so that we
327         -- don't gratuitiously float bindings out of the RHSs
328     float_alt (con, bs, rhs)
329         = case (floatCaseAlt case_lvl rhs)      of { (fs, rhs_floats, rhs') ->
330           (fs, rhs_floats, (con, [b | TB b _ <- bs], rhs')) }
331
332
333 floatList :: (a -> (FloatStats, FloatBinds, b)) -> [a] -> (FloatStats, FloatBinds, [b])
334 floatList _ [] = (zeroStats, emptyFloats, [])
335 floatList f (a:as) = case f a            of { (fs_a,  binds_a,  b)  ->
336                      case floatList f as of { (fs_as, binds_as, bs) ->
337                      (fs_a `add_stats` fs_as, binds_a `plusFloats`  binds_as, b:bs) }}
338
339 getBindLevel :: Bind (TaggedBndr Level) -> Level
340 getBindLevel (NonRec (TB _ lvl) _)       = lvl
341 getBindLevel (Rec (((TB _ lvl), _) : _)) = lvl
342 getBindLevel (Rec [])                    = panic "getBindLevel Rec []"
343 \end{code}
344
345 %************************************************************************
346 %*                                                                      *
347 \subsection{Utility bits for floating stats}
348 %*                                                                      *
349 %************************************************************************
350
351 I didn't implement this with unboxed numbers.  I don't want to be too
352 strict in this stuff, as it is rarely turned on.  (WDP 95/09)
353
354 \begin{code}
355 data FloatStats
356   = FlS Int  -- Number of top-floats * lambda groups they've been past
357         Int  -- Number of non-top-floats * lambda groups they've been past
358         Int  -- Number of lambda (groups) seen
359
360 get_stats :: FloatStats -> (Int, Int, Int)
361 get_stats (FlS a b c) = (a, b, c)
362
363 zeroStats :: FloatStats
364 zeroStats = FlS 0 0 0
365
366 sum_stats :: [FloatStats] -> FloatStats
367 sum_stats xs = foldr add_stats zeroStats xs
368
369 add_stats :: FloatStats -> FloatStats -> FloatStats
370 add_stats (FlS a1 b1 c1) (FlS a2 b2 c2)
371   = FlS (a1 + a2) (b1 + b2) (c1 + c2)
372
373 add_to_stats :: FloatStats -> FloatBinds -> FloatStats
374 add_to_stats (FlS a b c) (FB tops others)
375   = FlS (a + lengthBag tops) (b + lengthBag (flattenMajor others)) (c + 1)
376 \end{code}
377
378
379 %************************************************************************
380 %*                                                                      *
381 \subsection{Utility bits for floating}
382 %*                                                                      *
383 %************************************************************************
384
385 Note [Representation of FloatBinds]
386 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
387 The FloatBinds types is somewhat important.  We can get very large numbers
388 of floating bindings, often all destined for the top level.  A typical example
389 is     x = [4,2,5,2,5, .... ]
390 Then we get lots of small expressions like (fromInteger 4), which all get
391 lifted to top level.  
392
393 The trouble is that  
394   (a) we partition these floating bindings *at every binding site* 
395   (b) SetLevels introduces a new bindings site for every float
396 So we had better not look at each binding at each binding site!
397
398 That is why MajorEnv is represented as a finite map.
399
400 We keep the bindings destined for the *top* level separate, because
401 we float them out even if they don't escape a *value* lambda; see
402 partitionByMajorLevel.
403
404
405 \begin{code}
406 type FloatBind = CoreBind       -- INVARIANT: a FloatBind is always lifted
407
408 data FloatBinds  = FB !(Bag FloatBind)          -- Destined for top level
409                       !MajorEnv                 -- Levels other than top
410      -- See Note [Representation of FloatBinds]
411
412 type MajorEnv = UniqFM MinorEnv                 -- Keyed by major level
413 type MinorEnv = UniqFM (Bag FloatBind)          -- Keyed by minor level
414
415 flattenFloats :: FloatBinds -> Bag FloatBind
416 flattenFloats (FB tops others) = tops `unionBags` flattenMajor others
417
418 flattenMajor :: MajorEnv -> Bag FloatBind
419 flattenMajor = foldUFM (unionBags . flattenMinor) emptyBag
420
421 flattenMinor :: MinorEnv -> Bag FloatBind
422 flattenMinor = foldUFM unionBags emptyBag
423
424 emptyFloats :: FloatBinds
425 emptyFloats = FB emptyBag emptyUFM
426
427 unitFloat :: Level -> FloatBind -> FloatBinds
428 unitFloat lvl@(Level major minor) b 
429   | isTopLvl lvl = FB (unitBag b) emptyUFM
430   | otherwise    = FB emptyBag (unitUFM major (unitUFM minor (unitBag b)))
431
432 plusFloats :: FloatBinds -> FloatBinds -> FloatBinds
433 plusFloats (FB t1 b1) (FB t2 b2) = FB (t1 `unionBags` t2) (b1 `plusMajor` b2)
434
435 plusMajor :: MajorEnv -> MajorEnv -> MajorEnv
436 plusMajor = plusUFM_C plusMinor
437
438 plusMinor :: MinorEnv -> MinorEnv -> MinorEnv
439 plusMinor = plusUFM_C unionBags
440
441 floatsToBindPairs :: Bag FloatBind -> [(Id,CoreExpr)] -> [(Id,CoreExpr)]
442 floatsToBindPairs floats binds = foldrBag add binds floats
443   where
444    add (Rec pairs)         binds = pairs ++ binds
445    add (NonRec binder rhs) binds = (binder,rhs) : binds
446
447 install :: Bag FloatBind -> CoreExpr -> CoreExpr
448 install defn_groups expr
449   = foldrBag install_group expr defn_groups
450   where
451     install_group defns body = Let defns body
452
453 partitionByMajorLevel, partitionByLevel
454         :: Level                -- Partitioning level
455         -> FloatBinds           -- Defns to be divided into 2 piles...
456         -> (FloatBinds,         -- Defns  with level strictly < partition level,
457             Bag FloatBind)      -- The rest
458
459 --       ---- partitionByMajorLevel ----
460 -- Float it if we escape a value lambda, *or* if we get to the top level
461 -- If we can get to the top level, say "yes" anyway. This means that 
462 --      x = f e
463 -- transforms to 
464 --    lvl = e
465 --    x = f lvl
466 -- which is as it should be
467
468 partitionByMajorLevel (Level major _) (FB tops defns)
469   = (FB tops outer, heres `unionBags` flattenMajor inner)
470   where
471     (outer, mb_heres, inner) = splitUFM defns major
472     heres = case mb_heres of 
473                Nothing -> emptyBag
474                Just h  -> flattenMinor h
475
476 partitionByLevel (Level major minor) (FB tops defns)
477   = (FB tops (outer_maj `plusMajor` unitUFM major outer_min),
478      here_min `unionBags` flattenMinor inner_min 
479               `unionBags` flattenMajor inner_maj)
480
481   where
482     (outer_maj, mb_here_maj, inner_maj) = splitUFM defns major
483     (outer_min, mb_here_min, inner_min) = case mb_here_maj of
484                                             Nothing -> (emptyUFM, Nothing, emptyUFM)
485                                             Just min_defns -> splitUFM min_defns minor
486     here_min = mb_here_min `orElse` emptyBag
487
488 wrapCostCentre :: CostCentre -> FloatBinds -> FloatBinds
489 wrapCostCentre cc (FB tops defns)
490   = FB (wrap_defns tops) (mapUFM (mapUFM wrap_defns) defns)
491   where
492     wrap_defns = mapBag wrap_one 
493     wrap_one (NonRec binder rhs) = NonRec binder (mkSCC cc rhs)
494     wrap_one (Rec pairs)         = Rec (mapSnd (mkSCC cc) pairs)
495 \end{code}
496