[project @ 2001-11-16 15:42:26 by simonpj]
[ghc-hetmet.git] / ghc / compiler / simplCore / SimplUtils.lhs
index afc53dc..e894bc0 100644 (file)
@@ -5,8 +5,8 @@
 
 \begin{code}
 module SimplUtils (
-       simplBinder, simplBinders, simplRecIds, simplLetId, simplLamBinders,
-       tryEtaExpansion,
+       simplBinder, simplBinders, simplRecBndrs, 
+       simplLetBndr, simplLamBndrs, 
        newId, mkLam, mkCase,
 
        -- The continuation type
@@ -30,8 +30,8 @@ import CoreUtils      ( cheapEqExpr, exprType,
                          findDefault, exprOkForSpeculation, exprIsValue
                        )
 import qualified Subst ( simplBndrs, simplBndr, simplLetId, simplLamBndr )
-import Id              ( Id, idType, idInfo,
-                         mkSysLocal, hasNoBinding, isDeadBinder, idNewDemandInfo,
+import Id              ( Id, idType, idInfo, 
+                         mkSysLocal, isDeadBinder, idNewDemandInfo,
                          idUnfolding, idNewStrictness
                        )
 import NewDemand       ( isStrictDmd, isBotRes, splitStrictSig )
@@ -41,7 +41,7 @@ import Type           ( Type, seqType,
                          splitRepFunTys, isStrictType
                        )
 import OccName         ( UserFS )
-import TyCon           ( tyConDataConsIfAvailable, isDataTyCon )
+import TyCon           ( tyConDataConsIfAvailable, isAlgTyCon, isNewTyCon )
 import DataCon         ( dataConRepArity, dataConSig, dataConArgTys )
 import Var             ( mkSysTyVar, tyVarKind )
 import Util            ( lengthExceeds, mapAccumL )
@@ -77,14 +77,16 @@ data SimplCont              -- Strict contexts
             InId [InAlt] SimplEnv      -- The case binder, alts, and subst-env
             SimplCont
 
-  | ArgOf    DupFlag           -- An arbitrary strict context: the argument 
+  | ArgOf    LetRhsFlag                -- An arbitrary strict context: the argument 
                                --      of a strict function, or a primitive-arg fn
                                --      or a PrimOp
-            LetRhsFlag
+                               -- No DupFlag because we never duplicate it
+            OutType            -- arg_ty: type of the argument itself
             OutType            -- cont_ty: the type of the expression being sought by the context
                                --      f (error "foo") ==> coerce t (error "foo")
                                -- when f is strict
                                -- We need to know the type t, to which to coerce.
+
             (SimplEnv -> OutExpr -> SimplM FloatsWithExpr)     -- What to do with the result
                                -- The result expression in the OutExprStuff has type cont_ty
 
@@ -98,7 +100,7 @@ instance Outputable LetRhsFlag where
 instance Outputable SimplCont where
   ppr (Stop _ is_rhs _)             = ptext SLIT("Stop") <> brackets (ppr is_rhs)
   ppr (ApplyTo dup arg se cont)      = (ptext SLIT("ApplyTo") <+> ppr dup <+> ppr arg) $$ ppr cont
-  ppr (ArgOf   dup _ _ _)           = ptext SLIT("ArgOf...") <+> ppr dup
+  ppr (ArgOf _ _ _ _)               = ptext SLIT("ArgOf...")
   ppr (Select dup bndr alts se cont) = (ptext SLIT("Select") <+> ppr dup <+> ppr bndr) $$ 
                                       (nest 4 (ppr alts)) $$ ppr cont
   ppr (CoerceIt ty cont)            = (ptext SLIT("CoerceIt") <+> ppr ty) $$ ppr cont
@@ -120,7 +122,7 @@ mkStop ty is_rhs = Stop ty is_rhs (canUpdateInPlace ty)
 
 contIsRhs :: SimplCont -> Bool
 contIsRhs (Stop _ AnRhs _)    = True
-contIsRhs (ArgOf _ AnRhs _ _) = True
+contIsRhs (ArgOf AnRhs _ _ _) = True
 contIsRhs other                      = False
 
 contIsRhsOrArg (Stop _ _ _)    = True
@@ -131,7 +133,6 @@ contIsRhsOrArg other               = False
 contIsDupable :: SimplCont -> Bool
 contIsDupable (Stop _ _ _)                      = True
 contIsDupable (ApplyTo  OkToDup _ _ _)   = True
-contIsDupable (ArgOf    OkToDup _ _ _)   = True
 contIsDupable (Select   OkToDup _ _ _ _) = True
 contIsDupable (CoerceIt _ cont)          = contIsDupable cont
 contIsDupable (InlinePlease cont)       = contIsDupable cont
@@ -439,29 +440,25 @@ simplBinder env bndr
     returnSmpl (setSubst env subst', bndr')
 
 
-simplLamBinders :: SimplEnv -> [InBinder] -> SimplM (SimplEnv, [OutBinder])
-simplLamBinders env bndrs
+simplLetBndr :: SimplEnv -> InBinder -> SimplM (SimplEnv, OutBinder)
+simplLetBndr env id
   = let
-       (subst', bndrs') = mapAccumL Subst.simplLamBndr (getSubst env) bndrs
+       (subst', id') = Subst.simplLetId (getSubst env) id
     in
-    seqBndrs bndrs'    `seq`
-    returnSmpl (setSubst env subst', bndrs')
+    seqBndr id'                `seq`
+    returnSmpl (setSubst env subst', id')
 
-simplRecIds :: SimplEnv -> [InBinder] -> SimplM (SimplEnv, [OutBinder])
-simplRecIds env ids
-  = let
-       (subst', ids') = mapAccumL Subst.simplLetId (getSubst env) ids
-    in
-    seqBndrs ids'      `seq`
-    returnSmpl (setSubst env subst', ids')
+simplLamBndrs, simplRecBndrs 
+       :: SimplEnv -> [InBinder] -> SimplM (SimplEnv, [OutBinder])
+simplRecBndrs = simplBndrs Subst.simplLetId
+simplLamBndrs = simplBndrs Subst.simplLamBndr
 
-simplLetId :: SimplEnv -> InBinder -> SimplM (SimplEnv, OutBinder)
-simplLetId env id
+simplBndrs simpl_bndr env bndrs
   = let
-       (subst', id') = Subst.simplLetId (getSubst env) id
+       (subst', bndrs') = mapAccumL simpl_bndr (getSubst env) bndrs
     in
-    seqBndr id'                `seq`
-    returnSmpl (setSubst env subst', id')
+    seqBndrs bndrs'    `seq`
+    returnSmpl (setSubst env subst', bndrs')
 
 seqBndrs [] = ()
 seqBndrs (b:bs) = seqBndr b `seq` seqBndrs bs
@@ -550,7 +547,11 @@ tryEtaReduce bndrs body
     go []       (Var fun)     | ok_fun fun   = Just (Var fun)  -- Success!
     go _        _                           = Nothing          -- Failure!
 
-    ok_fun fun   = not (fun `elem` bndrs) && not (hasNoBinding fun)
+    ok_fun fun   = not (fun `elem` bndrs) && 
+                  isEvaldUnfolding (idUnfolding fun)
+       -- The exprIsValue is because eta reduction is not 
+       -- valid in general:  \x. bot  /=  bot
+       -- So we need to be sure that the "fun" is a value.
     ok_arg b arg = varToCoreExpr b `cheapEqExpr` arg
 \end{code}
 
@@ -579,14 +580,10 @@ actually computing the expansion.
 tryEtaExpansion :: OutExpr -> SimplM OutExpr
 -- There is at least one runtime binder in the binders
 tryEtaExpansion body
-  | arity_is_manifest          -- Some lambdas but not enough
-  = returnSmpl body
-
-  | otherwise
   = getUniquesSmpl                     `thenSmpl` \ us ->
     returnSmpl (etaExpand fun_arity us body (exprType body))
   where
-    (fun_arity, arity_is_manifest) = exprEtaExpandArity body
+    fun_arity = exprEtaExpandArity body
 \end{code}
 
 
@@ -782,10 +779,10 @@ tryRhsTyLam env tyvars body               -- Only does something if there's a let
 mkCase puts a case expression back together, trying various transformations first.
 
 \begin{code}
-mkCase :: OutExpr -> OutId -> [OutAlt] -> SimplM OutExpr
+mkCase :: OutExpr -> [AltCon] -> OutId -> [OutAlt] -> SimplM OutExpr
 
-mkCase scrut case_bndr alts
-  = mkAlts scrut case_bndr alts        `thenSmpl` \ better_alts ->
+mkCase scrut handled_cons case_bndr alts
+  = mkAlts scrut handled_cons case_bndr alts   `thenSmpl` \ better_alts ->
     mkCase1 scrut case_bndr better_alts
 \end{code}
 
@@ -860,7 +857,7 @@ and similarly in cascade for all the join points!
 --------------------------------------------------
 --     1. Merge identical branches
 --------------------------------------------------
-mkAlts scrut case_bndr alts@((con1,bndrs1,rhs1) : con_alts)
+mkAlts scrut handled_cons case_bndr alts@((con1,bndrs1,rhs1) : con_alts)
   | all isDeadBinder bndrs1,                   -- Remember the default 
     length filtered_alts < length con_alts     -- alternative comes first
   = tick (AltMerge case_bndr)                  `thenSmpl_`
@@ -875,13 +872,20 @@ mkAlts scrut case_bndr alts@((con1,bndrs1,rhs1) : con_alts)
 --     2. Fill in missing constructor
 --------------------------------------------------
 
-mkAlts scrut case_bndr alts
-  | Just (tycon, inst_tys) <- splitTyConApp_maybe (idType case_bndr),
-    isDataTyCon tycon,                 -- It's a data type
-    (alts_no_deflt, Just rhs) <- findDefault alts,
-               -- There is a DEFAULT case
-    [missing_con] <- filter is_missing (tyConDataConsIfAvailable tycon)
-               -- There is just one missing constructor!
+mkAlts scrut handled_cons case_bndr alts
+  | (alts_no_deflt, Just rhs) <- findDefault alts,
+                       -- There is a DEFAULT case
+
+    Just (tycon, inst_tys) <- splitTyConApp_maybe (idType case_bndr),
+    isAlgTyCon tycon,          -- It's a data type, tuple, or unboxed tuples.  
+    not (isNewTyCon tycon),    -- We can have a newtype, if we are just doing an eval:
+                               --      case x of { DEFAULT -> e }
+                               -- and we don't want to fill in a default for them!
+
+    [missing_con] <- [con | con <- tyConDataConsIfAvailable tycon,
+                           not (con `elem` handled_data_cons)]
+                       -- There is just one missing constructor!
+
   = tick (FillInCaseDefault case_bndr) `thenSmpl_`
     getUniquesSmpl                     `thenSmpl` \ tv_uniqs ->
     getUniquesSmpl                     `thenSmpl` \ id_uniqs ->
@@ -895,16 +899,13 @@ mkAlts scrut case_bndr alts
     in
     returnSmpl better_alts
   where
-    impossible_cons   = otherCons (idUnfolding case_bndr)
-    handled_data_cons = [data_con | DataAlt data_con         <- impossible_cons] ++
-                       [data_con | (DataAlt data_con, _, _) <- alts]
-    is_missing con    = not (con `elem` handled_data_cons)
+    handled_data_cons = [data_con | DataAlt data_con <- handled_cons]
 
 --------------------------------------------------
 --     3.  Merge nested cases
 --------------------------------------------------
 
-mkAlts scrut outer_bndr outer_alts
+mkAlts scrut handled_cons outer_bndr outer_alts
   | opt_SimplCaseMerge,
     (outer_alts_without_deflt, maybe_outer_deflt)   <- findDefault outer_alts,
     Just (Case (Var scrut_var) inner_bndr inner_alts) <- maybe_outer_deflt,
@@ -948,7 +949,7 @@ mkAlts scrut outer_bndr outer_alts
 --     Catch-all
 --------------------------------------------------
 
-mkAlts scrut case_bndr other_alts = returnSmpl other_alts
+mkAlts scrut handled_cons case_bndr other_alts = returnSmpl other_alts
 \end{code}