Warning fix for unused and redundant imports
[ghc-hetmet.git] / compiler / simplCore / LiberateCase.lhs
index eebb11c..02a3fab 100644 (file)
@@ -8,18 +8,27 @@ module LiberateCase ( liberateCase ) where
 
 #include "HsVersions.h"
 
-import DynFlags                ( DynFlags, DynFlag(..) )
-import StaticFlags     ( opt_LiberateCaseThreshold )
+import DynFlags
+import HscTypes
 import CoreLint                ( showPass, endPass )
 import CoreSyn
 import CoreUnfold      ( couldBeSmallEnoughToInline )
-import Id              ( Id, setIdName, idName, setIdNotExported )
+import Rules           ( RuleBase )
+import UniqSupply      ( UniqSupply )
+import SimplMonad      ( SimplCount, zeroSimplCount )
+import Id
+import FamInstEnv
+import Type
+import Coercion
+import TyCon
 import VarEnv
 import Name            ( localiseName )
-import Outputable
 import Util             ( notNull )
+import Data.IORef      ( readIORef )
 \end{code}
 
+The liberate-case transformation
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 This module walks over @Core@, and looks for @case@ on free variables.
 The criterion is:
        if there is case on a free on the route to the recursive call,
@@ -27,30 +36,24 @@ The criterion is:
 
 Example
 
-\begin{verbatim}
-f = \ t -> case v of
-              V a b -> a : f t
-\end{verbatim}
+   f = \ t -> case v of
+                V a b -> a : f t
 
 => the inner f is replaced.
 
-\begin{verbatim}
-f = \ t -> case v of
-              V a b -> a : (letrec
+   f = \ t -> case v of
+                V a b -> a : (letrec
                                f =  \ t -> case v of
                                               V a b -> a : f t
-                            in f) t
-\end{verbatim}
+                              in f) t
 (note the NEED for shadowing)
 
 => Simplify
 
-\begin{verbatim}
-f = \ t -> case v of
-              V a b -> a : (letrec
+  f = \ t -> case v of
+                V a b -> a : (letrec
                                f = \ t -> a : f t
-                            in f t)
-\begin{verbatim}
+                              in f t)
 
 Better code, because 'a' is  free inside the inner letrec, rather
 than needing projection from v.
@@ -72,7 +75,6 @@ We'd like to avoid the redundant pattern match, transforming to
 
        (is this necessarily an improvement)
 
-
 Similarly drop:
 
        drop n [] = []
@@ -81,6 +83,15 @@ Similarly drop:
 
 Would like to pass n along unboxed.
        
+Note [Scrutinee with cast]
+~~~~~~~~~~~~~~~~~~~~~~~~~~
+Consider this:
+    f = \ t -> case (v `cast` co) of
+                V a b -> a : f t
+
+Exactly the same optimisation (unrolling one call to f) will work here, 
+despite the cast.  See mk_alt_env in the Case branch of libCase.
+
 
 To think about (Apr 94)
 ~~~~~~~~~~~~~~
@@ -96,7 +107,6 @@ big.
 
 Data types
 ~~~~~~~~~~
-
 The ``level'' of a binder tells how many
 recursive defns lexically enclose the binding
 A recursive defn "encloses" its RHS, not its
@@ -110,57 +120,64 @@ scope.  For example:
 Here, the level of @f@ is zero, the level of @g@ is one,
 and the level of @h@ is zero (NB not one).
 
-\begin{code}
-type LibCaseLevel = Int
-
-topLevel :: LibCaseLevel
-topLevel = 0
-\end{code}
+Note [Indexed data types]
+~~~~~~~~~~~~~~~~~~~~~~~~~
+Consider
+       data family T :: * -> *
+       data T Int = TI Int
+
+       f :: T Int -> Bool
+       f x = case x of { DEFAULT -> <body> }
+
+We would like to change this to
+       f x = case x `cast` co of { TI p -> <body> }
+
+so that <body> can make use of the fact that x is already evaluated to
+a TI; and a case on a known data type may be more efficient than a
+polymorphic one (not sure this is true any longer).  Anyway the former
+showed up in Roman's experiments.  Example:
+  foo :: FooT Int -> Int -> Int
+  foo t n = t `seq` bar n
+     where
+       bar 0 = 0
+       bar n = bar (n - case t of TI i -> i)
+Here we'd like to avoid repeated evaluating t inside the loop, by 
+taking advantage of the `seq`.
+
+We implement this as part of the liberate-case transformation by 
+spotting
+       case <scrut> of (x::T) tys { DEFAULT ->  <body> }
+where x :: T tys, and T is a indexed family tycon.  Find the
+representation type (T77 tys'), and coercion co, and transform to
+       case <scrut> `cast` co of (y::T77 tys')
+           DEFAULT -> let x = y `cast` sym co in <body>
+
+The "find the representation type" part is done by looking up in the
+family-instance environment.
+
+NB: in fact we re-use x (changing its type) to avoid making a fresh y;
+this entails shadowing, but that's ok.
+
+%************************************************************************
+%*                                                                     *
+        Top-level code
+%*                                                                     *
+%************************************************************************
 
 \begin{code}
-data LibCaseEnv
-  = LibCaseEnv
-       Int                     -- Bomb-out size for deciding if
-                               -- potential liberatees are too big.
-                               -- (passed in from cmd-line args)
-
-       LibCaseLevel            -- Current level
-
-       (IdEnv LibCaseLevel)    -- Binds all non-top-level in-scope Ids
-                               -- (top-level and imported things have
-                               -- a level of zero)
-
-       (IdEnv CoreBind)        -- Binds *only* recursively defined
-                               -- Ids, to their own binding group,
-                               -- and *only* in their own RHSs
-
-       [(Id,LibCaseLevel)]     -- Each of these Ids was scrutinised by an
-                               -- enclosing case expression, with the
-                               -- specified number of enclosing
-                               -- recursive bindings; furthermore,
-                               -- the Id is bound at a lower level
-                               -- than the case expression.  The
-                               -- order is insignificant; it's a bag
-                               -- really
-
-initEnv :: Int -> LibCaseEnv
-initEnv bomb_size = LibCaseEnv bomb_size 0 emptyVarEnv emptyVarEnv []
-
-bombOutSize (LibCaseEnv bomb_size _ _ _ _) = bomb_size
-\end{code}
-
-
-Programs
-~~~~~~~~
-\begin{code}
-liberateCase :: DynFlags -> [CoreBind] -> IO [CoreBind]
-liberateCase dflags binds
-  = do {
-       showPass dflags "Liberate case" ;
-       let { binds' = do_prog (initEnv opt_LiberateCaseThreshold) binds } ;
-       endPass dflags "Liberate case" Opt_D_verbose_core2core binds'
-                               {- no specific flag for dumping -} 
-    }
+liberateCase :: HscEnv -> UniqSupply -> RuleBase -> ModGuts
+            -> IO (SimplCount, ModGuts)
+liberateCase hsc_env _ _ guts
+  = do { let dflags = hsc_dflags hsc_env
+       ; eps <- readIORef (hsc_EPS hsc_env)
+       ; let fam_envs = (eps_fam_inst_env eps, mg_fam_inst_env guts)
+
+       ; showPass dflags "Liberate case"
+       ; let { env = initEnv dflags fam_envs
+             ; binds' = do_prog env (mg_binds guts) }
+       ; endPass dflags "Liberate case" Opt_D_verbose_core2core binds'
+                       {- no specific flag for dumping -} 
+       ; return (zeroSimplCount dflags, guts { mg_binds = binds' }) }
   where
     do_prog env [] = []
     do_prog env (bind:binds) = bind' : do_prog env' binds
@@ -168,9 +185,15 @@ liberateCase dflags binds
                               (env', bind') = libCaseBind env bind
 \end{code}
 
+
+%************************************************************************
+%*                                                                     *
+        Main payload
+%*                                                                     *
+%************************************************************************
+
 Bindings
 ~~~~~~~~
-
 \begin{code}
 libCaseBind :: LibCaseEnv -> CoreBind -> (LibCaseEnv, CoreBind)
 
@@ -236,16 +259,34 @@ libCase env (Let bind body)
     (env_body, bind') = libCaseBind env bind
 
 libCase env (Case scrut bndr ty alts)
-  = Case (libCase env scrut) bndr ty (map (libCaseAlt env_alts) alts)
+  = mkCase env (libCase env scrut) bndr ty (map (libCaseAlt env_alts) alts)
   where
-    env_alts = addBinders env_with_scrut [bndr]
-    env_with_scrut = case scrut of
-                       Var scrut_var -> addScrutedVar env scrut_var
-                       other         -> env
+    env_alts = addBinders (mk_alt_env scrut) [bndr]
+    mk_alt_env (Var scrut_var) = addScrutedVar env scrut_var
+    mk_alt_env (Cast scrut _)  = mk_alt_env scrut      -- Note [Scrutinee with cast]
+    mk_alt_env otehr          = env
 
 libCaseAlt env (con,args,rhs) = (con, args, libCase (addBinders env args) rhs)
 \end{code}
 
+\begin{code}
+mkCase :: LibCaseEnv -> CoreExpr -> Id -> Type -> [CoreAlt] -> CoreExpr
+-- See Note [Indexed data types]
+mkCase env scrut bndr ty [(DEFAULT,_,rhs)]
+  | Just (tycon, tys)   <- splitTyConApp_maybe (idType bndr)
+  , [(subst, fam_inst)] <- lookupFamInstEnv (lc_fams env) tycon tys
+  = let 
+       rep_tc     = famInstTyCon fam_inst
+       rep_tys    = map (substTyVar subst) (tyConTyVars rep_tc)
+       bndr'      = setIdType bndr (mkTyConApp rep_tc rep_tys)
+       Just co_tc = tyConFamilyCoercion_maybe rep_tc
+       co         = mkTyConApp co_tc rep_tys
+       bind       = NonRec bndr (Cast (Var bndr') (mkSymCoercion co))
+    in mkCase env (Cast scrut co) bndr' ty [(DEFAULT,[],Let bind rhs)]
+mkCase env scrut bndr ty alts
+  = Case scrut bndr ty alts
+\end{code}
+
 Ids
 ~~~
 \begin{code}
@@ -264,19 +305,23 @@ libCaseId env v
 \end{code}
 
 
+%************************************************************************
+%*                                                                     *
+       Utility functions
+%*                                                                     *
+%************************************************************************
 
-Utility functions
-~~~~~~~~~~~~~~~~~
 \begin{code}
 addBinders :: LibCaseEnv -> [CoreBndr] -> LibCaseEnv
-addBinders (LibCaseEnv bomb lvl lvl_env rec_env scruts) binders
-  = LibCaseEnv bomb lvl lvl_env' rec_env scruts
+addBinders env@(LibCaseEnv { lc_lvl = lvl, lc_lvl_env = lvl_env }) binders
+  = env { lc_lvl_env = lvl_env' }
   where
     lvl_env' = extendVarEnvList lvl_env (binders `zip` repeat lvl)
 
 addRecBinds :: LibCaseEnv -> [(Id,CoreExpr)] -> LibCaseEnv
-addRecBinds (LibCaseEnv bomb lvl lvl_env rec_env scruts) pairs
-  = LibCaseEnv bomb lvl' lvl_env' rec_env' scruts
+addRecBinds env@(LibCaseEnv {lc_lvl = lvl, lc_lvl_env = lvl_env, 
+                            lc_rec_env = rec_env}) pairs
+  = env { lc_lvl = lvl', lc_lvl_env = lvl_env', lc_rec_env = rec_env' }
   where
     lvl'     = lvl + 1
     lvl_env' = extendVarEnvList lvl_env [(binder,lvl) | (binder,_) <- pairs]
@@ -286,9 +331,10 @@ addScrutedVar :: LibCaseEnv
              -> Id             -- This Id is being scrutinised by a case expression
              -> LibCaseEnv
 
-addScrutedVar env@(LibCaseEnv bomb lvl lvl_env rec_env scruts) scrut_var
+addScrutedVar env@(LibCaseEnv { lc_lvl = lvl, lc_lvl_env = lvl_env, 
+                               lc_scruts = scruts }) scrut_var
   | bind_lvl < lvl
-  = LibCaseEnv bomb lvl lvl_env rec_env scruts'
+  = env { lc_scruts = scruts' }
        -- Add to scruts iff the scrut_var is being scrutinised at
        -- a deeper level than its defn
 
@@ -300,19 +346,77 @@ addScrutedVar env@(LibCaseEnv bomb lvl lvl_env rec_env scruts) scrut_var
                 Nothing  -> topLevel
 
 lookupRecId :: LibCaseEnv -> Id -> Maybe CoreBind
-lookupRecId (LibCaseEnv bomb lvl lvl_env rec_env scruts) id
-  = lookupVarEnv rec_env id
+lookupRecId env id = lookupVarEnv (lc_rec_env env) id
 
 lookupLevel :: LibCaseEnv -> Id -> LibCaseLevel
-lookupLevel (LibCaseEnv bomb lvl lvl_env rec_env scruts) id
-  = case lookupVarEnv lvl_env id of
-      Just lvl -> lvl
+lookupLevel env id
+  = case lookupVarEnv (lc_lvl_env env) id of
+      Just lvl -> lc_lvl env
       Nothing  -> topLevel
 
 freeScruts :: LibCaseEnv
           -> LibCaseLevel      -- Level of the recursive Id
           -> [Id]              -- Ids that are scrutinised between the binding
                                -- of the recursive Id and here
-freeScruts (LibCaseEnv bomb lvl lvl_env rec_env scruts) rec_bind_lvl
-  = [v | (v,scrut_lvl) <- scruts, scrut_lvl > rec_bind_lvl]
+freeScruts env rec_bind_lvl
+  = [v | (v,scrut_lvl) <- lc_scruts env, scrut_lvl > rec_bind_lvl]
+\end{code}
+
+%************************************************************************
+%*                                                                     *
+        The environment
+%*                                                                     *
+%************************************************************************
+
+\begin{code}
+type LibCaseLevel = Int
+
+topLevel :: LibCaseLevel
+topLevel = 0
 \end{code}
+
+\begin{code}
+data LibCaseEnv
+  = LibCaseEnv {
+       lc_size :: Int,         -- Bomb-out size for deciding if
+                               -- potential liberatees are too big.
+                               -- (passed in from cmd-line args)
+
+       lc_lvl :: LibCaseLevel, -- Current level
+
+       lc_lvl_env :: IdEnv LibCaseLevel,  
+                       -- Binds all non-top-level in-scope Ids
+                       -- (top-level and imported things have
+                       -- a level of zero)
+
+       lc_rec_env :: IdEnv CoreBind, 
+                       -- Binds *only* recursively defined ids, 
+                       -- to their own binding group,
+                       -- and *only* in their own RHSs
+
+       lc_scruts :: [(Id,LibCaseLevel)],
+                       -- Each of these Ids was scrutinised by an
+                       -- enclosing case expression, with the
+                       -- specified number of enclosing
+                       -- recursive bindings; furthermore,
+                       -- the Id is bound at a lower level
+                       -- than the case expression.  The order is
+                       -- insignificant; it's a bag really
+
+       lc_fams :: FamInstEnvs
+                       -- Instance env for indexed data types 
+       }
+
+initEnv :: DynFlags -> FamInstEnvs -> LibCaseEnv
+initEnv dflags fams
+  = LibCaseEnv { lc_size = specThreshold dflags,
+                lc_lvl = 0,
+                lc_lvl_env = emptyVarEnv, 
+                lc_rec_env = emptyVarEnv,
+                lc_scruts = [],
+                lc_fams = fams }
+
+bombOutSize = lc_size
+\end{code}
+
+