[project @ 2005-02-25 13:06:31 by simonpj]
[ghc-hetmet.git] / ghc / compiler / typecheck / TcBinds.lhs
index f0de50a..a16cddc 100644 (file)
@@ -23,8 +23,10 @@ import TcHsSyn               ( TcId, TcDictBinds, zonkId, mkHsLet )
 import TcRnMonad
 import Inst            ( InstOrigin(..), newDictsAtLoc, newIPDict, instToId )
 import TcEnv           ( tcExtendIdEnv, tcExtendIdEnv2, tcExtendTyVarEnv2, 
-                         newLocalName, tcLookupLocalIds, pprBinders )
-import TcUnify         ( Expected(..), tcInfer, checkSigTyVars, sigCtxt )
+                         newLocalName, tcLookupLocalIds, pprBinders,
+                         tcGetGlobalTyVars )
+import TcUnify         ( Expected(..), tcInfer, unifyTheta, 
+                         bleatEscapedTvs, sigCtxt )
 import TcSimplify      ( tcSimplifyInfer, tcSimplifyInferCheck, tcSimplifyRestricted, 
                          tcSimplifyToDicts, tcSimplifyIPs )
 import TcHsType                ( tcHsSigType, UserTypeCtxt(..), tcAddLetBoundTyVars,
@@ -32,16 +34,15 @@ import TcHsType             ( tcHsSigType, UserTypeCtxt(..), tcAddLetBoundTyVars,
                        )
 import TcPat           ( tcPat, PatCtxt(..) )
 import TcSimplify      ( bindInstsOfLocalFuns )
-import TcMType         ( newTyFlexiVarTy, tcSkolType, zonkQuantifiedTyVar, zonkTcTypes )
+import TcMType         ( newTyFlexiVarTy, zonkQuantifiedTyVar, 
+                         tcInstSigType, zonkTcTypes, zonkTcTyVar )
 import TcType          ( TcTyVar, SkolemInfo(SigSkol), 
                          TcTauType, TcSigmaType, 
-                         TvSubstEnv, mkOpenTvSubst, substTheta, substTy, 
                          mkTyVarTy, mkForAllTys, mkFunTys, tyVarsOfType, 
                          mkForAllTy, isUnLiftedType, tcGetTyVar_maybe, 
-                         mkTyVarTys )
-import Unify           ( tcMatchPreds )
+                         mkTyVarTys, tidyOpenTyVar, tidyOpenType )
 import Kind            ( argTypeKind )
-import VarEnv          ( lookupVarEnv ) 
+import VarEnv          ( TyVarEnv, emptyVarEnv, lookupVarEnv, extendVarEnv, emptyTidyEnv ) 
 import TysPrim         ( alphaTyVar )
 import Id              ( mkLocalId, mkSpecPragmaId, setInlinePragma )
 import Var             ( idType, idName )
@@ -51,7 +52,6 @@ import VarSet
 import SrcLoc          ( Located(..), unLoc, noLoc, getLoc )
 import Bag
 import Util            ( isIn )
-import Maybes          ( orElse )
 import BasicTypes      ( TopLevelFlag(..), RecFlag(..), isNonRec, isRec, 
                          isNotTopLevel, isAlwaysActive )
 import FiniteMap       ( listToFM, lookupFM )
@@ -464,6 +464,7 @@ tcMonoBinds binds lookup_sig is_rec
 
        ; binds' <- tcExtendTyVarEnv2 rhs_tvs   $
                    tcExtendIdEnv2   rhs_id_env $
+                   traceTc (text "tcMonoBinds" <+> vcat [ppr n <+> ppr id <+> ppr (idType id) | (n,id) <- rhs_id_env]) `thenM_`
                    mapBagM (wrapLocM tcRhs) tc_binds
        ; return (binds', mono_info) }
    where
@@ -560,6 +561,27 @@ getMonoBindInfo tc_binds
 %*                                                                     *
 %************************************************************************
 
+Type signatures are tricky.  Consider
+
+  x :: [a]
+  y :: b
+  (x,y,z) = ([y,z], z, head x)
+
+Here, x and y have type sigs, which go into the environment.  We used to
+instantiate their types with skolem constants, and push those types into
+the RHS, so we'd typecheck the RHS with type
+       ( [a*], b*, c )
+where a*, b* are skolem constants, and c is an ordinary meta type varible.
+
+The trouble is that the occurrences of z in the RHS force a* and b* to 
+be the *same*, so we can't make them into skolem constants that don't unify
+with each other.  Alas.
+
+Current solution: don't use skolems at all.  Instead, instantiate the type
+signatures with ordinary meta type variables, and check at the end that
+each group has remained distinct.
+
+
 \begin{code}
 tcTySigs :: [LSig Name] -> TcM [TcSigInfo]
 -- The trick here is that all the signatures should have the same
@@ -570,8 +592,21 @@ tcTySigs [] = return []
 
 tcTySigs sigs
   = do { (tc_sig1 : tc_sigs) <- mappM tcTySig sigs
-       ; tc_sigs'            <- mapM (checkSigCtxt tc_sig1) tc_sigs
-        ; return (tc_sig1 : tc_sigs') }
+       ; mapM (check_ctxt tc_sig1) tc_sigs
+        ; return (tc_sig1 : tc_sigs) }
+  where
+       -- Check tha all the signature contexts are the same
+       -- The type signatures on a mutually-recursive group of definitions
+       -- must all have the same context (or none).
+       --
+       -- We unify them because, with polymorphic recursion, their types
+       -- might not otherwise be related.  This is a rather subtle issue.
+    check_ctxt :: TcSigInfo -> TcSigInfo -> TcM ()
+    check_ctxt sig1@(TcSigInfo { sig_theta = theta1 }) sig@(TcSigInfo { sig_theta = theta })
+       = setSrcSpan (instLocSrcSpan (sig_loc sig))     $
+         addErrCtxt (sigContextsCtxt sig1 sig)         $
+         unifyTheta theta1 theta
+
 
 tcTySig :: LSig Name -> TcM TcSigInfo
 tcTySig (L span (Sig (L _ name) ty))
@@ -587,51 +622,11 @@ tcTySig (L span (Sig (L _ name) ty))
                                L _ (HsForAllTy _ tvs _ _) -> hsLTyVarNames tvs
                                other                      -> []
 
-       ; (tvs, theta, tau) <- tcSkolType rigid_info sigma_ty
+       ; (tvs, theta, tau) <- tcInstSigType sigma_ty
        ; loc <- getInstLoc (SigOrigin rigid_info)
        ; return (TcSigInfo { sig_id = poly_id, sig_scoped = scoped_names,
                              sig_tvs = tvs, sig_theta = theta, sig_tau = tau, 
                              sig_loc = loc }) }
-
-checkSigCtxt :: TcSigInfo -> TcSigInfo -> TcM TcSigInfo
-checkSigCtxt sig1 sig@(TcSigInfo { sig_tvs = tvs, sig_theta = theta, sig_tau = tau })
-  =    -- Try to match the context of this signature with 
-       -- that of the first signature
-    case tcMatchPreds (sig_tvs sig) (sig_theta sig) (sig_theta sig1) of {
-       Nothing   -> bale_out ;
-       Just tenv ->
-
-    case check_tvs tenv tvs of {
-       Nothing   -> bale_out ;
-       Just tvs' -> 
-
-    let 
-       subst  = mkOpenTvSubst tenv
-    in
-    return (sig { sig_tvs   = tvs', 
-                 sig_theta = substTheta subst theta, 
-                 sig_tau   = substTy subst tau }) }}
-
-  where
-    bale_out = setSrcSpan (instLocSrcSpan (sig_loc sig)) $
-               failWithTc $
-               sigContextsErr (sig_id sig1) (sig_id sig)
-
-       -- Rather tedious check that the type variables
-       -- have been matched only with another type variable,
-       -- and that two type variables have not been matched
-       -- with the same one
-       -- A return of Nothing indicates that one of the bad
-       -- things has happened
-    check_tvs :: TvSubstEnv -> [TcTyVar] -> Maybe [TcTyVar]
-    check_tvs tenv [] = Just []
-    check_tvs tenv (tv:tvs) 
-       = do { let ty = lookupVarEnv tenv tv `orElse` mkTyVarTy tv
-            ; tv'  <- tcGetTyVar_maybe ty
-            ; tvs' <- check_tvs tenv tvs
-            ; if tv' `elem` tvs'
-              then Nothing
-              else Just (tv':tvs') }
 \end{code}
 
 \begin{code}
@@ -680,34 +675,74 @@ generalise top_lvl is_unrestricted mono_infos sigs lie_req
     is_mono_sig sig = null (sig_theta sig)
     doc = ptext SLIT("type signature(s) for") <+> pprBinders bndr_names
 
-mkMethInst (TcSigInfo { sig_id = poly_id, sig_tvs = tvs, 
-                       sig_theta = theta, sig_tau = tau, sig_loc = loc }) mono_id
-  = Method mono_id poly_id (mkTyVarTys tvs) theta tau loc
+    mkMethInst (TcSigInfo { sig_id = poly_id, sig_tvs = tvs, 
+                           sig_theta = theta, sig_tau = tau, sig_loc = loc }) mono_id
+      = Method mono_id poly_id (mkTyVarTys tvs) theta tau loc
 
 checkSigsTyVars :: [TcTyVar] -> [TcSigInfo] -> TcM [TcTyVar]
 checkSigsTyVars qtvs sigs 
-  = mappM check_one sigs       `thenM` \ sig_tvs_s ->
-    let
-       -- Sigh.  Make sure that all the tyvars in the type sigs
-       -- appear in the returned ty var list, which is what we are
-       -- going to generalise over.  Reason: we occasionally get
-       -- silly types like
-       --      type T a = () -> ()
-       --      f :: T a
-       --      f () = ()
-       -- Here, 'a' won't appear in qtvs, so we have to add it
-
-       sig_tvs = foldl extendVarSetList emptyVarSet sig_tvs_s
-       all_tvs = extendVarSetList sig_tvs qtvs
-    in
-    returnM (varSetElems all_tvs)
+  = do { gbl_tvs <- tcGetGlobalTyVars
+       ; sig_tvs_s <- mappM (check_sig gbl_tvs) sigs
+
+       ; let   -- Sigh.  Make sure that all the tyvars in the type sigs
+               -- appear in the returned ty var list, which is what we are
+               -- going to generalise over.  Reason: we occasionally get
+               -- silly types like
+               --      type T a = () -> ()
+               --      f :: T a
+               --      f () = ()
+               -- Here, 'a' won't appear in qtvs, so we have to add it
+               sig_tvs = foldl extendVarSetList emptyVarSet sig_tvs_s
+               all_tvs = varSetElems (extendVarSetList sig_tvs qtvs)
+       ; returnM all_tvs }
   where
-    check_one (TcSigInfo {sig_id = id, sig_tvs = tvs, sig_theta = theta, sig_tau = tau})
-      = addErrCtxt (ptext SLIT("In the type signature for") 
-                     <+> quotes (ppr id))              $
-       addErrCtxtM (sigCtxt id tvs theta tau)          $
-       do { checkSigTyVars tvs; return tvs }
-\end{code}
+    check_sig gbl_tvs (TcSigInfo {sig_id = id, sig_tvs = tvs, 
+                                 sig_theta = theta, sig_tau = tau})
+      = addErrCtxt (ptext SLIT("In the type signature for") <+> quotes (ppr id))       $
+       addErrCtxtM (sigCtxt id tvs theta tau)                                          $
+       do { tvs' <- checkDistinctTyVars tvs
+          ; ifM (any (`elemVarSet` gbl_tvs) tvs')
+                (bleatEscapedTvs gbl_tvs tvs tvs') 
+          ; return tvs' }
+
+checkDistinctTyVars :: [TcTyVar] -> TcM [TcTyVar]
+-- (checkDistinctTyVars tvs) checks that the tvs from one type signature
+-- are still all type variables, and all distinct from each other.  
+-- It returns a zonked set of type variables.
+-- For example, if the type sig is
+--     f :: forall a b. a -> b -> b
+-- we want to check that 'a' and 'b' haven't 
+--     (a) been unified with a non-tyvar type
+--     (b) been unified with each other (all distinct)
+
+checkDistinctTyVars sig_tvs
+  = do { zonked_tvs <- mapM zonk_one sig_tvs
+       ; foldlM check_dup emptyVarEnv (sig_tvs `zip` zonked_tvs)
+       ; return zonked_tvs }
+  where
+    zonk_one sig_tv = do { ty <- zonkTcTyVar sig_tv
+                        ; case tcGetTyVar_maybe ty of
+                            Just tv' -> return tv'
+                            Nothing  -> bomb_out sig_tv "a type" ty }
+
+    check_dup :: TyVarEnv TcTyVar -> (TcTyVar, TcTyVar) -> TcM (TyVarEnv TcTyVar)
+       -- The TyVarEnv maps each zonked type variable back to its
+       -- corresponding user-written signature type variable
+    check_dup acc (sig_tv, zonked_tv)
+       = case lookupVarEnv acc zonked_tv of
+               Just sig_tv' -> bomb_out sig_tv "another quantified type variable" 
+                                               (mkTyVarTy sig_tv')
+
+               Nothing -> return (extendVarEnv acc zonked_tv sig_tv)
+
+    bomb_out sig_tv doc ty 
+       = failWithTc (ptext SLIT("Quantified type variable") <+> quotes (ppr tidy_tv) 
+                    <+> ptext SLIT("is unified with") <+> text doc <+> ppr tidy_ty)
+       where
+        (env1,  tidy_tv) = tidyOpenTyVar emptyTidyEnv sig_tv
+        (_env2, tidy_ty) = tidyOpenType  env1         ty
+\end{code}    
+
 
 @getTyVarsToGen@ decides what type variables to generalise over.
 
@@ -865,11 +900,14 @@ valSpecSigCtxt v ty
         nest 4 (ppr v <+> dcolon <+> ppr ty)]
 
 -----------------------------------------------
-sigContextsErr id1 id2
-  = vcat [ptext SLIT("Mis-match between the contexts of the signatures for"), 
+sigContextsCtxt sig1 sig2
+  = vcat [ptext SLIT("When matching the contexts of the signatures for"), 
          nest 2 (vcat [ppr id1 <+> dcolon <+> ppr (idType id1),
                        ppr id2 <+> dcolon <+> ppr (idType id2)]),
          ptext SLIT("The signature contexts in a mutually recursive group should all be identical")]
+  where
+    id1 = sig_id sig1
+    id2 = sig_id sig2
 
 
 -----------------------------------------------