[project @ 2002-07-11 06:52:23 by ken]
[ghc-hetmet.git] / ghc / compiler / simplCore / CSE.lhs
index b2e124a..3354bf8 100644 (file)
@@ -10,18 +10,19 @@ module CSE (
 
 #include "HsVersions.h"
 
-import CmdLineOpts     ( DynFlag(..), DynFlags, dopt )
-import Id              ( Id, idType )
+import CmdLineOpts     ( DynFlag(..), DynFlags )
+import Id              ( Id, idType, idWorkerInfo )
+import IdInfo          ( workerExists )
 import CoreUtils       ( hashExpr, cheapEqExpr, exprIsBig, mkAltExpr )
 import DataCon         ( isUnboxedTupleCon )
-import Type            ( splitTyConApp_maybe )
+import Type            ( tyConAppArgs )
 import Subst           ( InScopeSet, uniqAway, emptyInScopeSet, 
                          extendInScopeSet, elemInScopeSet )
 import CoreSyn
 import VarEnv  
-import CoreLint                ( beginPass, endPass )
+import CoreLint                ( showPass, endPass )
 import Outputable
-import Util            ( mapAccumL )
+import Util            ( mapAccumL, lengthExceeds )
 import UniqFM
 \end{code}
 
@@ -107,11 +108,9 @@ cseProgram :: DynFlags -> [CoreBind] -> IO [CoreBind]
 
 cseProgram dflags binds
   = do {
-       beginPass dflags "Common sub-expression";
+       showPass dflags "Common sub-expression";
        let { binds' = cseBinds emptyCSEnv binds };
-       endPass dflags "Common sub-expression" 
-               (dopt Opt_D_dump_cse dflags || dopt Opt_D_verbose_core2core dflags)
-               binds'  
+       endPass dflags "Common sub-expression"  Opt_D_dump_cse binds'   
     }
 
 cseBinds :: CSEnv -> [CoreBind] -> [CoreBind]
@@ -128,12 +127,23 @@ cseBind env (Rec pairs)  = let (env', pairs') = mapAccumL do_one env pairs
                           in (env', Rec pairs')
                         
 
-do_one env (id, rhs) = case lookupCSEnv env rhs' of
-                         Just other_id -> (extendSubst env' id other_id, (id', Var other_id))
-                         Nothing       -> (addCSEnvItem env' id' rhs',   (id', rhs'))
-                    where
-                       (env', id') = addBinder env id
-                       rhs'        = cseExpr env' rhs
+do_one env (id, rhs) 
+  = case lookupCSEnv env rhs' of
+       Just other_id -> (extendSubst env' id other_id, (id', Var other_id))
+       Nothing       -> (addCSEnvItem env' id' rhs',   (id', rhs'))
+  where
+    (env', id') = addBinder env id
+    rhs' | not (workerExists (idWorkerInfo id)) = cseExpr env' rhs
+
+               -- Hack alert: don't do CSE on wrapper RHSs.
+               -- Otherwise we find:
+               --      $wf = h
+               --      f = \x -> ...$wf...
+               -- ===>
+               --      f = \x -> ...h...
+               -- But the WorkerInfo for f still says $wf, which is now dead!
+         | otherwise = rhs
+
 
 tryForCSE :: CSEnv -> CoreExpr -> CoreExpr
 tryForCSE env (Type t) = Type t
@@ -170,9 +180,7 @@ cseAlts env scrut' bndr bndr' alts
                other ->  (bndr', extendCSEnv env bndr' scrut') -- See "yet another wrinkle"
                                                                -- map: scrut' -> bndr'
 
-    arg_tys = case splitTyConApp_maybe (idType bndr) of
-               Just (_, arg_tys) -> arg_tys
-               other             -> pprPanic "cseAlts" (ppr bndr)
+    arg_tys = tyConAppArgs (idType bndr)
 
     cse_alt (DataAlt con, args, rhs)
        | not (null args || isUnboxedTupleCon con)
@@ -224,12 +232,14 @@ lookup_list ((x,e):es) expr | cheapEqExpr e expr = Just x
 
 addCSEnvItem env id expr | exprIsBig expr = env
                         | otherwise      = extendCSEnv env id expr
+   -- We don't try to CSE big expressions, because they are expensive to compare
+   -- (and are unlikely to be the same anyway)
 
 extendCSEnv (CS cs in_scope sub) id expr
   = CS (addToUFM_C combine cs hash [(id, expr)]) in_scope sub
   where
     hash   = hashExpr expr
-    combine old new = WARN( length result > 4, text "extendCSEnv: long list:" <+> ppr result )
+    combine old new = WARN( result `lengthExceeds` 4, text "extendCSEnv: long list:" <+> ppr result )
                      result
                    where
                      result = new ++ old