Super-monster patch implementing the new typechecker -- at last
[ghc-hetmet.git] / compiler / simplCore / FloatOut.lhs
1 %
2 % (c) The GRASP/AQUA Project, Glasgow University, 1992-1998
3 %
4 \section[FloatOut]{Float bindings outwards (towards the top level)}
5
6 ``Long-distance'' floating of bindings towards the top level.
7
8 \begin{code}
9 module FloatOut ( floatOutwards ) where
10
11 import CoreSyn
12 import CoreUtils
13 import CoreArity        ( etaExpand )
14 import CoreMonad        ( FloatOutSwitches(..) )
15
16 import DynFlags         ( DynFlags, DynFlag(..) )
17 import ErrUtils         ( dumpIfSet_dyn )
18 import CostCentre       ( dupifyCC, CostCentre )
19 import Id               ( Id, idType, idArity, isBottomingId )
20 import Type             ( isUnLiftedType )
21 import SetLevels        ( Level(..), LevelledExpr, LevelledBind,
22                           setLevels, isTopLvl )
23 import UniqSupply       ( UniqSupply )
24 import Bag
25 import Util
26 import Maybes
27 import UniqFM
28 import Outputable
29 import FastString
30 \end{code}
31
32         -----------------
33         Overall game plan
34         -----------------
35
36 The Big Main Idea is:
37
38         To float out sub-expressions that can thereby get outside
39         a non-one-shot value lambda, and hence may be shared.
40
41
42 To achieve this we may need to do two thing:
43
44    a) Let-bind the sub-expression:
45
46         f (g x)  ==>  let lvl = f (g x) in lvl
47
48       Now we can float the binding for 'lvl'.  
49
50    b) More than that, we may need to abstract wrt a type variable
51
52         \x -> ... /\a -> let v = ...a... in ....
53
54       Here the binding for v mentions 'a' but not 'x'.  So we
55       abstract wrt 'a', to give this binding for 'v':
56
57             vp = /\a -> ...a...
58             v  = vp a
59
60       Now the binding for vp can float out unimpeded.
61       I can't remember why this case seemed important enough to
62       deal with, but I certainly found cases where important floats
63       didn't happen if we did not abstract wrt tyvars.
64
65 With this in mind we can also achieve another goal: lambda lifting.
66 We can make an arbitrary (function) binding float to top level by
67 abstracting wrt *all* local variables, not just type variables, leaving
68 a binding that can be floated right to top level.  Whether or not this
69 happens is controlled by a flag.
70
71
72 Random comments
73 ~~~~~~~~~~~~~~~
74
75 At the moment we never float a binding out to between two adjacent
76 lambdas.  For example:
77
78 @
79         \x y -> let t = x+x in ...
80 ===>
81         \x -> let t = x+x in \y -> ...
82 @
83 Reason: this is less efficient in the case where the original lambda
84 is never partially applied.
85
86 But there's a case I've seen where this might not be true.  Consider:
87 @
88 elEm2 x ys
89   = elem' x ys
90   where
91     elem' _ []  = False
92     elem' x (y:ys)      = x==y || elem' x ys
93 @
94 It turns out that this generates a subexpression of the form
95 @
96         \deq x ys -> let eq = eqFromEqDict deq in ...
97 @
98 vwhich might usefully be separated to
99 @
100         \deq -> let eq = eqFromEqDict deq in \xy -> ...
101 @
102 Well, maybe.  We don't do this at the moment.
103
104
105 %************************************************************************
106 %*                                                                      *
107 \subsection[floatOutwards]{@floatOutwards@: let-floating interface function}
108 %*                                                                      *
109 %************************************************************************
110
111 \begin{code}
112 floatOutwards :: FloatOutSwitches
113               -> DynFlags
114               -> UniqSupply 
115               -> [CoreBind] -> IO [CoreBind]
116
117 floatOutwards float_sws dflags us pgm
118   = do {
119         let { annotated_w_levels = setLevels float_sws pgm us ;
120               (fss, binds_s')    = unzip (map floatTopBind annotated_w_levels)
121             } ;
122
123         dumpIfSet_dyn dflags Opt_D_verbose_core2core "Levels added:"
124                   (vcat (map ppr annotated_w_levels));
125
126         let { (tlets, ntlets, lams) = get_stats (sum_stats fss) };
127
128         dumpIfSet_dyn dflags Opt_D_dump_simpl_stats "FloatOut stats:"
129                 (hcat [ int tlets,  ptext (sLit " Lets floated to top level; "),
130                         int ntlets, ptext (sLit " Lets floated elsewhere; from "),
131                         int lams,   ptext (sLit " Lambda groups")]);
132
133         return (concat binds_s')
134     }
135
136 floatTopBind :: LevelledBind -> (FloatStats, [CoreBind])
137 floatTopBind bind
138   = case (floatBind bind) of { (fs, floats) ->
139     (fs, bagToList (flattenFloats floats)) }
140 \end{code}
141
142 %************************************************************************
143 %*                                                                      *
144 \subsection[FloatOut-Bind]{Floating in a binding (the business end)}
145 %*                                                                      *
146 %************************************************************************
147
148 \begin{code}
149 floatBind :: LevelledBind -> (FloatStats, FloatBinds)
150 floatBind (NonRec (TB var level) rhs)
151   = case (floatRhs level rhs) of { (fs, rhs_floats, rhs') ->
152
153         -- A tiresome hack: 
154         -- see Note [Bottoming floats: eta expansion] in SetLevels
155     let rhs'' | isBottomingId var = etaExpand (idArity var) rhs'
156               | otherwise         = rhs'
157
158     in (fs, rhs_floats `plusFloats` unitFloat level (NonRec var rhs'')) }
159
160 floatBind (Rec pairs)
161   = case floatList do_pair pairs of { (fs, rhs_floats, new_pairs) ->
162         -- NB: the rhs floats may contain references to the 
163         -- bound things.  For example
164         --      f = ...(let v = ...f... in b) ...
165     if not (isTopLvl dest_lvl) then
166         -- Find which bindings float out at least one lambda beyond this one
167         -- These ones can't mention the binders, because they couldn't 
168         -- be escaping a major level if so.
169         -- The ones that are not going further can join the letrec;
170         -- they may not be mutually recursive but the occurrence analyser will
171         -- find that out. In our example we make a Rec thus:
172         --      v = ...f...
173         --      f = ... b ...
174         case (partitionByMajorLevel dest_lvl rhs_floats) of { (floats', heres) ->
175         (fs, floats' `plusFloats` unitFloat dest_lvl 
176                  (Rec (floatsToBindPairs heres new_pairs))) }
177     else
178         -- For top level, no need to partition; just make them all recursive
179         -- (And the partition wouldn't work because they'd all end up in floats')
180         (fs, unitFloat dest_lvl
181                  (Rec (floatsToBindPairs (flattenFloats rhs_floats) new_pairs)))  }
182   where
183     (((TB _ dest_lvl), _) : _) = pairs
184
185     do_pair (TB name level, rhs)
186       = case (floatRhs level rhs) of { (fs, rhs_floats, rhs') ->
187         (fs, rhs_floats, (name, rhs')) }
188
189 ---------------
190 floatList :: (a -> (FloatStats, FloatBinds, b)) -> [a] -> (FloatStats, FloatBinds, [b])
191 floatList _ [] = (zeroStats, emptyFloats, [])
192 floatList f (a:as) = case f a            of { (fs_a,  binds_a,  b)  ->
193                      case floatList f as of { (fs_as, binds_as, bs) ->
194                      (fs_a `add_stats` fs_as, binds_a `plusFloats`  binds_as, b:bs) }}
195 \end{code}
196
197
198 %************************************************************************
199
200 \subsection[FloatOut-Expr]{Floating in expressions}
201 %*                                                                      *
202 %************************************************************************
203
204 \begin{code}
205 floatExpr, floatRhs, floatCaseAlt
206          :: Level
207          -> LevelledExpr
208          -> (FloatStats, FloatBinds, CoreExpr)
209
210 floatCaseAlt lvl arg    -- Used rec rhss, and case-alternative rhss
211   = case (floatExpr lvl arg) of { (fsa, floats, arg') ->
212     case (partitionByMajorLevel lvl floats) of { (floats', heres) ->
213         -- Dump bindings that aren't going to escape from a lambda;
214         -- in particular, we must dump the ones that are bound by 
215         -- the rec or case alternative
216     (fsa, floats', install heres arg') }}
217
218 -----------------
219 floatRhs lvl arg        -- Used for nested non-rec rhss, and fn args
220                         -- See Note [Floating out of RHS]
221   = floatExpr lvl arg
222
223 -----------------
224 floatExpr _ (Var v)   = (zeroStats, emptyFloats, Var v)
225 floatExpr _ (Type ty) = (zeroStats, emptyFloats, Type ty)
226 floatExpr _ (Lit lit) = (zeroStats, emptyFloats, Lit lit)
227           
228 floatExpr lvl (App e a)
229   = case (floatExpr      lvl e) of { (fse, floats_e, e') ->
230     case (floatRhs lvl a)       of { (fsa, floats_a, a') ->
231     (fse `add_stats` fsa, floats_e `plusFloats` floats_a, App e' a') }}
232
233 floatExpr _ lam@(Lam _ _)
234   = let
235         (bndrs_w_lvls, body) = collectBinders lam
236         bndrs                = [b | TB b _ <- bndrs_w_lvls]
237         lvls                 = [l | TB _ l <- bndrs_w_lvls]
238
239         -- For the all-tyvar case we are prepared to pull 
240         -- the lets out, to implement the float-out-of-big-lambda
241         -- transform; but otherwise we only float bindings that are
242         -- going to escape a value lambda.
243         -- In particular, for one-shot lambdas we don't float things
244         -- out; we get no saving by so doing.
245         partition_fn | all isTyCoVar bndrs = partitionByLevel
246                      | otherwise         = partitionByMajorLevel
247     in
248     case (floatExpr (last lvls) body) of { (fs, floats, body') ->
249
250         -- Dump any bindings which absolutely cannot go any further
251     case (partition_fn (head lvls) floats)      of { (floats', heres) ->
252
253     (add_to_stats fs floats', floats', mkLams bndrs (install heres body'))
254     }}
255
256 floatExpr lvl (Note note@(SCC cc) expr)
257   = case (floatExpr lvl expr)    of { (fs, floating_defns, expr') ->
258     let
259         -- Annotate bindings floated outwards past an scc expression
260         -- with the cc.  We mark that cc as "duplicated", though.
261
262         annotated_defns = wrapCostCentre (dupifyCC cc) floating_defns
263     in
264     (fs, annotated_defns, Note note expr') }
265
266 floatExpr lvl (Note note expr)  -- Other than SCCs
267   = case (floatExpr lvl expr)    of { (fs, floating_defns, expr') ->
268     (fs, floating_defns, Note note expr') }
269
270 floatExpr lvl (Cast expr co)
271   = case (floatExpr lvl expr)   of { (fs, floating_defns, expr') ->
272     (fs, floating_defns, Cast expr' co) }
273
274 floatExpr lvl (Let (NonRec (TB bndr bndr_lvl) rhs) body)
275   | isUnLiftedType (idType bndr)  -- Treat unlifted lets just like a case
276                                   -- I.e. floatExpr for rhs, floatCaseAlt for body
277   = case floatExpr lvl rhs          of { (_, rhs_floats, rhs') ->
278     case floatCaseAlt bndr_lvl body of { (fs, body_floats, body') ->
279     (fs, rhs_floats `plusFloats` body_floats, Let (NonRec bndr rhs') body') }}
280
281 floatExpr lvl (Let bind body)
282   = case (floatBind bind)     of { (fsb, bind_floats) ->
283     case (floatExpr lvl body) of { (fse, body_floats, body') ->
284     case partitionByMajorLevel lvl (bind_floats `plusFloats` body_floats) 
285                               of { (floats, heres) ->
286         -- See Note [Avoiding unnecessary floating]
287     (add_stats fsb fse, floats, install heres body')  } } }
288
289 floatExpr lvl (Case scrut (TB case_bndr case_lvl) ty alts)
290   = case floatExpr lvl scrut    of { (fse, fde, scrut') ->
291     case floatList float_alt alts       of { (fsa, fda, alts')  ->
292     (add_stats fse fsa, fda `plusFloats` fde, Case scrut' case_bndr ty alts')
293     }}
294   where
295         -- Use floatCaseAlt for the alternatives, so that we
296         -- don't gratuitiously float bindings out of the RHSs
297     float_alt (con, bs, rhs)
298         = case (floatCaseAlt case_lvl rhs)      of { (fs, rhs_floats, rhs') ->
299           (fs, rhs_floats, (con, [b | TB b _ <- bs], rhs')) }
300 \end{code}
301
302 Note [Avoiding unnecessary floating]
303 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
304 In general we want to avoid floating a let unnecessarily, because
305 it might worsen strictness:
306     let 
307        x = ...(let y = e in y+y)....
308 Here y is demanded.  If we float it outside the lazy 'x=..' then
309 we'd have to zap its demand info, and it may never be restored.
310
311 So at a 'let' we leave the binding right where the are unless
312 the binding will escape a value lambda.  That's what the 
313 partitionByMajorLevel does in the floatExpr (Let ...) case.
314
315 Notice, though, that we must take care to drop any bindings
316 from the body of the let that depend on the staying-put bindings.
317
318 We used instead to do the partitionByMajorLevel on the RHS of an '=',
319 in floatRhs.  But that was quite tiresome.  We needed to test for
320 values or trival rhss, because (in particular) we don't want to insert
321 new bindings between the "=" and the "\".  E.g.
322         f = \x -> let <bind> in <body>
323 We do not want
324         f = let <bind> in \x -> <body>
325 (a) The simplifier will immediately float it further out, so we may
326         as well do so right now; in general, keeping rhss as manifest 
327         values is good
328 (b) If a float-in pass follows immediately, it might add yet more
329         bindings just after the '='.  And some of them might (correctly)
330         be strict even though the 'let f' is lazy, because f, being a value,
331         gets its demand-info zapped by the simplifier.
332 And even all that turned out to be very fragile, and broke
333 altogether when profiling got in the way.
334
335 So now we do the partition right at the (Let..) itself.
336
337 %************************************************************************
338 %*                                                                      *
339 \subsection{Utility bits for floating stats}
340 %*                                                                      *
341 %************************************************************************
342
343 I didn't implement this with unboxed numbers.  I don't want to be too
344 strict in this stuff, as it is rarely turned on.  (WDP 95/09)
345
346 \begin{code}
347 data FloatStats
348   = FlS Int  -- Number of top-floats * lambda groups they've been past
349         Int  -- Number of non-top-floats * lambda groups they've been past
350         Int  -- Number of lambda (groups) seen
351
352 get_stats :: FloatStats -> (Int, Int, Int)
353 get_stats (FlS a b c) = (a, b, c)
354
355 zeroStats :: FloatStats
356 zeroStats = FlS 0 0 0
357
358 sum_stats :: [FloatStats] -> FloatStats
359 sum_stats xs = foldr add_stats zeroStats xs
360
361 add_stats :: FloatStats -> FloatStats -> FloatStats
362 add_stats (FlS a1 b1 c1) (FlS a2 b2 c2)
363   = FlS (a1 + a2) (b1 + b2) (c1 + c2)
364
365 add_to_stats :: FloatStats -> FloatBinds -> FloatStats
366 add_to_stats (FlS a b c) (FB tops others)
367   = FlS (a + lengthBag tops) (b + lengthBag (flattenMajor others)) (c + 1)
368 \end{code}
369
370
371 %************************************************************************
372 %*                                                                      *
373 \subsection{Utility bits for floating}
374 %*                                                                      *
375 %************************************************************************
376
377 Note [Representation of FloatBinds]
378 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
379 The FloatBinds types is somewhat important.  We can get very large numbers
380 of floating bindings, often all destined for the top level.  A typical example
381 is     x = [4,2,5,2,5, .... ]
382 Then we get lots of small expressions like (fromInteger 4), which all get
383 lifted to top level.  
384
385 The trouble is that  
386   (a) we partition these floating bindings *at every binding site* 
387   (b) SetLevels introduces a new bindings site for every float
388 So we had better not look at each binding at each binding site!
389
390 That is why MajorEnv is represented as a finite map.
391
392 We keep the bindings destined for the *top* level separate, because
393 we float them out even if they don't escape a *value* lambda; see
394 partitionByMajorLevel.
395
396
397 \begin{code}
398 type FloatBind = CoreBind       -- INVARIANT: a FloatBind is always lifted
399
400 data FloatBinds  = FB !(Bag FloatBind)          -- Destined for top level
401                       !MajorEnv                 -- Levels other than top
402      -- See Note [Representation of FloatBinds]
403
404 type MajorEnv = UniqFM MinorEnv                 -- Keyed by major level
405 type MinorEnv = UniqFM (Bag FloatBind)          -- Keyed by minor level
406
407 flattenFloats :: FloatBinds -> Bag FloatBind
408 flattenFloats (FB tops others) = tops `unionBags` flattenMajor others
409
410 flattenMajor :: MajorEnv -> Bag FloatBind
411 flattenMajor = foldUFM (unionBags . flattenMinor) emptyBag
412
413 flattenMinor :: MinorEnv -> Bag FloatBind
414 flattenMinor = foldUFM unionBags emptyBag
415
416 emptyFloats :: FloatBinds
417 emptyFloats = FB emptyBag emptyUFM
418
419 unitFloat :: Level -> FloatBind -> FloatBinds
420 unitFloat lvl@(Level major minor) b 
421   | isTopLvl lvl = FB (unitBag b) emptyUFM
422   | otherwise    = FB emptyBag (unitUFM major (unitUFM minor (unitBag b)))
423
424 plusFloats :: FloatBinds -> FloatBinds -> FloatBinds
425 plusFloats (FB t1 b1) (FB t2 b2) = FB (t1 `unionBags` t2) (b1 `plusMajor` b2)
426
427 plusMajor :: MajorEnv -> MajorEnv -> MajorEnv
428 plusMajor = plusUFM_C plusMinor
429
430 plusMinor :: MinorEnv -> MinorEnv -> MinorEnv
431 plusMinor = plusUFM_C unionBags
432
433 floatsToBindPairs :: Bag FloatBind -> [(Id,CoreExpr)] -> [(Id,CoreExpr)]
434 floatsToBindPairs floats binds = foldrBag add binds floats
435   where
436    add (Rec pairs)         binds = pairs ++ binds
437    add (NonRec binder rhs) binds = (binder,rhs) : binds
438
439 install :: Bag FloatBind -> CoreExpr -> CoreExpr
440 install defn_groups expr
441   = foldrBag install_group expr defn_groups
442   where
443     install_group defns body = Let defns body
444
445 partitionByMajorLevel, partitionByLevel
446         :: Level                -- Partitioning level
447         -> FloatBinds           -- Defns to be divided into 2 piles...
448         -> (FloatBinds,         -- Defns  with level strictly < partition level,
449             Bag FloatBind)      -- The rest
450
451 --       ---- partitionByMajorLevel ----
452 -- Float it if we escape a value lambda, *or* if we get to the top level
453 -- If we can get to the top level, say "yes" anyway. This means that 
454 --      x = f e
455 -- transforms to 
456 --    lvl = e
457 --    x = f lvl
458 -- which is as it should be
459
460 partitionByMajorLevel (Level major _) (FB tops defns)
461   = (FB tops outer, heres `unionBags` flattenMajor inner)
462   where
463     (outer, mb_heres, inner) = splitUFM defns major
464     heres = case mb_heres of 
465                Nothing -> emptyBag
466                Just h  -> flattenMinor h
467
468 partitionByLevel (Level major minor) (FB tops defns)
469   = (FB tops (outer_maj `plusMajor` unitUFM major outer_min),
470      here_min `unionBags` flattenMinor inner_min 
471               `unionBags` flattenMajor inner_maj)
472
473   where
474     (outer_maj, mb_here_maj, inner_maj) = splitUFM defns major
475     (outer_min, mb_here_min, inner_min) = case mb_here_maj of
476                                             Nothing -> (emptyUFM, Nothing, emptyUFM)
477                                             Just min_defns -> splitUFM min_defns minor
478     here_min = mb_here_min `orElse` emptyBag
479
480 wrapCostCentre :: CostCentre -> FloatBinds -> FloatBinds
481 wrapCostCentre cc (FB tops defns)
482   = FB (wrap_defns tops) (mapUFM (mapUFM wrap_defns) defns)
483   where
484     wrap_defns = mapBag wrap_one 
485     wrap_one (NonRec binder rhs) = NonRec binder (mkSCC cc rhs)
486     wrap_one (Rec pairs)         = Rec (mapSnd (mkSCC cc) pairs)
487 \end{code}