[project @ 2004-11-25 11:36:34 by simonpj]
[ghc-hetmet.git] / ghc / compiler / specialise / Rules.lhs
index 8a489a0..f627d46 100644 (file)
@@ -6,9 +6,8 @@
 \begin{code}
 module Rules (
        RuleBase, emptyRuleBase, 
-       extendRuleBase, extendRuleBaseList, addRuleBaseFVs, 
-       ruleBaseIds, ruleBaseFVs,
-       pprRuleBase, ruleCheckProgram,
+       extendRuleBaseList, 
+       ruleBaseIds, pprRuleBase, ruleCheckProgram,
 
         lookupRule, addRule, addIdSpecialisations
     ) where
@@ -17,25 +16,27 @@ module Rules (
 
 import CoreSyn         -- All of it
 import OccurAnal       ( occurAnalyseRule )
-import CoreFVs         ( exprFreeVars, ruleRhsFreeVars, ruleLhsFreeIds )
+import CoreFVs         ( exprFreeVars, ruleRhsFreeVars )
 import CoreUnfold      ( isCheapUnfolding, unfoldingTemplate )
 import CoreUtils       ( eqExpr )
 import CoreTidy                ( pprTidyIdRules )
-import Subst           ( Subst, InScopeSet, mkInScopeSet, lookupSubst, extendSubst,
-                         substEnv, setSubstEnv, emptySubst, isInScope, emptyInScopeSet,
-                         bindSubstList, unBindSubstList, substInScope, uniqAway
+import Subst           ( Subst, SubstResult(..), extendIdSubst,
+                         getTvSubstEnv, setTvSubstEnv,
+                         emptySubst, isInScope, lookupIdSubst, lookupTvSubst,
+                         bindSubstList, unBindSubstList, substInScope
                        )
 import Id              ( Id, idUnfolding, idSpecialisation, setIdSpecialisation ) 
-import Var             ( isId )
+import Var             ( Var, isId )
 import VarSet
 import VarEnv
 import TcType          ( mkTyVarTy )
-import qualified TcType ( match )
+import qualified Unify  ( matchTyX )
 import BasicTypes      ( Activation, CompilerPhase, isActive )
 
 import Outputable
+import FastString
 import Maybe           ( isJust, isNothing, fromMaybe )
-import Util            ( sortLt )
+import Util            ( sortLe )
 import Bag
 import List            ( isPrefixOf )
 \end{code}
@@ -171,13 +172,19 @@ matchRule is_active in_scope rule@(Rule rn act tpl_vars tpl_args rhs) args
 
    -----------------------
    app_match subst fn vs = foldl go fn vs
-       where   
-         senv    = substEnv subst
-         go fn v = case lookupSubstEnv senv v of
-                       Just (DoneEx ex)  -> fn `App` ex 
-                       Just (DoneTy ty)  -> fn `App` Type ty
-                       -- Substitution should bind them all!
-
+     where     
+       go fn v = case lookupVar subst v of
+                   Just e  -> fn `App` e 
+                   Nothing -> pprPanic "app_match: unbound tpl" (ppr v)
+
+lookupVar :: Subst -> Var -> Maybe CoreExpr
+lookupVar subst v
+   | isId v    = case lookupIdSubst subst v of
+                  Just (DoneEx ex) -> Just ex
+                  other            -> Nothing
+   | otherwise = case lookupTvSubst subst v of
+                  Just ty -> Just (Type ty)
+                  Nothing -> Nothing
 
    -----------------------
 {-     The code below tries to match even if there are more 
@@ -229,10 +236,13 @@ type Matcher result =  VarSet                     -- Template variables
                    -> Subst  -> Maybe result   -- Substitution so far -> result
 -- The *SubstEnv* in these Substs apply to the TEMPLATE only 
 
--- The *InScopeSet* in these Substs gives variables bound so far in the
+-- The *InScopeSet* in these Substs is HIJACKED,
+--     to give the set of variables bound so far in the
 --     target term.  So when matching forall a. (\x. a x) against (\y. y y)
 --     while processing the body of the lambdas, the in-scope set will be {y}.
 --     That lets us do the occurs-check when matching 'a' against 'y'
+--
+--     It starts off empty
 
 match :: CoreExpr              -- Template
       -> CoreExpr              -- Target
@@ -240,14 +250,18 @@ match :: CoreExpr         -- Template
 
 match_fail = Nothing
 
-match (Var v1) e2 tpl_vars kont subst
-  = case lookupSubst subst v1 of
+-- ToDo: remove this debugging junk
+-- match e1 e2 tpls kont subst = pprTrace "match" (ppr e1 <+> ppr e2 <+> ppr subst) $ match_ e1 e2 tpls kont subst
+match = match_
+
+match_ (Var v1) e2 tpl_vars kont subst
+  = case lookupIdSubst subst v1 of
        Nothing | v1 `elemVarSet` tpl_vars      -- v1 is a template variable
                -> if (any (`isInScope` subst) (varSetElems (exprFreeVars e2))) then
                         match_fail             -- Occurs check failure
                                                -- e.g. match forall a. (\x-> a x) against (\y. y y)
                   else
-                        kont (extendSubst subst v1 (DoneEx e2))
+                        kont (extendIdSubst subst v1 (DoneEx e2))
 
 
                | eqExpr (Var v1) e2       -> kont subst
@@ -257,27 +271,32 @@ match (Var v1) e2 tpl_vars kont subst
 
        other -> match_fail
 
-match (Lit lit1) (Lit lit2) tpl_vars kont subst
+match_ (Lit lit1) (Lit lit2) tpl_vars kont subst
   | lit1 == lit2
   = kont subst
 
-match (App f1 a1) (App f2 a2) tpl_vars kont subst
+match_ (App f1 a1) (App f2 a2) tpl_vars kont subst
   = match f1 f2 tpl_vars (match a1 a2 tpl_vars kont) subst
 
-match (Lam x1 e1) (Lam x2 e2) tpl_vars kont subst
+match_ (Lam x1 e1) (Lam x2 e2) tpl_vars kont subst
   = bind [x1] [x2] (match e1 e2) tpl_vars kont subst
 
 -- This rule does eta expansion
 --             (\x.M)  ~  N    iff     M  ~  N x
 -- See assumption A3
-match (Lam x1 e1) e2 tpl_vars kont subst
+match_ (Lam x1 e1) e2 tpl_vars kont subst
   = bind [x1] [x1] (match e1 (App e2 (mkVarArg x1))) tpl_vars kont subst
 
 -- Eta expansion the other way
 --     M  ~  (\y.N)    iff   \y.M y  ~  \y.N
 --                     iff   M y     ~  N
 -- Remembering that by (A), y can't be free in M, we get this
-match e1 (Lam x2 e2) tpl_vars kont subst
+match_ e1 (Lam x2 e2) tpl_vars kont subst
+  | new_id == x2       -- If the two are equal, don't bind, else we get
+                       -- a substitution looking like x->x, and that sends
+                       -- Unify.matchTy into a loop
+  = match (App e1 (mkVarArg new_id)) e2 tpl_vars kont subst
+  | otherwise
   = bind [new_id] [x2] (match (App e1 (mkVarArg new_id)) e2) tpl_vars kont subst
   where
     new_id = uniqAway (substInScope subst) x2
@@ -289,16 +308,18 @@ match e1 (Lam x2 e2) tpl_vars kont subst
        -- The first \x is ok, but when we inline k, hoping it might
        -- match (:) we find a second \x.
 
-match (Case e1 x1 alts1) (Case e2 x2 alts2) tpl_vars kont subst
-  = match e1 e2 tpl_vars case_kont subst
+-- gaw 2004
+match_ (Case e1 x1 ty1 alts1) (Case e2 x2 ty2 alts2) tpl_vars kont subst
+  = (match_ty ty1 ty2 tpl_vars $
+     match e1 e2 tpl_vars case_kont) subst
   where
-    case_kont subst = bind [x1] [x2] (match_alts alts1 (sortLt lt_alt alts2))
+    case_kont subst = bind [x1] [x2] (match_alts alts1 (sortLe le_alt alts2))
                                     tpl_vars kont subst
 
-match (Type ty1) (Type ty2) tpl_vars kont subst
+match_ (Type ty1) (Type ty2) tpl_vars kont subst
   = match_ty ty1 ty2 tpl_vars kont subst
 
-match (Note (Coerce to1 from1) e1) (Note (Coerce to2 from2) e2)
+match_ (Note (Coerce to1 from1) e1) (Note (Coerce to2 from2) e2)
       tpl_vars kont subst
   = (match_ty to1   to2   tpl_vars $
      match_ty from1 from2 tpl_vars $
@@ -325,7 +346,7 @@ match e1 (Let bind e2) tpl_vars kont subst
 -- variable, we expand it so long as its unfolding is a WHNF
 -- (Its occurrence information is not necessarily up to date,
 --  so we don't use it.)
-match e1 (Var v2) tpl_vars kont subst
+match_ e1 (Var v2) tpl_vars kont subst
   | isCheapUnfolding unfolding
   = match e1 (unfoldingTemplate unfolding) tpl_vars kont subst
   where
@@ -334,7 +355,7 @@ match e1 (Var v2) tpl_vars kont subst
 
 -- We can't cope with lets in the template
 
-match e1 e2 tpl_vars kont subst = match_fail
+match_ e1 e2 tpl_vars kont subst = match_fail
 
 
 ------------------------------------------
@@ -347,7 +368,7 @@ match_alts ((c1,vs1,r1):alts1) ((c2,vs2,r2):alts2) tpl_vars kont subst
                 subst
 match_alts alts1 alts2 tpl_vars kont subst = match_fail
 
-lt_alt (con1, _, _) (con2, _, _) = con1 < con2
+le_alt (con1, _, _) (con2, _, _) = con1 <= con2
 
 ----------------------------------------
 bind :: [CoreBndr]     -- Template binders
@@ -368,18 +389,10 @@ bind vs1 vs2 matcher tpl_vars kont subst
     subst'        = bindSubstList subst vs1 vs2
 
        -- The unBindSubst relies on no shadowing in the template
-    not_in_subst v = isNothing (lookupSubst subst v)
+    not_in_subst v = isNothing (lookupVar subst v)
     bug_msg = sep [ppr vs1, ppr vs2]
 
 ----------------------------------------
-matches [] [] tpl_vars kont subst 
-  = kont subst
-matches (e:es) (e':es') tpl_vars kont subst
-  = match e e' tpl_vars (matches es es' tpl_vars kont) subst
-matches es es' tpl_vars kont subst 
-  = match_fail
-
-----------------------------------------
 mkVarArg :: CoreBndr -> CoreArg
 mkVarArg v | isId v    = Var v
           | otherwise = Type (mkTyVarTy v)
@@ -394,9 +407,9 @@ We only want to replace (f T) with f', not (f Int).
 \begin{code}
 ----------------------------------------
 match_ty ty1 ty2 tpl_vars kont subst
-  = TcType.match ty1 ty2 tpl_vars kont' (substEnv subst)
-  where
-    kont' senv = kont (setSubstEnv subst senv) 
+  = case Unify.matchTyX tpl_vars (getTvSubstEnv subst) ty1 ty2 of
+       Just tv_env' -> kont (setTvSubstEnv subst tv_env')
+       Nothing      -> match_fail
 \end{code}
 
 
@@ -522,8 +535,9 @@ ruleCheck env (App f a)     = ruleCheckApp env (App f a) []
 ruleCheck env (Note n e)    = ruleCheck env e
 ruleCheck env (Let bd e)    = ruleCheckBind env bd `unionBags` ruleCheck env e
 ruleCheck env (Lam b e)     = ruleCheck env e
-ruleCheck env (Case e _ as) = ruleCheck env e `unionBags` 
-                             unionManyBags [ruleCheck env r | (_,_,r) <- as]
+-- gaw 2004
+ruleCheck env (Case e _ _ as) = ruleCheck env e `unionBags` 
+                               unionManyBags [ruleCheck env r | (_,_,r) <- as]
 
 ruleCheckApp env (App f a) as = ruleCheck env a `unionBags` ruleCheckApp env f (a:as)
 ruleCheckApp env (Var f) as   = ruleCheckFun env f as
@@ -541,7 +555,7 @@ ruleCheckFun (phase, pat) fn args
   where
     name_match_rules = case idSpecialisation fn of
                          Rules rules _ -> filter match rules
-    match rule = pat `isPrefixOf` _UNPK_ (ruleName rule)
+    match rule = pat `isPrefixOf` unpackFS (ruleName rule)
 
 ruleAppCheck_help :: CompilerPhase -> Id -> [CoreExpr] -> [CoreRule] -> SDoc
 ruleAppCheck_help phase fn args rules
@@ -554,8 +568,10 @@ ruleAppCheck_help phase fn args rules
 
     check_rule rule = rule_herald rule <> colon <+> rule_info rule
 
-    rule_herald (BuiltinRule name _) = text "Builtin rule" <+> doubleQuotes (ptext name)
-    rule_herald (Rule name _ _ _ _)  = text "Rule" <+> doubleQuotes (ptext name)
+    rule_herald (BuiltinRule name _) = 
+       ptext SLIT("Builtin rule") <+> doubleQuotes (ftext name)
+    rule_herald (Rule name _ _ _ _)  = 
+       ptext SLIT("Rule") <+> doubleQuotes (ftext name)
 
     rule_info rule
        | Just (name,_) <- matchRule noBlackList emptyInScopeSet rule args
@@ -591,43 +607,27 @@ data RuleBase = RuleBase
                    IdSet       -- Ids with their rules in their specialisations
                                -- Held as a set, so that it can simply be the initial
                                -- in-scope set in the simplifier
-
-                   IdSet       -- Ids (whether local or imported) mentioned on 
-                               -- LHS of some rule; these should be black listed
-
        -- This representation is a bit cute, and I wonder if we should
        -- change it to use (IdEnv CoreRule) which seems a bit more natural
 
-ruleBaseIds (RuleBase ids _) = ids
-ruleBaseFVs (RuleBase _ fvs) = fvs
-
-emptyRuleBase = RuleBase emptyVarSet emptyVarSet
+ruleBaseIds (RuleBase ids) = ids
+emptyRuleBase = RuleBase emptyVarSet
 
-addRuleBaseFVs :: RuleBase -> IdSet -> RuleBase
-addRuleBaseFVs (RuleBase rules fvs) extra_fvs 
-  = RuleBase rules (fvs `unionVarSet` extra_fvs)
-
-extendRuleBaseList :: RuleBase -> [(Id,CoreRule)] -> RuleBase
+extendRuleBaseList :: RuleBase -> [IdCoreRule] -> RuleBase
 extendRuleBaseList rule_base new_guys
   = foldl extendRuleBase rule_base new_guys
 
-extendRuleBase :: RuleBase -> (Id,CoreRule) -> RuleBase
-extendRuleBase (RuleBase rule_ids rule_fvs) (id, rule)
+extendRuleBase :: RuleBase -> IdCoreRule -> RuleBase
+extendRuleBase (RuleBase rule_ids) (IdCoreRule id _ rule)
   = RuleBase (extendVarSet rule_ids new_id)
-            (rule_fvs `unionVarSet` extendVarSet lhs_fvs id)
   where
-    new_id = setIdSpecialisation id (addRule id old_rules rule)
-
+    new_id    = setIdSpecialisation id (addRule id old_rules rule)
     old_rules = idSpecialisation (fromMaybe id (lookupVarSet rule_ids id))
        -- Get the old rules from rule_ids if the Id is already there, but
        -- if not, use the Id from the incoming rule.  If may be a PrimOpId,
        -- in which case it may have rules in its belly already.  Seems
        -- dreadfully hackoid.
 
-    lhs_fvs = ruleLhsFreeIds rule
-       -- Finds *all* the free Ids of the LHS, not just
-       -- locally defined ones!!
-
 pprRuleBase :: RuleBase -> SDoc
-pprRuleBase (RuleBase rules _) = vcat [ pprTidyIdRules id | id <- varSetElems rules ]
+pprRuleBase (RuleBase rules) = vcat [ pprTidyIdRules id | id <- varSetElems rules ]
 \end{code}