From: simonpj@microsoft.com Date: Mon, 28 Apr 2008 15:57:11 +0000 (+0000) Subject: Fix Trac #1969: perfomance bug in the specialiser X-Git-Url: http://git.megacz.com/?p=ghc-hetmet.git;a=commitdiff_plain;h=6246f5738bc482423e51342eb117a40539be790e Fix Trac #1969: perfomance bug in the specialiser The specialiser was using a rather brain-dead representation for UsageDetails, with much converting from lists to finite maps and back. This patch does some significant refactoring. It doesn't change the representation altogether, but it does eliminate the to-and-fro nonsense. It validates OK, but it's always possible that I have inadvertently lost specialisation somewhere, so keep an eye out for any run-time performance regressions. Oh, and Specialise is now warning-free too. --- diff --git a/compiler/specialise/Specialise.lhs b/compiler/specialise/Specialise.lhs index 9455f0f..055c858 100644 --- a/compiler/specialise/Specialise.lhs +++ b/compiler/specialise/Specialise.lhs @@ -4,7 +4,6 @@ \section[Specialise]{Stamping out overloading, and (optionally) polymorphism} \begin{code} -{-# OPTIONS -w #-} -- The above warning supression flag is a temporary kludge. -- While working on this module you are encouraged to remove it and fix -- any warnings in the module. See @@ -44,9 +43,7 @@ import MkId ( voidArgId, realWorldPrimId ) import FiniteMap import Maybes ( catMaybes, maybeToBool ) import ErrUtils ( dumpIfSet_dyn ) -import BasicTypes ( Activation( AlwaysActive ) ) import Bag -import List ( partition ) import Util import Outputable import FastString @@ -628,7 +625,7 @@ specExpr :: Subst -> CoreExpr -> SpecM (CoreExpr, UsageDetails) ---------------- First the easy cases -------------------- specExpr subst (Type ty) = return (Type (substTy subst ty), emptyUDs) specExpr subst (Var v) = return (specVar subst v, emptyUDs) -specExpr subst (Lit lit) = return (Lit lit, emptyUDs) +specExpr _ (Lit lit) = return (Lit lit, emptyUDs) specExpr subst (Cast e co) = do (e', uds) <- specExpr subst e return ((Cast e' (substTy subst co)), uds) @@ -638,7 +635,7 @@ specExpr subst (Note note body) = do ---------------- Applications might generate a call instance -------------------- -specExpr subst expr@(App fun arg) +specExpr subst expr@(App {}) = go expr [] where go (App fun arg) args = do (arg', uds_arg) <- specExpr subst arg @@ -648,7 +645,7 @@ specExpr subst expr@(App fun arg) go (Var f) args = case specVar subst f of Var f' -> return (Var f', mkCallUDs subst f' args) e' -> return (e', emptyUDs) -- I don't expect this! - go other args = specExpr subst other + go other _ = specExpr subst other ---------------- Lambda/case require dumping of usage details -------------------- specExpr subst e@(Lam _ _) = do @@ -691,7 +688,8 @@ specExpr subst (Let bind body) = do return (foldr Let body' binds', uds) -- Must apply the type substitution to coerceions -specNote subst note = note +specNote :: Subst -> Note -> Note +specNote _ note = note \end{code} %************************************************************************ @@ -707,47 +705,52 @@ specBind :: Subst -- Use this for RHSs -> SpecM ([CoreBind], -- New bindings UsageDetails) -- And info to pass upstream -specBind rhs_subst bind body_uds = do - (bind', bind_uds) <- specBindItself rhs_subst bind (calls body_uds) - let - bndrs = bindersOf bind - all_uds = zapCalls bndrs (body_uds `plusUDs` bind_uds) - -- It's important that the `plusUDs` is this way round, +specBind rhs_subst bind body_uds + = do { (bind', bind_uds) <- specBindItself rhs_subst bind (calls body_uds) + ; return (finishSpecBind bind' bind_uds body_uds) } + +finishSpecBind :: CoreBind -> UsageDetails -> UsageDetails -> ([CoreBind], UsageDetails) +finishSpecBind bind + (MkUD { dict_binds = rhs_dbs, calls = rhs_calls, ud_fvs = rhs_fvs }) + (MkUD { dict_binds = body_dbs, calls = body_calls, ud_fvs = body_fvs }) + | not (mkVarSet bndrs `intersectsVarSet` all_fvs) + -- Common case 1: the bound variables are not + -- mentioned in the dictionary bindings + = ([bind], MkUD { dict_binds = body_dbs `unionBags` rhs_dbs + -- It's important that the `unionBags` is this way round, -- because body_uds may bind dictionaries that are -- used in the calls passed to specDefn. So the - -- dictionary bindings in bind_uds may mention + -- dictionary bindings in rhs_uds may mention -- dictionaries bound in body_uds. - case splitUDs bndrs all_uds of + , calls = all_calls + , ud_fvs = all_fvs }) + + | case bind of { NonRec {} -> True; Rec {} -> False } + -- Common case 2: no specialisation happened, and binding + -- is non-recursive. But the binding may be + -- mentioned in body_dbs, so we should put it first + = ([], MkUD { dict_binds = rhs_dbs `unionBags` ((bind, b_fvs) `consBag` body_dbs) + , calls = all_calls + , ud_fvs = all_fvs `unionVarSet` b_fvs }) + + | otherwise -- General case: make a huge Rec (sigh) + = ([], MkUD { dict_binds = unitBag (Rec all_db_prs, all_db_fvs) + , calls = all_calls + , ud_fvs = all_fvs `unionVarSet` b_fvs }) + where + all_fvs = rhs_fvs `unionVarSet` body_fvs + all_calls = zapCalls bndrs (rhs_calls `unionCalls` body_calls) - (_, ([],[])) -- This binding doesn't bind anything needed - -- in the UDs, so put the binding here - -- This is the case for most non-dict bindings, except - -- for the few that are mentioned in a dict binding - -- that is floating upwards in body_uds - -> return ([bind'], all_uds) + bndrs = bindersOf bind + b_fvs = bind_fvs bind - (float_uds, (dict_binds, calls)) -- This binding is needed in the UDs, so float it out - -> return ([], float_uds `plusUDs` mkBigUD bind' dict_binds calls) - - --- A truly gruesome function -mkBigUD bind@(NonRec _ _) dbs calls - = -- Common case: non-recursive and no specialisations - -- (if there were any specialistions it would have been made recursive) - MkUD { dict_binds = listToBag (mkDB bind : dbs), - calls = listToCallDetails calls } - -mkBigUD bind dbs calls - = -- General case - MkUD { dict_binds = unitBag (mkDB (Rec (bind_prs bind ++ dbsToPairs dbs))), - -- Make a huge Rec - calls = listToCallDetails calls } - where - bind_prs (NonRec b r) = [(b,r)] - bind_prs (Rec prs) = prs + (all_db_prs, all_db_fvs) = add (bind, b_fvs) $ + foldrBag add ([], emptyVarSet) $ + rhs_dbs `unionBags` body_dbs + add (NonRec b r, b_fvs) (prs, fvs) = ((b,r) : prs, b_fvs `unionVarSet` fvs) + add (Rec b_prs, b_fvs) (prs, fvs) = (b_prs ++ prs, b_fvs `unionVarSet` fvs) - dbsToPairs [] = [] - dbsToPairs ((bind,_):dbs) = bind_prs bind ++ dbsToPairs dbs +specBindItself :: Subst -> CoreBind -> CallDetails -> SpecM (CoreBind, UsageDetails) -- specBindItself deals with the RHS, specialising it according -- to the calls found in the body (if any) @@ -822,7 +825,6 @@ specDefn subst calls (fn, rhs) (rhs_tyvars, rhs_ids, rhs_body) = collectTyAndValBinders rhs_inside rhs_dicts = take n_dicts rhs_ids - rhs_bndrs = rhs_tyvars ++ rhs_dicts body = mkLams (drop n_dicts rhs_ids) rhs_body -- Glue back on the non-dict lambdas @@ -836,7 +838,7 @@ specDefn subst calls (fn, rhs) -> SpecM ((Id,CoreExpr), -- Specialised definition UsageDetails, -- Usage details from specialised body CoreRule) -- Info for the Id's SpecEnv - spec_call (CallKey call_ts, (call_ds, call_fvs)) + spec_call (CallKey call_ts, (call_ds, _)) = ASSERT( call_ts `lengthIs` n_tyvars && call_ds `lengthIs` n_dicts ) do -- Calls are only recorded for properly-saturated applications @@ -862,7 +864,7 @@ specDefn subst calls (fn, rhs) ty_args = zipWithEqual "spec_call" mk_ty_arg rhs_tyvars call_ts where mk_ty_arg rhs_tyvar Nothing = Type (mkTyVarTy rhs_tyvar) - mk_ty_arg rhs_tyvar (Just ty) = Type ty + mk_ty_arg _ (Just ty) = Type ty rhs_subst = extendTvSubstList subst (spec_tyvars `zip` [ty | Just ty <- call_ts]) (rhs_subst', rhs_dicts') <- cloneBinders rhs_subst rhs_dicts @@ -1008,20 +1010,27 @@ data UsageDetails -- in ds1 `union` ds2, bindings in ds2 can depend on those in ds1 -- (Remember, Bags preserve order in GHC.) - calls :: !CallDetails + calls :: !CallDetails, + + ud_fvs :: !VarSet -- A superset of the variables mentioned in + -- either dict_binds or calls } +instance Outputable UsageDetails where + ppr (MkUD { dict_binds = dbs, calls = calls, ud_fvs = fvs }) + = ptext (sLit "MkUD") <+> braces (sep (punctuate comma + [ptext (sLit "binds") <+> equals <+> ppr dbs, + ptext (sLit "calls") <+> equals <+> ppr calls, + ptext (sLit "fvs") <+> equals <+> ppr fvs])) + type DictBind = (CoreBind, VarSet) -- The set is the free vars of the binding -- both tyvars and dicts type DictExpr = CoreExpr -emptyUDs = MkUD { dict_binds = emptyBag, calls = emptyFM } - -type ProtoUsageDetails = ([DictBind], - [(Id, CallKey, ([DictExpr], VarSet))] - ) +emptyUDs :: UsageDetails +emptyUDs = MkUD { dict_binds = emptyBag, calls = emptyFM, ud_fvs = emptyVarSet } ------------------------------------------------------------ type CallDetails = FiniteMap Id CallInfo @@ -1034,25 +1043,30 @@ type CallInfo = FiniteMap CallKey -- The list of types and dictionaries is guaranteed to -- match the type of f +instance Outputable CallKey where + ppr (CallKey ts) = ppr ts + -- Type isn't an instance of Ord, so that we can control which -- instance we use. That's tiresome here. Oh well instance Eq CallKey where - k1 == k2 = case k1 `compare` k2 of { EQ -> True; other -> False } + k1 == k2 = case k1 `compare` k2 of { EQ -> True; _ -> False } instance Ord CallKey where compare (CallKey k1) (CallKey k2) = cmpList cmp k1 k2 where - cmp Nothing Nothing = EQ - cmp Nothing (Just t2) = LT - cmp (Just t1) Nothing = GT + cmp Nothing Nothing = EQ + cmp Nothing (Just _) = LT + cmp (Just _) Nothing = GT cmp (Just t1) (Just t2) = tcCmpType t1 t2 unionCalls :: CallDetails -> CallDetails -> CallDetails unionCalls c1 c2 = plusFM_C plusFM c1 c2 -singleCall :: Id -> [Maybe Type] -> [DictExpr] -> CallDetails +singleCall :: Id -> [Maybe Type] -> [DictExpr] -> UsageDetails singleCall id tys dicts - = unitFM id (unitFM (CallKey tys) (dicts, call_fvs)) + = MkUD {dict_binds = emptyBag, + calls = unitFM id (unitFM (CallKey tys) (dicts, call_fvs)), + ud_fvs = call_fvs } where call_fvs = exprsFreeVars dicts `unionVarSet` tys_fvs tys_fvs = tyVarsOfTypes (catMaybes tys) @@ -1066,17 +1080,7 @@ singleCall id tys dicts -- -- We don't include the 'id' itself. -listToCallDetails calls - = foldr (unionCalls . mk_call) emptyFM calls - where - mk_call (id, tys, dicts_w_fvs) = unitFM id (unitFM tys dicts_w_fvs) - -- NB: the free vars of the call are provided - -callDetailsToList calls = [ (id,tys,dicts) - | (id,fm) <- fmToList calls, - (tys, dicts) <- fmToList fm - ] - +mkCallUDs :: Subst -> Id -> [CoreExpr] -> UsageDetails mkCallUDs subst f args | null theta || not (all isClassPred theta) @@ -1085,7 +1089,7 @@ mkCallUDs subst f args -- *don't* say what the value of the implicit param is! || not (spec_tys `lengthIs` n_tyvars) || not ( dicts `lengthIs` n_dicts) - || maybeToBool (lookupRule (\act -> True) (substInScope subst) emptyRuleBase f args) + || maybeToBool (lookupRule (\_act -> True) (substInScope subst) emptyRuleBase f args) -- There's already a rule covering this call. A typical case -- is where there's an explicit user-provided rule. Then -- we don't want to create a specialised version @@ -1093,9 +1097,7 @@ mkCallUDs subst f args = emptyUDs -- Not overloaded, or no specialisation wanted | otherwise - = MkUD {dict_binds = emptyBag, - calls = singleCall f spec_tys dicts - } + = singleCall f spec_tys dicts where (tyvars, theta, _) = tcSplitSigmaTy (idType f) constrained_tyvars = tyVarsOfTheta theta @@ -1111,26 +1113,31 @@ mkCallUDs subst f args ------------------------------------------------------------ plusUDs :: UsageDetails -> UsageDetails -> UsageDetails -plusUDs (MkUD {dict_binds = db1, calls = calls1}) - (MkUD {dict_binds = db2, calls = calls2}) - = MkUD {dict_binds = d, calls = c} +plusUDs (MkUD {dict_binds = db1, calls = calls1, ud_fvs = fvs1}) + (MkUD {dict_binds = db2, calls = calls2, ud_fvs = fvs2}) + = MkUD {dict_binds = d, calls = c, ud_fvs = fvs1 `unionVarSet` fvs2} where d = db1 `unionBags` db2 c = calls1 `unionCalls` calls2 +plusUDList :: [UsageDetails] -> UsageDetails plusUDList = foldr plusUDs emptyUDs -- zapCalls deletes calls to ids from uds -zapCalls ids uds = uds {calls = delListFromFM (calls uds) ids} +zapCalls :: [Id] -> CallDetails -> CallDetails +zapCalls ids calls = delListFromFM calls ids +mkDB :: CoreBind -> DictBind mkDB bind = (bind, bind_fvs bind) +bind_fvs :: CoreBind -> VarSet bind_fvs (NonRec bndr rhs) = pair_fvs (bndr,rhs) bind_fvs (Rec prs) = foldl delVarSet rhs_fvs bndrs where bndrs = map fst prs rhs_fvs = unionVarSets (map pair_fvs prs) +pair_fvs :: (Id, CoreExpr) -> VarSet pair_fvs (bndr, rhs) = exprFreeVars rhs `unionVarSet` idFreeVars bndr -- Don't forget variables mentioned in the -- rules of the bndr. C.f. OccAnal.addRuleUsage @@ -1138,8 +1145,14 @@ pair_fvs (bndr, rhs) = exprFreeVars rhs `unionVarSet` idFreeVars bndr -- type T a = Int -- x :: T a = 3 -addDictBind (dict,rhs) uds = uds { dict_binds = mkDB (NonRec dict rhs) `consBag` dict_binds uds } +addDictBind :: (Id,CoreExpr) -> UsageDetails -> UsageDetails +addDictBind (dict,rhs) uds + = uds { dict_binds = db `consBag` dict_binds uds + , ud_fvs = ud_fvs uds `unionVarSet` fvs } + where + db@(_, fvs) = mkDB (NonRec dict rhs) +dumpAllDictBinds :: UsageDetails -> [CoreBind] -> [CoreBind] dumpAllDictBinds (MkUD {dict_binds = dbs}) binds = foldrBag add binds dbs where @@ -1148,44 +1161,23 @@ dumpAllDictBinds (MkUD {dict_binds = dbs}) binds dumpUDs :: [CoreBndr] -> UsageDetails -> CoreExpr -> (UsageDetails, CoreExpr) -dumpUDs bndrs uds body - = (free_uds, foldr add_let body dict_binds) - where - (free_uds, (dict_binds, _)) = splitUDs bndrs uds - add_let (bind,_) body = Let bind body - -splitUDs :: [CoreBndr] - -> UsageDetails - -> (UsageDetails, -- These don't mention the binders - ProtoUsageDetails) -- These do - -splitUDs bndrs uds@(MkUD {dict_binds = orig_dbs, - calls = orig_calls}) - - = if isEmptyBag dump_dbs && null dump_calls then - -- Common case: binder doesn't affect floats - (uds, ([],[])) - - else - -- Binders bind some of the fvs of the floats - (MkUD {dict_binds = free_dbs, - calls = listToCallDetails free_calls}, - (bagToList dump_dbs, dump_calls) - ) - +dumpUDs bndrs (MkUD { dict_binds = orig_dbs + , calls = orig_calls + , ud_fvs = fvs}) body + = (MkUD { dict_binds = free_dbs + , calls = free_calls + , ud_fvs = fvs `minusVarSet` bndr_set}, -- This may delete fewer variables + foldrBag add_let body dump_dbs) -- than in priciple possible where bndr_set = mkVarSet bndrs + add_let (bind,_) body = Let bind body - (free_dbs, dump_dbs, dump_idset) - = foldlBag dump_db (emptyBag, emptyBag, bndr_set) orig_dbs + (free_dbs, dump_dbs, dump_set) + = foldlBag dump_db (emptyBag, emptyBag, bndr_set) orig_dbs -- Important that it's foldl not foldr; -- we're accumulating the set of dumped ids in dump_set - -- Filter out any calls that mention things that are being dumped - orig_call_list = callDetailsToList orig_calls - (dump_calls, free_calls) = partition captured orig_call_list - captured (id,tys,(dicts, fvs)) = fvs `intersectsVarSet` dump_idset - || id `elemVarSet` dump_idset + free_calls = filterCalls dump_set orig_calls dump_db (free_dbs, dump_dbs, dump_idset) db@(bind, fvs) | dump_idset `intersectsVarSet` fvs -- Dump it @@ -1194,6 +1186,15 @@ splitUDs bndrs uds@(MkUD {dict_binds = orig_dbs, | otherwise -- Don't dump it = (free_dbs `snocBag` db, dump_dbs, dump_idset) + +filterCalls :: VarSet -> CallDetails -> CallDetails +-- Remove any calls that mention the variables +filterCalls bs calls + = mapFM (\_ cs -> filter_calls cs) $ + filterFM (\k _ -> k `elemVarSet` bs) calls + where + filter_calls :: CallInfo -> CallInfo + filter_calls = filterFM (\_ (_, fvs) -> fvs `intersectsVarSet` bs) \end{code} @@ -1206,9 +1207,11 @@ splitUDs bndrs uds@(MkUD {dict_binds = orig_dbs, \begin{code} type SpecM a = UniqSM a +initSM :: UniqSupply -> SpecM a -> a initSM = initUs_ -mapAndCombineSM f [] = return ([], emptyUDs) +mapAndCombineSM :: (a -> SpecM (b, UsageDetails)) -> [a] -> SpecM ([b], UsageDetails) +mapAndCombineSM _ [] = return ([], emptyUDs) mapAndCombineSM f (x:xs) = do (y, uds1) <- f x (ys, uds2) <- mapAndCombineSM f xs return (y:ys, uds1 `plusUDs` uds2) @@ -1226,10 +1229,12 @@ cloneBindSM subst (Rec pairs) = do let (subst', bndrs') = cloneRecIdBndrs subst us (map fst pairs) return (subst', subst', Rec (bndrs' `zip` map snd pairs)) +cloneBinders :: Subst -> [CoreBndr] -> SpecM (Subst, [CoreBndr]) cloneBinders subst bndrs = do us <- getUniqueSupplyM return (cloneIdBndrs subst us bndrs) +newIdSM :: Id -> Type -> SpecM Id newIdSM old_id new_ty = do uniq <- getUniqueM let