[project @ 2001-06-25 08:09:57 by simonpj]
[ghc-hetmet.git] / ghc / compiler / specialise / Rules.lhs
index 8d8819a..591e4db 100644 (file)
@@ -17,7 +17,7 @@ module Rules (
 
 import CoreSyn         -- All of it
 import OccurAnal       ( occurAnalyseRule )
-import CoreFVs         ( exprFreeVars, ruleRhsFreeVars, ruleSomeLhsFreeVars )
+import CoreFVs         ( exprFreeVars, ruleRhsFreeVars, ruleLhsFreeIds )
 import CoreUnfold      ( isCheapUnfolding, unfoldingTemplate )
 import CoreUtils       ( eqExpr )
 import PprCore         ( pprCoreRule )
@@ -29,11 +29,12 @@ import Id           ( Id, idUnfolding, idSpecialisation, setIdSpecialisation )
 import Var             ( isId )
 import VarSet
 import VarEnv
-import Type            ( mkTyVarTy )
-import qualified Unify ( match )
+import TcType          ( mkTyVarTy )
+import qualified TcType ( match )
+import TypeRep         ( Type(..) )    -- Can see type representation for matching
 
 import Outputable
-import Maybes          ( maybeToBool )
+import Maybe           ( isJust, isNothing, fromMaybe )
 import Util            ( sortLt )
 \end{code}
 
@@ -180,7 +181,7 @@ matchRule in_scope rule@(Rule rn tpl_vars tpl_args rhs) args
                                     mk_result_args subst done)
            Nothing         -> Nothing  -- Failure
       where
-       (done, leftovers) = partition (\v -> maybeToBool (lookupSubstEnv subst_env v))
+       (done, leftovers) = partition (\v -> isJust (lookupSubstEnv subst_env v))
                                      (map zapOccInfo tpl_vars)
                -- Zap the occ info 
        subst_env = substEnv subst
@@ -237,10 +238,10 @@ match (Var v1) e2 tpl_vars kont subst
                         kont (extendSubst subst v1 (DoneEx e2))
 
 
-               | eqExpr (Var v1) e2             -> kont subst
+               | eqExpr (Var v1) e2       -> kont subst
                        -- v1 is not a template variable, so it must be a global constant
 
-       Just (DoneEx e2')  | eqExpr e2'       e2 -> kont subst
+       Just (DoneEx e2')  | eqExpr e2' e2 -> kont subst
 
        other -> match_fail
 
@@ -355,16 +356,10 @@ bind vs1 vs2 matcher tpl_vars kont subst
     subst'        = bindSubstList subst vs1 vs2
 
        -- The unBindSubst relies on no shadowing in the template
-    not_in_subst v = not (maybeToBool (lookupSubst subst v))
+    not_in_subst v = isNothing (lookupSubst subst v)
     bug_msg = sep [ppr vs1, ppr vs2]
 
 ----------------------------------------
-match_ty ty1 ty2 tpl_vars kont subst
-  = case Unify.match False {- for now: KSW 2000-10 -} ty1 ty2 tpl_vars Just (substEnv subst) of
-       Nothing    -> match_fail
-       Just senv' -> kont (setSubstEnv subst senv') 
-
-----------------------------------------
 matches [] [] tpl_vars kont subst 
   = kont subst
 matches (e:es) (e':es') tpl_vars kont subst
@@ -378,6 +373,22 @@ mkVarArg v | isId v    = Var v
           | otherwise = Type (mkTyVarTy v)
 \end{code}
 
+Matching Core types: use the matcher in TcType.
+Notice that we treat newtypes as opaque.  For example, suppose 
+we have a specialised version of a function at a newtype, say 
+       newtype T = MkT Int
+We only want to replace (f T) with f', not (f Int).
+
+\begin{code}
+----------------------------------------
+match_ty ty1 ty2 tpl_vars kont subst
+  = TcType.match ty1 ty2 tpl_vars kont' (substEnv subst)
+  where
+    kont' senv = kont (setSubstEnv subst senv) 
+\end{code}
+
+
+
 %************************************************************************
 %*                                                                     *
 \subsection{Adding a new rule}
@@ -421,7 +432,7 @@ insertRule rules new_rule@(Rule _ tpl_vars tpl_args _)
     go (rule:rules) | new_is_more_specific rule = (new_rule:rule:rules)
                    | otherwise                 = rule : go rules
 
-    new_is_more_specific rule = maybeToBool (matchRule tpl_var_set rule tpl_args)
+    new_is_more_specific rule = isJust (matchRule tpl_var_set rule tpl_args)
 
 addIdSpecialisations :: Id -> [CoreRule] -> Id
 addIdSpecialisations id rules
@@ -458,7 +469,7 @@ data RuleBase = RuleBase
                                -- Held as a set, so that it can simply be the initial
                                -- in-scope set in the simplifier
 
-                    IdSet      -- Ids (whether local or imported) mentioned on 
+                   IdSet       -- Ids (whether local or imported) mentioned on 
                                -- LHS of some rule; these should be black listed
 
        -- This representation is a bit cute, and I wonder if we should
@@ -483,12 +494,15 @@ extendRuleBase (RuleBase rule_ids rule_fvs) (id, rule)
             (rule_fvs `unionVarSet` extendVarSet lhs_fvs id)
   where
     new_id = setIdSpecialisation id (addRule old_rules id rule)
-    old_rules = case lookupVarSet rule_ids id of
-                  Nothing  -> emptyCoreRules
-                  Just id' -> idSpecialisation id'
-    
-    lhs_fvs = ruleSomeLhsFreeVars isId rule
-       -- Find *all* the free Ids of the LHS, not just
+
+    old_rules = idSpecialisation (fromMaybe id (lookupVarSet rule_ids id))
+       -- Get the old rules from rule_ids if the Id is already there, but
+       -- if not, use the Id from the incoming rule.  If may be a PrimOpId,
+       -- in which case it may have rules in its belly already.  Seems
+       -- dreadfully hackoid.
+
+    lhs_fvs = ruleLhsFreeIds rule
+       -- Finds *all* the free Ids of the LHS, not just
        -- locally defined ones!!
 
 pprRuleBase :: RuleBase -> SDoc