e6e5ff1fb2a7e282137d88734b2747a52f626c27
[ghc-hetmet.git] / compiler / simplCore / SAT.lhs
1 %
2 % (c) The GRASP/AQUA Project, Glasgow University, 1992-1998
3 %
4
5 %************************************************************************
6
7                Static Argument Transformation pass
8
9 %************************************************************************
10
11 May be seen as removing invariants from loops:
12 Arguments of recursive functions that do not change in recursive
13 calls are removed from the recursion, which is done locally
14 and only passes the arguments which effectively change.
15
16 Example:
17 map = /\ ab -> \f -> \xs -> case xs of
18                  []       -> []
19                  (a:b) -> f a : map f b
20
21 as map is recursively called with the same argument f (unmodified)
22 we transform it to
23
24 map = /\ ab -> \f -> \xs -> let map' ys = case ys of
25                        []     -> []
26                        (a:b) -> f a : map' b
27                 in map' xs
28
29 Notice that for a compiler that uses lambda lifting this is
30 useless as map' will be transformed back to what map was.
31
32 We could possibly do the same for big lambdas, but we don't as
33 they will eventually be removed in later stages of the compiler,
34 therefore there is no penalty in keeping them.
35
36 We only apply the SAT when the number of static args is > 2. This
37 produces few bad cases.  See
38                 should_transform
39 in saTransform.
40
41 Here are the headline nofib results:
42                   Size    Allocs   Runtime
43 Min             +0.0%    -13.7%    -21.4%
44 Max             +0.1%     +0.0%     +5.4%
45 Geometric Mean  +0.0%     -0.2%     -6.9%
46
47 The previous patch, to fix polymorphic floatout demand signatures, is
48 essential to make this work well!
49
50
51 \begin{code}
52
53 module SAT ( doStaticArgs ) where
54
55 import DynFlags
56 import Var hiding (mkLocalId)
57 import CoreSyn
58 import CoreLint
59 import CoreUtils
60 import Type
61 import TcType
62 import Id
63 import Name
64 import OccName
65 import VarEnv
66 import UniqSupply
67 import Util
68 import UniqFM
69 import VarSet
70 import Unique
71 import UniqSet
72 import Outputable
73
74 import Data.List
75 import FastString
76
77 #include "HsVersions.h"
78 \end{code}
79
80 \begin{code}
81 doStaticArgs :: DynFlags -> UniqSupply -> [CoreBind] -> IO [CoreBind]
82 doStaticArgs dflags us binds = do
83     showPass dflags "Static argument"
84     let binds' = snd $ mapAccumL sat_bind_threaded_us us binds
85     endPass dflags "Static argument" Opt_D_verbose_core2core binds'
86   where
87     sat_bind_threaded_us us bind =
88         let (us1, us2) = splitUniqSupply us
89         in (us1, fst $ runSAT us2 (satBind bind emptyUniqSet))
90 \end{code}
91 \begin{code}
92 -- We don't bother to SAT recursive groups since it can lead
93 -- to massive code expansion: see Andre Santos' thesis for details.
94 -- This means we only apply the actual SAT to Rec groups of one element,
95 -- but we want to recurse into the others anyway to discover other binds
96 satBind :: CoreBind -> IdSet -> SatM (CoreBind, IdSATInfo)
97 satBind (NonRec binder expr) interesting_ids = do
98     (expr', sat_info_expr, expr_app) <- satExpr expr interesting_ids
99     return (NonRec binder expr', finalizeApp expr_app sat_info_expr)
100 satBind (Rec [(binder, rhs)]) interesting_ids = do
101     let interesting_ids' = interesting_ids `addOneToUniqSet` binder
102         (rhs_binders, rhs_body) = collectBinders rhs
103     (rhs_body', sat_info_rhs_body) <- satTopLevelExpr rhs_body interesting_ids'
104     let sat_info_rhs_from_args = unitVarEnv binder (bindersToSATInfo rhs_binders)
105         sat_info_rhs' = mergeIdSATInfo sat_info_rhs_from_args sat_info_rhs_body
106         
107         shadowing = binder `elementOfUniqSet` interesting_ids
108         sat_info_rhs'' = if shadowing
109                         then sat_info_rhs' `delFromUFM` binder -- For safety
110                         else sat_info_rhs'
111     
112     bind' <- saTransformMaybe binder (lookupUFM sat_info_rhs' binder) 
113                               rhs_binders rhs_body'
114     return (bind', sat_info_rhs'')
115 satBind (Rec pairs) interesting_ids = do
116     let (binders, rhss) = unzip pairs
117     rhss_SATed <- mapM (\e -> satTopLevelExpr e interesting_ids) rhss
118     let (rhss', sat_info_rhss') = unzip rhss_SATed
119     return (Rec (zipEqual "satBind" binders rhss'), mergeIdSATInfos sat_info_rhss')
120 \end{code}
121 \begin{code}
122 data App = VarApp Id | TypeApp Type
123 data Staticness a = Static a | NotStatic
124
125 type IdAppInfo = (Id, SATInfo)
126
127 type SATInfo = [Staticness App]
128 type IdSATInfo = IdEnv SATInfo
129 emptyIdSATInfo :: IdSATInfo
130 emptyIdSATInfo = emptyUFM
131
132 {-
133 pprIdSATInfo id_sat_info = vcat (map pprIdAndSATInfo (fmToList id_sat_info))
134   where pprIdAndSATInfo (v, sat_info) = hang (ppr v <> colon) 4 (pprSATInfo sat_info)
135 -}
136
137 pprSATInfo :: SATInfo -> SDoc
138 pprSATInfo staticness = hcat $ map pprStaticness staticness
139
140 pprStaticness :: Staticness App -> SDoc
141 pprStaticness (Static (VarApp _))  = ptext (sLit "SV") 
142 pprStaticness (Static (TypeApp _)) = ptext (sLit "ST") 
143 pprStaticness NotStatic            = ptext (sLit "NS")
144
145
146 mergeSATInfo :: SATInfo -> SATInfo -> SATInfo
147 mergeSATInfo [] _  = []
148 mergeSATInfo _  [] = []
149 mergeSATInfo (NotStatic:statics) (_:apps) = NotStatic : mergeSATInfo statics apps
150 mergeSATInfo (_:statics) (NotStatic:apps) = NotStatic : mergeSATInfo statics apps
151 mergeSATInfo ((Static (VarApp v)):statics)  ((Static (VarApp v')):apps)  = (if v == v' then Static (VarApp v) else NotStatic) : mergeSATInfo statics apps
152 mergeSATInfo ((Static (TypeApp t)):statics) ((Static (TypeApp t')):apps) = (if t `coreEqType` t' then Static (TypeApp t) else NotStatic) : mergeSATInfo statics apps
153 mergeSATInfo l  r  = pprPanic "mergeSATInfo" $ ptext (sLit "Left:") <> pprSATInfo l <> ptext (sLit ", ")
154                                             <> ptext (sLit "Right:") <> pprSATInfo r
155
156 mergeIdSATInfo :: IdSATInfo -> IdSATInfo -> IdSATInfo
157 mergeIdSATInfo = plusUFM_C mergeSATInfo
158
159 mergeIdSATInfos :: [IdSATInfo] -> IdSATInfo
160 mergeIdSATInfos = foldl' mergeIdSATInfo emptyIdSATInfo
161
162 bindersToSATInfo :: [Id] -> SATInfo
163 bindersToSATInfo vs = map (Static . binderToApp) vs
164     where binderToApp v = if isId v
165                           then VarApp v
166                           else TypeApp $ mkTyVarTy v
167
168 finalizeApp :: Maybe IdAppInfo -> IdSATInfo -> IdSATInfo
169 finalizeApp Nothing id_sat_info = id_sat_info
170 finalizeApp (Just (v, sat_info')) id_sat_info = 
171     let sat_info'' = case lookupUFM id_sat_info v of
172                         Nothing -> sat_info'
173                         Just sat_info -> mergeSATInfo sat_info sat_info'
174     in extendVarEnv id_sat_info v sat_info''
175 \end{code}
176 \begin{code}
177 satTopLevelExpr :: CoreExpr -> IdSet -> SatM (CoreExpr, IdSATInfo)
178 satTopLevelExpr expr interesting_ids = do
179     (expr', sat_info_expr, expr_app) <- satExpr expr interesting_ids
180     return (expr', finalizeApp expr_app sat_info_expr)
181
182 satExpr :: CoreExpr -> IdSet -> SatM (CoreExpr, IdSATInfo, Maybe IdAppInfo)
183 satExpr var@(Var v) interesting_ids = do
184     let app_info = if v `elementOfUniqSet` interesting_ids
185                    then Just (v, [])
186                    else Nothing
187     return (var, emptyIdSATInfo, app_info)
188
189 satExpr lit@(Lit _) _ = do
190     return (lit, emptyIdSATInfo, Nothing)
191
192 satExpr (Lam binders body) interesting_ids = do
193     (body', sat_info, this_app) <- satExpr body interesting_ids
194     return (Lam binders body', finalizeApp this_app sat_info, Nothing)
195
196 satExpr (App fn arg) interesting_ids = do
197     (fn', sat_info_fn, fn_app) <- satExpr fn interesting_ids
198     let satRemainder = boring fn' sat_info_fn
199     case fn_app of
200         Nothing -> satRemainder Nothing
201         Just (fn_id, fn_app_info) ->
202             -- TODO: remove this use of append somehow (use a data structure with O(1) append but a left-to-right kind of interface)
203             let satRemainderWithStaticness arg_staticness = satRemainder $ Just (fn_id, fn_app_info ++ [arg_staticness])
204             in case arg of
205                 Type t -> satRemainderWithStaticness $ Static (TypeApp t)
206                 Var v  -> satRemainderWithStaticness $ Static (VarApp v)
207                 _      -> satRemainderWithStaticness $ NotStatic
208   where
209     boring :: CoreExpr -> IdSATInfo -> Maybe IdAppInfo -> SatM (CoreExpr, IdSATInfo, Maybe IdAppInfo)
210     boring fn' sat_info_fn app_info = 
211         do (arg', sat_info_arg, arg_app) <- satExpr arg interesting_ids
212            let sat_info_arg' = finalizeApp arg_app sat_info_arg
213                sat_info = mergeIdSATInfo sat_info_fn sat_info_arg'
214            return (App fn' arg', sat_info, app_info)
215
216 satExpr (Case expr bndr ty alts) interesting_ids = do
217     (expr', sat_info_expr, expr_app) <- satExpr expr interesting_ids
218     let sat_info_expr' = finalizeApp expr_app sat_info_expr
219     
220     zipped_alts' <- mapM satAlt alts
221     let (alts', sat_infos_alts) = unzip zipped_alts'
222     return (Case expr' bndr ty alts', mergeIdSATInfo sat_info_expr' (mergeIdSATInfos sat_infos_alts), Nothing)
223   where
224     satAlt (con, bndrs, expr) = do
225         (expr', sat_info_expr) <- satTopLevelExpr expr interesting_ids
226         return ((con, bndrs, expr'), sat_info_expr)
227
228 satExpr (Let bind body) interesting_ids = do
229     (body', sat_info_body, body_app) <- satExpr body interesting_ids
230     (bind', sat_info_bind) <- satBind bind interesting_ids
231     return (Let bind' body', mergeIdSATInfo sat_info_body sat_info_bind, body_app)
232
233 satExpr (Note note expr) interesting_ids = do
234     (expr', sat_info_expr, expr_app) <- satExpr expr interesting_ids
235     return (Note note expr', sat_info_expr, expr_app)
236
237 satExpr ty@(Type _) _ = do
238     return (ty, emptyIdSATInfo, Nothing)
239
240 satExpr (Cast expr coercion) interesting_ids = do
241     (expr', sat_info_expr, expr_app) <- satExpr expr interesting_ids
242     return (Cast expr' coercion, sat_info_expr, expr_app)
243 \end{code}
244
245 %************************************************************************
246
247                 Static Argument Transformation Monad
248
249 %************************************************************************
250
251 \begin{code}
252 type SatM result = UniqSM result
253
254 runSAT :: UniqSupply -> SatM a -> a
255 runSAT = initUs_
256
257 newUnique :: SatM Unique
258 newUnique = getUniqueUs
259 \end{code}
260
261
262 %************************************************************************
263
264                 Static Argument Transformation Monad
265
266 %************************************************************************
267
268 To do the transformation, the game plan is to:
269
270 1. Create a small nonrecursive RHS that takes the
271    original arguments to the function but discards
272    the ones that are static and makes a call to the
273    SATed version with the remainder. We intend that
274    this will be inlined later, removing the overhead
275
276 2. Bind this nonrecursive RHS over the original body
277    WITH THE SAME UNIQUE as the original body so that
278    any recursive calls to the original now go via
279    the small wrapper
280
281 3. Rebind the original function to a new one which contains
282    our SATed function and just makes a call to it:
283    we call the thing making this call the local body
284
285 Example: transform this
286
287     map :: forall a b. (a->b) -> [a] -> [b]
288     map = /\ab. \(f:a->b) (as:[a]) -> body[map]
289 to
290     map :: forall a b. (a->b) -> [a] -> [b]
291     map = /\ab. \(f:a->b) (as:[a]) ->
292          letrec map' :: [a] -> [b]
293                     -- The "worker function
294                 map' = \(as:[a]) -> 
295                          let map :: forall a' b'. (a -> b) -> [a] -> [b]
296                                 -- The "shadow function
297                              map = /\a'b'. \(f':(a->b) (as:[a]).
298                                    map' as
299                          in body[map]
300          in map' as
301
302 Note [Shadow binding]
303 ~~~~~~~~~~~~~~~~~~~~~
304 The calls to the inner map inside body[map] should get inlined
305 by the local re-binding of 'map'.  We call this the "shadow binding".
306
307 But we can't use the original binder 'map' unchanged, because
308 it might be exported, in which case the shadow binding won't be
309 discarded as dead code after it is inlined.
310
311 So we use a hack: we make a new SysLocal binder with the *same* unique
312 as binder.  (Another alternative would be to reset the export flag.)
313
314 Note [Binder type capture]
315 ~~~~~~~~~~~~~~~~~~~~~~~~~~
316 Notice that in the inner map (the "shadow function"), the static arguments
317 are discarded -- it's as if they were underscores.  Instead, mentions
318 of these arguments (notably in the types of dynamic arguments) are bound
319 by the *outer* lambdas of the main function.  So we must make up fresh
320 names for the static arguments so that they do not capture variables 
321 mentioned in the types of dynamic args.  
322
323 In the map example, the shadow function must clone the static type
324 argument a,b, giving a',b', to ensure that in the \(as:[a]), the 'a'
325 is bound by the outer forall.  We clone f' too for consistency, but
326 that doesn't matter either way because static Id arguments aren't 
327 mentioned in the shadow binding at all.
328
329 If we don't we get something like this:
330
331 [Exported]
332 [Arity 3]
333 GHC.Base.until =
334   \ (@ a_aiK)
335     (p_a6T :: a_aiK -> GHC.Bool.Bool)
336     (f_a6V :: a_aiK -> a_aiK)
337     (x_a6X :: a_aiK) ->
338     letrec {
339       sat_worker_s1aU :: a_aiK -> a_aiK
340       []
341       sat_worker_s1aU =
342         \ (x_a6X :: a_aiK) ->
343           let {
344             sat_shadow_r17 :: forall a_a3O.
345                               (a_a3O -> GHC.Bool.Bool) -> (a_a3O -> a_a3O) -> a_a3O -> a_a3O
346             []
347             sat_shadow_r17 =
348               \ (@ a_aiK)
349                 (p_a6T :: a_aiK -> GHC.Bool.Bool)
350                 (f_a6V :: a_aiK -> a_aiK)
351                 (x_a6X :: a_aiK) ->
352                 sat_worker_s1aU x_a6X } in
353           case p_a6T x_a6X of wild_X3y [ALWAYS Dead Nothing] {
354             GHC.Bool.False -> GHC.Base.until @ a_aiK p_a6T f_a6V (f_a6V x_a6X);
355             GHC.Bool.True -> x_a6X
356           }; } in
357     sat_worker_s1aU x_a6X
358     
359 Where sat_shadow has captured the type variables of x_a6X etc as it has a a_aiK 
360 type argument. This is bad because it means the application sat_worker_s1aU x_a6X
361 is not well typed.
362
363 \begin{code}
364 saTransformMaybe :: Id -> Maybe SATInfo -> [Id] -> CoreExpr -> SatM CoreBind
365 saTransformMaybe binder maybe_arg_staticness rhs_binders rhs_body
366   | Just arg_staticness <- maybe_arg_staticness
367   , should_transform arg_staticness
368   = saTransform binder arg_staticness rhs_binders rhs_body
369   | otherwise
370   = return (Rec [(binder, mkLams rhs_binders rhs_body)])
371   where 
372     should_transform staticness = n_static_args > 1 -- THIS IS THE DECISION POINT
373       where
374         n_static_args = length (filter isStaticValue staticness)
375
376 saTransform :: Id -> SATInfo -> [Id] -> CoreExpr -> SatM CoreBind
377 saTransform binder arg_staticness rhs_binders rhs_body
378   = do  { shadow_lam_bndrs <- mapM clone binders_w_staticness
379         ; uniq             <- newUnique
380         ; return (NonRec binder (mk_new_rhs uniq shadow_lam_bndrs)) }
381   where
382     -- Running example: foldr
383     -- foldr \alpha \beta c n xs = e, for some e
384     -- arg_staticness = [Static TypeApp, Static TypeApp, Static VarApp, Static VarApp, NonStatic]
385     -- rhs_binders = [\alpha, \beta, c, n, xs]
386     -- rhs_body = e
387     
388     binders_w_staticness = rhs_binders `zip` (arg_staticness ++ repeat NotStatic)
389                                         -- Any extra args are assumed NotStatic
390
391     non_static_args :: [Var]
392             -- non_static_args = [xs]
393             -- rhs_binders_without_type_capture = [\alpha', \beta', c, n, xs]
394     non_static_args = [v | (v, NotStatic) <- binders_w_staticness]
395
396     clone (bndr, NotStatic) = return bndr
397     clone (bndr, _        ) = do { uniq <- newUnique
398                                  ; return (setVarUnique bndr uniq) }
399
400     -- new_rhs = \alpha beta c n xs -> 
401     --           let sat_worker = \xs -> let sat_shadow = \alpha' beta' c n xs -> 
402     --                                       sat_worker xs 
403     --                                   in e
404     --           in sat_worker xs
405     mk_new_rhs uniq shadow_lam_bndrs 
406         = mkLams rhs_binders $ 
407           Let (Rec [(rec_body_bndr, rec_body)]) 
408           local_body
409         where
410           local_body = mkVarApps (Var rec_body_bndr) non_static_args
411
412           rec_body = mkLams non_static_args $
413                      Let (NonRec shadow_bndr shadow_rhs) rhs_body
414
415             -- See Note [Binder type capture]
416           shadow_rhs = mkLams shadow_lam_bndrs local_body
417             -- nonrec_rhs = \alpha' beta' c n xs -> sat_worker xs
418
419           rec_body_bndr = mkSysLocal (fsLit "sat_worker") uniq (exprType rec_body)
420             -- rec_body_bndr = sat_worker
421     
422             -- See Note [Shadow binding]; make a SysLocal
423           shadow_bndr = mkSysLocal (occNameFS (getOccName binder)) 
424                                    (idUnique binder)
425                                    (exprType shadow_rhs)
426
427 isStaticValue :: Staticness App -> Bool
428 isStaticValue (Static (VarApp _)) = True
429 isStaticValue _                   = False
430
431 \end{code}