[project @ 2003-03-27 08:18:21 by simonpj]
[ghc-hetmet.git] / ghc / compiler / typecheck / TcMatches.lhs
index 91d5aef..55c7a0c 100644 (file)
@@ -5,7 +5,7 @@
 
 \begin{code}
 module TcMatches ( tcMatchesFun, tcMatchesCase, tcMatchLambda, 
-                  tcDoStmts, tcStmtsAndThen, tcGRHSs 
+                  tcDoStmts, tcStmtsAndThen, tcGRHSs, tcThingWithSig
        ) where
 
 #include "HsVersions.h"
@@ -16,25 +16,27 @@ import HsSyn                ( HsExpr(..), HsBinds(..), Match(..), GRHSs(..), GRHS(..),
                          MonoBinds(..), Stmt(..), HsMatchContext(..), HsStmtContext(..),
                          pprMatch, getMatchLoc, isDoExpr,
                          pprMatchContext, pprStmtContext, pprStmtResultContext,
-                         mkMonoBind, nullMonoBinds, collectSigTysFromPats, andMonoBindList
+                         mkMonoBind, collectSigTysFromPats, andMonoBindList
                        )
 import RnHsSyn         ( RenamedMatch, RenamedGRHSs, RenamedStmt, 
                          RenamedPat, RenamedMatchContext )
-import TcHsSyn         ( TcMatch, TcGRHSs, TcStmt, TcDictBinds, 
-                         TcMonoBinds, TcPat, TcStmt )
+import TcHsSyn         ( TcMatch, TcGRHSs, TcStmt, TcDictBinds, TcHsBinds, 
+                         TcMonoBinds, TcPat, TcStmt, ExprCoFn,
+                         isIdCoercion, (<$>), (<.>) )
 
 import TcRnMonad
 import TcMonoType      ( tcAddScopedTyVars, tcHsSigType, UserTypeCtxt(..) )
-import Inst            ( tcSyntaxName )
-import TcEnv           ( TcId, tcLookupLocalIds, tcExtendLocalValEnv, tcExtendLocalValEnv2 )
+import Inst            ( tcSyntaxName, tcInstCall )
+import TcEnv           ( TcId, tcLookupLocalIds, tcLookupId, tcExtendLocalValEnv, tcExtendLocalValEnv2 )
 import TcPat           ( tcPat, tcMonoPatBndr )
 import TcMType         ( newTyVarTy, newTyVarTys, zonkTcType, zapToType )
-import TcType          ( TcType, TcTyVar, tyVarsOfType, tidyOpenTypes, tidyOpenType,
+import TcType          ( TcType, TcTyVar, TcSigmaType, TcRhoType,
+                         tyVarsOfType, tidyOpenTypes, tidyOpenType, isSigmaTy,
                          mkFunTy, isOverloadedTy, liftedTypeKind, openTypeKind, 
                          mkArrowKind, mkAppTy )
 import TcBinds         ( tcBindsAndThen )
 import TcUnify         ( unifyPArrTy,subFunTy, unifyListTy, unifyTauTy,
-                         checkSigTyVarsWrt, tcSubExp, isIdCoercion, (<$>), unifyTauTyLists )
+                         checkSigTyVarsWrt, tcSubExp, tcGen )
 import TcSimplify      ( tcSimplifyCheck, bindInstsOfLocalFuns )
 import Name            ( Name )
 import PrelNames       ( monadNames, mfixName )
@@ -63,13 +65,12 @@ is used in error messages.  It checks that all the equations have the
 same number of arguments before using @tcMatches@ to do the work.
 
 \begin{code}
-tcMatchesFun :: [(Name,Id)]    -- Bindings for the variables bound in this group
-            -> Name
+tcMatchesFun :: Name
             -> TcType          -- Expected type
             -> [RenamedMatch]
             -> TcM [TcMatch]
 
-tcMatchesFun xve fun_name expected_ty matches@(first_match:_)
+tcMatchesFun fun_name expected_ty matches@(first_match:_)
   =     -- Check that they all have the same no of arguments
         -- Set the location to that of the first equation, so that
         -- any inter-equation error messages get some vaguely
@@ -86,7 +87,7 @@ tcMatchesFun xve fun_name expected_ty matches@(first_match:_)
        -- may show up as something wrong with the (non-existent) type signature
 
        -- No need to zonk expected_ty, because subFunTy does that on the fly
-    tcMatches xve (FunRhs fun_name) matches expected_ty
+    tcMatches (FunRhs fun_name) matches expected_ty
 \end{code}
 
 @tcMatchesCase@ doesn't do the argument-count check because the
@@ -100,22 +101,21 @@ tcMatchesCase :: [RenamedMatch]           -- The case alternatives
 
 tcMatchesCase matches expr_ty
   = newTyVarTy openTypeKind                                    `thenM` \ scrut_ty ->
-    tcMatches [] CaseAlt matches (mkFunTy scrut_ty expr_ty)    `thenM` \ matches' ->
+    tcMatches CaseAlt matches (mkFunTy scrut_ty expr_ty)       `thenM` \ matches' ->
     returnM (scrut_ty, matches')
 
 tcMatchLambda :: RenamedMatch -> TcType -> TcM TcMatch
-tcMatchLambda match res_ty = tcMatch [] LambdaExpr match res_ty
+tcMatchLambda match res_ty = tcMatch LambdaExpr match res_ty
 \end{code}
 
 
 \begin{code}
-tcMatches :: [(Name,Id)]
-         -> RenamedMatchContext 
+tcMatches :: RenamedMatchContext 
          -> [RenamedMatch]
          -> TcType
          -> TcM [TcMatch]
 
-tcMatches xve ctxt matches expected_ty
+tcMatches ctxt matches expected_ty
   =    -- If there is more than one branch, and expected_ty is a 'hole',
        -- all branches must be types, not type schemes, otherwise the
        -- in which we check them would affect the result.
@@ -126,7 +126,7 @@ tcMatches xve ctxt matches expected_ty
 
     mappM (tc_match expected_ty') matches
   where
-    tc_match expected_ty match = tcMatch xve ctxt match expected_ty
+    tc_match expected_ty match = tcMatch ctxt match expected_ty
 \end{code}
 
 
@@ -137,8 +137,7 @@ tcMatches xve ctxt matches expected_ty
 %************************************************************************
 
 \begin{code}
-tcMatch :: [(Name,Id)]
-       -> RenamedMatchContext
+tcMatch :: RenamedMatchContext
        -> RenamedMatch
        -> TcType       -- Expected result-type of the Match.
                        -- Early unification with this guy gives better error messages
@@ -147,25 +146,22 @@ tcMatch :: [(Name,Id)]
                        -- where there are n patterns.
        -> TcM TcMatch
 
-tcMatch xve1 ctxt match@(Match pats maybe_rhs_sig grhss) expected_ty
+tcMatch ctxt match@(Match pats maybe_rhs_sig grhss) expected_ty
   = addSrcLoc (getMatchLoc match)              $       -- At one stage I removed this;
     addErrCtxt (matchCtxt ctxt match)          $       -- I'm not sure why, so I put it back
     tcMatchPats pats expected_ty tc_grhss      `thenM` \ (pats', grhss', ex_binds) ->
-    returnM (Match pats' Nothing (glue_on Recursive ex_binds grhss'))
+    returnM (Match pats' Nothing (glue_on ex_binds grhss'))
 
   where
     tc_grhss rhs_ty 
-       = tcExtendLocalValEnv2 xve1                     $
-
-               -- Deal with the result signature
+       =       -- Deal with the result signature
          case maybe_rhs_sig of
            Nothing ->  tcGRHSs ctxt grhss rhs_ty
 
            Just sig ->  tcAddScopedTyVars [sig]        $
                                -- Bring into scope the type variables in the signature
-                        tcHsSigType ResSigCtxt sig     `thenM` \ sig_ty ->
-                        tcGRHSs ctxt grhss sig_ty      `thenM` \ grhss' ->
-                        tcSubExp rhs_ty sig_ty         `thenM` \ co_fn  ->
+                        tcHsSigType ResSigCtxt sig                             `thenM` \ sig_ty ->
+                        tcThingWithSig sig_ty (tcGRHSs ctxt grhss) rhs_ty      `thenM` \ (co_fn, grhss') ->
                         returnM (lift_grhss co_fn rhs_ty grhss')
 
 -- lift_grhss pushes the coercion down to the right hand sides,
@@ -173,7 +169,7 @@ tcMatch xve1 ctxt match@(Match pats maybe_rhs_sig grhss) expected_ty
 lift_grhss co_fn rhs_ty grhss 
   | isIdCoercion co_fn = grhss
 lift_grhss co_fn rhs_ty (GRHSs grhss binds ty)
-  = GRHSs (map lift_grhs grhss) binds rhs_ty   -- Change the type, since we
+  = GRHSs (map lift_grhs grhss) binds rhs_ty   -- Change the type, since the coercion does
   where
     lift_grhs (GRHS stmts loc) = GRHS (map lift_stmt stmts) loc
              
@@ -181,9 +177,9 @@ lift_grhss co_fn rhs_ty (GRHSs grhss binds ty)
     lift_stmt stmt            = stmt
    
 -- glue_on just avoids stupid dross
-glue_on _ EmptyMonoBinds grhss = grhss         -- The common case
-glue_on is_rec mbinds (GRHSs grhss binds ty)
-  = GRHSs grhss (mkMonoBind mbinds [] is_rec `ThenBinds` binds) ty
+glue_on EmptyBinds grhss = grhss               -- The common case
+glue_on binds1 (GRHSs grhss binds2 ty)
+  = GRHSs grhss (binds1 `ThenBinds` binds2) ty
 
 
 tcGRHSs :: RenamedMatchContext -> RenamedGRHSs
@@ -206,6 +202,31 @@ tcGRHSs ctxt (GRHSs grhss binds _) expected_ty
 \end{code}
 
 
+\begin{code}
+tcThingWithSig :: TcSigmaType          -- Type signature
+              -> (TcRhoType -> TcM r)  -- How to type check the thing inside
+              -> TcRhoType             -- Overall expected result type
+              -> TcM (ExprCoFn, r)
+-- Used for expressions with a type signature, and for result type signatures
+
+tcThingWithSig sig_ty thing_inside res_ty
+  | not (isSigmaTy sig_ty)
+  = thing_inside sig_ty                `thenM` \ result ->
+    tcSubExp res_ty sig_ty     `thenM` \ co_fn ->
+    returnM (co_fn, result)
+
+  | otherwise  -- The signature has some outer foralls
+  =    -- Must instantiate the outer for-alls of sig_tc_ty
+       -- else we risk instantiating a ? res_ty to a forall-type
+       -- which breaks the invariant that tcMonoExpr only returns phi-types
+    tcGen sig_ty emptyVarSet thing_inside      `thenM` \ (gen_fn, result) ->
+    tcInstCall SignatureOrigin sig_ty          `thenM` \ (inst_fn, inst_sig_ty) ->
+    tcSubExp res_ty inst_sig_ty                        `thenM` \ co_fn ->
+    returnM (co_fn <.> inst_fn <.> gen_fn,  result)
+       -- Note that we generalise, then instantiate. Ah well.
+\end{code}
+
+
 %************************************************************************
 %*                                                                     *
 \subsection{tcMatchPats}
@@ -216,7 +237,7 @@ tcGRHSs ctxt (GRHSs grhss binds _) expected_ty
 tcMatchPats
        :: [RenamedPat] -> TcType
        -> (TcType -> TcM a)
-       -> TcM ([TcPat], a, TcDictBinds)
+       -> TcM ([TcPat], a, TcHsBinds)
 -- Typecheck the patterns, extend the environment to bind the variables,
 -- do the thing inside, use any existentially-bound dictionaries to 
 -- discharge parts of the returning LIE, and deal with pattern type
@@ -246,7 +267,7 @@ tcMatchPats pats expected_ty thing_inside
        --      f (C g) x = g x
        -- Here, result_ty will be simply Int, but expected_ty is (a -> Int).
 
-    returnM (pats', result, ex_binds)
+    returnM (pats', result, mkMonoBind Recursive ex_binds)
 
 tc_match_pats [] expected_ty thing_inside
   = thing_inside expected_ty   `thenM` \ answer ->
@@ -433,7 +454,7 @@ tcStmtAndThen combine do_or_lc m_ty@(m,elt_ty) stmt@(BindStmt pat exp src_loc) t
        popErrCtxt thing_inside
     )                                                  `thenM` \ ([pat'], thing, dict_binds) ->
     returnM (combine (BindStmt pat' exp' src_loc)
-                    (glue_binds combine Recursive dict_binds thing))
+                    (glue_binds combine dict_binds thing))
 
        -- ParStmt
 tcStmtAndThen combine do_or_lc m_ty (ParStmtOut bndr_stmts_s) thing_inside
@@ -458,23 +479,32 @@ tcStmtAndThen combine do_or_lc m_ty (ParStmtOut bndr_stmts_s) thing_inside
     combine_par stmt (stmts, thing) = (stmt:stmts, thing)
 
        -- RecStmt
-tcStmtAndThen combine do_or_lc m_ty (RecStmt recNames stmts) thing_inside
+tcStmtAndThen combine do_or_lc m_ty (RecStmt recNames stmts _) thing_inside
   = newTyVarTys (length recNames) liftedTypeKind               `thenM` \ recTys ->
-    tcExtendLocalValEnv (zipWith mkLocalId recNames recTys)    $
+    let
+       mono_ids = zipWith mkLocalId recNames recTys
+    in
+    tcExtendLocalValEnv mono_ids                       $
     tcStmtsAndThen combine_rec do_or_lc m_ty stmts (
-       tcLookupLocalIds recNames  `thenM` \ rn ->
-       returnM ([], rn)
-    )                                                          `thenM` \ (stmts', recNames') ->
+       mappM tc_ret (recNames `zip` recTys)    `thenM` \ rets ->
+       returnM ([], rets)
+    )                                          `thenM` \ (stmts', rets) ->
 
-    -- Unify the types of the "final" Ids with those of "knot-tied" Ids
-    unifyTauTyLists recTys (map idType recNames')      `thenM_`
-  
-    thing_inside                                       `thenM` \ thing ->
+       -- NB: it's the mono_ids that scope over this part
+    thing_inside                               `thenM` \ thing ->
   
-    returnM (combine (RecStmt recNames' stmts') thing)
+    returnM (combine (RecStmt mono_ids stmts' rets) thing)
   where 
     combine_rec stmt (stmts, thing) = (stmt:stmts, thing)
 
+    -- Unify the types of the "final" Ids with those of "knot-tied" Ids
+    tc_ret (rec_name, mono_ty)
+       = tcLookupId rec_name                   `thenM` \ poly_id ->
+               -- poly_id may have a polymorphic type
+               -- but mono_ty is just a monomorphic type variable
+         tcSubExp mono_ty (idType poly_id)     `thenM` \ co_fn ->
+         returnM (co_fn <$> HsVar poly_id) 
+
        -- ExprStmt
 tcStmtAndThen combine do_or_lc m_ty@(m, res_elt_ty) stmt@(ExprStmt exp _ locn) thing_inside
   = addErrCtxt (stmtCtxt do_or_lc stmt) (
@@ -506,9 +536,8 @@ tcStmtAndThen combine do_or_lc m_ty@(m, res_elt_ty) stmt@(ResultStmt exp locn) t
 
 
 ------------------------------
-glue_binds combine is_rec binds thing 
-  | nullMonoBinds binds = thing
-  | otherwise          = combine (LetStmt (mkMonoBind binds [] is_rec)) thing
+glue_binds combine EmptyBinds  thing = thing
+glue_binds combine other_binds thing = combine (LetStmt other_binds) thing
 \end{code}