da59e91403bebb68aa7ef15b3d5784b201c3c220
[ghc-hetmet.git] / compiler / simplCore / FloatOut.lhs
1 %
2 % (c) The GRASP/AQUA Project, Glasgow University, 1992-1998
3 %
4 \section[FloatOut]{Float bindings outwards (towards the top level)}
5
6 ``Long-distance'' floating of bindings towards the top level.
7
8 \begin{code}
9 {-# OPTIONS -w #-}
10 -- The above warning supression flag is a temporary kludge.
11 -- While working on this module you are encouraged to remove it and fix
12 -- any warnings in the module. See
13 --     http://hackage.haskell.org/trac/ghc/wiki/CodingStyle#Warnings
14 -- for details
15
16 module FloatOut ( floatOutwards ) where
17
18 #include "HsVersions.h"
19
20 import CoreSyn
21 import CoreUtils
22
23 import DynFlags ( DynFlags, DynFlag(..), FloatOutSwitches(..) )
24 import ErrUtils         ( dumpIfSet_dyn )
25 import CostCentre       ( dupifyCC, CostCentre )
26 import Id               ( Id, idType )
27 import Type             ( isUnLiftedType )
28 import CoreLint         ( showPass, endPass )
29 import SetLevels        ( Level(..), LevelledExpr, LevelledBind,
30                           setLevels, ltMajLvl, ltLvl, isTopLvl )
31 import UniqSupply       ( UniqSupply )
32 import List             ( partition )
33 import Outputable
34 \end{code}
35
36         -----------------
37         Overall game plan
38         -----------------
39
40 The Big Main Idea is:
41
42         To float out sub-expressions that can thereby get outside
43         a non-one-shot value lambda, and hence may be shared.
44
45
46 To achieve this we may need to do two thing:
47
48    a) Let-bind the sub-expression:
49
50         f (g x)  ==>  let lvl = f (g x) in lvl
51
52       Now we can float the binding for 'lvl'.  
53
54    b) More than that, we may need to abstract wrt a type variable
55
56         \x -> ... /\a -> let v = ...a... in ....
57
58       Here the binding for v mentions 'a' but not 'x'.  So we
59       abstract wrt 'a', to give this binding for 'v':
60
61             vp = /\a -> ...a...
62             v  = vp a
63
64       Now the binding for vp can float out unimpeded.
65       I can't remember why this case seemed important enough to
66       deal with, but I certainly found cases where important floats
67       didn't happen if we did not abstract wrt tyvars.
68
69 With this in mind we can also achieve another goal: lambda lifting.
70 We can make an arbitrary (function) binding float to top level by
71 abstracting wrt *all* local variables, not just type variables, leaving
72 a binding that can be floated right to top level.  Whether or not this
73 happens is controlled by a flag.
74
75
76 Random comments
77 ~~~~~~~~~~~~~~~
78
79 At the moment we never float a binding out to between two adjacent
80 lambdas.  For example:
81
82 @
83         \x y -> let t = x+x in ...
84 ===>
85         \x -> let t = x+x in \y -> ...
86 @
87 Reason: this is less efficient in the case where the original lambda
88 is never partially applied.
89
90 But there's a case I've seen where this might not be true.  Consider:
91 @
92 elEm2 x ys
93   = elem' x ys
94   where
95     elem' _ []  = False
96     elem' x (y:ys)      = x==y || elem' x ys
97 @
98 It turns out that this generates a subexpression of the form
99 @
100         \deq x ys -> let eq = eqFromEqDict deq in ...
101 @
102 vwhich might usefully be separated to
103 @
104         \deq -> let eq = eqFromEqDict deq in \xy -> ...
105 @
106 Well, maybe.  We don't do this at the moment.
107
108 \begin{code}
109 type FloatBind     = (Level, CoreBind)  -- INVARIANT: a FloatBind is always lifted
110 type FloatBinds    = [FloatBind]        
111 \end{code}
112
113 %************************************************************************
114 %*                                                                      *
115 \subsection[floatOutwards]{@floatOutwards@: let-floating interface function}
116 %*                                                                      *
117 %************************************************************************
118
119 \begin{code}
120 floatOutwards :: FloatOutSwitches
121               -> DynFlags
122               -> UniqSupply 
123               -> [CoreBind] -> IO [CoreBind]
124
125 floatOutwards float_sws dflags us pgm
126   = do {
127         showPass dflags float_msg ;
128
129         let { annotated_w_levels = setLevels float_sws pgm us ;
130               (fss, binds_s')    = unzip (map floatTopBind annotated_w_levels)
131             } ;
132
133         dumpIfSet_dyn dflags Opt_D_verbose_core2core "Levels added:"
134                   (vcat (map ppr annotated_w_levels));
135
136         let { (tlets, ntlets, lams) = get_stats (sum_stats fss) };
137
138         dumpIfSet_dyn dflags Opt_D_dump_simpl_stats "FloatOut stats:"
139                 (hcat [ int tlets,  ptext SLIT(" Lets floated to top level; "),
140                         int ntlets, ptext SLIT(" Lets floated elsewhere; from "),
141                         int lams,   ptext SLIT(" Lambda groups")]);
142
143         endPass dflags float_msg  Opt_D_verbose_core2core (concat binds_s')
144                         {- no specific flag for dumping float-out -} 
145     }
146   where
147     float_msg = showSDoc (text "Float out" <+> parens (sws float_sws))
148     sws (FloatOutSw lam const) = pp_not lam   <+> text "lambdas" <> comma <+>
149                                  pp_not const <+> text "constants"
150     pp_not True  = empty
151     pp_not False = text "not"
152
153 floatTopBind bind
154   = case (floatBind bind) of { (fs, floats) ->
155     (fs, floatsToBinds floats)
156     }
157 \end{code}
158
159 %************************************************************************
160 %*                                                                      *
161 \subsection[FloatOut-Bind]{Floating in a binding (the business end)}
162 %*                                                                      *
163 %************************************************************************
164
165
166 \begin{code}
167 floatBind :: LevelledBind -> (FloatStats, FloatBinds)
168
169 floatBind (NonRec (TB name level) rhs)
170   = case (floatRhs level rhs) of { (fs, rhs_floats, rhs') ->
171     (fs, rhs_floats ++ [(level, NonRec name rhs')]) }
172
173 floatBind bind@(Rec pairs)
174   = case (unzip3 (map do_pair pairs)) of { (fss, rhss_floats, new_pairs) ->
175     let rhs_floats = concat rhss_floats in
176
177     if not (isTopLvl bind_dest_lvl) then
178         -- Find which bindings float out at least one lambda beyond this one
179         -- These ones can't mention the binders, because they couldn't 
180         -- be escaping a major level if so.
181         -- The ones that are not going further can join the letrec;
182         -- they may not be mutually recursive but the occurrence analyser will
183         -- find that out.
184         case (partitionByMajorLevel bind_dest_lvl rhs_floats) of { (floats', heres) ->
185         (sum_stats fss, floats' ++ [(bind_dest_lvl, Rec (floatsToBindPairs heres ++ new_pairs))]) }
186     else
187         -- In a recursive binding, *destined for* the top level
188         -- (only), the rhs floats may contain references to the 
189         -- bound things.  For example
190         --      f = ...(let v = ...f... in b) ...
191         --  might get floated to
192         --      v = ...f...
193         --      f = ... b ...
194         -- and hence we must (pessimistically) make all the floats recursive
195         -- with the top binding.  Later dependency analysis will unravel it.
196         --
197         -- This can only happen for bindings destined for the top level,
198         -- because only then will partitionByMajorLevel allow through a binding
199         -- that only differs in its minor level
200         (sum_stats fss, [(bind_dest_lvl, Rec (new_pairs ++ floatsToBindPairs rhs_floats))])
201     }
202   where
203     bind_dest_lvl = getBindLevel bind
204
205     do_pair (TB name level, rhs)
206       = case (floatRhs level rhs) of { (fs, rhs_floats, rhs') ->
207         (fs, rhs_floats, (name, rhs'))
208         }
209 \end{code}
210
211 %************************************************************************
212
213 \subsection[FloatOut-Expr]{Floating in expressions}
214 %*                                                                      *
215 %************************************************************************
216
217 \begin{code}
218 floatExpr, floatRhs, floatCaseAlt
219          :: Level
220          -> LevelledExpr
221          -> (FloatStats, FloatBinds, CoreExpr)
222
223 floatCaseAlt lvl arg    -- Used rec rhss, and case-alternative rhss
224   = case (floatExpr lvl arg) of { (fsa, floats, arg') ->
225     case (partitionByMajorLevel lvl floats) of { (floats', heres) ->
226         -- Dump bindings that aren't going to escape from a lambda;
227         -- in particular, we must dump the ones that are bound by 
228         -- the rec or case alternative
229     (fsa, floats', install heres arg') }}
230
231 floatRhs lvl arg        -- Used for nested non-rec rhss, and fn args
232                         -- See Note [Floating out of RHS]
233   = case (floatExpr lvl arg) of { (fsa, floats, arg') ->
234     if exprIsCheap arg' then    
235         (fsa, floats, arg')
236     else
237     case (partitionByMajorLevel lvl floats) of { (floats', heres) ->
238     (fsa, floats', install heres arg') }}
239
240 -- Note [Floating out of RHSs]
241 -- ~~~~~~~~~~~~~~~~~~~~~~~~~~~
242 -- Dump bindings that aren't going to escape from a lambda
243 -- This isn't a scoping issue (the binder isn't in scope in the RHS 
244 --      of a non-rec binding)
245 -- Rather, it is to avoid floating the x binding out of
246 --      f (let x = e in b)
247 -- unnecessarily.  But we first test for values or trival rhss,
248 -- because (in particular) we don't want to insert new bindings between
249 -- the "=" and the "\".  E.g.
250 --      f = \x -> let <bind> in <body>
251 -- We do not want
252 --      f = let <bind> in \x -> <body>
253 -- (a) The simplifier will immediately float it further out, so we may
254 --      as well do so right now; in general, keeping rhss as manifest 
255 --      values is good
256 -- (b) If a float-in pass follows immediately, it might add yet more
257 --      bindings just after the '='.  And some of them might (correctly)
258 --      be strict even though the 'let f' is lazy, because f, being a value,
259 --      gets its demand-info zapped by the simplifier.
260 --
261 -- We use exprIsCheap because that is also what's used by the simplifier
262 -- to decide whether to float a let out of a let
263
264 floatExpr _ (Var v)   = (zeroStats, [], Var v)
265 floatExpr _ (Type ty) = (zeroStats, [], Type ty)
266 floatExpr _ (Lit lit) = (zeroStats, [], Lit lit)
267           
268 floatExpr lvl (App e a)
269   = case (floatExpr      lvl e) of { (fse, floats_e, e') ->
270     case (floatRhs lvl a)       of { (fsa, floats_a, a') ->
271     (fse `add_stats` fsa, floats_e ++ floats_a, App e' a') }}
272
273 floatExpr lvl lam@(Lam _ _)
274   = let
275         (bndrs_w_lvls, body) = collectBinders lam
276         bndrs                = [b | TB b _ <- bndrs_w_lvls]
277         lvls                 = [l | TB b l <- bndrs_w_lvls]
278
279         -- For the all-tyvar case we are prepared to pull 
280         -- the lets out, to implement the float-out-of-big-lambda
281         -- transform; but otherwise we only float bindings that are
282         -- going to escape a value lambda.
283         -- In particular, for one-shot lambdas we don't float things
284         -- out; we get no saving by so doing.
285         partition_fn | all isTyVar bndrs = partitionByLevel
286                      | otherwise         = partitionByMajorLevel
287     in
288     case (floatExpr (last lvls) body) of { (fs, floats, body') ->
289
290         -- Dump any bindings which absolutely cannot go any further
291     case (partition_fn (head lvls) floats)      of { (floats', heres) ->
292
293     (add_to_stats fs floats', floats', mkLams bndrs (install heres body'))
294     }}
295
296 floatExpr lvl (Note note@(SCC cc) expr)
297   = case (floatExpr lvl expr)    of { (fs, floating_defns, expr') ->
298     let
299         -- Annotate bindings floated outwards past an scc expression
300         -- with the cc.  We mark that cc as "duplicated", though.
301
302         annotated_defns = annotate (dupifyCC cc) floating_defns
303     in
304     (fs, annotated_defns, Note note expr') }
305   where
306     annotate :: CostCentre -> FloatBinds -> FloatBinds
307
308     annotate dupd_cc defn_groups
309       = [ (level, ann_bind floater) | (level, floater) <- defn_groups ]
310       where
311         ann_bind (NonRec binder rhs)
312           = NonRec binder (mkSCC dupd_cc rhs)
313
314         ann_bind (Rec pairs)
315           = Rec [(binder, mkSCC dupd_cc rhs) | (binder, rhs) <- pairs]
316
317 floatExpr lvl (Note InlineMe expr)      -- Other than SCCs
318   = case floatExpr InlineCtxt expr of { (fs, floating_defns, expr') ->
319         -- There can be some floating_defns, arising from
320         -- ordinary lets that were there all the time.  It seems
321         -- more efficient to test once here than to avoid putting
322         -- them into floating_defns (which would mean testing for
323         -- inlineCtxt  at every let)
324     (fs, [], Note InlineMe (install floating_defns expr')) }    -- See notes in SetLevels
325
326 floatExpr lvl (Note note expr)  -- Other than SCCs
327   = case (floatExpr lvl expr)    of { (fs, floating_defns, expr') ->
328     (fs, floating_defns, Note note expr') }
329
330 floatExpr lvl (Cast expr co)
331   = case (floatExpr lvl expr)   of { (fs, floating_defns, expr') ->
332     (fs, floating_defns, Cast expr' co) }
333
334 floatExpr lvl (Let (NonRec (TB bndr bndr_lvl) rhs) body)
335   | isUnLiftedType (idType bndr)        -- Treat unlifted lets just like a case
336                                 -- I.e. floatExpr for rhs, floatCaseAlt for body
337   = case floatExpr lvl rhs          of { (fs, rhs_floats, rhs') ->
338     case floatCaseAlt bndr_lvl body of { (fs, body_floats, body') ->
339     (fs, rhs_floats ++ body_floats, Let (NonRec bndr rhs') body') }}
340
341 floatExpr lvl (Let bind body)
342   = case (floatBind bind)     of { (fsb, bind_floats) ->
343     case (floatExpr lvl body) of { (fse, body_floats, body') ->
344     (add_stats fsb fse,
345      bind_floats ++ body_floats,
346      body')  }}
347
348 floatExpr lvl (Case scrut (TB case_bndr case_lvl) ty alts)
349   = case floatExpr lvl scrut    of { (fse, fde, scrut') ->
350     case floatList float_alt alts       of { (fsa, fda, alts')  ->
351     (add_stats fse fsa, fda ++ fde, Case scrut' case_bndr ty alts')
352     }}
353   where
354         -- Use floatCaseAlt for the alternatives, so that we
355         -- don't gratuitiously float bindings out of the RHSs
356     float_alt (con, bs, rhs)
357         = case (floatCaseAlt case_lvl rhs)      of { (fs, rhs_floats, rhs') ->
358           (fs, rhs_floats, (con, [b | TB b _ <- bs], rhs')) }
359
360
361 floatList :: (a -> (FloatStats, FloatBinds, b)) -> [a] -> (FloatStats, FloatBinds, [b])
362 floatList f [] = (zeroStats, [], [])
363 floatList f (a:as) = case f a            of { (fs_a,  binds_a,  b)  ->
364                      case floatList f as of { (fs_as, binds_as, bs) ->
365                      (fs_a `add_stats` fs_as, binds_a ++ binds_as, b:bs) }}
366 \end{code}
367
368 %************************************************************************
369 %*                                                                      *
370 \subsection{Utility bits for floating stats}
371 %*                                                                      *
372 %************************************************************************
373
374 I didn't implement this with unboxed numbers.  I don't want to be too
375 strict in this stuff, as it is rarely turned on.  (WDP 95/09)
376
377 \begin{code}
378 data FloatStats
379   = FlS Int  -- Number of top-floats * lambda groups they've been past
380         Int  -- Number of non-top-floats * lambda groups they've been past
381         Int  -- Number of lambda (groups) seen
382
383 get_stats (FlS a b c) = (a, b, c)
384
385 zeroStats = FlS 0 0 0
386
387 sum_stats xs = foldr add_stats zeroStats xs
388
389 add_stats (FlS a1 b1 c1) (FlS a2 b2 c2)
390   = FlS (a1 + a2) (b1 + b2) (c1 + c2)
391
392 add_to_stats (FlS a b c) floats
393   = FlS (a + length top_floats) (b + length other_floats) (c + 1)
394   where
395     (top_floats, other_floats) = partition to_very_top floats
396
397     to_very_top (my_lvl, _) = isTopLvl my_lvl
398 \end{code}
399
400
401 %************************************************************************
402 %*                                                                      *
403 \subsection{Utility bits for floating}
404 %*                                                                      *
405 %************************************************************************
406
407 \begin{code}
408 getBindLevel (NonRec (TB _ lvl) _)      = lvl
409 getBindLevel (Rec (((TB _ lvl), _) : _)) = lvl
410 \end{code}
411
412 \begin{code}
413 partitionByMajorLevel, partitionByLevel
414         :: Level                -- Partitioning level
415
416         -> FloatBinds           -- Defns to be divided into 2 piles...
417
418         -> (FloatBinds, -- Defns  with level strictly < partition level,
419             FloatBinds) -- The rest
420
421
422 partitionByMajorLevel ctxt_lvl defns
423   = partition float_further defns
424   where
425         -- Float it if we escape a value lambda, or if we get to the top level
426     float_further (my_lvl, bind) = my_lvl `ltMajLvl` ctxt_lvl || isTopLvl my_lvl
427         -- The isTopLvl part says that if we can get to the top level, say "yes" anyway
428         -- This means that 
429         --      x = f e
430         -- transforms to 
431         --    lvl = e
432         --    x = f lvl
433         -- which is as it should be
434
435 partitionByLevel ctxt_lvl defns
436   = partition float_further defns
437   where
438     float_further (my_lvl, _) = my_lvl `ltLvl` ctxt_lvl
439 \end{code}
440
441 \begin{code}
442 floatsToBinds :: FloatBinds -> [CoreBind]
443 floatsToBinds floats = map snd floats
444
445 floatsToBindPairs :: FloatBinds -> [(Id,CoreExpr)]
446
447 floatsToBindPairs floats = concat (map mk_pairs floats)
448   where
449    mk_pairs (_, Rec pairs)         = pairs
450    mk_pairs (_, NonRec binder rhs) = [(binder,rhs)]
451
452 install :: FloatBinds -> CoreExpr -> CoreExpr
453
454 install defn_groups expr
455   = foldr install_group expr defn_groups
456   where
457     install_group (_, defns) body = Let defns body
458 \end{code}