Improve liberate-case to take account of coercions
[ghc-hetmet.git] / compiler / simplCore / LiberateCase.lhs
1 %
2 % (c) The AQUA Project, Glasgow University, 1994-1998
3 %
4 \section[LiberateCase]{Unroll recursion to allow evals to be lifted from a loop}
5
6 \begin{code}
7 module LiberateCase ( liberateCase ) where
8
9 #include "HsVersions.h"
10
11 import DynFlags         ( DynFlags, DynFlag(..) )
12 import StaticFlags      ( opt_LiberateCaseThreshold )
13 import CoreLint         ( showPass, endPass )
14 import CoreSyn
15 import CoreUnfold       ( couldBeSmallEnoughToInline )
16 import Id               ( Id, setIdName, idName, setIdNotExported )
17 import VarEnv
18 import Name             ( localiseName )
19 import Outputable
20 import Util             ( notNull )
21 \end{code}
22
23 This module walks over @Core@, and looks for @case@ on free variables.
24 The criterion is:
25         if there is case on a free on the route to the recursive call,
26         then the recursive call is replaced with an unfolding.
27
28 Example
29
30 \begin{verbatim}
31 f = \ t -> case v of
32                V a b -> a : f t
33 \end{verbatim}
34
35 => the inner f is replaced.
36
37 \begin{verbatim}
38 f = \ t -> case v of
39                V a b -> a : (letrec
40                                 f =  \ t -> case v of
41                                                V a b -> a : f t
42                              in f) t
43 \end{verbatim}
44 (note the NEED for shadowing)
45
46 => Simplify
47
48 \begin{verbatim}
49 f = \ t -> case v of
50                V a b -> a : (letrec
51                                 f = \ t -> a : f t
52                              in f t)
53 \begin{verbatim}
54
55 Better code, because 'a' is  free inside the inner letrec, rather
56 than needing projection from v.
57
58 Other examples we'd like to catch with this kind of transformation
59
60         last []     = error 
61         last (x:[]) = x
62         last (x:xs) = last xs
63
64 We'd like to avoid the redundant pattern match, transforming to
65
66         last [] = error
67         last (x:[]) = x
68         last (x:(y:ys)) = last' y ys
69                 where
70                   last' y []     = y
71                   last' _ (y:ys) = last' y ys
72
73         (is this necessarily an improvement)
74
75
76 Similarly drop:
77
78         drop n [] = []
79         drop 0 xs = xs
80         drop n (x:xs) = drop (n-1) xs
81
82 Would like to pass n along unboxed.
83         
84 Note [Scrutinee with cast]
85 ~~~~~~~~~~~~~~~~~~~~~~~~~~
86 Consider this:
87     f = \ t -> case (v `cast` co) of
88                  V a b -> a : f t
89
90 Exactly the same optimistaion (unrolling one call to f) will work here, 
91 despite the cast.  See mk_alt_env in the Case branch of libCase.
92
93
94 To think about (Apr 94)
95 ~~~~~~~~~~~~~~
96
97 Main worry: duplicating code excessively.  At the moment we duplicate
98 the entire binding group once at each recursive call.  But there may
99 be a group of recursive calls which share a common set of evaluated
100 free variables, in which case the duplication is a plain waste.
101
102 Another thing we could consider adding is some unfold-threshold thing,
103 so that we'll only duplicate if the size of the group rhss isn't too
104 big.
105
106 Data types
107 ~~~~~~~~~~
108
109 The ``level'' of a binder tells how many
110 recursive defns lexically enclose the binding
111 A recursive defn "encloses" its RHS, not its
112 scope.  For example:
113 \begin{verbatim}
114         letrec f = let g = ... in ...
115         in
116         let h = ...
117         in ...
118 \end{verbatim}
119 Here, the level of @f@ is zero, the level of @g@ is one,
120 and the level of @h@ is zero (NB not one).
121
122 \begin{code}
123 type LibCaseLevel = Int
124
125 topLevel :: LibCaseLevel
126 topLevel = 0
127 \end{code}
128
129 \begin{code}
130 data LibCaseEnv
131   = LibCaseEnv
132         Int                     -- Bomb-out size for deciding if
133                                 -- potential liberatees are too big.
134                                 -- (passed in from cmd-line args)
135
136         LibCaseLevel            -- Current level
137
138         (IdEnv LibCaseLevel)    -- Binds all non-top-level in-scope Ids
139                                 -- (top-level and imported things have
140                                 -- a level of zero)
141
142         (IdEnv CoreBind)        -- Binds *only* recursively defined
143                                 -- Ids, to their own binding group,
144                                 -- and *only* in their own RHSs
145
146         [(Id,LibCaseLevel)]     -- Each of these Ids was scrutinised by an
147                                 -- enclosing case expression, with the
148                                 -- specified number of enclosing
149                                 -- recursive bindings; furthermore,
150                                 -- the Id is bound at a lower level
151                                 -- than the case expression.  The
152                                 -- order is insignificant; it's a bag
153                                 -- really
154
155 initEnv :: Int -> LibCaseEnv
156 initEnv bomb_size = LibCaseEnv bomb_size 0 emptyVarEnv emptyVarEnv []
157
158 bombOutSize (LibCaseEnv bomb_size _ _ _ _) = bomb_size
159 \end{code}
160
161
162 Programs
163 ~~~~~~~~
164 \begin{code}
165 liberateCase :: DynFlags -> [CoreBind] -> IO [CoreBind]
166 liberateCase dflags binds
167   = do {
168         showPass dflags "Liberate case" ;
169         let { binds' = do_prog (initEnv opt_LiberateCaseThreshold) binds } ;
170         endPass dflags "Liberate case" Opt_D_verbose_core2core binds'
171                                 {- no specific flag for dumping -} 
172     }
173   where
174     do_prog env [] = []
175     do_prog env (bind:binds) = bind' : do_prog env' binds
176                              where
177                                (env', bind') = libCaseBind env bind
178 \end{code}
179
180 Bindings
181 ~~~~~~~~
182
183 \begin{code}
184 libCaseBind :: LibCaseEnv -> CoreBind -> (LibCaseEnv, CoreBind)
185
186 libCaseBind env (NonRec binder rhs)
187   = (addBinders env [binder], NonRec binder (libCase env rhs))
188
189 libCaseBind env (Rec pairs)
190   = (env_body, Rec pairs')
191   where
192     (binders, rhss) = unzip pairs
193
194     env_body = addBinders env binders
195
196     pairs' = [(binder, libCase env_rhs rhs) | (binder,rhs) <- pairs]
197
198     env_rhs = if all rhs_small_enough rhss then extended_env else env
199
200         -- We extend the rec-env by binding each Id to its rhs, first
201         -- processing the rhs with an *un-extended* environment, so
202         -- that the same process doesn't occur for ever!
203         --
204     extended_env = addRecBinds env [ (adjust binder, libCase env_body rhs)
205                                    | (binder, rhs) <- pairs ]
206
207         -- Two subtle things: 
208         -- (a)  Reset the export flags on the binders so
209         --      that we don't get name clashes on exported things if the 
210         --      local binding floats out to top level.  This is most unlikely
211         --      to happen, since the whole point concerns free variables. 
212         --      But resetting the export flag is right regardless.
213         -- 
214         -- (b)  Make the name an Internal one.  External Names should never be
215         --      nested; if it were floated to the top level, we'd get a name
216         --      clash at code generation time.
217     adjust bndr = setIdNotExported (setIdName bndr (localiseName (idName bndr)))
218
219     rhs_small_enough rhs = couldBeSmallEnoughToInline lIBERATE_BOMB_SIZE rhs
220     lIBERATE_BOMB_SIZE   = bombOutSize env
221 \end{code}
222
223
224 Expressions
225 ~~~~~~~~~~~
226
227 \begin{code}
228 libCase :: LibCaseEnv
229         -> CoreExpr
230         -> CoreExpr
231
232 libCase env (Var v)             = libCaseId env v
233 libCase env (Lit lit)           = Lit lit
234 libCase env (Type ty)           = Type ty
235 libCase env (App fun arg)       = App (libCase env fun) (libCase env arg)
236 libCase env (Note note body)    = Note note (libCase env body)
237 libCase env (Cast e co)         = Cast (libCase env e) co
238
239 libCase env (Lam binder body)
240   = Lam binder (libCase (addBinders env [binder]) body)
241
242 libCase env (Let bind body)
243   = Let bind' (libCase env_body body)
244   where
245     (env_body, bind') = libCaseBind env bind
246
247 libCase env (Case scrut bndr ty alts)
248   = Case (libCase env scrut) bndr ty (map (libCaseAlt env_alts) alts)
249   where
250     env_alts = addBinders (mk_alt_env scrut) [bndr]
251     mk_alt_env (Var scrut_var) = addScrutedVar env scrut_var
252     mk_alt_env (Cast scrut _)  = mk_alt_env scrut       -- Note [Scrutinee with cast]
253     mk_alt_env otehr           = env
254
255 libCaseAlt env (con,args,rhs) = (con, args, libCase (addBinders env args) rhs)
256 \end{code}
257
258 Ids
259 ~~~
260 \begin{code}
261 libCaseId :: LibCaseEnv -> Id -> CoreExpr
262 libCaseId env v
263   | Just the_bind <- lookupRecId env v  -- It's a use of a recursive thing
264   , notNull free_scruts                 -- with free vars scrutinised in RHS
265   = Let the_bind (Var v)
266
267   | otherwise
268   = Var v
269
270   where
271     rec_id_level = lookupLevel env v
272     free_scruts  = freeScruts env rec_id_level
273 \end{code}
274
275
276
277 Utility functions
278 ~~~~~~~~~~~~~~~~~
279 \begin{code}
280 addBinders :: LibCaseEnv -> [CoreBndr] -> LibCaseEnv
281 addBinders (LibCaseEnv bomb lvl lvl_env rec_env scruts) binders
282   = LibCaseEnv bomb lvl lvl_env' rec_env scruts
283   where
284     lvl_env' = extendVarEnvList lvl_env (binders `zip` repeat lvl)
285
286 addRecBinds :: LibCaseEnv -> [(Id,CoreExpr)] -> LibCaseEnv
287 addRecBinds (LibCaseEnv bomb lvl lvl_env rec_env scruts) pairs
288   = LibCaseEnv bomb lvl' lvl_env' rec_env' scruts
289   where
290     lvl'     = lvl + 1
291     lvl_env' = extendVarEnvList lvl_env [(binder,lvl) | (binder,_) <- pairs]
292     rec_env' = extendVarEnvList rec_env [(binder, Rec pairs) | (binder,_) <- pairs]
293
294 addScrutedVar :: LibCaseEnv
295               -> Id             -- This Id is being scrutinised by a case expression
296               -> LibCaseEnv
297
298 addScrutedVar env@(LibCaseEnv bomb lvl lvl_env rec_env scruts) scrut_var
299   | bind_lvl < lvl
300   = LibCaseEnv bomb lvl lvl_env rec_env scruts'
301         -- Add to scruts iff the scrut_var is being scrutinised at
302         -- a deeper level than its defn
303
304   | otherwise = env
305   where
306     scruts'  = (scrut_var, lvl) : scruts
307     bind_lvl = case lookupVarEnv lvl_env scrut_var of
308                  Just lvl -> lvl
309                  Nothing  -> topLevel
310
311 lookupRecId :: LibCaseEnv -> Id -> Maybe CoreBind
312 lookupRecId (LibCaseEnv bomb lvl lvl_env rec_env scruts) id
313   = lookupVarEnv rec_env id
314
315 lookupLevel :: LibCaseEnv -> Id -> LibCaseLevel
316 lookupLevel (LibCaseEnv bomb lvl lvl_env rec_env scruts) id
317   = case lookupVarEnv lvl_env id of
318       Just lvl -> lvl
319       Nothing  -> topLevel
320
321 freeScruts :: LibCaseEnv
322            -> LibCaseLevel      -- Level of the recursive Id
323            -> [Id]              -- Ids that are scrutinised between the binding
324                                 -- of the recursive Id and here
325 freeScruts (LibCaseEnv bomb lvl lvl_env rec_env scruts) rec_bind_lvl
326   = [v | (v,scrut_lvl) <- scruts, scrut_lvl > rec_bind_lvl]
327 \end{code}