[project @ 2000-10-25 13:51:50 by simonpj]
[ghc-hetmet.git] / ghc / compiler / simplCore / FloatOut.lhs
1 %
2 % (c) The GRASP/AQUA Project, Glasgow University, 1992-1998
3 %
4 \section[FloatOut]{Float bindings outwards (towards the top level)}
5
6 ``Long-distance'' floating of bindings towards the top level.
7
8 \begin{code}
9 module FloatOut ( floatOutwards ) where
10
11 #include "HsVersions.h"
12
13 import CoreSyn
14 import CoreUtils        ( mkSCC )
15
16 import CmdLineOpts      ( DynFlags, DynFlag(..), dopt )
17 import ErrUtils         ( dumpIfSet_dyn )
18 import CostCentre       ( dupifyCC, CostCentre )
19 import Id               ( Id )
20 import VarEnv
21 import CoreLint         ( beginPass, endPass )
22 import SetLevels        ( setLevels,
23                           Level(..), tOP_LEVEL, ltMajLvl, ltLvl, isTopLvl
24                         )
25 import UniqSupply       ( UniqSupply )
26 import List             ( partition )
27 import Outputable
28 \end{code}
29
30 Random comments
31 ~~~~~~~~~~~~~~~
32
33 At the moment we never float a binding out to between two adjacent
34 lambdas.  For example:
35
36 @
37         \x y -> let t = x+x in ...
38 ===>
39         \x -> let t = x+x in \y -> ...
40 @
41 Reason: this is less efficient in the case where the original lambda
42 is never partially applied.
43
44 But there's a case I've seen where this might not be true.  Consider:
45 @
46 elEm2 x ys
47   = elem' x ys
48   where
49     elem' _ []  = False
50     elem' x (y:ys)      = x==y || elem' x ys
51 @
52 It turns out that this generates a subexpression of the form
53 @
54         \deq x ys -> let eq = eqFromEqDict deq in ...
55 @
56 vwhich might usefully be separated to
57 @
58         \deq -> let eq = eqFromEqDict deq in \xy -> ...
59 @
60 Well, maybe.  We don't do this at the moment.
61
62 \begin{code}
63 type LevelledExpr  = TaggedExpr Level
64 type LevelledBind  = TaggedBind Level
65 type FloatBind     = (Level, CoreBind)
66 type FloatBinds    = [FloatBind]
67 \end{code}
68
69 %************************************************************************
70 %*                                                                      *
71 \subsection[floatOutwards]{@floatOutwards@: let-floating interface function}
72 %*                                                                      *
73 %************************************************************************
74
75 \begin{code}
76 floatOutwards :: DynFlags
77               -> Bool           -- True <=> float lambdas to top level
78               -> UniqSupply 
79               -> [CoreBind] -> IO [CoreBind]
80
81 floatOutwards dflags float_lams us pgm
82   = do {
83         beginPass dflags float_msg ;
84
85         let { annotated_w_levels = setLevels float_lams pgm us ;
86               (fss, binds_s')    = unzip (map floatTopBind annotated_w_levels)
87             } ;
88
89         dumpIfSet_dyn dflags Opt_D_verbose_core2core "Levels added:"
90                   (vcat (map ppr annotated_w_levels));
91
92         let { (tlets, ntlets, lams) = get_stats (sum_stats fss) };
93
94         dumpIfSet_dyn dflags Opt_D_dump_simpl_stats "FloatOut stats:"
95                 (hcat [ int tlets,  ptext SLIT(" Lets floated to top level; "),
96                         int ntlets, ptext SLIT(" Lets floated elsewhere; from "),
97                         int lams,   ptext SLIT(" Lambda groups")]);
98
99         endPass dflags float_msg
100                 (dopt Opt_D_verbose_core2core dflags)
101                         {- no specific flag for dumping float-out -} 
102                 (concat binds_s')
103     }
104   where
105     float_msg | float_lams = "Float out (floating lambdas too)"
106               | otherwise  = "Float out (not floating lambdas)"
107
108 floatTopBind bind@(NonRec _ _)
109   = case (floatBind emptyVarEnv tOP_LEVEL bind) of { (fs, floats, bind', _) ->
110     (fs, floatsToBinds floats ++ [bind'])
111     }
112
113 floatTopBind bind@(Rec _)
114   = case (floatBind emptyVarEnv tOP_LEVEL bind) of { (fs, floats, Rec pairs', _) ->
115         -- Actually floats will be empty
116     --false:ASSERT(null floats)
117     (fs, [Rec (floatsToBindPairs floats ++ pairs')])
118     }
119 \end{code}
120
121 %************************************************************************
122 %*                                                                      *
123 \subsection[FloatOut-Bind]{Floating in a binding (the business end)}
124 %*                                                                      *
125 %************************************************************************
126
127
128 \begin{code}
129 floatBind :: IdEnv Level
130           -> Level
131           -> LevelledBind
132           -> (FloatStats, FloatBinds, CoreBind, IdEnv Level)
133
134 floatBind env lvl (NonRec (name,level) rhs)
135   = case (floatRhs env level rhs) of { (fs, rhs_floats, rhs') ->
136     (fs, rhs_floats,
137      NonRec name rhs',
138      extendVarEnv env name level)
139     }
140
141 floatBind env lvl bind@(Rec pairs)
142   = case (unzip3 (map do_pair pairs)) of { (fss, rhss_floats, new_pairs) ->
143
144     if not (isTopLvl bind_level) then
145         -- Standard case
146         (sum_stats fss, concat rhss_floats, Rec new_pairs, new_env)
147     else
148         {- In a recursive binding, destined for the top level (only),
149            the rhs floats may contain
150            references to the bound things.  For example
151
152                 f = ...(let v = ...f... in b) ...
153
154            might get floated to
155
156                 v = ...f...
157                 f = ... b ...
158
159            and hence we must (pessimistically) make all the floats recursive
160            with the top binding.  Later dependency analysis will unravel it.
161         -}
162
163         (sum_stats fss,
164          [],
165          Rec (new_pairs ++ floatsToBindPairs (concat rhss_floats)),
166          new_env)
167
168     }
169   where
170     new_env = extendVarEnvList env (map fst pairs)
171
172     bind_level = getBindLevel bind
173
174     do_pair ((name, level), rhs)
175       = case (floatRhs new_env level rhs) of { (fs, rhs_floats, rhs') ->
176         (fs, rhs_floats, (name, rhs'))
177         }
178 \end{code}
179
180 %************************************************************************
181
182 \subsection[FloatOut-Expr]{Floating in expressions}
183 %*                                                                      *
184 %************************************************************************
185
186 \begin{code}
187 floatExpr, floatRhs
188          :: IdEnv Level
189          -> Level
190          -> LevelledExpr
191          -> (FloatStats, FloatBinds, CoreExpr)
192
193 floatRhs env lvl arg
194   = case (floatExpr env lvl arg) of { (fsa, floats, arg') ->
195     case (partitionByMajorLevel lvl floats) of { (floats', heres) ->
196         -- Dump bindings that aren't going to escape from a lambda
197         -- This is to avoid floating the x binding out of
198         --      f (let x = e in b)
199         -- unnecessarily.  It even causes a bug to do so if we have
200         --      y = writeArr# a n (let x = e in b)
201         -- because the y binding is an expr-ok-for-speculation one.
202     (fsa, floats', install heres arg') }}
203
204 floatExpr env _ (Var v)      = (zeroStats, [], Var v)
205 floatExpr env _ (Type ty)    = (zeroStats, [], Type ty)
206 floatExpr env _ (Lit lit)    = (zeroStats, [], Lit lit)
207           
208 floatExpr env lvl (App e a)
209   = case (floatExpr env lvl e) of { (fse, floats_e, e') ->
210     case (floatRhs env lvl a) of { (fsa, floats_a, a') ->
211     (fse `add_stats` fsa, floats_e ++ floats_a, App e' a') }}
212
213 floatExpr env lvl (Lam (tv,incd_lvl) e)
214   | isTyVar tv
215   = case (floatExpr env incd_lvl e) of { (fs, floats, e') ->
216
217         -- Dump any bindings which absolutely cannot go any further
218     case (partitionByLevel incd_lvl floats)     of { (floats', heres) ->
219
220     (fs, floats', Lam tv (install heres e'))
221     }}
222
223 floatExpr env lvl (Lam (arg,incd_lvl) rhs)
224   = ASSERT( isId arg )
225     let
226         new_env  = extendVarEnv env arg incd_lvl
227     in
228     case (floatExpr new_env incd_lvl rhs) of { (fs, floats, rhs') ->
229
230         -- Dump any bindings which absolutely cannot go any further
231     case (partitionByLevel incd_lvl floats)     of { (floats', heres) ->
232
233     (add_to_stats fs floats',
234      floats',
235      Lam arg (install heres rhs'))
236     }}
237
238 floatExpr env lvl (Note note@(SCC cc) expr)
239   = case (floatExpr env lvl expr)    of { (fs, floating_defns, expr') ->
240     let
241         -- Annotate bindings floated outwards past an scc expression
242         -- with the cc.  We mark that cc as "duplicated", though.
243
244         annotated_defns = annotate (dupifyCC cc) floating_defns
245     in
246     (fs, annotated_defns, Note note expr') }
247   where
248     annotate :: CostCentre -> FloatBinds -> FloatBinds
249
250     annotate dupd_cc defn_groups
251       = [ (level, ann_bind floater) | (level, floater) <- defn_groups ]
252       where
253         ann_bind (NonRec binder rhs)
254           = NonRec binder (mkSCC dupd_cc rhs)
255
256         ann_bind (Rec pairs)
257           = Rec [(binder, mkSCC dupd_cc rhs) | (binder, rhs) <- pairs]
258
259 -- At one time I tried the effect of not float anything out of an InlineMe,
260 -- but it sometimes works badly.  For example, consider PrelArr.done.  It
261 -- has the form         __inline (\d. e)
262 -- where e doesn't mention d.  If we float this to 
263 --      __inline (let x = e in \d. x)
264 -- things are bad.  The inliner doesn't even inline it because it doesn't look
265 -- like a head-normal form.  So it seems a lesser evil to let things float.
266 -- In SetLevels we do set the context to (Level 0 0) when we get to an InlineMe
267 -- which discourages floating out.
268
269 floatExpr env lvl (Note note expr)      -- Other than SCCs
270   = case (floatExpr env lvl expr)    of { (fs, floating_defns, expr') ->
271     (fs, floating_defns, Note note expr') }
272
273 floatExpr env lvl (Let bind body)
274   = case (floatBind env     lvl bind) of { (fsb, rhs_floats, bind', new_env) ->
275     case (floatExpr new_env lvl body) of { (fse, body_floats, body') ->
276     (add_stats fsb fse,
277      rhs_floats ++ [(bind_lvl, bind')] ++ body_floats,
278      body')
279     }}
280   where
281     bind_lvl = getBindLevel bind
282
283 floatExpr env lvl (Case scrut (case_bndr, case_lvl) alts)
284   = case floatExpr env lvl scrut        of { (fse, fde, scrut') ->
285     case floatList float_alt alts       of { (fsa, fda, alts')  ->
286     (add_stats fse fsa, fda ++ fde, Case scrut' case_bndr alts')
287     }}
288   where
289       alts_env = extendVarEnv env case_bndr case_lvl
290
291       partition_fn = partitionByMajorLevel
292
293       float_alt (con, bs, rhs)
294         = let
295               bs' = map fst bs
296               new_env = extendVarEnvList alts_env bs
297           in
298           case (floatExpr new_env case_lvl rhs)         of { (fs, rhs_floats, rhs') ->
299           case (partition_fn case_lvl rhs_floats)       of { (rhs_floats', heres) ->
300           (fs, rhs_floats', (con, bs', install heres rhs')) }}
301
302
303 floatList :: (a -> (FloatStats, FloatBinds, b)) -> [a] -> (FloatStats, FloatBinds, [b])
304 floatList f [] = (zeroStats, [], [])
305 floatList f (a:as) = case f a            of { (fs_a,  binds_a,  b)  ->
306                      case floatList f as of { (fs_as, binds_as, bs) ->
307                      (fs_a `add_stats` fs_as, binds_a ++ binds_as, b:bs) }}
308 \end{code}
309
310 %************************************************************************
311 %*                                                                      *
312 \subsection{Utility bits for floating stats}
313 %*                                                                      *
314 %************************************************************************
315
316 I didn't implement this with unboxed numbers.  I don't want to be too
317 strict in this stuff, as it is rarely turned on.  (WDP 95/09)
318
319 \begin{code}
320 data FloatStats
321   = FlS Int  -- Number of top-floats * lambda groups they've been past
322         Int  -- Number of non-top-floats * lambda groups they've been past
323         Int  -- Number of lambda (groups) seen
324
325 get_stats (FlS a b c) = (a, b, c)
326
327 zeroStats = FlS 0 0 0
328
329 sum_stats xs = foldr add_stats zeroStats xs
330
331 add_stats (FlS a1 b1 c1) (FlS a2 b2 c2)
332   = FlS (a1 + a2) (b1 + b2) (c1 + c2)
333
334 add_to_stats (FlS a b c) floats
335   = FlS (a + length top_floats) (b + length other_floats) (c + 1)
336   where
337     (top_floats, other_floats) = partition to_very_top floats
338
339     to_very_top (my_lvl, _) = isTopLvl my_lvl
340 \end{code}
341
342
343 %************************************************************************
344 %*                                                                      *
345 \subsection{Utility bits for floating}
346 %*                                                                      *
347 %************************************************************************
348
349 \begin{code}
350 getBindLevel (NonRec (_, lvl) _)      = lvl
351 getBindLevel (Rec (((_,lvl), _) : _)) = lvl
352 \end{code}
353
354 \begin{code}
355 partitionByMajorLevel, partitionByLevel
356         :: Level                -- Partitioning level
357
358         -> FloatBinds           -- Defns to be divided into 2 piles...
359
360         -> (FloatBinds, -- Defns  with level strictly < partition level,
361             FloatBinds) -- The rest
362
363
364 partitionByMajorLevel ctxt_lvl defns
365   = partition float_further defns
366   where
367         -- Float it if we escape a value lambda, 
368         -- or if we get to the top level
369     float_further (my_lvl, bind) = my_lvl `ltMajLvl` ctxt_lvl || isTopLvl my_lvl
370         -- The isTopLvl part says that if we can get to the top level, say "yes" anyway
371         -- This means that 
372         --      x = f e
373         -- transforms to 
374         --    lvl = e
375         --    x = f lvl
376         -- which is as it should be
377
378 partitionByLevel ctxt_lvl defns
379   = partition float_further defns
380   where
381     float_further (my_lvl, _) = my_lvl `ltLvl` ctxt_lvl
382 \end{code}
383
384 \begin{code}
385 floatsToBinds :: FloatBinds -> [CoreBind]
386 floatsToBinds floats = map snd floats
387
388 floatsToBindPairs :: FloatBinds -> [(Id,CoreExpr)]
389
390 floatsToBindPairs floats = concat (map mk_pairs floats)
391   where
392    mk_pairs (_, Rec pairs)         = pairs
393    mk_pairs (_, NonRec binder rhs) = [(binder,rhs)]
394
395 install :: FloatBinds -> CoreExpr -> CoreExpr
396
397 install defn_groups expr
398   = foldr install_group expr defn_groups
399   where
400     install_group (_, defns) body = Let defns body
401 \end{code}