83e5d5a6f1cab88a4bab1f9783d85c2a7723d56c
[ghc-hetmet.git] / ghc / compiler / simplCore / FloatOut.lhs
1 %
2 % (c) The GRASP/AQUA Project, Glasgow University, 1992-1998
3 %
4 \section[FloatOut]{Float bindings outwards (towards the top level)}
5
6 ``Long-distance'' floating of bindings towards the top level.
7
8 \begin{code}
9 module FloatOut ( floatOutwards ) where
10
11 #include "HsVersions.h"
12
13 import CoreSyn
14
15 import CmdLineOpts      ( opt_D_verbose_core2core, opt_D_dump_simpl_stats )
16 import ErrUtils         ( dumpIfSet )
17 import CostCentre       ( dupifyCC, CostCentre )
18 import Id               ( Id, idType )
19 import Const            ( isWHNFCon )
20 import VarEnv
21 import CoreLint         ( beginPass, endPass )
22 import PprCore
23 import SetLevels        ( setLevels,
24                           Level(..), tOP_LEVEL, ltMajLvl, ltLvl, isTopLvl
25                         )
26 import BasicTypes       ( Unused )
27 import Type             ( isUnLiftedType )
28 import Var              ( TyVar )
29 import UniqSupply       ( UniqSupply )
30 import List             ( partition )
31 import Outputable
32 \end{code}
33
34 Random comments
35 ~~~~~~~~~~~~~~~
36
37 At the moment we never float a binding out to between two adjacent
38 lambdas.  For example:
39
40 @
41         \x y -> let t = x+x in ...
42 ===>
43         \x -> let t = x+x in \y -> ...
44 @
45 Reason: this is less efficient in the case where the original lambda
46 is never partially applied.
47
48 But there's a case I've seen where this might not be true.  Consider:
49 @
50 elEm2 x ys
51   = elem' x ys
52   where
53     elem' _ []  = False
54     elem' x (y:ys)      = x==y || elem' x ys
55 @
56 It turns out that this generates a subexpression of the form
57 @
58         \deq x ys -> let eq = eqFromEqDict deq in ...
59 @
60 vwhich might usefully be separated to
61 @
62         \deq -> let eq = eqFromEqDict deq in \xy -> ...
63 @
64 Well, maybe.  We don't do this at the moment.
65
66 \begin{code}
67 type LevelledExpr  = TaggedExpr Level
68 type LevelledBind  = TaggedBind Level
69 type FloatBind     = (Level, CoreBind)
70 type FloatBinds    = [FloatBind]
71 \end{code}
72
73 %************************************************************************
74 %*                                                                      *
75 \subsection[floatOutwards]{@floatOutwards@: let-floating interface function}
76 %*                                                                      *
77 %************************************************************************
78
79 \begin{code}
80 floatOutwards :: UniqSupply -> [CoreBind] -> IO [CoreBind]
81
82 floatOutwards us pgm
83   = do {
84         beginPass "Float out";
85
86         let { annotated_w_levels = setLevels pgm us ;
87               (fss, binds_s')    = unzip (map floatTopBind annotated_w_levels)
88             } ;
89
90         dumpIfSet opt_D_verbose_core2core "Levels added:"
91                   (vcat (map ppr annotated_w_levels));
92
93         let { (tlets, ntlets, lams) = get_stats (sum_stats fss) };
94
95         dumpIfSet opt_D_dump_simpl_stats "FloatOut stats:"
96                 (hcat [ int tlets,  ptext SLIT(" Lets floated to top level; "),
97                         int ntlets, ptext SLIT(" Lets floated elsewhere; from "),
98                         int lams,   ptext SLIT(" Lambda groups")]);
99
100         endPass "Float out" 
101                 opt_D_verbose_core2core         {- no specific flag for dumping float-out -} 
102                 (concat binds_s')
103     }
104
105 floatTopBind bind@(NonRec _ _)
106   = case (floatBind emptyVarEnv tOP_LEVEL bind) of { (fs, floats, bind', _) ->
107     (fs, floatsToBinds floats ++ [bind'])
108     }
109
110 floatTopBind bind@(Rec _)
111   = case (floatBind emptyVarEnv tOP_LEVEL bind) of { (fs, floats, Rec pairs', _) ->
112         -- Actually floats will be empty
113     --false:ASSERT(null floats)
114     (fs, [Rec (floatsToBindPairs floats ++ pairs')])
115     }
116 \end{code}
117
118 %************************************************************************
119 %*                                                                      *
120 \subsection[FloatOut-Bind]{Floating in a binding (the business end)}
121 %*                                                                      *
122 %************************************************************************
123
124
125 \begin{code}
126 floatBind :: IdEnv Level
127           -> Level
128           -> LevelledBind
129           -> (FloatStats, FloatBinds, CoreBind, IdEnv Level)
130
131 floatBind env lvl (NonRec (name,level) rhs)
132   = case (floatRhs env level rhs) of { (fs, rhs_floats, rhs') ->
133     (fs, rhs_floats,
134      NonRec name rhs',
135      extendVarEnv env name level)
136     }
137
138 floatBind env lvl bind@(Rec pairs)
139   = case (unzip3 (map do_pair pairs)) of { (fss, rhss_floats, new_pairs) ->
140
141     if not (isTopLvl bind_level) then
142         -- Standard case
143         (sum_stats fss, concat rhss_floats, Rec new_pairs, new_env)
144     else
145         {- In a recursive binding, destined for the top level (only),
146            the rhs floats may contain
147            references to the bound things.  For example
148
149                 f = ...(let v = ...f... in b) ...
150
151            might get floated to
152
153                 v = ...f...
154                 f = ... b ...
155
156            and hence we must (pessimistically) make all the floats recursive
157            with the top binding.  Later dependency analysis will unravel it.
158         -}
159
160         (sum_stats fss,
161          [],
162          Rec (new_pairs ++ floatsToBindPairs (concat rhss_floats)),
163          new_env)
164
165     }
166   where
167     new_env = extendVarEnvList env (map fst pairs)
168
169     bind_level = getBindLevel bind
170
171     do_pair ((name, level), rhs)
172       = case (floatRhs new_env level rhs) of { (fs, rhs_floats, rhs') ->
173         (fs, rhs_floats, (name, rhs'))
174         }
175 \end{code}
176
177 %************************************************************************
178
179 \subsection[FloatOut-Expr]{Floating in expressions}
180 %*                                                                      *
181 %************************************************************************
182
183 \begin{code}
184 floatExpr, floatRhs
185          :: IdEnv Level
186          -> Level
187          -> LevelledExpr
188          -> (FloatStats, FloatBinds, CoreExpr)
189
190 floatRhs env lvl arg
191   = case (floatExpr env lvl arg) of { (fsa, floats, arg') ->
192     case (partitionByMajorLevel lvl floats) of { (floats', heres) ->
193         -- Dump bindings that aren't going to escape from a lambda
194         -- This is to avoid floating the x binding out of
195         --      f (let x = e in b)
196         -- unnecessarily.  It even causes a bug to do so if we have
197         --      y = writeArr# a n (let x = e in b)
198         -- because the y binding is an expr-ok-for-speculation one.
199     (fsa, floats', install heres arg') }}
200
201 floatExpr env _ (Var v)      = (zeroStats, [], Var v)
202 floatExpr env _ (Type ty)    = (zeroStats, [], Type ty)
203 floatExpr env lvl (Con con as) 
204   = case floatList (floatRhs env lvl) as of { (stats, floats, as') ->
205     (stats, floats, Con con as') }
206           
207 floatExpr env lvl (App e a)
208   = case (floatExpr env lvl e) of { (fse, floats_e, e') ->
209     case (floatRhs env lvl a) of { (fsa, floats_a, a') ->
210     (fse `add_stats` fsa, floats_e ++ floats_a, App e' a') }}
211
212 floatExpr env lvl (Lam (tv,incd_lvl) e)
213   | isTyVar tv
214   = case (floatExpr env incd_lvl e) of { (fs, floats, e') ->
215
216         -- Dump any bindings which absolutely cannot go any further
217     case (partitionByLevel incd_lvl floats)     of { (floats', heres) ->
218
219     (fs, floats', Lam tv (install heres e'))
220     }}
221
222 floatExpr env lvl (Lam (arg,incd_lvl) rhs)
223   = ASSERT( isId arg )
224     let
225         new_env  = extendVarEnv env arg incd_lvl
226     in
227     case (floatExpr new_env incd_lvl rhs) of { (fs, floats, rhs') ->
228
229         -- Dump any bindings which absolutely cannot go any further
230     case (partitionByLevel incd_lvl floats)     of { (floats', heres) ->
231
232     (add_to_stats fs floats',
233      floats',
234      Lam arg (install heres rhs'))
235     }}
236
237 floatExpr env lvl (Note note@(SCC cc) expr)
238   = case (floatExpr env lvl expr)    of { (fs, floating_defns, expr') ->
239     let
240         -- Annotate bindings floated outwards past an scc expression
241         -- with the cc.  We mark that cc as "duplicated", though.
242
243         annotated_defns = annotate (dupifyCC cc) floating_defns
244     in
245     (fs, annotated_defns, Note note expr') }
246   where
247     annotate :: CostCentre -> FloatBinds -> FloatBinds
248
249     annotate dupd_cc defn_groups
250       = [ (level, ann_bind floater) | (level, floater) <- defn_groups ]
251       where
252         ann_bind (NonRec binder rhs)
253           = NonRec binder (ann_rhs rhs)
254
255         ann_bind (Rec pairs)
256           = Rec [(binder, ann_rhs rhs) | (binder, rhs) <- pairs]
257
258         ann_rhs (Lam arg e)     = Lam arg (ann_rhs e)
259         ann_rhs rhs@(Con con _) | isWHNFCon con = rhs   -- no point in scc'ing WHNF data
260         ann_rhs rhs             = Note (SCC dupd_cc) rhs
261
262         -- Note: Nested SCC's are preserved for the benefit of
263         --       cost centre stack profiling (Durham)
264
265 -- At one time I tried the effect of not float anything out of an InlineMe,
266 -- but it sometimes works badly.  For example, consider PrelArr.done.  It
267 -- has the form         __inline (\d. e)
268 -- where e doesn't mention d.  If we float this to 
269 --      __inline (let x = e in \d. x)
270 -- things are bad.  The inliner doesn't even inline it because it doesn't look
271 -- like a head-normal form.  So it seems a lesser evil to let things float.
272 -- In SetLevels we do set the context to (Level 0 0) when we get to an InlineMe
273 -- which discourages floating out.
274
275 floatExpr env lvl (Note note expr)      -- Other than SCCs
276   = case (floatExpr env lvl expr)    of { (fs, floating_defns, expr') ->
277     (fs, floating_defns, Note note expr') }
278
279 floatExpr env lvl (Let bind body)
280   = case (floatBind env     lvl bind) of { (fsb, rhs_floats, bind', new_env) ->
281     case (floatExpr new_env lvl body) of { (fse, body_floats, body') ->
282     (add_stats fsb fse,
283      rhs_floats ++ [(bind_lvl, bind')] ++ body_floats,
284      body')
285     }}
286   where
287     bind_lvl = getBindLevel bind
288
289 floatExpr env lvl (Case scrut (case_bndr, case_lvl) alts)
290   = case floatExpr env lvl scrut        of { (fse, fde, scrut') ->
291     case floatList float_alt alts       of { (fsa, fda, alts')  ->
292     (add_stats fse fsa, fda ++ fde, Case scrut' case_bndr alts')
293     }}
294   where
295       alts_env = extendVarEnv env case_bndr case_lvl
296
297       partition_fn = partitionByMajorLevel
298
299       float_alt (con, bs, rhs)
300         = let
301               bs' = map fst bs
302               new_env = extendVarEnvList alts_env bs
303           in
304           case (floatExpr new_env case_lvl rhs)         of { (fs, rhs_floats, rhs') ->
305           case (partition_fn case_lvl rhs_floats)       of { (rhs_floats', heres) ->
306           (fs, rhs_floats', (con, bs', install heres rhs')) }}
307
308
309 floatList :: (a -> (FloatStats, FloatBinds, b)) -> [a] -> (FloatStats, FloatBinds, [b])
310 floatList f [] = (zeroStats, [], [])
311 floatList f (a:as) = case f a            of { (fs_a,  binds_a,  b)  ->
312                      case floatList f as of { (fs_as, binds_as, bs) ->
313                      (fs_a `add_stats` fs_as, binds_a ++ binds_as, b:bs) }}
314 \end{code}
315
316 %************************************************************************
317 %*                                                                      *
318 \subsection{Utility bits for floating stats}
319 %*                                                                      *
320 %************************************************************************
321
322 I didn't implement this with unboxed numbers.  I don't want to be too
323 strict in this stuff, as it is rarely turned on.  (WDP 95/09)
324
325 \begin{code}
326 data FloatStats
327   = FlS Int  -- Number of top-floats * lambda groups they've been past
328         Int  -- Number of non-top-floats * lambda groups they've been past
329         Int  -- Number of lambda (groups) seen
330
331 get_stats (FlS a b c) = (a, b, c)
332
333 zeroStats = FlS 0 0 0
334
335 sum_stats xs = foldr add_stats zeroStats xs
336
337 add_stats (FlS a1 b1 c1) (FlS a2 b2 c2)
338   = FlS (a1 + a2) (b1 + b2) (c1 + c2)
339
340 add_to_stats (FlS a b c) floats
341   = FlS (a + length top_floats) (b + length other_floats) (c + 1)
342   where
343     (top_floats, other_floats) = partition to_very_top floats
344
345     to_very_top (my_lvl, _) = isTopLvl my_lvl
346 \end{code}
347
348
349 %************************************************************************
350 %*                                                                      *
351 \subsection{Utility bits for floating}
352 %*                                                                      *
353 %************************************************************************
354
355 \begin{code}
356 getBindLevel (NonRec (_, lvl) _)      = lvl
357 getBindLevel (Rec (((_,lvl), _) : _)) = lvl
358 \end{code}
359
360 \begin{code}
361 partitionByMajorLevel, partitionByLevel
362         :: Level                -- Partitioning level
363
364         -> FloatBinds           -- Defns to be divided into 2 piles...
365
366         -> (FloatBinds, -- Defns  with level strictly < partition level,
367             FloatBinds) -- The rest
368
369
370 partitionByMajorLevel ctxt_lvl defns
371   = partition float_further defns
372   where
373         -- Float it if we escape a value lambda, 
374         -- or if we get to the top level
375     float_further (my_lvl, bind) = my_lvl `ltMajLvl` ctxt_lvl || isTopLvl my_lvl
376         -- The isTopLvl part says that if we can get to the top level, say "yes" anyway
377         -- This means that 
378         --      x = f e
379         -- transforms to 
380         --    lvl = e
381         --    x = f lvl
382         -- which is as it should be
383
384 partitionByLevel ctxt_lvl defns
385   = partition float_further defns
386   where
387     float_further (my_lvl, _) = my_lvl `ltLvl` ctxt_lvl
388 \end{code}
389
390 \begin{code}
391 floatsToBinds :: FloatBinds -> [CoreBind]
392 floatsToBinds floats = map snd floats
393
394 floatsToBindPairs :: FloatBinds -> [(Id,CoreExpr)]
395
396 floatsToBindPairs floats = concat (map mk_pairs floats)
397   where
398    mk_pairs (_, Rec pairs)         = pairs
399    mk_pairs (_, NonRec binder rhs) = [(binder,rhs)]
400
401 install :: FloatBinds -> CoreExpr -> CoreExpr
402
403 install defn_groups expr
404   = foldr install_group expr defn_groups
405   where
406     install_group (_, defns) body = Let defns body
407 \end{code}