+
+--------------------------------
+-- Monad comprehensions
+
+tcMcStmt :: TcStmtChecker
+
+-- Generators for monad comprehensions ( pat <- rhs )
+--
+-- [ body | q <- gen ] -> gen :: m a
+-- q :: a
+--
+tcMcStmt ctxt (BindStmt pat rhs bind_op fail_op) res_ty thing_inside
+ = do { rhs_ty <- newFlexiTyVarTy liftedTypeKind
+ ; pat_ty <- newFlexiTyVarTy liftedTypeKind
+ ; new_res_ty <- newFlexiTyVarTy liftedTypeKind
+ ; bind_op' <- tcSyntaxOp MCompOrigin bind_op
+ (mkFunTys [rhs_ty, mkFunTy pat_ty new_res_ty] res_ty)
+
+ -- If (but only if) the pattern can fail,
+ -- typecheck the 'fail' operator
+ ; fail_op' <- if isIrrefutableHsPat pat
+ then return noSyntaxExpr
+ else tcSyntaxOp MCompOrigin fail_op (mkFunTy stringTy new_res_ty)
+
+ ; rhs' <- tcMonoExprNC rhs rhs_ty
+ ; (pat', thing) <- tcPat (StmtCtxt ctxt) pat pat_ty $
+ thing_inside new_res_ty
+
+ ; return (BindStmt pat' rhs' bind_op' fail_op', thing) }
+
+-- Boolean expressions.
+--
+-- [ body | stmts, expr ] -> expr :: m Bool
+--
+tcMcStmt _ (ExprStmt rhs then_op guard_op _) res_ty thing_inside
+ = do { -- Deal with rebindable syntax:
+ -- guard_op :: test_ty -> rhs_ty
+ -- then_op :: rhs_ty -> new_res_ty -> res_ty
+ -- Where test_ty is, for example, Bool
+ test_ty <- newFlexiTyVarTy liftedTypeKind
+ ; rhs_ty <- newFlexiTyVarTy liftedTypeKind
+ ; new_res_ty <- newFlexiTyVarTy liftedTypeKind
+ ; rhs' <- tcMonoExpr rhs test_ty
+ ; guard_op' <- tcSyntaxOp MCompOrigin guard_op
+ (mkFunTy test_ty rhs_ty)
+ ; then_op' <- tcSyntaxOp MCompOrigin then_op
+ (mkFunTys [rhs_ty, new_res_ty] res_ty)
+ ; thing <- thing_inside new_res_ty
+ ; return (ExprStmt rhs' then_op' guard_op' rhs_ty, thing) }
+
+-- Transform statements.
+--
+-- [ body | stmts, then f ] -> f :: forall a. m a -> m a
+-- [ body | stmts, then f by e ] -> f :: forall a. (a -> t) -> m a -> m a
+--
+tcMcStmt ctxt (TransformStmt stmts binders usingExpr maybeByExpr return_op bind_op) elt_ty thing_inside
+ = do {
+ -- We don't know the types of binders yet, so we use this dummy and
+ -- later unify this type with the `m_bndr_ty`
+ ty_dummy <- newFlexiTyVarTy liftedTypeKind
+
+ ; (stmts', (binders', usingExpr', maybeByExpr', return_op', bind_op', thing)) <-
+ tcStmts (TransformStmtCtxt ctxt) tcMcStmt stmts ty_dummy $ \elt_ty' -> do
+ { (_, (m_ty, _)) <- matchExpectedAppTy elt_ty'
+ ; (usingExpr', maybeByExpr') <-
+ case maybeByExpr of
+ Nothing -> do
+ -- We must validate that usingExpr :: forall a. m a -> m a
+ let using_ty = mkForAllTy alphaTyVar $
+ (m_ty `mkAppTy` alphaTy)
+ `mkFunTy`
+ (m_ty `mkAppTy` alphaTy)
+ usingExpr' <- tcPolyExpr usingExpr using_ty
+ return (usingExpr', Nothing)
+ Just byExpr -> do
+ -- We must infer a type such that e :: t and then check that
+ -- usingExpr :: forall a. (a -> t) -> m a -> m a
+ (byExpr', tTy) <- tcInferRhoNC byExpr
+ let using_ty = mkForAllTy alphaTyVar $
+ (alphaTy `mkFunTy` tTy)
+ `mkFunTy`
+ (m_ty `mkAppTy` alphaTy)
+ `mkFunTy`
+ (m_ty `mkAppTy` alphaTy)
+ usingExpr' <- tcPolyExpr usingExpr using_ty
+ return (usingExpr', Just byExpr')
+
+ ; bndr_ids <- tcLookupLocalIds binders
+
+ -- `return` and `>>=` are used to pass around/modify our
+ -- binders, so we know their types:
+ --
+ -- return :: (a,b,c,..) -> m (a,b,c,..)
+ -- (>>=) :: m (a,b,c,..)
+ -- -> ( (a,b,c,..) -> m (a,b,c,..) )
+ -- -> m (a,b,c,..)
+ --
+ ; let bndr_ty = mkChunkified mkBoxedTupleTy $ map idType bndr_ids
+ m_bndr_ty = m_ty `mkAppTy` bndr_ty
+
+ ; return_op' <- tcSyntaxOp MCompOrigin return_op
+ (bndr_ty `mkFunTy` m_bndr_ty)
+
+ ; bind_op' <- tcSyntaxOp MCompOrigin bind_op $
+ m_bndr_ty `mkFunTy` (bndr_ty `mkFunTy` elt_ty)
+ `mkFunTy` elt_ty
+
+ -- Unify types of the inner comprehension and the binders type
+ ; _ <- unifyType elt_ty' m_bndr_ty
+
+ -- Typecheck the `thing` with out old type (which is the type
+ -- of the final result of our comprehension)
+ ; thing <- thing_inside elt_ty
+
+ ; return (bndr_ids, usingExpr', maybeByExpr', return_op', bind_op', thing) }
+
+ ; return (TransformStmt stmts' binders' usingExpr' maybeByExpr' return_op' bind_op', thing) }
+
+-- Grouping statements
+--
+-- [ body | stmts, then group by e ]
+-- -> e :: t
+-- [ body | stmts, then group by e using f ]
+-- -> e :: t
+-- f :: forall a. (a -> t) -> m a -> m (m a)
+-- [ body | stmts, then group using f ]
+-- -> f :: forall a. m a -> m (m a)
+--
+tcMcStmt ctxt (GroupStmt stmts bindersMap by using return_op bind_op liftM_op) elt_ty thing_inside
+ = do { let (bndr_names, m_bndr_names) = unzip bindersMap
+
+ ; (_,(m_ty,_)) <- matchExpectedAppTy elt_ty
+ ; let alphaMTy = m_ty `mkAppTy` alphaTy
+ alphaMMTy = m_ty `mkAppTy` alphaMTy
+
+ -- We don't know the type of the bindings yet. It's not elt_ty!
+ ; bndr_ty_dummy <- newFlexiTyVarTy liftedTypeKind
+
+ ; (stmts', (bndr_ids, by', using_ty, return_op', bind_op')) <-
+ tcStmts (TransformStmtCtxt ctxt) tcMcStmt stmts bndr_ty_dummy $ \elt_ty' -> do
+ { (by', using_ty) <-
+ case by of
+ Nothing -> -- check that using :: forall a. m a -> m (m a)
+ return (Nothing, mkForAllTy alphaTyVar $
+ alphaMTy `mkFunTy` alphaMMTy)
+
+ Just by_e -> -- check that using :: forall a. (a -> t) -> m a -> m (m a)
+ -- where by :: t
+ do { (by_e', t_ty) <- tcInferRhoNC by_e
+ ; return (Just by_e', mkForAllTy alphaTyVar $
+ (alphaTy `mkFunTy` t_ty)
+ `mkFunTy` alphaMTy
+ `mkFunTy` alphaMMTy) }
+
+
+ -- Find the Ids (and hence types) of all old binders
+ ; bndr_ids <- tcLookupLocalIds bndr_names
+
+ -- 'return' is only used for the binders, so we know its type.
+ --
+ -- return :: (a,b,c,..) -> m (a,b,c,..)
+ --
+ ; let bndr_ty = mkChunkified mkBoxedTupleTy $ map idType bndr_ids
+ m_bndr_ty = m_ty `mkAppTy` bndr_ty
+ ; return_op' <- tcSyntaxOp MCompOrigin return_op $ bndr_ty `mkFunTy` m_bndr_ty
+
+ -- '>>=' is used to pass the grouped binders to the rest of the
+ -- comprehension.
+ --
+ -- (>>=) :: m (m a, m b, m c, ..)
+ -- -> ( (m a, m b, m c, ..) -> new_elt_ty )
+ -- -> elt_ty
+ --
+ ; let bndr_m_ty = mkChunkified mkBoxedTupleTy $ map (mkAppTy m_ty . idType) bndr_ids
+ m_bndr_m_ty = m_ty `mkAppTy` bndr_m_ty
+ ; new_elt_ty <- newFlexiTyVarTy liftedTypeKind
+ ; bind_op' <- tcSyntaxOp MCompOrigin bind_op $
+ m_bndr_m_ty `mkFunTy` (bndr_m_ty `mkFunTy` new_elt_ty)
+ `mkFunTy` elt_ty
+
+ -- Finally make sure the type of the inner comprehension
+ -- represents the types of our binders
+ ; _ <- unifyType elt_ty' m_bndr_ty
+
+ ; return (bndr_ids, by', using_ty, return_op', bind_op') }
+
+ ; let mk_m_bndr :: Name -> TcId -> TcId
+ mk_m_bndr m_bndr_name bndr_id =
+ mkLocalId m_bndr_name (m_ty `mkAppTy` idType bndr_id)
+
+ -- Ensure that every old binder of type `b` is linked up with its
+ -- new binder which should have type `m b`
+ m_bndr_ids = zipWith mk_m_bndr m_bndr_names bndr_ids
+ bindersMap' = bndr_ids `zip` m_bndr_ids
+
+ -- See Note [GroupStmt binder map] in HsExpr
+
+ ; using' <- case using of
+ Left e -> do { e' <- tcPolyExpr e using_ty; return (Left e') }
+ Right e -> do { e' <- tcPolyExpr (noLoc e) using_ty; return (Right (unLoc e')) }
+
+ -- Type check 'liftM' with 'forall a b. (a -> b) -> m_ty a -> m_ty b'
+ ; liftM_op' <- fmap unLoc . tcPolyExpr (noLoc liftM_op) $
+ mkForAllTy alphaTyVar $ mkForAllTy betaTyVar $
+ (alphaTy `mkFunTy` betaTy)
+ `mkFunTy`
+ (m_ty `mkAppTy` alphaTy)
+ `mkFunTy`
+ (m_ty `mkAppTy` betaTy)
+
+ -- Type check the thing in the environment with these new binders and
+ -- return the result
+ ; thing <- tcExtendIdEnv m_bndr_ids (thing_inside elt_ty)
+
+ ; return (GroupStmt stmts' bindersMap' by' using' return_op' bind_op' liftM_op', thing) }
+
+-- Typecheck `ParStmt`. See `tcLcStmt` for more informations about typechecking
+-- of `ParStmt`s.
+--
+-- Note: The `mzip` function will get typechecked via:
+--
+-- ParStmt [st1::t1, st2::t2, st3::t3]
+--
+-- mzip :: m st1
+-- -> (m st2 -> m st3 -> m (st2, st3)) -- recursive call
+-- -> m (st1, (st2, st3))
+--
+tcMcStmt ctxt (ParStmt bndr_stmts_s mzip_op bind_op return_op) elt_ty thing_inside
+ = do { (_,(m_ty,_)) <- matchExpectedAppTy elt_ty
+ ; (pairs', thing) <- loop m_ty bndr_stmts_s
+
+ ; let mzip_ty = mkForAllTys [alphaTyVar, betaTyVar] $
+ (m_ty `mkAppTy` alphaTy)
+ `mkFunTy`
+ (m_ty `mkAppTy` betaTy)
+ `mkFunTy`
+ (m_ty `mkAppTy` mkBoxedTupleTy [alphaTy, betaTy])
+ ; mzip_op' <- unLoc `fmap` tcPolyExpr (noLoc mzip_op) mzip_ty
+
+ -- Typecheck bind:
+ ; let tys = map (mkChunkified mkBoxedTupleTy . map idType . snd) pairs'
+ tuple_ty = mk_tuple_ty tys
+
+ ; bind_op' <- tcSyntaxOp MCompOrigin bind_op $
+ (m_ty `mkAppTy` tuple_ty)
+ `mkFunTy`
+ (tuple_ty `mkFunTy` elt_ty)
+ `mkFunTy`
+ elt_ty
+
+ ; return_op' <- fmap unLoc . tcPolyExpr (noLoc return_op) $
+ mkForAllTy alphaTyVar $
+ alphaTy `mkFunTy` (m_ty `mkAppTy` alphaTy)
+ ; return (ParStmt pairs' mzip_op' bind_op' return_op', thing) }
+
+ where mk_tuple_ty tys = foldr (\tn tm -> mkBoxedTupleTy [tn, tm]) (last tys) (init tys)
+
+ -- loop :: Type -- m_ty
+ -- -> [([LStmt Name], [Name])]
+ -- -> TcM ([([LStmt TcId], [TcId])], thing)
+ loop _ [] = do { thing <- thing_inside elt_ty
+ ; return ([], thing) } -- matching in the branches
+
+ loop m_ty ((stmts, names) : pairs)
+ = do { -- type dummy since we don't know all binder types yet
+ ty_dummy <- newFlexiTyVarTy liftedTypeKind
+ ; (stmts', (ids, pairs', thing))
+ <- tcStmts ctxt tcMcStmt stmts ty_dummy $ \elt_ty' ->
+ do { ids <- tcLookupLocalIds names
+ ; _ <- unifyType elt_ty' (m_ty `mkAppTy` (mkChunkified mkBoxedTupleTy) (map idType ids))
+ ; (pairs', thing) <- loop m_ty pairs
+ ; return (ids, pairs', thing) }
+ ; return ( (stmts', ids) : pairs', thing ) }
+
+tcMcStmt _ stmt _ _
+ = pprPanic "tcMcStmt: unexpected Stmt" (ppr stmt)
+
+-- Typecheck 'body' with type 'a' instead of 'm a' like the rest of the
+-- statements, ignore the second type argument coming from the tcStmts loop
+tcMcBody :: LHsExpr Name
+ -> SyntaxExpr Name
+ -> TcRhoType
+ -> TcM (LHsExpr TcId, SyntaxExpr TcId)
+tcMcBody body return_op res_ty
+ = do { (_, (_, a_ty)) <- matchExpectedAppTy res_ty
+ ; body' <- tcMonoExpr body a_ty
+ ; return_op' <- tcSyntaxOp MCompOrigin return_op
+ (a_ty `mkFunTy` res_ty)
+ ; return (body', return_op')
+ }
+
+