Add 'rec' to stmts in a 'do', and deprecate 'mdo'
[ghc-hetmet.git] / compiler / typecheck / TcMatches.lhs
index 37fbd19..37b8cbe 100644 (file)
@@ -12,7 +12,8 @@ module TcMatches ( tcMatchesFun, tcGRHSsPat, tcMatchesCase, tcMatchLambda,
                   tcDoStmt, tcMDoStmt, tcGuardStmt
        ) where
 
-import {-# SOURCE #-}  TcExpr( tcSyntaxOp, tcInferRho, tcMonoExpr, tcPolyExpr )
+import {-# SOURCE #-}  TcExpr( tcSyntaxOp, tcInferRhoNC, 
+                                tcMonoExpr, tcMonoExprNC, tcPolyExpr )
 
 import HsSyn
 import TcRnMonad
@@ -24,6 +25,7 @@ import TcType
 import TcBinds
 import TcUnify
 import TcSimplify
+import MkCore
 import Name
 import TysWiredIn
 import PrelNames
@@ -73,7 +75,7 @@ tcMatchesFun fun_name inf matches exp_ty
                -- This is one of two places places we call subFunTys
                -- The point is that if expected_y is a "hole", we want 
                -- to make pat_tys and rhs_ty as "holes" too.
-       ; subFunTys doc n_pats exp_ty     $ \ pat_tys rhs_ty -> 
+       ; subFunTys doc n_pats exp_ty (Just (FunSigCtxt fun_name)) $ \ pat_tys rhs_ty -> 
          tcMatches match_ctxt pat_tys rhs_ty matches
        }
   where
@@ -105,12 +107,13 @@ tcMatchesCase ctxt scrut_ty matches res_ty
 
 tcMatchLambda :: MatchGroup Name -> BoxyRhoType -> TcM (HsWrapper, MatchGroup TcId)
 tcMatchLambda match res_ty 
-  = subFunTys doc n_pats res_ty        $ \ pat_tys rhs_ty ->
+  = subFunTys doc n_pats res_ty Nothing        $ \ pat_tys rhs_ty ->
     tcMatches match_ctxt pat_tys rhs_ty match
   where
     n_pats = matchGroupArity match
     doc = sep [ ptext (sLit "The lambda expression")
-                <+> quotes (pprSetDepth 1 $ pprMatches (LambdaExpr :: HsMatchContext Name) match),
+                <+> quotes (pprSetDepth (PartWay 1) $ 
+                             pprMatches (LambdaExpr :: HsMatchContext Name) match),
                        -- The pprSetDepth makes the abstraction print briefly
                ptext (sLit "has") <+> speakNOf n_pats (ptext (sLit "argument"))]
     match_ctxt = MC { mc_what = LambdaExpr,
@@ -166,7 +169,7 @@ tcMatch ctxt pat_tys rhs_ty match
   where
     tc_match ctxt pat_tys rhs_ty match@(Match pats maybe_rhs_sig grhss)
       = add_match_ctxt match $
-        do { (pats', grhss') <- tcLamPats pats pat_tys rhs_ty $
+        do { (pats', grhss') <- tcPats (mc_what ctxt) pats pat_tys rhs_ty $
                                tc_grhss ctxt maybe_rhs_sig grhss
           ; return (Match pats' Nothing grhss') }
 
@@ -267,7 +270,7 @@ tcDoStmts ctxt _ _ _ = pprPanic "tcDoStmts" (pprStmtContext ctxt)
 tcBody :: LHsExpr Name -> BoxyRhoType -> TcM (LHsExpr TcId)
 tcBody body res_ty
   = do { traceTc (text "tcBody" <+> ppr res_ty)
-       ; body' <- tcPolyExpr body res_ty
+       ; body' <- tcMonoExpr body res_ty
        ; return body' 
         } 
 \end{code}
@@ -326,9 +329,9 @@ tcGuardStmt _ (ExprStmt guard _ _) res_ty thing_inside
        ; thing  <- thing_inside res_ty
        ; return (ExprStmt guard' noSyntaxExpr boolTy, thing) }
 
-tcGuardStmt _ (BindStmt pat rhs _ _) res_ty thing_inside
-  = do { (rhs', rhs_ty) <- tcInferRho rhs
-       ; (pat', thing)  <- tcLamPat pat rhs_ty res_ty thing_inside
+tcGuardStmt ctxt (BindStmt pat rhs _ _) res_ty thing_inside
+  = do { (rhs', rhs_ty) <- tcInferRhoNC rhs    -- Stmt has a context already
+       ; (pat', thing)  <- tcPat (StmtCtxt ctxt) pat rhs_ty res_ty thing_inside
        ; return (BindStmt pat' rhs' noSyntaxExpr noSyntaxExpr, thing) }
 
 tcGuardStmt _ stmt _ _
@@ -342,10 +345,10 @@ tcLcStmt :: TyCon -- The list/Parray type constructor ([] or PArray)
         -> TcStmtChecker
 
 -- A generator, pat <- rhs
-tcLcStmt m_tc _ (BindStmt pat rhs _ _) res_ty thing_inside
+tcLcStmt m_tc ctxt (BindStmt pat rhs _ _) res_ty thing_inside
  = do  { (rhs', pat_ty) <- withBox liftedTypeKind $ \ ty ->
                            tcMonoExpr rhs (mkTyConApp m_tc [ty])
-       ; (pat', thing)  <- tcLamPat pat pat_ty res_ty thing_inside
+       ; (pat', thing)  <- tcPat (StmtCtxt ctxt) pat pat_ty res_ty thing_inside
        ; return (BindStmt pat' rhs' noSyntaxExpr noSyntaxExpr, thing) }
 
 -- A boolean guard
@@ -404,7 +407,7 @@ tcLcStmt m_tc ctxt (TransformStmt (stmts, binders) usingExpr maybeByExpr) elt_ty
                         return (usingExpr', Nothing)
                     Just byExpr -> do
                         -- We must infer a type such that e :: t and then check that usingExpr :: forall a. (a -> t) -> [a] -> [a]
-                        (byExpr', tTy) <- tcInferRho byExpr
+                        (byExpr', tTy) <- tcInferRhoNC byExpr
                         usingExpr' <- tcPolyExpr usingExpr (mkForAllTy alphaTyVar ((alphaTy `mkFunTy` tTy) `mkFunTy` (alphaListTy `mkFunTy` alphaListTy)))
                         return (usingExpr', Just byExpr')
             
@@ -428,7 +431,7 @@ tcLcStmt m_tc ctxt (GroupStmt (stmts, bindersMap) groupByClause) elt_ty thing_in
                             tcPolyExpr usingExpr (mkForAllTy alphaTyVar (alphaListTy `mkFunTy` alphaListListTy)) >>= (return . GroupByNothing)
                         GroupBySomething eitherUsingExpr byExpr -> do
                             -- We must infer a type such that byExpr :: t
-                            (byExpr', tTy) <- tcInferRho byExpr
+                            (byExpr', tTy) <- tcInferRhoNC byExpr
                             
                             -- If it exists, we then check that usingExpr :: forall a. (a -> t) -> [a] -> [[a]]
                             let expectedUsingType = mkForAllTy alphaTyVar ((alphaTy `mkFunTy` tTy) `mkFunTy` (alphaListTy `mkFunTy` alphaListListTy))
@@ -463,25 +466,23 @@ tcLcStmt _ _ stmt _ _
 
 tcDoStmt :: TcStmtChecker
 
-tcDoStmt _ (BindStmt pat rhs bind_op fail_op) res_ty thing_inside
-  = do { (rhs', rhs_ty) <- tcInferRho rhs
-               -- We should use type *inference* for the RHS computations, 
-                -- becuase of GADTs. 
-               --      do { pat <- rhs; <rest> }
-               -- is rather like
-               --      case rhs of { pat -> <rest> }
-               -- We do inference on rhs, so that information about its type 
-                -- can be refined when type-checking the pattern. 
+tcDoStmt ctxt (BindStmt pat rhs bind_op fail_op) res_ty thing_inside
+  = do {       -- Deal with rebindable syntax:
+               --       (>>=) :: rhs_ty -> (pat_ty -> new_res_ty) -> res_ty
+               -- This level of generality is needed for using do-notation
+               -- in full generality; see Trac #1537
+
+               -- I'd like to put this *after* the tcSyntaxOp 
+                -- (see Note [Treat rebindable syntax first], but that breaks 
+               -- the rigidity info for GADTs.  When we move to the new story
+                -- for GADTs, we can move this after tcSyntaxOp
+          (rhs', rhs_ty) <- tcInferRhoNC rhs
 
-       -- Deal with rebindable syntax:
-       --       (>>=) :: rhs_ty -> (pat_ty -> new_res_ty) -> res_ty
-       -- This level of generality is needed for using do-notation
-       -- in full generality; see Trac #1537
        ; ((bind_op', new_res_ty), pat_ty) <- 
             withBox liftedTypeKind $ \ pat_ty ->
             withBox liftedTypeKind $ \ new_res_ty ->
             tcSyntaxOp DoOrigin bind_op 
-                       (mkFunTys [rhs_ty, mkFunTy pat_ty new_res_ty] res_ty)
+                            (mkFunTys [rhs_ty, mkFunTy pat_ty new_res_ty] res_ty)
 
                -- If (but only if) the pattern can fail, 
                -- typecheck the 'fail' operator
@@ -489,31 +490,94 @@ tcDoStmt _ (BindStmt pat rhs bind_op fail_op) res_ty thing_inside
                      then return noSyntaxExpr
                      else tcSyntaxOp DoOrigin fail_op (mkFunTy stringTy new_res_ty)
 
-       ; (pat', thing) <- tcLamPat pat pat_ty new_res_ty thing_inside
+               -- We should typecheck the RHS *before* the pattern,
+                -- because of GADTs. 
+               --      do { pat <- rhs; <rest> }
+               -- is rather like
+               --      case rhs of { pat -> <rest> }
+               -- We do inference on rhs, so that information about its type 
+                -- can be refined when type-checking the pattern. 
+
+       ; (pat', thing) <- tcPat (StmtCtxt ctxt) pat pat_ty new_res_ty thing_inside
 
        ; return (BindStmt pat' rhs' bind_op' fail_op', thing) }
 
 
 tcDoStmt _ (ExprStmt rhs then_op _) res_ty thing_inside
-  = do { (rhs', rhs_ty) <- tcInferRho rhs
-
-       -- Deal with rebindable syntax; (>>) :: rhs_ty -> new_res_ty -> res_ty
-       ; (then_op', new_res_ty) <-
+  = do {       -- Deal with rebindable syntax; 
+                --   (>>) :: rhs_ty -> new_res_ty -> res_ty
+               -- See also Note [Treat rebindable syntax first]
+         ((then_op', rhs_ty), new_res_ty) <-
                withBox liftedTypeKind $ \ new_res_ty ->
+               withBox liftedTypeKind $ \ rhs_ty ->
                tcSyntaxOp DoOrigin then_op 
                           (mkFunTys [rhs_ty, new_res_ty] res_ty)
 
+        ; rhs' <- tcMonoExprNC rhs rhs_ty
        ; thing <- thing_inside new_res_ty
        ; return (ExprStmt rhs' then_op' rhs_ty, thing) }
 
-tcDoStmt ctxt (RecStmt {}) _ _
-  = failWithTc (ptext (sLit "Illegal 'rec' stmt in") <+> pprStmtContext ctxt)
-       -- This case can't be caught in the renamer
-       -- see RnExpr.checkRecStmt
+tcDoStmt ctxt (RecStmt { recS_stmts = stmts, recS_later_ids = later_names
+                       , recS_rec_ids = rec_names, recS_ret_fn = ret_op
+                       , recS_mfix_fn = mfix_op, recS_bind_fn = bind_op }) 
+         res_ty thing_inside
+  = do  { let tup_names = rec_names ++ filterOut (`elem` rec_names) later_names
+        ; tup_elt_tys <- newFlexiTyVarTys (length tup_names) liftedTypeKind
+        ; let tup_ids = zipWith mkLocalId tup_names tup_elt_tys
+             tup_ty  = mkCoreTupTy tup_elt_tys
+
+        ; tcExtendIdEnv tup_ids $ do
+        { ((stmts', (ret_op', tup_rets)), stmts_ty)
+                <- withBox liftedTypeKind $ \ stmts_ty ->
+                   tcStmts ctxt tcDoStmt stmts stmts_ty   $ \ inner_res_ty ->
+                   do { tup_rets <- zipWithM tc_ret tup_names tup_elt_tys
+                     ; ret_op' <- tcSyntaxOp DoOrigin ret_op (mkFunTy tup_ty inner_res_ty)
+                      ; return (ret_op', tup_rets) }
+
+       ; (mfix_op', mfix_res_ty) <- withBox liftedTypeKind $ \ mfix_res_ty ->
+                                     tcSyntaxOp DoOrigin mfix_op
+                                        (mkFunTy (mkFunTy tup_ty stmts_ty) mfix_res_ty)
+
+       ; (bind_op', new_res_ty) <- withBox liftedTypeKind $ \ new_res_ty ->
+                                   tcSyntaxOp DoOrigin bind_op 
+                                       (mkFunTys [mfix_res_ty, mkFunTy tup_ty new_res_ty] res_ty)
+
+        ; (thing,lie) <- getLIE (thing_inside new_res_ty)
+        ; lie_binds <- bindInstsOfLocalFuns lie tup_ids
+  
+        ; let rec_ids = takeList rec_names tup_ids
+       ; later_ids <- tcLookupLocalIds later_names
+       ; traceTc (text "tcdo" <+> vcat [ppr rec_ids <+> ppr (map idType rec_ids),
+                                         ppr later_ids <+> ppr (map idType later_ids)])
+        ; return (RecStmt { recS_stmts = stmts', recS_later_ids = later_ids
+                          , recS_rec_ids = rec_ids, recS_ret_fn = ret_op' 
+                          , recS_mfix_fn = mfix_op', recS_bind_fn = bind_op'
+                          , recS_rec_rets = tup_rets, recS_dicts = lie_binds }, thing)
+        }}
+  where 
+    -- Unify the types of the "final" Ids with those of "knot-tied" Ids
+    tc_ret rec_name mono_ty
+        = do { poly_id <- tcLookupId rec_name
+                -- poly_id may have a polymorphic type
+                -- but mono_ty is just a monomorphic type variable
+             ; co_fn <- tcSubExp DoOrigin (idType poly_id) mono_ty
+             ; return (mkHsWrap co_fn (HsVar poly_id)) }
 
 tcDoStmt _ stmt _ _
   = pprPanic "tcDoStmt: unexpected Stmt" (ppr stmt)
+\end{code}
 
+Note [Treat rebindable syntax first]
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+When typechecking
+       do { bar; ... } :: IO ()
+we want to typecheck 'bar' in the knowledge that it should be an IO thing,
+pushing info from the context into the RHS.  To do this, we check the
+rebindable syntax first, and push that information into (tcMonoExprNC rhs).
+Otherwise the error shows up when cheking the rebindable syntax, and
+the expected/inferred stuff is back to front (see Trac #3613).
+
+\begin{code}
 --------------------------------
 --     Mdo-notation
 -- The distinctive features here are
@@ -522,9 +586,9 @@ tcDoStmt _ stmt _ _
 
 tcMDoStmt :: (LHsExpr Name -> TcM (LHsExpr TcId, TcType))      -- RHS inference
          -> TcStmtChecker
-tcMDoStmt tc_rhs _ (BindStmt pat rhs _ _) res_ty thing_inside
+tcMDoStmt tc_rhs ctxt (BindStmt pat rhs _ _) res_ty thing_inside
   = do { (rhs', pat_ty) <- tc_rhs rhs
-       ; (pat', thing)  <- tcLamPat pat pat_ty res_ty thing_inside
+       ; (pat', thing)  <- tcPat (StmtCtxt ctxt) pat pat_ty res_ty thing_inside
        ; return (BindStmt pat' rhs' noSyntaxExpr noSyntaxExpr, thing) }
 
 tcMDoStmt tc_rhs _ (ExprStmt rhs _ _) res_ty thing_inside
@@ -532,7 +596,7 @@ tcMDoStmt tc_rhs _ (ExprStmt rhs _ _) res_ty thing_inside
        ; thing          <- thing_inside res_ty
        ; return (ExprStmt rhs' noSyntaxExpr elt_ty, thing) }
 
-tcMDoStmt tc_rhs ctxt (RecStmt stmts laterNames recNames _ _) res_ty thing_inside
+tcMDoStmt tc_rhs ctxt (RecStmt stmts laterNames recNames _ _ _ _ _) res_ty thing_inside
   = do { rec_tys <- newFlexiTyVarTys (length recNames) liftedTypeKind
        ; let rec_ids = zipWith mkLocalId recNames rec_tys
        ; tcExtendIdEnv rec_ids                 $ do
@@ -550,7 +614,7 @@ tcMDoStmt tc_rhs ctxt (RecStmt stmts laterNames recNames _ _) res_ty thing_insid
                --      (see note [RecStmt] in HsExpr)
        ; lie_binds <- bindInstsOfLocalFuns lie later_ids
   
-       ; return (RecStmt stmts' later_ids rec_ids rec_rets lie_binds, thing)
+       ; return (RecStmt stmts' later_ids rec_ids noSyntaxExpr noSyntaxExpr noSyntaxExpr rec_rets lie_binds, thing)
        }}
   where 
     -- Unify the types of the "final" Ids with those of "knot-tied" Ids