X-Git-Url: http://git.megacz.com/?p=ghc-hetmet.git;a=blobdiff_plain;f=compiler%2Ftypecheck%2FTcRules.lhs;h=71c539993da115a6c4aab534236f47517715fcd3;hp=a95251d1b40627c24aa6660f3068ecadae519ac6;hb=5de363ca9ebdb7d85e3c353c1cffdf0a1c11128e;hpb=422028fc34aff8fe02ff6a35f2b3cfa0481d8ed7 diff --git a/compiler/typecheck/TcRules.lhs b/compiler/typecheck/TcRules.lhs index a95251d..71c5399 100644 --- a/compiler/typecheck/TcRules.lhs +++ b/compiler/typecheck/TcRules.lhs @@ -16,51 +16,66 @@ import TcType import TcHsType import TcExpr import TcEnv -import Inst import Id +import Var ( Var ) import Name +import VarSet import SrcLoc import Outputable import FastString +import Data.List( partition ) \end{code} +Note [Typechecking rules] +~~~~~~~~~~~~~~~~~~~~~~~~~ +We *infer* the typ of the LHS, and use that type to *check* the type of +the RHS. That means that higher-rank rules work reasonably well. Here's +an example (test simplCore/should_compile/rule2.hs) produced by Roman: + + foo :: (forall m. m a -> m b) -> m a -> m b + foo f = ... + + bar :: (forall m. m a -> m a) -> m a -> m a + bar f = ... + + {-# RULES "foo/bar" foo = bar #-} + +He wanted the rule to typecheck. + \begin{code} tcRules :: [LRuleDecl Name] -> TcM [LRuleDecl TcId] tcRules decls = mapM (wrapLocM tcRule) decls tcRule :: RuleDecl Name -> TcM (RuleDecl TcId) -tcRule (HsRule name act vars lhs fv_lhs rhs fv_rhs) - = addErrCtxt (ruleCtxt name) $ do - traceTc (ptext (sLit "---- Rule ------") <+> ppr name) - rule_ty <- newFlexiTyVarTy openTypeKind - - -- Deal with the tyvars mentioned in signatures - (ids, lhs', rhs', lhs_lie, rhs_lie) <- - tcRuleBndrs vars $ \ ids -> do - -- Now LHS and RHS - (lhs', lhs_lie) <- getLIE (tcMonoExpr lhs rule_ty) - (rhs', rhs_lie) <- getLIE (tcMonoExpr rhs rule_ty) - return (ids, lhs', rhs', lhs_lie, rhs_lie) - - -- Check that LHS has no overloading at all - (lhs_dicts, lhs_binds) <- tcSimplifyRuleLhs lhs_lie - - -- Gather the template variables and tyvars - let - tpl_ids = map instToId lhs_dicts ++ ids +tcRule (HsRule name act hs_bndrs lhs fv_lhs rhs fv_rhs) + = addErrCtxt (ruleCtxt name) $ + do { traceTc "---- Rule ------" (ppr name) + + -- Note [Typechecking rules] + ; vars <- tcRuleBndrs hs_bndrs + ; let (id_bndrs, tv_bndrs) = partition isId vars + ; (lhs', lhs_lie, rhs', rhs_lie, rule_ty) + <- tcExtendTyVarEnv tv_bndrs $ + tcExtendIdEnv id_bndrs $ + do { ((lhs', rule_ty), lhs_lie) <- getConstraints (tcInferRho lhs) + ; (rhs', rhs_lie) <- getConstraints (tcMonoExpr rhs rule_ty) + ; return (lhs', lhs_lie, rhs', rhs_lie, rule_ty) } + + ; (lhs_dicts, lhs_ev_binds, rhs_ev_binds) + <- simplifyRule name tv_bndrs lhs_lie rhs_lie -- IMPORTANT! We *quantify* over any dicts that appear in the LHS -- Reason: - -- a) The particular dictionary isn't important, because its value + -- (a) The particular dictionary isn't important, because its value -- depends only on the type -- e.g gcd Int $fIntegralInt -- Here we'd like to match against (gcd Int any_d) for any 'any_d' -- - -- b) We'd like to make available the dictionaries bound - -- on the LHS in the RHS, so quantifying over them is good - -- See the 'lhs_dicts' in tcSimplifyAndCheck for the RHS + -- (b) We'd like to make available the dictionaries bound + -- on the LHS in the RHS, so quantifying over them is good + -- See the 'lhs_dicts' in tcSimplifyAndCheck for the RHS - -- We initially quantify over any tyvars free in *either* the rule + -- We quantify over any tyvars free in *either* the rule -- *or* the bound variables. The latter is important. Consider -- ss (x,(y,z)) = (x,z) -- RULE: forall v. fst (ss v) = fst v @@ -68,32 +83,29 @@ tcRule (HsRule name act vars lhs fv_lhs rhs fv_rhs) -- -- We also need to get the free tyvars of the LHS; but we do that -- during zonking (see TcHsSyn.zonkRule) - -- - forall_tvs = tyVarsOfTypes (rule_ty : map idType tpl_ids) - -- RHS can be a bit more lenient. In particular, - -- we let constant dictionaries etc float outwards - -- - -- NB: tcSimplifyInferCheck zonks the forall_tvs, and - -- knocks out any that are constrained by the environment - loc <- getInstLoc (SigOrigin (RuleSkol name)) - (forall_tvs1, rhs_binds) <- tcSimplifyInferCheck loc - forall_tvs - lhs_dicts rhs_lie - - return (HsRule name act - (map (RuleBndr . noLoc) (forall_tvs1 ++ tpl_ids)) -- yuk - (mkHsDictLet lhs_binds lhs') fv_lhs - (mkHsDictLet rhs_binds rhs') fv_rhs) - -tcRuleBndrs :: [RuleBndr Name] -> ([Id] -> TcM a) -> TcM a -tcRuleBndrs [] thing_inside = thing_inside [] -tcRuleBndrs (RuleBndr var : vars) thing_inside + ; let tpl_ids = lhs_dicts ++ id_bndrs + forall_tvs = tyVarsOfTypes (rule_ty : map idType tpl_ids) + + -- Now figure out what to quantify over + -- c.f. TcSimplify.simplifyInfer + ; zonked_forall_tvs <- zonkTcTyVarsAndFV forall_tvs + ; gbl_tvs <- tcGetGlobalTyVars -- Already zonked + ; qtvs <- zonkQuantifiedTyVars (varSetElems (zonked_forall_tvs `minusVarSet` gbl_tvs)) + + ; return (HsRule name act + (map (RuleBndr . noLoc) (qtvs ++ tpl_ids)) -- yuk + (mkHsDictLet lhs_ev_binds lhs') fv_lhs + (mkHsDictLet rhs_ev_binds rhs') fv_rhs) } + +tcRuleBndrs :: [RuleBndr Name] -> TcM [Var] +tcRuleBndrs [] + = return [] +tcRuleBndrs (RuleBndr var : rule_bndrs) = do { ty <- newFlexiTyVarTy openTypeKind - ; let id = mkLocalId (unLoc var) ty - ; tcExtendIdEnv [id] $ - tcRuleBndrs vars (\ids -> thing_inside (id:ids)) } -tcRuleBndrs (RuleBndrSig var rn_ty : vars) thing_inside + ; vars <- tcRuleBndrs rule_bndrs + ; return (mkLocalId (unLoc var) ty : vars) } +tcRuleBndrs (RuleBndrSig var rn_ty : rule_bndrs) -- e.g x :: a->a -- The tyvar 'a' is brought into scope first, just as if you'd written -- a::*, x :: a->a @@ -102,9 +114,11 @@ tcRuleBndrs (RuleBndrSig var rn_ty : vars) thing_inside ; let skol_tvs = tcSkolSigTyVars (SigSkol ctxt) tyvars id_ty = substTyWith tyvars (mkTyVarTys skol_tvs) ty id = mkLocalId (unLoc var) id_ty - ; tcExtendTyVarEnv skol_tvs $ - tcExtendIdEnv [id] $ - tcRuleBndrs vars (\ids -> thing_inside (id:ids)) } + + -- The type variables scope over subsequent bindings; yuk + ; vars <- tcExtendTyVarEnv skol_tvs $ + tcRuleBndrs rule_bndrs + ; return (skol_tvs ++ id : vars) } ruleCtxt :: FastString -> SDoc ruleCtxt name = ptext (sLit "When checking the transformation rule") <+>