Fix egregious sharing bug in LiberateCase
[ghc-hetmet.git] / compiler / simplCore / LiberateCase.lhs
1 %
2 % (c) The AQUA Project, Glasgow University, 1994-1998
3 %
4 \section[LiberateCase]{Unroll recursion to allow evals to be lifted from a loop}
5
6 \begin{code}
7 module LiberateCase ( liberateCase ) where
8
9 #include "HsVersions.h"
10
11 import DynFlags
12 import HscTypes
13 import CoreLint         ( showPass, endPass )
14 import CoreSyn
15 import CoreUnfold       ( couldBeSmallEnoughToInline )
16 import Rules            ( RuleBase )
17 import UniqSupply       ( UniqSupply )
18 import SimplMonad       ( SimplCount, zeroSimplCount )
19 import Id
20 import VarEnv
21 import Name             ( localiseName )
22 import Util             ( notNull )
23 \end{code}
24
25 The liberate-case transformation
26 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
27 This module walks over @Core@, and looks for @case@ on free variables.
28 The criterion is:
29         if there is case on a free on the route to the recursive call,
30         then the recursive call is replaced with an unfolding.
31
32 Example
33
34    f = \ t -> case v of
35                  V a b -> a : f t
36
37 => the inner f is replaced.
38
39    f = \ t -> case v of
40                  V a b -> a : (letrec
41                                 f =  \ t -> case v of
42                                                V a b -> a : f t
43                                in f) t
44 (note the NEED for shadowing)
45
46 => Simplify
47
48   f = \ t -> case v of
49                  V a b -> a : (letrec
50                                 f = \ t -> a : f t
51                                in f t)
52
53 Better code, because 'a' is  free inside the inner letrec, rather
54 than needing projection from v.
55
56 Other examples we'd like to catch with this kind of transformation
57
58         last []     = error 
59         last (x:[]) = x
60         last (x:xs) = last xs
61
62 We'd like to avoid the redundant pattern match, transforming to
63
64         last [] = error
65         last (x:[]) = x
66         last (x:(y:ys)) = last' y ys
67                 where
68                   last' y []     = y
69                   last' _ (y:ys) = last' y ys
70
71         (is this necessarily an improvement)
72
73 Similarly drop:
74
75         drop n [] = []
76         drop 0 xs = xs
77         drop n (x:xs) = drop (n-1) xs
78
79 Would like to pass n along unboxed.
80         
81 Note [Scrutinee with cast]
82 ~~~~~~~~~~~~~~~~~~~~~~~~~~
83 Consider this:
84     f = \ t -> case (v `cast` co) of
85                  V a b -> a : f t
86
87 Exactly the same optimisation (unrolling one call to f) will work here, 
88 despite the cast.  See mk_alt_env in the Case branch of libCase.
89
90
91 Note [Only functions!]
92 ~~~~~~~~~~~~~~~~~~~~~~
93 Consider the following code
94
95        f = g (case v of V a b -> a : t f)
96
97 where g is expensive. If we aren't careful, liberate case will turn this into
98
99        f = g (case v of
100                V a b -> a : t (letrec f = g (case v of V a b -> a : f t)
101                                 in f)
102              )
103
104 Yikes! We evaluate g twice. This leads to a O(2^n) explosion
105 if g calls back to the same code recursively.
106
107 Solution: make sure that we only do the liberate-case thing on *functions*
108
109 To think about (Apr 94)
110 ~~~~~~~~~~~~~~
111 Main worry: duplicating code excessively.  At the moment we duplicate
112 the entire binding group once at each recursive call.  But there may
113 be a group of recursive calls which share a common set of evaluated
114 free variables, in which case the duplication is a plain waste.
115
116 Another thing we could consider adding is some unfold-threshold thing,
117 so that we'll only duplicate if the size of the group rhss isn't too
118 big.
119
120 Data types
121 ~~~~~~~~~~
122 The ``level'' of a binder tells how many
123 recursive defns lexically enclose the binding
124 A recursive defn "encloses" its RHS, not its
125 scope.  For example:
126 \begin{verbatim}
127         letrec f = let g = ... in ...
128         in
129         let h = ...
130         in ...
131 \end{verbatim}
132 Here, the level of @f@ is zero, the level of @g@ is one,
133 and the level of @h@ is zero (NB not one).
134
135
136 %************************************************************************
137 %*                                                                      *
138          Top-level code
139 %*                                                                      *
140 %************************************************************************
141
142 \begin{code}
143 liberateCase :: HscEnv -> UniqSupply -> RuleBase -> ModGuts
144              -> IO (SimplCount, ModGuts)
145 liberateCase hsc_env _ _ guts
146   = do  { let dflags = hsc_dflags hsc_env
147
148         ; showPass dflags "Liberate case"
149         ; let { env = initEnv dflags
150               ; binds' = do_prog env (mg_binds guts) }
151         ; endPass dflags "Liberate case" Opt_D_verbose_core2core binds'
152                         {- no specific flag for dumping -} 
153         ; return (zeroSimplCount dflags, guts { mg_binds = binds' }) }
154   where
155     do_prog env [] = []
156     do_prog env (bind:binds) = bind' : do_prog env' binds
157                              where
158                                (env', bind') = libCaseBind env bind
159 \end{code}
160
161
162 %************************************************************************
163 %*                                                                      *
164          Main payload
165 %*                                                                      *
166 %************************************************************************
167
168 Bindings
169 ~~~~~~~~
170 \begin{code}
171 libCaseBind :: LibCaseEnv -> CoreBind -> (LibCaseEnv, CoreBind)
172
173 libCaseBind env (NonRec binder rhs)
174   = (addBinders env [binder], NonRec binder (libCase env rhs))
175
176 libCaseBind env (Rec pairs)
177   = (env_body, Rec pairs')
178   where
179     (binders, rhss) = unzip pairs
180
181     env_body = addBinders env binders
182
183     pairs' = [(binder, libCase env_rhs rhs) | (binder,rhs) <- pairs]
184
185     env_rhs = if all rhs_small_enough pairs then extended_env else env
186
187         -- We extend the rec-env by binding each Id to its rhs, first
188         -- processing the rhs with an *un-extended* environment, so
189         -- that the same process doesn't occur for ever!
190         --
191     extended_env = addRecBinds env [ (adjust binder, libCase env_body rhs)
192                                    | (binder, rhs) <- pairs ]
193
194         -- Two subtle things: 
195         -- (a)  Reset the export flags on the binders so
196         --      that we don't get name clashes on exported things if the 
197         --      local binding floats out to top level.  This is most unlikely
198         --      to happen, since the whole point concerns free variables. 
199         --      But resetting the export flag is right regardless.
200         -- 
201         -- (b)  Make the name an Internal one.  External Names should never be
202         --      nested; if it were floated to the top level, we'd get a name
203         --      clash at code generation time.
204     adjust bndr = setIdNotExported (setIdName bndr (localiseName (idName bndr)))
205
206     rhs_small_enough (id,rhs)
207         =  idArity id > 0       -- Note [Only functions!]
208         && couldBeSmallEnoughToInline (bombOutSize env) rhs
209 \end{code}
210
211
212 Expressions
213 ~~~~~~~~~~~
214
215 \begin{code}
216 libCase :: LibCaseEnv
217         -> CoreExpr
218         -> CoreExpr
219
220 libCase env (Var v)             = libCaseId env v
221 libCase env (Lit lit)           = Lit lit
222 libCase env (Type ty)           = Type ty
223 libCase env (App fun arg)       = App (libCase env fun) (libCase env arg)
224 libCase env (Note note body)    = Note note (libCase env body)
225 libCase env (Cast e co)         = Cast (libCase env e) co
226
227 libCase env (Lam binder body)
228   = Lam binder (libCase (addBinders env [binder]) body)
229
230 libCase env (Let bind body)
231   = Let bind' (libCase env_body body)
232   where
233     (env_body, bind') = libCaseBind env bind
234
235 libCase env (Case scrut bndr ty alts)
236   = Case (libCase env scrut) bndr ty (map (libCaseAlt env_alts) alts)
237   where
238     env_alts = addBinders (mk_alt_env scrut) [bndr]
239     mk_alt_env (Var scrut_var) = addScrutedVar env scrut_var
240     mk_alt_env (Cast scrut _)  = mk_alt_env scrut       -- Note [Scrutinee with cast]
241     mk_alt_env otehr           = env
242
243 libCaseAlt env (con,args,rhs) = (con, args, libCase (addBinders env args) rhs)
244 \end{code}
245
246
247 Ids
248 ~~~
249 \begin{code}
250 libCaseId :: LibCaseEnv -> Id -> CoreExpr
251 libCaseId env v
252   | Just the_bind <- lookupRecId env v  -- It's a use of a recursive thing
253   , notNull free_scruts                 -- with free vars scrutinised in RHS
254   = Let the_bind (Var v)
255
256   | otherwise
257   = Var v
258
259   where
260     rec_id_level = lookupLevel env v
261     free_scruts  = freeScruts env rec_id_level
262 \end{code}
263
264
265 %************************************************************************
266 %*                                                                      *
267         Utility functions
268 %*                                                                      *
269 %************************************************************************
270
271 \begin{code}
272 addBinders :: LibCaseEnv -> [CoreBndr] -> LibCaseEnv
273 addBinders env@(LibCaseEnv { lc_lvl = lvl, lc_lvl_env = lvl_env }) binders
274   = env { lc_lvl_env = lvl_env' }
275   where
276     lvl_env' = extendVarEnvList lvl_env (binders `zip` repeat lvl)
277
278 addRecBinds :: LibCaseEnv -> [(Id,CoreExpr)] -> LibCaseEnv
279 addRecBinds env@(LibCaseEnv {lc_lvl = lvl, lc_lvl_env = lvl_env, 
280                              lc_rec_env = rec_env}) pairs
281   = env { lc_lvl = lvl', lc_lvl_env = lvl_env', lc_rec_env = rec_env' }
282   where
283     lvl'     = lvl + 1
284     lvl_env' = extendVarEnvList lvl_env [(binder,lvl) | (binder,_) <- pairs]
285     rec_env' = extendVarEnvList rec_env [(binder, Rec pairs) | (binder,_) <- pairs]
286
287 addScrutedVar :: LibCaseEnv
288               -> Id             -- This Id is being scrutinised by a case expression
289               -> LibCaseEnv
290
291 addScrutedVar env@(LibCaseEnv { lc_lvl = lvl, lc_lvl_env = lvl_env, 
292                                 lc_scruts = scruts }) scrut_var
293   | bind_lvl < lvl
294   = env { lc_scruts = scruts' }
295         -- Add to scruts iff the scrut_var is being scrutinised at
296         -- a deeper level than its defn
297
298   | otherwise = env
299   where
300     scruts'  = (scrut_var, lvl) : scruts
301     bind_lvl = case lookupVarEnv lvl_env scrut_var of
302                  Just lvl -> lvl
303                  Nothing  -> topLevel
304
305 lookupRecId :: LibCaseEnv -> Id -> Maybe CoreBind
306 lookupRecId env id = lookupVarEnv (lc_rec_env env) id
307
308 lookupLevel :: LibCaseEnv -> Id -> LibCaseLevel
309 lookupLevel env id
310   = case lookupVarEnv (lc_lvl_env env) id of
311       Just lvl -> lvl
312       Nothing  -> topLevel
313
314 freeScruts :: LibCaseEnv
315            -> LibCaseLevel      -- Level of the recursive Id
316            -> [Id]              -- Ids that are scrutinised between the binding
317                                 -- of the recursive Id and here
318 freeScruts env rec_bind_lvl
319   = [v | (v,scrut_lvl) <- lc_scruts env, scrut_lvl > rec_bind_lvl]
320 \end{code}
321
322 %************************************************************************
323 %*                                                                      *
324          The environment
325 %*                                                                      *
326 %************************************************************************
327
328 \begin{code}
329 type LibCaseLevel = Int
330
331 topLevel :: LibCaseLevel
332 topLevel = 0
333 \end{code}
334
335 \begin{code}
336 data LibCaseEnv
337   = LibCaseEnv {
338         lc_size :: Int,         -- Bomb-out size for deciding if
339                                 -- potential liberatees are too big.
340                                 -- (passed in from cmd-line args)
341
342         lc_lvl :: LibCaseLevel, -- Current level
343
344         lc_lvl_env :: IdEnv LibCaseLevel,  
345                         -- Binds all non-top-level in-scope Ids
346                         -- (top-level and imported things have
347                         -- a level of zero)
348
349         lc_rec_env :: IdEnv CoreBind, 
350                         -- Binds *only* recursively defined ids, 
351                         -- to their own binding group,
352                         -- and *only* in their own RHSs
353
354         lc_scruts :: [(Id,LibCaseLevel)]
355                         -- Each of these Ids was scrutinised by an
356                         -- enclosing case expression, with the
357                         -- specified number of enclosing
358                         -- recursive bindings; furthermore,
359                         -- the Id is bound at a lower level
360                         -- than the case expression.  The order is
361                         -- insignificant; it's a bag really
362         }
363
364 initEnv :: DynFlags -> LibCaseEnv
365 initEnv dflags 
366   = LibCaseEnv { lc_size = specThreshold dflags,
367                  lc_lvl = 0,
368                  lc_lvl_env = emptyVarEnv, 
369                  lc_rec_env = emptyVarEnv,
370                  lc_scruts = [] }
371
372 bombOutSize = lc_size
373 \end{code}
374
375