[project @ 1998-01-08 18:03:08 by simonm]
[ghc-hetmet.git] / ghc / compiler / simplCore / FloatOut.lhs
1 %
2 % (c) The GRASP/AQUA Project, Glasgow University, 1992-1996
3 %
4 \section[FloatOut]{Float bindings outwards (towards the top level)}
5
6 ``Long-distance'' floating of bindings towards the top level.
7
8 \begin{code}
9 module FloatOut ( floatOutwards ) where
10
11 #include "HsVersions.h"
12
13 import CoreSyn
14
15 import CmdLineOpts      ( opt_D_verbose_core2core, opt_D_simplifier_stats )
16 import CostCentre       ( dupifyCC, CostCentre )
17 import Id               ( nullIdEnv, addOneToIdEnv, growIdEnvList, IdEnv,
18                           GenId{-instance Outputable-}, Id
19                         )
20 import PprCore
21 import PprType          ( GenTyVar )
22 import SetLevels        -- all of it
23 import BasicTypes       ( Unused )
24 import TyVar            ( GenTyVar{-instance Eq-}, TyVar )
25 import Unique           ( Unique{-instance Eq-} )
26 import UniqSupply       ( UniqSupply )
27 import List             ( partition )
28 import Outputable
29 \end{code}
30
31 Random comments
32 ~~~~~~~~~~~~~~~
33
34 At the moment we never float a binding out to between two adjacent
35 lambdas.  For example:
36
37 @
38         \x y -> let t = x+x in ...
39 ===>
40         \x -> let t = x+x in \y -> ...
41 @
42 Reason: this is less efficient in the case where the original lambda
43 is never partially applied.
44
45 But there's a case I've seen where this might not be true.  Consider:
46 @
47 elEm2 x ys
48   = elem' x ys
49   where
50     elem' _ []  = False
51     elem' x (y:ys)      = x==y || elem' x ys
52 @
53 It turns out that this generates a subexpression of the form
54 @
55         \deq x ys -> let eq = eqFromEqDict deq in ...
56 @
57 which might usefully be separated to
58 @
59         \deq -> let eq = eqFromEqDict deq in \xy -> ...
60 @
61 Well, maybe.  We don't do this at the moment.
62
63 \begin{code}
64 type LevelledExpr  = GenCoreExpr    (Id, Level) Id Unused
65 type LevelledBind  = GenCoreBinding (Id, Level) Id Unused
66 type FloatingBind  = (Level, Floater)
67 type FloatingBinds = [FloatingBind]
68
69 data Floater
70   = LetFloater  CoreBinding
71   | CaseFloater (CoreExpr -> CoreExpr)
72                 -- A CoreExpr with a hole in it:
73                 -- "Give me a right-hand side of the
74                 -- (usually single) alternative, and
75                 -- I'll build the case..."
76 \end{code}
77
78 %************************************************************************
79 %*                                                                      *
80 \subsection[floatOutwards]{@floatOutwards@: let-floating interface function}
81 %*                                                                      *
82 %************************************************************************
83
84 \begin{code}
85 floatOutwards :: UniqSupply -> [CoreBinding] -> [CoreBinding]
86
87 floatOutwards us pgm
88   = case (setLevels pgm us) of { annotated_w_levels ->
89
90     case (unzip (map floatTopBind annotated_w_levels))
91                 of { (fss, final_toplev_binds_s) ->
92
93     (if opt_D_verbose_core2core
94      then pprTrace "Levels added:\n"
95                    (vcat (map (ppr) annotated_w_levels))
96      else id
97     )
98     ( if not (opt_D_simplifier_stats) then
99          id
100       else
101          let
102             (tlets, ntlets, lams) = get_stats (sum_stats fss)
103          in
104          pprTrace "FloatOut stats: " (hcat [
105                 int tlets,  ptext SLIT(" Lets floated to top level; "),
106                 int ntlets, ptext SLIT(" Lets floated elsewhere; from "),
107                 int lams,   ptext SLIT(" Lambda groups")])
108     )
109     concat final_toplev_binds_s
110     }}
111
112 floatTopBind bind@(NonRec _ _)
113   = case (floatBind nullIdEnv tOP_LEVEL bind) of { (fs, floats, bind', _) ->
114     (fs, floatsToBinds floats ++ [bind'])
115     }
116
117 floatTopBind bind@(Rec _)
118   = case (floatBind nullIdEnv tOP_LEVEL bind) of { (fs, floats, Rec pairs', _) ->
119         -- Actually floats will be empty
120     --false:ASSERT(null floats)
121     (fs, [Rec (floatsToBindPairs floats ++ pairs')])
122     }
123 \end{code}
124
125 %************************************************************************
126 %*                                                                      *
127 \subsection[FloatOut-Bind]{Floating in a binding (the business end)}
128 %*                                                                      *
129 %************************************************************************
130
131
132 \begin{code}
133 floatBind :: IdEnv Level
134           -> Level
135           -> LevelledBind
136           -> (FloatStats, FloatingBinds, CoreBinding, IdEnv Level)
137
138 floatBind env lvl (NonRec (name,level) rhs)
139   = case (floatExpr env level rhs) of { (fs, rhs_floats, rhs') ->
140
141         -- A good dumping point
142     case (partitionByMajorLevel level rhs_floats) of { (rhs_floats', heres) ->
143
144     (fs, rhs_floats',
145      NonRec name (install heres rhs'),
146      addOneToIdEnv env name level)
147     }}
148
149 floatBind env lvl bind@(Rec pairs)
150   = case (unzip3 (map do_pair pairs)) of { (fss, rhss_floats, new_pairs) ->
151
152     if not (isTopLvl bind_level) then
153         -- Standard case
154         (sum_stats fss, concat rhss_floats, Rec new_pairs, new_env)
155     else
156         {- In a recursive binding, destined for the top level (only),
157            the rhs floats may contain
158            references to the bound things.  For example
159
160                 f = ...(let v = ...f... in b) ...
161
162            might get floated to
163
164                 v = ...f...
165                 f = ... b ...
166
167            and hence we must (pessimistically) make all the floats recursive
168            with the top binding.  Later dependency analysis will unravel it.
169         -}
170
171         (sum_stats fss,
172          [],
173          Rec (new_pairs ++ floatsToBindPairs (concat rhss_floats)),
174          new_env)
175
176     }
177   where
178     new_env = growIdEnvList env (map fst pairs)
179
180     bind_level = getBindLevel bind
181
182     do_pair ((name, level), rhs)
183       = case (floatExpr new_env level rhs) of { (fs, rhs_floats, rhs') ->
184
185                 -- A good dumping point
186         case (partitionByMajorLevel level rhs_floats) of { (rhs_floats', heres) ->
187
188         (fs, rhs_floats', (name, install heres rhs'))
189         }}
190 \end{code}
191
192 %************************************************************************
193
194 \subsection[FloatOut-Expr]{Floating in expressions}
195 %*                                                                      *
196 %************************************************************************
197
198 \begin{code}
199 floatExpr :: IdEnv Level
200           -> Level
201           -> LevelledExpr
202           -> (FloatStats, FloatingBinds, CoreExpr)
203
204 floatExpr env _ (Var v)      = (zero_stats, [], Var v)
205 floatExpr env _ (Lit l)      = (zero_stats, [], Lit l)
206 floatExpr env _ (Prim op as) = (zero_stats, [], Prim op as)
207 floatExpr env _ (Con con as) = (zero_stats, [], Con con as)
208           
209 floatExpr env lvl (App e a)
210   = case (floatExpr env lvl e) of { (fs, floating_defns, e') ->
211     (fs, floating_defns, App e' a) }
212
213 floatExpr env lvl (Lam (TyBinder tv) e)
214   = let
215         incd_lvl = incMinorLvl lvl
216     in
217     case (floatExpr env incd_lvl e) of { (fs, floats, e') ->
218
219         -- Dump any bindings which absolutely cannot go any further
220     case (partitionByLevel incd_lvl floats)     of { (floats', heres) ->
221
222     (fs, floats', Lam (TyBinder tv) (install heres e'))
223     }}
224
225 floatExpr env lvl (Lam (ValBinder (arg,incd_lvl)) rhs)
226   = let
227         new_env  = addOneToIdEnv env arg incd_lvl
228     in
229     case (floatExpr new_env incd_lvl rhs) of { (fs, floats, rhs') ->
230
231         -- Dump any bindings which absolutely cannot go any further
232     case (partitionByLevel incd_lvl floats)     of { (floats', heres) ->
233
234     (add_to_stats fs floats',
235      floats',
236      Lam (ValBinder arg) (install heres rhs'))
237     }}
238
239 floatExpr env lvl (SCC cc expr)
240   = case (floatExpr env lvl expr)    of { (fs, floating_defns, expr') ->
241     let
242         -- annotate bindings floated outwards past an scc expression
243         -- with the cc.  We mark that cc as "duplicated", though.
244
245         annotated_defns = annotate (dupifyCC cc) floating_defns
246     in
247     (fs, annotated_defns, SCC cc expr') }
248   where
249     annotate :: CostCentre -> FloatingBinds -> FloatingBinds
250
251     annotate dupd_cc defn_groups
252       = [ (level, ann_bind floater) | (level, floater) <- defn_groups ]
253       where
254         ann_bind (LetFloater (NonRec binder rhs))
255           = LetFloater (NonRec binder (ann_rhs rhs))
256
257         ann_bind (LetFloater (Rec pairs))
258           = LetFloater (Rec [(binder, ann_rhs rhs) | (binder, rhs) <- pairs])
259
260         ann_bind (CaseFloater fn) = CaseFloater ( \ rhs -> SCC dupd_cc (fn rhs) )
261
262         ann_rhs (Lam arg e)   = Lam arg (ann_rhs e)
263         ann_rhs rhs@(Con _ _) = rhs     -- no point in scc'ing WHNF data
264         ann_rhs rhs           = SCC dupd_cc rhs
265
266         -- Note: Nested SCC's are preserved for the benefit of
267         --       cost centre stack profiling (Durham)
268
269 floatExpr env lvl (Coerce c ty expr)
270   = case (floatExpr env lvl expr)    of { (fs, floating_defns, expr') ->
271     (fs, floating_defns, Coerce c ty expr') }
272
273 floatExpr env lvl (Let bind body)
274   = case (floatBind env     lvl bind) of { (fsb, rhs_floats, bind', new_env) ->
275     case (floatExpr new_env lvl body) of { (fse, body_floats, body') ->
276     (add_stats fsb fse,
277      rhs_floats ++ [(bind_lvl, LetFloater bind')] ++ body_floats,
278      body')
279     }}
280   where
281     bind_lvl = getBindLevel bind
282
283 floatExpr env lvl (Case scrut alts)
284   = case (floatExpr env lvl scrut) of { (fse, fde, scrut') ->
285
286     case (scrut', float_alts alts) of
287         (_, (fsa, fda, alts')) ->
288                 (add_stats fse fsa, fda ++ fde, Case scrut' alts')
289     }
290     {-  OLD CASE-FLOATING CODE: DROPPED FOR NOW.  (SLPJ 7/2/94)
291
292         (Var scrut_var, (fda, AlgAlts [(con,bs,rhs')] NoDefault))
293                 | scrut_var_lvl `ltMajLvl` lvl ->
294
295                 -- Candidate for case floater; scrutinising a variable; it can
296                 -- escape outside a lambda; there's only one alternative.
297                 (fda ++ fde ++ [case_floater], rhs')
298
299                 where
300                 case_floater = (scrut_var_lvl, CaseFloater fn)
301                 fn body = Case scrut' (AlgAlts [(con,bs,body)] NoDefault)
302                 scrut_var_lvl = case lookupIdEnv env scrut_var of
303                                   Nothing  -> Level 0 0
304                                   Just lvl -> unTopify lvl
305
306     END OF CASE FLOATING DROPPED -}
307   where
308       incd_lvl = incMinorLvl lvl
309
310       partition_fn = partitionByMajorLevel
311
312 {-      OMITTED
313         We don't want to be too keen about floating lets out of case alternatives
314         because they may benefit from seeing the evaluation done by the case.
315
316         The main reason for doing this is to allocate in fewer larger blocks
317         but that's really an STG-level issue.
318
319                         case alts of
320                                 -- Just one alternative, then dump only
321                                 -- what *has* to be dumped
322                         AlgAlts  [_] NoDefault     -> partitionByLevel
323                         AlgAlts  []  (BindDefault _ _) -> partitionByLevel
324                         PrimAlts [_] NoDefault     -> partitionByLevel
325                         PrimAlts []  (BindDefault _ _) -> partitionByLevel
326
327                                 -- If there's more than one alternative, then
328                                 -- this is a dumping point
329                         other                              -> partitionByMajorLevel
330 -}
331
332       float_alts (AlgAlts alts deflt)
333         = case (float_deflt  deflt)              of { (fsd,  fdd,  deflt') ->
334           case (unzip3 (map float_alg_alt alts)) of { (fsas, fdas, alts') ->
335           (foldr add_stats fsd fsas,
336            concat fdas ++ fdd,
337            AlgAlts alts' deflt') }}
338
339       float_alts (PrimAlts alts deflt)
340         = case (float_deflt deflt)                of { (fsd,   fdd, deflt') ->
341           case (unzip3 (map float_prim_alt alts)) of { (fsas, fdas, alts') ->
342           (foldr add_stats fsd fsas,
343            concat fdas ++ fdd,
344            PrimAlts alts' deflt') }}
345
346       -------------
347       float_alg_alt (con, bs, rhs)
348         = let
349               bs' = map fst bs
350               new_env = growIdEnvList env bs
351           in
352           case (floatExpr new_env incd_lvl rhs) of { (fs, rhs_floats, rhs') ->
353           case (partition_fn incd_lvl rhs_floats)       of { (rhs_floats', heres) ->
354           (fs, rhs_floats', (con, bs', install heres rhs')) }}
355
356       --------------
357       float_prim_alt (lit, rhs)
358         = case (floatExpr env incd_lvl rhs)             of { (fs, rhs_floats, rhs') ->
359           case (partition_fn incd_lvl rhs_floats)       of { (rhs_floats', heres) ->
360           (fs, rhs_floats', (lit, install heres rhs')) }}
361
362       --------------
363       float_deflt NoDefault = (zero_stats, [], NoDefault)
364
365       float_deflt (BindDefault (b,lvl) rhs)
366         = case (floatExpr new_env lvl rhs)              of { (fs, rhs_floats, rhs') ->
367           case (partition_fn incd_lvl rhs_floats)       of { (rhs_floats', heres) ->
368           (fs, rhs_floats', BindDefault b (install heres rhs')) }}
369         where
370           new_env = addOneToIdEnv env b lvl
371 \end{code}
372
373 %************************************************************************
374 %*                                                                      *
375 \subsection{Utility bits for floating stats}
376 %*                                                                      *
377 %************************************************************************
378
379 I didn't implement this with unboxed numbers.  I don't want to be too
380 strict in this stuff, as it is rarely turned on.  (WDP 95/09)
381
382 \begin{code}
383 data FloatStats
384   = FlS Int  -- Number of top-floats * lambda groups they've been past
385         Int  -- Number of non-top-floats * lambda groups they've been past
386         Int  -- Number of lambda (groups) seen
387
388 get_stats (FlS a b c) = (a, b, c)
389
390 zero_stats = FlS 0 0 0
391
392 sum_stats xs = foldr add_stats zero_stats xs
393
394 add_stats (FlS a1 b1 c1) (FlS a2 b2 c2)
395   = FlS (a1 + a2) (b1 + b2) (c1 + c2)
396
397 add_to_stats (FlS a b c) floats
398   = FlS (a + length top_floats) (b + length other_floats) (c + 1)
399   where
400     (top_floats, other_floats) = partition to_very_top floats
401
402     to_very_top (my_lvl, _) = isTopLvl my_lvl
403 \end{code}
404
405 %************************************************************************
406 %*                                                                      *
407 \subsection{Utility bits for floating}
408 %*                                                                      *
409 %************************************************************************
410
411 \begin{code}
412 getBindLevel (NonRec (_, lvl) _)      = lvl
413 getBindLevel (Rec (((_,lvl), _) : _)) = lvl
414 \end{code}
415
416 \begin{code}
417 partitionByMajorLevel, partitionByLevel
418         :: Level                -- Partitioning level
419
420         -> FloatingBinds        -- Defns to be divided into 2 piles...
421
422         -> (FloatingBinds,      -- Defns  with level strictly < partition level,
423             FloatingBinds)      -- The rest
424
425
426 partitionByMajorLevel ctxt_lvl defns
427   = partition float_further defns
428   where
429     float_further (my_lvl, _) = my_lvl `ltMajLvl` ctxt_lvl ||
430                                 isTopLvl my_lvl
431
432 partitionByLevel ctxt_lvl defns
433   = partition float_further defns
434   where
435     float_further (my_lvl, _) = my_lvl `ltLvl` ctxt_lvl
436 \end{code}
437
438 \begin{code}
439 floatsToBinds :: FloatingBinds -> [CoreBinding]
440 floatsToBinds floats = map get_bind floats
441                      where
442                        get_bind (_, LetFloater bind) = bind
443                        get_bind (_, CaseFloater _)   = panic "floatsToBinds"
444
445 floatsToBindPairs :: FloatingBinds -> [(Id,CoreExpr)]
446
447 floatsToBindPairs floats = concat (map mk_pairs floats)
448   where
449    mk_pairs (_, LetFloater (Rec pairs))         = pairs
450    mk_pairs (_, LetFloater (NonRec binder rhs)) = [(binder,rhs)]
451    mk_pairs (_, CaseFloater _)                    = panic "floatsToBindPairs"
452
453 install :: FloatingBinds -> CoreExpr -> CoreExpr
454
455 install defn_groups expr
456   = foldr install_group expr defn_groups
457   where
458     install_group (_, LetFloater defns) body = Let defns body
459     install_group (_, CaseFloater fn)   body = fn body
460 \end{code}