Completely new treatment of INLINE pragmas (big patch)
[ghc-hetmet.git] / compiler / simplCore / FloatOut.lhs
1 %
2 % (c) The GRASP/AQUA Project, Glasgow University, 1992-1998
3 %
4 \section[FloatOut]{Float bindings outwards (towards the top level)}
5
6 ``Long-distance'' floating of bindings towards the top level.
7
8 \begin{code}
9 module FloatOut ( floatOutwards ) where
10
11 import CoreSyn
12 import CoreUtils
13
14 import DynFlags ( DynFlags, DynFlag(..), FloatOutSwitches(..) )
15 import ErrUtils         ( dumpIfSet_dyn )
16 import CostCentre       ( dupifyCC, CostCentre )
17 import Id               ( Id, idType )
18 import Type             ( isUnLiftedType )
19 import SetLevels        ( Level(..), LevelledExpr, LevelledBind,
20                           setLevels, ltMajLvl, ltLvl, isTopLvl )
21 import UniqSupply       ( UniqSupply )
22 import List             ( partition )
23 import Outputable
24 import FastString
25 \end{code}
26
27         -----------------
28         Overall game plan
29         -----------------
30
31 The Big Main Idea is:
32
33         To float out sub-expressions that can thereby get outside
34         a non-one-shot value lambda, and hence may be shared.
35
36
37 To achieve this we may need to do two thing:
38
39    a) Let-bind the sub-expression:
40
41         f (g x)  ==>  let lvl = f (g x) in lvl
42
43       Now we can float the binding for 'lvl'.  
44
45    b) More than that, we may need to abstract wrt a type variable
46
47         \x -> ... /\a -> let v = ...a... in ....
48
49       Here the binding for v mentions 'a' but not 'x'.  So we
50       abstract wrt 'a', to give this binding for 'v':
51
52             vp = /\a -> ...a...
53             v  = vp a
54
55       Now the binding for vp can float out unimpeded.
56       I can't remember why this case seemed important enough to
57       deal with, but I certainly found cases where important floats
58       didn't happen if we did not abstract wrt tyvars.
59
60 With this in mind we can also achieve another goal: lambda lifting.
61 We can make an arbitrary (function) binding float to top level by
62 abstracting wrt *all* local variables, not just type variables, leaving
63 a binding that can be floated right to top level.  Whether or not this
64 happens is controlled by a flag.
65
66
67 Random comments
68 ~~~~~~~~~~~~~~~
69
70 At the moment we never float a binding out to between two adjacent
71 lambdas.  For example:
72
73 @
74         \x y -> let t = x+x in ...
75 ===>
76         \x -> let t = x+x in \y -> ...
77 @
78 Reason: this is less efficient in the case where the original lambda
79 is never partially applied.
80
81 But there's a case I've seen where this might not be true.  Consider:
82 @
83 elEm2 x ys
84   = elem' x ys
85   where
86     elem' _ []  = False
87     elem' x (y:ys)      = x==y || elem' x ys
88 @
89 It turns out that this generates a subexpression of the form
90 @
91         \deq x ys -> let eq = eqFromEqDict deq in ...
92 @
93 vwhich might usefully be separated to
94 @
95         \deq -> let eq = eqFromEqDict deq in \xy -> ...
96 @
97 Well, maybe.  We don't do this at the moment.
98
99 \begin{code}
100 type FloatBind     = (Level, CoreBind)  -- INVARIANT: a FloatBind is always lifted
101 type FloatBinds    = [FloatBind]        
102 \end{code}
103
104 %************************************************************************
105 %*                                                                      *
106 \subsection[floatOutwards]{@floatOutwards@: let-floating interface function}
107 %*                                                                      *
108 %************************************************************************
109
110 \begin{code}
111 floatOutwards :: FloatOutSwitches
112               -> DynFlags
113               -> UniqSupply 
114               -> [CoreBind] -> IO [CoreBind]
115
116 floatOutwards float_sws dflags us pgm
117   = do {
118         let { annotated_w_levels = setLevels float_sws pgm us ;
119               (fss, binds_s')    = unzip (map floatTopBind annotated_w_levels)
120             } ;
121
122         dumpIfSet_dyn dflags Opt_D_verbose_core2core "Levels added:"
123                   (vcat (map ppr annotated_w_levels));
124
125         let { (tlets, ntlets, lams) = get_stats (sum_stats fss) };
126
127         dumpIfSet_dyn dflags Opt_D_dump_simpl_stats "FloatOut stats:"
128                 (hcat [ int tlets,  ptext (sLit " Lets floated to top level; "),
129                         int ntlets, ptext (sLit " Lets floated elsewhere; from "),
130                         int lams,   ptext (sLit " Lambda groups")]);
131
132         return (concat binds_s')
133     }
134
135 floatTopBind :: LevelledBind -> (FloatStats, [CoreBind])
136 floatTopBind bind
137   = case (floatBind bind) of { (fs, floats) ->
138     (fs, floatsToBinds floats)
139     }
140 \end{code}
141
142 %************************************************************************
143 %*                                                                      *
144 \subsection[FloatOut-Bind]{Floating in a binding (the business end)}
145 %*                                                                      *
146 %************************************************************************
147
148
149 \begin{code}
150 floatBind :: LevelledBind -> (FloatStats, FloatBinds)
151
152 floatBind (NonRec (TB name level) rhs)
153   = case (floatRhs level rhs) of { (fs, rhs_floats, rhs') ->
154     (fs, rhs_floats ++ [(level, NonRec name rhs')]) }
155
156 floatBind bind@(Rec pairs)
157   = case (unzip3 (map do_pair pairs)) of { (fss, rhss_floats, new_pairs) ->
158     let rhs_floats = concat rhss_floats in
159
160     if not (isTopLvl bind_dest_lvl) then
161         -- Find which bindings float out at least one lambda beyond this one
162         -- These ones can't mention the binders, because they couldn't 
163         -- be escaping a major level if so.
164         -- The ones that are not going further can join the letrec;
165         -- they may not be mutually recursive but the occurrence analyser will
166         -- find that out.
167         case (partitionByMajorLevel bind_dest_lvl rhs_floats) of { (floats', heres) ->
168         (sum_stats fss, floats' ++ [(bind_dest_lvl, Rec (floatsToBindPairs heres ++ new_pairs))]) }
169     else
170         -- In a recursive binding, *destined for* the top level
171         -- (only), the rhs floats may contain references to the 
172         -- bound things.  For example
173         --      f = ...(let v = ...f... in b) ...
174         --  might get floated to
175         --      v = ...f...
176         --      f = ... b ...
177         -- and hence we must (pessimistically) make all the floats recursive
178         -- with the top binding.  Later dependency analysis will unravel it.
179         --
180         -- This can only happen for bindings destined for the top level,
181         -- because only then will partitionByMajorLevel allow through a binding
182         -- that only differs in its minor level
183         (sum_stats fss, [(bind_dest_lvl, Rec (new_pairs ++ floatsToBindPairs rhs_floats))])
184     }
185   where
186     bind_dest_lvl = getBindLevel bind
187
188     do_pair (TB name level, rhs)
189       = case (floatRhs level rhs) of { (fs, rhs_floats, rhs') ->
190         (fs, rhs_floats, (name, rhs'))
191         }
192 \end{code}
193
194 %************************************************************************
195
196 \subsection[FloatOut-Expr]{Floating in expressions}
197 %*                                                                      *
198 %************************************************************************
199
200 \begin{code}
201 floatExpr, floatRhs, floatCaseAlt
202          :: Level
203          -> LevelledExpr
204          -> (FloatStats, FloatBinds, CoreExpr)
205
206 floatCaseAlt lvl arg    -- Used rec rhss, and case-alternative rhss
207   = case (floatExpr lvl arg) of { (fsa, floats, arg') ->
208     case (partitionByMajorLevel lvl floats) of { (floats', heres) ->
209         -- Dump bindings that aren't going to escape from a lambda;
210         -- in particular, we must dump the ones that are bound by 
211         -- the rec or case alternative
212     (fsa, floats', install heres arg') }}
213
214 floatRhs lvl arg        -- Used for nested non-rec rhss, and fn args
215                         -- See Note [Floating out of RHS]
216   = case (floatExpr lvl arg) of { (fsa, floats, arg') ->
217     if exprIsCheap arg' then    
218         (fsa, floats, arg')
219     else
220     case (partitionByMajorLevel lvl floats) of { (floats', heres) ->
221     (fsa, floats', install heres arg') }}
222
223 -- Note [Floating out of RHSs]
224 -- ~~~~~~~~~~~~~~~~~~~~~~~~~~~
225 -- Dump bindings that aren't going to escape from a lambda
226 -- This isn't a scoping issue (the binder isn't in scope in the RHS 
227 --      of a non-rec binding)
228 -- Rather, it is to avoid floating the x binding out of
229 --      f (let x = e in b)
230 -- unnecessarily.  But we first test for values or trival rhss,
231 -- because (in particular) we don't want to insert new bindings between
232 -- the "=" and the "\".  E.g.
233 --      f = \x -> let <bind> in <body>
234 -- We do not want
235 --      f = let <bind> in \x -> <body>
236 -- (a) The simplifier will immediately float it further out, so we may
237 --      as well do so right now; in general, keeping rhss as manifest 
238 --      values is good
239 -- (b) If a float-in pass follows immediately, it might add yet more
240 --      bindings just after the '='.  And some of them might (correctly)
241 --      be strict even though the 'let f' is lazy, because f, being a value,
242 --      gets its demand-info zapped by the simplifier.
243 --
244 -- We use exprIsCheap because that is also what's used by the simplifier
245 -- to decide whether to float a let out of a let
246
247 floatExpr _ (Var v)   = (zeroStats, [], Var v)
248 floatExpr _ (Type ty) = (zeroStats, [], Type ty)
249 floatExpr _ (Lit lit) = (zeroStats, [], Lit lit)
250           
251 floatExpr lvl (App e a)
252   = case (floatExpr      lvl e) of { (fse, floats_e, e') ->
253     case (floatRhs lvl a)       of { (fsa, floats_a, a') ->
254     (fse `add_stats` fsa, floats_e ++ floats_a, App e' a') }}
255
256 floatExpr _ lam@(Lam _ _)
257   = let
258         (bndrs_w_lvls, body) = collectBinders lam
259         bndrs                = [b | TB b _ <- bndrs_w_lvls]
260         lvls                 = [l | TB _ l <- bndrs_w_lvls]
261
262         -- For the all-tyvar case we are prepared to pull 
263         -- the lets out, to implement the float-out-of-big-lambda
264         -- transform; but otherwise we only float bindings that are
265         -- going to escape a value lambda.
266         -- In particular, for one-shot lambdas we don't float things
267         -- out; we get no saving by so doing.
268         partition_fn | all isTyVar bndrs = partitionByLevel
269                      | otherwise         = partitionByMajorLevel
270     in
271     case (floatExpr (last lvls) body) of { (fs, floats, body') ->
272
273         -- Dump any bindings which absolutely cannot go any further
274     case (partition_fn (head lvls) floats)      of { (floats', heres) ->
275
276     (add_to_stats fs floats', floats', mkLams bndrs (install heres body'))
277     }}
278
279 floatExpr lvl (Note note@(SCC cc) expr)
280   = case (floatExpr lvl expr)    of { (fs, floating_defns, expr') ->
281     let
282         -- Annotate bindings floated outwards past an scc expression
283         -- with the cc.  We mark that cc as "duplicated", though.
284
285         annotated_defns = annotate (dupifyCC cc) floating_defns
286     in
287     (fs, annotated_defns, Note note expr') }
288   where
289     annotate :: CostCentre -> FloatBinds -> FloatBinds
290
291     annotate dupd_cc defn_groups
292       = [ (level, ann_bind floater) | (level, floater) <- defn_groups ]
293       where
294         ann_bind (NonRec binder rhs)
295           = NonRec binder (mkSCC dupd_cc rhs)
296
297         ann_bind (Rec pairs)
298           = Rec [(binder, mkSCC dupd_cc rhs) | (binder, rhs) <- pairs]
299
300 floatExpr lvl (Note note expr)  -- Other than SCCs
301   = case (floatExpr lvl expr)    of { (fs, floating_defns, expr') ->
302     (fs, floating_defns, Note note expr') }
303
304 floatExpr lvl (Cast expr co)
305   = case (floatExpr lvl expr)   of { (fs, floating_defns, expr') ->
306     (fs, floating_defns, Cast expr' co) }
307
308 floatExpr lvl (Let (NonRec (TB bndr bndr_lvl) rhs) body)
309   | isUnLiftedType (idType bndr)        -- Treat unlifted lets just like a case
310                                 -- I.e. floatExpr for rhs, floatCaseAlt for body
311   = case floatExpr lvl rhs          of { (_, rhs_floats, rhs') ->
312     case floatCaseAlt bndr_lvl body of { (fs, body_floats, body') ->
313     (fs, rhs_floats ++ body_floats, Let (NonRec bndr rhs') body') }}
314
315 floatExpr lvl (Let bind body)
316   = case (floatBind bind)     of { (fsb, bind_floats) ->
317     case (floatExpr lvl body) of { (fse, body_floats, body') ->
318     (add_stats fsb fse,
319      bind_floats ++ body_floats,
320      body')  }}
321
322 floatExpr lvl (Case scrut (TB case_bndr case_lvl) ty alts)
323   = case floatExpr lvl scrut    of { (fse, fde, scrut') ->
324     case floatList float_alt alts       of { (fsa, fda, alts')  ->
325     (add_stats fse fsa, fda ++ fde, Case scrut' case_bndr ty alts')
326     }}
327   where
328         -- Use floatCaseAlt for the alternatives, so that we
329         -- don't gratuitiously float bindings out of the RHSs
330     float_alt (con, bs, rhs)
331         = case (floatCaseAlt case_lvl rhs)      of { (fs, rhs_floats, rhs') ->
332           (fs, rhs_floats, (con, [b | TB b _ <- bs], rhs')) }
333
334
335 floatList :: (a -> (FloatStats, FloatBinds, b)) -> [a] -> (FloatStats, FloatBinds, [b])
336 floatList _ [] = (zeroStats, [], [])
337 floatList f (a:as) = case f a            of { (fs_a,  binds_a,  b)  ->
338                      case floatList f as of { (fs_as, binds_as, bs) ->
339                      (fs_a `add_stats` fs_as, binds_a ++ binds_as, b:bs) }}
340 \end{code}
341
342 %************************************************************************
343 %*                                                                      *
344 \subsection{Utility bits for floating stats}
345 %*                                                                      *
346 %************************************************************************
347
348 I didn't implement this with unboxed numbers.  I don't want to be too
349 strict in this stuff, as it is rarely turned on.  (WDP 95/09)
350
351 \begin{code}
352 data FloatStats
353   = FlS Int  -- Number of top-floats * lambda groups they've been past
354         Int  -- Number of non-top-floats * lambda groups they've been past
355         Int  -- Number of lambda (groups) seen
356
357 get_stats :: FloatStats -> (Int, Int, Int)
358 get_stats (FlS a b c) = (a, b, c)
359
360 zeroStats :: FloatStats
361 zeroStats = FlS 0 0 0
362
363 sum_stats :: [FloatStats] -> FloatStats
364 sum_stats xs = foldr add_stats zeroStats xs
365
366 add_stats :: FloatStats -> FloatStats -> FloatStats
367 add_stats (FlS a1 b1 c1) (FlS a2 b2 c2)
368   = FlS (a1 + a2) (b1 + b2) (c1 + c2)
369
370 add_to_stats :: FloatStats -> [(Level, Bind CoreBndr)] -> FloatStats
371 add_to_stats (FlS a b c) floats
372   = FlS (a + length top_floats) (b + length other_floats) (c + 1)
373   where
374     (top_floats, other_floats) = partition to_very_top floats
375
376     to_very_top (my_lvl, _) = isTopLvl my_lvl
377 \end{code}
378
379
380 %************************************************************************
381 %*                                                                      *
382 \subsection{Utility bits for floating}
383 %*                                                                      *
384 %************************************************************************
385
386 \begin{code}
387 getBindLevel :: Bind (TaggedBndr Level) -> Level
388 getBindLevel (NonRec (TB _ lvl) _)       = lvl
389 getBindLevel (Rec (((TB _ lvl), _) : _)) = lvl
390 getBindLevel (Rec [])                    = panic "getBindLevel Rec []"
391 \end{code}
392
393 \begin{code}
394 partitionByMajorLevel, partitionByLevel
395         :: Level                -- Partitioning level
396
397         -> FloatBinds           -- Defns to be divided into 2 piles...
398
399         -> (FloatBinds, -- Defns  with level strictly < partition level,
400             FloatBinds) -- The rest
401
402
403 partitionByMajorLevel ctxt_lvl defns
404   = partition float_further defns
405   where
406         -- Float it if we escape a value lambda, or if we get to the top level
407     float_further (my_lvl, _) = my_lvl `ltMajLvl` ctxt_lvl || isTopLvl my_lvl
408         -- The isTopLvl part says that if we can get to the top level, say "yes" anyway
409         -- This means that 
410         --      x = f e
411         -- transforms to 
412         --    lvl = e
413         --    x = f lvl
414         -- which is as it should be
415
416 partitionByLevel ctxt_lvl defns
417   = partition float_further defns
418   where
419     float_further (my_lvl, _) = my_lvl `ltLvl` ctxt_lvl
420 \end{code}
421
422 \begin{code}
423 floatsToBinds :: FloatBinds -> [CoreBind]
424 floatsToBinds floats = map snd floats
425
426 floatsToBindPairs :: FloatBinds -> [(Id,CoreExpr)]
427
428 floatsToBindPairs floats = concat (map mk_pairs floats)
429   where
430    mk_pairs (_, Rec pairs)         = pairs
431    mk_pairs (_, NonRec binder rhs) = [(binder,rhs)]
432
433 install :: FloatBinds -> CoreExpr -> CoreExpr
434
435 install defn_groups expr
436   = foldr install_group expr defn_groups
437   where
438     install_group (_, defns) body = Let defns body
439 \end{code}