Fix Trac #2111: improve error handling for 'rec' in do-notation
[ghc-hetmet.git] / compiler / typecheck / TcMatches.lhs
index bd83a55..452bae7 100644 (file)
@@ -6,6 +6,13 @@
 TcMatches: Typecheck some @Matches@
 
 \begin{code}
+{-# OPTIONS -w #-}
+-- The above warning supression flag is a temporary kludge.
+-- While working on this module you are encouraged to remove it and fix
+-- any warnings in the module. See
+--     http://hackage.haskell.org/trac/ghc/wiki/Commentary/CodingStyle#Warnings
+-- for details
+
 module TcMatches ( tcMatchesFun, tcGRHSsPat, tcMatchesCase, tcMatchLambda,
                   matchCtxt, TcMatchCtxt(..), 
                   tcStmts, tcDoStmts, tcBody,
@@ -32,8 +39,12 @@ import TysWiredIn
 import PrelNames
 import Id
 import TyCon
+import TysPrim
 import Outputable
+import Util
 import SrcLoc
+
+import Control.Monad
 \end{code}
 
 %************************************************************************
@@ -99,7 +110,7 @@ tcMatchLambda match res_ty
   where
     n_pats = matchGroupArity match
     doc = sep [ ptext SLIT("The lambda expression")
-                <+> quotes (pprSetDepth 1 $ pprMatches LambdaExpr match),
+                <+> quotes (pprSetDepth 1 $ pprMatches (LambdaExpr :: HsMatchContext Name) match),
                        -- The pprSetDepth makes the abstraction print briefly
                ptext SLIT("has") <+> speakNOf n_pats (ptext SLIT("argument"))]
     match_ctxt = MC { mc_what = LambdaExpr,
@@ -186,9 +197,9 @@ tcGRHSs :: TcMatchCtxt -> GRHSs Name -> (Refinement, BoxyRhoType)
 
 tcGRHSs ctxt (GRHSs grhss binds) res_ty
   = do { (binds', grhss') <- tcLocalBinds binds $
-                             mappM (wrapLocM (tcGRHS ctxt res_ty)) grhss
+                             mapM (wrapLocM (tcGRHS ctxt res_ty)) grhss
 
-       ; returnM (GRHSs grhss' binds') }
+       ; return (GRHSs grhss' binds') }
 
 -------------
 tcGRHS :: TcMatchCtxt -> (Refinement, BoxyRhoType) -> GRHS Name -> TcM (GRHS TcId)
@@ -215,29 +226,29 @@ tcDoStmts :: HsStmtContext Name
          -> BoxyRhoType
          -> TcM (HsExpr TcId)          -- Returns a HsDo
 tcDoStmts ListComp stmts body res_ty
-  = do { elt_ty <- boxySplitListTy res_ty
+  = do { (elt_ty, coi) <- boxySplitListTy res_ty
        ; (stmts', body') <- tcStmts ListComp (tcLcStmt listTyCon) stmts 
                                     (emptyRefinement,elt_ty) $
                             tcBody body
-       ; return (HsDo ListComp stmts' body' (mkListTy elt_ty)) }
+       ; return $ mkHsWrapCoI coi 
+                     (HsDo ListComp stmts' body' (mkListTy elt_ty)) }
 
 tcDoStmts PArrComp stmts body res_ty
-  = do { [elt_ty] <- boxySplitTyConApp parrTyCon res_ty
+  = do { (elt_ty, coi) <- boxySplitPArrTy res_ty
        ; (stmts', body') <- tcStmts PArrComp (tcLcStmt parrTyCon) stmts 
                                     (emptyRefinement, elt_ty) $
                             tcBody body
-       ; return (HsDo PArrComp stmts' body' (mkPArrTy elt_ty)) }
+       ; return $ mkHsWrapCoI coi 
+                     (HsDo PArrComp stmts' body' (mkPArrTy elt_ty)) }
 
 tcDoStmts DoExpr stmts body res_ty
-  = do { (m_ty, elt_ty) <- boxySplitAppTy res_ty
-       ; let res_ty' = mkAppTy m_ty elt_ty     -- The boxySplit consumes res_ty
-       ; (stmts', body') <- tcStmts DoExpr (tcDoStmt m_ty) stmts 
-                                    (emptyRefinement, res_ty') $
+  = do { (stmts', body') <- tcStmts DoExpr tcDoStmt stmts 
+                                    (emptyRefinement, res_ty) $
                             tcBody body
-       ; return (HsDo DoExpr stmts' body' res_ty') }
+       ; return (HsDo DoExpr stmts' body' res_ty) }
 
 tcDoStmts ctxt@(MDoExpr _) stmts body res_ty
-  = do { (m_ty, elt_ty) <- boxySplitAppTy res_ty
+  = do { ((m_ty, elt_ty), coi) <- boxySplitAppTy res_ty
        ; let res_ty' = mkAppTy m_ty elt_ty     -- The boxySplit consumes res_ty
              tc_rhs rhs = withBox liftedTypeKind $ \ pat_ty ->
                           tcMonoExpr rhs (mkAppTy m_ty pat_ty)
@@ -248,7 +259,9 @@ tcDoStmts ctxt@(MDoExpr _) stmts body res_ty
 
        ; let names = [mfixName, bindMName, thenMName, returnMName, failMName]
        ; insts <- mapM (newMethodFromName DoOrigin m_ty) names
-       ; return (HsDo (MDoExpr (names `zip` insts)) stmts' body' res_ty') }
+       ; return $ 
+            mkHsWrapCoI coi 
+              (HsDo (MDoExpr (names `zip` insts)) stmts' body' res_ty') }
 
 tcDoStmts ctxt stmts body res_ty = pprPanic "tcDoStmts" (pprStmtContext ctxt)
 
@@ -382,19 +395,80 @@ tcLcStmt m_tc ctxt (ParStmt bndr_stmts_s) elt_ty thing_inside
                      ; return (ids, pairs', thing) }
           ; return ( (stmts', ids) : pairs', thing ) }
 
+tcLcStmt m_tc ctxt (TransformStmt (stmts, binders) usingExpr maybeByExpr) elt_ty thing_inside = do
+    (stmts', (binders', usingExpr', maybeByExpr', thing)) <- 
+        tcStmts (TransformStmtCtxt ctxt) (tcLcStmt m_tc) stmts elt_ty $ \elt_ty' -> do
+            let alphaListTy = mkTyConApp m_tc [alphaTy]
+                    
+            (usingExpr', maybeByExpr') <- 
+                case maybeByExpr of
+                    Nothing -> do
+                        -- We must validate that usingExpr :: forall a. [a] -> [a]
+                        usingExpr' <- tcPolyExpr usingExpr (mkForAllTy alphaTyVar (alphaListTy `mkFunTy` alphaListTy))
+                        return (usingExpr', Nothing)
+                    Just byExpr -> do
+                        -- We must infer a type such that e :: t and then check that usingExpr :: forall a. (a -> t) -> [a] -> [a]
+                        (byExpr', tTy) <- tcInferRho byExpr
+                        usingExpr' <- tcPolyExpr usingExpr (mkForAllTy alphaTyVar ((alphaTy `mkFunTy` tTy) `mkFunTy` (alphaListTy `mkFunTy` alphaListTy)))
+                        return (usingExpr', Just byExpr')
+            
+            binders' <- tcLookupLocalIds binders
+            thing <- thing_inside elt_ty'
+            
+            return (binders', usingExpr', maybeByExpr', thing)
+
+    return (TransformStmt (stmts', binders') usingExpr' maybeByExpr', thing)
+
+tcLcStmt m_tc ctxt (GroupStmt (stmts, bindersMap) groupByClause) elt_ty thing_inside = do
+        (stmts', (bindersMap', groupByClause', thing)) <-
+            tcStmts (TransformStmtCtxt ctxt) (tcLcStmt m_tc) stmts elt_ty $ \elt_ty' -> do
+                let alphaListTy = mkTyConApp m_tc [alphaTy]
+                    alphaListListTy = mkTyConApp m_tc [alphaListTy]
+            
+                groupByClause' <- 
+                    case groupByClause of
+                        GroupByNothing usingExpr ->
+                            -- We must validate that usingExpr :: forall a. [a] -> [[a]]
+                            tcPolyExpr usingExpr (mkForAllTy alphaTyVar (alphaListTy `mkFunTy` alphaListListTy)) >>= (return . GroupByNothing)
+                        GroupBySomething eitherUsingExpr byExpr -> do
+                            -- We must infer a type such that byExpr :: t
+                            (byExpr', tTy) <- tcInferRho byExpr
+                            
+                            -- If it exists, we then check that usingExpr :: forall a. (a -> t) -> [a] -> [[a]]
+                            let expectedUsingType = mkForAllTy alphaTyVar ((alphaTy `mkFunTy` tTy) `mkFunTy` (alphaListTy `mkFunTy` alphaListListTy))
+                            eitherUsingExpr' <- 
+                                case eitherUsingExpr of
+                                    Left usingExpr  -> (tcPolyExpr usingExpr expectedUsingType) >>= (return . Left)
+                                    Right usingExpr -> (tcPolyExpr (noLoc usingExpr) expectedUsingType) >>= (return . Right . unLoc)
+                            return $ GroupBySomething eitherUsingExpr' byExpr'
+            
+                -- Find the IDs and types of all old binders
+                let (oldBinders, newBinders) = unzip bindersMap
+                oldBinders' <- tcLookupLocalIds oldBinders
+                
+                -- Ensure that every old binder of type b is linked up with its new binder which should have type [b]
+                let newBinders' = zipWith associateNewBinder oldBinders' newBinders
+            
+                -- Type check the thing in the environment with these new binders and return the result
+                thing <- tcExtendIdEnv newBinders' (thing_inside elt_ty')
+                return (zipEqual "tcLcStmt: Old and new binder lists were not of the same length" oldBinders' newBinders', groupByClause', thing)
+        
+        return (GroupStmt (stmts', bindersMap') groupByClause', thing)
+    where
+        associateNewBinder :: TcId -> Name -> TcId
+        associateNewBinder oldBinder newBinder = mkLocalId newBinder (mkTyConApp m_tc [idType oldBinder])
+    
 tcLcStmt m_tc ctxt stmt elt_ty thing_inside
   = pprPanic "tcLcStmt: unexpected Stmt" (ppr stmt)
-
+        
 --------------------------------
 --     Do-notation
 -- The main excitement here is dealing with rebindable syntax
 
-tcDoStmt :: TcType             -- Monad type,  m
-        -> TcStmtChecker
+tcDoStmt :: TcStmtChecker
 
-tcDoStmt m_ty ctxt (BindStmt pat rhs bind_op fail_op) reft_res_ty@(_,res_ty) thing_inside
-  = do { (rhs', pat_ty) <- withBox liftedTypeKind $ \ pat_ty -> 
-                           tcMonoExpr rhs (mkAppTy m_ty pat_ty)
+tcDoStmt ctxt (BindStmt pat rhs bind_op fail_op) (reft,res_ty) thing_inside
+  = do { (rhs', rhs_ty) <- tcInferRho rhs
                -- We should use type *inference* for the RHS computations, becuase of GADTs. 
                --      do { pat <- rhs; <rest> }
                -- is rather like
@@ -402,31 +476,45 @@ tcDoStmt m_ty ctxt (BindStmt pat rhs bind_op fail_op) reft_res_ty@(_,res_ty) thi
                -- We do inference on rhs, so that information about its type can be refined
                -- when type-checking the pattern. 
 
-       ; (pat', thing) <- tcLamPat pat pat_ty reft_res_ty thing_inside
+       -- Deal with rebindable syntax:
+       --       (>>=) :: rhs_ty -> (pat_ty -> new_res_ty) -> res_ty
+       -- This level of generality is needed for using do-notation
+       -- in full generality; see Trac #1537
+       ; ((bind_op', new_res_ty), pat_ty) <- 
+            withBox liftedTypeKind $ \ pat_ty ->
+            withBox liftedTypeKind $ \ new_res_ty ->
+            tcSyntaxOp DoOrigin bind_op 
+                       (mkFunTys [rhs_ty, mkFunTy pat_ty new_res_ty] res_ty)
 
-       -- Deal with rebindable syntax; (>>=) :: m a -> (a -> m b) -> m b
-       ; let bind_ty = mkFunTys [mkAppTy m_ty pat_ty, 
-                                 mkFunTy pat_ty res_ty] res_ty
-       ; bind_op' <- tcSyntaxOp DoOrigin bind_op bind_ty
                -- If (but only if) the pattern can fail, 
                -- typecheck the 'fail' operator
-       ; fail_op' <- if isIrrefutableHsPat pat' 
+       ; fail_op' <- if isIrrefutableHsPat pat 
                      then return noSyntaxExpr
-                     else tcSyntaxOp DoOrigin fail_op (mkFunTy stringTy res_ty)
+                     else tcSyntaxOp DoOrigin fail_op (mkFunTy stringTy new_res_ty)
+
+       ; (pat', thing) <- tcLamPat pat pat_ty (reft, new_res_ty) thing_inside
+
        ; return (BindStmt pat' rhs' bind_op' fail_op', thing) }
 
 
-tcDoStmt m_ty ctxt (ExprStmt rhs then_op _) reft_res_ty@(_,res_ty) thing_inside
-  = do {       -- Deal with rebindable syntax; (>>) :: m a -> m b -> m b
-         a_ty <- newFlexiTyVarTy liftedTypeKind
-       ; let rhs_ty  = mkAppTy m_ty a_ty
-             then_ty = mkFunTys [rhs_ty, res_ty] res_ty
-       ; then_op' <- tcSyntaxOp DoOrigin then_op then_ty
-       ; rhs' <- tcPolyExpr rhs rhs_ty
-       ; thing <- thing_inside reft_res_ty
+tcDoStmt ctxt (ExprStmt rhs then_op _) (reft,res_ty) thing_inside
+  = do { (rhs', rhs_ty) <- tcInferRho rhs
+
+       -- Deal with rebindable syntax; (>>) :: rhs_ty -> new_res_ty -> res_ty
+       ; (then_op', new_res_ty) <-
+               withBox liftedTypeKind $ \ new_res_ty ->
+               tcSyntaxOp DoOrigin then_op 
+                          (mkFunTys [rhs_ty, new_res_ty] res_ty)
+
+       ; thing <- thing_inside (reft, new_res_ty)
        ; return (ExprStmt rhs' then_op' rhs_ty, thing) }
 
-tcDoStmt m_ty ctxt stmt res_ty thing_inside
+tcDoStmt ctxt (RecStmt {}) res_ty thing_inside
+  = failWithTc (ptext SLIT("Illegal 'rec' stmt in") <+> pprStmtContext ctxt)
+       -- This case can't be caught in the renamer
+       -- see RnExpr.checkRecStmt
+
+tcDoStmt ctxt stmt res_ty thing_inside
   = pprPanic "tcDoStmt: unexpected Stmt" (ppr stmt)
 
 --------------------------------
@@ -473,7 +561,7 @@ tcMDoStmt tc_rhs ctxt (RecStmt stmts laterNames recNames _ _) res_ty thing_insid
        = do { poly_id <- tcLookupId rec_name
                -- poly_id may have a polymorphic type
                -- but mono_ty is just a monomorphic type variable
-            ; co_fn <- tcSubExp (idType poly_id) mono_ty
+            ; co_fn <- tcSubExp DoOrigin (idType poly_id) mono_ty
             ; return (mkHsWrap co_fn (HsVar poly_id)) }
 
 tcMDoStmt tc_rhs ctxt stmt res_ty thing_inside