Don't dump floated bindings just outside a lambda
[ghc-hetmet.git] / compiler / simplCore / FloatOut.lhs
1 %
2 % (c) The GRASP/AQUA Project, Glasgow University, 1992-1998
3 %
4 \section[FloatOut]{Float bindings outwards (towards the top level)}
5
6 ``Long-distance'' floating of bindings towards the top level.
7
8 \begin{code}
9 module FloatOut ( floatOutwards ) where
10
11 #include "HsVersions.h"
12
13 import CoreSyn
14 import CoreUtils        ( mkSCC, exprIsHNF, exprIsTrivial )
15
16 import DynFlags ( DynFlags, DynFlag(..), FloatOutSwitches(..) )
17 import ErrUtils         ( dumpIfSet_dyn )
18 import CostCentre       ( dupifyCC, CostCentre )
19 import Id               ( Id, idType )
20 import Type             ( isUnLiftedType )
21 import CoreLint         ( showPass, endPass )
22 import SetLevels        ( Level(..), LevelledExpr, LevelledBind,
23                           setLevels, ltMajLvl, ltLvl, isTopLvl )
24 import UniqSupply       ( UniqSupply )
25 import List             ( partition )
26 import Outputable
27 \end{code}
28
29         -----------------
30         Overall game plan
31         -----------------
32
33 The Big Main Idea is:
34
35         To float out sub-expressions that can thereby get outside
36         a non-one-shot value lambda, and hence may be shared.
37
38
39 To achieve this we may need to do two thing:
40
41    a) Let-bind the sub-expression:
42
43         f (g x)  ==>  let lvl = f (g x) in lvl
44
45       Now we can float the binding for 'lvl'.  
46
47    b) More than that, we may need to abstract wrt a type variable
48
49         \x -> ... /\a -> let v = ...a... in ....
50
51       Here the binding for v mentions 'a' but not 'x'.  So we
52       abstract wrt 'a', to give this binding for 'v':
53
54             vp = /\a -> ...a...
55             v  = vp a
56
57       Now the binding for vp can float out unimpeded.
58       I can't remember why this case seemed important enough to
59       deal with, but I certainly found cases where important floats
60       didn't happen if we did not abstract wrt tyvars.
61
62 With this in mind we can also achieve another goal: lambda lifting.
63 We can make an arbitrary (function) binding float to top level by
64 abstracting wrt *all* local variables, not just type variables, leaving
65 a binding that can be floated right to top level.  Whether or not this
66 happens is controlled by a flag.
67
68
69 Random comments
70 ~~~~~~~~~~~~~~~
71
72 At the moment we never float a binding out to between two adjacent
73 lambdas.  For example:
74
75 @
76         \x y -> let t = x+x in ...
77 ===>
78         \x -> let t = x+x in \y -> ...
79 @
80 Reason: this is less efficient in the case where the original lambda
81 is never partially applied.
82
83 But there's a case I've seen where this might not be true.  Consider:
84 @
85 elEm2 x ys
86   = elem' x ys
87   where
88     elem' _ []  = False
89     elem' x (y:ys)      = x==y || elem' x ys
90 @
91 It turns out that this generates a subexpression of the form
92 @
93         \deq x ys -> let eq = eqFromEqDict deq in ...
94 @
95 vwhich might usefully be separated to
96 @
97         \deq -> let eq = eqFromEqDict deq in \xy -> ...
98 @
99 Well, maybe.  We don't do this at the moment.
100
101 \begin{code}
102 type FloatBind     = (Level, CoreBind)  -- INVARIANT: a FloatBind is always lifted
103 type FloatBinds    = [FloatBind]        
104 \end{code}
105
106 %************************************************************************
107 %*                                                                      *
108 \subsection[floatOutwards]{@floatOutwards@: let-floating interface function}
109 %*                                                                      *
110 %************************************************************************
111
112 \begin{code}
113 floatOutwards :: FloatOutSwitches
114               -> DynFlags
115               -> UniqSupply 
116               -> [CoreBind] -> IO [CoreBind]
117
118 floatOutwards float_sws dflags us pgm
119   = do {
120         showPass dflags float_msg ;
121
122         let { annotated_w_levels = setLevels float_sws pgm us ;
123               (fss, binds_s')    = unzip (map floatTopBind annotated_w_levels)
124             } ;
125
126         dumpIfSet_dyn dflags Opt_D_verbose_core2core "Levels added:"
127                   (vcat (map ppr annotated_w_levels));
128
129         let { (tlets, ntlets, lams) = get_stats (sum_stats fss) };
130
131         dumpIfSet_dyn dflags Opt_D_dump_simpl_stats "FloatOut stats:"
132                 (hcat [ int tlets,  ptext SLIT(" Lets floated to top level; "),
133                         int ntlets, ptext SLIT(" Lets floated elsewhere; from "),
134                         int lams,   ptext SLIT(" Lambda groups")]);
135
136         endPass dflags float_msg  Opt_D_verbose_core2core (concat binds_s')
137                         {- no specific flag for dumping float-out -} 
138     }
139   where
140     float_msg = showSDoc (text "Float out" <+> parens (sws float_sws))
141     sws (FloatOutSw lam const) = pp_not lam   <+> text "lambdas" <> comma <+>
142                                  pp_not const <+> text "constants"
143     pp_not True  = empty
144     pp_not False = text "not"
145
146 floatTopBind bind
147   = case (floatBind bind) of { (fs, floats) ->
148     (fs, floatsToBinds floats)
149     }
150 \end{code}
151
152 %************************************************************************
153 %*                                                                      *
154 \subsection[FloatOut-Bind]{Floating in a binding (the business end)}
155 %*                                                                      *
156 %************************************************************************
157
158
159 \begin{code}
160 floatBind :: LevelledBind -> (FloatStats, FloatBinds)
161
162 floatBind (NonRec (TB name level) rhs)
163   = case (floatRhs level rhs) of { (fs, rhs_floats, rhs') ->
164     (fs, rhs_floats ++ [(level, NonRec name rhs')]) }
165
166 floatBind bind@(Rec pairs)
167   = case (unzip3 (map do_pair pairs)) of { (fss, rhss_floats, new_pairs) ->
168     let rhs_floats = concat rhss_floats in
169
170     if not (isTopLvl bind_dest_lvl) then
171         -- Find which bindings float out at least one lambda beyond this one
172         -- These ones can't mention the binders, because they couldn't 
173         -- be escaping a major level if so.
174         -- The ones that are not going further can join the letrec;
175         -- they may not be mutually recursive but the occurrence analyser will
176         -- find that out.
177         case (partitionByMajorLevel bind_dest_lvl rhs_floats) of { (floats', heres) ->
178         (sum_stats fss, floats' ++ [(bind_dest_lvl, Rec (floatsToBindPairs heres ++ new_pairs))]) }
179     else
180         -- In a recursive binding, *destined for* the top level
181         -- (only), the rhs floats may contain references to the 
182         -- bound things.  For example
183         --      f = ...(let v = ...f... in b) ...
184         --  might get floated to
185         --      v = ...f...
186         --      f = ... b ...
187         -- and hence we must (pessimistically) make all the floats recursive
188         -- with the top binding.  Later dependency analysis will unravel it.
189         --
190         -- This can only happen for bindings destined for the top level,
191         -- because only then will partitionByMajorLevel allow through a binding
192         -- that only differs in its minor level
193         (sum_stats fss, [(bind_dest_lvl, Rec (new_pairs ++ floatsToBindPairs rhs_floats))])
194     }
195   where
196     bind_dest_lvl = getBindLevel bind
197
198     do_pair (TB name level, rhs)
199       = case (floatRhs level rhs) of { (fs, rhs_floats, rhs') ->
200         (fs, rhs_floats, (name, rhs'))
201         }
202 \end{code}
203
204 %************************************************************************
205
206 \subsection[FloatOut-Expr]{Floating in expressions}
207 %*                                                                      *
208 %************************************************************************
209
210 \begin{code}
211 floatExpr, floatRhs, floatCaseAlt
212          :: Level
213          -> LevelledExpr
214          -> (FloatStats, FloatBinds, CoreExpr)
215
216 floatCaseAlt lvl arg    -- Used rec rhss, and case-alternative rhss
217   = case (floatExpr lvl arg) of { (fsa, floats, arg') ->
218     case (partitionByMajorLevel lvl floats) of { (floats', heres) ->
219         -- Dump bindings that aren't going to escape from a lambda;
220         -- in particular, we must dump the ones that are bound by 
221         -- the rec or case alternative
222     (fsa, floats', install heres arg') }}
223
224 floatRhs lvl arg        -- Used for nested non-rec rhss, and fn args
225                         -- See Note [Floating out of RHS]
226   = case (floatExpr lvl arg) of { (fsa, floats, arg') ->
227     if exprIsHNF arg' || exprIsTrivial arg' then
228         (fsa, floats, arg')
229     else
230     case (partitionByMajorLevel lvl floats) of { (floats', heres) ->
231     (fsa, floats', install heres arg') }}
232
233 -- Note [Floating out of RHSs]
234 -- ~~~~~~~~~~~~~~~~~~~~~~~~~~~
235 -- Dump bindings that aren't going to escape from a lambda
236 -- This isn't a scoping issue (the binder isn't in scope in the RHS 
237 --      of a non-rec binding)
238 -- Rather, it is to avoid floating the x binding out of
239 --      f (let x = e in b)
240 -- unnecessarily.  But we first test for values or trival rhss,
241 -- because (in particular) we don't want to insert new bindings between
242 -- the "=" and the "\".  E.g.
243 --      f = \x -> let <bind> in <body>
244 -- We do not want
245 --      f = let <bind> in \x -> <body>
246 -- (a) The simplifier will immediately float it further out, so we may
247 --      as well do so right now; in general, keeping rhss as manifest 
248 --      values is good
249 -- (b) If a float-in pass follows immediately, it might add yet more
250 --      bindings just after the '='.  And some of them might (correctly)
251 --      be strict even though the 'let f' is lazy, because f, being a value,
252 --      gets its demand-info zapped by the simplifier.
253
254 floatExpr _ (Var v)   = (zeroStats, [], Var v)
255 floatExpr _ (Type ty) = (zeroStats, [], Type ty)
256 floatExpr _ (Lit lit) = (zeroStats, [], Lit lit)
257           
258 floatExpr lvl (App e a)
259   = case (floatExpr      lvl e) of { (fse, floats_e, e') ->
260     case (floatRhs lvl a)       of { (fsa, floats_a, a') ->
261     (fse `add_stats` fsa, floats_e ++ floats_a, App e' a') }}
262
263 floatExpr lvl lam@(Lam _ _)
264   = let
265         (bndrs_w_lvls, body) = collectBinders lam
266         bndrs                = [b | TB b _ <- bndrs_w_lvls]
267         lvls                 = [l | TB b l <- bndrs_w_lvls]
268
269         -- For the all-tyvar case we are prepared to pull 
270         -- the lets out, to implement the float-out-of-big-lambda
271         -- transform; but otherwise we only float bindings that are
272         -- going to escape a value lambda.
273         -- In particular, for one-shot lambdas we don't float things
274         -- out; we get no saving by so doing.
275         partition_fn | all isTyVar bndrs = partitionByLevel
276                      | otherwise         = partitionByMajorLevel
277     in
278     case (floatExpr (last lvls) body) of { (fs, floats, body') ->
279
280         -- Dump any bindings which absolutely cannot go any further
281     case (partition_fn (head lvls) floats)      of { (floats', heres) ->
282
283     (add_to_stats fs floats', floats', mkLams bndrs (install heres body'))
284     }}
285
286 floatExpr lvl (Note note@(SCC cc) expr)
287   = case (floatExpr lvl expr)    of { (fs, floating_defns, expr') ->
288     let
289         -- Annotate bindings floated outwards past an scc expression
290         -- with the cc.  We mark that cc as "duplicated", though.
291
292         annotated_defns = annotate (dupifyCC cc) floating_defns
293     in
294     (fs, annotated_defns, Note note expr') }
295   where
296     annotate :: CostCentre -> FloatBinds -> FloatBinds
297
298     annotate dupd_cc defn_groups
299       = [ (level, ann_bind floater) | (level, floater) <- defn_groups ]
300       where
301         ann_bind (NonRec binder rhs)
302           = NonRec binder (mkSCC dupd_cc rhs)
303
304         ann_bind (Rec pairs)
305           = Rec [(binder, mkSCC dupd_cc rhs) | (binder, rhs) <- pairs]
306
307 floatExpr lvl (Note InlineMe expr)      -- Other than SCCs
308   = case floatExpr InlineCtxt expr of { (fs, floating_defns, expr') ->
309         -- There can be some floating_defns, arising from
310         -- ordinary lets that were there all the time.  It seems
311         -- more efficient to test once here than to avoid putting
312         -- them into floating_defns (which would mean testing for
313         -- inlineCtxt  at every let)
314     (fs, [], Note InlineMe (install floating_defns expr')) }    -- See notes in SetLevels
315
316 floatExpr lvl (Note note expr)  -- Other than SCCs
317   = case (floatExpr lvl expr)    of { (fs, floating_defns, expr') ->
318     (fs, floating_defns, Note note expr') }
319
320 floatExpr lvl (Cast expr co)
321   = case (floatExpr lvl expr)   of { (fs, floating_defns, expr') ->
322     (fs, floating_defns, Cast expr' co) }
323
324 floatExpr lvl (Let (NonRec (TB bndr bndr_lvl) rhs) body)
325   | isUnLiftedType (idType bndr)        -- Treat unlifted lets just like a case
326   = case floatExpr lvl rhs      of { (fs, rhs_floats, rhs') ->
327     case floatRhs bndr_lvl body of { (fs, body_floats, body') ->
328     (fs, rhs_floats ++ body_floats, Let (NonRec bndr rhs') body') }}
329
330 floatExpr lvl (Let bind body)
331   = case (floatBind bind)     of { (fsb, bind_floats) ->
332     case (floatExpr lvl body) of { (fse, body_floats, body') ->
333     (add_stats fsb fse,
334      bind_floats ++ body_floats,
335      body')  }}
336
337 floatExpr lvl (Case scrut (TB case_bndr case_lvl) ty alts)
338   = case floatExpr lvl scrut    of { (fse, fde, scrut') ->
339     case floatList float_alt alts       of { (fsa, fda, alts')  ->
340     (add_stats fse fsa, fda ++ fde, Case scrut' case_bndr ty alts')
341     }}
342   where
343         -- Use floatCaseAlt for the alternatives, so that we
344         -- don't gratuitiously float bindings out of the RHSs
345     float_alt (con, bs, rhs)
346         = case (floatCaseAlt case_lvl rhs)      of { (fs, rhs_floats, rhs') ->
347           (fs, rhs_floats, (con, [b | TB b _ <- bs], rhs')) }
348
349
350 floatList :: (a -> (FloatStats, FloatBinds, b)) -> [a] -> (FloatStats, FloatBinds, [b])
351 floatList f [] = (zeroStats, [], [])
352 floatList f (a:as) = case f a            of { (fs_a,  binds_a,  b)  ->
353                      case floatList f as of { (fs_as, binds_as, bs) ->
354                      (fs_a `add_stats` fs_as, binds_a ++ binds_as, b:bs) }}
355 \end{code}
356
357 %************************************************************************
358 %*                                                                      *
359 \subsection{Utility bits for floating stats}
360 %*                                                                      *
361 %************************************************************************
362
363 I didn't implement this with unboxed numbers.  I don't want to be too
364 strict in this stuff, as it is rarely turned on.  (WDP 95/09)
365
366 \begin{code}
367 data FloatStats
368   = FlS Int  -- Number of top-floats * lambda groups they've been past
369         Int  -- Number of non-top-floats * lambda groups they've been past
370         Int  -- Number of lambda (groups) seen
371
372 get_stats (FlS a b c) = (a, b, c)
373
374 zeroStats = FlS 0 0 0
375
376 sum_stats xs = foldr add_stats zeroStats xs
377
378 add_stats (FlS a1 b1 c1) (FlS a2 b2 c2)
379   = FlS (a1 + a2) (b1 + b2) (c1 + c2)
380
381 add_to_stats (FlS a b c) floats
382   = FlS (a + length top_floats) (b + length other_floats) (c + 1)
383   where
384     (top_floats, other_floats) = partition to_very_top floats
385
386     to_very_top (my_lvl, _) = isTopLvl my_lvl
387 \end{code}
388
389
390 %************************************************************************
391 %*                                                                      *
392 \subsection{Utility bits for floating}
393 %*                                                                      *
394 %************************************************************************
395
396 \begin{code}
397 getBindLevel (NonRec (TB _ lvl) _)      = lvl
398 getBindLevel (Rec (((TB _ lvl), _) : _)) = lvl
399 \end{code}
400
401 \begin{code}
402 partitionByMajorLevel, partitionByLevel
403         :: Level                -- Partitioning level
404
405         -> FloatBinds           -- Defns to be divided into 2 piles...
406
407         -> (FloatBinds, -- Defns  with level strictly < partition level,
408             FloatBinds) -- The rest
409
410
411 partitionByMajorLevel ctxt_lvl defns
412   = partition float_further defns
413   where
414         -- Float it if we escape a value lambda, or if we get to the top level
415     float_further (my_lvl, bind) = my_lvl `ltMajLvl` ctxt_lvl || isTopLvl my_lvl
416         -- The isTopLvl part says that if we can get to the top level, say "yes" anyway
417         -- This means that 
418         --      x = f e
419         -- transforms to 
420         --    lvl = e
421         --    x = f lvl
422         -- which is as it should be
423
424 partitionByLevel ctxt_lvl defns
425   = partition float_further defns
426   where
427     float_further (my_lvl, _) = my_lvl `ltLvl` ctxt_lvl
428 \end{code}
429
430 \begin{code}
431 floatsToBinds :: FloatBinds -> [CoreBind]
432 floatsToBinds floats = map snd floats
433
434 floatsToBindPairs :: FloatBinds -> [(Id,CoreExpr)]
435
436 floatsToBindPairs floats = concat (map mk_pairs floats)
437   where
438    mk_pairs (_, Rec pairs)         = pairs
439    mk_pairs (_, NonRec binder rhs) = [(binder,rhs)]
440
441 install :: FloatBinds -> CoreExpr -> CoreExpr
442
443 install defn_groups expr
444   = foldr install_group expr defn_groups
445   where
446     install_group (_, defns) body = Let defns body
447 \end{code}