2 % (c) The GRASP/AQUA Project, Glasgow University, 1992-1998
4 \section[FloatOut]{Float bindings outwards (towards the top level)}
6 ``Long-distance'' floating of bindings towards the top level.
9 module FloatOut ( floatOutwards ) where
13 import CoreArity ( etaExpand )
14 import CoreMonad ( FloatOutSwitches(..) )
16 import DynFlags ( DynFlags, DynFlag(..) )
17 import ErrUtils ( dumpIfSet_dyn )
18 import CostCentre ( dupifyCC, CostCentre )
19 import Id ( Id, idType, idArity, isBottomingId )
20 import Type ( isUnLiftedType )
21 import SetLevels ( Level(..), LevelledExpr, LevelledBind,
23 import UniqSupply ( UniqSupply )
29 import qualified Data.IntMap as M
31 #include "HsVersions.h"
40 To float out sub-expressions that can thereby get outside
41 a non-one-shot value lambda, and hence may be shared.
44 To achieve this we may need to do two thing:
46 a) Let-bind the sub-expression:
48 f (g x) ==> let lvl = f (g x) in lvl
50 Now we can float the binding for 'lvl'.
52 b) More than that, we may need to abstract wrt a type variable
54 \x -> ... /\a -> let v = ...a... in ....
56 Here the binding for v mentions 'a' but not 'x'. So we
57 abstract wrt 'a', to give this binding for 'v':
62 Now the binding for vp can float out unimpeded.
63 I can't remember why this case seemed important enough to
64 deal with, but I certainly found cases where important floats
65 didn't happen if we did not abstract wrt tyvars.
67 With this in mind we can also achieve another goal: lambda lifting.
68 We can make an arbitrary (function) binding float to top level by
69 abstracting wrt *all* local variables, not just type variables, leaving
70 a binding that can be floated right to top level. Whether or not this
71 happens is controlled by a flag.
77 At the moment we never float a binding out to between two adjacent
81 \x y -> let t = x+x in ...
83 \x -> let t = x+x in \y -> ...
85 Reason: this is less efficient in the case where the original lambda
86 is never partially applied.
88 But there's a case I've seen where this might not be true. Consider:
94 elem' x (y:ys) = x==y || elem' x ys
96 It turns out that this generates a subexpression of the form
98 \deq x ys -> let eq = eqFromEqDict deq in ...
100 vwhich might usefully be separated to
102 \deq -> let eq = eqFromEqDict deq in \xy -> ...
104 Well, maybe. We don't do this at the moment.
107 %************************************************************************
109 \subsection[floatOutwards]{@floatOutwards@: let-floating interface function}
111 %************************************************************************
114 floatOutwards :: FloatOutSwitches
117 -> [CoreBind] -> IO [CoreBind]
119 floatOutwards float_sws dflags us pgm
121 let { annotated_w_levels = setLevels float_sws pgm us ;
122 (fss, binds_s') = unzip (map floatTopBind annotated_w_levels)
125 dumpIfSet_dyn dflags Opt_D_verbose_core2core "Levels added:"
126 (vcat (map ppr annotated_w_levels));
128 let { (tlets, ntlets, lams) = get_stats (sum_stats fss) };
130 dumpIfSet_dyn dflags Opt_D_dump_simpl_stats "FloatOut stats:"
131 (hcat [ int tlets, ptext (sLit " Lets floated to top level; "),
132 int ntlets, ptext (sLit " Lets floated elsewhere; from "),
133 int lams, ptext (sLit " Lambda groups")]);
135 return (concat binds_s')
138 floatTopBind :: LevelledBind -> (FloatStats, [CoreBind])
140 = case (floatBind bind) of { (fs, floats) ->
141 (fs, bagToList (flattenFloats floats)) }
144 %************************************************************************
146 \subsection[FloatOut-Bind]{Floating in a binding (the business end)}
148 %************************************************************************
151 floatBind :: LevelledBind -> (FloatStats, FloatBinds)
152 floatBind (NonRec (TB var level) rhs)
153 = case (floatRhs level rhs) of { (fs, rhs_floats, rhs') ->
156 -- see Note [Bottoming floats: eta expansion] in SetLevels
157 let rhs'' | isBottomingId var = etaExpand (idArity var) rhs'
160 in (fs, rhs_floats `plusFloats` unitFloat level (NonRec var rhs'')) }
162 floatBind (Rec pairs)
163 = case floatList do_pair pairs of { (fs, rhs_floats, new_pairs) ->
164 -- NB: the rhs floats may contain references to the
165 -- bound things. For example
166 -- f = ...(let v = ...f... in b) ...
167 if not (isTopLvl dest_lvl) then
168 -- Find which bindings float out at least one lambda beyond this one
169 -- These ones can't mention the binders, because they couldn't
170 -- be escaping a major level if so.
171 -- The ones that are not going further can join the letrec;
172 -- they may not be mutually recursive but the occurrence analyser will
173 -- find that out. In our example we make a Rec thus:
176 case (partitionByMajorLevel dest_lvl rhs_floats) of { (floats', heres) ->
177 (fs, floats' `plusFloats` unitFloat dest_lvl
178 (Rec (floatsToBindPairs heres new_pairs))) }
180 -- For top level, no need to partition; just make them all recursive
181 -- (And the partition wouldn't work because they'd all end up in floats')
182 (fs, unitFloat dest_lvl
183 (Rec (floatsToBindPairs (flattenFloats rhs_floats) new_pairs))) }
185 (((TB _ dest_lvl), _) : _) = pairs
187 do_pair (TB name level, rhs)
188 = case (floatRhs level rhs) of { (fs, rhs_floats, rhs') ->
189 (fs, rhs_floats, (name, rhs')) }
192 floatList :: (a -> (FloatStats, FloatBinds, b)) -> [a] -> (FloatStats, FloatBinds, [b])
193 floatList _ [] = (zeroStats, emptyFloats, [])
194 floatList f (a:as) = case f a of { (fs_a, binds_a, b) ->
195 case floatList f as of { (fs_as, binds_as, bs) ->
196 (fs_a `add_stats` fs_as, binds_a `plusFloats` binds_as, b:bs) }}
200 %************************************************************************
202 \subsection[FloatOut-Expr]{Floating in expressions}
204 %************************************************************************
207 floatExpr, floatRhs, floatCaseAlt
210 -> (FloatStats, FloatBinds, CoreExpr)
212 floatCaseAlt lvl arg -- Used rec rhss, and case-alternative rhss
213 = case (floatExpr lvl arg) of { (fsa, floats, arg') ->
214 case (partitionByMajorLevel lvl floats) of { (floats', heres) ->
215 -- Dump bindings that aren't going to escape from a lambda;
216 -- in particular, we must dump the ones that are bound by
217 -- the rec or case alternative
218 (fsa, floats', install heres arg') }}
221 floatRhs lvl arg -- Used for nested non-rec rhss, and fn args
222 -- See Note [Floating out of RHS]
226 floatExpr _ (Var v) = (zeroStats, emptyFloats, Var v)
227 floatExpr _ (Type ty) = (zeroStats, emptyFloats, Type ty)
228 floatExpr _ (Coercion co) = (zeroStats, emptyFloats, Coercion co)
229 floatExpr _ (Lit lit) = (zeroStats, emptyFloats, Lit lit)
231 floatExpr lvl (App e a)
232 = case (floatExpr lvl e) of { (fse, floats_e, e') ->
233 case (floatRhs lvl a) of { (fsa, floats_a, a') ->
234 (fse `add_stats` fsa, floats_e `plusFloats` floats_a, App e' a') }}
236 floatExpr _ lam@(Lam (TB _ lam_lvl) _)
237 = let (bndrs_w_lvls, body) = collectBinders lam
238 bndrs = [b | TB b _ <- bndrs_w_lvls]
239 -- All the binders have the same level
240 -- See SetLevels.lvlLamBndrs
242 case (floatExpr lam_lvl body) of { (fs, floats, body1) ->
244 -- Dump anything that is captured by this lambda
245 -- Eg \x -> ...(\y -> let v = <blah> in ...)...
246 -- We'll have the binding (v = <blah>) in the floats,
247 -- but must dump it at the lambda-x
248 case (partitionByLevel lam_lvl floats) of { (floats1, heres) ->
249 (add_to_stats fs floats1, floats1, mkLams bndrs (install heres body1))
252 floatExpr lvl (Note note@(SCC cc) expr)
253 = case (floatExpr lvl expr) of { (fs, floating_defns, expr') ->
255 -- Annotate bindings floated outwards past an scc expression
256 -- with the cc. We mark that cc as "duplicated", though.
258 annotated_defns = wrapCostCentre (dupifyCC cc) floating_defns
260 (fs, annotated_defns, Note note expr') }
262 floatExpr lvl (Note note expr) -- Other than SCCs
263 = case (floatExpr lvl expr) of { (fs, floating_defns, expr') ->
264 (fs, floating_defns, Note note expr') }
266 floatExpr lvl (Cast expr co)
267 = case (floatExpr lvl expr) of { (fs, floating_defns, expr') ->
268 (fs, floating_defns, Cast expr' co) }
270 floatExpr lvl (Let (NonRec (TB bndr bndr_lvl) rhs) body)
271 | isUnLiftedType (idType bndr) -- Treat unlifted lets just like a case
272 -- I.e. floatExpr for rhs, floatCaseAlt for body
273 = case floatExpr lvl rhs of { (_, rhs_floats, rhs') ->
274 case floatCaseAlt bndr_lvl body of { (fs, body_floats, body') ->
275 (fs, rhs_floats `plusFloats` body_floats, Let (NonRec bndr rhs') body') }}
277 floatExpr lvl (Let bind body)
278 = case (floatBind bind) of { (fsb, bind_floats) ->
279 case (floatExpr lvl body) of { (fse, body_floats, body') ->
280 case partitionByMajorLevel lvl (bind_floats `plusFloats` body_floats)
281 of { (floats, heres) ->
282 -- See Note [Avoiding unnecessary floating]
283 (add_stats fsb fse, floats, install heres body') } } }
285 floatExpr lvl (Case scrut (TB case_bndr case_lvl) ty alts)
286 = case floatExpr lvl scrut of { (fse, fde, scrut') ->
287 case floatList float_alt alts of { (fsa, fda, alts') ->
288 (add_stats fse fsa, fda `plusFloats` fde, Case scrut' case_bndr ty alts')
291 -- Use floatCaseAlt for the alternatives, so that we
292 -- don't gratuitiously float bindings out of the RHSs
293 float_alt (con, bs, rhs)
294 = case (floatCaseAlt case_lvl rhs) of { (fs, rhs_floats, rhs') ->
295 (fs, rhs_floats, (con, [b | TB b _ <- bs], rhs')) }
298 Note [Avoiding unnecessary floating]
299 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
300 In general we want to avoid floating a let unnecessarily, because
301 it might worsen strictness:
303 x = ...(let y = e in y+y)....
304 Here y is demanded. If we float it outside the lazy 'x=..' then
305 we'd have to zap its demand info, and it may never be restored.
307 So at a 'let' we leave the binding right where the are unless
308 the binding will escape a value lambda. That's what the
309 partitionByMajorLevel does in the floatExpr (Let ...) case.
311 Notice, though, that we must take care to drop any bindings
312 from the body of the let that depend on the staying-put bindings.
314 We used instead to do the partitionByMajorLevel on the RHS of an '=',
315 in floatRhs. But that was quite tiresome. We needed to test for
316 values or trival rhss, because (in particular) we don't want to insert
317 new bindings between the "=" and the "\". E.g.
318 f = \x -> let <bind> in <body>
320 f = let <bind> in \x -> <body>
321 (a) The simplifier will immediately float it further out, so we may
322 as well do so right now; in general, keeping rhss as manifest
324 (b) If a float-in pass follows immediately, it might add yet more
325 bindings just after the '='. And some of them might (correctly)
326 be strict even though the 'let f' is lazy, because f, being a value,
327 gets its demand-info zapped by the simplifier.
328 And even all that turned out to be very fragile, and broke
329 altogether when profiling got in the way.
331 So now we do the partition right at the (Let..) itself.
333 %************************************************************************
335 \subsection{Utility bits for floating stats}
337 %************************************************************************
339 I didn't implement this with unboxed numbers. I don't want to be too
340 strict in this stuff, as it is rarely turned on. (WDP 95/09)
344 = FlS Int -- Number of top-floats * lambda groups they've been past
345 Int -- Number of non-top-floats * lambda groups they've been past
346 Int -- Number of lambda (groups) seen
348 get_stats :: FloatStats -> (Int, Int, Int)
349 get_stats (FlS a b c) = (a, b, c)
351 zeroStats :: FloatStats
352 zeroStats = FlS 0 0 0
354 sum_stats :: [FloatStats] -> FloatStats
355 sum_stats xs = foldr add_stats zeroStats xs
357 add_stats :: FloatStats -> FloatStats -> FloatStats
358 add_stats (FlS a1 b1 c1) (FlS a2 b2 c2)
359 = FlS (a1 + a2) (b1 + b2) (c1 + c2)
361 add_to_stats :: FloatStats -> FloatBinds -> FloatStats
362 add_to_stats (FlS a b c) (FB tops others)
363 = FlS (a + lengthBag tops) (b + lengthBag (flattenMajor others)) (c + 1)
367 %************************************************************************
369 \subsection{Utility bits for floating}
371 %************************************************************************
373 Note [Representation of FloatBinds]
374 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
375 The FloatBinds types is somewhat important. We can get very large numbers
376 of floating bindings, often all destined for the top level. A typical example
377 is x = [4,2,5,2,5, .... ]
378 Then we get lots of small expressions like (fromInteger 4), which all get
382 (a) we partition these floating bindings *at every binding site*
383 (b) SetLevels introduces a new bindings site for every float
384 So we had better not look at each binding at each binding site!
386 That is why MajorEnv is represented as a finite map.
388 We keep the bindings destined for the *top* level separate, because
389 we float them out even if they don't escape a *value* lambda; see
390 partitionByMajorLevel.
394 type FloatBind = CoreBind -- INVARIANT: a FloatBind is always lifted
396 data FloatBinds = FB !(Bag FloatBind) -- Destined for top level
397 !MajorEnv -- Levels other than top
398 -- See Note [Representation of FloatBinds]
400 instance Outputable FloatBinds where
401 ppr (FB fbs env) = ptext (sLit "FB") <+> (braces $ vcat
402 [ ptext (sLit "binds =") <+> ppr fbs
403 , ptext (sLit "env =") <+> ppr env ])
405 type MajorEnv = M.IntMap MinorEnv -- Keyed by major level
406 type MinorEnv = M.IntMap (Bag FloatBind) -- Keyed by minor level
408 flattenFloats :: FloatBinds -> Bag FloatBind
409 flattenFloats (FB tops others) = tops `unionBags` flattenMajor others
411 flattenMajor :: MajorEnv -> Bag FloatBind
412 flattenMajor = M.fold (unionBags . flattenMinor) emptyBag
414 flattenMinor :: MinorEnv -> Bag FloatBind
415 flattenMinor = M.fold unionBags emptyBag
417 emptyFloats :: FloatBinds
418 emptyFloats = FB emptyBag M.empty
420 unitFloat :: Level -> FloatBind -> FloatBinds
421 unitFloat lvl@(Level major minor) b
422 | isTopLvl lvl = FB (unitBag b) M.empty
423 | otherwise = FB emptyBag (M.singleton major (M.singleton minor (unitBag b)))
425 plusFloats :: FloatBinds -> FloatBinds -> FloatBinds
426 plusFloats (FB t1 b1) (FB t2 b2) = FB (t1 `unionBags` t2) (b1 `plusMajor` b2)
428 plusMajor :: MajorEnv -> MajorEnv -> MajorEnv
429 plusMajor = M.unionWith plusMinor
431 plusMinor :: MinorEnv -> MinorEnv -> MinorEnv
432 plusMinor = M.unionWith unionBags
434 floatsToBindPairs :: Bag FloatBind -> [(Id,CoreExpr)] -> [(Id,CoreExpr)]
435 floatsToBindPairs floats binds = foldrBag add binds floats
437 add (Rec pairs) binds = pairs ++ binds
438 add (NonRec binder rhs) binds = (binder,rhs) : binds
440 install :: Bag FloatBind -> CoreExpr -> CoreExpr
441 install defn_groups expr
442 = foldrBag install_group expr defn_groups
444 install_group defns body = Let defns body
446 partitionByMajorLevel, partitionByLevel
447 :: Level -- Partitioning level
448 -> FloatBinds -- Defns to be divided into 2 piles...
449 -> (FloatBinds, -- Defns with level strictly < partition level,
450 Bag FloatBind) -- The rest
452 -- ---- partitionByMajorLevel ----
453 -- Float it if we escape a value lambda, *or* if we get to the top level
454 -- If we can get to the top level, say "yes" anyway. This means that
459 -- which is as it should be
461 partitionByMajorLevel (Level major _) (FB tops defns)
462 = (FB tops outer, heres `unionBags` flattenMajor inner)
464 (outer, mb_heres, inner) = M.splitLookup major defns
465 heres = case mb_heres of
467 Just h -> flattenMinor h
469 partitionByLevel (Level major minor) (FB tops defns)
470 = (FB tops (outer_maj `plusMajor` M.singleton major outer_min),
471 here_min `unionBags` flattenMinor inner_min
472 `unionBags` flattenMajor inner_maj)
475 (outer_maj, mb_here_maj, inner_maj) = M.splitLookup major defns
476 (outer_min, mb_here_min, inner_min) = case mb_here_maj of
477 Nothing -> (M.empty, Nothing, M.empty)
478 Just min_defns -> M.splitLookup minor min_defns
479 here_min = mb_here_min `orElse` emptyBag
481 wrapCostCentre :: CostCentre -> FloatBinds -> FloatBinds
482 wrapCostCentre cc (FB tops defns)
483 = FB (wrap_defns tops) (M.map (M.map wrap_defns) defns)
485 wrap_defns = mapBag wrap_one
486 wrap_one (NonRec binder rhs) = NonRec binder (mkSCC cc rhs)
487 wrap_one (Rec pairs) = Rec (mapSnd (mkSCC cc) pairs)