4b1055bbede2a34b7f87543207a7939a5befe106
[ghc-hetmet.git] / compiler / simplCore / LiberateCase.lhs
1 %
2 % (c) The AQUA Project, Glasgow University, 1994-1998
3 %
4 \section[LiberateCase]{Unroll recursion to allow evals to be lifted from a loop}
5
6 \begin{code}
7 module LiberateCase ( liberateCase ) where
8
9 #include "HsVersions.h"
10
11 import DynFlags
12 import CoreSyn
13 import CoreUnfold       ( couldBeSmallEnoughToInline )
14 import Id
15 import VarEnv
16 import Util             ( notNull )
17 \end{code}
18
19 The liberate-case transformation
20 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
21 This module walks over @Core@, and looks for @case@ on free variables.
22 The criterion is:
23         if there is case on a free on the route to the recursive call,
24         then the recursive call is replaced with an unfolding.
25
26 Example
27
28    f = \ t -> case v of
29                  V a b -> a : f t
30
31 => the inner f is replaced.
32
33    f = \ t -> case v of
34                  V a b -> a : (letrec
35                                 f =  \ t -> case v of
36                                                V a b -> a : f t
37                                in f) t
38 (note the NEED for shadowing)
39
40 => Simplify
41
42   f = \ t -> case v of
43                  V a b -> a : (letrec
44                                 f = \ t -> a : f t
45                                in f t)
46
47 Better code, because 'a' is  free inside the inner letrec, rather
48 than needing projection from v.
49
50 Note that this deals with *free variables*.  SpecConstr deals with
51 *arguments* that are of known form.  E.g.
52
53         last []     = error 
54         last (x:[]) = x
55         last (x:xs) = last xs
56
57         
58 Note [Scrutinee with cast]
59 ~~~~~~~~~~~~~~~~~~~~~~~~~~
60 Consider this:
61     f = \ t -> case (v `cast` co) of
62                  V a b -> a : f t
63
64 Exactly the same optimisation (unrolling one call to f) will work here, 
65 despite the cast.  See mk_alt_env in the Case branch of libCase.
66
67
68 Note [Only functions!]
69 ~~~~~~~~~~~~~~~~~~~~~~
70 Consider the following code
71
72        f = g (case v of V a b -> a : t f)
73
74 where g is expensive. If we aren't careful, liberate case will turn this into
75
76        f = g (case v of
77                V a b -> a : t (letrec f = g (case v of V a b -> a : f t)
78                                 in f)
79              )
80
81 Yikes! We evaluate g twice. This leads to a O(2^n) explosion
82 if g calls back to the same code recursively.
83
84 Solution: make sure that we only do the liberate-case thing on *functions*
85
86 To think about (Apr 94)
87 ~~~~~~~~~~~~~~
88 Main worry: duplicating code excessively.  At the moment we duplicate
89 the entire binding group once at each recursive call.  But there may
90 be a group of recursive calls which share a common set of evaluated
91 free variables, in which case the duplication is a plain waste.
92
93 Another thing we could consider adding is some unfold-threshold thing,
94 so that we'll only duplicate if the size of the group rhss isn't too
95 big.
96
97 Data types
98 ~~~~~~~~~~
99 The ``level'' of a binder tells how many
100 recursive defns lexically enclose the binding
101 A recursive defn "encloses" its RHS, not its
102 scope.  For example:
103 \begin{verbatim}
104         letrec f = let g = ... in ...
105         in
106         let h = ...
107         in ...
108 \end{verbatim}
109 Here, the level of @f@ is zero, the level of @g@ is one,
110 and the level of @h@ is zero (NB not one).
111
112
113 %************************************************************************
114 %*                                                                      *
115          Top-level code
116 %*                                                                      *
117 %************************************************************************
118
119 \begin{code}
120 liberateCase :: DynFlags -> [CoreBind] -> [CoreBind]
121 liberateCase dflags binds = do_prog (initEnv dflags) binds
122   where
123     do_prog _   [] = []
124     do_prog env (bind:binds) = bind' : do_prog env' binds
125                              where
126                                (env', bind') = libCaseBind env bind
127 \end{code}
128
129
130 %************************************************************************
131 %*                                                                      *
132          Main payload
133 %*                                                                      *
134 %************************************************************************
135
136 Bindings
137 ~~~~~~~~
138 \begin{code}
139 libCaseBind :: LibCaseEnv -> CoreBind -> (LibCaseEnv, CoreBind)
140
141 libCaseBind env (NonRec binder rhs)
142   = (addBinders env [binder], NonRec binder (libCase env rhs))
143
144 libCaseBind env (Rec pairs)
145   = (env_body, Rec pairs')
146   where
147     (binders, _rhss) = unzip pairs
148
149     env_body = addBinders env binders
150
151     pairs' = [(binder, libCase env_rhs rhs) | (binder,rhs) <- pairs]
152
153     env_rhs = if all rhs_small_enough pairs then extended_env else env
154
155         -- We extend the rec-env by binding each Id to its rhs, first
156         -- processing the rhs with an *un-extended* environment, so
157         -- that the same process doesn't occur for ever!
158         --
159     extended_env = addRecBinds env [ (localiseId binder, libCase env_body rhs)
160                                    | (binder, rhs) <- pairs ]
161
162         -- The call to localiseId is needed for two subtle reasons
163         -- (a)  Reset the export flags on the binders so
164         --      that we don't get name clashes on exported things if the 
165         --      local binding floats out to top level.  This is most unlikely
166         --      to happen, since the whole point concerns free variables. 
167         --      But resetting the export flag is right regardless.
168         -- 
169         -- (b)  Make the name an Internal one.  External Names should never be
170         --      nested; if it were floated to the top level, we'd get a name
171         --      clash at code generation time.
172
173     rhs_small_enough (id,rhs)
174         =  idArity id > 0       -- Note [Only functions!]
175         && maybe True (\size -> couldBeSmallEnoughToInline size rhs)
176                       (bombOutSize env)
177 \end{code}
178
179
180 Expressions
181 ~~~~~~~~~~~
182
183 \begin{code}
184 libCase :: LibCaseEnv
185         -> CoreExpr
186         -> CoreExpr
187
188 libCase env (Var v)             = libCaseId env v
189 libCase _   (Lit lit)           = Lit lit
190 libCase _   (Type ty)           = Type ty
191 libCase env (App fun arg)       = App (libCase env fun) (libCase env arg)
192 libCase env (Note note body)    = Note note (libCase env body)
193 libCase env (Cast e co)         = Cast (libCase env e) co
194
195 libCase env (Lam binder body)
196   = Lam binder (libCase (addBinders env [binder]) body)
197
198 libCase env (Let bind body)
199   = Let bind' (libCase env_body body)
200   where
201     (env_body, bind') = libCaseBind env bind
202
203 libCase env (Case scrut bndr ty alts)
204   = Case (libCase env scrut) bndr ty (map (libCaseAlt env_alts) alts)
205   where
206     env_alts = addBinders (mk_alt_env scrut) [bndr]
207     mk_alt_env (Var scrut_var) = addScrutedVar env scrut_var
208     mk_alt_env (Cast scrut _)  = mk_alt_env scrut       -- Note [Scrutinee with cast]
209     mk_alt_env _               = env
210
211 libCaseAlt :: LibCaseEnv -> (AltCon, [CoreBndr], CoreExpr)
212                          -> (AltCon, [CoreBndr], CoreExpr)
213 libCaseAlt env (con,args,rhs) = (con, args, libCase (addBinders env args) rhs)
214 \end{code}
215
216
217 Ids
218 ~~~
219 \begin{code}
220 libCaseId :: LibCaseEnv -> Id -> CoreExpr
221 libCaseId env v
222   | Just the_bind <- lookupRecId env v  -- It's a use of a recursive thing
223   , notNull free_scruts                 -- with free vars scrutinised in RHS
224   = Let the_bind (Var v)
225
226   | otherwise
227   = Var v
228
229   where
230     rec_id_level = lookupLevel env v
231     free_scruts  = freeScruts env rec_id_level
232
233 freeScruts :: LibCaseEnv
234            -> LibCaseLevel      -- Level of the recursive Id
235            -> [Id]              -- Ids that are scrutinised between the binding
236                                 -- of the recursive Id and here
237 freeScruts env rec_bind_lvl
238   = [v | (v,scrut_bind_lvl) <- lc_scruts env
239        , scrut_bind_lvl <= rec_bind_lvl]
240         -- Note [When to specialise]
241 \end{code}
242
243 Note [When to specialise]
244 ~~~~~~~~~~~~~~~~~~~~~~~~~
245 Consider
246   f = \x. letrec g = \y. case x of
247                            True  -> ... (f a) ...
248                            False -> ... (g b) ...
249
250 We get the following levels
251           f  0
252           x  1
253           g  1
254           y  2  
255
256 Then 'x' is being scrutinised at a deeper level than its binding, so
257 it's added to lc_sruts:  [(x,1)]  
258
259 We do *not* want to specialise the call to 'f', becuase 'x' is not free 
260 in 'f'.  So here the bind-level of 'x' (=1) is not <= the bind-level of 'f' (=0).
261
262 We *do* want to specialise the call to 'g', because 'x' is free in g.
263 Here the bind-level of 'x' (=1) is <= the bind-level of 'g' (=1).
264
265
266 %************************************************************************
267 %*                                                                      *
268         Utility functions
269 %*                                                                      *
270 %************************************************************************
271
272 \begin{code}
273 addBinders :: LibCaseEnv -> [CoreBndr] -> LibCaseEnv
274 addBinders env@(LibCaseEnv { lc_lvl = lvl, lc_lvl_env = lvl_env }) binders
275   = env { lc_lvl_env = lvl_env' }
276   where
277     lvl_env' = extendVarEnvList lvl_env (binders `zip` repeat lvl)
278
279 addRecBinds :: LibCaseEnv -> [(Id,CoreExpr)] -> LibCaseEnv
280 addRecBinds env@(LibCaseEnv {lc_lvl = lvl, lc_lvl_env = lvl_env, 
281                              lc_rec_env = rec_env}) pairs
282   = env { lc_lvl = lvl', lc_lvl_env = lvl_env', lc_rec_env = rec_env' }
283   where
284     lvl'     = lvl + 1
285     lvl_env' = extendVarEnvList lvl_env [(binder,lvl) | (binder,_) <- pairs]
286     rec_env' = extendVarEnvList rec_env [(binder, Rec pairs) | (binder,_) <- pairs]
287
288 addScrutedVar :: LibCaseEnv
289               -> Id             -- This Id is being scrutinised by a case expression
290               -> LibCaseEnv
291
292 addScrutedVar env@(LibCaseEnv { lc_lvl = lvl, lc_lvl_env = lvl_env, 
293                                 lc_scruts = scruts }) scrut_var
294   | bind_lvl < lvl
295   = env { lc_scruts = scruts' }
296         -- Add to scruts iff the scrut_var is being scrutinised at
297         -- a deeper level than its defn
298
299   | otherwise = env
300   where
301     scruts'  = (scrut_var, bind_lvl) : scruts
302     bind_lvl = case lookupVarEnv lvl_env scrut_var of
303                  Just lvl -> lvl
304                  Nothing  -> topLevel
305
306 lookupRecId :: LibCaseEnv -> Id -> Maybe CoreBind
307 lookupRecId env id = lookupVarEnv (lc_rec_env env) id
308
309 lookupLevel :: LibCaseEnv -> Id -> LibCaseLevel
310 lookupLevel env id
311   = case lookupVarEnv (lc_lvl_env env) id of
312       Just lvl -> lvl
313       Nothing  -> topLevel
314 \end{code}
315
316 %************************************************************************
317 %*                                                                      *
318          The environment
319 %*                                                                      *
320 %************************************************************************
321
322 \begin{code}
323 type LibCaseLevel = Int
324
325 topLevel :: LibCaseLevel
326 topLevel = 0
327 \end{code}
328
329 \begin{code}
330 data LibCaseEnv
331   = LibCaseEnv {
332         lc_size :: Maybe Int,   -- Bomb-out size for deciding if
333                                 -- potential liberatees are too big.
334                                 -- (passed in from cmd-line args)
335
336         lc_lvl :: LibCaseLevel, -- Current level
337                 -- The level is incremented when (and only when) going
338                 -- inside the RHS of a (sufficiently small) recursive
339                 -- function.
340
341         lc_lvl_env :: IdEnv LibCaseLevel,  
342                 -- Binds all non-top-level in-scope Ids (top-level and
343                 -- imported things have a level of zero)
344
345         lc_rec_env :: IdEnv CoreBind, 
346                 -- Binds *only* recursively defined ids, to their own
347                 -- binding group, and *only* in their own RHSs
348
349         lc_scruts :: [(Id,LibCaseLevel)]
350                 -- Each of these Ids was scrutinised by an enclosing
351                 -- case expression, at a level deeper than its binding
352                 -- level.  The LibCaseLevel recorded here is the *binding
353                 -- level* of the scrutinised Id.
354                 -- 
355                 -- The order is insignificant; it's a bag really
356         }
357
358 initEnv :: DynFlags -> LibCaseEnv
359 initEnv dflags 
360   = LibCaseEnv { lc_size = liberateCaseThreshold dflags,
361                  lc_lvl = 0,
362                  lc_lvl_env = emptyVarEnv, 
363                  lc_rec_env = emptyVarEnv,
364                  lc_scruts = [] }
365
366 bombOutSize :: LibCaseEnv -> Maybe Int
367 bombOutSize = lc_size
368 \end{code}
369
370