From f6d254cccd3dc25fff9ff50c2e1bea52b10345e4 Mon Sep 17 00:00:00 2001 From: Simon Peyton Jones Date: Mon, 2 May 2011 13:56:37 +0100 Subject: [PATCH] More on monad-comp; an intermediate state, so don't pull --- compiler/deSugar/DsExpr.lhs | 8 +++- compiler/hsSyn/HsExpr.lhs | 18 ++++--- compiler/hsSyn/HsUtils.lhs | 8 +--- compiler/parser/Parser.y.pp | 2 +- compiler/rename/RnExpr.lhs | 96 +++++++++++++++++++------------------- compiler/typecheck/TcMatches.lhs | 27 ++++++----- 6 files changed, 85 insertions(+), 74 deletions(-) diff --git a/compiler/deSugar/DsExpr.lhs b/compiler/deSugar/DsExpr.lhs index 418bda5..4088e44 100644 --- a/compiler/deSugar/DsExpr.lhs +++ b/compiler/deSugar/DsExpr.lhs @@ -740,6 +740,7 @@ dsDo stmts noSyntaxExpr -- Tuple cannot fail tup_ids = rec_ids ++ filterOut (`elem` rec_ids) later_ids + tup_ty = mkBoxedTupleTy (map idType tup_ids) -- Deals with singleton case rec_tup_pats = map nlVarPat tup_ids later_pats = rec_tup_pats rets = map noLoc rec_rets @@ -748,8 +749,11 @@ dsDo stmts (mkFunTy tup_ty body_ty)) mfix_pat = noLoc $ LazyPat $ mkLHsPatTup rec_tup_pats body = noLoc $ HsDo DoExpr (rec_stmts ++ [ret_stmt]) body_ty - ret_stmt = noLoc $ LastStmt (mkLHsTupleExpr rets) return_op - tup_ty = mkBoxedTupleTy (map idType tup_ids) -- Deals with singleton case + ret_app = nlHsApp (noLoc return_op) (mkLHsTupleExpr rets) + ret_stmt = noLoc $ mkLastStmt ret_app + -- This LastStmt will be desugared with dsDo, + -- which ignores the return_op in the LastStmt, + -- so we must apply the return_op explicitly handle_failure :: LPat Id -> MatchResult -> SyntaxExpr Id -> DsM CoreExpr -- In a do expression, pattern-match failure just calls diff --git a/compiler/hsSyn/HsExpr.lhs b/compiler/hsSyn/HsExpr.lhs index cf9c0d7..fba270c 100644 --- a/compiler/hsSyn/HsExpr.lhs +++ b/compiler/hsSyn/HsExpr.lhs @@ -833,7 +833,8 @@ type Stmt id = StmtLR id id -- The SyntaxExprs in here are used *only* for do-notation and monad -- comprehensions, which have rebindable syntax. Otherwise they are unused. data StmtLR idL idR - = LastStmt -- Always the last Stmt in ListComp, MonadComp, PArrComp, DoExpr, MDoExpr + = LastStmt -- Always the last Stmt in ListComp, MonadComp, PArrComp, + -- and (after the renamer) DoExpr, MDoExpr -- Not used for GhciStmt, PatGuard, which scope over other stuff (LHsExpr idR) (SyntaxExpr idR) -- The return operator, used only for MonadComp @@ -1090,7 +1091,7 @@ instance (OutputableBndr idL, OutputableBndr idR) => Outputable (StmtLR idL idR) ppr stmt = pprStmt stmt pprStmt :: (OutputableBndr idL, OutputableBndr idR) => (StmtLR idL idR) -> SDoc -pprStmt (LastStmt expr _) = ppr expr +pprStmt (LastStmt expr _) = ifPprDebug (ptext (sLit "[last]")) <+> ppr expr pprStmt (BindStmt pat expr _ _) = hsep [ppr pat, ptext (sLit "<-"), ppr expr] pprStmt (LetStmt binds) = hsep [ptext (sLit "let"), pprBinds binds] pprStmt (ExprStmt expr _ _ _) = ppr expr @@ -1354,8 +1355,8 @@ pprAStmtContext ctxt = article <+> pprStmtContext ctxt ----------------- pprStmtContext GhciStmt = ptext (sLit "interactive GHCi command") -pprStmtContext DoExpr = ptext (sLit "'do' expression") -pprStmtContext MDoExpr = ptext (sLit "'mdo' expression") +pprStmtContext DoExpr = ptext (sLit "'do' block") +pprStmtContext MDoExpr = ptext (sLit "'mdo' block") pprStmtContext ListComp = ptext (sLit "list comprehension") pprStmtContext MonadComp = ptext (sLit "monad comprehension") pprStmtContext PArrComp = ptext (sLit "array comprehension") @@ -1402,8 +1403,13 @@ pprMatchInCtxt ctxt match = hang (ptext (sLit "In") <+> pprMatchContext ctxt <> pprStmtInCtxt :: (OutputableBndr idL, OutputableBndr idR) => HsStmtContext idL -> StmtLR idL idR -> SDoc -pprStmtInCtxt ctxt stmt = hang (ptext (sLit "In a stmt of") <+> pprAStmtContext ctxt <> colon) - 4 (ppr_stmt stmt) +pprStmtInCtxt ctxt (LastStmt e _) + | isListCompExpr ctxt -- For [ e | .. ], do not mutter about "stmts" + = hang (ptext (sLit "In the expression:")) 2 (ppr e) + +pprStmtInCtxt ctxt stmt + = hang (ptext (sLit "In a stmt of") <+> pprAStmtContext ctxt <> colon) + 2 (ppr_stmt stmt) where -- For Group and Transform Stmts, don't print the nested stmts! ppr_stmt (GroupStmt { grpS_by = by, grpS_using = using diff --git a/compiler/hsSyn/HsUtils.lhs b/compiler/hsSyn/HsUtils.lhs index de883f2..51a0de3 100644 --- a/compiler/hsSyn/HsUtils.lhs +++ b/compiler/hsSyn/HsUtils.lhs @@ -21,7 +21,7 @@ module HsUtils( mkMatchGroup, mkMatch, mkHsLam, mkHsIf, mkHsWrap, mkLHsWrap, mkHsWrapCoI, mkLHsWrapCoI, coiToHsWrapper, mkHsLams, mkHsDictLet, - mkHsOpApp, mkHsDo, mkHsComp, mkHsWrapPat, mkHsWrapPatCoI, mkDoStmts, + mkHsOpApp, mkHsDo, mkHsComp, mkHsWrapPat, mkHsWrapPatCoI, nlHsTyApp, nlHsVar, nlHsLit, nlHsApp, nlHsApps, nlHsIntLit, nlHsVarApps, nlHsDo, nlHsOpApp, nlHsLam, nlHsPar, nlHsIf, nlHsCase, nlList, @@ -192,7 +192,6 @@ mkHsFractional :: Rational -> PostTcType -> HsOverLit id mkHsIsString :: FastString -> PostTcType -> HsOverLit id mkHsDo :: HsStmtContext Name -> [LStmt id] -> HsExpr id mkHsComp :: HsStmtContext Name -> [LStmt id] -> LHsExpr id -> HsExpr id -mkDoStmts :: [LStmt id] -> [LStmt id] mkNPat :: HsOverLit id -> Maybe (SyntaxExpr id) -> Pat id mkNPlusKPat :: Located id -> HsOverLit id -> Pat id @@ -215,11 +214,6 @@ mkHsIsString s = OverLit (HsIsString s) noRebindableInfo noSyntaxExpr noRebindableInfo :: Bool noRebindableInfo = error "noRebindableInfo" -- Just another placeholder; --- mkDoStmts turns a trailing ExprStmt into a LastStmt -mkDoStmts [L loc (ExprStmt e _ _ _)] = [L loc (mkLastStmt e)] -mkDoStmts (s:ss) = s : mkDoStmts ss -mkDoStmts [] = [] - mkHsDo ctxt stmts = HsDo ctxt stmts placeHolderType mkHsComp ctxt stmts expr = mkHsDo ctxt (stmts ++ [last_stmt]) where diff --git a/compiler/parser/Parser.y.pp b/compiler/parser/Parser.y.pp index c42ea0c..aa20ea6 100644 --- a/compiler/parser/Parser.y.pp +++ b/compiler/parser/Parser.y.pp @@ -1602,7 +1602,7 @@ apats :: { [LPat RdrName] } -- Statement sequences stmtlist :: { Located [LStmt RdrName] } - : '{' stmts '}' { LL (mkDoStmts (unLoc $2)) } + : '{' stmts '}' { LL (unLoc $2) } | vocurly stmts close { $2 } -- do { ;; s ; s ; ; s ;; } diff --git a/compiler/rename/RnExpr.lhs b/compiler/rename/RnExpr.lhs index d1dd222..11d44e3 100644 --- a/compiler/rename/RnExpr.lhs +++ b/compiler/rename/RnExpr.lhs @@ -648,32 +648,22 @@ rnStmts MDoExpr stmts thing_inside -- Deal with mdo = -- Behave like do { rec { ...all but last... }; last } do { ((stmts1, (stmts2, thing)), fvs) <- rnStmt MDoExpr (noLoc $ mkRecStmt all_but_last) $ \ _ -> - do { checkStmt MDoExpr True last_stmt - ; rnStmt MDoExpr last_stmt thing_inside } + do { last_stmt' <- checkLastStmt MDoExpr last_stmt + ; rnStmt MDoExpr last_stmt' thing_inside } ; return (((stmts1 ++ stmts2), thing), fvs) } where Just (all_but_last, last_stmt) = snocView stmts -rnStmts ctxt (lstmt@(L loc stmt) : lstmts) thing_inside +rnStmts ctxt (lstmt@(L loc _) : lstmts) thing_inside | null lstmts = setSrcSpan loc $ - do { -- Turn a final ExprStmt into a LastStmt - -- This is the first place it's convenient to do this - -- (In principle the parser could do it, but it's - -- just not very convenient to do so.) - let stmt' | okEmpty ctxt - = lstmt - | otherwise - = case stmt of - ExprStmt e _ _ _ -> L loc (mkLastStmt e) - _ -> lstmt - ; checkStmt ctxt True {- last stmt -} stmt' - ; rnStmt ctxt stmt' thing_inside } + do { lstmt' <- checkLastStmt ctxt lstmt + ; rnStmt ctxt lstmt' thing_inside } | otherwise = do { ((stmts1, (stmts2, thing)), fvs) <- setSrcSpan loc $ - do { checkStmt ctxt False {- Not last -} lstmt + do { checkStmt ctxt lstmt ; rnStmt ctxt lstmt $ \ bndrs1 -> rnStmts ctxt lstmts $ \ bndrs2 -> thing_inside (bndrs1 ++ bndrs2) } @@ -1211,7 +1201,7 @@ checkEmptyStmts :: HsStmtContext Name -> RnM () checkEmptyStmts ctxt = unless (okEmpty ctxt) (addErr (emptyErr ctxt)) -okEmpty :: HsStmtContext Name -> Bool +okEmpty :: HsStmtContext a -> Bool okEmpty (PatGuard {}) = True okEmpty _ = False @@ -1221,14 +1211,42 @@ emptyErr (TransformStmtCtxt {}) = ptext (sLit "Empty statement group preceding ' emptyErr ctxt = ptext (sLit "Empty") <+> pprStmtContext ctxt ---------------------- +checkLastStmt :: HsStmtContext Name + -> LStmt RdrName + -> RnM (LStmt RdrName) +checkLastStmt ctxt lstmt@(L loc stmt) + = case ctxt of + ListComp -> check_comp + MonadComp -> check_comp + PArrComp -> check_comp + DoExpr -> check_do + MDoExpr -> check_do + _ -> check_other + where + check_do -- Expect ExprStmt, and change it to LastStmt + = case stmt of + ExprStmt e _ _ _ -> return (L loc (mkLastStmt e)) + LastStmt {} -> return lstmt -- "Deriving" clauses may generate a + -- LastStmt directly (unlike the parser) + _ -> do { addErr (hang last_error 2 (ppr stmt)); return lstmt } + last_error = (ptext (sLit "The last statement in") <+> pprAStmtContext ctxt + <+> ptext (sLit "must be an expression")) + + check_comp -- Expect LastStmt; this should be enforced by the parser! + = case stmt of + LastStmt {} -> return lstmt + _ -> pprPanic "checkLastStmt" (ppr lstmt) + + check_other -- Behave just as if this wasn't the last stmt + = do { checkStmt ctxt lstmt; return lstmt } + -- Checking when a particular Stmt is ok checkStmt :: HsStmtContext Name - -> Bool -- True <=> this is the last Stmt in the sequence -> LStmt RdrName -> RnM () -checkStmt ctxt is_last (L _ stmt) +checkStmt ctxt (L _ stmt) = do { dflags <- getDOpts - ; case okStmt dflags ctxt is_last stmt of + ; case okStmt dflags ctxt stmt of Nothing -> return () Just extra -> addErr (msg $$ extra) } where @@ -1250,42 +1268,32 @@ isOK, notOK :: Maybe SDoc isOK = Nothing notOK = Just empty -okStmt, okDoStmt, okCompStmt :: DynFlags -> HsStmtContext Name -> Bool +okStmt, okDoStmt, okCompStmt :: DynFlags -> HsStmtContext Name -> Stmt RdrName -> Maybe SDoc -- Return Nothing if OK, (Just extra) if not ok -- The "extra" is an SDoc that is appended to an generic error message -okStmt _ (PatGuard {}) _ stmt +okStmt _ (PatGuard {}) stmt = case stmt of ExprStmt {} -> isOK BindStmt {} -> isOK LetStmt {} -> isOK _ -> notOK -okStmt dflags (ParStmtCtxt ctxt) _ stmt +okStmt dflags (ParStmtCtxt ctxt) stmt = case stmt of LetStmt (HsIPBinds {}) -> notOK - _ -> okStmt dflags ctxt False stmt - -- NB: is_last=False in recursive - -- call; the branches of of a Par - -- not finish with a LastStmt + _ -> okStmt dflags ctxt stmt -okStmt dflags (TransformStmtCtxt ctxt) _ stmt - = okStmt dflags ctxt False stmt +okStmt dflags (TransformStmtCtxt ctxt) stmt + = okStmt dflags ctxt stmt -okStmt dflags ctxt is_last stmt - | isDoExpr ctxt = okDoStmt dflags ctxt is_last stmt - | isListCompExpr ctxt = okCompStmt dflags ctxt is_last stmt +okStmt dflags ctxt stmt + | isDoExpr ctxt = okDoStmt dflags ctxt stmt + | isListCompExpr ctxt = okCompStmt dflags ctxt stmt | otherwise = pprPanic "okStmt" (pprStmtContext ctxt) ---------------- -okDoStmt dflags ctxt is_last stmt - | is_last - = case stmt of - LastStmt {} -> isOK - _ -> Just (ptext (sLit "The last statement in") <+> pprAStmtContext ctxt - <+> ptext (sLit "must be an expression")) - - | otherwise +okDoStmt dflags _ stmt = case stmt of RecStmt {} | Opt_DoRec `xopt` dflags -> isOK @@ -1297,13 +1305,7 @@ okDoStmt dflags ctxt is_last stmt ---------------- -okCompStmt dflags _ is_last stmt - | is_last - = case stmt of - LastStmt {} -> Nothing - _ -> pprPanic "Unexpected stmt" (ppr stmt) -- Not a user error - - | otherwise +okCompStmt dflags _ stmt = case stmt of BindStmt {} -> isOK LetStmt {} -> isOK diff --git a/compiler/typecheck/TcMatches.lhs b/compiler/typecheck/TcMatches.lhs index 820e517..87449b6 100644 --- a/compiler/typecheck/TcMatches.lhs +++ b/compiler/typecheck/TcMatches.lhs @@ -358,7 +358,7 @@ tcLcStmt :: TyCon -- The list/Parray type constructor ([] or PArray) -> TcStmtChecker tcLcStmt _ _ (LastStmt body _) elt_ty thing_inside - = do { body' <- tcMonoExpr body elt_ty + = do { body' <- tcMonoExprNC body elt_ty ; thing <- thing_inside (panic "tcLcStmt: thing_inside") ; return (LastStmt body' noSyntaxExpr, thing) } @@ -502,7 +502,7 @@ tcMcStmt _ (LastStmt body return_op) res_ty thing_inside = do { a_ty <- newFlexiTyVarTy liftedTypeKind ; return_op' <- tcSyntaxOp MCompOrigin return_op (a_ty `mkFunTy` res_ty) - ; body' <- tcMonoExpr body a_ty + ; body' <- tcMonoExprNC body a_ty ; thing <- thing_inside (panic "tcMcStmt: thing_inside") ; return (LastStmt body' return_op', thing) } @@ -558,15 +558,20 @@ tcMcStmt _ (ExprStmt rhs then_op guard_op _) res_ty thing_inside -- [ body | stmts, then f by e ] -> f :: forall a. (a -> t) -> m a -> m a -- tcMcStmt ctxt (TransformStmt stmts binders usingExpr maybeByExpr return_op bind_op) res_ty thing_inside - = do { - -- We don't know the types of binders yet, so we use this dummy and - -- later unify this type with the `m_bndr_ty` - ty_dummy <- newFlexiTyVarTy liftedTypeKind + = do { let star_star_kind = liftedTypeKind `mkArrowKind` liftedTypeKind + ; m1_ty <- newFlexiTyVarTy star_star_kind + ; m2_ty <- newFlexiTyVarTy star_star_kind + ; n_ty <- newFlexiTyVarTy star_star_kind + ; tup_ty_var <- newFlexiTyVarTy liftedTypeKind + ; new_res_ty <- newFlexiTyVarTy liftedTypeKind + ; let m1_tup_ty = m1_ty `mkAppTy` tup_ty_var + -- 'stmts' returns a result of type (m1_ty tuple_ty), + -- typically something like [(Int,Bool,Int)] + -- We don't know what tuple_ty is yet, so we use a variable ; (stmts', (binders', usingExpr', maybeByExpr', return_op', bind_op', thing)) <- - tcStmtsAndThen (TransformStmtCtxt ctxt) tcMcStmt stmts ty_dummy $ \res_ty' -> do - { (_, (m_ty, _)) <- matchExpectedAppTy res_ty' - ; (usingExpr', maybeByExpr') <- + tcStmtsAndThen (TransformStmtCtxt ctxt) tcMcStmt stmts m1_tup_ty $ \res_ty' -> do + { (usingExpr', maybeByExpr') <- case maybeByExpr of Nothing -> do -- We must validate that usingExpr :: forall a. m a -> m a @@ -671,8 +676,8 @@ tcMcStmt ctxt (GroupStmt { grpS_stmts = stmts, grpS_bndrs = bindersMap using_res_ty = m2_ty `mkAppTy` n_tup_ty -- m2 (n (a,b,c)) using_fun_ty = using_arg_ty `mkFunTy` using_arg_ty - -- (>>=) :: m2 (n (a,b,c)) -> ( n (a,b,c) -> new_res_ty ) -> res_ty - -- using :: ((a,b,c)->t) -> m1 (a,b,c) -> m2 (n (a,b,c)) + -- (>>=) :: m2 (n (a,b,c)) -> ( n (a,b,c) -> new_res_ty ) -> res_ty + -- using :: ((a,b,c)->t) -> m1 (a,b,c) -> m2 (n (a,b,c)) --------------- Typecheck the 'bind' function ------------- ; bind_op' <- tcSyntaxOp MCompOrigin bind_op $ -- 1.7.10.4