From: simonpj@microsoft.com Date: Thu, 25 May 2006 15:44:47 +0000 (+0000) Subject: Make rule-matching robust to lets X-Git-Url: http://git.megacz.com/?p=ghc-hetmet.git;a=commitdiff_plain;h=7656f8c4bd8d786bf83c1ab2dca0cdd1a903e5bf Make rule-matching robust to lets Consider a RULE like forall arr. splitD (joinD arr) = arr Until now, this rule would not match code of form splitD (let { d = ... } in joinD (...d...)) because the 'let' got in the way. This patch makes the rule-matcher robust to lets. See comments with the Let case of Rules.match. This improvement is highly desirable in the fusion rules for NDP stuff that Roman is working on, where we are doing fusion of *overloaded* functions (which may look lazy). The let expression that Roman tripped up on was a dictioary binding. --- diff --git a/compiler/basicTypes/VarEnv.lhs b/compiler/basicTypes/VarEnv.lhs index bfeecdc..da2f960 100644 --- a/compiler/basicTypes/VarEnv.lhs +++ b/compiler/basicTypes/VarEnv.lhs @@ -26,7 +26,8 @@ module VarEnv ( -- RnEnv2 and its operations RnEnv2, mkRnEnv2, rnBndr2, rnBndrs2, rnOccL, rnOccR, inRnEnvL, inRnEnvR, - rnBndrL, rnBndrR, nukeRnEnvL, nukeRnEnvR, + rnBndrL, rnBndrR, nukeRnEnvL, nukeRnEnvR, extendRnInScopeList, + rnInScope, -- TidyEnvs TidyEnv, emptyTidyEnv @@ -40,7 +41,7 @@ import VarSet import UniqFM import Unique ( Unique, deriveUnique, getUnique ) import Util ( zipEqual, foldl2 ) -import Maybes ( orElse, isJust ) +import Maybes ( orElse ) import StaticFlags( opt_PprStyle_Debug ) import Outputable import FastTypes @@ -183,6 +184,13 @@ mkRnEnv2 vars = RV2 { envL = emptyVarEnv , envR = emptyVarEnv , in_scope = vars } +extendRnInScopeList :: RnEnv2 -> [Var] -> RnEnv2 +extendRnInScopeList env vs + = env { in_scope = extendInScopeSetList (in_scope env) vs } + +rnInScope :: Var -> RnEnv2 -> Bool +rnInScope x env = x `elemInScopeSet` in_scope env + rnBndrs2 :: RnEnv2 -> [Var] -> [Var] -> RnEnv2 -- Arg lists must be of equal length rnBndrs2 env bsL bsR = foldl2 rnBndr2 env bsL bsR @@ -236,8 +244,8 @@ rnOccR (RV2 { envR = env }) v = lookupVarEnv env v `orElse` v inRnEnvL, inRnEnvR :: RnEnv2 -> Var -> Bool -- Tells whether a variable is locally bound -inRnEnvL (RV2 { envL = env }) v = isJust (lookupVarEnv env v) -inRnEnvR (RV2 { envR = env }) v = isJust (lookupVarEnv env v) +inRnEnvL (RV2 { envL = env }) v = v `elemVarEnv` env +inRnEnvR (RV2 { envR = env }) v = v `elemVarEnv` env nukeRnEnvL, nukeRnEnvR :: RnEnv2 -> RnEnv2 nukeRnEnvL env = env { envL = emptyVarEnv } diff --git a/compiler/coreSyn/CoreFVs.lhs b/compiler/coreSyn/CoreFVs.lhs index fb6017e..2fae6ac 100644 --- a/compiler/coreSyn/CoreFVs.lhs +++ b/compiler/coreSyn/CoreFVs.lhs @@ -5,8 +5,9 @@ Taken quite directly from the Peyton Jones/Lester paper. \begin{code} module CoreFVs ( - exprFreeVars, -- CoreExpr -> VarSet -- Find all locally-defined free Ids or tyvars + exprFreeVars, -- CoreExpr -> VarSet -- Find all locally-defined free Ids or tyvars exprsFreeVars, -- [CoreExpr] -> VarSet + bindFreeVars, -- CoreBind -> VarSet exprSomeFreeVars, exprsSomeFreeVars, exprFreeNames, exprsFreeNames, @@ -59,6 +60,12 @@ exprFreeVars = exprSomeFreeVars isLocalVar exprsFreeVars :: [CoreExpr] -> VarSet exprsFreeVars = foldr (unionVarSet . exprFreeVars) emptyVarSet +bindFreeVars :: CoreBind -> VarSet +bindFreeVars (NonRec b r) = exprFreeVars r +bindFreeVars (Rec prs) = addBndrs (map fst prs) + (foldr (union . rhs_fvs) noVars prs) + isLocalVar emptyVarSet + exprSomeFreeVars :: InterestingVarFun -- Says which Vars are interesting -> CoreExpr -> VarSet diff --git a/compiler/specialise/Rules.lhs b/compiler/specialise/Rules.lhs index 4c223d4..b12147d 100644 --- a/compiler/specialise/Rules.lhs +++ b/compiler/specialise/Rules.lhs @@ -18,7 +18,7 @@ module Rules ( import CoreSyn -- All of it import OccurAnal ( occurAnalyseExpr ) -import CoreFVs ( exprFreeVars, exprsFreeVars, rulesRhsFreeVars ) +import CoreFVs ( exprFreeVars, exprsFreeVars, bindFreeVars, rulesRhsFreeVars ) import CoreUnfold ( isCheapUnfolding, unfoldingTemplate ) import CoreUtils ( tcEqExprX ) import PprCore ( pprRules ) @@ -33,7 +33,8 @@ import VarEnv ( IdEnv, InScopeSet, emptyTidyEnv, emptyInScopeSet, mkInScopeSet, extendInScopeSetList, emptyVarEnv, lookupVarEnv, extendVarEnv, nukeRnEnvL, mkRnEnv2, rnOccR, rnOccL, inRnEnvR, - rnBndrR, rnBndr2, rnBndrL, rnBndrs2 ) + rnBndrR, rnBndr2, rnBndrL, rnBndrs2, + rnInScope, extendRnInScopeList ) import VarSet import Name ( Name, NamedThing(..), nameOccName ) import NameEnv @@ -42,8 +43,9 @@ import BasicTypes ( Activation, CompilerPhase, isActive ) import Outputable import FastString import Maybes ( isJust, orElse ) +import OrdList import Bag -import Util ( singleton ) +import Util ( singleton, mapAccumL ) import List ( isPrefixOf ) \end{code} @@ -305,9 +307,10 @@ matchRule is_active in_scope args rough_args | otherwise = case matchN in_scope tpl_vars tpl_args args of Nothing -> Nothing - Just (tpl_vals, leftovers) -> Just (rule_fn - `mkApps` tpl_vals - `mkApps` leftovers) + Just (binds, tpl_vals, leftovers) -> Just (mkLets binds $ + rule_fn + `mkApps` tpl_vals + `mkApps` leftovers) where rule_fn = occurAnalyseExpr (mkLams tpl_vars rhs) -- We could do this when putting things into the rulebase, I guess @@ -318,12 +321,16 @@ matchN :: InScopeSet -> [Var] -- Template tyvars -> [CoreExpr] -- Template -> [CoreExpr] -- Target; can have more elts than template - -> Maybe ([CoreExpr], -- What is substituted for each template var + -> Maybe ([CoreBind], -- Bindings to wrap around the entire result + [CoreExpr], -- What is substituted for each template var [CoreExpr]) -- Leftover target exprs matchN in_scope tmpl_vars tmpl_es target_es - = do { (subst, leftover_es) <- go init_menv emptySubstEnv tmpl_es target_es - ; return (map (lookup_tmpl subst) tmpl_vars, leftover_es) } + = do { ((tv_subst, id_subst, binds), leftover_es) + <- go init_menv emptySubstEnv tmpl_es target_es + ; return (fromOL binds, + map (lookup_tmpl tv_subst id_subst) tmpl_vars, + leftover_es) } where init_menv = ME { me_tmpls = mkVarSet tmpl_vars, me_env = init_rn_env } init_rn_env = mkRnEnv2 (extendInScopeSetList in_scope tmpl_vars) @@ -333,8 +340,8 @@ matchN in_scope tmpl_vars tmpl_es target_es go menv subst (t:ts) (e:es) = do { subst1 <- match menv subst t e ; go menv subst1 ts es } - lookup_tmpl :: (TvSubstEnv, IdSubstEnv) -> Var -> CoreExpr - lookup_tmpl (tv_subst, id_subst) tmpl_var + lookup_tmpl :: TvSubstEnv -> IdSubstEnv -> Var -> CoreExpr + lookup_tmpl tv_subst id_subst tmpl_var | isTyVar tmpl_var = case lookupVarEnv tv_subst tmpl_var of Just ty -> Type ty Nothing -> unbound tmpl_var @@ -353,13 +360,18 @@ matchN in_scope tmpl_vars tmpl_es target_es \begin{code} -- These two definitions are not the same as in Subst, -- but they simple and direct, and purely local to this module --- The third, for TvSubstEnv, is the same as in VarEnv, but repeated here --- for uniformity with IdSubstEnv -type SubstEnv = (TvSubstEnv, IdSubstEnv) +-- +-- * The domain of the TvSubstEnv and IdSubstEnv are the template +-- variables passed into the match. +-- +-- * The (OrdList CoreBind) in a SubstEnv are the bindings floated out +-- from nested matches; see the Let case of match, below +-- +type SubstEnv = (TvSubstEnv, IdSubstEnv, OrdList CoreBind) type IdSubstEnv = IdEnv CoreExpr emptySubstEnv :: SubstEnv -emptySubstEnv = (emptyVarEnv, emptyVarEnv) +emptySubstEnv = (emptyVarEnv, emptyVarEnv, nilOL) -- At one stage I tried to match even if there are more @@ -393,29 +405,10 @@ match :: MatchEnv -- succeed in matching what looks like the template variable 'a' against 3. -- The Var case follows closely what happens in Unify.match -match menv subst@(tv_subst, id_subst) (Var v1) e2 - | v1 `elemVarSet` me_tmpls menv - = case lookupVarEnv id_subst v1' of - Nothing | any (inRnEnvR rn_env) (varSetElems (exprFreeVars e2)) - -> Nothing -- Occurs check failure - -- e.g. match forall a. (\x-> a x) against (\y. y y) - - | otherwise - -> Just (tv_subst, extendVarEnv id_subst v1 e2) - - Just e2' | tcEqExprX (nukeRnEnvL rn_env) e2' e2 - -> Just subst - - other -> Nothing - - | -- v1 is not a template variable; check for an exact match with e2 - Var v2 <- e2, v1' == rnOccR rn_env v2 +match menv subst (Var v1) e2 + | Just subst <- match_var menv subst v1 e2 = Just subst - where - rn_env = me_env menv - v1' = rnOccL rn_env v1 - -- Here is another important rule: if the term being matched is a -- variable, we expand it so long as its unfolding is a WHNF -- (Its occurrence information is not necessarily up to date, @@ -470,26 +463,101 @@ match menv subst (Note (Coerce to1 from1) e1) (Note (Coerce to2 from2) e2) ; subst2 <- match_ty menv subst1 from1 from2 ; match menv subst2 e1 e2 } +-- Matching a let-expression. Consider +-- RULE forall x. f (g x) = +-- and target expression +-- f (let { w=R } in g E)) +-- Then we'd like the rule to match, to generate +-- let { w=R } in (\x. ) E +-- In effect, we want to float the let-binding outward, to enable +-- the match to happen. This is the WHOLE REASON for accumulating +-- bindings in the SubstEnv +-- +-- We can only do this if +-- (a) Widening the scope of w does not capture any variables +-- We use a conservative test: w is not already in scope +-- (b) The free variables of R are not bound by the part of the +-- target expression outside the let binding; e.g. +-- f (\v. let w = v+1 in g E) +-- Here we obviously cannot float the let-binding for w. + +match menv subst@(tv_subst, id_subst, binds) e1 (Let bind e2) + | all freshly_bound bndrs, + not (any locally_bound bind_fvs) + = match (menv { me_env = rn_env' }) + (tv_subst, id_subst, binds `snocOL` bind) + e1 e2 + where + rn_env = me_env menv + bndrs = bindersOf bind + bind_fvs = varSetElems (bindFreeVars bind) + freshly_bound x = not (x `rnInScope` rn_env) + locally_bound x = inRnEnvR rn_env x + rn_env' = extendRnInScopeList rn_env bndrs + -- This is an interesting rule: we simply ignore lets in the -- term being matched against! The unfolding inside it is (by assumption) -- already inside any occurrences of the bound variables, so we'll expand --- them when we encounter them. -match menv subst e1 (Let (NonRec x2 r2) e2) - = match menv' subst e1 e2 +-- them when we encounter them. This gives a chance of matching +-- forall x,y. f (g (x,y)) +-- against +-- f (let v = (a,b) in g v) + +match menv subst e1 (Let bind e2) + = match (menv { me_env = rn_env' }) subst e1 e2 where - menv' = menv { me_env = fst (rnBndrR (me_env menv) x2) } - -- It's important to do this renaming. For example: + (rn_env', _bndrs') = mapAccumL rnBndrR (me_env menv) (bindersOf bind) + -- It's important to do this renaming, so that the bndrs + -- are brought into the local scope. For example: -- Matching -- forall f,x,xs. f (x:xs) -- against -- f (let y = e in (y:[])) - -- We must not get success with x->y! Instead, we - -- need an occurs check. + -- We must not get success with x->y! So we record that y is + -- locally bound (with rnBndrR), and proceed. The Var case + -- will fail when trying to bind x->y + -- -- Everything else fails match menv subst e1 e2 = Nothing ------------------------------------------ +match_var :: MatchEnv + -> SubstEnv + -> Var -- Template + -> CoreExpr -- Target + -> Maybe SubstEnv +match_var menv subst@(tv_subst, id_subst, binds) v1 e2 + | v1' `elemVarSet` me_tmpls menv + = case lookupVarEnv id_subst v1' of + Nothing | any (inRnEnvR rn_env) (varSetElems (exprFreeVars e2)) + -> Nothing -- Occurs check failure + -- e.g. match forall a. (\x-> a x) against (\y. y y) + + | otherwise + -> Just (tv_subst, extendVarEnv id_subst v1 e2, binds) + + Just e2' | tcEqExprX (nukeRnEnvL rn_env) e2' e2 + -> Just subst + + | otherwise + -> Nothing + + | otherwise -- v1 is not a template variable; check for an exact match with e2 + = case e2 of + Var v2 | v1' == rnOccR rn_env v2 -> Just subst + other -> Nothing + + where + rn_env = me_env menv + v1' = rnOccL rn_env v1 + -- If the template is + -- forall x. f x (\x -> x) = ... + -- Then the x inside the lambda isn't the + -- template x, so we must rename first! + + +------------------------------------------ match_alts :: MatchEnv -> SubstEnv -> [CoreAlt] -- Template @@ -517,9 +585,9 @@ We only want to replace (f T) with f', not (f Int). \begin{code} ------------------------------------------ -match_ty menv (tv_subst, id_subst) ty1 ty2 +match_ty menv (tv_subst, id_subst, binds) ty1 ty2 = do { tv_subst' <- Unify.ruleMatchTyX menv tv_subst ty1 ty2 - ; return (tv_subst', id_subst) } + ; return (tv_subst', id_subst, binds) } \end{code}