Make the LiberateCase transformation understand associated types
authorsimonpj@microsoft.com <unknown>
Thu, 11 Jan 2007 09:15:33 +0000 (09:15 +0000)
committersimonpj@microsoft.com <unknown>
Thu, 11 Jan 2007 09:15:33 +0000 (09:15 +0000)
Consider this FC program:
data family AT a :: *
data instance AT Int = T1 Int Int

f :: AT Int -> Int
f t = case t of DEFAULT -> <body>

We'd like to replace the DEFAULT by a use of T1, so that if
we scrutinise t inside <body> we share the evaluation:

f t = case (t `cast` co) of T1 x y -> <body>

I decided to do this as part of the liberate-case transformation,
which is already trying to avoid redundant evals.

The new transformation requires knowledge of the family instance
environment, so I had to extend ModGuts to carry the fam_inst_env,
and put that envt into the liberate-case environment.

Otherwise it's all pretty straightforward.

compiler/deSugar/Desugar.lhs
compiler/main/HscTypes.lhs
compiler/simplCore/LiberateCase.lhs
compiler/simplCore/SimplCore.lhs
compiler/typecheck/TcRnDriver.lhs

index 970bd20..b4ff273 100644 (file)
@@ -60,23 +60,24 @@ deSugar :: HscEnv -> ModLocation -> TcGblEnv -> IO (Maybe ModGuts)
 
 deSugar hsc_env 
         mod_loc
 
 deSugar hsc_env 
         mod_loc
-        tcg_env@(TcGblEnv { tcg_mod       = mod,
-                           tcg_src       = hsc_src,
-                           tcg_type_env  = type_env,
-                           tcg_imports   = imports,
-                           tcg_exports   = exports,
-                           tcg_dus       = dus, 
-                           tcg_inst_uses = dfun_uses_var,
-                           tcg_th_used   = th_var,
-                           tcg_keep      = keep_var,
-                           tcg_rdr_env   = rdr_env,
-                           tcg_fix_env   = fix_env,
-                           tcg_deprecs   = deprecs,
-                           tcg_binds     = binds,
-                           tcg_fords     = fords,
-                           tcg_rules     = rules,
-                           tcg_insts     = insts,
-                           tcg_fam_insts = fam_insts })
+        tcg_env@(TcGblEnv { tcg_mod          = mod,
+                           tcg_src          = hsc_src,
+                           tcg_type_env     = type_env,
+                           tcg_imports      = imports,
+                           tcg_exports      = exports,
+                           tcg_dus          = dus, 
+                           tcg_inst_uses    = dfun_uses_var,
+                           tcg_th_used      = th_var,
+                           tcg_keep         = keep_var,
+                           tcg_rdr_env      = rdr_env,
+                           tcg_fix_env      = fix_env,
+                           tcg_fam_inst_env = fam_inst_env,
+                           tcg_deprecs      = deprecs,
+                           tcg_binds        = binds,
+                           tcg_fords        = fords,
+                           tcg_rules        = rules,
+                           tcg_insts        = insts,
+                           tcg_fam_insts    = fam_insts })
   = do { showPass dflags "Desugar"
 
        -- Desugar the program
   = do { showPass dflags "Desugar"
 
        -- Desugar the program
@@ -156,23 +157,24 @@ deSugar hsc_env
                -- sort to get into canonical order
 
             mod_guts = ModGuts {       
                -- sort to get into canonical order
 
             mod_guts = ModGuts {       
-               mg_module    = mod,
-               mg_boot      = isHsBoot hsc_src,
-               mg_exports   = exports,
-               mg_deps      = deps,
-               mg_usages    = usages,
-               mg_dir_imps  = [m | (m,_,_) <- moduleEnvElts dir_imp_mods],
-               mg_rdr_env   = rdr_env,
-               mg_fix_env   = fix_env,
-               mg_deprecs   = deprecs,
-               mg_types     = type_env,
-               mg_insts     = insts,
-               mg_fam_insts = fam_insts,
-               mg_rules     = ds_rules,
-               mg_binds     = ds_binds,
-               mg_foreign   = ds_fords,
-               mg_hpc_info  = ds_hpc_info,
-                mg_dbg_sites = dbgSites }
+               mg_module       = mod,
+               mg_boot         = isHsBoot hsc_src,
+               mg_exports      = exports,
+               mg_deps         = deps,
+               mg_usages       = usages,
+               mg_dir_imps     = [m | (m,_,_) <- moduleEnvElts dir_imp_mods],
+               mg_rdr_env      = rdr_env,
+               mg_fix_env      = fix_env,
+               mg_deprecs      = deprecs,
+               mg_types        = type_env,
+               mg_insts        = insts,
+               mg_fam_insts    = fam_insts,
+               mg_fam_inst_env = fam_inst_env,
+               mg_rules        = ds_rules,
+               mg_binds        = ds_binds,
+               mg_foreign      = ds_fords,
+               mg_hpc_info     = ds_hpc_info,
+                mg_dbg_sites   = dbgSites }
         ; return (Just mod_guts)
        }}}
 
         ; return (Just mod_guts)
        }}}
 
index 4155807..2b8f8f7 100644 (file)
@@ -485,7 +485,10 @@ data ModGuts
         mg_rdr_env   :: !GlobalRdrEnv,  -- Top-level lexical environment
        mg_fix_env   :: !FixityEnv,      -- Fixity env, for things declared in
                                         --   this module 
         mg_rdr_env   :: !GlobalRdrEnv,  -- Top-level lexical environment
        mg_fix_env   :: !FixityEnv,      -- Fixity env, for things declared in
                                         --   this module 
-       mg_deprecs   :: !Deprecations,   -- Deprecations declared in the module
+
+       mg_fam_inst_env :: FamInstEnv,   -- Type-family instance enviroment
+                                        -- for *home-package* modules (including
+                                        -- this one).  c.f. tcg_fam_inst_env
 
        mg_types     :: !TypeEnv,
        mg_insts     :: ![Instance],     -- Instances 
 
        mg_types     :: !TypeEnv,
        mg_insts     :: ![Instance],     -- Instances 
@@ -493,6 +496,7 @@ data ModGuts
         mg_rules     :: ![CoreRule],    -- Rules from this module
        mg_binds     :: ![CoreBind],     -- Bindings for this module
        mg_foreign   :: !ForeignStubs,
         mg_rules     :: ![CoreRule],    -- Rules from this module
        mg_binds     :: ![CoreBind],     -- Bindings for this module
        mg_foreign   :: !ForeignStubs,
+       mg_deprecs   :: !Deprecations,   -- Deprecations declared in the module
        mg_hpc_info  :: !HpcInfo,        -- info about coverage tick boxes
         mg_dbg_sites :: ![(SiteNumber, Coord)]     -- Bkpts inserted by the renamer
     }
        mg_hpc_info  :: !HpcInfo,        -- info about coverage tick boxes
         mg_dbg_sites :: ![(SiteNumber, Coord)]     -- Bkpts inserted by the renamer
     }
index 67d2e5c..31063d3 100644 (file)
@@ -8,18 +8,28 @@ module LiberateCase ( liberateCase ) where
 
 #include "HsVersions.h"
 
 
 #include "HsVersions.h"
 
-import DynFlags                ( DynFlags, DynFlag(..) )
-import StaticFlags     ( opt_LiberateCaseThreshold )
+import DynFlags
+import HscTypes
 import CoreLint                ( showPass, endPass )
 import CoreSyn
 import CoreUnfold      ( couldBeSmallEnoughToInline )
 import CoreLint                ( showPass, endPass )
 import CoreSyn
 import CoreUnfold      ( couldBeSmallEnoughToInline )
-import Id              ( Id, setIdName, idName, setIdNotExported )
+import Rules           ( RuleBase )
+import UniqSupply      ( UniqSupply )
+import SimplMonad      ( SimplCount, zeroSimplCount )
+import Id
+import FamInstEnv
+import Type
+import Coercion
+import TyCon
 import VarEnv
 import Name            ( localiseName )
 import Outputable
 import Util             ( notNull )
 import VarEnv
 import Name            ( localiseName )
 import Outputable
 import Util             ( notNull )
+import Data.IORef      ( readIORef )
 \end{code}
 
 \end{code}
 
+The liberate-case transformation
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 This module walks over @Core@, and looks for @case@ on free variables.
 The criterion is:
        if there is case on a free on the route to the recursive call,
 This module walks over @Core@, and looks for @case@ on free variables.
 The criterion is:
        if there is case on a free on the route to the recursive call,
@@ -27,30 +37,24 @@ The criterion is:
 
 Example
 
 
 Example
 
-\begin{verbatim}
-f = \ t -> case v of
-              V a b -> a : f t
-\end{verbatim}
+   f = \ t -> case v of
+                V a b -> a : f t
 
 => the inner f is replaced.
 
 
 => the inner f is replaced.
 
-\begin{verbatim}
-f = \ t -> case v of
-              V a b -> a : (letrec
+   f = \ t -> case v of
+                V a b -> a : (letrec
                                f =  \ t -> case v of
                                               V a b -> a : f t
                                f =  \ t -> case v of
                                               V a b -> a : f t
-                            in f) t
-\end{verbatim}
+                              in f) t
 (note the NEED for shadowing)
 
 => Simplify
 
 (note the NEED for shadowing)
 
 => Simplify
 
-\begin{verbatim}
-f = \ t -> case v of
-              V a b -> a : (letrec
+  f = \ t -> case v of
+                V a b -> a : (letrec
                                f = \ t -> a : f t
                                f = \ t -> a : f t
-                            in f t)
-\begin{verbatim}
+                              in f t)
 
 Better code, because 'a' is  free inside the inner letrec, rather
 than needing projection from v.
 
 Better code, because 'a' is  free inside the inner letrec, rather
 than needing projection from v.
@@ -72,7 +76,6 @@ We'd like to avoid the redundant pattern match, transforming to
 
        (is this necessarily an improvement)
 
 
        (is this necessarily an improvement)
 
-
 Similarly drop:
 
        drop n [] = []
 Similarly drop:
 
        drop n [] = []
@@ -119,66 +122,64 @@ scope.  For example:
 Here, the level of @f@ is zero, the level of @g@ is one,
 and the level of @h@ is zero (NB not one).
 
 Here, the level of @f@ is zero, the level of @g@ is one,
 and the level of @h@ is zero (NB not one).
 
-\begin{code}
-type LibCaseLevel = Int
-
-topLevel :: LibCaseLevel
-topLevel = 0
-\end{code}
-
-\begin{code}
-data LibCaseEnv
-  = LibCaseEnv {
-       lc_size :: Int,         -- Bomb-out size for deciding if
-                               -- potential liberatees are too big.
-                               -- (passed in from cmd-line args)
-
-       lc_lvl :: LibCaseLevel, -- Current level
-
-       lc_lvl_env :: IdEnv LibCaseLevel,  
-                       -- Binds all non-top-level in-scope Ids
-                       -- (top-level and imported things have
-                       -- a level of zero)
-
-       lc_rec_env :: IdEnv CoreBind, 
-                       -- Binds *only* recursively defined ids, 
-                       -- to their own binding group,
-                       -- and *only* in their own RHSs
+Note [Indexed data types]
+~~~~~~~~~~~~~~~~~~~~~~~~~
+Consider
+       data family T :: * -> *
+       data T Int = TI Int
+
+       f :: T Int -> Bool
+       f x = case x of { DEFAULT -> <body> }
+
+We would like to change this to
+       f x = case x `cast` co of { TI p -> <body> }
+
+so that <body> can make use of the fact that x is already evaluated to
+a TI; and a case on a known data type may be more efficient than a
+polymorphic one (not sure this is true any longer).  Anyway the former
+showed up in Roman's experiments.  Example:
+  foo :: FooT Int -> Int -> Int
+  foo t n = t `seq` bar n
+     where
+       bar 0 = 0
+       bar n = bar (n - case t of TI i -> i)
+Here we'd like to avoid repeated evaluating t inside the loop, by 
+taking advantage of the `seq`.
+
+We implement this as part of the liberate-case transformation by 
+spotting
+       case <scrut> of (x::T) tys { DEFAULT ->  <body> }
+where x :: T tys, and T is a indexed family tycon.  Find the
+representation type (T77 tys'), and coercion co, and transform to
+       case <scrut> `cast` co of (y::T77 tys')
+           DEFAULT -> let x = y `cast` sym co in <body>
+
+The "find the representation type" part is done by looking up in the
+family-instance environment.
+
+NB: in fact we re-use x (changing its type) to avoid making a fresh y;
+this entails shadowing, but that's ok.
+
+%************************************************************************
+%*                                                                     *
+        Top-level code
+%*                                                                     *
+%************************************************************************
 
 
-       lc_scruts :: [(Id,LibCaseLevel)]
-                       -- Each of these Ids was scrutinised by an
-                       -- enclosing case expression, with the
-                       -- specified number of enclosing
-                       -- recursive bindings; furthermore,
-                       -- the Id is bound at a lower level
-                       -- than the case expression.  The order is
-                       -- insignificant; it's a bag really
-
---     lc_fams :: FamInstEnvs
-                       -- Instance env for indexed data types 
-       }
-
-initEnv :: Int -> LibCaseEnv
-initEnv bomb_size
-  = LibCaseEnv { lc_size = bomb_size, lc_lvl = 0,
-                lc_lvl_env = emptyVarEnv, lc_rec_env = emptyVarEnv,
-                lc_scruts = [] }
-
-bombOutSize = lc_size
-\end{code}
-
-
-Programs
-~~~~~~~~
 \begin{code}
 \begin{code}
-liberateCase :: DynFlags -> [CoreBind] -> IO [CoreBind]
-liberateCase dflags binds
-  = do {
-       showPass dflags "Liberate case" ;
-       let { binds' = do_prog (initEnv opt_LiberateCaseThreshold) binds } ;
-       endPass dflags "Liberate case" Opt_D_verbose_core2core binds'
-                               {- no specific flag for dumping -} 
-    }
+liberateCase :: HscEnv -> UniqSupply -> RuleBase -> ModGuts
+            -> IO (SimplCount, ModGuts)
+liberateCase hsc_env _ _ guts
+  = do { let dflags = hsc_dflags hsc_env
+       ; eps <- readIORef (hsc_EPS hsc_env)
+       ; let fam_envs = (eps_fam_inst_env eps, mg_fam_inst_env guts)
+
+       ; showPass dflags "Liberate case"
+       ; let { env = initEnv dflags fam_envs
+             ; binds' = do_prog env (mg_binds guts) }
+       ; endPass dflags "Liberate case" Opt_D_verbose_core2core binds'
+                       {- no specific flag for dumping -} 
+       ; return (zeroSimplCount dflags, guts { mg_binds = binds' }) }
   where
     do_prog env [] = []
     do_prog env (bind:binds) = bind' : do_prog env' binds
   where
     do_prog env [] = []
     do_prog env (bind:binds) = bind' : do_prog env' binds
@@ -186,9 +187,15 @@ liberateCase dflags binds
                               (env', bind') = libCaseBind env bind
 \end{code}
 
                               (env', bind') = libCaseBind env bind
 \end{code}
 
+
+%************************************************************************
+%*                                                                     *
+        Main payload
+%*                                                                     *
+%************************************************************************
+
 Bindings
 ~~~~~~~~
 Bindings
 ~~~~~~~~
-
 \begin{code}
 libCaseBind :: LibCaseEnv -> CoreBind -> (LibCaseEnv, CoreBind)
 
 \begin{code}
 libCaseBind :: LibCaseEnv -> CoreBind -> (LibCaseEnv, CoreBind)
 
@@ -254,7 +261,7 @@ libCase env (Let bind body)
     (env_body, bind') = libCaseBind env bind
 
 libCase env (Case scrut bndr ty alts)
     (env_body, bind') = libCaseBind env bind
 
 libCase env (Case scrut bndr ty alts)
-  = Case (libCase env scrut) bndr ty (map (libCaseAlt env_alts) alts)
+  = mkCase env (libCase env scrut) bndr ty (map (libCaseAlt env_alts) alts)
   where
     env_alts = addBinders (mk_alt_env scrut) [bndr]
     mk_alt_env (Var scrut_var) = addScrutedVar env scrut_var
   where
     env_alts = addBinders (mk_alt_env scrut) [bndr]
     mk_alt_env (Var scrut_var) = addScrutedVar env scrut_var
@@ -264,6 +271,24 @@ libCase env (Case scrut bndr ty alts)
 libCaseAlt env (con,args,rhs) = (con, args, libCase (addBinders env args) rhs)
 \end{code}
 
 libCaseAlt env (con,args,rhs) = (con, args, libCase (addBinders env args) rhs)
 \end{code}
 
+\begin{code}
+mkCase :: LibCaseEnv -> CoreExpr -> Id -> Type -> [CoreAlt] -> CoreExpr
+-- See Note [Indexed data types]
+mkCase env scrut bndr ty [(DEFAULT,_,rhs)]
+  | Just (tycon, tys)   <- splitTyConApp_maybe (idType bndr)
+  , [(subst, fam_inst)] <- lookupFamInstEnv (lc_fams env) tycon tys
+  = let 
+       rep_tc     = famInstTyCon fam_inst
+       rep_tys    = map (substTyVar subst) (tyConTyVars rep_tc)
+       bndr'      = setIdType bndr (mkTyConApp rep_tc rep_tys)
+       Just co_tc = tyConFamilyCoercion_maybe rep_tc
+       co         = mkTyConApp co_tc rep_tys
+       bind       = NonRec bndr (Cast (Var bndr') (mkSymCoercion co))
+    in mkCase env (Cast scrut co) bndr' ty [(DEFAULT,[],Let bind rhs)]
+mkCase env scrut bndr ty alts
+  = Case scrut bndr ty alts
+\end{code}
+
 Ids
 ~~~
 \begin{code}
 Ids
 ~~~
 \begin{code}
@@ -282,9 +307,12 @@ libCaseId env v
 \end{code}
 
 
 \end{code}
 
 
+%************************************************************************
+%*                                                                     *
+       Utility functions
+%*                                                                     *
+%************************************************************************
 
 
-Utility functions
-~~~~~~~~~~~~~~~~~
 \begin{code}
 addBinders :: LibCaseEnv -> [CoreBndr] -> LibCaseEnv
 addBinders env@(LibCaseEnv { lc_lvl = lvl, lc_lvl_env = lvl_env }) binders
 \begin{code}
 addBinders :: LibCaseEnv -> [CoreBndr] -> LibCaseEnv
 addBinders env@(LibCaseEnv { lc_lvl = lvl, lc_lvl_env = lvl_env }) binders
@@ -335,3 +363,62 @@ freeScruts :: LibCaseEnv
 freeScruts env rec_bind_lvl
   = [v | (v,scrut_lvl) <- lc_scruts env, scrut_lvl > rec_bind_lvl]
 \end{code}
 freeScruts env rec_bind_lvl
   = [v | (v,scrut_lvl) <- lc_scruts env, scrut_lvl > rec_bind_lvl]
 \end{code}
+
+%************************************************************************
+%*                                                                     *
+        The environment
+%*                                                                     *
+%************************************************************************
+
+\begin{code}
+type LibCaseLevel = Int
+
+topLevel :: LibCaseLevel
+topLevel = 0
+\end{code}
+
+\begin{code}
+data LibCaseEnv
+  = LibCaseEnv {
+       lc_size :: Int,         -- Bomb-out size for deciding if
+                               -- potential liberatees are too big.
+                               -- (passed in from cmd-line args)
+
+       lc_lvl :: LibCaseLevel, -- Current level
+
+       lc_lvl_env :: IdEnv LibCaseLevel,  
+                       -- Binds all non-top-level in-scope Ids
+                       -- (top-level and imported things have
+                       -- a level of zero)
+
+       lc_rec_env :: IdEnv CoreBind, 
+                       -- Binds *only* recursively defined ids, 
+                       -- to their own binding group,
+                       -- and *only* in their own RHSs
+
+       lc_scruts :: [(Id,LibCaseLevel)],
+                       -- Each of these Ids was scrutinised by an
+                       -- enclosing case expression, with the
+                       -- specified number of enclosing
+                       -- recursive bindings; furthermore,
+                       -- the Id is bound at a lower level
+                       -- than the case expression.  The order is
+                       -- insignificant; it's a bag really
+
+       lc_fams :: FamInstEnvs
+                       -- Instance env for indexed data types 
+       }
+
+initEnv :: DynFlags -> FamInstEnvs -> LibCaseEnv
+initEnv dflags fams
+  = LibCaseEnv { lc_size = libCaseThreshold dflags,
+                lc_lvl = 0,
+                lc_lvl_env = emptyVarEnv, 
+                lc_rec_env = emptyVarEnv,
+                lc_scruts = [],
+                lc_fams = fams }
+
+bombOutSize = lc_size
+\end{code}
+
+
index 2fd1026..41e0922 100644 (file)
@@ -134,7 +134,7 @@ doCorePasses hsc_env rb us stats guts (to_do : to_dos)
 
 doCorePass (CoreDoSimplify mode sws)   = _scc_ "Simplify"      simplifyPgm mode sws
 doCorePass CoreCSE                    = _scc_ "CommonSubExpr" trBinds  cseProgram
 
 doCorePass (CoreDoSimplify mode sws)   = _scc_ "Simplify"      simplifyPgm mode sws
 doCorePass CoreCSE                    = _scc_ "CommonSubExpr" trBinds  cseProgram
-doCorePass CoreLiberateCase           = _scc_ "LiberateCase"  trBinds  liberateCase
+doCorePass CoreLiberateCase           = _scc_ "LiberateCase"  liberateCase
 doCorePass CoreDoFloatInwards          = _scc_ "FloatInwards"  trBinds  floatInwards
 doCorePass (CoreDoFloatOutwards f)     = _scc_ "FloatOutwards" trBindsU (floatOutwards f)
 doCorePass CoreDoStaticArgs           = _scc_ "StaticArgs"    trBinds  doStaticArgs
 doCorePass CoreDoFloatInwards          = _scc_ "FloatInwards"  trBinds  floatInwards
 doCorePass (CoreDoFloatOutwards f)     = _scc_ "FloatOutwards" trBindsU (floatOutwards f)
 doCorePass CoreDoStaticArgs           = _scc_ "StaticArgs"    trBinds  doStaticArgs
@@ -148,6 +148,8 @@ doCorePass (CoreDoRuleCheck phase pat) = observe (ruleCheck phase pat)
 doCorePass CoreDoNothing              = observe (\ _ _ -> return ())
 #ifdef OLD_STRICTNESS                 
 doCorePass CoreDoOldStrictness        = _scc_ "OldStrictness" trBinds doOldStrictness
 doCorePass CoreDoNothing              = observe (\ _ _ -> return ())
 #ifdef OLD_STRICTNESS                 
 doCorePass CoreDoOldStrictness        = _scc_ "OldStrictness" trBinds doOldStrictness
+#else
+doCorePass CoreDoOldStrictness        = panic "CoreDoOldStrictness"
 #endif
 
 #ifdef OLD_STRICTNESS
 #endif
 
 #ifdef OLD_STRICTNESS
index aee72c8..eabd3bc 100644 (file)
@@ -302,6 +302,7 @@ tcRnExtCore hsc_env (HsExtCore this_mod decls src_binds)
                                mg_types     = final_type_env,
                                mg_insts     = tcg_insts tcg_env,
                                mg_fam_insts = tcg_fam_insts tcg_env,
                                mg_types     = final_type_env,
                                mg_insts     = tcg_insts tcg_env,
                                mg_fam_insts = tcg_fam_insts tcg_env,
+                               mg_fam_inst_env = tcg_fam_inst_env tcg_env,
                                mg_rules     = [],
                                mg_binds     = core_binds,
 
                                mg_rules     = [],
                                mg_binds     = core_binds,