Improve the interaction of 'seq' and associated data types
authorsimonpj@microsoft.com <unknown>
Wed, 23 May 2007 11:48:18 +0000 (11:48 +0000)
committersimonpj@microsoft.com <unknown>
Wed, 23 May 2007 11:48:18 +0000 (11:48 +0000)
Roman produced programs involving associated types that did not optimise well.
His programs were something like this:

  data family T a
  data instance T Int = MkT Bool Char

  bar :: T Int -> Int
  bar t = t `seq` loop 0
where
  loop = ...

You'd think that the `seq` should unbox 't' outside the loop, since
a (T Int) is just a MkT pair.

The most robust way to make this happen is for the simplifier to understand
a bit about type-family instances.   See
Note [Improving seq]
in Simplify.lhs.  We use FamInstEnv.topNormaliseType to do the interesting
work.

To make this happen I did a bit of refactoring to the simplifier
monad.

I'd previously done a very similar transformation in LiberateCase, but it
was happening too late.  So this patch takes it out of LiberateCase as
well as adding it to Simplify.

compiler/simplCore/LiberateCase.lhs
compiler/simplCore/SimplCore.lhs
compiler/simplCore/SimplEnv.lhs
compiler/simplCore/SimplMonad.lhs
compiler/simplCore/Simplify.lhs
compiler/types/FamInstEnv.lhs

index 9f03adf..0df9b37 100644 (file)
@@ -17,14 +17,9 @@ import Rules         ( RuleBase )
 import UniqSupply      ( UniqSupply )
 import SimplMonad      ( SimplCount, zeroSimplCount )
 import Id
-import FamInstEnv
-import Type
-import Coercion
-import TyCon
 import VarEnv
 import Name            ( localiseName )
 import Util             ( notNull )
-import Data.IORef      ( readIORef )
 \end{code}
 
 The liberate-case transformation
@@ -120,43 +115,6 @@ scope.  For example:
 Here, the level of @f@ is zero, the level of @g@ is one,
 and the level of @h@ is zero (NB not one).
 
-Note [Indexed data types]
-~~~~~~~~~~~~~~~~~~~~~~~~~
-Consider
-       data family T :: * -> *
-       data T Int = TI Int
-
-       f :: T Int -> Bool
-       f x = case x of { DEFAULT -> <body> }
-
-We would like to change this to
-       f x = case x `cast` co of { TI p -> <body> }
-
-so that <body> can make use of the fact that x is already evaluated to
-a TI; and a case on a known data type may be more efficient than a
-polymorphic one (not sure this is true any longer).  Anyway the former
-showed up in Roman's experiments.  Example:
-  foo :: FooT Int -> Int -> Int
-  foo t n = t `seq` bar n
-     where
-       bar 0 = 0
-       bar n = bar (n - case t of TI i -> i)
-Here we'd like to avoid repeated evaluating t inside the loop, by 
-taking advantage of the `seq`.
-
-We implement this as part of the liberate-case transformation by 
-spotting
-       case <scrut> of (x::T) tys { DEFAULT ->  <body> }
-where x :: T tys, and T is a indexed family tycon.  Find the
-representation type (T77 tys'), and coercion co, and transform to
-       case <scrut> `cast` co of (y::T77 tys')
-           DEFAULT -> let x = y `cast` sym co in <body>
-
-The "find the representation type" part is done by looking up in the
-family-instance environment.
-
-NB: in fact we re-use x (changing its type) to avoid making a fresh y;
-this entails shadowing, but that's ok.
 
 %************************************************************************
 %*                                                                     *
@@ -169,11 +127,9 @@ liberateCase :: HscEnv -> UniqSupply -> RuleBase -> ModGuts
             -> IO (SimplCount, ModGuts)
 liberateCase hsc_env _ _ guts
   = do { let dflags = hsc_dflags hsc_env
-       ; eps <- readIORef (hsc_EPS hsc_env)
-       ; let fam_envs = (eps_fam_inst_env eps, mg_fam_inst_env guts)
 
        ; showPass dflags "Liberate case"
-       ; let { env = initEnv dflags fam_envs
+       ; let { env = initEnv dflags
              ; binds' = do_prog env (mg_binds guts) }
        ; endPass dflags "Liberate case" Opt_D_verbose_core2core binds'
                        {- no specific flag for dumping -} 
@@ -259,7 +215,7 @@ libCase env (Let bind body)
     (env_body, bind') = libCaseBind env bind
 
 libCase env (Case scrut bndr ty alts)
-  = mkCase env (libCase env scrut) bndr ty (map (libCaseAlt env_alts) alts)
+  = Case (libCase env scrut) bndr ty (map (libCaseAlt env_alts) alts)
   where
     env_alts = addBinders (mk_alt_env scrut) [bndr]
     mk_alt_env (Var scrut_var) = addScrutedVar env scrut_var
@@ -269,22 +225,6 @@ libCase env (Case scrut bndr ty alts)
 libCaseAlt env (con,args,rhs) = (con, args, libCase (addBinders env args) rhs)
 \end{code}
 
-\begin{code}
-mkCase :: LibCaseEnv -> CoreExpr -> Id -> Type -> [CoreAlt] -> CoreExpr
--- See Note [Indexed data types]
-mkCase env scrut bndr ty [(DEFAULT,_,rhs)]
-  | Just (tycon, tys)   <- splitTyConApp_maybe (idType bndr)
-  , [(fam_inst, rep_tys)] <- lookupFamInstEnv (lc_fams env) tycon tys
-  = let 
-       rep_tc     = famInstTyCon fam_inst
-       bndr'      = setIdType bndr (mkTyConApp rep_tc rep_tys)
-       Just co_tc = tyConFamilyCoercion_maybe rep_tc
-       co         = mkTyConApp co_tc rep_tys
-       bind       = NonRec bndr (Cast (Var bndr') (mkSymCoercion co))
-    in mkCase env (Cast scrut co) bndr' ty [(DEFAULT,[],Let bind rhs)]
-mkCase env scrut bndr ty alts
-  = Case scrut bndr ty alts
-\end{code}
 
 Ids
 ~~~
@@ -393,7 +333,7 @@ data LibCaseEnv
                        -- to their own binding group,
                        -- and *only* in their own RHSs
 
-       lc_scruts :: [(Id,LibCaseLevel)],
+       lc_scruts :: [(Id,LibCaseLevel)]
                        -- Each of these Ids was scrutinised by an
                        -- enclosing case expression, with the
                        -- specified number of enclosing
@@ -401,19 +341,15 @@ data LibCaseEnv
                        -- the Id is bound at a lower level
                        -- than the case expression.  The order is
                        -- insignificant; it's a bag really
-
-       lc_fams :: FamInstEnvs
-                       -- Instance env for indexed data types 
        }
 
-initEnv :: DynFlags -> FamInstEnvs -> LibCaseEnv
-initEnv dflags fams
+initEnv :: DynFlags -> LibCaseEnv
+initEnv dflags 
   = LibCaseEnv { lc_size = specThreshold dflags,
                 lc_lvl = 0,
                 lc_lvl_env = emptyVarEnv, 
                 lc_rec_env = emptyVarEnv,
-                lc_scruts = [],
-                lc_fams = fams }
+                lc_scruts = [] }
 
 bombOutSize = lc_size
 \end{code}
index 200ebc4..032e3b0 100644 (file)
@@ -33,6 +33,7 @@ import ErrUtils               ( dumpIfSet, dumpIfSet_dyn, showPass )
 import CoreLint                ( endPass )
 import FloatIn         ( floatInwards )
 import FloatOut                ( floatOutwards )
+import FamInstEnv
 import Id              ( Id, modifyIdInfo, idInfo, isExportedId, isLocalId,
                          idSpecialisation, idName )
 import VarSet
@@ -101,7 +102,7 @@ simplifyExpr dflags expr
 
        ; us <-  mkSplitUniqSupply 's'
 
-       ; let (expr', _counts) = initSmpl dflags us $
+       ; let (expr', _counts) = initSmpl dflags emptyRuleBase emptyFamInstEnvs us $
                                 simplExprGently gentleSimplEnv expr
 
        ; dumpIfSet_dyn dflags Opt_D_dump_simpl "Simplified expression"
@@ -111,9 +112,7 @@ simplifyExpr dflags expr
        }
 
 gentleSimplEnv :: SimplEnv
-gentleSimplEnv = mkSimplEnv SimplGently 
-                           (isAmongSimpl [])
-                           emptyRuleBase
+gentleSimplEnv = mkSimplEnv SimplGently  (isAmongSimpl [])
 
 doCorePasses :: HscEnv
              -> RuleBase        -- the imported main rule base
@@ -232,7 +231,8 @@ prepareRules hsc_env@(HscEnv { hsc_dflags = dflags, hsc_HPT = hpt })
                -- from the local binders, to avoid warnings from Simplify.simplVar
              local_ids        = mkInScopeSet (mkVarSet (bindersOfBinds binds))
              env              = setInScopeSet gentleSimplEnv local_ids 
-             (better_rules,_) = initSmpl dflags us (mapSmpl (simplRule env) local_rules)
+             (better_rules,_) = initSmpl dflags emptyRuleBase emptyFamInstEnvs us $
+                                (mapSmpl (simplRule env) local_rules)
              home_pkg_rules   = hptRules hsc_env (dep_mods deps)
 
                -- Find the rules for locally-defined Ids; then we can attach them
@@ -445,7 +445,10 @@ simplifyPgm mode switches hsc_env us imp_rule_base guts
                -- miss the rules for Ids hidden inside imported inlinings
           eps <- hscEPS hsc_env ;
           let  { rule_base' = unionRuleBase imp_rule_base (eps_rule_base eps)
-               ; simpl_env  = mkSimplEnv mode sw_chkr rule_base' } ;
+               ; simpl_env  = mkSimplEnv mode sw_chkr 
+               ; simpl_binds = _scc_ "SimplTopBinds" 
+                               simplTopBinds simpl_env tagged_binds
+               ; fam_envs = (eps_fam_inst_env eps, mg_fam_inst_env guts) } ;
           
                -- Simplify the program
                -- We do this with a *case* not a *let* because lazy pattern
@@ -458,7 +461,7 @@ simplifyPgm mode switches hsc_env us imp_rule_base guts
                --      case t of {(_,counts') -> if counts'=0 then ... }
                -- So the conditional didn't force counts', because the
                -- selection got duplicated.  Sigh!
-          case initSmpl dflags us1 (_scc_ "SimplTopBinds" simplTopBinds simpl_env tagged_binds) of {
+          case initSmpl dflags rule_base' fam_envs us1 simpl_binds of {
                (binds', counts') -> do {
 
           let  { all_counts = counts `plusSimplCount` counts'
index 2fedf87..1d7d2e4 100644 (file)
@@ -101,9 +101,6 @@ data SimplEnv
        seChkr      :: SwitchChecker,
        seCC        :: CostCentreStack, -- The enclosing CCS (when profiling)
 
-       -- Rules from other modules
-       seExtRules  :: RuleBase,
-
        -- The current set of in-scope variables
        -- They are all OutVars, and all bound in this module
        seInScope   :: InScopeSet,      -- OutVars only
@@ -207,11 +204,11 @@ seIdSubst:
 
 
 \begin{code}
-mkSimplEnv :: SimplifierMode -> SwitchChecker -> RuleBase -> SimplEnv
-mkSimplEnv mode switches rules
+mkSimplEnv :: SimplifierMode -> SwitchChecker -> SimplEnv
+mkSimplEnv mode switches
   = SimplEnv { seChkr = switches, seCC = subsumedCCS, 
               seMode = mode, seInScope = emptyInScopeSet, 
-              seExtRules = rules, seFloats = emptyFloats,
+              seFloats = emptyFloats,
               seTvSubst = emptyVarEnv, seIdSubst = emptyVarEnv }
        -- The top level "enclosing CC" is "SUBSUMED".
 
@@ -289,10 +286,6 @@ mkContEx (SimplEnv { seTvSubst = tvs, seIdSubst = ids }) e = ContEx tvs ids e
 isEmptySimplSubst :: SimplEnv -> Bool
 isEmptySimplSubst (SimplEnv { seTvSubst = tvs, seIdSubst = ids })
   = isEmptyVarEnv tvs && isEmptyVarEnv ids
-
----------------------
-getRules :: SimplEnv -> RuleBase
-getRules = seExtRules
 \end{code}
 
 
@@ -639,8 +632,8 @@ substLetIdBndr env@(SimplEnv { seInScope = in_scope, seIdSubst = id_subst }) old
              = delVarEnv id_subst old_id
 \end{code}
 
-Add IdInfo back onto a let-bound Id
-~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+Note [Add IdInfo back onto a let-bound Id]
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 We must transfer the IdInfo of the original binder to the new binder.
 This is crucial, to preserve
        strictness
index a198b32..7126883 100644 (file)
@@ -9,7 +9,7 @@ module SimplMonad (
        SimplM,
        initSmpl, returnSmpl, thenSmpl, thenSmpl_,
        mapSmpl, mapAndUnzipSmpl, mapAccumLSmpl,
-       getDOptsSmpl,
+       getDOptsSmpl, getRules, getFamEnvs,
 
         -- Unique supply
         getUniqueSmpl, getUniquesSmpl, getUniqSupplySmpl, newId,
@@ -29,6 +29,8 @@ module SimplMonad (
 
 import Id              ( Id, mkSysLocal )
 import Type             ( Type )
+import FamInstEnv      ( FamInstEnv )
+import Rules           ( RuleBase )
 import UniqSupply      ( uniqsFromSupply, uniqFromSupply, splitUniqSupply,
                          UniqSupply
                        )
@@ -61,22 +63,28 @@ For the simplifier monad, we want to {\em thread} a unique supply and a counter.
 
 \begin{code}
 newtype SimplM result
-  =  SM  { unSM :: DynFlags            -- We thread the unique supply because
-                  -> UniqSupply        -- constantly splitting it is rather expensive
-                  -> SimplCount 
-                  -> (result, UniqSupply, SimplCount)}
+  =  SM  { unSM :: SimplTopEnv -- Envt that does not change much
+               -> UniqSupply   -- We thread the unique supply because
+                               -- constantly splitting it is rather expensive
+               -> SimplCount 
+               -> (result, UniqSupply, SimplCount)}
+
+data SimplTopEnv = STE { st_flags :: DynFlags 
+                       , st_rules :: RuleBase
+                       , st_fams  :: (FamInstEnv, FamInstEnv) }
 \end{code}
 
 \begin{code}
-initSmpl :: DynFlags
+initSmpl :: DynFlags -> RuleBase -> (FamInstEnv, FamInstEnv) 
         -> UniqSupply          -- No init count; set to 0
         -> SimplM a
         -> (a, SimplCount)
 
-initSmpl dflags us m
-  = case unSM m dflags us (zeroSimplCount dflags) of 
+initSmpl dflags rules fam_envs us m
+  = case unSM m env us (zeroSimplCount dflags) of 
        (result, _, count) -> (result, count)
-
+  where
+    env = STE { st_flags = dflags, st_rules = rules, st_fams = fam_envs }
 
 {-# INLINE thenSmpl #-}
 {-# INLINE thenSmpl_ #-}
@@ -88,20 +96,20 @@ instance Monad SimplM where
    return = returnSmpl
 
 returnSmpl :: a -> SimplM a
-returnSmpl e = SM (\ dflags us sc -> (e, us, sc))
+returnSmpl e = SM (\ st_env us sc -> (e, us, sc))
 
 thenSmpl  :: SimplM a -> (a -> SimplM b) -> SimplM b
 thenSmpl_ :: SimplM a -> SimplM b -> SimplM b
 
 thenSmpl m k 
-  = SM (\ dflags us0 sc0 ->
-         case (unSM m dflags us0 sc0) of 
-               (m_result, us1, sc1) -> unSM (k m_result) dflags us1 sc1 )
+  = SM (\ st_env us0 sc0 ->
+         case (unSM m st_env us0 sc0) of 
+               (m_result, us1, sc1) -> unSM (k m_result) st_env us1 sc1 )
 
 thenSmpl_ m k 
-  = SM (\dflags us0 sc0 ->
-        case (unSM m dflags us0 sc0) of 
-               (_, us1, sc1) -> unSM k dflags us1 sc1)
+  = SM (\st_env us0 sc0 ->
+        case (unSM m st_env us0 sc0) of 
+               (_, us1, sc1) -> unSM k st_env us1 sc1)
 \end{code}
 
 
@@ -138,22 +146,27 @@ mapAccumLSmpl f acc (x:xs) = f acc x      `thenSmpl` \ (acc', x') ->
 \begin{code}
 getUniqSupplySmpl :: SimplM UniqSupply
 getUniqSupplySmpl 
-   = SM (\dflags us sc -> case splitUniqSupply us of
+   = SM (\st_env us sc -> case splitUniqSupply us of
                                (us1, us2) -> (us1, us2, sc))
 
 getUniqueSmpl :: SimplM Unique
 getUniqueSmpl 
-   = SM (\dflags us sc -> case splitUniqSupply us of
+   = SM (\st_env us sc -> case splitUniqSupply us of
                                (us1, us2) -> (uniqFromSupply us1, us2, sc))
 
 getUniquesSmpl :: SimplM [Unique]
 getUniquesSmpl 
-   = SM (\dflags us sc -> case splitUniqSupply us of
+   = SM (\st_env us sc -> case splitUniqSupply us of
                                (us1, us2) -> (uniqsFromSupply us1, us2, sc))
 
 getDOptsSmpl :: SimplM DynFlags
-getDOptsSmpl 
-   = SM (\dflags us sc -> (dflags, us, sc))
+getDOptsSmpl = SM (\st_env us sc -> (st_flags st_env, us, sc))
+
+getRules :: SimplM RuleBase
+getRules = SM (\st_env us sc -> (st_rules st_env, us, sc))
+
+getFamEnvs :: SimplM (FamInstEnv, FamInstEnv)
+getFamEnvs = SM (\st_env us sc -> (st_fams st_env, us, sc))
 
 newId :: FastString -> Type -> SimplM Id
 newId fs ty = getUniqueSmpl    `thenSmpl` \ uniq ->
@@ -169,18 +182,18 @@ newId fs ty = getUniqueSmpl       `thenSmpl` \ uniq ->
 
 \begin{code}
 getSimplCount :: SimplM SimplCount
-getSimplCount = SM (\dflags us sc -> (sc, us, sc))
+getSimplCount = SM (\st_env us sc -> (sc, us, sc))
 
 tick :: Tick -> SimplM ()
 tick t 
-   = SM (\dflags us sc -> let sc' = doTick t sc 
+   = SM (\st_env us sc -> let sc' = doTick t sc 
                          in sc' `seq` ((), us, sc'))
 
 freeTick :: Tick -> SimplM ()
 -- Record a tick, but don't add to the total tick count, which is
 -- used to decide when nothing further has happened
 freeTick t 
-   = SM (\dflags us sc -> let sc' = doFreeTick t sc
+   = SM (\st_env us sc -> let sc' = doFreeTick t sc
                          in sc' `seq` ((), us, sc'))
 \end{code}
 
index d97249f..ac1f790 100644 (file)
@@ -17,6 +17,7 @@ import Id
 import Var
 import IdInfo
 import Coercion
+import FamInstEnv      ( topNormaliseType )
 import DataCon         ( dataConRepStrictness, dataConUnivTyVars )
 import CoreSyn
 import NewDemand       ( isStrictDmd )
@@ -870,7 +871,7 @@ simplNonRecE env bndr (rhs, rhs_se) (bndrs, body) cont
                     (StrictBind bndr bndrs body env cont) }
 
   | otherwise
-  = do { (env, bndr') <- simplBinder env bndr
+  = do { (env, bndr') <- simplNonRecBndr env bndr
        ; env <- simplLazyBind env NotTopLevel NonRecursive bndr bndr' rhs rhs_se
        ; simplLam env bndrs body cont }
 \end{code}
@@ -962,8 +963,8 @@ completeCall env var cont
        -- is recursive, and hence a loop breaker:
        --      foldr k z (build g) = g k z
        -- So it's up to the programmer: rules can cause divergence
+       ; rules <- getRules
        ; let   in_scope   = getInScope env
-               rules      = getRules env
                maybe_rule = case activeRule dflags env of
                                Nothing     -> Nothing  -- No rules apply
                                Just act_fn -> lookupRule act_fn in_scope 
@@ -1030,7 +1031,7 @@ rebuildCall env fun fun_ty (has_rules, []) cont
   -- Then, especially in the first of these cases, we'd like to discard
   -- the continuation, leaving just the bottoming expression.  But the
   -- type might not be right, so we may have to add a coerce.
-  | not (contIsTrivial cont)    -- Only do thia if there is a non-trivial
+  | not (contIsTrivial cont)    -- Only do this if there is a non-trivial
   = return (env, mk_coerce fun)  -- contination to discard, else we do it
   where                                 -- again and again!
     cont_ty = contResultType cont
@@ -1177,9 +1178,9 @@ rebuildCase env scrut case_bndr alts cont
          (env, dup_cont, nodup_cont) <- prepareCaseCont env alts cont
 
        -- Simplify the alternatives
-       ; (case_bndr', alts') <- simplAlts env scrut case_bndr alts dup_cont
+       ; (scrut', case_bndr', alts') <- simplAlts env scrut case_bndr alts dup_cont
        ; let res_ty' = contResultType dup_cont
-       ; case_expr <- mkCase scrut case_bndr' res_ty' alts'
+       ; case_expr <- mkCase scrut' case_bndr' res_ty' alts'
 
        -- Notice that rebuildDone returns the in-scope set from env, not alt_env
        -- The case binder *not* scope over the whole returned case-expression
@@ -1277,6 +1278,35 @@ arranging that inside the outer case we add the unfolding
        v |-> x `cast` (sym co)
 to v.  Then we should inline v at the inner case, cancel the casts, and away we go
        
+Note [Improving seq]
+~~~~~~~~~~~~~~~~~~~
+Consider
+       type family F :: * -> *
+       type instance F Int = Int
+
+       ... case e of x { DEFAULT -> rhs } ...
+
+where x::F Int.  Then we'd like to rewrite (F Int) to Int, getting
+
+       case e `cast` co of x'::Int
+          I# x# -> let x = x' `cast` sym co 
+                   in rhs
+
+so that 'rhs' can take advantage of hte form of x'.  Notice that Note
+[Case of cast] may then apply to the result.
+
+This showed up in Roman's experiments.  Example:
+  foo :: F Int -> Int -> Int
+  foo t n = t `seq` bar n
+     where
+       bar 0 = 0
+       bar n = bar (n - case t of TI i -> i)
+Here we'd like to avoid repeated evaluating t inside the loop, by 
+taking advantage of the `seq`.
+
+At one point I did transformation in LiberateCase, but it's more robust here.
+(Otherwise, there's a danger that we'll simply drop the 'seq' altogether, before
+LiberateCase gets to see it.)
 
 Note [Case elimination]
 ~~~~~~~~~~~~~~~~~~~~~~~
@@ -1366,30 +1396,56 @@ I don't really know how to improve this situation.
 
 
 \begin{code}
-simplCaseBinder :: SimplEnv -> OutExpr -> InId -> SimplM (SimplEnv, OutId)
-simplCaseBinder env scrut case_bndr
-  | switchIsOn (getSwitchChecker env) NoCaseOfCase
-       -- See Note [no-case-of-case]
-  = do { (env, case_bndr') <- simplBinder env case_bndr
-       ; return (env, case_bndr') }
-
-simplCaseBinder env (Var v) case_bndr
--- Failed try [see Note 2 above]
---     not (isEvaldUnfolding (idUnfolding v))
-  = do { (env, case_bndr') <- simplBinder env (zapOccInfo case_bndr)
-       ; return (modifyInScope env v case_bndr', case_bndr') }
-       -- We could extend the substitution instead, but it would be
-       -- a hack because then the substitution wouldn't be idempotent
-       -- any more (v is an OutId).  And this does just as well.
-           
-simplCaseBinder env (Cast (Var v) co) case_bndr                -- Note [Case of cast]
-  = do { (env, case_bndr') <- simplBinder env (zapOccInfo case_bndr)
-       ; let rhs = Cast (Var case_bndr') (mkSymCoercion co)
-       ; return (addBinderUnfolding env v rhs, case_bndr') }
-
-simplCaseBinder env other_scrut case_bndr 
-  = do { (env, case_bndr') <- simplBinder env case_bndr
-       ; return (env, case_bndr') }
+simplCaseBinder :: SimplEnv -> OutExpr -> OutId -> [InAlt]
+               -> SimplM (SimplEnv, OutExpr, OutId)
+simplCaseBinder env scrut case_bndr alts
+  = do { (env1, case_bndr1) <- simplBinder env case_bndr
+
+       ; fam_envs <- getFamEnvs
+       ; (env2, scrut2, case_bndr2) <- improve_seq fam_envs env1 scrut 
+                                               case_bndr case_bndr1 alts
+                       -- Note [Improving seq]
+
+       ; let (env3, case_bndr3) = improve_case_bndr env2 scrut2 case_bndr2
+                       -- Note [Case of cast]
+
+       ; return (env3, scrut2, case_bndr3) }
+  where
+
+    improve_seq fam_envs env1 scrut case_bndr case_bndr1 [(DEFAULT,_,_)] 
+       | Just (co, ty2) <- topNormaliseType fam_envs (idType case_bndr1)
+       =  do { case_bndr2 <- newId FSLIT("nt") ty2
+             ; let rhs  = DoneEx (Var case_bndr2 `Cast` mkSymCoercion co)
+                   env2 = extendIdSubst env1 case_bndr rhs
+             ; return (env2, scrut `Cast` co, case_bndr2) }
+
+    improve_seq fam_envs env1 scrut case_bndr case_bndr1 alts
+       = return (env1, scrut, case_bndr1)
+
+
+    improve_case_bndr env scrut case_bndr
+       | switchIsOn (getSwitchChecker env) NoCaseOfCase
+               -- See Note [no-case-of-case]
+       = (env, case_bndr)
+
+       | otherwise     -- Failed try [see Note 2 above]
+                       --     not (isEvaldUnfolding (idUnfolding v))
+       = case scrut of
+           Var v -> (modifyInScope env1 v case_bndr', case_bndr')
+               -- Note about using modifyInScope for v here
+               -- We could extend the substitution instead, but it would be
+               -- a hack because then the substitution wouldn't be idempotent
+               -- any more (v is an OutId).  And this does just as well.
+
+           Cast (Var v) co -> (addBinderUnfolding env1 v rhs, case_bndr')
+                           where
+                               rhs = Cast (Var case_bndr') (mkSymCoercion co)
+
+           other -> (env, case_bndr)
+       where
+         case_bndr' = zapOccInfo case_bndr
+         env1       = modifyInScope env case_bndr case_bndr'
+
 
 zapOccInfo :: InId -> InId     -- See Note [zapOccInfo]
 zapOccInfo b = b `setIdOccInfo` NoOccInfo
@@ -1441,19 +1497,19 @@ simplAlts :: SimplEnv
          -> OutExpr
          -> InId                       -- Case binder
          -> [InAlt] -> SimplCont
-         -> SimplM (OutId, [OutAlt])   -- Includes the continuation
+         -> SimplM (OutExpr, OutId, [OutAlt])  -- Includes the continuation
 -- Like simplExpr, this just returns the simplified alternatives;
 -- it not return an environment
 
 simplAlts env scrut case_bndr alts cont'
   = -- pprTrace "simplAlts" (ppr alts $$ ppr (seIdSubst env)) $
     do { let alt_env = zapFloats env
-       ; (alt_env, case_bndr') <- simplCaseBinder alt_env scrut case_bndr
+       ; (alt_env, scrut', case_bndr') <- simplCaseBinder alt_env scrut case_bndr alts
 
        ; (imposs_deflt_cons, in_alts) <- prepareAlts scrut case_bndr' alts
 
        ; alts' <- mapM (simplAlt alt_env imposs_deflt_cons case_bndr' cont') in_alts
-       ; return (case_bndr', alts') }
+       ; return (scrut', case_bndr', alts') }
 
 ------------------------------------
 simplAlt :: SimplEnv
index d1a3445..8751e40 100644 (file)
@@ -10,14 +10,14 @@ module FamInstEnv (
        pprFamInst, pprFamInstHdr, pprFamInsts, 
        famInstHead, mkLocalFamInst, mkImportedFamInst,
 
-       FamInstEnvs, FamInstEnv, emptyFamInstEnv, 
+       FamInstEnvs, FamInstEnv, emptyFamInstEnv, emptyFamInstEnvs, 
        extendFamInstEnv, extendFamInstEnvList, 
        famInstEnvElts, familyInstances,
 
        lookupFamInstEnv, lookupFamInstEnvUnify,
        
        -- Normalisation
-       toplevelNormaliseFamInst
+       topNormaliseType
     ) where
 
 #include "HsVersions.h"
@@ -168,6 +168,9 @@ data FamilyInstEnv
 --  * The fs_tvs are distinct in each FamInst
 --     of a range value of the map (so we can safely unify them)
 
+emptyFamInstEnvs :: (FamInstEnv, FamInstEnv)
+emptyFamInstEnvs = (emptyFamInstEnv, emptyFamInstEnv)
+
 emptyFamInstEnv :: FamInstEnv
 emptyFamInstEnv = emptyUFM
 
@@ -196,7 +199,7 @@ extendFamInstEnv inst_env ins_item@(FamInst {fi_fam = cls_nm, fi_tcs = mb_tcs})
 
 %************************************************************************
 %*                                                                     *
-\subsection{Looking up a family instance}
+               Looking up a family instance
 %*                                                                     *
 %************************************************************************
 
@@ -224,6 +227,9 @@ lookupFamInstEnv :: FamInstEnvs
                 -> TyCon -> [Type]             -- What we are looking for
                 -> [FamInstMatch]              -- Successful matches
 lookupFamInstEnv (pkg_ie, home_ie) fam tys
+  | not (isOpenTyCon fam) 
+  = []
+  | otherwise
   = home_matches ++ pkg_matches
   where
     rough_tcs    = roughMatchTcs tys
@@ -273,6 +279,9 @@ indexed synonyms and we don't want to slow that down by needless unification.
 lookupFamInstEnvUnify :: (FamInstEnv, FamInstEnv) -> TyCon -> [Type]
                      -> [(FamInstMatch)]
 lookupFamInstEnvUnify (pkg_ie, home_ie) fam tys
+  | not (isOpenTyCon fam) 
+  = []
+  | otherwise
   = home_matches ++ pkg_matches
   where
     rough_tcs    = roughMatchTcs tys
@@ -318,98 +327,94 @@ bind_fn tv | isTcTyVar tv && isExistentialTyVar tv = Skolem
           | otherwise                             = BindMe
 \end{code}
 
---------------------------------------
--- Normalisation 
+%************************************************************************
+%*                                                                     *
+               Looking up a family instance
+%*                                                                     *
+%************************************************************************
 
 \begin{code}
-       -- get rid of TOPLEVEL type functions by rewriting them 
-       -- i.e. treating their equations as a TRS
-toplevelNormaliseFamInst :: FamInstEnvs ->
-                           Type ->
-                           (CoercionI,Type)
-toplevelNormaliseFamInst env ty
-       | Just ty' <- tcView ty = normaliseFamInst env ty'
-toplevelNormaliseFamInst env ty@(TyConApp tyCon tys)
-       | isOpenTyCon tyCon
-       = normaliseFamInst env ty
-toplevelNormaliseFamInst env ty
-       = (IdCo,ty)
+topNormaliseType :: FamInstEnvs
+                     -> Type
+                     -> Maybe (Coercion, Type)
+
+-- Get rid of *outermost* (or toplevel) type functions by rewriting them
+-- By "outer" we mean that toplevelNormaliseType guarantees to return
+-- a type that does not have a reducible redex (F ty1 .. tyn) as its
+-- outermost form.  It *can* return something like (Maybe (F ty)), where
+-- (F ty) is a redex.
+
+topNormaliseType env ty
+  | Just ty' <- tcView ty = topNormaliseType env ty'
+
+topNormaliseType env ty@(TyConApp tc tys)
+  | isOpenTyCon tc
+  , (ACo co, ty) <- normaliseType env ty
+  = Just (co, ty)
+
+topNormaliseType env ty
+  = Nothing
         
 
-       -- get rid of ALL type functions by rewriting them 
-       -- i.e. treating their equations as a TRS
-normaliseFamInst :: FamInstEnvs ->     -- environment with family instances
-                   Type ->             -- old type
-                   (CoercionI,Type)    -- (coercion,new type)
-normaliseFamInst env ty 
-       | Just ty' <- tcView ty = normaliseFamInst env ty' 
-normaliseFamInst env ty@(TyConApp tyCon tys) =
-       let (cois,ntys) = mapAndUnzip (normaliseFamInst env) tys
-           tycon_coi   = mkTyConAppCoI tyCon ntys cois
-           maybe_ty_co = lookupFamInst env tyCon ntys
-        in case maybe_ty_co of
-               -- a matching family instance exists
-               Just (ty',co) ->
-                       let first_coi      = mkTransCoI tycon_coi (ACo co)
-                           (rest_coi,nty) = normaliseFamInst env ty'
-                           fix_coi        = mkTransCoI first_coi rest_coi
-                       in (fix_coi,nty)
-               -- no matching family instance exists
+normaliseType :: FamInstEnvs           -- environment with family instances
+             -> Type                   -- old type
+             -> (CoercionI,Type)       -- (coercion,new type), where
+                                       -- co :: old-type ~ new_type
+-- Normalise the input type, by eliminating all type-function redexes
+
+normaliseType env ty 
+  | Just ty' <- coreView ty = normaliseType env ty' 
+
+normaliseType env ty@(TyConApp tyCon tys)
+  = let        -- First normalise the arg types
+       (cois, ntys) = mapAndUnzip (normaliseType env) tys
+       tycon_coi    = mkTyConAppCoI tyCon ntys cois
+    in         -- Now try the top-level redex
+    case lookupFamInstEnv env tyCon ntys of
+               -- A matching family instance exists
+       [(fam_inst, tys)] -> (fix_coi, nty)
+           where
+               rep_tc         = famInstTyCon fam_inst
+               co_tycon       = expectJust "lookupFamInst" (tyConFamilyCoercion_maybe rep_tc)
+               co             = mkTyConApp co_tycon tys
+               first_coi      = mkTransCoI tycon_coi (ACo co)
+               (rest_coi,nty) = normaliseType env (mkTyConApp rep_tc tys)
+               fix_coi        = mkTransCoI first_coi rest_coi
+
+               -- No unique matching family instance exists;
                -- we do not do anything
-               Nothing -> 
-                       (tycon_coi,TyConApp tyCon ntys)
-normaliseFamInst env ty@(AppTy ty1 ty2)        =
-       let (coi1,nty1) = normaliseFamInst env ty1
-           (coi2,nty2) = normaliseFamInst env ty2
+       other -> (tycon_coi, TyConApp tyCon ntys)
+
+  where
+
+normaliseType env ty@(AppTy ty1 ty2)
+  =    let (coi1,nty1) = normaliseType env ty1
+           (coi2,nty2) = normaliseType env ty2
        in  (mkAppTyCoI nty1 coi1 nty2 coi2, AppTy nty1 nty2)
-normaliseFamInst env ty@(FunTy ty1 ty2)        =
-       let (coi1,nty1) = normaliseFamInst env ty1
-           (coi2,nty2) = normaliseFamInst env ty2
+normaliseType env ty@(FunTy ty1 ty2)
+  =    let (coi1,nty1) = normaliseType env ty1
+           (coi2,nty2) = normaliseType env ty2
        in  (mkFunTyCoI nty1 coi1 nty2 coi2, FunTy nty1 nty2)
-normaliseFamInst env ty@(ForAllTy tyvar ty1)   =
-       let (coi,nty1) = normaliseFamInst env ty1
+normaliseType env ty@(ForAllTy tyvar ty1)
+  =    let (coi,nty1) = normaliseType env ty1
        in  (mkForAllTyCoI tyvar coi,ForAllTy tyvar nty1)
-normaliseFamInst env ty@(NoteTy note ty1)      =
-       let (coi,nty1) = normaliseFamInst env ty1
+normaliseType env ty@(NoteTy note ty1)
+  =    let (coi,nty1) = normaliseType env ty1
        in  (mkNoteTyCoI note coi,NoteTy note nty1)
-normaliseFamInst env ty@(TyVarTy _) =
-       (IdCo,ty)
-normaliseFamInst env (PredTy predty) =
-       normaliseFamInstPred env predty 
-
-normaliseFamInstPred :: FamInstEnvs -> PredType -> (CoercionI,Type)
-normaliseFamInstPred env (ClassP cls tys) =
-       let (cois,tys') = mapAndUnzip (normaliseFamInst env) tys
+normaliseType env ty@(TyVarTy _)
+  =    (IdCo,ty)
+normaliseType env (PredTy predty)
+  =    normalisePred env predty        
+
+normalisePred :: FamInstEnvs -> PredType -> (CoercionI,Type)
+normalisePred env (ClassP cls tys)
+  =    let (cois,tys') = mapAndUnzip (normaliseType env) tys
        in  (mkClassPPredCoI cls tys' cois, PredTy $ ClassP cls tys')
-normaliseFamInstPred env (IParam ipn ty) =
-       let (coi,ty') = normaliseFamInst env ty
+normalisePred env (IParam ipn ty)
+  =    let (coi,ty') = normaliseType env ty
        in  (mkIParamPredCoI ipn coi, PredTy $ IParam ipn ty')
-normaliseFamInstPred env (EqPred ty1 ty2) =
-       let (coi1,ty1') = normaliseFamInst env ty1
-            (coi2,ty2') = normaliseFamInst env ty2
+normalisePred env (EqPred ty1 ty2)
+  =    let (coi1,ty1') = normaliseType env ty1
+            (coi2,ty2') = normaliseType env ty2
        in  (mkEqPredCoI ty1' coi1 ty2' coi2, PredTy $ EqPred ty1' ty2')
-lookupFamInst :: FamInstEnvs -> TyCon -> [Type] -> Maybe (Type,Coercion)
-
--- (lookupFamInst F tys) looks for a top-level instance
---     co : forall a. F tys' = G a
---   (The rhs is always of form G a; see Note [The FamInst structure]
---     in FamInst.)
--- where we can instantiate 'a' with t to make tys'[t/a] = tys
--- Hence   (co t) : F tys ~ G t
--- Then we return (Just (G t, co t))
-
-lookupFamInst env tycon tys 
-  | not (isOpenTyCon tycon)            -- Dead code; fix.
-  = Nothing
-  | otherwise
-  = case lookupFamInstEnv env tycon tys of
-          [(subst, fam_inst)] -> 
-            Just (mkTyConApp rep_tc substituted_vars, mkTyConApp coercion_tycon substituted_vars)
-               where   -- NB: invariant of lookupFamInstEnv is that (tyConTyVars rep_tc)
-                       --     is in the domain of the substitution
-                 rep_tc           = famInstTyCon fam_inst
-                 coercion_tycon   = expectJust "lookupFamInst" (tyConFamilyCoercion_maybe rep_tc)
-                 substituted_vars = substTyVars subst (tyConTyVars rep_tc)
-          other -> Nothing
 \end{code}