ab7923947a97123ff75241e58ebcb95ab2955c79
[ghc-hetmet.git] / compiler / simplCore / LiberateCase.lhs
1 %
2 % (c) The AQUA Project, Glasgow University, 1994-1998
3 %
4 \section[LiberateCase]{Unroll recursion to allow evals to be lifted from a loop}
5
6 \begin{code}
7 module LiberateCase ( liberateCase ) where
8
9 #include "HsVersions.h"
10
11 import DynFlags
12 import HscTypes
13 import CoreLint         ( showPass, endPass )
14 import CoreSyn
15 import CoreUnfold       ( couldBeSmallEnoughToInline )
16 import Rules            ( RuleBase )
17 import UniqSupply       ( UniqSupply )
18 import SimplMonad       ( SimplCount, zeroSimplCount )
19 import Id
20 import VarEnv
21 import Name             ( localiseName )
22 import Util             ( notNull )
23 \end{code}
24
25 The liberate-case transformation
26 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
27 This module walks over @Core@, and looks for @case@ on free variables.
28 The criterion is:
29         if there is case on a free on the route to the recursive call,
30         then the recursive call is replaced with an unfolding.
31
32 Example
33
34    f = \ t -> case v of
35                  V a b -> a : f t
36
37 => the inner f is replaced.
38
39    f = \ t -> case v of
40                  V a b -> a : (letrec
41                                 f =  \ t -> case v of
42                                                V a b -> a : f t
43                                in f) t
44 (note the NEED for shadowing)
45
46 => Simplify
47
48   f = \ t -> case v of
49                  V a b -> a : (letrec
50                                 f = \ t -> a : f t
51                                in f t)
52
53 Better code, because 'a' is  free inside the inner letrec, rather
54 than needing projection from v.
55
56 Note that this deals with *free variables*.  SpecConstr deals with
57 *arguments* that are of known form.  E.g.
58
59         last []     = error 
60         last (x:[]) = x
61         last (x:xs) = last xs
62
63         
64 Note [Scrutinee with cast]
65 ~~~~~~~~~~~~~~~~~~~~~~~~~~
66 Consider this:
67     f = \ t -> case (v `cast` co) of
68                  V a b -> a : f t
69
70 Exactly the same optimisation (unrolling one call to f) will work here, 
71 despite the cast.  See mk_alt_env in the Case branch of libCase.
72
73
74 Note [Only functions!]
75 ~~~~~~~~~~~~~~~~~~~~~~
76 Consider the following code
77
78        f = g (case v of V a b -> a : t f)
79
80 where g is expensive. If we aren't careful, liberate case will turn this into
81
82        f = g (case v of
83                V a b -> a : t (letrec f = g (case v of V a b -> a : f t)
84                                 in f)
85              )
86
87 Yikes! We evaluate g twice. This leads to a O(2^n) explosion
88 if g calls back to the same code recursively.
89
90 Solution: make sure that we only do the liberate-case thing on *functions*
91
92 To think about (Apr 94)
93 ~~~~~~~~~~~~~~
94 Main worry: duplicating code excessively.  At the moment we duplicate
95 the entire binding group once at each recursive call.  But there may
96 be a group of recursive calls which share a common set of evaluated
97 free variables, in which case the duplication is a plain waste.
98
99 Another thing we could consider adding is some unfold-threshold thing,
100 so that we'll only duplicate if the size of the group rhss isn't too
101 big.
102
103 Data types
104 ~~~~~~~~~~
105 The ``level'' of a binder tells how many
106 recursive defns lexically enclose the binding
107 A recursive defn "encloses" its RHS, not its
108 scope.  For example:
109 \begin{verbatim}
110         letrec f = let g = ... in ...
111         in
112         let h = ...
113         in ...
114 \end{verbatim}
115 Here, the level of @f@ is zero, the level of @g@ is one,
116 and the level of @h@ is zero (NB not one).
117
118
119 %************************************************************************
120 %*                                                                      *
121          Top-level code
122 %*                                                                      *
123 %************************************************************************
124
125 \begin{code}
126 liberateCase :: HscEnv -> UniqSupply -> RuleBase -> ModGuts
127              -> IO (SimplCount, ModGuts)
128 liberateCase hsc_env _ _ guts
129   = do  { let dflags = hsc_dflags hsc_env
130
131         ; showPass dflags "Liberate case"
132         ; let { env = initEnv dflags
133               ; binds' = do_prog env (mg_binds guts) }
134         ; endPass dflags "Liberate case" Opt_D_verbose_core2core binds'
135                         {- no specific flag for dumping -} 
136         ; return (zeroSimplCount dflags, guts { mg_binds = binds' }) }
137   where
138     do_prog _   [] = []
139     do_prog env (bind:binds) = bind' : do_prog env' binds
140                              where
141                                (env', bind') = libCaseBind env bind
142 \end{code}
143
144
145 %************************************************************************
146 %*                                                                      *
147          Main payload
148 %*                                                                      *
149 %************************************************************************
150
151 Bindings
152 ~~~~~~~~
153 \begin{code}
154 libCaseBind :: LibCaseEnv -> CoreBind -> (LibCaseEnv, CoreBind)
155
156 libCaseBind env (NonRec binder rhs)
157   = (addBinders env [binder], NonRec binder (libCase env rhs))
158
159 libCaseBind env (Rec pairs)
160   = (env_body, Rec pairs')
161   where
162     (binders, _rhss) = unzip pairs
163
164     env_body = addBinders env binders
165
166     pairs' = [(binder, libCase env_rhs rhs) | (binder,rhs) <- pairs]
167
168     env_rhs = if all rhs_small_enough pairs then extended_env else env
169
170         -- We extend the rec-env by binding each Id to its rhs, first
171         -- processing the rhs with an *un-extended* environment, so
172         -- that the same process doesn't occur for ever!
173         --
174     extended_env = addRecBinds env [ (adjust binder, libCase env_body rhs)
175                                    | (binder, rhs) <- pairs ]
176
177         -- Two subtle things: 
178         -- (a)  Reset the export flags on the binders so
179         --      that we don't get name clashes on exported things if the 
180         --      local binding floats out to top level.  This is most unlikely
181         --      to happen, since the whole point concerns free variables. 
182         --      But resetting the export flag is right regardless.
183         -- 
184         -- (b)  Make the name an Internal one.  External Names should never be
185         --      nested; if it were floated to the top level, we'd get a name
186         --      clash at code generation time.
187     adjust bndr = setIdNotExported (setIdName bndr (localiseName (idName bndr)))
188
189     rhs_small_enough (id,rhs)
190         =  idArity id > 0       -- Note [Only functions!]
191         && maybe True (\size -> couldBeSmallEnoughToInline size rhs)
192                       (bombOutSize env)
193 \end{code}
194
195
196 Expressions
197 ~~~~~~~~~~~
198
199 \begin{code}
200 libCase :: LibCaseEnv
201         -> CoreExpr
202         -> CoreExpr
203
204 libCase env (Var v)             = libCaseId env v
205 libCase _   (Lit lit)           = Lit lit
206 libCase _   (Type ty)           = Type ty
207 libCase env (App fun arg)       = App (libCase env fun) (libCase env arg)
208 libCase env (Note note body)    = Note note (libCase env body)
209 libCase env (Cast e co)         = Cast (libCase env e) co
210
211 libCase env (Lam binder body)
212   = Lam binder (libCase (addBinders env [binder]) body)
213
214 libCase env (Let bind body)
215   = Let bind' (libCase env_body body)
216   where
217     (env_body, bind') = libCaseBind env bind
218
219 libCase env (Case scrut bndr ty alts)
220   = Case (libCase env scrut) bndr ty (map (libCaseAlt env_alts) alts)
221   where
222     env_alts = addBinders (mk_alt_env scrut) [bndr]
223     mk_alt_env (Var scrut_var) = addScrutedVar env scrut_var
224     mk_alt_env (Cast scrut _)  = mk_alt_env scrut       -- Note [Scrutinee with cast]
225     mk_alt_env _               = env
226
227 libCaseAlt :: LibCaseEnv -> (AltCon, [CoreBndr], CoreExpr)
228                          -> (AltCon, [CoreBndr], CoreExpr)
229 libCaseAlt env (con,args,rhs) = (con, args, libCase (addBinders env args) rhs)
230 \end{code}
231
232
233 Ids
234 ~~~
235 \begin{code}
236 libCaseId :: LibCaseEnv -> Id -> CoreExpr
237 libCaseId env v
238   | Just the_bind <- lookupRecId env v  -- It's a use of a recursive thing
239   , notNull free_scruts                 -- with free vars scrutinised in RHS
240   = Let the_bind (Var v)
241
242   | otherwise
243   = Var v
244
245   where
246     rec_id_level = lookupLevel env v
247     free_scruts  = freeScruts env rec_id_level
248
249 freeScruts :: LibCaseEnv
250            -> LibCaseLevel      -- Level of the recursive Id
251            -> [Id]              -- Ids that are scrutinised between the binding
252                                 -- of the recursive Id and here
253 freeScruts env rec_bind_lvl
254   = [v | (v,scrut_bind_lvl) <- lc_scruts env
255        , scrut_bind_lvl <= rec_bind_lvl]
256         -- Note [When to specialise]
257 \end{code}
258
259 Note [When to specialise]
260 ~~~~~~~~~~~~~~~~~~~~~~~~~
261 Consider
262   f = \x. letrec g = \y. case x of
263                            True  -> ... (f a) ...
264                            False -> ... (g b) ...
265
266 We get the following levels
267           f  0
268           x  1
269           g  1
270           y  2  
271
272 Then 'x' is being scrutinised at a deeper level than its binding, so
273 it's added to lc_sruts:  [(x,1)]  
274
275 We do *not* want to specialise the call to 'f', becuase 'x' is not free 
276 in 'f'.  So here the bind-level of 'x' (=1) is not <= the bind-level of 'f' (=0).
277
278 We *do* want to specialise the call to 'g', because 'x' is free in g.
279 Here the bind-level of 'x' (=1) is <= the bind-level of 'g' (=1).
280
281
282 %************************************************************************
283 %*                                                                      *
284         Utility functions
285 %*                                                                      *
286 %************************************************************************
287
288 \begin{code}
289 addBinders :: LibCaseEnv -> [CoreBndr] -> LibCaseEnv
290 addBinders env@(LibCaseEnv { lc_lvl = lvl, lc_lvl_env = lvl_env }) binders
291   = env { lc_lvl_env = lvl_env' }
292   where
293     lvl_env' = extendVarEnvList lvl_env (binders `zip` repeat lvl)
294
295 addRecBinds :: LibCaseEnv -> [(Id,CoreExpr)] -> LibCaseEnv
296 addRecBinds env@(LibCaseEnv {lc_lvl = lvl, lc_lvl_env = lvl_env, 
297                              lc_rec_env = rec_env}) pairs
298   = env { lc_lvl = lvl', lc_lvl_env = lvl_env', lc_rec_env = rec_env' }
299   where
300     lvl'     = lvl + 1
301     lvl_env' = extendVarEnvList lvl_env [(binder,lvl) | (binder,_) <- pairs]
302     rec_env' = extendVarEnvList rec_env [(binder, Rec pairs) | (binder,_) <- pairs]
303
304 addScrutedVar :: LibCaseEnv
305               -> Id             -- This Id is being scrutinised by a case expression
306               -> LibCaseEnv
307
308 addScrutedVar env@(LibCaseEnv { lc_lvl = lvl, lc_lvl_env = lvl_env, 
309                                 lc_scruts = scruts }) scrut_var
310   | bind_lvl < lvl
311   = env { lc_scruts = scruts' }
312         -- Add to scruts iff the scrut_var is being scrutinised at
313         -- a deeper level than its defn
314
315   | otherwise = env
316   where
317     scruts'  = (scrut_var, bind_lvl) : scruts
318     bind_lvl = case lookupVarEnv lvl_env scrut_var of
319                  Just lvl -> lvl
320                  Nothing  -> topLevel
321
322 lookupRecId :: LibCaseEnv -> Id -> Maybe CoreBind
323 lookupRecId env id = lookupVarEnv (lc_rec_env env) id
324
325 lookupLevel :: LibCaseEnv -> Id -> LibCaseLevel
326 lookupLevel env id
327   = case lookupVarEnv (lc_lvl_env env) id of
328       Just lvl -> lvl
329       Nothing  -> topLevel
330 \end{code}
331
332 %************************************************************************
333 %*                                                                      *
334          The environment
335 %*                                                                      *
336 %************************************************************************
337
338 \begin{code}
339 type LibCaseLevel = Int
340
341 topLevel :: LibCaseLevel
342 topLevel = 0
343 \end{code}
344
345 \begin{code}
346 data LibCaseEnv
347   = LibCaseEnv {
348         lc_size :: Maybe Int,   -- Bomb-out size for deciding if
349                                 -- potential liberatees are too big.
350                                 -- (passed in from cmd-line args)
351
352         lc_lvl :: LibCaseLevel, -- Current level
353                 -- The level is incremented when (and only when) going
354                 -- inside the RHS of a (sufficiently small) recursive
355                 -- function.
356
357         lc_lvl_env :: IdEnv LibCaseLevel,  
358                 -- Binds all non-top-level in-scope Ids (top-level and
359                 -- imported things have a level of zero)
360
361         lc_rec_env :: IdEnv CoreBind, 
362                 -- Binds *only* recursively defined ids, to their own
363                 -- binding group, and *only* in their own RHSs
364
365         lc_scruts :: [(Id,LibCaseLevel)]
366                 -- Each of these Ids was scrutinised by an enclosing
367                 -- case expression, at a level deeper than its binding
368                 -- level.  The LibCaseLevel recorded here is the *binding
369                 -- level* of the scrutinised Id.
370                 -- 
371                 -- The order is insignificant; it's a bag really
372         }
373
374 initEnv :: DynFlags -> LibCaseEnv
375 initEnv dflags 
376   = LibCaseEnv { lc_size = liberateCaseThreshold dflags,
377                  lc_lvl = 0,
378                  lc_lvl_env = emptyVarEnv, 
379                  lc_rec_env = emptyVarEnv,
380                  lc_scruts = [] }
381
382 bombOutSize :: LibCaseEnv -> Maybe Int
383 bombOutSize = lc_size
384 \end{code}
385
386