(F)SLIT -> (f)sLit in Specialse
[ghc-hetmet.git] / compiler / specialise / Rules.lhs
index 03cc6c1..000df94 100644 (file)
@@ -4,6 +4,13 @@
 \section[CoreRules]{Transformation rules}
 
 \begin{code}
 \section[CoreRules]{Transformation rules}
 
 \begin{code}
+{-# OPTIONS -w #-}
+-- The above warning supression flag is a temporary kludge.
+-- While working on this module you are encouraged to remove it and fix
+-- any warnings in the module. See
+--     http://hackage.haskell.org/trac/ghc/wiki/Commentary/CodingStyle#Warnings
+-- for details
+
 module Rules (
        RuleBase, emptyRuleBase, mkRuleBase, extendRuleBaseList, 
        unionRuleBase, pprRuleBase, ruleCheckProgram,
 module Rules (
        RuleBase, emptyRuleBase, mkRuleBase, extendRuleBaseList, 
        unionRuleBase, pprRuleBase, ruleCheckProgram,
@@ -19,18 +26,16 @@ module Rules (
 #include "HsVersions.h"
 
 import CoreSyn         -- All of it
 #include "HsVersions.h"
 
 import CoreSyn         -- All of it
-import CoreSubst       ( substExpr, mkSubst )
 import OccurAnal       ( occurAnalyseExpr )
 import OccurAnal       ( occurAnalyseExpr )
-import CoreFVs         ( exprFreeVars, exprsFreeVars, bindFreeVars, rulesRhsFreeVars )
+import CoreFVs         ( exprFreeVars, exprsFreeVars, bindFreeVars, rulesFreeVars )
 import CoreUnfold      ( isCheapUnfolding, unfoldingTemplate )
 import CoreUnfold      ( isCheapUnfolding, unfoldingTemplate )
-import CoreUtils       ( tcEqExprX )
+import CoreUtils       ( tcEqExprX, exprType )
 import PprCore         ( pprRules )
 import PprCore         ( pprRules )
-import Type            ( TvSubstEnv )
+import Type            ( Type, TvSubstEnv )
 import Coercion         ( coercionKind )
 import TcType          ( tcSplitTyConApp_maybe )
 import CoreTidy                ( tidyRules )
 import Coercion         ( coercionKind )
 import TcType          ( tcSplitTyConApp_maybe )
 import CoreTidy                ( tidyRules )
-import Id              ( Id, idUnfolding, isLocalId, isGlobalId, idName,
-                         idSpecialisation, idCoreRules, setIdSpecialisation ) 
+import Id
 import IdInfo          ( SpecInfo( SpecInfo ) )
 import Var             ( Var )
 import VarEnv
 import IdInfo          ( SpecInfo( SpecInfo ) )
 import Var             ( Var )
 import VarEnv
@@ -39,13 +44,14 @@ import Name         ( Name, NamedThing(..) )
 import NameEnv
 import Unify           ( ruleMatchTyX, MatchEnv(..) )
 import BasicTypes      ( Activation, CompilerPhase, isActive )
 import NameEnv
 import Unify           ( ruleMatchTyX, MatchEnv(..) )
 import BasicTypes      ( Activation, CompilerPhase, isActive )
+import StaticFlags     ( opt_PprStyle_Debug )
 import Outputable
 import FastString
 import Maybes
 import OrdList
 import Bag
 import Util
 import Outputable
 import FastString
 import Maybes
 import OrdList
 import Bag
 import Util
-import List hiding( mapAccumL )        -- Also defined in Util
+import Data.List
 \end{code}
 
 
 \end{code}
 
 
@@ -117,11 +123,13 @@ ruleCantMatch :: [Maybe Name] -> [Maybe Name] -> Bool
 -- It's only a one-way match; unlike instance matching we 
 -- don't consider unification
 -- 
 -- It's only a one-way match; unlike instance matching we 
 -- don't consider unification
 -- 
--- Notice that there is no case
---     ruleCantMatch (Just n1 : ts) (Nothing : as) = True
--- Reason: a local variable 'v' in the actuals might 
---        have an unfolding which is a global.
---        This quite often happens with case scrutinees.
+-- Notice that [_$_]
+--     ruleCantMatch [Nothing] [Just n2] = False
+--      Reason: a template variable can be instantiated by a constant
+-- Also:
+--     ruleCantMatch [Just n1] [Nothing] = False
+--      Reason: a local variable 'v' in the actuals might [_$_]
+
 ruleCantMatch (Just n1 : ts) (Just n2 : as) = n1 /= n2 || ruleCantMatch ts as
 ruleCantMatch (t       : ts) (a       : as) = ruleCantMatch ts as
 ruleCantMatch ts            as             = False
 ruleCantMatch (Just n1 : ts) (Just n2 : as) = n1 /= n2 || ruleCantMatch ts as
 ruleCantMatch (t       : ts) (a       : as) = ruleCantMatch ts as
 ruleCantMatch ts            as             = False
@@ -136,11 +144,11 @@ ruleCantMatch ts       as             = False
 
 \begin{code}
 mkSpecInfo :: [CoreRule] -> SpecInfo
 
 \begin{code}
 mkSpecInfo :: [CoreRule] -> SpecInfo
-mkSpecInfo rules = SpecInfo rules (rulesRhsFreeVars rules)
+mkSpecInfo rules = SpecInfo rules (rulesFreeVars rules)
 
 extendSpecInfo :: SpecInfo -> [CoreRule] -> SpecInfo
 extendSpecInfo (SpecInfo rs1 fvs1) rs2
 
 extendSpecInfo :: SpecInfo -> [CoreRule] -> SpecInfo
 extendSpecInfo (SpecInfo rs1 fvs1) rs2
-  = SpecInfo (rs2 ++ rs1) (rulesRhsFreeVars rs2 `unionVarSet` fvs1)
+  = SpecInfo (rs2 ++ rs1) (rulesFreeVars rs2 `unionVarSet` fvs1)
 
 addSpecInfo :: SpecInfo -> SpecInfo -> SpecInfo
 addSpecInfo (SpecInfo rs1 fvs1) (SpecInfo rs2 fvs2) 
 
 addSpecInfo :: SpecInfo -> SpecInfo -> SpecInfo
 addSpecInfo (SpecInfo rs1 fvs1) (SpecInfo rs2 fvs2) 
@@ -196,20 +204,41 @@ pprRuleBase rules = vcat [ pprRules (tidyRules emptyTidyEnv rs)
 %*                                                                     *
 %************************************************************************
 
 %*                                                                     *
 %************************************************************************
 
+Note [Extra args in rule matching]
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+If we find a matching rule, we return (Just (rule, rhs)), 
+but the rule firing has only consumed as many of the input args
+as the ruleArity says.  It's up to the caller to keep track
+of any left-over args.  E.g. if you call
+       lookupRule ... f [e1, e2, e3]
+and it returns Just (r, rhs), where r has ruleArity 2
+then the real rewrite is
+       f e1 e2 e3 ==> rhs e3
+
+You might think it'd be cleaner for lookupRule to deal with the
+leftover arguments, by applying 'rhs' to them, but the main call
+in the Simplifier works better as it is.  Reason: the 'args' passed
+to lookupRule are the result of a lazy substitution
+
 \begin{code}
 lookupRule :: (Activation -> Bool) -> InScopeSet
           -> RuleBase  -- Imported rules
           -> Id -> [CoreExpr] -> Maybe (CoreRule, CoreExpr)
 \begin{code}
 lookupRule :: (Activation -> Bool) -> InScopeSet
           -> RuleBase  -- Imported rules
           -> Id -> [CoreExpr] -> Maybe (CoreRule, CoreExpr)
+-- See Note [Extra argsin rule matching]
 lookupRule is_active in_scope rule_base fn args
 lookupRule is_active in_scope rule_base fn args
-  = matchRules is_active in_scope fn args rules
-  where
+  = matchRules is_active in_scope fn args (getRules rule_base fn)
+
+getRules :: RuleBase -> Id -> [CoreRule]
        -- The rules for an Id come from two places:
        --      (a) the ones it is born with (idCoreRules fn)
        --      (b) rules added in subsequent modules (extra_rules)
        -- PrimOps, for example, are born with a bunch of rules under (a)
        -- The rules for an Id come from two places:
        --      (a) the ones it is born with (idCoreRules fn)
        --      (b) rules added in subsequent modules (extra_rules)
        -- PrimOps, for example, are born with a bunch of rules under (a)
-    rules = extra_rules ++ idCoreRules fn
-    extra_rules | isLocalId fn = []
-               | otherwise    = lookupNameEnv rule_base (idName fn) `orElse` []
+getRules rule_base fn
+  | isLocalId fn  = idCoreRules fn
+  | otherwise     = WARN( not (isPrimOpId fn) && notNull (idCoreRules fn), 
+                         ppr fn <+> ppr (idCoreRules fn) )
+                   idCoreRules fn ++ (lookupNameEnv rule_base (idName fn) `orElse` [])
+       -- Only PrimOpIds have rules inside themselves, and perhaps more besides
 
 matchRules :: (Activation -> Bool) -> InScopeSet
           -> Id -> [CoreExpr]
 
 matchRules :: (Activation -> Bool) -> InScopeSet
           -> Id -> [CoreExpr]
@@ -241,15 +270,17 @@ findBest target (rule,ans)   [] = (rule,ans)
 findBest target (rule1,ans1) ((rule2,ans2):prs)
   | rule1 `isMoreSpecific` rule2 = findBest target (rule1,ans1) prs
   | rule2 `isMoreSpecific` rule1 = findBest target (rule2,ans2) prs
 findBest target (rule1,ans1) ((rule2,ans2):prs)
   | rule1 `isMoreSpecific` rule2 = findBest target (rule1,ans1) prs
   | rule2 `isMoreSpecific` rule1 = findBest target (rule2,ans2) prs
-#ifdef DEBUG
-  | otherwise = pprTrace "Rules.findBest: rule overlap (Rule 1 wins)"
-                        (vcat [ptext SLIT("Expression to match:") <+> ppr fn <+> sep (map ppr args),
-                               ptext SLIT("Rule 1:") <+> ppr rule1, 
-                               ptext SLIT("Rule 2:") <+> ppr rule2]) $
+  | debugIsOn = let pp_rule rule
+                       | opt_PprStyle_Debug = ppr rule
+                       | otherwise          = doubleQuotes (ftext (ru_name rule))
+               in pprTrace "Rules.findBest: rule overlap (Rule 1 wins)"
+                        (vcat [if opt_PprStyle_Debug then 
+                                  ptext (sLit "Expression to match:") <+> ppr fn <+> sep (map ppr args)
+                               else empty,
+                               ptext (sLit "Rule 1:") <+> pp_rule rule1, 
+                               ptext (sLit "Rule 2:") <+> pp_rule rule2]) $
                findBest target (rule1,ans1) prs
                findBest target (rule1,ans1) prs
-#else
   | otherwise = findBest target (rule1,ans1) prs
   | otherwise = findBest target (rule1,ans1) prs
-#endif
   where
     (fn,args) = target
 
   where
     (fn,args) = target
 
@@ -506,7 +537,6 @@ match menv subst@(tv_subst, id_subst, binds) e1 (Let bind e2)
   where
     rn_env   = me_env menv
     bndrs    = bindersOf  bind
   where
     rn_env   = me_env menv
     bndrs    = bindersOf  bind
-    rhss     = rhssOfBind bind
     bind_fvs = varSetElems (bindFreeVars bind)
     locally_bound x   = inRnEnvR rn_env x
     freshly_bound x = not (x `rnInScope` rn_env)
     bind_fvs = varSetElems (bindFreeVars bind)
     locally_bound x   = inRnEnvR rn_env x
     freshly_bound x = not (x `rnInScope` rn_env)
@@ -567,11 +597,8 @@ match menv subst (Type ty1) (Type ty2)
   = match_ty menv subst ty1 ty2
 
 match menv subst (Cast e1 co1) (Cast e2 co2)
   = match_ty menv subst ty1 ty2
 
 match menv subst (Cast e1 co1) (Cast e2 co2)
-  | (from1, to1) <- coercionKind co1
-  , (from2, to2) <- coercionKind co2
-  = do { subst1 <- match_ty menv subst  to1   to2
-       ; subst2 <- match_ty menv subst1 from1 from2
-       ; match menv subst2 e1 e2 }
+  = do { subst1 <- match_ty menv subst co1 co2
+       ; match menv subst1 e1 e2 }
 
 {-     REMOVING OLD CODE: I think that the above handling for let is 
                           better than the stuff here, which looks 
 
 {-     REMOVING OLD CODE: I think that the above handling for let is 
                           better than the stuff here, which looks 
@@ -616,8 +643,21 @@ match_var menv subst@(tv_subst, id_subst, binds) v1 e2
                -> Nothing      -- Occurs check failure
                -- e.g. match forall a. (\x-> a x) against (\y. y y)
 
                -> Nothing      -- Occurs check failure
                -- e.g. match forall a. (\x-> a x) against (\y. y y)
 
-               | otherwise     -- No renaming to do on e2
-               -> Just (tv_subst, extendVarEnv id_subst v1' e2, binds)
+               | otherwise     -- No renaming to do on e2, because no free var
+                               -- of e2 is in the rnEnvR of the envt
+               -- Note [Matching variable types]
+               -- ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+               -- However, we must match the *types*; e.g.
+               --   forall (c::Char->Int) (x::Char). 
+               --      f (c x) = "RULE FIRED"
+               -- We must only match on args that have the right type
+               -- It's actually quite difficult to come up with an example that shows
+               -- you need type matching, esp since matching is left-to-right, so type
+               -- args get matched first.  But it's possible (e.g. simplrun008) and
+               -- this is the Right Thing to do
+               -> do   { tv_subst' <- Unify.ruleMatchTyX menv tv_subst (idType v1') (exprType e2)
+                                               -- c.f. match_ty below
+                       ; return (tv_subst', extendVarEnv id_subst v1' e2, binds) }
 
        Just e1' | tcEqExprX (nukeRnEnvL rn_env) e1' e2 
                 -> Just subst
 
        Just e1' | tcEqExprX (nukeRnEnvL rn_env) e1' e2 
                 -> Just subst
@@ -667,6 +707,11 @@ We only want to replace (f T) with f', not (f Int).
 
 \begin{code}
 ------------------------------------------
 
 \begin{code}
 ------------------------------------------
+match_ty :: MatchEnv
+        -> SubstEnv
+        -> Type                -- Template
+        -> Type                -- Target
+        -> Maybe SubstEnv
 match_ty menv (tv_subst, id_subst, binds) ty1 ty2
   = do { tv_subst' <- Unify.ruleMatchTyX menv tv_subst ty1 ty2
        ; return (tv_subst', id_subst, binds) }
 match_ty menv (tv_subst, id_subst, binds) ty1 ty2
   = do { tv_subst' <- Unify.ruleMatchTyX menv tv_subst ty1 ty2
        ; return (tv_subst', id_subst, binds) }
@@ -722,16 +767,11 @@ is so important.
 We want to know what sites have rules that could have fired but didn't.
 This pass runs over the tree (without changing it) and reports such.
 
 We want to know what sites have rules that could have fired but didn't.
 This pass runs over the tree (without changing it) and reports such.
 
-NB: we assume that this follows a run of the simplifier, so every Id
-occurrence (including occurrences of imported Ids) is decorated with
-all its (active) rules.  No need to construct a rule base or anything
-like that.
-
 \begin{code}
 \begin{code}
-ruleCheckProgram :: CompilerPhase -> String -> [CoreBind] -> SDoc
+ruleCheckProgram :: CompilerPhase -> String -> RuleBase -> [CoreBind] -> SDoc
 -- Report partial matches for rules beginning 
 -- with the specified string
 -- Report partial matches for rules beginning 
 -- with the specified string
-ruleCheckProgram phase rule_pat binds 
+ruleCheckProgram phase rule_pat rule_base binds 
   | isEmptyBag results
   = text "Rule check results: no rule application sites"
   | otherwise
   | isEmptyBag results
   = text "Rule check results: no rule application sites"
   | otherwise
@@ -740,10 +780,10 @@ ruleCheckProgram phase rule_pat binds
          vcat [ p $$ line | p <- bagToList results ]
         ]
   where
          vcat [ p $$ line | p <- bagToList results ]
         ]
   where
-    results = unionManyBags (map (ruleCheckBind (phase, rule_pat)) binds)
+    results = unionManyBags (map (ruleCheckBind (phase, rule_pat, rule_base)) binds)
     line = text (replicate 20 '-')
          
     line = text (replicate 20 '-')
          
-type RuleCheckEnv = (CompilerPhase, String)    -- Phase and Pattern
+type RuleCheckEnv = (CompilerPhase, String, RuleBase)  -- Phase and Pattern
 
 ruleCheckBind :: RuleCheckEnv -> CoreBind -> Bag SDoc
    -- The Bag returned has one SDoc for each call site found
 
 ruleCheckBind :: RuleCheckEnv -> CoreBind -> Bag SDoc
    -- The Bag returned has one SDoc for each call site found
@@ -772,11 +812,11 @@ ruleCheckFun :: RuleCheckEnv -> Id -> [CoreExpr] -> Bag SDoc
 -- Produce a report for all rules matching the predicate
 -- saying why it doesn't match the specified application
 
 -- Produce a report for all rules matching the predicate
 -- saying why it doesn't match the specified application
 
-ruleCheckFun (phase, pat) fn args
+ruleCheckFun (phase, pat, rule_base) fn args
   | null name_match_rules = emptyBag
   | otherwise            = unitBag (ruleAppCheck_help phase fn args name_match_rules)
   where
   | null name_match_rules = emptyBag
   | otherwise            = unitBag (ruleAppCheck_help phase fn args name_match_rules)
   where
-    name_match_rules = filter match (idCoreRules fn)
+    name_match_rules = filter match (getRules rule_base fn)
     match rule = pat `isPrefixOf` unpackFS (ruleName rule)
 
 ruleAppCheck_help :: CompilerPhase -> Id -> [CoreExpr] -> [CoreRule] -> SDoc
     match rule = pat `isPrefixOf` unpackFS (ruleName rule)
 
 ruleAppCheck_help :: CompilerPhase -> Id -> [CoreExpr] -> [CoreRule] -> SDoc
@@ -792,9 +832,9 @@ ruleAppCheck_help phase fn args rules
     check_rule rule = rule_herald rule <> colon <+> rule_info rule
 
     rule_herald (BuiltinRule { ru_name = name })
     check_rule rule = rule_herald rule <> colon <+> rule_info rule
 
     rule_herald (BuiltinRule { ru_name = name })
-       = ptext SLIT("Builtin rule") <+> doubleQuotes (ftext name)
+       = ptext (sLit "Builtin rule") <+> doubleQuotes (ftext name)
     rule_herald (Rule { ru_name = name })
     rule_herald (Rule { ru_name = name })
-       = ptext SLIT("Rule") <+> doubleQuotes (ftext name)
+       = ptext (sLit "Rule") <+> doubleQuotes (ftext name)
 
     rule_info rule
        | Just _ <- matchRule noBlackList emptyInScopeSet args rough_args rule
 
     rule_info rule
        | Just _ <- matchRule noBlackList emptyInScopeSet args rough_args rule