From: Simon Peyton Jones Date: Fri, 29 Apr 2011 17:06:03 +0000 (+0100) Subject: Simon's hacking on monad-comp; incomplete X-Git-Url: http://git.megacz.com/?p=ghc-hetmet.git;a=commitdiff_plain;h=4ac2bb39dffb4b825ece73b349ff0d56d79092d7 Simon's hacking on monad-comp; incomplete --- diff --git a/compiler/deSugar/Coverage.lhs b/compiler/deSugar/Coverage.lhs index e73c249..711f66e 100644 --- a/compiler/deSugar/Coverage.lhs +++ b/compiler/deSugar/Coverage.lhs @@ -301,11 +301,9 @@ addTickHsExpr (HsLet binds e) = liftM2 HsLet (addTickHsLocalBinds binds) -- to think about: !patterns. (addTickLHsExprNeverOrAlways e) -addTickHsExpr (HsDo cxt stmts last_exp return_exp srcloc) = do - (stmts', last_exp') <- addTickLStmts' forQual stmts - (addTickLHsExpr last_exp) - return_exp' <- addTickSyntaxExpr hpcSrcSpan return_exp - return (HsDo cxt stmts' last_exp' return_exp' srcloc) +addTickHsExpr (HsDo cxt stmts srcloc) + = do { (stmts', _) <- addTickLStmts' forQual stmts (return ()) + ; return (HsDo cxt stmts' srcloc) } where forQual = case cxt of ListComp -> Just $ BinBox QualBinBox @@ -425,14 +423,16 @@ addTickLStmts isGuard stmts = do addTickLStmts' :: (Maybe (Bool -> BoxLabel)) -> [LStmt Id] -> TM a -> TM ([LStmt Id], a) addTickLStmts' isGuard lstmts res - = bindLocals binders $ do - lstmts' <- mapM (liftL (addTickStmt isGuard)) lstmts - a <- res - return (lstmts', a) - where - binders = collectLStmtsBinders lstmts + = bindLocals (collectLStmtsBinders lstmts) $ + do { lstmts' <- mapM (liftL (addTickStmt isGuard)) lstmts + ; a <- res + ; return (lstmts', a) } addTickStmt :: (Maybe (Bool -> BoxLabel)) -> Stmt Id -> TM (Stmt Id) +addTickStmt _isGuard (LastStmt e ret) = do + liftM2 LastStmt + (addTickLHsExprAlways e) + (addTickSyntaxExpr hpcSrcSpan ret) addTickStmt _isGuard (BindStmt pat e bind fail) = do liftM4 BindStmt (addTickLPat pat) @@ -577,10 +577,9 @@ addTickHsCmd (HsLet binds c) = liftM2 HsLet (addTickHsLocalBinds binds) -- to think about: !patterns. (addTickLHsCmd c) -addTickHsCmd (HsDo cxt stmts last_exp return_exp srcloc) = do - (stmts', last_exp') <- addTickLCmdStmts' stmts (addTickLHsCmd last_exp) - return_exp' <- addTickSyntaxExpr hpcSrcSpan return_exp - return (HsDo cxt stmts' last_exp' return_exp' srcloc) +addTickHsCmd (HsDo cxt stmts srcloc) + = do { (stmts', _) <- addTickLCmdStmts' stmts (return ()) + ; return (HsDo cxt stmts' srcloc) } addTickHsCmd (HsArrApp e1 e2 ty1 arr_ty lr) = liftM5 HsArrApp diff --git a/compiler/deSugar/DsArrows.lhs b/compiler/deSugar/DsArrows.lhs index 608f25e..a5bf2b6 100644 --- a/compiler/deSugar/DsArrows.lhs +++ b/compiler/deSugar/DsArrows.lhs @@ -541,8 +541,8 @@ dsCmd ids local_vars env_ids stack res_ty (HsLet binds body) = do core_body, exprFreeVars core_binds `intersectVarSet` local_vars) -dsCmd ids local_vars env_ids [] res_ty (HsDo _ctxt stmts body _ _) - = dsCmdDo ids local_vars env_ids res_ty stmts body +dsCmd ids local_vars env_ids [] res_ty (HsDo _ctxt stmts _) + = dsCmdDo ids local_vars env_ids res_ty stmts -- A |- e :: forall e. a1 (e*ts1) t1 -> ... an (e*tsn) tn -> a (e*ts) t -- A | xs |- ci :: [tsi] ti @@ -618,7 +618,6 @@ dsCmdDo :: DsCmdEnv -- arrow combinators -- so don't pull on it too early -> Type -- return type of the statement -> [LStmt Id] -- statements to desugar - -> LHsExpr Id -- body -> DsM (CoreExpr, -- desugared expression IdSet) -- set of local vars that occur free @@ -626,15 +625,17 @@ dsCmdDo :: DsCmdEnv -- arrow combinators -- -------------------------- -- A | xs |- do { c } :: [] t -dsCmdDo ids local_vars env_ids res_ty [] body +dsCmdDo _ _ _ _ [] = panic "dsCmdDo" + +dsCmdDo ids local_vars env_ids res_ty [L _ (LastStmt body _)] = dsLCmd ids local_vars env_ids [] res_ty body -dsCmdDo ids local_vars env_ids res_ty (stmt:stmts) body = do +dsCmdDo ids local_vars env_ids res_ty (stmt:stmts) = do let bound_vars = mkVarSet (collectLStmtBinders stmt) local_vars' = local_vars `unionVarSet` bound_vars (core_stmts, _, env_ids') <- fixDs (\ ~(_,_,env_ids') -> do - (core_stmts, fv_stmts) <- dsCmdDo ids local_vars' env_ids' res_ty stmts body + (core_stmts, fv_stmts) <- dsCmdDo ids local_vars' env_ids' res_ty stmts return (core_stmts, fv_stmts, varSetElems fv_stmts)) (core_stmt, fv_stmt) <- dsCmdLStmt ids local_vars env_ids env_ids' stmt return (do_compose ids diff --git a/compiler/deSugar/DsExpr.lhs b/compiler/deSugar/DsExpr.lhs index fb3f856..c55c2d4 100644 --- a/compiler/deSugar/DsExpr.lhs +++ b/compiler/deSugar/DsExpr.lhs @@ -325,29 +325,12 @@ dsExpr (HsLet binds body) = do -- We need the `ListComp' form to use `deListComp' (rather than the "do" form) -- because the interpretation of `stmts' depends on what sort of thing it is. -- -dsExpr (HsDo ListComp stmts body _ result_ty) - = -- Special case for list comprehensions - dsListComp stmts body elt_ty - where - [elt_ty] = tcTyConAppArgs result_ty - -dsExpr (HsDo DoExpr stmts body _ result_ty) - = dsDo stmts body result_ty - -dsExpr (HsDo GhciStmt stmts body _ result_ty) - = dsDo stmts body result_ty - -dsExpr (HsDo MDoExpr stmts body _ result_ty) - = dsDo stmts body result_ty - -dsExpr (HsDo MonadComp stmts body return_op result_ty) - = dsMonadComp stmts return_op body result_ty - -dsExpr (HsDo PArrComp stmts body _ result_ty) - = -- Special case for array comprehensions - dsPArrComp (map unLoc stmts) body elt_ty - where - [elt_ty] = tcTyConAppArgs result_ty +dsExpr (HsDo ListComp stmts res_ty) = dsListComp stmts res_ty +dsExpr (HsDo PArrComp stmts _) = dsPArrComp (map unLoc stmts) +dsExpr (HsDo DoExpr stmts res_ty) = dsDo stmts res_ty +dsExpr (HsDo GhciStmt stmts res_ty) = dsDo stmts res_ty +dsExpr (HsDo MDoExpr stmts res_ty) = dsDo stmts res_ty +dsExpr (HsDo MonadComp stmts res_ty) = dsMonadComp stmts res_ty dsExpr (HsIf mb_fun guard_expr then_expr else_expr) = do { pred <- dsLExpr guard_expr @@ -712,24 +695,24 @@ Haskell 98 report: \begin{code} dsDo :: [LStmt Id] - -> LHsExpr Id -> Type -- Type of the whole expression -> DsM CoreExpr -dsDo stmts body result_ty +dsDo stmts result_ty = goL stmts where - -- result_ty must be of the form (m b) - (m_ty, _b_ty) = tcSplitAppTy result_ty - - goL [] = dsLExpr body - goL ((L loc stmt):lstmts) = putSrcSpanDs loc (go loc stmt lstmts) + goL [] = panic "dsDo" + goL (L loc stmt:lstmts) = putSrcSpanDs loc (go loc stmt lstmts) + go _ (LastStmt body ret_op) stmts + = ASSERT( null stmts ) + do { body' <- dsLExpr body + ; ret_op' <- dsExpr ret_op + ; return (App ret_op' body') } + go _ (ExprStmt rhs then_expr _ _) stmts = do { rhs2 <- dsLExpr rhs - ; case tcSplitAppTy_maybe (exprType rhs2) of - Just (container_ty, returning_ty) -> warnDiscardedDoBindings rhs container_ty returning_ty - _ -> return () + ; warnDiscardedDoBindings rhs (exprType rhs2) ; then_expr2 <- dsExpr then_expr ; rest <- goL stmts ; return (mkApps then_expr2 [rhs2, rest]) } @@ -753,29 +736,25 @@ dsDo stmts body result_ty go loc (RecStmt { recS_stmts = rec_stmts, recS_later_ids = later_ids , recS_rec_ids = rec_ids, recS_ret_fn = return_op , recS_mfix_fn = mfix_op, recS_bind_fn = bind_op - , recS_rec_rets = rec_rets }) stmts + , recS_rec_rets = rec_rets, recS_ret_ty = body_ty }) stmts = ASSERT( length rec_ids > 0 ) goL (new_bind_stmt : stmts) where - -- returnE <- dsExpr return_id - -- mfixE <- dsExpr mfix_id - new_bind_stmt = L loc $ BindStmt (mkLHsPatTup later_pats) mfix_app - bind_op + new_bind_stmt = L loc $ BindStmt (mkLHsPatTup later_pats) + mfix_app bind_op noSyntaxExpr -- Tuple cannot fail tup_ids = rec_ids ++ filterOut (`elem` rec_ids) later_ids rec_tup_pats = map nlVarPat tup_ids later_pats = rec_tup_pats rets = map noLoc rec_rets - - mfix_app = nlHsApp (noLoc mfix_op) mfix_arg - mfix_arg = noLoc $ HsLam (MatchGroup [mkSimpleMatch [mfix_pat] body] - (mkFunTy tup_ty body_ty)) - mfix_pat = noLoc $ LazyPat $ mkLHsPatTup rec_tup_pats - body = noLoc $ HsDo DoExpr rec_stmts return_app noSyntaxExpr body_ty - return_app = nlHsApp (noLoc return_op) (mkLHsTupleExpr rets) - body_ty = mkAppTy m_ty tup_ty - tup_ty = mkBoxedTupleTy (map idType tup_ids) -- Deals with singleton case + mfix_app = nlHsApp (noLoc mfix_op) mfix_arg + mfix_arg = noLoc $ HsLam (MatchGroup [mkSimpleMatch [mfix_pat] body] + (mkFunTy tup_ty body_ty)) + mfix_pat = noLoc $ LazyPat $ mkLHsPatTup rec_tup_pats + body = noLoc $ HsDo DoExpr (rec_stmts ++ [ret_stmt]) body_ty + ret_stmt = noLoc $ LastStmt return_op (mkLHsTupleExpr rets) + tup_ty = mkBoxedTupleTy (map idType tup_ids) -- Deals with singleton case handle_failure :: LPat Id -> MatchResult -> SyntaxExpr Id -> DsM CoreExpr -- In a do expression, pattern-match failure just calls @@ -793,103 +772,6 @@ mk_fail_msg pat = "Pattern match failure in do expression at " ++ showSDoc (ppr (getLoc pat)) \end{code} -Translation for RecStmt's: ------------------------------ -We turn (RecStmt [v1,..vn] stmts) into: - - (v1,..,vn) <- mfix (\~(v1,..vn). do stmts - return (v1,..vn)) - -\begin{code} -{- -dsMDo :: HsStmtContext Name - -> [(Name,Id)] - -> [LStmt Id] - -> LHsExpr Id - -> Type -- Type of the whole expression - -> DsM CoreExpr - -dsMDo ctxt tbl stmts body result_ty - = goL stmts - where - goL [] = dsLExpr body - goL ((L loc stmt):lstmts) = putSrcSpanDs loc (go loc stmt lstmts) - - (m_ty, b_ty) = tcSplitAppTy result_ty -- result_ty must be of the form (m b) - return_id = lookupEvidence tbl returnMName - bind_id = lookupEvidence tbl bindMName - then_id = lookupEvidence tbl thenMName - fail_id = lookupEvidence tbl failMName - - go _ (LetStmt binds) stmts - = do { rest <- goL stmts - ; dsLocalBinds binds rest } - - go _ (ExprStmt rhs then_expr rhs_ty) stmts - = do { rhs2 <- dsLExpr rhs - ; warnDiscardedDoBindings rhs m_ty rhs_ty - ; then_expr2 <- dsExpr then_expr - ; rest <- goL stmts - ; return (mkApps then_expr2 [rhs2, rest]) } - - go _ (BindStmt pat rhs bind_op _) stmts - = do { body <- goL stmts - ; rhs' <- dsLExpr rhs - ; bind_op' <- dsExpr bind_op - ; var <- selectSimpleMatchVarL pat - ; match <- matchSinglePat (Var var) (StmtCtxt ctxt) pat - result_ty (cantFailMatchResult body) - ; match_code <- handle_failure pat match fail_op - ; return (mkApps bind_op [rhs', Lam var match_code]) } - - go loc (RecStmt { recS_stmts = rec_stmts, recS_later_ids = later_ids - , recS_rec_ids = rec_ids, recS_rec_rets = rec_rets - , recS_mfix_fn = mfix_op, recS_bind_fn = bind_op }) stmts - = ASSERT( length rec_ids > 0 ) - ASSERT( length rec_ids == length rec_rets ) - ASSERT( isEmptyTcEvBinds _ev_binds ) - pprTrace "dsMDo" (ppr later_ids) $ - goL (new_bind_stmt : stmts) - where - new_bind_stmt = L loc $ BindStmt (mk_tup_pat later_pats) mfix_app - bind_op noSyntaxExpr - - -- Remove the later_ids that appear (without fancy coercions) - -- in rec_rets, because there's no need to knot-tie them separately - -- See Note [RecStmt] in HsExpr - later_ids' = filter (`notElem` mono_rec_ids) later_ids - mono_rec_ids = [ id | HsVar id <- rec_rets ] - - mfix_app = nlHsApp (noLoc mfix_op) mfix_arg - mfix_arg = noLoc $ HsLam (MatchGroup [mkSimpleMatch [mfix_pat] body] - (mkFunTy tup_ty body_ty)) - - -- The rec_tup_pat must bind the rec_ids only; remember that the - -- trimmed_laters may share the same Names - -- Meanwhile, the later_pats must bind the later_vars - rec_tup_pats = map mk_wild_pat later_ids' ++ map nlVarPat rec_ids - later_pats = map nlVarPat later_ids' ++ map mk_later_pat rec_ids - rets = map nlHsVar later_ids' ++ map noLoc rec_rets - - mfix_pat = noLoc $ LazyPat $ mk_tup_pat rec_tup_pats - body = noLoc $ HsDo ctxt rec_stmts return_app noSyntaxExpr body_ty - body_ty = mkAppTy m_ty tup_ty - tup_ty = mkBoxedTupleTy (map idType (later_ids' ++ rec_ids)) -- Deals with singleton case - - return_app = nlHsApp (noLoc return_op) (mkLHsTupleExpr rets) - - mk_wild_pat :: Id -> LPat Id - mk_wild_pat v = noLoc $ WildPat $ idType v - - mk_later_pat :: Id -> LPat Id - mk_later_pat v | v `elem` later_ids' = mk_wild_pat v - | otherwise = nlVarPat v - - mk_tup_pat :: [LPat Id] -> LPat Id - mk_tup_pat [p] = p - mk_tup_pat ps = noLoc $ mkVanillaTuplePat ps Boxed --} -\end{code} %************************************************************************ %* * @@ -929,30 +811,34 @@ conversionNames \begin{code} -- Warn about certain types of values discarded in monadic bindings (#3263) -warnDiscardedDoBindings :: LHsExpr Id -> Type -> Type -> DsM () -warnDiscardedDoBindings rhs container_ty returning_ty = do { - -- Warn about discarding non-() things in 'monadic' binding - ; warn_unused <- doptDs Opt_WarnUnusedDoBind - ; if warn_unused && not (returning_ty `tcEqType` unitTy) - then warnDs (unusedMonadBind rhs returning_ty) - else do { - -- Warn about discarding m a things in 'monadic' binding of the same type, - -- but only if we didn't already warn due to Opt_WarnUnusedDoBind - ; warn_wrong <- doptDs Opt_WarnWrongDoBind - ; case tcSplitAppTy_maybe returning_ty of - Just (returning_container_ty, _) -> when (warn_wrong && container_ty `tcEqType` returning_container_ty) $ - warnDs (wrongMonadBind rhs returning_ty) - _ -> return () } } +warnDiscardedDoBindings :: LHsExpr Id -> Type -> DsM () +warnDiscardedDoBindings rhs rhs_ty + | Just (m_ty, elt_ty) <- tcSplitAppTy_maybe rhs_ty + = do { -- Warn about discarding non-() things in 'monadic' binding + ; warn_unused <- doptDs Opt_WarnUnusedDoBind + ; if warn_unused && not (isUnitTy elt_ty) + then warnDs (unusedMonadBind rhs elt_ty) + else + -- Warn about discarding m a things in 'monadic' binding of the same type, + -- but only if we didn't already warn due to Opt_WarnUnusedDoBind + do { warn_wrong <- doptDs Opt_WarnWrongDoBind + ; case tcSplitAppTy_maybe elt_ty of + Just (elt_m_ty, _) | warn_wrong, m_ty `tcEqType` elt_m_ty + -> warnDs (wrongMonadBind rhs elt_ty) + _ -> return () } } + + | otherwise -- RHS does have type of form (m ty), which is wierd + = return () -- but at lesat this warning is irrelevant unusedMonadBind :: LHsExpr Id -> Type -> SDoc -unusedMonadBind rhs returning_ty - = ptext (sLit "A do-notation statement discarded a result of type") <+> ppr returning_ty <> dot $$ +unusedMonadBind rhs elt_ty + = ptext (sLit "A do-notation statement discarded a result of type") <+> ppr elt_ty <> dot $$ ptext (sLit "Suppress this warning by saying \"_ <- ") <> ppr rhs <> ptext (sLit "\",") $$ ptext (sLit "or by using the flag -fno-warn-unused-do-bind") wrongMonadBind :: LHsExpr Id -> Type -> SDoc -wrongMonadBind rhs returning_ty - = ptext (sLit "A do-notation statement discarded a result of type") <+> ppr returning_ty <> dot $$ +wrongMonadBind rhs elt_ty + = ptext (sLit "A do-notation statement discarded a result of type") <+> ppr elt_ty <> dot $$ ptext (sLit "Suppress this warning by saying \"_ <- ") <> ppr rhs <> ptext (sLit "\",") $$ ptext (sLit "or by using the flag -fno-warn-wrong-do-bind") \end{code} diff --git a/compiler/deSugar/DsListComp.lhs b/compiler/deSugar/DsListComp.lhs index 7fa7848..1ecab67 100644 --- a/compiler/deSugar/DsListComp.lhs +++ b/compiler/deSugar/DsListComp.lhs @@ -49,12 +49,12 @@ There will be at least one ``qualifier'' in the input. \begin{code} dsListComp :: [LStmt Id] - -> LHsExpr Id - -> Type -- Type of list elements + -> Type -- Type of entire list -> DsM CoreExpr -dsListComp lquals body elt_ty = do +dsListComp lquals res_ty = do dflags <- getDOptsDs let quals = map unLoc lquals + [elt_ty] = tcTyConAppArgs res_ty if not (dopt Opt_EnableRewriteRules dflags) || dopt Opt_IgnoreInterfacePragmas dflags -- Either rules are switched off, or we are ignoring what there are; @@ -62,8 +62,8 @@ dsListComp lquals body elt_ty = do -- Wadler-style desugaring || isParallelComp quals -- Foldr-style desugaring can't handle parallel list comprehensions - then deListComp quals body (mkNilExpr elt_ty) - else mkBuildExpr elt_ty (\(c, _) (n, _) -> dfListComp c n quals body) + then deListComp quals (mkNilExpr elt_ty) + else mkBuildExpr elt_ty (\(c, _) (n, _) -> dfListComp c n quals) -- Foldr/build should be enabled, so desugar -- into foldrs and builds @@ -83,12 +83,11 @@ dsListComp lquals body elt_ty = do -- and the type of the elements that it outputs (tuples of binders) dsInnerListComp :: ([LStmt Id], [Id]) -> DsM (CoreExpr, Type) dsInnerListComp (stmts, bndrs) = do - expr <- dsListComp stmts (mkBigLHsVarTup bndrs) bndrs_tuple_type - return (expr, bndrs_tuple_type) - where - bndrs_types = map idType bndrs - bndrs_tuple_type = mkBigCoreTupTy bndrs_types - + = do { expr <- dsListComp (stmts ++ [noLoc $ mkLastStmt (mkBigLHsVarTup bndrs)]) + bndrs_tuple_type + ; return (expr, bndrs_tuple_type) } + where + bndrs_tuple_type = mkBigCoreVarTupTy bndrs -- This function factors out commonality between the desugaring strategies for TransformStmt. -- Given such a statement it gives you back an expression representing how to compute the transformed @@ -228,9 +227,40 @@ with the Unboxed variety. \begin{code} -deListComp :: [Stmt Id] -> LHsExpr Id -> CoreExpr -> DsM CoreExpr +deListComp :: [Stmt Id] -> CoreExpr -> DsM CoreExpr + +deListComp [] _ = panic "deListComp" -deListComp (ParStmt stmtss_w_bndrs _ _ _ : quals) body list +deListComp (LastStmt body _ : quals) list + = -- Figure 7.4, SLPJ, p 135, rule C above + ASSERT( null quals ) + do { core_body <- dsLExpr body + ; return (mkConsExpr (exprType core_body) core_body list) } + + -- Non-last: must be a guard +deListComp (ExprStmt guard _ _ _ : quals) list = do -- rule B above + core_guard <- dsLExpr guard + core_rest <- deListComp quals list + return (mkIfThenElse core_guard core_rest list) + +-- [e | let B, qs] = let B in [e | qs] +deListComp (LetStmt binds : quals) list = do + core_rest <- deListComp quals list + dsLocalBinds binds core_rest + +deListComp (stmt@(TransformStmt {}) : quals) list = do + (inner_list_expr, pat) <- dsTransformStmt stmt + deBindComp pat inner_list_expr quals list + +deListComp (stmt@(GroupStmt {}) : quals) list = do + (inner_list_expr, pat) <- dsGroupStmt stmt + deBindComp pat inner_list_expr quals list + +deListComp (BindStmt pat list1 _ _ : quals) core_list2 = do -- rule A' above + core_list1 <- dsLExpr list1 + deBindComp pat core_list1 quals core_list2 + +deListComp (ParStmt stmtss_w_bndrs _ _ _ : quals) list = do exps_and_qual_tys <- mapM dsInnerListComp stmtss_w_bndrs let (exps, qual_tys) = unzip exps_and_qual_tys @@ -239,7 +269,7 @@ deListComp (ParStmt stmtss_w_bndrs _ _ _ : quals) body list -- Deal with [e | pat <- zip l1 .. ln] in example above deBindComp pat (Let (Rec [(zip_fn, zip_rhs)]) (mkApps (Var zip_fn) exps)) - quals body list + quals list where bndrs_s = map snd stmtss_w_bndrs @@ -247,34 +277,6 @@ deListComp (ParStmt stmtss_w_bndrs _ _ _ : quals) body list -- pat is the pattern ((x1,..,xn), (y1,..,ym)) in the example above pat = mkBigLHsPatTup pats pats = map mkBigLHsVarPatTup bndrs_s - - -- Last: the one to return -deListComp [] body list = do -- Figure 7.4, SLPJ, p 135, rule C above - core_body <- dsLExpr body - return (mkConsExpr (exprType core_body) core_body list) - - -- Non-last: must be a guard -deListComp (ExprStmt guard _ _ _ : quals) body list = do -- rule B above - core_guard <- dsLExpr guard - core_rest <- deListComp quals body list - return (mkIfThenElse core_guard core_rest list) - --- [e | let B, qs] = let B in [e | qs] -deListComp (LetStmt binds : quals) body list = do - core_rest <- deListComp quals body list - dsLocalBinds binds core_rest - -deListComp (stmt@(TransformStmt {}) : quals) body list = do - (inner_list_expr, pat) <- dsTransformStmt stmt - deBindComp pat inner_list_expr quals body list - -deListComp (stmt@(GroupStmt {}) : quals) body list = do - (inner_list_expr, pat) <- dsGroupStmt stmt - deBindComp pat inner_list_expr quals body list - -deListComp (BindStmt pat list1 _ _ : quals) body core_list2 = do -- rule A' above - core_list1 <- dsLExpr list1 - deBindComp pat core_list1 quals body core_list2 \end{code} @@ -282,10 +284,9 @@ deListComp (BindStmt pat list1 _ _ : quals) body core_list2 = do -- rule A' abov deBindComp :: OutPat Id -> CoreExpr -> [Stmt Id] - -> LHsExpr Id -> CoreExpr -> DsM (Expr Id) -deBindComp pat core_list1 quals body core_list2 = do +deBindComp pat core_list1 quals core_list2 = do let u3_ty@u1_ty = exprType core_list1 -- two names, same thing @@ -302,7 +303,7 @@ deBindComp pat core_list1 quals body core_list2 = do core_fail = App (Var h) (Var u3) letrec_body = App (Var h) core_list1 - rest_expr <- deListComp quals body core_fail + rest_expr <- deListComp quals core_fail core_match <- matchSimply (Var u2) (StmtCtxt ListComp) pat rest_expr core_fail let @@ -337,48 +338,48 @@ TE[ e | p <- l , q ] c n = let \begin{code} dfListComp :: Id -> Id -- 'c' and 'n' -> [Stmt Id] -- the rest of the qual's - -> LHsExpr Id -> DsM CoreExpr - -- Last: the one to return -dfListComp c_id n_id [] body = do - core_body <- dsLExpr body - return (mkApps (Var c_id) [core_body, Var n_id]) +dfListComp _ _ [] = panic "dfListComp" + +dfListComp c_id n_id (LastStmt body _ : quals) + = ASSERT( null quals ) + do { core_body <- dsLExpr body + ; return (mkApps (Var c_id) [core_body, Var n_id]) } -- Non-last: must be a guard -dfListComp c_id n_id (ExprStmt guard _ _ _ : quals) body = do +dfListComp c_id n_id (ExprStmt guard _ _ _ : quals) = do core_guard <- dsLExpr guard - core_rest <- dfListComp c_id n_id quals body + core_rest <- dfListComp c_id n_id quals return (mkIfThenElse core_guard core_rest (Var n_id)) -dfListComp c_id n_id (LetStmt binds : quals) body = do +dfListComp c_id n_id (LetStmt binds : quals) = do -- new in 1.3, local bindings - core_rest <- dfListComp c_id n_id quals body + core_rest <- dfListComp c_id n_id quals dsLocalBinds binds core_rest -dfListComp c_id n_id (stmt@(TransformStmt {}) : quals) body = do +dfListComp c_id n_id (stmt@(TransformStmt {}) : quals) = do (inner_list_expr, pat) <- dsTransformStmt stmt -- Anyway, we bind the newly transformed list via the generic binding function - dfBindComp c_id n_id (pat, inner_list_expr) quals body + dfBindComp c_id n_id (pat, inner_list_expr) quals -dfListComp c_id n_id (stmt@(GroupStmt {}) : quals) body = do +dfListComp c_id n_id (stmt@(GroupStmt {}) : quals) = do (inner_list_expr, pat) <- dsGroupStmt stmt -- Anyway, we bind the newly grouped list via the generic binding function - dfBindComp c_id n_id (pat, inner_list_expr) quals body + dfBindComp c_id n_id (pat, inner_list_expr) quals -dfListComp c_id n_id (BindStmt pat list1 _ _ : quals) body = do +dfListComp c_id n_id (BindStmt pat list1 _ _ : quals) = do -- evaluate the two lists core_list1 <- dsLExpr list1 -- Do the rest of the work in the generic binding builder - dfBindComp c_id n_id (pat, core_list1) quals body + dfBindComp c_id n_id (pat, core_list1) quals dfBindComp :: Id -> Id -- 'c' and 'n' -> (LPat Id, CoreExpr) -> [Stmt Id] -- the rest of the qual's - -> LHsExpr Id -> DsM CoreExpr -dfBindComp c_id n_id (pat, core_list1) quals body = do +dfBindComp c_id n_id (pat, core_list1) quals = do -- find the required type let x_ty = hsLPatType pat b_ty = idType n_id @@ -387,7 +388,7 @@ dfBindComp c_id n_id (pat, core_list1) quals body = do [b, x] <- newSysLocalsDs [b_ty, x_ty] -- build rest of the comprehesion - core_rest <- dfListComp c_id b quals body + core_rest <- dfListComp c_id b quals -- build the pattern match core_expr <- matchSimply (Var x) (StmtCtxt ListComp) @@ -482,9 +483,6 @@ mkUnzipBind elt_tys = do unzip_fn_ty = elt_tuple_list_ty `mkFunTy` elt_list_tuple_ty mkConcatExpression (list_element_ty, head, tail) = mkConsExpr list_element_ty head tail - - - \end{code} %************************************************************************ @@ -500,11 +498,10 @@ mkUnzipBind elt_tys = do -- [:e | qss:] = <<[:e | qss:]>> () [:():] -- dsPArrComp :: [Stmt Id] - -> LHsExpr Id - -> Type -- Don't use; called with `undefined' below -> DsM CoreExpr -dsPArrComp [ParStmt qss _ _ _] body _ = -- parallel comprehension - dePArrParComp qss body + +-- Special case for parallel comprehension +dsPArrComp (ParStmt qss _ _ _ : quals) = dePArrParComp qss quals -- Special case for simple generators: -- @@ -515,7 +512,7 @@ dsPArrComp [ParStmt qss _ _ _] body _ = -- parallel comprehension -- <<[:e' | p <- e, qs:]>> = -- <<[:e' | qs:]>> p (filterP (\x -> case x of {p -> True; _ -> False}) e) -- -dsPArrComp (BindStmt p e _ _ : qs) body _ = do +dsPArrComp (BindStmt p e _ _ : qs) = do filterP <- dsLookupDPHId filterPName ce <- dsLExpr e let ety'ce = parrElemType ce @@ -525,38 +522,41 @@ dsPArrComp (BindStmt p e _ _ : qs) body _ = do pred <- matchSimply (Var v) (StmtCtxt PArrComp) p true false let gen | isIrrefutableHsPat p = ce | otherwise = mkApps (Var filterP) [Type ety'ce, mkLams [v] pred, ce] - dePArrComp qs body p gen + dePArrComp qs p gen -dsPArrComp qs body _ = do -- no ParStmt in `qs' +dsPArrComp qs = do -- no ParStmt in `qs' sglP <- dsLookupDPHId singletonPName let unitArray = mkApps (Var sglP) [Type unitTy, mkCoreTup []] - dePArrComp qs body (noLoc $ WildPat unitTy) unitArray + dePArrComp qs (noLoc $ WildPat unitTy) unitArray -- the work horse -- dePArrComp :: [Stmt Id] - -> LHsExpr Id -> LPat Id -- the current generator pattern -> CoreExpr -- the current generator expression -> DsM CoreExpr + +dePArrComp [] _ _ = panic "dePArrComp" + -- -- <<[:e' | :]>> pa ea = mapP (\pa -> e') ea -- -dePArrComp [] e' pa cea = do - mapP <- dsLookupDPHId mapPName - let ty = parrElemType cea - (clam, ty'e') <- deLambda ty pa e' - return $ mkApps (Var mapP) [Type ty, Type ty'e', clam, cea] +dePArrComp (LastStmt e' _ : quals) pa cea + = ASSERT( null quals ) + do { mapP <- dsLookupDPHId mapPName + ; let ty = parrElemType cea + ; (clam, ty'e') <- deLambda ty pa e' + ; return $ mkApps (Var mapP) [Type ty, Type ty'e', clam, cea] } -- -- <<[:e' | b, qs:]>> pa ea = <<[:e' | qs:]>> pa (filterP (\pa -> b) ea) -- -dePArrComp (ExprStmt b _ _ _ : qs) body pa cea = do +dePArrComp (ExprStmt b _ _ _ : qs) pa cea = do filterP <- dsLookupDPHId filterPName let ty = parrElemType cea (clam,_) <- deLambda ty pa b - dePArrComp qs body pa (mkApps (Var filterP) [Type ty, clam, cea]) + dePArrComp qs pa (mkApps (Var filterP) [Type ty, clam, cea]) -- -- <<[:e' | p <- e, qs:]>> pa ea = @@ -571,7 +571,7 @@ dePArrComp (ExprStmt b _ _ _ : qs) body pa cea = do -- in -- <<[:e' | qs:]>> (pa, p) (crossMapP ea ef) -- -dePArrComp (BindStmt p e _ _ : qs) body pa cea = do +dePArrComp (BindStmt p e _ _ : qs) pa cea = do filterP <- dsLookupDPHId filterPName crossMapP <- dsLookupDPHId crossMapPName ce <- dsLExpr e @@ -587,7 +587,7 @@ dePArrComp (BindStmt p e _ _ : qs) body pa cea = do let ety'cef = ety'ce -- filter doesn't change the element type pa' = mkLHsPatTup [pa, p] - dePArrComp qs body pa' (mkApps (Var crossMapP) + dePArrComp qs pa' (mkApps (Var crossMapP) [Type ety'cea, Type ety'cef, cea, clam]) -- -- <<[:e' | let ds, qs:]>> pa ea = @@ -596,7 +596,7 @@ dePArrComp (BindStmt p e _ _ : qs) body pa cea = do -- where -- {x_1, ..., x_n} = DV (ds) -- Defined Variables -- -dePArrComp (LetStmt ds : qs) body pa cea = do +dePArrComp (LetStmt ds : qs) pa cea = do mapP <- dsLookupDPHId mapPName let xs = collectLocalBinders ds ty'cea = parrElemType cea @@ -611,14 +611,14 @@ dePArrComp (LetStmt ds : qs) body pa cea = do ccase <- matchSimply (Var v) (StmtCtxt PArrComp) pa projBody cerr let pa' = mkLHsPatTup [pa, mkLHsPatTup (map nlVarPat xs)] proj = mkLams [v] ccase - dePArrComp qs body pa' (mkApps (Var mapP) + dePArrComp qs pa' (mkApps (Var mapP) [Type ty'cea, Type errTy, proj, cea]) -- -- The parser guarantees that parallel comprehensions can only appear as -- singeltons qualifier lists, which we already special case in the caller. -- So, encountering one here is a bug. -- -dePArrComp (ParStmt _ _ _ _ : _) _ _ _ = +dePArrComp (ParStmt _ _ _ _ : _) _ _ = panic "DsListComp.dePArrComp: malformed comprehension AST" -- <<[:e' | qs | qss:]>> pa ea = @@ -627,17 +627,17 @@ dePArrComp (ParStmt _ _ _ _ : _) _ _ _ = -- where -- {x_1, ..., x_n} = DV (qs) -- -dePArrParComp :: [([LStmt Id], [Id])] -> LHsExpr Id -> DsM CoreExpr -dePArrParComp qss body = do +dePArrParComp :: [([LStmt Id], [Id])] -> [Stmt Id] -> DsM CoreExpr +dePArrParComp qss quals = do (pQss, ceQss) <- deParStmt qss - dePArrComp [] body pQss ceQss + dePArrComp quals pQss ceQss where deParStmt [] = -- empty parallel statement lists have no source representation panic "DsListComp.dePArrComp: Empty parallel list comprehension" deParStmt ((qs, xs):qss) = do -- first statement let res_expr = mkLHsVarTuple xs - cqs <- dsPArrComp (map unLoc qs) res_expr undefined + cqs <- dsPArrComp (map unLoc qs ++ [mkLastStmt res_expr]) parStmts qss (mkLHsVarPatTup xs) cqs --- parStmts [] pa cea = return (pa, cea) @@ -646,7 +646,7 @@ dePArrParComp qss body = do let pa' = mkLHsPatTup [pa, mkLHsVarPatTup xs] ty'cea = parrElemType cea res_expr = mkLHsVarTuple xs - cqs <- dsPArrComp (map unLoc qs) res_expr undefined + cqs <- dsPArrComp (map unLoc qs ++ [mkLastStmt res_expr]) let ty'cqs = parrElemType cqs cea' = mkApps (Var zipP) [Type ty'cea, Type ty'cqs, cea, cqs] parStmts qss pa' cea' @@ -701,11 +701,9 @@ data DsMonadComp = DsMonadComp -- Entry point for monad comprehension desugaring -- dsMonadComp :: [LStmt Id] -- the statements - -> SyntaxExpr Id -- the "return" function - -> LHsExpr Id -- the body -> Type -- the final type -> DsM CoreExpr -dsMonadComp stmts return_op body res_ty +dsMonadComp stmts res_ty = dsMcStmts stmts (DsMonadComp (Left return_op) body m_ty) where (m_ty, _) = tcSplitAppTy res_ty @@ -729,30 +727,33 @@ dsMcStmts ((L loc stmt) : lstmts) mc = putSrcSpanDs loc (dsMcStmt stmt lstmts mc) -dsMcStmt :: Stmt Id - -> [LStmt Id] - -> DsMonadComp - -> DsM CoreExpr +dsMcStmt :: Stmt Id -> [LStmt Id] -> DsM CoreExpr + +dsMcStmt (LastStmt body ret_op) stmts + = ASSERT( null stmts ) + do { body' <- dsLExpr body + ; ret_op' <- dsExpr ret_op + ; return (App ret_op' body') } -- [ .. | let binds, stmts ] -dsMcStmt (LetStmt binds) stmts mc - = do { rest <- dsMcStmts stmts mc +dsMcStmt (LetStmt binds) stmts + = do { rest <- dsMcStmts stmts ; dsLocalBinds binds rest } -- [ .. | a <- m, stmts ] -dsMcStmt (BindStmt pat rhs bind_op fail_op) stmts mc - = do { rhs' <- dsLExpr rhs - ; dsMcBindStmt pat rhs' bind_op fail_op stmts mc } +dsMcStmt (BindStmt pat rhs bind_op fail_op) stmts + = do { rhs' <- dsLExpr rhs + ; dsMcBindStmt pat rhs' bind_op fail_op stmts } -- Apply `guard` to the `exp` expression -- -- [ .. | exp, stmts ] -- -dsMcStmt (ExprStmt exp then_exp guard_exp _) stmts mc +dsMcStmt (ExprStmt exp then_exp guard_exp _) stmts = do { exp' <- dsLExpr exp ; guard_exp' <- dsExpr guard_exp ; then_exp' <- dsExpr then_exp - ; rest <- dsMcStmts stmts mc + ; rest <- dsMcStmts stmts ; return $ mkApps then_exp' [ mkApps guard_exp' [exp'] , rest ] } @@ -762,26 +763,38 @@ dsMcStmt (ExprStmt exp then_exp guard_exp _) stmts mc -- -- where [| qs |] is the desugared inner monad comprehenion generated by the -- statements `qs`. -dsMcStmt (TransformStmt stmts binders usingExpr maybeByExpr return_op bind_op) stmts_rest mc - = do { (expr, _) <- dsInnerMonadComp (stmts, binders) (mc { mc_return = Left return_op }) - ; let binders_tuple_type = mkBigCoreTupTy $ map idType binders +dsMcStmt (TransformStmt stmts binders usingExpr maybeByExpr return_op bind_op) stmts_rest + = do { expr <- dsInnerMonadComp stmts binders return_op + ; let binders_tup_type = mkBigCoreTupTy $ map idType binders ; usingExpr' <- dsLExpr usingExpr ; using_args <- case maybeByExpr of Nothing -> return [expr] Just byExpr -> do byExpr' <- dsLExpr byExpr us <- newUniqueSupply - tuple_binder <- newSysLocalDs binders_tuple_type - let byExprWrapper = mkTupleCase us binders byExpr' tuple_binder (Var tuple_binder) - return [Lam tuple_binder byExprWrapper, expr] + tup_binder <- newSysLocalDs binders_tup_type + let byExprWrapper = mkTupleCase us binders byExpr' tup_binder (Var tup_binder) + return [Lam tup_binder byExprWrapper, expr] ; let pat = mkBigLHsVarPatTup binders - rhs = mkApps usingExpr' ((Type binders_tuple_type) : using_args) + rhs = mkApps usingExpr' ((Type binders_tup_type) : using_args) - ; dsMcBindStmt pat rhs bind_op noSyntaxExpr stmts_rest mc } + ; dsMcBindStmt pat rhs bind_op noSyntaxExpr stmts_rest } -- Group statements desugar like this: -- +-- [| (q, then group by e using f); rest |] +-- ---> f {qt} (\qv -> e) [| q; return qv |] >>= \ n_tup -> +-- case unzip n_tup of qv -> [| rest |] +-- +-- where variables (v1:t1, ..., vk:tk) are bound by q +-- qv = (v1, ..., vk) +-- qt = (t1, ..., tk) +-- (>>=) :: m2 a -> (a -> m3 b) -> m3 b +-- f :: forall a. (a -> t) -> m1 a -> m2 (n a) +-- n_tup :: n qt +-- unzip :: n qt -> (n t1, ..., n tk) (needs Functor n) +-- -- [| q, then group by e using f |] -> (f (\q_v -> e) [| q |]) >>= (return . (unzip q_v)) -- -- which is equal to @@ -790,24 +803,23 @@ dsMcStmt (TransformStmt stmts binders usingExpr maybeByExpr return_op bind_op) s -- -- where unzip is of the form -- --- unzip :: m (a,b,c,..) -> (m a,m b,m c,..) --- unzip m_tuple = ( liftM selN1 m_tuple --- , liftM selN2 m_tuple +-- unzip :: n (a,b,c,..) -> (n a,n b,n c,..) +-- unzip m_tuple = ( fmap selN1 m_tuple +-- , fmap selN2 m_tuple -- , .. ) -- where selN1 (a,b,c,..) = a -- selN2 (a,b,c,..) = b -- .. -- -dsMcStmt (GroupStmt stmts binderMap by using return_op bind_op liftM_op) stmts_rest mc +dsMcStmt (GroupStmt stmts binderMap by using return_op bind_op fmap_op) stmts_rest = do { let (fromBinders, toBinders) = unzip binderMap - fromBindersTypes = map idType fromBinders + fromBindersTypes = map idType fromBinders -- Types ty fromBindersTupleTy = mkBigCoreTupTy fromBindersTypes - toBindersTypes = map idType toBinders + toBindersTypes = map idType toBinders -- Types (n ty) toBindersTupleTy = mkBigCoreTupTy toBindersTypes - m_ty = mc_m_ty mc -- Desugar an inner comprehension which outputs a list of tuples of the "from" binders - ; (expr, _) <- dsInnerMonadComp (stmts, fromBinders) (mc { mc_return = Left return_op }) + ; expr <- dsInnerMonadComp stmts fromBinders return_op -- Work out what arguments should be supplied to that expression: i.e. is an extraction -- function required? If so, create that desugared function and add to arguments @@ -815,62 +827,45 @@ dsMcStmt (GroupStmt stmts binderMap by using return_op bind_op liftM_op) stmts_r ; usingArgs <- case by of Nothing -> return [expr] Just by_e -> do { by_e' <- dsLExpr by_e - ; us <- newUniqueSupply - ; from_tup_id <- newSysLocalDs fromBindersTupleTy - ; let by_wrap = mkTupleCase us fromBinders by_e' - from_tup_id (Var from_tup_id) - ; return [Lam from_tup_id by_wrap, expr] } + ; lam <- matchTuple fromBinders by_e' + ; return [lam, expr] } -- Create an unzip function for the appropriate arity and element types - ; liftM_op' <- dsExpr liftM_op - ; (unzip_fn, unzip_rhs) <- mkMcUnzipM liftM_op' m_ty fromBindersTypes + ; fmap_op' <- dsExpr fmap_op + ; (unzip_fn, unzip_rhs) <- mkMcUnzipM fmap_op' m_ty fromBindersTypes -- Generate the expressions to build the grouped list - - ; let -- First we apply the grouping function to the inner monad - inner_monad_expr = mkApps usingExpr' ((Type fromBindersTupleTy) : usingArgs) - -- Then we map our "unzip" across it to turn the "monad of tuples" into "tuples of monads" - -- We make sure we instantiate the type variable "a" to be a "monad of 'from' tuples" and - -- the "b" to be a "tuple of 'to' monads"! - unzipped_inner_monad_expr = mkApps liftM_op' -- ! - -- Types: - [ Type (m_ty `mkAppTy` fromBindersTupleTy), Type toBindersTupleTy - -- And arguments: - , Var unzip_fn, inner_monad_expr ] - -- Then finally we bind the unzip function around that expression - bound_unzipped_inner_monad_expr = Let (Rec [(unzip_fn, unzip_rhs)]) unzipped_inner_monad_expr - - -- Build a pattern that ensures the consumer binds into the NEW binders, which hold monads - -- rather than single values - ; let pat = mkBigLHsVarPatTup toBinders - rhs = bound_unzipped_inner_monad_expr - - ; dsMcBindStmt pat rhs bind_op noSyntaxExpr stmts_rest mc } + -- Build a pattern that ensures the consumer binds into the NEW binders, + -- which hold monads rather than single values + ; bind_op' <- dsExpr bind_op + ; let bind_ty = exprType bind_op' -- m2 (n (a,b,c)) -> (n (a,b,c) -> r1) -> r2 + n_tup_ty = funArgTy $ funArgTy $ funResultTy bind_ty + + ; body <- dsMcStmts stmts_rest + ; n_tup_var <- newSysLocalDs n_tup_ty + ; tup_n_var <- newSysLocalDs (mkBigCoreVarTupTy toBinders) + ; us <- newUniqueSupply + ; let unzip_n_tup = Let (Rec [(unzip_fn, unzip_rhs)]) $ + App (Var unzip_fn) (Var n_tup_var) + -- unzip_n_tup :: (n a, n b, n c) + body' = mkTupleCase us toBinders body unzip_n_tup (Var tup_n_var) + + ; return (mkApps bind_op' [rhs', Lam n_tup_var body']) } -- Parallel statements. Use `Control.Monad.Zip.mzip` to zip parallel -- statements, for example: -- -- [ body | qs1 | qs2 | qs3 ] --- -> [ body | (bndrs1, (bndrs2, bndrs3)) <- mzip qs1 (mzip qs2 qs3) ] --- --- where `mzip` is of the form +-- -> [ body | (bndrs1, (bndrs2, bndrs3)) +-- <- [bndrs1 | qs1] `mzip` ([bndrs2 | qs2] `mzip` [bndrs3 | qs3]) ] -- --- mzip :: m a -> m b -> m (a,b) --- -dsMcStmt (ParStmt pairs mzip_op bind_op return_op) stmts_rest mc - = do { -- Get types for `return` - return_op' <- dsExpr return_op - ; let pairs_with_return = map (\tp@(_,b) -> (mkReturn b,tp)) pairs - mkReturn bndrs = mkApps return_op' [Type (mkBigCoreTupTy (map idType bndrs))] - - ; pairs' <- mapM (\(r,tp) -> dsInnerMonadComp tp mc{mc_return = Right r}) - pairs_with_return - - ; let (exps, _qual_tys) = unzip pairs' - -- Types of our `Id`s are getting messed up by `dsInnerMonadComp` - -- so we construct them by hand: - qual_tys = map (mkBigCoreTupTy . map idType . snd) pairs +-- where `mzip` has type +-- mzip :: forall a b. m a -> m b -> m (a,b) +-- NB: we need a polymorphic mzip because we call it several times +dsMcStmt (ParStmt pairs mzip_op bind_op return_op) stmts_rest + = do { exps <- mapM ds_inner pairs + ; let qual_tys = map (mkBigCoreVarTupTy . snd) pairs ; mzip_op' <- dsExpr mzip_op ; (zip_fn, zip_rhs) <- mkMcZipM mzip_op' (mc_m_ty mc) qual_tys @@ -881,9 +876,23 @@ dsMcStmt (ParStmt pairs mzip_op bind_op return_op) stmts_rest mc pat = foldr (\tn tm -> mkBigLHsPatTup [tn, tm]) (last vars) (init vars) rhs = Let (Rec [(zip_fn, zip_rhs)]) (mkApps (Var zip_fn) exps) - ; dsMcBindStmt pat rhs bind_op noSyntaxExpr stmts_rest mc } + ; dsMcBindStmt pat rhs bind_op noSyntaxExpr stmts_rest } + where + ds_inner (stmts, bndrs) = dsInnerMonadComp stmts bndrs mono_ret_op + where + mono_ret_op = HsWrap (WpTyApp (mkBigCoreVarTupTy bndrs)) return_op -dsMcStmt stmt _ _ = pprPanic "dsMcStmt: unexpected stmt" (ppr stmt) +dsMcStmt stmt _ = pprPanic "dsMcStmt: unexpected stmt" (ppr stmt) + + +matchTuple :: [Id] -> CoreExpr -> DsM CoreExpr +-- (matchTuple [a,b,c] body) +-- returns the Core term +-- \x. case x of (a,b,c) -> body +matchTuple ids body + = do { us <- newUniqueSupply + ; tup_id <- newSysLocalDs (mkBigLHsVarPatTup ids) + ; return (Lam tup_id $ mkTupleCase us ids body tup_id (Var tup_id)) } -- general `rhs' >>= \pat -> stmts` desugaring where `rhs'` is already a @@ -893,10 +902,9 @@ dsMcBindStmt :: LPat Id -> SyntaxExpr Id -> SyntaxExpr Id -> [LStmt Id] - -> DsMonadComp -> DsM CoreExpr -dsMcBindStmt pat rhs' bind_op fail_op stmts mc - = do { body <- dsMcStmts stmts mc +dsMcBindStmt pat rhs' bind_op fail_op stmts + = do { body <- dsMcStmts stmts ; bind_op' <- dsExpr bind_op ; var <- selectSimpleMatchVarL pat ; let bind_ty = exprType bind_op' -- rhs -> (pat -> res1) -> res2 @@ -922,16 +930,16 @@ dsMcBindStmt pat rhs' bind_op fail_op stmts mc showSDoc (ppr (getLoc pat)) -- Desugar nested monad comprehensions, for example in `then..` constructs -dsInnerMonadComp :: ([LStmt Id], [Id]) - -> DsMonadComp - -> DsM (CoreExpr, Type) -dsInnerMonadComp (stmts, bndrs) DsMonadComp{ mc_return, mc_m_ty } - = do { expr <- dsMcStmts stmts mc' - ; return (expr, bndrs_tuple_type) } - where - bndrs_types = map idType bndrs - bndrs_tuple_type = mkAppTy mc_m_ty $ mkBigCoreTupTy bndrs_types - mc' = DsMonadComp mc_return (mkBigLHsVarTup bndrs) mc_m_ty +-- dsInnerMonadComp quals [a,b,c] ret_op +-- returns the desugaring of +-- [ (a,b,c) | quals ] + +dsInnerMonadComp :: [LStmt Id] + -> [Id] -- Return a tuple of these variables + -> LHsExpr Id -- The monomorphic "return" operator + -> DsM CoreExpr +dsInnerMonadComp stmts bndrs ret_op + = dsMcStmts (stmts ++ [noLoc (ReturnStmt (mkBigLHsVarTup bndrs) ret_op)]) -- The `unzip` function for `GroupStmt` in a monad comprehensions -- diff --git a/compiler/hsSyn/HsExpr.lhs b/compiler/hsSyn/HsExpr.lhs index e367af5..f7b693f 100644 --- a/compiler/hsSyn/HsExpr.lhs +++ b/compiler/hsSyn/HsExpr.lhs @@ -23,6 +23,7 @@ import Name import BasicTypes import DataCon import SrcLoc +import Util( dropTail ) import Outputable import FastString @@ -146,10 +147,6 @@ data HsExpr id -- because in this context we never use -- the PatGuard or ParStmt variant [LStmt id] -- "do":one or more stmts - (LHsExpr id) -- The body; the last expression in the - -- 'do' of [ body | ... ] in a list comp - (SyntaxExpr id) -- The 'return' function, see Note - -- [Monad Comprehensions] PostTcType -- Type of the whole expression | ExplicitList -- syntactic list @@ -441,7 +438,7 @@ ppr_expr (HsLet binds expr) = sep [hang (ptext (sLit "let")) 2 (pprBinds binds), hang (ptext (sLit "in")) 2 (ppr expr)] -ppr_expr (HsDo do_or_list_comp stmts body _ _) = pprDo do_or_list_comp stmts body +ppr_expr (HsDo do_or_list_comp stmts _) = pprDo do_or_list_comp stmts ppr_expr (ExplicitList _ exprs) = brackets (pprDeeperList fsep (punctuate comma (map ppr_lexpr exprs))) @@ -577,7 +574,7 @@ pprParendExpr expr HsPar {} -> pp_as_was HsBracket {} -> pp_as_was HsBracketOut _ [] -> pp_as_was - HsDo sc _ _ _ _ + HsDo sc _ _ | isListCompExpr sc -> pp_as_was _ -> parens pp_as_was @@ -835,7 +832,12 @@ type Stmt id = StmtLR id id -- The SyntaxExprs in here are used *only* for do-notation and monad -- comprehensions, which have rebindable syntax. Otherwise they are unused. data StmtLR idL idR - = BindStmt (LPat idL) + = LastStmt -- Always the last Stmt in ListComp, MonadComp, PArrComp, DoExpr, MDoExpr + -- Not used for GhciStmt, PatGuard, which scope over other stuff + (LHsExpr idR) + (SyntaxExpr idR) -- The return operator, used only for MonadComp + -- See Note [Monad Comprehensions] + | BindStmt (LPat idL) (LHsExpr idR) (SyntaxExpr idR) -- The (>>=) operator (SyntaxExpr idR) -- The fail operator @@ -852,9 +854,10 @@ data StmtLR idL idR -- ParStmts only occur in a list/monad comprehension | ParStmt [([LStmt idL], [idR])] - (SyntaxExpr idR) -- polymorphic `mzip` for monad comprehensions + (SyntaxExpr idR) -- Polymorphic `mzip` for monad comprehensions (SyntaxExpr idR) -- The `>>=` operator - (SyntaxExpr idR) -- polymorphic `return` operator + (SyntaxExpr idR) -- Polymorphic `return` operator + -- with type (forall a. a -> m a) -- See notes [Monad Comprehensions] -- After renaming, the ids are the binders bound by the stmts and used @@ -926,6 +929,10 @@ data StmtLR idL idR -- because the Id may be *polymorphic*, but -- the returned thing has to be *monomorphic*, -- so they may be type applications + + , recS_ret_ty :: PostTcType -- The type of of do { stmts; return (a,b,c) } + -- With rebindable syntax the type might not + -- be quite as simple as (m (tya, tyb, tyc)). } deriving (Data, Typeable) \end{code} @@ -1022,10 +1029,10 @@ where v1..vn are the later_ids Note [Monad Comprehensions] ~~~~~~~~~~~~~~~~~~~~~~~~~~~ -Monad comprehensions require seperate functions like 'return' and '>>=' for -desugaring. These functions are stored in the 'HsDo' expression and the -statements used in monad comprehensions. For example, the 'return' of the -'HsDo' expression is used to lift the body of the monad comprehension: +Monad comprehensions require separate functions like 'return' and +'>>=' for desugaring. These functions are stored in the statements +used in monad comprehensions. For example, the 'return' of the 'LastStmt' +expression is used to lift the body of the monad comprehension: [ body | stmts ] => @@ -1065,6 +1072,7 @@ instance (OutputableBndr idL, OutputableBndr idR) => Outputable (StmtLR idL idR) ppr stmt = pprStmt stmt pprStmt :: (OutputableBndr idL, OutputableBndr idR) => (StmtLR idL idR) -> SDoc +pprStmt (LastStmt expr _) = ppr expr pprStmt (BindStmt pat expr _ _) = hsep [ppr pat, ptext (sLit "<-"), ppr expr] pprStmt (LetStmt binds) = hsep [ptext (sLit "let"), pprBinds binds] pprStmt (ExprStmt expr _ _ _) = ppr expr @@ -1103,28 +1111,32 @@ pprBy :: OutputableBndr id => Maybe (LHsExpr id) -> SDoc pprBy Nothing = empty pprBy (Just e) = ptext (sLit "by") <+> ppr e -pprDo :: OutputableBndr id => HsStmtContext any -> [LStmt id] -> LHsExpr id -> SDoc -pprDo DoExpr stmts body = ptext (sLit "do") <+> ppr_do_stmts stmts body -pprDo GhciStmt stmts body = ptext (sLit "do") <+> ppr_do_stmts stmts body -pprDo MDoExpr stmts body = ptext (sLit "mdo") <+> ppr_do_stmts stmts body -pprDo ListComp stmts body = brackets $ pprComp stmts body -pprDo PArrComp stmts body = pa_brackets $ pprComp stmts body -pprDo MonadComp stmts body = brackets $ pprComp stmts body -pprDo _ _ _ = panic "pprDo" -- PatGuard, ParStmtCxt +pprDo :: OutputableBndr id => HsStmtContext any -> [LStmt id] -> SDoc +pprDo DoExpr stmts = ptext (sLit "do") <+> ppr_do_stmts stmts +pprDo GhciStmt stmts = ptext (sLit "do") <+> ppr_do_stmts stmts +pprDo MDoExpr stmts = ptext (sLit "mdo") <+> ppr_do_stmts stmts +pprDo ListComp stmts = brackets $ pprComp stmts +pprDo PArrComp stmts = pa_brackets $ pprComp stmts +pprDo MonadComp stmts = brackets $ pprComp stmts +pprDo _ _ = panic "pprDo" -- PatGuard, ParStmtCxt -ppr_do_stmts :: OutputableBndr id => [LStmt id] -> LHsExpr id -> SDoc +ppr_do_stmts :: OutputableBndr id => [LStmt id] -> SDoc -- Print a bunch of do stmts, with explicit braces and semicolons, -- so that we are not vulnerable to layout bugs -ppr_do_stmts stmts body - = lbrace <+> pprDeeperList vcat ([ppr s <> semi | s <- stmts] ++ [ppr body]) +ppr_do_stmts stmts + = lbrace <+> pprDeeperList vcat ([ppr s <> semi | s <- stmts]) <+> rbrace ppr_lc_stmts :: OutputableBndr id => [LStmt id] -> [SDoc] ppr_lc_stmts stmts = [ppr s <> comma | s <- stmts] -pprComp :: OutputableBndr id => [LStmt id] -> LHsExpr id -> SDoc -pprComp quals body -- Prints: body | qual1, ..., qualn - = hang (ppr body <+> char '|') 2 (interpp'SP quals) +pprComp :: OutputableBndr id => [LStmt id] -> SDoc +pprComp quals -- Prints: body | qual1, ..., qualn + | not (null quals) + , L _ (LastStmt body _) <- last quals + = hang (ppr body <+> char '|') 2 (interpp'SP (dropTail 1 quals)) + | otherwise + = pprPanic "pprComp" (interpp'SP quals) \end{code} %************************************************************************ @@ -1242,11 +1254,13 @@ data HsMatchContext id -- Context of a Match data HsStmtContext id = ListComp - | DoExpr - | GhciStmt -- A command-line Stmt in GHCi pat <- rhs - | MDoExpr -- Recursive do-expression | MonadComp | PArrComp -- Parallel array comprehension + + | DoExpr -- do { ... } + | MDoExpr -- mdo { ... } ie recursive do-expression + + | GhciStmt -- A command-line Stmt in GHCi pat <- rhs | PatGuard (HsMatchContext id) -- Pattern guard for specified thing | ParStmtCtxt (HsStmtContext id) -- A branch of a parallel stmt | TransformStmtCtxt (HsStmtContext id) -- A branch of a transform stmt diff --git a/compiler/hsSyn/HsUtils.lhs b/compiler/hsSyn/HsUtils.lhs index 44e3a32..0d91e9f 100644 --- a/compiler/hsSyn/HsUtils.lhs +++ b/compiler/hsSyn/HsUtils.lhs @@ -21,7 +21,7 @@ module HsUtils( mkMatchGroup, mkMatch, mkHsLam, mkHsIf, mkHsWrap, mkLHsWrap, mkHsWrapCoI, mkLHsWrapCoI, coiToHsWrapper, mkHsLams, mkHsDictLet, - mkHsOpApp, mkHsDo, mkHsWrapPat, mkHsWrapPatCoI, + mkHsOpApp, mkHsDo, mkHsComp, mkHsWrapPat, mkHsWrapPatCoI, mkDoStmts, nlHsTyApp, nlHsVar, nlHsLit, nlHsApp, nlHsApps, nlHsIntLit, nlHsVarApps, nlHsDo, nlHsOpApp, nlHsLam, nlHsPar, nlHsIf, nlHsCase, nlList, @@ -42,7 +42,7 @@ module HsUtils( nlHsAppTy, nlHsTyVar, nlHsFunTy, nlHsTyConApp, -- Stmts - mkTransformStmt, mkTransformByStmt, mkExprStmt, mkBindStmt, + mkTransformStmt, mkTransformByStmt, mkExprStmt, mkBindStmt, mkLastStmt, mkGroupUsingStmt, mkGroupByStmt, mkGroupByUsingStmt, emptyRecStmt, mkRecStmt, @@ -190,7 +190,9 @@ mkSimpleHsAlt pat expr mkHsIntegral :: Integer -> PostTcType -> HsOverLit id mkHsFractional :: Rational -> PostTcType -> HsOverLit id mkHsIsString :: FastString -> PostTcType -> HsOverLit id -mkHsDo :: HsStmtContext Name -> [LStmt id] -> LHsExpr id -> HsExpr id +mkHsDo :: HsStmtContext Name -> [LStmt id] -> HsExpr id +mkHsComp :: HsStmtContext Name -> [LStmt id] -> LHsExpr id -> HsExpr id +mkDoStmts :: [LStmt id] -> [LStmt id] mkNPat :: HsOverLit id -> Maybe (SyntaxExpr id) -> Pat id mkNPlusKPat :: Located id -> HsOverLit id -> Pat id @@ -198,6 +200,7 @@ mkNPlusKPat :: Located id -> HsOverLit id -> Pat id mkTransformStmt :: [LStmt idL] -> LHsExpr idR -> StmtLR idL idR mkTransformByStmt :: [LStmt idL] -> LHsExpr idR -> LHsExpr idR -> StmtLR idL idR +mkLastStmt :: LHsExpr idR -> StmtLR idL idR mkExprStmt :: LHsExpr idR -> StmtLR idL idR mkBindStmt :: LPat idL -> LHsExpr idR -> StmtLR idL idR @@ -212,7 +215,15 @@ mkHsIsString s = OverLit (HsIsString s) noRebindableInfo noSyntaxExpr noRebindableInfo :: Bool noRebindableInfo = error "noRebindableInfo" -- Just another placeholder; -mkHsDo ctxt stmts body = HsDo ctxt stmts body noSyntaxExpr placeHolderType +-- mkDoStmts turns a trailing ExprStmt into a LastStmt +mkDoStmts [L loc (ExprStmt e _ _ _)] = [L loc (mkLastStmt e)] +mkDoStmts (s:ss) = s : mkDoStmts ss +mkDoStmts [] = [] + +mkHsDo ctxt stmts = HsDo ctxt stmts placeHolderType +mkHsComp ctxt stmts expr = mkHsDo ctxt (stmts ++ [last_stmt]) + where + last_stmt = L (getLoc expr) $ mkLastStmt expr mkHsIf :: LHsExpr id -> LHsExpr id -> LHsExpr id -> HsExpr id mkHsIf c a b = HsIf (Just noSyntaxExpr) c a b @@ -231,13 +242,14 @@ mkGroupUsingStmt stmts usingExpr = GroupStmt stmts [] Nothing (Le mkGroupByStmt stmts byExpr = GroupStmt stmts [] (Just byExpr) (Right noSyntaxExpr) noSyntaxExpr noSyntaxExpr noSyntaxExpr mkGroupByUsingStmt stmts byExpr usingExpr = GroupStmt stmts [] (Just byExpr) (Left usingExpr) noSyntaxExpr noSyntaxExpr noSyntaxExpr +mkLastStmt expr = LastStmt expr noSyntaxExpr mkExprStmt expr = ExprStmt expr noSyntaxExpr noSyntaxExpr placeHolderType mkBindStmt pat expr = BindStmt pat expr noSyntaxExpr noSyntaxExpr emptyRecStmt = RecStmt { recS_stmts = [], recS_later_ids = [], recS_rec_ids = [] , recS_ret_fn = noSyntaxExpr, recS_mfix_fn = noSyntaxExpr , recS_bind_fn = noSyntaxExpr - , recS_rec_rets = [] } + , recS_rec_rets = [], recS_ret_ty = placeHolderType } mkRecStmt stmts = emptyRecStmt { recS_stmts = stmts } @@ -327,8 +339,8 @@ nlWildConPat con = noLoc (ConPatIn (noLoc (getRdrName con)) nlWildPat :: LPat id nlWildPat = noLoc (WildPat placeHolderType) -- Pre-typechecking -nlHsDo :: HsStmtContext Name -> [LStmt id] -> LHsExpr id -> LHsExpr id -nlHsDo ctxt stmts body = noLoc (mkHsDo ctxt stmts body) +nlHsDo :: HsStmtContext Name -> [LStmt id] -> LHsExpr id +nlHsDo ctxt stmts = noLoc (mkHsDo ctxt stmts) nlHsOpApp :: LHsExpr id -> id -> LHsExpr id -> LHsExpr id nlHsOpApp e1 op e2 = noLoc (mkHsOpApp e1 op e2) @@ -496,7 +508,8 @@ collectStmtBinders :: StmtLR idL idR -> [idL] -- Id Binders for a Stmt... [but what about pattern-sig type vars]? collectStmtBinders (BindStmt pat _ _ _) = collectPatBinders pat collectStmtBinders (LetStmt binds) = collectLocalBinders binds -collectStmtBinders (ExprStmt _ _ _ _) = [] +collectStmtBinders (ExprStmt {}) = [] +collectStmtBinders (LastStmt {}) = [] collectStmtBinders (ParStmt xs _ _ _) = collectLStmtsBinders $ concatMap fst xs collectStmtBinders (TransformStmt stmts _ _ _ _ _) = collectLStmtsBinders stmts @@ -642,7 +655,8 @@ lStmtsImplicits = hs_lstmts hs_stmt (BindStmt pat _ _ _) = lPatImplicits pat hs_stmt (LetStmt binds) = hs_local_binds binds - hs_stmt (ExprStmt _ _ _ _) = emptyNameSet + hs_stmt (ExprStmt {}) = emptyNameSet + hs_stmt (LastStmt {}) = emptyNameSet hs_stmt (ParStmt xs _ _ _) = hs_lstmts $ concatMap fst xs hs_stmt (TransformStmt stmts _ _ _ _ _) = hs_lstmts stmts diff --git a/compiler/parser/Parser.y.pp b/compiler/parser/Parser.y.pp index ec8d3ff..c42ea0c 100644 --- a/compiler/parser/Parser.y.pp +++ b/compiler/parser/Parser.y.pp @@ -1283,14 +1283,9 @@ exp10 :: { LHsExpr RdrName } | 'case' exp 'of' altslist { LL $ HsCase $2 (mkMatchGroup (unLoc $4)) } | '-' fexp { LL $ NegApp $2 noSyntaxExpr } - | 'do' stmtlist {% let loc = comb2 $1 $2 in - checkDo loc (unLoc $2) >>= \ (stmts,body) -> - return (L loc (mkHsDo DoExpr stmts body)) } - | 'mdo' stmtlist {% let loc = comb2 $1 $2 in - checkDo loc (unLoc $2) >>= \ (stmts,body) -> - return (L loc (mkHsDo MDoExpr - [L loc (mkRecStmt stmts)] - body)) } + | 'do' stmtlist { L (comb2 $1 $2) (mkHsDo DoExpr (unLoc $2)) } + | 'mdo' stmtlist { L (comb2 $1 $2) (mkHsDo MDoExpr (unLoc $2)) } + | scc_annot exp { LL $ if opt_SccProfilingOn then HsSCC (unLoc $1) $2 else HsPar $2 } @@ -1465,8 +1460,10 @@ list :: { LHsExpr RdrName } | texp ',' exp '..' { LL $ ArithSeq noPostTcExpr (FromThen $1 $3) } | texp '..' exp { LL $ ArithSeq noPostTcExpr (FromTo $1 $3) } | texp ',' exp '..' exp { LL $ ArithSeq noPostTcExpr (FromThenTo $1 $3 $5) } - | texp '|' flattenedpquals {% checkMonadComp >>= \ ctxt -> - return (sL (comb2 $1 $>) $ mkHsDo ctxt (unLoc $3) $1) } + | texp '|' flattenedpquals + {% checkMonadComp >>= \ ctxt -> + return (sL (comb2 $1 $>) $ + mkHsComp ctxt (unLoc $3) $1) } lexps :: { Located [LHsExpr RdrName] } : lexps ',' texp { LL (((:) $! $3) $! unLoc $1) } @@ -1538,7 +1535,7 @@ parr :: { LHsExpr RdrName } (reverse (unLoc $1)) } | texp '..' exp { LL $ PArrSeq noPostTcExpr (FromTo $1 $3) } | texp ',' exp '..' exp { LL $ PArrSeq noPostTcExpr (FromThenTo $1 $3 $5) } - | texp '|' flattenedpquals { LL $ mkHsDo PArrComp (unLoc $3) $1 } + | texp '|' flattenedpquals { LL $ mkHsComp PArrComp (unLoc $3) $1 } -- We are reusing `lexps' and `flattenedpquals' from the list case. @@ -1605,7 +1602,7 @@ apats :: { [LPat RdrName] } -- Statement sequences stmtlist :: { Located [LStmt RdrName] } - : '{' stmts '}' { LL (unLoc $2) } + : '{' stmts '}' { LL (mkDoStmts (unLoc $2)) } | vocurly stmts close { $2 } -- do { ;; s ; s ; ; s ;; } diff --git a/compiler/parser/RdrHsSyn.lhs b/compiler/parser/RdrHsSyn.lhs index 0e22c69..3b14990 100644 --- a/compiler/parser/RdrHsSyn.lhs +++ b/compiler/parser/RdrHsSyn.lhs @@ -40,8 +40,6 @@ module RdrHsSyn ( checkPattern, -- HsExp -> P HsPat bang_RDR, checkPatterns, -- SrcLoc -> [HsExp] -> P [HsPat] - checkDo, -- [Stmt] -> P [Stmt] - checkMDo, -- [Stmt] -> P [Stmt] checkMonadComp, -- P (HsStmtContext RdrName) checkValDef, -- (SrcLoc, HsExp, HsRhs, [HsDecl]) -> P HsDecl checkValSig, -- (SrcLoc, HsExp, HsRhs, [HsDecl]) -> P HsDecl @@ -613,34 +611,6 @@ checkPred (L spn ty) check loc _ _ = parseErrorSDoc loc (text "malformed class assertion:" <+> ppr ty) ---------------------------------------------------------------------------- --- Checking statements in a do-expression --- We parse do { e1 ; e2 ; } --- as [ExprStmt e1, ExprStmt e2] --- checkDo (a) checks that the last thing is an ExprStmt --- (b) returns it separately --- same comments apply for mdo as well - -checkDo, checkMDo :: SrcSpan -> [LStmt RdrName] -> P ([LStmt RdrName], LHsExpr RdrName) - -checkDo = checkDoMDo "a " "'do'" -checkMDo = checkDoMDo "an " "'mdo'" - -checkDoMDo :: String -> String -> SrcSpan -> [LStmt RdrName] -> P ([LStmt RdrName], LHsExpr RdrName) -checkDoMDo _ nm loc [] = parseErrorSDoc loc (text ("Empty " ++ nm ++ " construct")) -checkDoMDo pre nm _ ss = do - check ss - where - check [] = panic "RdrHsSyn:checkDoMDo" - check [L _ (ExprStmt e _ _ _)] = return ([], e) - check [L l e] = parseErrorSDoc l - (text ("The last statement in " ++ pre ++ nm ++ - " construct must be an expression:") - $$ ppr e) - check (s:ss) = do - (ss',e') <- check ss - return ((s:ss'),e') - -- ------------------------------------------------------------------------- -- Checking Patterns. diff --git a/compiler/prelude/PrelNames.lhs b/compiler/prelude/PrelNames.lhs index 9b59f5d..421ec45 100644 --- a/compiler/prelude/PrelNames.lhs +++ b/compiler/prelude/PrelNames.lhs @@ -160,6 +160,7 @@ basicKnownKeyNames -- Monad stuff thenIOName, bindIOName, returnIOName, failIOName, failMName, bindMName, thenMName, returnMName, + fmapName, -- MonadRec stuff mfixName, @@ -612,6 +613,7 @@ eqName = methName gHC_CLASSES (fsLit "==") eqClassOpKey ordClassName = clsQual gHC_CLASSES (fsLit "Ord") ordClassKey geName = methName gHC_CLASSES (fsLit ">=") geClassOpKey functorClassName = clsQual gHC_BASE (fsLit "Functor") functorClassKey +fmapName = methName gHC_BASE (fsLit "fmap") fmapClassOpKey -- Class Monad monadClassName, thenMName, bindMName, returnMName, failMName :: Name @@ -1312,6 +1314,7 @@ negateClassOpKey = mkPreludeMiscIdUnique 111 failMClassOpKey = mkPreludeMiscIdUnique 112 bindMClassOpKey = mkPreludeMiscIdUnique 113 -- (>>=) thenMClassOpKey = mkPreludeMiscIdUnique 114 -- (>>) +fmapClassOpKey = mkPreludeMiscIdUnique 115 returnMClassOpKey = mkPreludeMiscIdUnique 117 -- Recursive do notation diff --git a/compiler/rename/RnExpr.lhs b/compiler/rename/RnExpr.lhs index 425cb40..e3e92bc 100644 --- a/compiler/rename/RnExpr.lhs +++ b/compiler/rename/RnExpr.lhs @@ -224,16 +224,9 @@ rnExpr (HsLet binds expr) rnLExpr expr `thenM` \ (expr',fvExpr) -> return (HsLet binds' expr', fvExpr) -rnExpr (HsDo do_or_lc stmts body _ _) - = do { ((stmts', body'), fvs1) <- rnStmts do_or_lc stmts $ \ _ -> - rnLExpr body - ; (return_op, fvs2) <- - if isMonadCompExpr do_or_lc - then lookupSyntaxName returnMName - else return (noSyntaxExpr, emptyFVs) - - ; return ( HsDo do_or_lc stmts' body' return_op placeHolderType - , fvs1 `plusFV` fvs2 ) } +rnExpr (HsDo do_or_lc stmts _) + = do { ((stmts', _), fvs) <- rnStmts do_or_lc stmts (\ _ -> return ()) + ; return ( HsDo do_or_lc stmts' placeHolderType, fvs ) } rnExpr (ExplicitList _ exps) = rnExprs exps `thenM` \ (exps', fvs) -> @@ -653,25 +646,53 @@ rnStmts :: HsStmtContext Name -> [LStmt RdrName] -- -- Renaming a single RecStmt can give a sequence of smaller Stmts -rnStmts _ [] thing_inside - = do { (res, fvs) <- thing_inside [] - ; return (([], res), fvs) } +rnStmts ctxt [] thing_inside + = do { addErr (ptext (sLit "Empty") <+> pprStmtContext ctxt) + ; (thing, fvs) <- thing_inside [] + ; return (([], thing), fvs) } + +rnStmts MDoExpr stmts thing_inside -- Deal with mdo + = -- Behave like do { rec { ...all but last... }; last } + do { ((stmts1, (stmts2, thing)), fvs) + <- rnStmt MDoExpr (mkRecStmt all_but_last) $ \ bndrs -> + do { checkStmt MDoExpr True last_stmt + ; rnStmt MDoExpr last_stmt thing_inside } + ; return (((stmts1 ++ stmts2), thing), fvs) } + where + Just (all_but_last, last_stmt) = snocView stmts rnStmts ctxt (stmt@(L loc _) : stmts) thing_inside + | null stmts + = setSrcSpan loc $ + do { let last_stmt = case stmt of + ExprStmt e _ _ _ -> LastStmt e noSyntaxExpr + ; checkStmt ctxt True {- last stmt -} stmt + ; rnStmt ctxt stmt thing_inside } + + | otherwise = do { ((stmts1, (stmts2, thing)), fvs) - <- setSrcSpan loc $ - rnStmt ctxt stmt $ \ bndrs1 -> - rnStmts ctxt stmts $ \ bndrs2 -> - thing_inside (bndrs1 ++ bndrs2) + <- setSrcSpan loc $ + do { checkStmt ctxt False {- Not last -} stmt + ; rnStmt ctxt stmt $ \ bndrs1 -> + rnStmts ctxt stmts $ \ bndrs2 -> + thing_inside (bndrs1 ++ bndrs2) } ; return (((stmts1 ++ stmts2), thing), fvs) } - -rnStmt :: HsStmtContext Name -> LStmt RdrName +---------------------- +rnStmt :: HsStmtContext Name + -> LStmt RdrName -> ([Name] -> RnM (thing, FreeVars)) -> RnM (([LStmt Name], thing), FreeVars) -- Variables bound by the Stmt, and mentioned in thing_inside, -- do not appear in the result FreeVars +rnStmt ctxt (L loc (LastStmt expr _)) thing_inside + = do { (expr', fv_expr) <- rnLExpr expr + ; (ret_op, fvs1) <- lookupSyntaxName returnMName + ; (thing, fvs3) <- thing_inside [] + ; return (([L loc (LastStmt expr' ret_op)], thing), + fv_expr `plusFV` fvs1 `plusFV` fvs3) } + rnStmt ctxt (L loc (ExprStmt expr _ _ _)) thing_inside = do { (expr', fv_expr) <- rnLExpr expr ; (then_op, fvs1) <- lookupSyntaxName thenMName @@ -683,7 +704,8 @@ rnStmt ctxt (L loc (ExprStmt expr _ _ _)) thing_inside fv_expr `plusFV` fvs1 `plusFV` fvs2 `plusFV` fvs3) } rnStmt ctxt (L loc (BindStmt pat expr _ _)) thing_inside - = do { (expr', fv_expr) <- rnLExpr expr + = do { checkBindStmt ctxt is_last + ; (expr', fv_expr) <- rnLExpr expr -- The binders do not scope over the expression ; (bind_op, fvs1) <- lookupSyntaxName bindMName ; (fail_op, fvs2) <- lookupSyntaxName failMName @@ -701,8 +723,7 @@ rnStmt ctxt (L loc (LetStmt binds)) thing_inside ; return (([L loc (LetStmt binds')], thing), fvs) } } rnStmt ctxt (L _ (RecStmt { recS_stmts = rec_stmts })) thing_inside - = do { checkRecStmt ctxt - + = do { -- Step1: Bring all the binders of the mdo into scope -- (Remember that this also removes the binders from the -- finally-returned free-vars.) @@ -745,8 +766,7 @@ rnStmt ctxt (L _ (RecStmt { recS_stmts = rec_stmts })) thing_inside ; return ((rec_stmts', thing), fvs `plusFV` fvs1 `plusFV` fvs2 `plusFV` fvs3) } } rnStmt ctxt (L loc (ParStmt segs _ _ _)) thing_inside - = do { checkParStmt ctxt - ; ((mzip_op, fvs1), (bind_op, fvs2), (return_op, fvs3)) <- if isMonadCompExpr ctxt + = do { ((mzip_op, fvs1), (bind_op, fvs2), (return_op, fvs3)) <- if isMonadCompExpr ctxt then (,,) <$> lookupSyntaxName mzipName <*> lookupSyntaxName bindMName <*> lookupSyntaxName returnMName @@ -758,9 +778,7 @@ rnStmt ctxt (L loc (ParStmt segs _ _ _)) thing_inside , fvs1 `plusFV` fvs2 `plusFV` fvs3 `plusFV` fvs4) } rnStmt ctxt (L loc (TransformStmt stmts _ using by _ _)) thing_inside - = do { checkTransformStmt ctxt - - ; (using', fvs1) <- rnLExpr using + = do { (using', fvs1) <- rnLExpr using ; ((stmts', (by', used_bndrs, thing)), fvs2) <- rnStmts (TransformStmtCtxt ctxt) stmts $ \ bndrs -> @@ -786,9 +804,7 @@ rnStmt ctxt (L loc (TransformStmt stmts _ using by _ _)) thing_inside fvs1 `plusFV` fvs2 `plusFV` fvs3 `plusFV` fvs4) } rnStmt ctxt (L loc (GroupStmt stmts _ by using _ _ _)) thing_inside - = do { checkTransformStmt ctxt - - -- Rename the 'using' expression in the context before the transform is begun + = do { -- Rename the 'using' expression in the context before the transform is begun ; (using', fvs1) <- case using of Left e -> do { (e', fvs) <- rnLExpr e; return (Left e', fvs) } Right _ @@ -810,11 +826,11 @@ rnStmt ctxt (L loc (GroupStmt stmts _ by using _ _ _)) thing_inside ; return ((by', used_bndrs, thing), fvs) } -- Lookup `return`, `(>>=)` and `liftM` for monad comprehensions - ; ((return_op, fvs3), (bind_op, fvs4), (liftM_op, fvs5)) <- + ; ((return_op, fvs3), (bind_op, fvs4), (fmap_op, fvs5)) <- if isMonadCompExpr ctxt then (,,) <$> lookupSyntaxName returnMName <*> lookupSyntaxName bindMName - <*> lookupSyntaxName liftMName + <*> lookupSyntaxName fmapName else return ( (noSyntaxExpr, emptyFVs) , (noSyntaxExpr, emptyFVs) , (noSyntaxExpr, emptyFVs) ) @@ -825,7 +841,7 @@ rnStmt ctxt (L loc (GroupStmt stmts _ by using _ _ _)) thing_inside -- See Note [GroupStmt binder map] in HsExpr ; traceRn (text "rnStmt: implicitly rebound these used binders:" <+> ppr bndr_map) - ; return (([L loc (GroupStmt stmts' bndr_map by' using' return_op bind_op liftM_op)], thing), all_fvs) } + ; return (([L loc (GroupStmt stmts' bndr_map by' using' return_op bind_op fmap_op)], thing), all_fvs) } type ParSeg id = ([LStmt id], [id]) -- The Names are bound by the Stmts @@ -1182,22 +1198,124 @@ program. %************************************************************************ \begin{code} - ---------------------- -- Checking when a particular Stmt is ok -checkLetStmt :: HsStmtContext Name -> HsLocalBinds RdrName -> RnM () -checkLetStmt (ParStmtCtxt _) (HsIPBinds binds) = addErr (badIpBinds (ptext (sLit "a parallel list comprehension:")) binds) -checkLetStmt _ctxt _binds = return () +checkStmt :: HsStmtContext Name + -> Bool -- True <=> this is the last Stmt in the sequence + -> LStmt RdrName + -> RnM () +checkStmt ctxt is_last (L _ stmt) + = do { dflags <- getDOpts + ; case okStmt dflags ctxt is_last stmt of + Nothing -> return () + Just extr -> addErr (msg $$ extra) } + where + msg = ptext (sLit "Unexpected") <+> pprStmtCat stmt + <+> ptext (sLit "statement in") <+> pprStmtContext ctxt + +pprStmtCat :: Stmt a -> SDoc +pprStmtCat (TransformStmt {}) = ptext (sLit "transform") +pprStmtCat (GroupStmt {}) = ptext (sLit "group") +pprStmtCat (LastStmt {}) = ptext (sLit "return expression") +pprStmtCat (ExprStmt {}) = ptext (sLit "exprssion") +pprStmtCat (BindStmt {}) = ptext (sLit "binding") +pprStmtCat (LetStmt {}) = ptext (sLit "let") +pprStmtCat (RecStmt {}) = ptext (sLit "rec") +pprStmtCat (ParStmt {}) = ptext (sLit "parallel") + +------------ +isOK, notOK :: Maybe SDoc +isOK = Nothing +notOK = Just empty + +okStmt, okDoStmt, okCompStmt :: DynFlags -> HsStmtContext Name -> Bool + -> Stmt RdrName -> Maybe SDoc +-- Return Nothing if OK, (Just extra) if not ok +-- The "extra" is an SDoc that is appended to an generic error message +okStmt dflags GhciStmt is_last stmt + = case stmt of + ExprStmt {} -> isOK + BindStmt {} -> isOK + LetStmt {} -> isOK + _ -> notOK + +okStmt dflags (PatGuard {}) is_last stmt + = case stmt of + ExprStmt {} -> isOK + BindStmt {} -> isOK + LetStmt {} -> isOK + _ -> notOK + +okStmt dflags (ParStmtCtxt ctxt) is_last stmt + = case stmt of + LetStmt (HsIPBinds {}) -> notOK + _ -> okStmt dflags ctxt is_last stmt + +okStmt dflags (TransformStmtCtxt ctxt) is_last stmt + = okStmt dflags ctxt is_last stmt + +okStmt ctxt is_last stmt + | isDoExpr ctxt = okDoStmt ctxt is_last stmt + | isCompExpr ctxt = okCompStmt ctxt is_last stmt + | otherwise = pprPanic "okStmt" (pprStmtContext ctxt) + +---------------- +okDoStmt dflags ctxt is_last stmt + | is_last + = case stmt of + LastStmt {} -> isOK + _ -> Just (ptext (sLit "The last statement in") <+> what <+> + ptext (sLIt "construct must be an expression")) + where + what = case ctxt of + DoExpr -> ptext (sLit "a 'do'") + MDoExpr -> ptext (sLit "an 'mdo'") + _ -> panic "checkStmt" + + | otherwise + = case stmt of + RecStmt {} -> isOK -- Shouldn't we test a flag? + BindStmt {} -> isOK + LetStmt {} -> isOK + ExprStmt {} -> isOK + _ -> notOK + + +---------------- +okCompStmt dflags ctxt is_last stmt + | is_last + = case stmt of + LastStmt {} -> Nothing + -> pprPanic "Unexpected stmt" (ppr stmt) -- Not a user error + + | otherwise + = case stmt of + BindStmt {} -> isOK + LetStmt {} -> isOK + ExprStmt {} -> isOK + RecStmt {} -> notOK + ParStmt {} + | dopt dflags Opt_ParallelListComp -> isOK + | otherwise -> Just (ptext (sLit "Use -XParallelListComp")) + TransformStmt {} + | dopt dflags Opt_transformListComp -> isOK + | otherwise -> Just (ptext (sLit "Use -XTransformListComp")) + GroupStmt {} + | dopt dflags Opt_transformListComp -> isOK + | otherwise -> Just (ptext (sLit "Use -XTransformListComp")) + + +checkStmt :: HsStmtContext Name -> Stmt RdrName -> Maybe SDoc +-- Non-last stmt + +checkStmt (ParStmtCtxt _) (HsIPBinds binds) + = Just (badIpBinds (ptext (sLit "a parallel list comprehension:")) binds) -- We do not allow implicit-parameter bindings in a parallel -- list comprehension. I'm not sure what it might mean. ---------- -checkRecStmt :: HsStmtContext Name -> RnM () -checkRecStmt MDoExpr = return () -- Recursive stmt ok in 'mdo' -checkRecStmt DoExpr = return () -- and in 'do' -checkRecStmt ctxt = addErr msg - where - msg = ptext (sLit "Illegal 'rec' stmt in") <+> pprStmtContext ctxt +checkStmt ctxt (RecStmt {}) + | not (isDoExpr ctxt) + = addErr (ptext (sLit "Illegal 'rec' stmt in") <+> pprStmtContext ctxt) --------- checkParStmt :: HsStmtContext Name -> RnM () diff --git a/compiler/typecheck/TcExpr.lhs b/compiler/typecheck/TcExpr.lhs index a821f25..d24ebbe 100644 --- a/compiler/typecheck/TcExpr.lhs +++ b/compiler/typecheck/TcExpr.lhs @@ -415,8 +415,8 @@ tcExpr (HsIf (Just fun) pred b1 b2) res_ty -- Note [Rebindable syntax for if] -- and it maintains uniformity with other rebindable syntax ; return (HsIf (Just fun') pred' b1' b2') } -tcExpr (HsDo do_or_lc stmts body return_op _) res_ty - = tcDoStmts do_or_lc stmts body return_op res_ty +tcExpr (HsDo do_or_lc stmts _) res_ty + = tcDoStmts do_or_lc stmts res_ty tcExpr (HsProc pat cmd) res_ty = do { (pat', cmd', coi) <- tcProc pat cmd res_ty diff --git a/compiler/typecheck/TcGenDeriv.lhs b/compiler/typecheck/TcGenDeriv.lhs index efacac2..f7e5d39 100644 --- a/compiler/typecheck/TcGenDeriv.lhs +++ b/compiler/typecheck/TcGenDeriv.lhs @@ -779,7 +779,7 @@ gen_Ix_binds loc tycon single_con_range = mk_easy_FunBind loc range_RDR [nlTuplePat [con_pat as_needed, con_pat bs_needed] Boxed] $ - nlHsDo ListComp stmts con_expr + noLoc (mkHsComp ListComp stmts con_expr) where stmts = zipWith3Equal "single_con_range" mk_qual as_needed bs_needed cs_needed @@ -893,7 +893,7 @@ gen_Read_binds get_fixity loc tycon read_nullary_cons = case nullary_cons of [] -> [] - [con] -> [nlHsDo DoExpr (match_con con) (result_expr con [])] + [con] -> [nlHsDo DoExpr (match_con con ++ [mkExprStmt (result_expr con [])])] _ -> [nlHsApp (nlHsVar choose_RDR) (nlList (map mk_pair nullary_cons))] -- NB For operators the parens around (:=:) are matched by the @@ -965,11 +965,12 @@ gen_Read_binds get_fixity loc tycon ------------------------------------------------------------------------ -- Helpers ------------------------------------------------------------------------ - mk_alt e1 e2 = genOpApp e1 alt_RDR e2 -- e1 +++ e2 - mk_parser p ss b = nlHsApps prec_RDR [nlHsIntLit p, nlHsDo DoExpr ss b] -- prec p (do { ss ; b }) - bindLex pat = noLoc (mkBindStmt pat (nlHsVar lexP_RDR)) -- pat <- lexP - con_app con as = nlHsVarApps (getRdrName con) as -- con as - result_expr con as = nlHsApp (nlHsVar returnM_RDR) (con_app con as) -- return (con as) + mk_alt e1 e2 = genOpApp e1 alt_RDR e2 -- e1 +++ e2 + mk_parser p ss b = nlHsApps prec_RDR [nlHsIntLit p -- prec p (do { ss ; b }) + , nlHsDo DoExpr (ss ++ [mkExprStmt b])] + bindLex pat = noLoc (mkBindStmt pat (nlHsVar lexP_RDR)) -- pat <- lexP + con_app con as = nlHsVarApps (getRdrName con) as -- con as + result_expr con as = nlHsApp (nlHsVar returnM_RDR) (con_app con as) -- return (con as) punc_pat s = nlConPat punc_RDR [nlLitPat (mkHsString s)] -- Punc 'c' diff --git a/compiler/typecheck/TcHsSyn.lhs b/compiler/typecheck/TcHsSyn.lhs index 357db73..518582f 100644 --- a/compiler/typecheck/TcHsSyn.lhs +++ b/compiler/typecheck/TcHsSyn.lhs @@ -578,12 +578,10 @@ zonkExpr env (HsLet binds expr) zonkLExpr new_env expr `thenM` \ new_expr -> returnM (HsLet new_binds new_expr) -zonkExpr env (HsDo do_or_lc stmts body return_op ty) - = zonkStmts env stmts `thenM` \ (new_env, new_stmts) -> - zonkLExpr new_env body `thenM` \ new_body -> - zonkExpr new_env return_op `thenM` \ new_return -> +zonkExpr env (HsDo do_or_lc stmts ty) + = zonkStmts env stmts `thenM` \ (_, new_stmts) -> zonkTcTypeToType env ty `thenM` \ new_ty -> - returnM (HsDo do_or_lc new_stmts new_body new_return new_ty) + returnM (HsDo do_or_lc new_stmts new_ty) zonkExpr env (ExplicitList ty exprs) = zonkTcTypeToType env ty `thenM` \ new_ty -> @@ -745,9 +743,10 @@ zonkStmt env (ParStmt stmts_w_bndrs mzip_op bind_op return_op) zonkStmt env (RecStmt { recS_stmts = segStmts, recS_later_ids = lvs, recS_rec_ids = rvs , recS_ret_fn = ret_id, recS_mfix_fn = mfix_id, recS_bind_fn = bind_id - , recS_rec_rets = rets }) + , recS_rec_rets = rets, redS_ret_ty = ret_ty }) = do { new_rvs <- zonkIdBndrs env rvs ; new_lvs <- zonkIdBndrs env lvs + ; new_ret_ty <- zonkTcTypeToType env ret_ty ; new_ret_id <- zonkExpr env ret_id ; new_mfix_id <- zonkExpr env mfix_id ; new_bind_id <- zonkExpr env bind_id @@ -760,7 +759,7 @@ zonkStmt env (RecStmt { recS_stmts = segStmts, recS_later_ids = lvs, recS_rec_id RecStmt { recS_stmts = new_segStmts, recS_later_ids = new_lvs , recS_rec_ids = new_rvs, recS_ret_fn = new_ret_id , recS_mfix_fn = new_mfix_id, recS_bind_fn = new_bind_id - , recS_rec_rets = new_rets }) } + , recS_rec_rets = new_rets, recS_ret_ty = new_ret_ty }) } zonkStmt env (ExprStmt expr then_op guard_op ty) = zonkLExpr env expr `thenM` \ new_expr -> @@ -769,6 +768,11 @@ zonkStmt env (ExprStmt expr then_op guard_op ty) zonkTcTypeToType env ty `thenM` \ new_ty -> returnM (env, ExprStmt new_expr new_then new_guard new_ty) +zonkStmt env (LastStmt expr ret_op) + = zonkLExpr env expr `thenM` \ new_expr -> + zonkExpr env ret_op `thenM` \ new_ret -> + returnM (env, LastStmt new_expr new_ret) + zonkStmt env (TransformStmt stmts binders usingExpr maybeByExpr return_op bind_op) = do { (env', stmts') <- zonkStmts env stmts ; let binders' = zonkIdOccs env' binders diff --git a/compiler/typecheck/TcMatches.lhs b/compiler/typecheck/TcMatches.lhs index 31aa555..60bf7e2 100644 --- a/compiler/typecheck/TcMatches.lhs +++ b/compiler/typecheck/TcMatches.lhs @@ -241,41 +241,31 @@ tcGRHS ctxt res_ty (GRHS guards rhs) \begin{code} tcDoStmts :: HsStmtContext Name -> [LStmt Name] - -> LHsExpr Name - -> SyntaxExpr Name -- 'return' function for monad - -- comprehensions -> TcRhoType -> TcM (HsExpr TcId) -- Returns a HsDo -tcDoStmts ListComp stmts body _ res_ty +tcDoStmts ListComp stmts res_ty = do { (coi, elt_ty) <- matchExpectedListTy res_ty - ; (stmts', body') <- tcStmts ListComp (tcLcStmt listTyCon) stmts - elt_ty $ - tcBody body + ; stmts' <- tcStmts ListComp (tcLcStmt listTyCon) stmts res_ty ; return $ mkHsWrapCoI coi - (HsDo ListComp stmts' body' noSyntaxExpr (mkListTy elt_ty)) } + (HsDo ListComp stmts' (mkListTy elt_ty)) } -tcDoStmts PArrComp stmts body _ res_ty +tcDoStmts PArrComp stmts res_ty = do { (coi, elt_ty) <- matchExpectedPArrTy res_ty - ; (stmts', body') <- tcStmts PArrComp (tcLcStmt parrTyCon) stmts - elt_ty $ - tcBody body + ; stmts' <- tcStmts PArrComp (tcLcStmt parrTyCon) stmts elt_ty ; return $ mkHsWrapCoI coi - (HsDo PArrComp stmts' body' noSyntaxExpr (mkPArrTy elt_ty)) } + (HsDo PArrComp stmts' (mkPArrTy elt_ty)) } -tcDoStmts DoExpr stmts body _ res_ty - = do { (stmts', body') <- tcStmts DoExpr tcDoStmt stmts res_ty $ - tcBody body - ; return (HsDo DoExpr stmts' body' noSyntaxExpr res_ty) } +tcDoStmts DoExpr stmts res_ty + = do { stmts' <- tcStmts DoExpr tcDoStmt stmts res_ty + ; return (HsDo DoExpr stmts' res_ty) } -tcDoStmts MDoExpr stmts body _ res_ty - = do { (stmts', body') <- tcStmts MDoExpr tcDoStmt stmts res_ty $ - tcBody body - ; return (HsDo MDoExpr stmts' body' noSyntaxExpr res_ty) } +tcDoStmts MDoExpr stmts res_ty + = do { stmts' <- tcStmts MDoExpr tcDoStmt stmts res_ty + ; return (HsDo MDoExpr stmts' res_ty) } -tcDoStmts MonadComp stmts body return_op res_ty - = do { (stmts', (body', return_op')) <- tcStmts MonadComp tcMcStmt stmts res_ty $ - tcMcBody body return_op - ; return $ HsDo MonadComp stmts' body' return_op' res_ty } +tcDoStmts MonadComp stmts res_ty + = do { stmts' <- tcStmts MonadComp tcMcStmt stmts res_ty + ; return (HsDo MonadComp stmts' res_ty) } tcDoStmts ctxt _ _ _ _ = pprPanic "tcDoStmts" (pprStmtContext ctxt) @@ -306,30 +296,40 @@ tcStmts :: HsStmtContext Name -> TcStmtChecker -- NB: higher-rank type -> [LStmt Name] -> TcRhoType - -> (TcRhoType -> TcM thing) - -> TcM ([LStmt TcId], thing) + -> TcM [LStmt TcId] +tcStmts ctxt stmt_chk stmts res_ty + = do { (stmts', _) <- tcStmtsAndThen ctxt stmt_check stmts res_ty $ + const (return ()) + ; return stmts' } + +tcStmtsAndThen :: HsStmtContext Name + -> TcStmtChecker -- NB: higher-rank type + -> [LStmt Name] + -> TcRhoType + -> (TcRhoType -> TcM thing) + -> TcM ([LStmt TcId], thing) -- Note the higher-rank type. stmt_chk is applied at different -- types in the equations for tcStmts -tcStmts _ _ [] res_ty thing_inside +tcStmtsAndThen _ _ [] res_ty thing_inside = do { thing <- thing_inside res_ty ; return ([], thing) } -- LetStmts are handled uniformly, regardless of context -tcStmts ctxt stmt_chk (L loc (LetStmt binds) : stmts) res_ty thing_inside +tcStmtsAndThen ctxt stmt_chk (L loc (LetStmt binds) : stmts) res_ty thing_inside = do { (binds', (stmts',thing)) <- tcLocalBinds binds $ - tcStmts ctxt stmt_chk stmts res_ty thing_inside + tcStmtsAndThen ctxt stmt_chk stmts res_ty thing_inside ; return (L loc (LetStmt binds') : stmts', thing) } -- For the vanilla case, handle the location-setting part -tcStmts ctxt stmt_chk (L loc stmt : stmts) res_ty thing_inside +tcStmtsAndThen ctxt stmt_chk (L loc stmt : stmts) res_ty thing_inside = do { (stmt', (stmts', thing)) <- - setSrcSpan loc $ - addErrCtxt (pprStmtInCtxt ctxt stmt) $ - stmt_chk ctxt stmt res_ty $ \ res_ty' -> - popErrCtxt $ - tcStmts ctxt stmt_chk stmts res_ty' $ + setSrcSpan loc $ + addErrCtxt (pprStmtInCtxt ctxt stmt) $ + stmt_chk ctxt stmt res_ty $ \ res_ty' -> + popErrCtxt $ + tcStmtsAndThen ctxt stmt_chk stmts res_ty' $ thing_inside ; return (L loc stmt' : stmts', thing) } @@ -357,18 +357,23 @@ tcGuardStmt _ stmt _ _ tcLcStmt :: TyCon -- The list/Parray type constructor ([] or PArray) -> TcStmtChecker +tcLcStmt m_tc ctxt (LastStmt body _) elt_ty thing_inside + = do { body' <- tcMonoExpr body elt_ty + ; thing <- thing_inside elt_ty + ; return (LastStmt body' noSyntaxExpr, thing) } + -- A generator, pat <- rhs -tcLcStmt m_tc ctxt (BindStmt pat rhs _ _) res_ty thing_inside +tcLcStmt m_tc ctxt (BindStmt pat rhs _ _) elt_ty thing_inside = do { pat_ty <- newFlexiTyVarTy liftedTypeKind ; rhs' <- tcMonoExpr rhs (mkTyConApp m_tc [pat_ty]) ; (pat', thing) <- tcPat (StmtCtxt ctxt) pat pat_ty $ - thing_inside res_ty + thing_inside elt_ty ; return (BindStmt pat' rhs' noSyntaxExpr noSyntaxExpr, thing) } -- A boolean guard -tcLcStmt _ _ (ExprStmt rhs _ _ _) res_ty thing_inside +tcLcStmt _ _ (ExprStmt rhs _ _ _) elt_ty thing_inside = do { rhs' <- tcMonoExpr rhs boolTy - ; thing <- thing_inside res_ty + ; thing <- thing_inside elt_ty ; return (ExprStmt rhs' noSyntaxExpr noSyntaxExpr boolTy, thing) } -- A parallel set of comprehensions @@ -491,20 +496,29 @@ tcLcStmt _ _ stmt _ _ tcMcStmt :: TcStmtChecker +tcMcStmt ctxt (LastStmt body return_op) res_ty thing_inside + = do { a_ty <- newFlexiTyVarTy liftedTypeKind + ; return_op' <- tcSyntaxOp MCompOrigin return_op + (a_ty `mkFunTy` res_ty) + ; body' <- tcMonoExpr body a_ty + ; return (body', return_op') } + -- Generators for monad comprehensions ( pat <- rhs ) -- -- [ body | q <- gen ] -> gen :: m a -- q :: a -- + tcMcStmt ctxt (BindStmt pat rhs bind_op fail_op) res_ty thing_inside = do { rhs_ty <- newFlexiTyVarTy liftedTypeKind ; pat_ty <- newFlexiTyVarTy liftedTypeKind ; new_res_ty <- newFlexiTyVarTy liftedTypeKind + + -- (>>=) :: rhs_ty -> (pat_ty -> new_res_ty) -> res_ty ; bind_op' <- tcSyntaxOp MCompOrigin bind_op (mkFunTys [rhs_ty, mkFunTy pat_ty new_res_ty] res_ty) - -- If (but only if) the pattern can fail, - -- typecheck the 'fail' operator + -- If (but only if) the pattern can fail, typecheck the 'fail' operator ; fail_op' <- if isIrrefutableHsPat pat then return noSyntaxExpr else tcSyntaxOp MCompOrigin fail_op (mkFunTy stringTy new_res_ty) @@ -540,15 +554,15 @@ tcMcStmt _ (ExprStmt rhs then_op guard_op _) res_ty thing_inside -- [ body | stmts, then f ] -> f :: forall a. m a -> m a -- [ body | stmts, then f by e ] -> f :: forall a. (a -> t) -> m a -> m a -- -tcMcStmt ctxt (TransformStmt stmts binders usingExpr maybeByExpr return_op bind_op) elt_ty thing_inside +tcMcStmt ctxt (TransformStmt stmts binders usingExpr maybeByExpr return_op bind_op) res_ty thing_inside = do { -- We don't know the types of binders yet, so we use this dummy and -- later unify this type with the `m_bndr_ty` ty_dummy <- newFlexiTyVarTy liftedTypeKind ; (stmts', (binders', usingExpr', maybeByExpr', return_op', bind_op', thing)) <- - tcStmts (TransformStmtCtxt ctxt) tcMcStmt stmts ty_dummy $ \elt_ty' -> do - { (_, (m_ty, _)) <- matchExpectedAppTy elt_ty' + tcStmts (TransformStmtCtxt ctxt) tcMcStmt stmts ty_dummy $ \res_ty' -> do + { (_, (m_ty, _)) <- matchExpectedAppTy res_ty' ; (usingExpr', maybeByExpr') <- case maybeByExpr of Nothing -> do @@ -582,22 +596,22 @@ tcMcStmt ctxt (TransformStmt stmts binders usingExpr maybeByExpr return_op bind_ -- -> ( (a,b,c,..) -> m (a,b,c,..) ) -- -> m (a,b,c,..) -- - ; let bndr_ty = mkChunkified mkBoxedTupleTy $ map idType bndr_ids + ; let bndr_ty = mkBigCoreVarTupTy bndr_ids m_bndr_ty = m_ty `mkAppTy` bndr_ty ; return_op' <- tcSyntaxOp MCompOrigin return_op (bndr_ty `mkFunTy` m_bndr_ty) ; bind_op' <- tcSyntaxOp MCompOrigin bind_op $ - m_bndr_ty `mkFunTy` (bndr_ty `mkFunTy` elt_ty) - `mkFunTy` elt_ty + m_bndr_ty `mkFunTy` (bndr_ty `mkFunTy` res_ty) + `mkFunTy` res_ty -- Unify types of the inner comprehension and the binders type - ; _ <- unifyType elt_ty' m_bndr_ty + ; _ <- unifyType res_ty' m_bndr_ty -- Typecheck the `thing` with out old type (which is the type -- of the final result of our comprehension) - ; thing <- thing_inside elt_ty + ; thing <- thing_inside res_ty ; return (bndr_ids, usingExpr', maybeByExpr', return_op', bind_op', thing) } @@ -613,32 +627,21 @@ tcMcStmt ctxt (TransformStmt stmts binders usingExpr maybeByExpr return_op bind_ -- [ body | stmts, then group using f ] -- -> f :: forall a. m a -> m (m a) -- -tcMcStmt ctxt (GroupStmt stmts bindersMap by using return_op bind_op liftM_op) elt_ty thing_inside - = do { let (bndr_names, m_bndr_names) = unzip bindersMap - - ; (_,(m_ty,_)) <- matchExpectedAppTy elt_ty - ; let alphaMTy = m_ty `mkAppTy` alphaTy - alphaMMTy = m_ty `mkAppTy` alphaMTy - - -- We don't know the type of the bindings yet. It's not elt_ty! - ; bndr_ty_dummy <- newFlexiTyVarTy liftedTypeKind - - ; (stmts', (bndr_ids, by', using_ty, return_op', bind_op')) <- - tcStmts (TransformStmtCtxt ctxt) tcMcStmt stmts bndr_ty_dummy $ \elt_ty' -> do - { (by', using_ty) <- - case by of - Nothing -> -- check that using :: forall a. m a -> m (m a) - return (Nothing, mkForAllTy alphaTyVar $ - alphaMTy `mkFunTy` alphaMMTy) - - Just by_e -> -- check that using :: forall a. (a -> t) -> m a -> m (m a) - -- where by :: t - do { (by_e', t_ty) <- tcInferRhoNC by_e - ; return (Just by_e', mkForAllTy alphaTyVar $ - (alphaTy `mkFunTy` t_ty) - `mkFunTy` alphaMTy - `mkFunTy` alphaMMTy) } - +tcMcStmt ctxt (GroupStmt stmts bindersMap by using return_op bind_op fmap_op) res_ty thing_inside + = do { m1_ty <- newFlexiTyVarTy liftedTypeKind + ; m2_ty <- newFlexiTyVarTy liftedTypeKind + ; n_ty <- newFlexiTyVarTy liftedTypeKind + ; tup_ty_var <- newFlexiTyVarTy liftedTypeKind + ; new_res_ty <- newFlexiTyVarTy liftedTypeKind + ; let (bndr_names, n_bndr_names) = unzip bindersMap + m1_tup_ty = m1_ty `mkAppTy` tup_ty_var + + -- 'stmts' returns a result of type (m1_ty tuple_ty), + -- typically something like [(Int,Bool,Int)] + -- We don't know what tuple_ty is yet, so we use a variable + ; (stmts', (bndr_ids, by_e_ty, return_op')) <- + tcStmts (TransformStmtCtxt ctxt) tcMcStmt stmts m1_tup_ty $ \res_ty' -> do + { by_e_ty <- mapM tcInferRhoNC by_e -- Find the Ids (and hence types) of all old binders ; bndr_ids <- tcLookupLocalIds bndr_names @@ -646,48 +649,52 @@ tcMcStmt ctxt (GroupStmt stmts bindersMap by using return_op bind_op liftM_op) e -- 'return' is only used for the binders, so we know its type. -- -- return :: (a,b,c,..) -> m (a,b,c,..) - -- - ; let bndr_ty = mkChunkified mkBoxedTupleTy $ map idType bndr_ids - m_bndr_ty = m_ty `mkAppTy` bndr_ty - ; return_op' <- tcSyntaxOp MCompOrigin return_op $ bndr_ty `mkFunTy` m_bndr_ty + ; return_op' <- tcSyntaxOp MCompOrigin return_op $ + (mkBigCoreVarTupTy bndr_ids) `mkFunTy` res_ty' - -- '>>=' is used to pass the grouped binders to the rest of the - -- comprehension. - -- - -- (>>=) :: m (m a, m b, m c, ..) - -- -> ( (m a, m b, m c, ..) -> new_elt_ty ) - -- -> elt_ty - -- - ; let bndr_m_ty = mkChunkified mkBoxedTupleTy $ map (mkAppTy m_ty . idType) bndr_ids - m_bndr_m_ty = m_ty `mkAppTy` bndr_m_ty - ; new_elt_ty <- newFlexiTyVarTy liftedTypeKind - ; bind_op' <- tcSyntaxOp MCompOrigin bind_op $ - m_bndr_m_ty `mkFunTy` (bndr_m_ty `mkFunTy` new_elt_ty) - `mkFunTy` elt_ty + ; return (bndr_ids, by_e_ty, return_op') } - -- Finally make sure the type of the inner comprehension - -- represents the types of our binders - ; _ <- unifyType elt_ty' m_bndr_ty - ; return (bndr_ids, by', using_ty, return_op', bind_op') } - ; let mk_m_bndr :: Name -> TcId -> TcId - mk_m_bndr m_bndr_name bndr_id = - mkLocalId m_bndr_name (m_ty `mkAppTy` idType bndr_id) + ; let tup_ty = mkBigCoreVarTupTy bndr_ids -- (a,b,c) + using_arg_ty = m1_ty `mkAppTy` tup_ty -- m1 (a,b,c) + n_tup_ty = n_ty `mkAppTy` tup_ty -- n (a,b,c) + using_res_ty = m2_ty `mkAppTy` n_tup_ty -- m2 (n (a,b,c)) + using_fun_ty = using_arg_ty `mkFunTy` using_arg_ty + + -- (>>=) :: m2 (n (a,b,c)) -> ( n (a,b,c) -> new_res_ty ) -> res_ty + -- using :: ((a,b,c)->t) -> m1 (a,b,c) -> m2 (n (a,b,c)) - -- Ensure that every old binder of type `b` is linked up with its - -- new binder which should have type `m b` - m_bndr_ids = zipWith mk_m_bndr m_bndr_names bndr_ids - bindersMap' = bndr_ids `zip` m_bndr_ids + --------------- Typecheck the 'bind' function ------------- + ; bind_op' <- tcSyntaxOp MCompOrigin bind_op $ + using_res_ty `mkFunTy` (n_tup_ty `mkFunTy` new_res_ty) + `mkFunTy` res_ty - -- See Note [GroupStmt binder map] in HsExpr + --------------- Typecheck the 'using' function ------------- + ; let using_fun_ty = (m1_ty `mkAppTy` alphaTy) `mkFunTy` + (m2_ty `mkAppTy` (n_ty `mkAppTy` alphaTy)) + using_poly_ty = case by_e_ty of + Nothing -> mkForAllTy alphaTyVar using_fun_ty + -- using :: forall a. m1 a -> m2 (n a) - ; using' <- case using of - Left e -> do { e' <- tcPolyExpr e using_ty; return (Left e') } - Right e -> do { e' <- tcPolyExpr (noLoc e) using_ty; return (Right (unLoc e')) } + Just (_,t_ty) -> mkForAllTy alphaTyVar $ + (alphaTy `mkFunTy` t_ty) `mkFunTy` using_fun_ty + -- using :: forall a. (a->t) -> m1 a -> m2 (n a) + -- where by :: t - -- Type check 'liftM' with 'forall a b. (a -> b) -> m_ty a -> m_ty b' - ; liftM_op' <- fmap unLoc . tcPolyExpr (noLoc liftM_op) $ + ; using' <- case using of + Left e -> do { e' <- tcPolyExpr e using_poly_ty + ; return (Left e') } + Right e -> do { e' <- tcPolyExpr (noLoc e) using_poly_ty + ; return (Right (unLoc e')) } + ; coi <- unifyType (applyTy using_poly_ty tup_ty) + (case by_e_ty of + Nothing -> using_fun_ty + Just (_,t_ty) -> (tup_ty `mkFunTy` t_ty) `mkFunTy` using_fun_ty) + ; let final_using = mkHsWrapCoI coi (HsWrap (WpTyApp tup_ty) using') + + --------------- Typecheck the 'fmap' function ------------- + ; fmap_op' <- fmap unLoc . tcPolyExpr (noLoc fmap_op) $ mkForAllTy alphaTyVar $ mkForAllTy betaTyVar $ (alphaTy `mkFunTy` betaTy) `mkFunTy` @@ -695,11 +702,23 @@ tcMcStmt ctxt (GroupStmt stmts bindersMap by using return_op bind_op liftM_op) e `mkFunTy` (m_ty `mkAppTy` betaTy) + ; let mk_n_bndr :: Name -> TcId -> TcId + mk_n_bndr n_bndr_name bndr_id + = mkLocalId bndr_name (n_ty `mkAppTy` idType bndr_id) + + -- Ensure that every old binder of type `b` is linked up with its + -- new binder which should have type `n b` + -- See Note [GroupStmt binder map] in HsExpr + n_bndr_ids = zipWith mk_n_bndr n_bndr_names bndr_ids + bindersMap' = bndr_ids `zip` n_bndr_ids + -- Type check the thing in the environment with these new binders and -- return the result - ; thing <- tcExtendIdEnv m_bndr_ids (thing_inside elt_ty) + ; thing <- tcExtendIdEnv n_bndr_ids (thing_inside res_ty) - ; return (GroupStmt stmts' bindersMap' by' using' return_op' bind_op' liftM_op', thing) } + ; return (GroupStmt stmts' bindersMap' + (fmap fst by_e_ty) final_using + return_op' bind_op' fmap_op', thing) } -- Typecheck `ParStmt`. See `tcLcStmt` for more informations about typechecking -- of `ParStmt`s. @@ -712,8 +731,8 @@ tcMcStmt ctxt (GroupStmt stmts bindersMap by using return_op bind_op liftM_op) e -- -> (m st2 -> m st3 -> m (st2, st3)) -- recursive call -- -> m (st1, (st2, st3)) -- -tcMcStmt ctxt (ParStmt bndr_stmts_s mzip_op bind_op return_op) elt_ty thing_inside - = do { (_,(m_ty,_)) <- matchExpectedAppTy elt_ty +tcMcStmt ctxt (ParStmt bndr_stmts_s mzip_op bind_op return_op) res_ty thing_inside + = do { (_,(m_ty,_)) <- matchExpectedAppTy res_ty ; (pairs', thing) <- loop m_ty bndr_stmts_s ; let mzip_ty = mkForAllTys [alphaTyVar, betaTyVar] $ @@ -725,19 +744,22 @@ tcMcStmt ctxt (ParStmt bndr_stmts_s mzip_op bind_op return_op) elt_ty thing_insi ; mzip_op' <- unLoc `fmap` tcPolyExpr (noLoc mzip_op) mzip_ty -- Typecheck bind: - ; let tys = map (mkChunkified mkBoxedTupleTy . map idType . snd) pairs' + ; let tys = map (mkBigCoreVarTupTy . snd) pairs' tuple_ty = mk_tuple_ty tys ; bind_op' <- tcSyntaxOp MCompOrigin bind_op $ (m_ty `mkAppTy` tuple_ty) `mkFunTy` - (tuple_ty `mkFunTy` elt_ty) + (tuple_ty `mkFunTy` res_ty) `mkFunTy` - elt_ty + res_ty ; return_op' <- fmap unLoc . tcPolyExpr (noLoc return_op) $ mkForAllTy alphaTyVar $ alphaTy `mkFunTy` (m_ty `mkAppTy` alphaTy) + ; return_op' <- tcSyntaxOp MCompOrigin return_op + (bndr_ty `mkFunTy` m_bndr_ty) + ; return (ParStmt pairs' mzip_op' bind_op' return_op', thing) } where mk_tuple_ty tys = foldr (\tn tm -> mkBoxedTupleTy [tn, tm]) (last tys) (init tys) @@ -745,16 +767,16 @@ tcMcStmt ctxt (ParStmt bndr_stmts_s mzip_op bind_op return_op) elt_ty thing_insi -- loop :: Type -- m_ty -- -> [([LStmt Name], [Name])] -- -> TcM ([([LStmt TcId], [TcId])], thing) - loop _ [] = do { thing <- thing_inside elt_ty + loop _ [] = do { thing <- thing_inside res_ty ; return ([], thing) } -- matching in the branches loop m_ty ((stmts, names) : pairs) = do { -- type dummy since we don't know all binder types yet ty_dummy <- newFlexiTyVarTy liftedTypeKind ; (stmts', (ids, pairs', thing)) - <- tcStmts ctxt tcMcStmt stmts ty_dummy $ \elt_ty' -> + <- tcStmts ctxt tcMcStmt stmts ty_dummy $ \res_ty' -> do { ids <- tcLookupLocalIds names - ; _ <- unifyType elt_ty' (m_ty `mkAppTy` (mkChunkified mkBoxedTupleTy) (map idType ids)) + ; _ <- unifyType res_ty' (m_ty `mkAppTy` mkBigCoreVarTupTy ids) ; (pairs', thing) <- loop m_ty pairs ; return (ids, pairs', thing) } ; return ( (stmts', ids) : pairs', thing ) } @@ -762,27 +784,17 @@ tcMcStmt ctxt (ParStmt bndr_stmts_s mzip_op bind_op return_op) elt_ty thing_insi tcMcStmt _ stmt _ _ = pprPanic "tcMcStmt: unexpected Stmt" (ppr stmt) --- Typecheck 'body' with type 'a' instead of 'm a' like the rest of the --- statements, ignore the second type argument coming from the tcStmts loop -tcMcBody :: LHsExpr Name - -> SyntaxExpr Name - -> TcRhoType - -> TcM (LHsExpr TcId, SyntaxExpr TcId) -tcMcBody body return_op res_ty - = do { (_, (_, a_ty)) <- matchExpectedAppTy res_ty - ; body' <- tcMonoExpr body a_ty - ; return_op' <- tcSyntaxOp MCompOrigin return_op - (a_ty `mkFunTy` res_ty) - ; return (body', return_op') - } - - -------------------------------- -- Do-notation -- The main excitement here is dealing with rebindable syntax tcDoStmt :: TcStmtChecker +tcDoStmt ctxt (LastStmt body _) res_ty thing_inside + = do { body' <- tcMonoExpr body res_ty + ; thing <- thing_inside body_ty + ; return (LastStmt body' noSyntaxExpr, thing) } + tcDoStmt ctxt (BindStmt pat rhs bind_op fail_op) res_ty thing_inside = do { -- Deal with rebindable syntax: -- (>>=) :: rhs_ty -> (pat_ty -> new_res_ty) -> res_ty @@ -862,7 +874,7 @@ tcDoStmt ctxt (RecStmt { recS_stmts = stmts, recS_later_ids = later_names ; return (RecStmt { recS_stmts = stmts', recS_later_ids = later_ids , recS_rec_ids = rec_ids, recS_ret_fn = ret_op' , recS_mfix_fn = mfix_op', recS_bind_fn = bind_op' - , recS_rec_rets = tup_rets }, thing) + , recS_rec_rets = tup_rets, recS_ret_ty = stmts_ty }, thing) }} tcDoStmt _ stmt _ _ @@ -888,6 +900,7 @@ the expected/inferred stuff is back to front (see Trac #3613). tcMDoStmt :: (LHsExpr Name -> TcM (LHsExpr TcId, TcType)) -- RHS inference -> TcStmtChecker +-- Used only by TcArrows... should be gotten rid of tcMDoStmt tc_rhs ctxt (BindStmt pat rhs _ _) res_ty thing_inside = do { (rhs', pat_ty) <- tc_rhs rhs ; (pat', thing) <- tcPat (StmtCtxt ctxt) pat pat_ty $