Fix LiberateCase
[ghc-hetmet.git] / compiler / simplCore / LiberateCase.lhs
1 %
2 % (c) The AQUA Project, Glasgow University, 1994-1998
3 %
4 \section[LiberateCase]{Unroll recursion to allow evals to be lifted from a loop}
5
6 \begin{code}
7 {-# OPTIONS -w #-}
8 -- The above warning supression flag is a temporary kludge.
9 -- While working on this module you are encouraged to remove it and fix
10 -- any warnings in the module. See
11 --     http://hackage.haskell.org/trac/ghc/wiki/Commentary/CodingStyle#Warnings
12 -- for details
13
14 module LiberateCase ( liberateCase ) where
15
16 #include "HsVersions.h"
17
18 import DynFlags
19 import HscTypes
20 import CoreLint         ( showPass, endPass )
21 import CoreSyn
22 import CoreUnfold       ( couldBeSmallEnoughToInline )
23 import Rules            ( RuleBase )
24 import UniqSupply       ( UniqSupply )
25 import SimplMonad       ( SimplCount, zeroSimplCount )
26 import Id
27 import VarEnv
28 import Name             ( localiseName )
29 import Util             ( notNull )
30 \end{code}
31
32 The liberate-case transformation
33 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
34 This module walks over @Core@, and looks for @case@ on free variables.
35 The criterion is:
36         if there is case on a free on the route to the recursive call,
37         then the recursive call is replaced with an unfolding.
38
39 Example
40
41    f = \ t -> case v of
42                  V a b -> a : f t
43
44 => the inner f is replaced.
45
46    f = \ t -> case v of
47                  V a b -> a : (letrec
48                                 f =  \ t -> case v of
49                                                V a b -> a : f t
50                                in f) t
51 (note the NEED for shadowing)
52
53 => Simplify
54
55   f = \ t -> case v of
56                  V a b -> a : (letrec
57                                 f = \ t -> a : f t
58                                in f t)
59
60 Better code, because 'a' is  free inside the inner letrec, rather
61 than needing projection from v.
62
63 Note that this deals with *free variables*.  SpecConstr deals with
64 *arguments* that are of known form.  E.g.
65
66         last []     = error 
67         last (x:[]) = x
68         last (x:xs) = last xs
69
70         
71 Note [Scrutinee with cast]
72 ~~~~~~~~~~~~~~~~~~~~~~~~~~
73 Consider this:
74     f = \ t -> case (v `cast` co) of
75                  V a b -> a : f t
76
77 Exactly the same optimisation (unrolling one call to f) will work here, 
78 despite the cast.  See mk_alt_env in the Case branch of libCase.
79
80
81 Note [Only functions!]
82 ~~~~~~~~~~~~~~~~~~~~~~
83 Consider the following code
84
85        f = g (case v of V a b -> a : t f)
86
87 where g is expensive. If we aren't careful, liberate case will turn this into
88
89        f = g (case v of
90                V a b -> a : t (letrec f = g (case v of V a b -> a : f t)
91                                 in f)
92              )
93
94 Yikes! We evaluate g twice. This leads to a O(2^n) explosion
95 if g calls back to the same code recursively.
96
97 Solution: make sure that we only do the liberate-case thing on *functions*
98
99 To think about (Apr 94)
100 ~~~~~~~~~~~~~~
101 Main worry: duplicating code excessively.  At the moment we duplicate
102 the entire binding group once at each recursive call.  But there may
103 be a group of recursive calls which share a common set of evaluated
104 free variables, in which case the duplication is a plain waste.
105
106 Another thing we could consider adding is some unfold-threshold thing,
107 so that we'll only duplicate if the size of the group rhss isn't too
108 big.
109
110 Data types
111 ~~~~~~~~~~
112 The ``level'' of a binder tells how many
113 recursive defns lexically enclose the binding
114 A recursive defn "encloses" its RHS, not its
115 scope.  For example:
116 \begin{verbatim}
117         letrec f = let g = ... in ...
118         in
119         let h = ...
120         in ...
121 \end{verbatim}
122 Here, the level of @f@ is zero, the level of @g@ is one,
123 and the level of @h@ is zero (NB not one).
124
125
126 %************************************************************************
127 %*                                                                      *
128          Top-level code
129 %*                                                                      *
130 %************************************************************************
131
132 \begin{code}
133 liberateCase :: HscEnv -> UniqSupply -> RuleBase -> ModGuts
134              -> IO (SimplCount, ModGuts)
135 liberateCase hsc_env _ _ guts
136   = do  { let dflags = hsc_dflags hsc_env
137
138         ; showPass dflags "Liberate case"
139         ; let { env = initEnv dflags
140               ; binds' = do_prog env (mg_binds guts) }
141         ; endPass dflags "Liberate case" Opt_D_verbose_core2core binds'
142                         {- no specific flag for dumping -} 
143         ; return (zeroSimplCount dflags, guts { mg_binds = binds' }) }
144   where
145     do_prog env [] = []
146     do_prog env (bind:binds) = bind' : do_prog env' binds
147                              where
148                                (env', bind') = libCaseBind env bind
149 \end{code}
150
151
152 %************************************************************************
153 %*                                                                      *
154          Main payload
155 %*                                                                      *
156 %************************************************************************
157
158 Bindings
159 ~~~~~~~~
160 \begin{code}
161 libCaseBind :: LibCaseEnv -> CoreBind -> (LibCaseEnv, CoreBind)
162
163 libCaseBind env (NonRec binder rhs)
164   = (addBinders env [binder], NonRec binder (libCase env rhs))
165
166 libCaseBind env (Rec pairs)
167   = (env_body, Rec pairs')
168   where
169     (binders, rhss) = unzip pairs
170
171     env_body = addBinders env binders
172
173     pairs' = [(binder, libCase env_rhs rhs) | (binder,rhs) <- pairs]
174
175     env_rhs = if all rhs_small_enough pairs then extended_env else env
176
177         -- We extend the rec-env by binding each Id to its rhs, first
178         -- processing the rhs with an *un-extended* environment, so
179         -- that the same process doesn't occur for ever!
180         --
181     extended_env = addRecBinds env [ (adjust binder, libCase env_body rhs)
182                                    | (binder, rhs) <- pairs ]
183
184         -- Two subtle things: 
185         -- (a)  Reset the export flags on the binders so
186         --      that we don't get name clashes on exported things if the 
187         --      local binding floats out to top level.  This is most unlikely
188         --      to happen, since the whole point concerns free variables. 
189         --      But resetting the export flag is right regardless.
190         -- 
191         -- (b)  Make the name an Internal one.  External Names should never be
192         --      nested; if it were floated to the top level, we'd get a name
193         --      clash at code generation time.
194     adjust bndr = setIdNotExported (setIdName bndr (localiseName (idName bndr)))
195
196     rhs_small_enough (id,rhs)
197         =  idArity id > 0       -- Note [Only functions!]
198         && couldBeSmallEnoughToInline (bombOutSize env) rhs
199 \end{code}
200
201
202 Expressions
203 ~~~~~~~~~~~
204
205 \begin{code}
206 libCase :: LibCaseEnv
207         -> CoreExpr
208         -> CoreExpr
209
210 libCase env (Var v)             = libCaseId env v
211 libCase env (Lit lit)           = Lit lit
212 libCase env (Type ty)           = Type ty
213 libCase env (App fun arg)       = App (libCase env fun) (libCase env arg)
214 libCase env (Note note body)    = Note note (libCase env body)
215 libCase env (Cast e co)         = Cast (libCase env e) co
216
217 libCase env (Lam binder body)
218   = Lam binder (libCase (addBinders env [binder]) body)
219
220 libCase env (Let bind body)
221   = Let bind' (libCase env_body body)
222   where
223     (env_body, bind') = libCaseBind env bind
224
225 libCase env (Case scrut bndr ty alts)
226   = Case (libCase env scrut) bndr ty (map (libCaseAlt env_alts) alts)
227   where
228     env_alts = addBinders (mk_alt_env scrut) [bndr]
229     mk_alt_env (Var scrut_var) = addScrutedVar env scrut_var
230     mk_alt_env (Cast scrut _)  = mk_alt_env scrut       -- Note [Scrutinee with cast]
231     mk_alt_env otehr           = env
232
233 libCaseAlt env (con,args,rhs) = (con, args, libCase (addBinders env args) rhs)
234 \end{code}
235
236
237 Ids
238 ~~~
239 \begin{code}
240 libCaseId :: LibCaseEnv -> Id -> CoreExpr
241 libCaseId env v
242   | Just the_bind <- lookupRecId env v  -- It's a use of a recursive thing
243   , notNull free_scruts                 -- with free vars scrutinised in RHS
244   = Let the_bind (Var v)
245
246   | otherwise
247   = Var v
248
249   where
250     rec_id_level = lookupLevel env v
251     free_scruts  = freeScruts env rec_id_level
252
253 freeScruts :: LibCaseEnv
254            -> LibCaseLevel      -- Level of the recursive Id
255            -> [Id]              -- Ids that are scrutinised between the binding
256                                 -- of the recursive Id and here
257 freeScruts env rec_bind_lvl
258   = [v | (v,scrut_bind_lvl) <- lc_scruts env
259        , scrut_bind_lvl <= rec_bind_lvl]
260         -- Note [When to specialise]
261 \end{code}
262
263 Note [When to specialise]
264 ~~~~~~~~~~~~~~~~~~~~~~~~~
265 Consider
266   f = \x. letrec g = \y. case x of
267                            True  -> ... (f a) ...
268                            False -> ... (g b) ...
269
270 We get the following levels
271           f  0
272           x  1
273           g  1
274           y  2  
275
276 Then 'x' is being scrutinised at a deeper level than its binding, so
277 it's added to lc_sruts:  [(x,1)]  
278
279 We do *not* want to specialise the call to 'f', becuase 'x' is not free 
280 in 'f'.  So here the bind-level of 'x' (=1) is not <= the bind-level of 'f' (=0).
281
282 We *do* want to specialise the call to 'g', because 'x' is free in g.
283 Here the bind-level of 'x' (=1) is <= the bind-level of 'g' (=1).
284
285
286 %************************************************************************
287 %*                                                                      *
288         Utility functions
289 %*                                                                      *
290 %************************************************************************
291
292 \begin{code}
293 addBinders :: LibCaseEnv -> [CoreBndr] -> LibCaseEnv
294 addBinders env@(LibCaseEnv { lc_lvl = lvl, lc_lvl_env = lvl_env }) binders
295   = env { lc_lvl_env = lvl_env' }
296   where
297     lvl_env' = extendVarEnvList lvl_env (binders `zip` repeat lvl)
298
299 addRecBinds :: LibCaseEnv -> [(Id,CoreExpr)] -> LibCaseEnv
300 addRecBinds env@(LibCaseEnv {lc_lvl = lvl, lc_lvl_env = lvl_env, 
301                              lc_rec_env = rec_env}) pairs
302   = env { lc_lvl = lvl', lc_lvl_env = lvl_env', lc_rec_env = rec_env' }
303   where
304     lvl'     = lvl + 1
305     lvl_env' = extendVarEnvList lvl_env [(binder,lvl) | (binder,_) <- pairs]
306     rec_env' = extendVarEnvList rec_env [(binder, Rec pairs) | (binder,_) <- pairs]
307
308 addScrutedVar :: LibCaseEnv
309               -> Id             -- This Id is being scrutinised by a case expression
310               -> LibCaseEnv
311
312 addScrutedVar env@(LibCaseEnv { lc_lvl = lvl, lc_lvl_env = lvl_env, 
313                                 lc_scruts = scruts }) scrut_var
314   | bind_lvl < lvl
315   = env { lc_scruts = scruts' }
316         -- Add to scruts iff the scrut_var is being scrutinised at
317         -- a deeper level than its defn
318
319   | otherwise = env
320   where
321     scruts'  = (scrut_var, bind_lvl) : scruts
322     bind_lvl = case lookupVarEnv lvl_env scrut_var of
323                  Just lvl -> lvl
324                  Nothing  -> topLevel
325
326 lookupRecId :: LibCaseEnv -> Id -> Maybe CoreBind
327 lookupRecId env id = lookupVarEnv (lc_rec_env env) id
328
329 lookupLevel :: LibCaseEnv -> Id -> LibCaseLevel
330 lookupLevel env id
331   = case lookupVarEnv (lc_lvl_env env) id of
332       Just lvl -> lvl
333       Nothing  -> topLevel
334 \end{code}
335
336 %************************************************************************
337 %*                                                                      *
338          The environment
339 %*                                                                      *
340 %************************************************************************
341
342 \begin{code}
343 type LibCaseLevel = Int
344
345 topLevel :: LibCaseLevel
346 topLevel = 0
347 \end{code}
348
349 \begin{code}
350 data LibCaseEnv
351   = LibCaseEnv {
352         lc_size :: Int,         -- Bomb-out size for deciding if
353                                 -- potential liberatees are too big.
354                                 -- (passed in from cmd-line args)
355
356         lc_lvl :: LibCaseLevel, -- Current level
357                 -- The level is incremented when (and only when) going
358                 -- inside the RHS of a (sufficiently small) recursive
359                 -- function.
360
361         lc_lvl_env :: IdEnv LibCaseLevel,  
362                 -- Binds all non-top-level in-scope Ids (top-level and
363                 -- imported things have a level of zero)
364
365         lc_rec_env :: IdEnv CoreBind, 
366                 -- Binds *only* recursively defined ids, to their own
367                 -- binding group, and *only* in their own RHSs
368
369         lc_scruts :: [(Id,LibCaseLevel)]
370                 -- Each of these Ids was scrutinised by an enclosing
371                 -- case expression, at a level deeper than its binding
372                 -- level.  The LibCaseLevel recorded here is the *binding
373                 -- level* of the scrutinised Id.
374                 -- 
375                 -- The order is insignificant; it's a bag really
376         }
377
378 initEnv :: DynFlags -> LibCaseEnv
379 initEnv dflags 
380   = LibCaseEnv { lc_size = specThreshold dflags,
381                  lc_lvl = 0,
382                  lc_lvl_env = emptyVarEnv, 
383                  lc_rec_env = emptyVarEnv,
384                  lc_scruts = [] }
385
386 bombOutSize = lc_size
387 \end{code}
388
389