Refactoring: mainly rename ic_env_tvs to ic_untch
[ghc-hetmet.git] / compiler / typecheck / TcRules.lhs
index 3ee95cf..71c5399 100644 (file)
@@ -6,13 +6,6 @@
 TcRules: Typechecking transformation rules
 
 \begin{code}
-{-# OPTIONS -w #-}
--- The above warning supression flag is a temporary kludge.
--- While working on this module you are encouraged to remove it and fix
--- any warnings in the module. See
---     http://hackage.haskell.org/trac/ghc/wiki/Commentary/CodingStyle#Warnings
--- for details
-
 module TcRules ( tcRules ) where
 
 import HsSyn
@@ -23,51 +16,66 @@ import TcType
 import TcHsType
 import TcExpr
 import TcEnv
-import Inst
 import Id
+import Var     ( Var )
 import Name
+import VarSet
 import SrcLoc
 import Outputable
 import FastString
+import Data.List( partition )
 \end{code}
 
+Note [Typechecking rules]
+~~~~~~~~~~~~~~~~~~~~~~~~~
+We *infer* the typ of the LHS, and use that type to *check* the type of 
+the RHS.  That means that higher-rank rules work reasonably well. Here's
+an example (test simplCore/should_compile/rule2.hs) produced by Roman:
+
+   foo :: (forall m. m a -> m b) -> m a -> m b
+   foo f = ...
+
+   bar :: (forall m. m a -> m a) -> m a -> m a
+   bar f = ...
+
+   {-# RULES "foo/bar" foo = bar #-}
+
+He wanted the rule to typecheck.
+
 \begin{code}
 tcRules :: [LRuleDecl Name] -> TcM [LRuleDecl TcId]
 tcRules decls = mapM (wrapLocM tcRule) decls
 
 tcRule :: RuleDecl Name -> TcM (RuleDecl TcId)
-tcRule (HsRule name act vars lhs fv_lhs rhs fv_rhs)
-  = addErrCtxt (ruleCtxt name)                 $ do
-    traceTc (ptext (sLit "---- Rule ------") <+> ppr name)
-    rule_ty <- newFlexiTyVarTy openTypeKind
-
-       -- Deal with the tyvars mentioned in signatures
-    (ids, lhs', rhs', lhs_lie, rhs_lie) <-
-      tcRuleBndrs vars $ \ ids -> do
-               -- Now LHS and RHS
-        (lhs', lhs_lie) <- getLIE (tcMonoExpr lhs rule_ty)
-        (rhs', rhs_lie) <- getLIE (tcMonoExpr rhs rule_ty)
-        return (ids, lhs', rhs', lhs_lie, rhs_lie)
-
-               -- Check that LHS has no overloading at all
-    (lhs_dicts, lhs_binds) <- tcSimplifyRuleLhs lhs_lie
-
-       -- Gather the template variables and tyvars
-    let
-       tpl_ids = map instToId lhs_dicts ++ ids
+tcRule (HsRule name act hs_bndrs lhs fv_lhs rhs fv_rhs)
+  = addErrCtxt (ruleCtxt name) $
+    do { traceTc "---- Rule ------" (ppr name)
+
+       -- Note [Typechecking rules]
+       ; vars <- tcRuleBndrs hs_bndrs
+       ; let (id_bndrs, tv_bndrs) = partition isId vars
+       ; (lhs', lhs_lie, rhs', rhs_lie, rule_ty)
+            <- tcExtendTyVarEnv tv_bndrs $
+               tcExtendIdEnv id_bndrs $
+               do { ((lhs', rule_ty), lhs_lie) <- getConstraints (tcInferRho lhs)
+                  ; (rhs', rhs_lie) <- getConstraints (tcMonoExpr rhs rule_ty)
+                  ; return (lhs', lhs_lie, rhs', rhs_lie, rule_ty) }
+
+       ; (lhs_dicts, lhs_ev_binds, rhs_ev_binds) 
+             <- simplifyRule name tv_bndrs lhs_lie rhs_lie
 
        -- IMPORTANT!  We *quantify* over any dicts that appear in the LHS
        -- Reason: 
-       --      a) The particular dictionary isn't important, because its value
+       --      (a) The particular dictionary isn't important, because its value
        --         depends only on the type
        --              e.g     gcd Int $fIntegralInt
        --         Here we'd like to match against (gcd Int any_d) for any 'any_d'
        --
-       --      b) We'd like to make available the dictionaries bound 
-       --         on the LHS in the RHS, so quantifying over them is good
-       --         See the 'lhs_dicts' in tcSimplifyAndCheck for the RHS
+       --      (b) We'd like to make available the dictionaries bound 
+       --          on the LHS in the RHS, so quantifying over them is good
+       --          See the 'lhs_dicts' in tcSimplifyAndCheck for the RHS
 
-       -- We initially quantify over any tyvars free in *either* the rule
+       -- We quantify over any tyvars free in *either* the rule
        --  *or* the bound variables.  The latter is important.  Consider
        --      ss (x,(y,z)) = (x,z)
        --      RULE:  forall v. fst (ss v) = fst v
@@ -75,32 +83,29 @@ tcRule (HsRule name act vars lhs fv_lhs rhs fv_rhs)
        --
        -- We also need to get the free tyvars of the LHS; but we do that
        -- during zonking (see TcHsSyn.zonkRule)
-       --
-       forall_tvs = tyVarsOfTypes (rule_ty : map idType tpl_ids)
 
-       -- RHS can be a bit more lenient.  In particular,
-       -- we let constant dictionaries etc float outwards
-       --
-       -- NB: tcSimplifyInferCheck zonks the forall_tvs, and 
-       --     knocks out any that are constrained by the environment
-    loc <- getInstLoc (SigOrigin (RuleSkol name))
-    (forall_tvs1, rhs_binds) <- tcSimplifyInferCheck loc
-                                        forall_tvs
-                                        lhs_dicts rhs_lie
+       ; let tpl_ids    = lhs_dicts ++ id_bndrs
+             forall_tvs = tyVarsOfTypes (rule_ty : map idType tpl_ids)
 
-    return (HsRule name act
-                   (map (RuleBndr . noLoc) (forall_tvs1 ++ tpl_ids))   -- yuk
-                   (mkHsDictLet lhs_binds lhs') fv_lhs
-                   (mkHsDictLet rhs_binds rhs') fv_rhs)
+            -- Now figure out what to quantify over
+            -- c.f. TcSimplify.simplifyInfer
+       ; zonked_forall_tvs <- zonkTcTyVarsAndFV forall_tvs
+       ; gbl_tvs           <- tcGetGlobalTyVars             -- Already zonked
+       ; qtvs <- zonkQuantifiedTyVars (varSetElems (zonked_forall_tvs `minusVarSet` gbl_tvs))
 
+       ; return (HsRule name act
+                   (map (RuleBndr . noLoc) (qtvs ++ tpl_ids))  -- yuk
+                   (mkHsDictLet lhs_ev_binds lhs') fv_lhs
+                   (mkHsDictLet rhs_ev_binds rhs') fv_rhs) }
 
-tcRuleBndrs [] thing_inside = thing_inside []
-tcRuleBndrs (RuleBndr var : vars) thing_inside
+tcRuleBndrs :: [RuleBndr Name] -> TcM [Var]
+tcRuleBndrs [] 
+  = return []
+tcRuleBndrs (RuleBndr var : rule_bndrs)
   = do         { ty <- newFlexiTyVarTy openTypeKind
-       ; let id = mkLocalId (unLoc var) ty
-       ; tcExtendIdEnv [id] $
-         tcRuleBndrs vars (\ids -> thing_inside (id:ids)) }
-tcRuleBndrs (RuleBndrSig var rn_ty : vars) thing_inside
+        ; vars <- tcRuleBndrs rule_bndrs
+       ; return (mkLocalId (unLoc var) ty : vars) }
+tcRuleBndrs (RuleBndrSig var rn_ty : rule_bndrs)
 --  e.g        x :: a->a
 --  The tyvar 'a' is brought into scope first, just as if you'd written
 --             a::*, x :: a->a
@@ -109,10 +114,13 @@ tcRuleBndrs (RuleBndrSig var rn_ty : vars) thing_inside
        ; let skol_tvs = tcSkolSigTyVars (SigSkol ctxt) tyvars
              id_ty = substTyWith tyvars (mkTyVarTys skol_tvs) ty
              id = mkLocalId (unLoc var) id_ty
-       ; tcExtendTyVarEnv skol_tvs $
-         tcExtendIdEnv [id] $
-         tcRuleBndrs vars (\ids -> thing_inside (id:ids)) }
 
+             -- The type variables scope over subsequent bindings; yuk
+        ; vars <- tcExtendTyVarEnv skol_tvs $ 
+                  tcRuleBndrs rule_bndrs 
+       ; return (skol_tvs ++ id : vars) }
+
+ruleCtxt :: FastString -> SDoc
 ruleCtxt name = ptext (sLit "When checking the transformation rule") <+> 
                doubleQuotes (ftext name)
 \end{code}