2 % (c) The GRASP/AQUA Project, Glasgow University, 1992-1998
4 \section[FloatOut]{Float bindings outwards (towards the top level)}
6 ``Long-distance'' floating of bindings towards the top level.
9 module FloatOut ( floatOutwards ) where
11 #include "HsVersions.h"
14 import CoreUtils ( mkSCC, exprIsValue, exprIsTrivial )
16 import DynFlags ( DynFlags, DynFlag(..), FloatOutSwitches(..) )
17 import ErrUtils ( dumpIfSet_dyn )
18 import CostCentre ( dupifyCC, CostCentre )
19 import Id ( Id, idType )
20 import Type ( isUnLiftedType )
21 import CoreLint ( showPass, endPass )
22 import SetLevels ( Level(..), LevelledExpr, LevelledBind,
23 setLevels, ltMajLvl, ltLvl, isTopLvl )
24 import UniqSupply ( UniqSupply )
25 import List ( partition )
27 import Util ( notNull )
36 To float out sub-expressions that can thereby get outside
37 a non-one-shot value lambda, and hence may be shared.
40 To achieve this we may need to do two thing:
42 a) Let-bind the sub-expression:
44 f (g x) ==> let lvl = f (g x) in lvl
46 Now we can float the binding for 'lvl'.
48 b) More than that, we may need to abstract wrt a type variable
50 \x -> ... /\a -> let v = ...a... in ....
52 Here the binding for v mentions 'a' but not 'x'. So we
53 abstract wrt 'a', to give this binding for 'v':
58 Now the binding for vp can float out unimpeded.
59 I can't remember why this case seemed important enough to
60 deal with, but I certainly found cases where important floats
61 didn't happen if we did not abstract wrt tyvars.
63 With this in mind we can also achieve another goal: lambda lifting.
64 We can make an arbitrary (function) binding float to top level by
65 abstracting wrt *all* local variables, not just type variables, leaving
66 a binding that can be floated right to top level. Whether or not this
67 happens is controlled by a flag.
73 At the moment we never float a binding out to between two adjacent
77 \x y -> let t = x+x in ...
79 \x -> let t = x+x in \y -> ...
81 Reason: this is less efficient in the case where the original lambda
82 is never partially applied.
84 But there's a case I've seen where this might not be true. Consider:
90 elem' x (y:ys) = x==y || elem' x ys
92 It turns out that this generates a subexpression of the form
94 \deq x ys -> let eq = eqFromEqDict deq in ...
96 vwhich might usefully be separated to
98 \deq -> let eq = eqFromEqDict deq in \xy -> ...
100 Well, maybe. We don't do this at the moment.
103 type FloatBind = (Level, CoreBind) -- INVARIANT: a FloatBind is always lifted
104 type FloatBinds = [FloatBind]
107 %************************************************************************
109 \subsection[floatOutwards]{@floatOutwards@: let-floating interface function}
111 %************************************************************************
114 floatOutwards :: FloatOutSwitches
117 -> [CoreBind] -> IO [CoreBind]
119 floatOutwards float_sws dflags us pgm
121 showPass dflags float_msg ;
123 let { annotated_w_levels = setLevels float_sws pgm us ;
124 (fss, binds_s') = unzip (map floatTopBind annotated_w_levels)
127 dumpIfSet_dyn dflags Opt_D_verbose_core2core "Levels added:"
128 (vcat (map ppr annotated_w_levels));
130 let { (tlets, ntlets, lams) = get_stats (sum_stats fss) };
132 dumpIfSet_dyn dflags Opt_D_dump_simpl_stats "FloatOut stats:"
133 (hcat [ int tlets, ptext SLIT(" Lets floated to top level; "),
134 int ntlets, ptext SLIT(" Lets floated elsewhere; from "),
135 int lams, ptext SLIT(" Lambda groups")]);
137 endPass dflags float_msg Opt_D_verbose_core2core (concat binds_s')
138 {- no specific flag for dumping float-out -}
141 float_msg = showSDoc (text "Float out" <+> parens (sws float_sws))
142 sws (FloatOutSw lam const) = pp_not lam <+> text "lambdas" <> comma <+>
143 pp_not const <+> text "constants"
145 pp_not False = text "not"
147 floatTopBind bind@(NonRec _ _)
148 = case (floatBind bind) of { (fs, floats, bind') ->
149 (fs, floatsToBinds floats ++ [bind'])
152 floatTopBind bind@(Rec _)
153 = case (floatBind bind) of { (fs, floats, Rec pairs') ->
154 WARN( notNull floats, ppr bind $$ ppr floats )
155 (fs, [Rec (floatsToBindPairs floats ++ pairs')]) }
158 %************************************************************************
160 \subsection[FloatOut-Bind]{Floating in a binding (the business end)}
162 %************************************************************************
166 floatBind :: LevelledBind
167 -> (FloatStats, FloatBinds, CoreBind)
169 floatBind (NonRec (TB name level) rhs)
170 = case (floatNonRecRhs level rhs) of { (fs, rhs_floats, rhs') ->
171 (fs, rhs_floats, NonRec name rhs') }
173 floatBind bind@(Rec pairs)
174 = case (unzip3 (map do_pair pairs)) of { (fss, rhss_floats, new_pairs) ->
176 if not (isTopLvl bind_dest_level) then
177 -- Standard case; the floated bindings can't mention the
178 -- binders, because they couldn't be escaping a major level
180 (sum_stats fss, concat rhss_floats, Rec new_pairs)
182 -- In a recursive binding, *destined for* the top level
183 -- (only), the rhs floats may contain references to the
184 -- bound things. For example
185 -- f = ...(let v = ...f... in b) ...
186 -- might get floated to
189 -- and hence we must (pessimistically) make all the floats recursive
190 -- with the top binding. Later dependency analysis will unravel it.
192 -- This can only happen for bindings destined for the top level,
193 -- because only then will partitionByMajorLevel allow through a binding
194 -- that only differs in its minor level
196 Rec (new_pairs ++ floatsToBindPairs (concat rhss_floats)))
199 bind_dest_level = getBindLevel bind
201 do_pair (TB name level, rhs)
202 = case (floatRhs level rhs) of { (fs, rhs_floats, rhs') ->
203 (fs, rhs_floats, (name, rhs'))
207 %************************************************************************
209 \subsection[FloatOut-Expr]{Floating in expressions}
211 %************************************************************************
214 floatExpr, floatRhs, floatNonRecRhs
217 -> (FloatStats, FloatBinds, CoreExpr)
219 floatRhs lvl arg -- Used rec rhss, and case-alternative rhss
220 = case (floatExpr lvl arg) of { (fsa, floats, arg') ->
221 case (partitionByMajorLevel lvl floats) of { (floats', heres) ->
222 -- Dump bindings that aren't going to escape from a lambda;
223 -- in particular, we must dump the ones that are bound by
224 -- the rec or case alternative
225 (fsa, floats', install heres arg') }}
227 floatNonRecRhs lvl arg -- Used for nested non-rec rhss, and fn args
228 = case (floatExpr lvl arg) of { (fsa, floats, arg') ->
229 -- Dump bindings that aren't going to escape from a lambda
230 -- This isn't a scoping issue (the binder isn't in scope in the RHS of a non-rec binding)
231 -- Rather, it is to avoid floating the x binding out of
232 -- f (let x = e in b)
233 -- unnecessarily. But we first test for values or trival rhss,
234 -- because (in particular) we don't want to insert new bindings between
235 -- the "=" and the "\". E.g.
236 -- f = \x -> let <bind> in <body>
238 -- f = let <bind> in \x -> <body>
239 -- (a) The simplifier will immediately float it further out, so we may
240 -- as well do so right now; in general, keeping rhss as manifest
242 -- (b) If a float-in pass follows immediately, it might add yet more
243 -- bindings just after the '='. And some of them might (correctly)
244 -- be strict even though the 'let f' is lazy, because f, being a value,
245 -- gets its demand-info zapped by the simplifier.
246 if exprIsValue arg' || exprIsTrivial arg' then
249 case (partitionByMajorLevel lvl floats) of { (floats', heres) ->
250 (fsa, floats', install heres arg') }}
252 floatExpr _ (Var v) = (zeroStats, [], Var v)
253 floatExpr _ (Type ty) = (zeroStats, [], Type ty)
254 floatExpr _ (Lit lit) = (zeroStats, [], Lit lit)
256 floatExpr lvl (App e a)
257 = case (floatExpr lvl e) of { (fse, floats_e, e') ->
258 case (floatNonRecRhs lvl a) of { (fsa, floats_a, a') ->
259 (fse `add_stats` fsa, floats_e ++ floats_a, App e' a') }}
261 floatExpr lvl lam@(Lam _ _)
263 (bndrs_w_lvls, body) = collectBinders lam
264 bndrs = [b | TB b _ <- bndrs_w_lvls]
265 lvls = [l | TB b l <- bndrs_w_lvls]
267 -- For the all-tyvar case we are prepared to pull
268 -- the lets out, to implement the float-out-of-big-lambda
269 -- transform; but otherwise we only float bindings that are
270 -- going to escape a value lambda.
271 -- In particular, for one-shot lambdas we don't float things
272 -- out; we get no saving by so doing.
273 partition_fn | all isTyVar bndrs = partitionByLevel
274 | otherwise = partitionByMajorLevel
276 case (floatExpr (last lvls) body) of { (fs, floats, body') ->
278 -- Dump any bindings which absolutely cannot go any further
279 case (partition_fn (head lvls) floats) of { (floats', heres) ->
281 (add_to_stats fs floats', floats', mkLams bndrs (install heres body'))
284 floatExpr lvl (Note note@(SCC cc) expr)
285 = case (floatExpr lvl expr) of { (fs, floating_defns, expr') ->
287 -- Annotate bindings floated outwards past an scc expression
288 -- with the cc. We mark that cc as "duplicated", though.
290 annotated_defns = annotate (dupifyCC cc) floating_defns
292 (fs, annotated_defns, Note note expr') }
294 annotate :: CostCentre -> FloatBinds -> FloatBinds
296 annotate dupd_cc defn_groups
297 = [ (level, ann_bind floater) | (level, floater) <- defn_groups ]
299 ann_bind (NonRec binder rhs)
300 = NonRec binder (mkSCC dupd_cc rhs)
303 = Rec [(binder, mkSCC dupd_cc rhs) | (binder, rhs) <- pairs]
305 floatExpr lvl (Note InlineMe expr) -- Other than SCCs
306 = case floatExpr InlineCtxt expr of { (fs, floating_defns, expr') ->
307 -- There can be some floating_defns, arising from
308 -- ordinary lets that were there all the time. It seems
309 -- more efficient to test once here than to avoid putting
310 -- them into floating_defns (which would mean testing for
311 -- inlineCtxt at every let)
312 (fs, [], Note InlineMe (install floating_defns expr')) } -- See notes in SetLevels
314 floatExpr lvl (Note note expr) -- Other than SCCs
315 = case (floatExpr lvl expr) of { (fs, floating_defns, expr') ->
316 (fs, floating_defns, Note note expr') }
318 floatExpr lvl (Let (NonRec (TB bndr bndr_lvl) rhs) body)
319 | isUnLiftedType (idType bndr) -- Treat unlifted lets just like a case
320 = case floatExpr lvl rhs of { (fs, rhs_floats, rhs') ->
321 case floatRhs bndr_lvl body of { (fs, body_floats, body') ->
322 (fs, rhs_floats ++ body_floats, Let (NonRec bndr rhs') body') }}
324 floatExpr lvl (Let bind body)
325 = case (floatBind bind) of { (fsb, rhs_floats, bind') ->
326 case (floatExpr lvl body) of { (fse, body_floats, body') ->
328 rhs_floats ++ [(bind_lvl, bind')] ++ body_floats,
331 bind_lvl = getBindLevel bind
333 floatExpr lvl (Case scrut (TB case_bndr case_lvl) ty alts)
334 = case floatExpr lvl scrut of { (fse, fde, scrut') ->
335 case floatList float_alt alts of { (fsa, fda, alts') ->
336 (add_stats fse fsa, fda ++ fde, Case scrut' case_bndr ty alts')
339 -- Use floatRhs for the alternatives, so that we
340 -- don't gratuitiously float bindings out of the RHSs
341 float_alt (con, bs, rhs)
342 = case (floatRhs case_lvl rhs) of { (fs, rhs_floats, rhs') ->
343 (fs, rhs_floats, (con, [b | TB b _ <- bs], rhs')) }
346 floatList :: (a -> (FloatStats, FloatBinds, b)) -> [a] -> (FloatStats, FloatBinds, [b])
347 floatList f [] = (zeroStats, [], [])
348 floatList f (a:as) = case f a of { (fs_a, binds_a, b) ->
349 case floatList f as of { (fs_as, binds_as, bs) ->
350 (fs_a `add_stats` fs_as, binds_a ++ binds_as, b:bs) }}
353 %************************************************************************
355 \subsection{Utility bits for floating stats}
357 %************************************************************************
359 I didn't implement this with unboxed numbers. I don't want to be too
360 strict in this stuff, as it is rarely turned on. (WDP 95/09)
364 = FlS Int -- Number of top-floats * lambda groups they've been past
365 Int -- Number of non-top-floats * lambda groups they've been past
366 Int -- Number of lambda (groups) seen
368 get_stats (FlS a b c) = (a, b, c)
370 zeroStats = FlS 0 0 0
372 sum_stats xs = foldr add_stats zeroStats xs
374 add_stats (FlS a1 b1 c1) (FlS a2 b2 c2)
375 = FlS (a1 + a2) (b1 + b2) (c1 + c2)
377 add_to_stats (FlS a b c) floats
378 = FlS (a + length top_floats) (b + length other_floats) (c + 1)
380 (top_floats, other_floats) = partition to_very_top floats
382 to_very_top (my_lvl, _) = isTopLvl my_lvl
386 %************************************************************************
388 \subsection{Utility bits for floating}
390 %************************************************************************
393 getBindLevel (NonRec (TB _ lvl) _) = lvl
394 getBindLevel (Rec (((TB _ lvl), _) : _)) = lvl
398 partitionByMajorLevel, partitionByLevel
399 :: Level -- Partitioning level
401 -> FloatBinds -- Defns to be divided into 2 piles...
403 -> (FloatBinds, -- Defns with level strictly < partition level,
404 FloatBinds) -- The rest
407 partitionByMajorLevel ctxt_lvl defns
408 = partition float_further defns
410 -- Float it if we escape a value lambda, or if we get to the top level
411 float_further (my_lvl, bind) = my_lvl `ltMajLvl` ctxt_lvl || isTopLvl my_lvl
412 -- The isTopLvl part says that if we can get to the top level, say "yes" anyway
418 -- which is as it should be
420 partitionByLevel ctxt_lvl defns
421 = partition float_further defns
423 float_further (my_lvl, _) = my_lvl `ltLvl` ctxt_lvl
427 floatsToBinds :: FloatBinds -> [CoreBind]
428 floatsToBinds floats = map snd floats
430 floatsToBindPairs :: FloatBinds -> [(Id,CoreExpr)]
432 floatsToBindPairs floats = concat (map mk_pairs floats)
434 mk_pairs (_, Rec pairs) = pairs
435 mk_pairs (_, NonRec binder rhs) = [(binder,rhs)]
437 install :: FloatBinds -> CoreExpr -> CoreExpr
439 install defn_groups expr
440 = foldr install_group expr defn_groups
442 install_group (_, defns) body = Let defns body