Fix a long-standing bug in FloatOut
[ghc-hetmet.git] / compiler / simplCore / FloatOut.lhs
1 %
2 % (c) The GRASP/AQUA Project, Glasgow University, 1992-1998
3 %
4 \section[FloatOut]{Float bindings outwards (towards the top level)}
5
6 ``Long-distance'' floating of bindings towards the top level.
7
8 \begin{code}
9 {-# OPTIONS -w #-}
10 -- The above warning supression flag is a temporary kludge.
11 -- While working on this module you are encouraged to remove it and fix
12 -- any warnings in the module. See
13 --     http://hackage.haskell.org/trac/ghc/wiki/Commentary/CodingStyle#Warnings
14 -- for details
15
16 module FloatOut ( floatOutwards ) where
17
18 #include "HsVersions.h"
19
20 import CoreSyn
21 import CoreUtils
22
23 import DynFlags ( DynFlags, DynFlag(..), FloatOutSwitches(..) )
24 import ErrUtils         ( dumpIfSet_dyn )
25 import CostCentre       ( dupifyCC, CostCentre )
26 import Id               ( Id, idType )
27 import Type             ( isUnLiftedType )
28 import CoreLint         ( showPass, endPass )
29 import SetLevels        ( Level(..), LevelledExpr, LevelledBind,
30                           setLevels, ltMajLvl, ltLvl, isTopLvl )
31 import UniqSupply       ( UniqSupply )
32 import List             ( partition )
33 import Outputable
34 import FastString
35 \end{code}
36
37         -----------------
38         Overall game plan
39         -----------------
40
41 The Big Main Idea is:
42
43         To float out sub-expressions that can thereby get outside
44         a non-one-shot value lambda, and hence may be shared.
45
46
47 To achieve this we may need to do two thing:
48
49    a) Let-bind the sub-expression:
50
51         f (g x)  ==>  let lvl = f (g x) in lvl
52
53       Now we can float the binding for 'lvl'.  
54
55    b) More than that, we may need to abstract wrt a type variable
56
57         \x -> ... /\a -> let v = ...a... in ....
58
59       Here the binding for v mentions 'a' but not 'x'.  So we
60       abstract wrt 'a', to give this binding for 'v':
61
62             vp = /\a -> ...a...
63             v  = vp a
64
65       Now the binding for vp can float out unimpeded.
66       I can't remember why this case seemed important enough to
67       deal with, but I certainly found cases where important floats
68       didn't happen if we did not abstract wrt tyvars.
69
70 With this in mind we can also achieve another goal: lambda lifting.
71 We can make an arbitrary (function) binding float to top level by
72 abstracting wrt *all* local variables, not just type variables, leaving
73 a binding that can be floated right to top level.  Whether or not this
74 happens is controlled by a flag.
75
76
77 Random comments
78 ~~~~~~~~~~~~~~~
79
80 At the moment we never float a binding out to between two adjacent
81 lambdas.  For example:
82
83 @
84         \x y -> let t = x+x in ...
85 ===>
86         \x -> let t = x+x in \y -> ...
87 @
88 Reason: this is less efficient in the case where the original lambda
89 is never partially applied.
90
91 But there's a case I've seen where this might not be true.  Consider:
92 @
93 elEm2 x ys
94   = elem' x ys
95   where
96     elem' _ []  = False
97     elem' x (y:ys)      = x==y || elem' x ys
98 @
99 It turns out that this generates a subexpression of the form
100 @
101         \deq x ys -> let eq = eqFromEqDict deq in ...
102 @
103 vwhich might usefully be separated to
104 @
105         \deq -> let eq = eqFromEqDict deq in \xy -> ...
106 @
107 Well, maybe.  We don't do this at the moment.
108
109 \begin{code}
110 type FloatBind     = (Level, CoreBind)  -- INVARIANT: a FloatBind is always lifted
111 type FloatBinds    = [FloatBind]        
112 \end{code}
113
114 %************************************************************************
115 %*                                                                      *
116 \subsection[floatOutwards]{@floatOutwards@: let-floating interface function}
117 %*                                                                      *
118 %************************************************************************
119
120 \begin{code}
121 floatOutwards :: FloatOutSwitches
122               -> DynFlags
123               -> UniqSupply 
124               -> [CoreBind] -> IO [CoreBind]
125
126 floatOutwards float_sws dflags us pgm
127   = do {
128         showPass dflags float_msg ;
129
130         let { annotated_w_levels = setLevels float_sws pgm us ;
131               (fss, binds_s')    = unzip (map floatTopBind annotated_w_levels)
132             } ;
133
134         dumpIfSet_dyn dflags Opt_D_verbose_core2core "Levels added:"
135                   (vcat (map ppr annotated_w_levels));
136
137         let { (tlets, ntlets, lams) = get_stats (sum_stats fss) };
138
139         dumpIfSet_dyn dflags Opt_D_dump_simpl_stats "FloatOut stats:"
140                 (hcat [ int tlets,  ptext SLIT(" Lets floated to top level; "),
141                         int ntlets, ptext SLIT(" Lets floated elsewhere; from "),
142                         int lams,   ptext SLIT(" Lambda groups")]);
143
144         endPass dflags float_msg  Opt_D_verbose_core2core (concat binds_s')
145                         {- no specific flag for dumping float-out -} 
146     }
147   where
148     float_msg = showSDoc (text "Float out" <+> parens (sws float_sws))
149     sws (FloatOutSw lam const) = pp_not lam   <+> text "lambdas" <> comma <+>
150                                  pp_not const <+> text "constants"
151     pp_not True  = empty
152     pp_not False = text "not"
153
154 floatTopBind bind
155   = case (floatBind bind) of { (fs, floats) ->
156     (fs, floatsToBinds floats)
157     }
158 \end{code}
159
160 %************************************************************************
161 %*                                                                      *
162 \subsection[FloatOut-Bind]{Floating in a binding (the business end)}
163 %*                                                                      *
164 %************************************************************************
165
166
167 \begin{code}
168 floatBind :: LevelledBind -> (FloatStats, FloatBinds)
169
170 floatBind (NonRec (TB name level) rhs)
171   = case (floatRhs level rhs) of { (fs, rhs_floats, rhs') ->
172     (fs, rhs_floats ++ [(level, NonRec name rhs')]) }
173
174 floatBind bind@(Rec pairs)
175   = case (unzip3 (map do_pair pairs)) of { (fss, rhss_floats, new_pairs) ->
176     let rhs_floats = concat rhss_floats in
177
178     if not (isTopLvl bind_dest_lvl) then
179         -- Find which bindings float out at least one lambda beyond this one
180         -- These ones can't mention the binders, because they couldn't 
181         -- be escaping a major level if so.
182         -- The ones that are not going further can join the letrec;
183         -- they may not be mutually recursive but the occurrence analyser will
184         -- find that out.
185         case (partitionByMajorLevel bind_dest_lvl rhs_floats) of { (floats', heres) ->
186         (sum_stats fss, floats' ++ [(bind_dest_lvl, Rec (floatsToBindPairs heres ++ new_pairs))]) }
187     else
188         -- In a recursive binding, *destined for* the top level
189         -- (only), the rhs floats may contain references to the 
190         -- bound things.  For example
191         --      f = ...(let v = ...f... in b) ...
192         --  might get floated to
193         --      v = ...f...
194         --      f = ... b ...
195         -- and hence we must (pessimistically) make all the floats recursive
196         -- with the top binding.  Later dependency analysis will unravel it.
197         --
198         -- This can only happen for bindings destined for the top level,
199         -- because only then will partitionByMajorLevel allow through a binding
200         -- that only differs in its minor level
201         (sum_stats fss, [(bind_dest_lvl, Rec (new_pairs ++ floatsToBindPairs rhs_floats))])
202     }
203   where
204     bind_dest_lvl = getBindLevel bind
205
206     do_pair (TB name level, rhs)
207       = case (floatRhs level rhs) of { (fs, rhs_floats, rhs') ->
208         (fs, rhs_floats, (name, rhs'))
209         }
210 \end{code}
211
212 %************************************************************************
213
214 \subsection[FloatOut-Expr]{Floating in expressions}
215 %*                                                                      *
216 %************************************************************************
217
218 \begin{code}
219 floatExpr, floatRhs, floatCaseAlt
220          :: Level
221          -> LevelledExpr
222          -> (FloatStats, FloatBinds, CoreExpr)
223
224 floatCaseAlt lvl arg    -- Used rec rhss, and case-alternative rhss
225   = case (floatExpr lvl arg) of { (fsa, floats, arg') ->
226     case (partitionByMajorLevel lvl floats) of { (floats', heres) ->
227         -- Dump bindings that aren't going to escape from a lambda;
228         -- in particular, we must dump the ones that are bound by 
229         -- the rec or case alternative
230     (fsa, floats', install heres arg') }}
231
232 floatRhs lvl arg        -- Used for nested non-rec rhss, and fn args
233                         -- See Note [Floating out of RHS]
234   = case (floatExpr lvl arg) of { (fsa, floats, arg') ->
235     if exprIsCheap arg' then    
236         (fsa, floats, arg')
237     else
238     case (partitionByMajorLevel lvl floats) of { (floats', heres) ->
239     (fsa, floats', install heres arg') }}
240
241 -- Note [Floating out of RHSs]
242 -- ~~~~~~~~~~~~~~~~~~~~~~~~~~~
243 -- Dump bindings that aren't going to escape from a lambda
244 -- This isn't a scoping issue (the binder isn't in scope in the RHS 
245 --      of a non-rec binding)
246 -- Rather, it is to avoid floating the x binding out of
247 --      f (let x = e in b)
248 -- unnecessarily.  But we first test for values or trival rhss,
249 -- because (in particular) we don't want to insert new bindings between
250 -- the "=" and the "\".  E.g.
251 --      f = \x -> let <bind> in <body>
252 -- We do not want
253 --      f = let <bind> in \x -> <body>
254 -- (a) The simplifier will immediately float it further out, so we may
255 --      as well do so right now; in general, keeping rhss as manifest 
256 --      values is good
257 -- (b) If a float-in pass follows immediately, it might add yet more
258 --      bindings just after the '='.  And some of them might (correctly)
259 --      be strict even though the 'let f' is lazy, because f, being a value,
260 --      gets its demand-info zapped by the simplifier.
261 --
262 -- We use exprIsCheap because that is also what's used by the simplifier
263 -- to decide whether to float a let out of a let
264
265 floatExpr _ (Var v)   = (zeroStats, [], Var v)
266 floatExpr _ (Type ty) = (zeroStats, [], Type ty)
267 floatExpr _ (Lit lit) = (zeroStats, [], Lit lit)
268           
269 floatExpr lvl (App e a)
270   = case (floatExpr      lvl e) of { (fse, floats_e, e') ->
271     case (floatRhs lvl a)       of { (fsa, floats_a, a') ->
272     (fse `add_stats` fsa, floats_e ++ floats_a, App e' a') }}
273
274 floatExpr lvl lam@(Lam _ _)
275   = let
276         (bndrs_w_lvls, body) = collectBinders lam
277         bndrs                = [b | TB b _ <- bndrs_w_lvls]
278         lvls                 = [l | TB b l <- bndrs_w_lvls]
279
280         -- For the all-tyvar case we are prepared to pull 
281         -- the lets out, to implement the float-out-of-big-lambda
282         -- transform; but otherwise we only float bindings that are
283         -- going to escape a value lambda.
284         -- In particular, for one-shot lambdas we don't float things
285         -- out; we get no saving by so doing.
286         partition_fn | all isTyVar bndrs = partitionByLevel
287                      | otherwise         = partitionByMajorLevel
288     in
289     case (floatExpr (last lvls) body) of { (fs, floats, body') ->
290
291         -- Dump any bindings which absolutely cannot go any further
292     case (partition_fn (head lvls) floats)      of { (floats', heres) ->
293
294     (add_to_stats fs floats', floats', mkLams bndrs (install heres body'))
295     }}
296
297 floatExpr lvl (Note note@(SCC cc) expr)
298   = case (floatExpr lvl expr)    of { (fs, floating_defns, expr') ->
299     let
300         -- Annotate bindings floated outwards past an scc expression
301         -- with the cc.  We mark that cc as "duplicated", though.
302
303         annotated_defns = annotate (dupifyCC cc) floating_defns
304     in
305     (fs, annotated_defns, Note note expr') }
306   where
307     annotate :: CostCentre -> FloatBinds -> FloatBinds
308
309     annotate dupd_cc defn_groups
310       = [ (level, ann_bind floater) | (level, floater) <- defn_groups ]
311       where
312         ann_bind (NonRec binder rhs)
313           = NonRec binder (mkSCC dupd_cc rhs)
314
315         ann_bind (Rec pairs)
316           = Rec [(binder, mkSCC dupd_cc rhs) | (binder, rhs) <- pairs]
317
318 floatExpr lvl (Note InlineMe expr)      -- Other than SCCs
319   = (zeroStats, [], Note InlineMe (unTag expr))
320         -- Do no floating at all inside INLINE. [_$_]
321         -- The SetLevels pass did not clone the bindings, so it's
322         -- unsafe to do any floating, even if we dump the results
323         -- inside the Note (which is what we used to do).
324
325 floatExpr lvl (Note note expr)  -- Other than SCCs
326   = case (floatExpr lvl expr)    of { (fs, floating_defns, expr') ->
327     (fs, floating_defns, Note note expr') }
328
329 floatExpr lvl (Cast expr co)
330   = case (floatExpr lvl expr)   of { (fs, floating_defns, expr') ->
331     (fs, floating_defns, Cast expr' co) }
332
333 floatExpr lvl (Let (NonRec (TB bndr bndr_lvl) rhs) body)
334   | isUnLiftedType (idType bndr)        -- Treat unlifted lets just like a case
335                                 -- I.e. floatExpr for rhs, floatCaseAlt for body
336   = case floatExpr lvl rhs          of { (fs, rhs_floats, rhs') ->
337     case floatCaseAlt bndr_lvl body of { (fs, body_floats, body') ->
338     (fs, rhs_floats ++ body_floats, Let (NonRec bndr rhs') body') }}
339
340 floatExpr lvl (Let bind body)
341   = case (floatBind bind)     of { (fsb, bind_floats) ->
342     case (floatExpr lvl body) of { (fse, body_floats, body') ->
343     (add_stats fsb fse,
344      bind_floats ++ body_floats,
345      body')  }}
346
347 floatExpr lvl (Case scrut (TB case_bndr case_lvl) ty alts)
348   = case floatExpr lvl scrut    of { (fse, fde, scrut') ->
349     case floatList float_alt alts       of { (fsa, fda, alts')  ->
350     (add_stats fse fsa, fda ++ fde, Case scrut' case_bndr ty alts')
351     }}
352   where
353         -- Use floatCaseAlt for the alternatives, so that we
354         -- don't gratuitiously float bindings out of the RHSs
355     float_alt (con, bs, rhs)
356         = case (floatCaseAlt case_lvl rhs)      of { (fs, rhs_floats, rhs') ->
357           (fs, rhs_floats, (con, [b | TB b _ <- bs], rhs')) }
358
359
360 floatList :: (a -> (FloatStats, FloatBinds, b)) -> [a] -> (FloatStats, FloatBinds, [b])
361 floatList f [] = (zeroStats, [], [])
362 floatList f (a:as) = case f a            of { (fs_a,  binds_a,  b)  ->
363                      case floatList f as of { (fs_as, binds_as, bs) ->
364                      (fs_a `add_stats` fs_as, binds_a ++ binds_as, b:bs) }}
365
366 unTagBndr :: TaggedBndr tag -> CoreBndr
367 unTagBndr (TB b _) = b
368
369 unTag :: TaggedExpr tag -> CoreExpr
370 unTag (Var v)     = Var v
371 unTag (Lit l)     = Lit l
372 unTag (Type ty)   = Type ty
373 unTag (Note n e)  = Note n (unTag e)
374 unTag (App e1 e2) = App (unTag e1) (unTag e2)
375 unTag (Lam b e)   = Lam (unTagBndr b) (unTag e)
376 unTag (Cast e co) = Cast (unTag e) co
377 unTag (Let (Rec prs) e)    = Let (Rec [(unTagBndr b,unTag r) | (b, r) <- prs]) (unTag e)
378 unTag (Let (NonRec b r) e) = Let (NonRec (unTagBndr b) (unTag r)) (unTag e)
379 unTag (Case e b ty alts)   = Case (unTag e) (unTagBndr b) ty
380                                   [(c, map unTagBndr bs, unTag r) | (c,bs,r) <- alts]
381 \end{code}
382
383 %************************************************************************
384 %*                                                                      *
385 \subsection{Utility bits for floating stats}
386 %*                                                                      *
387 %************************************************************************
388
389 I didn't implement this with unboxed numbers.  I don't want to be too
390 strict in this stuff, as it is rarely turned on.  (WDP 95/09)
391
392 \begin{code}
393 data FloatStats
394   = FlS Int  -- Number of top-floats * lambda groups they've been past
395         Int  -- Number of non-top-floats * lambda groups they've been past
396         Int  -- Number of lambda (groups) seen
397
398 get_stats (FlS a b c) = (a, b, c)
399
400 zeroStats = FlS 0 0 0
401
402 sum_stats xs = foldr add_stats zeroStats xs
403
404 add_stats (FlS a1 b1 c1) (FlS a2 b2 c2)
405   = FlS (a1 + a2) (b1 + b2) (c1 + c2)
406
407 add_to_stats (FlS a b c) floats
408   = FlS (a + length top_floats) (b + length other_floats) (c + 1)
409   where
410     (top_floats, other_floats) = partition to_very_top floats
411
412     to_very_top (my_lvl, _) = isTopLvl my_lvl
413 \end{code}
414
415
416 %************************************************************************
417 %*                                                                      *
418 \subsection{Utility bits for floating}
419 %*                                                                      *
420 %************************************************************************
421
422 \begin{code}
423 getBindLevel (NonRec (TB _ lvl) _)      = lvl
424 getBindLevel (Rec (((TB _ lvl), _) : _)) = lvl
425 \end{code}
426
427 \begin{code}
428 partitionByMajorLevel, partitionByLevel
429         :: Level                -- Partitioning level
430
431         -> FloatBinds           -- Defns to be divided into 2 piles...
432
433         -> (FloatBinds, -- Defns  with level strictly < partition level,
434             FloatBinds) -- The rest
435
436
437 partitionByMajorLevel ctxt_lvl defns
438   = partition float_further defns
439   where
440         -- Float it if we escape a value lambda, or if we get to the top level
441     float_further (my_lvl, bind) = my_lvl `ltMajLvl` ctxt_lvl || isTopLvl my_lvl
442         -- The isTopLvl part says that if we can get to the top level, say "yes" anyway
443         -- This means that 
444         --      x = f e
445         -- transforms to 
446         --    lvl = e
447         --    x = f lvl
448         -- which is as it should be
449
450 partitionByLevel ctxt_lvl defns
451   = partition float_further defns
452   where
453     float_further (my_lvl, _) = my_lvl `ltLvl` ctxt_lvl
454 \end{code}
455
456 \begin{code}
457 floatsToBinds :: FloatBinds -> [CoreBind]
458 floatsToBinds floats = map snd floats
459
460 floatsToBindPairs :: FloatBinds -> [(Id,CoreExpr)]
461
462 floatsToBindPairs floats = concat (map mk_pairs floats)
463   where
464    mk_pairs (_, Rec pairs)         = pairs
465    mk_pairs (_, NonRec binder rhs) = [(binder,rhs)]
466
467 install :: FloatBinds -> CoreExpr -> CoreExpr
468
469 install defn_groups expr
470   = foldr install_group expr defn_groups
471   where
472     install_group (_, defns) body = Let defns body
473 \end{code}