5e8282e0dca66fc556b38da63170a65e13151ab1
[ghc-hetmet.git] / ghc / compiler / simplCore / FloatOut.lhs
1 %
2 % (c) The GRASP/AQUA Project, Glasgow University, 1992-1998
3 %
4 \section[FloatOut]{Float bindings outwards (towards the top level)}
5
6 ``Long-distance'' floating of bindings towards the top level.
7
8 \begin{code}
9 module FloatOut ( floatOutwards ) where
10
11 #include "HsVersions.h"
12
13 import CoreSyn
14 import CoreUtils        ( mkSCC )
15
16 import CmdLineOpts      ( DynFlags, DynFlag(..), FloatOutSwitches(..) )
17 import ErrUtils         ( dumpIfSet_dyn )
18 import CostCentre       ( dupifyCC, CostCentre )
19 import Id               ( Id )
20 import CoreLint         ( showPass, endPass )
21 import SetLevels        ( setLevels, Level(..), ltMajLvl, ltLvl, isTopLvl )
22 import UniqSupply       ( UniqSupply )
23 import List             ( partition )
24 import Outputable
25 \end{code}
26
27         -----------------
28         Overall game plan
29         -----------------
30
31 The Big Main Idea is:
32
33         To float out sub-expressions that can thereby get outside
34         a non-one-shot value lambda, and hence may be shared.
35
36
37 To achieve this we may need to do two thing:
38
39    a) Let-bind the sub-expression:
40
41         f (g x)  ==>  let lvl = f (g x) in lvl
42
43       Now we can float the binding for 'lvl'.  
44
45    b) More than that, we may need to abstract wrt a type variable
46
47         \x -> ... /\a -> let v = ...a... in ....
48
49       Here the binding for v mentions 'a' but not 'x'.  So we
50       abstract wrt 'a', to give this binding for 'v':
51
52             vp = /\a -> ...a...
53             v  = vp a
54
55       Now the binding for vp can float out unimpeded.
56       I can't remember why this case seemed important enough to
57       deal with, but I certainly found cases where important floats
58       didn't happen if we did not abstract wrt tyvars.
59
60 With this in mind we can also achieve another goal: lambda lifting.
61 We can make an arbitrary (function) binding float to top level by
62 abstracting wrt *all* local variables, not just type variables, leaving
63 a binding that can be floated right to top level.  Whether or not this
64 happens is controlled by a flag.
65
66
67 Random comments
68 ~~~~~~~~~~~~~~~
69
70 At the moment we never float a binding out to between two adjacent
71 lambdas.  For example:
72
73 @
74         \x y -> let t = x+x in ...
75 ===>
76         \x -> let t = x+x in \y -> ...
77 @
78 Reason: this is less efficient in the case where the original lambda
79 is never partially applied.
80
81 But there's a case I've seen where this might not be true.  Consider:
82 @
83 elEm2 x ys
84   = elem' x ys
85   where
86     elem' _ []  = False
87     elem' x (y:ys)      = x==y || elem' x ys
88 @
89 It turns out that this generates a subexpression of the form
90 @
91         \deq x ys -> let eq = eqFromEqDict deq in ...
92 @
93 vwhich might usefully be separated to
94 @
95         \deq -> let eq = eqFromEqDict deq in \xy -> ...
96 @
97 Well, maybe.  We don't do this at the moment.
98
99 \begin{code}
100 type LevelledExpr  = TaggedExpr Level
101 type LevelledBind  = TaggedBind Level
102 type FloatBind     = (Level, CoreBind)
103 type FloatBinds    = [FloatBind]
104 \end{code}
105
106 %************************************************************************
107 %*                                                                      *
108 \subsection[floatOutwards]{@floatOutwards@: let-floating interface function}
109 %*                                                                      *
110 %************************************************************************
111
112 \begin{code}
113 floatOutwards :: DynFlags
114               -> FloatOutSwitches
115               -> UniqSupply 
116               -> [CoreBind] -> IO [CoreBind]
117
118 floatOutwards dflags float_sws us pgm
119   = do {
120         showPass dflags float_msg ;
121
122         let { annotated_w_levels = setLevels float_sws pgm us ;
123               (fss, binds_s')    = unzip (map floatTopBind annotated_w_levels)
124             } ;
125
126         dumpIfSet_dyn dflags Opt_D_verbose_core2core "Levels added:"
127                   (vcat (map ppr annotated_w_levels));
128
129         let { (tlets, ntlets, lams) = get_stats (sum_stats fss) };
130
131         dumpIfSet_dyn dflags Opt_D_dump_simpl_stats "FloatOut stats:"
132                 (hcat [ int tlets,  ptext SLIT(" Lets floated to top level; "),
133                         int ntlets, ptext SLIT(" Lets floated elsewhere; from "),
134                         int lams,   ptext SLIT(" Lambda groups")]);
135
136         endPass dflags float_msg  Opt_D_verbose_core2core (concat binds_s')
137                         {- no specific flag for dumping float-out -} 
138     }
139   where
140     float_msg = showSDoc (text "Float out" <+> parens (sws float_sws))
141     sws (FloatOutSw lam const) = pp_not lam   <+> text "lambdas" <> comma <+>
142                                  pp_not const <+> text "constants"
143     pp_not True  = empty
144     pp_not False = text "not"
145
146 floatTopBind bind@(NonRec _ _)
147   = case (floatBind bind) of { (fs, floats, bind') ->
148     (fs, floatsToBinds floats ++ [bind'])
149     }
150
151 floatTopBind bind@(Rec _)
152   = case (floatBind bind) of { (fs, floats, Rec pairs') ->
153     WARN( not (null floats), ppr bind $$ ppr floats )
154     (fs, [Rec (floatsToBindPairs floats ++ pairs')]) }
155 \end{code}
156
157 %************************************************************************
158 %*                                                                      *
159 \subsection[FloatOut-Bind]{Floating in a binding (the business end)}
160 %*                                                                      *
161 %************************************************************************
162
163
164 \begin{code}
165 floatBind :: LevelledBind
166           -> (FloatStats, FloatBinds, CoreBind)
167
168 floatBind (NonRec (name,level) rhs)
169   = case (floatRhs level rhs) of { (fs, rhs_floats, rhs') ->
170     (fs, rhs_floats, NonRec name rhs') }
171
172 floatBind bind@(Rec pairs)
173   = case (unzip3 (map do_pair pairs)) of { (fss, rhss_floats, new_pairs) ->
174
175     if not (isTopLvl bind_level) then
176         -- Standard case
177         (sum_stats fss, concat rhss_floats, Rec new_pairs)
178     else
179         -- In a recursive binding, *destined for* the top level
180         -- (only), the rhs floats may contain references to the 
181         -- bound things.  For example
182         --
183         --      f = ...(let v = ...f... in b) ...
184         --
185         --  might get floated to
186         --
187         --      v = ...f...
188         --      f = ... b ...
189         --
190         -- and hence we must (pessimistically) make all the floats recursive
191         -- with the top binding.  Later dependency analysis will unravel it.
192         --
193         -- Can't happen on nested bindings because floatRhs will dump
194         -- the bindings in the RHS (partitionByMajorLevel treats top specially)
195         (sum_stats fss, [],
196          Rec (new_pairs ++ floatsToBindPairs (concat rhss_floats)))
197     }
198   where
199     bind_level = getBindLevel bind
200
201     do_pair ((name, level), rhs)
202       = case (floatRhs level rhs) of { (fs, rhs_floats, rhs') ->
203         (fs, rhs_floats, (name, rhs'))
204         }
205 \end{code}
206
207 %************************************************************************
208
209 \subsection[FloatOut-Expr]{Floating in expressions}
210 %*                                                                      *
211 %************************************************************************
212
213 \begin{code}
214 floatExpr, floatRhs
215          :: Level
216          -> LevelledExpr
217          -> (FloatStats, FloatBinds, CoreExpr)
218
219 floatRhs lvl arg
220   = case (floatExpr lvl arg) of { (fsa, floats, arg') ->
221     case (partitionByMajorLevel lvl floats) of { (floats', heres) ->
222         -- Dump bindings that aren't going to escape from a lambda
223         -- This is to avoid floating the x binding out of
224         --      f (let x = e in b)
225         -- unnecessarily.  It even causes a bug to do so if we have
226         --      y = writeArr# a n (let x = e in b)
227         -- because the y binding is an expr-ok-for-speculation one.
228         -- [SLPJ Dec 01: I don't understand this last comment; 
229         --               writeArr# is not ok-for-spec because of its side effect]
230     (fsa, floats', install heres arg') }}
231
232 floatExpr _ (Var v)   = (zeroStats, [], Var v)
233 floatExpr _ (Type ty) = (zeroStats, [], Type ty)
234 floatExpr _ (Lit lit) = (zeroStats, [], Lit lit)
235           
236 floatExpr lvl (App e a)
237   = case (floatExpr lvl e) of { (fse, floats_e, e') ->
238     case (floatRhs  lvl a) of { (fsa, floats_a, a') ->
239     (fse `add_stats` fsa, floats_e ++ floats_a, App e' a') }}
240
241 floatExpr lvl lam@(Lam _ _)
242   = let
243         (bndrs_w_lvls, body) = collectBinders lam
244         (bndrs, lvls)        = unzip bndrs_w_lvls
245
246         -- For the all-tyvar case we are prepared to pull 
247         -- the lets out, to implement the float-out-of-big-lambda
248         -- transform; but otherwise we only float bindings that are
249         -- going to escape a value lambda.
250         -- In particular, for one-shot lambdas we don't float things
251         -- out; we get no saving by so doing.
252         partition_fn | all isTyVar bndrs = partitionByLevel
253                      | otherwise         = partitionByMajorLevel
254     in
255     case (floatExpr (last lvls) body) of { (fs, floats, body') ->
256
257         -- Dump any bindings which absolutely cannot go any further
258     case (partition_fn (head lvls) floats)      of { (floats', heres) ->
259
260     (add_to_stats fs floats', floats', mkLams bndrs (install heres body'))
261     }}
262
263 floatExpr lvl (Note note@(SCC cc) expr)
264   = case (floatExpr lvl expr)    of { (fs, floating_defns, expr') ->
265     let
266         -- Annotate bindings floated outwards past an scc expression
267         -- with the cc.  We mark that cc as "duplicated", though.
268
269         annotated_defns = annotate (dupifyCC cc) floating_defns
270     in
271     (fs, annotated_defns, Note note expr') }
272   where
273     annotate :: CostCentre -> FloatBinds -> FloatBinds
274
275     annotate dupd_cc defn_groups
276       = [ (level, ann_bind floater) | (level, floater) <- defn_groups ]
277       where
278         ann_bind (NonRec binder rhs)
279           = NonRec binder (mkSCC dupd_cc rhs)
280
281         ann_bind (Rec pairs)
282           = Rec [(binder, mkSCC dupd_cc rhs) | (binder, rhs) <- pairs]
283
284 floatExpr lvl (Note InlineMe expr)      -- Other than SCCs
285   = case floatExpr InlineCtxt expr of { (fs, floating_defns, expr') ->
286         -- There can be some floating_defns, arising from
287         -- ordinary lets that were there all the time.  It seems
288         -- more efficient to test once here than to avoid putting
289         -- them into floating_defns (which would mean testing for
290         -- inlineCtxt  at every let)
291     (fs, [], Note InlineMe (install floating_defns expr')) }    -- See notes in SetLevels
292
293 floatExpr lvl (Note note expr)  -- Other than SCCs
294   = case (floatExpr lvl expr)    of { (fs, floating_defns, expr') ->
295     (fs, floating_defns, Note note expr') }
296
297 floatExpr lvl (Let bind body)
298   = case (floatBind bind)     of { (fsb, rhs_floats,  bind') ->
299     case (floatExpr lvl body) of { (fse, body_floats, body') ->
300 --    if isInlineCtxt lvl then  -- No floating inside an InlineMe
301 --      ASSERT( null rhs_floats && null body_floats )
302 --      (add_stats fsb fse, [], Let bind' body')
303 --    else
304         (add_stats fsb fse,
305          rhs_floats ++ [(bind_lvl, bind')] ++ body_floats,
306          body')
307     }}
308   where
309     bind_lvl = getBindLevel bind
310
311 floatExpr lvl (Case scrut (case_bndr, case_lvl) alts)
312   = case floatExpr lvl scrut    of { (fse, fde, scrut') ->
313     case floatList float_alt alts       of { (fsa, fda, alts')  ->
314     (add_stats fse fsa, fda ++ fde, Case scrut' case_bndr alts')
315     }}
316   where
317         -- Use floatRhs for the alternatives, so that we
318         -- don't gratuitiously float bindings out of the RHSs
319     float_alt (con, bs, rhs)
320         = case (floatRhs case_lvl rhs)  of { (fs, rhs_floats, rhs') ->
321           (fs, rhs_floats, (con, map fst bs, rhs')) }
322
323
324 floatList :: (a -> (FloatStats, FloatBinds, b)) -> [a] -> (FloatStats, FloatBinds, [b])
325 floatList f [] = (zeroStats, [], [])
326 floatList f (a:as) = case f a            of { (fs_a,  binds_a,  b)  ->
327                      case floatList f as of { (fs_as, binds_as, bs) ->
328                      (fs_a `add_stats` fs_as, binds_a ++ binds_as, b:bs) }}
329 \end{code}
330
331 %************************************************************************
332 %*                                                                      *
333 \subsection{Utility bits for floating stats}
334 %*                                                                      *
335 %************************************************************************
336
337 I didn't implement this with unboxed numbers.  I don't want to be too
338 strict in this stuff, as it is rarely turned on.  (WDP 95/09)
339
340 \begin{code}
341 data FloatStats
342   = FlS Int  -- Number of top-floats * lambda groups they've been past
343         Int  -- Number of non-top-floats * lambda groups they've been past
344         Int  -- Number of lambda (groups) seen
345
346 get_stats (FlS a b c) = (a, b, c)
347
348 zeroStats = FlS 0 0 0
349
350 sum_stats xs = foldr add_stats zeroStats xs
351
352 add_stats (FlS a1 b1 c1) (FlS a2 b2 c2)
353   = FlS (a1 + a2) (b1 + b2) (c1 + c2)
354
355 add_to_stats (FlS a b c) floats
356   = FlS (a + length top_floats) (b + length other_floats) (c + 1)
357   where
358     (top_floats, other_floats) = partition to_very_top floats
359
360     to_very_top (my_lvl, _) = isTopLvl my_lvl
361 \end{code}
362
363
364 %************************************************************************
365 %*                                                                      *
366 \subsection{Utility bits for floating}
367 %*                                                                      *
368 %************************************************************************
369
370 \begin{code}
371 getBindLevel (NonRec (_, lvl) _)      = lvl
372 getBindLevel (Rec (((_,lvl), _) : _)) = lvl
373 \end{code}
374
375 \begin{code}
376 partitionByMajorLevel, partitionByLevel
377         :: Level                -- Partitioning level
378
379         -> FloatBinds           -- Defns to be divided into 2 piles...
380
381         -> (FloatBinds, -- Defns  with level strictly < partition level,
382             FloatBinds) -- The rest
383
384
385 partitionByMajorLevel ctxt_lvl defns
386   = partition float_further defns
387   where
388         -- Float it if we escape a value lambda, or if we get to the top level
389     float_further (my_lvl, bind) = my_lvl `ltMajLvl` ctxt_lvl || isTopLvl my_lvl
390         -- The isTopLvl part says that if we can get to the top level, say "yes" anyway
391         -- This means that 
392         --      x = f e
393         -- transforms to 
394         --    lvl = e
395         --    x = f lvl
396         -- which is as it should be
397
398 partitionByLevel ctxt_lvl defns
399   = partition float_further defns
400   where
401     float_further (my_lvl, _) = my_lvl `ltLvl` ctxt_lvl
402 \end{code}
403
404 \begin{code}
405 floatsToBinds :: FloatBinds -> [CoreBind]
406 floatsToBinds floats = map snd floats
407
408 floatsToBindPairs :: FloatBinds -> [(Id,CoreExpr)]
409
410 floatsToBindPairs floats = concat (map mk_pairs floats)
411   where
412    mk_pairs (_, Rec pairs)         = pairs
413    mk_pairs (_, NonRec binder rhs) = [(binder,rhs)]
414
415 install :: FloatBinds -> CoreExpr -> CoreExpr
416
417 install defn_groups expr
418   = foldr install_group expr defn_groups
419   where
420     install_group (_, defns) body = Let defns body
421 \end{code}