d65f7bd17e646b668ebce87eb3dfde880bdccb4c
[ghc-hetmet.git] / compiler / simplCore / FloatOut.lhs
1 %
2 % (c) The GRASP/AQUA Project, Glasgow University, 1992-1998
3 %
4 \section[FloatOut]{Float bindings outwards (towards the top level)}
5
6 ``Long-distance'' floating of bindings towards the top level.
7
8 \begin{code}
9 module FloatOut ( floatOutwards ) where
10
11 import CoreSyn
12 import CoreUtils
13
14 import DynFlags ( DynFlags, DynFlag(..), FloatOutSwitches(..) )
15 import ErrUtils         ( dumpIfSet_dyn )
16 import CostCentre       ( dupifyCC, CostCentre )
17 import Id               ( Id, idType )
18 import Type             ( isUnLiftedType )
19 import SetLevels        ( Level(..), LevelledExpr, LevelledBind,
20                           setLevels, isTopLvl, tOP_LEVEL )
21 import UniqSupply       ( UniqSupply )
22 import Bag
23 import Util
24 import Maybes
25 import UniqFM
26 import Outputable
27 import FastString
28 \end{code}
29
30         -----------------
31         Overall game plan
32         -----------------
33
34 The Big Main Idea is:
35
36         To float out sub-expressions that can thereby get outside
37         a non-one-shot value lambda, and hence may be shared.
38
39
40 To achieve this we may need to do two thing:
41
42    a) Let-bind the sub-expression:
43
44         f (g x)  ==>  let lvl = f (g x) in lvl
45
46       Now we can float the binding for 'lvl'.  
47
48    b) More than that, we may need to abstract wrt a type variable
49
50         \x -> ... /\a -> let v = ...a... in ....
51
52       Here the binding for v mentions 'a' but not 'x'.  So we
53       abstract wrt 'a', to give this binding for 'v':
54
55             vp = /\a -> ...a...
56             v  = vp a
57
58       Now the binding for vp can float out unimpeded.
59       I can't remember why this case seemed important enough to
60       deal with, but I certainly found cases where important floats
61       didn't happen if we did not abstract wrt tyvars.
62
63 With this in mind we can also achieve another goal: lambda lifting.
64 We can make an arbitrary (function) binding float to top level by
65 abstracting wrt *all* local variables, not just type variables, leaving
66 a binding that can be floated right to top level.  Whether or not this
67 happens is controlled by a flag.
68
69
70 Random comments
71 ~~~~~~~~~~~~~~~
72
73 At the moment we never float a binding out to between two adjacent
74 lambdas.  For example:
75
76 @
77         \x y -> let t = x+x in ...
78 ===>
79         \x -> let t = x+x in \y -> ...
80 @
81 Reason: this is less efficient in the case where the original lambda
82 is never partially applied.
83
84 But there's a case I've seen where this might not be true.  Consider:
85 @
86 elEm2 x ys
87   = elem' x ys
88   where
89     elem' _ []  = False
90     elem' x (y:ys)      = x==y || elem' x ys
91 @
92 It turns out that this generates a subexpression of the form
93 @
94         \deq x ys -> let eq = eqFromEqDict deq in ...
95 @
96 vwhich might usefully be separated to
97 @
98         \deq -> let eq = eqFromEqDict deq in \xy -> ...
99 @
100 Well, maybe.  We don't do this at the moment.
101
102
103 %************************************************************************
104 %*                                                                      *
105 \subsection[floatOutwards]{@floatOutwards@: let-floating interface function}
106 %*                                                                      *
107 %************************************************************************
108
109 \begin{code}
110 floatOutwards :: FloatOutSwitches
111               -> DynFlags
112               -> UniqSupply 
113               -> [CoreBind] -> IO [CoreBind]
114
115 floatOutwards float_sws dflags us pgm
116   = do {
117         let { annotated_w_levels = setLevels float_sws pgm us ;
118               (fss, binds_s')    = unzip (map floatTopBind annotated_w_levels)
119             } ;
120
121         dumpIfSet_dyn dflags Opt_D_verbose_core2core "Levels added:"
122                   (vcat (map ppr annotated_w_levels));
123
124         let { (tlets, ntlets, lams) = get_stats (sum_stats fss) };
125
126         dumpIfSet_dyn dflags Opt_D_dump_simpl_stats "FloatOut stats:"
127                 (hcat [ int tlets,  ptext (sLit " Lets floated to top level; "),
128                         int ntlets, ptext (sLit " Lets floated elsewhere; from "),
129                         int lams,   ptext (sLit " Lambda groups")]);
130
131         return (concat binds_s')
132     }
133
134 floatTopBind :: LevelledBind -> (FloatStats, [CoreBind])
135 floatTopBind bind
136   = case (floatBind bind) of { (fs, floats) ->
137     (fs, bagToList (flattenFloats floats))
138     }
139 \end{code}
140
141 %************************************************************************
142 %*                                                                      *
143 \subsection[FloatOut-Bind]{Floating in a binding (the business end)}
144 %*                                                                      *
145 %************************************************************************
146
147
148 \begin{code}
149 floatBind :: LevelledBind -> (FloatStats, FloatBinds)
150
151 floatBind (NonRec (TB name level) rhs)
152   = case (floatRhs level rhs) of { (fs, rhs_floats, rhs') ->
153     (fs, rhs_floats `plusFloats` unitFloat level (NonRec name rhs')) }
154
155 floatBind bind@(Rec pairs)
156   = case (unzip3 (map do_pair pairs)) of { (fss, rhss_floats, new_pairs) ->
157     let rhs_floats = foldr1 plusFloats rhss_floats in
158
159     if not (isTopLvl bind_dest_lvl) then
160         -- Find which bindings float out at least one lambda beyond this one
161         -- These ones can't mention the binders, because they couldn't 
162         -- be escaping a major level if so.
163         -- The ones that are not going further can join the letrec;
164         -- they may not be mutually recursive but the occurrence analyser will
165         -- find that out.
166         case (partitionByMajorLevel bind_dest_lvl rhs_floats) of { (floats', heres) ->
167         (sum_stats fss, 
168          floats' `plusFloats` unitFloat bind_dest_lvl 
169                                 (Rec (floatsToBindPairs heres new_pairs))) }
170     else
171         -- In a recursive binding, *destined for* the top level
172         -- (only), the rhs floats may contain references to the 
173         -- bound things.  For example
174         --      f = ...(let v = ...f... in b) ...
175         --  might get floated to
176         --      v = ...f...
177         --      f = ... b ...
178         -- and hence we must (pessimistically) make all the floats recursive
179         -- with the top binding.  Later dependency analysis will unravel it.
180         --
181         -- This can only happen for bindings destined for the top level,
182         -- because only then will partitionByMajorLevel allow through a binding
183         -- that only differs in its minor level
184         (sum_stats fss, unitFloat tOP_LEVEL
185                            (Rec (floatsToBindPairs (flattenFloats rhs_floats) new_pairs)))
186     }
187   where
188     bind_dest_lvl = getBindLevel bind
189
190     do_pair (TB name level, rhs)
191       = case (floatRhs level rhs) of { (fs, rhs_floats, rhs') ->
192         (fs, rhs_floats, (name, rhs'))
193         }
194 \end{code}
195
196 %************************************************************************
197
198 \subsection[FloatOut-Expr]{Floating in expressions}
199 %*                                                                      *
200 %************************************************************************
201
202 \begin{code}
203 floatExpr, floatRhs, floatCaseAlt
204          :: Level
205          -> LevelledExpr
206          -> (FloatStats, FloatBinds, CoreExpr)
207
208 floatCaseAlt lvl arg    -- Used rec rhss, and case-alternative rhss
209   = case (floatExpr lvl arg) of { (fsa, floats, arg') ->
210     case (partitionByMajorLevel lvl floats) of { (floats', heres) ->
211         -- Dump bindings that aren't going to escape from a lambda;
212         -- in particular, we must dump the ones that are bound by 
213         -- the rec or case alternative
214     (fsa, floats', install heres arg') }}
215
216 floatRhs lvl arg        -- Used for nested non-rec rhss, and fn args
217                         -- See Note [Floating out of RHS]
218   = case (floatExpr lvl arg) of { (fsa, floats, arg') ->
219     if exprIsCheap arg' then    
220         (fsa, floats, arg')
221     else
222     case (partitionByMajorLevel lvl floats) of { (floats', heres) ->
223     (fsa, floats', install heres arg') }}
224
225 -- Note [Floating out of RHSs]
226 -- ~~~~~~~~~~~~~~~~~~~~~~~~~~~
227 -- Dump bindings that aren't going to escape from a lambda
228 -- This isn't a scoping issue (the binder isn't in scope in the RHS 
229 --      of a non-rec binding)
230 -- Rather, it is to avoid floating the x binding out of
231 --      f (let x = e in b)
232 -- unnecessarily.  But we first test for values or trival rhss,
233 -- because (in particular) we don't want to insert new bindings between
234 -- the "=" and the "\".  E.g.
235 --      f = \x -> let <bind> in <body>
236 -- We do not want
237 --      f = let <bind> in \x -> <body>
238 -- (a) The simplifier will immediately float it further out, so we may
239 --      as well do so right now; in general, keeping rhss as manifest 
240 --      values is good
241 -- (b) If a float-in pass follows immediately, it might add yet more
242 --      bindings just after the '='.  And some of them might (correctly)
243 --      be strict even though the 'let f' is lazy, because f, being a value,
244 --      gets its demand-info zapped by the simplifier.
245 --
246 -- We use exprIsCheap because that is also what's used by the simplifier
247 -- to decide whether to float a let out of a let
248
249 floatExpr _ (Var v)   = (zeroStats, emptyFloats, Var v)
250 floatExpr _ (Type ty) = (zeroStats, emptyFloats, Type ty)
251 floatExpr _ (Lit lit) = (zeroStats, emptyFloats, Lit lit)
252           
253 floatExpr lvl (App e a)
254   = case (floatExpr      lvl e) of { (fse, floats_e, e') ->
255     case (floatRhs lvl a)       of { (fsa, floats_a, a') ->
256     (fse `add_stats` fsa, floats_e `plusFloats` floats_a, App e' a') }}
257
258 floatExpr _ lam@(Lam _ _)
259   = let
260         (bndrs_w_lvls, body) = collectBinders lam
261         bndrs                = [b | TB b _ <- bndrs_w_lvls]
262         lvls                 = [l | TB _ l <- bndrs_w_lvls]
263
264         -- For the all-tyvar case we are prepared to pull 
265         -- the lets out, to implement the float-out-of-big-lambda
266         -- transform; but otherwise we only float bindings that are
267         -- going to escape a value lambda.
268         -- In particular, for one-shot lambdas we don't float things
269         -- out; we get no saving by so doing.
270         partition_fn | all isTyVar bndrs = partitionByLevel
271                      | otherwise         = partitionByMajorLevel
272     in
273     case (floatExpr (last lvls) body) of { (fs, floats, body') ->
274
275         -- Dump any bindings which absolutely cannot go any further
276     case (partition_fn (head lvls) floats)      of { (floats', heres) ->
277
278     (add_to_stats fs floats', floats', mkLams bndrs (install heres body'))
279     }}
280
281 floatExpr lvl (Note note@(SCC cc) expr)
282   = case (floatExpr lvl expr)    of { (fs, floating_defns, expr') ->
283     let
284         -- Annotate bindings floated outwards past an scc expression
285         -- with the cc.  We mark that cc as "duplicated", though.
286
287         annotated_defns = wrapCostCentre (dupifyCC cc) floating_defns
288     in
289     (fs, annotated_defns, Note note expr') }
290
291 floatExpr lvl (Note note expr)  -- Other than SCCs
292   = case (floatExpr lvl expr)    of { (fs, floating_defns, expr') ->
293     (fs, floating_defns, Note note expr') }
294
295 floatExpr lvl (Cast expr co)
296   = case (floatExpr lvl expr)   of { (fs, floating_defns, expr') ->
297     (fs, floating_defns, Cast expr' co) }
298
299 floatExpr lvl (Let (NonRec (TB bndr bndr_lvl) rhs) body)
300   | isUnLiftedType (idType bndr)        -- Treat unlifted lets just like a case
301                                 -- I.e. floatExpr for rhs, floatCaseAlt for body
302   = case floatExpr lvl rhs          of { (_, rhs_floats, rhs') ->
303     case floatCaseAlt bndr_lvl body of { (fs, body_floats, body') ->
304     (fs, rhs_floats `plusFloats` body_floats, Let (NonRec bndr rhs') body') }}
305
306 floatExpr lvl (Let bind body)
307   = case (floatBind bind)     of { (fsb, bind_floats) ->
308     case (floatExpr lvl body) of { (fse, body_floats, body') ->
309     (add_stats fsb fse,
310      bind_floats `plusFloats` body_floats,
311      body')  }}
312
313 floatExpr lvl (Case scrut (TB case_bndr case_lvl) ty alts)
314   = case floatExpr lvl scrut    of { (fse, fde, scrut') ->
315     case floatList float_alt alts       of { (fsa, fda, alts')  ->
316     (add_stats fse fsa, fda `plusFloats` fde, Case scrut' case_bndr ty alts')
317     }}
318   where
319         -- Use floatCaseAlt for the alternatives, so that we
320         -- don't gratuitiously float bindings out of the RHSs
321     float_alt (con, bs, rhs)
322         = case (floatCaseAlt case_lvl rhs)      of { (fs, rhs_floats, rhs') ->
323           (fs, rhs_floats, (con, [b | TB b _ <- bs], rhs')) }
324
325
326 floatList :: (a -> (FloatStats, FloatBinds, b)) -> [a] -> (FloatStats, FloatBinds, [b])
327 floatList _ [] = (zeroStats, emptyFloats, [])
328 floatList f (a:as) = case f a            of { (fs_a,  binds_a,  b)  ->
329                      case floatList f as of { (fs_as, binds_as, bs) ->
330                      (fs_a `add_stats` fs_as, binds_a `plusFloats`  binds_as, b:bs) }}
331
332 getBindLevel :: Bind (TaggedBndr Level) -> Level
333 getBindLevel (NonRec (TB _ lvl) _)       = lvl
334 getBindLevel (Rec (((TB _ lvl), _) : _)) = lvl
335 getBindLevel (Rec [])                    = panic "getBindLevel Rec []"
336 \end{code}
337
338 %************************************************************************
339 %*                                                                      *
340 \subsection{Utility bits for floating stats}
341 %*                                                                      *
342 %************************************************************************
343
344 I didn't implement this with unboxed numbers.  I don't want to be too
345 strict in this stuff, as it is rarely turned on.  (WDP 95/09)
346
347 \begin{code}
348 data FloatStats
349   = FlS Int  -- Number of top-floats * lambda groups they've been past
350         Int  -- Number of non-top-floats * lambda groups they've been past
351         Int  -- Number of lambda (groups) seen
352
353 get_stats :: FloatStats -> (Int, Int, Int)
354 get_stats (FlS a b c) = (a, b, c)
355
356 zeroStats :: FloatStats
357 zeroStats = FlS 0 0 0
358
359 sum_stats :: [FloatStats] -> FloatStats
360 sum_stats xs = foldr add_stats zeroStats xs
361
362 add_stats :: FloatStats -> FloatStats -> FloatStats
363 add_stats (FlS a1 b1 c1) (FlS a2 b2 c2)
364   = FlS (a1 + a2) (b1 + b2) (c1 + c2)
365
366 add_to_stats :: FloatStats -> FloatBinds -> FloatStats
367 add_to_stats (FlS a b c) (FB tops others)
368   = FlS (a + lengthBag tops) (b + lengthBag (flattenMajor others)) (c + 1)
369 \end{code}
370
371
372 %************************************************************************
373 %*                                                                      *
374 \subsection{Utility bits for floating}
375 %*                                                                      *
376 %************************************************************************
377
378 Note [Representation of FloatBinds]
379 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
380 The FloatBinds types is somewhat important.  We can get very large numbers
381 of floating bindings, often all destined for the top level.  A typical example
382 is     x = [4,2,5,2,5, .... ]
383 Then we get lots of small expressions like (fromInteger 4), which all get
384 lifted to top level.  
385
386 The trouble is that  
387   (a) we partition these floating bindings *at every binding site* 
388   (b) SetLevels introduces a new bindings site for every float
389 So we had better not look at each binding at each binding site!
390
391 That is why MajorEnv is represented as a finite map.
392
393 We keep the bindings destined for the *top* level separate, because
394 we float them out even if they don't escape a *value* lambda; see
395 partitionByMajorLevel.
396
397
398 \begin{code}
399 type FloatBind = CoreBind       -- INVARIANT: a FloatBind is always lifted
400
401 data FloatBinds  = FB !(Bag FloatBind)          -- Destined for top level
402                       !MajorEnv                 -- Levels other than top
403      -- See Note [Representation of FloatBinds]
404
405 type MajorEnv = UniqFM MinorEnv                 -- Keyed by major level
406 type MinorEnv = UniqFM (Bag FloatBind)          -- Keyed by minor level
407
408 flattenFloats :: FloatBinds -> Bag FloatBind
409 flattenFloats (FB tops others) = tops `unionBags` flattenMajor others
410
411 flattenMajor :: MajorEnv -> Bag FloatBind
412 flattenMajor = foldUFM (unionBags . flattenMinor) emptyBag
413
414 flattenMinor :: MinorEnv -> Bag FloatBind
415 flattenMinor = foldUFM unionBags emptyBag
416
417 emptyFloats :: FloatBinds
418 emptyFloats = FB emptyBag emptyUFM
419
420 unitFloat :: Level -> FloatBind -> FloatBinds
421 unitFloat lvl@(Level major minor) b 
422   | isTopLvl lvl = FB (unitBag b) emptyUFM
423   | otherwise    = FB emptyBag (unitUFM major (unitUFM minor (unitBag b)))
424
425 plusFloats :: FloatBinds -> FloatBinds -> FloatBinds
426 plusFloats (FB t1 b1) (FB t2 b2) = FB (t1 `unionBags` t2) (b1 `plusMajor` b2)
427
428 plusMajor :: MajorEnv -> MajorEnv -> MajorEnv
429 plusMajor = plusUFM_C plusMinor
430
431 plusMinor :: MinorEnv -> MinorEnv -> MinorEnv
432 plusMinor = plusUFM_C unionBags
433
434 floatsToBindPairs :: Bag FloatBind -> [(Id,CoreExpr)] -> [(Id,CoreExpr)]
435 floatsToBindPairs floats binds = foldrBag add binds floats
436   where
437    add (Rec pairs)         binds = pairs ++ binds
438    add (NonRec binder rhs) binds = (binder,rhs) : binds
439
440 install :: Bag FloatBind -> CoreExpr -> CoreExpr
441 install defn_groups expr
442   = foldrBag install_group expr defn_groups
443   where
444     install_group defns body = Let defns body
445
446 partitionByMajorLevel, partitionByLevel
447         :: Level                -- Partitioning level
448         -> FloatBinds           -- Defns to be divided into 2 piles...
449         -> (FloatBinds,         -- Defns  with level strictly < partition level,
450             Bag FloatBind)      -- The rest
451
452 --       ---- partitionByMajorLevel ----
453 -- Float it if we escape a value lambda, *or* if we get to the top level
454 -- If we can get to the top level, say "yes" anyway. This means that 
455 --      x = f e
456 -- transforms to 
457 --    lvl = e
458 --    x = f lvl
459 -- which is as it should be
460
461 partitionByMajorLevel (Level major _) (FB tops defns)
462   = (FB tops outer, heres `unionBags` flattenMajor inner)
463   where
464     (outer, mb_heres, inner) = splitUFM defns major
465     heres = case mb_heres of 
466                Nothing -> emptyBag
467                Just h  -> flattenMinor h
468
469 partitionByLevel (Level major minor) (FB tops defns)
470   = (FB tops (outer_maj `plusMajor` unitUFM major outer_min),
471      here_min `unionBags` flattenMinor inner_min 
472               `unionBags` flattenMajor inner_maj)
473
474   where
475     (outer_maj, mb_here_maj, inner_maj) = splitUFM defns major
476     (outer_min, mb_here_min, inner_min) = case mb_here_maj of
477                                             Nothing -> (emptyUFM, Nothing, emptyUFM)
478                                             Just min_defns -> splitUFM min_defns minor
479     here_min = mb_here_min `orElse` emptyBag
480
481 wrapCostCentre :: CostCentre -> FloatBinds -> FloatBinds
482 wrapCostCentre cc (FB tops defns)
483   = FB (wrap_defns tops) (mapUFM (mapUFM wrap_defns) defns)
484   where
485     wrap_defns = mapBag wrap_one 
486     wrap_one (NonRec binder rhs) = NonRec binder (mkSCC cc rhs)
487     wrap_one (Rec pairs)         = Rec (mapSnd (mkSCC cc) pairs)
488 \end{code}
489