Bottom extraction: float out bottoming expressions to top level
[ghc-hetmet.git] / compiler / simplCore / FloatOut.lhs
1 %
2 % (c) The GRASP/AQUA Project, Glasgow University, 1992-1998
3 %
4 \section[FloatOut]{Float bindings outwards (towards the top level)}
5
6 ``Long-distance'' floating of bindings towards the top level.
7
8 \begin{code}
9 module FloatOut ( floatOutwards ) where
10
11 import CoreSyn
12 import CoreUtils
13 import CoreArity        ( etaExpand )
14
15 import DynFlags ( DynFlags, DynFlag(..), FloatOutSwitches(..) )
16 import ErrUtils         ( dumpIfSet_dyn )
17 import CostCentre       ( dupifyCC, CostCentre )
18 import Id               ( Id, idType, idArity, isBottomingId )
19 import Type             ( isUnLiftedType )
20 import SetLevels        ( Level(..), LevelledExpr, LevelledBind,
21                           setLevels, isTopLvl, tOP_LEVEL )
22 import UniqSupply       ( UniqSupply )
23 import Bag
24 import Util
25 import Maybes
26 import UniqFM
27 import Outputable
28 import FastString
29 \end{code}
30
31         -----------------
32         Overall game plan
33         -----------------
34
35 The Big Main Idea is:
36
37         To float out sub-expressions that can thereby get outside
38         a non-one-shot value lambda, and hence may be shared.
39
40
41 To achieve this we may need to do two thing:
42
43    a) Let-bind the sub-expression:
44
45         f (g x)  ==>  let lvl = f (g x) in lvl
46
47       Now we can float the binding for 'lvl'.  
48
49    b) More than that, we may need to abstract wrt a type variable
50
51         \x -> ... /\a -> let v = ...a... in ....
52
53       Here the binding for v mentions 'a' but not 'x'.  So we
54       abstract wrt 'a', to give this binding for 'v':
55
56             vp = /\a -> ...a...
57             v  = vp a
58
59       Now the binding for vp can float out unimpeded.
60       I can't remember why this case seemed important enough to
61       deal with, but I certainly found cases where important floats
62       didn't happen if we did not abstract wrt tyvars.
63
64 With this in mind we can also achieve another goal: lambda lifting.
65 We can make an arbitrary (function) binding float to top level by
66 abstracting wrt *all* local variables, not just type variables, leaving
67 a binding that can be floated right to top level.  Whether or not this
68 happens is controlled by a flag.
69
70
71 Random comments
72 ~~~~~~~~~~~~~~~
73
74 At the moment we never float a binding out to between two adjacent
75 lambdas.  For example:
76
77 @
78         \x y -> let t = x+x in ...
79 ===>
80         \x -> let t = x+x in \y -> ...
81 @
82 Reason: this is less efficient in the case where the original lambda
83 is never partially applied.
84
85 But there's a case I've seen where this might not be true.  Consider:
86 @
87 elEm2 x ys
88   = elem' x ys
89   where
90     elem' _ []  = False
91     elem' x (y:ys)      = x==y || elem' x ys
92 @
93 It turns out that this generates a subexpression of the form
94 @
95         \deq x ys -> let eq = eqFromEqDict deq in ...
96 @
97 vwhich might usefully be separated to
98 @
99         \deq -> let eq = eqFromEqDict deq in \xy -> ...
100 @
101 Well, maybe.  We don't do this at the moment.
102
103
104 %************************************************************************
105 %*                                                                      *
106 \subsection[floatOutwards]{@floatOutwards@: let-floating interface function}
107 %*                                                                      *
108 %************************************************************************
109
110 \begin{code}
111 floatOutwards :: FloatOutSwitches
112               -> DynFlags
113               -> UniqSupply 
114               -> [CoreBind] -> IO [CoreBind]
115
116 floatOutwards float_sws dflags us pgm
117   = do {
118         let { annotated_w_levels = setLevels float_sws pgm us ;
119               (fss, binds_s')    = unzip (map floatTopBind annotated_w_levels)
120             } ;
121
122         dumpIfSet_dyn dflags Opt_D_verbose_core2core "Levels added:"
123                   (vcat (map ppr annotated_w_levels));
124
125         let { (tlets, ntlets, lams) = get_stats (sum_stats fss) };
126
127         dumpIfSet_dyn dflags Opt_D_dump_simpl_stats "FloatOut stats:"
128                 (hcat [ int tlets,  ptext (sLit " Lets floated to top level; "),
129                         int ntlets, ptext (sLit " Lets floated elsewhere; from "),
130                         int lams,   ptext (sLit " Lambda groups")]);
131
132         return (concat binds_s')
133     }
134
135 floatTopBind :: LevelledBind -> (FloatStats, [CoreBind])
136 floatTopBind bind
137   = case (floatBind bind) of { (fs, floats) ->
138     (fs, bagToList (flattenFloats floats))
139     }
140 \end{code}
141
142 %************************************************************************
143 %*                                                                      *
144 \subsection[FloatOut-Bind]{Floating in a binding (the business end)}
145 %*                                                                      *
146 %************************************************************************
147
148 \begin{code}
149 floatBind :: LevelledBind -> (FloatStats, FloatBinds)
150
151 floatBind (NonRec (TB var level) rhs)
152   = case (floatRhs level rhs) of { (fs, rhs_floats, rhs') ->
153
154         -- A tiresome hack: 
155         -- see Note [Bottoming floats: eta expansion] in SetLevels
156     let rhs'' | isBottomingId var = etaExpand (idArity var) rhs'
157               | otherwise         = rhs'
158
159     in (fs, rhs_floats `plusFloats` unitFloat level (NonRec var rhs'')) }
160
161 floatBind bind@(Rec pairs)
162   = case (unzip3 (map do_pair pairs)) of { (fss, rhss_floats, new_pairs) ->
163     let rhs_floats = foldr1 plusFloats rhss_floats in
164
165     if not (isTopLvl bind_dest_lvl) then
166         -- Find which bindings float out at least one lambda beyond this one
167         -- These ones can't mention the binders, because they couldn't 
168         -- be escaping a major level if so.
169         -- The ones that are not going further can join the letrec;
170         -- they may not be mutually recursive but the occurrence analyser will
171         -- find that out.
172         case (partitionByMajorLevel bind_dest_lvl rhs_floats) of { (floats', heres) ->
173         (sum_stats fss, 
174          floats' `plusFloats` unitFloat bind_dest_lvl 
175                                 (Rec (floatsToBindPairs heres new_pairs))) }
176     else
177         -- In a recursive binding, *destined for* the top level
178         -- (only), the rhs floats may contain references to the 
179         -- bound things.  For example
180         --      f = ...(let v = ...f... in b) ...
181         --  might get floated to
182         --      v = ...f...
183         --      f = ... b ...
184         -- and hence we must (pessimistically) make all the floats recursive
185         -- with the top binding.  Later dependency analysis will unravel it.
186         --
187         -- This can only happen for bindings destined for the top level,
188         -- because only then will partitionByMajorLevel allow through a binding
189         -- that only differs in its minor level
190         (sum_stats fss, unitFloat tOP_LEVEL
191                            (Rec (floatsToBindPairs (flattenFloats rhs_floats) new_pairs)))
192     }
193   where
194     bind_dest_lvl = getBindLevel bind
195
196     do_pair (TB name level, rhs)
197       = case (floatRhs level rhs) of { (fs, rhs_floats, rhs') ->
198         (fs, rhs_floats, (name, rhs'))
199         }
200 \end{code}
201
202 %************************************************************************
203
204 \subsection[FloatOut-Expr]{Floating in expressions}
205 %*                                                                      *
206 %************************************************************************
207
208 \begin{code}
209 floatExpr, floatRhs, floatCaseAlt
210          :: Level
211          -> LevelledExpr
212          -> (FloatStats, FloatBinds, CoreExpr)
213
214 floatCaseAlt lvl arg    -- Used rec rhss, and case-alternative rhss
215   = case (floatExpr lvl arg) of { (fsa, floats, arg') ->
216     case (partitionByMajorLevel lvl floats) of { (floats', heres) ->
217         -- Dump bindings that aren't going to escape from a lambda;
218         -- in particular, we must dump the ones that are bound by 
219         -- the rec or case alternative
220     (fsa, floats', install heres arg') }}
221
222 floatRhs lvl arg        -- Used for nested non-rec rhss, and fn args
223                         -- See Note [Floating out of RHS]
224   = case (floatExpr lvl arg) of { (fsa, floats, arg') ->
225     if exprIsCheap arg' then    
226         (fsa, floats, arg')
227     else
228     case (partitionByMajorLevel lvl floats) of { (floats', heres) ->
229     (fsa, floats', install heres arg') }}
230
231 -- Note [Floating out of RHSs]
232 -- ~~~~~~~~~~~~~~~~~~~~~~~~~~~
233 -- Dump bindings that aren't going to escape from a lambda
234 -- This isn't a scoping issue (the binder isn't in scope in the RHS 
235 --      of a non-rec binding)
236 -- Rather, it is to avoid floating the x binding out of
237 --      f (let x = e in b)
238 -- unnecessarily.  But we first test for values or trival rhss,
239 -- because (in particular) we don't want to insert new bindings between
240 -- the "=" and the "\".  E.g.
241 --      f = \x -> let <bind> in <body>
242 -- We do not want
243 --      f = let <bind> in \x -> <body>
244 -- (a) The simplifier will immediately float it further out, so we may
245 --      as well do so right now; in general, keeping rhss as manifest 
246 --      values is good
247 -- (b) If a float-in pass follows immediately, it might add yet more
248 --      bindings just after the '='.  And some of them might (correctly)
249 --      be strict even though the 'let f' is lazy, because f, being a value,
250 --      gets its demand-info zapped by the simplifier.
251 --
252 -- We use exprIsCheap because that is also what's used by the simplifier
253 -- to decide whether to float a let out of a let
254
255 floatExpr _ (Var v)   = (zeroStats, emptyFloats, Var v)
256 floatExpr _ (Type ty) = (zeroStats, emptyFloats, Type ty)
257 floatExpr _ (Lit lit) = (zeroStats, emptyFloats, Lit lit)
258           
259 floatExpr lvl (App e a)
260   = case (floatExpr      lvl e) of { (fse, floats_e, e') ->
261     case (floatRhs lvl a)       of { (fsa, floats_a, a') ->
262     (fse `add_stats` fsa, floats_e `plusFloats` floats_a, App e' a') }}
263
264 floatExpr _ lam@(Lam _ _)
265   = let
266         (bndrs_w_lvls, body) = collectBinders lam
267         bndrs                = [b | TB b _ <- bndrs_w_lvls]
268         lvls                 = [l | TB _ l <- bndrs_w_lvls]
269
270         -- For the all-tyvar case we are prepared to pull 
271         -- the lets out, to implement the float-out-of-big-lambda
272         -- transform; but otherwise we only float bindings that are
273         -- going to escape a value lambda.
274         -- In particular, for one-shot lambdas we don't float things
275         -- out; we get no saving by so doing.
276         partition_fn | all isTyVar bndrs = partitionByLevel
277                      | otherwise         = partitionByMajorLevel
278     in
279     case (floatExpr (last lvls) body) of { (fs, floats, body') ->
280
281         -- Dump any bindings which absolutely cannot go any further
282     case (partition_fn (head lvls) floats)      of { (floats', heres) ->
283
284     (add_to_stats fs floats', floats', mkLams bndrs (install heres body'))
285     }}
286
287 floatExpr lvl (Note note@(SCC cc) expr)
288   = case (floatExpr lvl expr)    of { (fs, floating_defns, expr') ->
289     let
290         -- Annotate bindings floated outwards past an scc expression
291         -- with the cc.  We mark that cc as "duplicated", though.
292
293         annotated_defns = wrapCostCentre (dupifyCC cc) floating_defns
294     in
295     (fs, annotated_defns, Note note expr') }
296
297 floatExpr lvl (Note note expr)  -- Other than SCCs
298   = case (floatExpr lvl expr)    of { (fs, floating_defns, expr') ->
299     (fs, floating_defns, Note note expr') }
300
301 floatExpr lvl (Cast expr co)
302   = case (floatExpr lvl expr)   of { (fs, floating_defns, expr') ->
303     (fs, floating_defns, Cast expr' co) }
304
305 floatExpr lvl (Let (NonRec (TB bndr bndr_lvl) rhs) body)
306   | isUnLiftedType (idType bndr)  -- Treat unlifted lets just like a case
307                                   -- I.e. floatExpr for rhs, floatCaseAlt for body
308   = case floatExpr lvl rhs          of { (_, rhs_floats, rhs') ->
309     case floatCaseAlt bndr_lvl body of { (fs, body_floats, body') ->
310     (fs, rhs_floats `plusFloats` body_floats, Let (NonRec bndr rhs') body') }}
311
312 floatExpr lvl (Let bind body)
313   = case (floatBind bind)     of { (fsb, bind_floats) ->
314     case (floatExpr lvl body) of { (fse, body_floats, body') ->
315     (add_stats fsb fse,
316      bind_floats `plusFloats` body_floats,
317      body')  }}
318
319 floatExpr lvl (Case scrut (TB case_bndr case_lvl) ty alts)
320   = case floatExpr lvl scrut    of { (fse, fde, scrut') ->
321     case floatList float_alt alts       of { (fsa, fda, alts')  ->
322     (add_stats fse fsa, fda `plusFloats` fde, Case scrut' case_bndr ty alts')
323     }}
324   where
325         -- Use floatCaseAlt for the alternatives, so that we
326         -- don't gratuitiously float bindings out of the RHSs
327     float_alt (con, bs, rhs)
328         = case (floatCaseAlt case_lvl rhs)      of { (fs, rhs_floats, rhs') ->
329           (fs, rhs_floats, (con, [b | TB b _ <- bs], rhs')) }
330
331
332 floatList :: (a -> (FloatStats, FloatBinds, b)) -> [a] -> (FloatStats, FloatBinds, [b])
333 floatList _ [] = (zeroStats, emptyFloats, [])
334 floatList f (a:as) = case f a            of { (fs_a,  binds_a,  b)  ->
335                      case floatList f as of { (fs_as, binds_as, bs) ->
336                      (fs_a `add_stats` fs_as, binds_a `plusFloats`  binds_as, b:bs) }}
337
338 getBindLevel :: Bind (TaggedBndr Level) -> Level
339 getBindLevel (NonRec (TB _ lvl) _)       = lvl
340 getBindLevel (Rec (((TB _ lvl), _) : _)) = lvl
341 getBindLevel (Rec [])                    = panic "getBindLevel Rec []"
342 \end{code}
343
344 %************************************************************************
345 %*                                                                      *
346 \subsection{Utility bits for floating stats}
347 %*                                                                      *
348 %************************************************************************
349
350 I didn't implement this with unboxed numbers.  I don't want to be too
351 strict in this stuff, as it is rarely turned on.  (WDP 95/09)
352
353 \begin{code}
354 data FloatStats
355   = FlS Int  -- Number of top-floats * lambda groups they've been past
356         Int  -- Number of non-top-floats * lambda groups they've been past
357         Int  -- Number of lambda (groups) seen
358
359 get_stats :: FloatStats -> (Int, Int, Int)
360 get_stats (FlS a b c) = (a, b, c)
361
362 zeroStats :: FloatStats
363 zeroStats = FlS 0 0 0
364
365 sum_stats :: [FloatStats] -> FloatStats
366 sum_stats xs = foldr add_stats zeroStats xs
367
368 add_stats :: FloatStats -> FloatStats -> FloatStats
369 add_stats (FlS a1 b1 c1) (FlS a2 b2 c2)
370   = FlS (a1 + a2) (b1 + b2) (c1 + c2)
371
372 add_to_stats :: FloatStats -> FloatBinds -> FloatStats
373 add_to_stats (FlS a b c) (FB tops others)
374   = FlS (a + lengthBag tops) (b + lengthBag (flattenMajor others)) (c + 1)
375 \end{code}
376
377
378 %************************************************************************
379 %*                                                                      *
380 \subsection{Utility bits for floating}
381 %*                                                                      *
382 %************************************************************************
383
384 Note [Representation of FloatBinds]
385 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
386 The FloatBinds types is somewhat important.  We can get very large numbers
387 of floating bindings, often all destined for the top level.  A typical example
388 is     x = [4,2,5,2,5, .... ]
389 Then we get lots of small expressions like (fromInteger 4), which all get
390 lifted to top level.  
391
392 The trouble is that  
393   (a) we partition these floating bindings *at every binding site* 
394   (b) SetLevels introduces a new bindings site for every float
395 So we had better not look at each binding at each binding site!
396
397 That is why MajorEnv is represented as a finite map.
398
399 We keep the bindings destined for the *top* level separate, because
400 we float them out even if they don't escape a *value* lambda; see
401 partitionByMajorLevel.
402
403
404 \begin{code}
405 type FloatBind = CoreBind       -- INVARIANT: a FloatBind is always lifted
406
407 data FloatBinds  = FB !(Bag FloatBind)          -- Destined for top level
408                       !MajorEnv                 -- Levels other than top
409      -- See Note [Representation of FloatBinds]
410
411 type MajorEnv = UniqFM MinorEnv                 -- Keyed by major level
412 type MinorEnv = UniqFM (Bag FloatBind)          -- Keyed by minor level
413
414 flattenFloats :: FloatBinds -> Bag FloatBind
415 flattenFloats (FB tops others) = tops `unionBags` flattenMajor others
416
417 flattenMajor :: MajorEnv -> Bag FloatBind
418 flattenMajor = foldUFM (unionBags . flattenMinor) emptyBag
419
420 flattenMinor :: MinorEnv -> Bag FloatBind
421 flattenMinor = foldUFM unionBags emptyBag
422
423 emptyFloats :: FloatBinds
424 emptyFloats = FB emptyBag emptyUFM
425
426 unitFloat :: Level -> FloatBind -> FloatBinds
427 unitFloat lvl@(Level major minor) b 
428   | isTopLvl lvl = FB (unitBag b) emptyUFM
429   | otherwise    = FB emptyBag (unitUFM major (unitUFM minor (unitBag b)))
430
431 plusFloats :: FloatBinds -> FloatBinds -> FloatBinds
432 plusFloats (FB t1 b1) (FB t2 b2) = FB (t1 `unionBags` t2) (b1 `plusMajor` b2)
433
434 plusMajor :: MajorEnv -> MajorEnv -> MajorEnv
435 plusMajor = plusUFM_C plusMinor
436
437 plusMinor :: MinorEnv -> MinorEnv -> MinorEnv
438 plusMinor = plusUFM_C unionBags
439
440 floatsToBindPairs :: Bag FloatBind -> [(Id,CoreExpr)] -> [(Id,CoreExpr)]
441 floatsToBindPairs floats binds = foldrBag add binds floats
442   where
443    add (Rec pairs)         binds = pairs ++ binds
444    add (NonRec binder rhs) binds = (binder,rhs) : binds
445
446 install :: Bag FloatBind -> CoreExpr -> CoreExpr
447 install defn_groups expr
448   = foldrBag install_group expr defn_groups
449   where
450     install_group defns body = Let defns body
451
452 partitionByMajorLevel, partitionByLevel
453         :: Level                -- Partitioning level
454         -> FloatBinds           -- Defns to be divided into 2 piles...
455         -> (FloatBinds,         -- Defns  with level strictly < partition level,
456             Bag FloatBind)      -- The rest
457
458 --       ---- partitionByMajorLevel ----
459 -- Float it if we escape a value lambda, *or* if we get to the top level
460 -- If we can get to the top level, say "yes" anyway. This means that 
461 --      x = f e
462 -- transforms to 
463 --    lvl = e
464 --    x = f lvl
465 -- which is as it should be
466
467 partitionByMajorLevel (Level major _) (FB tops defns)
468   = (FB tops outer, heres `unionBags` flattenMajor inner)
469   where
470     (outer, mb_heres, inner) = splitUFM defns major
471     heres = case mb_heres of 
472                Nothing -> emptyBag
473                Just h  -> flattenMinor h
474
475 partitionByLevel (Level major minor) (FB tops defns)
476   = (FB tops (outer_maj `plusMajor` unitUFM major outer_min),
477      here_min `unionBags` flattenMinor inner_min 
478               `unionBags` flattenMajor inner_maj)
479
480   where
481     (outer_maj, mb_here_maj, inner_maj) = splitUFM defns major
482     (outer_min, mb_here_min, inner_min) = case mb_here_maj of
483                                             Nothing -> (emptyUFM, Nothing, emptyUFM)
484                                             Just min_defns -> splitUFM min_defns minor
485     here_min = mb_here_min `orElse` emptyBag
486
487 wrapCostCentre :: CostCentre -> FloatBinds -> FloatBinds
488 wrapCostCentre cc (FB tops defns)
489   = FB (wrap_defns tops) (mapUFM (mapUFM wrap_defns) defns)
490   where
491     wrap_defns = mapBag wrap_one 
492     wrap_one (NonRec binder rhs) = NonRec binder (mkSCC cc rhs)
493     wrap_one (Rec pairs)         = Rec (mapSnd (mkSCC cc) pairs)
494 \end{code}
495