[project @ 1999-07-06 16:45:31 by simonpj]
[ghc-hetmet.git] / ghc / compiler / simplCore / FloatOut.lhs
1 %
2 % (c) The GRASP/AQUA Project, Glasgow University, 1992-1998
3 %
4 \section[FloatOut]{Float bindings outwards (towards the top level)}
5
6 ``Long-distance'' floating of bindings towards the top level.
7
8 \begin{code}
9 module FloatOut ( floatOutwards ) where
10
11 #include "HsVersions.h"
12
13 import CoreSyn
14
15 import CmdLineOpts      ( opt_D_verbose_core2core, opt_D_dump_simpl_stats )
16 import ErrUtils         ( dumpIfSet )
17 import CostCentre       ( dupifyCC, CostCentre )
18 import Id               ( Id )
19 import Const            ( isWHNFCon )
20 import VarEnv
21 import CoreLint         ( beginPass, endPass )
22 import PprCore
23 import SetLevels        ( setLevels,
24                           Level(..), tOP_LEVEL, ltMajLvl, ltLvl, isTopLvl
25                         )
26 import BasicTypes       ( Unused )
27 import Var              ( TyVar )
28 import UniqSupply       ( UniqSupply )
29 import List             ( partition )
30 import Outputable
31 \end{code}
32
33 Random comments
34 ~~~~~~~~~~~~~~~
35
36 At the moment we never float a binding out to between two adjacent
37 lambdas.  For example:
38
39 @
40         \x y -> let t = x+x in ...
41 ===>
42         \x -> let t = x+x in \y -> ...
43 @
44 Reason: this is less efficient in the case where the original lambda
45 is never partially applied.
46
47 But there's a case I've seen where this might not be true.  Consider:
48 @
49 elEm2 x ys
50   = elem' x ys
51   where
52     elem' _ []  = False
53     elem' x (y:ys)      = x==y || elem' x ys
54 @
55 It turns out that this generates a subexpression of the form
56 @
57         \deq x ys -> let eq = eqFromEqDict deq in ...
58 @
59 vwhich might usefully be separated to
60 @
61         \deq -> let eq = eqFromEqDict deq in \xy -> ...
62 @
63 Well, maybe.  We don't do this at the moment.
64
65 \begin{code}
66 type LevelledExpr  = TaggedExpr Level
67 type LevelledBind  = TaggedBind Level
68 type FloatBind     = (Level, CoreBind)
69 type FloatBinds    = [FloatBind]
70 \end{code}
71
72 %************************************************************************
73 %*                                                                      *
74 \subsection[floatOutwards]{@floatOutwards@: let-floating interface function}
75 %*                                                                      *
76 %************************************************************************
77
78 \begin{code}
79 floatOutwards :: UniqSupply -> [CoreBind] -> IO [CoreBind]
80
81 floatOutwards us pgm
82   = do {
83         beginPass "Float out";
84
85         let { annotated_w_levels = setLevels pgm us ;
86               (fss, binds_s')    = unzip (map floatTopBind annotated_w_levels)
87             } ;
88
89         dumpIfSet opt_D_verbose_core2core "Levels added:"
90                   (vcat (map ppr annotated_w_levels));
91
92         let { (tlets, ntlets, lams) = get_stats (sum_stats fss) };
93
94         dumpIfSet opt_D_dump_simpl_stats "FloatOut stats:"
95                 (hcat [ int tlets,  ptext SLIT(" Lets floated to top level; "),
96                         int ntlets, ptext SLIT(" Lets floated elsewhere; from "),
97                         int lams,   ptext SLIT(" Lambda groups")]);
98
99         endPass "Float out" 
100                 opt_D_verbose_core2core         {- no specific flag for dumping float-out -} 
101                 (concat binds_s')
102     }
103
104 floatTopBind bind@(NonRec _ _)
105   = case (floatBind emptyVarEnv tOP_LEVEL bind) of { (fs, floats, bind', _) ->
106     (fs, floatsToBinds floats ++ [bind'])
107     }
108
109 floatTopBind bind@(Rec _)
110   = case (floatBind emptyVarEnv tOP_LEVEL bind) of { (fs, floats, Rec pairs', _) ->
111         -- Actually floats will be empty
112     --false:ASSERT(null floats)
113     (fs, [Rec (floatsToBindPairs floats ++ pairs')])
114     }
115 \end{code}
116
117 %************************************************************************
118 %*                                                                      *
119 \subsection[FloatOut-Bind]{Floating in a binding (the business end)}
120 %*                                                                      *
121 %************************************************************************
122
123
124 \begin{code}
125 floatBind :: IdEnv Level
126           -> Level
127           -> LevelledBind
128           -> (FloatStats, FloatBinds, CoreBind, IdEnv Level)
129
130 floatBind env lvl (NonRec (name,level) rhs)
131   = case (floatRhs env level rhs) of { (fs, rhs_floats, rhs') ->
132     (fs, rhs_floats,
133      NonRec name rhs',
134      extendVarEnv env name level)
135     }
136
137 floatBind env lvl bind@(Rec pairs)
138   = case (unzip3 (map do_pair pairs)) of { (fss, rhss_floats, new_pairs) ->
139
140     if not (isTopLvl bind_level) then
141         -- Standard case
142         (sum_stats fss, concat rhss_floats, Rec new_pairs, new_env)
143     else
144         {- In a recursive binding, destined for the top level (only),
145            the rhs floats may contain
146            references to the bound things.  For example
147
148                 f = ...(let v = ...f... in b) ...
149
150            might get floated to
151
152                 v = ...f...
153                 f = ... b ...
154
155            and hence we must (pessimistically) make all the floats recursive
156            with the top binding.  Later dependency analysis will unravel it.
157         -}
158
159         (sum_stats fss,
160          [],
161          Rec (new_pairs ++ floatsToBindPairs (concat rhss_floats)),
162          new_env)
163
164     }
165   where
166     new_env = extendVarEnvList env (map fst pairs)
167
168     bind_level = getBindLevel bind
169
170     do_pair ((name, level), rhs)
171       = case (floatRhs new_env level rhs) of { (fs, rhs_floats, rhs') ->
172         (fs, rhs_floats, (name, rhs'))
173         }
174 \end{code}
175
176 %************************************************************************
177
178 \subsection[FloatOut-Expr]{Floating in expressions}
179 %*                                                                      *
180 %************************************************************************
181
182 \begin{code}
183 floatExpr, floatRhs
184          :: IdEnv Level
185          -> Level
186          -> LevelledExpr
187          -> (FloatStats, FloatBinds, CoreExpr)
188
189 floatRhs env lvl arg
190   = case (floatExpr env lvl arg) of { (fsa, floats, arg') ->
191     case (partitionByMajorLevel lvl floats) of { (floats', heres) ->
192         -- Dump bindings that aren't going to escape from a lambda
193         -- This is to avoid floating the x binding out of
194         --      f (let x = e in b)
195         -- unnecessarily.  It even causes a bug to do so if we have
196         --      y = writeArr# a n (let x = e in b)
197         -- because the y binding is an expr-ok-for-speculation one.
198     (fsa, floats', install heres arg') }}
199
200 floatExpr env _ (Var v)      = (zeroStats, [], Var v)
201 floatExpr env _ (Type ty)    = (zeroStats, [], Type ty)
202 floatExpr env lvl (Con con as) 
203   = case floatList (floatRhs env lvl) as of { (stats, floats, as') ->
204     (stats, floats, Con con as') }
205           
206 floatExpr env lvl (App e a)
207   = case (floatExpr env lvl e) of { (fse, floats_e, e') ->
208     case (floatRhs env lvl a) of { (fsa, floats_a, a') ->
209     (fse `add_stats` fsa, floats_e ++ floats_a, App e' a') }}
210
211 floatExpr env lvl (Lam (tv,incd_lvl) e)
212   | isTyVar tv
213   = case (floatExpr env incd_lvl e) of { (fs, floats, e') ->
214
215         -- Dump any bindings which absolutely cannot go any further
216     case (partitionByLevel incd_lvl floats)     of { (floats', heres) ->
217
218     (fs, floats', Lam tv (install heres e'))
219     }}
220
221 floatExpr env lvl (Lam (arg,incd_lvl) rhs)
222   = ASSERT( isId arg )
223     let
224         new_env  = extendVarEnv env arg incd_lvl
225     in
226     case (floatExpr new_env incd_lvl rhs) of { (fs, floats, rhs') ->
227
228         -- Dump any bindings which absolutely cannot go any further
229     case (partitionByLevel incd_lvl floats)     of { (floats', heres) ->
230
231     (add_to_stats fs floats',
232      floats',
233      Lam arg (install heres rhs'))
234     }}
235
236 floatExpr env lvl (Note note@(SCC cc) expr)
237   = case (floatExpr env lvl expr)    of { (fs, floating_defns, expr') ->
238     let
239         -- Annotate bindings floated outwards past an scc expression
240         -- with the cc.  We mark that cc as "duplicated", though.
241
242         annotated_defns = annotate (dupifyCC cc) floating_defns
243     in
244     (fs, annotated_defns, Note note expr') }
245   where
246     annotate :: CostCentre -> FloatBinds -> FloatBinds
247
248     annotate dupd_cc defn_groups
249       = [ (level, ann_bind floater) | (level, floater) <- defn_groups ]
250       where
251         ann_bind (NonRec binder rhs)
252           = NonRec binder (ann_rhs rhs)
253
254         ann_bind (Rec pairs)
255           = Rec [(binder, ann_rhs rhs) | (binder, rhs) <- pairs]
256
257         ann_rhs (Lam arg e)     = Lam arg (ann_rhs e)
258         ann_rhs rhs@(Con con _) | isWHNFCon con = rhs   -- no point in scc'ing WHNF data
259         ann_rhs rhs             = Note (SCC dupd_cc) rhs
260
261         -- Note: Nested SCC's are preserved for the benefit of
262         --       cost centre stack profiling (Durham)
263
264 floatExpr env lvl (Note note expr)      -- Other than SCCs
265   = case (floatExpr env lvl expr)    of { (fs, floating_defns, expr') ->
266     (fs, floating_defns, Note note expr') }
267
268 floatExpr env lvl (Let bind body)
269   = case (floatBind env     lvl bind) of { (fsb, rhs_floats, bind', new_env) ->
270     case (floatExpr new_env lvl body) of { (fse, body_floats, body') ->
271     (add_stats fsb fse,
272      rhs_floats ++ [(bind_lvl, bind')] ++ body_floats,
273      body')
274     }}
275   where
276     bind_lvl = getBindLevel bind
277
278 floatExpr env lvl (Case scrut (case_bndr, case_lvl) alts)
279   = case floatExpr env lvl scrut        of { (fse, fde, scrut') ->
280     case floatList float_alt alts       of { (fsa, fda, alts')  ->
281     (add_stats fse fsa, fda ++ fde, Case scrut' case_bndr alts')
282     }}
283   where
284       alts_env = extendVarEnv env case_bndr case_lvl
285
286       partition_fn = partitionByMajorLevel
287
288       float_alt (con, bs, rhs)
289         = let
290               bs' = map fst bs
291               new_env = extendVarEnvList alts_env bs
292           in
293           case (floatExpr new_env case_lvl rhs)         of { (fs, rhs_floats, rhs') ->
294           case (partition_fn case_lvl rhs_floats)       of { (rhs_floats', heres) ->
295           (fs, rhs_floats', (con, bs', install heres rhs')) }}
296
297
298 floatList :: (a -> (FloatStats, FloatBinds, b)) -> [a] -> (FloatStats, FloatBinds, [b])
299 floatList f [] = (zeroStats, [], [])
300 floatList f (a:as) = case f a            of { (fs_a,  binds_a,  b)  ->
301                      case floatList f as of { (fs_as, binds_as, bs) ->
302                      (fs_a `add_stats` fs_as, binds_a ++ binds_as, b:bs) }}
303 \end{code}
304
305 %************************************************************************
306 %*                                                                      *
307 \subsection{Utility bits for floating stats}
308 %*                                                                      *
309 %************************************************************************
310
311 I didn't implement this with unboxed numbers.  I don't want to be too
312 strict in this stuff, as it is rarely turned on.  (WDP 95/09)
313
314 \begin{code}
315 data FloatStats
316   = FlS Int  -- Number of top-floats * lambda groups they've been past
317         Int  -- Number of non-top-floats * lambda groups they've been past
318         Int  -- Number of lambda (groups) seen
319
320 get_stats (FlS a b c) = (a, b, c)
321
322 zeroStats = FlS 0 0 0
323
324 sum_stats xs = foldr add_stats zeroStats xs
325
326 add_stats (FlS a1 b1 c1) (FlS a2 b2 c2)
327   = FlS (a1 + a2) (b1 + b2) (c1 + c2)
328
329 add_to_stats (FlS a b c) floats
330   = FlS (a + length top_floats) (b + length other_floats) (c + 1)
331   where
332     (top_floats, other_floats) = partition to_very_top floats
333
334     to_very_top (my_lvl, _) = isTopLvl my_lvl
335 \end{code}
336
337
338 %************************************************************************
339 %*                                                                      *
340 \subsection{Utility bits for floating}
341 %*                                                                      *
342 %************************************************************************
343
344 \begin{code}
345 getBindLevel (NonRec (_, lvl) _)      = lvl
346 getBindLevel (Rec (((_,lvl), _) : _)) = lvl
347 \end{code}
348
349 \begin{code}
350 partitionByMajorLevel, partitionByLevel
351         :: Level                -- Partitioning level
352
353         -> FloatBinds           -- Defns to be divided into 2 piles...
354
355         -> (FloatBinds, -- Defns  with level strictly < partition level,
356             FloatBinds) -- The rest
357
358
359 partitionByMajorLevel ctxt_lvl defns
360   = partition float_further defns
361   where
362     float_further (my_lvl, _) = my_lvl `lt_major` ctxt_lvl
363
364 my_lvl `lt_major`  ctxt_lvl = my_lvl `ltMajLvl` ctxt_lvl ||
365                               isTopLvl my_lvl
366
367 partitionByLevel ctxt_lvl defns
368   = partition float_further defns
369   where
370     float_further (my_lvl, _) = my_lvl `ltLvl` ctxt_lvl
371 \end{code}
372
373 \begin{code}
374 floatsToBinds :: FloatBinds -> [CoreBind]
375 floatsToBinds floats = map snd floats
376
377 floatsToBindPairs :: FloatBinds -> [(Id,CoreExpr)]
378
379 floatsToBindPairs floats = concat (map mk_pairs floats)
380   where
381    mk_pairs (_, Rec pairs)         = pairs
382    mk_pairs (_, NonRec binder rhs) = [(binder,rhs)]
383
384 install :: FloatBinds -> CoreExpr -> CoreExpr
385
386 install defn_groups expr
387   = foldr install_group expr defn_groups
388   where
389     install_group (_, defns) body = Let defns body
390 \end{code}