2a51a2100e5a546e397cea5fd2122b071c0bf989
[ghc-hetmet.git] / compiler / simplCore / FloatOut.lhs
1 %
2 % (c) The GRASP/AQUA Project, Glasgow University, 1992-1998
3 %
4 \section[FloatOut]{Float bindings outwards (towards the top level)}
5
6 ``Long-distance'' floating of bindings towards the top level.
7
8 \begin{code}
9 module FloatOut ( floatOutwards ) where
10
11 import CoreSyn
12 import CoreUtils
13 import CoreArity        ( etaExpand )
14 import CoreMonad        ( FloatOutSwitches(..) )
15
16 import DynFlags         ( DynFlags, DynFlag(..) )
17 import ErrUtils         ( dumpIfSet_dyn )
18 import CostCentre       ( dupifyCC, CostCentre )
19 import Id               ( Id, idType, idArity, isBottomingId )
20 import Type             ( isUnLiftedType )
21 import SetLevels        ( Level(..), LevelledExpr, LevelledBind,
22                           setLevels, isTopLvl )
23 import UniqSupply       ( UniqSupply )
24 import Bag
25 import Util
26 import Maybes
27 import Outputable
28 import FastString
29 import qualified Data.IntMap as M
30
31 #include "HsVersions.h"
32 \end{code}
33
34         -----------------
35         Overall game plan
36         -----------------
37
38 The Big Main Idea is:
39
40         To float out sub-expressions that can thereby get outside
41         a non-one-shot value lambda, and hence may be shared.
42
43
44 To achieve this we may need to do two thing:
45
46    a) Let-bind the sub-expression:
47
48         f (g x)  ==>  let lvl = f (g x) in lvl
49
50       Now we can float the binding for 'lvl'.  
51
52    b) More than that, we may need to abstract wrt a type variable
53
54         \x -> ... /\a -> let v = ...a... in ....
55
56       Here the binding for v mentions 'a' but not 'x'.  So we
57       abstract wrt 'a', to give this binding for 'v':
58
59             vp = /\a -> ...a...
60             v  = vp a
61
62       Now the binding for vp can float out unimpeded.
63       I can't remember why this case seemed important enough to
64       deal with, but I certainly found cases where important floats
65       didn't happen if we did not abstract wrt tyvars.
66
67 With this in mind we can also achieve another goal: lambda lifting.
68 We can make an arbitrary (function) binding float to top level by
69 abstracting wrt *all* local variables, not just type variables, leaving
70 a binding that can be floated right to top level.  Whether or not this
71 happens is controlled by a flag.
72
73
74 Random comments
75 ~~~~~~~~~~~~~~~
76
77 At the moment we never float a binding out to between two adjacent
78 lambdas.  For example:
79
80 @
81         \x y -> let t = x+x in ...
82 ===>
83         \x -> let t = x+x in \y -> ...
84 @
85 Reason: this is less efficient in the case where the original lambda
86 is never partially applied.
87
88 But there's a case I've seen where this might not be true.  Consider:
89 @
90 elEm2 x ys
91   = elem' x ys
92   where
93     elem' _ []  = False
94     elem' x (y:ys)      = x==y || elem' x ys
95 @
96 It turns out that this generates a subexpression of the form
97 @
98         \deq x ys -> let eq = eqFromEqDict deq in ...
99 @
100 vwhich might usefully be separated to
101 @
102         \deq -> let eq = eqFromEqDict deq in \xy -> ...
103 @
104 Well, maybe.  We don't do this at the moment.
105
106
107 %************************************************************************
108 %*                                                                      *
109 \subsection[floatOutwards]{@floatOutwards@: let-floating interface function}
110 %*                                                                      *
111 %************************************************************************
112
113 \begin{code}
114 floatOutwards :: FloatOutSwitches
115               -> DynFlags
116               -> UniqSupply 
117               -> [CoreBind] -> IO [CoreBind]
118
119 floatOutwards float_sws dflags us pgm
120   = do {
121         let { annotated_w_levels = setLevels float_sws pgm us ;
122               (fss, binds_s')    = unzip (map floatTopBind annotated_w_levels)
123             } ;
124
125         dumpIfSet_dyn dflags Opt_D_verbose_core2core "Levels added:"
126                   (vcat (map ppr annotated_w_levels));
127
128         let { (tlets, ntlets, lams) = get_stats (sum_stats fss) };
129
130         dumpIfSet_dyn dflags Opt_D_dump_simpl_stats "FloatOut stats:"
131                 (hcat [ int tlets,  ptext (sLit " Lets floated to top level; "),
132                         int ntlets, ptext (sLit " Lets floated elsewhere; from "),
133                         int lams,   ptext (sLit " Lambda groups")]);
134
135         return (concat binds_s')
136     }
137
138 floatTopBind :: LevelledBind -> (FloatStats, [CoreBind])
139 floatTopBind bind
140   = case (floatBind bind) of { (fs, floats) ->
141     (fs, bagToList (flattenFloats floats)) }
142 \end{code}
143
144 %************************************************************************
145 %*                                                                      *
146 \subsection[FloatOut-Bind]{Floating in a binding (the business end)}
147 %*                                                                      *
148 %************************************************************************
149
150 \begin{code}
151 floatBind :: LevelledBind -> (FloatStats, FloatBinds)
152 floatBind (NonRec (TB var level) rhs)
153   = case (floatRhs level rhs) of { (fs, rhs_floats, rhs') ->
154
155         -- A tiresome hack: 
156         -- see Note [Bottoming floats: eta expansion] in SetLevels
157     let rhs'' | isBottomingId var = etaExpand (idArity var) rhs'
158               | otherwise         = rhs'
159
160     in (fs, rhs_floats `plusFloats` unitFloat level (NonRec var rhs'')) }
161
162 floatBind (Rec pairs)
163   = case floatList do_pair pairs of { (fs, rhs_floats, new_pairs) ->
164         -- NB: the rhs floats may contain references to the 
165         -- bound things.  For example
166         --      f = ...(let v = ...f... in b) ...
167     if not (isTopLvl dest_lvl) then
168         -- Find which bindings float out at least one lambda beyond this one
169         -- These ones can't mention the binders, because they couldn't 
170         -- be escaping a major level if so.
171         -- The ones that are not going further can join the letrec;
172         -- they may not be mutually recursive but the occurrence analyser will
173         -- find that out. In our example we make a Rec thus:
174         --      v = ...f...
175         --      f = ... b ...
176         case (partitionByMajorLevel dest_lvl rhs_floats) of { (floats', heres) ->
177         (fs, floats' `plusFloats` unitFloat dest_lvl 
178                  (Rec (floatsToBindPairs heres new_pairs))) }
179     else
180         -- For top level, no need to partition; just make them all recursive
181         -- (And the partition wouldn't work because they'd all end up in floats')
182         (fs, unitFloat dest_lvl
183                  (Rec (floatsToBindPairs (flattenFloats rhs_floats) new_pairs)))  }
184   where
185     (((TB _ dest_lvl), _) : _) = pairs
186
187     do_pair (TB name level, rhs)
188       = case (floatRhs level rhs) of { (fs, rhs_floats, rhs') ->
189         (fs, rhs_floats, (name, rhs')) }
190
191 ---------------
192 floatList :: (a -> (FloatStats, FloatBinds, b)) -> [a] -> (FloatStats, FloatBinds, [b])
193 floatList _ [] = (zeroStats, emptyFloats, [])
194 floatList f (a:as) = case f a            of { (fs_a,  binds_a,  b)  ->
195                      case floatList f as of { (fs_as, binds_as, bs) ->
196                      (fs_a `add_stats` fs_as, binds_a `plusFloats`  binds_as, b:bs) }}
197 \end{code}
198
199
200 %************************************************************************
201
202 \subsection[FloatOut-Expr]{Floating in expressions}
203 %*                                                                      *
204 %************************************************************************
205
206 \begin{code}
207 floatExpr, floatRhs, floatCaseAlt
208          :: Level
209          -> LevelledExpr
210          -> (FloatStats, FloatBinds, CoreExpr)
211
212 floatCaseAlt lvl arg    -- Used rec rhss, and case-alternative rhss
213   = case (floatExpr lvl arg) of { (fsa, floats, arg') ->
214     case (partitionByMajorLevel lvl floats) of { (floats', heres) ->
215         -- Dump bindings that aren't going to escape from a lambda;
216         -- in particular, we must dump the ones that are bound by 
217         -- the rec or case alternative
218     (fsa, floats', install heres arg') }}
219
220 -----------------
221 floatRhs lvl arg        -- Used for nested non-rec rhss, and fn args
222                         -- See Note [Floating out of RHS]
223   = floatExpr lvl arg
224
225 -----------------
226 floatExpr _ (Var v)   = (zeroStats, emptyFloats, Var v)
227 floatExpr _ (Type ty) = (zeroStats, emptyFloats, Type ty)
228 floatExpr _ (Lit lit) = (zeroStats, emptyFloats, Lit lit)
229           
230 floatExpr lvl (App e a)
231   = case (floatExpr      lvl e) of { (fse, floats_e, e') ->
232     case (floatRhs lvl a)       of { (fsa, floats_a, a') ->
233     (fse `add_stats` fsa, floats_e `plusFloats` floats_a, App e' a') }}
234
235 floatExpr _ lam@(Lam (TB _ lam_lvl) _)
236   = let (bndrs_w_lvls, body) = collectBinders lam
237         bndrs                = [b | TB b _ <- bndrs_w_lvls]
238         -- All the binders have the same level
239         -- See SetLevels.lvlLamBndrs
240     in
241     case (floatExpr lam_lvl body) of { (fs, floats, body1) ->
242
243         -- Dump anything that is captured by this lambda
244         -- Eg  \x -> ...(\y -> let v = <blah> in ...)...
245         -- We'll have the binding (v = <blah>) in the floats,
246         -- but must dump it at the lambda-x
247     case (partitionByLevel lam_lvl floats)      of { (floats1, heres) ->
248     (add_to_stats fs floats1, floats1, mkLams bndrs (install heres body1))
249     }}
250
251 floatExpr lvl (Note note@(SCC cc) expr)
252   = case (floatExpr lvl expr)    of { (fs, floating_defns, expr') ->
253     let
254         -- Annotate bindings floated outwards past an scc expression
255         -- with the cc.  We mark that cc as "duplicated", though.
256
257         annotated_defns = wrapCostCentre (dupifyCC cc) floating_defns
258     in
259     (fs, annotated_defns, Note note expr') }
260
261 floatExpr lvl (Note note expr)  -- Other than SCCs
262   = case (floatExpr lvl expr)    of { (fs, floating_defns, expr') ->
263     (fs, floating_defns, Note note expr') }
264
265 floatExpr lvl (Cast expr co)
266   = case (floatExpr lvl expr)   of { (fs, floating_defns, expr') ->
267     (fs, floating_defns, Cast expr' co) }
268
269 floatExpr lvl (Let (NonRec (TB bndr bndr_lvl) rhs) body)
270   | isUnLiftedType (idType bndr)  -- Treat unlifted lets just like a case
271                                   -- I.e. floatExpr for rhs, floatCaseAlt for body
272   = case floatExpr lvl rhs          of { (_, rhs_floats, rhs') ->
273     case floatCaseAlt bndr_lvl body of { (fs, body_floats, body') ->
274     (fs, rhs_floats `plusFloats` body_floats, Let (NonRec bndr rhs') body') }}
275
276 floatExpr lvl (Let bind body)
277   = case (floatBind bind)     of { (fsb, bind_floats) ->
278     case (floatExpr lvl body) of { (fse, body_floats, body') ->
279     case partitionByMajorLevel lvl (bind_floats `plusFloats` body_floats) 
280                               of { (floats, heres) ->
281         -- See Note [Avoiding unnecessary floating]
282     (add_stats fsb fse, floats, install heres body')  } } }
283
284 floatExpr lvl (Case scrut (TB case_bndr case_lvl) ty alts)
285   = case floatExpr lvl scrut    of { (fse, fde, scrut') ->
286     case floatList float_alt alts       of { (fsa, fda, alts')  ->
287     (add_stats fse fsa, fda `plusFloats` fde, Case scrut' case_bndr ty alts')
288     }}
289   where
290         -- Use floatCaseAlt for the alternatives, so that we
291         -- don't gratuitiously float bindings out of the RHSs
292     float_alt (con, bs, rhs)
293         = case (floatCaseAlt case_lvl rhs)      of { (fs, rhs_floats, rhs') ->
294           (fs, rhs_floats, (con, [b | TB b _ <- bs], rhs')) }
295 \end{code}
296
297 Note [Avoiding unnecessary floating]
298 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
299 In general we want to avoid floating a let unnecessarily, because
300 it might worsen strictness:
301     let 
302        x = ...(let y = e in y+y)....
303 Here y is demanded.  If we float it outside the lazy 'x=..' then
304 we'd have to zap its demand info, and it may never be restored.
305
306 So at a 'let' we leave the binding right where the are unless
307 the binding will escape a value lambda.  That's what the 
308 partitionByMajorLevel does in the floatExpr (Let ...) case.
309
310 Notice, though, that we must take care to drop any bindings
311 from the body of the let that depend on the staying-put bindings.
312
313 We used instead to do the partitionByMajorLevel on the RHS of an '=',
314 in floatRhs.  But that was quite tiresome.  We needed to test for
315 values or trival rhss, because (in particular) we don't want to insert
316 new bindings between the "=" and the "\".  E.g.
317         f = \x -> let <bind> in <body>
318 We do not want
319         f = let <bind> in \x -> <body>
320 (a) The simplifier will immediately float it further out, so we may
321         as well do so right now; in general, keeping rhss as manifest 
322         values is good
323 (b) If a float-in pass follows immediately, it might add yet more
324         bindings just after the '='.  And some of them might (correctly)
325         be strict even though the 'let f' is lazy, because f, being a value,
326         gets its demand-info zapped by the simplifier.
327 And even all that turned out to be very fragile, and broke
328 altogether when profiling got in the way.
329
330 So now we do the partition right at the (Let..) itself.
331
332 %************************************************************************
333 %*                                                                      *
334 \subsection{Utility bits for floating stats}
335 %*                                                                      *
336 %************************************************************************
337
338 I didn't implement this with unboxed numbers.  I don't want to be too
339 strict in this stuff, as it is rarely turned on.  (WDP 95/09)
340
341 \begin{code}
342 data FloatStats
343   = FlS Int  -- Number of top-floats * lambda groups they've been past
344         Int  -- Number of non-top-floats * lambda groups they've been past
345         Int  -- Number of lambda (groups) seen
346
347 get_stats :: FloatStats -> (Int, Int, Int)
348 get_stats (FlS a b c) = (a, b, c)
349
350 zeroStats :: FloatStats
351 zeroStats = FlS 0 0 0
352
353 sum_stats :: [FloatStats] -> FloatStats
354 sum_stats xs = foldr add_stats zeroStats xs
355
356 add_stats :: FloatStats -> FloatStats -> FloatStats
357 add_stats (FlS a1 b1 c1) (FlS a2 b2 c2)
358   = FlS (a1 + a2) (b1 + b2) (c1 + c2)
359
360 add_to_stats :: FloatStats -> FloatBinds -> FloatStats
361 add_to_stats (FlS a b c) (FB tops others)
362   = FlS (a + lengthBag tops) (b + lengthBag (flattenMajor others)) (c + 1)
363 \end{code}
364
365
366 %************************************************************************
367 %*                                                                      *
368 \subsection{Utility bits for floating}
369 %*                                                                      *
370 %************************************************************************
371
372 Note [Representation of FloatBinds]
373 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
374 The FloatBinds types is somewhat important.  We can get very large numbers
375 of floating bindings, often all destined for the top level.  A typical example
376 is     x = [4,2,5,2,5, .... ]
377 Then we get lots of small expressions like (fromInteger 4), which all get
378 lifted to top level.  
379
380 The trouble is that  
381   (a) we partition these floating bindings *at every binding site* 
382   (b) SetLevels introduces a new bindings site for every float
383 So we had better not look at each binding at each binding site!
384
385 That is why MajorEnv is represented as a finite map.
386
387 We keep the bindings destined for the *top* level separate, because
388 we float them out even if they don't escape a *value* lambda; see
389 partitionByMajorLevel.
390
391
392 \begin{code}
393 type FloatBind = CoreBind       -- INVARIANT: a FloatBind is always lifted
394
395 data FloatBinds  = FB !(Bag FloatBind)          -- Destined for top level
396                       !MajorEnv                 -- Levels other than top
397      -- See Note [Representation of FloatBinds]
398
399 instance Outputable FloatBinds where
400   ppr (FB fbs env) = ptext (sLit "FB") <+> (braces $ vcat
401                        [ ptext (sLit "binds =") <+> ppr fbs
402                        , ptext (sLit "env =") <+> ppr env ])
403
404 type MajorEnv = M.IntMap MinorEnv                       -- Keyed by major level
405 type MinorEnv = M.IntMap (Bag FloatBind)                -- Keyed by minor level
406
407 flattenFloats :: FloatBinds -> Bag FloatBind
408 flattenFloats (FB tops others) = tops `unionBags` flattenMajor others
409
410 flattenMajor :: MajorEnv -> Bag FloatBind
411 flattenMajor = M.fold (unionBags . flattenMinor) emptyBag
412
413 flattenMinor :: MinorEnv -> Bag FloatBind
414 flattenMinor = M.fold unionBags emptyBag
415
416 emptyFloats :: FloatBinds
417 emptyFloats = FB emptyBag M.empty
418
419 unitFloat :: Level -> FloatBind -> FloatBinds
420 unitFloat lvl@(Level major minor) b 
421   | isTopLvl lvl = FB (unitBag b) M.empty
422   | otherwise    = FB emptyBag (M.singleton major (M.singleton minor (unitBag b)))
423
424 plusFloats :: FloatBinds -> FloatBinds -> FloatBinds
425 plusFloats (FB t1 b1) (FB t2 b2) = FB (t1 `unionBags` t2) (b1 `plusMajor` b2)
426
427 plusMajor :: MajorEnv -> MajorEnv -> MajorEnv
428 plusMajor = M.unionWith plusMinor
429
430 plusMinor :: MinorEnv -> MinorEnv -> MinorEnv
431 plusMinor = M.unionWith unionBags
432
433 floatsToBindPairs :: Bag FloatBind -> [(Id,CoreExpr)] -> [(Id,CoreExpr)]
434 floatsToBindPairs floats binds = foldrBag add binds floats
435   where
436    add (Rec pairs)         binds = pairs ++ binds
437    add (NonRec binder rhs) binds = (binder,rhs) : binds
438
439 install :: Bag FloatBind -> CoreExpr -> CoreExpr
440 install defn_groups expr
441   = foldrBag install_group expr defn_groups
442   where
443     install_group defns body = Let defns body
444
445 partitionByMajorLevel, partitionByLevel
446         :: Level                -- Partitioning level
447         -> FloatBinds           -- Defns to be divided into 2 piles...
448         -> (FloatBinds,         -- Defns  with level strictly < partition level,
449             Bag FloatBind)      -- The rest
450
451 --       ---- partitionByMajorLevel ----
452 -- Float it if we escape a value lambda, *or* if we get to the top level
453 -- If we can get to the top level, say "yes" anyway. This means that 
454 --      x = f e
455 -- transforms to 
456 --    lvl = e
457 --    x = f lvl
458 -- which is as it should be
459
460 partitionByMajorLevel (Level major _) (FB tops defns)
461   = (FB tops outer, heres `unionBags` flattenMajor inner)
462   where
463     (outer, mb_heres, inner) = M.splitLookup major defns
464     heres = case mb_heres of 
465                Nothing -> emptyBag
466                Just h  -> flattenMinor h
467
468 partitionByLevel (Level major minor) (FB tops defns)
469   = (FB tops (outer_maj `plusMajor` M.singleton major outer_min),
470      here_min `unionBags` flattenMinor inner_min 
471               `unionBags` flattenMajor inner_maj)
472
473   where
474     (outer_maj, mb_here_maj, inner_maj) = M.splitLookup major defns
475     (outer_min, mb_here_min, inner_min) = case mb_here_maj of
476                                             Nothing -> (M.empty, Nothing, M.empty)
477                                             Just min_defns -> M.splitLookup minor min_defns
478     here_min = mb_here_min `orElse` emptyBag
479
480 wrapCostCentre :: CostCentre -> FloatBinds -> FloatBinds
481 wrapCostCentre cc (FB tops defns)
482   = FB (wrap_defns tops) (M.map (M.map wrap_defns) defns)
483   where
484     wrap_defns = mapBag wrap_one 
485     wrap_one (NonRec binder rhs) = NonRec binder (mkSCC cc rhs)
486     wrap_one (Rec pairs)         = Rec (mapSnd (mkSCC cc) pairs)
487 \end{code}