From: Simon Peyton Jones Date: Thu, 28 Apr 2011 10:44:12 +0000 (+0100) Subject: Preliminary monad-comprehension patch (Trac #4370) X-Git-Url: http://git.megacz.com/?p=ghc-hetmet.git;a=commitdiff_plain;h=478e69b303eb2e653a2ebf5c888b5efdfef1fb9d Preliminary monad-comprehension patch (Trac #4370) This is the work of Nils Schweinsberg It adds the language extension -XMonadComprehensions, which generalises list comprehension syntax [ e | x <- xs] to work over arbitrary monads. --- diff --git a/compiler/deSugar/Coverage.lhs b/compiler/deSugar/Coverage.lhs index 0daa6be..e73c249 100644 --- a/compiler/deSugar/Coverage.lhs +++ b/compiler/deSugar/Coverage.lhs @@ -301,10 +301,11 @@ addTickHsExpr (HsLet binds e) = liftM2 HsLet (addTickHsLocalBinds binds) -- to think about: !patterns. (addTickLHsExprNeverOrAlways e) -addTickHsExpr (HsDo cxt stmts last_exp srcloc) = do +addTickHsExpr (HsDo cxt stmts last_exp return_exp srcloc) = do (stmts', last_exp') <- addTickLStmts' forQual stmts (addTickLHsExpr last_exp) - return (HsDo cxt stmts' last_exp' srcloc) + return_exp' <- addTickSyntaxExpr hpcSrcSpan return_exp + return (HsDo cxt stmts' last_exp' return_exp' srcloc) where forQual = case cxt of ListComp -> Just $ BinBox QualBinBox @@ -438,31 +439,38 @@ addTickStmt _isGuard (BindStmt pat e bind fail) = do (addTickLHsExprAlways e) (addTickSyntaxExpr hpcSrcSpan bind) (addTickSyntaxExpr hpcSrcSpan fail) -addTickStmt isGuard (ExprStmt e bind' ty) = do - liftM3 ExprStmt +addTickStmt isGuard (ExprStmt e bind' guard' ty) = do + liftM4 ExprStmt (addTick isGuard e) (addTickSyntaxExpr hpcSrcSpan bind') + (addTickSyntaxExpr hpcSrcSpan guard') (return ty) addTickStmt _isGuard (LetStmt binds) = do liftM LetStmt (addTickHsLocalBinds binds) -addTickStmt isGuard (ParStmt pairs) = do - liftM ParStmt +addTickStmt isGuard (ParStmt pairs mzipExpr bindExpr returnExpr) = do + liftM4 ParStmt (mapM (addTickStmtAndBinders isGuard) pairs) - -addTickStmt isGuard (TransformStmt stmts ids usingExpr maybeByExpr) = do - liftM4 TransformStmt - (addTickLStmts isGuard stmts) - (return ids) - (addTickLHsExprAlways usingExpr) - (addTickMaybeByLHsExpr maybeByExpr) - -addTickStmt isGuard (GroupStmt stmts binderMap by using) = do - liftM4 GroupStmt - (addTickLStmts isGuard stmts) - (return binderMap) - (fmapMaybeM addTickLHsExprAlways by) - (fmapEitherM addTickLHsExprAlways (addTickSyntaxExpr hpcSrcSpan) using) + (addTickSyntaxExpr hpcSrcSpan mzipExpr) + (addTickSyntaxExpr hpcSrcSpan bindExpr) + (addTickSyntaxExpr hpcSrcSpan returnExpr) + +addTickStmt isGuard (TransformStmt stmts ids usingExpr maybeByExpr returnExpr bindExpr) = do + t_s <- (addTickLStmts isGuard stmts) + t_u <- (addTickLHsExprAlways usingExpr) + t_m <- (addTickMaybeByLHsExpr maybeByExpr) + t_r <- (addTickSyntaxExpr hpcSrcSpan returnExpr) + t_b <- (addTickSyntaxExpr hpcSrcSpan bindExpr) + return $ TransformStmt t_s ids t_u t_m t_r t_b + +addTickStmt isGuard (GroupStmt stmts binderMap by using returnExpr bindExpr liftMExpr) = do + t_s <- (addTickLStmts isGuard stmts) + t_y <- (fmapMaybeM addTickLHsExprAlways by) + t_u <- (fmapEitherM addTickLHsExprAlways (addTickSyntaxExpr hpcSrcSpan) using) + t_f <- (addTickSyntaxExpr hpcSrcSpan returnExpr) + t_b <- (addTickSyntaxExpr hpcSrcSpan bindExpr) + t_m <- (addTickSyntaxExpr hpcSrcSpan liftMExpr) + return $ GroupStmt t_s binderMap t_y t_u t_b t_f t_m addTickStmt isGuard stmt@(RecStmt {}) = do { stmts' <- addTickLStmts isGuard (recS_stmts stmt) @@ -569,9 +577,10 @@ addTickHsCmd (HsLet binds c) = liftM2 HsLet (addTickHsLocalBinds binds) -- to think about: !patterns. (addTickLHsCmd c) -addTickHsCmd (HsDo cxt stmts last_exp srcloc) = do +addTickHsCmd (HsDo cxt stmts last_exp return_exp srcloc) = do (stmts', last_exp') <- addTickLCmdStmts' stmts (addTickLHsCmd last_exp) - return (HsDo cxt stmts' last_exp' srcloc) + return_exp' <- addTickSyntaxExpr hpcSrcSpan return_exp + return (HsDo cxt stmts' last_exp' return_exp' srcloc) addTickHsCmd (HsArrApp e1 e2 ty1 arr_ty lr) = liftM5 HsArrApp @@ -635,10 +644,11 @@ addTickCmdStmt (BindStmt pat c bind fail) = do (addTickLHsCmd c) (return bind) (return fail) -addTickCmdStmt (ExprStmt c bind' ty) = do - liftM3 ExprStmt +addTickCmdStmt (ExprStmt c bind' guard' ty) = do + liftM4 ExprStmt (addTickLHsCmd c) - (return bind') + (addTickSyntaxExpr hpcSrcSpan bind') + (addTickSyntaxExpr hpcSrcSpan guard') (return ty) addTickCmdStmt (LetStmt binds) = do liftM LetStmt diff --git a/compiler/deSugar/DsArrows.lhs b/compiler/deSugar/DsArrows.lhs index 58bf6b8..608f25e 100644 --- a/compiler/deSugar/DsArrows.lhs +++ b/compiler/deSugar/DsArrows.lhs @@ -541,7 +541,7 @@ dsCmd ids local_vars env_ids stack res_ty (HsLet binds body) = do core_body, exprFreeVars core_binds `intersectVarSet` local_vars) -dsCmd ids local_vars env_ids [] res_ty (HsDo _ctxt stmts body _) +dsCmd ids local_vars env_ids [] res_ty (HsDo _ctxt stmts body _ _) = dsCmdDo ids local_vars env_ids res_ty stmts body -- A |- e :: forall e. a1 (e*ts1) t1 -> ... an (e*tsn) tn -> a (e*ts) t @@ -674,7 +674,7 @@ dsCmdStmt -- ---> arr (\ (xs) -> ((xs1),(xs'))) >>> first c >>> -- arr snd >>> ss -dsCmdStmt ids local_vars env_ids out_ids (ExprStmt cmd _ c_ty) = do +dsCmdStmt ids local_vars env_ids out_ids (ExprStmt cmd _ _ c_ty) = do (core_cmd, fv_cmd, env_ids1) <- dsfixCmd ids local_vars [] c_ty cmd core_mux <- matchEnvStack env_ids [] (mkCorePairExpr (mkBigCoreVarTup env_ids1) (mkBigCoreVarTup out_ids)) diff --git a/compiler/deSugar/DsExpr.lhs b/compiler/deSugar/DsExpr.lhs index 1781aef..fb3f856 100644 --- a/compiler/deSugar/DsExpr.lhs +++ b/compiler/deSugar/DsExpr.lhs @@ -325,22 +325,25 @@ dsExpr (HsLet binds body) = do -- We need the `ListComp' form to use `deListComp' (rather than the "do" form) -- because the interpretation of `stmts' depends on what sort of thing it is. -- -dsExpr (HsDo ListComp stmts body result_ty) +dsExpr (HsDo ListComp stmts body _ result_ty) = -- Special case for list comprehensions dsListComp stmts body elt_ty where [elt_ty] = tcTyConAppArgs result_ty -dsExpr (HsDo DoExpr stmts body result_ty) +dsExpr (HsDo DoExpr stmts body _ result_ty) = dsDo stmts body result_ty -dsExpr (HsDo GhciStmt stmts body result_ty) +dsExpr (HsDo GhciStmt stmts body _ result_ty) = dsDo stmts body result_ty -dsExpr (HsDo MDoExpr stmts body result_ty) +dsExpr (HsDo MDoExpr stmts body _ result_ty) = dsDo stmts body result_ty -dsExpr (HsDo PArrComp stmts body result_ty) +dsExpr (HsDo MonadComp stmts body return_op result_ty) + = dsMonadComp stmts return_op body result_ty + +dsExpr (HsDo PArrComp stmts body _ result_ty) = -- Special case for array comprehensions dsPArrComp (map unLoc stmts) body elt_ty where @@ -722,7 +725,7 @@ dsDo stmts body result_ty goL [] = dsLExpr body goL ((L loc stmt):lstmts) = putSrcSpanDs loc (go loc stmt lstmts) - go _ (ExprStmt rhs then_expr _) stmts + go _ (ExprStmt rhs then_expr _ _) stmts = do { rhs2 <- dsLExpr rhs ; case tcSplitAppTy_maybe (exprType rhs2) of Just (container_ty, returning_ty) -> warnDiscardedDoBindings rhs container_ty returning_ty @@ -769,7 +772,7 @@ dsDo stmts body result_ty mfix_arg = noLoc $ HsLam (MatchGroup [mkSimpleMatch [mfix_pat] body] (mkFunTy tup_ty body_ty)) mfix_pat = noLoc $ LazyPat $ mkLHsPatTup rec_tup_pats - body = noLoc $ HsDo DoExpr rec_stmts return_app body_ty + body = noLoc $ HsDo DoExpr rec_stmts return_app noSyntaxExpr body_ty return_app = nlHsApp (noLoc return_op) (mkLHsTupleExpr rets) body_ty = mkAppTy m_ty tup_ty tup_ty = mkBoxedTupleTy (map idType tup_ids) -- Deals with singleton case @@ -869,7 +872,7 @@ dsMDo ctxt tbl stmts body result_ty rets = map nlHsVar later_ids' ++ map noLoc rec_rets mfix_pat = noLoc $ LazyPat $ mk_tup_pat rec_tup_pats - body = noLoc $ HsDo ctxt rec_stmts return_app body_ty + body = noLoc $ HsDo ctxt rec_stmts return_app noSyntaxExpr body_ty body_ty = mkAppTy m_ty tup_ty tup_ty = mkBoxedTupleTy (map idType (later_ids' ++ rec_ids)) -- Deals with singleton case @@ -888,7 +891,6 @@ dsMDo ctxt tbl stmts body result_ty -} \end{code} - %************************************************************************ %* * Warning about identities diff --git a/compiler/deSugar/DsGRHSs.lhs b/compiler/deSugar/DsGRHSs.lhs index a7260e2..d3fcf76 100644 --- a/compiler/deSugar/DsGRHSs.lhs +++ b/compiler/deSugar/DsGRHSs.lhs @@ -106,11 +106,11 @@ matchGuards [] _ rhs _ -- NB: The success of this clause depends on the typechecker not -- wrapping the 'otherwise' in empty HsTyApp or HsWrap constructors -- If it does, you'll get bogus overlap warnings -matchGuards (ExprStmt e _ _ : stmts) ctx rhs rhs_ty +matchGuards (ExprStmt e _ _ _ : stmts) ctx rhs rhs_ty | Just addTicks <- isTrueLHsExpr e = do match_result <- matchGuards stmts ctx rhs rhs_ty return (adjustMatchResultDs addTicks match_result) -matchGuards (ExprStmt expr _ _ : stmts) ctx rhs rhs_ty = do +matchGuards (ExprStmt expr _ _ _ : stmts) ctx rhs rhs_ty = do match_result <- matchGuards stmts ctx rhs rhs_ty pred_expr <- dsLExpr expr return (mkGuardedMatchResult pred_expr match_result) diff --git a/compiler/deSugar/DsListComp.lhs b/compiler/deSugar/DsListComp.lhs index cd22b8f..7fa7848 100644 --- a/compiler/deSugar/DsListComp.lhs +++ b/compiler/deSugar/DsListComp.lhs @@ -3,9 +3,10 @@ % (c) The GRASP/AQUA Project, Glasgow University, 1992-1998 % -Desugaring list comprehensions and array comprehensions +Desugaring list comprehensions, monad comprehensions and array comprehensions \begin{code} +{-# LANGUAGE NamedFieldPuns #-} {-# OPTIONS -fno-warn-incomplete-patterns #-} -- The above warning supression flag is a temporary kludge. -- While working on this module you are encouraged to remove it and fix @@ -13,11 +14,11 @@ Desugaring list comprehensions and array comprehensions -- http://hackage.haskell.org/trac/ghc/wiki/Commentary/CodingStyle#Warnings -- for details -module DsListComp ( dsListComp, dsPArrComp ) where +module DsListComp ( dsListComp, dsPArrComp, dsMonadComp ) where #include "HsVersions.h" -import {-# SOURCE #-} DsExpr ( dsLExpr, dsLocalBinds ) +import {-# SOURCE #-} DsExpr ( dsExpr, dsLExpr, dsLocalBinds ) import HsSyn import TcHsSyn @@ -37,6 +38,7 @@ import PrelNames import SrcLoc import Outputable import FastString +import TcType \end{code} List comprehensions may be desugared in one of two ways: ``ordinary'' @@ -72,8 +74,8 @@ dsListComp lquals body elt_ty = do -- mix of possibly a single element in length, so we do this to leave the possibility open isParallelComp = any isParallelStmt - isParallelStmt (ParStmt _) = True - isParallelStmt _ = False + isParallelStmt (ParStmt _ _ _ _) = True + isParallelStmt _ = False -- This function lets you desugar a inner list comprehension and a list of the binders @@ -92,7 +94,7 @@ dsInnerListComp (stmts, bndrs) = do -- Given such a statement it gives you back an expression representing how to compute the transformed -- list and the tuple that you need to bind from that list in order to proceed with your desugaring dsTransformStmt :: Stmt Id -> DsM (CoreExpr, LPat Id) -dsTransformStmt (TransformStmt stmts binders usingExpr maybeByExpr) +dsTransformStmt (TransformStmt stmts binders usingExpr maybeByExpr _ _) = do { (expr, binders_tuple_type) <- dsInnerListComp (stmts, binders) ; usingExpr' <- dsLExpr usingExpr @@ -116,7 +118,7 @@ dsTransformStmt (TransformStmt stmts binders usingExpr maybeByExpr) -- Given such a statement it gives you back an expression representing how to compute the transformed -- list and the tuple that you need to bind from that list in order to proceed with your desugaring dsGroupStmt :: Stmt Id -> DsM (CoreExpr, LPat Id) -dsGroupStmt (GroupStmt stmts binderMap by using) = do +dsGroupStmt (GroupStmt stmts binderMap by using _ _ _) = do let (fromBinders, toBinders) = unzip binderMap fromBindersTypes = map idType fromBinders @@ -228,7 +230,7 @@ with the Unboxed variety. deListComp :: [Stmt Id] -> LHsExpr Id -> CoreExpr -> DsM CoreExpr -deListComp (ParStmt stmtss_w_bndrs : quals) body list +deListComp (ParStmt stmtss_w_bndrs _ _ _ : quals) body list = do exps_and_qual_tys <- mapM dsInnerListComp stmtss_w_bndrs let (exps, qual_tys) = unzip exps_and_qual_tys @@ -252,7 +254,7 @@ deListComp [] body list = do -- Figure 7.4, SLPJ, p 135, rule C above return (mkConsExpr (exprType core_body) core_body list) -- Non-last: must be a guard -deListComp (ExprStmt guard _ _ : quals) body list = do -- rule B above +deListComp (ExprStmt guard _ _ _ : quals) body list = do -- rule B above core_guard <- dsLExpr guard core_rest <- deListComp quals body list return (mkIfThenElse core_guard core_rest list) @@ -344,7 +346,7 @@ dfListComp c_id n_id [] body = do return (mkApps (Var c_id) [core_body, Var n_id]) -- Non-last: must be a guard -dfListComp c_id n_id (ExprStmt guard _ _ : quals) body = do +dfListComp c_id n_id (ExprStmt guard _ _ _ : quals) body = do core_guard <- dsLExpr guard core_rest <- dfListComp c_id n_id quals body return (mkIfThenElse core_guard core_rest (Var n_id)) @@ -501,7 +503,7 @@ dsPArrComp :: [Stmt Id] -> LHsExpr Id -> Type -- Don't use; called with `undefined' below -> DsM CoreExpr -dsPArrComp [ParStmt qss] body _ = -- parallel comprehension +dsPArrComp [ParStmt qss _ _ _] body _ = -- parallel comprehension dePArrParComp qss body -- Special case for simple generators: @@ -550,7 +552,7 @@ dePArrComp [] e' pa cea = do -- -- <<[:e' | b, qs:]>> pa ea = <<[:e' | qs:]>> pa (filterP (\pa -> b) ea) -- -dePArrComp (ExprStmt b _ _ : qs) body pa cea = do +dePArrComp (ExprStmt b _ _ _ : qs) body pa cea = do filterP <- dsLookupDPHId filterPName let ty = parrElemType cea (clam,_) <- deLambda ty pa b @@ -616,7 +618,7 @@ dePArrComp (LetStmt ds : qs) body pa cea = do -- singeltons qualifier lists, which we already special case in the caller. -- So, encountering one here is a bug. -- -dePArrComp (ParStmt _ : _) _ _ _ = +dePArrComp (ParStmt _ _ _ _ : _) _ _ _ = panic "DsListComp.dePArrComp: malformed comprehension AST" -- <<[:e' | qs | qss:]>> pa ea = @@ -682,3 +684,341 @@ parrElemType e = _ -> panic "DsListComp.parrElemType: not a parallel array type" \end{code} + +Translation for monad comprehensions + +\begin{code} + +-- | Keep the "context" of a monad comprehension in a small data type to avoid +-- some boilerplate... +data DsMonadComp = DsMonadComp + { mc_return :: Either (SyntaxExpr Id) (Expr CoreBndr) + , mc_body :: LHsExpr Id + , mc_m_ty :: Type + } + +-- +-- Entry point for monad comprehension desugaring +-- +dsMonadComp :: [LStmt Id] -- the statements + -> SyntaxExpr Id -- the "return" function + -> LHsExpr Id -- the body + -> Type -- the final type + -> DsM CoreExpr +dsMonadComp stmts return_op body res_ty + = dsMcStmts stmts (DsMonadComp (Left return_op) body m_ty) + where + (m_ty, _) = tcSplitAppTy res_ty + + +dsMcStmts :: [LStmt Id] + -> DsMonadComp + -> DsM CoreExpr + +-- No statements left for desugaring. Desugar the body after calling "return" +-- on it. +dsMcStmts [] DsMonadComp { mc_return, mc_body } + = case mc_return of + Left ret -> dsLExpr $ noLoc ret `nlHsApp` mc_body + Right ret' -> do + { body' <- dsLExpr mc_body + ; return $ mkApps ret' [body'] } + +-- Otherwise desugar each statement step by step +dsMcStmts ((L loc stmt) : lstmts) mc + = putSrcSpanDs loc (dsMcStmt stmt lstmts mc) + + +dsMcStmt :: Stmt Id + -> [LStmt Id] + -> DsMonadComp + -> DsM CoreExpr + +-- [ .. | let binds, stmts ] +dsMcStmt (LetStmt binds) stmts mc + = do { rest <- dsMcStmts stmts mc + ; dsLocalBinds binds rest } + +-- [ .. | a <- m, stmts ] +dsMcStmt (BindStmt pat rhs bind_op fail_op) stmts mc + = do { rhs' <- dsLExpr rhs + ; dsMcBindStmt pat rhs' bind_op fail_op stmts mc } + +-- Apply `guard` to the `exp` expression +-- +-- [ .. | exp, stmts ] +-- +dsMcStmt (ExprStmt exp then_exp guard_exp _) stmts mc + = do { exp' <- dsLExpr exp + ; guard_exp' <- dsExpr guard_exp + ; then_exp' <- dsExpr then_exp + ; rest <- dsMcStmts stmts mc + ; return $ mkApps then_exp' [ mkApps guard_exp' [exp'] + , rest ] } + +-- Transform statements desugar like this: +-- +-- [ .. | qs, then f by e ] -> f (\q_v -> e) [| qs |] +-- +-- where [| qs |] is the desugared inner monad comprehenion generated by the +-- statements `qs`. +dsMcStmt (TransformStmt stmts binders usingExpr maybeByExpr return_op bind_op) stmts_rest mc + = do { (expr, _) <- dsInnerMonadComp (stmts, binders) (mc { mc_return = Left return_op }) + ; let binders_tuple_type = mkBigCoreTupTy $ map idType binders + ; usingExpr' <- dsLExpr usingExpr + ; using_args <- case maybeByExpr of + Nothing -> return [expr] + Just byExpr -> do + byExpr' <- dsLExpr byExpr + us <- newUniqueSupply + tuple_binder <- newSysLocalDs binders_tuple_type + let byExprWrapper = mkTupleCase us binders byExpr' tuple_binder (Var tuple_binder) + return [Lam tuple_binder byExprWrapper, expr] + + ; let pat = mkBigLHsVarPatTup binders + rhs = mkApps usingExpr' ((Type binders_tuple_type) : using_args) + + ; dsMcBindStmt pat rhs bind_op noSyntaxExpr stmts_rest mc } + +-- Group statements desugar like this: +-- +-- [| q, then group by e using f |] -> (f (\q_v -> e) [| q |]) >>= (return . (unzip q_v)) +-- +-- which is equal to +-- +-- [| q, then group by e using f |] -> liftM (unzip q_v) (f (\q_v -> e) [| q |]) +-- +-- where unzip is of the form +-- +-- unzip :: m (a,b,c,..) -> (m a,m b,m c,..) +-- unzip m_tuple = ( liftM selN1 m_tuple +-- , liftM selN2 m_tuple +-- , .. ) +-- where selN1 (a,b,c,..) = a +-- selN2 (a,b,c,..) = b +-- .. +-- +dsMcStmt (GroupStmt stmts binderMap by using return_op bind_op liftM_op) stmts_rest mc + = do { let (fromBinders, toBinders) = unzip binderMap + fromBindersTypes = map idType fromBinders + fromBindersTupleTy = mkBigCoreTupTy fromBindersTypes + toBindersTypes = map idType toBinders + toBindersTupleTy = mkBigCoreTupTy toBindersTypes + m_ty = mc_m_ty mc + + -- Desugar an inner comprehension which outputs a list of tuples of the "from" binders + ; (expr, _) <- dsInnerMonadComp (stmts, fromBinders) (mc { mc_return = Left return_op }) + + -- Work out what arguments should be supplied to that expression: i.e. is an extraction + -- function required? If so, create that desugared function and add to arguments + ; usingExpr' <- dsLExpr (either id noLoc using) + ; usingArgs <- case by of + Nothing -> return [expr] + Just by_e -> do { by_e' <- dsLExpr by_e + ; us <- newUniqueSupply + ; from_tup_id <- newSysLocalDs fromBindersTupleTy + ; let by_wrap = mkTupleCase us fromBinders by_e' + from_tup_id (Var from_tup_id) + ; return [Lam from_tup_id by_wrap, expr] } + + -- Create an unzip function for the appropriate arity and element types + ; liftM_op' <- dsExpr liftM_op + ; (unzip_fn, unzip_rhs) <- mkMcUnzipM liftM_op' m_ty fromBindersTypes + + -- Generate the expressions to build the grouped list + + ; let -- First we apply the grouping function to the inner monad + inner_monad_expr = mkApps usingExpr' ((Type fromBindersTupleTy) : usingArgs) + -- Then we map our "unzip" across it to turn the "monad of tuples" into "tuples of monads" + -- We make sure we instantiate the type variable "a" to be a "monad of 'from' tuples" and + -- the "b" to be a "tuple of 'to' monads"! + unzipped_inner_monad_expr = mkApps liftM_op' -- ! + -- Types: + [ Type (m_ty `mkAppTy` fromBindersTupleTy), Type toBindersTupleTy + -- And arguments: + , Var unzip_fn, inner_monad_expr ] + -- Then finally we bind the unzip function around that expression + bound_unzipped_inner_monad_expr = Let (Rec [(unzip_fn, unzip_rhs)]) unzipped_inner_monad_expr + + -- Build a pattern that ensures the consumer binds into the NEW binders, which hold monads + -- rather than single values + ; let pat = mkBigLHsVarPatTup toBinders + rhs = bound_unzipped_inner_monad_expr + + ; dsMcBindStmt pat rhs bind_op noSyntaxExpr stmts_rest mc } + +-- Parallel statements. Use `Control.Monad.Zip.mzip` to zip parallel +-- statements, for example: +-- +-- [ body | qs1 | qs2 | qs3 ] +-- -> [ body | (bndrs1, (bndrs2, bndrs3)) <- mzip qs1 (mzip qs2 qs3) ] +-- +-- where `mzip` is of the form +-- +-- mzip :: m a -> m b -> m (a,b) +-- +dsMcStmt (ParStmt pairs mzip_op bind_op return_op) stmts_rest mc + = do { -- Get types for `return` + return_op' <- dsExpr return_op + ; let pairs_with_return = map (\tp@(_,b) -> (mkReturn b,tp)) pairs + mkReturn bndrs = mkApps return_op' [Type (mkBigCoreTupTy (map idType bndrs))] + + ; pairs' <- mapM (\(r,tp) -> dsInnerMonadComp tp mc{mc_return = Right r}) + pairs_with_return + + ; let (exps, _qual_tys) = unzip pairs' + -- Types of our `Id`s are getting messed up by `dsInnerMonadComp` + -- so we construct them by hand: + qual_tys = map (mkBigCoreTupTy . map idType . snd) pairs + + ; mzip_op' <- dsExpr mzip_op + ; (zip_fn, zip_rhs) <- mkMcZipM mzip_op' (mc_m_ty mc) qual_tys + + ; let -- The pattern variables + vars = map (mkBigLHsVarPatTup . snd) pairs + -- Pattern with tuples of variables + -- [v1,v2,v3] => (v1, (v2, v3)) + pat = foldr (\tn tm -> mkBigLHsPatTup [tn, tm]) (last vars) (init vars) + rhs = Let (Rec [(zip_fn, zip_rhs)]) (mkApps (Var zip_fn) exps) + + ; dsMcBindStmt pat rhs bind_op noSyntaxExpr stmts_rest mc } + +dsMcStmt stmt _ _ = pprPanic "dsMcStmt: unexpected stmt" (ppr stmt) + + +-- general `rhs' >>= \pat -> stmts` desugaring where `rhs'` is already a +-- desugared `CoreExpr` +dsMcBindStmt :: LPat Id + -> CoreExpr -- ^ the desugared rhs of the bind statement + -> SyntaxExpr Id + -> SyntaxExpr Id + -> [LStmt Id] + -> DsMonadComp + -> DsM CoreExpr +dsMcBindStmt pat rhs' bind_op fail_op stmts mc + = do { body <- dsMcStmts stmts mc + ; bind_op' <- dsExpr bind_op + ; var <- selectSimpleMatchVarL pat + ; let bind_ty = exprType bind_op' -- rhs -> (pat -> res1) -> res2 + res1_ty = funResultTy (funArgTy (funResultTy bind_ty)) + ; match <- matchSinglePat (Var var) (StmtCtxt DoExpr) pat + res1_ty (cantFailMatchResult body) + ; match_code <- handle_failure pat match fail_op + ; return (mkApps bind_op' [rhs', Lam var match_code]) } + + where + -- In a monad comprehension expression, pattern-match failure just calls + -- the monadic `fail` rather than throwing an exception + handle_failure pat match fail_op + | matchCanFail match + = do { fail_op' <- dsExpr fail_op + ; fail_msg <- mkStringExpr (mk_fail_msg pat) + ; extractMatchResult match (App fail_op' fail_msg) } + | otherwise + = extractMatchResult match (error "It can't fail") + + mk_fail_msg :: Located e -> String + mk_fail_msg pat = "Pattern match failure in monad comprehension at " ++ + showSDoc (ppr (getLoc pat)) + +-- Desugar nested monad comprehensions, for example in `then..` constructs +dsInnerMonadComp :: ([LStmt Id], [Id]) + -> DsMonadComp + -> DsM (CoreExpr, Type) +dsInnerMonadComp (stmts, bndrs) DsMonadComp{ mc_return, mc_m_ty } + = do { expr <- dsMcStmts stmts mc' + ; return (expr, bndrs_tuple_type) } + where + bndrs_types = map idType bndrs + bndrs_tuple_type = mkAppTy mc_m_ty $ mkBigCoreTupTy bndrs_types + mc' = DsMonadComp mc_return (mkBigLHsVarTup bndrs) mc_m_ty + +-- The `unzip` function for `GroupStmt` in a monad comprehensions +-- +-- unzip :: m (a,b,..) -> (m a,m b,..) +-- unzip m_tuple = ( liftM selN1 m_tuple +-- , liftM selN2 m_tuple +-- , .. ) +-- +-- mkMcUnzipM m [t1, t2] +-- = (unzip_fn, \ys :: m (t1, t2) -> +-- ( liftM (selN1 :: (t1, t2) -> t1) ys +-- , liftM (selN2 :: (t1, t2) -> t2) ys +-- )) +-- +mkMcUnzipM :: CoreExpr + -> Type -- m + -> [Type] -- [a,b,c,..] + -> DsM (Id, CoreExpr) +mkMcUnzipM liftM_op m_ty elt_tys + = do { ys <- newSysLocalDs monad_tuple_ty + ; xs <- mapM newSysLocalDs elt_tys + ; scrut <- newSysLocalDs tuple_tys + + ; unzip_fn <- newSysLocalDs unzip_fn_ty + + ; let -- Select one Id from our tuple + selectExpr n = mkLams [scrut] $ mkTupleSelector xs (xs !! n) scrut (Var scrut) + -- Apply 'selectVar' and 'ys' to 'liftM' + tupleElem n = mkApps liftM_op + -- Types (m is figured out by the type checker): + -- liftM :: forall a b. (a -> b) -> m a -> m b + [ Type tuple_tys, Type (elt_tys !! n) + -- Arguments: + , selectExpr n, Var ys ] + -- The final expression with the big tuple + unzip_body = mkBigCoreTup [ tupleElem n | n <- [0..length elt_tys - 1] ] + + ; return (unzip_fn, mkLams [ys] unzip_body) } + where monad_tys = map (m_ty `mkAppTy`) elt_tys -- [m a,m b,m c,..] + tuple_monad_tys = mkBigCoreTupTy monad_tys -- (m a,m b,m c,..) + tuple_tys = mkBigCoreTupTy elt_tys -- (a,b,c,..) + monad_tuple_ty = m_ty `mkAppTy` tuple_tys -- m (a,b,c,..) + unzip_fn_ty = monad_tuple_ty `mkFunTy` tuple_monad_tys -- m (a,b,c,..) -> (m a,m b,m c,..) + +-- Generate the `mzip` function for `ParStmt` in monad comprehensions, for +-- example: +-- +-- mzip :: m t1 +-- -> (m t2 -> m t3 -> m (t2, t3)) +-- -> m (t1, (t2, t3)) +-- +-- mkMcZipM m [t1, t2, t3] +-- = (zip_fn, \(q1::t1) (q2::t2) (q3::t3) -> +-- mzip q1 (mzip q2 q3)) +-- +mkMcZipM :: CoreExpr + -> Type + -> [Type] + -> DsM (Id, CoreExpr) + +mkMcZipM mzip_op m_ty tys@(_:_:_) -- min. 2 types + = do { (ids, t1, tuple_ty, zip_body) <- loop tys + ; zip_fn <- newSysLocalDs $ + (m_ty `mkAppTy` t1) + `mkFunTy` + (m_ty `mkAppTy` tuple_ty) + `mkFunTy` + (m_ty `mkAppTy` mkBigCoreTupTy [t1, tuple_ty]) + ; return (zip_fn, mkLams ids zip_body) } + + where + -- loop :: [Type] -> DsM ([Id], Type, [Type], CoreExpr) + loop [t1, t2] = do -- last run of the `loop` + { ids@[a,b] <- newSysLocalsDs (map (m_ty `mkAppTy`) [t1,t2]) + ; let zip_body = mkApps mzip_op [ Type t1, Type t2 , Var a, Var b ] + ; return (ids, t1, t2, zip_body) } + + loop (t1:tr) = do + { -- Get ty, ids etc from the "inner" zip + (ids', t1', t2', zip_body') <- loop tr + + ; a <- newSysLocalDs $ m_ty `mkAppTy` t1 + ; let tuple_ty' = mkBigCoreTupTy [t1', t2'] + zip_body = mkApps mzip_op [ Type t1, Type tuple_ty', Var a, zip_body' ] + ; return ((a:ids'), t1, tuple_ty', zip_body) } + +-- This case should never happen: +mkMcZipM _ _ tys = pprPanic "mkMcZipM: unexpected argument" (ppr tys) + +\end{code} diff --git a/compiler/deSugar/DsMeta.hs b/compiler/deSugar/DsMeta.hs index e34c696..2c1939f 100644 --- a/compiler/deSugar/DsMeta.hs +++ b/compiler/deSugar/DsMeta.hs @@ -721,7 +721,7 @@ repE (HsLet bs e) = do { (ss,ds) <- repBinds bs ; wrapGenSyms ss z } -- FIXME: I haven't got the types here right yet -repE e@(HsDo ctxt sts body _) +repE e@(HsDo ctxt sts body _ _) | case ctxt of { DoExpr -> True; GhciStmt -> True; _ -> False } = do { (ss,zs) <- repLSts sts; body' <- addBinds ss $ repLE body; @@ -737,7 +737,7 @@ repE e@(HsDo ctxt sts body _) wrapGenSyms ss e' } | otherwise - = notHandled "mdo and [: :]" (ppr e) + = notHandled "mdo, monad comprehension and [: :]" (ppr e) repE (ExplicitList _ es) = do { xs <- repLEs es; repListExp xs } repE e@(ExplicitPArr _ _) = notHandled "Parallel arrays" (ppr e) @@ -817,7 +817,7 @@ repGuards other wrapGenSyms (concat xs) gd } where process :: LGRHS Name -> DsM ([GenSymBind], (Core (TH.Q (TH.Guard, TH.Exp)))) - process (L _ (GRHS [L _ (ExprStmt e1 _ _)] e2)) + process (L _ (GRHS [L _ (ExprStmt e1 _ _ _)] e2)) = do { x <- repLNormalGE e1 e2; return ([], x) } process (L _ (GRHS ss rhs)) @@ -876,7 +876,7 @@ repSts (LetStmt bs : ss) = ; z <- repLetSt ds ; (ss2,zs) <- addBinds ss1 (repSts ss) ; return (ss1++ss2, z : zs) } -repSts (ExprStmt e _ _ : ss) = +repSts (ExprStmt e _ _ _ : ss) = do { e2 <- repLE e ; z <- repNoBindSt e2 ; (ss2,zs) <- repSts ss diff --git a/compiler/hsSyn/Convert.lhs b/compiler/hsSyn/Convert.lhs index b5e6c41..c9cbfef 100644 --- a/compiler/hsSyn/Convert.lhs +++ b/compiler/hsSyn/Convert.lhs @@ -523,9 +523,9 @@ cvtHsDo do_or_lc stmts | otherwise = do { stmts' <- cvtStmts stmts ; body <- case last stmts' of - L _ (ExprStmt body _ _) -> return body + L _ (ExprStmt body _ _ _) -> return body stmt' -> failWith (bad_last stmt') - ; return $ HsDo do_or_lc (init stmts') body void } + ; return $ HsDo do_or_lc (init stmts') body noSyntaxExpr void } where bad_last stmt = vcat [ ptext (sLit "Illegal last statement of") <+> pprStmtContext do_or_lc <> colon , nest 2 $ Outputable.ppr stmt @@ -539,7 +539,7 @@ cvtStmt (NoBindS e) = do { e' <- cvtl e; returnL $ mkExprStmt e' } cvtStmt (TH.BindS p e) = do { p' <- cvtPat p; e' <- cvtl e; returnL $ mkBindStmt p' e' } cvtStmt (TH.LetS ds) = do { ds' <- cvtLocalDecs (ptext (sLit "a let binding")) ds ; returnL $ LetStmt ds' } -cvtStmt (TH.ParS dss) = do { dss' <- mapM cvt_one dss; returnL $ ParStmt dss' } +cvtStmt (TH.ParS dss) = do { dss' <- mapM cvt_one dss; returnL $ ParStmt dss' noSyntaxExpr noSyntaxExpr noSyntaxExpr } where cvt_one ds = do { ds' <- cvtStmts ds; return (ds', undefined) } diff --git a/compiler/hsSyn/HsExpr.lhs b/compiler/hsSyn/HsExpr.lhs index 06616f1..e367af5 100644 --- a/compiler/hsSyn/HsExpr.lhs +++ b/compiler/hsSyn/HsExpr.lhs @@ -148,6 +148,8 @@ data HsExpr id [LStmt id] -- "do":one or more stmts (LHsExpr id) -- The body; the last expression in the -- 'do' of [ body | ... ] in a list comp + (SyntaxExpr id) -- The 'return' function, see Note + -- [Monad Comprehensions] PostTcType -- Type of the whole expression | ExplicitList -- syntactic list @@ -439,7 +441,7 @@ ppr_expr (HsLet binds expr) = sep [hang (ptext (sLit "let")) 2 (pprBinds binds), hang (ptext (sLit "in")) 2 (ppr expr)] -ppr_expr (HsDo do_or_list_comp stmts body _) = pprDo do_or_list_comp stmts body +ppr_expr (HsDo do_or_list_comp stmts body _ _) = pprDo do_or_list_comp stmts body ppr_expr (ExplicitList _ exprs) = brackets (pprDeeperList fsep (punctuate comma (map ppr_lexpr exprs))) @@ -575,7 +577,7 @@ pprParendExpr expr HsPar {} -> pp_as_was HsBracket {} -> pp_as_was HsBracketOut _ [] -> pp_as_was - HsDo sc _ _ _ + HsDo sc _ _ _ _ | isListCompExpr sc -> pp_as_was _ -> parens pp_as_was @@ -830,8 +832,8 @@ type LStmtLR idL idR = Located (StmtLR idL idR) type Stmt id = StmtLR id id --- The SyntaxExprs in here are used *only* for do-notation, which --- has rebindable syntax. Otherwise they are unused. +-- The SyntaxExprs in here are used *only* for do-notation and monad +-- comprehensions, which have rebindable syntax. Otherwise they are unused. data StmtLR idL idR = BindStmt (LPat idL) (LHsExpr idR) @@ -842,17 +844,24 @@ data StmtLR idL idR | ExprStmt (LHsExpr idR) -- See Note [ExprStmt] (SyntaxExpr idR) -- The (>>) operator + (SyntaxExpr idR) -- The `guard` operator + -- See notes [Monad Comprehensions] PostTcType -- Element type of the RHS (used for arrows) | LetStmt (HsLocalBindsLR idL idR) - -- ParStmts only occur in a list comprehension + -- ParStmts only occur in a list/monad comprehension | ParStmt [([LStmt idL], [idR])] + (SyntaxExpr idR) -- polymorphic `mzip` for monad comprehensions + (SyntaxExpr idR) -- The `>>=` operator + (SyntaxExpr idR) -- polymorphic `return` operator + -- See notes [Monad Comprehensions] + -- After renaming, the ids are the binders bound by the stmts and used -- after them - -- "qs, then f by e" ==> TransformStmt qs binders f (Just e) - -- "qs, then f" ==> TransformStmt qs binders f Nothing + -- "qs, then f by e" ==> TransformStmt qs binders f (Just e) (return) (>>=) + -- "qs, then f" ==> TransformStmt qs binders f Nothing (return) (>>=) | TransformStmt [LStmt idL] -- Stmts are the ones to the left of the 'then' @@ -863,6 +872,11 @@ data StmtLR idL idR (Maybe (LHsExpr idR)) -- "by e" (optional) + (SyntaxExpr idR) -- The 'return' function for inner monad + -- comprehensions + (SyntaxExpr idR) -- The '(>>=)' operator. + -- See Note [Monad Comprehensions] + | GroupStmt [LStmt idL] -- Stmts to the *left* of the 'group' -- which generates the tuples to be grouped @@ -874,7 +888,14 @@ data StmtLR idL idR (Either -- "using f" (LHsExpr idR) -- Left f => explicit "using f" (SyntaxExpr idR)) -- Right f => implicit; filled in with 'groupWith' - + -- (list comprehensions) or 'groupM' (monad + -- comprehensions) + + (SyntaxExpr idR) -- The 'return' function for inner monad + -- comprehensions + (SyntaxExpr idR) -- The '(>>=)' operator + (SyntaxExpr idR) -- The 'liftM' function from Control.Monad for desugaring + -- See Note [Monad Comprehensions] -- Recursive statement (see Note [How RecStmt works] below) | RecStmt @@ -952,6 +973,12 @@ depends on the context. Consider the following contexts: E :: Bool Translation: if E then fail else ... + A monad comprehension of type (m res_ty) + ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + * ExprStmt E Bool: [ .. | .... E ] + E :: Bool + Translation: guard E >> ... + Array comprehensions are handled like list comprehensions -=chak Note [How RecStmt works] @@ -993,6 +1020,45 @@ A (RecStmt stmts) types as if you had written where v1..vn are the later_ids r1..rm are the rec_ids +Note [Monad Comprehensions] +~~~~~~~~~~~~~~~~~~~~~~~~~~~ +Monad comprehensions require seperate functions like 'return' and '>>=' for +desugaring. These functions are stored in the 'HsDo' expression and the +statements used in monad comprehensions. For example, the 'return' of the +'HsDo' expression is used to lift the body of the monad comprehension: + + [ body | stmts ] + => + stmts >>= \bndrs -> return body + +In transform and grouping statements ('then ..' and 'then group ..') the +'return' function is required for nested monad comprehensions, for example: + + [ body | stmts, then f, rest ] + => + f [ env | stmts ] >>= \bndrs -> [ body | rest ] + +Normal expressions require the 'Control.Monad.guard' function for boolean +expressions: + + [ body | exp, stmts ] + => + guard exp >> [ body | stmts ] + +Grouping/parallel statements require the 'Control.Monad.Group.groupM' and +'Control.Monad.Zip.mzip' functions: + + [ body | stmts, then group by e, rest] + => + groupM [ body | stmts ] >>= \bndrs -> [ body | rest ] + + [ body | stmts1 | stmts2 | .. ] + => + mzip stmts1 (mzip stmts2 (..)) >>= \(bndrs1, (bndrs2, ..)) -> return body + +In any other context than 'MonadComp', the fields for most of these +'SyntaxExpr's stay bottom. + \begin{code} instance (OutputableBndr idL, OutputableBndr idR) => Outputable (StmtLR idL idR) where @@ -1001,14 +1067,14 @@ instance (OutputableBndr idL, OutputableBndr idR) => Outputable (StmtLR idL idR) pprStmt :: (OutputableBndr idL, OutputableBndr idR) => (StmtLR idL idR) -> SDoc pprStmt (BindStmt pat expr _ _) = hsep [ppr pat, ptext (sLit "<-"), ppr expr] pprStmt (LetStmt binds) = hsep [ptext (sLit "let"), pprBinds binds] -pprStmt (ExprStmt expr _ _) = ppr expr -pprStmt (ParStmt stmtss) = hsep (map doStmts stmtss) +pprStmt (ExprStmt expr _ _ _) = ppr expr +pprStmt (ParStmt stmtss _ _ _) = hsep (map doStmts stmtss) where doStmts stmts = ptext (sLit "| ") <> ppr stmts -pprStmt (TransformStmt stmts bndrs using by) +pprStmt (TransformStmt stmts bndrs using by _ _) = sep (ppr_lc_stmts stmts ++ [pprTransformStmt bndrs using by]) -pprStmt (GroupStmt stmts _ by using) +pprStmt (GroupStmt stmts _ by using _ _ _) = sep (ppr_lc_stmts stmts ++ [pprGroupStmt by using]) pprStmt (RecStmt { recS_stmts = segment, recS_rec_ids = rec_ids @@ -1043,6 +1109,7 @@ pprDo GhciStmt stmts body = ptext (sLit "do") <+> ppr_do_stmts stmts body pprDo MDoExpr stmts body = ptext (sLit "mdo") <+> ppr_do_stmts stmts body pprDo ListComp stmts body = brackets $ pprComp stmts body pprDo PArrComp stmts body = pa_brackets $ pprComp stmts body +pprDo MonadComp stmts body = brackets $ pprComp stmts body pprDo _ _ _ = panic "pprDo" -- PatGuard, ParStmtCxt ppr_do_stmts :: OutputableBndr id => [LStmt id] -> LHsExpr id -> SDoc @@ -1178,6 +1245,7 @@ data HsStmtContext id | DoExpr | GhciStmt -- A command-line Stmt in GHCi pat <- rhs | MDoExpr -- Recursive do-expression + | MonadComp | PArrComp -- Parallel array comprehension | PatGuard (HsMatchContext id) -- Pattern guard for specified thing | ParStmtCtxt (HsStmtContext id) -- A branch of a parallel stmt @@ -1192,9 +1260,16 @@ isDoExpr MDoExpr = True isDoExpr _ = False isListCompExpr :: HsStmtContext id -> Bool -isListCompExpr ListComp = True -isListCompExpr PArrComp = True -isListCompExpr _ = False +isListCompExpr ListComp = True +isListCompExpr PArrComp = True +isListCompExpr MonadComp = True +isListCompExpr _ = False + +isMonadCompExpr :: HsStmtContext id -> Bool +isMonadCompExpr MonadComp = True +isMonadCompExpr (ParStmtCtxt ctxt) = isMonadCompExpr ctxt +isMonadCompExpr (TransformStmtCtxt ctxt) = isMonadCompExpr ctxt +isMonadCompExpr _ = False \end{code} \begin{code} @@ -1242,6 +1317,7 @@ pprStmtContext GhciStmt = ptext (sLit "an interactive GHCi command") pprStmtContext DoExpr = ptext (sLit "a 'do' expression") pprStmtContext MDoExpr = ptext (sLit "an 'mdo' expression") pprStmtContext ListComp = ptext (sLit "a list comprehension") +pprStmtContext MonadComp = ptext (sLit "a monad comprehension") pprStmtContext PArrComp = ptext (sLit "an array comprehension") {- @@ -1275,6 +1351,7 @@ matchContextErrString (StmtCtxt GhciStmt) = ptext (sLit "interactive GHCi matchContextErrString (StmtCtxt DoExpr) = ptext (sLit "'do' expression") matchContextErrString (StmtCtxt MDoExpr) = ptext (sLit "'mdo' expression") matchContextErrString (StmtCtxt ListComp) = ptext (sLit "list comprehension") +matchContextErrString (StmtCtxt MonadComp) = ptext (sLit "monad comprehension") matchContextErrString (StmtCtxt PArrComp) = ptext (sLit "array comprehension") \end{code} @@ -1290,7 +1367,7 @@ pprStmtInCtxt ctxt stmt = hang (ptext (sLit "In a stmt of") <+> pprStmtContext c 4 (ppr_stmt stmt) where -- For Group and Transform Stmts, don't print the nested stmts! - ppr_stmt (GroupStmt _ _ by using) = pprGroupStmt by using - ppr_stmt (TransformStmt _ bndrs using by) = pprTransformStmt bndrs using by - ppr_stmt stmt = pprStmt stmt + ppr_stmt (GroupStmt _ _ by using _ _ _) = pprGroupStmt by using + ppr_stmt (TransformStmt _ bndrs using by _ _) = pprTransformStmt bndrs using by + ppr_stmt stmt = pprStmt stmt \end{code} diff --git a/compiler/hsSyn/HsLit.lhs b/compiler/hsSyn/HsLit.lhs index 0874dda..c29083c 100644 --- a/compiler/hsSyn/HsLit.lhs +++ b/compiler/hsSyn/HsLit.lhs @@ -63,8 +63,7 @@ instance Eq HsLit where data HsOverLit id -- An overloaded literal = OverLit { ol_val :: OverLitVal, - ol_rebindable :: Bool, -- True <=> rebindable syntax - -- False <=> standard syntax + ol_rebindable :: Bool, -- ol_witness :: SyntaxExpr id, -- Note [Overloaded literal witnesses] ol_type :: PostTcType } deriving (Data, Typeable) @@ -79,6 +78,19 @@ overLitType :: HsOverLit a -> Type overLitType = ol_type \end{code} +Note [ol_rebindable] +~~~~~~~~~~~~~~~~~~~~ +The ol_rebindable field is True if this literal is actually +using rebindable syntax. Specifically: + + False iff ol_witness is the standard one + True iff ol_witness is non-standard + +Equivalently it's True if + a) RebindableSyntax is on + b) the witness for fromInteger/fromRational/fromString + that happens to be in scope isn't the standard one + Note [Overloaded literal witnesses] ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ *Before* type checking, the SyntaxExpr in an HsOverLit is the diff --git a/compiler/hsSyn/HsPat.lhs b/compiler/hsSyn/HsPat.lhs index 78b5887..3efcd59 100644 --- a/compiler/hsSyn/HsPat.lhs +++ b/compiler/hsSyn/HsPat.lhs @@ -122,7 +122,9 @@ data Pat id | LitPat HsLit -- Used for *non-overloaded* literal patterns: -- Int#, Char#, Int, Char, String, etc. - | NPat (HsOverLit id) -- ALWAYS positive + | NPat -- Used for all overloaded literals, + -- including overloaded strings with -XOverloadedStrings + (HsOverLit id) -- ALWAYS positive (Maybe (SyntaxExpr id)) -- Just (Name of 'negate') for negative -- patterns, Nothing otherwise (SyntaxExpr id) -- Equality checker, of type t->t->Bool diff --git a/compiler/hsSyn/HsUtils.lhs b/compiler/hsSyn/HsUtils.lhs index 13f3cd7..44e3a32 100644 --- a/compiler/hsSyn/HsUtils.lhs +++ b/compiler/hsSyn/HsUtils.lhs @@ -212,7 +212,7 @@ mkHsIsString s = OverLit (HsIsString s) noRebindableInfo noSyntaxExpr noRebindableInfo :: Bool noRebindableInfo = error "noRebindableInfo" -- Just another placeholder; -mkHsDo ctxt stmts body = HsDo ctxt stmts body placeHolderType +mkHsDo ctxt stmts body = HsDo ctxt stmts body noSyntaxExpr placeHolderType mkHsIf :: LHsExpr id -> LHsExpr id -> LHsExpr id -> HsExpr id mkHsIf c a b = HsIf (Just noSyntaxExpr) c a b @@ -220,18 +220,18 @@ mkHsIf c a b = HsIf (Just noSyntaxExpr) c a b mkNPat lit neg = NPat lit neg noSyntaxExpr mkNPlusKPat id lit = NPlusKPat id lit noSyntaxExpr noSyntaxExpr -mkTransformStmt stmts usingExpr = TransformStmt stmts [] usingExpr Nothing -mkTransformByStmt stmts usingExpr byExpr = TransformStmt stmts [] usingExpr (Just byExpr) +mkTransformStmt stmts usingExpr = TransformStmt stmts [] usingExpr Nothing noSyntaxExpr noSyntaxExpr +mkTransformByStmt stmts usingExpr byExpr = TransformStmt stmts [] usingExpr (Just byExpr) noSyntaxExpr noSyntaxExpr mkGroupUsingStmt :: [LStmt idL] -> LHsExpr idR -> StmtLR idL idR mkGroupByStmt :: [LStmt idL] -> LHsExpr idR -> StmtLR idL idR mkGroupByUsingStmt :: [LStmt idL] -> LHsExpr idR -> LHsExpr idR -> StmtLR idL idR -mkGroupUsingStmt stmts usingExpr = GroupStmt stmts [] Nothing (Left usingExpr) -mkGroupByStmt stmts byExpr = GroupStmt stmts [] (Just byExpr) (Right noSyntaxExpr) -mkGroupByUsingStmt stmts byExpr usingExpr = GroupStmt stmts [] (Just byExpr) (Left usingExpr) +mkGroupUsingStmt stmts usingExpr = GroupStmt stmts [] Nothing (Left usingExpr) noSyntaxExpr noSyntaxExpr noSyntaxExpr +mkGroupByStmt stmts byExpr = GroupStmt stmts [] (Just byExpr) (Right noSyntaxExpr) noSyntaxExpr noSyntaxExpr noSyntaxExpr +mkGroupByUsingStmt stmts byExpr usingExpr = GroupStmt stmts [] (Just byExpr) (Left usingExpr) noSyntaxExpr noSyntaxExpr noSyntaxExpr -mkExprStmt expr = ExprStmt expr noSyntaxExpr placeHolderType +mkExprStmt expr = ExprStmt expr noSyntaxExpr noSyntaxExpr placeHolderType mkBindStmt pat expr = BindStmt pat expr noSyntaxExpr noSyntaxExpr emptyRecStmt = RecStmt { recS_stmts = [], recS_later_ids = [], recS_rec_ids = [] @@ -496,12 +496,12 @@ collectStmtBinders :: StmtLR idL idR -> [idL] -- Id Binders for a Stmt... [but what about pattern-sig type vars]? collectStmtBinders (BindStmt pat _ _ _) = collectPatBinders pat collectStmtBinders (LetStmt binds) = collectLocalBinders binds -collectStmtBinders (ExprStmt _ _ _) = [] -collectStmtBinders (ParStmt xs) = collectLStmtsBinders +collectStmtBinders (ExprStmt _ _ _ _) = [] +collectStmtBinders (ParStmt xs _ _ _) = collectLStmtsBinders $ concatMap fst xs -collectStmtBinders (TransformStmt stmts _ _ _) = collectLStmtsBinders stmts -collectStmtBinders (GroupStmt stmts _ _ _) = collectLStmtsBinders stmts -collectStmtBinders (RecStmt { recS_stmts = ss }) = collectLStmtsBinders ss +collectStmtBinders (TransformStmt stmts _ _ _ _ _) = collectLStmtsBinders stmts +collectStmtBinders (GroupStmt stmts _ _ _ _ _ _) = collectLStmtsBinders stmts +collectStmtBinders (RecStmt { recS_stmts = ss }) = collectLStmtsBinders ss ----------------- Patterns -------------------------- @@ -642,11 +642,11 @@ lStmtsImplicits = hs_lstmts hs_stmt (BindStmt pat _ _ _) = lPatImplicits pat hs_stmt (LetStmt binds) = hs_local_binds binds - hs_stmt (ExprStmt _ _ _) = emptyNameSet - hs_stmt (ParStmt xs) = hs_lstmts $ concatMap fst xs + hs_stmt (ExprStmt _ _ _ _) = emptyNameSet + hs_stmt (ParStmt xs _ _ _) = hs_lstmts $ concatMap fst xs - hs_stmt (TransformStmt stmts _ _ _) = hs_lstmts stmts - hs_stmt (GroupStmt stmts _ _ _) = hs_lstmts stmts + hs_stmt (TransformStmt stmts _ _ _ _ _) = hs_lstmts stmts + hs_stmt (GroupStmt stmts _ _ _ _ _ _) = hs_lstmts stmts hs_stmt (RecStmt { recS_stmts = ss }) = hs_lstmts ss hs_local_binds (HsValBinds val_binds) = hsValBindsImplicits val_binds diff --git a/compiler/main/DynFlags.hs b/compiler/main/DynFlags.hs index fa05195..0914c32 100644 --- a/compiler/main/DynFlags.hs +++ b/compiler/main/DynFlags.hs @@ -351,6 +351,7 @@ data ExtensionFlag | Opt_KindSignatures | Opt_ParallelListComp | Opt_TransformListComp + | Opt_MonadComprehensions | Opt_GeneralizedNewtypeDeriving | Opt_RecursiveDo | Opt_DoRec @@ -1569,6 +1570,7 @@ xFlags = [ ( "EmptyDataDecls", Opt_EmptyDataDecls, nop ), ( "ParallelListComp", Opt_ParallelListComp, nop ), ( "TransformListComp", Opt_TransformListComp, nop ), + ( "MonadComprehensions", Opt_MonadComprehensions, nop), ( "ForeignFunctionInterface", Opt_ForeignFunctionInterface, nop ), ( "UnliftedFFITypes", Opt_UnliftedFFITypes, nop ), ( "GHCForeignImportPrim", Opt_GHCForeignImportPrim, nop ), diff --git a/compiler/main/HscMain.lhs b/compiler/main/HscMain.lhs index 70ddd6a..f0c1111 100644 --- a/compiler/main/HscMain.lhs +++ b/compiler/main/HscMain.lhs @@ -1132,7 +1132,7 @@ hscTcExpr -- Typecheck an expression (but don't run it) hscTcExpr hsc_env expr = runHsc hsc_env $ do maybe_stmt <- hscParseStmt expr case maybe_stmt of - Just (L _ (ExprStmt expr _ _)) -> + Just (L _ (ExprStmt expr _ _ _)) -> ioMsgMaybe $ tcRnExpr hsc_env (hsc_IC hsc_env) expr _ -> liftIO $ throwIO $ mkSrcErr $ unitBag $ diff --git a/compiler/parser/Lexer.x b/compiler/parser/Lexer.x index 5c41d72..61019b3 100644 --- a/compiler/parser/Lexer.x +++ b/compiler/parser/Lexer.x @@ -1893,6 +1893,7 @@ mkPState flags buf loc = .|. unboxedTuplesBit `setBitIf` xopt Opt_UnboxedTuples flags .|. datatypeContextsBit `setBitIf` xopt Opt_DatatypeContexts flags .|. transformComprehensionsBit `setBitIf` xopt Opt_TransformListComp flags + .|. transformComprehensionsBit `setBitIf` xopt Opt_MonadComprehensions flags .|. rawTokenStreamBit `setBitIf` dopt Opt_KeepRawTokenStream flags .|. alternativeLayoutRuleBit `setBitIf` xopt Opt_AlternativeLayoutRule flags .|. relaxedLayoutBit `setBitIf` xopt Opt_RelaxedLayout flags diff --git a/compiler/parser/Parser.y.pp b/compiler/parser/Parser.y.pp index bfadfba..ec8d3ff 100644 --- a/compiler/parser/Parser.y.pp +++ b/compiler/parser/Parser.y.pp @@ -1465,7 +1465,8 @@ list :: { LHsExpr RdrName } | texp ',' exp '..' { LL $ ArithSeq noPostTcExpr (FromThen $1 $3) } | texp '..' exp { LL $ ArithSeq noPostTcExpr (FromTo $1 $3) } | texp ',' exp '..' exp { LL $ ArithSeq noPostTcExpr (FromThenTo $1 $3 $5) } - | texp '|' flattenedpquals { sL (comb2 $1 $>) $ mkHsDo ListComp (unLoc $3) $1 } + | texp '|' flattenedpquals {% checkMonadComp >>= \ ctxt -> + return (sL (comb2 $1 $>) $ mkHsDo ctxt (unLoc $3) $1) } lexps :: { Located [LHsExpr RdrName] } : lexps ',' texp { LL (((:) $! $3) $! unLoc $1) } @@ -1480,7 +1481,7 @@ flattenedpquals :: { Located [LStmt RdrName] } -- We just had one thing in our "parallel" list so -- we simply return that thing directly - qss -> L1 [L1 $ ParStmt [(qs, undefined) | qs <- qss]] + qss -> L1 [L1 $ ParStmt [(qs, undefined) | qs <- qss] noSyntaxExpr noSyntaxExpr noSyntaxExpr] -- We actually found some actual parallel lists so -- we wrap them into as a ParStmt } diff --git a/compiler/parser/RdrHsSyn.lhs b/compiler/parser/RdrHsSyn.lhs index 47abf23..0e22c69 100644 --- a/compiler/parser/RdrHsSyn.lhs +++ b/compiler/parser/RdrHsSyn.lhs @@ -42,6 +42,7 @@ module RdrHsSyn ( checkPatterns, -- SrcLoc -> [HsExp] -> P [HsPat] checkDo, -- [Stmt] -> P [Stmt] checkMDo, -- [Stmt] -> P [Stmt] + checkMonadComp, -- P (HsStmtContext RdrName) checkValDef, -- (SrcLoc, HsExp, HsRhs, [HsDecl]) -> P HsDecl checkValSig, -- (SrcLoc, HsExp, HsRhs, [HsDecl]) -> P HsDecl checkDoAndIfThenElse, @@ -54,6 +55,7 @@ import Class ( FunDep ) import TypeRep ( Kind ) import RdrName ( RdrName, isRdrTyVar, isRdrTc, mkUnqual, rdrNameOcc, isRdrDataCon, isUnqual, getRdrName, setRdrNameSpace ) +import Name ( Name ) import BasicTypes ( maxPrecedence, Activation(..), RuleMatchInfo, InlinePragma(..), InlineSpec(..) ) import Lexer @@ -629,8 +631,8 @@ checkDoMDo _ nm loc [] = parseErrorSDoc loc (text ("Empty " ++ nm ++ " const checkDoMDo pre nm _ ss = do check ss where - check [] = panic "RdrHsSyn:checkDoMDo" - check [L _ (ExprStmt e _ _)] = return ([], e) + check [] = panic "RdrHsSyn:checkDoMDo" + check [L _ (ExprStmt e _ _ _)] = return ([], e) check [L l e] = parseErrorSDoc l (text ("The last statement in " ++ pre ++ nm ++ " construct must be an expression:") @@ -912,6 +914,20 @@ isFunLhs e = go e [] _ -> return Nothing } go _ _ = return Nothing + +--------------------------------------------------------------------------- +-- Check for monad comprehensions +-- +-- If the flag MonadComprehensions is set, return a `MonadComp' context, +-- otherwise use the usual `ListComp' context + +checkMonadComp :: P (HsStmtContext Name) +checkMonadComp = do + pState <- getPState + return $ if xopt Opt_MonadComprehensions (dflags pState) + then MonadComp + else ListComp + --------------------------------------------------------------------------- -- Miscellaneous utilities diff --git a/compiler/prelude/PrelNames.lhs b/compiler/prelude/PrelNames.lhs index 24756d5..9b59f5d 100644 --- a/compiler/prelude/PrelNames.lhs +++ b/compiler/prelude/PrelNames.lhs @@ -221,6 +221,12 @@ basicKnownKeyNames -- dotnet interop , objectTyConName, marshalObjectName, unmarshalObjectName , marshalStringName, unmarshalStringName, checkDotnetResName + + -- Monad comprehensions + , guardMName + , liftMName + , groupMName + , mzipName ] genericTyConNames :: [Name] @@ -262,8 +268,9 @@ gHC_PRIM, gHC_TYPES, gHC_UNIT, gHC_ORDERING, gHC_GENERICS, gHC_PACK, gHC_CONC, gHC_IO, gHC_IO_Exception, gHC_ST, gHC_ARR, gHC_STABLE, gHC_ADDR, gHC_PTR, gHC_ERR, gHC_REAL, gHC_FLOAT, gHC_TOP_HANDLER, sYSTEM_IO, dYNAMIC, tYPEABLE, gENERICS, - dOTNET, rEAD_PREC, lEX, gHC_INT, gHC_WORD, mONAD, mONAD_FIX, aRROW, cONTROL_APPLICATIVE, - gHC_DESUGAR, rANDOM, gHC_EXTS, cONTROL_EXCEPTION_BASE :: Module + dOTNET, rEAD_PREC, lEX, gHC_INT, gHC_WORD, mONAD, mONAD_FIX, mONAD_GROUP, mONAD_ZIP, + aRROW, cONTROL_APPLICATIVE, gHC_DESUGAR, rANDOM, gHC_EXTS, + cONTROL_EXCEPTION_BASE :: Module gHC_PRIM = mkPrimModule (fsLit "GHC.Prim") -- Primitive types and values gHC_TYPES = mkPrimModule (fsLit "GHC.Types") @@ -311,6 +318,8 @@ gHC_INT = mkBaseModule (fsLit "GHC.Int") gHC_WORD = mkBaseModule (fsLit "GHC.Word") mONAD = mkBaseModule (fsLit "Control.Monad") mONAD_FIX = mkBaseModule (fsLit "Control.Monad.Fix") +mONAD_GROUP = mkBaseModule (fsLit "Control.Monad.Group") +mONAD_ZIP = mkBaseModule (fsLit "Control.Monad.Zip") aRROW = mkBaseModule (fsLit "Control.Arrow") cONTROL_APPLICATIVE = mkBaseModule (fsLit "Control.Applicative") gHC_DESUGAR = mkBaseModule (fsLit "GHC.Desugar") @@ -834,6 +843,14 @@ appAName = varQual aRROW (fsLit "app") appAIdKey choiceAName = varQual aRROW (fsLit "|||") choiceAIdKey loopAName = varQual aRROW (fsLit "loop") loopAIdKey +-- Monad comprehensions +guardMName, liftMName, groupMName, mzipName :: Name +guardMName = varQual mONAD (fsLit "guard") guardMIdKey +liftMName = varQual mONAD (fsLit "liftM") liftMIdKey +groupMName = varQual mONAD_GROUP (fsLit "mgroupWith") groupMIdKey +mzipName = varQual mONAD_ZIP (fsLit "mzip") mzipIdKey + + -- Annotation type checking toAnnotationWrapperName :: Name toAnnotationWrapperName = varQual gHC_DESUGAR (fsLit "toAnnotationWrapper") toAnnotationWrapperIdKey @@ -1325,6 +1342,14 @@ realToFracIdKey = mkPreludeMiscIdUnique 128 toIntegerClassOpKey = mkPreludeMiscIdUnique 129 toRationalClassOpKey = mkPreludeMiscIdUnique 130 +-- Monad comprehensions +guardMIdKey, liftMIdKey, groupMIdKey, mzipIdKey :: Unique +guardMIdKey = mkPreludeMiscIdUnique 131 +liftMIdKey = mkPreludeMiscIdUnique 132 +groupMIdKey = mkPreludeMiscIdUnique 133 +mzipIdKey = mkPreludeMiscIdUnique 134 + + ---------------- Template Haskell ------------------- -- USES IdUniques 200-499 ----------------------------------------------------- diff --git a/compiler/rename/RnBinds.lhs b/compiler/rename/RnBinds.lhs index df3b12d..dc7ea96 100644 --- a/compiler/rename/RnBinds.lhs +++ b/compiler/rename/RnBinds.lhs @@ -789,9 +789,9 @@ rnGRHS' ctxt (GRHS guards rhs) -- Standard Haskell 1.4 guards are just a single boolean -- expression, rather than a list of qualifiers as in the -- Glasgow extension - is_standard_guard [] = True - is_standard_guard [L _ (ExprStmt _ _ _)] = True - is_standard_guard _ = False + is_standard_guard [] = True + is_standard_guard [L _ (ExprStmt _ _ _ _)] = True + is_standard_guard _ = False \end{code} %************************************************************************ diff --git a/compiler/rename/RnExpr.lhs b/compiler/rename/RnExpr.lhs index d11249a..425cb40 100644 --- a/compiler/rename/RnExpr.lhs +++ b/compiler/rename/RnExpr.lhs @@ -224,10 +224,16 @@ rnExpr (HsLet binds expr) rnLExpr expr `thenM` \ (expr',fvExpr) -> return (HsLet binds' expr', fvExpr) -rnExpr (HsDo do_or_lc stmts body _) - = do { ((stmts', body'), fvs) <- rnStmts do_or_lc stmts $ \ _ -> - rnLExpr body - ; return (HsDo do_or_lc stmts' body' placeHolderType, fvs) } +rnExpr (HsDo do_or_lc stmts body _ _) + = do { ((stmts', body'), fvs1) <- rnStmts do_or_lc stmts $ \ _ -> + rnLExpr body + ; (return_op, fvs2) <- + if isMonadCompExpr do_or_lc + then lookupSyntaxName returnMName + else return (noSyntaxExpr, emptyFVs) + + ; return ( HsDo do_or_lc stmts' body' return_op placeHolderType + , fvs1 `plusFV` fvs2 ) } rnExpr (ExplicitList _ exps) = rnExprs exps `thenM` \ (exps', fvs) -> @@ -441,9 +447,10 @@ convertOpFormsCmd (HsIf f exp c1 c2) convertOpFormsCmd (HsLet binds cmd) = HsLet binds (convertOpFormsLCmd cmd) -convertOpFormsCmd (HsDo ctxt stmts body ty) +convertOpFormsCmd (HsDo ctxt stmts body return_op ty) = HsDo ctxt (map (fmap convertOpFormsStmt) stmts) - (convertOpFormsLCmd body) ty + (convertOpFormsLCmd body) + (convertOpFormsCmd return_op) ty -- Anything else is unchanged. This includes HsArrForm (already done), -- things with no sub-commands, and illegal commands (which will be @@ -453,8 +460,8 @@ convertOpFormsCmd c = c convertOpFormsStmt :: StmtLR id id -> StmtLR id id convertOpFormsStmt (BindStmt pat cmd _ _) = BindStmt pat (convertOpFormsLCmd cmd) noSyntaxExpr noSyntaxExpr -convertOpFormsStmt (ExprStmt cmd _ _) - = ExprStmt (convertOpFormsLCmd cmd) noSyntaxExpr placeHolderType +convertOpFormsStmt (ExprStmt cmd _ _ _) + = ExprStmt (convertOpFormsLCmd cmd) noSyntaxExpr noSyntaxExpr placeHolderType convertOpFormsStmt stmt@(RecStmt { recS_stmts = stmts }) = stmt { recS_stmts = map (fmap convertOpFormsStmt) stmts } convertOpFormsStmt stmt = stmt @@ -497,7 +504,7 @@ methodNamesCmd (HsIf _ _ c1 c2) methodNamesCmd (HsLet _ c) = methodNamesLCmd c -methodNamesCmd (HsDo _ stmts body _) +methodNamesCmd (HsDo _ stmts body _ _) = methodNamesStmts stmts `plusFV` methodNamesLCmd body methodNamesCmd (HsApp c _) = methodNamesLCmd c @@ -538,11 +545,11 @@ methodNamesLStmt :: Located (StmtLR Name Name) -> FreeVars methodNamesLStmt = methodNamesStmt . unLoc methodNamesStmt :: StmtLR Name Name -> FreeVars -methodNamesStmt (ExprStmt cmd _ _) = methodNamesLCmd cmd +methodNamesStmt (ExprStmt cmd _ _ _) = methodNamesLCmd cmd methodNamesStmt (BindStmt _ cmd _ _) = methodNamesLCmd cmd methodNamesStmt (RecStmt { recS_stmts = stmts }) = methodNamesStmts stmts `addOneFV` loopAName methodNamesStmt (LetStmt _) = emptyFVs -methodNamesStmt (ParStmt _) = emptyFVs +methodNamesStmt (ParStmt _ _ _ _) = emptyFVs methodNamesStmt (TransformStmt {}) = emptyFVs methodNamesStmt (GroupStmt {}) = emptyFVs -- ParStmt, TransformStmt and GroupStmt can't occur in commands, but it's not convenient to error @@ -665,12 +672,15 @@ rnStmt :: HsStmtContext Name -> LStmt RdrName -- Variables bound by the Stmt, and mentioned in thing_inside, -- do not appear in the result FreeVars -rnStmt _ (L loc (ExprStmt expr _ _)) thing_inside +rnStmt ctxt (L loc (ExprStmt expr _ _ _)) thing_inside = do { (expr', fv_expr) <- rnLExpr expr ; (then_op, fvs1) <- lookupSyntaxName thenMName - ; (thing, fvs2) <- thing_inside [] - ; return (([L loc (ExprStmt expr' then_op placeHolderType)], thing), - fv_expr `plusFV` fvs1 `plusFV` fvs2) } + ; (guard_op, fvs2) <- if isMonadCompExpr ctxt + then lookupSyntaxName guardMName + else return (noSyntaxExpr, emptyFVs) + ; (thing, fvs3) <- thing_inside [] + ; return (([L loc (ExprStmt expr' then_op guard_op placeHolderType)], thing), + fv_expr `plusFV` fvs1 `plusFV` fvs2 `plusFV` fvs3) } rnStmt ctxt (L loc (BindStmt pat expr _ _)) thing_inside = do { (expr', fv_expr) <- rnLExpr expr @@ -734,12 +744,20 @@ rnStmt ctxt (L _ (RecStmt { recS_stmts = rec_stmts })) thing_inside ; return ((rec_stmts', thing), fvs `plusFV` fvs1 `plusFV` fvs2 `plusFV` fvs3) } } -rnStmt ctxt (L loc (ParStmt segs)) thing_inside +rnStmt ctxt (L loc (ParStmt segs _ _ _)) thing_inside = do { checkParStmt ctxt - ; ((segs', thing), fvs) <- rnParallelStmts (ParStmtCtxt ctxt) segs thing_inside - ; return (([L loc (ParStmt segs')], thing), fvs) } - -rnStmt ctxt (L loc (TransformStmt stmts _ using by)) thing_inside + ; ((mzip_op, fvs1), (bind_op, fvs2), (return_op, fvs3)) <- if isMonadCompExpr ctxt + then (,,) <$> lookupSyntaxName mzipName + <*> lookupSyntaxName bindMName + <*> lookupSyntaxName returnMName + else return ( (noSyntaxExpr, emptyFVs) + , (noSyntaxExpr, emptyFVs) + , (noSyntaxExpr, emptyFVs) ) + ; ((segs', thing), fvs4) <- rnParallelStmts (ParStmtCtxt ctxt) segs thing_inside + ; return ( ([L loc (ParStmt segs' mzip_op bind_op return_op)], thing) + , fvs1 `plusFV` fvs2 `plusFV` fvs3 `plusFV` fvs4) } + +rnStmt ctxt (L loc (TransformStmt stmts _ using by _ _)) thing_inside = do { checkTransformStmt ctxt ; (using', fvs1) <- rnLExpr using @@ -756,17 +774,30 @@ rnStmt ctxt (L loc (TransformStmt stmts _ using by)) thing_inside -- the "thing inside", **or of the by-expression**, as used ; return ((by', used_bndrs, thing), fvs) } - ; return (([L loc (TransformStmt stmts' used_bndrs using' by')], thing), - fvs1 `plusFV` fvs2) } + -- Lookup `(>>=)` and `fail` for monad comprehensions + ; ((return_op, fvs3), (bind_op, fvs4)) <- + if isMonadCompExpr ctxt + then (,) <$> lookupSyntaxName returnMName + <*> lookupSyntaxName bindMName + else return ( (noSyntaxExpr, emptyFVs) + , (noSyntaxExpr, emptyFVs) ) + + ; return (([L loc (TransformStmt stmts' used_bndrs using' by' return_op bind_op)], thing), + fvs1 `plusFV` fvs2 `plusFV` fvs3 `plusFV` fvs4) } -rnStmt ctxt (L loc (GroupStmt stmts _ by using)) thing_inside +rnStmt ctxt (L loc (GroupStmt stmts _ by using _ _ _)) thing_inside = do { checkTransformStmt ctxt -- Rename the 'using' expression in the context before the transform is begun ; (using', fvs1) <- case using of Left e -> do { (e', fvs) <- rnLExpr e; return (Left e', fvs) } - Right _ -> do { (e', fvs) <- lookupSyntaxName groupWithName - ; return (Right e', fvs) } + Right _ + | isMonadCompExpr ctxt -> + do { (e', fvs) <- lookupSyntaxName groupMName + ; return (Right e', fvs) } + | otherwise -> + do { (e', fvs) <- lookupSyntaxName groupWithName + ; return (Right e', fvs) } -- Rename the stmts and the 'by' expression -- Keep track of the variables mentioned in the 'by' expression @@ -778,13 +809,23 @@ rnStmt ctxt (L loc (GroupStmt stmts _ by using)) thing_inside used_bndrs = filter (`elemNameSet` fvs) bndrs ; return ((by', used_bndrs, thing), fvs) } - ; let all_fvs = fvs1 `plusFV` fvs2 + -- Lookup `return`, `(>>=)` and `liftM` for monad comprehensions + ; ((return_op, fvs3), (bind_op, fvs4), (liftM_op, fvs5)) <- + if isMonadCompExpr ctxt + then (,,) <$> lookupSyntaxName returnMName + <*> lookupSyntaxName bindMName + <*> lookupSyntaxName liftMName + else return ( (noSyntaxExpr, emptyFVs) + , (noSyntaxExpr, emptyFVs) + , (noSyntaxExpr, emptyFVs) ) + + ; let all_fvs = fvs1 `plusFV` fvs2 `plusFV` fvs3 `plusFV` fvs4 + `plusFV` fvs5 bndr_map = used_bndrs `zip` used_bndrs -- See Note [GroupStmt binder map] in HsExpr ; traceRn (text "rnStmt: implicitly rebound these used binders:" <+> ppr bndr_map) - ; return (([L loc (GroupStmt stmts' bndr_map by' using')], thing), all_fvs) } - + ; return (([L loc (GroupStmt stmts' bndr_map by' using' return_op bind_op liftM_op)], thing), all_fvs) } type ParSeg id = ([LStmt id], [id]) -- The Names are bound by the Stmts @@ -901,9 +942,9 @@ rn_rec_stmt_lhs :: MiniFixityEnv -- so we don't bother to compute it accurately in the other cases -> RnM [(LStmtLR Name RdrName, FreeVars)] -rn_rec_stmt_lhs _ (L loc (ExprStmt expr a b)) = return [(L loc (ExprStmt expr a b), - -- this is actually correct - emptyFVs)] +rn_rec_stmt_lhs _ (L loc (ExprStmt expr a b c)) = return [(L loc (ExprStmt expr a b c), + -- this is actually correct + emptyFVs)] rn_rec_stmt_lhs fix_env (L loc (BindStmt pat expr a b)) = do @@ -926,7 +967,7 @@ rn_rec_stmt_lhs fix_env (L loc (LetStmt (HsValBinds binds))) rn_rec_stmt_lhs fix_env (L _ (RecStmt { recS_stmts = stmts })) -- Flatten Rec inside Rec = rn_rec_stmts_lhs fix_env stmts -rn_rec_stmt_lhs _ stmt@(L _ (ParStmt _)) -- Syntactically illegal in mdo +rn_rec_stmt_lhs _ stmt@(L _ (ParStmt _ _ _ _)) -- Syntactically illegal in mdo = pprPanic "rn_rec_stmt" (ppr stmt) rn_rec_stmt_lhs _ stmt@(L _ (TransformStmt {})) -- Syntactically illegal in mdo @@ -957,11 +998,11 @@ rn_rec_stmt :: [Name] -> LStmtLR Name RdrName -> FreeVars -> RnM [Segment (LStmt -- Rename a Stmt that is inside a RecStmt (or mdo) -- Assumes all binders are already in scope -- Turns each stmt into a singleton Stmt -rn_rec_stmt _ (L loc (ExprStmt expr _ _)) _ +rn_rec_stmt _ (L loc (ExprStmt expr _ _ _)) _ = rnLExpr expr `thenM` \ (expr', fvs) -> lookupSyntaxName thenMName `thenM` \ (then_op, fvs1) -> return [(emptyNameSet, fvs `plusFV` fvs1, emptyNameSet, - L loc (ExprStmt expr' then_op placeHolderType))] + L loc (ExprStmt expr' then_op noSyntaxExpr placeHolderType))] rn_rec_stmt _ (L loc (BindStmt pat' expr _ _)) fv_pat = rnLExpr expr `thenM` \ (expr', fv_expr) -> @@ -1161,10 +1202,13 @@ checkRecStmt ctxt = addErr msg --------- checkParStmt :: HsStmtContext Name -> RnM () checkParStmt _ - = do { parallel_list_comp <- xoptM Opt_ParallelListComp - ; checkErr parallel_list_comp msg } + = do { monad_comp <- xoptM Opt_MonadComprehensions + ; unless monad_comp $ do + { parallel_list_comp <- xoptM Opt_ParallelListComp + ; checkErr parallel_list_comp msg } + } where - msg = ptext (sLit "Illegal parallel list comprehension: use -XParallelListComp") + msg = ptext (sLit "Illegal parallel list comprehension: use -XParallelListComp or -XMonadComprehensions") --------- checkTransformStmt :: HsStmtContext Name -> RnM () @@ -1173,7 +1217,10 @@ checkTransformStmt ListComp -- Ensure we are really within a list comprehension = do { transform_list_comp <- xoptM Opt_TransformListComp ; checkErr transform_list_comp msg } where - msg = ptext (sLit "Illegal transform or grouping list comprehension: use -XTransformListComp") + msg = ptext (sLit "Illegal transform or grouping list comprehension: use -XTransformListComp or -XMonadComprehensions") +checkTransformStmt MonadComp -- Monad comprehensions are always fine, since the + -- MonadComprehensions flag will already be turned on + = do { return () } checkTransformStmt (ParStmtCtxt ctxt) = checkTransformStmt ctxt -- Ok to nest inside a parallel comprehension checkTransformStmt (TransformStmtCtxt ctxt) = checkTransformStmt ctxt -- Ok to nest inside a parallel comprehension checkTransformStmt ctxt = addErr msg diff --git a/compiler/typecheck/TcArrows.lhs b/compiler/typecheck/TcArrows.lhs index ae4a1e8..8fdb47c 100644 --- a/compiler/typecheck/TcArrows.lhs +++ b/compiler/typecheck/TcArrows.lhs @@ -213,11 +213,11 @@ tc_cmd env cmd@(HsLam (MatchGroup [L mtch_loc (match@(Match pats _maybe_rhs_sig ------------------------------------------- -- Do notation -tc_cmd env cmd@(HsDo do_or_lc stmts body _ty) (cmd_stk, res_ty) +tc_cmd env cmd@(HsDo do_or_lc stmts body _ _ty) (cmd_stk, res_ty) = do { checkTc (null cmd_stk) (nonEmptyCmdStkErr cmd) ; (stmts', body') <- tcStmts do_or_lc (tcMDoStmt tc_rhs) stmts res_ty $ tcGuardedCmd env body [] - ; return (HsDo do_or_lc stmts' body' res_ty) } + ; return (HsDo do_or_lc stmts' body' noSyntaxExpr res_ty) } where tc_rhs rhs = do { ty <- newFlexiTyVarTy liftedTypeKind ; rhs' <- tcCmd env rhs ([], ty) diff --git a/compiler/typecheck/TcExpr.lhs b/compiler/typecheck/TcExpr.lhs index 6bb0820..a821f25 100644 --- a/compiler/typecheck/TcExpr.lhs +++ b/compiler/typecheck/TcExpr.lhs @@ -415,8 +415,8 @@ tcExpr (HsIf (Just fun) pred b1 b2) res_ty -- Note [Rebindable syntax for if] -- and it maintains uniformity with other rebindable syntax ; return (HsIf (Just fun') pred' b1' b2') } -tcExpr (HsDo do_or_lc stmts body _) res_ty - = tcDoStmts do_or_lc stmts body res_ty +tcExpr (HsDo do_or_lc stmts body return_op _) res_ty + = tcDoStmts do_or_lc stmts body return_op res_ty tcExpr (HsProc pat cmd) res_ty = do { (pat', cmd', coi) <- tcProc pat cmd res_ty diff --git a/compiler/typecheck/TcHsSyn.lhs b/compiler/typecheck/TcHsSyn.lhs index 122b743..357db73 100644 --- a/compiler/typecheck/TcHsSyn.lhs +++ b/compiler/typecheck/TcHsSyn.lhs @@ -578,11 +578,12 @@ zonkExpr env (HsLet binds expr) zonkLExpr new_env expr `thenM` \ new_expr -> returnM (HsLet new_binds new_expr) -zonkExpr env (HsDo do_or_lc stmts body ty) +zonkExpr env (HsDo do_or_lc stmts body return_op ty) = zonkStmts env stmts `thenM` \ (new_env, new_stmts) -> zonkLExpr new_env body `thenM` \ new_body -> + zonkExpr new_env return_op `thenM` \ new_return -> zonkTcTypeToType env ty `thenM` \ new_ty -> - returnM (HsDo do_or_lc new_stmts new_body new_ty) + returnM (HsDo do_or_lc new_stmts new_body new_return new_ty) zonkExpr env (ExplicitList ty exprs) = zonkTcTypeToType env ty `thenM` \ new_ty -> @@ -728,13 +729,16 @@ zonkStmts env (s:ss) = do { (env1, s') <- wrapLocSndM (zonkStmt env) s ; return (env2, s' : ss') } zonkStmt :: ZonkEnv -> Stmt TcId -> TcM (ZonkEnv, Stmt Id) -zonkStmt env (ParStmt stmts_w_bndrs) +zonkStmt env (ParStmt stmts_w_bndrs mzip_op bind_op return_op) = mappM zonk_branch stmts_w_bndrs `thenM` \ new_stmts_w_bndrs -> let new_binders = concat (map snd new_stmts_w_bndrs) env1 = extendZonkEnv env new_binders in - return (env1, ParStmt new_stmts_w_bndrs) + zonkExpr env1 mzip_op `thenM` \ new_mzip -> + zonkExpr env1 bind_op `thenM` \ new_bind -> + zonkExpr env1 return_op `thenM` \ new_return -> + return (env1, ParStmt new_stmts_w_bndrs new_mzip new_bind new_return) where zonk_branch (stmts, bndrs) = zonkStmts env stmts `thenM` \ (env1, new_stmts) -> returnM (new_stmts, zonkIdOccs env1 bndrs) @@ -758,26 +762,32 @@ zonkStmt env (RecStmt { recS_stmts = segStmts, recS_later_ids = lvs, recS_rec_id , recS_mfix_fn = new_mfix_id, recS_bind_fn = new_bind_id , recS_rec_rets = new_rets }) } -zonkStmt env (ExprStmt expr then_op ty) +zonkStmt env (ExprStmt expr then_op guard_op ty) = zonkLExpr env expr `thenM` \ new_expr -> zonkExpr env then_op `thenM` \ new_then -> + zonkExpr env guard_op `thenM` \ new_guard -> zonkTcTypeToType env ty `thenM` \ new_ty -> - returnM (env, ExprStmt new_expr new_then new_ty) + returnM (env, ExprStmt new_expr new_then new_guard new_ty) -zonkStmt env (TransformStmt stmts binders usingExpr maybeByExpr) +zonkStmt env (TransformStmt stmts binders usingExpr maybeByExpr return_op bind_op) = do { (env', stmts') <- zonkStmts env stmts ; let binders' = zonkIdOccs env' binders ; usingExpr' <- zonkLExpr env' usingExpr ; maybeByExpr' <- zonkMaybeLExpr env' maybeByExpr - ; return (env', TransformStmt stmts' binders' usingExpr' maybeByExpr') } + ; return_op' <- zonkExpr env' return_op + ; bind_op' <- zonkExpr env' bind_op + ; return (env', TransformStmt stmts' binders' usingExpr' maybeByExpr' return_op' bind_op') } -zonkStmt env (GroupStmt stmts binderMap by using) +zonkStmt env (GroupStmt stmts binderMap by using return_op bind_op liftM_op) = do { (env', stmts') <- zonkStmts env stmts ; binderMap' <- mappM (zonkBinderMapEntry env') binderMap ; by' <- fmapMaybeM (zonkLExpr env') by ; using' <- fmapEitherM (zonkLExpr env) (zonkExpr env) using + ; return_op' <- zonkExpr env' return_op + ; bind_op' <- zonkExpr env' bind_op + ; liftM_op' <- zonkExpr env' liftM_op ; let env'' = extendZonkEnv env' (map snd binderMap') - ; return (env'', GroupStmt stmts' binderMap' by' using') } + ; return (env'', GroupStmt stmts' binderMap' by' using' return_op' bind_op' liftM_op') } where zonkBinderMapEntry env (oldBinder, newBinder) = do let oldBinder' = zonkIdOcc env oldBinder @@ -1112,4 +1122,4 @@ zonkTypeZapping ty zonk_unbound_tyvar tv = do { let ty = anyTypeOfKind (tyVarKind tv) ; writeMetaTyVar tv ty ; return ty } -\end{code} \ No newline at end of file +\end{code} diff --git a/compiler/typecheck/TcMatches.lhs b/compiler/typecheck/TcMatches.lhs index 860a6db..31aa555 100644 --- a/compiler/typecheck/TcMatches.lhs +++ b/compiler/typecheck/TcMatches.lhs @@ -16,6 +16,7 @@ import {-# SOURCE #-} TcExpr( tcSyntaxOp, tcInferRhoNC, tcCheckId, tcMonoExpr, tcMonoExprNC, tcPolyExpr ) import HsSyn +import BasicTypes import TcRnMonad import TcEnv import TcPat @@ -30,11 +31,13 @@ import TyCon import TysPrim import Coercion ( mkSymCoI ) import Outputable -import BasicTypes ( Arity ) import Util import SrcLoc import FastString +-- Create chunkified tuple tybes for monad comprehensions +import MkCore + import Control.Monad #include "HsVersions.h" @@ -239,35 +242,42 @@ tcGRHS ctxt res_ty (GRHS guards rhs) tcDoStmts :: HsStmtContext Name -> [LStmt Name] -> LHsExpr Name + -> SyntaxExpr Name -- 'return' function for monad + -- comprehensions -> TcRhoType -> TcM (HsExpr TcId) -- Returns a HsDo -tcDoStmts ListComp stmts body res_ty +tcDoStmts ListComp stmts body _ res_ty = do { (coi, elt_ty) <- matchExpectedListTy res_ty ; (stmts', body') <- tcStmts ListComp (tcLcStmt listTyCon) stmts elt_ty $ tcBody body ; return $ mkHsWrapCoI coi - (HsDo ListComp stmts' body' (mkListTy elt_ty)) } + (HsDo ListComp stmts' body' noSyntaxExpr (mkListTy elt_ty)) } -tcDoStmts PArrComp stmts body res_ty +tcDoStmts PArrComp stmts body _ res_ty = do { (coi, elt_ty) <- matchExpectedPArrTy res_ty ; (stmts', body') <- tcStmts PArrComp (tcLcStmt parrTyCon) stmts elt_ty $ tcBody body ; return $ mkHsWrapCoI coi - (HsDo PArrComp stmts' body' (mkPArrTy elt_ty)) } + (HsDo PArrComp stmts' body' noSyntaxExpr (mkPArrTy elt_ty)) } -tcDoStmts DoExpr stmts body res_ty +tcDoStmts DoExpr stmts body _ res_ty = do { (stmts', body') <- tcStmts DoExpr tcDoStmt stmts res_ty $ tcBody body - ; return (HsDo DoExpr stmts' body' res_ty) } + ; return (HsDo DoExpr stmts' body' noSyntaxExpr res_ty) } -tcDoStmts MDoExpr stmts body res_ty +tcDoStmts MDoExpr stmts body _ res_ty = do { (stmts', body') <- tcStmts MDoExpr tcDoStmt stmts res_ty $ tcBody body - ; return (HsDo MDoExpr stmts' body' res_ty) } + ; return (HsDo MDoExpr stmts' body' noSyntaxExpr res_ty) } + +tcDoStmts MonadComp stmts body return_op res_ty + = do { (stmts', (body', return_op')) <- tcStmts MonadComp tcMcStmt stmts res_ty $ + tcMcBody body return_op + ; return $ HsDo MonadComp stmts' body' return_op' res_ty } -tcDoStmts ctxt _ _ _ = pprPanic "tcDoStmts" (pprStmtContext ctxt) +tcDoStmts ctxt _ _ _ _ = pprPanic "tcDoStmts" (pprStmtContext ctxt) tcBody :: LHsExpr Name -> TcRhoType -> TcM (LHsExpr TcId) tcBody body res_ty @@ -326,10 +336,10 @@ tcStmts ctxt stmt_chk (L loc stmt : stmts) res_ty thing_inside -------------------------------- -- Pattern guards tcGuardStmt :: TcStmtChecker -tcGuardStmt _ (ExprStmt guard _ _) res_ty thing_inside +tcGuardStmt _ (ExprStmt guard _ _ _) res_ty thing_inside = do { guard' <- tcMonoExpr guard boolTy ; thing <- thing_inside res_ty - ; return (ExprStmt guard' noSyntaxExpr boolTy, thing) } + ; return (ExprStmt guard' noSyntaxExpr noSyntaxExpr boolTy, thing) } tcGuardStmt ctxt (BindStmt pat rhs _ _) res_ty thing_inside = do { (rhs', rhs_ty) <- tcInferRhoNC rhs -- Stmt has a context already @@ -356,10 +366,10 @@ tcLcStmt m_tc ctxt (BindStmt pat rhs _ _) res_ty thing_inside ; return (BindStmt pat' rhs' noSyntaxExpr noSyntaxExpr, thing) } -- A boolean guard -tcLcStmt _ _ (ExprStmt rhs _ _) res_ty thing_inside +tcLcStmt _ _ (ExprStmt rhs _ _ _) res_ty thing_inside = do { rhs' <- tcMonoExpr rhs boolTy ; thing <- thing_inside res_ty - ; return (ExprStmt rhs' noSyntaxExpr boolTy, thing) } + ; return (ExprStmt rhs' noSyntaxExpr noSyntaxExpr boolTy, thing) } -- A parallel set of comprehensions -- [ (g x, h x) | ... ; let g v = ... @@ -382,9 +392,9 @@ tcLcStmt _ _ (ExprStmt rhs _ _) res_ty thing_inside -- So the binders of the first parallel group will be in scope in the second -- group. But that's fine; there's no shadowing to worry about. -tcLcStmt m_tc ctxt (ParStmt bndr_stmts_s) elt_ty thing_inside +tcLcStmt m_tc ctxt (ParStmt bndr_stmts_s _ _ _) elt_ty thing_inside = do { (pairs', thing) <- loop bndr_stmts_s - ; return (ParStmt pairs', thing) } + ; return (ParStmt pairs' noSyntaxExpr noSyntaxExpr noSyntaxExpr, thing) } where -- loop :: [([LStmt Name], [Name])] -> TcM ([([LStmt TcId], [TcId])], thing) loop [] = do { thing <- thing_inside elt_ty @@ -398,7 +408,7 @@ tcLcStmt m_tc ctxt (ParStmt bndr_stmts_s) elt_ty thing_inside ; return (ids, pairs', thing) } ; return ( (stmts', ids) : pairs', thing ) } -tcLcStmt m_tc ctxt (TransformStmt stmts binders usingExpr maybeByExpr) elt_ty thing_inside = do +tcLcStmt m_tc ctxt (TransformStmt stmts binders usingExpr maybeByExpr _ _) elt_ty thing_inside = do (stmts', (binders', usingExpr', maybeByExpr', thing)) <- tcStmts (TransformStmtCtxt ctxt) (tcLcStmt m_tc) stmts elt_ty $ \elt_ty' -> do let alphaListTy = mkTyConApp m_tc [alphaTy] @@ -425,9 +435,9 @@ tcLcStmt m_tc ctxt (TransformStmt stmts binders usingExpr maybeByExpr) elt_ty th return (binders', usingExpr', maybeByExpr', thing) - return (TransformStmt stmts' binders' usingExpr' maybeByExpr', thing) + return (TransformStmt stmts' binders' usingExpr' maybeByExpr' noSyntaxExpr noSyntaxExpr, thing) -tcLcStmt m_tc ctxt (GroupStmt stmts bindersMap by using) elt_ty thing_inside +tcLcStmt m_tc ctxt (GroupStmt stmts bindersMap by using _ _ _) elt_ty thing_inside = do { let (bndr_names, list_bndr_names) = unzip bindersMap ; (stmts', (bndr_ids, by', using_ty, elt_ty')) <- @@ -463,7 +473,7 @@ tcLcStmt m_tc ctxt (GroupStmt stmts bindersMap by using) elt_ty thing_inside -- Type check the thing in the environment with -- these new binders and return the result ; thing <- tcExtendIdEnv list_bndr_ids (thing_inside elt_ty') - ; return (GroupStmt stmts' bindersMap' by' using', thing) } + ; return (GroupStmt stmts' bindersMap' by' using' noSyntaxExpr noSyntaxExpr noSyntaxExpr, thing) } where alphaListTy = mkTyConApp m_tc [alphaTy] alphaListListTy = mkTyConApp m_tc [alphaListTy] @@ -475,6 +485,298 @@ tcLcStmt m_tc ctxt (GroupStmt stmts bindersMap by using) elt_ty thing_inside tcLcStmt _ _ stmt _ _ = pprPanic "tcLcStmt: unexpected Stmt" (ppr stmt) + +-------------------------------- +-- Monad comprehensions + +tcMcStmt :: TcStmtChecker + +-- Generators for monad comprehensions ( pat <- rhs ) +-- +-- [ body | q <- gen ] -> gen :: m a +-- q :: a +-- +tcMcStmt ctxt (BindStmt pat rhs bind_op fail_op) res_ty thing_inside + = do { rhs_ty <- newFlexiTyVarTy liftedTypeKind + ; pat_ty <- newFlexiTyVarTy liftedTypeKind + ; new_res_ty <- newFlexiTyVarTy liftedTypeKind + ; bind_op' <- tcSyntaxOp MCompOrigin bind_op + (mkFunTys [rhs_ty, mkFunTy pat_ty new_res_ty] res_ty) + + -- If (but only if) the pattern can fail, + -- typecheck the 'fail' operator + ; fail_op' <- if isIrrefutableHsPat pat + then return noSyntaxExpr + else tcSyntaxOp MCompOrigin fail_op (mkFunTy stringTy new_res_ty) + + ; rhs' <- tcMonoExprNC rhs rhs_ty + ; (pat', thing) <- tcPat (StmtCtxt ctxt) pat pat_ty $ + thing_inside new_res_ty + + ; return (BindStmt pat' rhs' bind_op' fail_op', thing) } + +-- Boolean expressions. +-- +-- [ body | stmts, expr ] -> expr :: m Bool +-- +tcMcStmt _ (ExprStmt rhs then_op guard_op _) res_ty thing_inside + = do { -- Deal with rebindable syntax: + -- guard_op :: test_ty -> rhs_ty + -- then_op :: rhs_ty -> new_res_ty -> res_ty + -- Where test_ty is, for example, Bool + test_ty <- newFlexiTyVarTy liftedTypeKind + ; rhs_ty <- newFlexiTyVarTy liftedTypeKind + ; new_res_ty <- newFlexiTyVarTy liftedTypeKind + ; rhs' <- tcMonoExpr rhs test_ty + ; guard_op' <- tcSyntaxOp MCompOrigin guard_op + (mkFunTy test_ty rhs_ty) + ; then_op' <- tcSyntaxOp MCompOrigin then_op + (mkFunTys [rhs_ty, new_res_ty] res_ty) + ; thing <- thing_inside new_res_ty + ; return (ExprStmt rhs' then_op' guard_op' rhs_ty, thing) } + +-- Transform statements. +-- +-- [ body | stmts, then f ] -> f :: forall a. m a -> m a +-- [ body | stmts, then f by e ] -> f :: forall a. (a -> t) -> m a -> m a +-- +tcMcStmt ctxt (TransformStmt stmts binders usingExpr maybeByExpr return_op bind_op) elt_ty thing_inside + = do { + -- We don't know the types of binders yet, so we use this dummy and + -- later unify this type with the `m_bndr_ty` + ty_dummy <- newFlexiTyVarTy liftedTypeKind + + ; (stmts', (binders', usingExpr', maybeByExpr', return_op', bind_op', thing)) <- + tcStmts (TransformStmtCtxt ctxt) tcMcStmt stmts ty_dummy $ \elt_ty' -> do + { (_, (m_ty, _)) <- matchExpectedAppTy elt_ty' + ; (usingExpr', maybeByExpr') <- + case maybeByExpr of + Nothing -> do + -- We must validate that usingExpr :: forall a. m a -> m a + let using_ty = mkForAllTy alphaTyVar $ + (m_ty `mkAppTy` alphaTy) + `mkFunTy` + (m_ty `mkAppTy` alphaTy) + usingExpr' <- tcPolyExpr usingExpr using_ty + return (usingExpr', Nothing) + Just byExpr -> do + -- We must infer a type such that e :: t and then check that + -- usingExpr :: forall a. (a -> t) -> m a -> m a + (byExpr', tTy) <- tcInferRhoNC byExpr + let using_ty = mkForAllTy alphaTyVar $ + (alphaTy `mkFunTy` tTy) + `mkFunTy` + (m_ty `mkAppTy` alphaTy) + `mkFunTy` + (m_ty `mkAppTy` alphaTy) + usingExpr' <- tcPolyExpr usingExpr using_ty + return (usingExpr', Just byExpr') + + ; bndr_ids <- tcLookupLocalIds binders + + -- `return` and `>>=` are used to pass around/modify our + -- binders, so we know their types: + -- + -- return :: (a,b,c,..) -> m (a,b,c,..) + -- (>>=) :: m (a,b,c,..) + -- -> ( (a,b,c,..) -> m (a,b,c,..) ) + -- -> m (a,b,c,..) + -- + ; let bndr_ty = mkChunkified mkBoxedTupleTy $ map idType bndr_ids + m_bndr_ty = m_ty `mkAppTy` bndr_ty + + ; return_op' <- tcSyntaxOp MCompOrigin return_op + (bndr_ty `mkFunTy` m_bndr_ty) + + ; bind_op' <- tcSyntaxOp MCompOrigin bind_op $ + m_bndr_ty `mkFunTy` (bndr_ty `mkFunTy` elt_ty) + `mkFunTy` elt_ty + + -- Unify types of the inner comprehension and the binders type + ; _ <- unifyType elt_ty' m_bndr_ty + + -- Typecheck the `thing` with out old type (which is the type + -- of the final result of our comprehension) + ; thing <- thing_inside elt_ty + + ; return (bndr_ids, usingExpr', maybeByExpr', return_op', bind_op', thing) } + + ; return (TransformStmt stmts' binders' usingExpr' maybeByExpr' return_op' bind_op', thing) } + +-- Grouping statements +-- +-- [ body | stmts, then group by e ] +-- -> e :: t +-- [ body | stmts, then group by e using f ] +-- -> e :: t +-- f :: forall a. (a -> t) -> m a -> m (m a) +-- [ body | stmts, then group using f ] +-- -> f :: forall a. m a -> m (m a) +-- +tcMcStmt ctxt (GroupStmt stmts bindersMap by using return_op bind_op liftM_op) elt_ty thing_inside + = do { let (bndr_names, m_bndr_names) = unzip bindersMap + + ; (_,(m_ty,_)) <- matchExpectedAppTy elt_ty + ; let alphaMTy = m_ty `mkAppTy` alphaTy + alphaMMTy = m_ty `mkAppTy` alphaMTy + + -- We don't know the type of the bindings yet. It's not elt_ty! + ; bndr_ty_dummy <- newFlexiTyVarTy liftedTypeKind + + ; (stmts', (bndr_ids, by', using_ty, return_op', bind_op')) <- + tcStmts (TransformStmtCtxt ctxt) tcMcStmt stmts bndr_ty_dummy $ \elt_ty' -> do + { (by', using_ty) <- + case by of + Nothing -> -- check that using :: forall a. m a -> m (m a) + return (Nothing, mkForAllTy alphaTyVar $ + alphaMTy `mkFunTy` alphaMMTy) + + Just by_e -> -- check that using :: forall a. (a -> t) -> m a -> m (m a) + -- where by :: t + do { (by_e', t_ty) <- tcInferRhoNC by_e + ; return (Just by_e', mkForAllTy alphaTyVar $ + (alphaTy `mkFunTy` t_ty) + `mkFunTy` alphaMTy + `mkFunTy` alphaMMTy) } + + + -- Find the Ids (and hence types) of all old binders + ; bndr_ids <- tcLookupLocalIds bndr_names + + -- 'return' is only used for the binders, so we know its type. + -- + -- return :: (a,b,c,..) -> m (a,b,c,..) + -- + ; let bndr_ty = mkChunkified mkBoxedTupleTy $ map idType bndr_ids + m_bndr_ty = m_ty `mkAppTy` bndr_ty + ; return_op' <- tcSyntaxOp MCompOrigin return_op $ bndr_ty `mkFunTy` m_bndr_ty + + -- '>>=' is used to pass the grouped binders to the rest of the + -- comprehension. + -- + -- (>>=) :: m (m a, m b, m c, ..) + -- -> ( (m a, m b, m c, ..) -> new_elt_ty ) + -- -> elt_ty + -- + ; let bndr_m_ty = mkChunkified mkBoxedTupleTy $ map (mkAppTy m_ty . idType) bndr_ids + m_bndr_m_ty = m_ty `mkAppTy` bndr_m_ty + ; new_elt_ty <- newFlexiTyVarTy liftedTypeKind + ; bind_op' <- tcSyntaxOp MCompOrigin bind_op $ + m_bndr_m_ty `mkFunTy` (bndr_m_ty `mkFunTy` new_elt_ty) + `mkFunTy` elt_ty + + -- Finally make sure the type of the inner comprehension + -- represents the types of our binders + ; _ <- unifyType elt_ty' m_bndr_ty + + ; return (bndr_ids, by', using_ty, return_op', bind_op') } + + ; let mk_m_bndr :: Name -> TcId -> TcId + mk_m_bndr m_bndr_name bndr_id = + mkLocalId m_bndr_name (m_ty `mkAppTy` idType bndr_id) + + -- Ensure that every old binder of type `b` is linked up with its + -- new binder which should have type `m b` + m_bndr_ids = zipWith mk_m_bndr m_bndr_names bndr_ids + bindersMap' = bndr_ids `zip` m_bndr_ids + + -- See Note [GroupStmt binder map] in HsExpr + + ; using' <- case using of + Left e -> do { e' <- tcPolyExpr e using_ty; return (Left e') } + Right e -> do { e' <- tcPolyExpr (noLoc e) using_ty; return (Right (unLoc e')) } + + -- Type check 'liftM' with 'forall a b. (a -> b) -> m_ty a -> m_ty b' + ; liftM_op' <- fmap unLoc . tcPolyExpr (noLoc liftM_op) $ + mkForAllTy alphaTyVar $ mkForAllTy betaTyVar $ + (alphaTy `mkFunTy` betaTy) + `mkFunTy` + (m_ty `mkAppTy` alphaTy) + `mkFunTy` + (m_ty `mkAppTy` betaTy) + + -- Type check the thing in the environment with these new binders and + -- return the result + ; thing <- tcExtendIdEnv m_bndr_ids (thing_inside elt_ty) + + ; return (GroupStmt stmts' bindersMap' by' using' return_op' bind_op' liftM_op', thing) } + +-- Typecheck `ParStmt`. See `tcLcStmt` for more informations about typechecking +-- of `ParStmt`s. +-- +-- Note: The `mzip` function will get typechecked via: +-- +-- ParStmt [st1::t1, st2::t2, st3::t3] +-- +-- mzip :: m st1 +-- -> (m st2 -> m st3 -> m (st2, st3)) -- recursive call +-- -> m (st1, (st2, st3)) +-- +tcMcStmt ctxt (ParStmt bndr_stmts_s mzip_op bind_op return_op) elt_ty thing_inside + = do { (_,(m_ty,_)) <- matchExpectedAppTy elt_ty + ; (pairs', thing) <- loop m_ty bndr_stmts_s + + ; let mzip_ty = mkForAllTys [alphaTyVar, betaTyVar] $ + (m_ty `mkAppTy` alphaTy) + `mkFunTy` + (m_ty `mkAppTy` betaTy) + `mkFunTy` + (m_ty `mkAppTy` mkBoxedTupleTy [alphaTy, betaTy]) + ; mzip_op' <- unLoc `fmap` tcPolyExpr (noLoc mzip_op) mzip_ty + + -- Typecheck bind: + ; let tys = map (mkChunkified mkBoxedTupleTy . map idType . snd) pairs' + tuple_ty = mk_tuple_ty tys + + ; bind_op' <- tcSyntaxOp MCompOrigin bind_op $ + (m_ty `mkAppTy` tuple_ty) + `mkFunTy` + (tuple_ty `mkFunTy` elt_ty) + `mkFunTy` + elt_ty + + ; return_op' <- fmap unLoc . tcPolyExpr (noLoc return_op) $ + mkForAllTy alphaTyVar $ + alphaTy `mkFunTy` (m_ty `mkAppTy` alphaTy) + ; return (ParStmt pairs' mzip_op' bind_op' return_op', thing) } + + where mk_tuple_ty tys = foldr (\tn tm -> mkBoxedTupleTy [tn, tm]) (last tys) (init tys) + + -- loop :: Type -- m_ty + -- -> [([LStmt Name], [Name])] + -- -> TcM ([([LStmt TcId], [TcId])], thing) + loop _ [] = do { thing <- thing_inside elt_ty + ; return ([], thing) } -- matching in the branches + + loop m_ty ((stmts, names) : pairs) + = do { -- type dummy since we don't know all binder types yet + ty_dummy <- newFlexiTyVarTy liftedTypeKind + ; (stmts', (ids, pairs', thing)) + <- tcStmts ctxt tcMcStmt stmts ty_dummy $ \elt_ty' -> + do { ids <- tcLookupLocalIds names + ; _ <- unifyType elt_ty' (m_ty `mkAppTy` (mkChunkified mkBoxedTupleTy) (map idType ids)) + ; (pairs', thing) <- loop m_ty pairs + ; return (ids, pairs', thing) } + ; return ( (stmts', ids) : pairs', thing ) } + +tcMcStmt _ stmt _ _ + = pprPanic "tcMcStmt: unexpected Stmt" (ppr stmt) + +-- Typecheck 'body' with type 'a' instead of 'm a' like the rest of the +-- statements, ignore the second type argument coming from the tcStmts loop +tcMcBody :: LHsExpr Name + -> SyntaxExpr Name + -> TcRhoType + -> TcM (LHsExpr TcId, SyntaxExpr TcId) +tcMcBody body return_op res_ty + = do { (_, (_, a_ty)) <- matchExpectedAppTy res_ty + ; body' <- tcMonoExpr body a_ty + ; return_op' <- tcSyntaxOp MCompOrigin return_op + (a_ty `mkFunTy` res_ty) + ; return (body', return_op') + } + + -------------------------------- -- Do-notation -- The main excitement here is dealing with rebindable syntax @@ -510,7 +812,7 @@ tcDoStmt ctxt (BindStmt pat rhs bind_op fail_op) res_ty thing_inside ; return (BindStmt pat' rhs' bind_op' fail_op', thing) } -tcDoStmt _ (ExprStmt rhs then_op _) res_ty thing_inside +tcDoStmt _ (ExprStmt rhs then_op _ _) res_ty thing_inside = do { -- Deal with rebindable syntax; -- (>>) :: rhs_ty -> new_res_ty -> res_ty -- See also Note [Treat rebindable syntax first] @@ -521,7 +823,7 @@ tcDoStmt _ (ExprStmt rhs then_op _) res_ty thing_inside ; rhs' <- tcMonoExprNC rhs rhs_ty ; thing <- thing_inside new_res_ty - ; return (ExprStmt rhs' then_op' rhs_ty, thing) } + ; return (ExprStmt rhs' then_op' noSyntaxExpr rhs_ty, thing) } tcDoStmt ctxt (RecStmt { recS_stmts = stmts, recS_later_ids = later_names , recS_rec_ids = rec_names, recS_ret_fn = ret_op @@ -592,10 +894,10 @@ tcMDoStmt tc_rhs ctxt (BindStmt pat rhs _ _) res_ty thing_inside thing_inside res_ty ; return (BindStmt pat' rhs' noSyntaxExpr noSyntaxExpr, thing) } -tcMDoStmt tc_rhs _ (ExprStmt rhs _ _) res_ty thing_inside +tcMDoStmt tc_rhs _ (ExprStmt rhs _ _ _) res_ty thing_inside = do { (rhs', elt_ty) <- tc_rhs rhs ; thing <- thing_inside res_ty - ; return (ExprStmt rhs' noSyntaxExpr elt_ty, thing) } + ; return (ExprStmt rhs' noSyntaxExpr noSyntaxExpr elt_ty, thing) } tcMDoStmt tc_rhs ctxt (RecStmt { recS_stmts = stmts, recS_later_ids = laterNames , recS_rec_ids = recNames }) res_ty thing_inside @@ -620,6 +922,7 @@ tcMDoStmt tc_rhs ctxt (RecStmt { recS_stmts = stmts, recS_later_ids = laterNames tcMDoStmt _ _ stmt _ _ = pprPanic "tcMDoStmt: unexpected Stmt" (ppr stmt) + \end{code} diff --git a/compiler/typecheck/TcRnDriver.lhs b/compiler/typecheck/TcRnDriver.lhs index 23c2e67..b9f7913 100644 --- a/compiler/typecheck/TcRnDriver.lhs +++ b/compiler/typecheck/TcRnDriver.lhs @@ -1205,7 +1205,7 @@ runPlans (p:ps) = tryTcLIE_ (runPlans ps) p -------------------- mkPlan :: LStmt Name -> TcM PlanResult -mkPlan (L loc (ExprStmt expr _ _)) -- An expression typed at the prompt +mkPlan (L loc (ExprStmt expr _ _ _)) -- An expression typed at the prompt = do { uniq <- newUnique -- is treated very specially ; let fresh_it = itName uniq the_bind = L loc $ mkFunBind (L loc fresh_it) matches @@ -1214,7 +1214,7 @@ mkPlan (L loc (ExprStmt expr _ _)) -- An expression typed at the prompt bind_stmt = L loc $ BindStmt (nlVarPat fresh_it) expr (HsVar bindIOName) noSyntaxExpr print_it = L loc $ ExprStmt (nlHsApp (nlHsVar printName) (nlHsVar fresh_it)) - (HsVar thenIOName) placeHolderType + (HsVar thenIOName) noSyntaxExpr placeHolderType -- The plans are: -- [it <- e; print it] but not if it::() @@ -1242,7 +1242,7 @@ mkPlan (L loc (ExprStmt expr _ _)) -- An expression typed at the prompt mkPlan stmt@(L loc (BindStmt {})) | [v] <- collectLStmtBinders stmt -- One binder, for a bind stmt = do { let print_v = L loc $ ExprStmt (nlHsApp (nlHsVar printName) (nlHsVar v)) - (HsVar thenIOName) placeHolderType + (HsVar thenIOName) noSyntaxExpr placeHolderType ; print_bind_result <- doptM Opt_PrintBindResult ; let print_plan = do @@ -1304,7 +1304,7 @@ tcGhciStmts stmts traceTc "TcRnDriver.tcGhciStmts: done" empty ; return (ids, mkHsDictLet (EvBinds const_binds) $ - noLoc (HsDo GhciStmt tc_stmts (mk_return ids) io_ret_ty)) + noLoc (HsDo GhciStmt tc_stmts (mk_return ids) noSyntaxExpr io_ret_ty)) } \end{code} diff --git a/compiler/typecheck/TcRnTypes.lhs b/compiler/typecheck/TcRnTypes.lhs index 8858c13..4b174e5 100644 --- a/compiler/typecheck/TcRnTypes.lhs +++ b/compiler/typecheck/TcRnTypes.lhs @@ -1112,6 +1112,7 @@ data CtOrigin | StandAloneDerivOrigin -- Typechecking stand-alone deriving | DefaultOrigin -- Typechecking a default decl | DoOrigin -- Arising from a do expression + | MCompOrigin -- Arising from a monad comprehension | IfOrigin -- Arising from an if statement | ProcOrigin -- Arising from a proc expression | AnnOrigin -- An annotation @@ -1147,6 +1148,7 @@ pprO DerivOrigin = ptext (sLit "the 'deriving' clause of a data type declarat pprO StandAloneDerivOrigin = ptext (sLit "a 'deriving' declaration") pprO DefaultOrigin = ptext (sLit "a 'default' declaration") pprO DoOrigin = ptext (sLit "a do statement") +pprO MCompOrigin = ptext (sLit "a statement in a monad comprehension") pprO ProcOrigin = ptext (sLit "a proc expression") pprO (TypeEqOrigin eq) = ptext (sLit "an equality") <+> ppr eq pprO AnnOrigin = ptext (sLit "an annotation") diff --git a/docs/users_guide/flags.xml b/docs/users_guide/flags.xml index 26ab9eb..add2f5e 100644 --- a/docs/users_guide/flags.xml +++ b/docs/users_guide/flags.xml @@ -898,6 +898,12 @@ dynamic + + + Enable monad comprehensions. + dynamic + + Enable unlifted FFI types. diff --git a/docs/users_guide/glasgow_exts.xml b/docs/users_guide/glasgow_exts.xml index 9ea3332..54a4833 100644 --- a/docs/users_guide/glasgow_exts.xml +++ b/docs/users_guide/glasgow_exts.xml @@ -1201,6 +1201,168 @@ output = [ x + + + + Monad comprehensions + monad comprehensions + + + Monad comprehesions generalise the list comprehension notation to work + for any monad. + + + Monad comprehensions support: + + + + + Bindings: + + + +[ x + y | x <- Just 1, y <- Just 2 ] + + + + Bindings are translated with the (>>=) and + return functions to the usual do-notation: + + + +do x <- Just 1 + y <- Just 2 + return (x+y) + + + + + + Guards: + + + +[ x | x <- [1..10], x <= 5 ] + + + + Guards are translated with the guard function, + which requires a MonadPlus instance: + + + +do x <- [1..10] + guard (x <= 5) + return x + + + + + + Transform statements (as with -XTransformListComp): + + + +[ x+y | x <- [1..10], y <- [1..x], then take 2 ] + + + + This translates to: + + + +do (x,y) <- take 2 (do x <- [1..10] + y <- [1..x] + return (x,y)) + return (x+y) + + + + + + Group statements (as with -XTransformListComp): + + + +[ x | x <- [1,1,2,2,3], then group by x ] +[ x | x <- [1,1,2,2,3], then group by x using GHC.Exts.groupWith ] +[ x | x <- [1,1,2,2,3], then group using myGroup ] + + + + The basic then group by e statement is + translated using the mgroupWith function, which + requires a MonadGroup instance, defined in + Control.Monad.Group: + + + +do x <- mgroupWith (do x <- [1,1,2,2,3] + return x) + return x + + + + Note that the type of x is changed by the + grouping statement. + + + + The grouping function can also be defined with the + using keyword. + + + + + + Parallel statements (as with -XParallelListComp): + + + +[ (x+y) | x <- [1..10] + | y <- [11..20] + ] + + + + Parallel statements are translated using the + mzip function, which requires a + MonadZip instance defined in + Control.Monad.Zip: + + + +do (x,y) <- mzip (do x <- [1..10] + return x) + (do y <- [11..20] + return y) + return (x+y) + + + + + + + All these features are enabled by default if the + MonadComprehensions extension is enabled. The types + and more detailed examples on how to use comprehensions are explained + in the previous chapters and . In general you just have + to replace the type [a] with the type + Monad m => m a for monad comprehensions. + + + + Note: Even though most of these examples are using the list monad, + monad comprehensions work for any monad. + The base package offers all necessary instances for + lists, which make MonadComprehensions backward + compatible to built-in, transform and parallel list comprehensions. + + + +