Improve the interaction of 'seq' and associated data types
[ghc-hetmet.git] / compiler / simplCore / Simplify.lhs
index d97249f..ac1f790 100644 (file)
@@ -17,6 +17,7 @@ import Id
 import Var
 import IdInfo
 import Coercion
+import FamInstEnv      ( topNormaliseType )
 import DataCon         ( dataConRepStrictness, dataConUnivTyVars )
 import CoreSyn
 import NewDemand       ( isStrictDmd )
@@ -870,7 +871,7 @@ simplNonRecE env bndr (rhs, rhs_se) (bndrs, body) cont
                     (StrictBind bndr bndrs body env cont) }
 
   | otherwise
-  = do { (env, bndr') <- simplBinder env bndr
+  = do { (env, bndr') <- simplNonRecBndr env bndr
        ; env <- simplLazyBind env NotTopLevel NonRecursive bndr bndr' rhs rhs_se
        ; simplLam env bndrs body cont }
 \end{code}
@@ -962,8 +963,8 @@ completeCall env var cont
        -- is recursive, and hence a loop breaker:
        --      foldr k z (build g) = g k z
        -- So it's up to the programmer: rules can cause divergence
+       ; rules <- getRules
        ; let   in_scope   = getInScope env
-               rules      = getRules env
                maybe_rule = case activeRule dflags env of
                                Nothing     -> Nothing  -- No rules apply
                                Just act_fn -> lookupRule act_fn in_scope 
@@ -1030,7 +1031,7 @@ rebuildCall env fun fun_ty (has_rules, []) cont
   -- Then, especially in the first of these cases, we'd like to discard
   -- the continuation, leaving just the bottoming expression.  But the
   -- type might not be right, so we may have to add a coerce.
-  | not (contIsTrivial cont)    -- Only do thia if there is a non-trivial
+  | not (contIsTrivial cont)    -- Only do this if there is a non-trivial
   = return (env, mk_coerce fun)  -- contination to discard, else we do it
   where                                 -- again and again!
     cont_ty = contResultType cont
@@ -1177,9 +1178,9 @@ rebuildCase env scrut case_bndr alts cont
          (env, dup_cont, nodup_cont) <- prepareCaseCont env alts cont
 
        -- Simplify the alternatives
-       ; (case_bndr', alts') <- simplAlts env scrut case_bndr alts dup_cont
+       ; (scrut', case_bndr', alts') <- simplAlts env scrut case_bndr alts dup_cont
        ; let res_ty' = contResultType dup_cont
-       ; case_expr <- mkCase scrut case_bndr' res_ty' alts'
+       ; case_expr <- mkCase scrut' case_bndr' res_ty' alts'
 
        -- Notice that rebuildDone returns the in-scope set from env, not alt_env
        -- The case binder *not* scope over the whole returned case-expression
@@ -1277,6 +1278,35 @@ arranging that inside the outer case we add the unfolding
        v |-> x `cast` (sym co)
 to v.  Then we should inline v at the inner case, cancel the casts, and away we go
        
+Note [Improving seq]
+~~~~~~~~~~~~~~~~~~~
+Consider
+       type family F :: * -> *
+       type instance F Int = Int
+
+       ... case e of x { DEFAULT -> rhs } ...
+
+where x::F Int.  Then we'd like to rewrite (F Int) to Int, getting
+
+       case e `cast` co of x'::Int
+          I# x# -> let x = x' `cast` sym co 
+                   in rhs
+
+so that 'rhs' can take advantage of hte form of x'.  Notice that Note
+[Case of cast] may then apply to the result.
+
+This showed up in Roman's experiments.  Example:
+  foo :: F Int -> Int -> Int
+  foo t n = t `seq` bar n
+     where
+       bar 0 = 0
+       bar n = bar (n - case t of TI i -> i)
+Here we'd like to avoid repeated evaluating t inside the loop, by 
+taking advantage of the `seq`.
+
+At one point I did transformation in LiberateCase, but it's more robust here.
+(Otherwise, there's a danger that we'll simply drop the 'seq' altogether, before
+LiberateCase gets to see it.)
 
 Note [Case elimination]
 ~~~~~~~~~~~~~~~~~~~~~~~
@@ -1366,30 +1396,56 @@ I don't really know how to improve this situation.
 
 
 \begin{code}
-simplCaseBinder :: SimplEnv -> OutExpr -> InId -> SimplM (SimplEnv, OutId)
-simplCaseBinder env scrut case_bndr
-  | switchIsOn (getSwitchChecker env) NoCaseOfCase
-       -- See Note [no-case-of-case]
-  = do { (env, case_bndr') <- simplBinder env case_bndr
-       ; return (env, case_bndr') }
-
-simplCaseBinder env (Var v) case_bndr
--- Failed try [see Note 2 above]
---     not (isEvaldUnfolding (idUnfolding v))
-  = do { (env, case_bndr') <- simplBinder env (zapOccInfo case_bndr)
-       ; return (modifyInScope env v case_bndr', case_bndr') }
-       -- We could extend the substitution instead, but it would be
-       -- a hack because then the substitution wouldn't be idempotent
-       -- any more (v is an OutId).  And this does just as well.
-           
-simplCaseBinder env (Cast (Var v) co) case_bndr                -- Note [Case of cast]
-  = do { (env, case_bndr') <- simplBinder env (zapOccInfo case_bndr)
-       ; let rhs = Cast (Var case_bndr') (mkSymCoercion co)
-       ; return (addBinderUnfolding env v rhs, case_bndr') }
-
-simplCaseBinder env other_scrut case_bndr 
-  = do { (env, case_bndr') <- simplBinder env case_bndr
-       ; return (env, case_bndr') }
+simplCaseBinder :: SimplEnv -> OutExpr -> OutId -> [InAlt]
+               -> SimplM (SimplEnv, OutExpr, OutId)
+simplCaseBinder env scrut case_bndr alts
+  = do { (env1, case_bndr1) <- simplBinder env case_bndr
+
+       ; fam_envs <- getFamEnvs
+       ; (env2, scrut2, case_bndr2) <- improve_seq fam_envs env1 scrut 
+                                               case_bndr case_bndr1 alts
+                       -- Note [Improving seq]
+
+       ; let (env3, case_bndr3) = improve_case_bndr env2 scrut2 case_bndr2
+                       -- Note [Case of cast]
+
+       ; return (env3, scrut2, case_bndr3) }
+  where
+
+    improve_seq fam_envs env1 scrut case_bndr case_bndr1 [(DEFAULT,_,_)] 
+       | Just (co, ty2) <- topNormaliseType fam_envs (idType case_bndr1)
+       =  do { case_bndr2 <- newId FSLIT("nt") ty2
+             ; let rhs  = DoneEx (Var case_bndr2 `Cast` mkSymCoercion co)
+                   env2 = extendIdSubst env1 case_bndr rhs
+             ; return (env2, scrut `Cast` co, case_bndr2) }
+
+    improve_seq fam_envs env1 scrut case_bndr case_bndr1 alts
+       = return (env1, scrut, case_bndr1)
+
+
+    improve_case_bndr env scrut case_bndr
+       | switchIsOn (getSwitchChecker env) NoCaseOfCase
+               -- See Note [no-case-of-case]
+       = (env, case_bndr)
+
+       | otherwise     -- Failed try [see Note 2 above]
+                       --     not (isEvaldUnfolding (idUnfolding v))
+       = case scrut of
+           Var v -> (modifyInScope env1 v case_bndr', case_bndr')
+               -- Note about using modifyInScope for v here
+               -- We could extend the substitution instead, but it would be
+               -- a hack because then the substitution wouldn't be idempotent
+               -- any more (v is an OutId).  And this does just as well.
+
+           Cast (Var v) co -> (addBinderUnfolding env1 v rhs, case_bndr')
+                           where
+                               rhs = Cast (Var case_bndr') (mkSymCoercion co)
+
+           other -> (env, case_bndr)
+       where
+         case_bndr' = zapOccInfo case_bndr
+         env1       = modifyInScope env case_bndr case_bndr'
+
 
 zapOccInfo :: InId -> InId     -- See Note [zapOccInfo]
 zapOccInfo b = b `setIdOccInfo` NoOccInfo
@@ -1441,19 +1497,19 @@ simplAlts :: SimplEnv
          -> OutExpr
          -> InId                       -- Case binder
          -> [InAlt] -> SimplCont
-         -> SimplM (OutId, [OutAlt])   -- Includes the continuation
+         -> SimplM (OutExpr, OutId, [OutAlt])  -- Includes the continuation
 -- Like simplExpr, this just returns the simplified alternatives;
 -- it not return an environment
 
 simplAlts env scrut case_bndr alts cont'
   = -- pprTrace "simplAlts" (ppr alts $$ ppr (seIdSubst env)) $
     do { let alt_env = zapFloats env
-       ; (alt_env, case_bndr') <- simplCaseBinder alt_env scrut case_bndr
+       ; (alt_env, scrut', case_bndr') <- simplCaseBinder alt_env scrut case_bndr alts
 
        ; (imposs_deflt_cons, in_alts) <- prepareAlts scrut case_bndr' alts
 
        ; alts' <- mapM (simplAlt alt_env imposs_deflt_cons case_bndr' cont') in_alts
-       ; return (case_bndr', alts') }
+       ; return (scrut', case_bndr', alts') }
 
 ------------------------------------
 simplAlt :: SimplEnv