[project @ 2003-02-04 12:33:05 by simonpj]
[ghc-hetmet.git] / ghc / compiler / typecheck / TcMatches.lhs
index 944a300..f1048d8 100644 (file)
@@ -13,21 +13,22 @@ module TcMatches ( tcMatchesFun, tcMatchesCase, tcMatchLambda,
 import {-# SOURCE #-}  TcExpr( tcMonoExpr )
 
 import HsSyn           ( HsExpr(..), HsBinds(..), Match(..), GRHSs(..), GRHS(..),
-                         MonoBinds(..), Stmt(..), HsMatchContext(..), HsDoContext(..),
-                         pprMatch, getMatchLoc, pprMatchContext, isDoExpr,
-                         mkMonoBind, nullMonoBinds, collectSigTysFromPats, andMonoBindList
+                         MonoBinds(..), Stmt(..), HsMatchContext(..), HsStmtContext(..),
+                         pprMatch, getMatchLoc, isDoExpr,
+                         pprMatchContext, pprStmtContext, pprStmtResultContext,
+                         mkMonoBind, collectSigTysFromPats, andMonoBindList
                        )
 import RnHsSyn         ( RenamedMatch, RenamedGRHSs, RenamedStmt, 
                          RenamedPat, RenamedMatchContext )
-import TcHsSyn         ( TcMatch, TcGRHSs, TcStmt, TcDictBinds, 
+import TcHsSyn         ( TcMatch, TcGRHSs, TcStmt, TcDictBinds, TcHsBinds, 
                          TcMonoBinds, TcPat, TcStmt )
 
 import TcRnMonad
 import TcMonoType      ( tcAddScopedTyVars, tcHsSigType, UserTypeCtxt(..) )
 import Inst            ( tcSyntaxName )
-import TcEnv           ( TcId, tcLookupLocalIds, tcExtendLocalValEnv2 )
+import TcEnv           ( TcId, tcLookupLocalIds, tcLookupId, tcExtendLocalValEnv, tcExtendLocalValEnv2 )
 import TcPat           ( tcPat, tcMonoPatBndr )
-import TcMType         ( newTyVarTy, zonkTcType, zapToType )
+import TcMType         ( newTyVarTy, newTyVarTys, zonkTcType, zapToType )
 import TcType          ( TcType, TcTyVar, tyVarsOfType, tidyOpenTypes, tidyOpenType,
                          mkFunTy, isOverloadedTy, liftedTypeKind, openTypeKind, 
                          mkArrowKind, mkAppTy )
@@ -36,9 +37,9 @@ import TcUnify                ( unifyPArrTy,subFunTy, unifyListTy, unifyTauTy,
                          checkSigTyVarsWrt, tcSubExp, isIdCoercion, (<$>) )
 import TcSimplify      ( tcSimplifyCheck, bindInstsOfLocalFuns )
 import Name            ( Name )
-import PrelNames       ( monadNames )
+import PrelNames       ( monadNames, mfixName )
 import TysWiredIn      ( boolTy, mkListTy, mkPArrTy )
-import Id              ( idType, mkSysLocal )
+import Id              ( idType, mkSysLocal, mkLocalId )
 import CoreFVs         ( idFreeTyVars )
 import BasicTypes      ( RecFlag(..) )
 import VarSet
@@ -150,7 +151,7 @@ tcMatch xve1 ctxt match@(Match pats maybe_rhs_sig grhss) expected_ty
   = addSrcLoc (getMatchLoc match)              $       -- At one stage I removed this;
     addErrCtxt (matchCtxt ctxt match)          $       -- I'm not sure why, so I put it back
     tcMatchPats pats expected_ty tc_grhss      `thenM` \ (pats', grhss', ex_binds) ->
-    returnM (Match pats' Nothing (glue_on Recursive ex_binds grhss'))
+    returnM (Match pats' Nothing (glue_on ex_binds grhss'))
 
   where
     tc_grhss rhs_ty 
@@ -180,9 +181,9 @@ lift_grhss co_fn rhs_ty (GRHSs grhss binds ty)
     lift_stmt stmt            = stmt
    
 -- glue_on just avoids stupid dross
-glue_on _ EmptyMonoBinds grhss = grhss         -- The common case
-glue_on is_rec mbinds (GRHSs grhss binds ty)
-  = GRHSs grhss (mkMonoBind mbinds [] is_rec `ThenBinds` binds) ty
+glue_on EmptyBinds grhss = grhss               -- The common case
+glue_on binds1 (GRHSs grhss binds2 ty)
+  = GRHSs grhss (binds1 `ThenBinds` binds2) ty
 
 
 tcGRHSs :: RenamedMatchContext -> RenamedGRHSs
@@ -192,13 +193,15 @@ tcGRHSs :: RenamedMatchContext -> RenamedGRHSs
 tcGRHSs ctxt (GRHSs grhss binds _) expected_ty
   = tcBindsAndThen glue_on binds (tc_grhss grhss)
   where
+    m_ty =  (\ty -> ty, expected_ty) 
+
     tc_grhss grhss
        = mappM tc_grhs grhss       `thenM` \ grhss' ->
          returnM (GRHSs grhss' EmptyBinds expected_ty)
 
     tc_grhs (GRHS guarded locn)
-       = addSrcLoc locn                                $
-         tcStmts ctxt (\ty -> ty, expected_ty) guarded `thenM` \ guarded' ->
+       = addSrcLoc locn                        $
+         tcStmts (PatGuard ctxt) m_ty guarded  `thenM` \ guarded' ->
          returnM (GRHS guarded' locn)
 \end{code}
 
@@ -213,7 +216,7 @@ tcGRHSs ctxt (GRHSs grhss binds _) expected_ty
 tcMatchPats
        :: [RenamedPat] -> TcType
        -> (TcType -> TcM a)
-       -> TcM ([TcPat], a, TcDictBinds)
+       -> TcM ([TcPat], a, TcHsBinds)
 -- Typecheck the patterns, extend the environment to bind the variables,
 -- do the thing inside, use any existentially-bound dictionaries to 
 -- discharge parts of the returning LIE, and deal with pattern type
@@ -243,7 +246,7 @@ tcMatchPats pats expected_ty thing_inside
        --      f (C g) x = g x
        -- Here, result_ty will be simply Int, but expected_ty is (a -> Int).
 
-    returnM (pats', result, ex_binds)
+    returnM (pats', result, mkMonoBind Recursive ex_binds)
 
 tc_match_pats [] expected_ty thing_inside
   = thing_inside expected_ty   `thenM` \ answer ->
@@ -317,26 +320,24 @@ tcCheckExistentialPat ex_tvs ex_ids ex_lie lie_req match_ty
 %************************************************************************
 
 \begin{code}
-tcDoStmts :: HsDoContext -> [RenamedStmt] -> [Name] -> TcType
+tcDoStmts :: HsStmtContext Name -> [RenamedStmt] -> [Name] -> TcType
          -> TcM (TcMonoBinds, [TcStmt], [Id])
 tcDoStmts PArrComp stmts method_names res_ty
-  = unifyPArrTy res_ty                   `thenM` \elt_ty ->
-    tcStmts (DoCtxt PArrComp) 
-           (mkPArrTy, elt_ty) stmts      `thenM` \ stmts' ->
+  = unifyPArrTy res_ty                           `thenM` \elt_ty ->
+    tcStmts PArrComp (mkPArrTy, elt_ty) stmts      `thenM` \ stmts' ->
     returnM (EmptyMonoBinds, stmts', [{- unused -}])
 
 tcDoStmts ListComp stmts method_names res_ty
-  = unifyListTy res_ty                 `thenM` \ elt_ty ->
-    tcStmts (DoCtxt ListComp) 
-           (mkListTy, elt_ty) stmts    `thenM` \ stmts' ->
+  = unifyListTy res_ty                         `thenM` \ elt_ty ->
+    tcStmts ListComp (mkListTy, elt_ty) stmts  `thenM` \ stmts' ->
     returnM (EmptyMonoBinds, stmts', [{- unused -}])
 
-tcDoStmts DoExpr stmts method_names res_ty
+tcDoStmts do_or_mdo_expr stmts method_names res_ty
   = newTyVarTy (mkArrowKind liftedTypeKind liftedTypeKind)     `thenM` \ m_ty ->
     newTyVarTy liftedTypeKind                                  `thenM` \ elt_ty ->
     unifyTauTy res_ty (mkAppTy m_ty elt_ty)                    `thenM_`
 
-    tcStmts (DoCtxt DoExpr) (mkAppTy m_ty, elt_ty) stmts       `thenM` \ stmts' ->
+    tcStmts do_or_mdo_expr (mkAppTy m_ty, elt_ty) stmts                `thenM` \ stmts' ->
 
        -- Build the then and zero methods in case we need them
        -- It's important that "then" and "return" appear just once in the final LIE,
@@ -347,9 +348,12 @@ tcDoStmts DoExpr stmts method_names res_ty
        -- where the second "then" sees that it already exists in the "available" stuff.
        --
     mapAndUnzipM (tc_syn_name m_ty) 
-                (zipEqual "tcDoStmts" monadNames method_names)  `thenM` \ (binds, ids) ->
+                (zipEqual "tcDoStmts" currentMonadNames method_names)  `thenM` \ (binds, ids) ->
     returnM (andMonoBindList binds, stmts', ids)
   where
+    currentMonadNames = case do_or_mdo_expr of
+                         DoExpr  -> monadNames
+                         MDoExpr -> monadNames ++ [mfixName]
     tc_syn_name :: TcType -> (Name,Name) -> TcM (TcMonoBinds, Id)
     tc_syn_name m_ty (std_nm, usr_nm)
        = tcSyntaxName DoOrigin m_ty std_nm usr_nm      `thenM` \ (expr, expr_ty) ->
@@ -398,7 +402,7 @@ tcStmts do_or_lc m_ty stmts
 
 tcStmtsAndThen
        :: (TcStmt -> thing -> thing)   -- Combiner
-       -> RenamedMatchContext
+       -> HsStmtContext Name
         -> (TcType -> TcType, TcType)  -- m, the relationship type of pat and rhs in pat <- rhs
                                        -- elt_ty, where type of the comprehension is (m elt_ty)
         -> [RenamedStmt]
@@ -429,7 +433,7 @@ tcStmtAndThen combine do_or_lc m_ty@(m,elt_ty) stmt@(BindStmt pat exp src_loc) t
        popErrCtxt thing_inside
     )                                                  `thenM` \ ([pat'], thing, dict_binds) ->
     returnM (combine (BindStmt pat' exp' src_loc)
-                    (glue_binds combine Recursive dict_binds thing))
+                    (glue_binds combine dict_binds thing))
 
        -- ParStmt
 tcStmtAndThen combine do_or_lc m_ty (ParStmtOut bndr_stmts_s) thing_inside
@@ -442,7 +446,7 @@ tcStmtAndThen combine do_or_lc m_ty (ParStmtOut bndr_stmts_s) thing_inside
 
     loop ((bndrs,stmts) : pairs)
       = tcStmtsAndThen 
-               combine_par (DoCtxt ListComp) m_ty stmts
+               combine_par ListComp m_ty stmts
                        -- Notice we pass on m_ty; the result type is used only
                        -- to get escaping type variables for checkExistentialPat
                (tcLookupLocalIds bndrs `thenM` \ bndrs' ->
@@ -453,9 +457,36 @@ tcStmtAndThen combine do_or_lc m_ty (ParStmtOut bndr_stmts_s) thing_inside
 
     combine_par stmt (stmts, thing) = (stmt:stmts, thing)
 
+       -- RecStmt
+tcStmtAndThen combine do_or_lc m_ty (RecStmt recNames stmts _) thing_inside
+  = newTyVarTys (length recNames) liftedTypeKind               `thenM` \ recTys ->
+    let
+       mono_ids = zipWith mkLocalId recNames recTys
+    in
+    tcExtendLocalValEnv mono_ids                       $
+    tcStmtsAndThen combine_rec do_or_lc m_ty stmts (
+       mappM tc_ret (recNames `zip` recTys)    `thenM` \ rets ->
+       returnM ([], rets)
+    )                                          `thenM` \ (stmts', rets) ->
+
+       -- NB: it's the mono_ids that scope over this part
+    thing_inside                               `thenM` \ thing ->
+  
+    returnM (combine (RecStmt mono_ids stmts' rets) thing)
+  where 
+    combine_rec stmt (stmts, thing) = (stmt:stmts, thing)
+
+    -- Unify the types of the "final" Ids with those of "knot-tied" Ids
+    tc_ret (rec_name, mono_ty)
+       = tcLookupId rec_name                   `thenM` \ poly_id ->
+               -- poly_id may have a polymorphic type
+               -- but mono_ty is just a monomorphic type variable
+         tcSubExp mono_ty (idType poly_id)     `thenM` \ co_fn ->
+         returnM (co_fn <$> HsVar poly_id) 
+
        -- ExprStmt
 tcStmtAndThen combine do_or_lc m_ty@(m, res_elt_ty) stmt@(ExprStmt exp _ locn) thing_inside
-  = setErrCtxt (stmtCtxt do_or_lc stmt) (
+  = addErrCtxt (stmtCtxt do_or_lc stmt) (
        if isDoExpr do_or_lc then
                newTyVarTy openTypeKind         `thenM` \ any_ty ->
                tcMonoExpr exp (m any_ty)       `thenM` \ exp' ->
@@ -471,7 +502,7 @@ tcStmtAndThen combine do_or_lc m_ty@(m, res_elt_ty) stmt@(ExprStmt exp _ locn) t
 
        -- Result statements
 tcStmtAndThen combine do_or_lc m_ty@(m, res_elt_ty) stmt@(ResultStmt exp locn) thing_inside
-  = setErrCtxt (stmtCtxt do_or_lc stmt) (
+  = addErrCtxt (resCtxt do_or_lc stmt) (
        if isDoExpr do_or_lc then
                tcMonoExpr exp (m res_elt_ty)
        else
@@ -484,9 +515,8 @@ tcStmtAndThen combine do_or_lc m_ty@(m, res_elt_ty) stmt@(ResultStmt exp locn) t
 
 
 ------------------------------
-glue_binds combine is_rec binds thing 
-  | nullMonoBinds binds = thing
-  | otherwise          = combine (LetStmt (mkMonoBind binds [] is_rec)) thing
+glue_binds combine EmptyBinds  thing = thing
+glue_binds combine other_binds thing = combine (LetStmt other_binds) thing
 \end{code}
 
 
@@ -511,8 +541,9 @@ sameNoOfArgs matches = isSingleton (nub (map args_in_match matches))
 varyingArgsErr name matches
   = sep [ptext SLIT("Varying number of arguments for function"), quotes (ppr name)]
 
-matchCtxt ctxt  match  = hang (pprMatchContext ctxt     <> colon) 4 (pprMatch ctxt match)
-stmtCtxt do_or_lc stmt = hang (pprMatchContext do_or_lc <> colon) 4 (ppr stmt)
+matchCtxt ctxt  match  = hang (ptext SLIT("In") <+> pprMatchContext ctxt <> colon) 4 (pprMatch ctxt match)
+stmtCtxt do_or_lc stmt = hang (ptext SLIT("In") <+> pprStmtContext do_or_lc <> colon) 4 (ppr stmt)
+resCtxt  do_or_lc stmt = hang (ptext SLIT("In") <+> pprStmtResultContext do_or_lc <> colon) 4 (ppr stmt)
 
 sigPatCtxt bound_tvs bound_ids match_ty tidy_env 
   = zonkTcType match_ty                `thenM` \ match_ty' ->