Make rule-matching robust to lets
authorsimonpj@microsoft.com <unknown>
Thu, 25 May 2006 15:44:47 +0000 (15:44 +0000)
committersimonpj@microsoft.com <unknown>
Thu, 25 May 2006 15:44:47 +0000 (15:44 +0000)
Consider a RULE like
forall arr. splitD (joinD arr) = arr

Until now, this rule would not match code of form
splitD (let { d = ... } in joinD (...d...))
because the 'let' got in the way.

This patch makes the rule-matcher robust to lets.  See comments with
the Let case of Rules.match.

This improvement is highly desirable in the fusion rules for NDP
stuff that Roman is working on, where we are doing fusion of *overloaded*
functions (which may look lazy).  The let expression that Roman tripped
up on was a dictioary binding.

compiler/basicTypes/VarEnv.lhs
compiler/coreSyn/CoreFVs.lhs
compiler/specialise/Rules.lhs

index bfeecdc..da2f960 100644 (file)
@@ -26,7 +26,8 @@ module VarEnv (
 
        -- RnEnv2 and its operations
        RnEnv2, mkRnEnv2, rnBndr2, rnBndrs2, rnOccL, rnOccR, inRnEnvL, inRnEnvR,
 
        -- RnEnv2 and its operations
        RnEnv2, mkRnEnv2, rnBndr2, rnBndrs2, rnOccL, rnOccR, inRnEnvL, inRnEnvR,
-               rnBndrL, rnBndrR, nukeRnEnvL, nukeRnEnvR,
+               rnBndrL, rnBndrR, nukeRnEnvL, nukeRnEnvR, extendRnInScopeList,
+               rnInScope,
 
        -- TidyEnvs
        TidyEnv, emptyTidyEnv
 
        -- TidyEnvs
        TidyEnv, emptyTidyEnv
@@ -40,7 +41,7 @@ import VarSet
 import UniqFM  
 import Unique    ( Unique, deriveUnique, getUnique )
 import Util      ( zipEqual, foldl2 )
 import UniqFM  
 import Unique    ( Unique, deriveUnique, getUnique )
 import Util      ( zipEqual, foldl2 )
-import Maybes    ( orElse, isJust )
+import Maybes    ( orElse )
 import StaticFlags( opt_PprStyle_Debug )
 import Outputable
 import FastTypes
 import StaticFlags( opt_PprStyle_Debug )
 import Outputable
 import FastTypes
@@ -183,6 +184,13 @@ mkRnEnv2 vars = RV2        { envL     = emptyVarEnv
                        , envR     = emptyVarEnv
                        , in_scope = vars }
 
                        , envR     = emptyVarEnv
                        , in_scope = vars }
 
+extendRnInScopeList :: RnEnv2 -> [Var] -> RnEnv2
+extendRnInScopeList env vs
+  = env { in_scope = extendInScopeSetList (in_scope env) vs }
+
+rnInScope :: Var -> RnEnv2 -> Bool
+rnInScope x env = x `elemInScopeSet` in_scope env
+
 rnBndrs2 :: RnEnv2 -> [Var] -> [Var] -> RnEnv2
 -- Arg lists must be of equal length
 rnBndrs2 env bsL bsR = foldl2 rnBndr2 env bsL bsR 
 rnBndrs2 :: RnEnv2 -> [Var] -> [Var] -> RnEnv2
 -- Arg lists must be of equal length
 rnBndrs2 env bsL bsR = foldl2 rnBndr2 env bsL bsR 
@@ -236,8 +244,8 @@ rnOccR (RV2 { envR = env }) v = lookupVarEnv env v `orElse` v
 
 inRnEnvL, inRnEnvR :: RnEnv2 -> Var -> Bool
 -- Tells whether a variable is locally bound
 
 inRnEnvL, inRnEnvR :: RnEnv2 -> Var -> Bool
 -- Tells whether a variable is locally bound
-inRnEnvL (RV2 { envL = env }) v = isJust (lookupVarEnv env v)
-inRnEnvR (RV2 { envR = env }) v = isJust (lookupVarEnv env v)
+inRnEnvL (RV2 { envL = env }) v = v `elemVarEnv` env
+inRnEnvR (RV2 { envR = env }) v = v `elemVarEnv` env
 
 nukeRnEnvL, nukeRnEnvR :: RnEnv2 -> RnEnv2
 nukeRnEnvL env = env { envL = emptyVarEnv }
 
 nukeRnEnvL, nukeRnEnvR :: RnEnv2 -> RnEnv2
 nukeRnEnvL env = env { envL = emptyVarEnv }
index fb6017e..2fae6ac 100644 (file)
@@ -5,8 +5,9 @@ Taken quite directly from the Peyton Jones/Lester paper.
 
 \begin{code}
 module CoreFVs (
 
 \begin{code}
 module CoreFVs (
-       exprFreeVars,   -- CoreExpr -> VarSet   -- Find all locally-defined free Ids or tyvars
+       exprFreeVars,   -- CoreExpr   -> VarSet -- Find all locally-defined free Ids or tyvars
        exprsFreeVars,  -- [CoreExpr] -> VarSet
        exprsFreeVars,  -- [CoreExpr] -> VarSet
+       bindFreeVars,   -- CoreBind   -> VarSet
 
        exprSomeFreeVars, exprsSomeFreeVars,
        exprFreeNames, exprsFreeNames,
 
        exprSomeFreeVars, exprsSomeFreeVars,
        exprFreeNames, exprsFreeNames,
@@ -59,6 +60,12 @@ exprFreeVars = exprSomeFreeVars isLocalVar
 exprsFreeVars :: [CoreExpr] -> VarSet
 exprsFreeVars = foldr (unionVarSet . exprFreeVars) emptyVarSet
 
 exprsFreeVars :: [CoreExpr] -> VarSet
 exprsFreeVars = foldr (unionVarSet . exprFreeVars) emptyVarSet
 
+bindFreeVars :: CoreBind -> VarSet
+bindFreeVars (NonRec b r) = exprFreeVars r
+bindFreeVars (Rec prs)    = addBndrs (map fst prs) 
+                                    (foldr (union . rhs_fvs) noVars prs)
+                                    isLocalVar emptyVarSet
+
 exprSomeFreeVars :: InterestingVarFun  -- Says which Vars are interesting
                 -> CoreExpr
                 -> VarSet
 exprSomeFreeVars :: InterestingVarFun  -- Says which Vars are interesting
                 -> CoreExpr
                 -> VarSet
index 4c223d4..b12147d 100644 (file)
@@ -18,7 +18,7 @@ module Rules (
 
 import CoreSyn         -- All of it
 import OccurAnal       ( occurAnalyseExpr )
 
 import CoreSyn         -- All of it
 import OccurAnal       ( occurAnalyseExpr )
-import CoreFVs         ( exprFreeVars, exprsFreeVars, rulesRhsFreeVars )
+import CoreFVs         ( exprFreeVars, exprsFreeVars, bindFreeVars, rulesRhsFreeVars )
 import CoreUnfold      ( isCheapUnfolding, unfoldingTemplate )
 import CoreUtils       ( tcEqExprX )
 import PprCore         ( pprRules )
 import CoreUnfold      ( isCheapUnfolding, unfoldingTemplate )
 import CoreUtils       ( tcEqExprX )
 import PprCore         ( pprRules )
@@ -33,7 +33,8 @@ import VarEnv         ( IdEnv, InScopeSet, emptyTidyEnv,
                          emptyInScopeSet, mkInScopeSet, extendInScopeSetList, 
                          emptyVarEnv, lookupVarEnv, extendVarEnv, 
                          nukeRnEnvL, mkRnEnv2, rnOccR, rnOccL, inRnEnvR,
                          emptyInScopeSet, mkInScopeSet, extendInScopeSetList, 
                          emptyVarEnv, lookupVarEnv, extendVarEnv, 
                          nukeRnEnvL, mkRnEnv2, rnOccR, rnOccL, inRnEnvR,
-                         rnBndrR, rnBndr2, rnBndrL, rnBndrs2 )
+                         rnBndrR, rnBndr2, rnBndrL, rnBndrs2,
+                         rnInScope, extendRnInScopeList )
 import VarSet
 import Name            ( Name, NamedThing(..), nameOccName )
 import NameEnv
 import VarSet
 import Name            ( Name, NamedThing(..), nameOccName )
 import NameEnv
@@ -42,8 +43,9 @@ import BasicTypes     ( Activation, CompilerPhase, isActive )
 import Outputable
 import FastString
 import Maybes          ( isJust, orElse )
 import Outputable
 import FastString
 import Maybes          ( isJust, orElse )
+import OrdList
 import Bag
 import Bag
-import Util            ( singleton )
+import Util            ( singleton, mapAccumL )
 import List            ( isPrefixOf )
 \end{code}
 
 import List            ( isPrefixOf )
 \end{code}
 
@@ -305,9 +307,10 @@ matchRule is_active in_scope args rough_args
   | otherwise
   = case matchN in_scope tpl_vars tpl_args args of
        Nothing                    -> Nothing
   | otherwise
   = case matchN in_scope tpl_vars tpl_args args of
        Nothing                    -> Nothing
-       Just (tpl_vals, leftovers) -> Just (rule_fn
-                                           `mkApps` tpl_vals
-                                           `mkApps` leftovers)
+       Just (binds, tpl_vals, leftovers) -> Just (mkLets binds $
+                                                  rule_fn
+                                                   `mkApps` tpl_vals
+                                                   `mkApps` leftovers)
   where
     rule_fn = occurAnalyseExpr (mkLams tpl_vars rhs)
        -- We could do this when putting things into the rulebase, I guess
   where
     rule_fn = occurAnalyseExpr (mkLams tpl_vars rhs)
        -- We could do this when putting things into the rulebase, I guess
@@ -318,12 +321,16 @@ matchN    :: InScopeSet
        -> [Var]                -- Template tyvars
        -> [CoreExpr]           -- Template
        -> [CoreExpr]           -- Target; can have more elts than template
        -> [Var]                -- Template tyvars
        -> [CoreExpr]           -- Template
        -> [CoreExpr]           -- Target; can have more elts than template
-       -> Maybe ([CoreExpr],   -- What is substituted for each template var
+       -> Maybe ([CoreBind],   -- Bindings to wrap around the entire result
+                 [CoreExpr],   -- What is substituted for each template var
                  [CoreExpr])   -- Leftover target exprs
 
 matchN in_scope tmpl_vars tmpl_es target_es
                  [CoreExpr])   -- Leftover target exprs
 
 matchN in_scope tmpl_vars tmpl_es target_es
-  = do { (subst, leftover_es) <- go init_menv emptySubstEnv tmpl_es target_es
-       ; return (map (lookup_tmpl subst) tmpl_vars, leftover_es) }
+  = do { ((tv_subst, id_subst, binds), leftover_es)
+               <- go init_menv emptySubstEnv tmpl_es target_es
+       ; return (fromOL binds, 
+                 map (lookup_tmpl tv_subst id_subst) tmpl_vars, 
+                 leftover_es) }
   where
     init_menv = ME { me_tmpls = mkVarSet tmpl_vars, me_env = init_rn_env }
     init_rn_env = mkRnEnv2 (extendInScopeSetList in_scope tmpl_vars)
   where
     init_menv = ME { me_tmpls = mkVarSet tmpl_vars, me_env = init_rn_env }
     init_rn_env = mkRnEnv2 (extendInScopeSetList in_scope tmpl_vars)
@@ -333,8 +340,8 @@ matchN in_scope tmpl_vars tmpl_es target_es
     go menv subst (t:ts) (e:es) = do { subst1 <- match menv subst t e 
                                     ; go menv subst1 ts es }
 
     go menv subst (t:ts) (e:es) = do { subst1 <- match menv subst t e 
                                     ; go menv subst1 ts es }
 
-    lookup_tmpl :: (TvSubstEnv, IdSubstEnv) -> Var -> CoreExpr
-    lookup_tmpl (tv_subst, id_subst) tmpl_var
+    lookup_tmpl :: TvSubstEnv -> IdSubstEnv -> Var -> CoreExpr
+    lookup_tmpl tv_subst id_subst tmpl_var
        | isTyVar tmpl_var = case lookupVarEnv tv_subst tmpl_var of
                                Just ty         -> Type ty
                                Nothing         -> unbound tmpl_var
        | isTyVar tmpl_var = case lookupVarEnv tv_subst tmpl_var of
                                Just ty         -> Type ty
                                Nothing         -> unbound tmpl_var
@@ -353,13 +360,18 @@ matchN in_scope tmpl_vars tmpl_es target_es
 \begin{code}
 -- These two definitions are not the same as in Subst,
 -- but they simple and direct, and purely local to this module
 \begin{code}
 -- These two definitions are not the same as in Subst,
 -- but they simple and direct, and purely local to this module
--- The third, for TvSubstEnv, is the same as in VarEnv, but repeated here
--- for uniformity with IdSubstEnv
-type SubstEnv   = (TvSubstEnv, IdSubstEnv)     
+--
+-- * The domain of the TvSubstEnv and IdSubstEnv are the template
+--   variables passed into the match.
+--
+-- * The (OrdList CoreBind) in a SubstEnv are the bindings floated out
+--   from nested matches; see the Let case of match, below
+--
+type SubstEnv   = (TvSubstEnv, IdSubstEnv, OrdList CoreBind)
 type IdSubstEnv = IdEnv    CoreExpr            
 
 emptySubstEnv :: SubstEnv
 type IdSubstEnv = IdEnv    CoreExpr            
 
 emptySubstEnv :: SubstEnv
-emptySubstEnv = (emptyVarEnv, emptyVarEnv)
+emptySubstEnv = (emptyVarEnv, emptyVarEnv, nilOL)
 
 
 --     At one stage I tried to match even if there are more 
 
 
 --     At one stage I tried to match even if there are more 
@@ -393,29 +405,10 @@ match :: MatchEnv
 -- succeed in matching what looks like the template variable 'a' against 3.
 
 -- The Var case follows closely what happens in Unify.match
 -- succeed in matching what looks like the template variable 'a' against 3.
 
 -- The Var case follows closely what happens in Unify.match
-match menv subst@(tv_subst, id_subst) (Var v1) e2 
-  | v1 `elemVarSet` me_tmpls menv
-  = case lookupVarEnv id_subst v1' of
-       Nothing | any (inRnEnvR rn_env) (varSetElems (exprFreeVars e2))
-               -> Nothing      -- Occurs check failure
-               -- e.g. match forall a. (\x-> a x) against (\y. y y)
-
-               | otherwise
-               -> Just (tv_subst, extendVarEnv id_subst v1 e2)
-
-       Just e2' | tcEqExprX (nukeRnEnvL rn_env) e2' e2 
-                -> Just subst
-
-       other -> Nothing
-
-  |    -- v1 is not a template variable; check for an exact match with e2
-    Var v2 <- e2, v1' == rnOccR rn_env v2
+match menv subst (Var v1) e2 
+  | Just subst <- match_var menv subst v1 e2
   = Just subst
 
   = Just subst
 
-  where
-    rn_env = me_env menv
-    v1'    = rnOccL rn_env v1
-
 -- Here is another important rule: if the term being matched is a
 -- variable, we expand it so long as its unfolding is a WHNF
 -- (Its occurrence information is not necessarily up to date,
 -- Here is another important rule: if the term being matched is a
 -- variable, we expand it so long as its unfolding is a WHNF
 -- (Its occurrence information is not necessarily up to date,
@@ -470,26 +463,101 @@ match menv subst (Note (Coerce to1 from1) e1) (Note (Coerce to2 from2) e2)
        ; subst2 <- match_ty menv subst1 from1 from2
        ; match menv subst2 e1 e2 }
 
        ; subst2 <- match_ty menv subst1 from1 from2
        ; match menv subst2 e1 e2 }
 
+-- Matching a let-expression.  Consider
+--     RULE forall x.  f (g x) = <rhs>
+-- and target expression
+--     f (let { w=R } in g E))
+-- Then we'd like the rule to match, to generate
+--     let { w=R } in (\x. <rhs>) E
+-- In effect, we want to float the let-binding outward, to enable
+-- the match to happen.  This is the WHOLE REASON for accumulating
+-- bindings in the SubstEnv
+--
+-- We can only do this if
+--     (a) Widening the scope of w does not capture any variables
+--         We use a conservative test: w is not already in scope
+--     (b) The free variables of R are not bound by the part of the
+--         target expression outside the let binding; e.g.
+--             f (\v. let w = v+1 in g E)
+--         Here we obviously cannot float the let-binding for w.
+
+match menv subst@(tv_subst, id_subst, binds) e1 (Let bind e2)
+  | all freshly_bound bndrs,
+    not (any locally_bound bind_fvs)
+  = match (menv { me_env = rn_env' }) 
+         (tv_subst, id_subst, binds `snocOL` bind)
+         e1 e2
+  where
+    rn_env   = me_env menv
+    bndrs    = bindersOf bind
+    bind_fvs = varSetElems (bindFreeVars bind)
+    freshly_bound x = not (x `rnInScope` rn_env)
+    locally_bound x = inRnEnvR rn_env x
+    rn_env' = extendRnInScopeList rn_env bndrs
+
 -- This is an interesting rule: we simply ignore lets in the 
 -- term being matched against!  The unfolding inside it is (by assumption)
 -- already inside any occurrences of the bound variables, so we'll expand
 -- This is an interesting rule: we simply ignore lets in the 
 -- term being matched against!  The unfolding inside it is (by assumption)
 -- already inside any occurrences of the bound variables, so we'll expand
--- them when we encounter them.
-match menv subst e1 (Let (NonRec x2 r2) e2)
-  = match menv' subst e1 e2
+-- them when we encounter them.  This gives a chance of matching
+--     forall x,y.  f (g (x,y))
+-- against
+--     f (let v = (a,b) in g v)
+
+match menv subst e1 (Let bind e2)
+  = match (menv { me_env = rn_env' }) subst e1 e2
   where
   where
-    menv' = menv { me_env = fst (rnBndrR (me_env menv) x2) }
-       -- It's important to do this renaming. For example:
+    (rn_env', _bndrs') = mapAccumL rnBndrR (me_env menv) (bindersOf bind)
+       -- It's important to do this renaming, so that the bndrs
+       -- are brought into the local scope. For example:
        -- Matching
        --      forall f,x,xs. f (x:xs)
        --   against
        --      f (let y = e in (y:[]))
        -- Matching
        --      forall f,x,xs. f (x:xs)
        --   against
        --      f (let y = e in (y:[]))
-       -- We must not get success with x->y!  Instead, we 
-       -- need an occurs check.
+       -- We must not get success with x->y!  So we record that y is
+       -- locally bound (with rnBndrR), and proceed.  The Var case
+       -- will fail when trying to bind x->y
+       --
 
 -- Everything else fails
 match menv subst e1 e2 = Nothing
 
 ------------------------------------------
 
 -- Everything else fails
 match menv subst e1 e2 = Nothing
 
 ------------------------------------------
+match_var :: MatchEnv
+         -> SubstEnv
+         -> Var                -- Template
+         -> CoreExpr           -- Target
+         -> Maybe SubstEnv
+match_var menv subst@(tv_subst, id_subst, binds) v1 e2
+  | v1' `elemVarSet` me_tmpls menv
+  = case lookupVarEnv id_subst v1' of
+       Nothing | any (inRnEnvR rn_env) (varSetElems (exprFreeVars e2))
+               -> Nothing      -- Occurs check failure
+               -- e.g. match forall a. (\x-> a x) against (\y. y y)
+
+               | otherwise
+               -> Just (tv_subst, extendVarEnv id_subst v1 e2, binds)
+
+       Just e2' | tcEqExprX (nukeRnEnvL rn_env) e2' e2 
+                -> Just subst
+
+                | otherwise
+                -> Nothing
+
+  | otherwise  -- v1 is not a template variable; check for an exact match with e2
+  = case e2 of
+       Var v2 | v1' == rnOccR rn_env v2 -> Just subst
+       other                           -> Nothing
+
+  where
+    rn_env = me_env menv
+    v1'    = rnOccL rn_env v1  
+       -- If the template is
+       --      forall x. f x (\x -> x) = ...
+       -- Then the x inside the lambda isn't the 
+       -- template x, so we must rename first!
+                               
+
+------------------------------------------
 match_alts :: MatchEnv
       -> SubstEnv
       -> [CoreAlt]             -- Template
 match_alts :: MatchEnv
       -> SubstEnv
       -> [CoreAlt]             -- Template
@@ -517,9 +585,9 @@ We only want to replace (f T) with f', not (f Int).
 
 \begin{code}
 ------------------------------------------
 
 \begin{code}
 ------------------------------------------
-match_ty menv (tv_subst, id_subst) ty1 ty2
+match_ty menv (tv_subst, id_subst, binds) ty1 ty2
   = do { tv_subst' <- Unify.ruleMatchTyX menv tv_subst ty1 ty2
   = do { tv_subst' <- Unify.ruleMatchTyX menv tv_subst ty1 ty2
-       ; return (tv_subst', id_subst) }
+       ; return (tv_subst', id_subst, binds) }
 \end{code}
 
 
 \end{code}