9fe6b87481de6c01cd1889e7e6173559f26570d0
[ghc-hetmet.git] / compiler / simplCore / LiberateCase.lhs
1 %
2 % (c) The AQUA Project, Glasgow University, 1994-1998
3 %
4 \section[LiberateCase]{Unroll recursion to allow evals to be lifted from a loop}
5
6 \begin{code}
7 module LiberateCase ( liberateCase ) where
8
9 #include "HsVersions.h"
10
11 import DynFlags
12 import HscTypes
13 import CoreLint         ( showPass, endPass )
14 import CoreSyn
15 import CoreUnfold       ( couldBeSmallEnoughToInline )
16 import Rules            ( RuleBase )
17 import UniqSupply       ( UniqSupply )
18 import SimplMonad       ( SimplCount, zeroSimplCount )
19 import Id
20 import VarEnv
21 import Util             ( notNull )
22 \end{code}
23
24 The liberate-case transformation
25 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
26 This module walks over @Core@, and looks for @case@ on free variables.
27 The criterion is:
28         if there is case on a free on the route to the recursive call,
29         then the recursive call is replaced with an unfolding.
30
31 Example
32
33    f = \ t -> case v of
34                  V a b -> a : f t
35
36 => the inner f is replaced.
37
38    f = \ t -> case v of
39                  V a b -> a : (letrec
40                                 f =  \ t -> case v of
41                                                V a b -> a : f t
42                                in f) t
43 (note the NEED for shadowing)
44
45 => Simplify
46
47   f = \ t -> case v of
48                  V a b -> a : (letrec
49                                 f = \ t -> a : f t
50                                in f t)
51
52 Better code, because 'a' is  free inside the inner letrec, rather
53 than needing projection from v.
54
55 Note that this deals with *free variables*.  SpecConstr deals with
56 *arguments* that are of known form.  E.g.
57
58         last []     = error 
59         last (x:[]) = x
60         last (x:xs) = last xs
61
62         
63 Note [Scrutinee with cast]
64 ~~~~~~~~~~~~~~~~~~~~~~~~~~
65 Consider this:
66     f = \ t -> case (v `cast` co) of
67                  V a b -> a : f t
68
69 Exactly the same optimisation (unrolling one call to f) will work here, 
70 despite the cast.  See mk_alt_env in the Case branch of libCase.
71
72
73 Note [Only functions!]
74 ~~~~~~~~~~~~~~~~~~~~~~
75 Consider the following code
76
77        f = g (case v of V a b -> a : t f)
78
79 where g is expensive. If we aren't careful, liberate case will turn this into
80
81        f = g (case v of
82                V a b -> a : t (letrec f = g (case v of V a b -> a : f t)
83                                 in f)
84              )
85
86 Yikes! We evaluate g twice. This leads to a O(2^n) explosion
87 if g calls back to the same code recursively.
88
89 Solution: make sure that we only do the liberate-case thing on *functions*
90
91 To think about (Apr 94)
92 ~~~~~~~~~~~~~~
93 Main worry: duplicating code excessively.  At the moment we duplicate
94 the entire binding group once at each recursive call.  But there may
95 be a group of recursive calls which share a common set of evaluated
96 free variables, in which case the duplication is a plain waste.
97
98 Another thing we could consider adding is some unfold-threshold thing,
99 so that we'll only duplicate if the size of the group rhss isn't too
100 big.
101
102 Data types
103 ~~~~~~~~~~
104 The ``level'' of a binder tells how many
105 recursive defns lexically enclose the binding
106 A recursive defn "encloses" its RHS, not its
107 scope.  For example:
108 \begin{verbatim}
109         letrec f = let g = ... in ...
110         in
111         let h = ...
112         in ...
113 \end{verbatim}
114 Here, the level of @f@ is zero, the level of @g@ is one,
115 and the level of @h@ is zero (NB not one).
116
117
118 %************************************************************************
119 %*                                                                      *
120          Top-level code
121 %*                                                                      *
122 %************************************************************************
123
124 \begin{code}
125 liberateCase :: HscEnv -> UniqSupply -> RuleBase -> ModGuts
126              -> IO (SimplCount, ModGuts)
127 liberateCase hsc_env _ _ guts
128   = do  { let dflags = hsc_dflags hsc_env
129
130         ; showPass dflags "Liberate case"
131         ; let { env = initEnv dflags
132               ; binds' = do_prog env (mg_binds guts) }
133         ; endPass dflags "Liberate case" Opt_D_verbose_core2core binds'
134                         {- no specific flag for dumping -} 
135         ; return (zeroSimplCount dflags, guts { mg_binds = binds' }) }
136   where
137     do_prog _   [] = []
138     do_prog env (bind:binds) = bind' : do_prog env' binds
139                              where
140                                (env', bind') = libCaseBind env bind
141 \end{code}
142
143
144 %************************************************************************
145 %*                                                                      *
146          Main payload
147 %*                                                                      *
148 %************************************************************************
149
150 Bindings
151 ~~~~~~~~
152 \begin{code}
153 libCaseBind :: LibCaseEnv -> CoreBind -> (LibCaseEnv, CoreBind)
154
155 libCaseBind env (NonRec binder rhs)
156   = (addBinders env [binder], NonRec binder (libCase env rhs))
157
158 libCaseBind env (Rec pairs)
159   = (env_body, Rec pairs')
160   where
161     (binders, _rhss) = unzip pairs
162
163     env_body = addBinders env binders
164
165     pairs' = [(binder, libCase env_rhs rhs) | (binder,rhs) <- pairs]
166
167     env_rhs = if all rhs_small_enough pairs then extended_env else env
168
169         -- We extend the rec-env by binding each Id to its rhs, first
170         -- processing the rhs with an *un-extended* environment, so
171         -- that the same process doesn't occur for ever!
172         --
173     extended_env = addRecBinds env [ (localiseId binder, libCase env_body rhs)
174                                    | (binder, rhs) <- pairs ]
175
176         -- The call to localiseId is needed for two subtle reasons
177         -- (a)  Reset the export flags on the binders so
178         --      that we don't get name clashes on exported things if the 
179         --      local binding floats out to top level.  This is most unlikely
180         --      to happen, since the whole point concerns free variables. 
181         --      But resetting the export flag is right regardless.
182         -- 
183         -- (b)  Make the name an Internal one.  External Names should never be
184         --      nested; if it were floated to the top level, we'd get a name
185         --      clash at code generation time.
186
187     rhs_small_enough (id,rhs)
188         =  idArity id > 0       -- Note [Only functions!]
189         && maybe True (\size -> couldBeSmallEnoughToInline size rhs)
190                       (bombOutSize env)
191 \end{code}
192
193
194 Expressions
195 ~~~~~~~~~~~
196
197 \begin{code}
198 libCase :: LibCaseEnv
199         -> CoreExpr
200         -> CoreExpr
201
202 libCase env (Var v)             = libCaseId env v
203 libCase _   (Lit lit)           = Lit lit
204 libCase _   (Type ty)           = Type ty
205 libCase env (App fun arg)       = App (libCase env fun) (libCase env arg)
206 libCase env (Note note body)    = Note note (libCase env body)
207 libCase env (Cast e co)         = Cast (libCase env e) co
208
209 libCase env (Lam binder body)
210   = Lam binder (libCase (addBinders env [binder]) body)
211
212 libCase env (Let bind body)
213   = Let bind' (libCase env_body body)
214   where
215     (env_body, bind') = libCaseBind env bind
216
217 libCase env (Case scrut bndr ty alts)
218   = Case (libCase env scrut) bndr ty (map (libCaseAlt env_alts) alts)
219   where
220     env_alts = addBinders (mk_alt_env scrut) [bndr]
221     mk_alt_env (Var scrut_var) = addScrutedVar env scrut_var
222     mk_alt_env (Cast scrut _)  = mk_alt_env scrut       -- Note [Scrutinee with cast]
223     mk_alt_env _               = env
224
225 libCaseAlt :: LibCaseEnv -> (AltCon, [CoreBndr], CoreExpr)
226                          -> (AltCon, [CoreBndr], CoreExpr)
227 libCaseAlt env (con,args,rhs) = (con, args, libCase (addBinders env args) rhs)
228 \end{code}
229
230
231 Ids
232 ~~~
233 \begin{code}
234 libCaseId :: LibCaseEnv -> Id -> CoreExpr
235 libCaseId env v
236   | Just the_bind <- lookupRecId env v  -- It's a use of a recursive thing
237   , notNull free_scruts                 -- with free vars scrutinised in RHS
238   = Let the_bind (Var v)
239
240   | otherwise
241   = Var v
242
243   where
244     rec_id_level = lookupLevel env v
245     free_scruts  = freeScruts env rec_id_level
246
247 freeScruts :: LibCaseEnv
248            -> LibCaseLevel      -- Level of the recursive Id
249            -> [Id]              -- Ids that are scrutinised between the binding
250                                 -- of the recursive Id and here
251 freeScruts env rec_bind_lvl
252   = [v | (v,scrut_bind_lvl) <- lc_scruts env
253        , scrut_bind_lvl <= rec_bind_lvl]
254         -- Note [When to specialise]
255 \end{code}
256
257 Note [When to specialise]
258 ~~~~~~~~~~~~~~~~~~~~~~~~~
259 Consider
260   f = \x. letrec g = \y. case x of
261                            True  -> ... (f a) ...
262                            False -> ... (g b) ...
263
264 We get the following levels
265           f  0
266           x  1
267           g  1
268           y  2  
269
270 Then 'x' is being scrutinised at a deeper level than its binding, so
271 it's added to lc_sruts:  [(x,1)]  
272
273 We do *not* want to specialise the call to 'f', becuase 'x' is not free 
274 in 'f'.  So here the bind-level of 'x' (=1) is not <= the bind-level of 'f' (=0).
275
276 We *do* want to specialise the call to 'g', because 'x' is free in g.
277 Here the bind-level of 'x' (=1) is <= the bind-level of 'g' (=1).
278
279
280 %************************************************************************
281 %*                                                                      *
282         Utility functions
283 %*                                                                      *
284 %************************************************************************
285
286 \begin{code}
287 addBinders :: LibCaseEnv -> [CoreBndr] -> LibCaseEnv
288 addBinders env@(LibCaseEnv { lc_lvl = lvl, lc_lvl_env = lvl_env }) binders
289   = env { lc_lvl_env = lvl_env' }
290   where
291     lvl_env' = extendVarEnvList lvl_env (binders `zip` repeat lvl)
292
293 addRecBinds :: LibCaseEnv -> [(Id,CoreExpr)] -> LibCaseEnv
294 addRecBinds env@(LibCaseEnv {lc_lvl = lvl, lc_lvl_env = lvl_env, 
295                              lc_rec_env = rec_env}) pairs
296   = env { lc_lvl = lvl', lc_lvl_env = lvl_env', lc_rec_env = rec_env' }
297   where
298     lvl'     = lvl + 1
299     lvl_env' = extendVarEnvList lvl_env [(binder,lvl) | (binder,_) <- pairs]
300     rec_env' = extendVarEnvList rec_env [(binder, Rec pairs) | (binder,_) <- pairs]
301
302 addScrutedVar :: LibCaseEnv
303               -> Id             -- This Id is being scrutinised by a case expression
304               -> LibCaseEnv
305
306 addScrutedVar env@(LibCaseEnv { lc_lvl = lvl, lc_lvl_env = lvl_env, 
307                                 lc_scruts = scruts }) scrut_var
308   | bind_lvl < lvl
309   = env { lc_scruts = scruts' }
310         -- Add to scruts iff the scrut_var is being scrutinised at
311         -- a deeper level than its defn
312
313   | otherwise = env
314   where
315     scruts'  = (scrut_var, bind_lvl) : scruts
316     bind_lvl = case lookupVarEnv lvl_env scrut_var of
317                  Just lvl -> lvl
318                  Nothing  -> topLevel
319
320 lookupRecId :: LibCaseEnv -> Id -> Maybe CoreBind
321 lookupRecId env id = lookupVarEnv (lc_rec_env env) id
322
323 lookupLevel :: LibCaseEnv -> Id -> LibCaseLevel
324 lookupLevel env id
325   = case lookupVarEnv (lc_lvl_env env) id of
326       Just lvl -> lvl
327       Nothing  -> topLevel
328 \end{code}
329
330 %************************************************************************
331 %*                                                                      *
332          The environment
333 %*                                                                      *
334 %************************************************************************
335
336 \begin{code}
337 type LibCaseLevel = Int
338
339 topLevel :: LibCaseLevel
340 topLevel = 0
341 \end{code}
342
343 \begin{code}
344 data LibCaseEnv
345   = LibCaseEnv {
346         lc_size :: Maybe Int,   -- Bomb-out size for deciding if
347                                 -- potential liberatees are too big.
348                                 -- (passed in from cmd-line args)
349
350         lc_lvl :: LibCaseLevel, -- Current level
351                 -- The level is incremented when (and only when) going
352                 -- inside the RHS of a (sufficiently small) recursive
353                 -- function.
354
355         lc_lvl_env :: IdEnv LibCaseLevel,  
356                 -- Binds all non-top-level in-scope Ids (top-level and
357                 -- imported things have a level of zero)
358
359         lc_rec_env :: IdEnv CoreBind, 
360                 -- Binds *only* recursively defined ids, to their own
361                 -- binding group, and *only* in their own RHSs
362
363         lc_scruts :: [(Id,LibCaseLevel)]
364                 -- Each of these Ids was scrutinised by an enclosing
365                 -- case expression, at a level deeper than its binding
366                 -- level.  The LibCaseLevel recorded here is the *binding
367                 -- level* of the scrutinised Id.
368                 -- 
369                 -- The order is insignificant; it's a bag really
370         }
371
372 initEnv :: DynFlags -> LibCaseEnv
373 initEnv dflags 
374   = LibCaseEnv { lc_size = liberateCaseThreshold dflags,
375                  lc_lvl = 0,
376                  lc_lvl_env = emptyVarEnv, 
377                  lc_rec_env = emptyVarEnv,
378                  lc_scruts = [] }
379
380 bombOutSize :: LibCaseEnv -> Maybe Int
381 bombOutSize = lc_size
382 \end{code}
383
384