+
+--------------------------------
+-- Monad comprehensions
+
+tcMcStmt :: TcStmtChecker
+
+tcMcStmt _ (LastStmt body return_op) res_ty thing_inside
+ = do { a_ty <- newFlexiTyVarTy liftedTypeKind
+ ; return_op' <- tcSyntaxOp MCompOrigin return_op
+ (a_ty `mkFunTy` res_ty)
+ ; body' <- tcMonoExprNC body a_ty
+ ; thing <- thing_inside (panic "tcMcStmt: thing_inside")
+ ; return (LastStmt body' return_op', thing) }
+
+-- Generators for monad comprehensions ( pat <- rhs )
+--
+-- [ body | q <- gen ] -> gen :: m a
+-- q :: a
+--
+
+tcMcStmt ctxt (BindStmt pat rhs bind_op fail_op) res_ty thing_inside
+ = do { rhs_ty <- newFlexiTyVarTy liftedTypeKind
+ ; pat_ty <- newFlexiTyVarTy liftedTypeKind
+ ; new_res_ty <- newFlexiTyVarTy liftedTypeKind
+
+ -- (>>=) :: rhs_ty -> (pat_ty -> new_res_ty) -> res_ty
+ ; bind_op' <- tcSyntaxOp MCompOrigin bind_op
+ (mkFunTys [rhs_ty, mkFunTy pat_ty new_res_ty] res_ty)
+
+ -- If (but only if) the pattern can fail, typecheck the 'fail' operator
+ ; fail_op' <- if isIrrefutableHsPat pat
+ then return noSyntaxExpr
+ else tcSyntaxOp MCompOrigin fail_op (mkFunTy stringTy new_res_ty)
+
+ ; rhs' <- tcMonoExprNC rhs rhs_ty
+ ; (pat', thing) <- tcPat (StmtCtxt ctxt) pat pat_ty $
+ thing_inside new_res_ty
+
+ ; return (BindStmt pat' rhs' bind_op' fail_op', thing) }
+
+-- Boolean expressions.
+--
+-- [ body | stmts, expr ] -> expr :: m Bool
+--
+tcMcStmt _ (ExprStmt rhs then_op guard_op _) res_ty thing_inside
+ = do { -- Deal with rebindable syntax:
+ -- guard_op :: test_ty -> rhs_ty
+ -- then_op :: rhs_ty -> new_res_ty -> res_ty
+ -- Where test_ty is, for example, Bool
+ test_ty <- newFlexiTyVarTy liftedTypeKind
+ ; rhs_ty <- newFlexiTyVarTy liftedTypeKind
+ ; new_res_ty <- newFlexiTyVarTy liftedTypeKind
+ ; rhs' <- tcMonoExpr rhs test_ty
+ ; guard_op' <- tcSyntaxOp MCompOrigin guard_op
+ (mkFunTy test_ty rhs_ty)
+ ; then_op' <- tcSyntaxOp MCompOrigin then_op
+ (mkFunTys [rhs_ty, new_res_ty] res_ty)
+ ; thing <- thing_inside new_res_ty
+ ; return (ExprStmt rhs' then_op' guard_op' rhs_ty, thing) }
+
+-- Transform statements.
+--
+-- [ body | stmts, then f ] -> f :: forall a. m a -> m a
+-- [ body | stmts, then f by e ] -> f :: forall a. (a -> t) -> m a -> m a
+--
+tcMcStmt ctxt (TransformStmt stmts binders usingExpr maybeByExpr return_op bind_op) res_ty thing_inside
+ = do { let star_star_kind = liftedTypeKind `mkArrowKind` liftedTypeKind
+ ; m1_ty <- newFlexiTyVarTy star_star_kind
+ ; m2_ty <- newFlexiTyVarTy star_star_kind
+ ; n_ty <- newFlexiTyVarTy star_star_kind
+ ; tup_ty_var <- newFlexiTyVarTy liftedTypeKind
+ ; new_res_ty <- newFlexiTyVarTy liftedTypeKind
+ ; let m1_tup_ty = m1_ty `mkAppTy` tup_ty_var
+
+ -- 'stmts' returns a result of type (m1_ty tuple_ty),
+ -- typically something like [(Int,Bool,Int)]
+ -- We don't know what tuple_ty is yet, so we use a variable
+ ; (stmts', (binders', usingExpr', maybeByExpr', return_op', bind_op', thing)) <-
+ tcStmtsAndThen (TransformStmtCtxt ctxt) tcMcStmt stmts m1_tup_ty $ \res_ty' -> do
+ { (usingExpr', maybeByExpr') <-
+ case maybeByExpr of
+ Nothing -> do
+ -- We must validate that usingExpr :: forall a. m a -> m a
+ let using_ty = mkForAllTy alphaTyVar $
+ (m_ty `mkAppTy` alphaTy)
+ `mkFunTy`
+ (m_ty `mkAppTy` alphaTy)
+ usingExpr' <- tcPolyExpr usingExpr using_ty
+ return (usingExpr', Nothing)
+ Just byExpr -> do
+ -- We must infer a type such that e :: t and then check that
+ -- usingExpr :: forall a. (a -> t) -> m a -> m a
+ (byExpr', tTy) <- tcInferRhoNC byExpr
+ let using_ty = mkForAllTy alphaTyVar $
+ (alphaTy `mkFunTy` tTy)
+ `mkFunTy`
+ (m_ty `mkAppTy` alphaTy)
+ `mkFunTy`
+ (m_ty `mkAppTy` alphaTy)
+ usingExpr' <- tcPolyExpr usingExpr using_ty
+ return (usingExpr', Just byExpr')
+
+ ; bndr_ids <- tcLookupLocalIds binders
+
+ -- `return` and `>>=` are used to pass around/modify our
+ -- binders, so we know their types:
+ --
+ -- return :: (a,b,c,..) -> m (a,b,c,..)
+ -- (>>=) :: m (a,b,c,..)
+ -- -> ( (a,b,c,..) -> m (a,b,c,..) )
+ -- -> m (a,b,c,..)
+ --
+ ; let bndr_ty = mkBigCoreVarTupTy bndr_ids
+ m_bndr_ty = m_ty `mkAppTy` bndr_ty
+
+ ; return_op' <- tcSyntaxOp MCompOrigin return_op
+ (bndr_ty `mkFunTy` m_bndr_ty)
+
+ ; bind_op' <- tcSyntaxOp MCompOrigin bind_op $
+ m_bndr_ty `mkFunTy` (bndr_ty `mkFunTy` res_ty)
+ `mkFunTy` res_ty
+
+ -- Unify types of the inner comprehension and the binders type
+ ; _ <- unifyType res_ty' m_bndr_ty
+
+ -- Typecheck the `thing` with out old type (which is the type
+ -- of the final result of our comprehension)
+ ; thing <- thing_inside res_ty
+
+ ; return (bndr_ids, usingExpr', maybeByExpr', return_op', bind_op', thing) }
+
+ ; return (TransformStmt stmts' binders' usingExpr' maybeByExpr' return_op' bind_op', thing) }
+
+-- Grouping statements
+--
+-- [ body | stmts, then group by e ]
+-- -> e :: t
+-- [ body | stmts, then group by e using f ]
+-- -> e :: t
+-- f :: forall a. (a -> t) -> m a -> m (m a)
+-- [ body | stmts, then group using f ]
+-- -> f :: forall a. m a -> m (m a)
+--
+tcMcStmt ctxt (GroupStmt { grpS_stmts = stmts, grpS_bndrs = bindersMap
+ , grpS_by = by, grpS_using = using, grpS_explicit = explicit
+ , grpS_ret = return_op, grpS_bind = bind_op
+ , grpS_fmap = fmap_op }) res_ty thing_inside
+ = do { let star_star_kind = liftedTypeKind `mkArrowKind` liftedTypeKind
+ ; m1_ty <- newFlexiTyVarTy star_star_kind
+ ; m2_ty <- newFlexiTyVarTy star_star_kind
+ ; n_ty <- newFlexiTyVarTy star_star_kind
+ ; tup_ty_var <- newFlexiTyVarTy liftedTypeKind
+ ; new_res_ty <- newFlexiTyVarTy liftedTypeKind
+ ; let (bndr_names, n_bndr_names) = unzip bindersMap
+ m1_tup_ty = m1_ty `mkAppTy` tup_ty_var
+
+ -- 'stmts' returns a result of type (m1_ty tuple_ty),
+ -- typically something like [(Int,Bool,Int)]
+ -- We don't know what tuple_ty is yet, so we use a variable
+ ; (stmts', (bndr_ids, by_e_ty, return_op')) <-
+ tcStmtsAndThen (TransformStmtCtxt ctxt) tcMcStmt stmts m1_tup_ty $ \res_ty' -> do
+ { by_e_ty <- case by of
+ Nothing -> return Nothing
+ Just e -> do { e_ty <- tcInferRhoNC e; return (Just e_ty) }
+
+ -- Find the Ids (and hence types) of all old binders
+ ; bndr_ids <- tcLookupLocalIds bndr_names
+
+ -- 'return' is only used for the binders, so we know its type.
+ --
+ -- return :: (a,b,c,..) -> m (a,b,c,..)
+ ; return_op' <- tcSyntaxOp MCompOrigin return_op $
+ (mkBigCoreVarTupTy bndr_ids) `mkFunTy` res_ty'
+
+ ; return (bndr_ids, by_e_ty, return_op') }
+
+
+
+ ; let tup_ty = mkBigCoreVarTupTy bndr_ids -- (a,b,c)
+ using_arg_ty = m1_ty `mkAppTy` tup_ty -- m1 (a,b,c)
+ n_tup_ty = n_ty `mkAppTy` tup_ty -- n (a,b,c)
+ using_res_ty = m2_ty `mkAppTy` n_tup_ty -- m2 (n (a,b,c))
+ using_fun_ty = using_arg_ty `mkFunTy` using_arg_ty
+
+ -- (>>=) :: m2 (n (a,b,c)) -> ( n (a,b,c) -> new_res_ty ) -> res_ty
+ -- using :: ((a,b,c)->t) -> m1 (a,b,c) -> m2 (n (a,b,c))
+
+ --------------- Typecheck the 'bind' function -------------
+ ; bind_op' <- tcSyntaxOp MCompOrigin bind_op $
+ using_res_ty `mkFunTy` (n_tup_ty `mkFunTy` new_res_ty)
+ `mkFunTy` res_ty
+
+ --------------- Typecheck the 'using' function -------------
+ ; let poly_fun_ty = (m1_ty `mkAppTy` alphaTy) `mkFunTy`
+ (m2_ty `mkAppTy` (n_ty `mkAppTy` alphaTy))
+ using_poly_ty = case by_e_ty of
+ Nothing -> mkForAllTy alphaTyVar poly_fun_ty
+ -- using :: forall a. m1 a -> m2 (n a)
+
+ Just (_,t_ty) -> mkForAllTy alphaTyVar $
+ (alphaTy `mkFunTy` t_ty) `mkFunTy` poly_fun_ty
+ -- using :: forall a. (a->t) -> m1 a -> m2 (n a)
+ -- where by :: t
+
+ ; using' <- tcPolyExpr using using_poly_ty
+ ; coi <- unifyType (applyTy using_poly_ty tup_ty)
+ (case by_e_ty of
+ Nothing -> using_fun_ty
+ Just (_,t_ty) -> (tup_ty `mkFunTy` t_ty) `mkFunTy` using_fun_ty)
+ ; let final_using = fmap (mkHsWrapCoI coi . HsWrap (WpTyApp tup_ty)) using'
+
+ --------------- Typecheck the 'fmap' function -------------
+ ; fmap_op' <- fmap unLoc . tcPolyExpr (noLoc fmap_op) $
+ mkForAllTy alphaTyVar $ mkForAllTy betaTyVar $
+ (alphaTy `mkFunTy` betaTy)
+ `mkFunTy` (n_ty `mkAppTy` alphaTy)
+ `mkFunTy` (n_ty `mkAppTy` betaTy)
+
+ ; let mk_n_bndr :: Name -> TcId -> TcId
+ mk_n_bndr n_bndr_name bndr_id
+ = mkLocalId n_bndr_name (n_ty `mkAppTy` idType bndr_id)
+
+ -- Ensure that every old binder of type `b` is linked up with its
+ -- new binder which should have type `n b`
+ -- See Note [GroupStmt binder map] in HsExpr
+ n_bndr_ids = zipWith mk_n_bndr n_bndr_names bndr_ids
+ bindersMap' = bndr_ids `zip` n_bndr_ids
+
+ -- Type check the thing in the environment with these new binders and
+ -- return the result
+ ; thing <- tcExtendIdEnv n_bndr_ids (thing_inside res_ty)
+
+ ; return (GroupStmt { grpS_stmts = stmts', grpS_bndrs = bindersMap'
+ , grpS_by = fmap fst by_e_ty, grpS_using = final_using
+ , grpS_ret = return_op', grpS_bind = bind_op'
+ , grpS_fmap = fmap_op', grpS_explicit = explicit }, thing) }
+
+-- Typecheck `ParStmt`. See `tcLcStmt` for more informations about typechecking
+-- of `ParStmt`s.
+--
+-- Note: The `mzip` function will get typechecked via:
+--
+-- ParStmt [st1::t1, st2::t2, st3::t3]
+--
+-- mzip :: m st1
+-- -> (m st2 -> m st3 -> m (st2, st3)) -- recursive call
+-- -> m (st1, (st2, st3))
+--
+tcMcStmt ctxt (ParStmt bndr_stmts_s mzip_op bind_op return_op) res_ty thing_inside
+ = do { (_,(m_ty,_)) <- matchExpectedAppTy res_ty
+ -- ToDo: what if the coercion isn't the identity?
+
+ ; (pairs', thing) <- loop m_ty bndr_stmts_s
+
+ ; let mzip_ty = mkForAllTys [alphaTyVar, betaTyVar] $
+ (m_ty `mkAppTy` alphaTy)
+ `mkFunTy`
+ (m_ty `mkAppTy` betaTy)
+ `mkFunTy`
+ (m_ty `mkAppTy` mkBoxedTupleTy [alphaTy, betaTy])
+ ; mzip_op' <- unLoc `fmap` tcPolyExpr (noLoc mzip_op) mzip_ty
+
+ -- Typecheck bind:
+ ; let tys = map (mkBigCoreVarTupTy . snd) pairs'
+ tuple_ty = mk_tuple_ty tys
+
+ ; bind_op' <- tcSyntaxOp MCompOrigin bind_op $
+ (m_ty `mkAppTy` tuple_ty)
+ `mkFunTy`
+ (tuple_ty `mkFunTy` res_ty)
+ `mkFunTy`
+ res_ty
+
+ ; return_op' <- fmap unLoc . tcPolyExpr (noLoc return_op) $
+ mkForAllTy alphaTyVar $
+ alphaTy `mkFunTy` (m_ty `mkAppTy` alphaTy)
+
+ ; return (ParStmt pairs' mzip_op' bind_op' return_op', thing) }
+
+ where mk_tuple_ty tys = foldr1 (\tn tm -> mkBoxedTupleTy [tn, tm]) tys
+
+ -- loop :: Type -- m_ty
+ -- -> [([LStmt Name], [Name])]
+ -- -> TcM ([([LStmt TcId], [TcId])], thing)
+ loop _ [] = do { thing <- thing_inside res_ty
+ ; return ([], thing) } -- matching in the branches
+
+ loop m_ty ((stmts, names) : pairs)
+ = do { -- type dummy since we don't know all binder types yet
+ ty_dummy <- newFlexiTyVarTy liftedTypeKind
+ ; (stmts', (ids, pairs', thing))
+ <- tcStmtsAndThen ctxt tcMcStmt stmts ty_dummy $ \res_ty' ->
+ do { ids <- tcLookupLocalIds names
+ ; _ <- unifyType res_ty' (m_ty `mkAppTy` mkBigCoreVarTupTy ids)
+ ; (pairs', thing) <- loop m_ty pairs
+ ; return (ids, pairs', thing) }
+ ; return ( (stmts', ids) : pairs', thing ) }
+
+tcMcStmt _ stmt _ _
+ = pprPanic "tcMcStmt: unexpected Stmt" (ppr stmt)
+