import Id
import TcType
+import CoreMonad
import CoreSubst
-import CoreUnfold ( mkUnfolding, mkInlineRule )
+import CoreUnfold
import VarSet
import VarEnv
import CoreSyn
import Rules
import CoreUtils ( exprIsTrivial, applyTypeToArgs, mkPiTypes )
import CoreFVs ( exprFreeVars, exprsFreeVars, idFreeVars )
-import UniqSupply ( UniqSupply, UniqSM, initUs_, MonadUnique(..) )
+import UniqSupply ( UniqSM, initUs_, MonadUnique(..) )
import Name
import MkId ( voidArgId, realWorldPrimId )
import Maybes ( catMaybes, isJust )
-import BasicTypes ( isNeverActive, inlinePragmaActivation )
+import BasicTypes
+import HscTypes
import Bag
import Util
import Outputable
%************************************************************************
\begin{code}
-specProgram :: UniqSupply -> [CoreBind] -> [CoreBind]
-specProgram us binds = initSM us $
- do { (binds', uds') <- go binds
- ; return (wrapDictBinds (ud_binds uds') binds') }
+specProgram :: ModGuts -> CoreM ModGuts
+specProgram guts
+ = do { hpt_rules <- getRuleBase
+ ; let local_rules = mg_rules guts
+ rule_base = extendRuleBaseList hpt_rules (mg_rules guts)
+
+ -- Specialise the bindings of this module
+ ; (binds', uds) <- runSpecM (go (mg_binds guts))
+
+ -- Specialise imported functions
+ ; (new_rules, spec_binds) <- specImports emptyVarSet rule_base uds
+
+ ; return (guts { mg_binds = spec_binds ++ binds'
+ , mg_rules = local_rules ++ new_rules }) }
where
-- We need to start with a Subst that knows all the things
-- that are in scope, so that the substitution engine doesn't
-- accidentally re-use a unique that's already in use
-- Easiest thing is to do it all at once, as if all the top-level
-- decls were mutually recursive
- top_subst = mkEmptySubst (mkInScopeSet (mkVarSet (bindersOfBinds binds)))
+ top_subst = mkEmptySubst $ mkInScopeSet $ mkVarSet $
+ bindersOfBinds $ mg_binds guts
go [] = return ([], emptyUDs)
go (bind:binds) = do (binds', uds) <- go binds
(bind', uds') <- specBind top_subst bind uds
return (bind' ++ binds', uds')
+
+specImports :: VarSet -- Don't specialise these ones
+ -- See Note [Avoiding recursive specialisation]
+ -> RuleBase -- Rules from this module and the home package
+ -- (but not external packages, which can change)
+ -> UsageDetails -- Calls for imported things, and floating bindings
+ -> CoreM ( [CoreRule] -- New rules
+ , [CoreBind] ) -- Specialised bindings and floating bindings
+specImports done rb uds
+ = do { let import_calls = varEnvElts (ud_calls uds)
+ ; (rules, spec_binds) <- go rb import_calls
+ ; return (rules, wrapDictBinds (ud_binds uds) spec_binds) }
+ where
+ go _ [] = return ([], [])
+ go rb (CIS fn calls_for_fn : other_calls)
+ = do { (rules1, spec_binds1) <- specImport done rb fn (Map.toList calls_for_fn)
+ ; (rules2, spec_binds2) <- go (extendRuleBaseList rb rules1) other_calls
+ ; return (rules1 ++ rules2, spec_binds1 ++ spec_binds2) }
+
+specImport :: VarSet -- Don't specialise these
+ -- See Note [Avoiding recursive specialisation]
+ -> RuleBase -- Rules from this module
+ -> Id -> [CallInfo] -- Imported function and calls for it
+ -> CoreM ( [CoreRule] -- New rules
+ , [CoreBind] ) -- Specialised bindings
+specImport done rb fn calls_for_fn
+ | not (fn `elemVarSet` done)
+ , isInlinablePragma (idInlinePragma fn)
+ , Just rhs <- maybeUnfoldingTemplate (realIdUnfolding fn)
+ = do { -- Get rules from the external package state
+ -- We keep doing this in case we "page-fault in"
+ -- more rules as we go along
+ ; hsc_env <- getHscEnv
+ ; eps <- liftIO $ hscEPS hsc_env
+ ; let full_rb = unionRuleBase rb (eps_rule_base eps)
+ rules_for_fn = getRules full_rb fn
+
+ ; (rules1, spec_pairs, uds) <- runSpecM $
+ specCalls emptySubst rules_for_fn calls_for_fn fn rhs
+ ; let spec_binds1 = [NonRec b r | (b,r) <- spec_pairs]
+ -- After the rules kick in we may get recursion, but
+ -- we rely on a global GlomBinds to sort that out later
+
+ -- Now specialise any cascaded calls
+ ; (rules2, spec_binds2) <- specImports (extendVarSet done fn)
+ (extendRuleBaseList rb rules1)
+ uds
+
+ ; return (rules2 ++ rules1, spec_binds2 ++ spec_binds1) }
+
+ | otherwise
+ = WARN( True, ptext (sLit "specImport discard") <+> ppr fn <+> ppr calls_for_fn )
+ return ([], [])
\end{code}
+Avoiding recursive specialisation
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+When we specialise 'f' we may find new overloaded calls to 'g', 'h' in
+'f's RHS. So we want to specialise g,h. But we don't want to
+specialise f any more! It's possible that f's RHS might have a
+recursive yet-more-specialised call, so we'd diverge in that case.
+And if the call is to the same type, one specialisation is enough.
+Avoiding this recursive specialisation loop is the reason for the
+'done' VarSet passed to specImports and specImport.
+
%************************************************************************
%* *
\subsubsection{@specExpr@: the main function}
loc = getSrcSpan name
add_unf sc_flt sc_rhs -- Sole purpose: make sc_flt respond True to interestingDictId
- = setIdUnfolding sc_flt (mkUnfolding False False sc_rhs)
+ = setIdUnfolding sc_flt (mkSimpleUnfolding sc_rhs)
arg_set = mkVarSet args'
is_flt_sc_arg var = isId var
%************************************************************************
%* *
-\subsubsection{Dealing with a binding}
+ Dealing with a binding
%* *
%************************************************************************
UsageDetails) -- Stuff to fling upwards from the specialised versions
specDefn subst body_uds fn rhs
+ = do { let (body_uds_without_me, calls_for_me) = callsForMe fn body_uds
+ rules_for_me = idCoreRules fn
+ ; (rules, spec_defns, spec_uds) <- specCalls subst rules_for_me
+ calls_for_me fn rhs
+ ; return ( fn `addIdSpecialisations` rules
+ , spec_defns
+ , body_uds_without_me `plusUDs` spec_uds) }
+ -- It's important that the `plusUDs` is this way
+ -- round, because body_uds_without_me may bind
+ -- dictionaries that are used in calls_for_me passed
+ -- to specDefn. So the dictionary bindings in
+ -- spec_uds may mention dictionaries bound in
+ -- body_uds_without_me
+
+---------------------------
+specCalls :: Subst
+ -> [CoreRule] -- Existing RULES for the fn
+ -> [CallInfo]
+ -> Id -> CoreExpr
+ -> SpecM ([CoreRule], -- New RULES for the fn
+ [(Id,CoreExpr)], -- Extra, specialised bindings
+ UsageDetails) -- New usage details from the specialised RHSs
+
+-- This function checks existing rules, and does not create
+-- duplicate ones. So the caller does not nneed to do this filtering.
+-- See 'already_covered'
+
+specCalls subst rules_for_me calls_for_me fn rhs
-- The first case is the interesting one
| rhs_tyvars `lengthIs` n_tyvars -- Rhs of fn's defn has right number of big lambdas
&& rhs_ids `lengthAtLeast` n_dicts -- and enough dict args
-- See Note [Inline specialisation] for why we do not
-- switch off specialisation for inline functions
- = -- pprTrace "specDefn: some" (ppr fn $$ ppr calls_for_me) $
- do { -- Make a specialised version for each call in calls_for_me
- stuff <- mapM spec_call calls_for_me
+ = -- pprTrace "specDefn: some" (ppr fn $$ ppr calls_for_me $$ ppr rules_for_me) $
+ do { stuff <- mapM spec_call calls_for_me
; let (spec_defns, spec_uds, spec_rules) = unzip3 (catMaybes stuff)
- fn' = addIdSpecialisations fn spec_rules
- final_uds = body_uds_without_me `plusUDs` plusUDList spec_uds
- -- It's important that the `plusUDs` is this way
- -- round, because body_uds_without_me may bind
- -- dictionaries that are used in calls_for_me passed
- -- to specDefn. So the dictionary bindings in
- -- spec_uds may mention dictionaries bound in
- -- body_uds_without_me
-
- ; return (fn', spec_defns, final_uds) }
+ ; return (spec_rules, spec_defns, plusUDList spec_uds) }
| otherwise -- No calls or RHS doesn't fit our preconceptions
= WARN( notNull calls_for_me, ptext (sLit "Missed specialisation opportunity for") <+> ppr fn )
-- Note [Specialisation shape]
-- pprTrace "specDefn: none" (ppr fn $$ ppr calls_for_me) $
- return (fn, [], body_uds_without_me)
+ return ([], [], emptyUDs)
where
fn_type = idType fn
(tyvars, theta, _) = tcSplitSigmaTy fn_type
n_tyvars = length tyvars
n_dicts = length theta
- inl_act = inlinePragmaActivation (idInlinePragma fn)
+ inl_prag = idInlinePragma fn
+ inl_act = inlinePragmaActivation inl_prag
+ is_local = isLocalId fn
-- Figure out whether the function has an INLINE pragma
-- See Note [Inline specialisations]
- fn_has_inline_rule :: Maybe Bool -- Derive sat-flag from existing thing
- fn_has_inline_rule = case isInlineRule_maybe fn_unf of
- Just (_,sat) -> Just sat
- Nothing -> Nothing
spec_arity = unfoldingArity fn_unf - n_dicts -- Arity of the *specialised* inline rule
(rhs_tyvars, rhs_ids, rhs_body) = collectTyAndValBinders rhs
- (body_uds_without_me, calls_for_me) = callsForMe fn body_uds
-
rhs_dict_ids = take n_dicts rhs_ids
body = mkLams (drop n_dicts rhs_ids) rhs_body
-- Glue back on the non-dict lambdas
already_covered args -- Note [Specialisations already covered]
= isJust (lookupRule (const True) realIdUnfolding
(substInScope subst)
- fn args (idCoreRules fn))
+ fn args rules_for_me)
mk_ty_args :: [Maybe Type] -> [CoreExpr]
mk_ty_args call_ts = zipWithEqual "spec_call" mk_ty_arg rhs_tyvars call_ts
-- The rule to put in the function's specialisation is:
-- forall b, d1',d2'. f t1 b t3 d1' d2' = f1 b
rule_name = mkFastString ("SPEC " ++ showSDoc (ppr fn <+> ppr spec_ty_args))
- spec_env_rule = mkLocalRule
- rule_name
+ spec_env_rule = mkRule True {- Auto generated -} is_local
+ rule_name
inl_act -- Note [Auto-specialisation and RULES]
(idName fn)
(poly_tyvars ++ inst_dict_ids)
-- Add the { d1' = dx1; d2' = dx2 } usage stuff
final_uds = foldr consDictBind rhs_uds dx_binds
+ -- Add an InlineRule if the parent has one
+ -- See Note [Inline specialisations]
+ spec_unf
+ = case inlinePragmaSpec inl_prag of
+ Inline -> mkInlineUnfolding (Just spec_arity) spec_rhs
+ Inlinable -> mkInlinableUnfolding spec_rhs
+ _ -> NoUnfolding
+
-- Adding arity information just propagates it a bit faster
-- See Note [Arity decrease] in Simplify
-- Copy InlinePragma information from the parent Id.
-- So if f has INLINE[1] so does spec_f
spec_f_w_arity = spec_f `setIdArity` max 0 (fn_arity - n_dicts)
- `setInlineActivation` inl_act
+ `setInlinePragma` inl_prag
+ `setIdUnfolding` spec_unf
- -- Add an InlineRule if the parent has one
- -- See Note [Inline specialisations]
- final_spec_f
- | Just sat <- fn_has_inline_rule
- = let
- mb_spec_arity = if sat then Just spec_arity else Nothing
- in
- spec_f_w_arity `setIdUnfolding` mkInlineRule spec_rhs mb_spec_arity
- | otherwise
- = spec_f_w_arity
-
- ; return (Just ((final_spec_f, spec_rhs), final_uds, spec_env_rule)) } }
+ ; return (Just ((spec_f_w_arity, spec_rhs), final_uds, spec_env_rule)) } }
where
my_zipEqual xs ys zs
| debugIsOn && not (equalLength xs ys && equalLength ys zs)
| otherwise = go subst_w_unf (NonRec dx_id dx : binds) pairs
where
- dx_id1 = dx_id `setIdUnfolding` mkUnfolding False False dx
+ dx_id1 = dx_id `setIdUnfolding` mkSimpleUnfolding dx
subst_w_unf = extendIdSubst subst d (Var dx_id1)
-- Important! We're going to substitute dx_id1 for d
-- and we want it to look "interesting", else we won't gather *any*
Conclusion: we catch the nasty case using filter_dfuns in
-callsForMe To be honest I'm not 100% certain that this is 100%
+callsForMe. To be honest I'm not 100% certain that this is 100%
right, but it works. Sigh.
--
-- The list of types and dictionaries is guaranteed to
-- match the type of f
-type CallInfoSet = Map CallKey ([DictExpr], VarSet)
+data CallInfoSet = CIS Id (Map CallKey ([DictExpr], VarSet))
-- Range is dict args and the vars of the whole
-- call (including tyvars)
-- [*not* include the main id itself, of course]
type CallInfo = (CallKey, ([DictExpr], VarSet))
+instance Outputable CallInfoSet where
+ ppr (CIS fn map) = hang (ptext (sLit "CIS") <+> ppr fn)
+ 2 (ppr map)
+
instance Outputable CallKey where
ppr (CallKey ts) = ppr ts
cmp (Just t1) (Just t2) = tcCmpType t1 t2
unionCalls :: CallDetails -> CallDetails -> CallDetails
-unionCalls c1 c2 = plusVarEnv_C Map.union c1 c2
+unionCalls c1 c2 = plusVarEnv_C unionCallInfoSet c1 c2
--- plusCalls :: UsageDetails -> CallDetails -> UsageDetails
--- plusCalls uds call_ds = uds { ud_calls = ud_calls uds `unionCalls` call_ds }
+unionCallInfoSet :: CallInfoSet -> CallInfoSet -> CallInfoSet
+unionCallInfoSet (CIS f calls1) (CIS _ calls2) = CIS f (calls1 `Map.union` calls2)
callDetailsFVs :: CallDetails -> VarSet
callDetailsFVs calls = foldVarEnv (unionVarSet . callInfoFVs) emptyVarSet calls
callInfoFVs :: CallInfoSet -> VarSet
-callInfoFVs call_info = Map.foldRightWithKey (\_ (_,fv) vs -> unionVarSet fv vs) emptyVarSet call_info
+callInfoFVs (CIS _ call_info) = Map.foldRight (\(_,fv) vs -> unionVarSet fv vs) emptyVarSet call_info
------------------------------------------------------------
singleCall :: Id -> [Maybe Type] -> [DictExpr] -> UsageDetails
singleCall id tys dicts
= MkUD {ud_binds = emptyBag,
- ud_calls = unitVarEnv id (Map.singleton (CallKey tys) (dicts, call_fvs)) }
+ ud_calls = unitVarEnv id $ CIS id $
+ Map.singleton (CallKey tys) (dicts, call_fvs) }
where
call_fvs = exprsFreeVars dicts `unionVarSet` tys_fvs
tys_fvs = tyVarsOfTypes (catMaybes tys)
mkCallUDs :: Id -> [CoreExpr] -> UsageDetails
mkCallUDs f args
- | not (isLocalId f) -- Imported from elsewhere
- || null theta -- Not overloaded
+ | not (want_calls_for f) -- Imported from elsewhere
+ || null theta -- Not overloaded
|| not (all isClassPred theta)
-- Only specialise if all overloading is on class params.
-- In ptic, with implicit params, the type args
mk_spec_ty tyvar ty
| tyvar `elemVarSet` constrained_tyvars = Just ty
| otherwise = Nothing
+
+ want_calls_for f = isLocalId f || isInlinablePragma (idInlinePragma f)
\end{code}
Note [Interesting dictionary arguments]
uds_without_me = MkUD { ud_binds = orig_dbs, ud_calls = delVarEnv orig_calls fn }
calls_for_me = case lookupVarEnv orig_calls fn of
Nothing -> []
- Just cs -> filter_dfuns (Map.toList cs)
+ Just (CIS _ calls) -> filter_dfuns (Map.toList calls)
dep_set = foldlBag go (unitVarSet fn) orig_dbs
go dep_set (db,fvs) | fvs `intersectsVarSet` dep_set
= mapVarEnv filter_calls calls
where
filter_calls :: CallInfoSet -> CallInfoSet
- filter_calls = Map.filterWithKey (\_ (_, fvs) -> not (fvs `intersectsVarSet` bs))
+ filter_calls (CIS f calls) = CIS f (Map.filter keep_call calls)
+ keep_call (_, fvs) = not (fvs `intersectsVarSet` bs)
deleteCallsFor :: [Id] -> CallDetails -> CallDetails
-- Remove calls *for* bs
\begin{code}
type SpecM a = UniqSM a
-initSM :: UniqSupply -> SpecM a -> a
-initSM = initUs_
+runSpecM:: SpecM a -> CoreM a
+runSpecM spec = do { us <- getUniqueSupplyM
+ ; return (initUs_ us spec) }
mapAndCombineSM :: (a -> SpecM (b, UsageDetails)) -> [a] -> SpecM ([b], UsageDetails)
mapAndCombineSM _ [] = return ([], emptyUDs)