Minor refactoring
[ghc-hetmet.git] / compiler / typecheck / TcMatches.lhs
index 4748901..3457f32 100644 (file)
@@ -12,7 +12,8 @@ module TcMatches ( tcMatchesFun, tcGRHSsPat, tcMatchesCase, tcMatchLambda,
                   tcDoStmt, tcMDoStmt, tcGuardStmt
        ) where
 
-import {-# SOURCE #-}  TcExpr( tcSyntaxOp, tcInferRhoNC, tcMonoExpr, tcPolyExpr )
+import {-# SOURCE #-}  TcExpr( tcSyntaxOp, tcInferRhoNC, 
+                                tcMonoExpr, tcMonoExprNC, tcPolyExpr )
 
 import HsSyn
 import TcRnMonad
@@ -110,7 +111,8 @@ tcMatchLambda match res_ty
   where
     n_pats = matchGroupArity match
     doc = sep [ ptext (sLit "The lambda expression")
-                <+> quotes (pprSetDepth 1 $ pprMatches (LambdaExpr :: HsMatchContext Name) match),
+                <+> quotes (pprSetDepth (PartWay 1) $ 
+                             pprMatches (LambdaExpr :: HsMatchContext Name) match),
                        -- The pprSetDepth makes the abstraction print briefly
                ptext (sLit "has") <+> speakNOf n_pats (ptext (sLit "argument"))]
     match_ctxt = MC { mc_what = LambdaExpr,
@@ -166,7 +168,7 @@ tcMatch ctxt pat_tys rhs_ty match
   where
     tc_match ctxt pat_tys rhs_ty match@(Match pats maybe_rhs_sig grhss)
       = add_match_ctxt match $
-        do { (pats', grhss') <- tcLamPats pats pat_tys rhs_ty $
+        do { (pats', grhss') <- tcPats (mc_what ctxt) pats pat_tys rhs_ty $
                                tc_grhss ctxt maybe_rhs_sig grhss
           ; return (Match pats' Nothing grhss') }
 
@@ -326,9 +328,9 @@ tcGuardStmt _ (ExprStmt guard _ _) res_ty thing_inside
        ; thing  <- thing_inside res_ty
        ; return (ExprStmt guard' noSyntaxExpr boolTy, thing) }
 
-tcGuardStmt _ (BindStmt pat rhs _ _) res_ty thing_inside
+tcGuardStmt ctxt (BindStmt pat rhs _ _) res_ty thing_inside
   = do { (rhs', rhs_ty) <- tcInferRhoNC rhs    -- Stmt has a context already
-       ; (pat', thing)  <- tcLamPat pat rhs_ty res_ty thing_inside
+       ; (pat', thing)  <- tcPat (StmtCtxt ctxt) pat rhs_ty res_ty thing_inside
        ; return (BindStmt pat' rhs' noSyntaxExpr noSyntaxExpr, thing) }
 
 tcGuardStmt _ stmt _ _
@@ -342,10 +344,10 @@ tcLcStmt :: TyCon -- The list/Parray type constructor ([] or PArray)
         -> TcStmtChecker
 
 -- A generator, pat <- rhs
-tcLcStmt m_tc _ (BindStmt pat rhs _ _) res_ty thing_inside
+tcLcStmt m_tc ctxt (BindStmt pat rhs _ _) res_ty thing_inside
  = do  { (rhs', pat_ty) <- withBox liftedTypeKind $ \ ty ->
                            tcMonoExpr rhs (mkTyConApp m_tc [ty])
-       ; (pat', thing)  <- tcLamPat pat pat_ty res_ty thing_inside
+       ; (pat', thing)  <- tcPat (StmtCtxt ctxt) pat pat_ty res_ty thing_inside
        ; return (BindStmt pat' rhs' noSyntaxExpr noSyntaxExpr, thing) }
 
 -- A boolean guard
@@ -463,25 +465,23 @@ tcLcStmt _ _ stmt _ _
 
 tcDoStmt :: TcStmtChecker
 
-tcDoStmt _ (BindStmt pat rhs bind_op fail_op) res_ty thing_inside
-  = do { (rhs', rhs_ty) <- tcInferRhoNC rhs
-               -- We should use type *inference* for the RHS computations, 
-                -- becuase of GADTs. 
-               --      do { pat <- rhs; <rest> }
-               -- is rather like
-               --      case rhs of { pat -> <rest> }
-               -- We do inference on rhs, so that information about its type 
-                -- can be refined when type-checking the pattern. 
+tcDoStmt ctxt (BindStmt pat rhs bind_op fail_op) res_ty thing_inside
+  = do {       -- Deal with rebindable syntax:
+               --       (>>=) :: rhs_ty -> (pat_ty -> new_res_ty) -> res_ty
+               -- This level of generality is needed for using do-notation
+               -- in full generality; see Trac #1537
+
+               -- I'd like to put this *after* the tcSyntaxOp 
+                -- (see Note [Treat rebindable syntax first], but that breaks 
+               -- the rigidity info for GADTs.  When we move to the new story
+                -- for GADTs, we can move this after tcSyntaxOp
+          (rhs', rhs_ty) <- tcInferRhoNC rhs
 
-       -- Deal with rebindable syntax:
-       --       (>>=) :: rhs_ty -> (pat_ty -> new_res_ty) -> res_ty
-       -- This level of generality is needed for using do-notation
-       -- in full generality; see Trac #1537
        ; ((bind_op', new_res_ty), pat_ty) <- 
             withBox liftedTypeKind $ \ pat_ty ->
             withBox liftedTypeKind $ \ new_res_ty ->
             tcSyntaxOp DoOrigin bind_op 
-                       (mkFunTys [rhs_ty, mkFunTy pat_ty new_res_ty] res_ty)
+                            (mkFunTys [rhs_ty, mkFunTy pat_ty new_res_ty] res_ty)
 
                -- If (but only if) the pattern can fail, 
                -- typecheck the 'fail' operator
@@ -489,31 +489,94 @@ tcDoStmt _ (BindStmt pat rhs bind_op fail_op) res_ty thing_inside
                      then return noSyntaxExpr
                      else tcSyntaxOp DoOrigin fail_op (mkFunTy stringTy new_res_ty)
 
-       ; (pat', thing) <- tcLamPat pat pat_ty new_res_ty thing_inside
+               -- We should typecheck the RHS *before* the pattern,
+                -- because of GADTs. 
+               --      do { pat <- rhs; <rest> }
+               -- is rather like
+               --      case rhs of { pat -> <rest> }
+               -- We do inference on rhs, so that information about its type 
+                -- can be refined when type-checking the pattern. 
+
+       ; (pat', thing) <- tcPat (StmtCtxt ctxt) pat pat_ty new_res_ty thing_inside
 
        ; return (BindStmt pat' rhs' bind_op' fail_op', thing) }
 
 
 tcDoStmt _ (ExprStmt rhs then_op _) res_ty thing_inside
-  = do { (rhs', rhs_ty) <- tcInferRhoNC rhs
-
-       -- Deal with rebindable syntax; (>>) :: rhs_ty -> new_res_ty -> res_ty
-       ; (then_op', new_res_ty) <-
+  = do {       -- Deal with rebindable syntax; 
+                --   (>>) :: rhs_ty -> new_res_ty -> res_ty
+               -- See also Note [Treat rebindable syntax first]
+         ((then_op', rhs_ty), new_res_ty) <-
                withBox liftedTypeKind $ \ new_res_ty ->
+               withBox liftedTypeKind $ \ rhs_ty ->
                tcSyntaxOp DoOrigin then_op 
                           (mkFunTys [rhs_ty, new_res_ty] res_ty)
 
+        ; rhs' <- tcMonoExprNC rhs rhs_ty
        ; thing <- thing_inside new_res_ty
        ; return (ExprStmt rhs' then_op' rhs_ty, thing) }
 
-tcDoStmt ctxt (RecStmt {}) _ _
-  = failWithTc (ptext (sLit "Illegal 'rec' stmt in") <+> pprStmtContext ctxt)
-       -- This case can't be caught in the renamer
-       -- see RnExpr.checkRecStmt
+tcDoStmt ctxt (RecStmt { recS_stmts = stmts, recS_later_ids = later_names
+                       , recS_rec_ids = rec_names, recS_ret_fn = ret_op
+                       , recS_mfix_fn = mfix_op, recS_bind_fn = bind_op }) 
+         res_ty thing_inside
+  = do  { let tup_names = rec_names ++ filterOut (`elem` rec_names) later_names
+        ; tup_elt_tys <- newFlexiTyVarTys (length tup_names) liftedTypeKind
+        ; let tup_ids = zipWith mkLocalId tup_names tup_elt_tys
+             tup_ty  = mkBoxedTupleTy tup_elt_tys
+
+        ; tcExtendIdEnv tup_ids $ do
+        { ((stmts', (ret_op', tup_rets)), stmts_ty)
+                <- withBox liftedTypeKind $ \ stmts_ty ->
+                   tcStmts ctxt tcDoStmt stmts stmts_ty   $ \ inner_res_ty ->
+                   do { tup_rets <- zipWithM tc_ret tup_names tup_elt_tys
+                     ; ret_op' <- tcSyntaxOp DoOrigin ret_op (mkFunTy tup_ty inner_res_ty)
+                      ; return (ret_op', tup_rets) }
+
+       ; (mfix_op', mfix_res_ty) <- withBox liftedTypeKind $ \ mfix_res_ty ->
+                                     tcSyntaxOp DoOrigin mfix_op
+                                        (mkFunTy (mkFunTy tup_ty stmts_ty) mfix_res_ty)
+
+       ; (bind_op', new_res_ty) <- withBox liftedTypeKind $ \ new_res_ty ->
+                                   tcSyntaxOp DoOrigin bind_op 
+                                       (mkFunTys [mfix_res_ty, mkFunTy tup_ty new_res_ty] res_ty)
+
+        ; (thing,lie) <- getLIE (thing_inside new_res_ty)
+        ; lie_binds <- bindInstsOfLocalFuns lie tup_ids
+  
+        ; let rec_ids = takeList rec_names tup_ids
+       ; later_ids <- tcLookupLocalIds later_names
+       ; traceTc (text "tcdo" <+> vcat [ppr rec_ids <+> ppr (map idType rec_ids),
+                                         ppr later_ids <+> ppr (map idType later_ids)])
+        ; return (RecStmt { recS_stmts = stmts', recS_later_ids = later_ids
+                          , recS_rec_ids = rec_ids, recS_ret_fn = ret_op' 
+                          , recS_mfix_fn = mfix_op', recS_bind_fn = bind_op'
+                          , recS_rec_rets = tup_rets, recS_dicts = lie_binds }, thing)
+        }}
+  where 
+    -- Unify the types of the "final" Ids with those of "knot-tied" Ids
+    tc_ret rec_name mono_ty
+        = do { poly_id <- tcLookupId rec_name
+                -- poly_id may have a polymorphic type
+                -- but mono_ty is just a monomorphic type variable
+             ; co_fn <- tcSubExp DoOrigin (idType poly_id) mono_ty
+             ; return (mkHsWrap co_fn (HsVar poly_id)) }
 
 tcDoStmt _ stmt _ _
   = pprPanic "tcDoStmt: unexpected Stmt" (ppr stmt)
+\end{code}
 
+Note [Treat rebindable syntax first]
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+When typechecking
+       do { bar; ... } :: IO ()
+we want to typecheck 'bar' in the knowledge that it should be an IO thing,
+pushing info from the context into the RHS.  To do this, we check the
+rebindable syntax first, and push that information into (tcMonoExprNC rhs).
+Otherwise the error shows up when cheking the rebindable syntax, and
+the expected/inferred stuff is back to front (see Trac #3613).
+
+\begin{code}
 --------------------------------
 --     Mdo-notation
 -- The distinctive features here are
@@ -522,9 +585,9 @@ tcDoStmt _ stmt _ _
 
 tcMDoStmt :: (LHsExpr Name -> TcM (LHsExpr TcId, TcType))      -- RHS inference
          -> TcStmtChecker
-tcMDoStmt tc_rhs _ (BindStmt pat rhs _ _) res_ty thing_inside
+tcMDoStmt tc_rhs ctxt (BindStmt pat rhs _ _) res_ty thing_inside
   = do { (rhs', pat_ty) <- tc_rhs rhs
-       ; (pat', thing)  <- tcLamPat pat pat_ty res_ty thing_inside
+       ; (pat', thing)  <- tcPat (StmtCtxt ctxt) pat pat_ty res_ty thing_inside
        ; return (BindStmt pat' rhs' noSyntaxExpr noSyntaxExpr, thing) }
 
 tcMDoStmt tc_rhs _ (ExprStmt rhs _ _) res_ty thing_inside
@@ -532,7 +595,7 @@ tcMDoStmt tc_rhs _ (ExprStmt rhs _ _) res_ty thing_inside
        ; thing          <- thing_inside res_ty
        ; return (ExprStmt rhs' noSyntaxExpr elt_ty, thing) }
 
-tcMDoStmt tc_rhs ctxt (RecStmt stmts laterNames recNames _ _) res_ty thing_inside
+tcMDoStmt tc_rhs ctxt (RecStmt stmts laterNames recNames _ _ _ _ _) res_ty thing_inside
   = do { rec_tys <- newFlexiTyVarTys (length recNames) liftedTypeKind
        ; let rec_ids = zipWith mkLocalId recNames rec_tys
        ; tcExtendIdEnv rec_ids                 $ do
@@ -550,7 +613,7 @@ tcMDoStmt tc_rhs ctxt (RecStmt stmts laterNames recNames _ _) res_ty thing_insid
                --      (see note [RecStmt] in HsExpr)
        ; lie_binds <- bindInstsOfLocalFuns lie later_ids
   
-       ; return (RecStmt stmts' later_ids rec_ids rec_rets lie_binds, thing)
+       ; return (RecStmt stmts' later_ids rec_ids noSyntaxExpr noSyntaxExpr noSyntaxExpr rec_rets lie_binds, thing)
        }}
   where 
     -- Unify the types of the "final" Ids with those of "knot-tied" Ids