\section[Specialise]{Stamping out overloading, and (optionally) polymorphism}
module Specialise ( specProgram ) where
-- More efficient to collect a group of binders together all at once
-- and we don't want to split a lambda group with dumped bindings
-specExpr subst (Case scrut case_bndr ty alts) = do
- (scrut', uds_scrut) <- specExpr subst scrut
- (alts', uds_alts) <- mapAndCombineSM spec_alt alts
- return (Case scrut' case_bndr' (CoreSubst.substTy subst ty) alts',
- uds_scrut `plusUDs` uds_alts)
- where
- (subst_alt, case_bndr') = substBndr subst case_bndr
- -- No need to clone case binder; it can't float like a let(rec)
- spec_alt (con, args, rhs) = do
- (rhs', uds) <- specExpr subst_rhs rhs
- let (free_uds, dumped_dbs) = dumpUDs args' uds
- return ((con, args', wrapDictBindsE dumped_dbs rhs'), free_uds)
- where
- (subst_rhs, args') = substBndrs subst_alt args
+specExpr subst (Case scrut case_bndr ty alts)
+ = do { (scrut', scrut_uds) <- specExpr subst scrut
+ ; (scrut'', case_bndr', alts', alts_uds)
+ <- specCase subst scrut' case_bndr alts
+ ; return (Case scrut'' case_bndr' (CoreSubst.substTy subst ty) alts'
+ , scrut_uds `plusUDs` alts_uds) }
---------------- Finally, let is the interesting case --------------------
specExpr subst (Let bind body) = do
-- Must apply the type substitution to coerceions
specNote :: Subst -> Note -> Note
specNote _ note = note
+specCase :: Subst
+ -> CoreExpr -- Scrutinee, already done
+ -> Id -> [CoreAlt]
+ -> SpecM ( CoreExpr -- New scrutinee
+ , Id
+ , [CoreAlt]
+ , UsageDetails)
+specCase subst scrut' case_bndr [(con, args, rhs)]
+ | isDictId case_bndr -- See Note [Floating dictionaries out of cases]
+ , interestingDict scrut'
+ , not (isDeadBinder case_bndr && null sc_args')
+ = do { (case_bndr_flt : sc_args_flt) <- mapM clone_me (case_bndr' : sc_args')
+ ; let sc_rhss = [ Case (Var case_bndr_flt) case_bndr' (idType sc_arg')
+ [(con, args', Var sc_arg')]
+ | sc_arg' <- sc_args' ]
+ -- Extend the substitution for RHS to map the *original* binders
+ -- to their floated verions. Attach an unfolding to these floated
+ -- binders so they look interesting to interestingDict
+ mb_sc_flts :: [Maybe DictId]
+ mb_sc_flts = map (lookupVarEnv clone_env) args'
+ clone_env = zipVarEnv sc_args' (zipWith add_unf sc_args_flt sc_rhss)
+ subst_prs = (case_bndr, Var (add_unf case_bndr_flt scrut'))
+ : [ (arg, Var sc_flt)
+ | (arg, Just sc_flt) <- args `zip` mb_sc_flts ]
+ subst_rhs' = extendIdSubstList subst_rhs subst_prs
+ ; (rhs', rhs_uds) <- specExpr subst_rhs' rhs
+ ; let scrut_bind = mkDB (NonRec case_bndr_flt scrut')
+ case_bndr_set = unitVarSet case_bndr_flt
+ sc_binds = [(NonRec sc_arg_flt sc_rhs, case_bndr_set)
+ | (sc_arg_flt, sc_rhs) <- sc_args_flt `zip` sc_rhss ]
+ flt_binds = scrut_bind : sc_binds
+ (free_uds, dumped_dbs) = dumpUDs (case_bndr':args') rhs_uds
+ all_uds = flt_binds `addDictBinds` free_uds
+ alt' = (con, args', wrapDictBindsE dumped_dbs rhs')
+ ; return (Var case_bndr_flt, case_bndr', [alt'], all_uds) }
+ where
+ (subst_rhs, (case_bndr':args')) = substBndrs subst (case_bndr:args)
+ sc_args' = filter is_flt_sc_arg args'
+ clone_me bndr = do { uniq <- getUniqueM
+ ; return (mkUserLocal occ uniq ty loc) }
+ where
+ name = idName bndr
+ ty = idType bndr
+ occ = nameOccName name
+ loc = getSrcSpan name
+ add_unf sc_flt sc_rhs -- Sole purpose: make sc_flt respond True to interestingDictId
+ = setIdUnfolding sc_flt (mkUnfolding False False sc_rhs)
+ arg_set = mkVarSet args'
+ is_flt_sc_arg var = isId var
+ && not (isDeadBinder var)
+ && isDictTy var_ty
+ && not (tyVarsOfType var_ty `intersectsVarSet` arg_set)
+ where
+ var_ty = idType var
+specCase subst scrut case_bndr alts
+ = do { (alts', uds_alts) <- mapAndCombineSM spec_alt alts
+ ; return (scrut, case_bndr', alts', uds_alts) }
+ where
+ (subst_alt, case_bndr') = substBndr subst case_bndr
+ spec_alt (con, args, rhs) = do
+ (rhs', uds) <- specExpr subst_rhs rhs
+ let (free_uds, dumped_dbs) = dumpUDs (case_bndr' : args') uds
+ return ((con, args', wrapDictBindsE dumped_dbs rhs'), free_uds)
+ where
+ (subst_rhs, args') = substBndrs subst_alt args
+Note [Floating dictionaries out of cases]
+ g = \d. case d of { MkD sc ... -> ...(f sc)... }
+Naively we can't float d2's binding out of the case expression,
+because 'sc' is bound by the case, and that in turn means we can't
+specialise f, which seems a pity.
+So we invert the case, by floating out a binding
+for 'sc_flt' thus:
+ sc_flt = case d of { MkD sc ... -> sc }
+Now we can float the call instance for 'f'. Indeed this is just
+what'll happen if 'sc' was originally bound with a let binding,
+but case is more efficient, and necessary with equalities. So it's
+good to work with both.
+You might think that this won't make any difference, because the
+call instance will only get nuked by the \d. BUT if 'g' itself is
+specialised, then transitively we should be able to specialise f.
+In general, given
+ case e of cb { MkD sc ... -> ...(f sc)... }
+we transform to
+ let cb_flt = e
+ sc_flt = case cb_flt of { MkD sc ... -> sc }
+ in
+ case cb_flt of bg { MkD sc ... -> ....(f sc_flt)... }
+The "_flt" things are the floated binds; we use the current substitution
+to substitute sc -> sc_flt in the RHS
\subsubsection{Dealing with a binding}
-- See Note [Inline specialisation] for why we do not
-- switch off specialisation for inline functions
- = do { -- Make a specialised version for each call in calls_for_me
+ = -- pprTrace "specDefn: some" (ppr fn $$ ppr calls_for_me) $
+ do { -- Make a specialised version for each call in calls_for_me
stuff <- mapM spec_call calls_for_me
; let (spec_defns, spec_uds, spec_rules) = unzip3 (catMaybes stuff)
fn' = addIdSpecialisations fn spec_rules
| otherwise -- No calls or RHS doesn't fit our preconceptions
= WARN( notNull calls_for_me, ptext (sLit "Missed specialisation opportunity for") <+> ppr fn )
-- Note [Specialisation shape]
+ -- pprTrace "specDefn: none" (ppr fn $$ ppr calls_for_me) $
return (fn, [], body_uds_without_me)
because the code for the specialised f is not improved at all, because
d is lambda-bound. We simply get junk specialisations.
-What is "interesting"? Just that it has *some* structure.
+What is "interesting"? Just that it has *some* structure.
interestingDict :: CoreExpr -> Bool
consDictBind :: CoreBind -> UsageDetails -> UsageDetails
consDictBind bind uds = uds { ud_binds = mkDB bind `consBag` ud_binds uds }
+addDictBinds :: [DictBind] -> UsageDetails -> UsageDetails
+addDictBinds binds uds = uds { ud_binds = listToBag binds `unionBags` ud_binds uds }
snocDictBind :: UsageDetails -> CoreBind -> UsageDetails
snocDictBind uds bind = uds { ud_binds = ud_binds uds `snocBag` mkDB bind }
-- Used at a lambda or case binder; just dump anything mentioning the binder
dumpUDs bndrs uds@(MkUD { ud_binds = orig_dbs, ud_calls = orig_calls })
| null bndrs = (uds, emptyBag) -- Common in case alternatives
- | otherwise = (free_uds, dump_dbs)
+ | otherwise = -- pprTrace "dumpUDs" (ppr bndrs $$ ppr free_uds $$ ppr dump_dbs) $
+ (free_uds, dump_dbs)
free_uds = MkUD { ud_binds = free_dbs, ud_calls = free_calls }
bndr_set = mkVarSet bndrs
dumpBindUDs :: [CoreBndr] -> UsageDetails -> (UsageDetails, Bag DictBind, Bool)
-- Used at a lambda or case binder; just dump anything mentioning the binder
dumpBindUDs bndrs (MkUD { ud_binds = orig_dbs, ud_calls = orig_calls })
- = (free_uds, dump_dbs, float_all)
+ = -- pprTrace "dumpBindUDs" (ppr bndrs $$ ppr free_uds $$ ppr dump_dbs) $
+ (free_uds, dump_dbs, float_all)
free_uds = MkUD { ud_binds = free_dbs, ud_calls = free_calls }
bndr_set = mkVarSet bndrs
dep_set = foldlBag go (unitVarSet fn) orig_dbs
go dep_set (db,fvs) | fvs `intersectsVarSet` dep_set
= extendVarSetList dep_set (bindersOf db)
- | otherwise = fvs
+ | otherwise = dep_set
-- Note [Specialisation of dictionary functions]
filter_dfuns | isDFunId fn = filter ok_call