Correct an egregious typo in LiberateCase that emasculated it
[ghc-hetmet.git] / compiler / simplCore / LiberateCase.lhs
1 %
2 % (c) The AQUA Project, Glasgow University, 1994-1998
3 %
4 \section[LiberateCase]{Unroll recursion to allow evals to be lifted from a loop}
5
6 \begin{code}
7 module LiberateCase ( liberateCase ) where
8
9 #include "HsVersions.h"
10
11 import DynFlags
12 import HscTypes
13 import CoreLint         ( showPass, endPass )
14 import CoreSyn
15 import CoreUnfold       ( couldBeSmallEnoughToInline )
16 import Rules            ( RuleBase )
17 import UniqSupply       ( UniqSupply )
18 import SimplMonad       ( SimplCount, zeroSimplCount )
19 import Id
20 import FamInstEnv
21 import Type
22 import Coercion
23 import TyCon
24 import VarEnv
25 import Name             ( localiseName )
26 import Util             ( notNull )
27 import Data.IORef       ( readIORef )
28 \end{code}
29
30 The liberate-case transformation
31 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
32 This module walks over @Core@, and looks for @case@ on free variables.
33 The criterion is:
34         if there is case on a free on the route to the recursive call,
35         then the recursive call is replaced with an unfolding.
36
37 Example
38
39    f = \ t -> case v of
40                  V a b -> a : f t
41
42 => the inner f is replaced.
43
44    f = \ t -> case v of
45                  V a b -> a : (letrec
46                                 f =  \ t -> case v of
47                                                V a b -> a : f t
48                                in f) t
49 (note the NEED for shadowing)
50
51 => Simplify
52
53   f = \ t -> case v of
54                  V a b -> a : (letrec
55                                 f = \ t -> a : f t
56                                in f t)
57
58 Better code, because 'a' is  free inside the inner letrec, rather
59 than needing projection from v.
60
61 Other examples we'd like to catch with this kind of transformation
62
63         last []     = error 
64         last (x:[]) = x
65         last (x:xs) = last xs
66
67 We'd like to avoid the redundant pattern match, transforming to
68
69         last [] = error
70         last (x:[]) = x
71         last (x:(y:ys)) = last' y ys
72                 where
73                   last' y []     = y
74                   last' _ (y:ys) = last' y ys
75
76         (is this necessarily an improvement)
77
78 Similarly drop:
79
80         drop n [] = []
81         drop 0 xs = xs
82         drop n (x:xs) = drop (n-1) xs
83
84 Would like to pass n along unboxed.
85         
86 Note [Scrutinee with cast]
87 ~~~~~~~~~~~~~~~~~~~~~~~~~~
88 Consider this:
89     f = \ t -> case (v `cast` co) of
90                  V a b -> a : f t
91
92 Exactly the same optimisation (unrolling one call to f) will work here, 
93 despite the cast.  See mk_alt_env in the Case branch of libCase.
94
95
96 To think about (Apr 94)
97 ~~~~~~~~~~~~~~
98
99 Main worry: duplicating code excessively.  At the moment we duplicate
100 the entire binding group once at each recursive call.  But there may
101 be a group of recursive calls which share a common set of evaluated
102 free variables, in which case the duplication is a plain waste.
103
104 Another thing we could consider adding is some unfold-threshold thing,
105 so that we'll only duplicate if the size of the group rhss isn't too
106 big.
107
108 Data types
109 ~~~~~~~~~~
110 The ``level'' of a binder tells how many
111 recursive defns lexically enclose the binding
112 A recursive defn "encloses" its RHS, not its
113 scope.  For example:
114 \begin{verbatim}
115         letrec f = let g = ... in ...
116         in
117         let h = ...
118         in ...
119 \end{verbatim}
120 Here, the level of @f@ is zero, the level of @g@ is one,
121 and the level of @h@ is zero (NB not one).
122
123 Note [Indexed data types]
124 ~~~~~~~~~~~~~~~~~~~~~~~~~
125 Consider
126         data family T :: * -> *
127         data T Int = TI Int
128
129         f :: T Int -> Bool
130         f x = case x of { DEFAULT -> <body> }
131
132 We would like to change this to
133         f x = case x `cast` co of { TI p -> <body> }
134
135 so that <body> can make use of the fact that x is already evaluated to
136 a TI; and a case on a known data type may be more efficient than a
137 polymorphic one (not sure this is true any longer).  Anyway the former
138 showed up in Roman's experiments.  Example:
139   foo :: FooT Int -> Int -> Int
140   foo t n = t `seq` bar n
141      where
142        bar 0 = 0
143        bar n = bar (n - case t of TI i -> i)
144 Here we'd like to avoid repeated evaluating t inside the loop, by 
145 taking advantage of the `seq`.
146
147 We implement this as part of the liberate-case transformation by 
148 spotting
149         case <scrut> of (x::T) tys { DEFAULT ->  <body> }
150 where x :: T tys, and T is a indexed family tycon.  Find the
151 representation type (T77 tys'), and coercion co, and transform to
152         case <scrut> `cast` co of (y::T77 tys')
153             DEFAULT -> let x = y `cast` sym co in <body>
154
155 The "find the representation type" part is done by looking up in the
156 family-instance environment.
157
158 NB: in fact we re-use x (changing its type) to avoid making a fresh y;
159 this entails shadowing, but that's ok.
160
161 %************************************************************************
162 %*                                                                      *
163          Top-level code
164 %*                                                                      *
165 %************************************************************************
166
167 \begin{code}
168 liberateCase :: HscEnv -> UniqSupply -> RuleBase -> ModGuts
169              -> IO (SimplCount, ModGuts)
170 liberateCase hsc_env _ _ guts
171   = do  { let dflags = hsc_dflags hsc_env
172         ; eps <- readIORef (hsc_EPS hsc_env)
173         ; let fam_envs = (eps_fam_inst_env eps, mg_fam_inst_env guts)
174
175         ; showPass dflags "Liberate case"
176         ; let { env = initEnv dflags fam_envs
177               ; binds' = do_prog env (mg_binds guts) }
178         ; endPass dflags "Liberate case" Opt_D_verbose_core2core binds'
179                         {- no specific flag for dumping -} 
180         ; return (zeroSimplCount dflags, guts { mg_binds = binds' }) }
181   where
182     do_prog env [] = []
183     do_prog env (bind:binds) = bind' : do_prog env' binds
184                              where
185                                (env', bind') = libCaseBind env bind
186 \end{code}
187
188
189 %************************************************************************
190 %*                                                                      *
191          Main payload
192 %*                                                                      *
193 %************************************************************************
194
195 Bindings
196 ~~~~~~~~
197 \begin{code}
198 libCaseBind :: LibCaseEnv -> CoreBind -> (LibCaseEnv, CoreBind)
199
200 libCaseBind env (NonRec binder rhs)
201   = (addBinders env [binder], NonRec binder (libCase env rhs))
202
203 libCaseBind env (Rec pairs)
204   = (env_body, Rec pairs')
205   where
206     (binders, rhss) = unzip pairs
207
208     env_body = addBinders env binders
209
210     pairs' = [(binder, libCase env_rhs rhs) | (binder,rhs) <- pairs]
211
212     env_rhs = if all rhs_small_enough rhss then extended_env else env
213
214         -- We extend the rec-env by binding each Id to its rhs, first
215         -- processing the rhs with an *un-extended* environment, so
216         -- that the same process doesn't occur for ever!
217         --
218     extended_env = addRecBinds env [ (adjust binder, libCase env_body rhs)
219                                    | (binder, rhs) <- pairs ]
220
221         -- Two subtle things: 
222         -- (a)  Reset the export flags on the binders so
223         --      that we don't get name clashes on exported things if the 
224         --      local binding floats out to top level.  This is most unlikely
225         --      to happen, since the whole point concerns free variables. 
226         --      But resetting the export flag is right regardless.
227         -- 
228         -- (b)  Make the name an Internal one.  External Names should never be
229         --      nested; if it were floated to the top level, we'd get a name
230         --      clash at code generation time.
231     adjust bndr = setIdNotExported (setIdName bndr (localiseName (idName bndr)))
232
233     rhs_small_enough rhs = couldBeSmallEnoughToInline lIBERATE_BOMB_SIZE rhs
234     lIBERATE_BOMB_SIZE   = bombOutSize env
235 \end{code}
236
237
238 Expressions
239 ~~~~~~~~~~~
240
241 \begin{code}
242 libCase :: LibCaseEnv
243         -> CoreExpr
244         -> CoreExpr
245
246 libCase env (Var v)             = libCaseId env v
247 libCase env (Lit lit)           = Lit lit
248 libCase env (Type ty)           = Type ty
249 libCase env (App fun arg)       = App (libCase env fun) (libCase env arg)
250 libCase env (Note note body)    = Note note (libCase env body)
251 libCase env (Cast e co)         = Cast (libCase env e) co
252
253 libCase env (Lam binder body)
254   = Lam binder (libCase (addBinders env [binder]) body)
255
256 libCase env (Let bind body)
257   = Let bind' (libCase env_body body)
258   where
259     (env_body, bind') = libCaseBind env bind
260
261 libCase env (Case scrut bndr ty alts)
262   = mkCase env (libCase env scrut) bndr ty (map (libCaseAlt env_alts) alts)
263   where
264     env_alts = addBinders (mk_alt_env scrut) [bndr]
265     mk_alt_env (Var scrut_var) = addScrutedVar env scrut_var
266     mk_alt_env (Cast scrut _)  = mk_alt_env scrut       -- Note [Scrutinee with cast]
267     mk_alt_env otehr           = env
268
269 libCaseAlt env (con,args,rhs) = (con, args, libCase (addBinders env args) rhs)
270 \end{code}
271
272 \begin{code}
273 mkCase :: LibCaseEnv -> CoreExpr -> Id -> Type -> [CoreAlt] -> CoreExpr
274 -- See Note [Indexed data types]
275 mkCase env scrut bndr ty [(DEFAULT,_,rhs)]
276   | Just (tycon, tys)   <- splitTyConApp_maybe (idType bndr)
277   , [(subst, fam_inst)] <- lookupFamInstEnv (lc_fams env) tycon tys
278   = let 
279         rep_tc     = famInstTyCon fam_inst
280         rep_tys    = map (substTyVar subst) (tyConTyVars rep_tc)
281         bndr'      = setIdType bndr (mkTyConApp rep_tc rep_tys)
282         Just co_tc = tyConFamilyCoercion_maybe rep_tc
283         co         = mkTyConApp co_tc rep_tys
284         bind       = NonRec bndr (Cast (Var bndr') (mkSymCoercion co))
285     in mkCase env (Cast scrut co) bndr' ty [(DEFAULT,[],Let bind rhs)]
286 mkCase env scrut bndr ty alts
287   = Case scrut bndr ty alts
288 \end{code}
289
290 Ids
291 ~~~
292 \begin{code}
293 libCaseId :: LibCaseEnv -> Id -> CoreExpr
294 libCaseId env v
295   | Just the_bind <- lookupRecId env v  -- It's a use of a recursive thing
296   , notNull free_scruts                 -- with free vars scrutinised in RHS
297   = Let the_bind (Var v)
298
299   | otherwise
300   = Var v
301
302   where
303     rec_id_level = lookupLevel env v
304     free_scruts  = freeScruts env rec_id_level
305 \end{code}
306
307
308 %************************************************************************
309 %*                                                                      *
310         Utility functions
311 %*                                                                      *
312 %************************************************************************
313
314 \begin{code}
315 addBinders :: LibCaseEnv -> [CoreBndr] -> LibCaseEnv
316 addBinders env@(LibCaseEnv { lc_lvl = lvl, lc_lvl_env = lvl_env }) binders
317   = env { lc_lvl_env = lvl_env' }
318   where
319     lvl_env' = extendVarEnvList lvl_env (binders `zip` repeat lvl)
320
321 addRecBinds :: LibCaseEnv -> [(Id,CoreExpr)] -> LibCaseEnv
322 addRecBinds env@(LibCaseEnv {lc_lvl = lvl, lc_lvl_env = lvl_env, 
323                              lc_rec_env = rec_env}) pairs
324   = env { lc_lvl = lvl', lc_lvl_env = lvl_env', lc_rec_env = rec_env' }
325   where
326     lvl'     = lvl + 1
327     lvl_env' = extendVarEnvList lvl_env [(binder,lvl) | (binder,_) <- pairs]
328     rec_env' = extendVarEnvList rec_env [(binder, Rec pairs) | (binder,_) <- pairs]
329
330 addScrutedVar :: LibCaseEnv
331               -> Id             -- This Id is being scrutinised by a case expression
332               -> LibCaseEnv
333
334 addScrutedVar env@(LibCaseEnv { lc_lvl = lvl, lc_lvl_env = lvl_env, 
335                                 lc_scruts = scruts }) scrut_var
336   | bind_lvl < lvl
337   = env { lc_scruts = scruts' }
338         -- Add to scruts iff the scrut_var is being scrutinised at
339         -- a deeper level than its defn
340
341   | otherwise = env
342   where
343     scruts'  = (scrut_var, lvl) : scruts
344     bind_lvl = case lookupVarEnv lvl_env scrut_var of
345                  Just lvl -> lvl
346                  Nothing  -> topLevel
347
348 lookupRecId :: LibCaseEnv -> Id -> Maybe CoreBind
349 lookupRecId env id = lookupVarEnv (lc_rec_env env) id
350
351 lookupLevel :: LibCaseEnv -> Id -> LibCaseLevel
352 lookupLevel env id
353   = case lookupVarEnv (lc_lvl_env env) id of
354       Just lvl -> lvl
355       Nothing  -> topLevel
356
357 freeScruts :: LibCaseEnv
358            -> LibCaseLevel      -- Level of the recursive Id
359            -> [Id]              -- Ids that are scrutinised between the binding
360                                 -- of the recursive Id and here
361 freeScruts env rec_bind_lvl
362   = [v | (v,scrut_lvl) <- lc_scruts env, scrut_lvl > rec_bind_lvl]
363 \end{code}
364
365 %************************************************************************
366 %*                                                                      *
367          The environment
368 %*                                                                      *
369 %************************************************************************
370
371 \begin{code}
372 type LibCaseLevel = Int
373
374 topLevel :: LibCaseLevel
375 topLevel = 0
376 \end{code}
377
378 \begin{code}
379 data LibCaseEnv
380   = LibCaseEnv {
381         lc_size :: Int,         -- Bomb-out size for deciding if
382                                 -- potential liberatees are too big.
383                                 -- (passed in from cmd-line args)
384
385         lc_lvl :: LibCaseLevel, -- Current level
386
387         lc_lvl_env :: IdEnv LibCaseLevel,  
388                         -- Binds all non-top-level in-scope Ids
389                         -- (top-level and imported things have
390                         -- a level of zero)
391
392         lc_rec_env :: IdEnv CoreBind, 
393                         -- Binds *only* recursively defined ids, 
394                         -- to their own binding group,
395                         -- and *only* in their own RHSs
396
397         lc_scruts :: [(Id,LibCaseLevel)],
398                         -- Each of these Ids was scrutinised by an
399                         -- enclosing case expression, with the
400                         -- specified number of enclosing
401                         -- recursive bindings; furthermore,
402                         -- the Id is bound at a lower level
403                         -- than the case expression.  The order is
404                         -- insignificant; it's a bag really
405
406         lc_fams :: FamInstEnvs
407                         -- Instance env for indexed data types 
408         }
409
410 initEnv :: DynFlags -> FamInstEnvs -> LibCaseEnv
411 initEnv dflags fams
412   = LibCaseEnv { lc_size = specThreshold dflags,
413                  lc_lvl = 0,
414                  lc_lvl_env = emptyVarEnv, 
415                  lc_rec_env = emptyVarEnv,
416                  lc_scruts = [],
417                  lc_fams = fams }
418
419 bombOutSize = lc_size
420 \end{code}
421
422