From 7aa3f5247ae454b10b61e2f28a9431f0889a8cff Mon Sep 17 00:00:00 2001 From: "simonpj@microsoft.com" Date: Thu, 11 Jan 2007 09:15:33 +0000 Subject: [PATCH] Make the LiberateCase transformation understand associated types Consider this FC program: data family AT a :: * data instance AT Int = T1 Int Int f :: AT Int -> Int f t = case t of DEFAULT -> We'd like to replace the DEFAULT by a use of T1, so that if we scrutinise t inside we share the evaluation: f t = case (t `cast` co) of T1 x y -> I decided to do this as part of the liberate-case transformation, which is already trying to avoid redundant evals. The new transformation requires knowledge of the family instance environment, so I had to extend ModGuts to carry the fam_inst_env, and put that envt into the liberate-case environment. Otherwise it's all pretty straightforward. --- compiler/deSugar/Desugar.lhs | 70 +++++----- compiler/main/HscTypes.lhs | 6 +- compiler/simplCore/LiberateCase.lhs | 247 +++++++++++++++++++++++------------ compiler/simplCore/SimplCore.lhs | 4 +- compiler/typecheck/TcRnDriver.lhs | 1 + 5 files changed, 212 insertions(+), 116 deletions(-) diff --git a/compiler/deSugar/Desugar.lhs b/compiler/deSugar/Desugar.lhs index 970bd20..b4ff273 100644 --- a/compiler/deSugar/Desugar.lhs +++ b/compiler/deSugar/Desugar.lhs @@ -60,23 +60,24 @@ deSugar :: HscEnv -> ModLocation -> TcGblEnv -> IO (Maybe ModGuts) deSugar hsc_env mod_loc - tcg_env@(TcGblEnv { tcg_mod = mod, - tcg_src = hsc_src, - tcg_type_env = type_env, - tcg_imports = imports, - tcg_exports = exports, - tcg_dus = dus, - tcg_inst_uses = dfun_uses_var, - tcg_th_used = th_var, - tcg_keep = keep_var, - tcg_rdr_env = rdr_env, - tcg_fix_env = fix_env, - tcg_deprecs = deprecs, - tcg_binds = binds, - tcg_fords = fords, - tcg_rules = rules, - tcg_insts = insts, - tcg_fam_insts = fam_insts }) + tcg_env@(TcGblEnv { tcg_mod = mod, + tcg_src = hsc_src, + tcg_type_env = type_env, + tcg_imports = imports, + tcg_exports = exports, + tcg_dus = dus, + tcg_inst_uses = dfun_uses_var, + tcg_th_used = th_var, + tcg_keep = keep_var, + tcg_rdr_env = rdr_env, + tcg_fix_env = fix_env, + tcg_fam_inst_env = fam_inst_env, + tcg_deprecs = deprecs, + tcg_binds = binds, + tcg_fords = fords, + tcg_rules = rules, + tcg_insts = insts, + tcg_fam_insts = fam_insts }) = do { showPass dflags "Desugar" -- Desugar the program @@ -156,23 +157,24 @@ deSugar hsc_env -- sort to get into canonical order mod_guts = ModGuts { - mg_module = mod, - mg_boot = isHsBoot hsc_src, - mg_exports = exports, - mg_deps = deps, - mg_usages = usages, - mg_dir_imps = [m | (m,_,_) <- moduleEnvElts dir_imp_mods], - mg_rdr_env = rdr_env, - mg_fix_env = fix_env, - mg_deprecs = deprecs, - mg_types = type_env, - mg_insts = insts, - mg_fam_insts = fam_insts, - mg_rules = ds_rules, - mg_binds = ds_binds, - mg_foreign = ds_fords, - mg_hpc_info = ds_hpc_info, - mg_dbg_sites = dbgSites } + mg_module = mod, + mg_boot = isHsBoot hsc_src, + mg_exports = exports, + mg_deps = deps, + mg_usages = usages, + mg_dir_imps = [m | (m,_,_) <- moduleEnvElts dir_imp_mods], + mg_rdr_env = rdr_env, + mg_fix_env = fix_env, + mg_deprecs = deprecs, + mg_types = type_env, + mg_insts = insts, + mg_fam_insts = fam_insts, + mg_fam_inst_env = fam_inst_env, + mg_rules = ds_rules, + mg_binds = ds_binds, + mg_foreign = ds_fords, + mg_hpc_info = ds_hpc_info, + mg_dbg_sites = dbgSites } ; return (Just mod_guts) }}} diff --git a/compiler/main/HscTypes.lhs b/compiler/main/HscTypes.lhs index 4155807..2b8f8f7 100644 --- a/compiler/main/HscTypes.lhs +++ b/compiler/main/HscTypes.lhs @@ -485,7 +485,10 @@ data ModGuts mg_rdr_env :: !GlobalRdrEnv, -- Top-level lexical environment mg_fix_env :: !FixityEnv, -- Fixity env, for things declared in -- this module - mg_deprecs :: !Deprecations, -- Deprecations declared in the module + + mg_fam_inst_env :: FamInstEnv, -- Type-family instance enviroment + -- for *home-package* modules (including + -- this one). c.f. tcg_fam_inst_env mg_types :: !TypeEnv, mg_insts :: ![Instance], -- Instances @@ -493,6 +496,7 @@ data ModGuts mg_rules :: ![CoreRule], -- Rules from this module mg_binds :: ![CoreBind], -- Bindings for this module mg_foreign :: !ForeignStubs, + mg_deprecs :: !Deprecations, -- Deprecations declared in the module mg_hpc_info :: !HpcInfo, -- info about coverage tick boxes mg_dbg_sites :: ![(SiteNumber, Coord)] -- Bkpts inserted by the renamer } diff --git a/compiler/simplCore/LiberateCase.lhs b/compiler/simplCore/LiberateCase.lhs index 67d2e5c..31063d3 100644 --- a/compiler/simplCore/LiberateCase.lhs +++ b/compiler/simplCore/LiberateCase.lhs @@ -8,18 +8,28 @@ module LiberateCase ( liberateCase ) where #include "HsVersions.h" -import DynFlags ( DynFlags, DynFlag(..) ) -import StaticFlags ( opt_LiberateCaseThreshold ) +import DynFlags +import HscTypes import CoreLint ( showPass, endPass ) import CoreSyn import CoreUnfold ( couldBeSmallEnoughToInline ) -import Id ( Id, setIdName, idName, setIdNotExported ) +import Rules ( RuleBase ) +import UniqSupply ( UniqSupply ) +import SimplMonad ( SimplCount, zeroSimplCount ) +import Id +import FamInstEnv +import Type +import Coercion +import TyCon import VarEnv import Name ( localiseName ) import Outputable import Util ( notNull ) +import Data.IORef ( readIORef ) \end{code} +The liberate-case transformation +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ This module walks over @Core@, and looks for @case@ on free variables. The criterion is: if there is case on a free on the route to the recursive call, @@ -27,30 +37,24 @@ The criterion is: Example -\begin{verbatim} -f = \ t -> case v of - V a b -> a : f t -\end{verbatim} + f = \ t -> case v of + V a b -> a : f t => the inner f is replaced. -\begin{verbatim} -f = \ t -> case v of - V a b -> a : (letrec + f = \ t -> case v of + V a b -> a : (letrec f = \ t -> case v of V a b -> a : f t - in f) t -\end{verbatim} + in f) t (note the NEED for shadowing) => Simplify -\begin{verbatim} -f = \ t -> case v of - V a b -> a : (letrec + f = \ t -> case v of + V a b -> a : (letrec f = \ t -> a : f t - in f t) -\begin{verbatim} + in f t) Better code, because 'a' is free inside the inner letrec, rather than needing projection from v. @@ -72,7 +76,6 @@ We'd like to avoid the redundant pattern match, transforming to (is this necessarily an improvement) - Similarly drop: drop n [] = [] @@ -119,66 +122,64 @@ scope. For example: Here, the level of @f@ is zero, the level of @g@ is one, and the level of @h@ is zero (NB not one). -\begin{code} -type LibCaseLevel = Int - -topLevel :: LibCaseLevel -topLevel = 0 -\end{code} - -\begin{code} -data LibCaseEnv - = LibCaseEnv { - lc_size :: Int, -- Bomb-out size for deciding if - -- potential liberatees are too big. - -- (passed in from cmd-line args) - - lc_lvl :: LibCaseLevel, -- Current level - - lc_lvl_env :: IdEnv LibCaseLevel, - -- Binds all non-top-level in-scope Ids - -- (top-level and imported things have - -- a level of zero) - - lc_rec_env :: IdEnv CoreBind, - -- Binds *only* recursively defined ids, - -- to their own binding group, - -- and *only* in their own RHSs +Note [Indexed data types] +~~~~~~~~~~~~~~~~~~~~~~~~~ +Consider + data family T :: * -> * + data T Int = TI Int + + f :: T Int -> Bool + f x = case x of { DEFAULT -> } + +We would like to change this to + f x = case x `cast` co of { TI p -> } + +so that can make use of the fact that x is already evaluated to +a TI; and a case on a known data type may be more efficient than a +polymorphic one (not sure this is true any longer). Anyway the former +showed up in Roman's experiments. Example: + foo :: FooT Int -> Int -> Int + foo t n = t `seq` bar n + where + bar 0 = 0 + bar n = bar (n - case t of TI i -> i) +Here we'd like to avoid repeated evaluating t inside the loop, by +taking advantage of the `seq`. + +We implement this as part of the liberate-case transformation by +spotting + case of (x::T) tys { DEFAULT -> } +where x :: T tys, and T is a indexed family tycon. Find the +representation type (T77 tys'), and coercion co, and transform to + case `cast` co of (y::T77 tys') + DEFAULT -> let x = y `cast` sym co in + +The "find the representation type" part is done by looking up in the +family-instance environment. + +NB: in fact we re-use x (changing its type) to avoid making a fresh y; +this entails shadowing, but that's ok. + +%************************************************************************ +%* * + Top-level code +%* * +%************************************************************************ - lc_scruts :: [(Id,LibCaseLevel)] - -- Each of these Ids was scrutinised by an - -- enclosing case expression, with the - -- specified number of enclosing - -- recursive bindings; furthermore, - -- the Id is bound at a lower level - -- than the case expression. The order is - -- insignificant; it's a bag really - --- lc_fams :: FamInstEnvs - -- Instance env for indexed data types - } - -initEnv :: Int -> LibCaseEnv -initEnv bomb_size - = LibCaseEnv { lc_size = bomb_size, lc_lvl = 0, - lc_lvl_env = emptyVarEnv, lc_rec_env = emptyVarEnv, - lc_scruts = [] } - -bombOutSize = lc_size -\end{code} - - -Programs -~~~~~~~~ \begin{code} -liberateCase :: DynFlags -> [CoreBind] -> IO [CoreBind] -liberateCase dflags binds - = do { - showPass dflags "Liberate case" ; - let { binds' = do_prog (initEnv opt_LiberateCaseThreshold) binds } ; - endPass dflags "Liberate case" Opt_D_verbose_core2core binds' - {- no specific flag for dumping -} - } +liberateCase :: HscEnv -> UniqSupply -> RuleBase -> ModGuts + -> IO (SimplCount, ModGuts) +liberateCase hsc_env _ _ guts + = do { let dflags = hsc_dflags hsc_env + ; eps <- readIORef (hsc_EPS hsc_env) + ; let fam_envs = (eps_fam_inst_env eps, mg_fam_inst_env guts) + + ; showPass dflags "Liberate case" + ; let { env = initEnv dflags fam_envs + ; binds' = do_prog env (mg_binds guts) } + ; endPass dflags "Liberate case" Opt_D_verbose_core2core binds' + {- no specific flag for dumping -} + ; return (zeroSimplCount dflags, guts { mg_binds = binds' }) } where do_prog env [] = [] do_prog env (bind:binds) = bind' : do_prog env' binds @@ -186,9 +187,15 @@ liberateCase dflags binds (env', bind') = libCaseBind env bind \end{code} + +%************************************************************************ +%* * + Main payload +%* * +%************************************************************************ + Bindings ~~~~~~~~ - \begin{code} libCaseBind :: LibCaseEnv -> CoreBind -> (LibCaseEnv, CoreBind) @@ -254,7 +261,7 @@ libCase env (Let bind body) (env_body, bind') = libCaseBind env bind libCase env (Case scrut bndr ty alts) - = Case (libCase env scrut) bndr ty (map (libCaseAlt env_alts) alts) + = mkCase env (libCase env scrut) bndr ty (map (libCaseAlt env_alts) alts) where env_alts = addBinders (mk_alt_env scrut) [bndr] mk_alt_env (Var scrut_var) = addScrutedVar env scrut_var @@ -264,6 +271,24 @@ libCase env (Case scrut bndr ty alts) libCaseAlt env (con,args,rhs) = (con, args, libCase (addBinders env args) rhs) \end{code} +\begin{code} +mkCase :: LibCaseEnv -> CoreExpr -> Id -> Type -> [CoreAlt] -> CoreExpr +-- See Note [Indexed data types] +mkCase env scrut bndr ty [(DEFAULT,_,rhs)] + | Just (tycon, tys) <- splitTyConApp_maybe (idType bndr) + , [(subst, fam_inst)] <- lookupFamInstEnv (lc_fams env) tycon tys + = let + rep_tc = famInstTyCon fam_inst + rep_tys = map (substTyVar subst) (tyConTyVars rep_tc) + bndr' = setIdType bndr (mkTyConApp rep_tc rep_tys) + Just co_tc = tyConFamilyCoercion_maybe rep_tc + co = mkTyConApp co_tc rep_tys + bind = NonRec bndr (Cast (Var bndr') (mkSymCoercion co)) + in mkCase env (Cast scrut co) bndr' ty [(DEFAULT,[],Let bind rhs)] +mkCase env scrut bndr ty alts + = Case scrut bndr ty alts +\end{code} + Ids ~~~ \begin{code} @@ -282,9 +307,12 @@ libCaseId env v \end{code} +%************************************************************************ +%* * + Utility functions +%* * +%************************************************************************ -Utility functions -~~~~~~~~~~~~~~~~~ \begin{code} addBinders :: LibCaseEnv -> [CoreBndr] -> LibCaseEnv addBinders env@(LibCaseEnv { lc_lvl = lvl, lc_lvl_env = lvl_env }) binders @@ -335,3 +363,62 @@ freeScruts :: LibCaseEnv freeScruts env rec_bind_lvl = [v | (v,scrut_lvl) <- lc_scruts env, scrut_lvl > rec_bind_lvl] \end{code} + +%************************************************************************ +%* * + The environment +%* * +%************************************************************************ + +\begin{code} +type LibCaseLevel = Int + +topLevel :: LibCaseLevel +topLevel = 0 +\end{code} + +\begin{code} +data LibCaseEnv + = LibCaseEnv { + lc_size :: Int, -- Bomb-out size for deciding if + -- potential liberatees are too big. + -- (passed in from cmd-line args) + + lc_lvl :: LibCaseLevel, -- Current level + + lc_lvl_env :: IdEnv LibCaseLevel, + -- Binds all non-top-level in-scope Ids + -- (top-level and imported things have + -- a level of zero) + + lc_rec_env :: IdEnv CoreBind, + -- Binds *only* recursively defined ids, + -- to their own binding group, + -- and *only* in their own RHSs + + lc_scruts :: [(Id,LibCaseLevel)], + -- Each of these Ids was scrutinised by an + -- enclosing case expression, with the + -- specified number of enclosing + -- recursive bindings; furthermore, + -- the Id is bound at a lower level + -- than the case expression. The order is + -- insignificant; it's a bag really + + lc_fams :: FamInstEnvs + -- Instance env for indexed data types + } + +initEnv :: DynFlags -> FamInstEnvs -> LibCaseEnv +initEnv dflags fams + = LibCaseEnv { lc_size = libCaseThreshold dflags, + lc_lvl = 0, + lc_lvl_env = emptyVarEnv, + lc_rec_env = emptyVarEnv, + lc_scruts = [], + lc_fams = fams } + +bombOutSize = lc_size +\end{code} + + diff --git a/compiler/simplCore/SimplCore.lhs b/compiler/simplCore/SimplCore.lhs index 2fd1026..41e0922 100644 --- a/compiler/simplCore/SimplCore.lhs +++ b/compiler/simplCore/SimplCore.lhs @@ -134,7 +134,7 @@ doCorePasses hsc_env rb us stats guts (to_do : to_dos) doCorePass (CoreDoSimplify mode sws) = _scc_ "Simplify" simplifyPgm mode sws doCorePass CoreCSE = _scc_ "CommonSubExpr" trBinds cseProgram -doCorePass CoreLiberateCase = _scc_ "LiberateCase" trBinds liberateCase +doCorePass CoreLiberateCase = _scc_ "LiberateCase" liberateCase doCorePass CoreDoFloatInwards = _scc_ "FloatInwards" trBinds floatInwards doCorePass (CoreDoFloatOutwards f) = _scc_ "FloatOutwards" trBindsU (floatOutwards f) doCorePass CoreDoStaticArgs = _scc_ "StaticArgs" trBinds doStaticArgs @@ -148,6 +148,8 @@ doCorePass (CoreDoRuleCheck phase pat) = observe (ruleCheck phase pat) doCorePass CoreDoNothing = observe (\ _ _ -> return ()) #ifdef OLD_STRICTNESS doCorePass CoreDoOldStrictness = _scc_ "OldStrictness" trBinds doOldStrictness +#else +doCorePass CoreDoOldStrictness = panic "CoreDoOldStrictness" #endif #ifdef OLD_STRICTNESS diff --git a/compiler/typecheck/TcRnDriver.lhs b/compiler/typecheck/TcRnDriver.lhs index aee72c8..eabd3bc 100644 --- a/compiler/typecheck/TcRnDriver.lhs +++ b/compiler/typecheck/TcRnDriver.lhs @@ -302,6 +302,7 @@ tcRnExtCore hsc_env (HsExtCore this_mod decls src_binds) mg_types = final_type_env, mg_insts = tcg_insts tcg_env, mg_fam_insts = tcg_fam_insts tcg_env, + mg_fam_inst_env = tcg_fam_inst_env tcg_env, mg_rules = [], mg_binds = core_binds, -- 1.7.10.4