remove empty dir
[ghc-hetmet.git] / ghc / compiler / simplCore / FloatOut.lhs
1 %
2 % (c) The GRASP/AQUA Project, Glasgow University, 1992-1998
3 %
4 \section[FloatOut]{Float bindings outwards (towards the top level)}
5
6 ``Long-distance'' floating of bindings towards the top level.
7
8 \begin{code}
9 module FloatOut ( floatOutwards ) where
10
11 #include "HsVersions.h"
12
13 import CoreSyn
14 import CoreUtils        ( mkSCC, exprIsHNF, exprIsTrivial )
15
16 import DynFlags ( DynFlags, DynFlag(..), FloatOutSwitches(..) )
17 import ErrUtils         ( dumpIfSet_dyn )
18 import CostCentre       ( dupifyCC, CostCentre )
19 import Id               ( Id, idType )
20 import Type             ( isUnLiftedType )
21 import CoreLint         ( showPass, endPass )
22 import SetLevels        ( Level(..), LevelledExpr, LevelledBind,
23                           setLevels, ltMajLvl, ltLvl, isTopLvl )
24 import UniqSupply       ( UniqSupply )
25 import List             ( partition )
26 import Outputable
27 import Util             ( notNull )
28 \end{code}
29
30         -----------------
31         Overall game plan
32         -----------------
33
34 The Big Main Idea is:
35
36         To float out sub-expressions that can thereby get outside
37         a non-one-shot value lambda, and hence may be shared.
38
39
40 To achieve this we may need to do two thing:
41
42    a) Let-bind the sub-expression:
43
44         f (g x)  ==>  let lvl = f (g x) in lvl
45
46       Now we can float the binding for 'lvl'.  
47
48    b) More than that, we may need to abstract wrt a type variable
49
50         \x -> ... /\a -> let v = ...a... in ....
51
52       Here the binding for v mentions 'a' but not 'x'.  So we
53       abstract wrt 'a', to give this binding for 'v':
54
55             vp = /\a -> ...a...
56             v  = vp a
57
58       Now the binding for vp can float out unimpeded.
59       I can't remember why this case seemed important enough to
60       deal with, but I certainly found cases where important floats
61       didn't happen if we did not abstract wrt tyvars.
62
63 With this in mind we can also achieve another goal: lambda lifting.
64 We can make an arbitrary (function) binding float to top level by
65 abstracting wrt *all* local variables, not just type variables, leaving
66 a binding that can be floated right to top level.  Whether or not this
67 happens is controlled by a flag.
68
69
70 Random comments
71 ~~~~~~~~~~~~~~~
72
73 At the moment we never float a binding out to between two adjacent
74 lambdas.  For example:
75
76 @
77         \x y -> let t = x+x in ...
78 ===>
79         \x -> let t = x+x in \y -> ...
80 @
81 Reason: this is less efficient in the case where the original lambda
82 is never partially applied.
83
84 But there's a case I've seen where this might not be true.  Consider:
85 @
86 elEm2 x ys
87   = elem' x ys
88   where
89     elem' _ []  = False
90     elem' x (y:ys)      = x==y || elem' x ys
91 @
92 It turns out that this generates a subexpression of the form
93 @
94         \deq x ys -> let eq = eqFromEqDict deq in ...
95 @
96 vwhich might usefully be separated to
97 @
98         \deq -> let eq = eqFromEqDict deq in \xy -> ...
99 @
100 Well, maybe.  We don't do this at the moment.
101
102 \begin{code}
103 type FloatBind     = (Level, CoreBind)  -- INVARIANT: a FloatBind is always lifted
104 type FloatBinds    = [FloatBind]        
105 \end{code}
106
107 %************************************************************************
108 %*                                                                      *
109 \subsection[floatOutwards]{@floatOutwards@: let-floating interface function}
110 %*                                                                      *
111 %************************************************************************
112
113 \begin{code}
114 floatOutwards :: FloatOutSwitches
115               -> DynFlags
116               -> UniqSupply 
117               -> [CoreBind] -> IO [CoreBind]
118
119 floatOutwards float_sws dflags us pgm
120   = do {
121         showPass dflags float_msg ;
122
123         let { annotated_w_levels = setLevels float_sws pgm us ;
124               (fss, binds_s')    = unzip (map floatTopBind annotated_w_levels)
125             } ;
126
127         dumpIfSet_dyn dflags Opt_D_verbose_core2core "Levels added:"
128                   (vcat (map ppr annotated_w_levels));
129
130         let { (tlets, ntlets, lams) = get_stats (sum_stats fss) };
131
132         dumpIfSet_dyn dflags Opt_D_dump_simpl_stats "FloatOut stats:"
133                 (hcat [ int tlets,  ptext SLIT(" Lets floated to top level; "),
134                         int ntlets, ptext SLIT(" Lets floated elsewhere; from "),
135                         int lams,   ptext SLIT(" Lambda groups")]);
136
137         endPass dflags float_msg  Opt_D_verbose_core2core (concat binds_s')
138                         {- no specific flag for dumping float-out -} 
139     }
140   where
141     float_msg = showSDoc (text "Float out" <+> parens (sws float_sws))
142     sws (FloatOutSw lam const) = pp_not lam   <+> text "lambdas" <> comma <+>
143                                  pp_not const <+> text "constants"
144     pp_not True  = empty
145     pp_not False = text "not"
146
147 floatTopBind bind@(NonRec _ _)
148   = case (floatBind bind) of { (fs, floats, bind') ->
149     (fs, floatsToBinds floats ++ [bind'])
150     }
151
152 floatTopBind bind@(Rec _)
153   = case (floatBind bind) of { (fs, floats, Rec pairs') ->
154     WARN( notNull floats, ppr bind $$ ppr floats )
155     (fs, [Rec (floatsToBindPairs floats ++ pairs')]) }
156 \end{code}
157
158 %************************************************************************
159 %*                                                                      *
160 \subsection[FloatOut-Bind]{Floating in a binding (the business end)}
161 %*                                                                      *
162 %************************************************************************
163
164
165 \begin{code}
166 floatBind :: LevelledBind
167           -> (FloatStats, FloatBinds, CoreBind)
168
169 floatBind (NonRec (TB name level) rhs)
170   = case (floatNonRecRhs level rhs) of { (fs, rhs_floats, rhs') ->
171     (fs, rhs_floats, NonRec name rhs') }
172
173 floatBind bind@(Rec pairs)
174   = case (unzip3 (map do_pair pairs)) of { (fss, rhss_floats, new_pairs) ->
175
176     if not (isTopLvl bind_dest_level) then
177         -- Standard case; the floated bindings can't mention the
178         -- binders, because they couldn't be escaping a major level
179         -- if so.
180         (sum_stats fss, concat rhss_floats, Rec new_pairs)
181     else
182         -- In a recursive binding, *destined for* the top level
183         -- (only), the rhs floats may contain references to the 
184         -- bound things.  For example
185         --      f = ...(let v = ...f... in b) ...
186         --  might get floated to
187         --      v = ...f...
188         --      f = ... b ...
189         -- and hence we must (pessimistically) make all the floats recursive
190         -- with the top binding.  Later dependency analysis will unravel it.
191         --
192         -- This can only happen for bindings destined for the top level,
193         -- because only then will partitionByMajorLevel allow through a binding
194         -- that only differs in its minor level
195         (sum_stats fss, [],
196          Rec (new_pairs ++ floatsToBindPairs (concat rhss_floats)))
197     }
198   where
199     bind_dest_level = getBindLevel bind
200
201     do_pair (TB name level, rhs)
202       = case (floatRhs level rhs) of { (fs, rhs_floats, rhs') ->
203         (fs, rhs_floats, (name, rhs'))
204         }
205 \end{code}
206
207 %************************************************************************
208
209 \subsection[FloatOut-Expr]{Floating in expressions}
210 %*                                                                      *
211 %************************************************************************
212
213 \begin{code}
214 floatExpr, floatRhs, floatNonRecRhs
215          :: Level
216          -> LevelledExpr
217          -> (FloatStats, FloatBinds, CoreExpr)
218
219 floatRhs lvl arg        -- Used rec rhss, and case-alternative rhss
220   = case (floatExpr lvl arg) of { (fsa, floats, arg') ->
221     case (partitionByMajorLevel lvl floats) of { (floats', heres) ->
222         -- Dump bindings that aren't going to escape from a lambda;
223         -- in particular, we must dump the ones that are bound by 
224         -- the rec or case alternative
225     (fsa, floats', install heres arg') }}
226
227 floatNonRecRhs lvl arg  -- Used for nested non-rec rhss, and fn args
228   = case (floatExpr lvl arg) of { (fsa, floats, arg') ->
229         -- Dump bindings that aren't going to escape from a lambda
230         -- This isn't a scoping issue (the binder isn't in scope in the RHS of a non-rec binding)
231         -- Rather, it is to avoid floating the x binding out of
232         --      f (let x = e in b)
233         -- unnecessarily.  But we first test for values or trival rhss,
234         -- because (in particular) we don't want to insert new bindings between
235         -- the "=" and the "\".  E.g.
236         --      f = \x -> let <bind> in <body>
237         -- We do not want
238         --      f = let <bind> in \x -> <body>
239         -- (a) The simplifier will immediately float it further out, so we may
240         --      as well do so right now; in general, keeping rhss as manifest 
241         --      values is good
242         -- (b) If a float-in pass follows immediately, it might add yet more
243         --      bindings just after the '='.  And some of them might (correctly)
244         --      be strict even though the 'let f' is lazy, because f, being a value,
245         --      gets its demand-info zapped by the simplifier.
246     if exprIsHNF arg' || exprIsTrivial arg' then
247         (fsa, floats, arg')
248     else
249     case (partitionByMajorLevel lvl floats) of { (floats', heres) ->
250     (fsa, floats', install heres arg') }}
251
252 floatExpr _ (Var v)   = (zeroStats, [], Var v)
253 floatExpr _ (Type ty) = (zeroStats, [], Type ty)
254 floatExpr _ (Lit lit) = (zeroStats, [], Lit lit)
255           
256 floatExpr lvl (App e a)
257   = case (floatExpr      lvl e) of { (fse, floats_e, e') ->
258     case (floatNonRecRhs lvl a) of { (fsa, floats_a, a') ->
259     (fse `add_stats` fsa, floats_e ++ floats_a, App e' a') }}
260
261 floatExpr lvl lam@(Lam _ _)
262   = let
263         (bndrs_w_lvls, body) = collectBinders lam
264         bndrs                = [b | TB b _ <- bndrs_w_lvls]
265         lvls                 = [l | TB b l <- bndrs_w_lvls]
266
267         -- For the all-tyvar case we are prepared to pull 
268         -- the lets out, to implement the float-out-of-big-lambda
269         -- transform; but otherwise we only float bindings that are
270         -- going to escape a value lambda.
271         -- In particular, for one-shot lambdas we don't float things
272         -- out; we get no saving by so doing.
273         partition_fn | all isTyVar bndrs = partitionByLevel
274                      | otherwise         = partitionByMajorLevel
275     in
276     case (floatExpr (last lvls) body) of { (fs, floats, body') ->
277
278         -- Dump any bindings which absolutely cannot go any further
279     case (partition_fn (head lvls) floats)      of { (floats', heres) ->
280
281     (add_to_stats fs floats', floats', mkLams bndrs (install heres body'))
282     }}
283
284 floatExpr lvl (Note note@(SCC cc) expr)
285   = case (floatExpr lvl expr)    of { (fs, floating_defns, expr') ->
286     let
287         -- Annotate bindings floated outwards past an scc expression
288         -- with the cc.  We mark that cc as "duplicated", though.
289
290         annotated_defns = annotate (dupifyCC cc) floating_defns
291     in
292     (fs, annotated_defns, Note note expr') }
293   where
294     annotate :: CostCentre -> FloatBinds -> FloatBinds
295
296     annotate dupd_cc defn_groups
297       = [ (level, ann_bind floater) | (level, floater) <- defn_groups ]
298       where
299         ann_bind (NonRec binder rhs)
300           = NonRec binder (mkSCC dupd_cc rhs)
301
302         ann_bind (Rec pairs)
303           = Rec [(binder, mkSCC dupd_cc rhs) | (binder, rhs) <- pairs]
304
305 floatExpr lvl (Note InlineMe expr)      -- Other than SCCs
306   = case floatExpr InlineCtxt expr of { (fs, floating_defns, expr') ->
307         -- There can be some floating_defns, arising from
308         -- ordinary lets that were there all the time.  It seems
309         -- more efficient to test once here than to avoid putting
310         -- them into floating_defns (which would mean testing for
311         -- inlineCtxt  at every let)
312     (fs, [], Note InlineMe (install floating_defns expr')) }    -- See notes in SetLevels
313
314 floatExpr lvl (Note note expr)  -- Other than SCCs
315   = case (floatExpr lvl expr)    of { (fs, floating_defns, expr') ->
316     (fs, floating_defns, Note note expr') }
317
318 floatExpr lvl (Let (NonRec (TB bndr bndr_lvl) rhs) body)
319   | isUnLiftedType (idType bndr)        -- Treat unlifted lets just like a case
320   = case floatExpr lvl rhs      of { (fs, rhs_floats, rhs') ->
321     case floatRhs bndr_lvl body of { (fs, body_floats, body') ->
322     (fs, rhs_floats ++ body_floats, Let (NonRec bndr rhs') body') }}
323
324 floatExpr lvl (Let bind body)
325   = case (floatBind bind)     of { (fsb, rhs_floats,  bind') ->
326     case (floatExpr lvl body) of { (fse, body_floats, body') ->
327     (add_stats fsb fse,
328      rhs_floats ++ [(bind_lvl, bind')] ++ body_floats,
329      body')  }}
330   where
331     bind_lvl = getBindLevel bind
332
333 floatExpr lvl (Case scrut (TB case_bndr case_lvl) ty alts)
334   = case floatExpr lvl scrut    of { (fse, fde, scrut') ->
335     case floatList float_alt alts       of { (fsa, fda, alts')  ->
336     (add_stats fse fsa, fda ++ fde, Case scrut' case_bndr ty alts')
337     }}
338   where
339         -- Use floatRhs for the alternatives, so that we
340         -- don't gratuitiously float bindings out of the RHSs
341     float_alt (con, bs, rhs)
342         = case (floatRhs case_lvl rhs)  of { (fs, rhs_floats, rhs') ->
343           (fs, rhs_floats, (con, [b | TB b _ <- bs], rhs')) }
344
345
346 floatList :: (a -> (FloatStats, FloatBinds, b)) -> [a] -> (FloatStats, FloatBinds, [b])
347 floatList f [] = (zeroStats, [], [])
348 floatList f (a:as) = case f a            of { (fs_a,  binds_a,  b)  ->
349                      case floatList f as of { (fs_as, binds_as, bs) ->
350                      (fs_a `add_stats` fs_as, binds_a ++ binds_as, b:bs) }}
351 \end{code}
352
353 %************************************************************************
354 %*                                                                      *
355 \subsection{Utility bits for floating stats}
356 %*                                                                      *
357 %************************************************************************
358
359 I didn't implement this with unboxed numbers.  I don't want to be too
360 strict in this stuff, as it is rarely turned on.  (WDP 95/09)
361
362 \begin{code}
363 data FloatStats
364   = FlS Int  -- Number of top-floats * lambda groups they've been past
365         Int  -- Number of non-top-floats * lambda groups they've been past
366         Int  -- Number of lambda (groups) seen
367
368 get_stats (FlS a b c) = (a, b, c)
369
370 zeroStats = FlS 0 0 0
371
372 sum_stats xs = foldr add_stats zeroStats xs
373
374 add_stats (FlS a1 b1 c1) (FlS a2 b2 c2)
375   = FlS (a1 + a2) (b1 + b2) (c1 + c2)
376
377 add_to_stats (FlS a b c) floats
378   = FlS (a + length top_floats) (b + length other_floats) (c + 1)
379   where
380     (top_floats, other_floats) = partition to_very_top floats
381
382     to_very_top (my_lvl, _) = isTopLvl my_lvl
383 \end{code}
384
385
386 %************************************************************************
387 %*                                                                      *
388 \subsection{Utility bits for floating}
389 %*                                                                      *
390 %************************************************************************
391
392 \begin{code}
393 getBindLevel (NonRec (TB _ lvl) _)      = lvl
394 getBindLevel (Rec (((TB _ lvl), _) : _)) = lvl
395 \end{code}
396
397 \begin{code}
398 partitionByMajorLevel, partitionByLevel
399         :: Level                -- Partitioning level
400
401         -> FloatBinds           -- Defns to be divided into 2 piles...
402
403         -> (FloatBinds, -- Defns  with level strictly < partition level,
404             FloatBinds) -- The rest
405
406
407 partitionByMajorLevel ctxt_lvl defns
408   = partition float_further defns
409   where
410         -- Float it if we escape a value lambda, or if we get to the top level
411     float_further (my_lvl, bind) = my_lvl `ltMajLvl` ctxt_lvl || isTopLvl my_lvl
412         -- The isTopLvl part says that if we can get to the top level, say "yes" anyway
413         -- This means that 
414         --      x = f e
415         -- transforms to 
416         --    lvl = e
417         --    x = f lvl
418         -- which is as it should be
419
420 partitionByLevel ctxt_lvl defns
421   = partition float_further defns
422   where
423     float_further (my_lvl, _) = my_lvl `ltLvl` ctxt_lvl
424 \end{code}
425
426 \begin{code}
427 floatsToBinds :: FloatBinds -> [CoreBind]
428 floatsToBinds floats = map snd floats
429
430 floatsToBindPairs :: FloatBinds -> [(Id,CoreExpr)]
431
432 floatsToBindPairs floats = concat (map mk_pairs floats)
433   where
434    mk_pairs (_, Rec pairs)         = pairs
435    mk_pairs (_, NonRec binder rhs) = [(binder,rhs)]
436
437 install :: FloatBinds -> CoreExpr -> CoreExpr
438
439 install defn_groups expr
440   = foldr install_group expr defn_groups
441   where
442     install_group (_, defns) body = Let defns body
443 \end{code}