From c86e9006fbdc9cb229080dd6a64ce462e9e460af Mon Sep 17 00:00:00 2001 From: simonpj Date: Wed, 26 Feb 2003 17:04:16 +0000 Subject: [PATCH] [project @ 2003-02-26 17:04:11 by simonpj] ---------------------------------- Improve higher-rank type inference ---------------------------------- Yanling Wang pointed out that if we have f = \ (x :: forall a. a->a). x it would be reasonable to expect that type inference would get the "right" rank-2 type for f. She also found that the plausible definition f :: (forall a. a->a) = \x -> x acutally failed to type check. This commit fixes up TcBinds.tcMonoBinds so that it does a better job. The main idea is that there are three cases to consider in a function binding: a) 'f' has a separate type signature In this case, we know f's type everywhere b) The binding is recursive, and there is no type sig In this case we must give f a monotype in its RHS c) The binding is non-recursive, and there is no type sig Then we do not need to add 'f' to the envt, and can simply infer a type for the RHS, which may be higher ranked. --- ghc/compiler/typecheck/Inst.lhs | 7 +- ghc/compiler/typecheck/TcBinds.lhs | 166 ++++++++++++++-------------- ghc/compiler/typecheck/TcClassDcl.lhs | 2 +- ghc/compiler/typecheck/TcExpr.lhs | 26 ++--- ghc/compiler/typecheck/TcHsSyn.lhs | 52 ++++++++- ghc/compiler/typecheck/TcMatches.hi-boot | 3 +- ghc/compiler/typecheck/TcMatches.hi-boot-5 | 3 +- ghc/compiler/typecheck/TcMatches.hi-boot-6 | 4 +- ghc/compiler/typecheck/TcMatches.lhs | 71 +++++++----- ghc/compiler/typecheck/TcPat.lhs | 8 +- ghc/compiler/typecheck/TcUnify.lhs | 45 +------- 11 files changed, 200 insertions(+), 187 deletions(-) diff --git a/ghc/compiler/typecheck/Inst.lhs b/ghc/compiler/typecheck/Inst.lhs index 981731c..fc21ca1 100644 --- a/ghc/compiler/typecheck/Inst.lhs +++ b/ghc/compiler/typecheck/Inst.lhs @@ -40,7 +40,8 @@ import {-# SOURCE #-} TcExpr( tcExpr ) import HsSyn ( HsLit(..), HsOverLit(..), HsExpr(..) ) import TcHsSyn ( TcExpr, TcId, TcIdSet, TypecheckedHsExpr, - mkHsTyApp, mkHsDictApp, mkHsConApp, zonkId + mkHsTyApp, mkHsDictApp, mkHsConApp, zonkId, + mkCoercion, ExprCoFn ) import TcRnMonad import TcEnv ( tcGetInstEnv, tcLookupId, tcLookupTyCon, checkWellStaged, topIdLvl ) @@ -256,7 +257,7 @@ newIPDict orig ip_name ty \begin{code} -tcInstCall :: InstOrigin -> TcType -> TcM (TypecheckedHsExpr -> TypecheckedHsExpr, TcType) +tcInstCall :: InstOrigin -> TcType -> TcM (ExprCoFn, TcType) tcInstCall orig fun_ty -- fun_ty is usually a sigma-type = tcInstType VanillaTv fun_ty `thenM` \ (tyvars, theta, tau) -> newDicts orig theta `thenM` \ dicts -> @@ -264,7 +265,7 @@ tcInstCall orig fun_ty -- fun_ty is usually a sigma-type let inst_fn e = mkHsDictApp (mkHsTyApp e (mkTyVarTys tyvars)) (map instToId dicts) in - returnM (inst_fn, tau) + returnM (mkCoercion inst_fn, tau) tcInstDataCon :: InstOrigin -> DataCon -> TcM ([TcType], -- Types to instantiate at diff --git a/ghc/compiler/typecheck/TcBinds.lhs b/ghc/compiler/typecheck/TcBinds.lhs index 30cad8c..7171ed2 100644 --- a/ghc/compiler/typecheck/TcBinds.lhs +++ b/ghc/compiler/typecheck/TcBinds.lhs @@ -255,8 +255,9 @@ tcBindWithSigs top_lvl mbind sigs is_rec ) $ -- TYPECHECK THE BINDINGS - getLIE (tcMonoBinds mbind tc_ty_sigs is_rec) `thenM` \ ((mbind', binder_names, mono_ids), lie_req) -> + getLIE (tcMonoBinds mbind tc_ty_sigs is_rec) `thenM` \ ((mbind', bndr_names_w_ids), lie_req) -> let + (binder_names, mono_ids) = unzip (bagToList bndr_names_w_ids) tau_tvs = foldr (unionVarSet . tyVarsOfType . idType) emptyVarSet mono_ids in @@ -620,91 +621,86 @@ The signatures have been dealt with already. \begin{code} tcMonoBinds :: RenamedMonoBinds - -> [TcSigInfo] - -> RecFlag + -> [TcSigInfo] -> RecFlag -> TcM (TcMonoBinds, - [Name], -- Bound names - [TcId]) -- Corresponding monomorphic bound things + Bag (Name, -- Bound names + TcId)) -- Corresponding monomorphic bound things tcMonoBinds mbinds tc_ty_sigs is_rec - = tc_mb_pats mbinds `thenM` \ (complete_it, tvs, ids, lie_avail) -> - let - id_list = bagToList ids - (names, mono_ids) = unzip id_list - - -- This last defn is the key one: - -- extend the val envt with bindings for the - -- things bound in this group, overriding the monomorphic - -- ids with the polymorphic ones from the pattern - extra_val_env = case is_rec of - Recursive -> map mk_bind id_list - NonRecursive -> [] - in - -- Don't know how to deal with pattern-bound existentials yet - checkTc (isEmptyBag tvs && null lie_avail) - (existentialExplode mbinds) `thenM_` - - -- *Before* checking the RHSs, but *after* checking *all* the patterns, - -- extend the envt with bindings for all the bound ids; - -- and *then* override with the polymorphic Ids from the signatures - -- That is the whole point of the "complete_it" stuff. - -- - -- There's a further wrinkle: we have to delay extending the environment - -- until after we've dealt with any pattern-bound signature type variables - -- Consider f (x::a) = ...f... - -- We're going to check that a isn't unified with anything in the envt, - -- so f itself had better not be! So we pass the envt binding f into - -- complete_it, which extends the actual envt in TcMatches.tcMatch, after - -- dealing with the signature tyvars - - complete_it extra_val_env `thenM` \ mbinds' -> - - returnM (mbinds', names, mono_ids) + -- Three stages: + -- 1. Check the patterns, building up an environment binding + -- the variables in this group (in the recursive case) + -- 2. Extend the environment + -- 3. Check the RHSs + = tc_mb_pats mbinds `thenM` \ (complete_it, xve) -> + tcExtendLocalValEnv2 (bagToList xve) complete_it where - - mk_bind (name, mono_id) = case maybeSig tc_ty_sigs name of - Nothing -> (name, mono_id) - Just sig -> (idName poly_id, poly_id) - where - poly_id = tcSigPolyId sig - - tc_mb_pats EmptyMonoBinds - = returnM (\ xve -> returnM EmptyMonoBinds, emptyBag, emptyBag, []) + tc_mb_pats EmptyMonoBinds + = returnM (returnM (EmptyMonoBinds, emptyBag), emptyBag) tc_mb_pats (AndMonoBinds mb1 mb2) - = tc_mb_pats mb1 `thenM` \ (complete_it1, tvs1, ids1, lie_avail1) -> - tc_mb_pats mb2 `thenM` \ (complete_it2, tvs2, ids2, lie_avail2) -> + = tc_mb_pats mb1 `thenM` \ (complete_it1, xve1) -> + tc_mb_pats mb2 `thenM` \ (complete_it2, xve2) -> let - complete_it xve = complete_it1 xve `thenM` \ mb1' -> - complete_it2 xve `thenM` \ mb2' -> - returnM (AndMonoBinds mb1' mb2') + complete_it = complete_it1 `thenM` \ (mb1', bs1) -> + complete_it2 `thenM` \ (mb2', bs2) -> + returnM (AndMonoBinds mb1' mb2', bs1 `unionBags` bs2) in - returnM (complete_it, - tvs1 `unionBags` tvs2, - ids1 `unionBags` ids2, - lie_avail1 ++ lie_avail2) + returnM (complete_it, xve1 `unionBags` xve2) tc_mb_pats (FunMonoBind name inf matches locn) - = (case maybeSig tc_ty_sigs name of - Just sig -> returnM (tcSigMonoId sig) - Nothing -> newLocalName name `thenM` \ bndr_name -> - newTyVarTy openTypeKind `thenM` \ bndr_ty -> - -- NB: not a 'hole' tyvar; since there is no type - -- signature, we revert to ordinary H-M typechecking - -- which means the variable gets an inferred tau-type - returnM (mkLocalId bndr_name bndr_ty) - ) `thenM` \ bndr_id -> + -- Three cases: + -- a) Type sig supplied + -- b) No type sig and recursive + -- c) No type sig and non-recursive + + | Just sig <- maybeSig tc_ty_sigs name + = let -- (a) There is a type signature + -- Use it for the environment extension, and check + -- the RHS has the appropriate type (with outer for-alls stripped off) + mono_id = tcSigMonoId sig + mono_ty = idType mono_id + complete_it = addSrcLoc locn $ + tcMatchesFun name mono_ty matches `thenM` \ matches' -> + returnM (FunMonoBind mono_id inf matches' locn, + unitBag (name, mono_id)) + in + returnM (complete_it, if isRec is_rec then unitBag (name,tcSigPolyId sig) + else emptyBag) + + | isRec is_rec + = -- (b) No type signature, and recursive + -- So we must use an ordinary H-M type variable + -- which means the variable gets an inferred tau-type + newLocalName name `thenM` \ mono_name -> + newTyVarTy openTypeKind `thenM` \ mono_ty -> let - bndr_ty = idType bndr_id - complete_it xve = addSrcLoc locn $ - tcMatchesFun xve name bndr_ty matches `thenM` \ matches' -> - returnM (FunMonoBind bndr_id inf matches' locn) + mono_id = mkLocalId mono_name mono_ty + complete_it = addSrcLoc locn $ + tcMatchesFun name mono_ty matches `thenM` \ matches' -> + returnM (FunMonoBind mono_id inf matches' locn, + unitBag (name, mono_id)) in - returnM (complete_it, emptyBag, unitBag (name, bndr_id), []) - + returnM (complete_it, unitBag (name, mono_id)) + + | otherwise -- (c) No type signature, and non-recursive + = let -- So we can use a 'hole' type to infer a higher-rank type + complete_it + = addSrcLoc locn $ + newHoleTyVarTy `thenM` \ fun_ty -> + tcMatchesFun name fun_ty matches `thenM` \ matches' -> + readHoleResult fun_ty `thenM` \ fun_ty' -> + newLocalName name `thenM` \ mono_name -> + let + mono_id = mkLocalId mono_name fun_ty' + in + returnM (FunMonoBind mono_id inf matches' locn, + unitBag (name, mono_id)) + in + returnM (complete_it, emptyBag) + tc_mb_pats bind@(PatMonoBind pat grhss locn) = addSrcLoc locn $ - newHoleTyVarTy `thenM` \ pat_ty -> -- Now typecheck the pattern -- We do now support binding fresh (not-already-in-scope) scoped @@ -714,16 +710,21 @@ tcMonoBinds mbinds tc_ty_sigs is_rec -- The type variables are brought into scope in tc_binds_and_then, -- so we don't have to do anything here. - tcPat tc_pat_bndr pat pat_ty `thenM` \ (pat', tvs, ids, lie_avail) -> - readHoleResult pat_ty `thenM` \ pat_ty' -> + newHoleTyVarTy `thenM` \ pat_ty -> + tcPat tc_pat_bndr pat pat_ty `thenM` \ (pat', tvs, ids, lie_avail) -> + readHoleResult pat_ty `thenM` \ pat_ty' -> + + -- Don't know how to deal with pattern-bound existentials yet + checkTc (isEmptyBag tvs && null lie_avail) + (existentialExplode bind) `thenM_` + let - complete_it xve = addSrcLoc locn $ - addErrCtxt (patMonoBindsCtxt bind) $ - tcExtendLocalValEnv2 xve $ - tcGRHSs PatBindRhs grhss pat_ty' `thenM` \ grhss' -> - returnM (PatMonoBind pat' grhss' locn) + complete_it = addSrcLoc locn $ + addErrCtxt (patMonoBindsCtxt bind) $ + tcGRHSs PatBindRhs grhss pat_ty' `thenM` \ grhss' -> + returnM (PatMonoBind pat' grhss' locn, ids) in - returnM (complete_it, tvs, ids, lie_avail) + returnM (complete_it, if isRec is_rec then ids else emptyBag) -- tc_pat_bndr is used when dealing with a LHS binder in a pattern. -- If there was a type sig for that Id, we want to make it much @@ -735,9 +736,8 @@ tcMonoBinds mbinds tc_ty_sigs is_rec tc_pat_bndr name pat_ty = case maybeSig tc_ty_sigs name of - Nothing - -> newLocalName name `thenM` \ bndr_name -> - tcMonoPatBndr bndr_name pat_ty + Nothing -> newLocalName name `thenM` \ bndr_name -> + tcMonoPatBndr bndr_name pat_ty Just sig -> addSrcLoc (getSrcLoc name) $ tcSubPat (idType mono_id) pat_ty `thenM` \ co_fn -> diff --git a/ghc/compiler/typecheck/TcClassDcl.lhs b/ghc/compiler/typecheck/TcClassDcl.lhs index 2ebe668..bf829aa 100644 --- a/ghc/compiler/typecheck/TcClassDcl.lhs +++ b/ghc/compiler/typecheck/TcClassDcl.lhs @@ -457,7 +457,7 @@ tcMethodBind xtve inst_tyvars inst_theta avail_insts prags tcExtendTyVarEnv2 xtve ( addErrCtxt (methodCtxt sel_id) $ getLIE (tcMonoBinds meth_bind [meth_sig] NonRecursive) - ) `thenM` \ ((meth_bind, _, _), meth_lie) -> + ) `thenM` \ ((meth_bind, _), meth_lie) -> -- Now do context reduction. We simplify wrt both the local tyvars -- and the ones of the class/instance decl, so that there is diff --git a/ghc/compiler/typecheck/TcExpr.lhs b/ghc/compiler/typecheck/TcExpr.lhs index 296c504..6cfd445 100644 --- a/ghc/compiler/typecheck/TcExpr.lhs +++ b/ghc/compiler/typecheck/TcExpr.lhs @@ -19,11 +19,10 @@ import qualified DsMeta import HsSyn ( HsExpr(..), HsLit(..), ArithSeqInfo(..), recBindFields ) import RnHsSyn ( RenamedHsExpr, RenamedRecordBinds ) -import TcHsSyn ( TcExpr, TcRecordBinds, hsLitType, mkHsDictApp, mkHsTyApp, mkHsLet ) +import TcHsSyn ( TcExpr, TcRecordBinds, hsLitType, mkHsDictApp, mkHsTyApp, mkHsLet, (<$>) ) import TcRnMonad -import TcUnify ( tcSubExp, tcGen, (<$>), - unifyTauTy, unifyFunTy, unifyListTy, unifyPArrTy, - unifyTupleTy ) +import TcUnify ( tcSubExp, tcGen, + unifyTauTy, unifyFunTy, unifyListTy, unifyPArrTy, unifyTupleTy ) import BasicTypes ( isMarkedStrict ) import Inst ( InstOrigin(..), newOverloadedLit, newMethodFromName, newIPDict, @@ -34,7 +33,7 @@ import TcBinds ( tcBindsAndThen ) import TcEnv ( tcLookupClass, tcLookupGlobal_maybe, tcLookupIdLvl, tcLookupTyCon, tcLookupDataCon, tcLookupId ) -import TcMatches ( tcMatchesCase, tcMatchLambda, tcDoStmts ) +import TcMatches ( tcMatchesCase, tcMatchLambda, tcDoStmts, tcThingWithSig ) import TcMonoType ( tcHsSigType, UserTypeCtxt(..) ) import TcPat ( badFieldCon ) import TcMType ( tcInstTyVars, tcInstType, newHoleTyVarTy, zapToType, @@ -136,17 +135,10 @@ tcMonoExpr (HsIPVar ip) res_ty \begin{code} tcMonoExpr in_expr@(ExprWithTySig expr poly_ty) res_ty - = addErrCtxt (exprSigCtxt in_expr) $ - tcHsSigType ExprSigCtxt poly_ty `thenM` \ sig_tc_ty -> - tcExpr expr sig_tc_ty `thenM` \ expr' -> - - -- Must instantiate the outer for-alls of sig_tc_ty - -- else we risk instantiating a ? res_ty to a forall-type - -- which breaks the invariant that tcMonoExpr only returns phi-types - tcInstCall SignatureOrigin sig_tc_ty `thenM` \ (inst_fn, inst_sig_ty) -> - tcSubExp res_ty inst_sig_ty `thenM` \ co_fn -> - - returnM (co_fn <$> inst_fn expr') + = addErrCtxt (exprSigCtxt in_expr) $ + tcHsSigType ExprSigCtxt poly_ty `thenM` \ sig_tc_ty -> + tcThingWithSig sig_tc_ty (tcMonoExpr expr) res_ty `thenM` \ (co_fn, expr') -> + returnM (co_fn <$> expr') tcMonoExpr (HsType ty) res_ty = failWithTc (text "Can't handle type argument:" <+> ppr ty) @@ -832,7 +824,7 @@ tcId name -- Look up the Id and instantiate its type loop fun fun_ty | isSigmaTy fun_ty = tcInstCall orig fun_ty `thenM` \ (inst_fn, tau) -> - loop (inst_fn fun) tau + loop (inst_fn <$> fun) tau | otherwise = returnM (fun, fun_ty) diff --git a/ghc/compiler/typecheck/TcHsSyn.lhs b/ghc/compiler/typecheck/TcHsSyn.lhs index a09eb59..24dc515 100644 --- a/ghc/compiler/typecheck/TcHsSyn.lhs +++ b/ghc/compiler/typecheck/TcHsSyn.lhs @@ -27,6 +27,11 @@ module TcHsSyn ( mkHsTyLam, mkHsDictLam, mkHsLet, hsLitType, hsPatType, + -- Coercions + Coercion, ExprCoFn, PatCoFn, + (<$>), (<.>), mkCoercion, + idCoercion, isIdCoercion, + -- re-exported from TcMonad TcId, TcIdSet, @@ -65,6 +70,7 @@ import VarSet import VarEnv import BasicTypes ( RecFlag(..), Boxity(..), IPName(..), ipNameName, mapIPName ) import Maybes ( orElse ) +import Maybe ( isNothing ) import Unique ( Uniquable(..) ) import SrcLoc ( noSrcLoc ) import Bag @@ -182,12 +188,37 @@ hsLitType (HsDoublePrim d) = doublePrimTy hsLitType (HsLitLit _ ty) = ty \end{code} +%************************************************************************ +%* * +\subsection{Coercion functions} +%* * +%************************************************************************ + \begin{code} --- zonkId is used *during* typechecking just to zonk the Id's type -zonkId :: TcId -> TcM TcId -zonkId id - = zonkTcType (idType id) `thenM` \ ty' -> - returnM (setIdType id ty') +type Coercion a = Maybe (a -> a) + -- Nothing => identity fn + +type ExprCoFn = Coercion TypecheckedHsExpr +type PatCoFn = Coercion TcPat + +(<.>) :: Coercion a -> Coercion a -> Coercion a -- Composition +Nothing <.> Nothing = Nothing +Nothing <.> Just f = Just f +Just f <.> Nothing = Just f +Just f1 <.> Just f2 = Just (f1 . f2) + +(<$>) :: Coercion a -> a -> a +Just f <$> e = f e +Nothing <$> e = e + +mkCoercion :: (a -> a) -> Coercion a +mkCoercion f = Just f + +idCoercion :: Coercion a +idCoercion = Nothing + +isIdCoercion :: Coercion a -> Bool +isIdCoercion = isNothing \end{code} @@ -197,7 +228,16 @@ zonkId id %* * %************************************************************************ -This zonking pass runs over the bindings +\begin{code} +-- zonkId is used *during* typechecking just to zonk the Id's type +zonkId :: TcId -> TcM TcId +zonkId id + = zonkTcType (idType id) `thenM` \ ty' -> + returnM (setIdType id ty') +\end{code} + +The rest of the zonking is done *after* typechecking. +The main zonking pass runs over the bindings a) to convert TcTyVars to TyVars etc, dereferencing any bindings etc b) convert unbound TcTyVar to Void diff --git a/ghc/compiler/typecheck/TcMatches.hi-boot b/ghc/compiler/typecheck/TcMatches.hi-boot index 735e159..cdb14ff 100644 --- a/ghc/compiler/typecheck/TcMatches.hi-boot +++ b/ghc/compiler/typecheck/TcMatches.hi-boot @@ -8,8 +8,7 @@ _declarations_ -> TcType.TcType -> TcMonad.TcM s (TcHsSyn.TcGRHSs, TcMonad.LIE) ;; 3 tcMatchesFun _:_ _forall_ [s] => - [(Name.Name,Var.Id)] - -> Name.Name + Name.Name -> TcType.TcType -> [RnHsSyn.RenamedMatch] -> TcMonad.TcM s ([TcHsSyn.TcMatch], TcMonad.LIE) ;; diff --git a/ghc/compiler/typecheck/TcMatches.hi-boot-5 b/ghc/compiler/typecheck/TcMatches.hi-boot-5 index 881a6cf..3ab2fb8 100644 --- a/ghc/compiler/typecheck/TcMatches.hi-boot-5 +++ b/ghc/compiler/typecheck/TcMatches.hi-boot-5 @@ -5,8 +5,7 @@ __export TcMatches tcGRHSs tcMatchesFun; -> TcType.TcType -> TcRnTypes.TcM TcHsSyn.TcGRHSs ; 1 tcMatchesFun :: - [(Name.Name,Var.Id)] - -> Name.Name + Name.Name -> TcType.TcType -> [RnHsSyn.RenamedMatch] -> TcRnTypes.TcM [TcHsSyn.TcMatch] ; diff --git a/ghc/compiler/typecheck/TcMatches.hi-boot-6 b/ghc/compiler/typecheck/TcMatches.hi-boot-6 index c35bfee..624f36b 100644 --- a/ghc/compiler/typecheck/TcMatches.hi-boot-6 +++ b/ghc/compiler/typecheck/TcMatches.hi-boot-6 @@ -5,9 +5,7 @@ tcGRHSs :: HsExpr.HsMatchContext Name.Name -> TcType.TcType -> TcRnTypes.TcM TcHsSyn.TcGRHSs -tcMatchesFun :: - [(Name.Name,Var.Id)] - -> Name.Name +tcMatchesFun :: Name.Name -> TcType.TcType -> [RnHsSyn.RenamedMatch] -> TcRnTypes.TcM [TcHsSyn.TcMatch] diff --git a/ghc/compiler/typecheck/TcMatches.lhs b/ghc/compiler/typecheck/TcMatches.lhs index f1048d8..55c7a0c 100644 --- a/ghc/compiler/typecheck/TcMatches.lhs +++ b/ghc/compiler/typecheck/TcMatches.lhs @@ -5,7 +5,7 @@ \begin{code} module TcMatches ( tcMatchesFun, tcMatchesCase, tcMatchLambda, - tcDoStmts, tcStmtsAndThen, tcGRHSs + tcDoStmts, tcStmtsAndThen, tcGRHSs, tcThingWithSig ) where #include "HsVersions.h" @@ -21,20 +21,22 @@ import HsSyn ( HsExpr(..), HsBinds(..), Match(..), GRHSs(..), GRHS(..), import RnHsSyn ( RenamedMatch, RenamedGRHSs, RenamedStmt, RenamedPat, RenamedMatchContext ) import TcHsSyn ( TcMatch, TcGRHSs, TcStmt, TcDictBinds, TcHsBinds, - TcMonoBinds, TcPat, TcStmt ) + TcMonoBinds, TcPat, TcStmt, ExprCoFn, + isIdCoercion, (<$>), (<.>) ) import TcRnMonad import TcMonoType ( tcAddScopedTyVars, tcHsSigType, UserTypeCtxt(..) ) -import Inst ( tcSyntaxName ) +import Inst ( tcSyntaxName, tcInstCall ) import TcEnv ( TcId, tcLookupLocalIds, tcLookupId, tcExtendLocalValEnv, tcExtendLocalValEnv2 ) import TcPat ( tcPat, tcMonoPatBndr ) import TcMType ( newTyVarTy, newTyVarTys, zonkTcType, zapToType ) -import TcType ( TcType, TcTyVar, tyVarsOfType, tidyOpenTypes, tidyOpenType, +import TcType ( TcType, TcTyVar, TcSigmaType, TcRhoType, + tyVarsOfType, tidyOpenTypes, tidyOpenType, isSigmaTy, mkFunTy, isOverloadedTy, liftedTypeKind, openTypeKind, mkArrowKind, mkAppTy ) import TcBinds ( tcBindsAndThen ) import TcUnify ( unifyPArrTy,subFunTy, unifyListTy, unifyTauTy, - checkSigTyVarsWrt, tcSubExp, isIdCoercion, (<$>) ) + checkSigTyVarsWrt, tcSubExp, tcGen ) import TcSimplify ( tcSimplifyCheck, bindInstsOfLocalFuns ) import Name ( Name ) import PrelNames ( monadNames, mfixName ) @@ -63,13 +65,12 @@ is used in error messages. It checks that all the equations have the same number of arguments before using @tcMatches@ to do the work. \begin{code} -tcMatchesFun :: [(Name,Id)] -- Bindings for the variables bound in this group - -> Name +tcMatchesFun :: Name -> TcType -- Expected type -> [RenamedMatch] -> TcM [TcMatch] -tcMatchesFun xve fun_name expected_ty matches@(first_match:_) +tcMatchesFun fun_name expected_ty matches@(first_match:_) = -- Check that they all have the same no of arguments -- Set the location to that of the first equation, so that -- any inter-equation error messages get some vaguely @@ -86,7 +87,7 @@ tcMatchesFun xve fun_name expected_ty matches@(first_match:_) -- may show up as something wrong with the (non-existent) type signature -- No need to zonk expected_ty, because subFunTy does that on the fly - tcMatches xve (FunRhs fun_name) matches expected_ty + tcMatches (FunRhs fun_name) matches expected_ty \end{code} @tcMatchesCase@ doesn't do the argument-count check because the @@ -100,22 +101,21 @@ tcMatchesCase :: [RenamedMatch] -- The case alternatives tcMatchesCase matches expr_ty = newTyVarTy openTypeKind `thenM` \ scrut_ty -> - tcMatches [] CaseAlt matches (mkFunTy scrut_ty expr_ty) `thenM` \ matches' -> + tcMatches CaseAlt matches (mkFunTy scrut_ty expr_ty) `thenM` \ matches' -> returnM (scrut_ty, matches') tcMatchLambda :: RenamedMatch -> TcType -> TcM TcMatch -tcMatchLambda match res_ty = tcMatch [] LambdaExpr match res_ty +tcMatchLambda match res_ty = tcMatch LambdaExpr match res_ty \end{code} \begin{code} -tcMatches :: [(Name,Id)] - -> RenamedMatchContext +tcMatches :: RenamedMatchContext -> [RenamedMatch] -> TcType -> TcM [TcMatch] -tcMatches xve ctxt matches expected_ty +tcMatches ctxt matches expected_ty = -- If there is more than one branch, and expected_ty is a 'hole', -- all branches must be types, not type schemes, otherwise the -- in which we check them would affect the result. @@ -126,7 +126,7 @@ tcMatches xve ctxt matches expected_ty mappM (tc_match expected_ty') matches where - tc_match expected_ty match = tcMatch xve ctxt match expected_ty + tc_match expected_ty match = tcMatch ctxt match expected_ty \end{code} @@ -137,8 +137,7 @@ tcMatches xve ctxt matches expected_ty %************************************************************************ \begin{code} -tcMatch :: [(Name,Id)] - -> RenamedMatchContext +tcMatch :: RenamedMatchContext -> RenamedMatch -> TcType -- Expected result-type of the Match. -- Early unification with this guy gives better error messages @@ -147,7 +146,7 @@ tcMatch :: [(Name,Id)] -- where there are n patterns. -> TcM TcMatch -tcMatch xve1 ctxt match@(Match pats maybe_rhs_sig grhss) expected_ty +tcMatch ctxt match@(Match pats maybe_rhs_sig grhss) expected_ty = addSrcLoc (getMatchLoc match) $ -- At one stage I removed this; addErrCtxt (matchCtxt ctxt match) $ -- I'm not sure why, so I put it back tcMatchPats pats expected_ty tc_grhss `thenM` \ (pats', grhss', ex_binds) -> @@ -155,17 +154,14 @@ tcMatch xve1 ctxt match@(Match pats maybe_rhs_sig grhss) expected_ty where tc_grhss rhs_ty - = tcExtendLocalValEnv2 xve1 $ - - -- Deal with the result signature + = -- Deal with the result signature case maybe_rhs_sig of Nothing -> tcGRHSs ctxt grhss rhs_ty Just sig -> tcAddScopedTyVars [sig] $ -- Bring into scope the type variables in the signature - tcHsSigType ResSigCtxt sig `thenM` \ sig_ty -> - tcGRHSs ctxt grhss sig_ty `thenM` \ grhss' -> - tcSubExp rhs_ty sig_ty `thenM` \ co_fn -> + tcHsSigType ResSigCtxt sig `thenM` \ sig_ty -> + tcThingWithSig sig_ty (tcGRHSs ctxt grhss) rhs_ty `thenM` \ (co_fn, grhss') -> returnM (lift_grhss co_fn rhs_ty grhss') -- lift_grhss pushes the coercion down to the right hand sides, @@ -173,7 +169,7 @@ tcMatch xve1 ctxt match@(Match pats maybe_rhs_sig grhss) expected_ty lift_grhss co_fn rhs_ty grhss | isIdCoercion co_fn = grhss lift_grhss co_fn rhs_ty (GRHSs grhss binds ty) - = GRHSs (map lift_grhs grhss) binds rhs_ty -- Change the type, since we + = GRHSs (map lift_grhs grhss) binds rhs_ty -- Change the type, since the coercion does where lift_grhs (GRHS stmts loc) = GRHS (map lift_stmt stmts) loc @@ -206,6 +202,31 @@ tcGRHSs ctxt (GRHSs grhss binds _) expected_ty \end{code} +\begin{code} +tcThingWithSig :: TcSigmaType -- Type signature + -> (TcRhoType -> TcM r) -- How to type check the thing inside + -> TcRhoType -- Overall expected result type + -> TcM (ExprCoFn, r) +-- Used for expressions with a type signature, and for result type signatures + +tcThingWithSig sig_ty thing_inside res_ty + | not (isSigmaTy sig_ty) + = thing_inside sig_ty `thenM` \ result -> + tcSubExp res_ty sig_ty `thenM` \ co_fn -> + returnM (co_fn, result) + + | otherwise -- The signature has some outer foralls + = -- Must instantiate the outer for-alls of sig_tc_ty + -- else we risk instantiating a ? res_ty to a forall-type + -- which breaks the invariant that tcMonoExpr only returns phi-types + tcGen sig_ty emptyVarSet thing_inside `thenM` \ (gen_fn, result) -> + tcInstCall SignatureOrigin sig_ty `thenM` \ (inst_fn, inst_sig_ty) -> + tcSubExp res_ty inst_sig_ty `thenM` \ co_fn -> + returnM (co_fn <.> inst_fn <.> gen_fn, result) + -- Note that we generalise, then instantiate. Ah well. +\end{code} + + %************************************************************************ %* * \subsection{tcMatchPats} diff --git a/ghc/compiler/typecheck/TcPat.lhs b/ghc/compiler/typecheck/TcPat.lhs index 2f23094..bfd90a6 100644 --- a/ghc/compiler/typecheck/TcPat.lhs +++ b/ghc/compiler/typecheck/TcPat.lhs @@ -12,7 +12,9 @@ module TcPat ( tcPat, tcMonoPatBndr, tcSubPat, import HsSyn ( Pat(..), HsConDetails(..), HsLit(..), HsOverLit(..), HsExpr(..) ) import RnHsSyn ( RenamedPat ) -import TcHsSyn ( TcPat, TcId, hsLitType ) +import TcHsSyn ( TcPat, TcId, hsLitType, + mkCoercion, idCoercion, isIdCoercion, + (<$>), PatCoFn ) import TcRnMonad import Inst ( InstOrigin(..), @@ -27,9 +29,7 @@ import TcMType ( newTyVarTy, zapToType, arityErr ) import TcType ( TcType, TcTyVar, TcSigmaType, mkClassPred, liftedTypeKind ) import TcUnify ( tcSubOff, TcHoleType, - unifyTauTy, unifyListTy, unifyPArrTy, unifyTupleTy, - mkCoercion, idCoercion, isIdCoercion, - (<$>), PatCoFn ) + unifyTauTy, unifyListTy, unifyPArrTy, unifyTupleTy ) import TcMonoType ( tcHsSigType, UserTypeCtxt(..) ) import TysWiredIn ( stringTy ) diff --git a/ghc/compiler/typecheck/TcUnify.lhs b/ghc/compiler/typecheck/TcUnify.lhs index c04d310..1a05fd5 100644 --- a/ghc/compiler/typecheck/TcUnify.lhs +++ b/ghc/compiler/typecheck/TcUnify.lhs @@ -12,12 +12,7 @@ module TcUnify ( -- Various unifications unifyTauTy, unifyTauTyList, unifyTauTyLists, unifyFunTy, unifyListTy, unifyPArrTy, unifyTupleTy, - unifyKind, unifyKinds, unifyOpenTypeKind, unifyFunKind, - - -- Coercions - Coercion, ExprCoFn, PatCoFn, - (<$>), (<.>), mkCoercion, - idCoercion, isIdCoercion + unifyKind, unifyKinds, unifyOpenTypeKind, unifyFunKind ) where @@ -25,7 +20,8 @@ module TcUnify ( import HsSyn ( HsExpr(..) ) -import TcHsSyn ( TypecheckedHsExpr, TcPat, mkHsLet ) +import TcHsSyn ( TypecheckedHsExpr, TcPat, mkHsLet, + ExprCoFn, idCoercion, isIdCoercion, mkCoercion, (<.>), (<$>) ) import TypeRep ( Type(..), SourceType(..), TyNote(..), openKindCon ) import TcRnMonad -- TcType, amongst others @@ -181,7 +177,7 @@ tc_sub exp_sty expected_ty act_sty actual_ty | isSigmaTy actual_ty = tcInstCall Rank2Origin actual_ty `thenM` \ (inst_fn, body_ty) -> tc_sub exp_sty expected_ty body_ty body_ty `thenM` \ co_fn -> - returnM (co_fn <.> mkCoercion inst_fn) + returnM (co_fn <.> inst_fn) ----------------------------------- -- Function case @@ -353,39 +349,6 @@ tcGen expected_ty extra_tvs thing_inside -- We expect expected_ty to be a forall %************************************************************************ %* * -\subsection{Coercion functions} -%* * -%************************************************************************ - -\begin{code} -type Coercion a = Maybe (a -> a) - -- Nothing => identity fn - -type ExprCoFn = Coercion TypecheckedHsExpr -type PatCoFn = Coercion TcPat - -(<.>) :: Coercion a -> Coercion a -> Coercion a -- Composition -Nothing <.> Nothing = Nothing -Nothing <.> Just f = Just f -Just f <.> Nothing = Just f -Just f1 <.> Just f2 = Just (f1 . f2) - -(<$>) :: Coercion a -> a -> a -Just f <$> e = f e -Nothing <$> e = e - -mkCoercion :: (a -> a) -> Coercion a -mkCoercion f = Just f - -idCoercion :: Coercion a -idCoercion = Nothing - -isIdCoercion :: Coercion a -> Bool -isIdCoercion = isNothing -\end{code} - -%************************************************************************ -%* * \subsection[Unify-exported]{Exported unification functions} %* * %************************************************************************ -- 1.7.10.4