From c107a00ccf1e641a2d008939cf477c71caa028d5 Mon Sep 17 00:00:00 2001 From: "simonpj@microsoft.com" Date: Thu, 12 Aug 2010 13:11:33 +0000 Subject: [PATCH] Improve the Specialiser, fixing Trac #4203 Simply fixing #4203 is a tiny fix: in case alterantives we should do dumpUDs *including* the case binder. But I realised that we can do better and wasted far too much time implementing the idea. It's described in Note [Floating dictionaries out of cases] --- compiler/specialise/Specialise.lhs | 150 +++++++++++++++++++++++++++++------- 1 file changed, 124 insertions(+), 26 deletions(-) diff --git a/compiler/specialise/Specialise.lhs b/compiler/specialise/Specialise.lhs index 849b600..f18c8f9 100644 --- a/compiler/specialise/Specialise.lhs +++ b/compiler/specialise/Specialise.lhs @@ -4,12 +4,6 @@ \section[Specialise]{Stamping out overloading, and (optionally) polymorphism} \begin{code} --- The above warning supression flag is a temporary kludge. --- While working on this module you are encouraged to remove it and fix --- any warnings in the module. See --- http://hackage.haskell.org/trac/ghc/wiki/Commentary/CodingStyle#Warnings --- for details - module Specialise ( specProgram ) where #include "HsVersions.h" @@ -633,21 +627,12 @@ specExpr subst e@(Lam _ _) = do -- More efficient to collect a group of binders together all at once -- and we don't want to split a lambda group with dumped bindings -specExpr subst (Case scrut case_bndr ty alts) = do - (scrut', uds_scrut) <- specExpr subst scrut - (alts', uds_alts) <- mapAndCombineSM spec_alt alts - return (Case scrut' case_bndr' (CoreSubst.substTy subst ty) alts', - uds_scrut `plusUDs` uds_alts) - where - (subst_alt, case_bndr') = substBndr subst case_bndr - -- No need to clone case binder; it can't float like a let(rec) - - spec_alt (con, args, rhs) = do - (rhs', uds) <- specExpr subst_rhs rhs - let (free_uds, dumped_dbs) = dumpUDs args' uds - return ((con, args', wrapDictBindsE dumped_dbs rhs'), free_uds) - where - (subst_rhs, args') = substBndrs subst_alt args +specExpr subst (Case scrut case_bndr ty alts) + = do { (scrut', scrut_uds) <- specExpr subst scrut + ; (scrut'', case_bndr', alts', alts_uds) + <- specCase subst scrut' case_bndr alts + ; return (Case scrut'' case_bndr' (CoreSubst.substTy subst ty) alts' + , scrut_uds `plusUDs` alts_uds) } ---------------- Finally, let is the interesting case -------------------- specExpr subst (Let bind body) = do @@ -666,8 +651,114 @@ specExpr subst (Let bind body) = do -- Must apply the type substitution to coerceions specNote :: Subst -> Note -> Note specNote _ note = note + + +specCase :: Subst + -> CoreExpr -- Scrutinee, already done + -> Id -> [CoreAlt] + -> SpecM ( CoreExpr -- New scrutinee + , Id + , [CoreAlt] + , UsageDetails) +specCase subst scrut' case_bndr [(con, args, rhs)] + | isDictId case_bndr -- See Note [Floating dictionaries out of cases] + , interestingDict scrut' + , not (isDeadBinder case_bndr && null sc_args') + = do { (case_bndr_flt : sc_args_flt) <- mapM clone_me (case_bndr' : sc_args') + + ; let sc_rhss = [ Case (Var case_bndr_flt) case_bndr' (idType sc_arg') + [(con, args', Var sc_arg')] + | sc_arg' <- sc_args' ] + + -- Extend the substitution for RHS to map the *original* binders + -- to their floated verions. Attach an unfolding to these floated + -- binders so they look interesting to interestingDict + mb_sc_flts :: [Maybe DictId] + mb_sc_flts = map (lookupVarEnv clone_env) args' + clone_env = zipVarEnv sc_args' (zipWith add_unf sc_args_flt sc_rhss) + subst_prs = (case_bndr, Var (add_unf case_bndr_flt scrut')) + : [ (arg, Var sc_flt) + | (arg, Just sc_flt) <- args `zip` mb_sc_flts ] + subst_rhs' = extendIdSubstList subst_rhs subst_prs + + ; (rhs', rhs_uds) <- specExpr subst_rhs' rhs + ; let scrut_bind = mkDB (NonRec case_bndr_flt scrut') + case_bndr_set = unitVarSet case_bndr_flt + sc_binds = [(NonRec sc_arg_flt sc_rhs, case_bndr_set) + | (sc_arg_flt, sc_rhs) <- sc_args_flt `zip` sc_rhss ] + flt_binds = scrut_bind : sc_binds + (free_uds, dumped_dbs) = dumpUDs (case_bndr':args') rhs_uds + all_uds = flt_binds `addDictBinds` free_uds + alt' = (con, args', wrapDictBindsE dumped_dbs rhs') + ; return (Var case_bndr_flt, case_bndr', [alt'], all_uds) } + where + (subst_rhs, (case_bndr':args')) = substBndrs subst (case_bndr:args) + sc_args' = filter is_flt_sc_arg args' + + clone_me bndr = do { uniq <- getUniqueM + ; return (mkUserLocal occ uniq ty loc) } + where + name = idName bndr + ty = idType bndr + occ = nameOccName name + loc = getSrcSpan name + + add_unf sc_flt sc_rhs -- Sole purpose: make sc_flt respond True to interestingDictId + = setIdUnfolding sc_flt (mkUnfolding False False sc_rhs) + + arg_set = mkVarSet args' + is_flt_sc_arg var = isId var + && not (isDeadBinder var) + && isDictTy var_ty + && not (tyVarsOfType var_ty `intersectsVarSet` arg_set) + where + var_ty = idType var + + +specCase subst scrut case_bndr alts + = do { (alts', uds_alts) <- mapAndCombineSM spec_alt alts + ; return (scrut, case_bndr', alts', uds_alts) } + where + (subst_alt, case_bndr') = substBndr subst case_bndr + spec_alt (con, args, rhs) = do + (rhs', uds) <- specExpr subst_rhs rhs + let (free_uds, dumped_dbs) = dumpUDs (case_bndr' : args') uds + return ((con, args', wrapDictBindsE dumped_dbs rhs'), free_uds) + where + (subst_rhs, args') = substBndrs subst_alt args \end{code} +Note [Floating dictionaries out of cases] +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +Consider + g = \d. case d of { MkD sc ... -> ...(f sc)... } +Naively we can't float d2's binding out of the case expression, +because 'sc' is bound by the case, and that in turn means we can't +specialise f, which seems a pity. + +So we invert the case, by floating out a binding +for 'sc_flt' thus: + sc_flt = case d of { MkD sc ... -> sc } +Now we can float the call instance for 'f'. Indeed this is just +what'll happen if 'sc' was originally bound with a let binding, +but case is more efficient, and necessary with equalities. So it's +good to work with both. + +You might think that this won't make any difference, because the +call instance will only get nuked by the \d. BUT if 'g' itself is +specialised, then transitively we should be able to specialise f. + +In general, given + case e of cb { MkD sc ... -> ...(f sc)... } +we transform to + let cb_flt = e + sc_flt = case cb_flt of { MkD sc ... -> sc } + in + case cb_flt of bg { MkD sc ... -> ....(f sc_flt)... } + +The "_flt" things are the floated binds; we use the current substitution +to substitute sc -> sc_flt in the RHS + %************************************************************************ %* * \subsubsection{Dealing with a binding} @@ -782,7 +873,8 @@ specDefn subst body_uds fn rhs -- See Note [Inline specialisation] for why we do not -- switch off specialisation for inline functions - = do { -- Make a specialised version for each call in calls_for_me + = -- pprTrace "specDefn: some" (ppr fn $$ ppr calls_for_me) $ + do { -- Make a specialised version for each call in calls_for_me stuff <- mapM spec_call calls_for_me ; let (spec_defns, spec_uds, spec_rules) = unzip3 (catMaybes stuff) fn' = addIdSpecialisations fn spec_rules @@ -799,6 +891,7 @@ specDefn subst body_uds fn rhs | otherwise -- No calls or RHS doesn't fit our preconceptions = WARN( notNull calls_for_me, ptext (sLit "Missed specialisation opportunity for") <+> ppr fn ) -- Note [Specialisation shape] + -- pprTrace "specDefn: none" (ppr fn $$ ppr calls_for_me) $ return (fn, [], body_uds_without_me) where @@ -1326,7 +1419,7 @@ There really is not much point in specialising f wrt the dictionary d, because the code for the specialised f is not improved at all, because d is lambda-bound. We simply get junk specialisations. -What is "interesting"? Just that it has *some* structure. +What is "interesting"? Just that it has *some* structure. \begin{code} interestingDict :: CoreExpr -> Bool @@ -1388,6 +1481,9 @@ snocDictBinds uds dbs consDictBind :: CoreBind -> UsageDetails -> UsageDetails consDictBind bind uds = uds { ud_binds = mkDB bind `consBag` ud_binds uds } +addDictBinds :: [DictBind] -> UsageDetails -> UsageDetails +addDictBinds binds uds = uds { ud_binds = listToBag binds `unionBags` ud_binds uds } + snocDictBind :: UsageDetails -> CoreBind -> UsageDetails snocDictBind uds bind = uds { ud_binds = ud_binds uds `snocBag` mkDB bind } @@ -1408,7 +1504,8 @@ dumpUDs :: [CoreBndr] -> UsageDetails -> (UsageDetails, Bag DictBind) -- Used at a lambda or case binder; just dump anything mentioning the binder dumpUDs bndrs uds@(MkUD { ud_binds = orig_dbs, ud_calls = orig_calls }) | null bndrs = (uds, emptyBag) -- Common in case alternatives - | otherwise = (free_uds, dump_dbs) + | otherwise = -- pprTrace "dumpUDs" (ppr bndrs $$ ppr free_uds $$ ppr dump_dbs) $ + (free_uds, dump_dbs) where free_uds = MkUD { ud_binds = free_dbs, ud_calls = free_calls } bndr_set = mkVarSet bndrs @@ -1420,7 +1517,8 @@ dumpUDs bndrs uds@(MkUD { ud_binds = orig_dbs, ud_calls = orig_calls }) dumpBindUDs :: [CoreBndr] -> UsageDetails -> (UsageDetails, Bag DictBind, Bool) -- Used at a lambda or case binder; just dump anything mentioning the binder dumpBindUDs bndrs (MkUD { ud_binds = orig_dbs, ud_calls = orig_calls }) - = (free_uds, dump_dbs, float_all) + = -- pprTrace "dumpBindUDs" (ppr bndrs $$ ppr free_uds $$ ppr dump_dbs) $ + (free_uds, dump_dbs, float_all) where free_uds = MkUD { ud_binds = free_dbs, ud_calls = free_calls } bndr_set = mkVarSet bndrs @@ -1446,7 +1544,7 @@ callsForMe fn (MkUD { ud_binds = orig_dbs, ud_calls = orig_calls }) dep_set = foldlBag go (unitVarSet fn) orig_dbs go dep_set (db,fvs) | fvs `intersectsVarSet` dep_set = extendVarSetList dep_set (bindersOf db) - | otherwise = fvs + | otherwise = dep_set -- Note [Specialisation of dictionary functions] filter_dfuns | isDFunId fn = filter ok_call -- 1.7.10.4