[project @ 2003-02-26 17:04:11 by simonpj]
authorsimonpj <unknown>
Wed, 26 Feb 2003 17:04:16 +0000 (17:04 +0000)
committersimonpj <unknown>
Wed, 26 Feb 2003 17:04:16 +0000 (17:04 +0000)
----------------------------------
Improve higher-rank type inference
----------------------------------

Yanling Wang pointed out that if we have

f = \ (x :: forall a. a->a). x

it would be reasonable to expect that type inference would get the "right"
rank-2 type for f.  She also found that the plausible definition

f :: (forall a. a->a) = \x -> x

acutally failed to type check.

This commit fixes up TcBinds.tcMonoBinds so that it does a better job.
The main idea is that there are three cases to consider in a function binding:

  a) 'f' has a separate type signature
In this case, we know f's type everywhere

  b) The binding is recursive, and there is no type sig
In this case we must give f a monotype in its RHS

  c) The binding is non-recursive, and there is no type sig
Then we do not need to add 'f' to the envt, and can
simply infer a type for the RHS, which may be higher
ranked.

ghc/compiler/typecheck/Inst.lhs
ghc/compiler/typecheck/TcBinds.lhs
ghc/compiler/typecheck/TcClassDcl.lhs
ghc/compiler/typecheck/TcExpr.lhs
ghc/compiler/typecheck/TcHsSyn.lhs
ghc/compiler/typecheck/TcMatches.hi-boot
ghc/compiler/typecheck/TcMatches.hi-boot-5
ghc/compiler/typecheck/TcMatches.hi-boot-6
ghc/compiler/typecheck/TcMatches.lhs
ghc/compiler/typecheck/TcPat.lhs
ghc/compiler/typecheck/TcUnify.lhs

index 981731c..fc21ca1 100644 (file)
@@ -40,7 +40,8 @@ import {-# SOURCE #-} TcExpr( tcExpr )
 
 import HsSyn   ( HsLit(..), HsOverLit(..), HsExpr(..) )
 import TcHsSyn ( TcExpr, TcId, TcIdSet, TypecheckedHsExpr,
-                 mkHsTyApp, mkHsDictApp, mkHsConApp, zonkId
+                 mkHsTyApp, mkHsDictApp, mkHsConApp, zonkId,
+                 mkCoercion, ExprCoFn
                )
 import TcRnMonad
 import TcEnv   ( tcGetInstEnv, tcLookupId, tcLookupTyCon, checkWellStaged, topIdLvl )
@@ -256,7 +257,7 @@ newIPDict orig ip_name ty
 
 
 \begin{code}
-tcInstCall :: InstOrigin  -> TcType -> TcM (TypecheckedHsExpr -> TypecheckedHsExpr, TcType)
+tcInstCall :: InstOrigin  -> TcType -> TcM (ExprCoFn, TcType)
 tcInstCall orig fun_ty -- fun_ty is usually a sigma-type
   = tcInstType VanillaTv fun_ty        `thenM` \ (tyvars, theta, tau) ->
     newDicts orig theta                `thenM` \ dicts ->
@@ -264,7 +265,7 @@ tcInstCall orig fun_ty      -- fun_ty is usually a sigma-type
     let
        inst_fn e = mkHsDictApp (mkHsTyApp e (mkTyVarTys tyvars)) (map instToId dicts)
     in
-    returnM (inst_fn, tau)
+    returnM (mkCoercion inst_fn, tau)
 
 tcInstDataCon :: InstOrigin -> DataCon
              -> TcM ([TcType], -- Types to instantiate at
index 30cad8c..7171ed2 100644 (file)
@@ -255,8 +255,9 @@ tcBindWithSigs top_lvl mbind sigs is_rec
     )                                          $
 
        -- TYPECHECK THE BINDINGS
-    getLIE (tcMonoBinds mbind tc_ty_sigs is_rec)       `thenM` \ ((mbind', binder_names, mono_ids), lie_req) ->
+    getLIE (tcMonoBinds mbind tc_ty_sigs is_rec)       `thenM` \ ((mbind', bndr_names_w_ids), lie_req) ->
     let
+       (binder_names, mono_ids) = unzip (bagToList bndr_names_w_ids)
        tau_tvs = foldr (unionVarSet . tyVarsOfType . idType) emptyVarSet mono_ids
     in
 
@@ -620,91 +621,86 @@ The signatures have been dealt with already.
 
 \begin{code}
 tcMonoBinds :: RenamedMonoBinds 
-           -> [TcSigInfo]
-           -> RecFlag
+           -> [TcSigInfo] -> RecFlag
            -> TcM (TcMonoBinds, 
-                     [Name],           -- Bound names
-                     [TcId])           -- Corresponding monomorphic bound things
+                   Bag (Name,          -- Bound names
+                        TcId))         -- Corresponding monomorphic bound things
 
 tcMonoBinds mbinds tc_ty_sigs is_rec
-  = tc_mb_pats mbinds          `thenM` \ (complete_it, tvs, ids, lie_avail) ->
-    let
-       id_list           = bagToList ids
-       (names, mono_ids) = unzip id_list
-
-               -- This last defn is the key one:
-               -- extend the val envt with bindings for the 
-               -- things bound in this group, overriding the monomorphic
-               -- ids with the polymorphic ones from the pattern
-       extra_val_env = case is_rec of
-                         Recursive    -> map mk_bind id_list
-                         NonRecursive -> []
-    in
-       -- Don't know how to deal with pattern-bound existentials yet
-    checkTc (isEmptyBag tvs && null lie_avail) 
-           (existentialExplode mbinds)                 `thenM_` 
-
-       -- *Before* checking the RHSs, but *after* checking *all* the patterns,
-       -- extend the envt with bindings for all the bound ids;
-       --   and *then* override with the polymorphic Ids from the signatures
-       -- That is the whole point of the "complete_it" stuff.
-       --
-       -- There's a further wrinkle: we have to delay extending the environment
-       -- until after we've dealt with any pattern-bound signature type variables
-       -- Consider  f (x::a) = ...f...
-       -- We're going to check that a isn't unified with anything in the envt, 
-       -- so f itself had better not be!  So we pass the envt binding f into
-       -- complete_it, which extends the actual envt in TcMatches.tcMatch, after
-       -- dealing with the signature tyvars
-
-    complete_it extra_val_env                          `thenM` \ mbinds' ->
-
-    returnM (mbinds', names, mono_ids)
+       -- Three stages: 
+       -- 1. Check the patterns, building up an environment binding
+       --    the variables in this group (in the recursive case)
+       -- 2. Extend the environment
+       -- 3. Check the RHSs
+  = tc_mb_pats mbinds          `thenM` \ (complete_it, xve) ->
+    tcExtendLocalValEnv2 (bagToList xve) complete_it
   where
-
-    mk_bind (name, mono_id) = case maybeSig tc_ty_sigs name of
-                               Nothing  -> (name, mono_id)
-                               Just sig -> (idName poly_id, poly_id)
-                                        where
-                                           poly_id = tcSigPolyId sig
-
-    tc_mb_pats EmptyMonoBinds
-      = returnM (\ xve -> returnM EmptyMonoBinds, emptyBag, emptyBag, [])
+    tc_mb_pats EmptyMonoBinds 
+      = returnM (returnM (EmptyMonoBinds, emptyBag), emptyBag)
 
     tc_mb_pats (AndMonoBinds mb1 mb2)
-      = tc_mb_pats mb1         `thenM` \ (complete_it1, tvs1, ids1, lie_avail1) ->
-        tc_mb_pats mb2         `thenM` \ (complete_it2, tvs2, ids2, lie_avail2) ->
+      = tc_mb_pats mb1         `thenM` \ (complete_it1, xve1) ->
+        tc_mb_pats mb2         `thenM` \ (complete_it2, xve2) ->
        let
-          complete_it xve = complete_it1 xve   `thenM` \ mb1' ->
-                            complete_it2 xve   `thenM` \ mb2' ->
-                            returnM (AndMonoBinds mb1' mb2')
+          complete_it = complete_it1   `thenM` \ (mb1', bs1) ->
+                        complete_it2   `thenM` \ (mb2', bs2) ->
+                        returnM (AndMonoBinds mb1' mb2', bs1 `unionBags` bs2)
        in
-       returnM (complete_it,
-                 tvs1 `unionBags` tvs2,
-                 ids1 `unionBags` ids2,
-                 lie_avail1 ++ lie_avail2)
+       returnM (complete_it, xve1 `unionBags` xve2)
 
     tc_mb_pats (FunMonoBind name inf matches locn)
-      = (case maybeSig tc_ty_sigs name of
-           Just sig -> returnM (tcSigMonoId sig)
-           Nothing  -> newLocalName name       `thenM` \ bndr_name ->
-                       newTyVarTy openTypeKind `thenM` \ bndr_ty -> 
-                       -- NB: not a 'hole' tyvar; since there is no type 
-                       -- signature, we revert to ordinary H-M typechecking
-                       -- which means the variable gets an inferred tau-type
-                       returnM (mkLocalId bndr_name bndr_ty)
-       )                                       `thenM` \ bndr_id ->
+               -- Three cases:
+               --      a) Type sig supplied
+               --      b) No type sig and recursive
+               --      c) No type sig and non-recursive
+
+      | Just sig <- maybeSig tc_ty_sigs name 
+      = let    -- (a) There is a type signature
+               -- Use it for the environment extension, and check
+               -- the RHS has the appropriate type (with outer for-alls stripped off)
+          mono_id = tcSigMonoId sig
+          mono_ty = idType mono_id
+          complete_it = addSrcLoc locn                         $
+                        tcMatchesFun name mono_ty matches      `thenM` \ matches' ->
+                        returnM (FunMonoBind mono_id inf matches' locn, 
+                                 unitBag (name, mono_id))
+       in
+       returnM (complete_it, if isRec is_rec then unitBag (name,tcSigPolyId sig) 
+                                             else emptyBag)
+
+      | isRec is_rec
+      =                -- (b) No type signature, and recursive
+               -- So we must use an ordinary H-M type variable
+               -- which means the variable gets an inferred tau-type
+       newLocalName name               `thenM` \ mono_name ->
+       newTyVarTy openTypeKind         `thenM` \ mono_ty ->
        let
-          bndr_ty         = idType bndr_id
-          complete_it xve = addSrcLoc locn                             $
-                            tcMatchesFun xve name bndr_ty matches      `thenM` \ matches' ->
-                            returnM (FunMonoBind bndr_id inf matches' locn)
+          mono_id     = mkLocalId mono_name mono_ty
+          complete_it = addSrcLoc locn                         $
+                        tcMatchesFun name mono_ty matches      `thenM` \ matches' ->
+                        returnM (FunMonoBind mono_id inf matches' locn, 
+                                 unitBag (name, mono_id))
        in
-       returnM (complete_it, emptyBag, unitBag (name, bndr_id), [])
-
+       returnM (complete_it, unitBag (name, mono_id))
+
+      | otherwise      -- (c) No type signature, and non-recursive
+      =        let             -- So we can use a 'hole' type to infer a higher-rank type
+          complete_it 
+               = addSrcLoc locn                        $
+                 newHoleTyVarTy                        `thenM` \ fun_ty -> 
+                 tcMatchesFun name fun_ty matches      `thenM` \ matches' ->
+                 readHoleResult fun_ty                 `thenM` \ fun_ty' ->
+                 newLocalName name                     `thenM` \ mono_name ->
+                 let
+                    mono_id = mkLocalId mono_name fun_ty'
+                 in
+                 returnM (FunMonoBind mono_id inf matches' locn, 
+                          unitBag (name, mono_id))
+       in
+       returnM (complete_it, emptyBag)
+       
     tc_mb_pats bind@(PatMonoBind pat grhss locn)
       = addSrcLoc locn         $
-       newHoleTyVarTy                  `thenM` \ pat_ty -> 
 
                --      Now typecheck the pattern
                -- We do now support binding fresh (not-already-in-scope) scoped 
@@ -714,16 +710,21 @@ tcMonoBinds mbinds tc_ty_sigs is_rec
                -- The type variables are brought into scope in tc_binds_and_then,
                -- so we don't have to do anything here.
 
-       tcPat tc_pat_bndr pat pat_ty            `thenM` \ (pat', tvs, ids, lie_avail) ->
-       readHoleResult pat_ty                   `thenM` \ pat_ty' ->
+       newHoleTyVarTy                  `thenM` \ pat_ty -> 
+       tcPat tc_pat_bndr pat pat_ty    `thenM` \ (pat', tvs, ids, lie_avail) ->
+       readHoleResult pat_ty           `thenM` \ pat_ty' ->
+
+       -- Don't know how to deal with pattern-bound existentials yet
+        checkTc (isEmptyBag tvs && null lie_avail) 
+               (existentialExplode bind)       `thenM_` 
+
        let
-          complete_it xve = addSrcLoc locn                             $
-                            addErrCtxt (patMonoBindsCtxt bind) $
-                            tcExtendLocalValEnv2 xve                   $
-                            tcGRHSs PatBindRhs grhss pat_ty'           `thenM` \ grhss' ->
-                            returnM (PatMonoBind pat' grhss' locn)
+          complete_it = addSrcLoc locn                         $
+                        addErrCtxt (patMonoBindsCtxt bind)     $
+                        tcGRHSs PatBindRhs grhss pat_ty'       `thenM` \ grhss' ->
+                        returnM (PatMonoBind pat' grhss' locn, ids)
        in
-       returnM (complete_it, tvs, ids, lie_avail)
+       returnM (complete_it, if isRec is_rec then ids else emptyBag)
 
        -- tc_pat_bndr is used when dealing with a LHS binder in a pattern.
        -- If there was a type sig for that Id, we want to make it much
@@ -735,9 +736,8 @@ tcMonoBinds mbinds tc_ty_sigs is_rec
        
     tc_pat_bndr name pat_ty
        = case maybeSig tc_ty_sigs name of
-           Nothing
-               -> newLocalName name    `thenM` \ bndr_name ->
-                  tcMonoPatBndr bndr_name pat_ty
+           Nothing  -> newLocalName name                       `thenM` \ bndr_name ->
+                       tcMonoPatBndr bndr_name pat_ty
 
            Just sig -> addSrcLoc (getSrcLoc name)              $
                        tcSubPat (idType mono_id) pat_ty        `thenM` \ co_fn ->
index 2ebe668..bf829aa 100644 (file)
@@ -457,7 +457,7 @@ tcMethodBind xtve inst_tyvars inst_theta avail_insts prags
      tcExtendTyVarEnv2 xtve (
        addErrCtxt (methodCtxt sel_id)          $
        getLIE (tcMonoBinds meth_bind [meth_sig] NonRecursive)
-     )                                         `thenM` \ ((meth_bind, _, _), meth_lie) ->
+     )                                         `thenM` \ ((meth_bind, _), meth_lie) ->
 
        -- Now do context reduction.   We simplify wrt both the local tyvars
        -- and the ones of the class/instance decl, so that there is
index 296c504..6cfd445 100644 (file)
@@ -19,11 +19,10 @@ import qualified DsMeta
 
 import HsSyn           ( HsExpr(..), HsLit(..), ArithSeqInfo(..), recBindFields )
 import RnHsSyn         ( RenamedHsExpr, RenamedRecordBinds )
-import TcHsSyn         ( TcExpr, TcRecordBinds, hsLitType, mkHsDictApp, mkHsTyApp, mkHsLet )
+import TcHsSyn         ( TcExpr, TcRecordBinds, hsLitType, mkHsDictApp, mkHsTyApp, mkHsLet, (<$>) )
 import TcRnMonad
-import TcUnify         ( tcSubExp, tcGen, (<$>),
-                         unifyTauTy, unifyFunTy, unifyListTy, unifyPArrTy,
-                         unifyTupleTy )
+import TcUnify         ( tcSubExp, tcGen,
+                         unifyTauTy, unifyFunTy, unifyListTy, unifyPArrTy, unifyTupleTy )
 import BasicTypes      ( isMarkedStrict )
 import Inst            ( InstOrigin(..), 
                          newOverloadedLit, newMethodFromName, newIPDict,
@@ -34,7 +33,7 @@ import TcBinds                ( tcBindsAndThen )
 import TcEnv           ( tcLookupClass, tcLookupGlobal_maybe, tcLookupIdLvl,
                          tcLookupTyCon, tcLookupDataCon, tcLookupId
                        )
-import TcMatches       ( tcMatchesCase, tcMatchLambda, tcDoStmts )
+import TcMatches       ( tcMatchesCase, tcMatchLambda, tcDoStmts, tcThingWithSig )
 import TcMonoType      ( tcHsSigType, UserTypeCtxt(..) )
 import TcPat           ( badFieldCon )
 import TcMType         ( tcInstTyVars, tcInstType, newHoleTyVarTy, zapToType,
@@ -136,17 +135,10 @@ tcMonoExpr (HsIPVar ip) res_ty
 
 \begin{code}
 tcMonoExpr in_expr@(ExprWithTySig expr poly_ty) res_ty
- = addErrCtxt (exprSigCtxt in_expr)    $
-   tcHsSigType ExprSigCtxt poly_ty     `thenM` \ sig_tc_ty ->
-   tcExpr expr sig_tc_ty               `thenM` \ expr' ->
-
-       -- Must instantiate the outer for-alls of sig_tc_ty
-       -- else we risk instantiating a ? res_ty to a forall-type
-       -- which breaks the invariant that tcMonoExpr only returns phi-types
-   tcInstCall SignatureOrigin sig_tc_ty        `thenM` \ (inst_fn, inst_sig_ty) ->
-   tcSubExp res_ty inst_sig_ty         `thenM` \ co_fn ->
-
-   returnM (co_fn <$> inst_fn expr')
+ = addErrCtxt (exprSigCtxt in_expr)                    $
+   tcHsSigType ExprSigCtxt poly_ty                     `thenM` \ sig_tc_ty ->
+   tcThingWithSig sig_tc_ty (tcMonoExpr expr) res_ty   `thenM` \ (co_fn, expr') ->
+   returnM (co_fn <$> expr')
 
 tcMonoExpr (HsType ty) res_ty
   = failWithTc (text "Can't handle type argument:" <+> ppr ty)
@@ -832,7 +824,7 @@ tcId name   -- Look up the Id and instantiate its type
     loop fun fun_ty 
        | isSigmaTy fun_ty
        = tcInstCall orig fun_ty        `thenM` \ (inst_fn, tau) ->
-         loop (inst_fn fun) tau
+         loop (inst_fn <$> fun) tau
 
        | otherwise
        = returnM (fun, fun_ty)
index a09eb59..24dc515 100644 (file)
@@ -27,6 +27,11 @@ module TcHsSyn (
        mkHsTyLam, mkHsDictLam, mkHsLet,
        hsLitType, hsPatType, 
 
+       -- Coercions
+       Coercion, ExprCoFn, PatCoFn, 
+       (<$>), (<.>), mkCoercion, 
+       idCoercion, isIdCoercion,
+
        -- re-exported from TcMonad
        TcId, TcIdSet,
 
@@ -65,6 +70,7 @@ import VarSet
 import VarEnv
 import BasicTypes ( RecFlag(..), Boxity(..), IPName(..), ipNameName, mapIPName )
 import Maybes    ( orElse )
+import Maybe     ( isNothing )
 import Unique    ( Uniquable(..) )
 import SrcLoc    ( noSrcLoc )
 import Bag
@@ -182,12 +188,37 @@ hsLitType (HsDoublePrim d) = doublePrimTy
 hsLitType (HsLitLit _ ty)  = ty
 \end{code}
 
+%************************************************************************
+%*                                                                     *
+\subsection{Coercion functions}
+%*                                                                     *
+%************************************************************************
+
 \begin{code}
--- zonkId is used *during* typechecking just to zonk the Id's type
-zonkId :: TcId -> TcM TcId
-zonkId id
-  = zonkTcType (idType id) `thenM` \ ty' ->
-    returnM (setIdType id ty')
+type Coercion a = Maybe (a -> a)
+       -- Nothing => identity fn
+
+type ExprCoFn = Coercion TypecheckedHsExpr
+type PatCoFn  = Coercion TcPat
+
+(<.>) :: Coercion a -> Coercion a -> Coercion a        -- Composition
+Nothing <.> Nothing = Nothing
+Nothing <.> Just f  = Just f
+Just f  <.> Nothing = Just f
+Just f1 <.> Just f2 = Just (f1 . f2)
+
+(<$>) :: Coercion a -> a -> a
+Just f  <$> e = f e
+Nothing <$> e = e
+
+mkCoercion :: (a -> a) -> Coercion a
+mkCoercion f = Just f
+
+idCoercion :: Coercion a
+idCoercion = Nothing
+
+isIdCoercion :: Coercion a -> Bool
+isIdCoercion = isNothing
 \end{code}
 
 
@@ -197,7 +228,16 @@ zonkId id
 %*                                                                     *
 %************************************************************************
 
-This zonking pass runs over the bindings
+\begin{code}
+-- zonkId is used *during* typechecking just to zonk the Id's type
+zonkId :: TcId -> TcM TcId
+zonkId id
+  = zonkTcType (idType id) `thenM` \ ty' ->
+    returnM (setIdType id ty')
+\end{code}
+
+The rest of the zonking is done *after* typechecking.
+The main zonking pass runs over the bindings
 
  a) to convert TcTyVars to TyVars etc, dereferencing any bindings etc
  b) convert unbound TcTyVar to Void
index 735e159..cdb14ff 100644 (file)
@@ -8,8 +8,7 @@ _declarations_
              -> TcType.TcType
              -> TcMonad.TcM s (TcHsSyn.TcGRHSs, TcMonad.LIE) ;;
 3 tcMatchesFun _:_ _forall_ [s] => 
-               [(Name.Name,Var.Id)]
-            -> Name.Name
+               Name.Name
             -> TcType.TcType
             -> [RnHsSyn.RenamedMatch]
             -> TcMonad.TcM s ([TcHsSyn.TcMatch], TcMonad.LIE) ;;
index 881a6cf..3ab2fb8 100644 (file)
@@ -5,8 +5,7 @@ __export TcMatches tcGRHSs tcMatchesFun;
              -> TcType.TcType
              -> TcRnTypes.TcM TcHsSyn.TcGRHSs ;
 1 tcMatchesFun :: 
-               [(Name.Name,Var.Id)]
-            -> Name.Name
+               Name.Name
             -> TcType.TcType
             -> [RnHsSyn.RenamedMatch]
             -> TcRnTypes.TcM [TcHsSyn.TcMatch] ;
index c35bfee..624f36b 100644 (file)
@@ -5,9 +5,7 @@ tcGRHSs ::  HsExpr.HsMatchContext Name.Name
              -> TcType.TcType
              -> TcRnTypes.TcM TcHsSyn.TcGRHSs
 
-tcMatchesFun :: 
-               [(Name.Name,Var.Id)]
-            -> Name.Name
+tcMatchesFun :: Name.Name
             -> TcType.TcType
             -> [RnHsSyn.RenamedMatch]
             -> TcRnTypes.TcM [TcHsSyn.TcMatch]
index f1048d8..55c7a0c 100644 (file)
@@ -5,7 +5,7 @@
 
 \begin{code}
 module TcMatches ( tcMatchesFun, tcMatchesCase, tcMatchLambda, 
-                  tcDoStmts, tcStmtsAndThen, tcGRHSs 
+                  tcDoStmts, tcStmtsAndThen, tcGRHSs, tcThingWithSig
        ) where
 
 #include "HsVersions.h"
@@ -21,20 +21,22 @@ import HsSyn                ( HsExpr(..), HsBinds(..), Match(..), GRHSs(..), GRHS(..),
 import RnHsSyn         ( RenamedMatch, RenamedGRHSs, RenamedStmt, 
                          RenamedPat, RenamedMatchContext )
 import TcHsSyn         ( TcMatch, TcGRHSs, TcStmt, TcDictBinds, TcHsBinds, 
-                         TcMonoBinds, TcPat, TcStmt )
+                         TcMonoBinds, TcPat, TcStmt, ExprCoFn,
+                         isIdCoercion, (<$>), (<.>) )
 
 import TcRnMonad
 import TcMonoType      ( tcAddScopedTyVars, tcHsSigType, UserTypeCtxt(..) )
-import Inst            ( tcSyntaxName )
+import Inst            ( tcSyntaxName, tcInstCall )
 import TcEnv           ( TcId, tcLookupLocalIds, tcLookupId, tcExtendLocalValEnv, tcExtendLocalValEnv2 )
 import TcPat           ( tcPat, tcMonoPatBndr )
 import TcMType         ( newTyVarTy, newTyVarTys, zonkTcType, zapToType )
-import TcType          ( TcType, TcTyVar, tyVarsOfType, tidyOpenTypes, tidyOpenType,
+import TcType          ( TcType, TcTyVar, TcSigmaType, TcRhoType,
+                         tyVarsOfType, tidyOpenTypes, tidyOpenType, isSigmaTy,
                          mkFunTy, isOverloadedTy, liftedTypeKind, openTypeKind, 
                          mkArrowKind, mkAppTy )
 import TcBinds         ( tcBindsAndThen )
 import TcUnify         ( unifyPArrTy,subFunTy, unifyListTy, unifyTauTy,
-                         checkSigTyVarsWrt, tcSubExp, isIdCoercion, (<$>) )
+                         checkSigTyVarsWrt, tcSubExp, tcGen )
 import TcSimplify      ( tcSimplifyCheck, bindInstsOfLocalFuns )
 import Name            ( Name )
 import PrelNames       ( monadNames, mfixName )
@@ -63,13 +65,12 @@ is used in error messages.  It checks that all the equations have the
 same number of arguments before using @tcMatches@ to do the work.
 
 \begin{code}
-tcMatchesFun :: [(Name,Id)]    -- Bindings for the variables bound in this group
-            -> Name
+tcMatchesFun :: Name
             -> TcType          -- Expected type
             -> [RenamedMatch]
             -> TcM [TcMatch]
 
-tcMatchesFun xve fun_name expected_ty matches@(first_match:_)
+tcMatchesFun fun_name expected_ty matches@(first_match:_)
   =     -- Check that they all have the same no of arguments
         -- Set the location to that of the first equation, so that
         -- any inter-equation error messages get some vaguely
@@ -86,7 +87,7 @@ tcMatchesFun xve fun_name expected_ty matches@(first_match:_)
        -- may show up as something wrong with the (non-existent) type signature
 
        -- No need to zonk expected_ty, because subFunTy does that on the fly
-    tcMatches xve (FunRhs fun_name) matches expected_ty
+    tcMatches (FunRhs fun_name) matches expected_ty
 \end{code}
 
 @tcMatchesCase@ doesn't do the argument-count check because the
@@ -100,22 +101,21 @@ tcMatchesCase :: [RenamedMatch]           -- The case alternatives
 
 tcMatchesCase matches expr_ty
   = newTyVarTy openTypeKind                                    `thenM` \ scrut_ty ->
-    tcMatches [] CaseAlt matches (mkFunTy scrut_ty expr_ty)    `thenM` \ matches' ->
+    tcMatches CaseAlt matches (mkFunTy scrut_ty expr_ty)       `thenM` \ matches' ->
     returnM (scrut_ty, matches')
 
 tcMatchLambda :: RenamedMatch -> TcType -> TcM TcMatch
-tcMatchLambda match res_ty = tcMatch [] LambdaExpr match res_ty
+tcMatchLambda match res_ty = tcMatch LambdaExpr match res_ty
 \end{code}
 
 
 \begin{code}
-tcMatches :: [(Name,Id)]
-         -> RenamedMatchContext 
+tcMatches :: RenamedMatchContext 
          -> [RenamedMatch]
          -> TcType
          -> TcM [TcMatch]
 
-tcMatches xve ctxt matches expected_ty
+tcMatches ctxt matches expected_ty
   =    -- If there is more than one branch, and expected_ty is a 'hole',
        -- all branches must be types, not type schemes, otherwise the
        -- in which we check them would affect the result.
@@ -126,7 +126,7 @@ tcMatches xve ctxt matches expected_ty
 
     mappM (tc_match expected_ty') matches
   where
-    tc_match expected_ty match = tcMatch xve ctxt match expected_ty
+    tc_match expected_ty match = tcMatch ctxt match expected_ty
 \end{code}
 
 
@@ -137,8 +137,7 @@ tcMatches xve ctxt matches expected_ty
 %************************************************************************
 
 \begin{code}
-tcMatch :: [(Name,Id)]
-       -> RenamedMatchContext
+tcMatch :: RenamedMatchContext
        -> RenamedMatch
        -> TcType       -- Expected result-type of the Match.
                        -- Early unification with this guy gives better error messages
@@ -147,7 +146,7 @@ tcMatch :: [(Name,Id)]
                        -- where there are n patterns.
        -> TcM TcMatch
 
-tcMatch xve1 ctxt match@(Match pats maybe_rhs_sig grhss) expected_ty
+tcMatch ctxt match@(Match pats maybe_rhs_sig grhss) expected_ty
   = addSrcLoc (getMatchLoc match)              $       -- At one stage I removed this;
     addErrCtxt (matchCtxt ctxt match)          $       -- I'm not sure why, so I put it back
     tcMatchPats pats expected_ty tc_grhss      `thenM` \ (pats', grhss', ex_binds) ->
@@ -155,17 +154,14 @@ tcMatch xve1 ctxt match@(Match pats maybe_rhs_sig grhss) expected_ty
 
   where
     tc_grhss rhs_ty 
-       = tcExtendLocalValEnv2 xve1                     $
-
-               -- Deal with the result signature
+       =       -- Deal with the result signature
          case maybe_rhs_sig of
            Nothing ->  tcGRHSs ctxt grhss rhs_ty
 
            Just sig ->  tcAddScopedTyVars [sig]        $
                                -- Bring into scope the type variables in the signature
-                        tcHsSigType ResSigCtxt sig     `thenM` \ sig_ty ->
-                        tcGRHSs ctxt grhss sig_ty      `thenM` \ grhss' ->
-                        tcSubExp rhs_ty sig_ty         `thenM` \ co_fn  ->
+                        tcHsSigType ResSigCtxt sig                             `thenM` \ sig_ty ->
+                        tcThingWithSig sig_ty (tcGRHSs ctxt grhss) rhs_ty      `thenM` \ (co_fn, grhss') ->
                         returnM (lift_grhss co_fn rhs_ty grhss')
 
 -- lift_grhss pushes the coercion down to the right hand sides,
@@ -173,7 +169,7 @@ tcMatch xve1 ctxt match@(Match pats maybe_rhs_sig grhss) expected_ty
 lift_grhss co_fn rhs_ty grhss 
   | isIdCoercion co_fn = grhss
 lift_grhss co_fn rhs_ty (GRHSs grhss binds ty)
-  = GRHSs (map lift_grhs grhss) binds rhs_ty   -- Change the type, since we
+  = GRHSs (map lift_grhs grhss) binds rhs_ty   -- Change the type, since the coercion does
   where
     lift_grhs (GRHS stmts loc) = GRHS (map lift_stmt stmts) loc
              
@@ -206,6 +202,31 @@ tcGRHSs ctxt (GRHSs grhss binds _) expected_ty
 \end{code}
 
 
+\begin{code}
+tcThingWithSig :: TcSigmaType          -- Type signature
+              -> (TcRhoType -> TcM r)  -- How to type check the thing inside
+              -> TcRhoType             -- Overall expected result type
+              -> TcM (ExprCoFn, r)
+-- Used for expressions with a type signature, and for result type signatures
+
+tcThingWithSig sig_ty thing_inside res_ty
+  | not (isSigmaTy sig_ty)
+  = thing_inside sig_ty                `thenM` \ result ->
+    tcSubExp res_ty sig_ty     `thenM` \ co_fn ->
+    returnM (co_fn, result)
+
+  | otherwise  -- The signature has some outer foralls
+  =    -- Must instantiate the outer for-alls of sig_tc_ty
+       -- else we risk instantiating a ? res_ty to a forall-type
+       -- which breaks the invariant that tcMonoExpr only returns phi-types
+    tcGen sig_ty emptyVarSet thing_inside      `thenM` \ (gen_fn, result) ->
+    tcInstCall SignatureOrigin sig_ty          `thenM` \ (inst_fn, inst_sig_ty) ->
+    tcSubExp res_ty inst_sig_ty                        `thenM` \ co_fn ->
+    returnM (co_fn <.> inst_fn <.> gen_fn,  result)
+       -- Note that we generalise, then instantiate. Ah well.
+\end{code}
+
+
 %************************************************************************
 %*                                                                     *
 \subsection{tcMatchPats}
index 2f23094..bfd90a6 100644 (file)
@@ -12,7 +12,9 @@ module TcPat ( tcPat, tcMonoPatBndr, tcSubPat,
 
 import HsSyn           ( Pat(..), HsConDetails(..), HsLit(..), HsOverLit(..), HsExpr(..) )
 import RnHsSyn         ( RenamedPat )
-import TcHsSyn         ( TcPat, TcId, hsLitType )
+import TcHsSyn         ( TcPat, TcId, hsLitType,
+                         mkCoercion, idCoercion, isIdCoercion,
+                         (<$>), PatCoFn )
 
 import TcRnMonad
 import Inst            ( InstOrigin(..),
@@ -27,9 +29,7 @@ import TcMType                ( newTyVarTy, zapToType, arityErr )
 import TcType          ( TcType, TcTyVar, TcSigmaType, 
                          mkClassPred, liftedTypeKind )
 import TcUnify         ( tcSubOff, TcHoleType, 
-                         unifyTauTy, unifyListTy, unifyPArrTy, unifyTupleTy,  
-                         mkCoercion, idCoercion, isIdCoercion,
-                         (<$>), PatCoFn )
+                         unifyTauTy, unifyListTy, unifyPArrTy, unifyTupleTy )  
 import TcMonoType      ( tcHsSigType, UserTypeCtxt(..) )
 
 import TysWiredIn      ( stringTy )
index c04d310..1a05fd5 100644 (file)
@@ -12,12 +12,7 @@ module TcUnify (
        -- Various unifications
   unifyTauTy, unifyTauTyList, unifyTauTyLists, 
   unifyFunTy, unifyListTy, unifyPArrTy, unifyTupleTy,
-  unifyKind, unifyKinds, unifyOpenTypeKind, unifyFunKind,
-
-       -- Coercions
-  Coercion, ExprCoFn, PatCoFn, 
-  (<$>), (<.>), mkCoercion, 
-  idCoercion, isIdCoercion
+  unifyKind, unifyKinds, unifyOpenTypeKind, unifyFunKind
 
   ) where
 
@@ -25,7 +20,8 @@ module TcUnify (
 
 
 import HsSyn           ( HsExpr(..) )
-import TcHsSyn         ( TypecheckedHsExpr, TcPat, mkHsLet )
+import TcHsSyn         ( TypecheckedHsExpr, TcPat, mkHsLet,
+                         ExprCoFn, idCoercion, isIdCoercion, mkCoercion, (<.>), (<$>) )
 import TypeRep         ( Type(..), SourceType(..), TyNote(..), openKindCon )
 
 import TcRnMonad         -- TcType, amongst others
@@ -181,7 +177,7 @@ tc_sub exp_sty expected_ty act_sty actual_ty
   | isSigmaTy actual_ty
   = tcInstCall Rank2Origin actual_ty           `thenM` \ (inst_fn, body_ty) ->
     tc_sub exp_sty expected_ty body_ty body_ty `thenM` \ co_fn ->
-    returnM (co_fn <.> mkCoercion inst_fn)
+    returnM (co_fn <.> inst_fn)
 
 -----------------------------------
 -- Function case
@@ -353,39 +349,6 @@ tcGen expected_ty extra_tvs thing_inside   -- We expect expected_ty to be a forall
 
 %************************************************************************
 %*                                                                     *
-\subsection{Coercion functions}
-%*                                                                     *
-%************************************************************************
-
-\begin{code}
-type Coercion a = Maybe (a -> a)
-       -- Nothing => identity fn
-
-type ExprCoFn = Coercion TypecheckedHsExpr
-type PatCoFn  = Coercion TcPat
-
-(<.>) :: Coercion a -> Coercion a -> Coercion a        -- Composition
-Nothing <.> Nothing = Nothing
-Nothing <.> Just f  = Just f
-Just f  <.> Nothing = Just f
-Just f1 <.> Just f2 = Just (f1 . f2)
-
-(<$>) :: Coercion a -> a -> a
-Just f  <$> e = f e
-Nothing <$> e = e
-
-mkCoercion :: (a -> a) -> Coercion a
-mkCoercion f = Just f
-
-idCoercion :: Coercion a
-idCoercion = Nothing
-
-isIdCoercion :: Coercion a -> Bool
-isIdCoercion = isNothing
-\end{code}
-
-%************************************************************************
-%*                                                                     *
 \subsection[Unify-exported]{Exported unification functions}
 %*                                                                     *
 %************************************************************************