d0a27deaeeb18f7a06f58f709a49cc725f56d2d5
[ghc-hetmet.git] / compiler / simplCore / FloatOut.lhs
1 %
2 % (c) The GRASP/AQUA Project, Glasgow University, 1992-1998
3 %
4 \section[FloatOut]{Float bindings outwards (towards the top level)}
5
6 ``Long-distance'' floating of bindings towards the top level.
7
8 \begin{code}
9 {-# OPTIONS -w #-}
10 -- The above warning supression flag is a temporary kludge.
11 -- While working on this module you are encouraged to remove it and fix
12 -- any warnings in the module. See
13 --     http://hackage.haskell.org/trac/ghc/wiki/Commentary/CodingStyle#Warnings
14 -- for details
15
16 module FloatOut ( floatOutwards ) where
17
18 import CoreSyn
19 import CoreUtils
20
21 import DynFlags ( DynFlags, DynFlag(..), FloatOutSwitches(..) )
22 import ErrUtils         ( dumpIfSet_dyn )
23 import CostCentre       ( dupifyCC, CostCentre )
24 import Id               ( Id, idType )
25 import Type             ( isUnLiftedType )
26 import CoreLint         ( showPass, endPass )
27 import SetLevels        ( Level(..), LevelledExpr, LevelledBind,
28                           setLevels, ltMajLvl, ltLvl, isTopLvl )
29 import UniqSupply       ( UniqSupply )
30 import List             ( partition )
31 import Outputable
32 import FastString
33 \end{code}
34
35         -----------------
36         Overall game plan
37         -----------------
38
39 The Big Main Idea is:
40
41         To float out sub-expressions that can thereby get outside
42         a non-one-shot value lambda, and hence may be shared.
43
44
45 To achieve this we may need to do two thing:
46
47    a) Let-bind the sub-expression:
48
49         f (g x)  ==>  let lvl = f (g x) in lvl
50
51       Now we can float the binding for 'lvl'.  
52
53    b) More than that, we may need to abstract wrt a type variable
54
55         \x -> ... /\a -> let v = ...a... in ....
56
57       Here the binding for v mentions 'a' but not 'x'.  So we
58       abstract wrt 'a', to give this binding for 'v':
59
60             vp = /\a -> ...a...
61             v  = vp a
62
63       Now the binding for vp can float out unimpeded.
64       I can't remember why this case seemed important enough to
65       deal with, but I certainly found cases where important floats
66       didn't happen if we did not abstract wrt tyvars.
67
68 With this in mind we can also achieve another goal: lambda lifting.
69 We can make an arbitrary (function) binding float to top level by
70 abstracting wrt *all* local variables, not just type variables, leaving
71 a binding that can be floated right to top level.  Whether or not this
72 happens is controlled by a flag.
73
74
75 Random comments
76 ~~~~~~~~~~~~~~~
77
78 At the moment we never float a binding out to between two adjacent
79 lambdas.  For example:
80
81 @
82         \x y -> let t = x+x in ...
83 ===>
84         \x -> let t = x+x in \y -> ...
85 @
86 Reason: this is less efficient in the case where the original lambda
87 is never partially applied.
88
89 But there's a case I've seen where this might not be true.  Consider:
90 @
91 elEm2 x ys
92   = elem' x ys
93   where
94     elem' _ []  = False
95     elem' x (y:ys)      = x==y || elem' x ys
96 @
97 It turns out that this generates a subexpression of the form
98 @
99         \deq x ys -> let eq = eqFromEqDict deq in ...
100 @
101 vwhich might usefully be separated to
102 @
103         \deq -> let eq = eqFromEqDict deq in \xy -> ...
104 @
105 Well, maybe.  We don't do this at the moment.
106
107 \begin{code}
108 type FloatBind     = (Level, CoreBind)  -- INVARIANT: a FloatBind is always lifted
109 type FloatBinds    = [FloatBind]        
110 \end{code}
111
112 %************************************************************************
113 %*                                                                      *
114 \subsection[floatOutwards]{@floatOutwards@: let-floating interface function}
115 %*                                                                      *
116 %************************************************************************
117
118 \begin{code}
119 floatOutwards :: FloatOutSwitches
120               -> DynFlags
121               -> UniqSupply 
122               -> [CoreBind] -> IO [CoreBind]
123
124 floatOutwards float_sws dflags us pgm
125   = do {
126         showPass dflags float_msg ;
127
128         let { annotated_w_levels = setLevels float_sws pgm us ;
129               (fss, binds_s')    = unzip (map floatTopBind annotated_w_levels)
130             } ;
131
132         dumpIfSet_dyn dflags Opt_D_verbose_core2core "Levels added:"
133                   (vcat (map ppr annotated_w_levels));
134
135         let { (tlets, ntlets, lams) = get_stats (sum_stats fss) };
136
137         dumpIfSet_dyn dflags Opt_D_dump_simpl_stats "FloatOut stats:"
138                 (hcat [ int tlets,  ptext (sLit " Lets floated to top level; "),
139                         int ntlets, ptext (sLit " Lets floated elsewhere; from "),
140                         int lams,   ptext (sLit " Lambda groups")]);
141
142         endPass dflags float_msg  Opt_D_verbose_core2core (concat binds_s')
143                         {- no specific flag for dumping float-out -} 
144     }
145   where
146     float_msg = showSDoc (text "Float out" <+> parens (sws float_sws))
147     sws (FloatOutSw lam const) = pp_not lam   <+> text "lambdas" <> comma <+>
148                                  pp_not const <+> text "constants"
149     pp_not True  = empty
150     pp_not False = text "not"
151
152 floatTopBind bind
153   = case (floatBind bind) of { (fs, floats) ->
154     (fs, floatsToBinds floats)
155     }
156 \end{code}
157
158 %************************************************************************
159 %*                                                                      *
160 \subsection[FloatOut-Bind]{Floating in a binding (the business end)}
161 %*                                                                      *
162 %************************************************************************
163
164
165 \begin{code}
166 floatBind :: LevelledBind -> (FloatStats, FloatBinds)
167
168 floatBind (NonRec (TB name level) rhs)
169   = case (floatRhs level rhs) of { (fs, rhs_floats, rhs') ->
170     (fs, rhs_floats ++ [(level, NonRec name rhs')]) }
171
172 floatBind bind@(Rec pairs)
173   = case (unzip3 (map do_pair pairs)) of { (fss, rhss_floats, new_pairs) ->
174     let rhs_floats = concat rhss_floats in
175
176     if not (isTopLvl bind_dest_lvl) then
177         -- Find which bindings float out at least one lambda beyond this one
178         -- These ones can't mention the binders, because they couldn't 
179         -- be escaping a major level if so.
180         -- The ones that are not going further can join the letrec;
181         -- they may not be mutually recursive but the occurrence analyser will
182         -- find that out.
183         case (partitionByMajorLevel bind_dest_lvl rhs_floats) of { (floats', heres) ->
184         (sum_stats fss, floats' ++ [(bind_dest_lvl, Rec (floatsToBindPairs heres ++ new_pairs))]) }
185     else
186         -- In a recursive binding, *destined for* the top level
187         -- (only), the rhs floats may contain references to the 
188         -- bound things.  For example
189         --      f = ...(let v = ...f... in b) ...
190         --  might get floated to
191         --      v = ...f...
192         --      f = ... b ...
193         -- and hence we must (pessimistically) make all the floats recursive
194         -- with the top binding.  Later dependency analysis will unravel it.
195         --
196         -- This can only happen for bindings destined for the top level,
197         -- because only then will partitionByMajorLevel allow through a binding
198         -- that only differs in its minor level
199         (sum_stats fss, [(bind_dest_lvl, Rec (new_pairs ++ floatsToBindPairs rhs_floats))])
200     }
201   where
202     bind_dest_lvl = getBindLevel bind
203
204     do_pair (TB name level, rhs)
205       = case (floatRhs level rhs) of { (fs, rhs_floats, rhs') ->
206         (fs, rhs_floats, (name, rhs'))
207         }
208 \end{code}
209
210 %************************************************************************
211
212 \subsection[FloatOut-Expr]{Floating in expressions}
213 %*                                                                      *
214 %************************************************************************
215
216 \begin{code}
217 floatExpr, floatRhs, floatCaseAlt
218          :: Level
219          -> LevelledExpr
220          -> (FloatStats, FloatBinds, CoreExpr)
221
222 floatCaseAlt lvl arg    -- Used rec rhss, and case-alternative rhss
223   = case (floatExpr lvl arg) of { (fsa, floats, arg') ->
224     case (partitionByMajorLevel lvl floats) of { (floats', heres) ->
225         -- Dump bindings that aren't going to escape from a lambda;
226         -- in particular, we must dump the ones that are bound by 
227         -- the rec or case alternative
228     (fsa, floats', install heres arg') }}
229
230 floatRhs lvl arg        -- Used for nested non-rec rhss, and fn args
231                         -- See Note [Floating out of RHS]
232   = case (floatExpr lvl arg) of { (fsa, floats, arg') ->
233     if exprIsCheap arg' then    
234         (fsa, floats, arg')
235     else
236     case (partitionByMajorLevel lvl floats) of { (floats', heres) ->
237     (fsa, floats', install heres arg') }}
238
239 -- Note [Floating out of RHSs]
240 -- ~~~~~~~~~~~~~~~~~~~~~~~~~~~
241 -- Dump bindings that aren't going to escape from a lambda
242 -- This isn't a scoping issue (the binder isn't in scope in the RHS 
243 --      of a non-rec binding)
244 -- Rather, it is to avoid floating the x binding out of
245 --      f (let x = e in b)
246 -- unnecessarily.  But we first test for values or trival rhss,
247 -- because (in particular) we don't want to insert new bindings between
248 -- the "=" and the "\".  E.g.
249 --      f = \x -> let <bind> in <body>
250 -- We do not want
251 --      f = let <bind> in \x -> <body>
252 -- (a) The simplifier will immediately float it further out, so we may
253 --      as well do so right now; in general, keeping rhss as manifest 
254 --      values is good
255 -- (b) If a float-in pass follows immediately, it might add yet more
256 --      bindings just after the '='.  And some of them might (correctly)
257 --      be strict even though the 'let f' is lazy, because f, being a value,
258 --      gets its demand-info zapped by the simplifier.
259 --
260 -- We use exprIsCheap because that is also what's used by the simplifier
261 -- to decide whether to float a let out of a let
262
263 floatExpr _ (Var v)   = (zeroStats, [], Var v)
264 floatExpr _ (Type ty) = (zeroStats, [], Type ty)
265 floatExpr _ (Lit lit) = (zeroStats, [], Lit lit)
266           
267 floatExpr lvl (App e a)
268   = case (floatExpr      lvl e) of { (fse, floats_e, e') ->
269     case (floatRhs lvl a)       of { (fsa, floats_a, a') ->
270     (fse `add_stats` fsa, floats_e ++ floats_a, App e' a') }}
271
272 floatExpr lvl lam@(Lam _ _)
273   = let
274         (bndrs_w_lvls, body) = collectBinders lam
275         bndrs                = [b | TB b _ <- bndrs_w_lvls]
276         lvls                 = [l | TB b l <- bndrs_w_lvls]
277
278         -- For the all-tyvar case we are prepared to pull 
279         -- the lets out, to implement the float-out-of-big-lambda
280         -- transform; but otherwise we only float bindings that are
281         -- going to escape a value lambda.
282         -- In particular, for one-shot lambdas we don't float things
283         -- out; we get no saving by so doing.
284         partition_fn | all isTyVar bndrs = partitionByLevel
285                      | otherwise         = partitionByMajorLevel
286     in
287     case (floatExpr (last lvls) body) of { (fs, floats, body') ->
288
289         -- Dump any bindings which absolutely cannot go any further
290     case (partition_fn (head lvls) floats)      of { (floats', heres) ->
291
292     (add_to_stats fs floats', floats', mkLams bndrs (install heres body'))
293     }}
294
295 floatExpr lvl (Note note@(SCC cc) expr)
296   = case (floatExpr lvl expr)    of { (fs, floating_defns, expr') ->
297     let
298         -- Annotate bindings floated outwards past an scc expression
299         -- with the cc.  We mark that cc as "duplicated", though.
300
301         annotated_defns = annotate (dupifyCC cc) floating_defns
302     in
303     (fs, annotated_defns, Note note expr') }
304   where
305     annotate :: CostCentre -> FloatBinds -> FloatBinds
306
307     annotate dupd_cc defn_groups
308       = [ (level, ann_bind floater) | (level, floater) <- defn_groups ]
309       where
310         ann_bind (NonRec binder rhs)
311           = NonRec binder (mkSCC dupd_cc rhs)
312
313         ann_bind (Rec pairs)
314           = Rec [(binder, mkSCC dupd_cc rhs) | (binder, rhs) <- pairs]
315
316 floatExpr lvl (Note InlineMe expr)      -- Other than SCCs
317   = (zeroStats, [], Note InlineMe (unTag expr))
318         -- Do no floating at all inside INLINE. [_$_]
319         -- The SetLevels pass did not clone the bindings, so it's
320         -- unsafe to do any floating, even if we dump the results
321         -- inside the Note (which is what we used to do).
322
323 floatExpr lvl (Note note expr)  -- Other than SCCs
324   = case (floatExpr lvl expr)    of { (fs, floating_defns, expr') ->
325     (fs, floating_defns, Note note expr') }
326
327 floatExpr lvl (Cast expr co)
328   = case (floatExpr lvl expr)   of { (fs, floating_defns, expr') ->
329     (fs, floating_defns, Cast expr' co) }
330
331 floatExpr lvl (Let (NonRec (TB bndr bndr_lvl) rhs) body)
332   | isUnLiftedType (idType bndr)        -- Treat unlifted lets just like a case
333                                 -- I.e. floatExpr for rhs, floatCaseAlt for body
334   = case floatExpr lvl rhs          of { (fs, rhs_floats, rhs') ->
335     case floatCaseAlt bndr_lvl body of { (fs, body_floats, body') ->
336     (fs, rhs_floats ++ body_floats, Let (NonRec bndr rhs') body') }}
337
338 floatExpr lvl (Let bind body)
339   = case (floatBind bind)     of { (fsb, bind_floats) ->
340     case (floatExpr lvl body) of { (fse, body_floats, body') ->
341     (add_stats fsb fse,
342      bind_floats ++ body_floats,
343      body')  }}
344
345 floatExpr lvl (Case scrut (TB case_bndr case_lvl) ty alts)
346   = case floatExpr lvl scrut    of { (fse, fde, scrut') ->
347     case floatList float_alt alts       of { (fsa, fda, alts')  ->
348     (add_stats fse fsa, fda ++ fde, Case scrut' case_bndr ty alts')
349     }}
350   where
351         -- Use floatCaseAlt for the alternatives, so that we
352         -- don't gratuitiously float bindings out of the RHSs
353     float_alt (con, bs, rhs)
354         = case (floatCaseAlt case_lvl rhs)      of { (fs, rhs_floats, rhs') ->
355           (fs, rhs_floats, (con, [b | TB b _ <- bs], rhs')) }
356
357
358 floatList :: (a -> (FloatStats, FloatBinds, b)) -> [a] -> (FloatStats, FloatBinds, [b])
359 floatList f [] = (zeroStats, [], [])
360 floatList f (a:as) = case f a            of { (fs_a,  binds_a,  b)  ->
361                      case floatList f as of { (fs_as, binds_as, bs) ->
362                      (fs_a `add_stats` fs_as, binds_a ++ binds_as, b:bs) }}
363
364 unTagBndr :: TaggedBndr tag -> CoreBndr
365 unTagBndr (TB b _) = b
366
367 unTag :: TaggedExpr tag -> CoreExpr
368 unTag (Var v)     = Var v
369 unTag (Lit l)     = Lit l
370 unTag (Type ty)   = Type ty
371 unTag (Note n e)  = Note n (unTag e)
372 unTag (App e1 e2) = App (unTag e1) (unTag e2)
373 unTag (Lam b e)   = Lam (unTagBndr b) (unTag e)
374 unTag (Cast e co) = Cast (unTag e) co
375 unTag (Let (Rec prs) e)    = Let (Rec [(unTagBndr b,unTag r) | (b, r) <- prs]) (unTag e)
376 unTag (Let (NonRec b r) e) = Let (NonRec (unTagBndr b) (unTag r)) (unTag e)
377 unTag (Case e b ty alts)   = Case (unTag e) (unTagBndr b) ty
378                                   [(c, map unTagBndr bs, unTag r) | (c,bs,r) <- alts]
379 \end{code}
380
381 %************************************************************************
382 %*                                                                      *
383 \subsection{Utility bits for floating stats}
384 %*                                                                      *
385 %************************************************************************
386
387 I didn't implement this with unboxed numbers.  I don't want to be too
388 strict in this stuff, as it is rarely turned on.  (WDP 95/09)
389
390 \begin{code}
391 data FloatStats
392   = FlS Int  -- Number of top-floats * lambda groups they've been past
393         Int  -- Number of non-top-floats * lambda groups they've been past
394         Int  -- Number of lambda (groups) seen
395
396 get_stats (FlS a b c) = (a, b, c)
397
398 zeroStats = FlS 0 0 0
399
400 sum_stats xs = foldr add_stats zeroStats xs
401
402 add_stats (FlS a1 b1 c1) (FlS a2 b2 c2)
403   = FlS (a1 + a2) (b1 + b2) (c1 + c2)
404
405 add_to_stats (FlS a b c) floats
406   = FlS (a + length top_floats) (b + length other_floats) (c + 1)
407   where
408     (top_floats, other_floats) = partition to_very_top floats
409
410     to_very_top (my_lvl, _) = isTopLvl my_lvl
411 \end{code}
412
413
414 %************************************************************************
415 %*                                                                      *
416 \subsection{Utility bits for floating}
417 %*                                                                      *
418 %************************************************************************
419
420 \begin{code}
421 getBindLevel (NonRec (TB _ lvl) _)      = lvl
422 getBindLevel (Rec (((TB _ lvl), _) : _)) = lvl
423 \end{code}
424
425 \begin{code}
426 partitionByMajorLevel, partitionByLevel
427         :: Level                -- Partitioning level
428
429         -> FloatBinds           -- Defns to be divided into 2 piles...
430
431         -> (FloatBinds, -- Defns  with level strictly < partition level,
432             FloatBinds) -- The rest
433
434
435 partitionByMajorLevel ctxt_lvl defns
436   = partition float_further defns
437   where
438         -- Float it if we escape a value lambda, or if we get to the top level
439     float_further (my_lvl, bind) = my_lvl `ltMajLvl` ctxt_lvl || isTopLvl my_lvl
440         -- The isTopLvl part says that if we can get to the top level, say "yes" anyway
441         -- This means that 
442         --      x = f e
443         -- transforms to 
444         --    lvl = e
445         --    x = f lvl
446         -- which is as it should be
447
448 partitionByLevel ctxt_lvl defns
449   = partition float_further defns
450   where
451     float_further (my_lvl, _) = my_lvl `ltLvl` ctxt_lvl
452 \end{code}
453
454 \begin{code}
455 floatsToBinds :: FloatBinds -> [CoreBind]
456 floatsToBinds floats = map snd floats
457
458 floatsToBindPairs :: FloatBinds -> [(Id,CoreExpr)]
459
460 floatsToBindPairs floats = concat (map mk_pairs floats)
461   where
462    mk_pairs (_, Rec pairs)         = pairs
463    mk_pairs (_, NonRec binder rhs) = [(binder,rhs)]
464
465 install :: FloatBinds -> CoreExpr -> CoreExpr
466
467 install defn_groups expr
468   = foldr install_group expr defn_groups
469   where
470     install_group (_, defns) body = Let defns body
471 \end{code}