White space only
[ghc-hetmet.git] / compiler / simplCore / SAT.lhs
1 %
2 % (c) The GRASP/AQUA Project, Glasgow University, 1992-1998
3 %
4
5 %************************************************************************
6
7                Static Argument Transformation pass
8
9 %************************************************************************
10
11 May be seen as removing invariants from loops:
12 Arguments of recursive functions that do not change in recursive
13 calls are removed from the recursion, which is done locally
14 and only passes the arguments which effectively change.
15
16 Example:
17 map = /\ ab -> \f -> \xs -> case xs of
18                  []       -> []
19                  (a:b) -> f a : map f b
20
21 as map is recursively called with the same argument f (unmodified)
22 we transform it to
23
24 map = /\ ab -> \f -> \xs -> let map' ys = case ys of
25                        []     -> []
26                        (a:b) -> f a : map' b
27                 in map' xs
28
29 Notice that for a compiler that uses lambda lifting this is
30 useless as map' will be transformed back to what map was.
31
32 We could possibly do the same for big lambdas, but we don't as
33 they will eventually be removed in later stages of the compiler,
34 therefore there is no penalty in keeping them.
35
36 We only apply the SAT when the number of static args is > 2. This
37 produces few bad cases.  See
38                 should_transform
39 in saTransform.
40
41 Here are the headline nofib results:
42                   Size    Allocs   Runtime
43 Min             +0.0%    -13.7%    -21.4%
44 Max             +0.1%     +0.0%     +5.4%
45 Geometric Mean  +0.0%     -0.2%     -6.9%
46
47 The previous patch, to fix polymorphic floatout demand signatures, is
48 essential to make this work well!
49
50
51 \begin{code}
52
53 module SAT ( doStaticArgs ) where
54
55 import Var
56 import CoreSyn
57 import CoreUtils
58 import Type
59 import Id
60 import Name
61 import VarEnv
62 import UniqSupply
63 import Util
64 import UniqFM
65 import VarSet
66 import Unique
67 import UniqSet
68 import Outputable
69
70 import Data.List
71 import FastString
72
73 #include "HsVersions.h"
74 \end{code}
75
76 \begin{code}
77 doStaticArgs :: UniqSupply -> [CoreBind] -> [CoreBind]
78 doStaticArgs us binds = snd $ mapAccumL sat_bind_threaded_us us binds
79   where
80     sat_bind_threaded_us us bind =
81         let (us1, us2) = splitUniqSupply us
82         in (us1, fst $ runSAT us2 (satBind bind emptyUniqSet))
83 \end{code}
84 \begin{code}
85 -- We don't bother to SAT recursive groups since it can lead
86 -- to massive code expansion: see Andre Santos' thesis for details.
87 -- This means we only apply the actual SAT to Rec groups of one element,
88 -- but we want to recurse into the others anyway to discover other binds
89 satBind :: CoreBind -> IdSet -> SatM (CoreBind, IdSATInfo)
90 satBind (NonRec binder expr) interesting_ids = do
91     (expr', sat_info_expr, expr_app) <- satExpr expr interesting_ids
92     return (NonRec binder expr', finalizeApp expr_app sat_info_expr)
93 satBind (Rec [(binder, rhs)]) interesting_ids = do
94     let interesting_ids' = interesting_ids `addOneToUniqSet` binder
95         (rhs_binders, rhs_body) = collectBinders rhs
96     (rhs_body', sat_info_rhs_body) <- satTopLevelExpr rhs_body interesting_ids'
97     let sat_info_rhs_from_args = unitVarEnv binder (bindersToSATInfo rhs_binders)
98         sat_info_rhs' = mergeIdSATInfo sat_info_rhs_from_args sat_info_rhs_body
99         
100         shadowing = binder `elementOfUniqSet` interesting_ids
101         sat_info_rhs'' = if shadowing
102                         then sat_info_rhs' `delFromUFM` binder -- For safety
103                         else sat_info_rhs'
104     
105     bind' <- saTransformMaybe binder (lookupUFM sat_info_rhs' binder) 
106                               rhs_binders rhs_body'
107     return (bind', sat_info_rhs'')
108 satBind (Rec pairs) interesting_ids = do
109     let (binders, rhss) = unzip pairs
110     rhss_SATed <- mapM (\e -> satTopLevelExpr e interesting_ids) rhss
111     let (rhss', sat_info_rhss') = unzip rhss_SATed
112     return (Rec (zipEqual "satBind" binders rhss'), mergeIdSATInfos sat_info_rhss')
113 \end{code}
114 \begin{code}
115 data App = VarApp Id | TypeApp Type
116 data Staticness a = Static a | NotStatic
117
118 type IdAppInfo = (Id, SATInfo)
119
120 type SATInfo = [Staticness App]
121 type IdSATInfo = IdEnv SATInfo
122 emptyIdSATInfo :: IdSATInfo
123 emptyIdSATInfo = emptyUFM
124
125 {-
126 pprIdSATInfo id_sat_info = vcat (map pprIdAndSATInfo (Map.toList id_sat_info))
127   where pprIdAndSATInfo (v, sat_info) = hang (ppr v <> colon) 4 (pprSATInfo sat_info)
128 -}
129
130 pprSATInfo :: SATInfo -> SDoc
131 pprSATInfo staticness = hcat $ map pprStaticness staticness
132
133 pprStaticness :: Staticness App -> SDoc
134 pprStaticness (Static (VarApp _))  = ptext (sLit "SV") 
135 pprStaticness (Static (TypeApp _)) = ptext (sLit "ST") 
136 pprStaticness NotStatic            = ptext (sLit "NS")
137
138
139 mergeSATInfo :: SATInfo -> SATInfo -> SATInfo
140 mergeSATInfo [] _  = []
141 mergeSATInfo _  [] = []
142 mergeSATInfo (NotStatic:statics) (_:apps) = NotStatic : mergeSATInfo statics apps
143 mergeSATInfo (_:statics) (NotStatic:apps) = NotStatic : mergeSATInfo statics apps
144 mergeSATInfo ((Static (VarApp v)):statics)  ((Static (VarApp v')):apps)  = (if v == v' then Static (VarApp v) else NotStatic) : mergeSATInfo statics apps
145 mergeSATInfo ((Static (TypeApp t)):statics) ((Static (TypeApp t')):apps) = (if t `coreEqType` t' then Static (TypeApp t) else NotStatic) : mergeSATInfo statics apps
146 mergeSATInfo l  r  = pprPanic "mergeSATInfo" $ ptext (sLit "Left:") <> pprSATInfo l <> ptext (sLit ", ")
147                                             <> ptext (sLit "Right:") <> pprSATInfo r
148
149 mergeIdSATInfo :: IdSATInfo -> IdSATInfo -> IdSATInfo
150 mergeIdSATInfo = plusUFM_C mergeSATInfo
151
152 mergeIdSATInfos :: [IdSATInfo] -> IdSATInfo
153 mergeIdSATInfos = foldl' mergeIdSATInfo emptyIdSATInfo
154
155 bindersToSATInfo :: [Id] -> SATInfo
156 bindersToSATInfo vs = map (Static . binderToApp) vs
157     where binderToApp v = if isId v
158                           then VarApp v
159                           else TypeApp $ mkTyVarTy v
160
161 finalizeApp :: Maybe IdAppInfo -> IdSATInfo -> IdSATInfo
162 finalizeApp Nothing id_sat_info = id_sat_info
163 finalizeApp (Just (v, sat_info')) id_sat_info = 
164     let sat_info'' = case lookupUFM id_sat_info v of
165                         Nothing -> sat_info'
166                         Just sat_info -> mergeSATInfo sat_info sat_info'
167     in extendVarEnv id_sat_info v sat_info''
168 \end{code}
169 \begin{code}
170 satTopLevelExpr :: CoreExpr -> IdSet -> SatM (CoreExpr, IdSATInfo)
171 satTopLevelExpr expr interesting_ids = do
172     (expr', sat_info_expr, expr_app) <- satExpr expr interesting_ids
173     return (expr', finalizeApp expr_app sat_info_expr)
174
175 satExpr :: CoreExpr -> IdSet -> SatM (CoreExpr, IdSATInfo, Maybe IdAppInfo)
176 satExpr var@(Var v) interesting_ids = do
177     let app_info = if v `elementOfUniqSet` interesting_ids
178                    then Just (v, [])
179                    else Nothing
180     return (var, emptyIdSATInfo, app_info)
181
182 satExpr lit@(Lit _) _ = do
183     return (lit, emptyIdSATInfo, Nothing)
184
185 satExpr (Lam binders body) interesting_ids = do
186     (body', sat_info, this_app) <- satExpr body interesting_ids
187     return (Lam binders body', finalizeApp this_app sat_info, Nothing)
188
189 satExpr (App fn arg) interesting_ids = do
190     (fn', sat_info_fn, fn_app) <- satExpr fn interesting_ids
191     let satRemainder = boring fn' sat_info_fn
192     case fn_app of
193         Nothing -> satRemainder Nothing
194         Just (fn_id, fn_app_info) ->
195             -- TODO: remove this use of append somehow (use a data structure with O(1) append but a left-to-right kind of interface)
196             let satRemainderWithStaticness arg_staticness = satRemainder $ Just (fn_id, fn_app_info ++ [arg_staticness])
197             in case arg of
198                 Type t -> satRemainderWithStaticness $ Static (TypeApp t)
199                 Var v  -> satRemainderWithStaticness $ Static (VarApp v)
200                 _      -> satRemainderWithStaticness $ NotStatic
201   where
202     boring :: CoreExpr -> IdSATInfo -> Maybe IdAppInfo -> SatM (CoreExpr, IdSATInfo, Maybe IdAppInfo)
203     boring fn' sat_info_fn app_info = 
204         do (arg', sat_info_arg, arg_app) <- satExpr arg interesting_ids
205            let sat_info_arg' = finalizeApp arg_app sat_info_arg
206                sat_info = mergeIdSATInfo sat_info_fn sat_info_arg'
207            return (App fn' arg', sat_info, app_info)
208
209 satExpr (Case expr bndr ty alts) interesting_ids = do
210     (expr', sat_info_expr, expr_app) <- satExpr expr interesting_ids
211     let sat_info_expr' = finalizeApp expr_app sat_info_expr
212     
213     zipped_alts' <- mapM satAlt alts
214     let (alts', sat_infos_alts) = unzip zipped_alts'
215     return (Case expr' bndr ty alts', mergeIdSATInfo sat_info_expr' (mergeIdSATInfos sat_infos_alts), Nothing)
216   where
217     satAlt (con, bndrs, expr) = do
218         (expr', sat_info_expr) <- satTopLevelExpr expr interesting_ids
219         return ((con, bndrs, expr'), sat_info_expr)
220
221 satExpr (Let bind body) interesting_ids = do
222     (body', sat_info_body, body_app) <- satExpr body interesting_ids
223     (bind', sat_info_bind) <- satBind bind interesting_ids
224     return (Let bind' body', mergeIdSATInfo sat_info_body sat_info_bind, body_app)
225
226 satExpr (Note note expr) interesting_ids = do
227     (expr', sat_info_expr, expr_app) <- satExpr expr interesting_ids
228     return (Note note expr', sat_info_expr, expr_app)
229
230 satExpr ty@(Type _) _ = do
231     return (ty, emptyIdSATInfo, Nothing)
232
233 satExpr (Cast expr coercion) interesting_ids = do
234     (expr', sat_info_expr, expr_app) <- satExpr expr interesting_ids
235     return (Cast expr' coercion, sat_info_expr, expr_app)
236 \end{code}
237
238 %************************************************************************
239
240                 Static Argument Transformation Monad
241
242 %************************************************************************
243
244 \begin{code}
245 type SatM result = UniqSM result
246
247 runSAT :: UniqSupply -> SatM a -> a
248 runSAT = initUs_
249
250 newUnique :: SatM Unique
251 newUnique = getUniqueUs
252 \end{code}
253
254
255 %************************************************************************
256
257                 Static Argument Transformation Monad
258
259 %************************************************************************
260
261 To do the transformation, the game plan is to:
262
263 1. Create a small nonrecursive RHS that takes the
264    original arguments to the function but discards
265    the ones that are static and makes a call to the
266    SATed version with the remainder. We intend that
267    this will be inlined later, removing the overhead
268
269 2. Bind this nonrecursive RHS over the original body
270    WITH THE SAME UNIQUE as the original body so that
271    any recursive calls to the original now go via
272    the small wrapper
273
274 3. Rebind the original function to a new one which contains
275    our SATed function and just makes a call to it:
276    we call the thing making this call the local body
277
278 Example: transform this
279
280     map :: forall a b. (a->b) -> [a] -> [b]
281     map = /\ab. \(f:a->b) (as:[a]) -> body[map]
282 to
283     map :: forall a b. (a->b) -> [a] -> [b]
284     map = /\ab. \(f:a->b) (as:[a]) ->
285          letrec map' :: [a] -> [b]
286                     -- The "worker function
287                 map' = \(as:[a]) -> 
288                          let map :: forall a' b'. (a -> b) -> [a] -> [b]
289                                 -- The "shadow function
290                              map = /\a'b'. \(f':(a->b) (as:[a]).
291                                    map' as
292                          in body[map]
293          in map' as
294
295 Note [Shadow binding]
296 ~~~~~~~~~~~~~~~~~~~~~
297 The calls to the inner map inside body[map] should get inlined
298 by the local re-binding of 'map'.  We call this the "shadow binding".
299
300 But we can't use the original binder 'map' unchanged, because
301 it might be exported, in which case the shadow binding won't be
302 discarded as dead code after it is inlined.
303
304 So we use a hack: we make a new SysLocal binder with the *same* unique
305 as binder.  (Another alternative would be to reset the export flag.)
306
307 Note [Binder type capture]
308 ~~~~~~~~~~~~~~~~~~~~~~~~~~
309 Notice that in the inner map (the "shadow function"), the static arguments
310 are discarded -- it's as if they were underscores.  Instead, mentions
311 of these arguments (notably in the types of dynamic arguments) are bound
312 by the *outer* lambdas of the main function.  So we must make up fresh
313 names for the static arguments so that they do not capture variables 
314 mentioned in the types of dynamic args.  
315
316 In the map example, the shadow function must clone the static type
317 argument a,b, giving a',b', to ensure that in the \(as:[a]), the 'a'
318 is bound by the outer forall.  We clone f' too for consistency, but
319 that doesn't matter either way because static Id arguments aren't 
320 mentioned in the shadow binding at all.
321
322 If we don't we get something like this:
323
324 [Exported]
325 [Arity 3]
326 GHC.Base.until =
327   \ (@ a_aiK)
328     (p_a6T :: a_aiK -> GHC.Types.Bool)
329     (f_a6V :: a_aiK -> a_aiK)
330     (x_a6X :: a_aiK) ->
331     letrec {
332       sat_worker_s1aU :: a_aiK -> a_aiK
333       []
334       sat_worker_s1aU =
335         \ (x_a6X :: a_aiK) ->
336           let {
337             sat_shadow_r17 :: forall a_a3O.
338                               (a_a3O -> GHC.Types.Bool) -> (a_a3O -> a_a3O) -> a_a3O -> a_a3O
339             []
340             sat_shadow_r17 =
341               \ (@ a_aiK)
342                 (p_a6T :: a_aiK -> GHC.Types.Bool)
343                 (f_a6V :: a_aiK -> a_aiK)
344                 (x_a6X :: a_aiK) ->
345                 sat_worker_s1aU x_a6X } in
346           case p_a6T x_a6X of wild_X3y [ALWAYS Dead Nothing] {
347             GHC.Types.False -> GHC.Base.until @ a_aiK p_a6T f_a6V (f_a6V x_a6X);
348             GHC.Types.True -> x_a6X
349           }; } in
350     sat_worker_s1aU x_a6X
351     
352 Where sat_shadow has captured the type variables of x_a6X etc as it has a a_aiK 
353 type argument. This is bad because it means the application sat_worker_s1aU x_a6X
354 is not well typed.
355
356 \begin{code}
357 saTransformMaybe :: Id -> Maybe SATInfo -> [Id] -> CoreExpr -> SatM CoreBind
358 saTransformMaybe binder maybe_arg_staticness rhs_binders rhs_body
359   | Just arg_staticness <- maybe_arg_staticness
360   , should_transform arg_staticness
361   = saTransform binder arg_staticness rhs_binders rhs_body
362   | otherwise
363   = return (Rec [(binder, mkLams rhs_binders rhs_body)])
364   where 
365     should_transform staticness = n_static_args > 1 -- THIS IS THE DECISION POINT
366       where
367         n_static_args = length (filter isStaticValue staticness)
368
369 saTransform :: Id -> SATInfo -> [Id] -> CoreExpr -> SatM CoreBind
370 saTransform binder arg_staticness rhs_binders rhs_body
371   = do  { shadow_lam_bndrs <- mapM clone binders_w_staticness
372         ; uniq             <- newUnique
373         ; return (NonRec binder (mk_new_rhs uniq shadow_lam_bndrs)) }
374   where
375     -- Running example: foldr
376     -- foldr \alpha \beta c n xs = e, for some e
377     -- arg_staticness = [Static TypeApp, Static TypeApp, Static VarApp, Static VarApp, NonStatic]
378     -- rhs_binders = [\alpha, \beta, c, n, xs]
379     -- rhs_body = e
380     
381     binders_w_staticness = rhs_binders `zip` (arg_staticness ++ repeat NotStatic)
382                                         -- Any extra args are assumed NotStatic
383
384     non_static_args :: [Var]
385             -- non_static_args = [xs]
386             -- rhs_binders_without_type_capture = [\alpha', \beta', c, n, xs]
387     non_static_args = [v | (v, NotStatic) <- binders_w_staticness]
388
389     clone (bndr, NotStatic) = return bndr
390     clone (bndr, _        ) = do { uniq <- newUnique
391                                  ; return (setVarUnique bndr uniq) }
392
393     -- new_rhs = \alpha beta c n xs -> 
394     --           let sat_worker = \xs -> let sat_shadow = \alpha' beta' c n xs -> 
395     --                                       sat_worker xs 
396     --                                   in e
397     --           in sat_worker xs
398     mk_new_rhs uniq shadow_lam_bndrs 
399         = mkLams rhs_binders $ 
400           Let (Rec [(rec_body_bndr, rec_body)]) 
401           local_body
402         where
403           local_body = mkVarApps (Var rec_body_bndr) non_static_args
404
405           rec_body = mkLams non_static_args $
406                      Let (NonRec shadow_bndr shadow_rhs) rhs_body
407
408             -- See Note [Binder type capture]
409           shadow_rhs = mkLams shadow_lam_bndrs local_body
410             -- nonrec_rhs = \alpha' beta' c n xs -> sat_worker xs
411
412           rec_body_bndr = mkSysLocal (fsLit "sat_worker") uniq (exprType rec_body)
413             -- rec_body_bndr = sat_worker
414     
415             -- See Note [Shadow binding]; make a SysLocal
416           shadow_bndr = mkSysLocal (occNameFS (getOccName binder)) 
417                                    (idUnique binder)
418                                    (exprType shadow_rhs)
419
420 isStaticValue :: Staticness App -> Bool
421 isStaticValue (Static (VarApp _)) = True
422 isStaticValue _                   = False
423
424 \end{code}