X-Git-Url: http://git.megacz.com/?p=ghc-hetmet.git;a=blobdiff_plain;f=compiler%2Fspecialise%2FSpecialise.lhs;h=f6f85a114099ce6c91ae90cdc4b5aee8286a51ac;hp=2d0b383c1aeb1790e5dc01bcfbf56de8e3ff128e;hb=707ea5881703d680155aab268bdbf7edc113e3b1;hpb=e95ee1f718c6915c478005aad8af81705357d6ab diff --git a/compiler/specialise/Specialise.lhs b/compiler/specialise/Specialise.lhs index 2d0b383..f6f85a1 100644 --- a/compiler/specialise/Specialise.lhs +++ b/compiler/specialise/Specialise.lhs @@ -10,19 +10,21 @@ module Specialise ( specProgram ) where import Id import TcType +import CoreMonad import CoreSubst -import CoreUnfold ( mkUnfolding, mkInlineRule ) +import CoreUnfold import VarSet import VarEnv import CoreSyn import Rules import CoreUtils ( exprIsTrivial, applyTypeToArgs, mkPiTypes ) import CoreFVs ( exprFreeVars, exprsFreeVars, idFreeVars ) -import UniqSupply ( UniqSupply, UniqSM, initUs_, MonadUnique(..) ) +import UniqSupply ( UniqSM, initUs_, MonadUnique(..) ) import Name import MkId ( voidArgId, realWorldPrimId ) import Maybes ( catMaybes, isJust ) -import BasicTypes ( isNeverActive, inlinePragmaActivation ) +import BasicTypes +import HscTypes import Bag import Util import Outputable @@ -558,24 +560,98 @@ Hence, the invariant is this: %************************************************************************ \begin{code} -specProgram :: UniqSupply -> [CoreBind] -> [CoreBind] -specProgram us binds = initSM us $ - do { (binds', uds') <- go binds - ; return (wrapDictBinds (ud_binds uds') binds') } +specProgram :: ModGuts -> CoreM ModGuts +specProgram guts + = do { hpt_rules <- getRuleBase + ; let local_rules = mg_rules guts + rule_base = extendRuleBaseList hpt_rules (mg_rules guts) + + -- Specialise the bindings of this module + ; (binds', uds) <- runSpecM (go (mg_binds guts)) + + -- Specialise imported functions + ; (new_rules, spec_binds) <- specImports emptyVarSet rule_base uds + + ; return (guts { mg_binds = spec_binds ++ binds' + , mg_rules = local_rules ++ new_rules }) } where -- We need to start with a Subst that knows all the things -- that are in scope, so that the substitution engine doesn't -- accidentally re-use a unique that's already in use -- Easiest thing is to do it all at once, as if all the top-level -- decls were mutually recursive - top_subst = mkEmptySubst (mkInScopeSet (mkVarSet (bindersOfBinds binds))) + top_subst = mkEmptySubst $ mkInScopeSet $ mkVarSet $ + bindersOfBinds $ mg_binds guts go [] = return ([], emptyUDs) go (bind:binds) = do (binds', uds) <- go binds (bind', uds') <- specBind top_subst bind uds return (bind' ++ binds', uds') + +specImports :: VarSet -- Don't specialise these ones + -- See Note [Avoiding recursive specialisation] + -> RuleBase -- Rules from this module and the home package + -- (but not external packages, which can change) + -> UsageDetails -- Calls for imported things, and floating bindings + -> CoreM ( [CoreRule] -- New rules + , [CoreBind] ) -- Specialised bindings and floating bindings +specImports done rb uds + = do { let import_calls = varEnvElts (ud_calls uds) + ; (rules, spec_binds) <- go rb import_calls + ; return (rules, wrapDictBinds (ud_binds uds) spec_binds) } + where + go _ [] = return ([], []) + go rb (CIS fn calls_for_fn : other_calls) + = do { (rules1, spec_binds1) <- specImport done rb fn (Map.toList calls_for_fn) + ; (rules2, spec_binds2) <- go (extendRuleBaseList rb rules1) other_calls + ; return (rules1 ++ rules2, spec_binds1 ++ spec_binds2) } + +specImport :: VarSet -- Don't specialise these + -- See Note [Avoiding recursive specialisation] + -> RuleBase -- Rules from this module + -> Id -> [CallInfo] -- Imported function and calls for it + -> CoreM ( [CoreRule] -- New rules + , [CoreBind] ) -- Specialised bindings +specImport done rb fn calls_for_fn + | not (fn `elemVarSet` done) + , isInlinablePragma (idInlinePragma fn) + , Just rhs <- maybeUnfoldingTemplate (realIdUnfolding fn) + = do { -- Get rules from the external package state + -- We keep doing this in case we "page-fault in" + -- more rules as we go along + ; hsc_env <- getHscEnv + ; eps <- liftIO $ hscEPS hsc_env + ; let full_rb = unionRuleBase rb (eps_rule_base eps) + rules_for_fn = getRules full_rb fn + + ; (rules1, spec_pairs, uds) <- runSpecM $ + specCalls emptySubst rules_for_fn calls_for_fn fn rhs + ; let spec_binds1 = [NonRec b r | (b,r) <- spec_pairs] + -- After the rules kick in we may get recursion, but + -- we rely on a global GlomBinds to sort that out later + + -- Now specialise any cascaded calls + ; (rules2, spec_binds2) <- specImports (extendVarSet done fn) + (extendRuleBaseList rb rules1) + uds + + ; return (rules2 ++ rules1, spec_binds2 ++ spec_binds1) } + + | otherwise + = WARN( True, ptext (sLit "specImport discard") <+> ppr fn <+> ppr calls_for_fn ) + return ([], []) \end{code} +Avoiding recursive specialisation +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +When we specialise 'f' we may find new overloaded calls to 'g', 'h' in +'f's RHS. So we want to specialise g,h. But we don't want to +specialise f any more! It's possible that f's RHS might have a +recursive yet-more-specialised call, so we'd diverge in that case. +And if the call is to the same type, one specialisation is enough. +Avoiding this recursive specialisation loop is the reason for the +'done' VarSet passed to specImports and specImport. + %************************************************************************ %* * \subsubsection{@specExpr@: the main function} @@ -706,7 +782,7 @@ specCase subst scrut' case_bndr [(con, args, rhs)] loc = getSrcSpan name add_unf sc_flt sc_rhs -- Sole purpose: make sc_flt respond True to interestingDictId - = setIdUnfolding sc_flt (mkUnfolding False False sc_rhs) + = setIdUnfolding sc_flt (mkSimpleUnfolding sc_rhs) arg_set = mkVarSet args' is_flt_sc_arg var = isId var @@ -763,7 +839,7 @@ to substitute sc -> sc_flt in the RHS %************************************************************************ %* * -\subsubsection{Dealing with a binding} + Dealing with a binding %* * %************************************************************************ @@ -863,6 +939,34 @@ specDefn :: Subst UsageDetails) -- Stuff to fling upwards from the specialised versions specDefn subst body_uds fn rhs + = do { let (body_uds_without_me, calls_for_me) = callsForMe fn body_uds + rules_for_me = idCoreRules fn + ; (rules, spec_defns, spec_uds) <- specCalls subst rules_for_me + calls_for_me fn rhs + ; return ( fn `addIdSpecialisations` rules + , spec_defns + , body_uds_without_me `plusUDs` spec_uds) } + -- It's important that the `plusUDs` is this way + -- round, because body_uds_without_me may bind + -- dictionaries that are used in calls_for_me passed + -- to specDefn. So the dictionary bindings in + -- spec_uds may mention dictionaries bound in + -- body_uds_without_me + +--------------------------- +specCalls :: Subst + -> [CoreRule] -- Existing RULES for the fn + -> [CallInfo] + -> Id -> CoreExpr + -> SpecM ([CoreRule], -- New RULES for the fn + [(Id,CoreExpr)], -- Extra, specialised bindings + UsageDetails) -- New usage details from the specialised RHSs + +-- This function checks existing rules, and does not create +-- duplicate ones. So the caller does not nneed to do this filtering. +-- See 'already_covered' + +specCalls subst rules_for_me calls_for_me fn rhs -- The first case is the interesting one | rhs_tyvars `lengthIs` n_tyvars -- Rhs of fn's defn has right number of big lambdas && rhs_ids `lengthAtLeast` n_dicts -- and enough dict args @@ -875,26 +979,16 @@ specDefn subst body_uds fn rhs -- See Note [Inline specialisation] for why we do not -- switch off specialisation for inline functions - = -- pprTrace "specDefn: some" (ppr fn $$ ppr calls_for_me) $ - do { -- Make a specialised version for each call in calls_for_me - stuff <- mapM spec_call calls_for_me + = -- pprTrace "specDefn: some" (ppr fn $$ ppr calls_for_me $$ ppr rules_for_me) $ + do { stuff <- mapM spec_call calls_for_me ; let (spec_defns, spec_uds, spec_rules) = unzip3 (catMaybes stuff) - fn' = addIdSpecialisations fn spec_rules - final_uds = body_uds_without_me `plusUDs` plusUDList spec_uds - -- It's important that the `plusUDs` is this way - -- round, because body_uds_without_me may bind - -- dictionaries that are used in calls_for_me passed - -- to specDefn. So the dictionary bindings in - -- spec_uds may mention dictionaries bound in - -- body_uds_without_me - - ; return (fn', spec_defns, final_uds) } + ; return (spec_rules, spec_defns, plusUDList spec_uds) } | otherwise -- No calls or RHS doesn't fit our preconceptions = WARN( notNull calls_for_me, ptext (sLit "Missed specialisation opportunity for") <+> ppr fn ) -- Note [Specialisation shape] -- pprTrace "specDefn: none" (ppr fn $$ ppr calls_for_me) $ - return (fn, [], body_uds_without_me) + return ([], [], emptyUDs) where fn_type = idType fn @@ -903,21 +997,17 @@ specDefn subst body_uds fn rhs (tyvars, theta, _) = tcSplitSigmaTy fn_type n_tyvars = length tyvars n_dicts = length theta - inl_act = inlinePragmaActivation (idInlinePragma fn) + inl_prag = idInlinePragma fn + inl_act = inlinePragmaActivation inl_prag + is_local = isLocalId fn -- Figure out whether the function has an INLINE pragma -- See Note [Inline specialisations] - fn_has_inline_rule :: Maybe Bool -- Derive sat-flag from existing thing - fn_has_inline_rule = case isInlineRule_maybe fn_unf of - Just (_,sat) -> Just sat - Nothing -> Nothing spec_arity = unfoldingArity fn_unf - n_dicts -- Arity of the *specialised* inline rule (rhs_tyvars, rhs_ids, rhs_body) = collectTyAndValBinders rhs - (body_uds_without_me, calls_for_me) = callsForMe fn body_uds - rhs_dict_ids = take n_dicts rhs_ids body = mkLams (drop n_dicts rhs_ids) rhs_body -- Glue back on the non-dict lambdas @@ -926,7 +1016,7 @@ specDefn subst body_uds fn rhs already_covered args -- Note [Specialisations already covered] = isJust (lookupRule (const True) realIdUnfolding (substInScope subst) - fn args (idCoreRules fn)) + fn args rules_for_me) mk_ty_args :: [Maybe Type] -> [CoreExpr] mk_ty_args call_ts = zipWithEqual "spec_call" mk_ty_arg rhs_tyvars call_ts @@ -990,8 +1080,8 @@ specDefn subst body_uds fn rhs -- The rule to put in the function's specialisation is: -- forall b, d1',d2'. f t1 b t3 d1' d2' = f1 b rule_name = mkFastString ("SPEC " ++ showSDoc (ppr fn <+> ppr spec_ty_args)) - spec_env_rule = mkLocalRule - rule_name + spec_env_rule = mkRule True {- Auto generated -} is_local + rule_name inl_act -- Note [Auto-specialisation and RULES] (idName fn) (poly_tyvars ++ inst_dict_ids) @@ -1001,25 +1091,23 @@ specDefn subst body_uds fn rhs -- Add the { d1' = dx1; d2' = dx2 } usage stuff final_uds = foldr consDictBind rhs_uds dx_binds + -- Add an InlineRule if the parent has one + -- See Note [Inline specialisations] + spec_unf + = case inlinePragmaSpec inl_prag of + Inline -> mkInlineUnfolding (Just spec_arity) spec_rhs + Inlinable -> mkInlinableUnfolding spec_rhs + _ -> NoUnfolding + -- Adding arity information just propagates it a bit faster -- See Note [Arity decrease] in Simplify -- Copy InlinePragma information from the parent Id. -- So if f has INLINE[1] so does spec_f spec_f_w_arity = spec_f `setIdArity` max 0 (fn_arity - n_dicts) - `setInlineActivation` inl_act + `setInlinePragma` inl_prag + `setIdUnfolding` spec_unf - -- Add an InlineRule if the parent has one - -- See Note [Inline specialisations] - final_spec_f - | Just sat <- fn_has_inline_rule - = let - mb_spec_arity = if sat then Just spec_arity else Nothing - in - spec_f_w_arity `setIdUnfolding` mkInlineRule spec_rhs mb_spec_arity - | otherwise - = spec_f_w_arity - - ; return (Just ((final_spec_f, spec_rhs), final_uds, spec_env_rule)) } } + ; return (Just ((spec_f_w_arity, spec_rhs), final_uds, spec_env_rule)) } } where my_zipEqual xs ys zs | debugIsOn && not (equalLength xs ys && equalLength ys zs) @@ -1048,7 +1136,7 @@ bindAuxiliaryDicts subst triples = go subst [] triples | otherwise = go subst_w_unf (NonRec dx_id dx : binds) pairs where - dx_id1 = dx_id `setIdUnfolding` mkUnfolding False False dx + dx_id1 = dx_id `setIdUnfolding` mkSimpleUnfolding dx subst_w_unf = extendIdSubst subst d (Var dx_id1) -- Important! We're going to substitute dx_id1 for d -- and we want it to look "interesting", else we won't gather *any* @@ -1149,7 +1237,7 @@ group. (In this case it'll unravel a short moment later.) Conclusion: we catch the nasty case using filter_dfuns in -callsForMe To be honest I'm not 100% certain that this is 100% +callsForMe. To be honest I'm not 100% certain that this is 100% right, but it works. Sigh. @@ -1328,13 +1416,17 @@ newtype CallKey = CallKey [Maybe Type] -- Nothing => unconstrained type argu -- -- The list of types and dictionaries is guaranteed to -- match the type of f -type CallInfoSet = Map CallKey ([DictExpr], VarSet) +data CallInfoSet = CIS Id (Map CallKey ([DictExpr], VarSet)) -- Range is dict args and the vars of the whole -- call (including tyvars) -- [*not* include the main id itself, of course] type CallInfo = (CallKey, ([DictExpr], VarSet)) +instance Outputable CallInfoSet where + ppr (CIS fn map) = hang (ptext (sLit "CIS") <+> ppr fn) + 2 (ppr map) + instance Outputable CallKey where ppr (CallKey ts) = ppr ts @@ -1352,22 +1444,23 @@ instance Ord CallKey where cmp (Just t1) (Just t2) = tcCmpType t1 t2 unionCalls :: CallDetails -> CallDetails -> CallDetails -unionCalls c1 c2 = plusVarEnv_C Map.union c1 c2 +unionCalls c1 c2 = plusVarEnv_C unionCallInfoSet c1 c2 --- plusCalls :: UsageDetails -> CallDetails -> UsageDetails --- plusCalls uds call_ds = uds { ud_calls = ud_calls uds `unionCalls` call_ds } +unionCallInfoSet :: CallInfoSet -> CallInfoSet -> CallInfoSet +unionCallInfoSet (CIS f calls1) (CIS _ calls2) = CIS f (calls1 `Map.union` calls2) callDetailsFVs :: CallDetails -> VarSet callDetailsFVs calls = foldVarEnv (unionVarSet . callInfoFVs) emptyVarSet calls callInfoFVs :: CallInfoSet -> VarSet -callInfoFVs call_info = Map.foldRightWithKey (\_ (_,fv) vs -> unionVarSet fv vs) emptyVarSet call_info +callInfoFVs (CIS _ call_info) = Map.foldRight (\(_,fv) vs -> unionVarSet fv vs) emptyVarSet call_info ------------------------------------------------------------ singleCall :: Id -> [Maybe Type] -> [DictExpr] -> UsageDetails singleCall id tys dicts = MkUD {ud_binds = emptyBag, - ud_calls = unitVarEnv id (Map.singleton (CallKey tys) (dicts, call_fvs)) } + ud_calls = unitVarEnv id $ CIS id $ + Map.singleton (CallKey tys) (dicts, call_fvs) } where call_fvs = exprsFreeVars dicts `unionVarSet` tys_fvs tys_fvs = tyVarsOfTypes (catMaybes tys) @@ -1383,8 +1476,8 @@ singleCall id tys dicts mkCallUDs :: Id -> [CoreExpr] -> UsageDetails mkCallUDs f args - | not (isLocalId f) -- Imported from elsewhere - || null theta -- Not overloaded + | not (want_calls_for f) -- Imported from elsewhere + || null theta -- Not overloaded || not (all isClassPred theta) -- Only specialise if all overloading is on class params. -- In ptic, with implicit params, the type args @@ -1411,6 +1504,8 @@ mkCallUDs f args mk_spec_ty tyvar ty | tyvar `elemVarSet` constrained_tyvars = Just ty | otherwise = Nothing + + want_calls_for f = isLocalId f || isInlinablePragma (idInlinePragma f) \end{code} Note [Interesting dictionary arguments] @@ -1541,7 +1636,7 @@ callsForMe fn (MkUD { ud_binds = orig_dbs, ud_calls = orig_calls }) uds_without_me = MkUD { ud_binds = orig_dbs, ud_calls = delVarEnv orig_calls fn } calls_for_me = case lookupVarEnv orig_calls fn of Nothing -> [] - Just cs -> filter_dfuns (Map.toList cs) + Just (CIS _ calls) -> filter_dfuns (Map.toList calls) dep_set = foldlBag go (unitVarSet fn) orig_dbs go dep_set (db,fvs) | fvs `intersectsVarSet` dep_set @@ -1578,7 +1673,8 @@ deleteCallsMentioning bs calls = mapVarEnv filter_calls calls where filter_calls :: CallInfoSet -> CallInfoSet - filter_calls = Map.filterWithKey (\_ (_, fvs) -> not (fvs `intersectsVarSet` bs)) + filter_calls (CIS f calls) = CIS f (Map.filter keep_call calls) + keep_call (_, fvs) = not (fvs `intersectsVarSet` bs) deleteCallsFor :: [Id] -> CallDetails -> CallDetails -- Remove calls *for* bs @@ -1595,8 +1691,9 @@ deleteCallsFor bs calls = delVarEnvList calls bs \begin{code} type SpecM a = UniqSM a -initSM :: UniqSupply -> SpecM a -> a -initSM = initUs_ +runSpecM:: SpecM a -> CoreM a +runSpecM spec = do { us <- getUniqueSupplyM + ; return (initUs_ us spec) } mapAndCombineSM :: (a -> SpecM (b, UsageDetails)) -> [a] -> SpecM ([b], UsageDetails) mapAndCombineSM _ [] = return ([], emptyUDs)