add comment
[ghc-hetmet.git] / compiler / simplCore / FloatOut.lhs
1 %
2 % (c) The GRASP/AQUA Project, Glasgow University, 1992-1998
3 %
4 \section[FloatOut]{Float bindings outwards (towards the top level)}
5
6 ``Long-distance'' floating of bindings towards the top level.
7
8 \begin{code}
9 module FloatOut ( floatOutwards ) where
10
11 import CoreSyn
12 import CoreUtils
13 import CoreArity        ( etaExpand )
14 import CoreMonad        ( FloatOutSwitches(..) )
15
16 import DynFlags         ( DynFlags, DynFlag(..) )
17 import ErrUtils         ( dumpIfSet_dyn )
18 import CostCentre       ( dupifyCC, CostCentre )
19 import Id               ( Id, idType, idArity, isBottomingId )
20 import Type             ( isUnLiftedType )
21 import SetLevels        ( Level(..), LevelledExpr, LevelledBind,
22                           setLevels, isTopLvl )
23 import UniqSupply       ( UniqSupply )
24 import Bag
25 import Util
26 import Maybes
27 import Outputable
28 import FastString
29 import qualified Data.IntMap as M
30
31 #include "HsVersions.h"
32 \end{code}
33
34         -----------------
35         Overall game plan
36         -----------------
37
38 The Big Main Idea is:
39
40         To float out sub-expressions that can thereby get outside
41         a non-one-shot value lambda, and hence may be shared.
42
43
44 To achieve this we may need to do two thing:
45
46    a) Let-bind the sub-expression:
47
48         f (g x)  ==>  let lvl = f (g x) in lvl
49
50       Now we can float the binding for 'lvl'.  
51
52    b) More than that, we may need to abstract wrt a type variable
53
54         \x -> ... /\a -> let v = ...a... in ....
55
56       Here the binding for v mentions 'a' but not 'x'.  So we
57       abstract wrt 'a', to give this binding for 'v':
58
59             vp = /\a -> ...a...
60             v  = vp a
61
62       Now the binding for vp can float out unimpeded.
63       I can't remember why this case seemed important enough to
64       deal with, but I certainly found cases where important floats
65       didn't happen if we did not abstract wrt tyvars.
66
67 With this in mind we can also achieve another goal: lambda lifting.
68 We can make an arbitrary (function) binding float to top level by
69 abstracting wrt *all* local variables, not just type variables, leaving
70 a binding that can be floated right to top level.  Whether or not this
71 happens is controlled by a flag.
72
73
74 Random comments
75 ~~~~~~~~~~~~~~~
76
77 At the moment we never float a binding out to between two adjacent
78 lambdas.  For example:
79
80 @
81         \x y -> let t = x+x in ...
82 ===>
83         \x -> let t = x+x in \y -> ...
84 @
85 Reason: this is less efficient in the case where the original lambda
86 is never partially applied.
87
88 But there's a case I've seen where this might not be true.  Consider:
89 @
90 elEm2 x ys
91   = elem' x ys
92   where
93     elem' _ []  = False
94     elem' x (y:ys)      = x==y || elem' x ys
95 @
96 It turns out that this generates a subexpression of the form
97 @
98         \deq x ys -> let eq = eqFromEqDict deq in ...
99 @
100 vwhich might usefully be separated to
101 @
102         \deq -> let eq = eqFromEqDict deq in \xy -> ...
103 @
104 Well, maybe.  We don't do this at the moment.
105
106
107 %************************************************************************
108 %*                                                                      *
109 \subsection[floatOutwards]{@floatOutwards@: let-floating interface function}
110 %*                                                                      *
111 %************************************************************************
112
113 \begin{code}
114 floatOutwards :: FloatOutSwitches
115               -> DynFlags
116               -> UniqSupply 
117               -> [CoreBind] -> IO [CoreBind]
118
119 floatOutwards float_sws dflags us pgm
120   = do {
121         let { annotated_w_levels = setLevels float_sws pgm us ;
122               (fss, binds_s')    = unzip (map floatTopBind annotated_w_levels)
123             } ;
124
125         dumpIfSet_dyn dflags Opt_D_verbose_core2core "Levels added:"
126                   (vcat (map ppr annotated_w_levels));
127
128         let { (tlets, ntlets, lams) = get_stats (sum_stats fss) };
129
130         dumpIfSet_dyn dflags Opt_D_dump_simpl_stats "FloatOut stats:"
131                 (hcat [ int tlets,  ptext (sLit " Lets floated to top level; "),
132                         int ntlets, ptext (sLit " Lets floated elsewhere; from "),
133                         int lams,   ptext (sLit " Lambda groups")]);
134
135         return (concat binds_s')
136     }
137
138 floatTopBind :: LevelledBind -> (FloatStats, [CoreBind])
139 floatTopBind bind
140   = case (floatBind bind) of { (fs, floats) ->
141     (fs, bagToList (flattenFloats floats)) }
142 \end{code}
143
144 %************************************************************************
145 %*                                                                      *
146 \subsection[FloatOut-Bind]{Floating in a binding (the business end)}
147 %*                                                                      *
148 %************************************************************************
149
150 \begin{code}
151 floatBind :: LevelledBind -> (FloatStats, FloatBinds)
152 floatBind (NonRec (TB var level) rhs)
153   = case (floatRhs level rhs) of { (fs, rhs_floats, rhs') ->
154
155         -- A tiresome hack: 
156         -- see Note [Bottoming floats: eta expansion] in SetLevels
157     let rhs'' | isBottomingId var = etaExpand (idArity var) rhs'
158               | otherwise         = rhs'
159
160     in (fs, rhs_floats `plusFloats` unitFloat level (NonRec var rhs'')) }
161
162 floatBind (Rec pairs)
163   = case floatList do_pair pairs of { (fs, rhs_floats, new_pairs) ->
164         -- NB: the rhs floats may contain references to the 
165         -- bound things.  For example
166         --      f = ...(let v = ...f... in b) ...
167     if not (isTopLvl dest_lvl) then
168         -- Find which bindings float out at least one lambda beyond this one
169         -- These ones can't mention the binders, because they couldn't 
170         -- be escaping a major level if so.
171         -- The ones that are not going further can join the letrec;
172         -- they may not be mutually recursive but the occurrence analyser will
173         -- find that out. In our example we make a Rec thus:
174         --      v = ...f...
175         --      f = ... b ...
176         case (partitionByMajorLevel dest_lvl rhs_floats) of { (floats', heres) ->
177         (fs, floats' `plusFloats` unitFloat dest_lvl 
178                  (Rec (floatsToBindPairs heres new_pairs))) }
179     else
180         -- For top level, no need to partition; just make them all recursive
181         -- (And the partition wouldn't work because they'd all end up in floats')
182         (fs, unitFloat dest_lvl
183                  (Rec (floatsToBindPairs (flattenFloats rhs_floats) new_pairs)))  }
184   where
185     (((TB _ dest_lvl), _) : _) = pairs
186
187     do_pair (TB name level, rhs)
188       = case (floatRhs level rhs) of { (fs, rhs_floats, rhs') ->
189         (fs, rhs_floats, (name, rhs')) }
190
191 ---------------
192 floatList :: (a -> (FloatStats, FloatBinds, b)) -> [a] -> (FloatStats, FloatBinds, [b])
193 floatList _ [] = (zeroStats, emptyFloats, [])
194 floatList f (a:as) = case f a            of { (fs_a,  binds_a,  b)  ->
195                      case floatList f as of { (fs_as, binds_as, bs) ->
196                      (fs_a `add_stats` fs_as, binds_a `plusFloats`  binds_as, b:bs) }}
197 \end{code}
198
199
200 %************************************************************************
201
202 \subsection[FloatOut-Expr]{Floating in expressions}
203 %*                                                                      *
204 %************************************************************************
205
206 \begin{code}
207 floatExpr, floatRhs, floatCaseAlt
208          :: Level
209          -> LevelledExpr
210          -> (FloatStats, FloatBinds, CoreExpr)
211
212 floatCaseAlt lvl arg    -- Used rec rhss, and case-alternative rhss
213   = case (floatExpr lvl arg) of { (fsa, floats, arg') ->
214     case (partitionByMajorLevel lvl floats) of { (floats', heres) ->
215         -- Dump bindings that aren't going to escape from a lambda;
216         -- in particular, we must dump the ones that are bound by 
217         -- the rec or case alternative
218     (fsa, floats', install heres arg') }}
219
220 -----------------
221 floatRhs lvl arg        -- Used for nested non-rec rhss, and fn args
222                         -- See Note [Floating out of RHS]
223   = floatExpr lvl arg
224
225 -----------------
226 floatExpr _ (Var v)   = (zeroStats, emptyFloats, Var v)
227 floatExpr _ (Type ty) = (zeroStats, emptyFloats, Type ty)
228 floatExpr _ (Coercion co) = (zeroStats, emptyFloats, Coercion co)
229 floatExpr _ (Lit lit) = (zeroStats, emptyFloats, Lit lit)
230           
231 floatExpr lvl (App e a)
232   = case (floatExpr      lvl e) of { (fse, floats_e, e') ->
233     case (floatRhs lvl a)       of { (fsa, floats_a, a') ->
234     (fse `add_stats` fsa, floats_e `plusFloats` floats_a, App e' a') }}
235
236 floatExpr _ lam@(Lam (TB _ lam_lvl) _)
237   = let (bndrs_w_lvls, body) = collectBinders lam
238         bndrs                = [b | TB b _ <- bndrs_w_lvls]
239         -- All the binders have the same level
240         -- See SetLevels.lvlLamBndrs
241     in
242     case (floatExpr lam_lvl body) of { (fs, floats, body1) ->
243
244         -- Dump anything that is captured by this lambda
245         -- Eg  \x -> ...(\y -> let v = <blah> in ...)...
246         -- We'll have the binding (v = <blah>) in the floats,
247         -- but must dump it at the lambda-x
248     case (partitionByLevel lam_lvl floats)      of { (floats1, heres) ->
249     (add_to_stats fs floats1, floats1, mkLams bndrs (install heres body1))
250     }}
251
252 floatExpr lvl (Note note@(SCC cc) expr)
253   = case (floatExpr lvl expr)    of { (fs, floating_defns, expr') ->
254     let
255         -- Annotate bindings floated outwards past an scc expression
256         -- with the cc.  We mark that cc as "duplicated", though.
257
258         annotated_defns = wrapCostCentre (dupifyCC cc) floating_defns
259     in
260     (fs, annotated_defns, Note note expr') }
261
262 floatExpr lvl (Note note expr)  -- Other than SCCs
263   = case (floatExpr lvl expr)    of { (fs, floating_defns, expr') ->
264     (fs, floating_defns, Note note expr') }
265
266 floatExpr lvl (Cast expr co)
267   = case (floatExpr lvl expr)   of { (fs, floating_defns, expr') ->
268     (fs, floating_defns, Cast expr' co) }
269
270 floatExpr lvl (Let (NonRec (TB bndr bndr_lvl) rhs) body)
271   | isUnLiftedType (idType bndr)  -- Treat unlifted lets just like a case
272                                   -- I.e. floatExpr for rhs, floatCaseAlt for body
273   = case floatExpr lvl rhs          of { (_, rhs_floats, rhs') ->
274     case floatCaseAlt bndr_lvl body of { (fs, body_floats, body') ->
275     (fs, rhs_floats `plusFloats` body_floats, Let (NonRec bndr rhs') body') }}
276
277 floatExpr lvl (Let bind body)
278   = case (floatBind bind)     of { (fsb, bind_floats) ->
279     case (floatExpr lvl body) of { (fse, body_floats, body') ->
280     case partitionByMajorLevel lvl (bind_floats `plusFloats` body_floats) 
281                               of { (floats, heres) ->
282         -- See Note [Avoiding unnecessary floating]
283     (add_stats fsb fse, floats, install heres body')  } } }
284
285 floatExpr lvl (Case scrut (TB case_bndr case_lvl) ty alts)
286   = case floatExpr lvl scrut    of { (fse, fde, scrut') ->
287     case floatList float_alt alts       of { (fsa, fda, alts')  ->
288     (add_stats fse fsa, fda `plusFloats` fde, Case scrut' case_bndr ty alts')
289     }}
290   where
291         -- Use floatCaseAlt for the alternatives, so that we
292         -- don't gratuitiously float bindings out of the RHSs
293     float_alt (con, bs, rhs)
294         = case (floatCaseAlt case_lvl rhs)      of { (fs, rhs_floats, rhs') ->
295           (fs, rhs_floats, (con, [b | TB b _ <- bs], rhs')) }
296 \end{code}
297
298 Note [Avoiding unnecessary floating]
299 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
300 In general we want to avoid floating a let unnecessarily, because
301 it might worsen strictness:
302     let 
303        x = ...(let y = e in y+y)....
304 Here y is demanded.  If we float it outside the lazy 'x=..' then
305 we'd have to zap its demand info, and it may never be restored.
306
307 So at a 'let' we leave the binding right where the are unless
308 the binding will escape a value lambda.  That's what the 
309 partitionByMajorLevel does in the floatExpr (Let ...) case.
310
311 Notice, though, that we must take care to drop any bindings
312 from the body of the let that depend on the staying-put bindings.
313
314 We used instead to do the partitionByMajorLevel on the RHS of an '=',
315 in floatRhs.  But that was quite tiresome.  We needed to test for
316 values or trival rhss, because (in particular) we don't want to insert
317 new bindings between the "=" and the "\".  E.g.
318         f = \x -> let <bind> in <body>
319 We do not want
320         f = let <bind> in \x -> <body>
321 (a) The simplifier will immediately float it further out, so we may
322         as well do so right now; in general, keeping rhss as manifest 
323         values is good
324 (b) If a float-in pass follows immediately, it might add yet more
325         bindings just after the '='.  And some of them might (correctly)
326         be strict even though the 'let f' is lazy, because f, being a value,
327         gets its demand-info zapped by the simplifier.
328 And even all that turned out to be very fragile, and broke
329 altogether when profiling got in the way.
330
331 So now we do the partition right at the (Let..) itself.
332
333 %************************************************************************
334 %*                                                                      *
335 \subsection{Utility bits for floating stats}
336 %*                                                                      *
337 %************************************************************************
338
339 I didn't implement this with unboxed numbers.  I don't want to be too
340 strict in this stuff, as it is rarely turned on.  (WDP 95/09)
341
342 \begin{code}
343 data FloatStats
344   = FlS Int  -- Number of top-floats * lambda groups they've been past
345         Int  -- Number of non-top-floats * lambda groups they've been past
346         Int  -- Number of lambda (groups) seen
347
348 get_stats :: FloatStats -> (Int, Int, Int)
349 get_stats (FlS a b c) = (a, b, c)
350
351 zeroStats :: FloatStats
352 zeroStats = FlS 0 0 0
353
354 sum_stats :: [FloatStats] -> FloatStats
355 sum_stats xs = foldr add_stats zeroStats xs
356
357 add_stats :: FloatStats -> FloatStats -> FloatStats
358 add_stats (FlS a1 b1 c1) (FlS a2 b2 c2)
359   = FlS (a1 + a2) (b1 + b2) (c1 + c2)
360
361 add_to_stats :: FloatStats -> FloatBinds -> FloatStats
362 add_to_stats (FlS a b c) (FB tops others)
363   = FlS (a + lengthBag tops) (b + lengthBag (flattenMajor others)) (c + 1)
364 \end{code}
365
366
367 %************************************************************************
368 %*                                                                      *
369 \subsection{Utility bits for floating}
370 %*                                                                      *
371 %************************************************************************
372
373 Note [Representation of FloatBinds]
374 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
375 The FloatBinds types is somewhat important.  We can get very large numbers
376 of floating bindings, often all destined for the top level.  A typical example
377 is     x = [4,2,5,2,5, .... ]
378 Then we get lots of small expressions like (fromInteger 4), which all get
379 lifted to top level.  
380
381 The trouble is that  
382   (a) we partition these floating bindings *at every binding site* 
383   (b) SetLevels introduces a new bindings site for every float
384 So we had better not look at each binding at each binding site!
385
386 That is why MajorEnv is represented as a finite map.
387
388 We keep the bindings destined for the *top* level separate, because
389 we float them out even if they don't escape a *value* lambda; see
390 partitionByMajorLevel.
391
392
393 \begin{code}
394 type FloatBind = CoreBind       -- INVARIANT: a FloatBind is always lifted
395
396 data FloatBinds  = FB !(Bag FloatBind)          -- Destined for top level
397                       !MajorEnv                 -- Levels other than top
398      -- See Note [Representation of FloatBinds]
399
400 instance Outputable FloatBinds where
401   ppr (FB fbs env) = ptext (sLit "FB") <+> (braces $ vcat
402                        [ ptext (sLit "binds =") <+> ppr fbs
403                        , ptext (sLit "env =") <+> ppr env ])
404
405 type MajorEnv = M.IntMap MinorEnv                       -- Keyed by major level
406 type MinorEnv = M.IntMap (Bag FloatBind)                -- Keyed by minor level
407
408 flattenFloats :: FloatBinds -> Bag FloatBind
409 flattenFloats (FB tops others) = tops `unionBags` flattenMajor others
410
411 flattenMajor :: MajorEnv -> Bag FloatBind
412 flattenMajor = M.fold (unionBags . flattenMinor) emptyBag
413
414 flattenMinor :: MinorEnv -> Bag FloatBind
415 flattenMinor = M.fold unionBags emptyBag
416
417 emptyFloats :: FloatBinds
418 emptyFloats = FB emptyBag M.empty
419
420 unitFloat :: Level -> FloatBind -> FloatBinds
421 unitFloat lvl@(Level major minor) b 
422   | isTopLvl lvl = FB (unitBag b) M.empty
423   | otherwise    = FB emptyBag (M.singleton major (M.singleton minor (unitBag b)))
424
425 plusFloats :: FloatBinds -> FloatBinds -> FloatBinds
426 plusFloats (FB t1 b1) (FB t2 b2) = FB (t1 `unionBags` t2) (b1 `plusMajor` b2)
427
428 plusMajor :: MajorEnv -> MajorEnv -> MajorEnv
429 plusMajor = M.unionWith plusMinor
430
431 plusMinor :: MinorEnv -> MinorEnv -> MinorEnv
432 plusMinor = M.unionWith unionBags
433
434 floatsToBindPairs :: Bag FloatBind -> [(Id,CoreExpr)] -> [(Id,CoreExpr)]
435 floatsToBindPairs floats binds = foldrBag add binds floats
436   where
437    add (Rec pairs)         binds = pairs ++ binds
438    add (NonRec binder rhs) binds = (binder,rhs) : binds
439
440 install :: Bag FloatBind -> CoreExpr -> CoreExpr
441 install defn_groups expr
442   = foldrBag install_group expr defn_groups
443   where
444     install_group defns body = Let defns body
445
446 partitionByMajorLevel, partitionByLevel
447         :: Level                -- Partitioning level
448         -> FloatBinds           -- Defns to be divided into 2 piles...
449         -> (FloatBinds,         -- Defns  with level strictly < partition level,
450             Bag FloatBind)      -- The rest
451
452 --       ---- partitionByMajorLevel ----
453 -- Float it if we escape a value lambda, *or* if we get to the top level
454 -- If we can get to the top level, say "yes" anyway. This means that 
455 --      x = f e
456 -- transforms to 
457 --    lvl = e
458 --    x = f lvl
459 -- which is as it should be
460
461 partitionByMajorLevel (Level major _) (FB tops defns)
462   = (FB tops outer, heres `unionBags` flattenMajor inner)
463   where
464     (outer, mb_heres, inner) = M.splitLookup major defns
465     heres = case mb_heres of 
466                Nothing -> emptyBag
467                Just h  -> flattenMinor h
468
469 partitionByLevel (Level major minor) (FB tops defns)
470   = (FB tops (outer_maj `plusMajor` M.singleton major outer_min),
471      here_min `unionBags` flattenMinor inner_min 
472               `unionBags` flattenMajor inner_maj)
473
474   where
475     (outer_maj, mb_here_maj, inner_maj) = M.splitLookup major defns
476     (outer_min, mb_here_min, inner_min) = case mb_here_maj of
477                                             Nothing -> (M.empty, Nothing, M.empty)
478                                             Just min_defns -> M.splitLookup minor min_defns
479     here_min = mb_here_min `orElse` emptyBag
480
481 wrapCostCentre :: CostCentre -> FloatBinds -> FloatBinds
482 wrapCostCentre cc (FB tops defns)
483   = FB (wrap_defns tops) (M.map (M.map wrap_defns) defns)
484   where
485     wrap_defns = mapBag wrap_one 
486     wrap_one (NonRec binder rhs) = NonRec binder (mkSCC cc rhs)
487     wrap_one (Rec pairs)         = Rec (mapSnd (mkSCC cc) pairs)
488 \end{code}