[project @ 2000-07-07 09:37:39 by simonmar]
[ghc-hetmet.git] / ghc / compiler / specialise / Rules.lhs
index f1578c2..7a70d51 100644 (file)
@@ -5,31 +5,29 @@
 
 \begin{code}
 module Rules (
-       RuleBase, prepareRuleBase, lookupRule, addRule,
-       addIdSpecialisations,
-       ProtoCoreRule(..), pprProtoCoreRule,
-       orphanRule
+       RuleBase, prepareLocalRuleBase, prepareOrphanRuleBase,
+        unionRuleBase, lookupRule, addRule, addIdSpecialisations,
+       ProtoCoreRule(..), pprProtoCoreRule, pprRuleBase,
+       localRule, orphanRule
     ) where
 
 #include "HsVersions.h"
 
 import CoreSyn         -- All of it
-import Const           ( Con(..), Literal(..) )
-import OccurAnal       ( occurAnalyseExpr, tagBinders, UsageDetails )
+import OccurAnal       ( occurAnalyseRule )
 import BinderInfo      ( markMany )
-import CoreFVs         ( exprFreeVars, idRuleVars, ruleSomeLhsFreeVars )
+import CoreFVs         ( exprFreeVars, idRuleVars, ruleRhsFreeVars, ruleSomeLhsFreeVars )
 import CoreUnfold      ( isCheapUnfolding, unfoldingTemplate )
-import CoreUtils       ( eqExpr, cheapEqExpr )
+import CoreUtils       ( eqExpr )
 import PprCore         ( pprCoreRule )
 import Subst           ( Subst, InScopeSet, substBndr, lookupSubst, extendSubst,
                          mkSubst, substEnv, setSubstEnv, emptySubst, isInScope,
                          unBindSubst, bindSubstList, unBindSubstList, substInScope
                        )
-import Id              ( Id, getIdUnfolding, zapLamIdInfo, 
-                         getIdSpecialisation, setIdSpecialisation,
+import Id              ( Id, idUnfolding, zapLamIdInfo, 
+                         idSpecialisation, setIdSpecialisation,
                          setIdNoDiscard, maybeModifyIdInfo, modifyIdInfo
                        ) 
-import IdInfo          ( setSpecInfo, specInfo )
 import Name            ( Name, isLocallyDefined )
 import Var             ( isTyVar, isId )
 import VarSet
@@ -220,7 +218,7 @@ zapOccInfo bndr | isTyVar bndr = bndr
 \end{code}
 
 \begin{code}
-type Matcher result =  IdOrTyVarSet            -- Template variables
+type Matcher result =  VarSet                  -- Template variables
                    -> (Subst -> Maybe result)  -- Continuation if success
                    -> Subst  -> Maybe result   -- Substitution so far -> result
 -- The *SubstEnv* in these Substs apply to the TEMPLATE only 
@@ -253,9 +251,9 @@ match (Var v1) e2 tpl_vars kont subst
 
        other -> match_fail
 
-match (Con c1 es1) (Con c2 es2) tpl_vars kont subst
-  | c1 == c2
-  = matches es1 es2 tpl_vars kont subst
+match (Lit lit1) (Lit lit2) tpl_vars kont subst
+  | lit1 == lit2
+  = kont subst
 
 match (App f1 a1) (App f2 a2) tpl_vars kont subst
   = match f1 f2 tpl_vars (match a1 a2 tpl_vars kont) subst
@@ -325,7 +323,7 @@ match e1 (Var v2) tpl_vars kont subst
   | isCheapUnfolding unfolding
   = match e1 (unfoldingTemplate unfolding) tpl_vars kont subst
   where
-    unfolding = getIdUnfolding v2
+    unfolding = idUnfolding v2
 
 
 -- We can't cope with lets in the template
@@ -408,38 +406,36 @@ addRule id (Rules rules rhs_fvs) rule@(BuiltinRule _)
   = Rules (rule:rules) rhs_fvs
        -- Put it at the start for lack of anything better
 
-addRule id (Rules rules rhs_fvs) (Rule str tpl_vars tpl_args rhs)
-  = Rules (insert rules) (rhs_fvs `unionVarSet` new_rhs_fvs)
+addRule id (Rules rules rhs_fvs) rule
+  = Rules (insertRule rules new_rule) (rhs_fvs `unionVarSet` new_rhs_fvs)
   where
-    new_rule = Rule str tpl_vars' tpl_args rhs'
-               -- Add occ info to tpl_vars, rhs
-
-    (rhs_uds, rhs')      = occurAnalyseExpr isLocallyDefined rhs
-    (rhs_uds1, tpl_vars') = tagBinders rhs_uds tpl_vars
-
-    insert []                                      = [new_rule]
-    insert (rule:rules) | new_is_more_specific rule = (new_rule:rule:rules)
-                       | otherwise                 = rule : insert rules
-
-    new_is_more_specific rule = maybeToBool (matchRule tpl_var_set rule tpl_args)
-
-    tpl_var_set = mkVarSet tpl_vars'
-       -- Actually we should probably include the free vars of tpl_args,
-       -- but I can't be bothered
-
-    new_rhs_fvs = (exprFreeVars rhs' `minusVarSet` tpl_var_set) `delVarSet` id
+    new_rule    = occurAnalyseRule rule
+    new_rhs_fvs = ruleRhsFreeVars new_rule `delVarSet` id
        -- Hack alert!
        -- Don't include the Id in its own rhs free-var set.
        -- Otherwise the occurrence analyser makes bindings recursive
        -- that shoudn't be.  E.g.
        --      RULE:  f (f x y) z  ==>  f x (f y z)
 
+insertRule rules new_rule@(Rule _ tpl_vars tpl_args _)
+  = go rules
+  where
+    tpl_var_set = mkVarSet tpl_vars
+       -- Actually we should probably include the free vars of tpl_args,
+       -- but I can't be bothered
+
+    go []                                      = [new_rule]
+    go (rule:rules) | new_is_more_specific rule = (new_rule:rule:rules)
+                   | otherwise                 = rule : go rules
+
+    new_is_more_specific rule = maybeToBool (matchRule tpl_var_set rule tpl_args)
+
 addIdSpecialisations :: Id -> [([CoreBndr], [CoreExpr], CoreExpr)] -> Id
 addIdSpecialisations id spec_stuff
   = setIdSpecialisation id new_rules
   where
     rule_name = _PK_ ("SPEC " ++ showSDoc (ppr id))
-    new_rules = foldr add (getIdSpecialisation id) spec_stuff
+    new_rules = foldr add (idSpecialisation id) spec_stuff
     add (vars, args, rhs) rules = addRule id rules (Rule rule_name vars args rhs)
 \end{code}
 
@@ -458,16 +454,19 @@ data ProtoCoreRule
        CoreRule        -- The rule itself
        
 
-pprProtoCoreRule (ProtoCoreRule _ fn rule) = pprCoreRule (Just fn) rule
+pprProtoCoreRule (ProtoCoreRule _ fn rule) = pprCoreRule (ppr fn) rule
 
 lookupRule :: InScopeSet -> Id -> [CoreExpr] -> Maybe (RuleName, CoreExpr)
 lookupRule in_scope fn args
-  = case getIdSpecialisation fn of
+  = case idSpecialisation fn of
        Rules rules _ -> matchRules in_scope rules args
 
+localRule :: ProtoCoreRule -> Bool
+localRule (ProtoCoreRule local _ _) = local
+
 orphanRule :: ProtoCoreRule -> Bool
 -- An "orphan rule" is one that is defined in this 
--- module, but of ran *imported* function.  We need
+-- module, but for an *imported* function.  We need
 -- to track these separately when generating the interface file
 orphanRule (ProtoCoreRule local fn _)
   = local && not (isLocallyDefined fn)
@@ -485,17 +484,37 @@ type RuleBase = (IdSet,           -- Imported Ids that have rules attached
                 IdSet)         -- Ids (whether local or imported) mentioned on 
                                -- LHS of some rule; these should be black listed
 
+unionRuleBase (rule_ids1, black_ids1) (rule_ids2, black_ids2)
+  = (plusUFM_C merge_rules rule_ids1 rule_ids2,
+     unionVarSet black_ids1 black_ids2)
+  where
+    merge_rules id1 id2 = let rules1 = idSpecialisation id1
+                              rules2 = idSpecialisation id2
+                              new_rules = foldl (addRule id1) rules1 (rulesRules rules2)
+                          in
+                          setIdSpecialisation id1 new_rules
+
+pprRuleBase :: RuleBase -> SDoc
+pprRuleBase (rules,_) = vcat [ pprCoreRule (ppr id) rs
+                             | id <- varSetElems rules,
+                               rs <- rulesRules $ idSpecialisation id ]
+
+-- prepareLocalRuleBase takes the CoreBinds and rules defined in this module.
+-- It attaches those rules that are for local Ids to their binders, and
+-- returns the remainder attached to Ids in an IdSet.  It also returns
+-- Ids mentioned on LHS of some rule; these should be blacklisted.
+
 -- The rule Ids and LHS Ids are black-listed; that is, they aren't inlined
 -- so that the opportunity to apply the rule isn't lost too soon
 
-prepareRuleBase :: [CoreBind] -> [ProtoCoreRule] -> ([CoreBind], RuleBase)
-prepareRuleBase binds all_rules
-  = (map zap_bind binds, (imported_rule_ids, rule_lhs_fvs))
+prepareLocalRuleBase :: [CoreBind] -> [ProtoCoreRule] -> ([CoreBind], RuleBase)
+prepareLocalRuleBase binds local_rules
+  = (map zap_bind binds, (imported_id_rule_ids, rule_lhs_fvs))
   where
-    (rule_ids, rule_lhs_fvs) = foldr add_rule (emptyVarSet, emptyVarSet) all_rules
-    imported_rule_ids = filterVarSet (not . isLocallyDefined) rule_ids
+    (rule_ids, rule_lhs_fvs) = foldr add_rule (emptyVarSet, emptyVarSet) local_rules
+    imported_id_rule_ids = filterVarSet (not . isLocallyDefined) rule_ids
 
-       -- rule_fvs is the set of all variables mentioned in rules
+       -- rule_fvs is the set of all variables mentioned in this module's rules
     rule_fvs = foldVarSet (unionVarSet . idRuleVars) rule_lhs_fvs rule_ids
 
        -- Attach the rules for each locally-defined Id to that Id.
@@ -533,5 +552,12 @@ add_rule (ProtoCoreRule _ id rule)
        -- Find *all* the free Ids of the LHS, not just
        -- locally defined ones!!
 
-addRuleToId id rule = setIdSpecialisation id (addRule id (getIdSpecialisation id) rule)
+addRuleToId id rule = setIdSpecialisation id (addRule id (idSpecialisation id) rule)
+
+-- prepareOrphanRuleBase does exactly the same as prepareLocalRuleBase, except that
+-- it assumes that none of the rules can be attached to local Ids.
+
+prepareOrphanRuleBase :: [ProtoCoreRule] -> RuleBase
+prepareOrphanRuleBase imported_rules
+  = foldr add_rule (emptyVarSet, emptyVarSet) imported_rules
 \end{code}