[project @ 2005-03-07 17:46:24 by simonpj]
[ghc-hetmet.git] / ghc / compiler / simplCore / LiberateCase.lhs
1 %
2 % (c) The AQUA Project, Glasgow University, 1994-1998
3 %
4 \section[LiberateCase]{Unroll recursion to allow evals to be lifted from a loop}
5
6 \begin{code}
7 module LiberateCase ( liberateCase ) where
8
9 #include "HsVersions.h"
10
11 import CmdLineOpts      ( DynFlags, DynFlag(..), opt_LiberateCaseThreshold )
12 import CoreLint         ( showPass, endPass )
13 import CoreSyn
14 import CoreUnfold       ( couldBeSmallEnoughToInline )
15 import Id               ( Id, setIdName, idName, setIdNotExported )
16 import VarEnv
17 import Name             ( localiseName )
18 import Outputable
19 import Util             ( notNull )
20 \end{code}
21
22 This module walks over @Core@, and looks for @case@ on free variables.
23 The criterion is:
24         if there is case on a free on the route to the recursive call,
25         then the recursive call is replaced with an unfolding.
26
27 Example
28
29 \begin{verbatim}
30 f = \ t -> case v of
31                V a b -> a : f t
32 \end{verbatim}
33
34 => the inner f is replaced.
35
36 \begin{verbatim}
37 f = \ t -> case v of
38                V a b -> a : (letrec
39                                 f =  \ t -> case v of
40                                                V a b -> a : f t
41                              in f) t
42 \end{verbatim}
43 (note the NEED for shadowing)
44
45 => Simplify
46
47 \begin{verbatim}
48 f = \ t -> case v of
49                V a b -> a : (letrec
50                                 f = \ t -> a : f t
51                              in f t)
52 \begin{verbatim}
53
54 Better code, because 'a' is  free inside the inner letrec, rather
55 than needing projection from v.
56
57 Other examples we'd like to catch with this kind of transformation
58
59         last []     = error 
60         last (x:[]) = x
61         last (x:xs) = last xs
62
63 We'd like to avoid the redundant pattern match, transforming to
64
65         last [] = error
66         last (x:[]) = x
67         last (x:(y:ys)) = last' y ys
68                 where
69                   last' y []     = y
70                   last' _ (y:ys) = last' y ys
71
72         (is this necessarily an improvement)
73
74
75 Similarly drop:
76
77         drop n [] = []
78         drop 0 xs = xs
79         drop n (x:xs) = drop (n-1) xs
80
81 Would like to pass n along unboxed.
82         
83
84 To think about (Apr 94)
85 ~~~~~~~~~~~~~~
86
87 Main worry: duplicating code excessively.  At the moment we duplicate
88 the entire binding group once at each recursive call.  But there may
89 be a group of recursive calls which share a common set of evaluated
90 free variables, in which case the duplication is a plain waste.
91
92 Another thing we could consider adding is some unfold-threshold thing,
93 so that we'll only duplicate if the size of the group rhss isn't too
94 big.
95
96 Data types
97 ~~~~~~~~~~
98
99 The ``level'' of a binder tells how many
100 recursive defns lexically enclose the binding
101 A recursive defn "encloses" its RHS, not its
102 scope.  For example:
103 \begin{verbatim}
104         letrec f = let g = ... in ...
105         in
106         let h = ...
107         in ...
108 \end{verbatim}
109 Here, the level of @f@ is zero, the level of @g@ is one,
110 and the level of @h@ is zero (NB not one).
111
112 \begin{code}
113 type LibCaseLevel = Int
114
115 topLevel :: LibCaseLevel
116 topLevel = 0
117 \end{code}
118
119 \begin{code}
120 data LibCaseEnv
121   = LibCaseEnv
122         Int                     -- Bomb-out size for deciding if
123                                 -- potential liberatees are too big.
124                                 -- (passed in from cmd-line args)
125
126         LibCaseLevel            -- Current level
127
128         (IdEnv LibCaseLevel)    -- Binds all non-top-level in-scope Ids
129                                 -- (top-level and imported things have
130                                 -- a level of zero)
131
132         (IdEnv CoreBind)        -- Binds *only* recursively defined
133                                 -- Ids, to their own binding group,
134                                 -- and *only* in their own RHSs
135
136         [(Id,LibCaseLevel)]     -- Each of these Ids was scrutinised by an
137                                 -- enclosing case expression, with the
138                                 -- specified number of enclosing
139                                 -- recursive bindings; furthermore,
140                                 -- the Id is bound at a lower level
141                                 -- than the case expression.  The
142                                 -- order is insignificant; it's a bag
143                                 -- really
144
145 initEnv :: Int -> LibCaseEnv
146 initEnv bomb_size = LibCaseEnv bomb_size 0 emptyVarEnv emptyVarEnv []
147
148 bombOutSize (LibCaseEnv bomb_size _ _ _ _) = bomb_size
149 \end{code}
150
151
152 Programs
153 ~~~~~~~~
154 \begin{code}
155 liberateCase :: DynFlags -> [CoreBind] -> IO [CoreBind]
156 liberateCase dflags binds
157   = do {
158         showPass dflags "Liberate case" ;
159         let { binds' = do_prog (initEnv opt_LiberateCaseThreshold) binds } ;
160         endPass dflags "Liberate case" Opt_D_verbose_core2core binds'
161                                 {- no specific flag for dumping -} 
162     }
163   where
164     do_prog env [] = []
165     do_prog env (bind:binds) = bind' : do_prog env' binds
166                              where
167                                (env', bind') = libCaseBind env bind
168 \end{code}
169
170 Bindings
171 ~~~~~~~~
172
173 \begin{code}
174 libCaseBind :: LibCaseEnv -> CoreBind -> (LibCaseEnv, CoreBind)
175
176 libCaseBind env (NonRec binder rhs)
177   = (addBinders env [binder], NonRec binder (libCase env rhs))
178
179 libCaseBind env (Rec pairs)
180   = (env_body, Rec pairs')
181   where
182     (binders, rhss) = unzip pairs
183
184     env_body = addBinders env binders
185
186     pairs' = [(binder, libCase env_rhs rhs) | (binder,rhs) <- pairs]
187
188     env_rhs = if all rhs_small_enough rhss then extended_env else env
189
190         -- We extend the rec-env by binding each Id to its rhs, first
191         -- processing the rhs with an *un-extended* environment, so
192         -- that the same process doesn't occur for ever!
193         --
194     extended_env = addRecBinds env [ (setIdNotExported binder, libCase env_body rhs)
195                                    | (binder, rhs) <- pairs ]
196
197         -- Two subtle things: 
198         -- (a)  Reset the export flags on the binders so
199         --      that we don't get name clashes on exported things if the 
200         --      local binding floats out to top level.  This is most unlikely
201         --      to happen, since the whole point concerns free variables. 
202         --      But resetting the export flag is right regardless.
203         -- 
204         -- (b)  Make the name an Internal one.  External Names should never be
205         --      nested; if it were floated to the top level, we'd get a name
206         --      clash at code generation time.
207     adjust bndr = setIdNotExported (setIdName bndr (localiseName (idName bndr)))
208
209     rhs_small_enough rhs = couldBeSmallEnoughToInline lIBERATE_BOMB_SIZE rhs
210     lIBERATE_BOMB_SIZE   = bombOutSize env
211 \end{code}
212
213
214 Expressions
215 ~~~~~~~~~~~
216
217 \begin{code}
218 libCase :: LibCaseEnv
219         -> CoreExpr
220         -> CoreExpr
221
222 libCase env (Var v)             = libCaseId env v
223 libCase env (Lit lit)           = Lit lit
224 libCase env (Type ty)           = Type ty
225 libCase env (App fun arg)       = App (libCase env fun) (libCase env arg)
226 libCase env (Note note body)    = Note note (libCase env body)
227
228 libCase env (Lam binder body)
229   = Lam binder (libCase (addBinders env [binder]) body)
230
231 libCase env (Let bind body)
232   = Let bind' (libCase env_body body)
233   where
234     (env_body, bind') = libCaseBind env bind
235
236 libCase env (Case scrut bndr ty alts)
237   = Case (libCase env scrut) bndr ty (map (libCaseAlt env_alts) alts)
238   where
239     env_alts = addBinders env_with_scrut [bndr]
240     env_with_scrut = case scrut of
241                         Var scrut_var -> addScrutedVar env scrut_var
242                         other         -> env
243
244 libCaseAlt env (con,args,rhs) = (con, args, libCase (addBinders env args) rhs)
245 \end{code}
246
247 Ids
248 ~~~
249 \begin{code}
250 libCaseId :: LibCaseEnv -> Id -> CoreExpr
251 libCaseId env v
252   | Just the_bind <- lookupRecId env v  -- It's a use of a recursive thing
253   , notNull free_scruts                 -- with free vars scrutinised in RHS
254   = Let the_bind (Var v)
255
256   | otherwise
257   = Var v
258
259   where
260     rec_id_level = lookupLevel env v
261     free_scruts  = freeScruts env rec_id_level
262 \end{code}
263
264
265
266 Utility functions
267 ~~~~~~~~~~~~~~~~~
268 \begin{code}
269 addBinders :: LibCaseEnv -> [CoreBndr] -> LibCaseEnv
270 addBinders (LibCaseEnv bomb lvl lvl_env rec_env scruts) binders
271   = LibCaseEnv bomb lvl lvl_env' rec_env scruts
272   where
273     lvl_env' = extendVarEnvList lvl_env (binders `zip` repeat lvl)
274
275 addRecBinds :: LibCaseEnv -> [(Id,CoreExpr)] -> LibCaseEnv
276 addRecBinds (LibCaseEnv bomb lvl lvl_env rec_env scruts) pairs
277   = LibCaseEnv bomb lvl' lvl_env' rec_env' scruts
278   where
279     lvl'     = lvl + 1
280     lvl_env' = extendVarEnvList lvl_env [(binder,lvl) | (binder,_) <- pairs]
281     rec_env' = extendVarEnvList rec_env [(binder, Rec pairs) | (binder,_) <- pairs]
282
283 addScrutedVar :: LibCaseEnv
284               -> Id             -- This Id is being scrutinised by a case expression
285               -> LibCaseEnv
286
287 addScrutedVar env@(LibCaseEnv bomb lvl lvl_env rec_env scruts) scrut_var
288   | bind_lvl < lvl
289   = LibCaseEnv bomb lvl lvl_env rec_env scruts'
290         -- Add to scruts iff the scrut_var is being scrutinised at
291         -- a deeper level than its defn
292
293   | otherwise = env
294   where
295     scruts'  = (scrut_var, lvl) : scruts
296     bind_lvl = case lookupVarEnv lvl_env scrut_var of
297                  Just lvl -> lvl
298                  Nothing  -> topLevel
299
300 lookupRecId :: LibCaseEnv -> Id -> Maybe CoreBind
301 lookupRecId (LibCaseEnv bomb lvl lvl_env rec_env scruts) id
302   = lookupVarEnv rec_env id
303
304 lookupLevel :: LibCaseEnv -> Id -> LibCaseLevel
305 lookupLevel (LibCaseEnv bomb lvl lvl_env rec_env scruts) id
306   = case lookupVarEnv lvl_env id of
307       Just lvl -> lvl
308       Nothing  -> topLevel
309
310 freeScruts :: LibCaseEnv
311            -> LibCaseLevel      -- Level of the recursive Id
312            -> [Id]              -- Ids that are scrutinised between the binding
313                                 -- of the recursive Id and here
314 freeScruts (LibCaseEnv bomb lvl lvl_env rec_env scruts) rec_bind_lvl
315   = [v | (v,scrut_lvl) <- scruts, scrut_lvl > rec_bind_lvl]
316 \end{code}