swap <[]> and <{}> syntax
[ghc-hetmet.git] / compiler / simplCore / SAT.lhs
1 %
2 % (c) The GRASP/AQUA Project, Glasgow University, 1992-1998
3 %
4
5 %************************************************************************
6
7                Static Argument Transformation pass
8
9 %************************************************************************
10
11 May be seen as removing invariants from loops:
12 Arguments of recursive functions that do not change in recursive
13 calls are removed from the recursion, which is done locally
14 and only passes the arguments which effectively change.
15
16 Example:
17 map = /\ ab -> \f -> \xs -> case xs of
18                  []       -> []
19                  (a:b) -> f a : map f b
20
21 as map is recursively called with the same argument f (unmodified)
22 we transform it to
23
24 map = /\ ab -> \f -> \xs -> let map' ys = case ys of
25                        []     -> []
26                        (a:b) -> f a : map' b
27                 in map' xs
28
29 Notice that for a compiler that uses lambda lifting this is
30 useless as map' will be transformed back to what map was.
31
32 We could possibly do the same for big lambdas, but we don't as
33 they will eventually be removed in later stages of the compiler,
34 therefore there is no penalty in keeping them.
35
36 We only apply the SAT when the number of static args is > 2. This
37 produces few bad cases.  See
38                 should_transform
39 in saTransform.
40
41 Here are the headline nofib results:
42                   Size    Allocs   Runtime
43 Min             +0.0%    -13.7%    -21.4%
44 Max             +0.1%     +0.0%     +5.4%
45 Geometric Mean  +0.0%     -0.2%     -6.9%
46
47 The previous patch, to fix polymorphic floatout demand signatures, is
48 essential to make this work well!
49
50
51 \begin{code}
52
53 module SAT ( doStaticArgs ) where
54
55 import Var
56 import CoreSyn
57 import CoreUtils
58 import Type
59 import Coercion
60 import Id
61 import Name
62 import VarEnv
63 import UniqSupply
64 import Util
65 import UniqFM
66 import VarSet
67 import Unique
68 import UniqSet
69 import Outputable
70
71 import Data.List
72 import FastString
73
74 #include "HsVersions.h"
75 \end{code}
76
77 \begin{code}
78 doStaticArgs :: UniqSupply -> [CoreBind] -> [CoreBind]
79 doStaticArgs us binds = snd $ mapAccumL sat_bind_threaded_us us binds
80   where
81     sat_bind_threaded_us us bind =
82         let (us1, us2) = splitUniqSupply us
83         in (us1, fst $ runSAT us2 (satBind bind emptyUniqSet))
84 \end{code}
85 \begin{code}
86 -- We don't bother to SAT recursive groups since it can lead
87 -- to massive code expansion: see Andre Santos' thesis for details.
88 -- This means we only apply the actual SAT to Rec groups of one element,
89 -- but we want to recurse into the others anyway to discover other binds
90 satBind :: CoreBind -> IdSet -> SatM (CoreBind, IdSATInfo)
91 satBind (NonRec binder expr) interesting_ids = do
92     (expr', sat_info_expr, expr_app) <- satExpr expr interesting_ids
93     return (NonRec binder expr', finalizeApp expr_app sat_info_expr)
94 satBind (Rec [(binder, rhs)]) interesting_ids = do
95     let interesting_ids' = interesting_ids `addOneToUniqSet` binder
96         (rhs_binders, rhs_body) = collectBinders rhs
97     (rhs_body', sat_info_rhs_body) <- satTopLevelExpr rhs_body interesting_ids'
98     let sat_info_rhs_from_args = unitVarEnv binder (bindersToSATInfo rhs_binders)
99         sat_info_rhs' = mergeIdSATInfo sat_info_rhs_from_args sat_info_rhs_body
100         
101         shadowing = binder `elementOfUniqSet` interesting_ids
102         sat_info_rhs'' = if shadowing
103                         then sat_info_rhs' `delFromUFM` binder -- For safety
104                         else sat_info_rhs'
105     
106     bind' <- saTransformMaybe binder (lookupUFM sat_info_rhs' binder) 
107                               rhs_binders rhs_body'
108     return (bind', sat_info_rhs'')
109 satBind (Rec pairs) interesting_ids = do
110     let (binders, rhss) = unzip pairs
111     rhss_SATed <- mapM (\e -> satTopLevelExpr e interesting_ids) rhss
112     let (rhss', sat_info_rhss') = unzip rhss_SATed
113     return (Rec (zipEqual "satBind" binders rhss'), mergeIdSATInfos sat_info_rhss')
114 \end{code}
115 \begin{code}
116 data App = VarApp Id | TypeApp Type | CoApp Coercion
117 data Staticness a = Static a | NotStatic
118
119 type IdAppInfo = (Id, SATInfo)
120
121 type SATInfo = [Staticness App]
122 type IdSATInfo = IdEnv SATInfo
123 emptyIdSATInfo :: IdSATInfo
124 emptyIdSATInfo = emptyUFM
125
126 {-
127 pprIdSATInfo id_sat_info = vcat (map pprIdAndSATInfo (Map.toList id_sat_info))
128   where pprIdAndSATInfo (v, sat_info) = hang (ppr v <> colon) 4 (pprSATInfo sat_info)
129 -}
130
131 pprSATInfo :: SATInfo -> SDoc
132 pprSATInfo staticness = hcat $ map pprStaticness staticness
133
134 pprStaticness :: Staticness App -> SDoc
135 pprStaticness (Static (VarApp _))  = ptext (sLit "SV") 
136 pprStaticness (Static (TypeApp _)) = ptext (sLit "ST") 
137 pprStaticness (Static (CoApp _))   = ptext (sLit "SC")
138 pprStaticness NotStatic            = ptext (sLit "NS")
139
140
141 mergeSATInfo :: SATInfo -> SATInfo -> SATInfo
142 mergeSATInfo [] _  = []
143 mergeSATInfo _  [] = []
144 mergeSATInfo (NotStatic:statics) (_:apps) = NotStatic : mergeSATInfo statics apps
145 mergeSATInfo (_:statics) (NotStatic:apps) = NotStatic : mergeSATInfo statics apps
146 mergeSATInfo ((Static (VarApp v)):statics)  ((Static (VarApp v')):apps)  = (if v == v' then Static (VarApp v) else NotStatic) : mergeSATInfo statics apps
147 mergeSATInfo ((Static (TypeApp t)):statics) ((Static (TypeApp t')):apps) = (if t `eqType` t' then Static (TypeApp t) else NotStatic) : mergeSATInfo statics apps
148 mergeSATInfo ((Static (CoApp c)):statics) ((Static (CoApp c')):apps)     = (if c `coreEqCoercion` c' then Static (CoApp c) else NotStatic) : mergeSATInfo statics apps
149 mergeSATInfo l  r  = pprPanic "mergeSATInfo" $ ptext (sLit "Left:") <> pprSATInfo l <> ptext (sLit ", ")
150                                             <> ptext (sLit "Right:") <> pprSATInfo r
151
152 mergeIdSATInfo :: IdSATInfo -> IdSATInfo -> IdSATInfo
153 mergeIdSATInfo = plusUFM_C mergeSATInfo
154
155 mergeIdSATInfos :: [IdSATInfo] -> IdSATInfo
156 mergeIdSATInfos = foldl' mergeIdSATInfo emptyIdSATInfo
157
158 bindersToSATInfo :: [Id] -> SATInfo
159 bindersToSATInfo vs = map (Static . binderToApp) vs
160     where binderToApp v | isId v    = VarApp v
161                         | isTyVar v = TypeApp $ mkTyVarTy v
162                         | otherwise = CoApp $ mkCoVarCo v
163
164 finalizeApp :: Maybe IdAppInfo -> IdSATInfo -> IdSATInfo
165 finalizeApp Nothing id_sat_info = id_sat_info
166 finalizeApp (Just (v, sat_info')) id_sat_info = 
167     let sat_info'' = case lookupUFM id_sat_info v of
168                         Nothing -> sat_info'
169                         Just sat_info -> mergeSATInfo sat_info sat_info'
170     in extendVarEnv id_sat_info v sat_info''
171 \end{code}
172 \begin{code}
173 satTopLevelExpr :: CoreExpr -> IdSet -> SatM (CoreExpr, IdSATInfo)
174 satTopLevelExpr expr interesting_ids = do
175     (expr', sat_info_expr, expr_app) <- satExpr expr interesting_ids
176     return (expr', finalizeApp expr_app sat_info_expr)
177
178 satExpr :: CoreExpr -> IdSet -> SatM (CoreExpr, IdSATInfo, Maybe IdAppInfo)
179 satExpr var@(Var v) interesting_ids = do
180     let app_info = if v `elementOfUniqSet` interesting_ids
181                    then Just (v, [])
182                    else Nothing
183     return (var, emptyIdSATInfo, app_info)
184
185 satExpr lit@(Lit _) _ = do
186     return (lit, emptyIdSATInfo, Nothing)
187
188 satExpr (Lam binders body) interesting_ids = do
189     (body', sat_info, this_app) <- satExpr body interesting_ids
190     return (Lam binders body', finalizeApp this_app sat_info, Nothing)
191
192 satExpr (App fn arg) interesting_ids = do
193     (fn', sat_info_fn, fn_app) <- satExpr fn interesting_ids
194     let satRemainder = boring fn' sat_info_fn
195     case fn_app of
196         Nothing -> satRemainder Nothing
197         Just (fn_id, fn_app_info) ->
198             -- TODO: remove this use of append somehow (use a data structure with O(1) append but a left-to-right kind of interface)
199             let satRemainderWithStaticness arg_staticness = satRemainder $ Just (fn_id, fn_app_info ++ [arg_staticness])
200             in case arg of
201                 Type t     -> satRemainderWithStaticness $ Static (TypeApp t)
202                 Coercion c -> satRemainderWithStaticness $ Static (CoApp c)
203                 Var v      -> satRemainderWithStaticness $ Static (VarApp v)
204                 _          -> satRemainderWithStaticness $ NotStatic
205   where
206     boring :: CoreExpr -> IdSATInfo -> Maybe IdAppInfo -> SatM (CoreExpr, IdSATInfo, Maybe IdAppInfo)
207     boring fn' sat_info_fn app_info = 
208         do (arg', sat_info_arg, arg_app) <- satExpr arg interesting_ids
209            let sat_info_arg' = finalizeApp arg_app sat_info_arg
210                sat_info = mergeIdSATInfo sat_info_fn sat_info_arg'
211            return (App fn' arg', sat_info, app_info)
212
213 satExpr (Case expr bndr ty alts) interesting_ids = do
214     (expr', sat_info_expr, expr_app) <- satExpr expr interesting_ids
215     let sat_info_expr' = finalizeApp expr_app sat_info_expr
216     
217     zipped_alts' <- mapM satAlt alts
218     let (alts', sat_infos_alts) = unzip zipped_alts'
219     return (Case expr' bndr ty alts', mergeIdSATInfo sat_info_expr' (mergeIdSATInfos sat_infos_alts), Nothing)
220   where
221     satAlt (con, bndrs, expr) = do
222         (expr', sat_info_expr) <- satTopLevelExpr expr interesting_ids
223         return ((con, bndrs, expr'), sat_info_expr)
224
225 satExpr (Let bind body) interesting_ids = do
226     (body', sat_info_body, body_app) <- satExpr body interesting_ids
227     (bind', sat_info_bind) <- satBind bind interesting_ids
228     return (Let bind' body', mergeIdSATInfo sat_info_body sat_info_bind, body_app)
229
230 satExpr (Note note expr) interesting_ids = do
231     (expr', sat_info_expr, expr_app) <- satExpr expr interesting_ids
232     return (Note note expr', sat_info_expr, expr_app)
233
234 satExpr ty@(Type _) _ = do
235     return (ty, emptyIdSATInfo, Nothing)
236     
237 satExpr co@(Coercion _) _ = do
238     return (co, emptyIdSATInfo, Nothing)
239
240 satExpr (Cast expr coercion) interesting_ids = do
241     (expr', sat_info_expr, expr_app) <- satExpr expr interesting_ids
242     return (Cast expr' coercion, sat_info_expr, expr_app)
243 \end{code}
244
245 %************************************************************************
246
247                 Static Argument Transformation Monad
248
249 %************************************************************************
250
251 \begin{code}
252 type SatM result = UniqSM result
253
254 runSAT :: UniqSupply -> SatM a -> a
255 runSAT = initUs_
256
257 newUnique :: SatM Unique
258 newUnique = getUniqueUs
259 \end{code}
260
261
262 %************************************************************************
263
264                 Static Argument Transformation Monad
265
266 %************************************************************************
267
268 To do the transformation, the game plan is to:
269
270 1. Create a small nonrecursive RHS that takes the
271    original arguments to the function but discards
272    the ones that are static and makes a call to the
273    SATed version with the remainder. We intend that
274    this will be inlined later, removing the overhead
275
276 2. Bind this nonrecursive RHS over the original body
277    WITH THE SAME UNIQUE as the original body so that
278    any recursive calls to the original now go via
279    the small wrapper
280
281 3. Rebind the original function to a new one which contains
282    our SATed function and just makes a call to it:
283    we call the thing making this call the local body
284
285 Example: transform this
286
287     map :: forall a b. (a->b) -> [a] -> [b]
288     map = /\ab. \(f:a->b) (as:[a]) -> body[map]
289 to
290     map :: forall a b. (a->b) -> [a] -> [b]
291     map = /\ab. \(f:a->b) (as:[a]) ->
292          letrec map' :: [a] -> [b]
293                     -- The "worker function
294                 map' = \(as:[a]) -> 
295                          let map :: forall a' b'. (a -> b) -> [a] -> [b]
296                                 -- The "shadow function
297                              map = /\a'b'. \(f':(a->b) (as:[a]).
298                                    map' as
299                          in body[map]
300          in map' as
301
302 Note [Shadow binding]
303 ~~~~~~~~~~~~~~~~~~~~~
304 The calls to the inner map inside body[map] should get inlined
305 by the local re-binding of 'map'.  We call this the "shadow binding".
306
307 But we can't use the original binder 'map' unchanged, because
308 it might be exported, in which case the shadow binding won't be
309 discarded as dead code after it is inlined.
310
311 So we use a hack: we make a new SysLocal binder with the *same* unique
312 as binder.  (Another alternative would be to reset the export flag.)
313
314 Note [Binder type capture]
315 ~~~~~~~~~~~~~~~~~~~~~~~~~~
316 Notice that in the inner map (the "shadow function"), the static arguments
317 are discarded -- it's as if they were underscores.  Instead, mentions
318 of these arguments (notably in the types of dynamic arguments) are bound
319 by the *outer* lambdas of the main function.  So we must make up fresh
320 names for the static arguments so that they do not capture variables 
321 mentioned in the types of dynamic args.  
322
323 In the map example, the shadow function must clone the static type
324 argument a,b, giving a',b', to ensure that in the \(as:[a]), the 'a'
325 is bound by the outer forall.  We clone f' too for consistency, but
326 that doesn't matter either way because static Id arguments aren't 
327 mentioned in the shadow binding at all.
328
329 If we don't we get something like this:
330
331 [Exported]
332 [Arity 3]
333 GHC.Base.until =
334   \ (@ a_aiK)
335     (p_a6T :: a_aiK -> GHC.Types.Bool)
336     (f_a6V :: a_aiK -> a_aiK)
337     (x_a6X :: a_aiK) ->
338     letrec {
339       sat_worker_s1aU :: a_aiK -> a_aiK
340       []
341       sat_worker_s1aU =
342         \ (x_a6X :: a_aiK) ->
343           let {
344             sat_shadow_r17 :: forall a_a3O.
345                               (a_a3O -> GHC.Types.Bool) -> (a_a3O -> a_a3O) -> a_a3O -> a_a3O
346             []
347             sat_shadow_r17 =
348               \ (@ a_aiK)
349                 (p_a6T :: a_aiK -> GHC.Types.Bool)
350                 (f_a6V :: a_aiK -> a_aiK)
351                 (x_a6X :: a_aiK) ->
352                 sat_worker_s1aU x_a6X } in
353           case p_a6T x_a6X of wild_X3y [ALWAYS Dead Nothing] {
354             GHC.Types.False -> GHC.Base.until @ a_aiK p_a6T f_a6V (f_a6V x_a6X);
355             GHC.Types.True -> x_a6X
356           }; } in
357     sat_worker_s1aU x_a6X
358     
359 Where sat_shadow has captured the type variables of x_a6X etc as it has a a_aiK 
360 type argument. This is bad because it means the application sat_worker_s1aU x_a6X
361 is not well typed.
362
363 \begin{code}
364 saTransformMaybe :: Id -> Maybe SATInfo -> [Id] -> CoreExpr -> SatM CoreBind
365 saTransformMaybe binder maybe_arg_staticness rhs_binders rhs_body
366   | Just arg_staticness <- maybe_arg_staticness
367   , should_transform arg_staticness
368   = saTransform binder arg_staticness rhs_binders rhs_body
369   | otherwise
370   = return (Rec [(binder, mkLams rhs_binders rhs_body)])
371   where 
372     should_transform staticness = n_static_args > 1 -- THIS IS THE DECISION POINT
373       where
374         n_static_args = length (filter isStaticValue staticness)
375
376 saTransform :: Id -> SATInfo -> [Id] -> CoreExpr -> SatM CoreBind
377 saTransform binder arg_staticness rhs_binders rhs_body
378   = do  { shadow_lam_bndrs <- mapM clone binders_w_staticness
379         ; uniq             <- newUnique
380         ; return (NonRec binder (mk_new_rhs uniq shadow_lam_bndrs)) }
381   where
382     -- Running example: foldr
383     -- foldr \alpha \beta c n xs = e, for some e
384     -- arg_staticness = [Static TypeApp, Static TypeApp, Static VarApp, Static VarApp, NonStatic]
385     -- rhs_binders = [\alpha, \beta, c, n, xs]
386     -- rhs_body = e
387     
388     binders_w_staticness = rhs_binders `zip` (arg_staticness ++ repeat NotStatic)
389                                         -- Any extra args are assumed NotStatic
390
391     non_static_args :: [Var]
392             -- non_static_args = [xs]
393             -- rhs_binders_without_type_capture = [\alpha', \beta', c, n, xs]
394     non_static_args = [v | (v, NotStatic) <- binders_w_staticness]
395
396     clone (bndr, NotStatic) = return bndr
397     clone (bndr, _        ) = do { uniq <- newUnique
398                                  ; return (setVarUnique bndr uniq) }
399
400     -- new_rhs = \alpha beta c n xs -> 
401     --           let sat_worker = \xs -> let sat_shadow = \alpha' beta' c n xs -> 
402     --                                       sat_worker xs 
403     --                                   in e
404     --           in sat_worker xs
405     mk_new_rhs uniq shadow_lam_bndrs 
406         = mkLams rhs_binders $ 
407           Let (Rec [(rec_body_bndr, rec_body)]) 
408           local_body
409         where
410           local_body = mkVarApps (Var rec_body_bndr) non_static_args
411
412           rec_body = mkLams non_static_args $
413                      Let (NonRec shadow_bndr shadow_rhs) rhs_body
414
415             -- See Note [Binder type capture]
416           shadow_rhs = mkLams shadow_lam_bndrs local_body
417             -- nonrec_rhs = \alpha' beta' c n xs -> sat_worker xs
418
419           rec_body_bndr = mkSysLocal (fsLit "sat_worker") uniq (exprType rec_body)
420             -- rec_body_bndr = sat_worker
421     
422             -- See Note [Shadow binding]; make a SysLocal
423           shadow_bndr = mkSysLocal (occNameFS (getOccName binder)) 
424                                    (idUnique binder)
425                                    (exprType shadow_rhs)
426
427 isStaticValue :: Staticness App -> Bool
428 isStaticValue (Static (VarApp _)) = True
429 isStaticValue _                   = False
430
431 \end{code}