[project @ 2001-10-03 13:58:50 by simonpj]
[ghc-hetmet.git] / ghc / compiler / simplCore / LiberateCase.lhs
1 %
2 % (c) The AQUA Project, Glasgow University, 1994-1998
3 %
4 \section[LiberateCase]{Unroll recursion to allow evals to be lifted from a loop}
5
6 \begin{code}
7 module LiberateCase ( liberateCase ) where
8
9 #include "HsVersions.h"
10
11 import CmdLineOpts      ( DynFlags, DynFlag(..), opt_LiberateCaseThreshold )
12 import CoreLint         ( showPass, endPass )
13 import CoreSyn
14 import CoreUnfold       ( couldBeSmallEnoughToInline )
15 import Var              ( Id )
16 import VarEnv
17 import Outputable
18 \end{code}
19
20 This module walks over @Core@, and looks for @case@ on free variables.
21 The criterion is:
22         if there is case on a free on the route to the recursive call,
23         then the recursive call is replaced with an unfolding.
24
25 Example
26
27 \begin{verbatim}
28 f = \ t -> case v of
29                V a b -> a : f t
30 \end{verbatim}
31
32 => the inner f is replaced.
33
34 \begin{verbatim}
35 f = \ t -> case v of
36                V a b -> a : (letrec
37                                 f =  \ t -> case v of
38                                                V a b -> a : f t
39                              in f) t
40 \end{verbatim}
41 (note the NEED for shadowing)
42
43 => Simplify
44
45 \begin{verbatim}
46 f = \ t -> case v of
47                V a b -> a : (letrec
48                                 f = \ t -> a : f t
49                              in f t)
50 \begin{verbatim}
51
52 Better code, because 'a' is  free inside the inner letrec, rather
53 than needing projection from v.
54
55 Other examples we'd like to catch with this kind of transformation
56
57         last []     = error 
58         last (x:[]) = x
59         last (x:xs) = last xs
60
61 We'd like to avoid the redundant pattern match, transforming to
62
63         last [] = error
64         last (x:[]) = x
65         last (x:(y:ys)) = last' y ys
66                 where
67                   last' y []     = y
68                   last' _ (y:ys) = last' y ys
69
70         (is this necessarily an improvement)
71
72
73 Similarly drop:
74
75         drop n [] = []
76         drop 0 xs = xs
77         drop n (x:xs) = drop (n-1) xs
78
79 Would like to pass n along unboxed.
80         
81
82 To think about (Apr 94)
83 ~~~~~~~~~~~~~~
84
85 Main worry: duplicating code excessively.  At the moment we duplicate
86 the entire binding group once at each recursive call.  But there may
87 be a group of recursive calls which share a common set of evaluated
88 free variables, in which case the duplication is a plain waste.
89
90 Another thing we could consider adding is some unfold-threshold thing,
91 so that we'll only duplicate if the size of the group rhss isn't too
92 big.
93
94 Data types
95 ~~~~~~~~~~
96
97 The ``level'' of a binder tells how many
98 recursive defns lexically enclose the binding
99 A recursive defn "encloses" its RHS, not its
100 scope.  For example:
101 \begin{verbatim}
102         letrec f = let g = ... in ...
103         in
104         let h = ...
105         in ...
106 \end{verbatim}
107 Here, the level of @f@ is zero, the level of @g@ is one,
108 and the level of @h@ is zero (NB not one).
109
110 \begin{code}
111 type LibCaseLevel = Int
112
113 topLevel :: LibCaseLevel
114 topLevel = 0
115 \end{code}
116
117 \begin{code}
118 data LibCaseEnv
119   = LibCaseEnv
120         Int                     -- Bomb-out size for deciding if
121                                 -- potential liberatees are too big.
122                                 -- (passed in from cmd-line args)
123
124         LibCaseLevel            -- Current level
125
126         (IdEnv LibCaseLevel)    -- Binds all non-top-level in-scope Ids
127                                 -- (top-level and imported things have
128                                 -- a level of zero)
129
130         (IdEnv CoreBind)        -- Binds *only* recursively defined
131                                 -- Ids, to their own binding group,
132                                 -- and *only* in their own RHSs
133
134         [(Id,LibCaseLevel)]     -- Each of these Ids was scrutinised by an
135                                 -- enclosing case expression, with the
136                                 -- specified number of enclosing
137                                 -- recursive bindings; furthermore,
138                                 -- the Id is bound at a lower level
139                                 -- than the case expression.  The
140                                 -- order is insignificant; it's a bag
141                                 -- really
142
143 initEnv :: Int -> LibCaseEnv
144 initEnv bomb_size = LibCaseEnv bomb_size 0 emptyVarEnv emptyVarEnv []
145
146 bombOutSize (LibCaseEnv bomb_size _ _ _ _) = bomb_size
147 \end{code}
148
149
150 Programs
151 ~~~~~~~~
152 \begin{code}
153 liberateCase :: DynFlags -> [CoreBind] -> IO [CoreBind]
154 liberateCase dflags binds
155   = do {
156         showPass dflags "Liberate case" ;
157         let { binds' = do_prog (initEnv opt_LiberateCaseThreshold) binds } ;
158         endPass dflags "Liberate case" Opt_D_verbose_core2core binds'
159                                 {- no specific flag for dumping -} 
160     }
161   where
162     do_prog env [] = []
163     do_prog env (bind:binds) = bind' : do_prog env' binds
164                              where
165                                (env', bind') = libCaseBind env bind
166 \end{code}
167
168 Bindings
169 ~~~~~~~~
170
171 \begin{code}
172 libCaseBind :: LibCaseEnv -> CoreBind -> (LibCaseEnv, CoreBind)
173
174 libCaseBind env (NonRec binder rhs)
175   = (addBinders env [binder], NonRec binder (libCase env rhs))
176
177 libCaseBind env (Rec pairs)
178   = (env_body, Rec pairs')
179   where
180     (binders, rhss) = unzip pairs
181
182     env_body = addBinders env binders
183
184     pairs' = [(binder, libCase env_rhs rhs) | (binder,rhs) <- pairs]
185
186     env_rhs = if all rhs_small_enough rhss then extended_env else env
187
188         -- We extend the rec-env by binding each Id to its rhs, first
189         -- processing the rhs with an *un-extended* environment, so
190         -- that the same process doesn't occur for ever!
191
192     extended_env = addRecBinds env [ (binder, libCase env_body rhs)
193                                    | (binder, rhs) <- pairs ]
194
195     rhs_small_enough rhs = couldBeSmallEnoughToInline lIBERATE_BOMB_SIZE rhs
196     lIBERATE_BOMB_SIZE   = bombOutSize env
197 \end{code}
198
199
200 Expressions
201 ~~~~~~~~~~~
202
203 \begin{code}
204 libCase :: LibCaseEnv
205         -> CoreExpr
206         -> CoreExpr
207
208 libCase env (Var v)             = libCaseId env v
209 libCase env (Lit lit)           = Lit lit
210 libCase env (Type ty)           = Type ty
211 libCase env (App fun arg)       = App (libCase env fun) (libCase env arg)
212 libCase env (Note note body)    = Note note (libCase env body)
213
214 libCase env (Lam binder body)
215   = Lam binder (libCase (addBinders env [binder]) body)
216
217 libCase env (Let bind body)
218   = Let bind' (libCase env_body body)
219   where
220     (env_body, bind') = libCaseBind env bind
221
222 libCase env (Case scrut bndr alts)
223   = Case (libCase env scrut) bndr (map (libCaseAlt env_alts) alts)
224   where
225     env_alts = addBinders env_with_scrut [bndr]
226     env_with_scrut = case scrut of
227                         Var scrut_var -> addScrutedVar env scrut_var
228                         other         -> env
229
230 libCaseAlt env (con,args,rhs) = (con, args, libCase (addBinders env args) rhs)
231 \end{code}
232
233 Ids
234 ~~~
235 \begin{code}
236 libCaseId :: LibCaseEnv -> Id -> CoreExpr
237 libCaseId env v
238   | Just the_bind <- lookupRecId env v  -- It's a use of a recursive thing
239   , not (null free_scruts)              -- with free vars scrutinised in RHS
240   = Let the_bind (Var v)
241
242   | otherwise
243   = Var v
244
245   where
246     rec_id_level = lookupLevel env v
247     free_scruts  = freeScruts env rec_id_level
248 \end{code}
249
250
251
252 Utility functions
253 ~~~~~~~~~~~~~~~~~
254 \begin{code}
255 addBinders :: LibCaseEnv -> [CoreBndr] -> LibCaseEnv
256 addBinders (LibCaseEnv bomb lvl lvl_env rec_env scruts) binders
257   = LibCaseEnv bomb lvl lvl_env' rec_env scruts
258   where
259     lvl_env' = extendVarEnvList lvl_env (binders `zip` repeat lvl)
260
261 addRecBinds :: LibCaseEnv -> [(Id,CoreExpr)] -> LibCaseEnv
262 addRecBinds (LibCaseEnv bomb lvl lvl_env rec_env scruts) pairs
263   = LibCaseEnv bomb lvl' lvl_env' rec_env' scruts
264   where
265     lvl'     = lvl + 1
266     lvl_env' = extendVarEnvList lvl_env [(binder,lvl) | (binder,_) <- pairs]
267     rec_env' = extendVarEnvList rec_env [(binder, Rec pairs) | (binder,_) <- pairs]
268
269 addScrutedVar :: LibCaseEnv
270               -> Id             -- This Id is being scrutinised by a case expression
271               -> LibCaseEnv
272
273 addScrutedVar env@(LibCaseEnv bomb lvl lvl_env rec_env scruts) scrut_var
274   | bind_lvl < lvl
275   = LibCaseEnv bomb lvl lvl_env rec_env scruts'
276         -- Add to scruts iff the scrut_var is being scrutinised at
277         -- a deeper level than its defn
278
279   | otherwise = env
280   where
281     scruts'  = (scrut_var, lvl) : scruts
282     bind_lvl = case lookupVarEnv lvl_env scrut_var of
283                  Just lvl -> lvl
284                  Nothing  -> topLevel
285
286 lookupRecId :: LibCaseEnv -> Id -> Maybe CoreBind
287 lookupRecId (LibCaseEnv bomb lvl lvl_env rec_env scruts) id
288   = lookupVarEnv rec_env id
289
290 lookupLevel :: LibCaseEnv -> Id -> LibCaseLevel
291 lookupLevel (LibCaseEnv bomb lvl lvl_env rec_env scruts) id
292   = case lookupVarEnv lvl_env id of
293       Just lvl -> lvl
294       Nothing  -> topLevel
295
296 freeScruts :: LibCaseEnv
297            -> LibCaseLevel      -- Level of the recursive Id
298            -> [Id]              -- Ids that are scrutinised between the binding
299                                 -- of the recursive Id and here
300 freeScruts (LibCaseEnv bomb lvl lvl_env rec_env scruts) rec_bind_lvl
301   = [v | (v,scrut_lvl) <- scruts, scrut_lvl > rec_bind_lvl]
302 \end{code}