Fix Trac #2111: improve error handling for 'rec' in do-notation
[ghc-hetmet.git] / compiler / typecheck / TcMatches.lhs
index da1d0e0..452bae7 100644 (file)
@@ -39,8 +39,12 @@ import TysWiredIn
 import PrelNames
 import Id
 import TyCon
+import TysPrim
 import Outputable
+import Util
 import SrcLoc
+
+import Control.Monad
 \end{code}
 
 %************************************************************************
@@ -193,9 +197,9 @@ tcGRHSs :: TcMatchCtxt -> GRHSs Name -> (Refinement, BoxyRhoType)
 
 tcGRHSs ctxt (GRHSs grhss binds) res_ty
   = do { (binds', grhss') <- tcLocalBinds binds $
-                             mappM (wrapLocM (tcGRHS ctxt res_ty)) grhss
+                             mapM (wrapLocM (tcGRHS ctxt res_ty)) grhss
 
-       ; returnM (GRHSs grhss' binds') }
+       ; return (GRHSs grhss' binds') }
 
 -------------
 tcGRHS :: TcMatchCtxt -> (Refinement, BoxyRhoType) -> GRHS Name -> TcM (GRHS TcId)
@@ -391,16 +395,79 @@ tcLcStmt m_tc ctxt (ParStmt bndr_stmts_s) elt_ty thing_inside
                      ; return (ids, pairs', thing) }
           ; return ( (stmts', ids) : pairs', thing ) }
 
+tcLcStmt m_tc ctxt (TransformStmt (stmts, binders) usingExpr maybeByExpr) elt_ty thing_inside = do
+    (stmts', (binders', usingExpr', maybeByExpr', thing)) <- 
+        tcStmts (TransformStmtCtxt ctxt) (tcLcStmt m_tc) stmts elt_ty $ \elt_ty' -> do
+            let alphaListTy = mkTyConApp m_tc [alphaTy]
+                    
+            (usingExpr', maybeByExpr') <- 
+                case maybeByExpr of
+                    Nothing -> do
+                        -- We must validate that usingExpr :: forall a. [a] -> [a]
+                        usingExpr' <- tcPolyExpr usingExpr (mkForAllTy alphaTyVar (alphaListTy `mkFunTy` alphaListTy))
+                        return (usingExpr', Nothing)
+                    Just byExpr -> do
+                        -- We must infer a type such that e :: t and then check that usingExpr :: forall a. (a -> t) -> [a] -> [a]
+                        (byExpr', tTy) <- tcInferRho byExpr
+                        usingExpr' <- tcPolyExpr usingExpr (mkForAllTy alphaTyVar ((alphaTy `mkFunTy` tTy) `mkFunTy` (alphaListTy `mkFunTy` alphaListTy)))
+                        return (usingExpr', Just byExpr')
+            
+            binders' <- tcLookupLocalIds binders
+            thing <- thing_inside elt_ty'
+            
+            return (binders', usingExpr', maybeByExpr', thing)
+
+    return (TransformStmt (stmts', binders') usingExpr' maybeByExpr', thing)
+
+tcLcStmt m_tc ctxt (GroupStmt (stmts, bindersMap) groupByClause) elt_ty thing_inside = do
+        (stmts', (bindersMap', groupByClause', thing)) <-
+            tcStmts (TransformStmtCtxt ctxt) (tcLcStmt m_tc) stmts elt_ty $ \elt_ty' -> do
+                let alphaListTy = mkTyConApp m_tc [alphaTy]
+                    alphaListListTy = mkTyConApp m_tc [alphaListTy]
+            
+                groupByClause' <- 
+                    case groupByClause of
+                        GroupByNothing usingExpr ->
+                            -- We must validate that usingExpr :: forall a. [a] -> [[a]]
+                            tcPolyExpr usingExpr (mkForAllTy alphaTyVar (alphaListTy `mkFunTy` alphaListListTy)) >>= (return . GroupByNothing)
+                        GroupBySomething eitherUsingExpr byExpr -> do
+                            -- We must infer a type such that byExpr :: t
+                            (byExpr', tTy) <- tcInferRho byExpr
+                            
+                            -- If it exists, we then check that usingExpr :: forall a. (a -> t) -> [a] -> [[a]]
+                            let expectedUsingType = mkForAllTy alphaTyVar ((alphaTy `mkFunTy` tTy) `mkFunTy` (alphaListTy `mkFunTy` alphaListListTy))
+                            eitherUsingExpr' <- 
+                                case eitherUsingExpr of
+                                    Left usingExpr  -> (tcPolyExpr usingExpr expectedUsingType) >>= (return . Left)
+                                    Right usingExpr -> (tcPolyExpr (noLoc usingExpr) expectedUsingType) >>= (return . Right . unLoc)
+                            return $ GroupBySomething eitherUsingExpr' byExpr'
+            
+                -- Find the IDs and types of all old binders
+                let (oldBinders, newBinders) = unzip bindersMap
+                oldBinders' <- tcLookupLocalIds oldBinders
+                
+                -- Ensure that every old binder of type b is linked up with its new binder which should have type [b]
+                let newBinders' = zipWith associateNewBinder oldBinders' newBinders
+            
+                -- Type check the thing in the environment with these new binders and return the result
+                thing <- tcExtendIdEnv newBinders' (thing_inside elt_ty')
+                return (zipEqual "tcLcStmt: Old and new binder lists were not of the same length" oldBinders' newBinders', groupByClause', thing)
+        
+        return (GroupStmt (stmts', bindersMap') groupByClause', thing)
+    where
+        associateNewBinder :: TcId -> Name -> TcId
+        associateNewBinder oldBinder newBinder = mkLocalId newBinder (mkTyConApp m_tc [idType oldBinder])
+    
 tcLcStmt m_tc ctxt stmt elt_ty thing_inside
   = pprPanic "tcLcStmt: unexpected Stmt" (ppr stmt)
-
+        
 --------------------------------
 --     Do-notation
 -- The main excitement here is dealing with rebindable syntax
 
 tcDoStmt :: TcStmtChecker
 
-tcDoStmt ctxt (BindStmt pat rhs bind_op fail_op) reft_res_ty@(_,res_ty) thing_inside
+tcDoStmt ctxt (BindStmt pat rhs bind_op fail_op) (reft,res_ty) thing_inside
   = do { (rhs', rhs_ty) <- tcInferRho rhs
                -- We should use type *inference* for the RHS computations, becuase of GADTs. 
                --      do { pat <- rhs; <rest> }
@@ -409,33 +476,44 @@ tcDoStmt ctxt (BindStmt pat rhs bind_op fail_op) reft_res_ty@(_,res_ty) thing_in
                -- We do inference on rhs, so that information about its type can be refined
                -- when type-checking the pattern. 
 
-       -- Deal with rebindable syntax; (>>=) :: rhs_ty -> (a -> res_ty) -> res_ty
-       ; (bind_op', pat_ty) <- 
+       -- Deal with rebindable syntax:
+       --       (>>=) :: rhs_ty -> (pat_ty -> new_res_ty) -> res_ty
+       -- This level of generality is needed for using do-notation
+       -- in full generality; see Trac #1537
+       ; ((bind_op', new_res_ty), pat_ty) <- 
             withBox liftedTypeKind $ \ pat_ty ->
+            withBox liftedTypeKind $ \ new_res_ty ->
             tcSyntaxOp DoOrigin bind_op 
-                       (mkFunTys [rhs_ty, mkFunTy pat_ty res_ty] res_ty)
+                       (mkFunTys [rhs_ty, mkFunTy pat_ty new_res_ty] res_ty)
 
                -- If (but only if) the pattern can fail, 
                -- typecheck the 'fail' operator
        ; fail_op' <- if isIrrefutableHsPat pat 
                      then return noSyntaxExpr
-                     else tcSyntaxOp DoOrigin fail_op (mkFunTy stringTy res_ty)
+                     else tcSyntaxOp DoOrigin fail_op (mkFunTy stringTy new_res_ty)
 
-       ; (pat', thing) <- tcLamPat pat pat_ty reft_res_ty thing_inside
+       ; (pat', thing) <- tcLamPat pat pat_ty (reft, new_res_ty) thing_inside
 
        ; return (BindStmt pat' rhs' bind_op' fail_op', thing) }
 
 
-tcDoStmt ctxt (ExprStmt rhs then_op _) reft_res_ty@(_,res_ty) thing_inside
+tcDoStmt ctxt (ExprStmt rhs then_op _) (reft,res_ty) thing_inside
   = do { (rhs', rhs_ty) <- tcInferRho rhs
 
-       -- Deal with rebindable syntax; (>>) :: rhs_ty -> res_ty -> res_ty
-       ; then_op' <- tcSyntaxOp DoOrigin then_op 
-                                (mkFunTys [rhs_ty, res_ty] res_ty)
+       -- Deal with rebindable syntax; (>>) :: rhs_ty -> new_res_ty -> res_ty
+       ; (then_op', new_res_ty) <-
+               withBox liftedTypeKind $ \ new_res_ty ->
+               tcSyntaxOp DoOrigin then_op 
+                          (mkFunTys [rhs_ty, new_res_ty] res_ty)
 
-       ; thing <- thing_inside reft_res_ty
+       ; thing <- thing_inside (reft, new_res_ty)
        ; return (ExprStmt rhs' then_op' rhs_ty, thing) }
 
+tcDoStmt ctxt (RecStmt {}) res_ty thing_inside
+  = failWithTc (ptext SLIT("Illegal 'rec' stmt in") <+> pprStmtContext ctxt)
+       -- This case can't be caught in the renamer
+       -- see RnExpr.checkRecStmt
+
 tcDoStmt ctxt stmt res_ty thing_inside
   = pprPanic "tcDoStmt: unexpected Stmt" (ppr stmt)