[project @ 2001-09-26 16:19:28 by simonpj]
[ghc-hetmet.git] / ghc / compiler / simplCore / FloatOut.lhs
1 %
2 % (c) The GRASP/AQUA Project, Glasgow University, 1992-1998
3 %
4 \section[FloatOut]{Float bindings outwards (towards the top level)}
5
6 ``Long-distance'' floating of bindings towards the top level.
7
8 \begin{code}
9 module FloatOut ( floatOutwards ) where
10
11 #include "HsVersions.h"
12
13 import CoreSyn
14 import CoreUtils        ( mkSCC )
15
16 import CmdLineOpts      ( DynFlags, DynFlag(..) )
17 import ErrUtils         ( dumpIfSet_dyn )
18 import CostCentre       ( dupifyCC, CostCentre )
19 import Id               ( Id )
20 import VarEnv
21 import CoreLint         ( showPass, endPass )
22 import SetLevels        ( setLevels,
23                           Level(..), tOP_LEVEL, ltMajLvl, ltLvl, isTopLvl
24                         )
25 import UniqSupply       ( UniqSupply )
26 import List             ( partition )
27 import Outputable
28 \end{code}
29
30 Random comments
31 ~~~~~~~~~~~~~~~
32
33 At the moment we never float a binding out to between two adjacent
34 lambdas.  For example:
35
36 @
37         \x y -> let t = x+x in ...
38 ===>
39         \x -> let t = x+x in \y -> ...
40 @
41 Reason: this is less efficient in the case where the original lambda
42 is never partially applied.
43
44 But there's a case I've seen where this might not be true.  Consider:
45 @
46 elEm2 x ys
47   = elem' x ys
48   where
49     elem' _ []  = False
50     elem' x (y:ys)      = x==y || elem' x ys
51 @
52 It turns out that this generates a subexpression of the form
53 @
54         \deq x ys -> let eq = eqFromEqDict deq in ...
55 @
56 vwhich might usefully be separated to
57 @
58         \deq -> let eq = eqFromEqDict deq in \xy -> ...
59 @
60 Well, maybe.  We don't do this at the moment.
61
62 \begin{code}
63 type LevelledExpr  = TaggedExpr Level
64 type LevelledBind  = TaggedBind Level
65 type FloatBind     = (Level, CoreBind)
66 type FloatBinds    = [FloatBind]
67 \end{code}
68
69 %************************************************************************
70 %*                                                                      *
71 \subsection[floatOutwards]{@floatOutwards@: let-floating interface function}
72 %*                                                                      *
73 %************************************************************************
74
75 \begin{code}
76 floatOutwards :: DynFlags
77               -> Bool           -- True <=> float lambdas to top level
78               -> UniqSupply 
79               -> [CoreBind] -> IO [CoreBind]
80
81 floatOutwards dflags float_lams us pgm
82   = do {
83         showPass dflags float_msg ;
84
85         let { annotated_w_levels = setLevels float_lams pgm us ;
86               (fss, binds_s')    = unzip (map floatTopBind annotated_w_levels)
87             } ;
88
89         dumpIfSet_dyn dflags Opt_D_verbose_core2core "Levels added:"
90                   (vcat (map ppr annotated_w_levels));
91
92         let { (tlets, ntlets, lams) = get_stats (sum_stats fss) };
93
94         dumpIfSet_dyn dflags Opt_D_dump_simpl_stats "FloatOut stats:"
95                 (hcat [ int tlets,  ptext SLIT(" Lets floated to top level; "),
96                         int ntlets, ptext SLIT(" Lets floated elsewhere; from "),
97                         int lams,   ptext SLIT(" Lambda groups")]);
98
99         endPass dflags float_msg  Opt_D_verbose_core2core (concat binds_s')
100                         {- no specific flag for dumping float-out -} 
101     }
102   where
103     float_msg | float_lams = "Float out (floating lambdas too)"
104               | otherwise  = "Float out (not floating lambdas)"
105
106 floatTopBind bind@(NonRec _ _)
107   = case (floatBind emptyVarEnv tOP_LEVEL bind) of { (fs, floats, bind', _) ->
108     (fs, floatsToBinds floats ++ [bind'])
109     }
110
111 floatTopBind bind@(Rec _)
112   = case (floatBind emptyVarEnv tOP_LEVEL bind) of { (fs, floats, Rec pairs', _) ->
113         -- Actually floats will be empty
114     --false:ASSERT(null floats)
115     (fs, [Rec (floatsToBindPairs floats ++ pairs')])
116     }
117 \end{code}
118
119 %************************************************************************
120 %*                                                                      *
121 \subsection[FloatOut-Bind]{Floating in a binding (the business end)}
122 %*                                                                      *
123 %************************************************************************
124
125
126 \begin{code}
127 floatBind :: IdEnv Level
128           -> Level
129           -> LevelledBind
130           -> (FloatStats, FloatBinds, CoreBind, IdEnv Level)
131
132 floatBind env lvl (NonRec (name,level) rhs)
133   = case (floatRhs env level rhs) of { (fs, rhs_floats, rhs') ->
134     (fs, rhs_floats,
135      NonRec name rhs',
136      extendVarEnv env name level)
137     }
138
139 floatBind env lvl bind@(Rec pairs)
140   = case (unzip3 (map do_pair pairs)) of { (fss, rhss_floats, new_pairs) ->
141
142     if not (isTopLvl bind_level) then
143         -- Standard case
144         (sum_stats fss, concat rhss_floats, Rec new_pairs, new_env)
145     else
146         {- In a recursive binding, destined for the top level (only),
147            the rhs floats may contain
148            references to the bound things.  For example
149
150                 f = ...(let v = ...f... in b) ...
151
152            might get floated to
153
154                 v = ...f...
155                 f = ... b ...
156
157            and hence we must (pessimistically) make all the floats recursive
158            with the top binding.  Later dependency analysis will unravel it.
159         -}
160
161         (sum_stats fss,
162          [],
163          Rec (new_pairs ++ floatsToBindPairs (concat rhss_floats)),
164          new_env)
165
166     }
167   where
168     new_env = extendVarEnvList env (map fst pairs)
169
170     bind_level = getBindLevel bind
171
172     do_pair ((name, level), rhs)
173       = case (floatRhs new_env level rhs) of { (fs, rhs_floats, rhs') ->
174         (fs, rhs_floats, (name, rhs'))
175         }
176 \end{code}
177
178 %************************************************************************
179
180 \subsection[FloatOut-Expr]{Floating in expressions}
181 %*                                                                      *
182 %************************************************************************
183
184 \begin{code}
185 floatExpr, floatRhs
186          :: IdEnv Level
187          -> Level
188          -> LevelledExpr
189          -> (FloatStats, FloatBinds, CoreExpr)
190
191 floatRhs env lvl arg
192   = case (floatExpr env lvl arg) of { (fsa, floats, arg') ->
193     case (partitionByMajorLevel lvl floats) of { (floats', heres) ->
194         -- Dump bindings that aren't going to escape from a lambda
195         -- This is to avoid floating the x binding out of
196         --      f (let x = e in b)
197         -- unnecessarily.  It even causes a bug to do so if we have
198         --      y = writeArr# a n (let x = e in b)
199         -- because the y binding is an expr-ok-for-speculation one.
200     (fsa, floats', install heres arg') }}
201
202 floatExpr env _ (Var v)      = (zeroStats, [], Var v)
203 floatExpr env _ (Type ty)    = (zeroStats, [], Type ty)
204 floatExpr env _ (Lit lit)    = (zeroStats, [], Lit lit)
205           
206 floatExpr env lvl (App e a)
207   = case (floatExpr env lvl e) of { (fse, floats_e, e') ->
208     case (floatRhs env lvl a) of { (fsa, floats_a, a') ->
209     (fse `add_stats` fsa, floats_e ++ floats_a, App e' a') }}
210
211 floatExpr env lvl (Lam (tv,incd_lvl) e)
212   | isTyVar tv
213   = case (floatExpr env incd_lvl e) of { (fs, floats, e') ->
214
215         -- Dump any bindings which absolutely cannot go any further
216     case (partitionByLevel incd_lvl floats)     of { (floats', heres) ->
217
218     (fs, floats', Lam tv (install heres e'))
219     }}
220
221 floatExpr env lvl (Lam (arg,incd_lvl) rhs)
222   = ASSERT( isId arg )
223     let
224         new_env  = extendVarEnv env arg incd_lvl
225     in
226     case (floatExpr new_env incd_lvl rhs) of { (fs, floats, rhs') ->
227
228         -- Dump any bindings which absolutely cannot go any further
229     case (partitionByLevel incd_lvl floats)     of { (floats', heres) ->
230
231     (add_to_stats fs floats',
232      floats',
233      Lam arg (install heres rhs'))
234     }}
235
236 floatExpr env lvl (Note note@(SCC cc) expr)
237   = case (floatExpr env lvl expr)    of { (fs, floating_defns, expr') ->
238     let
239         -- Annotate bindings floated outwards past an scc expression
240         -- with the cc.  We mark that cc as "duplicated", though.
241
242         annotated_defns = annotate (dupifyCC cc) floating_defns
243     in
244     (fs, annotated_defns, Note note expr') }
245   where
246     annotate :: CostCentre -> FloatBinds -> FloatBinds
247
248     annotate dupd_cc defn_groups
249       = [ (level, ann_bind floater) | (level, floater) <- defn_groups ]
250       where
251         ann_bind (NonRec binder rhs)
252           = NonRec binder (mkSCC dupd_cc rhs)
253
254         ann_bind (Rec pairs)
255           = Rec [(binder, mkSCC dupd_cc rhs) | (binder, rhs) <- pairs]
256
257 -- At one time I tried the effect of not float anything out of an InlineMe,
258 -- but it sometimes works badly.  For example, consider PrelArr.done.  It
259 -- has the form         __inline (\d. e)
260 -- where e doesn't mention d.  If we float this to 
261 --      __inline (let x = e in \d. x)
262 -- things are bad.  The inliner doesn't even inline it because it doesn't look
263 -- like a head-normal form.  So it seems a lesser evil to let things float.
264 -- In SetLevels we do set the context to (Level 0 0) when we get to an InlineMe
265 -- which discourages floating out.
266
267 floatExpr env lvl (Note note expr)      -- Other than SCCs
268   = case (floatExpr env lvl expr)    of { (fs, floating_defns, expr') ->
269     (fs, floating_defns, Note note expr') }
270
271 floatExpr env lvl (Let bind body)
272   = case (floatBind env     lvl bind) of { (fsb, rhs_floats, bind', new_env) ->
273     case (floatExpr new_env lvl body) of { (fse, body_floats, body') ->
274     (add_stats fsb fse,
275      rhs_floats ++ [(bind_lvl, bind')] ++ body_floats,
276      body')
277     }}
278   where
279     bind_lvl = getBindLevel bind
280
281 floatExpr env lvl (Case scrut (case_bndr, case_lvl) alts)
282   = case floatExpr env lvl scrut        of { (fse, fde, scrut') ->
283     case floatList float_alt alts       of { (fsa, fda, alts')  ->
284     (add_stats fse fsa, fda ++ fde, Case scrut' case_bndr alts')
285     }}
286   where
287       alts_env = extendVarEnv env case_bndr case_lvl
288
289       partition_fn = partitionByMajorLevel
290
291       float_alt (con, bs, rhs)
292         = let
293               bs' = map fst bs
294               new_env = extendVarEnvList alts_env bs
295           in
296           case (floatExpr new_env case_lvl rhs)         of { (fs, rhs_floats, rhs') ->
297           case (partition_fn case_lvl rhs_floats)       of { (rhs_floats', heres) ->
298           (fs, rhs_floats', (con, bs', install heres rhs')) }}
299
300
301 floatList :: (a -> (FloatStats, FloatBinds, b)) -> [a] -> (FloatStats, FloatBinds, [b])
302 floatList f [] = (zeroStats, [], [])
303 floatList f (a:as) = case f a            of { (fs_a,  binds_a,  b)  ->
304                      case floatList f as of { (fs_as, binds_as, bs) ->
305                      (fs_a `add_stats` fs_as, binds_a ++ binds_as, b:bs) }}
306 \end{code}
307
308 %************************************************************************
309 %*                                                                      *
310 \subsection{Utility bits for floating stats}
311 %*                                                                      *
312 %************************************************************************
313
314 I didn't implement this with unboxed numbers.  I don't want to be too
315 strict in this stuff, as it is rarely turned on.  (WDP 95/09)
316
317 \begin{code}
318 data FloatStats
319   = FlS Int  -- Number of top-floats * lambda groups they've been past
320         Int  -- Number of non-top-floats * lambda groups they've been past
321         Int  -- Number of lambda (groups) seen
322
323 get_stats (FlS a b c) = (a, b, c)
324
325 zeroStats = FlS 0 0 0
326
327 sum_stats xs = foldr add_stats zeroStats xs
328
329 add_stats (FlS a1 b1 c1) (FlS a2 b2 c2)
330   = FlS (a1 + a2) (b1 + b2) (c1 + c2)
331
332 add_to_stats (FlS a b c) floats
333   = FlS (a + length top_floats) (b + length other_floats) (c + 1)
334   where
335     (top_floats, other_floats) = partition to_very_top floats
336
337     to_very_top (my_lvl, _) = isTopLvl my_lvl
338 \end{code}
339
340
341 %************************************************************************
342 %*                                                                      *
343 \subsection{Utility bits for floating}
344 %*                                                                      *
345 %************************************************************************
346
347 \begin{code}
348 getBindLevel (NonRec (_, lvl) _)      = lvl
349 getBindLevel (Rec (((_,lvl), _) : _)) = lvl
350 \end{code}
351
352 \begin{code}
353 partitionByMajorLevel, partitionByLevel
354         :: Level                -- Partitioning level
355
356         -> FloatBinds           -- Defns to be divided into 2 piles...
357
358         -> (FloatBinds, -- Defns  with level strictly < partition level,
359             FloatBinds) -- The rest
360
361
362 partitionByMajorLevel ctxt_lvl defns
363   = partition float_further defns
364   where
365         -- Float it if we escape a value lambda, 
366         -- or if we get to the top level
367     float_further (my_lvl, bind) = my_lvl `ltMajLvl` ctxt_lvl || isTopLvl my_lvl
368         -- The isTopLvl part says that if we can get to the top level, say "yes" anyway
369         -- This means that 
370         --      x = f e
371         -- transforms to 
372         --    lvl = e
373         --    x = f lvl
374         -- which is as it should be
375
376 partitionByLevel ctxt_lvl defns
377   = partition float_further defns
378   where
379     float_further (my_lvl, _) = my_lvl `ltLvl` ctxt_lvl
380 \end{code}
381
382 \begin{code}
383 floatsToBinds :: FloatBinds -> [CoreBind]
384 floatsToBinds floats = map snd floats
385
386 floatsToBindPairs :: FloatBinds -> [(Id,CoreExpr)]
387
388 floatsToBindPairs floats = concat (map mk_pairs floats)
389   where
390    mk_pairs (_, Rec pairs)         = pairs
391    mk_pairs (_, NonRec binder rhs) = [(binder,rhs)]
392
393 install :: FloatBinds -> CoreExpr -> CoreExpr
394
395 install defn_groups expr
396   = foldr install_group expr defn_groups
397   where
398     install_group (_, defns) body = Let defns body
399 \end{code}