Make the LiberateCase transformation understand associated types
[ghc-hetmet.git] / compiler / simplCore / LiberateCase.lhs
1 %
2 % (c) The AQUA Project, Glasgow University, 1994-1998
3 %
4 \section[LiberateCase]{Unroll recursion to allow evals to be lifted from a loop}
5
6 \begin{code}
7 module LiberateCase ( liberateCase ) where
8
9 #include "HsVersions.h"
10
11 import DynFlags
12 import HscTypes
13 import CoreLint         ( showPass, endPass )
14 import CoreSyn
15 import CoreUnfold       ( couldBeSmallEnoughToInline )
16 import Rules            ( RuleBase )
17 import UniqSupply       ( UniqSupply )
18 import SimplMonad       ( SimplCount, zeroSimplCount )
19 import Id
20 import FamInstEnv
21 import Type
22 import Coercion
23 import TyCon
24 import VarEnv
25 import Name             ( localiseName )
26 import Outputable
27 import Util             ( notNull )
28 import Data.IORef       ( readIORef )
29 \end{code}
30
31 The liberate-case transformation
32 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
33 This module walks over @Core@, and looks for @case@ on free variables.
34 The criterion is:
35         if there is case on a free on the route to the recursive call,
36         then the recursive call is replaced with an unfolding.
37
38 Example
39
40    f = \ t -> case v of
41                  V a b -> a : f t
42
43 => the inner f is replaced.
44
45    f = \ t -> case v of
46                  V a b -> a : (letrec
47                                 f =  \ t -> case v of
48                                                V a b -> a : f t
49                                in f) t
50 (note the NEED for shadowing)
51
52 => Simplify
53
54   f = \ t -> case v of
55                  V a b -> a : (letrec
56                                 f = \ t -> a : f t
57                                in f t)
58
59 Better code, because 'a' is  free inside the inner letrec, rather
60 than needing projection from v.
61
62 Other examples we'd like to catch with this kind of transformation
63
64         last []     = error 
65         last (x:[]) = x
66         last (x:xs) = last xs
67
68 We'd like to avoid the redundant pattern match, transforming to
69
70         last [] = error
71         last (x:[]) = x
72         last (x:(y:ys)) = last' y ys
73                 where
74                   last' y []     = y
75                   last' _ (y:ys) = last' y ys
76
77         (is this necessarily an improvement)
78
79 Similarly drop:
80
81         drop n [] = []
82         drop 0 xs = xs
83         drop n (x:xs) = drop (n-1) xs
84
85 Would like to pass n along unboxed.
86         
87 Note [Scrutinee with cast]
88 ~~~~~~~~~~~~~~~~~~~~~~~~~~
89 Consider this:
90     f = \ t -> case (v `cast` co) of
91                  V a b -> a : f t
92
93 Exactly the same optimistaion (unrolling one call to f) will work here, 
94 despite the cast.  See mk_alt_env in the Case branch of libCase.
95
96
97 To think about (Apr 94)
98 ~~~~~~~~~~~~~~
99
100 Main worry: duplicating code excessively.  At the moment we duplicate
101 the entire binding group once at each recursive call.  But there may
102 be a group of recursive calls which share a common set of evaluated
103 free variables, in which case the duplication is a plain waste.
104
105 Another thing we could consider adding is some unfold-threshold thing,
106 so that we'll only duplicate if the size of the group rhss isn't too
107 big.
108
109 Data types
110 ~~~~~~~~~~
111
112 The ``level'' of a binder tells how many
113 recursive defns lexically enclose the binding
114 A recursive defn "encloses" its RHS, not its
115 scope.  For example:
116 \begin{verbatim}
117         letrec f = let g = ... in ...
118         in
119         let h = ...
120         in ...
121 \end{verbatim}
122 Here, the level of @f@ is zero, the level of @g@ is one,
123 and the level of @h@ is zero (NB not one).
124
125 Note [Indexed data types]
126 ~~~~~~~~~~~~~~~~~~~~~~~~~
127 Consider
128         data family T :: * -> *
129         data T Int = TI Int
130
131         f :: T Int -> Bool
132         f x = case x of { DEFAULT -> <body> }
133
134 We would like to change this to
135         f x = case x `cast` co of { TI p -> <body> }
136
137 so that <body> can make use of the fact that x is already evaluated to
138 a TI; and a case on a known data type may be more efficient than a
139 polymorphic one (not sure this is true any longer).  Anyway the former
140 showed up in Roman's experiments.  Example:
141   foo :: FooT Int -> Int -> Int
142   foo t n = t `seq` bar n
143      where
144        bar 0 = 0
145        bar n = bar (n - case t of TI i -> i)
146 Here we'd like to avoid repeated evaluating t inside the loop, by 
147 taking advantage of the `seq`.
148
149 We implement this as part of the liberate-case transformation by 
150 spotting
151         case <scrut> of (x::T) tys { DEFAULT ->  <body> }
152 where x :: T tys, and T is a indexed family tycon.  Find the
153 representation type (T77 tys'), and coercion co, and transform to
154         case <scrut> `cast` co of (y::T77 tys')
155             DEFAULT -> let x = y `cast` sym co in <body>
156
157 The "find the representation type" part is done by looking up in the
158 family-instance environment.
159
160 NB: in fact we re-use x (changing its type) to avoid making a fresh y;
161 this entails shadowing, but that's ok.
162
163 %************************************************************************
164 %*                                                                      *
165          Top-level code
166 %*                                                                      *
167 %************************************************************************
168
169 \begin{code}
170 liberateCase :: HscEnv -> UniqSupply -> RuleBase -> ModGuts
171              -> IO (SimplCount, ModGuts)
172 liberateCase hsc_env _ _ guts
173   = do  { let dflags = hsc_dflags hsc_env
174         ; eps <- readIORef (hsc_EPS hsc_env)
175         ; let fam_envs = (eps_fam_inst_env eps, mg_fam_inst_env guts)
176
177         ; showPass dflags "Liberate case"
178         ; let { env = initEnv dflags fam_envs
179               ; binds' = do_prog env (mg_binds guts) }
180         ; endPass dflags "Liberate case" Opt_D_verbose_core2core binds'
181                         {- no specific flag for dumping -} 
182         ; return (zeroSimplCount dflags, guts { mg_binds = binds' }) }
183   where
184     do_prog env [] = []
185     do_prog env (bind:binds) = bind' : do_prog env' binds
186                              where
187                                (env', bind') = libCaseBind env bind
188 \end{code}
189
190
191 %************************************************************************
192 %*                                                                      *
193          Main payload
194 %*                                                                      *
195 %************************************************************************
196
197 Bindings
198 ~~~~~~~~
199 \begin{code}
200 libCaseBind :: LibCaseEnv -> CoreBind -> (LibCaseEnv, CoreBind)
201
202 libCaseBind env (NonRec binder rhs)
203   = (addBinders env [binder], NonRec binder (libCase env rhs))
204
205 libCaseBind env (Rec pairs)
206   = (env_body, Rec pairs')
207   where
208     (binders, rhss) = unzip pairs
209
210     env_body = addBinders env binders
211
212     pairs' = [(binder, libCase env_rhs rhs) | (binder,rhs) <- pairs]
213
214     env_rhs = if all rhs_small_enough rhss then extended_env else env
215
216         -- We extend the rec-env by binding each Id to its rhs, first
217         -- processing the rhs with an *un-extended* environment, so
218         -- that the same process doesn't occur for ever!
219         --
220     extended_env = addRecBinds env [ (adjust binder, libCase env_body rhs)
221                                    | (binder, rhs) <- pairs ]
222
223         -- Two subtle things: 
224         -- (a)  Reset the export flags on the binders so
225         --      that we don't get name clashes on exported things if the 
226         --      local binding floats out to top level.  This is most unlikely
227         --      to happen, since the whole point concerns free variables. 
228         --      But resetting the export flag is right regardless.
229         -- 
230         -- (b)  Make the name an Internal one.  External Names should never be
231         --      nested; if it were floated to the top level, we'd get a name
232         --      clash at code generation time.
233     adjust bndr = setIdNotExported (setIdName bndr (localiseName (idName bndr)))
234
235     rhs_small_enough rhs = couldBeSmallEnoughToInline lIBERATE_BOMB_SIZE rhs
236     lIBERATE_BOMB_SIZE   = bombOutSize env
237 \end{code}
238
239
240 Expressions
241 ~~~~~~~~~~~
242
243 \begin{code}
244 libCase :: LibCaseEnv
245         -> CoreExpr
246         -> CoreExpr
247
248 libCase env (Var v)             = libCaseId env v
249 libCase env (Lit lit)           = Lit lit
250 libCase env (Type ty)           = Type ty
251 libCase env (App fun arg)       = App (libCase env fun) (libCase env arg)
252 libCase env (Note note body)    = Note note (libCase env body)
253 libCase env (Cast e co)         = Cast (libCase env e) co
254
255 libCase env (Lam binder body)
256   = Lam binder (libCase (addBinders env [binder]) body)
257
258 libCase env (Let bind body)
259   = Let bind' (libCase env_body body)
260   where
261     (env_body, bind') = libCaseBind env bind
262
263 libCase env (Case scrut bndr ty alts)
264   = mkCase env (libCase env scrut) bndr ty (map (libCaseAlt env_alts) alts)
265   where
266     env_alts = addBinders (mk_alt_env scrut) [bndr]
267     mk_alt_env (Var scrut_var) = addScrutedVar env scrut_var
268     mk_alt_env (Cast scrut _)  = mk_alt_env scrut       -- Note [Scrutinee with cast]
269     mk_alt_env otehr           = env
270
271 libCaseAlt env (con,args,rhs) = (con, args, libCase (addBinders env args) rhs)
272 \end{code}
273
274 \begin{code}
275 mkCase :: LibCaseEnv -> CoreExpr -> Id -> Type -> [CoreAlt] -> CoreExpr
276 -- See Note [Indexed data types]
277 mkCase env scrut bndr ty [(DEFAULT,_,rhs)]
278   | Just (tycon, tys)   <- splitTyConApp_maybe (idType bndr)
279   , [(subst, fam_inst)] <- lookupFamInstEnv (lc_fams env) tycon tys
280   = let 
281         rep_tc     = famInstTyCon fam_inst
282         rep_tys    = map (substTyVar subst) (tyConTyVars rep_tc)
283         bndr'      = setIdType bndr (mkTyConApp rep_tc rep_tys)
284         Just co_tc = tyConFamilyCoercion_maybe rep_tc
285         co         = mkTyConApp co_tc rep_tys
286         bind       = NonRec bndr (Cast (Var bndr') (mkSymCoercion co))
287     in mkCase env (Cast scrut co) bndr' ty [(DEFAULT,[],Let bind rhs)]
288 mkCase env scrut bndr ty alts
289   = Case scrut bndr ty alts
290 \end{code}
291
292 Ids
293 ~~~
294 \begin{code}
295 libCaseId :: LibCaseEnv -> Id -> CoreExpr
296 libCaseId env v
297   | Just the_bind <- lookupRecId env v  -- It's a use of a recursive thing
298   , notNull free_scruts                 -- with free vars scrutinised in RHS
299   = Let the_bind (Var v)
300
301   | otherwise
302   = Var v
303
304   where
305     rec_id_level = lookupLevel env v
306     free_scruts  = freeScruts env rec_id_level
307 \end{code}
308
309
310 %************************************************************************
311 %*                                                                      *
312         Utility functions
313 %*                                                                      *
314 %************************************************************************
315
316 \begin{code}
317 addBinders :: LibCaseEnv -> [CoreBndr] -> LibCaseEnv
318 addBinders env@(LibCaseEnv { lc_lvl = lvl, lc_lvl_env = lvl_env }) binders
319   = env { lc_lvl_env = lvl_env' }
320   where
321     lvl_env' = extendVarEnvList lvl_env (binders `zip` repeat lvl)
322
323 addRecBinds :: LibCaseEnv -> [(Id,CoreExpr)] -> LibCaseEnv
324 addRecBinds env@(LibCaseEnv {lc_lvl = lvl, lc_lvl_env = lvl_env, 
325                              lc_rec_env = rec_env}) pairs
326   = env { lc_lvl = lvl', lc_lvl_env = lvl_env', lc_rec_env = rec_env' }
327   where
328     lvl'     = lvl + 1
329     lvl_env' = extendVarEnvList lvl_env [(binder,lvl) | (binder,_) <- pairs]
330     rec_env' = extendVarEnvList rec_env [(binder, Rec pairs) | (binder,_) <- pairs]
331
332 addScrutedVar :: LibCaseEnv
333               -> Id             -- This Id is being scrutinised by a case expression
334               -> LibCaseEnv
335
336 addScrutedVar env@(LibCaseEnv { lc_lvl = lvl, lc_lvl_env = lvl_env, 
337                                 lc_scruts = scruts }) scrut_var
338   | bind_lvl < lvl
339   = env { lc_scruts = scruts' }
340         -- Add to scruts iff the scrut_var is being scrutinised at
341         -- a deeper level than its defn
342
343   | otherwise = env
344   where
345     scruts'  = (scrut_var, lvl) : scruts
346     bind_lvl = case lookupVarEnv lvl_env scrut_var of
347                  Just lvl -> lvl
348                  Nothing  -> topLevel
349
350 lookupRecId :: LibCaseEnv -> Id -> Maybe CoreBind
351 lookupRecId env id = lookupVarEnv (lc_rec_env env) id
352
353 lookupLevel :: LibCaseEnv -> Id -> LibCaseLevel
354 lookupLevel env id
355   = case lookupVarEnv (lc_lvl_env env) id of
356       Just lvl -> lc_lvl env
357       Nothing  -> topLevel
358
359 freeScruts :: LibCaseEnv
360            -> LibCaseLevel      -- Level of the recursive Id
361            -> [Id]              -- Ids that are scrutinised between the binding
362                                 -- of the recursive Id and here
363 freeScruts env rec_bind_lvl
364   = [v | (v,scrut_lvl) <- lc_scruts env, scrut_lvl > rec_bind_lvl]
365 \end{code}
366
367 %************************************************************************
368 %*                                                                      *
369          The environment
370 %*                                                                      *
371 %************************************************************************
372
373 \begin{code}
374 type LibCaseLevel = Int
375
376 topLevel :: LibCaseLevel
377 topLevel = 0
378 \end{code}
379
380 \begin{code}
381 data LibCaseEnv
382   = LibCaseEnv {
383         lc_size :: Int,         -- Bomb-out size for deciding if
384                                 -- potential liberatees are too big.
385                                 -- (passed in from cmd-line args)
386
387         lc_lvl :: LibCaseLevel, -- Current level
388
389         lc_lvl_env :: IdEnv LibCaseLevel,  
390                         -- Binds all non-top-level in-scope Ids
391                         -- (top-level and imported things have
392                         -- a level of zero)
393
394         lc_rec_env :: IdEnv CoreBind, 
395                         -- Binds *only* recursively defined ids, 
396                         -- to their own binding group,
397                         -- and *only* in their own RHSs
398
399         lc_scruts :: [(Id,LibCaseLevel)],
400                         -- Each of these Ids was scrutinised by an
401                         -- enclosing case expression, with the
402                         -- specified number of enclosing
403                         -- recursive bindings; furthermore,
404                         -- the Id is bound at a lower level
405                         -- than the case expression.  The order is
406                         -- insignificant; it's a bag really
407
408         lc_fams :: FamInstEnvs
409                         -- Instance env for indexed data types 
410         }
411
412 initEnv :: DynFlags -> FamInstEnvs -> LibCaseEnv
413 initEnv dflags fams
414   = LibCaseEnv { lc_size = libCaseThreshold dflags,
415                  lc_lvl = 0,
416                  lc_lvl_env = emptyVarEnv, 
417                  lc_rec_env = emptyVarEnv,
418                  lc_scruts = [],
419                  lc_fams = fams }
420
421 bombOutSize = lc_size
422 \end{code}
423
424