X-Git-Url: http://git.megacz.com/?a=blobdiff_plain;f=compiler%2Fspecialise%2FSpecialise.lhs;h=055c85821a0eaceb7807ddd9de67539612fc62e1;hb=6246f5738bc482423e51342eb117a40539be790e;hp=67dc39cb23bdd8f63627a3b1f120814131e34578;hpb=fb236fbbea7f12293b030892c6dc866a96566200;p=ghc-hetmet.git diff --git a/compiler/specialise/Specialise.lhs b/compiler/specialise/Specialise.lhs index 67dc39c..055c858 100644 --- a/compiler/specialise/Specialise.lhs +++ b/compiler/specialise/Specialise.lhs @@ -4,7 +4,6 @@ \section[Specialise]{Stamping out overloading, and (optionally) polymorphism} \begin{code} -{-# OPTIONS -w #-} -- The above warning supression flag is a temporary kludge. -- While working on this module you are encouraged to remove it and fix -- any warnings in the module. See @@ -44,11 +43,8 @@ import MkId ( voidArgId, realWorldPrimId ) import FiniteMap import Maybes ( catMaybes, maybeToBool ) import ErrUtils ( dumpIfSet_dyn ) -import BasicTypes ( Activation( AlwaysActive ) ) import Bag -import List ( partition ) -import Util ( zipEqual, zipWithEqual, cmpList, lengthIs, - equalLength, lengthAtLeast, notNull ) +import Util import Outputable import FastString @@ -629,7 +625,7 @@ specExpr :: Subst -> CoreExpr -> SpecM (CoreExpr, UsageDetails) ---------------- First the easy cases -------------------- specExpr subst (Type ty) = return (Type (substTy subst ty), emptyUDs) specExpr subst (Var v) = return (specVar subst v, emptyUDs) -specExpr subst (Lit lit) = return (Lit lit, emptyUDs) +specExpr _ (Lit lit) = return (Lit lit, emptyUDs) specExpr subst (Cast e co) = do (e', uds) <- specExpr subst e return ((Cast e' (substTy subst co)), uds) @@ -639,7 +635,7 @@ specExpr subst (Note note body) = do ---------------- Applications might generate a call instance -------------------- -specExpr subst expr@(App fun arg) +specExpr subst expr@(App {}) = go expr [] where go (App fun arg) args = do (arg', uds_arg) <- specExpr subst arg @@ -649,7 +645,7 @@ specExpr subst expr@(App fun arg) go (Var f) args = case specVar subst f of Var f' -> return (Var f', mkCallUDs subst f' args) e' -> return (e', emptyUDs) -- I don't expect this! - go other args = specExpr subst other + go other _ = specExpr subst other ---------------- Lambda/case require dumping of usage details -------------------- specExpr subst e@(Lam _ _) = do @@ -672,7 +668,7 @@ specExpr subst (Case scrut case_bndr ty alts) = do spec_alt (con, args, rhs) = do (rhs', uds) <- specExpr subst_rhs rhs - let (uds', rhs'') = do dumpUDs args uds rhs' + let (uds', rhs'') = dumpUDs args uds rhs' return ((con, args', rhs''), uds') where (subst_rhs, args') = substBndrs subst_alt args @@ -692,7 +688,8 @@ specExpr subst (Let bind body) = do return (foldr Let body' binds', uds) -- Must apply the type substitution to coerceions -specNote subst note = note +specNote :: Subst -> Note -> Note +specNote _ note = note \end{code} %************************************************************************ @@ -708,47 +705,52 @@ specBind :: Subst -- Use this for RHSs -> SpecM ([CoreBind], -- New bindings UsageDetails) -- And info to pass upstream -specBind rhs_subst bind body_uds = do - (bind', bind_uds) <- specBindItself rhs_subst bind (calls body_uds) - let - bndrs = bindersOf bind - all_uds = zapCalls bndrs (body_uds `plusUDs` bind_uds) - -- It's important that the `plusUDs` is this way round, +specBind rhs_subst bind body_uds + = do { (bind', bind_uds) <- specBindItself rhs_subst bind (calls body_uds) + ; return (finishSpecBind bind' bind_uds body_uds) } + +finishSpecBind :: CoreBind -> UsageDetails -> UsageDetails -> ([CoreBind], UsageDetails) +finishSpecBind bind + (MkUD { dict_binds = rhs_dbs, calls = rhs_calls, ud_fvs = rhs_fvs }) + (MkUD { dict_binds = body_dbs, calls = body_calls, ud_fvs = body_fvs }) + | not (mkVarSet bndrs `intersectsVarSet` all_fvs) + -- Common case 1: the bound variables are not + -- mentioned in the dictionary bindings + = ([bind], MkUD { dict_binds = body_dbs `unionBags` rhs_dbs + -- It's important that the `unionBags` is this way round, -- because body_uds may bind dictionaries that are -- used in the calls passed to specDefn. So the - -- dictionary bindings in bind_uds may mention + -- dictionary bindings in rhs_uds may mention -- dictionaries bound in body_uds. - case splitUDs bndrs all_uds of + , calls = all_calls + , ud_fvs = all_fvs }) + + | case bind of { NonRec {} -> True; Rec {} -> False } + -- Common case 2: no specialisation happened, and binding + -- is non-recursive. But the binding may be + -- mentioned in body_dbs, so we should put it first + = ([], MkUD { dict_binds = rhs_dbs `unionBags` ((bind, b_fvs) `consBag` body_dbs) + , calls = all_calls + , ud_fvs = all_fvs `unionVarSet` b_fvs }) + + | otherwise -- General case: make a huge Rec (sigh) + = ([], MkUD { dict_binds = unitBag (Rec all_db_prs, all_db_fvs) + , calls = all_calls + , ud_fvs = all_fvs `unionVarSet` b_fvs }) + where + all_fvs = rhs_fvs `unionVarSet` body_fvs + all_calls = zapCalls bndrs (rhs_calls `unionCalls` body_calls) - (_, ([],[])) -- This binding doesn't bind anything needed - -- in the UDs, so put the binding here - -- This is the case for most non-dict bindings, except - -- for the few that are mentioned in a dict binding - -- that is floating upwards in body_uds - -> return ([bind'], all_uds) + bndrs = bindersOf bind + b_fvs = bind_fvs bind - (float_uds, (dict_binds, calls)) -- This binding is needed in the UDs, so float it out - -> return ([], float_uds `plusUDs` mkBigUD bind' dict_binds calls) - - --- A truly gruesome function -mkBigUD bind@(NonRec _ _) dbs calls - = -- Common case: non-recursive and no specialisations - -- (if there were any specialistions it would have been made recursive) - MkUD { dict_binds = listToBag (mkDB bind : dbs), - calls = listToCallDetails calls } - -mkBigUD bind dbs calls - = -- General case - MkUD { dict_binds = unitBag (mkDB (Rec (bind_prs bind ++ dbsToPairs dbs))), - -- Make a huge Rec - calls = listToCallDetails calls } - where - bind_prs (NonRec b r) = [(b,r)] - bind_prs (Rec prs) = prs + (all_db_prs, all_db_fvs) = add (bind, b_fvs) $ + foldrBag add ([], emptyVarSet) $ + rhs_dbs `unionBags` body_dbs + add (NonRec b r, b_fvs) (prs, fvs) = ((b,r) : prs, b_fvs `unionVarSet` fvs) + add (Rec b_prs, b_fvs) (prs, fvs) = (b_prs ++ prs, b_fvs `unionVarSet` fvs) - dbsToPairs [] = [] - dbsToPairs ((bind,_):dbs) = bind_prs bind ++ dbsToPairs dbs +specBindItself :: Subst -> CoreBind -> CallDetails -> SpecM (CoreBind, UsageDetails) -- specBindItself deals with the RHS, specialising it according -- to the calls found in the body (if any) @@ -805,7 +807,7 @@ specDefn subst calls (fn, rhs) rhs_uds `plusUDs` plusUDList spec_uds) | otherwise -- No calls or RHS doesn't fit our preconceptions - = WARN( notNull calls_for_me, ptext SLIT("Missed specialisation opportunity for") <+> ppr fn ) + = WARN( notNull calls_for_me, ptext (sLit "Missed specialisation opportunity for") <+> ppr fn ) -- Note [Specialisation shape] (do { (rhs', rhs_uds) <- specExpr subst rhs ; return ((fn, rhs'), [], rhs_uds) }) @@ -823,7 +825,6 @@ specDefn subst calls (fn, rhs) (rhs_tyvars, rhs_ids, rhs_body) = collectTyAndValBinders rhs_inside rhs_dicts = take n_dicts rhs_ids - rhs_bndrs = rhs_tyvars ++ rhs_dicts body = mkLams (drop n_dicts rhs_ids) rhs_body -- Glue back on the non-dict lambdas @@ -837,7 +838,7 @@ specDefn subst calls (fn, rhs) -> SpecM ((Id,CoreExpr), -- Specialised definition UsageDetails, -- Usage details from specialised body CoreRule) -- Info for the Id's SpecEnv - spec_call (CallKey call_ts, (call_ds, call_fvs)) + spec_call (CallKey call_ts, (call_ds, _)) = ASSERT( call_ts `lengthIs` n_tyvars && call_ds `lengthIs` n_dicts ) do -- Calls are only recorded for properly-saturated applications @@ -863,7 +864,7 @@ specDefn subst calls (fn, rhs) ty_args = zipWithEqual "spec_call" mk_ty_arg rhs_tyvars call_ts where mk_ty_arg rhs_tyvar Nothing = Type (mkTyVarTy rhs_tyvar) - mk_ty_arg rhs_tyvar (Just ty) = Type ty + mk_ty_arg _ (Just ty) = Type ty rhs_subst = extendTvSubstList subst (spec_tyvars `zip` [ty | Just ty <- call_ts]) (rhs_subst', rhs_dicts') <- cloneBinders rhs_subst rhs_dicts @@ -900,14 +901,13 @@ specDefn subst calls (fn, rhs) where my_zipEqual doc xs ys -#ifdef DEBUG - | not (equalLength xs ys) = pprPanic "my_zipEqual" (vcat + | debugIsOn && not (equalLength xs ys) + = pprPanic "my_zipEqual" (vcat [ ppr xs, ppr ys , ppr fn <+> ppr call_ts , ppr (idType fn), ppr theta , ppr n_dicts, ppr rhs_dicts , ppr rhs]) -#endif | otherwise = zipEqual doc xs ys \end{code} @@ -1010,20 +1010,27 @@ data UsageDetails -- in ds1 `union` ds2, bindings in ds2 can depend on those in ds1 -- (Remember, Bags preserve order in GHC.) - calls :: !CallDetails + calls :: !CallDetails, + + ud_fvs :: !VarSet -- A superset of the variables mentioned in + -- either dict_binds or calls } +instance Outputable UsageDetails where + ppr (MkUD { dict_binds = dbs, calls = calls, ud_fvs = fvs }) + = ptext (sLit "MkUD") <+> braces (sep (punctuate comma + [ptext (sLit "binds") <+> equals <+> ppr dbs, + ptext (sLit "calls") <+> equals <+> ppr calls, + ptext (sLit "fvs") <+> equals <+> ppr fvs])) + type DictBind = (CoreBind, VarSet) -- The set is the free vars of the binding -- both tyvars and dicts type DictExpr = CoreExpr -emptyUDs = MkUD { dict_binds = emptyBag, calls = emptyFM } - -type ProtoUsageDetails = ([DictBind], - [(Id, CallKey, ([DictExpr], VarSet))] - ) +emptyUDs :: UsageDetails +emptyUDs = MkUD { dict_binds = emptyBag, calls = emptyFM, ud_fvs = emptyVarSet } ------------------------------------------------------------ type CallDetails = FiniteMap Id CallInfo @@ -1036,25 +1043,30 @@ type CallInfo = FiniteMap CallKey -- The list of types and dictionaries is guaranteed to -- match the type of f +instance Outputable CallKey where + ppr (CallKey ts) = ppr ts + -- Type isn't an instance of Ord, so that we can control which -- instance we use. That's tiresome here. Oh well instance Eq CallKey where - k1 == k2 = case k1 `compare` k2 of { EQ -> True; other -> False } + k1 == k2 = case k1 `compare` k2 of { EQ -> True; _ -> False } instance Ord CallKey where compare (CallKey k1) (CallKey k2) = cmpList cmp k1 k2 where - cmp Nothing Nothing = EQ - cmp Nothing (Just t2) = LT - cmp (Just t1) Nothing = GT + cmp Nothing Nothing = EQ + cmp Nothing (Just _) = LT + cmp (Just _) Nothing = GT cmp (Just t1) (Just t2) = tcCmpType t1 t2 unionCalls :: CallDetails -> CallDetails -> CallDetails unionCalls c1 c2 = plusFM_C plusFM c1 c2 -singleCall :: Id -> [Maybe Type] -> [DictExpr] -> CallDetails +singleCall :: Id -> [Maybe Type] -> [DictExpr] -> UsageDetails singleCall id tys dicts - = unitFM id (unitFM (CallKey tys) (dicts, call_fvs)) + = MkUD {dict_binds = emptyBag, + calls = unitFM id (unitFM (CallKey tys) (dicts, call_fvs)), + ud_fvs = call_fvs } where call_fvs = exprsFreeVars dicts `unionVarSet` tys_fvs tys_fvs = tyVarsOfTypes (catMaybes tys) @@ -1068,17 +1080,7 @@ singleCall id tys dicts -- -- We don't include the 'id' itself. -listToCallDetails calls - = foldr (unionCalls . mk_call) emptyFM calls - where - mk_call (id, tys, dicts_w_fvs) = unitFM id (unitFM tys dicts_w_fvs) - -- NB: the free vars of the call are provided - -callDetailsToList calls = [ (id,tys,dicts) - | (id,fm) <- fmToList calls, - (tys, dicts) <- fmToList fm - ] - +mkCallUDs :: Subst -> Id -> [CoreExpr] -> UsageDetails mkCallUDs subst f args | null theta || not (all isClassPred theta) @@ -1087,7 +1089,7 @@ mkCallUDs subst f args -- *don't* say what the value of the implicit param is! || not (spec_tys `lengthIs` n_tyvars) || not ( dicts `lengthIs` n_dicts) - || maybeToBool (lookupRule (\act -> True) (substInScope subst) emptyRuleBase f args) + || maybeToBool (lookupRule (\_act -> True) (substInScope subst) emptyRuleBase f args) -- There's already a rule covering this call. A typical case -- is where there's an explicit user-provided rule. Then -- we don't want to create a specialised version @@ -1095,9 +1097,7 @@ mkCallUDs subst f args = emptyUDs -- Not overloaded, or no specialisation wanted | otherwise - = MkUD {dict_binds = emptyBag, - calls = singleCall f spec_tys dicts - } + = singleCall f spec_tys dicts where (tyvars, theta, _) = tcSplitSigmaTy (idType f) constrained_tyvars = tyVarsOfTheta theta @@ -1113,26 +1113,31 @@ mkCallUDs subst f args ------------------------------------------------------------ plusUDs :: UsageDetails -> UsageDetails -> UsageDetails -plusUDs (MkUD {dict_binds = db1, calls = calls1}) - (MkUD {dict_binds = db2, calls = calls2}) - = MkUD {dict_binds = d, calls = c} +plusUDs (MkUD {dict_binds = db1, calls = calls1, ud_fvs = fvs1}) + (MkUD {dict_binds = db2, calls = calls2, ud_fvs = fvs2}) + = MkUD {dict_binds = d, calls = c, ud_fvs = fvs1 `unionVarSet` fvs2} where d = db1 `unionBags` db2 c = calls1 `unionCalls` calls2 +plusUDList :: [UsageDetails] -> UsageDetails plusUDList = foldr plusUDs emptyUDs -- zapCalls deletes calls to ids from uds -zapCalls ids uds = uds {calls = delListFromFM (calls uds) ids} +zapCalls :: [Id] -> CallDetails -> CallDetails +zapCalls ids calls = delListFromFM calls ids +mkDB :: CoreBind -> DictBind mkDB bind = (bind, bind_fvs bind) +bind_fvs :: CoreBind -> VarSet bind_fvs (NonRec bndr rhs) = pair_fvs (bndr,rhs) bind_fvs (Rec prs) = foldl delVarSet rhs_fvs bndrs where bndrs = map fst prs rhs_fvs = unionVarSets (map pair_fvs prs) +pair_fvs :: (Id, CoreExpr) -> VarSet pair_fvs (bndr, rhs) = exprFreeVars rhs `unionVarSet` idFreeVars bndr -- Don't forget variables mentioned in the -- rules of the bndr. C.f. OccAnal.addRuleUsage @@ -1140,8 +1145,14 @@ pair_fvs (bndr, rhs) = exprFreeVars rhs `unionVarSet` idFreeVars bndr -- type T a = Int -- x :: T a = 3 -addDictBind (dict,rhs) uds = uds { dict_binds = mkDB (NonRec dict rhs) `consBag` dict_binds uds } +addDictBind :: (Id,CoreExpr) -> UsageDetails -> UsageDetails +addDictBind (dict,rhs) uds + = uds { dict_binds = db `consBag` dict_binds uds + , ud_fvs = ud_fvs uds `unionVarSet` fvs } + where + db@(_, fvs) = mkDB (NonRec dict rhs) +dumpAllDictBinds :: UsageDetails -> [CoreBind] -> [CoreBind] dumpAllDictBinds (MkUD {dict_binds = dbs}) binds = foldrBag add binds dbs where @@ -1150,44 +1161,23 @@ dumpAllDictBinds (MkUD {dict_binds = dbs}) binds dumpUDs :: [CoreBndr] -> UsageDetails -> CoreExpr -> (UsageDetails, CoreExpr) -dumpUDs bndrs uds body - = (free_uds, foldr add_let body dict_binds) - where - (free_uds, (dict_binds, _)) = splitUDs bndrs uds - add_let (bind,_) body = Let bind body - -splitUDs :: [CoreBndr] - -> UsageDetails - -> (UsageDetails, -- These don't mention the binders - ProtoUsageDetails) -- These do - -splitUDs bndrs uds@(MkUD {dict_binds = orig_dbs, - calls = orig_calls}) - - = if isEmptyBag dump_dbs && null dump_calls then - -- Common case: binder doesn't affect floats - (uds, ([],[])) - - else - -- Binders bind some of the fvs of the floats - (MkUD {dict_binds = free_dbs, - calls = listToCallDetails free_calls}, - (bagToList dump_dbs, dump_calls) - ) - +dumpUDs bndrs (MkUD { dict_binds = orig_dbs + , calls = orig_calls + , ud_fvs = fvs}) body + = (MkUD { dict_binds = free_dbs + , calls = free_calls + , ud_fvs = fvs `minusVarSet` bndr_set}, -- This may delete fewer variables + foldrBag add_let body dump_dbs) -- than in priciple possible where bndr_set = mkVarSet bndrs + add_let (bind,_) body = Let bind body - (free_dbs, dump_dbs, dump_idset) - = foldlBag dump_db (emptyBag, emptyBag, bndr_set) orig_dbs + (free_dbs, dump_dbs, dump_set) + = foldlBag dump_db (emptyBag, emptyBag, bndr_set) orig_dbs -- Important that it's foldl not foldr; -- we're accumulating the set of dumped ids in dump_set - -- Filter out any calls that mention things that are being dumped - orig_call_list = callDetailsToList orig_calls - (dump_calls, free_calls) = partition captured orig_call_list - captured (id,tys,(dicts, fvs)) = fvs `intersectsVarSet` dump_idset - || id `elemVarSet` dump_idset + free_calls = filterCalls dump_set orig_calls dump_db (free_dbs, dump_dbs, dump_idset) db@(bind, fvs) | dump_idset `intersectsVarSet` fvs -- Dump it @@ -1196,6 +1186,15 @@ splitUDs bndrs uds@(MkUD {dict_binds = orig_dbs, | otherwise -- Don't dump it = (free_dbs `snocBag` db, dump_dbs, dump_idset) + +filterCalls :: VarSet -> CallDetails -> CallDetails +-- Remove any calls that mention the variables +filterCalls bs calls + = mapFM (\_ cs -> filter_calls cs) $ + filterFM (\k _ -> k `elemVarSet` bs) calls + where + filter_calls :: CallInfo -> CallInfo + filter_calls = filterFM (\_ (_, fvs) -> fvs `intersectsVarSet` bs) \end{code} @@ -1208,9 +1207,11 @@ splitUDs bndrs uds@(MkUD {dict_binds = orig_dbs, \begin{code} type SpecM a = UniqSM a +initSM :: UniqSupply -> SpecM a -> a initSM = initUs_ -mapAndCombineSM f [] = return ([], emptyUDs) +mapAndCombineSM :: (a -> SpecM (b, UsageDetails)) -> [a] -> SpecM ([b], UsageDetails) +mapAndCombineSM _ [] = return ([], emptyUDs) mapAndCombineSM f (x:xs) = do (y, uds1) <- f x (ys, uds2) <- mapAndCombineSM f xs return (y:ys, uds1 `plusUDs` uds2) @@ -1220,7 +1221,7 @@ cloneBindSM :: Subst -> CoreBind -> SpecM (Subst, Subst, CoreBind) -- Return the substitution to use for RHSs, and the one to use for the body cloneBindSM subst (NonRec bndr rhs) = do us <- getUniqueSupplyM - let (subst', bndr') = do cloneIdBndr subst us bndr + let (subst', bndr') = cloneIdBndr subst us bndr return (subst, subst', NonRec bndr' rhs) cloneBindSM subst (Rec pairs) = do @@ -1228,10 +1229,12 @@ cloneBindSM subst (Rec pairs) = do let (subst', bndrs') = cloneRecIdBndrs subst us (map fst pairs) return (subst', subst', Rec (bndrs' `zip` map snd pairs)) +cloneBinders :: Subst -> [CoreBndr] -> SpecM (Subst, [CoreBndr]) cloneBinders subst bndrs = do us <- getUniqueSupplyM return (cloneIdBndrs subst us bndrs) +newIdSM :: Id -> Type -> SpecM Id newIdSM old_id new_ty = do uniq <- getUniqueM let