c929be33705eedd277fb2fe0e045e72e23243a0b
[ghc-hetmet.git] / ghc / compiler / simplCore / FloatOut.lhs
1 %
2 % (c) The GRASP/AQUA Project, Glasgow University, 1992-1998
3 %
4 \section[FloatOut]{Float bindings outwards (towards the top level)}
5
6 ``Long-distance'' floating of bindings towards the top level.
7
8 \begin{code}
9 module FloatOut ( floatOutwards ) where
10
11 #include "HsVersions.h"
12
13 import CoreSyn
14 import CoreUtils        ( mkSCC )
15
16 import CmdLineOpts      ( opt_D_verbose_core2core, opt_D_dump_simpl_stats )
17 import ErrUtils         ( dumpIfSet )
18 import CostCentre       ( dupifyCC, CostCentre )
19 import Id               ( Id, idType )
20 import VarEnv
21 import CoreLint         ( beginPass, endPass )
22 import PprCore
23 import SetLevels        ( setLevels,
24                           Level(..), tOP_LEVEL, ltMajLvl, ltLvl, isTopLvl
25                         )
26 import BasicTypes       ( Unused )
27 import Type             ( isUnLiftedType )
28 import Var              ( TyVar )
29 import UniqSupply       ( UniqSupply )
30 import List             ( partition )
31 import Outputable
32 \end{code}
33
34 Random comments
35 ~~~~~~~~~~~~~~~
36
37 At the moment we never float a binding out to between two adjacent
38 lambdas.  For example:
39
40 @
41         \x y -> let t = x+x in ...
42 ===>
43         \x -> let t = x+x in \y -> ...
44 @
45 Reason: this is less efficient in the case where the original lambda
46 is never partially applied.
47
48 But there's a case I've seen where this might not be true.  Consider:
49 @
50 elEm2 x ys
51   = elem' x ys
52   where
53     elem' _ []  = False
54     elem' x (y:ys)      = x==y || elem' x ys
55 @
56 It turns out that this generates a subexpression of the form
57 @
58         \deq x ys -> let eq = eqFromEqDict deq in ...
59 @
60 vwhich might usefully be separated to
61 @
62         \deq -> let eq = eqFromEqDict deq in \xy -> ...
63 @
64 Well, maybe.  We don't do this at the moment.
65
66 \begin{code}
67 type LevelledExpr  = TaggedExpr Level
68 type LevelledBind  = TaggedBind Level
69 type FloatBind     = (Level, CoreBind)
70 type FloatBinds    = [FloatBind]
71 \end{code}
72
73 %************************************************************************
74 %*                                                                      *
75 \subsection[floatOutwards]{@floatOutwards@: let-floating interface function}
76 %*                                                                      *
77 %************************************************************************
78
79 \begin{code}
80 floatOutwards :: Bool           -- True <=> float lambdas to top level
81               -> UniqSupply 
82               -> [CoreBind] -> IO [CoreBind]
83
84 floatOutwards float_lams us pgm
85   = do {
86         beginPass float_msg ;
87
88         let { annotated_w_levels = setLevels float_lams pgm us ;
89               (fss, binds_s')    = unzip (map floatTopBind annotated_w_levels)
90             } ;
91
92         dumpIfSet opt_D_verbose_core2core "Levels added:"
93                   (vcat (map ppr annotated_w_levels));
94
95         let { (tlets, ntlets, lams) = get_stats (sum_stats fss) };
96
97         dumpIfSet opt_D_dump_simpl_stats "FloatOut stats:"
98                 (hcat [ int tlets,  ptext SLIT(" Lets floated to top level; "),
99                         int ntlets, ptext SLIT(" Lets floated elsewhere; from "),
100                         int lams,   ptext SLIT(" Lambda groups")]);
101
102         endPass float_msg
103                 opt_D_verbose_core2core         {- no specific flag for dumping float-out -} 
104                 (concat binds_s')
105     }
106   where
107     float_msg | float_lams = "Float out (floating lambdas too)"
108               | otherwise  = "Float out (not floating lambdas)"
109
110 floatTopBind bind@(NonRec _ _)
111   = case (floatBind emptyVarEnv tOP_LEVEL bind) of { (fs, floats, bind', _) ->
112     (fs, floatsToBinds floats ++ [bind'])
113     }
114
115 floatTopBind bind@(Rec _)
116   = case (floatBind emptyVarEnv tOP_LEVEL bind) of { (fs, floats, Rec pairs', _) ->
117         -- Actually floats will be empty
118     --false:ASSERT(null floats)
119     (fs, [Rec (floatsToBindPairs floats ++ pairs')])
120     }
121 \end{code}
122
123 %************************************************************************
124 %*                                                                      *
125 \subsection[FloatOut-Bind]{Floating in a binding (the business end)}
126 %*                                                                      *
127 %************************************************************************
128
129
130 \begin{code}
131 floatBind :: IdEnv Level
132           -> Level
133           -> LevelledBind
134           -> (FloatStats, FloatBinds, CoreBind, IdEnv Level)
135
136 floatBind env lvl (NonRec (name,level) rhs)
137   = case (floatRhs env level rhs) of { (fs, rhs_floats, rhs') ->
138     (fs, rhs_floats,
139      NonRec name rhs',
140      extendVarEnv env name level)
141     }
142
143 floatBind env lvl bind@(Rec pairs)
144   = case (unzip3 (map do_pair pairs)) of { (fss, rhss_floats, new_pairs) ->
145
146     if not (isTopLvl bind_level) then
147         -- Standard case
148         (sum_stats fss, concat rhss_floats, Rec new_pairs, new_env)
149     else
150         {- In a recursive binding, destined for the top level (only),
151            the rhs floats may contain
152            references to the bound things.  For example
153
154                 f = ...(let v = ...f... in b) ...
155
156            might get floated to
157
158                 v = ...f...
159                 f = ... b ...
160
161            and hence we must (pessimistically) make all the floats recursive
162            with the top binding.  Later dependency analysis will unravel it.
163         -}
164
165         (sum_stats fss,
166          [],
167          Rec (new_pairs ++ floatsToBindPairs (concat rhss_floats)),
168          new_env)
169
170     }
171   where
172     new_env = extendVarEnvList env (map fst pairs)
173
174     bind_level = getBindLevel bind
175
176     do_pair ((name, level), rhs)
177       = case (floatRhs new_env level rhs) of { (fs, rhs_floats, rhs') ->
178         (fs, rhs_floats, (name, rhs'))
179         }
180 \end{code}
181
182 %************************************************************************
183
184 \subsection[FloatOut-Expr]{Floating in expressions}
185 %*                                                                      *
186 %************************************************************************
187
188 \begin{code}
189 floatExpr, floatRhs
190          :: IdEnv Level
191          -> Level
192          -> LevelledExpr
193          -> (FloatStats, FloatBinds, CoreExpr)
194
195 floatRhs env lvl arg
196   = case (floatExpr env lvl arg) of { (fsa, floats, arg') ->
197     case (partitionByMajorLevel lvl floats) of { (floats', heres) ->
198         -- Dump bindings that aren't going to escape from a lambda
199         -- This is to avoid floating the x binding out of
200         --      f (let x = e in b)
201         -- unnecessarily.  It even causes a bug to do so if we have
202         --      y = writeArr# a n (let x = e in b)
203         -- because the y binding is an expr-ok-for-speculation one.
204     (fsa, floats', install heres arg') }}
205
206 floatExpr env _ (Var v)      = (zeroStats, [], Var v)
207 floatExpr env _ (Type ty)    = (zeroStats, [], Type ty)
208 floatExpr env _ (Lit lit)    = (zeroStats, [], Lit lit)
209           
210 floatExpr env lvl (App e a)
211   = case (floatExpr env lvl e) of { (fse, floats_e, e') ->
212     case (floatRhs env lvl a) of { (fsa, floats_a, a') ->
213     (fse `add_stats` fsa, floats_e ++ floats_a, App e' a') }}
214
215 floatExpr env lvl (Lam (tv,incd_lvl) e)
216   | isTyVar tv
217   = case (floatExpr env incd_lvl e) of { (fs, floats, e') ->
218
219         -- Dump any bindings which absolutely cannot go any further
220     case (partitionByLevel incd_lvl floats)     of { (floats', heres) ->
221
222     (fs, floats', Lam tv (install heres e'))
223     }}
224
225 floatExpr env lvl (Lam (arg,incd_lvl) rhs)
226   = ASSERT( isId arg )
227     let
228         new_env  = extendVarEnv env arg incd_lvl
229     in
230     case (floatExpr new_env incd_lvl rhs) of { (fs, floats, rhs') ->
231
232         -- Dump any bindings which absolutely cannot go any further
233     case (partitionByLevel incd_lvl floats)     of { (floats', heres) ->
234
235     (add_to_stats fs floats',
236      floats',
237      Lam arg (install heres rhs'))
238     }}
239
240 floatExpr env lvl (Note note@(SCC cc) expr)
241   = case (floatExpr env lvl expr)    of { (fs, floating_defns, expr') ->
242     let
243         -- Annotate bindings floated outwards past an scc expression
244         -- with the cc.  We mark that cc as "duplicated", though.
245
246         annotated_defns = annotate (dupifyCC cc) floating_defns
247     in
248     (fs, annotated_defns, Note note expr') }
249   where
250     annotate :: CostCentre -> FloatBinds -> FloatBinds
251
252     annotate dupd_cc defn_groups
253       = [ (level, ann_bind floater) | (level, floater) <- defn_groups ]
254       where
255         ann_bind (NonRec binder rhs)
256           = NonRec binder (mkSCC dupd_cc rhs)
257
258         ann_bind (Rec pairs)
259           = Rec [(binder, mkSCC dupd_cc rhs) | (binder, rhs) <- pairs]
260
261 -- At one time I tried the effect of not float anything out of an InlineMe,
262 -- but it sometimes works badly.  For example, consider PrelArr.done.  It
263 -- has the form         __inline (\d. e)
264 -- where e doesn't mention d.  If we float this to 
265 --      __inline (let x = e in \d. x)
266 -- things are bad.  The inliner doesn't even inline it because it doesn't look
267 -- like a head-normal form.  So it seems a lesser evil to let things float.
268 -- In SetLevels we do set the context to (Level 0 0) when we get to an InlineMe
269 -- which discourages floating out.
270
271 floatExpr env lvl (Note note expr)      -- Other than SCCs
272   = case (floatExpr env lvl expr)    of { (fs, floating_defns, expr') ->
273     (fs, floating_defns, Note note expr') }
274
275 floatExpr env lvl (Let bind body)
276   = case (floatBind env     lvl bind) of { (fsb, rhs_floats, bind', new_env) ->
277     case (floatExpr new_env lvl body) of { (fse, body_floats, body') ->
278     (add_stats fsb fse,
279      rhs_floats ++ [(bind_lvl, bind')] ++ body_floats,
280      body')
281     }}
282   where
283     bind_lvl = getBindLevel bind
284
285 floatExpr env lvl (Case scrut (case_bndr, case_lvl) alts)
286   = case floatExpr env lvl scrut        of { (fse, fde, scrut') ->
287     case floatList float_alt alts       of { (fsa, fda, alts')  ->
288     (add_stats fse fsa, fda ++ fde, Case scrut' case_bndr alts')
289     }}
290   where
291       alts_env = extendVarEnv env case_bndr case_lvl
292
293       partition_fn = partitionByMajorLevel
294
295       float_alt (con, bs, rhs)
296         = let
297               bs' = map fst bs
298               new_env = extendVarEnvList alts_env bs
299           in
300           case (floatExpr new_env case_lvl rhs)         of { (fs, rhs_floats, rhs') ->
301           case (partition_fn case_lvl rhs_floats)       of { (rhs_floats', heres) ->
302           (fs, rhs_floats', (con, bs', install heres rhs')) }}
303
304
305 floatList :: (a -> (FloatStats, FloatBinds, b)) -> [a] -> (FloatStats, FloatBinds, [b])
306 floatList f [] = (zeroStats, [], [])
307 floatList f (a:as) = case f a            of { (fs_a,  binds_a,  b)  ->
308                      case floatList f as of { (fs_as, binds_as, bs) ->
309                      (fs_a `add_stats` fs_as, binds_a ++ binds_as, b:bs) }}
310 \end{code}
311
312 %************************************************************************
313 %*                                                                      *
314 \subsection{Utility bits for floating stats}
315 %*                                                                      *
316 %************************************************************************
317
318 I didn't implement this with unboxed numbers.  I don't want to be too
319 strict in this stuff, as it is rarely turned on.  (WDP 95/09)
320
321 \begin{code}
322 data FloatStats
323   = FlS Int  -- Number of top-floats * lambda groups they've been past
324         Int  -- Number of non-top-floats * lambda groups they've been past
325         Int  -- Number of lambda (groups) seen
326
327 get_stats (FlS a b c) = (a, b, c)
328
329 zeroStats = FlS 0 0 0
330
331 sum_stats xs = foldr add_stats zeroStats xs
332
333 add_stats (FlS a1 b1 c1) (FlS a2 b2 c2)
334   = FlS (a1 + a2) (b1 + b2) (c1 + c2)
335
336 add_to_stats (FlS a b c) floats
337   = FlS (a + length top_floats) (b + length other_floats) (c + 1)
338   where
339     (top_floats, other_floats) = partition to_very_top floats
340
341     to_very_top (my_lvl, _) = isTopLvl my_lvl
342 \end{code}
343
344
345 %************************************************************************
346 %*                                                                      *
347 \subsection{Utility bits for floating}
348 %*                                                                      *
349 %************************************************************************
350
351 \begin{code}
352 getBindLevel (NonRec (_, lvl) _)      = lvl
353 getBindLevel (Rec (((_,lvl), _) : _)) = lvl
354 \end{code}
355
356 \begin{code}
357 partitionByMajorLevel, partitionByLevel
358         :: Level                -- Partitioning level
359
360         -> FloatBinds           -- Defns to be divided into 2 piles...
361
362         -> (FloatBinds, -- Defns  with level strictly < partition level,
363             FloatBinds) -- The rest
364
365
366 partitionByMajorLevel ctxt_lvl defns
367   = partition float_further defns
368   where
369         -- Float it if we escape a value lambda, 
370         -- or if we get to the top level
371     float_further (my_lvl, bind) = my_lvl `ltMajLvl` ctxt_lvl || isTopLvl my_lvl
372         -- The isTopLvl part says that if we can get to the top level, say "yes" anyway
373         -- This means that 
374         --      x = f e
375         -- transforms to 
376         --    lvl = e
377         --    x = f lvl
378         -- which is as it should be
379
380 partitionByLevel ctxt_lvl defns
381   = partition float_further defns
382   where
383     float_further (my_lvl, _) = my_lvl `ltLvl` ctxt_lvl
384 \end{code}
385
386 \begin{code}
387 floatsToBinds :: FloatBinds -> [CoreBind]
388 floatsToBinds floats = map snd floats
389
390 floatsToBindPairs :: FloatBinds -> [(Id,CoreExpr)]
391
392 floatsToBindPairs floats = concat (map mk_pairs floats)
393   where
394    mk_pairs (_, Rec pairs)         = pairs
395    mk_pairs (_, NonRec binder rhs) = [(binder,rhs)]
396
397 install :: FloatBinds -> CoreExpr -> CoreExpr
398
399 install defn_groups expr
400   = foldr install_group expr defn_groups
401   where
402     install_group (_, defns) body = Let defns body
403 \end{code}