Minor refactoring
[ghc-hetmet.git] / compiler / typecheck / TcMatches.lhs
index 40e1ca0..3457f32 100644 (file)
@@ -7,12 +7,13 @@ TcMatches: Typecheck some @Matches@
 
 \begin{code}
 module TcMatches ( tcMatchesFun, tcGRHSsPat, tcMatchesCase, tcMatchLambda,
-                  matchCtxt, TcMatchCtxt(..), 
+                  TcMatchCtxt(..), 
                   tcStmts, tcDoStmts, tcBody,
                   tcDoStmt, tcMDoStmt, tcGuardStmt
        ) where
 
-import {-# SOURCE #-}  TcExpr( tcSyntaxOp, tcInferRho, tcMonoExpr, tcPolyExpr )
+import {-# SOURCE #-}  TcExpr( tcSyntaxOp, tcInferRhoNC, 
+                                tcMonoExpr, tcMonoExprNC, tcPolyExpr )
 
 import HsSyn
 import TcRnMonad
@@ -36,6 +37,8 @@ import SrcLoc
 import FastString
 
 import Control.Monad
+
+#include "HsVersions.h"
 \end{code}
 
 %************************************************************************
@@ -71,7 +74,7 @@ tcMatchesFun fun_name inf matches exp_ty
                -- This is one of two places places we call subFunTys
                -- The point is that if expected_y is a "hole", we want 
                -- to make pat_tys and rhs_ty as "holes" too.
-       ; subFunTys doc n_pats exp_ty     $ \ pat_tys rhs_ty -> 
+       ; subFunTys doc n_pats exp_ty (Just (FunSigCtxt fun_name)) $ \ pat_tys rhs_ty -> 
          tcMatches match_ctxt pat_tys rhs_ty matches
        }
   where
@@ -92,16 +95,24 @@ tcMatchesCase :: TcMatchCtxt                -- Case context
              -> TcM (MatchGroup TcId)  -- Translated alternatives
 
 tcMatchesCase ctxt scrut_ty matches res_ty
+  | isEmptyMatchGroup matches
+  =      -- Allow empty case expressions
+    do {  -- Make sure we follow the invariant that res_ty is filled in
+          res_ty' <- refineBoxToTau res_ty
+       ;  return (MatchGroup [] (mkFunTys [scrut_ty] res_ty')) }
+
+  | otherwise
   = tcMatches ctxt [scrut_ty] res_ty matches
 
 tcMatchLambda :: MatchGroup Name -> BoxyRhoType -> TcM (HsWrapper, MatchGroup TcId)
 tcMatchLambda match res_ty 
-  = subFunTys doc n_pats res_ty        $ \ pat_tys rhs_ty ->
+  = subFunTys doc n_pats res_ty Nothing        $ \ pat_tys rhs_ty ->
     tcMatches match_ctxt pat_tys rhs_ty match
   where
     n_pats = matchGroupArity match
     doc = sep [ ptext (sLit "The lambda expression")
-                <+> quotes (pprSetDepth 1 $ pprMatches (LambdaExpr :: HsMatchContext Name) match),
+                <+> quotes (pprSetDepth (PartWay 1) $ 
+                             pprMatches (LambdaExpr :: HsMatchContext Name) match),
                        -- The pprSetDepth makes the abstraction print briefly
                ptext (sLit "has") <+> speakNOf n_pats (ptext (sLit "argument"))]
     match_ctxt = MC { mc_what = LambdaExpr,
@@ -141,7 +152,8 @@ data TcMatchCtxt    -- c.f. TcStmtCtxt, also in this module
                 -> TcM (LHsExpr TcId) }        
 
 tcMatches ctxt pat_tys rhs_ty (MatchGroup matches _)
-  = do { matches' <- mapM (tcMatch ctxt pat_tys rhs_ty) matches
+  = ASSERT( not (null matches) )       -- Ensure that rhs_ty is filled in
+    do { matches' <- mapM (tcMatch ctxt pat_tys rhs_ty) matches
        ; return (MatchGroup matches' (mkFunTys pat_tys rhs_ty)) }
 
 -------------
@@ -156,7 +168,7 @@ tcMatch ctxt pat_tys rhs_ty match
   where
     tc_match ctxt pat_tys rhs_ty match@(Match pats maybe_rhs_sig grhss)
       = add_match_ctxt match $
-        do { (pats', grhss') <- tcLamPats pats pat_tys rhs_ty $
+        do { (pats', grhss') <- tcPats (mc_what ctxt) pats pat_tys rhs_ty $
                                tc_grhss ctxt maybe_rhs_sig grhss
           ; return (Match pats' Nothing grhss') }
 
@@ -164,17 +176,15 @@ tcMatch ctxt pat_tys rhs_ty match
       = tcGRHSs ctxt grhss rhs_ty      -- No result signature
 
        -- Result type sigs are no longer supported
-    tc_grhss ctxt (Just res_sig) grhss rhs_ty
-      = do { addErr (ptext (sLit "Ignoring (deprecated) result type signature")
-                       <+> ppr res_sig)
-          ; tcGRHSs ctxt grhss rhs_ty }
+    tc_grhss _ (Just {}) _ _
+      = panic "tc_ghrss"       -- Rejected by renamer
 
        -- For (\x -> e), tcExpr has already said "In the expresssion \x->e"
        -- so we don't want to add "In the lambda abstraction \x->e"
     add_match_ctxt match thing_inside
        = case mc_what ctxt of
            LambdaExpr -> thing_inside
-           m_ctxt     -> addErrCtxt (matchCtxt m_ctxt match) thing_inside
+           m_ctxt     -> addErrCtxt (pprMatchInCtxt m_ctxt match) thing_inside
 
 -------------
 tcGRHSs :: TcMatchCtxt -> GRHSs Name -> BoxyRhoType
@@ -259,7 +269,7 @@ tcDoStmts ctxt _ _ _ = pprPanic "tcDoStmts" (pprStmtContext ctxt)
 tcBody :: LHsExpr Name -> BoxyRhoType -> TcM (LHsExpr TcId)
 tcBody body res_ty
   = do { traceTc (text "tcBody" <+> ppr res_ty)
-       ; body' <- tcPolyExpr body res_ty
+       ; body' <- tcMonoExpr body res_ty
        ; return body' 
         } 
 \end{code}
@@ -303,7 +313,7 @@ tcStmts ctxt stmt_chk (L loc (LetStmt binds) : stmts) res_ty thing_inside
 tcStmts ctxt stmt_chk (L loc stmt : stmts) res_ty thing_inside
   = do         { (stmt', (stmts', thing)) <- 
                setSrcSpan loc                          $
-               addErrCtxt (stmtCtxt ctxt stmt)         $
+               addErrCtxt (pprStmtInCtxt ctxt stmt)    $
                stmt_chk ctxt stmt res_ty               $ \ res_ty' ->
                popErrCtxt                              $
                tcStmts ctxt stmt_chk stmts res_ty'     $
@@ -318,9 +328,9 @@ tcGuardStmt _ (ExprStmt guard _ _) res_ty thing_inside
        ; thing  <- thing_inside res_ty
        ; return (ExprStmt guard' noSyntaxExpr boolTy, thing) }
 
-tcGuardStmt _ (BindStmt pat rhs _ _) res_ty thing_inside
-  = do { (rhs', rhs_ty) <- tcInferRho rhs
-       ; (pat', thing)  <- tcLamPat pat rhs_ty res_ty thing_inside
+tcGuardStmt ctxt (BindStmt pat rhs _ _) res_ty thing_inside
+  = do { (rhs', rhs_ty) <- tcInferRhoNC rhs    -- Stmt has a context already
+       ; (pat', thing)  <- tcPat (StmtCtxt ctxt) pat rhs_ty res_ty thing_inside
        ; return (BindStmt pat' rhs' noSyntaxExpr noSyntaxExpr, thing) }
 
 tcGuardStmt _ stmt _ _
@@ -334,10 +344,10 @@ tcLcStmt :: TyCon -- The list/Parray type constructor ([] or PArray)
         -> TcStmtChecker
 
 -- A generator, pat <- rhs
-tcLcStmt m_tc _ (BindStmt pat rhs _ _) res_ty thing_inside
+tcLcStmt m_tc ctxt (BindStmt pat rhs _ _) res_ty thing_inside
  = do  { (rhs', pat_ty) <- withBox liftedTypeKind $ \ ty ->
                            tcMonoExpr rhs (mkTyConApp m_tc [ty])
-       ; (pat', thing)  <- tcLamPat pat pat_ty res_ty thing_inside
+       ; (pat', thing)  <- tcPat (StmtCtxt ctxt) pat pat_ty res_ty thing_inside
        ; return (BindStmt pat' rhs' noSyntaxExpr noSyntaxExpr, thing) }
 
 -- A boolean guard
@@ -396,7 +406,7 @@ tcLcStmt m_tc ctxt (TransformStmt (stmts, binders) usingExpr maybeByExpr) elt_ty
                         return (usingExpr', Nothing)
                     Just byExpr -> do
                         -- We must infer a type such that e :: t and then check that usingExpr :: forall a. (a -> t) -> [a] -> [a]
-                        (byExpr', tTy) <- tcInferRho byExpr
+                        (byExpr', tTy) <- tcInferRhoNC byExpr
                         usingExpr' <- tcPolyExpr usingExpr (mkForAllTy alphaTyVar ((alphaTy `mkFunTy` tTy) `mkFunTy` (alphaListTy `mkFunTy` alphaListTy)))
                         return (usingExpr', Just byExpr')
             
@@ -420,7 +430,7 @@ tcLcStmt m_tc ctxt (GroupStmt (stmts, bindersMap) groupByClause) elt_ty thing_in
                             tcPolyExpr usingExpr (mkForAllTy alphaTyVar (alphaListTy `mkFunTy` alphaListListTy)) >>= (return . GroupByNothing)
                         GroupBySomething eitherUsingExpr byExpr -> do
                             -- We must infer a type such that byExpr :: t
-                            (byExpr', tTy) <- tcInferRho byExpr
+                            (byExpr', tTy) <- tcInferRhoNC byExpr
                             
                             -- If it exists, we then check that usingExpr :: forall a. (a -> t) -> [a] -> [[a]]
                             let expectedUsingType = mkForAllTy alphaTyVar ((alphaTy `mkFunTy` tTy) `mkFunTy` (alphaListTy `mkFunTy` alphaListListTy))
@@ -455,25 +465,23 @@ tcLcStmt _ _ stmt _ _
 
 tcDoStmt :: TcStmtChecker
 
-tcDoStmt _ (BindStmt pat rhs bind_op fail_op) res_ty thing_inside
-  = do { (rhs', rhs_ty) <- tcInferRho rhs
-               -- We should use type *inference* for the RHS computations, 
-                -- becuase of GADTs. 
-               --      do { pat <- rhs; <rest> }
-               -- is rather like
-               --      case rhs of { pat -> <rest> }
-               -- We do inference on rhs, so that information about its type 
-                -- can be refined when type-checking the pattern. 
+tcDoStmt ctxt (BindStmt pat rhs bind_op fail_op) res_ty thing_inside
+  = do {       -- Deal with rebindable syntax:
+               --       (>>=) :: rhs_ty -> (pat_ty -> new_res_ty) -> res_ty
+               -- This level of generality is needed for using do-notation
+               -- in full generality; see Trac #1537
+
+               -- I'd like to put this *after* the tcSyntaxOp 
+                -- (see Note [Treat rebindable syntax first], but that breaks 
+               -- the rigidity info for GADTs.  When we move to the new story
+                -- for GADTs, we can move this after tcSyntaxOp
+          (rhs', rhs_ty) <- tcInferRhoNC rhs
 
-       -- Deal with rebindable syntax:
-       --       (>>=) :: rhs_ty -> (pat_ty -> new_res_ty) -> res_ty
-       -- This level of generality is needed for using do-notation
-       -- in full generality; see Trac #1537
        ; ((bind_op', new_res_ty), pat_ty) <- 
             withBox liftedTypeKind $ \ pat_ty ->
             withBox liftedTypeKind $ \ new_res_ty ->
             tcSyntaxOp DoOrigin bind_op 
-                       (mkFunTys [rhs_ty, mkFunTy pat_ty new_res_ty] res_ty)
+                            (mkFunTys [rhs_ty, mkFunTy pat_ty new_res_ty] res_ty)
 
                -- If (but only if) the pattern can fail, 
                -- typecheck the 'fail' operator
@@ -481,31 +489,94 @@ tcDoStmt _ (BindStmt pat rhs bind_op fail_op) res_ty thing_inside
                      then return noSyntaxExpr
                      else tcSyntaxOp DoOrigin fail_op (mkFunTy stringTy new_res_ty)
 
-       ; (pat', thing) <- tcLamPat pat pat_ty new_res_ty thing_inside
+               -- We should typecheck the RHS *before* the pattern,
+                -- because of GADTs. 
+               --      do { pat <- rhs; <rest> }
+               -- is rather like
+               --      case rhs of { pat -> <rest> }
+               -- We do inference on rhs, so that information about its type 
+                -- can be refined when type-checking the pattern. 
+
+       ; (pat', thing) <- tcPat (StmtCtxt ctxt) pat pat_ty new_res_ty thing_inside
 
        ; return (BindStmt pat' rhs' bind_op' fail_op', thing) }
 
 
 tcDoStmt _ (ExprStmt rhs then_op _) res_ty thing_inside
-  = do { (rhs', rhs_ty) <- tcInferRho rhs
-
-       -- Deal with rebindable syntax; (>>) :: rhs_ty -> new_res_ty -> res_ty
-       ; (then_op', new_res_ty) <-
+  = do {       -- Deal with rebindable syntax; 
+                --   (>>) :: rhs_ty -> new_res_ty -> res_ty
+               -- See also Note [Treat rebindable syntax first]
+         ((then_op', rhs_ty), new_res_ty) <-
                withBox liftedTypeKind $ \ new_res_ty ->
+               withBox liftedTypeKind $ \ rhs_ty ->
                tcSyntaxOp DoOrigin then_op 
                           (mkFunTys [rhs_ty, new_res_ty] res_ty)
 
+        ; rhs' <- tcMonoExprNC rhs rhs_ty
        ; thing <- thing_inside new_res_ty
        ; return (ExprStmt rhs' then_op' rhs_ty, thing) }
 
-tcDoStmt ctxt (RecStmt {}) _ _
-  = failWithTc (ptext (sLit "Illegal 'rec' stmt in") <+> pprStmtContext ctxt)
-       -- This case can't be caught in the renamer
-       -- see RnExpr.checkRecStmt
+tcDoStmt ctxt (RecStmt { recS_stmts = stmts, recS_later_ids = later_names
+                       , recS_rec_ids = rec_names, recS_ret_fn = ret_op
+                       , recS_mfix_fn = mfix_op, recS_bind_fn = bind_op }) 
+         res_ty thing_inside
+  = do  { let tup_names = rec_names ++ filterOut (`elem` rec_names) later_names
+        ; tup_elt_tys <- newFlexiTyVarTys (length tup_names) liftedTypeKind
+        ; let tup_ids = zipWith mkLocalId tup_names tup_elt_tys
+             tup_ty  = mkBoxedTupleTy tup_elt_tys
+
+        ; tcExtendIdEnv tup_ids $ do
+        { ((stmts', (ret_op', tup_rets)), stmts_ty)
+                <- withBox liftedTypeKind $ \ stmts_ty ->
+                   tcStmts ctxt tcDoStmt stmts stmts_ty   $ \ inner_res_ty ->
+                   do { tup_rets <- zipWithM tc_ret tup_names tup_elt_tys
+                     ; ret_op' <- tcSyntaxOp DoOrigin ret_op (mkFunTy tup_ty inner_res_ty)
+                      ; return (ret_op', tup_rets) }
+
+       ; (mfix_op', mfix_res_ty) <- withBox liftedTypeKind $ \ mfix_res_ty ->
+                                     tcSyntaxOp DoOrigin mfix_op
+                                        (mkFunTy (mkFunTy tup_ty stmts_ty) mfix_res_ty)
+
+       ; (bind_op', new_res_ty) <- withBox liftedTypeKind $ \ new_res_ty ->
+                                   tcSyntaxOp DoOrigin bind_op 
+                                       (mkFunTys [mfix_res_ty, mkFunTy tup_ty new_res_ty] res_ty)
+
+        ; (thing,lie) <- getLIE (thing_inside new_res_ty)
+        ; lie_binds <- bindInstsOfLocalFuns lie tup_ids
+  
+        ; let rec_ids = takeList rec_names tup_ids
+       ; later_ids <- tcLookupLocalIds later_names
+       ; traceTc (text "tcdo" <+> vcat [ppr rec_ids <+> ppr (map idType rec_ids),
+                                         ppr later_ids <+> ppr (map idType later_ids)])
+        ; return (RecStmt { recS_stmts = stmts', recS_later_ids = later_ids
+                          , recS_rec_ids = rec_ids, recS_ret_fn = ret_op' 
+                          , recS_mfix_fn = mfix_op', recS_bind_fn = bind_op'
+                          , recS_rec_rets = tup_rets, recS_dicts = lie_binds }, thing)
+        }}
+  where 
+    -- Unify the types of the "final" Ids with those of "knot-tied" Ids
+    tc_ret rec_name mono_ty
+        = do { poly_id <- tcLookupId rec_name
+                -- poly_id may have a polymorphic type
+                -- but mono_ty is just a monomorphic type variable
+             ; co_fn <- tcSubExp DoOrigin (idType poly_id) mono_ty
+             ; return (mkHsWrap co_fn (HsVar poly_id)) }
 
 tcDoStmt _ stmt _ _
   = pprPanic "tcDoStmt: unexpected Stmt" (ppr stmt)
+\end{code}
+
+Note [Treat rebindable syntax first]
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+When typechecking
+       do { bar; ... } :: IO ()
+we want to typecheck 'bar' in the knowledge that it should be an IO thing,
+pushing info from the context into the RHS.  To do this, we check the
+rebindable syntax first, and push that information into (tcMonoExprNC rhs).
+Otherwise the error shows up when cheking the rebindable syntax, and
+the expected/inferred stuff is back to front (see Trac #3613).
 
+\begin{code}
 --------------------------------
 --     Mdo-notation
 -- The distinctive features here are
@@ -514,9 +585,9 @@ tcDoStmt _ stmt _ _
 
 tcMDoStmt :: (LHsExpr Name -> TcM (LHsExpr TcId, TcType))      -- RHS inference
          -> TcStmtChecker
-tcMDoStmt tc_rhs _ (BindStmt pat rhs _ _) res_ty thing_inside
+tcMDoStmt tc_rhs ctxt (BindStmt pat rhs _ _) res_ty thing_inside
   = do { (rhs', pat_ty) <- tc_rhs rhs
-       ; (pat', thing)  <- tcLamPat pat pat_ty res_ty thing_inside
+       ; (pat', thing)  <- tcPat (StmtCtxt ctxt) pat pat_ty res_ty thing_inside
        ; return (BindStmt pat' rhs' noSyntaxExpr noSyntaxExpr, thing) }
 
 tcMDoStmt tc_rhs _ (ExprStmt rhs _ _) res_ty thing_inside
@@ -524,7 +595,7 @@ tcMDoStmt tc_rhs _ (ExprStmt rhs _ _) res_ty thing_inside
        ; thing          <- thing_inside res_ty
        ; return (ExprStmt rhs' noSyntaxExpr elt_ty, thing) }
 
-tcMDoStmt tc_rhs ctxt (RecStmt stmts laterNames recNames _ _) res_ty thing_inside
+tcMDoStmt tc_rhs ctxt (RecStmt stmts laterNames recNames _ _ _ _ _) res_ty thing_inside
   = do { rec_tys <- newFlexiTyVarTys (length recNames) liftedTypeKind
        ; let rec_ids = zipWith mkLocalId recNames rec_tys
        ; tcExtendIdEnv rec_ids                 $ do
@@ -542,7 +613,7 @@ tcMDoStmt tc_rhs ctxt (RecStmt stmts laterNames recNames _ _) res_ty thing_insid
                --      (see note [RecStmt] in HsExpr)
        ; lie_binds <- bindInstsOfLocalFuns lie later_ids
   
-       ; return (RecStmt stmts' later_ids rec_ids rec_rets lie_binds, thing)
+       ; return (RecStmt stmts' later_ids rec_ids noSyntaxExpr noSyntaxExpr noSyntaxExpr rec_rets lie_binds, thing)
        }}
   where 
     -- Unify the types of the "final" Ids with those of "knot-tied" Ids
@@ -586,12 +657,3 @@ checkArgs fun (MatchGroup (match1:matches) _)
 checkArgs _ _ = panic "TcPat.checkArgs" -- Matches always non-empty
 \end{code}
 
-\begin{code}
-matchCtxt :: HsMatchContext Name -> Match Name -> SDoc
-matchCtxt ctxt match  = hang (ptext (sLit "In") <+> pprMatchContext ctxt <> colon) 
-                          4 (pprMatch ctxt match)
-
-stmtCtxt :: HsStmtContext Name -> StmtLR Name Name -> SDoc
-stmtCtxt ctxt stmt = hang (ptext (sLit "In a stmt of") <+> pprStmtContext ctxt <> colon)
-                       4 (ppr stmt)
-\end{code}