Add (a) CoreM monad, (b) new Annotations feature
[ghc-hetmet.git] / compiler / simplCore / SAT.lhs
1 %
2 % (c) The GRASP/AQUA Project, Glasgow University, 1992-1998
3 %
4
5 %************************************************************************
6
7                Static Argument Transformation pass
8
9 %************************************************************************
10
11 May be seen as removing invariants from loops:
12 Arguments of recursive functions that do not change in recursive
13 calls are removed from the recursion, which is done locally
14 and only passes the arguments which effectively change.
15
16 Example:
17 map = /\ ab -> \f -> \xs -> case xs of
18                  []       -> []
19                  (a:b) -> f a : map f b
20
21 as map is recursively called with the same argument f (unmodified)
22 we transform it to
23
24 map = /\ ab -> \f -> \xs -> let map' ys = case ys of
25                        []     -> []
26                        (a:b) -> f a : map' b
27                 in map' xs
28
29 Notice that for a compiler that uses lambda lifting this is
30 useless as map' will be transformed back to what map was.
31
32 We could possibly do the same for big lambdas, but we don't as
33 they will eventually be removed in later stages of the compiler,
34 therefore there is no penalty in keeping them.
35
36 We only apply the SAT when the number of static args is > 2. This
37 produces few bad cases.  See
38                 should_transform
39 in saTransform.
40
41 Here are the headline nofib results:
42                   Size    Allocs   Runtime
43 Min             +0.0%    -13.7%    -21.4%
44 Max             +0.1%     +0.0%     +5.4%
45 Geometric Mean  +0.0%     -0.2%     -6.9%
46
47 The previous patch, to fix polymorphic floatout demand signatures, is
48 essential to make this work well!
49
50
51 \begin{code}
52
53 module SAT ( doStaticArgs ) where
54
55 import Var
56 import CoreSyn
57 import CoreUtils
58 import Type
59 import TcType
60 import Id
61 import Name
62 import OccName
63 import VarEnv
64 import UniqSupply
65 import Util
66 import UniqFM
67 import VarSet
68 import Unique
69 import UniqSet
70 import Outputable
71
72 import Data.List
73 import FastString
74
75 #include "HsVersions.h"
76 \end{code}
77
78 \begin{code}
79 doStaticArgs :: UniqSupply -> [CoreBind] -> [CoreBind]
80 doStaticArgs us binds = snd $ mapAccumL sat_bind_threaded_us us binds
81   where
82     sat_bind_threaded_us us bind =
83         let (us1, us2) = splitUniqSupply us
84         in (us1, fst $ runSAT us2 (satBind bind emptyUniqSet))
85 \end{code}
86 \begin{code}
87 -- We don't bother to SAT recursive groups since it can lead
88 -- to massive code expansion: see Andre Santos' thesis for details.
89 -- This means we only apply the actual SAT to Rec groups of one element,
90 -- but we want to recurse into the others anyway to discover other binds
91 satBind :: CoreBind -> IdSet -> SatM (CoreBind, IdSATInfo)
92 satBind (NonRec binder expr) interesting_ids = do
93     (expr', sat_info_expr, expr_app) <- satExpr expr interesting_ids
94     return (NonRec binder expr', finalizeApp expr_app sat_info_expr)
95 satBind (Rec [(binder, rhs)]) interesting_ids = do
96     let interesting_ids' = interesting_ids `addOneToUniqSet` binder
97         (rhs_binders, rhs_body) = collectBinders rhs
98     (rhs_body', sat_info_rhs_body) <- satTopLevelExpr rhs_body interesting_ids'
99     let sat_info_rhs_from_args = unitVarEnv binder (bindersToSATInfo rhs_binders)
100         sat_info_rhs' = mergeIdSATInfo sat_info_rhs_from_args sat_info_rhs_body
101         
102         shadowing = binder `elementOfUniqSet` interesting_ids
103         sat_info_rhs'' = if shadowing
104                         then sat_info_rhs' `delFromUFM` binder -- For safety
105                         else sat_info_rhs'
106     
107     bind' <- saTransformMaybe binder (lookupUFM sat_info_rhs' binder) 
108                               rhs_binders rhs_body'
109     return (bind', sat_info_rhs'')
110 satBind (Rec pairs) interesting_ids = do
111     let (binders, rhss) = unzip pairs
112     rhss_SATed <- mapM (\e -> satTopLevelExpr e interesting_ids) rhss
113     let (rhss', sat_info_rhss') = unzip rhss_SATed
114     return (Rec (zipEqual "satBind" binders rhss'), mergeIdSATInfos sat_info_rhss')
115 \end{code}
116 \begin{code}
117 data App = VarApp Id | TypeApp Type
118 data Staticness a = Static a | NotStatic
119
120 type IdAppInfo = (Id, SATInfo)
121
122 type SATInfo = [Staticness App]
123 type IdSATInfo = IdEnv SATInfo
124 emptyIdSATInfo :: IdSATInfo
125 emptyIdSATInfo = emptyUFM
126
127 {-
128 pprIdSATInfo id_sat_info = vcat (map pprIdAndSATInfo (fmToList id_sat_info))
129   where pprIdAndSATInfo (v, sat_info) = hang (ppr v <> colon) 4 (pprSATInfo sat_info)
130 -}
131
132 pprSATInfo :: SATInfo -> SDoc
133 pprSATInfo staticness = hcat $ map pprStaticness staticness
134
135 pprStaticness :: Staticness App -> SDoc
136 pprStaticness (Static (VarApp _))  = ptext (sLit "SV") 
137 pprStaticness (Static (TypeApp _)) = ptext (sLit "ST") 
138 pprStaticness NotStatic            = ptext (sLit "NS")
139
140
141 mergeSATInfo :: SATInfo -> SATInfo -> SATInfo
142 mergeSATInfo [] _  = []
143 mergeSATInfo _  [] = []
144 mergeSATInfo (NotStatic:statics) (_:apps) = NotStatic : mergeSATInfo statics apps
145 mergeSATInfo (_:statics) (NotStatic:apps) = NotStatic : mergeSATInfo statics apps
146 mergeSATInfo ((Static (VarApp v)):statics)  ((Static (VarApp v')):apps)  = (if v == v' then Static (VarApp v) else NotStatic) : mergeSATInfo statics apps
147 mergeSATInfo ((Static (TypeApp t)):statics) ((Static (TypeApp t')):apps) = (if t `coreEqType` t' then Static (TypeApp t) else NotStatic) : mergeSATInfo statics apps
148 mergeSATInfo l  r  = pprPanic "mergeSATInfo" $ ptext (sLit "Left:") <> pprSATInfo l <> ptext (sLit ", ")
149                                             <> ptext (sLit "Right:") <> pprSATInfo r
150
151 mergeIdSATInfo :: IdSATInfo -> IdSATInfo -> IdSATInfo
152 mergeIdSATInfo = plusUFM_C mergeSATInfo
153
154 mergeIdSATInfos :: [IdSATInfo] -> IdSATInfo
155 mergeIdSATInfos = foldl' mergeIdSATInfo emptyIdSATInfo
156
157 bindersToSATInfo :: [Id] -> SATInfo
158 bindersToSATInfo vs = map (Static . binderToApp) vs
159     where binderToApp v = if isId v
160                           then VarApp v
161                           else TypeApp $ mkTyVarTy v
162
163 finalizeApp :: Maybe IdAppInfo -> IdSATInfo -> IdSATInfo
164 finalizeApp Nothing id_sat_info = id_sat_info
165 finalizeApp (Just (v, sat_info')) id_sat_info = 
166     let sat_info'' = case lookupUFM id_sat_info v of
167                         Nothing -> sat_info'
168                         Just sat_info -> mergeSATInfo sat_info sat_info'
169     in extendVarEnv id_sat_info v sat_info''
170 \end{code}
171 \begin{code}
172 satTopLevelExpr :: CoreExpr -> IdSet -> SatM (CoreExpr, IdSATInfo)
173 satTopLevelExpr expr interesting_ids = do
174     (expr', sat_info_expr, expr_app) <- satExpr expr interesting_ids
175     return (expr', finalizeApp expr_app sat_info_expr)
176
177 satExpr :: CoreExpr -> IdSet -> SatM (CoreExpr, IdSATInfo, Maybe IdAppInfo)
178 satExpr var@(Var v) interesting_ids = do
179     let app_info = if v `elementOfUniqSet` interesting_ids
180                    then Just (v, [])
181                    else Nothing
182     return (var, emptyIdSATInfo, app_info)
183
184 satExpr lit@(Lit _) _ = do
185     return (lit, emptyIdSATInfo, Nothing)
186
187 satExpr (Lam binders body) interesting_ids = do
188     (body', sat_info, this_app) <- satExpr body interesting_ids
189     return (Lam binders body', finalizeApp this_app sat_info, Nothing)
190
191 satExpr (App fn arg) interesting_ids = do
192     (fn', sat_info_fn, fn_app) <- satExpr fn interesting_ids
193     let satRemainder = boring fn' sat_info_fn
194     case fn_app of
195         Nothing -> satRemainder Nothing
196         Just (fn_id, fn_app_info) ->
197             -- TODO: remove this use of append somehow (use a data structure with O(1) append but a left-to-right kind of interface)
198             let satRemainderWithStaticness arg_staticness = satRemainder $ Just (fn_id, fn_app_info ++ [arg_staticness])
199             in case arg of
200                 Type t -> satRemainderWithStaticness $ Static (TypeApp t)
201                 Var v  -> satRemainderWithStaticness $ Static (VarApp v)
202                 _      -> satRemainderWithStaticness $ NotStatic
203   where
204     boring :: CoreExpr -> IdSATInfo -> Maybe IdAppInfo -> SatM (CoreExpr, IdSATInfo, Maybe IdAppInfo)
205     boring fn' sat_info_fn app_info = 
206         do (arg', sat_info_arg, arg_app) <- satExpr arg interesting_ids
207            let sat_info_arg' = finalizeApp arg_app sat_info_arg
208                sat_info = mergeIdSATInfo sat_info_fn sat_info_arg'
209            return (App fn' arg', sat_info, app_info)
210
211 satExpr (Case expr bndr ty alts) interesting_ids = do
212     (expr', sat_info_expr, expr_app) <- satExpr expr interesting_ids
213     let sat_info_expr' = finalizeApp expr_app sat_info_expr
214     
215     zipped_alts' <- mapM satAlt alts
216     let (alts', sat_infos_alts) = unzip zipped_alts'
217     return (Case expr' bndr ty alts', mergeIdSATInfo sat_info_expr' (mergeIdSATInfos sat_infos_alts), Nothing)
218   where
219     satAlt (con, bndrs, expr) = do
220         (expr', sat_info_expr) <- satTopLevelExpr expr interesting_ids
221         return ((con, bndrs, expr'), sat_info_expr)
222
223 satExpr (Let bind body) interesting_ids = do
224     (body', sat_info_body, body_app) <- satExpr body interesting_ids
225     (bind', sat_info_bind) <- satBind bind interesting_ids
226     return (Let bind' body', mergeIdSATInfo sat_info_body sat_info_bind, body_app)
227
228 satExpr (Note note expr) interesting_ids = do
229     (expr', sat_info_expr, expr_app) <- satExpr expr interesting_ids
230     return (Note note expr', sat_info_expr, expr_app)
231
232 satExpr ty@(Type _) _ = do
233     return (ty, emptyIdSATInfo, Nothing)
234
235 satExpr (Cast expr coercion) interesting_ids = do
236     (expr', sat_info_expr, expr_app) <- satExpr expr interesting_ids
237     return (Cast expr' coercion, sat_info_expr, expr_app)
238 \end{code}
239
240 %************************************************************************
241
242                 Static Argument Transformation Monad
243
244 %************************************************************************
245
246 \begin{code}
247 type SatM result = UniqSM result
248
249 runSAT :: UniqSupply -> SatM a -> a
250 runSAT = initUs_
251
252 newUnique :: SatM Unique
253 newUnique = getUniqueUs
254 \end{code}
255
256
257 %************************************************************************
258
259                 Static Argument Transformation Monad
260
261 %************************************************************************
262
263 To do the transformation, the game plan is to:
264
265 1. Create a small nonrecursive RHS that takes the
266    original arguments to the function but discards
267    the ones that are static and makes a call to the
268    SATed version with the remainder. We intend that
269    this will be inlined later, removing the overhead
270
271 2. Bind this nonrecursive RHS over the original body
272    WITH THE SAME UNIQUE as the original body so that
273    any recursive calls to the original now go via
274    the small wrapper
275
276 3. Rebind the original function to a new one which contains
277    our SATed function and just makes a call to it:
278    we call the thing making this call the local body
279
280 Example: transform this
281
282     map :: forall a b. (a->b) -> [a] -> [b]
283     map = /\ab. \(f:a->b) (as:[a]) -> body[map]
284 to
285     map :: forall a b. (a->b) -> [a] -> [b]
286     map = /\ab. \(f:a->b) (as:[a]) ->
287          letrec map' :: [a] -> [b]
288                     -- The "worker function
289                 map' = \(as:[a]) -> 
290                          let map :: forall a' b'. (a -> b) -> [a] -> [b]
291                                 -- The "shadow function
292                              map = /\a'b'. \(f':(a->b) (as:[a]).
293                                    map' as
294                          in body[map]
295          in map' as
296
297 Note [Shadow binding]
298 ~~~~~~~~~~~~~~~~~~~~~
299 The calls to the inner map inside body[map] should get inlined
300 by the local re-binding of 'map'.  We call this the "shadow binding".
301
302 But we can't use the original binder 'map' unchanged, because
303 it might be exported, in which case the shadow binding won't be
304 discarded as dead code after it is inlined.
305
306 So we use a hack: we make a new SysLocal binder with the *same* unique
307 as binder.  (Another alternative would be to reset the export flag.)
308
309 Note [Binder type capture]
310 ~~~~~~~~~~~~~~~~~~~~~~~~~~
311 Notice that in the inner map (the "shadow function"), the static arguments
312 are discarded -- it's as if they were underscores.  Instead, mentions
313 of these arguments (notably in the types of dynamic arguments) are bound
314 by the *outer* lambdas of the main function.  So we must make up fresh
315 names for the static arguments so that they do not capture variables 
316 mentioned in the types of dynamic args.  
317
318 In the map example, the shadow function must clone the static type
319 argument a,b, giving a',b', to ensure that in the \(as:[a]), the 'a'
320 is bound by the outer forall.  We clone f' too for consistency, but
321 that doesn't matter either way because static Id arguments aren't 
322 mentioned in the shadow binding at all.
323
324 If we don't we get something like this:
325
326 [Exported]
327 [Arity 3]
328 GHC.Base.until =
329   \ (@ a_aiK)
330     (p_a6T :: a_aiK -> GHC.Bool.Bool)
331     (f_a6V :: a_aiK -> a_aiK)
332     (x_a6X :: a_aiK) ->
333     letrec {
334       sat_worker_s1aU :: a_aiK -> a_aiK
335       []
336       sat_worker_s1aU =
337         \ (x_a6X :: a_aiK) ->
338           let {
339             sat_shadow_r17 :: forall a_a3O.
340                               (a_a3O -> GHC.Bool.Bool) -> (a_a3O -> a_a3O) -> a_a3O -> a_a3O
341             []
342             sat_shadow_r17 =
343               \ (@ a_aiK)
344                 (p_a6T :: a_aiK -> GHC.Bool.Bool)
345                 (f_a6V :: a_aiK -> a_aiK)
346                 (x_a6X :: a_aiK) ->
347                 sat_worker_s1aU x_a6X } in
348           case p_a6T x_a6X of wild_X3y [ALWAYS Dead Nothing] {
349             GHC.Bool.False -> GHC.Base.until @ a_aiK p_a6T f_a6V (f_a6V x_a6X);
350             GHC.Bool.True -> x_a6X
351           }; } in
352     sat_worker_s1aU x_a6X
353     
354 Where sat_shadow has captured the type variables of x_a6X etc as it has a a_aiK 
355 type argument. This is bad because it means the application sat_worker_s1aU x_a6X
356 is not well typed.
357
358 \begin{code}
359 saTransformMaybe :: Id -> Maybe SATInfo -> [Id] -> CoreExpr -> SatM CoreBind
360 saTransformMaybe binder maybe_arg_staticness rhs_binders rhs_body
361   | Just arg_staticness <- maybe_arg_staticness
362   , should_transform arg_staticness
363   = saTransform binder arg_staticness rhs_binders rhs_body
364   | otherwise
365   = return (Rec [(binder, mkLams rhs_binders rhs_body)])
366   where 
367     should_transform staticness = n_static_args > 1 -- THIS IS THE DECISION POINT
368       where
369         n_static_args = length (filter isStaticValue staticness)
370
371 saTransform :: Id -> SATInfo -> [Id] -> CoreExpr -> SatM CoreBind
372 saTransform binder arg_staticness rhs_binders rhs_body
373   = do  { shadow_lam_bndrs <- mapM clone binders_w_staticness
374         ; uniq             <- newUnique
375         ; return (NonRec binder (mk_new_rhs uniq shadow_lam_bndrs)) }
376   where
377     -- Running example: foldr
378     -- foldr \alpha \beta c n xs = e, for some e
379     -- arg_staticness = [Static TypeApp, Static TypeApp, Static VarApp, Static VarApp, NonStatic]
380     -- rhs_binders = [\alpha, \beta, c, n, xs]
381     -- rhs_body = e
382     
383     binders_w_staticness = rhs_binders `zip` (arg_staticness ++ repeat NotStatic)
384                                         -- Any extra args are assumed NotStatic
385
386     non_static_args :: [Var]
387             -- non_static_args = [xs]
388             -- rhs_binders_without_type_capture = [\alpha', \beta', c, n, xs]
389     non_static_args = [v | (v, NotStatic) <- binders_w_staticness]
390
391     clone (bndr, NotStatic) = return bndr
392     clone (bndr, _        ) = do { uniq <- newUnique
393                                  ; return (setVarUnique bndr uniq) }
394
395     -- new_rhs = \alpha beta c n xs -> 
396     --           let sat_worker = \xs -> let sat_shadow = \alpha' beta' c n xs -> 
397     --                                       sat_worker xs 
398     --                                   in e
399     --           in sat_worker xs
400     mk_new_rhs uniq shadow_lam_bndrs 
401         = mkLams rhs_binders $ 
402           Let (Rec [(rec_body_bndr, rec_body)]) 
403           local_body
404         where
405           local_body = mkVarApps (Var rec_body_bndr) non_static_args
406
407           rec_body = mkLams non_static_args $
408                      Let (NonRec shadow_bndr shadow_rhs) rhs_body
409
410             -- See Note [Binder type capture]
411           shadow_rhs = mkLams shadow_lam_bndrs local_body
412             -- nonrec_rhs = \alpha' beta' c n xs -> sat_worker xs
413
414           rec_body_bndr = mkSysLocal (fsLit "sat_worker") uniq (exprType rec_body)
415             -- rec_body_bndr = sat_worker
416     
417             -- See Note [Shadow binding]; make a SysLocal
418           shadow_bndr = mkSysLocal (occNameFS (getOccName binder)) 
419                                    (idUnique binder)
420                                    (exprType shadow_rhs)
421
422 isStaticValue :: Staticness App -> Bool
423 isStaticValue (Static (VarApp _)) = True
424 isStaticValue _                   = False
425
426 \end{code}