[project @ 2000-10-24 17:09:44 by simonpj]
[ghc-hetmet.git] / ghc / compiler / specialise / Rules.lhs
index 6e7c6c2..ab1436b 100644 (file)
@@ -5,9 +5,10 @@
 
 \begin{code}
 module Rules (
-       RuleBase, prepareLocalRuleBase, prepareOrphanRuleBase,
+       RuleBase, emptyRuleBase, extendRuleBase, extendRuleBaseList,
+       prepareLocalRuleBase, prepareOrphanRuleBase,
         unionRuleBase, lookupRule, addRule, addIdSpecialisations,
-       ProtoCoreRule(..), pprProtoCoreRule,
+       ProtoCoreRule(..), pprProtoCoreRule, pprRuleBase,
        localRule, orphanRule
     ) where
 
@@ -15,32 +16,28 @@ module Rules (
 
 import CoreSyn         -- All of it
 import OccurAnal       ( occurAnalyseRule )
-import BinderInfo      ( markMany )
 import CoreFVs         ( exprFreeVars, idRuleVars, ruleRhsFreeVars, ruleSomeLhsFreeVars )
 import CoreUnfold      ( isCheapUnfolding, unfoldingTemplate )
 import CoreUtils       ( eqExpr )
 import PprCore         ( pprCoreRule )
-import Subst           ( Subst, InScopeSet, substBndr, lookupSubst, extendSubst,
-                         mkSubst, substEnv, setSubstEnv, emptySubst, isInScope,
-                         unBindSubst, bindSubstList, unBindSubstList, substInScope
+import Subst           ( Subst, InScopeSet, mkInScopeSet, lookupSubst, extendSubst,
+                         substEnv, setSubstEnv, emptySubst, isInScope,
+                         bindSubstList, unBindSubstList, substInScope, uniqAway
                        )
 import Id              ( Id, idUnfolding, zapLamIdInfo, 
                          idSpecialisation, setIdSpecialisation,
-                         setIdNoDiscard, maybeModifyIdInfo, modifyIdInfo
+                         setIdNoDiscard
                        ) 
-import Name            ( Name, isLocallyDefined )
+import Name            ( isLocallyDefined )
 import Var             ( isTyVar, isId )
 import VarSet
 import VarEnv
-import Type            ( mkTyVarTy, getTyVar_maybe )
+import Type            ( mkTyVarTy )
 import qualified Unify ( match )
-import CmdLineOpts     ( opt_D_dump_simpl, opt_D_verbose_core2core )
 
 import UniqFM
-import ErrUtils                ( dumpIfSet )
 import Outputable
 import Maybes          ( maybeToBool )
-import List            ( partition )
 import Util            ( sortLt )
 \end{code}
 
@@ -420,7 +417,7 @@ addRule id (Rules rules rhs_fvs) rule
 insertRule rules new_rule@(Rule _ tpl_vars tpl_args _)
   = go rules
   where
-    tpl_var_set = mkVarSet tpl_vars
+    tpl_var_set = mkInScopeSet (mkVarSet tpl_vars)
        -- Actually we should probably include the free vars of tpl_args,
        -- but I can't be bothered
 
@@ -480,9 +477,26 @@ orphanRule (ProtoCoreRule local fn _)
 %************************************************************************
 
 \begin{code}
-type RuleBase = (IdSet,                -- Imported Ids that have rules attached
-                IdSet)         -- Ids (whether local or imported) mentioned on 
-                               -- LHS of some rule; these should be black listed
+data RuleBase = RuleBase (IdEnv CoreRules)     -- Maps an Id to its rules
+                        IdSet                  -- Ids (whether local or imported) mentioned on 
+                                               -- LHS of some rule; these should be black listed
+
+emptyRuleBase = RuleBase emptyVarEnv emptyVarSet
+
+extendRuleBaseList :: RuleBase -> [(Name,CoreRule)] -> RuleBase
+extendRuleBaseList rule_base new_guys
+  = foldr extendRuleBase rule_base new_guys
+
+extendRuleBase :: RuleBase -> (Name,CoreRule) -> RuleBase
+extendRuleBase (RuleBase rule_env rule_fvs) (id, rule)
+  = RuleBase (extendVarEnv rule_env id (addRule id rules_for_id rule))
+            (rule_fvs `unionVarSet` extendVarSet lhs_fvs id)
+  where
+    rules_for_id = case lookupWithDefaultVarEnv rule_env emptyCoreRules id
+
+    lhs_fvs = ruleSomeLhsFreeVars isId rule
+       -- Find *all* the free Ids of the LHS, not just
+       -- locally defined ones!!
 
 unionRuleBase (rule_ids1, black_ids1) (rule_ids2, black_ids2)
   = (plusUFM_C merge_rules rule_ids1 rule_ids2,
@@ -494,6 +508,11 @@ unionRuleBase (rule_ids1, black_ids1) (rule_ids2, black_ids2)
                           in
                           setIdSpecialisation id1 new_rules
 
+pprRuleBase :: RuleBase -> SDoc
+pprRuleBase (rules,_) = vcat [ pprCoreRule (ppr id) rs
+                             | id <- varSetElems rules,
+                               rs <- rulesRules $ idSpecialisation id ]
+
 -- prepareLocalRuleBase takes the CoreBinds and rules defined in this module.
 -- It attaches those rules that are for local Ids to their binders, and
 -- returns the remainder attached to Ids in an IdSet.  It also returns
@@ -506,7 +525,7 @@ prepareLocalRuleBase :: [CoreBind] -> [ProtoCoreRule] -> ([CoreBind], RuleBase)
 prepareLocalRuleBase binds local_rules
   = (map zap_bind binds, (imported_id_rule_ids, rule_lhs_fvs))
   where
-    (rule_ids, rule_lhs_fvs) = foldr add_rule (emptyVarSet, emptyVarSet) local_rules
+    (rule_ids, rule_lhs_fvs) = foldr add_rule emptyRuleBase local_rules
     imported_id_rule_ids = filterVarSet (not . isLocallyDefined) rule_ids
 
        -- rule_fvs is the set of all variables mentioned in this module's rules
@@ -534,18 +553,6 @@ prepareLocalRuleBase binds local_rules
                          Just bndr'                           -> setIdNoDiscard bndr'
                          Nothing | bndr `elemVarSet` rule_fvs -> setIdNoDiscard bndr
                                  | otherwise                  -> bndr
-                 
-add_rule (ProtoCoreRule _ id rule)
-        (rule_id_set, rule_fvs)
-  = (rule_id_set `extendVarSet` new_id,
-     rule_fvs `unionVarSet` extendVarSet lhs_fvs id)
-  where
-    new_id = case lookupVarSet rule_id_set id of
-               Just id' -> addRuleToId id' rule
-               Nothing  -> addRuleToId id  rule
-    lhs_fvs = ruleSomeLhsFreeVars isId rule
-       -- Find *all* the free Ids of the LHS, not just
-       -- locally defined ones!!
 
 addRuleToId id rule = setIdSpecialisation id (addRule id (idSpecialisation id) rule)