One more wibble to FloatOut, fixes HEAD breakage (I hope)
[ghc-hetmet.git] / compiler / simplCore / LiberateCase.lhs
index c29a5b9..9b15734 100644 (file)
@@ -8,18 +8,28 @@ module LiberateCase ( liberateCase ) where
 
 #include "HsVersions.h"
 
 
 #include "HsVersions.h"
 
-import DynFlags                ( DynFlags, DynFlag(..) )
-import StaticFlags     ( opt_LiberateCaseThreshold )
+import DynFlags
+import HscTypes
 import CoreLint                ( showPass, endPass )
 import CoreSyn
 import CoreUnfold      ( couldBeSmallEnoughToInline )
 import CoreLint                ( showPass, endPass )
 import CoreSyn
 import CoreUnfold      ( couldBeSmallEnoughToInline )
-import Id              ( Id, setIdName, idName, setIdNotExported )
+import Rules           ( RuleBase )
+import UniqSupply      ( UniqSupply )
+import SimplMonad      ( SimplCount, zeroSimplCount )
+import Id
+import FamInstEnv
+import Type
+import Coercion
+import TyCon
 import VarEnv
 import Name            ( localiseName )
 import Outputable
 import Util             ( notNull )
 import VarEnv
 import Name            ( localiseName )
 import Outputable
 import Util             ( notNull )
+import Data.IORef      ( readIORef )
 \end{code}
 
 \end{code}
 
+The liberate-case transformation
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 This module walks over @Core@, and looks for @case@ on free variables.
 The criterion is:
        if there is case on a free on the route to the recursive call,
 This module walks over @Core@, and looks for @case@ on free variables.
 The criterion is:
        if there is case on a free on the route to the recursive call,
@@ -27,30 +37,24 @@ The criterion is:
 
 Example
 
 
 Example
 
-\begin{verbatim}
-f = \ t -> case v of
-              V a b -> a : f t
-\end{verbatim}
+   f = \ t -> case v of
+                V a b -> a : f t
 
 => the inner f is replaced.
 
 
 => the inner f is replaced.
 
-\begin{verbatim}
-f = \ t -> case v of
-              V a b -> a : (letrec
+   f = \ t -> case v of
+                V a b -> a : (letrec
                                f =  \ t -> case v of
                                               V a b -> a : f t
                                f =  \ t -> case v of
                                               V a b -> a : f t
-                            in f) t
-\end{verbatim}
+                              in f) t
 (note the NEED for shadowing)
 
 => Simplify
 
 (note the NEED for shadowing)
 
 => Simplify
 
-\begin{verbatim}
-f = \ t -> case v of
-              V a b -> a : (letrec
+  f = \ t -> case v of
+                V a b -> a : (letrec
                                f = \ t -> a : f t
                                f = \ t -> a : f t
-                            in f t)
-\begin{verbatim}
+                              in f t)
 
 Better code, because 'a' is  free inside the inner letrec, rather
 than needing projection from v.
 
 Better code, because 'a' is  free inside the inner letrec, rather
 than needing projection from v.
@@ -72,7 +76,6 @@ We'd like to avoid the redundant pattern match, transforming to
 
        (is this necessarily an improvement)
 
 
        (is this necessarily an improvement)
 
-
 Similarly drop:
 
        drop n [] = []
 Similarly drop:
 
        drop n [] = []
@@ -81,6 +84,15 @@ Similarly drop:
 
 Would like to pass n along unboxed.
        
 
 Would like to pass n along unboxed.
        
+Note [Scrutinee with cast]
+~~~~~~~~~~~~~~~~~~~~~~~~~~
+Consider this:
+    f = \ t -> case (v `cast` co) of
+                V a b -> a : f t
+
+Exactly the same optimisation (unrolling one call to f) will work here, 
+despite the cast.  See mk_alt_env in the Case branch of libCase.
+
 
 To think about (Apr 94)
 ~~~~~~~~~~~~~~
 
 To think about (Apr 94)
 ~~~~~~~~~~~~~~
@@ -96,7 +108,6 @@ big.
 
 Data types
 ~~~~~~~~~~
 
 Data types
 ~~~~~~~~~~
-
 The ``level'' of a binder tells how many
 recursive defns lexically enclose the binding
 A recursive defn "encloses" its RHS, not its
 The ``level'' of a binder tells how many
 recursive defns lexically enclose the binding
 A recursive defn "encloses" its RHS, not its
@@ -110,57 +121,64 @@ scope.  For example:
 Here, the level of @f@ is zero, the level of @g@ is one,
 and the level of @h@ is zero (NB not one).
 
 Here, the level of @f@ is zero, the level of @g@ is one,
 and the level of @h@ is zero (NB not one).
 
-\begin{code}
-type LibCaseLevel = Int
-
-topLevel :: LibCaseLevel
-topLevel = 0
-\end{code}
+Note [Indexed data types]
+~~~~~~~~~~~~~~~~~~~~~~~~~
+Consider
+       data family T :: * -> *
+       data T Int = TI Int
+
+       f :: T Int -> Bool
+       f x = case x of { DEFAULT -> <body> }
+
+We would like to change this to
+       f x = case x `cast` co of { TI p -> <body> }
+
+so that <body> can make use of the fact that x is already evaluated to
+a TI; and a case on a known data type may be more efficient than a
+polymorphic one (not sure this is true any longer).  Anyway the former
+showed up in Roman's experiments.  Example:
+  foo :: FooT Int -> Int -> Int
+  foo t n = t `seq` bar n
+     where
+       bar 0 = 0
+       bar n = bar (n - case t of TI i -> i)
+Here we'd like to avoid repeated evaluating t inside the loop, by 
+taking advantage of the `seq`.
+
+We implement this as part of the liberate-case transformation by 
+spotting
+       case <scrut> of (x::T) tys { DEFAULT ->  <body> }
+where x :: T tys, and T is a indexed family tycon.  Find the
+representation type (T77 tys'), and coercion co, and transform to
+       case <scrut> `cast` co of (y::T77 tys')
+           DEFAULT -> let x = y `cast` sym co in <body>
+
+The "find the representation type" part is done by looking up in the
+family-instance environment.
+
+NB: in fact we re-use x (changing its type) to avoid making a fresh y;
+this entails shadowing, but that's ok.
+
+%************************************************************************
+%*                                                                     *
+        Top-level code
+%*                                                                     *
+%************************************************************************
 
 \begin{code}
 
 \begin{code}
-data LibCaseEnv
-  = LibCaseEnv
-       Int                     -- Bomb-out size for deciding if
-                               -- potential liberatees are too big.
-                               -- (passed in from cmd-line args)
-
-       LibCaseLevel            -- Current level
-
-       (IdEnv LibCaseLevel)    -- Binds all non-top-level in-scope Ids
-                               -- (top-level and imported things have
-                               -- a level of zero)
-
-       (IdEnv CoreBind)        -- Binds *only* recursively defined
-                               -- Ids, to their own binding group,
-                               -- and *only* in their own RHSs
-
-       [(Id,LibCaseLevel)]     -- Each of these Ids was scrutinised by an
-                               -- enclosing case expression, with the
-                               -- specified number of enclosing
-                               -- recursive bindings; furthermore,
-                               -- the Id is bound at a lower level
-                               -- than the case expression.  The
-                               -- order is insignificant; it's a bag
-                               -- really
-
-initEnv :: Int -> LibCaseEnv
-initEnv bomb_size = LibCaseEnv bomb_size 0 emptyVarEnv emptyVarEnv []
-
-bombOutSize (LibCaseEnv bomb_size _ _ _ _) = bomb_size
-\end{code}
-
-
-Programs
-~~~~~~~~
-\begin{code}
-liberateCase :: DynFlags -> [CoreBind] -> IO [CoreBind]
-liberateCase dflags binds
-  = do {
-       showPass dflags "Liberate case" ;
-       let { binds' = do_prog (initEnv opt_LiberateCaseThreshold) binds } ;
-       endPass dflags "Liberate case" Opt_D_verbose_core2core binds'
-                               {- no specific flag for dumping -} 
-    }
+liberateCase :: HscEnv -> UniqSupply -> RuleBase -> ModGuts
+            -> IO (SimplCount, ModGuts)
+liberateCase hsc_env _ _ guts
+  = do { let dflags = hsc_dflags hsc_env
+       ; eps <- readIORef (hsc_EPS hsc_env)
+       ; let fam_envs = (eps_fam_inst_env eps, mg_fam_inst_env guts)
+
+       ; showPass dflags "Liberate case"
+       ; let { env = initEnv dflags fam_envs
+             ; binds' = do_prog env (mg_binds guts) }
+       ; endPass dflags "Liberate case" Opt_D_verbose_core2core binds'
+                       {- no specific flag for dumping -} 
+       ; return (zeroSimplCount dflags, guts { mg_binds = binds' }) }
   where
     do_prog env [] = []
     do_prog env (bind:binds) = bind' : do_prog env' binds
   where
     do_prog env [] = []
     do_prog env (bind:binds) = bind' : do_prog env' binds
@@ -168,9 +186,15 @@ liberateCase dflags binds
                               (env', bind') = libCaseBind env bind
 \end{code}
 
                               (env', bind') = libCaseBind env bind
 \end{code}
 
+
+%************************************************************************
+%*                                                                     *
+        Main payload
+%*                                                                     *
+%************************************************************************
+
 Bindings
 ~~~~~~~~
 Bindings
 ~~~~~~~~
-
 \begin{code}
 libCaseBind :: LibCaseEnv -> CoreBind -> (LibCaseEnv, CoreBind)
 
 \begin{code}
 libCaseBind :: LibCaseEnv -> CoreBind -> (LibCaseEnv, CoreBind)
 
@@ -225,6 +249,7 @@ libCase env (Lit lit)               = Lit lit
 libCase env (Type ty)          = Type ty
 libCase env (App fun arg)       = App (libCase env fun) (libCase env arg)
 libCase env (Note note body)    = Note note (libCase env body)
 libCase env (Type ty)          = Type ty
 libCase env (App fun arg)       = App (libCase env fun) (libCase env arg)
 libCase env (Note note body)    = Note note (libCase env body)
+libCase env (Cast e co)         = Cast (libCase env e) co
 
 libCase env (Lam binder body)
   = Lam binder (libCase (addBinders env [binder]) body)
 
 libCase env (Lam binder body)
   = Lam binder (libCase (addBinders env [binder]) body)
@@ -235,16 +260,34 @@ libCase env (Let bind body)
     (env_body, bind') = libCaseBind env bind
 
 libCase env (Case scrut bndr ty alts)
     (env_body, bind') = libCaseBind env bind
 
 libCase env (Case scrut bndr ty alts)
-  = Case (libCase env scrut) bndr ty (map (libCaseAlt env_alts) alts)
+  = mkCase env (libCase env scrut) bndr ty (map (libCaseAlt env_alts) alts)
   where
   where
-    env_alts = addBinders env_with_scrut [bndr]
-    env_with_scrut = case scrut of
-                       Var scrut_var -> addScrutedVar env scrut_var
-                       other         -> env
+    env_alts = addBinders (mk_alt_env scrut) [bndr]
+    mk_alt_env (Var scrut_var) = addScrutedVar env scrut_var
+    mk_alt_env (Cast scrut _)  = mk_alt_env scrut      -- Note [Scrutinee with cast]
+    mk_alt_env otehr          = env
 
 libCaseAlt env (con,args,rhs) = (con, args, libCase (addBinders env args) rhs)
 \end{code}
 
 
 libCaseAlt env (con,args,rhs) = (con, args, libCase (addBinders env args) rhs)
 \end{code}
 
+\begin{code}
+mkCase :: LibCaseEnv -> CoreExpr -> Id -> Type -> [CoreAlt] -> CoreExpr
+-- See Note [Indexed data types]
+mkCase env scrut bndr ty [(DEFAULT,_,rhs)]
+  | Just (tycon, tys)   <- splitTyConApp_maybe (idType bndr)
+  , [(subst, fam_inst)] <- lookupFamInstEnv (lc_fams env) tycon tys
+  = let 
+       rep_tc     = famInstTyCon fam_inst
+       rep_tys    = map (substTyVar subst) (tyConTyVars rep_tc)
+       bndr'      = setIdType bndr (mkTyConApp rep_tc rep_tys)
+       Just co_tc = tyConFamilyCoercion_maybe rep_tc
+       co         = mkTyConApp co_tc rep_tys
+       bind       = NonRec bndr (Cast (Var bndr') (mkSymCoercion co))
+    in mkCase env (Cast scrut co) bndr' ty [(DEFAULT,[],Let bind rhs)]
+mkCase env scrut bndr ty alts
+  = Case scrut bndr ty alts
+\end{code}
+
 Ids
 ~~~
 \begin{code}
 Ids
 ~~~
 \begin{code}
@@ -263,19 +306,23 @@ libCaseId env v
 \end{code}
 
 
 \end{code}
 
 
+%************************************************************************
+%*                                                                     *
+       Utility functions
+%*                                                                     *
+%************************************************************************
 
 
-Utility functions
-~~~~~~~~~~~~~~~~~
 \begin{code}
 addBinders :: LibCaseEnv -> [CoreBndr] -> LibCaseEnv
 \begin{code}
 addBinders :: LibCaseEnv -> [CoreBndr] -> LibCaseEnv
-addBinders (LibCaseEnv bomb lvl lvl_env rec_env scruts) binders
-  = LibCaseEnv bomb lvl lvl_env' rec_env scruts
+addBinders env@(LibCaseEnv { lc_lvl = lvl, lc_lvl_env = lvl_env }) binders
+  = env { lc_lvl_env = lvl_env' }
   where
     lvl_env' = extendVarEnvList lvl_env (binders `zip` repeat lvl)
 
 addRecBinds :: LibCaseEnv -> [(Id,CoreExpr)] -> LibCaseEnv
   where
     lvl_env' = extendVarEnvList lvl_env (binders `zip` repeat lvl)
 
 addRecBinds :: LibCaseEnv -> [(Id,CoreExpr)] -> LibCaseEnv
-addRecBinds (LibCaseEnv bomb lvl lvl_env rec_env scruts) pairs
-  = LibCaseEnv bomb lvl' lvl_env' rec_env' scruts
+addRecBinds env@(LibCaseEnv {lc_lvl = lvl, lc_lvl_env = lvl_env, 
+                            lc_rec_env = rec_env}) pairs
+  = env { lc_lvl = lvl', lc_lvl_env = lvl_env', lc_rec_env = rec_env' }
   where
     lvl'     = lvl + 1
     lvl_env' = extendVarEnvList lvl_env [(binder,lvl) | (binder,_) <- pairs]
   where
     lvl'     = lvl + 1
     lvl_env' = extendVarEnvList lvl_env [(binder,lvl) | (binder,_) <- pairs]
@@ -285,9 +332,10 @@ addScrutedVar :: LibCaseEnv
              -> Id             -- This Id is being scrutinised by a case expression
              -> LibCaseEnv
 
              -> Id             -- This Id is being scrutinised by a case expression
              -> LibCaseEnv
 
-addScrutedVar env@(LibCaseEnv bomb lvl lvl_env rec_env scruts) scrut_var
+addScrutedVar env@(LibCaseEnv { lc_lvl = lvl, lc_lvl_env = lvl_env, 
+                               lc_scruts = scruts }) scrut_var
   | bind_lvl < lvl
   | bind_lvl < lvl
-  = LibCaseEnv bomb lvl lvl_env rec_env scruts'
+  = env { lc_scruts = scruts' }
        -- Add to scruts iff the scrut_var is being scrutinised at
        -- a deeper level than its defn
 
        -- Add to scruts iff the scrut_var is being scrutinised at
        -- a deeper level than its defn
 
@@ -299,19 +347,77 @@ addScrutedVar env@(LibCaseEnv bomb lvl lvl_env rec_env scruts) scrut_var
                 Nothing  -> topLevel
 
 lookupRecId :: LibCaseEnv -> Id -> Maybe CoreBind
                 Nothing  -> topLevel
 
 lookupRecId :: LibCaseEnv -> Id -> Maybe CoreBind
-lookupRecId (LibCaseEnv bomb lvl lvl_env rec_env scruts) id
-  = lookupVarEnv rec_env id
+lookupRecId env id = lookupVarEnv (lc_rec_env env) id
 
 lookupLevel :: LibCaseEnv -> Id -> LibCaseLevel
 
 lookupLevel :: LibCaseEnv -> Id -> LibCaseLevel
-lookupLevel (LibCaseEnv bomb lvl lvl_env rec_env scruts) id
-  = case lookupVarEnv lvl_env id of
-      Just lvl -> lvl
+lookupLevel env id
+  = case lookupVarEnv (lc_lvl_env env) id of
+      Just lvl -> lc_lvl env
       Nothing  -> topLevel
 
 freeScruts :: LibCaseEnv
           -> LibCaseLevel      -- Level of the recursive Id
           -> [Id]              -- Ids that are scrutinised between the binding
                                -- of the recursive Id and here
       Nothing  -> topLevel
 
 freeScruts :: LibCaseEnv
           -> LibCaseLevel      -- Level of the recursive Id
           -> [Id]              -- Ids that are scrutinised between the binding
                                -- of the recursive Id and here
-freeScruts (LibCaseEnv bomb lvl lvl_env rec_env scruts) rec_bind_lvl
-  = [v | (v,scrut_lvl) <- scruts, scrut_lvl > rec_bind_lvl]
+freeScruts env rec_bind_lvl
+  = [v | (v,scrut_lvl) <- lc_scruts env, scrut_lvl > rec_bind_lvl]
+\end{code}
+
+%************************************************************************
+%*                                                                     *
+        The environment
+%*                                                                     *
+%************************************************************************
+
+\begin{code}
+type LibCaseLevel = Int
+
+topLevel :: LibCaseLevel
+topLevel = 0
 \end{code}
 \end{code}
+
+\begin{code}
+data LibCaseEnv
+  = LibCaseEnv {
+       lc_size :: Int,         -- Bomb-out size for deciding if
+                               -- potential liberatees are too big.
+                               -- (passed in from cmd-line args)
+
+       lc_lvl :: LibCaseLevel, -- Current level
+
+       lc_lvl_env :: IdEnv LibCaseLevel,  
+                       -- Binds all non-top-level in-scope Ids
+                       -- (top-level and imported things have
+                       -- a level of zero)
+
+       lc_rec_env :: IdEnv CoreBind, 
+                       -- Binds *only* recursively defined ids, 
+                       -- to their own binding group,
+                       -- and *only* in their own RHSs
+
+       lc_scruts :: [(Id,LibCaseLevel)],
+                       -- Each of these Ids was scrutinised by an
+                       -- enclosing case expression, with the
+                       -- specified number of enclosing
+                       -- recursive bindings; furthermore,
+                       -- the Id is bound at a lower level
+                       -- than the case expression.  The order is
+                       -- insignificant; it's a bag really
+
+       lc_fams :: FamInstEnvs
+                       -- Instance env for indexed data types 
+       }
+
+initEnv :: DynFlags -> FamInstEnvs -> LibCaseEnv
+initEnv dflags fams
+  = LibCaseEnv { lc_size = libCaseThreshold dflags,
+                lc_lvl = 0,
+                lc_lvl_env = emptyVarEnv, 
+                lc_rec_env = emptyVarEnv,
+                lc_scruts = [],
+                lc_fams = fams }
+
+bombOutSize = lc_size
+\end{code}
+
+