[project @ 2001-06-25 08:08:32 by simonpj]
authorsimonpj <unknown>
Mon, 25 Jun 2001 08:08:32 +0000 (08:08 +0000)
committersimonpj <unknown>
Mon, 25 Jun 2001 08:08:32 +0000 (08:08 +0000)
---------------------------
Add a new case optimisation
---------------------------

I found that lib/std/PrelCError had a case-expression that was
generating terrible code.   Something like this

x | p `is` 1 -> e1
  | p `is` 2 -> e2
...etc...

where @is@ was something like

p `is` n = p /= (-1) && p == n

This gave rise to a horrible sequence of cases

case p of
  (-1) -> $j p
  1    -> e1
  DEFAULT -> $j p

and similarly in cascade for all the join points!

Solution: add the following transformation:

case e of =====>     case e of
  C _ -> <expr>      D v -> ....v....
  D v -> ....v....      DEFAULT -> <expr>
  DEFAULT -> <expr>

The point is that we merge common RHSs, at least for the DEFAULT case.
[One could do something more elaborate but I've never seen it needed.]

This transformation is implemented in SimplUtils.mkCase

*** WARNING ***
To make this transformation easy, I have switched the convention
for DEFAULT clauses.  They must now occur FIRST in the list of
alternatives for a Core case expression.  (The semantics is
unchanged: they still are a catch-all case.)

The reason is that DEFAULT clauses sometimes need special treatment,
and it's a lot easier to find them at the front.

The easiest way to be insensitive to this change is to use
CoreUtils.findDefault to pull the default clause out.

I've made the (surprisingly few) changes consequent on this changed
of convention, but they aren't in this commit.  Instead they are part
of the big commit on newtypes I'm doing at the same time.

ghc/compiler/coreSyn/CoreSyn.lhs
ghc/compiler/coreSyn/CoreUtils.lhs
ghc/compiler/simplCore/SimplUtils.lhs

index 10ffe27..2c89f6e 100644 (file)
@@ -77,7 +77,7 @@ data Expr b   -- "b" for the type of binders,
   | Lam   b (Expr b)
   | Let   (Bind b) (Expr b)
   | Case  (Expr b) b [Alt b]   -- Binder gets bound to value of scrutinee
-                               -- DEFAULT case must be last, if it occurs at all
+                               -- DEFAULT case must be *first*, if it occurs at all
   | Note  Note (Expr b)
   | Type  Type                 -- This should only show up at the top
                                -- level of an Arg
index f7130eb..b3cab66 100644 (file)
@@ -11,7 +11,7 @@ module CoreUtils (
         mkPiType,
 
        -- Taking expressions apart
-       findDefault, findAlt,
+       findDefault, findAlt, hasDefault,
 
        -- Properties of expressions
        exprType, coreAltsType, 
@@ -60,7 +60,7 @@ import IdInfo         ( LBVarInfo(..),
 import Demand          ( appIsBottom )
 import Type            ( Type, mkFunTy, mkForAllTy, splitFunTy_maybe, 
                          applyTys, isUnLiftedType, seqType, mkUTy, mkTyVarTy,
-                         splitForAllTy_maybe, splitNewType_maybe, isForAllTy
+                         splitForAllTy_maybe, isForAllTy, eqType
                        )
 import TysWiredIn      ( boolTy, trueDataCon, falseDataCon )
 import CostCentre      ( CostCentre )
@@ -185,13 +185,13 @@ mkInlineMe e         = Note InlineMe e
 mkCoerce :: Type -> Type -> CoreExpr -> CoreExpr
 
 mkCoerce to_ty from_ty (Note (Coerce to_ty2 from_ty2) expr)
-  = ASSERT( from_ty == to_ty2 )
+  = ASSERT( from_ty `eqType` to_ty2 )
     mkCoerce to_ty from_ty2 expr
 
 mkCoerce to_ty from_ty expr
-  | to_ty == from_ty = expr
-  | otherwise       = ASSERT( from_ty == exprType expr )
-                      Note (Coerce to_ty from_ty) expr
+  | to_ty `eqType` from_ty = expr
+  | otherwise             = ASSERT( from_ty `eqType` exprType expr )
+                            Note (Coerce to_ty from_ty) expr
 \end{code}
 
 \begin{code}
@@ -251,25 +251,31 @@ mkIfThenElse guard then_expr else_expr
 %*                                                                     *
 %************************************************************************
 
+The default alternative must be first, if it exists at all.
+This makes it easy to find, though it makes matching marginally harder.
 
 \begin{code}
+hasDefault :: [CoreAlt] -> Bool
+hasDefault ((DEFAULT,_,_) : alts) = True
+hasDefault _                     = False
+
 findDefault :: [CoreAlt] -> ([CoreAlt], Maybe CoreExpr)
-findDefault []                         = ([], Nothing)
-findDefault ((DEFAULT,args,rhs) : alts) = ASSERT( null alts && null args ) 
-                                         ([], Just rhs)
-findDefault (alt : alts)               = case findDefault alts of 
-                                           (alts', deflt) -> (alt : alts', deflt)
+findDefault ((DEFAULT,args,rhs) : alts) = ASSERT( null args ) (alts, Just rhs)
+findDefault alts                       =                     (alts, Nothing)
 
 findAlt :: AltCon -> [CoreAlt] -> CoreAlt
 findAlt con alts
-  = go alts
+  = case alts of
+       (deflt@(DEFAULT,_,_):alts) -> go alts deflt
+       other                      -> go alts panic_deflt
+
   where
-    go []          = pprPanic "Missing alternative" (ppr con $$ vcat (map ppr alts))
-    go (alt : alts) | matches alt = alt
-                   | otherwise   = go alts
+    panic_deflt = pprPanic "Missing alternative" (ppr con $$ vcat (map ppr alts))
 
-    matches (DEFAULT, _, _) = True
-    matches (con1, _, _)    = con == con1
+    go []                     deflt               = deflt
+    go (alt@(con1,_,_) : alts) deflt | con == con1 = alt
+                                    | otherwise   = ASSERT( not (con1 == DEFAULT) )
+                                                    go alts deflt
 \end{code}
 
 
@@ -755,13 +761,8 @@ etaExpand n us expr ty
                                   (us1, us2) = splitUniqSupply us
                                   uniq       = uniqFromSupply us1 
                                   
-       ; Nothing -> 
-  
-       case splitNewType_maybe ty of {
-         Just ty' -> mkCoerce ty ty' (etaExpand n us (mkCoerce ty' ty expr) ty') ;
-  
-         Nothing -> pprTrace "Bad eta expand" (ppr expr $$ ppr ty) expr
-       }}}
+       ; Nothing -> pprTrace "Bad eta expand" (ppr expr $$ ppr ty) expr
+       }}
 \end{code}
 
 
@@ -818,7 +819,7 @@ cheapEqExpr :: Expr b -> Expr b -> Bool
 
 cheapEqExpr (Var v1)   (Var v2)   = v1==v2
 cheapEqExpr (Lit lit1) (Lit lit2) = lit1 == lit2
-cheapEqExpr (Type t1)  (Type t2)  = t1 == t2
+cheapEqExpr (Type t1)  (Type t2)  = t1 `eqType` t2
 
 cheapEqExpr (App f1 a1) (App f2 a2)
   = f1 `cheapEqExpr` f2 && a1 `cheapEqExpr` a2
@@ -838,6 +839,9 @@ exprIsBig other            = True
 \begin{code}
 eqExpr :: CoreExpr -> CoreExpr -> Bool
        -- Works ok at more general type, but only needed at CoreExpr
+       -- Used in rule matching, so when we find a type we use
+       -- eqTcType, which doesn't look through newtypes
+       -- [And it doesn't risk falling into a black hole either.]
 eqExpr e1 e2
   = eq emptyVarEnv e1 e2
   where
@@ -868,7 +872,7 @@ eqExpr e1 e2
                                       env' = extendVarEnv env v1 v2
 
     eq env (Note n1 e1) (Note n2 e2) = eq_note env n1 n2 && eq env e1 e2
-    eq env (Type t1)    (Type t2)    = t1 == t2
+    eq env (Type t1)    (Type t2)    = t1 `eqType` t2
     eq env e1          e2           = False
                                         
     eq_list env []      []       = True
@@ -879,7 +883,7 @@ eqExpr e1 e2
                                         eq (extendVarEnvList env (vs1 `zip` vs2)) r1 r2
 
     eq_note env (SCC cc1)      (SCC cc2)      = cc1 == cc2
-    eq_note env (Coerce t1 f1) (Coerce t2 f2) = t1==t2 && f1==f2
+    eq_note env (Coerce t1 f1) (Coerce t2 f2) = t1 `eqType` t2 && f1 `eqType` f2
     eq_note env InlineCall     InlineCall     = True
     eq_note env other1        other2         = False
 \end{code}
index 501dd60..d40f151 100644 (file)
@@ -40,9 +40,10 @@ import Demand                ( isStrict )
 import SimplMonad
 import Type            ( Type, mkForAllTys, seqType, repType,
                          splitTyConApp_maybe, tyConAppArgs, mkTyVarTys,
-                         isDictTy, isDataType, isUnLiftedType,
+                         isUnLiftedType,
                          splitRepFunTys
                        )
+import TcType          ( isStrictType )
 import TyCon           ( tyConDataConsIfAvailable )
 import DataCon         ( dataConRepArity )
 import VarEnv          ( SubstEnv )
@@ -246,19 +247,6 @@ getContArgs fun orig_cont
 
          other -> vanilla_stricts      -- Not enough args, or no strictness
 
-
--------------------
-isStrictType :: Type -> Bool
-       -- isStrictType computes whether an argument (or let RHS) should
-       -- be computed strictly or lazily, based only on its type
-isStrictType ty
-  | isUnLiftedType ty                              = True
-  | opt_DictsStrict && isDictTy ty && isDataType ty = True
-  | otherwise                                      = False 
-       -- Return true only for dictionary types where the dictionary
-       -- has more than one component (else we risk poking on the component
-       -- of a newtype dictionary)
-
 -------------------
 interestingArg :: InScopeSet -> InExpr -> SubstEnv -> Bool
        -- An argument is interesting if it has *some* structure
@@ -402,21 +390,16 @@ canUpdateInPlace :: Type -> Bool
 -- small arity.  But arity zero isn't good -- we share the single copy
 -- for that case, so no point in sharing.
 
--- Note the repType: we want to look through newtypes for this purpose
-
 canUpdateInPlace ty 
   | not opt_UF_UpdateInPlace = False
   | otherwise
-  = case splitTyConApp_maybe (repType ty) of {
-                       Nothing         -> False ;
-                       Just (tycon, _) -> 
-
-                     case tyConDataConsIfAvailable tycon of
-                       [dc]  -> arity == 1 || arity == 2
-                             where
-                                arity = dataConRepArity dc
-                       other -> False
-                     }
+  = case splitTyConApp_maybe ty of 
+       Nothing         -> False 
+       Just (tycon, _) -> case tyConDataConsIfAvailable tycon of
+                               [dc]  -> arity == 1 || arity == 2
+                                     where
+                                        arity = dataConRepArity dc
+                               other -> False
 \end{code}
 
 
@@ -774,11 +757,12 @@ mkCase scrut outer_bndr outer_alts
        -- Secondly, if you do, you get an infinite loop, because the bindNonRec
        -- in munge_rhs puts a case into the DEFAULT branch!
   where
-    new_alts = outer_alts_without_deflt ++ munged_inner_alts
+    new_alts = add_default maybe_inner_default
+                          (outer_alts_without_deflt ++ inner_con_alts)
+
     maybe_case_in_default = case findDefault outer_alts of
                                (outer_alts_without_default,
                                 Just (Case (Var scrut_var) inner_bndr inner_alts))
-                                
                                   | outer_bndr == scrut_var
                                   -> Just (outer_alts_without_default, inner_bndr, inner_alts)
                                other -> Nothing
@@ -793,12 +777,17 @@ mkCase scrut outer_bndr outer_alts
                           not (con `elem` outer_cons)  -- Eliminate shadowed inner alts
                        ]
     munge_rhs rhs = bindNonRec inner_bndr (Var outer_bndr) rhs
+
+    (inner_con_alts, maybe_inner_default) = findDefault munged_inner_alts
+
+    add_default (Just rhs) alts = (DEFAULT,[],rhs) : alts
+    add_default Nothing    alts = alts
 \end{code}
 
 Now the identity-case transformation:
 
        case e of               ===> e
-               True -> True;
+               True  -> True;
                False -> False
 
 and similar friends.
@@ -831,11 +820,43 @@ mkCase scrut case_bndr alts
                        other                 -> scrut
 \end{code}
 
-The catch-all case
+The catch-all case.  We do a final transformation that I've
+occasionally seen making a big difference:
+
+       case e of               =====>     case e of
+         C _ -> f x                         D v -> ....v....
+         D v -> ....v....                   DEFAULT -> f x
+         DEFAULT -> f x
 
+The point is that we merge common RHSs, at least for the DEFAULT case.
+[One could do something more elaborate but I've never seen it needed.]
+The case where this came up was like this (lib/std/PrelCError.lhs):
+
+       x | p `is` 1 -> e1
+         | p `is` 2 -> e2
+       ...etc...
+
+where @is@ was something like
+       
+       p `is` n = p /= (-1) && p == n
+
+This gave rise to a horrible sequence of cases
+
+       case p of
+         (-1) -> $j p
+         1    -> e1
+         DEFAULT -> $j p
+
+and similarly in cascade for all the join points!
+         
 \begin{code}
 mkCase other_scrut case_bndr other_alts
-  = returnSmpl (Case other_scrut case_bndr other_alts)
+  = returnSmpl (Case other_scrut case_bndr (mergeDefault other_alts))
+
+mergeDefault (deflt_alt@(DEFAULT,_,deflt_rhs) : con_alts)
+  = deflt_alt : [alt | alt@(con,_,rhs) <- con_alts, not (rhs `cheapEqExpr` deflt_rhs)]
+       -- NB: we can neglect the binders because we won't get equality if the
+       -- binders are mentioned in rhs (no shadowing)
+mergeDefault other_alts
+  = other_alts
 \end{code}
-
-