Merge branch monad-comp onto master
authorSimon Peyton Jones <simonpj@microsoft.com>
Wed, 4 May 2011 15:37:08 +0000 (16:37 +0100)
committerSimon Peyton Jones <simonpj@microsoft.com>
Wed, 4 May 2011 15:37:08 +0000 (16:37 +0100)
This patch implements monad comprehensions, Trac #4370.
Thanks to Nils Schweinsberg for doing most of the heavy lifting.

I did quite a lot of related refactoring as well.  Notably:

* Combined TransformStmt and GroupStmt into a single
  constructor TransStmt; they share a lot of code.
  I also made TransStmt into a record; it has a lot of fields.

* Remove the "result expression" field of HsDo, and instead
  implement LastStmt, which is expected to be at the end
  of a list of Stmts

* Generalise and tidy up the typechecking of monad comprehensions

* Do-notation in arrows is marked with HsStmtContext = ArrowExpr

* tcMDoStmt (which was only used for arrows) is moved
  to TcArrows, and renamed tcArrDoStmt

* Improved documentation in the user manual

* Lots of other minor changes

33 files changed:
compiler/deSugar/Check.lhs
compiler/deSugar/Coverage.lhs
compiler/deSugar/DsArrows.lhs
compiler/deSugar/DsExpr.lhs
compiler/deSugar/DsGRHSs.lhs
compiler/deSugar/DsListComp.lhs
compiler/deSugar/DsMeta.hs
compiler/hsSyn/Convert.lhs
compiler/hsSyn/HsExpr.lhs
compiler/hsSyn/HsLit.lhs
compiler/hsSyn/HsPat.lhs
compiler/hsSyn/HsUtils.lhs
compiler/main/DynFlags.hs
compiler/main/HscMain.lhs
compiler/parser/Lexer.x
compiler/parser/Parser.y.pp
compiler/parser/RdrHsSyn.lhs
compiler/prelude/PrelNames.lhs
compiler/rename/RnBinds.lhs
compiler/rename/RnExpr.lhs
compiler/typecheck/TcArrows.lhs
compiler/typecheck/TcExpr.lhs
compiler/typecheck/TcGenDeriv.lhs
compiler/typecheck/TcHsSyn.lhs
compiler/typecheck/TcMatches.lhs
compiler/typecheck/TcPat.lhs
compiler/typecheck/TcRnDriver.lhs
compiler/typecheck/TcRnMonad.lhs
compiler/typecheck/TcRnTypes.lhs
compiler/typecheck/TcSMonad.lhs
compiler/typecheck/TcUnify.lhs
docs/users_guide/flags.xml
docs/users_guide/glasgow_exts.xml

index 3d3aa4f..fa85a1d 100644 (file)
@@ -110,9 +110,11 @@ type EqnSet = UniqSet EqnNo
 check :: [EquationInfo] -> ([ExhaustivePat], [EquationInfo])
   -- Second result is the shadowed equations
   -- if there are view patterns, just give up - don't know what the function is
-check qs = (untidy_warns, shadowed_eqns)
+check qs = pprTrace "check" (ppr tidy_qs) $
+           (untidy_warns, shadowed_eqns)
       where
-       (warns, used_nos) = check' ([1..] `zip` map tidy_eqn qs)
+        tidy_qs = map tidy_eqn qs
+       (warns, used_nos) = check' ([1..] `zip` tidy_qs)
        untidy_warns = map untidy_exhaustive warns 
        shadowed_eqns = [eqn | (eqn,i) <- qs `zip` [1..], 
                                not (i `elementOfUniqSet` used_nos)]
index 0daa6be..8071da7 100644 (file)
@@ -301,10 +301,9 @@ addTickHsExpr (HsLet binds e) =
        liftM2 HsLet
                (addTickHsLocalBinds binds) -- to think about: !patterns.
                 (addTickLHsExprNeverOrAlways e)
-addTickHsExpr (HsDo cxt stmts last_exp srcloc) = do
-        (stmts', last_exp') <- addTickLStmts' forQual stmts 
-                                     (addTickLHsExpr last_exp)
-       return (HsDo cxt stmts' last_exp' srcloc)
+addTickHsExpr (HsDo cxt stmts srcloc) 
+  = do { (stmts', _) <- addTickLStmts' forQual stmts (return ())
+       ; return (HsDo cxt stmts' srcloc) }
   where
        forQual = case cxt of
                    ListComp -> Just $ BinBox QualBinBox
@@ -424,45 +423,50 @@ addTickLStmts isGuard stmts = do
 addTickLStmts' :: (Maybe (Bool -> BoxLabel)) -> [LStmt Id] -> TM a 
                -> TM ([LStmt Id], a)
 addTickLStmts' isGuard lstmts res
-  = bindLocals binders $ do
-        lstmts' <- mapM (liftL (addTickStmt isGuard)) lstmts
-        a <- res
-        return (lstmts', a)
-  where
-        binders = collectLStmtsBinders lstmts
+  = bindLocals (collectLStmtsBinders lstmts) $ 
+    do { lstmts' <- mapM (liftL (addTickStmt isGuard)) lstmts
+       ; a <- res
+       ; return (lstmts', a) }
 
 addTickStmt :: (Maybe (Bool -> BoxLabel)) -> Stmt Id -> TM (Stmt Id)
+addTickStmt _isGuard (LastStmt e ret) = do
+       liftM2 LastStmt
+               (addTickLHsExpr e)
+               (addTickSyntaxExpr hpcSrcSpan ret)
 addTickStmt _isGuard (BindStmt pat e bind fail) = do
        liftM4 BindStmt
                (addTickLPat pat)
                (addTickLHsExprAlways e)
                (addTickSyntaxExpr hpcSrcSpan bind)
                (addTickSyntaxExpr hpcSrcSpan fail)
-addTickStmt isGuard (ExprStmt e bind' ty) = do
-       liftM3 ExprStmt
+addTickStmt isGuard (ExprStmt e bind' guard' ty) = do
+       liftM4 ExprStmt
                (addTick isGuard e)
                (addTickSyntaxExpr hpcSrcSpan bind')
+               (addTickSyntaxExpr hpcSrcSpan guard')
                (return ty)
 addTickStmt _isGuard (LetStmt binds) = do
        liftM LetStmt
                (addTickHsLocalBinds binds)
-addTickStmt isGuard (ParStmt pairs) = do
-    liftM ParStmt 
+addTickStmt isGuard (ParStmt pairs mzipExpr bindExpr returnExpr) = do
+    liftM4 ParStmt 
         (mapM (addTickStmtAndBinders isGuard) pairs)
-
-addTickStmt isGuard (TransformStmt stmts ids usingExpr maybeByExpr) = do
-    liftM4 TransformStmt 
-        (addTickLStmts isGuard stmts)
-        (return ids)
-        (addTickLHsExprAlways usingExpr)
-        (addTickMaybeByLHsExpr maybeByExpr)
-
-addTickStmt isGuard (GroupStmt stmts binderMap by using) = do
-    liftM4 GroupStmt 
-        (addTickLStmts isGuard stmts)
-        (return binderMap)
-        (fmapMaybeM  addTickLHsExprAlways by)
-       (fmapEitherM addTickLHsExprAlways (addTickSyntaxExpr hpcSrcSpan) using)
+        (addTickSyntaxExpr hpcSrcSpan mzipExpr)
+        (addTickSyntaxExpr hpcSrcSpan bindExpr)
+        (addTickSyntaxExpr hpcSrcSpan returnExpr)
+
+addTickStmt isGuard stmt@(TransStmt { trS_stmts = stmts
+                                    , trS_by = by, trS_using = using
+                                    , trS_ret = returnExpr, trS_bind = bindExpr
+                                    , trS_fmap = liftMExpr }) = do
+    t_s <- addTickLStmts isGuard stmts
+    t_y <- fmapMaybeM  addTickLHsExprAlways by
+    t_u <- addTickLHsExprAlways using
+    t_f <- addTickSyntaxExpr hpcSrcSpan returnExpr
+    t_b <- addTickSyntaxExpr hpcSrcSpan bindExpr
+    t_m <- addTickSyntaxExpr hpcSrcSpan liftMExpr
+    return $ stmt { trS_stmts = t_s, trS_by = t_y, trS_using = t_u
+                  , trS_ret = t_f, trS_bind = t_b, trS_fmap = t_m }
 
 addTickStmt isGuard stmt@(RecStmt {})
   = do { stmts' <- addTickLStmts isGuard (recS_stmts stmt)
@@ -483,12 +487,6 @@ addTickStmtAndBinders isGuard (stmts, ids) =
         (addTickLStmts isGuard stmts)
         (return ids)
 
-addTickMaybeByLHsExpr :: Maybe (LHsExpr Id) -> TM (Maybe (LHsExpr Id))
-addTickMaybeByLHsExpr maybeByExpr = 
-    case maybeByExpr of
-        Nothing -> return Nothing
-        Just byExpr -> addTickLHsExprAlways byExpr >>= (return . Just)
-
 addTickHsLocalBinds :: HsLocalBinds Id -> TM (HsLocalBinds Id)
 addTickHsLocalBinds (HsValBinds binds) = 
        liftM HsValBinds 
@@ -569,9 +567,9 @@ addTickHsCmd (HsLet binds c) =
        liftM2 HsLet
                (addTickHsLocalBinds binds) -- to think about: !patterns.
                 (addTickLHsCmd c)
-addTickHsCmd (HsDo cxt stmts last_exp srcloc) = do
-        (stmts', last_exp') <- addTickLCmdStmts' stmts (addTickLHsCmd last_exp)
-       return (HsDo cxt stmts' last_exp' srcloc)
+addTickHsCmd (HsDo cxt stmts srcloc)
+  = do { (stmts', _) <- addTickLCmdStmts' stmts (return ())
+       ; return (HsDo cxt stmts' srcloc) }
 
 addTickHsCmd (HsArrApp  e1 e2 ty1 arr_ty lr) = 
         liftM5 HsArrApp
@@ -635,10 +633,15 @@ addTickCmdStmt (BindStmt pat c bind fail) = do
                (addTickLHsCmd c)
                (return bind)
                (return fail)
-addTickCmdStmt (ExprStmt c bind' ty) = do
-       liftM3 ExprStmt
+addTickCmdStmt (LastStmt c ret) = do
+       liftM2 LastStmt
+               (addTickLHsCmd c)
+               (addTickSyntaxExpr hpcSrcSpan ret)
+addTickCmdStmt (ExprStmt c bind' guard' ty) = do
+       liftM4 ExprStmt
                (addTickLHsCmd c)
-               (return bind')
+               (addTickSyntaxExpr hpcSrcSpan bind')
+                (addTickSyntaxExpr hpcSrcSpan guard')
                (return ty)
 addTickCmdStmt (LetStmt binds) = do
        liftM LetStmt
index 58bf6b8..a5bf2b6 100644 (file)
@@ -541,8 +541,8 @@ dsCmd ids local_vars env_ids stack res_ty (HsLet binds body) = do
                         core_body,
         exprFreeVars core_binds `intersectVarSet` local_vars)
 
-dsCmd ids local_vars env_ids [] res_ty (HsDo _ctxt stmts body _)
-  = dsCmdDo ids local_vars env_ids res_ty stmts body
+dsCmd ids local_vars env_ids [] res_ty (HsDo _ctxt stmts _)
+  = dsCmdDo ids local_vars env_ids res_ty stmts 
 
 --     A |- e :: forall e. a1 (e*ts1) t1 -> ... an (e*tsn) tn -> a (e*ts) t
 --     A | xs |- ci :: [tsi] ti
@@ -618,7 +618,6 @@ dsCmdDo :: DsCmdEnv         -- arrow combinators
                                -- so don't pull on it too early
        -> Type                 -- return type of the statement
        -> [LStmt Id]           -- statements to desugar
-       -> LHsExpr Id           -- body
        -> DsM (CoreExpr,       -- desugared expression
                IdSet)          -- set of local vars that occur free
 
@@ -626,15 +625,17 @@ dsCmdDo :: DsCmdEnv               -- arrow combinators
 --     --------------------------
 --     A | xs |- do { c } :: [] t
 
-dsCmdDo ids local_vars env_ids res_ty [] body
+dsCmdDo _ _ _ _ [] = panic "dsCmdDo"
+
+dsCmdDo ids local_vars env_ids res_ty [L _ (LastStmt body _)]
   = dsLCmd ids local_vars env_ids [] res_ty body
 
-dsCmdDo ids local_vars env_ids res_ty (stmt:stmts) body = do
+dsCmdDo ids local_vars env_ids res_ty (stmt:stmts) = do
     let
         bound_vars = mkVarSet (collectLStmtBinders stmt)
         local_vars' = local_vars `unionVarSet` bound_vars
     (core_stmts, _, env_ids') <- fixDs (\ ~(_,_,env_ids') -> do
-        (core_stmts, fv_stmts) <- dsCmdDo ids local_vars' env_ids' res_ty stmts body
+        (core_stmts, fv_stmts) <- dsCmdDo ids local_vars' env_ids' res_ty stmts 
         return (core_stmts, fv_stmts, varSetElems fv_stmts))
     (core_stmt, fv_stmt) <- dsCmdLStmt ids local_vars env_ids env_ids' stmt
     return (do_compose ids
@@ -674,7 +675,7 @@ dsCmdStmt
 --             ---> arr (\ (xs) -> ((xs1),(xs'))) >>> first c >>>
 --                     arr snd >>> ss
 
-dsCmdStmt ids local_vars env_ids out_ids (ExprStmt cmd _ c_ty) = do
+dsCmdStmt ids local_vars env_ids out_ids (ExprStmt cmd _ _ c_ty) = do
     (core_cmd, fv_cmd, env_ids1) <- dsfixCmd ids local_vars [] c_ty cmd
     core_mux <- matchEnvStack env_ids []
         (mkCorePairExpr (mkBigCoreVarTup env_ids1) (mkBigCoreVarTup out_ids))
index 1781aef..4088e44 100644 (file)
@@ -325,26 +325,12 @@ dsExpr (HsLet binds body) = do
 -- We need the `ListComp' form to use `deListComp' (rather than the "do" form)
 -- because the interpretation of `stmts' depends on what sort of thing it is.
 --
-dsExpr (HsDo ListComp stmts body result_ty)
-  =    -- Special case for list comprehensions
-    dsListComp stmts body elt_ty
-  where
-    [elt_ty] = tcTyConAppArgs result_ty
-
-dsExpr (HsDo DoExpr stmts body result_ty)
-  = dsDo stmts body result_ty
-
-dsExpr (HsDo GhciStmt stmts body result_ty)
-  = dsDo stmts body result_ty
-
-dsExpr (HsDo MDoExpr stmts body result_ty)
-  = dsDo stmts body result_ty
-
-dsExpr (HsDo PArrComp stmts body result_ty)
-  =    -- Special case for array comprehensions
-    dsPArrComp (map unLoc stmts) body elt_ty
-  where
-    [elt_ty] = tcTyConAppArgs result_ty
+dsExpr (HsDo ListComp  stmts res_ty) = dsListComp stmts res_ty
+dsExpr (HsDo PArrComp  stmts _)      = dsPArrComp (map unLoc stmts)
+dsExpr (HsDo DoExpr    stmts _)      = dsDo stmts 
+dsExpr (HsDo GhciStmt  stmts _)      = dsDo stmts 
+dsExpr (HsDo MDoExpr   stmts _)      = dsDo stmts 
+dsExpr (HsDo MonadComp stmts _)      = dsMonadComp stmts
 
 dsExpr (HsIf mb_fun guard_expr then_expr else_expr)
   = do { pred <- dsLExpr guard_expr
@@ -708,25 +694,20 @@ handled in DsListComp).  Basically does the translation given in the
 Haskell 98 report:
 
 \begin{code}
-dsDo   :: [LStmt Id]
-       -> LHsExpr Id
-       -> Type                 -- Type of the whole expression
-       -> DsM CoreExpr
-
-dsDo stmts body result_ty
+dsDo :: [LStmt Id] -> DsM CoreExpr
+dsDo stmts
   = goL stmts
   where
-    -- result_ty must be of the form (m b)
-    (m_ty, _b_ty) = tcSplitAppTy result_ty
-
-    goL [] = dsLExpr body
-    goL ((L loc stmt):lstmts) = putSrcSpanDs loc (go loc stmt lstmts)
+    goL [] = panic "dsDo"
+    goL (L loc stmt:lstmts) = putSrcSpanDs loc (go loc stmt lstmts)
   
-    go _ (ExprStmt rhs then_expr _) stmts
+    go _ (LastStmt body _) stmts
+      = ASSERT( null stmts ) dsLExpr body
+        -- The 'return' op isn't used for 'do' expressions
+
+    go _ (ExprStmt rhs then_expr _ _) stmts
       = do { rhs2 <- dsLExpr rhs
-           ; case tcSplitAppTy_maybe (exprType rhs2) of
-                Just (container_ty, returning_ty) -> warnDiscardedDoBindings rhs container_ty returning_ty
-                _                                 -> return ()
+           ; warnDiscardedDoBindings rhs (exprType rhs2) 
            ; then_expr2 <- dsExpr then_expr
           ; rest <- goL stmts
           ; return (mkApps then_expr2 [rhs2, rest]) }
@@ -750,29 +731,29 @@ dsDo stmts body result_ty
     go loc (RecStmt { recS_stmts = rec_stmts, recS_later_ids = later_ids
                     , recS_rec_ids = rec_ids, recS_ret_fn = return_op
                     , recS_mfix_fn = mfix_op, recS_bind_fn = bind_op
-                    , recS_rec_rets = rec_rets }) stmts
+                    , recS_rec_rets = rec_rets, recS_ret_ty = body_ty }) stmts
       = ASSERT( length rec_ids > 0 )
         goL (new_bind_stmt : stmts)
       where
-        -- returnE <- dsExpr return_id
-        -- mfixE <- dsExpr mfix_id
-        new_bind_stmt = L loc $ BindStmt (mkLHsPatTup later_pats) mfix_app
-                                         bind_op 
+        new_bind_stmt = L loc $ BindStmt (mkLHsPatTup later_pats)
+                                         mfix_app bind_op 
                                          noSyntaxExpr  -- Tuple cannot fail
 
         tup_ids      = rec_ids ++ filterOut (`elem` rec_ids) later_ids
+        tup_ty       = mkBoxedTupleTy (map idType tup_ids) -- Deals with singleton case
         rec_tup_pats = map nlVarPat tup_ids
         later_pats   = rec_tup_pats
         rets         = map noLoc rec_rets
-
-        mfix_app   = nlHsApp (noLoc mfix_op) mfix_arg
-        mfix_arg   = noLoc $ HsLam (MatchGroup [mkSimpleMatch [mfix_pat] body]
-                                             (mkFunTy tup_ty body_ty))
-        mfix_pat   = noLoc $ LazyPat $ mkLHsPatTup rec_tup_pats
-        body       = noLoc $ HsDo DoExpr rec_stmts return_app body_ty
-        return_app = nlHsApp (noLoc return_op) (mkLHsTupleExpr rets)
-       body_ty    = mkAppTy m_ty tup_ty
-        tup_ty     = mkBoxedTupleTy (map idType tup_ids) -- Deals with singleton case
+        mfix_app     = nlHsApp (noLoc mfix_op) mfix_arg
+        mfix_arg     = noLoc $ HsLam (MatchGroup [mkSimpleMatch [mfix_pat] body]
+                                                 (mkFunTy tup_ty body_ty))
+        mfix_pat     = noLoc $ LazyPat $ mkLHsPatTup rec_tup_pats
+        body         = noLoc $ HsDo DoExpr (rec_stmts ++ [ret_stmt]) body_ty
+        ret_app      = nlHsApp (noLoc return_op) (mkLHsTupleExpr rets)
+        ret_stmt     = noLoc $ mkLastStmt ret_app
+                    -- This LastStmt will be desugared with dsDo, 
+                    -- which ignores the return_op in the LastStmt,
+                    -- so we must apply the return_op explicitly 
 
 handle_failure :: LPat Id -> MatchResult -> SyntaxExpr Id -> DsM CoreExpr
     -- In a do expression, pattern-match failure just calls
@@ -790,104 +771,6 @@ mk_fail_msg pat = "Pattern match failure in do expression at " ++
                  showSDoc (ppr (getLoc pat))
 \end{code}
 
-Translation for RecStmt's: 
------------------------------
-We turn (RecStmt [v1,..vn] stmts) into:
-  
-  (v1,..,vn) <- mfix (\~(v1,..vn). do stmts
-                                     return (v1,..vn))
-
-\begin{code}
-{-
-dsMDo   :: HsStmtContext Name
-        -> [(Name,Id)]
-       -> [LStmt Id]
-       -> LHsExpr Id
-       -> Type                 -- Type of the whole expression
-       -> DsM CoreExpr
-
-dsMDo ctxt tbl stmts body result_ty
-  = goL stmts
-  where
-    goL [] = dsLExpr body
-    goL ((L loc stmt):lstmts) = putSrcSpanDs loc (go loc stmt lstmts)
-  
-    (m_ty, b_ty) = tcSplitAppTy result_ty      -- result_ty must be of the form (m b)
-    return_id = lookupEvidence tbl returnMName
-    bind_id   = lookupEvidence tbl bindMName
-    then_id   = lookupEvidence tbl thenMName
-    fail_id   = lookupEvidence tbl failMName
-
-    go _ (LetStmt binds) stmts
-      = do { rest <- goL stmts
-          ; dsLocalBinds binds rest }
-
-    go _ (ExprStmt rhs then_expr rhs_ty) stmts
-      = do { rhs2 <- dsLExpr rhs
-          ; warnDiscardedDoBindings rhs m_ty rhs_ty
-           ; then_expr2 <- dsExpr then_expr
-           ; rest <- goL stmts
-           ; return (mkApps then_expr2 [rhs2, rest]) }
-    
-    go _ (BindStmt pat rhs bind_op _) stmts
-      = do { body     <- goL stmts
-           ; rhs'     <- dsLExpr rhs
-           ; bind_op' <- dsExpr bind_op
-           ; var   <- selectSimpleMatchVarL pat
-          ; match <- matchSinglePat (Var var) (StmtCtxt ctxt) pat
-                                     result_ty (cantFailMatchResult body)
-           ; match_code <- handle_failure pat match fail_op
-           ; return (mkApps bind_op [rhs', Lam var match_code]) }
-    
-    go loc (RecStmt { recS_stmts = rec_stmts, recS_later_ids = later_ids
-                    , recS_rec_ids = rec_ids, recS_rec_rets = rec_rets
-                    , recS_mfix_fn = mfix_op, recS_bind_fn = bind_op }) stmts
-      = ASSERT( length rec_ids > 0 )
-        ASSERT( length rec_ids == length rec_rets )
-        ASSERT( isEmptyTcEvBinds _ev_binds )
-        pprTrace "dsMDo" (ppr later_ids) $
-        goL (new_bind_stmt : stmts)
-      where
-        new_bind_stmt = L loc $ BindStmt (mk_tup_pat later_pats) mfix_app
-                                         bind_op noSyntaxExpr
-       
-               -- Remove the later_ids that appear (without fancy coercions) 
-               -- in rec_rets, because there's no need to knot-tie them separately
-               -- See Note [RecStmt] in HsExpr
-       later_ids'   = filter (`notElem` mono_rec_ids) later_ids
-       mono_rec_ids = [ id | HsVar id <- rec_rets ]
-    
-        mfix_app = nlHsApp (noLoc mfix_op) mfix_arg
-       mfix_arg = noLoc $ HsLam (MatchGroup [mkSimpleMatch [mfix_pat] body]
-                                            (mkFunTy tup_ty body_ty))
-
-       -- The rec_tup_pat must bind the rec_ids only; remember that the 
-       --      trimmed_laters may share the same Names
-       -- Meanwhile, the later_pats must bind the later_vars
-       rec_tup_pats = map mk_wild_pat later_ids' ++ map nlVarPat rec_ids
-       later_pats   = map nlVarPat    later_ids' ++ map mk_later_pat rec_ids
-       rets         = map nlHsVar     later_ids' ++ map noLoc rec_rets
-
-       mfix_pat = noLoc $ LazyPat $ mk_tup_pat rec_tup_pats
-       body     = noLoc $ HsDo ctxt rec_stmts return_app body_ty
-       body_ty = mkAppTy m_ty tup_ty
-       tup_ty  = mkBoxedTupleTy (map idType (later_ids' ++ rec_ids))  -- Deals with singleton case
-
-        return_app  = nlHsApp (noLoc return_op) (mkLHsTupleExpr rets)
-
-       mk_wild_pat :: Id -> LPat Id 
-       mk_wild_pat v = noLoc $ WildPat $ idType v
-
-       mk_later_pat :: Id -> LPat Id
-       mk_later_pat v | v `elem` later_ids' = mk_wild_pat v
-                      | otherwise           = nlVarPat v
-
-       mk_tup_pat :: [LPat Id] -> LPat Id
-       mk_tup_pat [p] = p
-       mk_tup_pat ps  = noLoc $ mkVanillaTuplePat ps Boxed
--}
-\end{code}
-
 
 %************************************************************************
 %*                                                                     *
@@ -927,30 +810,34 @@ conversionNames
 
 \begin{code}
 -- Warn about certain types of values discarded in monadic bindings (#3263)
-warnDiscardedDoBindings :: LHsExpr Id -> Type -> Type -> DsM ()
-warnDiscardedDoBindings rhs container_ty returning_ty = do {
-          -- Warn about discarding non-() things in 'monadic' binding
-        ; warn_unused <- doptDs Opt_WarnUnusedDoBind
-        ; if warn_unused && not (returning_ty `tcEqType` unitTy)
-           then warnDs (unusedMonadBind rhs returning_ty)
-           else do {
-          -- Warn about discarding m a things in 'monadic' binding of the same type,
-          -- but only if we didn't already warn due to Opt_WarnUnusedDoBind
-        ; warn_wrong <- doptDs Opt_WarnWrongDoBind
-        ; case tcSplitAppTy_maybe returning_ty of
-                  Just (returning_container_ty, _) -> when (warn_wrong && container_ty `tcEqType` returning_container_ty) $
-                                                            warnDs (wrongMonadBind rhs returning_ty)
-                  _ -> return () } }
+warnDiscardedDoBindings :: LHsExpr Id -> Type -> DsM ()
+warnDiscardedDoBindings rhs rhs_ty
+  | Just (m_ty, elt_ty) <- tcSplitAppTy_maybe rhs_ty
+  = do {  -- Warn about discarding non-() things in 'monadic' binding
+       ; warn_unused <- doptDs Opt_WarnUnusedDoBind
+       ; if warn_unused && not (isUnitTy elt_ty)
+         then warnDs (unusedMonadBind rhs elt_ty)
+         else 
+         -- Warn about discarding m a things in 'monadic' binding of the same type,
+         -- but only if we didn't already warn due to Opt_WarnUnusedDoBind
+    do { warn_wrong <- doptDs Opt_WarnWrongDoBind
+       ; case tcSplitAppTy_maybe elt_ty of
+           Just (elt_m_ty, _) | warn_wrong, m_ty `tcEqType` elt_m_ty
+                              -> warnDs (wrongMonadBind rhs elt_ty)
+           _ -> return () } }
+
+  | otherwise  -- RHS does have type of form (m ty), which is wierd
+  = return ()   -- but at lesat this warning is irrelevant
 
 unusedMonadBind :: LHsExpr Id -> Type -> SDoc
-unusedMonadBind rhs returning_ty
-  = ptext (sLit "A do-notation statement discarded a result of type") <+> ppr returning_ty <> dot $$
+unusedMonadBind rhs elt_ty
+  = ptext (sLit "A do-notation statement discarded a result of type") <+> ppr elt_ty <> dot $$
     ptext (sLit "Suppress this warning by saying \"_ <- ") <> ppr rhs <> ptext (sLit "\",") $$
     ptext (sLit "or by using the flag -fno-warn-unused-do-bind")
 
 wrongMonadBind :: LHsExpr Id -> Type -> SDoc
-wrongMonadBind rhs returning_ty
-  = ptext (sLit "A do-notation statement discarded a result of type") <+> ppr returning_ty <> dot $$
+wrongMonadBind rhs elt_ty
+  = ptext (sLit "A do-notation statement discarded a result of type") <+> ppr elt_ty <> dot $$
     ptext (sLit "Suppress this warning by saying \"_ <- ") <> ppr rhs <> ptext (sLit "\",") $$
     ptext (sLit "or by using the flag -fno-warn-wrong-do-bind")
 \end{code}
index a7260e2..d3fcf76 100644 (file)
@@ -106,11 +106,11 @@ matchGuards [] _ rhs _
        -- NB:  The success of this clause depends on the typechecker not
        --      wrapping the 'otherwise' in empty HsTyApp or HsWrap constructors
        --      If it does, you'll get bogus overlap warnings
-matchGuards (ExprStmt e _ _ : stmts) ctx rhs rhs_ty
+matchGuards (ExprStmt e _ _ _ : stmts) ctx rhs rhs_ty
   | Just addTicks <- isTrueLHsExpr e = do
     match_result <- matchGuards stmts ctx rhs rhs_ty
     return (adjustMatchResultDs addTicks match_result)
-matchGuards (ExprStmt expr _ _ : stmts) ctx rhs rhs_ty = do
+matchGuards (ExprStmt expr _ _ _ : stmts) ctx rhs rhs_ty = do
     match_result <- matchGuards stmts ctx rhs rhs_ty
     pred_expr <- dsLExpr expr
     return (mkGuardedMatchResult pred_expr match_result)
index cd22b8f..aabd6b0 100644 (file)
@@ -3,9 +3,10 @@
 % (c) The GRASP/AQUA Project, Glasgow University, 1992-1998
 %
 
-Desugaring list comprehensions and array comprehensions
+Desugaring list comprehensions, monad comprehensions and array comprehensions
 
 \begin{code}
+{-# LANGUAGE NamedFieldPuns #-}
 {-# OPTIONS -fno-warn-incomplete-patterns #-}
 -- The above warning supression flag is a temporary kludge.
 -- While working on this module you are encouraged to remove it and fix
@@ -13,11 +14,11 @@ Desugaring list comprehensions and array comprehensions
 --     http://hackage.haskell.org/trac/ghc/wiki/Commentary/CodingStyle#Warnings
 -- for details
 
-module DsListComp ( dsListComp, dsPArrComp ) where
+module DsListComp ( dsListComp, dsPArrComp, dsMonadComp ) where
 
 #include "HsVersions.h"
 
-import {-# SOURCE #-} DsExpr ( dsLExpr, dsLocalBinds )
+import {-# SOURCE #-} DsExpr ( dsExpr, dsLExpr, dsLocalBinds )
 
 import HsSyn
 import TcHsSyn
@@ -37,6 +38,7 @@ import PrelNames
 import SrcLoc
 import Outputable
 import FastString
+import TcType
 \end{code}
 
 List comprehensions may be desugared in one of two ways: ``ordinary''
@@ -47,12 +49,14 @@ There will be at least one ``qualifier'' in the input.
 
 \begin{code}
 dsListComp :: [LStmt Id] 
-          -> LHsExpr Id
-          -> Type              -- Type of list elements
+          -> Type              -- Type of entire list 
           -> DsM CoreExpr
-dsListComp lquals body elt_ty = do 
+dsListComp lquals res_ty = do 
     dflags <- getDOptsDs
     let quals = map unLoc lquals
+        elt_ty = case tcTyConAppArgs res_ty of
+                   [elt_ty] -> elt_ty
+                   _ -> pprPanic "dsListComp" (ppr res_ty $$ ppr lquals)
     
     if not (dopt Opt_EnableRewriteRules dflags) || dopt Opt_IgnoreInterfacePragmas dflags
        -- Either rules are switched off, or we are ignoring what there are;
@@ -60,8 +64,8 @@ dsListComp lquals body elt_ty = do
        -- Wadler-style desugaring
        || isParallelComp quals
        -- Foldr-style desugaring can't handle parallel list comprehensions
-        then deListComp quals body (mkNilExpr elt_ty)
-        else mkBuildExpr elt_ty (\(c, _) (n, _) -> dfListComp c n quals body) 
+        then deListComp quals (mkNilExpr elt_ty)
+        else mkBuildExpr elt_ty (\(c, _) (n, _) -> dfListComp c n quals) 
              -- Foldr/build should be enabled, so desugar 
              -- into foldrs and builds
 
@@ -72,92 +76,69 @@ dsListComp lquals body elt_ty = do
     -- mix of possibly a single element in length, so we do this to leave the possibility open
     isParallelComp = any isParallelStmt
   
-    isParallelStmt (ParStmt _) = True
-    isParallelStmt _           = False
+    isParallelStmt (ParStmt _ _ _ _) = True
+    isParallelStmt _                 = False
     
     
 -- This function lets you desugar a inner list comprehension and a list of the binders
 -- of that comprehension that we need in the outer comprehension into such an expression
 -- and the type of the elements that it outputs (tuples of binders)
 dsInnerListComp :: ([LStmt Id], [Id]) -> DsM (CoreExpr, Type)
-dsInnerListComp (stmts, bndrs) = do
-        expr <- dsListComp stmts (mkBigLHsVarTup bndrs) bndrs_tuple_type
-        return (expr, bndrs_tuple_type)
-    where
-        bndrs_types = map idType bndrs
-        bndrs_tuple_type = mkBigCoreTupTy bndrs_types
-        
+dsInnerListComp (stmts, bndrs)
+  = do { expr <- dsListComp (stmts ++ [noLoc $ mkLastStmt (mkBigLHsVarTup bndrs)]) 
+                            (mkListTy bndrs_tuple_type)
+       ; return (expr, bndrs_tuple_type) }
+  where
+    bndrs_tuple_type = mkBigCoreVarTupTy bndrs
         
--- This function factors out commonality between the desugaring strategies for TransformStmt.
--- Given such a statement it gives you back an expression representing how to compute the transformed
--- list and the tuple that you need to bind from that list in order to proceed with your desugaring
-dsTransformStmt :: Stmt Id -> DsM (CoreExpr, LPat Id)
-dsTransformStmt (TransformStmt stmts binders usingExpr maybeByExpr)
- = do { (expr, binders_tuple_type) <- dsInnerListComp (stmts, binders)
-      ; usingExpr' <- dsLExpr usingExpr
-    
-      ; using_args <-
-          case maybeByExpr of
-            Nothing -> return [expr]
-            Just byExpr -> do
-                byExpr' <- dsLExpr byExpr
-                
-                us <- newUniqueSupply
-                [tuple_binder] <- newSysLocalsDs [binders_tuple_type]
-                let byExprWrapper = mkTupleCase us binders byExpr' tuple_binder (Var tuple_binder)
-                
-                return [Lam tuple_binder byExprWrapper, expr]
-
-      ; let inner_list_expr = mkApps usingExpr' ((Type binders_tuple_type) : using_args)
-            pat = mkBigLHsVarPatTup binders
-      ; return (inner_list_expr, pat) }
-    
 -- This function factors out commonality between the desugaring strategies for GroupStmt.
 -- Given such a statement it gives you back an expression representing how to compute the transformed
 -- list and the tuple that you need to bind from that list in order to proceed with your desugaring
-dsGroupStmt :: Stmt Id -> DsM (CoreExpr, LPat Id)
-dsGroupStmt (GroupStmt stmts binderMap by using) = do
-    let (fromBinders, toBinders) = unzip binderMap
-        
-        fromBindersTypes = map idType fromBinders
-        toBindersTypes = map idType toBinders
-        
-        toBindersTupleType = mkBigCoreTupTy toBindersTypes
+dsTransStmt :: Stmt Id -> DsM (CoreExpr, LPat Id)
+dsTransStmt (TransStmt { trS_form = form, trS_stmts = stmts, trS_bndrs = binderMap
+                       , trS_by = by, trS_using = using }) = do
+    let (from_bndrs, to_bndrs) = unzip binderMap
+        from_bndrs_tys  = map idType from_bndrs
+        to_bndrs_tys    = map idType to_bndrs
+        to_bndrs_tup_ty = mkBigCoreTupTy to_bndrs_tys
     
     -- Desugar an inner comprehension which outputs a list of tuples of the "from" binders
-    (expr, from_tup_ty) <- dsInnerListComp (stmts, fromBinders)
+    (expr, from_tup_ty) <- dsInnerListComp (stmts, from_bndrs)
     
     -- Work out what arguments should be supplied to that expression: i.e. is an extraction
     -- function required? If so, create that desugared function and add to arguments
-    usingExpr' <- dsLExpr (either id noLoc using)
+    usingExpr' <- dsLExpr using
     usingArgs <- case by of
                    Nothing   -> return [expr]
                   Just by_e -> do { by_e' <- dsLExpr by_e
-                                   ; us <- newUniqueSupply
-                                   ; [from_tup_id] <- newSysLocalsDs [from_tup_ty]
-                                   ; let by_wrap = mkTupleCase us fromBinders by_e' 
-                                                   from_tup_id (Var from_tup_id)
-                                   ; return [Lam from_tup_id by_wrap, expr] }
+                                   ; lam <- matchTuple from_bndrs by_e'
+                                   ; return [lam, expr] }
     
     -- Create an unzip function for the appropriate arity and element types and find "map"
-    (unzip_fn, unzip_rhs) <- mkUnzipBind fromBindersTypes
+    unzip_stuff <- mkUnzipBind form from_bndrs_tys
     map_id <- dsLookupGlobalId mapName
 
     -- Generate the expressions to build the grouped list
     let -- First we apply the grouping function to the inner list
-        inner_list_expr = mkApps usingExpr' ((Type from_tup_ty) : usingArgs)
+        inner_list_expr = mkApps usingExpr' usingArgs
         -- Then we map our "unzip" across it to turn the lists of tuples into tuples of lists
         -- We make sure we instantiate the type variable "a" to be a list of "from" tuples and
         -- the "b" to be a tuple of "to" lists!
-        unzipped_inner_list_expr = mkApps (Var map_id) 
-            [Type (mkListTy from_tup_ty), Type toBindersTupleType, Var unzip_fn, inner_list_expr]
         -- Then finally we bind the unzip function around that expression
-        bound_unzipped_inner_list_expr = Let (Rec [(unzip_fn, unzip_rhs)]) unzipped_inner_list_expr
-    
-    -- Build a pattern that ensures the consumer binds into the NEW binders, which hold lists rather than single values
-    let pat = mkBigLHsVarPatTup toBinders
+        bound_unzipped_inner_list_expr 
+          = case unzip_stuff of
+              Nothing -> inner_list_expr
+              Just (unzip_fn, unzip_rhs) -> Let (Rec [(unzip_fn, unzip_rhs)]) $
+                                            mkApps (Var map_id) $
+                                            [ Type (mkListTy from_tup_ty)
+                                            , Type to_bndrs_tup_ty
+                                            , Var unzip_fn
+                                            , inner_list_expr]
+
+    -- Build a pattern that ensures the consumer binds into the NEW binders, 
+    -- which hold lists rather than single values
+    let pat = mkBigLHsVarPatTup to_bndrs
     return (bound_unzipped_inner_list_expr, pat)
-    
 \end{code}
 
 %************************************************************************
@@ -226,53 +207,50 @@ with the Unboxed variety.
 
 \begin{code}
 
-deListComp :: [Stmt Id] -> LHsExpr Id -> CoreExpr -> DsM CoreExpr
-
-deListComp (ParStmt stmtss_w_bndrs : quals) body list
-  = do
-    exps_and_qual_tys <- mapM dsInnerListComp stmtss_w_bndrs
-    let (exps, qual_tys) = unzip exps_and_qual_tys
-    
-    (zip_fn, zip_rhs) <- mkZipBind qual_tys
+deListComp :: [Stmt Id] -> CoreExpr -> DsM CoreExpr
 
-       -- Deal with [e | pat <- zip l1 .. ln] in example above
-    deBindComp pat (Let (Rec [(zip_fn, zip_rhs)]) (mkApps (Var zip_fn) exps)) 
-                  quals body list
+deListComp [] _ = panic "deListComp"
 
-  where 
-       bndrs_s = map snd stmtss_w_bndrs
-
-       -- pat is the pattern ((x1,..,xn), (y1,..,ym)) in the example above
-       pat  = mkBigLHsPatTup pats
-       pats = map mkBigLHsVarPatTup bndrs_s
-
-       -- Last: the one to return
-deListComp [] body list = do    -- Figure 7.4, SLPJ, p 135, rule C above
-    core_body <- dsLExpr body
-    return (mkConsExpr (exprType core_body) core_body list)
+deListComp (LastStmt body _ : quals) list 
+  =     -- Figure 7.4, SLPJ, p 135, rule C above
+    ASSERT( null quals )
+    do { core_body <- dsLExpr body
+       ; return (mkConsExpr (exprType core_body) core_body list) }
 
        -- Non-last: must be a guard
-deListComp (ExprStmt guard _ _ : quals) body list = do  -- rule B above
+deListComp (ExprStmt guard _ _ _ : quals) list = do  -- rule B above
     core_guard <- dsLExpr guard
-    core_rest <- deListComp quals body list
+    core_rest <- deListComp quals list
     return (mkIfThenElse core_guard core_rest list)
 
 -- [e | let B, qs] = let B in [e | qs]
-deListComp (LetStmt binds : quals) body list = do
-    core_rest <- deListComp quals body list
+deListComp (LetStmt binds : quals) list = do
+    core_rest <- deListComp quals list
     dsLocalBinds binds core_rest
 
-deListComp (stmt@(TransformStmt {}) : quals) body list = do
-    (inner_list_expr, pat) <- dsTransformStmt stmt
-    deBindComp pat inner_list_expr quals body list
+deListComp (stmt@(TransStmt {}) : quals) list = do
+    (inner_list_expr, pat) <- dsTransStmt stmt
+    deBindComp pat inner_list_expr quals list
 
-deListComp (stmt@(GroupStmt {}) : quals) body list = do
-    (inner_list_expr, pat) <- dsGroupStmt stmt
-    deBindComp pat inner_list_expr quals body list
-
-deListComp (BindStmt pat list1 _ _ : quals) body core_list2 = do -- rule A' above
+deListComp (BindStmt pat list1 _ _ : quals) core_list2 = do -- rule A' above
     core_list1 <- dsLExpr list1
-    deBindComp pat core_list1 quals body core_list2
+    deBindComp pat core_list1 quals core_list2
+
+deListComp (ParStmt stmtss_w_bndrs _ _ _ : quals) list
+  = do { exps_and_qual_tys <- mapM dsInnerListComp stmtss_w_bndrs
+       ; let (exps, qual_tys) = unzip exps_and_qual_tys
+    
+       ; (zip_fn, zip_rhs) <- mkZipBind qual_tys
+
+       -- Deal with [e | pat <- zip l1 .. ln] in example above
+       ; deBindComp pat (Let (Rec [(zip_fn, zip_rhs)]) (mkApps (Var zip_fn) exps)) 
+                   quals list }
+  where 
+       bndrs_s = map snd stmtss_w_bndrs
+
+       -- pat is the pattern ((x1,..,xn), (y1,..,ym)) in the example above
+       pat  = mkBigLHsPatTup pats
+       pats = map mkBigLHsVarPatTup bndrs_s
 \end{code}
 
 
@@ -280,10 +258,9 @@ deListComp (BindStmt pat list1 _ _ : quals) body core_list2 = do -- rule A' abov
 deBindComp :: OutPat Id
            -> CoreExpr
            -> [Stmt Id]
-           -> LHsExpr Id
            -> CoreExpr
            -> DsM (Expr Id)
-deBindComp pat core_list1 quals body core_list2 = do
+deBindComp pat core_list1 quals core_list2 = do
     let
         u3_ty@u1_ty = exprType core_list1      -- two names, same thing
 
@@ -300,7 +277,7 @@ deBindComp pat core_list1 quals body core_list2 = do
         core_fail   = App (Var h) (Var u3)
         letrec_body = App (Var h) core_list1
         
-    rest_expr <- deListComp quals body core_fail
+    rest_expr <- deListComp quals core_fail
     core_match <- matchSimply (Var u2) (StmtCtxt ListComp) pat rest_expr core_fail     
     
     let
@@ -335,48 +312,43 @@ TE[ e | p <- l , q ] c n = let
 \begin{code}
 dfListComp :: Id -> Id -- 'c' and 'n'
         -> [Stmt Id]   -- the rest of the qual's
-        -> LHsExpr Id
         -> DsM CoreExpr
 
-       -- Last: the one to return
-dfListComp c_id n_id [] body = do
-    core_body <- dsLExpr body
-    return (mkApps (Var c_id) [core_body, Var n_id])
+dfListComp _ _ [] = panic "dfListComp"
+
+dfListComp c_id n_id (LastStmt body _ : quals) 
+  = ASSERT( null quals )
+    do { core_body <- dsLExpr body
+       ; return (mkApps (Var c_id) [core_body, Var n_id]) }
 
        -- Non-last: must be a guard
-dfListComp c_id n_id (ExprStmt guard _ _  : quals) body = do
+dfListComp c_id n_id (ExprStmt guard _ _ _  : quals) = do
     core_guard <- dsLExpr guard
-    core_rest <- dfListComp c_id n_id quals body
+    core_rest <- dfListComp c_id n_id quals
     return (mkIfThenElse core_guard core_rest (Var n_id))
 
-dfListComp c_id n_id (LetStmt binds : quals) body = do
+dfListComp c_id n_id (LetStmt binds : quals) = do
     -- new in 1.3, local bindings
-    core_rest <- dfListComp c_id n_id quals body
+    core_rest <- dfListComp c_id n_id quals
     dsLocalBinds binds core_rest
 
-dfListComp c_id n_id (stmt@(TransformStmt {}) : quals) body = do
-    (inner_list_expr, pat) <- dsTransformStmt stmt
-    -- Anyway, we bind the newly transformed list via the generic binding function
-    dfBindComp c_id n_id (pat, inner_list_expr) quals body
-
-dfListComp c_id n_id (stmt@(GroupStmt {}) : quals) body = do
-    (inner_list_expr, pat) <- dsGroupStmt stmt
+dfListComp c_id n_id (stmt@(TransStmt {}) : quals) = do
+    (inner_list_expr, pat) <- dsTransStmt stmt
     -- Anyway, we bind the newly grouped list via the generic binding function
-    dfBindComp c_id n_id (pat, inner_list_expr) quals body
+    dfBindComp c_id n_id (pat, inner_list_expr) quals 
     
-dfListComp c_id n_id (BindStmt pat list1 _ _ : quals) body = do
+dfListComp c_id n_id (BindStmt pat list1 _ _ : quals) = do
     -- evaluate the two lists
     core_list1 <- dsLExpr list1
     
     -- Do the rest of the work in the generic binding builder
-    dfBindComp c_id n_id (pat, core_list1) quals body
+    dfBindComp c_id n_id (pat, core_list1) quals
                
 dfBindComp :: Id -> Id         -- 'c' and 'n'
        -> (LPat Id, CoreExpr)
           -> [Stmt Id]                 -- the rest of the qual's
-          -> LHsExpr Id
           -> DsM CoreExpr
-dfBindComp c_id n_id (pat, core_list1) quals body = do
+dfBindComp c_id n_id (pat, core_list1) quals = do
     -- find the required type
     let x_ty   = hsLPatType pat
         b_ty   = idType n_id
@@ -385,7 +357,7 @@ dfBindComp c_id n_id (pat, core_list1) quals body = do
     [b, x] <- newSysLocalsDs [b_ty, x_ty]
 
     -- build rest of the comprehesion
-    core_rest <- dfListComp c_id b quals body
+    core_rest <- dfListComp c_id b quals
 
     -- build the pattern match
     core_expr <- matchSimply (Var x) (StmtCtxt ListComp)
@@ -439,7 +411,7 @@ mkZipBind elt_tys = do
                        -- Increasing order of tag
             
             
-mkUnzipBind :: [Type] -> DsM (Id, CoreExpr)
+mkUnzipBind :: TransForm -> [Type] -> DsM (Maybe (Id, CoreExpr))
 -- mkUnzipBind [t1, t2] 
 -- = (unzip, \ys :: [(t1, t2)] -> foldr (\ax :: (t1, t2) axs :: ([t1], [t2])
 --     -> case ax of
@@ -449,28 +421,29 @@ mkUnzipBind :: [Type] -> DsM (Id, CoreExpr)
 --      ys)
 -- 
 -- We use foldr here in all cases, even if rules are turned off, because we may as well!
-mkUnzipBind elt_tys = do
-    ax  <- newSysLocalDs elt_tuple_ty
-    axs <- newSysLocalDs elt_list_tuple_ty
-    ys  <- newSysLocalDs elt_tuple_list_ty
-    xs  <- mapM newSysLocalDs elt_tys
-    xss <- mapM newSysLocalDs elt_list_tys
+mkUnzipBind ThenForm _
+ = return Nothing    -- No unzipping for ThenForm
+mkUnzipBind _ elt_tys 
+  = do { ax  <- newSysLocalDs elt_tuple_ty
+       ; axs <- newSysLocalDs elt_list_tuple_ty
+       ; ys  <- newSysLocalDs elt_tuple_list_ty
+       ; xs  <- mapM newSysLocalDs elt_tys
+       ; xss <- mapM newSysLocalDs elt_list_tys
     
-    unzip_fn <- newSysLocalDs unzip_fn_ty
-
-    [us1, us2] <- sequence [newUniqueSupply, newUniqueSupply]
-
-    let nil_tuple = mkBigCoreTup (map mkNilExpr elt_tys)
-        
-        concat_expressions = map mkConcatExpression (zip3 elt_tys (map Var xs) (map Var xss))
-        tupled_concat_expression = mkBigCoreTup concat_expressions
-        
-        folder_body_inner_case = mkTupleCase us1 xss tupled_concat_expression axs (Var axs)
-        folder_body_outer_case = mkTupleCase us2 xs folder_body_inner_case ax (Var ax)
-        folder_body = mkLams [ax, axs] folder_body_outer_case
-        
-    unzip_body <- mkFoldrExpr elt_tuple_ty elt_list_tuple_ty folder_body nil_tuple (Var ys)
-    return (unzip_fn, mkLams [ys] unzip_body)
+       ; unzip_fn <- newSysLocalDs unzip_fn_ty
+
+       ; [us1, us2] <- sequence [newUniqueSupply, newUniqueSupply]
+
+       ; let nil_tuple = mkBigCoreTup (map mkNilExpr elt_tys)
+            concat_expressions = map mkConcatExpression (zip3 elt_tys (map Var xs) (map Var xss))
+            tupled_concat_expression = mkBigCoreTup concat_expressions
+           
+            folder_body_inner_case = mkTupleCase us1 xss tupled_concat_expression axs (Var axs)
+            folder_body_outer_case = mkTupleCase us2 xs folder_body_inner_case ax (Var ax)
+            folder_body = mkLams [ax, axs] folder_body_outer_case
+           
+       ; unzip_body <- mkFoldrExpr elt_tuple_ty elt_list_tuple_ty folder_body nil_tuple (Var ys)
+       ; return (Just (unzip_fn, mkLams [ys] unzip_body)) }
   where
     elt_tuple_ty       = mkBigCoreTupTy elt_tys
     elt_tuple_list_ty  = mkListTy elt_tuple_ty
@@ -480,9 +453,6 @@ mkUnzipBind elt_tys = do
     unzip_fn_ty        = elt_tuple_list_ty `mkFunTy` elt_list_tuple_ty
             
     mkConcatExpression (list_element_ty, head, tail) = mkConsExpr list_element_ty head tail
-            
-            
-
 \end{code}
 
 %************************************************************************
@@ -498,11 +468,10 @@ mkUnzipBind elt_tys = do
 --   [:e | qss:] = <<[:e | qss:]>> () [:():]
 --
 dsPArrComp :: [Stmt Id] 
-            -> LHsExpr Id
-            -> Type                -- Don't use; called with `undefined' below
             -> DsM CoreExpr
-dsPArrComp [ParStmt qss] body _  =  -- parallel comprehension
-  dePArrParComp qss body
+
+-- Special case for parallel comprehension
+dsPArrComp (ParStmt qss _ _ _ : quals) = dePArrParComp qss quals
 
 -- Special case for simple generators:
 --
@@ -513,7 +482,7 @@ dsPArrComp [ParStmt qss] body _  =  -- parallel comprehension
 --  <<[:e' | p <- e, qs:]>> = 
 --    <<[:e' | qs:]>> p (filterP (\x -> case x of {p -> True; _ -> False}) e)
 --
-dsPArrComp (BindStmt p e _ _ : qs) body _ = do
+dsPArrComp (BindStmt p e _ _ : qs) = do
     filterP <- dsLookupDPHId filterPName
     ce <- dsLExpr e
     let ety'ce  = parrElemType ce
@@ -523,38 +492,41 @@ dsPArrComp (BindStmt p e _ _ : qs) body _ = do
     pred <- matchSimply (Var v) (StmtCtxt PArrComp) p true false
     let gen | isIrrefutableHsPat p = ce
             | otherwise            = mkApps (Var filterP) [Type ety'ce, mkLams [v] pred, ce]
-    dePArrComp qs body p gen
+    dePArrComp qs p gen
 
-dsPArrComp qs            body _  = do -- no ParStmt in `qs'
+dsPArrComp qs = do -- no ParStmt in `qs'
     sglP <- dsLookupDPHId singletonPName
     let unitArray = mkApps (Var sglP) [Type unitTy, mkCoreTup []]
-    dePArrComp qs body (noLoc $ WildPat unitTy) unitArray
+    dePArrComp qs (noLoc $ WildPat unitTy) unitArray
 
 
 
 -- the work horse
 --
 dePArrComp :: [Stmt Id] 
-          -> LHsExpr Id
           -> LPat Id           -- the current generator pattern
           -> CoreExpr          -- the current generator expression
           -> DsM CoreExpr
+
+dePArrComp [] _ _ = panic "dePArrComp"
+
 --
 --  <<[:e' | :]>> pa ea = mapP (\pa -> e') ea
 --
-dePArrComp [] e' pa cea = do
-    mapP <- dsLookupDPHId mapPName
-    let ty = parrElemType cea
-    (clam, ty'e') <- deLambda ty pa e'
-    return $ mkApps (Var mapP) [Type ty, Type ty'e', clam, cea]
+dePArrComp (LastStmt e' _ : quals) pa cea
+  = ASSERT( null quals )
+    do { mapP <- dsLookupDPHId mapPName
+       ; let ty = parrElemType cea
+       ; (clam, ty'e') <- deLambda ty pa e'
+       ; return $ mkApps (Var mapP) [Type ty, Type ty'e', clam, cea] }
 --
 --  <<[:e' | b, qs:]>> pa ea = <<[:e' | qs:]>> pa (filterP (\pa -> b) ea)
 --
-dePArrComp (ExprStmt b _ _ : qs) body pa cea = do
+dePArrComp (ExprStmt b _ _ _ : qs) pa cea = do
     filterP <- dsLookupDPHId filterPName
     let ty = parrElemType cea
     (clam,_) <- deLambda ty pa b
-    dePArrComp qs body pa (mkApps (Var filterP) [Type ty, clam, cea])
+    dePArrComp qs pa (mkApps (Var filterP) [Type ty, clam, cea])
 
 --
 --  <<[:e' | p <- e, qs:]>> pa ea =
@@ -569,7 +541,7 @@ dePArrComp (ExprStmt b _ _ : qs) body pa cea = do
 --    in
 --    <<[:e' | qs:]>> (pa, p) (crossMapP ea ef)
 --
-dePArrComp (BindStmt p e _ _ : qs) body pa cea = do
+dePArrComp (BindStmt p e _ _ : qs) pa cea = do
     filterP <- dsLookupDPHId filterPName
     crossMapP <- dsLookupDPHId crossMapPName
     ce <- dsLExpr e
@@ -585,7 +557,7 @@ dePArrComp (BindStmt p e _ _ : qs) body pa cea = do
     let ety'cef = ety'ce                   -- filter doesn't change the element type
         pa'     = mkLHsPatTup [pa, p]
 
-    dePArrComp qs body pa' (mkApps (Var crossMapP) 
+    dePArrComp qs pa' (mkApps (Var crossMapP) 
                                  [Type ety'cea, Type ety'cef, cea, clam])
 --
 --  <<[:e' | let ds, qs:]>> pa ea = 
@@ -594,7 +566,7 @@ dePArrComp (BindStmt p e _ _ : qs) body pa cea = do
 --  where
 --    {x_1, ..., x_n} = DV (ds)                -- Defined Variables
 --
-dePArrComp (LetStmt ds : qs) body pa cea = do
+dePArrComp (LetStmt ds : qs) pa cea = do
     mapP <- dsLookupDPHId mapPName
     let xs     = collectLocalBinders ds
         ty'cea = parrElemType cea
@@ -609,14 +581,14 @@ dePArrComp (LetStmt ds : qs) body pa cea = do
     ccase <- matchSimply (Var v) (StmtCtxt PArrComp) pa projBody cerr
     let pa'    = mkLHsPatTup [pa, mkLHsPatTup (map nlVarPat xs)]
         proj   = mkLams [v] ccase
-    dePArrComp qs body pa' (mkApps (Var mapP) 
+    dePArrComp qs pa' (mkApps (Var mapP) 
                                    [Type ty'cea, Type errTy, proj, cea])
 --
 -- The parser guarantees that parallel comprehensions can only appear as
 -- singeltons qualifier lists, which we already special case in the caller.
 -- So, encountering one here is a bug.
 --
-dePArrComp (ParStmt _ : _) _ _ _ = 
+dePArrComp (ParStmt _ _ _ _ : _) _ _ = 
   panic "DsListComp.dePArrComp: malformed comprehension AST"
 
 --  <<[:e' | qs | qss:]>> pa ea = 
@@ -625,17 +597,17 @@ dePArrComp (ParStmt _ : _) _ _ _ =
 --    where
 --      {x_1, ..., x_n} = DV (qs)
 --
-dePArrParComp :: [([LStmt Id], [Id])] -> LHsExpr Id -> DsM CoreExpr
-dePArrParComp qss body = do
+dePArrParComp :: [([LStmt Id], [Id])] -> [Stmt Id] -> DsM CoreExpr
+dePArrParComp qss quals = do
     (pQss, ceQss) <- deParStmt qss
-    dePArrComp [] body pQss ceQss
+    dePArrComp quals pQss ceQss
   where
     deParStmt []             =
       -- empty parallel statement lists have no source representation
       panic "DsListComp.dePArrComp: Empty parallel list comprehension"
     deParStmt ((qs, xs):qss) = do        -- first statement
       let res_expr = mkLHsVarTuple xs
-      cqs <- dsPArrComp (map unLoc qs) res_expr undefined
+      cqs <- dsPArrComp (map unLoc qs ++ [mkLastStmt res_expr])
       parStmts qss (mkLHsVarPatTup xs) cqs
     ---
     parStmts []             pa cea = return (pa, cea)
@@ -644,7 +616,7 @@ dePArrParComp qss body = do
       let pa'      = mkLHsPatTup [pa, mkLHsVarPatTup xs]
           ty'cea   = parrElemType cea
           res_expr = mkLHsVarTuple xs
-      cqs <- dsPArrComp (map unLoc qs) res_expr undefined
+      cqs <- dsPArrComp (map unLoc qs ++ [mkLastStmt res_expr])
       let ty'cqs = parrElemType cqs
           cea'   = mkApps (Var zipP) [Type ty'cea, Type ty'cqs, cea, cqs]
       parStmts qss pa' cea'
@@ -682,3 +654,222 @@ parrElemType e  =
     _                                                    -> panic
       "DsListComp.parrElemType: not a parallel array type"
 \end{code}
+
+Translation for monad comprehensions
+
+\begin{code}
+-- Entry point for monad comprehension desugaring
+dsMonadComp :: [LStmt Id] -> DsM CoreExpr
+dsMonadComp stmts = dsMcStmts stmts
+
+dsMcStmts :: [LStmt Id] -> DsM CoreExpr
+dsMcStmts []                    = panic "dsMcStmts"
+dsMcStmts (L loc stmt : lstmts) = putSrcSpanDs loc (dsMcStmt stmt lstmts)
+
+---------------
+dsMcStmt :: Stmt Id -> [LStmt Id] -> DsM CoreExpr
+
+dsMcStmt (LastStmt body ret_op) stmts
+  = ASSERT( null stmts )
+    do { body' <- dsLExpr body
+       ; ret_op' <- dsExpr ret_op
+       ; return (App ret_op' body') }
+
+--   [ .. | let binds, stmts ]
+dsMcStmt (LetStmt binds) stmts 
+  = do { rest <- dsMcStmts stmts
+       ; dsLocalBinds binds rest }
+
+--   [ .. | a <- m, stmts ]
+dsMcStmt (BindStmt pat rhs bind_op fail_op) stmts
+  = do { rhs' <- dsLExpr rhs
+       ; dsMcBindStmt pat rhs' bind_op fail_op stmts }
+
+-- Apply `guard` to the `exp` expression
+--
+--   [ .. | exp, stmts ]
+--
+dsMcStmt (ExprStmt exp then_exp guard_exp _) stmts 
+  = do { exp'       <- dsLExpr exp
+       ; guard_exp' <- dsExpr guard_exp
+       ; then_exp'  <- dsExpr then_exp
+       ; rest       <- dsMcStmts stmts
+       ; return $ mkApps then_exp' [ mkApps guard_exp' [exp']
+                                   , rest ] }
+
+-- Group statements desugar like this:
+--
+--   [| (q, then group by e using f); rest |]
+--   --->  f {qt} (\qv -> e) [| q; return qv |] >>= \ n_tup -> 
+--         case unzip n_tup of qv' -> [| rest |]
+--
+-- where   variables (v1:t1, ..., vk:tk) are bound by q
+--         qv = (v1, ..., vk)
+--         qt = (t1, ..., tk)
+--         (>>=) :: m2 a -> (a -> m3 b) -> m3 b
+--         f :: forall a. (a -> t) -> m1 a -> m2 (n a)
+--         n_tup :: n qt
+--         unzip :: n qt -> (n t1, ..., n tk)    (needs Functor n)
+
+dsMcStmt (TransStmt { trS_stmts = stmts, trS_bndrs = bndrs
+                    , trS_by = by, trS_using = using
+                    , trS_ret = return_op, trS_bind = bind_op
+                    , trS_fmap = fmap_op, trS_form = form }) stmts_rest
+  = do { let (from_bndrs, to_bndrs) = unzip bndrs
+             from_bndr_tys          = map idType from_bndrs    -- Types ty
+
+       -- Desugar an inner comprehension which outputs a list of tuples of the "from" binders
+       ; expr <- dsInnerMonadComp stmts from_bndrs return_op
+
+       -- Work out what arguments should be supplied to that expression: i.e. is an extraction
+       -- function required? If so, create that desugared function and add to arguments
+       ; usingExpr' <- dsLExpr using
+       ; usingArgs <- case by of
+                        Nothing   -> return [expr]
+                        Just by_e -> do { by_e' <- dsLExpr by_e
+                                        ; lam <- matchTuple from_bndrs by_e'
+                                        ; return [lam, expr] }
+
+       -- Generate the expressions to build the grouped list
+       -- Build a pattern that ensures the consumer binds into the NEW binders, 
+       -- which hold monads rather than single values
+       ; bind_op' <- dsExpr bind_op
+       ; let bind_ty  = exprType bind_op'    -- m2 (n (a,b,c)) -> (n (a,b,c) -> r1) -> r2
+             n_tup_ty = funArgTy $ funArgTy $ funResultTy bind_ty   -- n (a,b,c)
+             tup_n_ty = mkBigCoreVarTupTy to_bndrs
+
+       ; body       <- dsMcStmts stmts_rest
+       ; n_tup_var  <- newSysLocalDs n_tup_ty
+       ; tup_n_var  <- newSysLocalDs tup_n_ty
+       ; tup_n_expr <- mkMcUnzipM form fmap_op n_tup_var from_bndr_tys
+       ; us         <- newUniqueSupply
+       ; let rhs'  = mkApps usingExpr' usingArgs
+             body' = mkTupleCase us to_bndrs body tup_n_var tup_n_expr
+                  
+       ; return (mkApps bind_op' [rhs', Lam n_tup_var body']) }
+
+-- Parallel statements. Use `Control.Monad.Zip.mzip` to zip parallel
+-- statements, for example:
+--
+--   [ body | qs1 | qs2 | qs3 ]
+--     ->  [ body | (bndrs1, (bndrs2, bndrs3)) 
+--                     <- [bndrs1 | qs1] `mzip` ([bndrs2 | qs2] `mzip` [bndrs3 | qs3]) ]
+--
+-- where `mzip` has type
+--   mzip :: forall a b. m a -> m b -> m (a,b)
+-- NB: we need a polymorphic mzip because we call it several times
+
+dsMcStmt (ParStmt pairs mzip_op bind_op return_op) stmts_rest
+ = do  { exps_w_tys  <- mapM ds_inner pairs   -- Pairs (exp :: m ty, ty)
+       ; mzip_op'    <- dsExpr mzip_op
+
+       ; let -- The pattern variables
+             pats = map (mkBigLHsVarPatTup . snd) pairs
+             -- Pattern with tuples of variables
+             -- [v1,v2,v3]  =>  (v1, (v2, v3))
+             pat = foldr1 (\p1 p2 -> mkLHsPatTup [p1, p2]) pats
+            (rhs, _) = foldr1 (\(e1,t1) (e2,t2) -> 
+                                 (mkApps mzip_op' [Type t1, Type t2, e1, e2],
+                                  mkBoxedTupleTy [t1,t2])) 
+                               exps_w_tys
+
+       ; dsMcBindStmt pat rhs bind_op noSyntaxExpr stmts_rest }
+  where
+    ds_inner (stmts, bndrs) = do { exp <- dsInnerMonadComp stmts bndrs mono_ret_op
+                                 ; return (exp, tup_ty) }
+       where 
+         mono_ret_op = HsWrap (WpTyApp tup_ty) return_op
+         tup_ty      = mkBigCoreVarTupTy bndrs
+
+dsMcStmt stmt _ = pprPanic "dsMcStmt: unexpected stmt" (ppr stmt)
+
+
+matchTuple :: [Id] -> CoreExpr -> DsM CoreExpr
+-- (matchTuple [a,b,c] body)
+--       returns the Core term
+--  \x. case x of (a,b,c) -> body 
+matchTuple ids body
+  = do { us <- newUniqueSupply
+       ; tup_id <- newSysLocalDs (mkBigCoreVarTupTy ids)
+       ; return (Lam tup_id $ mkTupleCase us ids body tup_id (Var tup_id)) }
+
+-- general `rhs' >>= \pat -> stmts` desugaring where `rhs'` is already a
+-- desugared `CoreExpr`
+dsMcBindStmt :: LPat Id
+             -> CoreExpr        -- ^ the desugared rhs of the bind statement
+             -> SyntaxExpr Id
+             -> SyntaxExpr Id
+             -> [LStmt Id]
+             -> DsM CoreExpr
+dsMcBindStmt pat rhs' bind_op fail_op stmts
+  = do  { body     <- dsMcStmts stmts 
+        ; bind_op' <- dsExpr bind_op
+        ; var      <- selectSimpleMatchVarL pat
+        ; let bind_ty = exprType bind_op'      -- rhs -> (pat -> res1) -> res2
+              res1_ty = funResultTy (funArgTy (funResultTy bind_ty))
+        ; match <- matchSinglePat (Var var) (StmtCtxt DoExpr) pat
+                                  res1_ty (cantFailMatchResult body)
+        ; match_code <- handle_failure pat match fail_op
+        ; return (mkApps bind_op' [rhs', Lam var match_code]) }
+
+  where
+    -- In a monad comprehension expression, pattern-match failure just calls
+    -- the monadic `fail` rather than throwing an exception
+    handle_failure pat match fail_op
+      | matchCanFail match
+        = do { fail_op' <- dsExpr fail_op
+             ; fail_msg <- mkStringExpr (mk_fail_msg pat)
+             ; extractMatchResult match (App fail_op' fail_msg) }
+      | otherwise
+        = extractMatchResult match (error "It can't fail") 
+
+    mk_fail_msg :: Located e -> String
+    mk_fail_msg pat = "Pattern match failure in monad comprehension at " ++ 
+                      showSDoc (ppr (getLoc pat))
+
+-- Desugar nested monad comprehensions, for example in `then..` constructs
+--    dsInnerMonadComp quals [a,b,c] ret_op
+-- returns the desugaring of 
+--       [ (a,b,c) | quals ]
+
+dsInnerMonadComp :: [LStmt Id]
+                 -> [Id]       -- Return a tuple of these variables
+                 -> HsExpr Id  -- The monomorphic "return" operator
+                 -> DsM CoreExpr
+dsInnerMonadComp stmts bndrs ret_op
+  = dsMcStmts (stmts ++ [noLoc (LastStmt (mkBigLHsVarTup bndrs) ret_op)])
+
+-- The `unzip` function for `GroupStmt` in a monad comprehensions
+--
+--   unzip :: m (a,b,..) -> (m a,m b,..)
+--   unzip m_tuple = ( liftM selN1 m_tuple
+--                   , liftM selN2 m_tuple
+--                   , .. )
+--
+--   mkMcUnzipM fmap ys [t1, t2]
+--     = ( fmap (selN1 :: (t1, t2) -> t1) ys
+--       , fmap (selN2 :: (t1, t2) -> t2) ys )
+
+mkMcUnzipM :: TransForm
+           -> SyntaxExpr TcId  -- fmap
+          -> Id                -- Of type n (a,b,c)
+          -> [Type]            -- [a,b,c]
+          -> DsM CoreExpr      -- Of type (n a, n b, n c)
+mkMcUnzipM ThenForm _ ys _     
+  = return (Var ys) -- No unzipping to do
+
+mkMcUnzipM _ fmap_op ys elt_tys
+  = do { fmap_op' <- dsExpr fmap_op
+       ; xs       <- mapM newSysLocalDs elt_tys
+       ; let tup_ty = mkBigCoreTupTy elt_tys
+       ; tup_xs   <- newSysLocalDs tup_ty
+       ; let mk_elt i = mkApps fmap_op'  -- fmap :: forall a b. (a -> b) -> n a -> n b
+                           [ Type tup_ty, Type (elt_tys !! i)
+                           , mk_sel i, Var ys]
+
+             mk_sel n = Lam tup_xs $ 
+                        mkTupleSelector xs (xs !! n) tup_xs (Var tup_xs)
+
+       ; return (mkBigCoreTup (map mk_elt [0..length elt_tys - 1])) }
+\end{code}
index e34c696..e68173a 100644 (file)
@@ -721,23 +721,19 @@ repE (HsLet bs e)         = do { (ss,ds) <- repBinds bs
                               ; wrapGenSyms ss z }
 
 -- FIXME: I haven't got the types here right yet
-repE e@(HsDo ctxt sts body _) 
+repE e@(HsDo ctxt sts _) 
  | case ctxt of { DoExpr -> True; GhciStmt -> True; _ -> False }
  = do { (ss,zs) <- repLSts sts; 
-       body'   <- addBinds ss $ repLE body;
-       ret     <- repNoBindSt body';   
-        e'      <- repDoE (nonEmptyCoreList (zs ++ [ret]));
+        e'      <- repDoE (nonEmptyCoreList zs);
         wrapGenSyms ss e' }
 
  | ListComp <- ctxt
  = do { (ss,zs) <- repLSts sts; 
-       body'   <- addBinds ss $ repLE body;
-       ret     <- repNoBindSt body';   
-        e'      <- repComp (nonEmptyCoreList (zs ++ [ret]));
+        e'      <- repComp (nonEmptyCoreList zs);
         wrapGenSyms ss e' }
 
   | otherwise
-  = notHandled "mdo and [: :]" (ppr e)
+  = notHandled "mdo, monad comprehension and [: :]" (ppr e)
 
 repE (ExplicitList _ es) = do { xs <- repLEs es; repListExp xs }
 repE e@(ExplicitPArr _ _) = notHandled "Parallel arrays" (ppr e)
@@ -817,7 +813,7 @@ repGuards other
      wrapGenSyms (concat xs) gd }
   where 
     process :: LGRHS Name -> DsM ([GenSymBind], (Core (TH.Q (TH.Guard, TH.Exp))))
-    process (L _ (GRHS [L _ (ExprStmt e1 _ _)] e2))
+    process (L _ (GRHS [L _ (ExprStmt e1 _ _ _)] e2))
            = do { x <- repLNormalGE e1 e2;
                   return ([], x) }
     process (L _ (GRHS ss rhs))
@@ -876,7 +872,7 @@ repSts (LetStmt bs : ss) =
       ; z <- repLetSt ds
       ; (ss2,zs) <- addBinds ss1 (repSts ss)
       ; return (ss1++ss2, z : zs) } 
-repSts (ExprStmt e _ _ : ss) =       
+repSts (ExprStmt e _ _ _ : ss) =       
    do { e2 <- repLE e
       ; z <- repNoBindSt e2 
       ; (ss2,zs) <- repSts ss
index b5e6c41..5933e9d 100644 (file)
@@ -522,12 +522,15 @@ cvtHsDo do_or_lc stmts
   | null stmts = failWith (ptext (sLit "Empty stmt list in do-block"))
   | otherwise
   = do { stmts' <- cvtStmts stmts
-       ; body <- case last stmts' of
-                   L _ (ExprStmt body _ _) -> return body
-                    stmt' -> failWith (bad_last stmt')
-       ; return $ HsDo do_or_lc (init stmts') body void }
+        ; let Just (stmts'', last') = snocView stmts'
+        
+       ; last'' <- case last' of
+                     L loc (ExprStmt body _ _ _) -> return (L loc (mkLastStmt body))
+                      _ -> failWith (bad_last last')
+
+       ; return $ HsDo do_or_lc (stmts'' ++ [last'']) void }
   where
-    bad_last stmt = vcat [ ptext (sLit "Illegal last statement of") <+> pprStmtContext do_or_lc <> colon
+    bad_last stmt = vcat [ ptext (sLit "Illegal last statement of") <+> pprAStmtContext do_or_lc <> colon
                          , nest 2 $ Outputable.ppr stmt
                         , ptext (sLit "(It should be an expression.)") ]
                
@@ -539,7 +542,7 @@ cvtStmt (NoBindS e)    = do { e' <- cvtl e; returnL $ mkExprStmt e' }
 cvtStmt (TH.BindS p e) = do { p' <- cvtPat p; e' <- cvtl e; returnL $ mkBindStmt p' e' }
 cvtStmt (TH.LetS ds)   = do { ds' <- cvtLocalDecs (ptext (sLit "a let binding")) ds
                             ; returnL $ LetStmt ds' }
-cvtStmt (TH.ParS dss)  = do { dss' <- mapM cvt_one dss; returnL $ ParStmt dss' }
+cvtStmt (TH.ParS dss)  = do { dss' <- mapM cvt_one dss; returnL $ ParStmt dss' noSyntaxExpr noSyntaxExpr noSyntaxExpr }
                       where
                         cvt_one ds = do { ds' <- cvtStmts ds; return (ds', undefined) }
 
index 06616f1..9c88783 100644 (file)
@@ -23,6 +23,8 @@ import Name
 import BasicTypes
 import DataCon
 import SrcLoc
+import Util( dropTail )
+import StaticFlags( opt_PprStyle_Debug )
 import Outputable
 import FastString
 
@@ -146,8 +148,6 @@ data HsExpr id
                                      -- because in this context we never use
                                      -- the PatGuard or ParStmt variant
                 [LStmt id]           -- "do":one or more stmts
-                (LHsExpr id)         -- The body; the last expression in the
-                                     -- 'do' of [ body | ... ] in a list comp
                 PostTcType           -- Type of the whole expression
 
   | ExplicitList                -- syntactic list
@@ -439,7 +439,7 @@ ppr_expr (HsLet binds expr)
   = sep [hang (ptext (sLit "let")) 2 (pprBinds binds),
          hang (ptext (sLit "in"))  2 (ppr expr)]
 
-ppr_expr (HsDo do_or_list_comp stmts body _) = pprDo do_or_list_comp stmts body
+ppr_expr (HsDo do_or_list_comp stmts _) = pprDo do_or_list_comp stmts
 
 ppr_expr (ExplicitList _ exprs)
   = brackets (pprDeeperList fsep (punctuate comma (map ppr_lexpr exprs)))
@@ -575,7 +575,7 @@ pprParendExpr expr
       HsPar {}          -> pp_as_was
       HsBracket {}      -> pp_as_was
       HsBracketOut _ [] -> pp_as_was
-      HsDo sc _ _ _
+      HsDo sc _ _
        | isListCompExpr sc -> pp_as_was
       _                    -> parens pp_as_was
 
@@ -830,51 +830,59 @@ type LStmtLR idL idR = Located (StmtLR idL idR)
 
 type Stmt id = StmtLR id id
 
--- The SyntaxExprs in here are used *only* for do-notation, which
--- has rebindable syntax.  Otherwise they are unused.
+-- The SyntaxExprs in here are used *only* for do-notation and monad
+-- comprehensions, which have rebindable syntax. Otherwise they are unused.
 data StmtLR idL idR
-  = BindStmt (LPat idL)
+  = LastStmt  -- Always the last Stmt in ListComp, MonadComp, PArrComp, 
+             -- and (after the renamer) DoExpr, MDoExpr
+              -- Not used for GhciStmt, PatGuard, which scope over other stuff
+               (LHsExpr idR)
+               (SyntaxExpr idR)   -- The return operator, used only for MonadComp
+                                 -- For ListComp, PArrComp, we use the baked-in 'return'
+                                 -- For DoExpr, MDoExpr, we don't appply a 'return' at all
+                                 -- See Note [Monad Comprehensions]
+  | BindStmt (LPat idL)
              (LHsExpr idR)
-             (SyntaxExpr idR) -- The (>>=) operator
+             (SyntaxExpr idR) -- The (>>=) operator; see Note [The type of bind]
              (SyntaxExpr idR) -- The fail operator
              -- The fail operator is noSyntaxExpr
              -- if the pattern match can't fail
 
   | ExprStmt (LHsExpr idR)     -- See Note [ExprStmt]
              (SyntaxExpr idR) -- The (>>) operator
+             (SyntaxExpr idR) -- The `guard` operator; used only in MonadComp
+                              -- See notes [Monad Comprehensions]
              PostTcType       -- Element type of the RHS (used for arrows)
 
   | LetStmt  (HsLocalBindsLR idL idR)
 
-  -- ParStmts only occur in a list comprehension
+  -- ParStmts only occur in a list/monad comprehension
   | ParStmt  [([LStmt idL], [idR])]
-  -- After renaming, the ids are the binders bound by the stmts and used
-  -- after them
-
-  -- "qs, then f by e" ==> TransformStmt qs binders f (Just e)
-  -- "qs, then f"      ==> TransformStmt qs binders f Nothing
-  | TransformStmt 
-         [LStmt idL]   -- Stmts are the ones to the left of the 'then'
-
-         [idR]                 -- After renaming, the IDs are the binders occurring 
-                       -- within this transform statement that are used after it
-
-         (LHsExpr idR)         -- "then f"
-
-         (Maybe (LHsExpr idR)) -- "by e" (optional)
-
-  | GroupStmt 
-         [LStmt idL]      -- Stmts to the *left* of the 'group'
-                         -- which generates the tuples to be grouped
-
-         [(idR, idR)]    -- See Note [GroupStmt binder map]
+             (SyntaxExpr idR)           -- Polymorphic `mzip` for monad comprehensions
+             (SyntaxExpr idR)           -- The `>>=` operator
+             (SyntaxExpr idR)           -- Polymorphic `return` operator
+                                       -- with type (forall a. a -> m a)
+                                        -- See notes [Monad Comprehensions]
+           -- After renaming, the ids are the binders 
+           -- bound by the stmts and used after themp
+
+  | TransStmt {
+      trS_form  :: TransForm,
+      trS_stmts :: [LStmt idL],      -- Stmts to the *left* of the 'group'
+                                     -- which generates the tuples to be grouped
+
+      trS_bndrs :: [(idR, idR)],     -- See Note [TransStmt binder map]
                                
-         (Maybe (LHsExpr idR))         -- "by e" (optional)
+      trS_using :: LHsExpr idR,
+      trS_by :: Maybe (LHsExpr idR),   -- "by e" (optional)
+       -- Invariant: if trS_form = GroupBy, then grp_by = Just e
 
-         (Either               -- "using f"
-             (LHsExpr idR)     --   Left f  => explicit "using f"
-             (SyntaxExpr idR)) --   Right f => implicit; filled in with 'groupWith'
-                                                       
+      trS_ret :: SyntaxExpr idR,      -- The monomorphic 'return' function for 
+                                       -- the inner monad comprehensions
+      trS_bind :: SyntaxExpr idR,     -- The '(>>=)' operator
+      trS_fmap :: SyntaxExpr idR      -- The polymorphic 'fmap' function for desugaring
+                                      -- Only for 'group' forms
+    }                                  -- See Note [Monad Comprehensions]
 
   -- Recursive statement (see Note [How RecStmt works] below)
   | RecStmt
@@ -905,20 +913,44 @@ data StmtLR idL idR
                                      -- because the Id may be *polymorphic*, but
                                      -- the returned thing has to be *monomorphic*, 
                                     -- so they may be type applications
+
+      , recS_ret_ty :: PostTcType    -- The type of of do { stmts; return (a,b,c) }
+                                    -- With rebindable syntax the type might not
+                                    -- be quite as simple as (m (tya, tyb, tyc)).
       }
   deriving (Data, Typeable)
+
+data TransForm         -- The 'f' below is the 'using' function, 'e' is the by function
+  = ThenForm           -- then f          or    then f by e
+  | GroupFormU         -- group using f   or    group using f by e
+  | GroupFormB         -- group by e  
+      -- In the GroupByFormB, trS_using is filled in with
+      --    'groupWith' (list comprehensions) or 
+      --    'groupM' (monad comprehensions)
+  deriving (Data, Typeable)
 \end{code}
 
-Note [GroupStmt binder map]
+Note [The type of bind in Stmts]
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+Some Stmts, notably BindStmt, keep the (>>=) bind operator.  
+We do NOT assume that it has type  
+    (>>=) :: m a -> (a -> m b) -> m b
+In some cases (see Trac #303, #1537) it might have a more 
+exotic type, such as
+    (>>=) :: m i j a -> (a -> m j k b) -> m i k b
+So we must be careful not to make assumptions about the type.
+In particular, the monad may not be uniform throughout.
+
+Note [TransStmt binder map]
 ~~~~~~~~~~~~~~~~~~~~~~~~~~~
-The [(idR,idR)] in a GroupStmt behaves as follows:
+The [(idR,idR)] in a TransStmt behaves as follows:
 
   * Before renaming: []
 
   * After renaming: 
          [ (x27,x27), ..., (z35,z35) ]
     These are the variables 
-        bound by the stmts to the left of the 'group'
+       bound by the stmts to the left of the 'group'
        and used either in the 'by' clause, 
                 or     in the stmts following the 'group'
     Each item is a pair of identical variables.
@@ -952,7 +984,13 @@ depends on the context.  Consider the following contexts:
                 E :: Bool
           Translation: if E then fail else ...
 
-Array comprehensions are handled like list comprehensions -=chak
+        A monad comprehension of type (m res_ty)
+        ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+        * ExprStmt E Bool:   [ .. | .... E ]
+                E :: Bool
+          Translation: guard E >> ...
+
+Array comprehensions are handled like list comprehensions.
 
 Note [How RecStmt works]
 ~~~~~~~~~~~~~~~~~~~~~~~~
@@ -993,23 +1031,60 @@ A (RecStmt stmts) types as if you had written
 where v1..vn are the later_ids
       r1..rm are the rec_ids
 
+Note [Monad Comprehensions]
+~~~~~~~~~~~~~~~~~~~~~~~~~~~
+Monad comprehensions require separate functions like 'return' and
+'>>=' for desugaring. These functions are stored in the statements
+used in monad comprehensions. For example, the 'return' of the 'LastStmt'
+expression is used to lift the body of the monad comprehension:
+
+  [ body | stmts ]
+   =>
+  stmts >>= \bndrs -> return body
+
+In transform and grouping statements ('then ..' and 'then group ..') the
+'return' function is required for nested monad comprehensions, for example:
+
+  [ body | stmts, then f, rest ]
+   =>
+  f [ env | stmts ] >>= \bndrs -> [ body | rest ]
+
+ExprStmts require the 'Control.Monad.guard' function for boolean
+expressions:
+
+  [ body | exp, stmts ]
+   =>
+  guard exp >> [ body | stmts ]
+
+Grouping/parallel statements require the 'Control.Monad.Group.groupM' and
+'Control.Monad.Zip.mzip' functions:
+
+  [ body | stmts, then group by e, rest]
+   =>
+  groupM [ body | stmts ] >>= \bndrs -> [ body | rest ]
+
+  [ body | stmts1 | stmts2 | .. ]
+   =>
+  mzip stmts1 (mzip stmts2 (..)) >>= \(bndrs1, (bndrs2, ..)) -> return body
+
+In any other context than 'MonadComp', the fields for most of these
+'SyntaxExpr's stay bottom.
+
 
 \begin{code}
 instance (OutputableBndr idL, OutputableBndr idR) => Outputable (StmtLR idL idR) where
     ppr stmt = pprStmt stmt
 
 pprStmt :: (OutputableBndr idL, OutputableBndr idR) => (StmtLR idL idR) -> SDoc
+pprStmt (LastStmt expr _)         = ifPprDebug (ptext (sLit "[last]")) <+> ppr expr
 pprStmt (BindStmt pat expr _ _)   = hsep [ppr pat, ptext (sLit "<-"), ppr expr]
 pprStmt (LetStmt binds)           = hsep [ptext (sLit "let"), pprBinds binds]
-pprStmt (ExprStmt expr _ _)       = ppr expr
-pprStmt (ParStmt stmtss)          = hsep (map doStmts stmtss)
+pprStmt (ExprStmt expr _ _ _)     = ppr expr
+pprStmt (ParStmt stmtss _ _ _)    = hsep (map doStmts stmtss)
   where doStmts stmts = ptext (sLit "| ") <> ppr stmts
 
-pprStmt (TransformStmt stmts bndrs using by)
-  = sep (ppr_lc_stmts stmts ++ [pprTransformStmt bndrs using by])
-
-pprStmt (GroupStmt stmts _ by using) 
-  = sep (ppr_lc_stmts stmts ++ [pprGroupStmt by using])
+pprStmt (TransStmt { trS_stmts = stmts, trS_by = by, trS_using = using, trS_form = form })
+  = sep (ppr_lc_stmts stmts ++ [pprTransStmt by using form])
 
 pprStmt (RecStmt { recS_stmts = segment, recS_rec_ids = rec_ids
                  , recS_later_ids = later_ids })
@@ -1024,40 +1099,47 @@ pprTransformStmt bndrs using by
         , nest 2 (ppr using)
         , nest 2 (pprBy by)]
 
-pprGroupStmt :: OutputableBndr id => Maybe (LHsExpr id)
-                                  -> Either (LHsExpr id) (SyntaxExpr is)
+pprTransStmt :: OutputableBndr id => Maybe (LHsExpr id)
+                                  -> LHsExpr id -> TransForm
                                  -> SDoc
-pprGroupStmt by using 
-  = sep [ ptext (sLit "then group"), nest 2 (pprBy by), nest 2 (ppr_using using)]
-  where
-    ppr_using (Right _) = empty
-    ppr_using (Left e)  = ptext (sLit "using") <+> ppr e
+pprTransStmt by using ThenForm
+  = sep [ ptext (sLit "then"), nest 2 (ppr using), nest 2 (pprBy by)]
+pprTransStmt by _ GroupFormB
+  = sep [ ptext (sLit "then group"), nest 2 (pprBy by) ]
+pprTransStmt by using GroupFormU
+  = sep [ ptext (sLit "then group"), nest 2 (pprBy by), nest 2 (ptext (sLit "using") <+> ppr using)]
 
 pprBy :: OutputableBndr id => Maybe (LHsExpr id) -> SDoc
 pprBy Nothing  = empty
 pprBy (Just e) = ptext (sLit "by") <+> ppr e
 
-pprDo :: OutputableBndr id => HsStmtContext any -> [LStmt id] -> LHsExpr id -> SDoc
-pprDo DoExpr      stmts body = ptext (sLit "do")  <+> ppr_do_stmts stmts body
-pprDo GhciStmt    stmts body = ptext (sLit "do")  <+> ppr_do_stmts stmts body
-pprDo MDoExpr     stmts body = ptext (sLit "mdo") <+> ppr_do_stmts stmts body
-pprDo ListComp    stmts body = brackets    $ pprComp stmts body
-pprDo PArrComp    stmts body = pa_brackets $ pprComp stmts body
-pprDo _           _     _    = panic "pprDo" -- PatGuard, ParStmtCxt
-
-ppr_do_stmts :: OutputableBndr id => [LStmt id] -> LHsExpr id -> SDoc
+pprDo :: OutputableBndr id => HsStmtContext any -> [LStmt id] -> SDoc
+pprDo DoExpr      stmts = ptext (sLit "do")  <+> ppr_do_stmts stmts
+pprDo GhciStmt    stmts = ptext (sLit "do")  <+> ppr_do_stmts stmts
+pprDo ArrowExpr   stmts = ptext (sLit "do")  <+> ppr_do_stmts stmts
+pprDo MDoExpr     stmts = ptext (sLit "mdo") <+> ppr_do_stmts stmts
+pprDo ListComp    stmts = brackets    $ pprComp stmts
+pprDo PArrComp    stmts = pa_brackets $ pprComp stmts
+pprDo MonadComp   stmts = brackets    $ pprComp stmts
+pprDo _           _     = panic "pprDo" -- PatGuard, ParStmtCxt
+
+ppr_do_stmts :: OutputableBndr id => [LStmt id] -> SDoc
 -- Print a bunch of do stmts, with explicit braces and semicolons,
 -- so that we are not vulnerable to layout bugs
-ppr_do_stmts stmts body
-  = lbrace <+> pprDeeperList vcat ([ppr s <> semi | s <- stmts] ++ [ppr body])
+ppr_do_stmts stmts 
+  = lbrace <+> pprDeeperList vcat (punctuate semi (map ppr stmts))
            <+> rbrace
 
 ppr_lc_stmts :: OutputableBndr id => [LStmt id] -> [SDoc]
 ppr_lc_stmts stmts = [ppr s <> comma | s <- stmts]
 
-pprComp :: OutputableBndr id => [LStmt id] -> LHsExpr id -> SDoc
-pprComp quals body       -- Prints:  body | qual1, ..., qualn 
-  = hang (ppr body <+> char '|') 2 (interpp'SP quals)
+pprComp :: OutputableBndr id => [LStmt id] -> SDoc
+pprComp quals    -- Prints:  body | qual1, ..., qualn 
+  | not (null quals)
+  , L _ (LastStmt body _) <- last quals
+  = hang (ppr body <+> char '|') 2 (interpp'SP (dropTail 1 quals))
+  | otherwise
+  = pprPanic "pprComp" (interpp'SP quals)
 \end{code}
 
 %************************************************************************
@@ -1175,26 +1257,33 @@ data HsMatchContext id  -- Context of a Match
 
 data HsStmtContext id
   = ListComp
-  | DoExpr
-  | GhciStmt                            -- A command-line Stmt in GHCi pat <- rhs
-  | MDoExpr                              -- Recursive do-expression
+  | MonadComp
   | PArrComp                             -- Parallel array comprehension
+
+  | DoExpr                              -- do { ... }
+  | MDoExpr                              -- mdo { ... }  ie recursive do-expression 
+  | ArrowExpr                           -- do-notation in an arrow-command context
+
+  | GhciStmt                            -- A command-line Stmt in GHCi pat <- rhs
   | PatGuard (HsMatchContext id)         -- Pattern guard for specified thing
   | ParStmtCtxt (HsStmtContext id)       -- A branch of a parallel stmt
-  | TransformStmtCtxt (HsStmtContext id) -- A branch of a transform stmt
+  | TransStmtCtxt (HsStmtContext id)     -- A branch of a transform stmt
   deriving (Data, Typeable)
 \end{code}
 
 \begin{code}
-isDoExpr :: HsStmtContext id -> Bool
-isDoExpr DoExpr  = True
-isDoExpr MDoExpr = True
-isDoExpr _       = False
-
 isListCompExpr :: HsStmtContext id -> Bool
-isListCompExpr ListComp = True
-isListCompExpr PArrComp = True
-isListCompExpr _        = False
+-- Uses syntax [ e | quals ]
+isListCompExpr ListComp  = True
+isListCompExpr PArrComp  = True
+isListCompExpr MonadComp = True
+isListCompExpr _         = False
+
+isMonadCompExpr :: HsStmtContext id -> Bool
+isMonadCompExpr MonadComp            = True
+isMonadCompExpr (ParStmtCtxt ctxt)   = isMonadCompExpr ctxt
+isMonadCompExpr (TransStmtCtxt ctxt) = isMonadCompExpr ctxt
+isMonadCompExpr _                    = False
 \end{code}
 
 \begin{code}
@@ -1231,33 +1320,41 @@ pprMatchContextNoun ProcExpr        = ptext (sLit "arrow abstraction")
 pprMatchContextNoun (StmtCtxt ctxt) = ptext (sLit "pattern binding in")
                                       $$ pprStmtContext ctxt
 
-pprStmtContext :: Outputable id => HsStmtContext id -> SDoc
+-----------------
+pprAStmtContext, pprStmtContext :: Outputable id => HsStmtContext id -> SDoc
+pprAStmtContext ctxt = article <+> pprStmtContext ctxt
+  where
+    pp_an = ptext (sLit "an")
+    pp_a  = ptext (sLit "a")
+    article = case ctxt of
+                  MDoExpr  -> pp_an
+                  PArrComp -> pp_an
+                 GhciStmt -> pp_an
+                  _        -> pp_a
+
+
+-----------------
+pprStmtContext GhciStmt        = ptext (sLit "interactive GHCi command")
+pprStmtContext DoExpr          = ptext (sLit "'do' block")
+pprStmtContext MDoExpr         = ptext (sLit "'mdo' block")
+pprStmtContext ArrowExpr       = ptext (sLit "'do' block in an arrow command")
+pprStmtContext ListComp        = ptext (sLit "list comprehension")
+pprStmtContext MonadComp       = ptext (sLit "monad comprehension")
+pprStmtContext PArrComp        = ptext (sLit "array comprehension")
+pprStmtContext (PatGuard ctxt) = ptext (sLit "pattern guard for") $$ pprMatchContext ctxt
+
+-- Drop the inner contexts when reporting errors, else we get
+--     Unexpected transform statement
+--     in a transformed branch of
+--          transformed branch of
+--          transformed branch of monad comprehension
 pprStmtContext (ParStmtCtxt c)
- = sep [ptext (sLit "a parallel branch of"), pprStmtContext c]
-pprStmtContext (TransformStmtCtxt c)
- = sep [ptext (sLit "a transformed branch of"), pprStmtContext c]
-pprStmtContext (PatGuard ctxt)
- = ptext (sLit "a pattern guard for") $$ pprMatchContext ctxt
-pprStmtContext GhciStmt        = ptext (sLit "an interactive GHCi command")
-pprStmtContext DoExpr          = ptext (sLit "a 'do' expression")
-pprStmtContext MDoExpr         = ptext (sLit "an 'mdo' expression")
-pprStmtContext ListComp        = ptext (sLit "a list comprehension")
-pprStmtContext PArrComp        = ptext (sLit "an array comprehension")
-
-{-
-pprMatchRhsContext (FunRhs fun) = ptext (sLit "a right-hand side of function") <+> quotes (ppr fun)
-pprMatchRhsContext CaseAlt      = ptext (sLit "the body of a case alternative")
-pprMatchRhsContext PatBindRhs   = ptext (sLit "the right-hand side of a pattern binding")
-pprMatchRhsContext LambdaExpr   = ptext (sLit "the body of a lambda")
-pprMatchRhsContext ProcExpr     = ptext (sLit "the body of a proc")
-pprMatchRhsContext other        = panic "pprMatchRhsContext"    -- RecUpd, StmtCtxt
-
--- Used for the result statement of comprehension
--- e.g. the 'e' in      [ e | ... ]
---      or the 'r' in   f x = r
-pprStmtResultContext (PatGuard ctxt) = pprMatchRhsContext ctxt
-pprStmtResultContext other           = ptext (sLit "the result of") <+> pprStmtContext other
--}
+ | opt_PprStyle_Debug = sep [ptext (sLit "parallel branch of"), pprAStmtContext c]
+ | otherwise          = pprStmtContext c
+pprStmtContext (TransStmtCtxt c)
+ | opt_PprStyle_Debug = sep [ptext (sLit "transformed branch of"), pprAStmtContext c]
+ | otherwise          = pprStmtContext c
+
 
 -- Used to generate the string for a *runtime* error message
 matchContextErrString :: Outputable id => HsMatchContext id -> SDoc
@@ -1268,14 +1365,16 @@ matchContextErrString RecUpd                     = ptext (sLit "record update")
 matchContextErrString LambdaExpr                 = ptext (sLit "lambda")
 matchContextErrString ProcExpr                   = ptext (sLit "proc")
 matchContextErrString ThPatQuote                 = panic "matchContextErrString"  -- Not used at runtime
-matchContextErrString (StmtCtxt (ParStmtCtxt c)) = matchContextErrString (StmtCtxt c)
-matchContextErrString (StmtCtxt (TransformStmtCtxt c)) = matchContextErrString (StmtCtxt c)
-matchContextErrString (StmtCtxt (PatGuard _))    = ptext (sLit "pattern guard")
-matchContextErrString (StmtCtxt GhciStmt)        = ptext (sLit "interactive GHCi command")
-matchContextErrString (StmtCtxt DoExpr)          = ptext (sLit "'do' expression")
-matchContextErrString (StmtCtxt MDoExpr)         = ptext (sLit "'mdo' expression")
-matchContextErrString (StmtCtxt ListComp)        = ptext (sLit "list comprehension")
-matchContextErrString (StmtCtxt PArrComp)        = ptext (sLit "array comprehension")
+matchContextErrString (StmtCtxt (ParStmtCtxt c))   = matchContextErrString (StmtCtxt c)
+matchContextErrString (StmtCtxt (TransStmtCtxt c)) = matchContextErrString (StmtCtxt c)
+matchContextErrString (StmtCtxt (PatGuard _))      = ptext (sLit "pattern guard")
+matchContextErrString (StmtCtxt GhciStmt)          = ptext (sLit "interactive GHCi command")
+matchContextErrString (StmtCtxt DoExpr)            = ptext (sLit "'do' block")
+matchContextErrString (StmtCtxt ArrowExpr)         = ptext (sLit "'do' block")
+matchContextErrString (StmtCtxt MDoExpr)           = ptext (sLit "'mdo' block")
+matchContextErrString (StmtCtxt ListComp)          = ptext (sLit "list comprehension")
+matchContextErrString (StmtCtxt MonadComp)         = ptext (sLit "monad comprehension")
+matchContextErrString (StmtCtxt PArrComp)          = ptext (sLit "array comprehension")
 \end{code}
 
 \begin{code}
@@ -1286,11 +1385,16 @@ pprMatchInCtxt ctxt match  = hang (ptext (sLit "In") <+> pprMatchContext ctxt <>
 
 pprStmtInCtxt :: (OutputableBndr idL, OutputableBndr idR)
               => HsStmtContext idL -> StmtLR idL idR -> SDoc
-pprStmtInCtxt ctxt stmt = hang (ptext (sLit "In a stmt of") <+> pprStmtContext ctxt <> colon)
-                         4 (ppr_stmt stmt)
+pprStmtInCtxt ctxt (LastStmt e _)
+  | isListCompExpr ctxt      -- For [ e | .. ], do not mutter about "stmts"
+  = hang (ptext (sLit "In the expression:")) 2 (ppr e)
+
+pprStmtInCtxt ctxt stmt 
+  = hang (ptext (sLit "In a stmt of") <+> pprAStmtContext ctxt <> colon)
+       2 (ppr_stmt stmt)
   where
     -- For Group and Transform Stmts, don't print the nested stmts!
-    ppr_stmt (GroupStmt _ _ by using)         = pprGroupStmt by using
-    ppr_stmt (TransformStmt _ bndrs using by) = pprTransformStmt bndrs using by
-    ppr_stmt stmt                             = pprStmt stmt
+    ppr_stmt (TransStmt { trS_by = by, trS_using = using
+                        , trS_form = form }) = pprTransStmt by using form
+    ppr_stmt stmt = pprStmt stmt
 \end{code}
index 0874dda..4a565ff 100644 (file)
@@ -63,8 +63,7 @@ instance Eq HsLit where
 data HsOverLit id      -- An overloaded literal
   = OverLit {
        ol_val :: OverLitVal, 
-       ol_rebindable :: Bool,          -- True <=> rebindable syntax
-                                       -- False <=> standard syntax
+       ol_rebindable :: Bool,          -- Note [ol_rebindable]
        ol_witness :: SyntaxExpr id,    -- Note [Overloaded literal witnesses]
        ol_type :: PostTcType }
   deriving (Data, Typeable)
@@ -79,6 +78,19 @@ overLitType :: HsOverLit a -> Type
 overLitType = ol_type
 \end{code}
 
+Note [ol_rebindable]
+~~~~~~~~~~~~~~~~~~~~
+The ol_rebindable field is True if this literal is actually 
+using rebindable syntax.  Specifically:
+
+  False iff ol_witness is the standard one
+  True  iff ol_witness is non-standard
+
+Equivalently it's True if
+  a) RebindableSyntax is on
+  b) the witness for fromInteger/fromRational/fromString
+     that happens to be in scope isn't the standard one
+
 Note [Overloaded literal witnesses]
 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 *Before* type checking, the SyntaxExpr in an HsOverLit is the
@@ -89,7 +101,7 @@ This witness should replace the literal.
 
 This dual role is unusual, because we're replacing 'fromInteger' with 
 a call to fromInteger.  Reason: it allows commoning up of the fromInteger
-calls, which wouldn't be possible if the desguarar made the application
+calls, which wouldn't be possible if the desguarar made the application.
 
 The PostTcType in each branch records the type the overload literal is
 found to have.
index 78b5887..3efcd59 100644 (file)
@@ -122,7 +122,9 @@ data Pat id
   | LitPat         HsLit               -- Used for *non-overloaded* literal patterns:
                                        -- Int#, Char#, Int, Char, String, etc.
 
-  | NPat           (HsOverLit id)              -- ALWAYS positive
+  | NPat               -- Used for all overloaded literals, 
+                       -- including overloaded strings with -XOverloadedStrings
+                    (HsOverLit id)             -- ALWAYS positive
                    (Maybe (SyntaxExpr id))     -- Just (Name of 'negate') for negative
                                                -- patterns, Nothing otherwise
                    (SyntaxExpr id)             -- Equality checker, of type t->t->Bool
index 13f3cd7..5e8dda3 100644 (file)
@@ -21,7 +21,7 @@ module HsUtils(
   mkMatchGroup, mkMatch, mkHsLam, mkHsIf,
   mkHsWrap, mkLHsWrap, mkHsWrapCoI, mkLHsWrapCoI,
   coiToHsWrapper, mkHsLams, mkHsDictLet,
-  mkHsOpApp, mkHsDo, mkHsWrapPat, mkHsWrapPatCoI,
+  mkHsOpApp, mkHsDo, mkHsComp, mkHsWrapPat, mkHsWrapPatCoI, 
 
   nlHsTyApp, nlHsVar, nlHsLit, nlHsApp, nlHsApps, nlHsIntLit, nlHsVarApps, 
   nlHsDo, nlHsOpApp, nlHsLam, nlHsPar, nlHsIf, nlHsCase, nlList,
@@ -42,8 +42,8 @@ module HsUtils(
   nlHsAppTy, nlHsTyVar, nlHsFunTy, nlHsTyConApp, 
 
   -- Stmts
-  mkTransformStmt, mkTransformByStmt, mkExprStmt, mkBindStmt,
-  mkGroupUsingStmt, mkGroupByStmt, mkGroupByUsingStmt, 
+  mkTransformStmt, mkTransformByStmt, mkExprStmt, mkBindStmt, mkLastStmt,
+  emptyTransStmt, mkGroupUsingStmt, mkGroupByStmt, mkGroupByUsingStmt, 
   emptyRecStmt, mkRecStmt, 
 
   -- Template Haskell
@@ -190,14 +190,13 @@ mkSimpleHsAlt pat expr
 mkHsIntegral   :: Integer -> PostTcType -> HsOverLit id
 mkHsFractional :: Rational -> PostTcType -> HsOverLit id
 mkHsIsString   :: FastString -> PostTcType -> HsOverLit id
-mkHsDo         :: HsStmtContext Name -> [LStmt id] -> LHsExpr id -> HsExpr id
+mkHsDo         :: HsStmtContext Name -> [LStmt id] -> HsExpr id
+mkHsComp       :: HsStmtContext Name -> [LStmt id] -> LHsExpr id -> HsExpr id
 
 mkNPat      :: HsOverLit id -> Maybe (SyntaxExpr id) -> Pat id
 mkNPlusKPat :: Located id -> HsOverLit id -> Pat id
 
-mkTransformStmt   :: [LStmt idL] -> LHsExpr idR                -> StmtLR idL idR
-mkTransformByStmt :: [LStmt idL] -> LHsExpr idR -> LHsExpr idR -> StmtLR idL idR
-
+mkLastStmt :: LHsExpr idR -> StmtLR idL idR
 mkExprStmt :: LHsExpr idR -> StmtLR idL idR
 mkBindStmt :: LPat idL -> LHsExpr idR -> StmtLR idL idR
 
@@ -212,7 +211,10 @@ mkHsIsString   s       = OverLit (HsIsString   s)  noRebindableInfo noSyntaxExpr
 noRebindableInfo :: Bool
 noRebindableInfo = error "noRebindableInfo"    -- Just another placeholder; 
 
-mkHsDo ctxt stmts body = HsDo ctxt stmts body placeHolderType
+mkHsDo ctxt stmts = HsDo ctxt stmts placeHolderType
+mkHsComp ctxt stmts expr = mkHsDo ctxt (stmts ++ [last_stmt])
+  where
+    last_stmt = L (getLoc expr) $ mkLastStmt expr
 
 mkHsIf :: LHsExpr id -> LHsExpr id -> LHsExpr id -> HsExpr id
 mkHsIf c a b = HsIf (Just noSyntaxExpr) c a b
@@ -220,24 +222,32 @@ mkHsIf c a b = HsIf (Just noSyntaxExpr) c a b
 mkNPat lit neg     = NPat lit neg noSyntaxExpr
 mkNPlusKPat id lit = NPlusKPat id lit noSyntaxExpr noSyntaxExpr
 
-mkTransformStmt   stmts usingExpr        = TransformStmt stmts [] usingExpr Nothing
-mkTransformByStmt stmts usingExpr byExpr = TransformStmt stmts [] usingExpr (Just byExpr)
-
+mkTransformStmt   :: [LStmt idL] -> LHsExpr idR                -> StmtLR idL idR
+mkTransformByStmt :: [LStmt idL] -> LHsExpr idR -> LHsExpr idR -> StmtLR idL idR
 mkGroupUsingStmt   :: [LStmt idL]                -> LHsExpr idR -> StmtLR idL idR
 mkGroupByStmt      :: [LStmt idL] -> LHsExpr idR                -> StmtLR idL idR
 mkGroupByUsingStmt :: [LStmt idL] -> LHsExpr idR -> LHsExpr idR -> StmtLR idL idR
 
-mkGroupUsingStmt   stmts usingExpr        = GroupStmt stmts [] Nothing       (Left usingExpr)    
-mkGroupByStmt      stmts byExpr           = GroupStmt stmts [] (Just byExpr) (Right noSyntaxExpr)
-mkGroupByUsingStmt stmts byExpr usingExpr = GroupStmt stmts [] (Just byExpr) (Left usingExpr)    
-
-mkExprStmt expr            = ExprStmt expr noSyntaxExpr placeHolderType
+emptyTransStmt :: StmtLR idL idR
+emptyTransStmt = TransStmt { trS_form = undefined, trS_stmts = [], trS_bndrs = [] 
+                           , trS_by = Nothing, trS_using = noLoc noSyntaxExpr
+                           , trS_ret = noSyntaxExpr, trS_bind = noSyntaxExpr
+                           , trS_fmap = noSyntaxExpr }
+mkTransformStmt   ss u    = emptyTransStmt { trS_form = ThenForm, trS_stmts = ss, trS_using = u }
+mkTransformByStmt ss u b  = emptyTransStmt { trS_form = ThenForm, trS_stmts = ss, trS_using = u, trS_by = Just b }
+mkGroupByStmt      ss b   = emptyTransStmt { trS_form = GroupFormB, trS_stmts = ss, trS_by = Just b }
+mkGroupUsingStmt   ss u   = emptyTransStmt { trS_form = GroupFormU, trS_stmts = ss, trS_using = u }
+mkGroupByUsingStmt ss b u = emptyTransStmt { trS_form = GroupFormU, trS_stmts = ss
+                                           , trS_by = Just b, trS_using = u }
+
+mkLastStmt expr            = LastStmt expr noSyntaxExpr
+mkExprStmt expr            = ExprStmt expr noSyntaxExpr noSyntaxExpr placeHolderType
 mkBindStmt pat expr = BindStmt pat expr noSyntaxExpr noSyntaxExpr
 
 emptyRecStmt = RecStmt { recS_stmts = [], recS_later_ids = [], recS_rec_ids = []
                        , recS_ret_fn = noSyntaxExpr, recS_mfix_fn = noSyntaxExpr
                       , recS_bind_fn = noSyntaxExpr
-                       , recS_rec_rets = [] }
+                       , recS_rec_rets = [], recS_ret_ty = placeHolderType }
 
 mkRecStmt stmts = emptyRecStmt { recS_stmts = stmts }
 
@@ -327,8 +337,8 @@ nlWildConPat con = noLoc (ConPatIn (noLoc (getRdrName con))
 nlWildPat :: LPat id
 nlWildPat  = noLoc (WildPat placeHolderType)   -- Pre-typechecking
 
-nlHsDo :: HsStmtContext Name -> [LStmt id] -> LHsExpr id -> LHsExpr id
-nlHsDo ctxt stmts body = noLoc (mkHsDo ctxt stmts body)
+nlHsDo :: HsStmtContext Name -> [LStmt id] -> LHsExpr id
+nlHsDo ctxt stmts = noLoc (mkHsDo ctxt stmts)
 
 nlHsOpApp :: LHsExpr id -> id -> LHsExpr id -> LHsExpr id
 nlHsOpApp e1 op e2 = noLoc (mkHsOpApp e1 op e2)
@@ -496,12 +506,12 @@ collectStmtBinders :: StmtLR idL idR -> [idL]
   -- Id Binders for a Stmt... [but what about pattern-sig type vars]?
 collectStmtBinders (BindStmt pat _ _ _) = collectPatBinders pat
 collectStmtBinders (LetStmt binds)      = collectLocalBinders binds
-collectStmtBinders (ExprStmt _ _ _)     = []
-collectStmtBinders (ParStmt xs)         = collectLStmtsBinders
+collectStmtBinders (ExprStmt {})        = []
+collectStmtBinders (LastStmt {})        = []
+collectStmtBinders (ParStmt xs _ _ _)   = collectLStmtsBinders
                                         $ concatMap fst xs
-collectStmtBinders (TransformStmt stmts _ _ _)   = collectLStmtsBinders stmts
-collectStmtBinders (GroupStmt     stmts _ _ _)   = collectLStmtsBinders stmts
-collectStmtBinders (RecStmt { recS_stmts = ss }) = collectLStmtsBinders ss
+collectStmtBinders (TransStmt { trS_stmts = stmts }) = collectLStmtsBinders stmts
+collectStmtBinders (RecStmt { recS_stmts = ss })     = collectLStmtsBinders ss
 
 
 ----------------- Patterns --------------------------
@@ -642,12 +652,12 @@ lStmtsImplicits = hs_lstmts
     
     hs_stmt (BindStmt pat _ _ _) = lPatImplicits pat
     hs_stmt (LetStmt binds)      = hs_local_binds binds
-    hs_stmt (ExprStmt _ _ _)     = emptyNameSet
-    hs_stmt (ParStmt xs)         = hs_lstmts $ concatMap fst xs
+    hs_stmt (ExprStmt {})        = emptyNameSet
+    hs_stmt (LastStmt {})        = emptyNameSet
+    hs_stmt (ParStmt xs _ _ _)   = hs_lstmts $ concatMap fst xs
     
-    hs_stmt (TransformStmt stmts _ _ _)   = hs_lstmts stmts
-    hs_stmt (GroupStmt     stmts _ _ _)   = hs_lstmts stmts
-    hs_stmt (RecStmt { recS_stmts = ss }) = hs_lstmts ss
+    hs_stmt (TransStmt { trS_stmts = stmts }) = hs_lstmts stmts
+    hs_stmt (RecStmt { recS_stmts = ss })     = hs_lstmts ss
     
     hs_local_binds (HsValBinds val_binds) = hsValBindsImplicits val_binds
     hs_local_binds (HsIPBinds _)         = emptyNameSet
index 7e15aa4..1d2d1f5 100644 (file)
@@ -358,6 +358,7 @@ data ExtensionFlag
    | Opt_KindSignatures
    | Opt_ParallelListComp
    | Opt_TransformListComp
+   | Opt_MonadComprehensions
    | Opt_GeneralizedNewtypeDeriving
    | Opt_RecursiveDo
    | Opt_DoRec
@@ -1620,6 +1621,7 @@ xFlags = [
   ( "EmptyDataDecls",                   Opt_EmptyDataDecls, nop ),
   ( "ParallelListComp",                 Opt_ParallelListComp, nop ),
   ( "TransformListComp",                Opt_TransformListComp, nop ),
+  ( "MonadComprehensions",              Opt_MonadComprehensions, nop),
   ( "ForeignFunctionInterface",         Opt_ForeignFunctionInterface, nop ),
   ( "UnliftedFFITypes",                 Opt_UnliftedFFITypes, nop ),
   ( "GHCForeignImportPrim",             Opt_GHCForeignImportPrim, nop ),
@@ -1628,9 +1630,9 @@ xFlags = [
   ( "RankNTypes",                       Opt_RankNTypes, nop ),
   ( "ImpredicativeTypes",               Opt_ImpredicativeTypes, nop), 
   ( "TypeOperators",                    Opt_TypeOperators, nop ),
-  ( "RecursiveDo",                      Opt_RecursiveDo,
+  ( "RecursiveDo",                      Opt_RecursiveDo,     -- Enables 'mdo'
     deprecatedForExtension "DoRec"),
-  ( "DoRec",                            Opt_DoRec, nop ),
+  ( "DoRec",                            Opt_DoRec, nop ),    -- Enables 'rec' keyword 
   ( "Arrows",                           Opt_Arrows, nop ),
   ( "ParallelArrays",                   Opt_ParallelArrays, nop ),
   ( "TemplateHaskell",                  Opt_TemplateHaskell, checkTemplateHaskellOk ),
index 36e53a8..6a5552f 100644 (file)
@@ -1132,7 +1132,7 @@ hscTcExpr -- Typecheck an expression (but don't run it)
 hscTcExpr hsc_env expr = runHsc hsc_env $ do
     maybe_stmt <- hscParseStmt expr
     case maybe_stmt of
-        Just (L _ (ExprStmt expr _ _)) ->
+        Just (L _ (ExprStmt expr _ _ _)) ->
             ioMsgMaybe $ tcRnExpr hsc_env (hsc_IC hsc_env) expr
         _ ->
             liftIO $ throwIO $ mkSrcErr $ unitBag $ mkPlainErrMsg noSrcSpan
index a2d2276..46f7488 100644 (file)
@@ -1893,6 +1893,7 @@ mkPState flags buf loc =
                .|. unboxedTuplesBit  `setBitIf` xopt Opt_UnboxedTuples   flags
                .|. datatypeContextsBit `setBitIf` xopt Opt_DatatypeContexts flags
                .|. transformComprehensionsBit `setBitIf` xopt Opt_TransformListComp flags
+               .|. transformComprehensionsBit `setBitIf` xopt Opt_MonadComprehensions flags
                .|. rawTokenStreamBit `setBitIf` dopt Opt_KeepRawTokenStream flags
                .|. alternativeLayoutRuleBit `setBitIf` xopt Opt_AlternativeLayoutRule flags
                .|. relaxedLayoutBit  `setBitIf` xopt Opt_RelaxedLayout flags
index bfadfba..aa20ea6 100644 (file)
@@ -1283,14 +1283,9 @@ exp10 :: { LHsExpr RdrName }
        | 'case' exp 'of' altslist              { LL $ HsCase $2 (mkMatchGroup (unLoc $4)) }
        | '-' fexp                              { LL $ NegApp $2 noSyntaxExpr }
 
-       | 'do' stmtlist                 {% let loc = comb2 $1 $2 in
-                                          checkDo loc (unLoc $2)  >>= \ (stmts,body) ->
-                                          return (L loc (mkHsDo DoExpr stmts body)) }
-       | 'mdo' stmtlist                {% let loc = comb2 $1 $2 in
-                                          checkDo loc (unLoc $2)  >>= \ (stmts,body) ->
-                                           return (L loc (mkHsDo MDoExpr
-                                                                 [L loc (mkRecStmt stmts)]
-                                                                 body)) }
+       | 'do' stmtlist                 { L (comb2 $1 $2) (mkHsDo DoExpr  (unLoc $2)) }
+       | 'mdo' stmtlist                { L (comb2 $1 $2) (mkHsDo MDoExpr (unLoc $2)) }
+
         | scc_annot exp                                { LL $ if opt_SccProfilingOn
                                                        then HsSCC (unLoc $1) $2
                                                        else HsPar $2 }
@@ -1465,7 +1460,10 @@ list :: { LHsExpr RdrName }
        | texp ',' exp '..'     { LL $ ArithSeq noPostTcExpr (FromThen $1 $3) }
        | texp '..' exp         { LL $ ArithSeq noPostTcExpr (FromTo $1 $3) }
        | texp ',' exp '..' exp { LL $ ArithSeq noPostTcExpr (FromThenTo $1 $3 $5) }
-       | texp '|' flattenedpquals      { sL (comb2 $1 $>) $ mkHsDo ListComp (unLoc $3) $1 }
+       | texp '|' flattenedpquals      
+             {% checkMonadComp >>= \ ctxt ->
+               return (sL (comb2 $1 $>) $ 
+                        mkHsComp ctxt (unLoc $3) $1) }
 
 lexps :: { Located [LHsExpr RdrName] }
        : lexps ',' texp                { LL (((:) $! $3) $! unLoc $1) }
@@ -1480,7 +1478,7 @@ flattenedpquals :: { Located [LStmt RdrName] }
                     -- We just had one thing in our "parallel" list so 
                     -- we simply return that thing directly
                     
-                    qss -> L1 [L1 $ ParStmt [(qs, undefined) | qs <- qss]]
+                    qss -> L1 [L1 $ ParStmt [(qs, undefined) | qs <- qss] noSyntaxExpr noSyntaxExpr noSyntaxExpr]
                     -- We actually found some actual parallel lists so
                     -- we wrap them into as a ParStmt
                 }
@@ -1537,7 +1535,7 @@ parr :: { LHsExpr RdrName }
                                                       (reverse (unLoc $1)) }
        | texp '..' exp                 { LL $ PArrSeq noPostTcExpr (FromTo $1 $3) }
        | texp ',' exp '..' exp         { LL $ PArrSeq noPostTcExpr (FromThenTo $1 $3 $5) }
-       | texp '|' flattenedpquals      { LL $ mkHsDo PArrComp (unLoc $3) $1 }
+       | texp '|' flattenedpquals      { LL $ mkHsComp PArrComp (unLoc $3) $1 }
 
 -- We are reusing `lexps' and `flattenedpquals' from the list case.
 
index 47abf23..3b14990 100644 (file)
@@ -40,8 +40,7 @@ module RdrHsSyn (
        checkPattern,         -- HsExp -> P HsPat
        bang_RDR,
        checkPatterns,        -- SrcLoc -> [HsExp] -> P [HsPat]
-       checkDo,              -- [Stmt] -> P [Stmt]
-       checkMDo,             -- [Stmt] -> P [Stmt]
+       checkMonadComp,       -- P (HsStmtContext RdrName)
        checkValDef,          -- (SrcLoc, HsExp, HsRhs, [HsDecl]) -> P HsDecl
        checkValSig,          -- (SrcLoc, HsExp, HsRhs, [HsDecl]) -> P HsDecl
        checkDoAndIfThenElse,
@@ -54,6 +53,7 @@ import Class            ( FunDep )
 import TypeRep          ( Kind )
 import RdrName         ( RdrName, isRdrTyVar, isRdrTc, mkUnqual, rdrNameOcc, 
                          isRdrDataCon, isUnqual, getRdrName, setRdrNameSpace )
+import Name             ( Name )
 import BasicTypes      ( maxPrecedence, Activation(..), RuleMatchInfo,
                           InlinePragma(..), InlineSpec(..) )
 import Lexer
@@ -611,34 +611,6 @@ checkPred (L spn ty)
     check loc _                        _    = parseErrorSDoc loc
                                 (text "malformed class assertion:" <+> ppr ty)
 
----------------------------------------------------------------------------
--- Checking statements in a do-expression
---     We parse   do { e1 ; e2 ; }
---     as [ExprStmt e1, ExprStmt e2]
--- checkDo (a) checks that the last thing is an ExprStmt
---        (b) returns it separately
--- same comments apply for mdo as well
-
-checkDo, checkMDo :: SrcSpan -> [LStmt RdrName] -> P ([LStmt RdrName], LHsExpr RdrName)
-
-checkDo         = checkDoMDo "a " "'do'"
-checkMDo = checkDoMDo "an " "'mdo'"
-
-checkDoMDo :: String -> String -> SrcSpan -> [LStmt RdrName] -> P ([LStmt RdrName], LHsExpr RdrName)
-checkDoMDo _   nm loc []   = parseErrorSDoc loc (text ("Empty " ++ nm ++ " construct"))
-checkDoMDo pre nm _   ss   = do
-  check ss
-  where 
-       check  []                     = panic "RdrHsSyn:checkDoMDo"
-       check  [L _ (ExprStmt e _ _)] = return ([], e)
-       check  [L l e] = parseErrorSDoc l
-                         (text ("The last statement in " ++ pre ++ nm ++
-                                                   " construct must be an expression:")
-                       $$ ppr e)
-       check (s:ss) = do
-         (ss',e') <-  check ss
-         return ((s:ss'),e')
-
 -- -------------------------------------------------------------------------
 -- Checking Patterns.
 
@@ -912,6 +884,20 @@ isFunLhs e = go e []
                 _ -> return Nothing }
    go _ _ = return Nothing
 
+
+---------------------------------------------------------------------------
+-- Check for monad comprehensions
+--
+-- If the flag MonadComprehensions is set, return a `MonadComp' context,
+-- otherwise use the usual `ListComp' context
+
+checkMonadComp :: P (HsStmtContext Name)
+checkMonadComp = do
+    pState <- getPState
+    return $ if xopt Opt_MonadComprehensions (dflags pState)
+                then MonadComp
+                else ListComp
+
 ---------------------------------------------------------------------------
 -- Miscellaneous utilities
 
index 24756d5..e1d287a 100644 (file)
@@ -160,6 +160,7 @@ basicKnownKeyNames
        -- Monad stuff
        thenIOName, bindIOName, returnIOName, failIOName,
        failMName, bindMName, thenMName, returnMName,
+        fmapName,
 
        -- MonadRec stuff
        mfixName,
@@ -221,6 +222,12 @@ basicKnownKeyNames
        -- dotnet interop
        , objectTyConName, marshalObjectName, unmarshalObjectName
        , marshalStringName, unmarshalStringName, checkDotnetResName
+
+        -- Monad comprehensions
+        , guardMName
+        , liftMName
+        , groupMName
+        , mzipName
     ]
 
 genericTyConNames :: [Name]
@@ -262,8 +269,9 @@ gHC_PRIM, gHC_TYPES, gHC_UNIT, gHC_ORDERING, gHC_GENERICS,
     gHC_PACK, gHC_CONC, gHC_IO, gHC_IO_Exception,
     gHC_ST, gHC_ARR, gHC_STABLE, gHC_ADDR, gHC_PTR, gHC_ERR, gHC_REAL,
     gHC_FLOAT, gHC_TOP_HANDLER, sYSTEM_IO, dYNAMIC, tYPEABLE, gENERICS,
-    dOTNET, rEAD_PREC, lEX, gHC_INT, gHC_WORD, mONAD, mONAD_FIX, aRROW, cONTROL_APPLICATIVE,
-    gHC_DESUGAR, rANDOM, gHC_EXTS, cONTROL_EXCEPTION_BASE :: Module
+    dOTNET, rEAD_PREC, lEX, gHC_INT, gHC_WORD, mONAD, mONAD_FIX, mONAD_GROUP, mONAD_ZIP,
+    aRROW, cONTROL_APPLICATIVE, gHC_DESUGAR, rANDOM, gHC_EXTS,
+    cONTROL_EXCEPTION_BASE :: Module
 
 gHC_PRIM       = mkPrimModule (fsLit "GHC.Prim")   -- Primitive types and values
 gHC_TYPES       = mkPrimModule (fsLit "GHC.Types")
@@ -311,6 +319,8 @@ gHC_INT             = mkBaseModule (fsLit "GHC.Int")
 gHC_WORD       = mkBaseModule (fsLit "GHC.Word")
 mONAD          = mkBaseModule (fsLit "Control.Monad")
 mONAD_FIX      = mkBaseModule (fsLit "Control.Monad.Fix")
+mONAD_GROUP     = mkBaseModule (fsLit "Control.Monad.Group")
+mONAD_ZIP       = mkBaseModule (fsLit "Control.Monad.Zip")
 aRROW          = mkBaseModule (fsLit "Control.Arrow")
 cONTROL_APPLICATIVE = mkBaseModule (fsLit "Control.Applicative")
 gHC_DESUGAR = mkBaseModule (fsLit "GHC.Desugar")
@@ -597,12 +607,13 @@ inlineIdName :: Name
 inlineIdName           = varQual gHC_MAGIC (fsLit "inline") inlineIdKey
 
 -- Base classes (Eq, Ord, Functor)
-eqClassName, eqName, ordClassName, geName, functorClassName :: Name
+fmapName, eqClassName, eqName, ordClassName, geName, functorClassName :: Name
 eqClassName      = clsQual  gHC_CLASSES (fsLit "Eq")      eqClassKey
 eqName           = methName gHC_CLASSES (fsLit "==")      eqClassOpKey
 ordClassName     = clsQual  gHC_CLASSES (fsLit "Ord")     ordClassKey
 geName           = methName gHC_CLASSES (fsLit ">=")      geClassOpKey
 functorClassName  = clsQual  gHC_BASE (fsLit "Functor") functorClassKey
+fmapName          = methName gHC_BASE (fsLit "fmap")    fmapClassOpKey
 
 -- Class Monad
 monadClassName, thenMName, bindMName, returnMName, failMName :: Name
@@ -834,6 +845,14 @@ appAName      = varQual aRROW (fsLit "app")          appAIdKey
 choiceAName       = varQual aRROW (fsLit "|||")          choiceAIdKey
 loopAName         = varQual aRROW (fsLit "loop")  loopAIdKey
 
+-- Monad comprehensions
+guardMName, liftMName, groupMName, mzipName :: Name
+guardMName         = varQual mONAD (fsLit "guard") guardMIdKey
+liftMName          = varQual mONAD (fsLit "liftM") liftMIdKey
+groupMName         = varQual mONAD_GROUP (fsLit "mgroupWith") groupMIdKey
+mzipName           = varQual mONAD_ZIP (fsLit "mzip") mzipIdKey
+
+
 -- Annotation type checking
 toAnnotationWrapperName :: Name
 toAnnotationWrapperName = varQual gHC_DESUGAR (fsLit "toAnnotationWrapper") toAnnotationWrapperIdKey
@@ -1280,7 +1299,8 @@ unboundKey                      = mkPreludeMiscIdUnique 101
 fromIntegerClassOpKey, minusClassOpKey, fromRationalClassOpKey,
     enumFromClassOpKey, enumFromThenClassOpKey, enumFromToClassOpKey,
     enumFromThenToClassOpKey, eqClassOpKey, geClassOpKey, negateClassOpKey,
-    failMClassOpKey, bindMClassOpKey, thenMClassOpKey, returnMClassOpKey
+    failMClassOpKey, bindMClassOpKey, thenMClassOpKey, returnMClassOpKey,
+    fmapClassOpKey
     :: Unique
 fromIntegerClassOpKey        = mkPreludeMiscIdUnique 102
 minusClassOpKey                      = mkPreludeMiscIdUnique 103
@@ -1295,6 +1315,7 @@ negateClassOpKey        = mkPreludeMiscIdUnique 111
 failMClassOpKey                      = mkPreludeMiscIdUnique 112
 bindMClassOpKey                      = mkPreludeMiscIdUnique 113 -- (>>=)
 thenMClassOpKey                      = mkPreludeMiscIdUnique 114 -- (>>)
+fmapClassOpKey                = mkPreludeMiscIdUnique 115
 returnMClassOpKey            = mkPreludeMiscIdUnique 117
 
 -- Recursive do notation
@@ -1325,6 +1346,14 @@ realToFracIdKey      = mkPreludeMiscIdUnique 128
 toIntegerClassOpKey  = mkPreludeMiscIdUnique 129
 toRationalClassOpKey = mkPreludeMiscIdUnique 130
 
+-- Monad comprehensions
+guardMIdKey, liftMIdKey, groupMIdKey, mzipIdKey :: Unique
+guardMIdKey     = mkPreludeMiscIdUnique 131
+liftMIdKey      = mkPreludeMiscIdUnique 132
+groupMIdKey     = mkPreludeMiscIdUnique 133
+mzipIdKey       = mkPreludeMiscIdUnique 134
+
+
 ---------------- Template Haskell -------------------
 --     USES IdUniques 200-499
 -----------------------------------------------------
index df3b12d..dc7ea96 100644 (file)
@@ -789,9 +789,9 @@ rnGRHS' ctxt (GRHS guards rhs)
        -- Standard Haskell 1.4 guards are just a single boolean
        -- expression, rather than a list of qualifiers as in the
        -- Glasgow extension
-    is_standard_guard []                     = True
-    is_standard_guard [L _ (ExprStmt _ _ _)] = True
-    is_standard_guard _                      = False
+    is_standard_guard []                       = True
+    is_standard_guard [L _ (ExprStmt _ _ _ _)] = True
+    is_standard_guard _                        = False
 \end{code}
 
 %************************************************************************
index d11249a..b3458db 100644 (file)
@@ -40,7 +40,7 @@ import RdrName
 import LoadIface       ( loadInterfaceForName )
 import UniqSet
 import Data.List
-import Util            ( isSingleton )
+import Util            ( isSingleton, snocView )
 import ListSetOps      ( removeDups )
 import Outputable
 import SrcLoc
@@ -224,10 +224,9 @@ rnExpr (HsLet binds expr)
     rnLExpr expr                        `thenM` \ (expr',fvExpr) ->
     return (HsLet binds' expr', fvExpr)
 
-rnExpr (HsDo do_or_lc stmts body _)
-  = do  { ((stmts', body'), fvs) <- rnStmts do_or_lc stmts $ \ _ ->
-                                   rnLExpr body
-       ; return (HsDo do_or_lc stmts' body' placeHolderType, fvs) }
+rnExpr (HsDo do_or_lc stmts _)
+  = do         { ((stmts', _), fvs) <- rnStmts do_or_lc stmts (\ _ -> return ((), emptyFVs))
+       ; return ( HsDo do_or_lc stmts' placeHolderType, fvs ) }
 
 rnExpr (ExplicitList _ exps)
   = rnExprs exps                       `thenM` \ (exps', fvs) ->
@@ -441,9 +440,9 @@ convertOpFormsCmd (HsIf f exp c1 c2)
 convertOpFormsCmd (HsLet binds cmd)
   = HsLet binds (convertOpFormsLCmd cmd)
 
-convertOpFormsCmd (HsDo ctxt stmts body ty)
-  = HsDo ctxt (map (fmap convertOpFormsStmt) stmts)
-             (convertOpFormsLCmd body) ty
+convertOpFormsCmd (HsDo DoExpr stmts ty)
+  = HsDo ArrowExpr (map (fmap convertOpFormsStmt) stmts) ty
+    -- Mark the HsDo as begin the body of an arrow command
 
 -- Anything else is unchanged.  This includes HsArrForm (already done),
 -- things with no sub-commands, and illegal commands (which will be
@@ -453,8 +452,8 @@ convertOpFormsCmd c = c
 convertOpFormsStmt :: StmtLR id id -> StmtLR id id
 convertOpFormsStmt (BindStmt pat cmd _ _)
   = BindStmt pat (convertOpFormsLCmd cmd) noSyntaxExpr noSyntaxExpr
-convertOpFormsStmt (ExprStmt cmd _ _)
-  = ExprStmt (convertOpFormsLCmd cmd) noSyntaxExpr placeHolderType
+convertOpFormsStmt (ExprStmt cmd _ _ _)
+  = ExprStmt (convertOpFormsLCmd cmd) noSyntaxExpr noSyntaxExpr placeHolderType
 convertOpFormsStmt stmt@(RecStmt { recS_stmts = stmts })
   = stmt { recS_stmts = map (fmap convertOpFormsStmt) stmts }
 convertOpFormsStmt stmt = stmt
@@ -495,14 +494,10 @@ methodNamesCmd (HsPar c) = methodNamesLCmd c
 methodNamesCmd (HsIf _ _ c1 c2)
   = methodNamesLCmd c1 `plusFV` methodNamesLCmd c2 `addOneFV` choiceAName
 
-methodNamesCmd (HsLet _ c) = methodNamesLCmd c
-
-methodNamesCmd (HsDo _ stmts body _) 
-  = methodNamesStmts stmts `plusFV` methodNamesLCmd body
-
-methodNamesCmd (HsApp c _) = methodNamesLCmd c
-
-methodNamesCmd (HsLam match) = methodNamesMatch match
+methodNamesCmd (HsLet _ c)      = methodNamesLCmd c
+methodNamesCmd (HsDo _ stmts _) = methodNamesStmts stmts 
+methodNamesCmd (HsApp c _)      = methodNamesLCmd c
+methodNamesCmd (HsLam match)    = methodNamesMatch match
 
 methodNamesCmd (HsCase _ matches)
   = methodNamesMatch matches `addOneFV` choiceAName
@@ -538,14 +533,14 @@ methodNamesLStmt :: Located (StmtLR Name Name) -> FreeVars
 methodNamesLStmt = methodNamesStmt . unLoc
 
 methodNamesStmt :: StmtLR Name Name -> FreeVars
-methodNamesStmt (ExprStmt cmd _ _)               = methodNamesLCmd cmd
+methodNamesStmt (LastStmt cmd _)                 = methodNamesLCmd cmd
+methodNamesStmt (ExprStmt cmd _ _ _)             = methodNamesLCmd cmd
 methodNamesStmt (BindStmt _ cmd _ _)             = methodNamesLCmd cmd
 methodNamesStmt (RecStmt { recS_stmts = stmts }) = methodNamesStmts stmts `addOneFV` loopAName
 methodNamesStmt (LetStmt _)                      = emptyFVs
-methodNamesStmt (ParStmt _)                      = emptyFVs
-methodNamesStmt (TransformStmt {})               = emptyFVs
-methodNamesStmt (GroupStmt {})                   = emptyFVs
-   -- ParStmt, TransformStmt and GroupStmt can't occur in commands, but it's not convenient to error 
+methodNamesStmt (ParStmt _ _ _ _)                = emptyFVs
+methodNamesStmt (TransStmt {})                   = emptyFVs
+   -- ParStmt and TransStmt can't occur in commands, but it's not convenient to error 
    -- here so we just do what's convenient
 \end{code}
 
@@ -588,14 +583,16 @@ rnArithSeq (FromThenTo expr1 expr2 expr3)
 
 \begin{code}
 rnBracket :: HsBracket RdrName -> RnM (HsBracket Name, FreeVars)
-rnBracket (VarBr n) = do { name <- lookupOccRn n
-                        ; this_mod <- getModule
-                        ; unless (nameIsLocalOrFrom this_mod name) $   -- Reason: deprecation checking asumes the
-                          do { _ <- loadInterfaceForName msg name      -- home interface is loaded, and this is the
-                             ; return () }                             -- only way that is going to happen
-                        ; return (VarBr name, unitFV name) }
-                   where
-                     msg = ptext (sLit "Need interface for Template Haskell quoted Name")
+rnBracket (VarBr n) 
+  = do { name <- lookupOccRn n
+       ; this_mod <- getModule
+       ; unless (nameIsLocalOrFrom this_mod name) $  -- Reason: deprecation checking assumes
+         do { _ <- loadInterfaceForName msg name     -- the home interface is loaded, and
+            ; return () }                           -- this is the only way that is going
+                                                    -- to happen
+       ; return (VarBr name, unitFV name) }
+  where
+    msg = ptext (sLit "Need interface for Template Haskell quoted Name")
 
 rnBracket (ExpBr e) = do { (e', fvs) <- rnLExpr e
                         ; return (ExpBr e', fvs) }
@@ -625,7 +622,8 @@ rnBracket (DecBrL decls)
                              rnSrcDecls group      
 
              -- Discard the tcg_env; it contains only extra info about fixity
-        ; traceRn (text "rnBracket dec" <+> (ppr (tcg_dus tcg_env) $$ ppr (duUses (tcg_dus tcg_env))))
+        ; traceRn (text "rnBracket dec" <+> (ppr (tcg_dus tcg_env) $$ 
+                   ppr (duUses (tcg_dus tcg_env))))
        ; return (DecBrG group', duUses (tcg_dus tcg_env)) }
 
 rnBracket (DecBrG _) = panic "rnBracket: unexpected DecBrG"
@@ -639,44 +637,72 @@ rnBracket (DecBrG _) = panic "rnBracket: unexpected DecBrG"
 
 \begin{code}
 rnStmts :: HsStmtContext Name -> [LStmt RdrName]
-             -> ([Name] -> RnM (thing, FreeVars))
-             -> RnM (([LStmt Name], thing), FreeVars)  
+       -> ([Name] -> RnM (thing, FreeVars))
+       -> RnM (([LStmt Name], thing), FreeVars)        
 -- Variables bound by the Stmts, and mentioned in thing_inside,
 -- do not appear in the result FreeVars
---
--- Renaming a single RecStmt can give a sequence of smaller Stmts
 
-rnStmts _ [] thing_inside
-  = do { (res, fvs) <- thing_inside []
-       ; return (([], res), fvs) }
+rnStmts ctxt [] thing_inside
+  = do { checkEmptyStmts ctxt
+       ; (thing, fvs) <- thing_inside []
+       ; return (([], thing), fvs) }
+
+rnStmts MDoExpr stmts thing_inside    -- Deal with mdo
+  = -- Behave like do { rec { ...all but last... }; last }
+    do { ((stmts1, (stmts2, thing)), fvs) 
+          <- rnStmt MDoExpr (noLoc $ mkRecStmt all_but_last) $ \ _ ->
+             do { last_stmt' <- checkLastStmt MDoExpr last_stmt
+                ; rnStmt MDoExpr last_stmt' thing_inside }
+       ; return (((stmts1 ++ stmts2), thing), fvs) }
+  where
+    Just (all_but_last, last_stmt) = snocView stmts
 
-rnStmts ctxt (stmt@(L loc _) : stmts) thing_inside
+rnStmts ctxt (lstmt@(L loc _) : lstmts) thing_inside
+  | null lstmts
+  = setSrcSpan loc $
+    do { lstmt' <- checkLastStmt ctxt lstmt
+       ; rnStmt ctxt lstmt' thing_inside }
+
+  | otherwise
   = do { ((stmts1, (stmts2, thing)), fvs) 
-            <- setSrcSpan loc           $
-               rnStmt ctxt stmt         $ \ bndrs1 ->
-               rnStmts ctxt stmts $ \ bndrs2 ->
-               thing_inside (bndrs1 ++ bndrs2)
+            <- setSrcSpan loc                         $
+               do { checkStmt ctxt lstmt
+                  ; rnStmt ctxt lstmt    $ \ bndrs1 ->
+                    rnStmts ctxt lstmts  $ \ bndrs2 ->
+                    thing_inside (bndrs1 ++ bndrs2) }
        ; return (((stmts1 ++ stmts2), thing), fvs) }
 
-
-rnStmt :: HsStmtContext Name -> LStmt RdrName
+----------------------
+rnStmt :: HsStmtContext Name 
+       -> LStmt RdrName
        -> ([Name] -> RnM (thing, FreeVars))
        -> RnM (([LStmt Name], thing), FreeVars)
 -- Variables bound by the Stmt, and mentioned in thing_inside,
 -- do not appear in the result FreeVars
 
-rnStmt _ (L loc (ExprStmt expr _ _)) thing_inside
+rnStmt ctxt (L loc (LastStmt expr _)) thing_inside
+  = do { (expr', fv_expr) <- rnLExpr expr
+       ; (ret_op, fvs1)   <- lookupStmtName ctxt returnMName
+       ; (thing,  fvs3)   <- thing_inside []
+       ; return (([L loc (LastStmt expr' ret_op)], thing),
+                 fv_expr `plusFV` fvs1 `plusFV` fvs3) }
+
+rnStmt ctxt (L loc (ExprStmt expr _ _ _)) thing_inside
   = do { (expr', fv_expr) <- rnLExpr expr
-       ; (then_op, fvs1)  <- lookupSyntaxName thenMName
-       ; (thing, fvs2)    <- thing_inside []
-       ; return (([L loc (ExprStmt expr' then_op placeHolderType)], thing),
-                 fv_expr `plusFV` fvs1 `plusFV` fvs2) }
+       ; (then_op, fvs1)  <- lookupStmtName ctxt thenMName
+       ; (guard_op, fvs2) <- if isListCompExpr ctxt
+                              then lookupStmtName ctxt guardMName
+                             else return (noSyntaxExpr, emptyFVs)
+                             -- Only list/parr/monad comprehensions use 'guard'
+       ; (thing, fvs3)    <- thing_inside []
+       ; return (([L loc (ExprStmt expr' then_op guard_op placeHolderType)], thing),
+                 fv_expr `plusFV` fvs1 `plusFV` fvs2 `plusFV` fvs3) }
 
 rnStmt ctxt (L loc (BindStmt pat expr _ _)) thing_inside
   = do { (expr', fv_expr) <- rnLExpr expr
                -- The binders do not scope over the expression
-       ; (bind_op, fvs1) <- lookupSyntaxName bindMName
-       ; (fail_op, fvs2) <- lookupSyntaxName failMName
+       ; (bind_op, fvs1) <- lookupStmtName ctxt bindMName
+       ; (fail_op, fvs2) <- lookupStmtName ctxt failMName
        ; rnPat (StmtCtxt ctxt) pat $ \ pat' -> do
        { (thing, fvs3) <- thing_inside (collectPatBinders pat')
        ; return (([L loc (BindStmt pat' expr' bind_op fail_op)], thing),
@@ -684,15 +710,13 @@ rnStmt ctxt (L loc (BindStmt pat expr _ _)) thing_inside
        -- fv_expr shouldn't really be filtered by the rnPatsAndThen
        -- but it does not matter because the names are unique
 
-rnStmt ctxt (L loc (LetStmt binds)) thing_inside 
-  = do { checkLetStmt ctxt binds
-       ; rnLocalBindsAndThen binds $ \binds' -> do
+rnStmt _ (L loc (LetStmt binds)) thing_inside 
+  = do { rnLocalBindsAndThen binds $ \binds' -> do
        { (thing, fvs) <- thing_inside (collectLocalBinders binds')
         ; return (([L loc (LetStmt binds')], thing), fvs) }  }
 
 rnStmt ctxt (L _ (RecStmt { recS_stmts = rec_stmts })) thing_inside
-  = do { checkRecStmt ctxt
-
+  = do { 
        -- Step1: Bring all the binders of the mdo into scope
        -- (Remember that this also removes the binders from the
        -- finally-returned free-vars.)
@@ -707,9 +731,9 @@ rnStmt ctxt (L _ (RecStmt { recS_stmts = rec_stmts })) thing_inside
        { let bndrs = nameSetToList $ foldr (unionNameSets . (\(ds,_,_,_) -> ds)) 
                                             emptyNameSet segs
         ; (thing, fvs_later) <- thing_inside bndrs
-       ; (return_op, fvs1)  <- lookupSyntaxName returnMName
-       ; (mfix_op,   fvs2)  <- lookupSyntaxName mfixName
-       ; (bind_op,   fvs3)  <- lookupSyntaxName bindMName
+       ; (return_op, fvs1)  <- lookupStmtName ctxt returnMName
+       ; (mfix_op,   fvs2)  <- lookupStmtName ctxt mfixName
+       ; (bind_op,   fvs3)  <- lookupStmtName ctxt bindMName
        ; let
                -- Step 2: Fill in the fwd refs.
                --         The segments are all singletons, but their fwd-ref
@@ -734,57 +758,51 @@ rnStmt ctxt (L _ (RecStmt { recS_stmts = rec_stmts })) thing_inside
 
        ; return ((rec_stmts', thing), fvs `plusFV` fvs1 `plusFV` fvs2 `plusFV` fvs3) } }
 
-rnStmt ctxt (L loc (ParStmt segs)) thing_inside
-  = do { checkParStmt ctxt
-       ; ((segs', thing), fvs) <- rnParallelStmts (ParStmtCtxt ctxt) segs thing_inside
-       ; return (([L loc (ParStmt segs')], thing), fvs) }
-
-rnStmt ctxt (L loc (TransformStmt stmts _ using by)) thing_inside
-  = do { checkTransformStmt ctxt
-    
-       ; (using', fvs1) <- rnLExpr using
-
-       ; ((stmts', (by', used_bndrs, thing)), fvs2)
-             <- rnStmts (TransformStmtCtxt ctxt) stmts $ \ bndrs ->
-                do { (by', fvs_by) <- case by of
-                                        Nothing -> return (Nothing, emptyFVs)
-                                        Just e  -> do { (e', fvs) <- rnLExpr e; return (Just e', fvs) }
-                   ; (thing, fvs_thing) <- thing_inside bndrs
-                   ; let fvs        = fvs_by `plusFV` fvs_thing
-                         used_bndrs = filter (`elemNameSet` fvs) bndrs
-                         -- The paper (Fig 5) has a bug here; we must treat any free varaible of
-                         -- the "thing inside", **or of the by-expression**, as used
-                   ; return ((by', used_bndrs, thing), fvs) }
-
-       ; return (([L loc (TransformStmt stmts' used_bndrs using' by')], thing), 
-                 fvs1 `plusFV` fvs2) }
-        
-rnStmt ctxt (L loc (GroupStmt stmts _ by using)) thing_inside
-  = do { checkTransformStmt ctxt
-    
-         -- Rename the 'using' expression in the context before the transform is begun
-       ; (using', fvs1) <- case using of
-                             Left e  -> do { (e', fvs) <- rnLExpr e; return (Left e', fvs) }
-                            Right _ -> do { (e', fvs) <- lookupSyntaxName groupWithName
-                                           ; return (Right e', fvs) }
+rnStmt ctxt (L loc (ParStmt segs _ _ _)) thing_inside
+  = do { (mzip_op, fvs1)   <- lookupStmtName ctxt mzipName
+        ; (bind_op, fvs2)   <- lookupStmtName ctxt bindMName
+        ; (return_op, fvs3) <- lookupStmtName ctxt returnMName
+       ; ((segs', thing), fvs4) <- rnParallelStmts (ParStmtCtxt ctxt) segs thing_inside
+       ; return ( ([L loc (ParStmt segs' mzip_op bind_op return_op)], thing)
+                 , fvs1 `plusFV` fvs2 `plusFV` fvs3 `plusFV` fvs4) }
+
+rnStmt ctxt (L loc (TransStmt { trS_stmts = stmts, trS_by = by, trS_form = form
+                              , trS_using = using })) thing_inside
+  = do { -- Rename the 'using' expression in the context before the transform is begun
+         (using', fvs1) <- case form of
+                             GroupFormB -> do { (e,fvs) <- lookupStmtName ctxt groupMName
+                                              ; return (noLoc e, fvs) }
+                            _          -> rnLExpr using
 
          -- Rename the stmts and the 'by' expression
         -- Keep track of the variables mentioned in the 'by' expression
        ; ((stmts', (by', used_bndrs, thing)), fvs2) 
-             <- rnStmts (TransformStmtCtxt ctxt) stmts $ \ bndrs ->
+             <- rnStmts (TransStmtCtxt ctxt) stmts $ \ bndrs ->
                 do { (by',   fvs_by) <- mapMaybeFvRn rnLExpr by
                    ; (thing, fvs_thing) <- thing_inside bndrs
                    ; let fvs = fvs_by `plusFV` fvs_thing
                          used_bndrs = filter (`elemNameSet` fvs) bndrs
+                         -- The paper (Fig 5) has a bug here; we must treat any free varaible
+                         -- of the "thing inside", **or of the by-expression**, as used
                    ; return ((by', used_bndrs, thing), fvs) }
 
-       ; let all_fvs  = fvs1 `plusFV` fvs2 
+       -- Lookup `return`, `(>>=)` and `liftM` for monad comprehensions
+       ; (return_op, fvs3) <- lookupStmtName ctxt returnMName
+       ; (bind_op,   fvs4) <- lookupStmtName ctxt bindMName
+       ; (fmap_op,   fvs5) <- case form of
+                                ThenForm -> return (noSyntaxExpr, emptyFVs)
+                                _        -> lookupStmtName ctxt fmapName
+
+       ; let all_fvs  = fvs1 `plusFV` fvs2 `plusFV` fvs3 
+                             `plusFV` fvs4 `plusFV` fvs5
              bndr_map = used_bndrs `zip` used_bndrs
-            -- See Note [GroupStmt binder map] in HsExpr
+            -- See Note [TransStmt binder map] in HsExpr
 
        ; traceRn (text "rnStmt: implicitly rebound these used binders:" <+> ppr bndr_map)
-       ; return (([L loc (GroupStmt stmts' bndr_map by' using')], thing), all_fvs) }
-
+       ; return (([L loc (TransStmt { trS_stmts = stmts', trS_bndrs = bndr_map
+                                    , trS_by = by', trS_using = using', trS_form = form
+                                    , trS_ret = return_op, trS_bind = bind_op
+                                    , trS_fmap = fmap_op })], thing), all_fvs) }
 
 type ParSeg id = ([LStmt id], [id])       -- The Names are bound by the Stmts
 
@@ -820,6 +838,12 @@ rnParallelStmts ctxt segs thing_inside
     cmpByOcc n1 n2 = nameOccName n1 `compare` nameOccName n2
     dupErr vs = addErr (ptext (sLit "Duplicate binding in parallel list comprehension for:")
                     <+> quotes (ppr (head vs)))
+
+lookupStmtName :: HsStmtContext Name -> Name -> RnM (HsExpr Name, FreeVars)
+-- Like lookupSyntaxName, but ListComp/PArrComp are never rebindable
+lookupStmtName ListComp n = return (HsVar n, emptyFVs)
+lookupStmtName PArrComp n = return (HsVar n, emptyFVs)
+lookupStmtName _        n = lookupSyntaxName n
 \end{code}
 
 Note [Renaming parallel Stmts]
@@ -901,9 +925,11 @@ rn_rec_stmt_lhs :: MiniFixityEnv
                    -- so we don't bother to compute it accurately in the other cases
                 -> RnM [(LStmtLR Name RdrName, FreeVars)]
 
-rn_rec_stmt_lhs _ (L loc (ExprStmt expr a b)) = return [(L loc (ExprStmt expr a b), 
-                                                       -- this is actually correct
-                                                       emptyFVs)]
+rn_rec_stmt_lhs _ (L loc (ExprStmt expr a b c)) 
+  = return [(L loc (ExprStmt expr a b c), emptyFVs)]
+
+rn_rec_stmt_lhs _ (L loc (LastStmt expr a)) 
+  = return [(L loc (LastStmt expr a), emptyFVs)]
 
 rn_rec_stmt_lhs fix_env (L loc (BindStmt pat expr a b)) 
   = do 
@@ -926,13 +952,10 @@ rn_rec_stmt_lhs fix_env (L loc (LetStmt (HsValBinds binds)))
 rn_rec_stmt_lhs fix_env (L _ (RecStmt { recS_stmts = stmts })) -- Flatten Rec inside Rec
     = rn_rec_stmts_lhs fix_env stmts
 
-rn_rec_stmt_lhs _ stmt@(L _ (ParStmt _))       -- Syntactically illegal in mdo
-  = pprPanic "rn_rec_stmt" (ppr stmt)
-  
-rn_rec_stmt_lhs _ stmt@(L _ (TransformStmt {}))        -- Syntactically illegal in mdo
+rn_rec_stmt_lhs _ stmt@(L _ (ParStmt _ _ _ _)) -- Syntactically illegal in mdo
   = pprPanic "rn_rec_stmt" (ppr stmt)
   
-rn_rec_stmt_lhs _ stmt@(L _ (GroupStmt {}))    -- Syntactically illegal in mdo
+rn_rec_stmt_lhs _ stmt@(L _ (TransStmt {}))    -- Syntactically illegal in mdo
   = pprPanic "rn_rec_stmt" (ppr stmt)
 
 rn_rec_stmt_lhs _ (L _ (LetStmt EmptyLocalBinds))
@@ -957,11 +980,17 @@ rn_rec_stmt :: [Name] -> LStmtLR Name RdrName -> FreeVars -> RnM [Segment (LStmt
        -- Rename a Stmt that is inside a RecStmt (or mdo)
        -- Assumes all binders are already in scope
        -- Turns each stmt into a singleton Stmt
-rn_rec_stmt _ (L loc (ExprStmt expr _ _)) _
+rn_rec_stmt _ (L loc (LastStmt expr _)) _
+  = do { (expr', fv_expr) <- rnLExpr expr
+       ; (ret_op, fvs1)   <- lookupSyntaxName returnMName
+       ; return [(emptyNameSet, fv_expr `plusFV` fvs1, emptyNameSet,
+                   L loc (LastStmt expr' ret_op))] }
+
+rn_rec_stmt _ (L loc (ExprStmt expr _ _ _)) _
   = rnLExpr expr `thenM` \ (expr', fvs) ->
     lookupSyntaxName thenMName `thenM` \ (then_op, fvs1) ->
     return [(emptyNameSet, fvs `plusFV` fvs1, emptyNameSet,
-             L loc (ExprStmt expr' then_op placeHolderType))]
+             L loc (ExprStmt expr' then_op noSyntaxExpr placeHolderType))]
 
 rn_rec_stmt _ (L loc (BindStmt pat' expr _ _)) fv_pat
   = rnLExpr expr               `thenM` \ (expr', fv_expr) ->
@@ -991,11 +1020,8 @@ rn_rec_stmt _ stmt@(L _ (RecStmt {})) _
 rn_rec_stmt _ stmt@(L _ (ParStmt {})) _        -- Syntactically illegal in mdo
   = pprPanic "rn_rec_stmt: ParStmt" (ppr stmt)
 
-rn_rec_stmt _ stmt@(L _ (TransformStmt {})) _  -- Syntactically illegal in mdo
-  = pprPanic "rn_rec_stmt: TransformStmt" (ppr stmt)
-
-rn_rec_stmt _ stmt@(L _ (GroupStmt {})) _      -- Syntactically illegal in mdo
-  = pprPanic "rn_rec_stmt: GroupStmt" (ppr stmt)
+rn_rec_stmt _ stmt@(L _ (TransStmt {})) _      -- Syntactically illegal in mdo
+  = pprPanic "rn_rec_stmt: TransStmt" (ppr stmt)
 
 rn_rec_stmt _ (L _ (LetStmt EmptyLocalBinds)) _
   = panic "rn_rec_stmt: LetStmt EmptyLocalBinds"
@@ -1141,44 +1167,151 @@ program.
 %************************************************************************
 
 \begin{code}
+checkEmptyStmts :: HsStmtContext Name -> RnM ()
+-- We've seen an empty sequence of Stmts... is that ok?
+checkEmptyStmts ctxt 
+  = unless (okEmpty ctxt) (addErr (emptyErr ctxt))
 
----------------------- 
--- Checking when a particular Stmt is ok
-checkLetStmt :: HsStmtContext Name -> HsLocalBinds RdrName -> RnM ()
-checkLetStmt (ParStmtCtxt _) (HsIPBinds binds) = addErr (badIpBinds (ptext (sLit "a parallel list comprehension:")) binds)
-checkLetStmt _ctxt          _binds            = return ()
-       -- We do not allow implicit-parameter bindings in a parallel
-       -- list comprehension.  I'm not sure what it might mean.
+okEmpty :: HsStmtContext a -> Bool
+okEmpty (PatGuard {}) = True
+okEmpty _             = False
 
----------
-checkRecStmt :: HsStmtContext Name -> RnM ()
-checkRecStmt MDoExpr = return ()      -- Recursive stmt ok in 'mdo'
-checkRecStmt DoExpr  = return ()      -- and in 'do'
-checkRecStmt ctxt    = addErr msg
-  where
-    msg = ptext (sLit "Illegal 'rec' stmt in") <+> pprStmtContext ctxt
+emptyErr :: HsStmtContext Name -> SDoc
+emptyErr (ParStmtCtxt {})   = ptext (sLit "Empty statement group in parallel comprehension")
+emptyErr (TransStmtCtxt {}) = ptext (sLit "Empty statement group preceding 'group' or 'then'")
+emptyErr ctxt               = ptext (sLit "Empty") <+> pprStmtContext ctxt
 
----------
-checkParStmt :: HsStmtContext Name -> RnM ()
-checkParStmt _
-  = do { parallel_list_comp <- xoptM Opt_ParallelListComp
-       ; checkErr parallel_list_comp msg }
+---------------------- 
+checkLastStmt :: HsStmtContext Name
+              -> LStmt RdrName 
+              -> RnM (LStmt RdrName)
+checkLastStmt ctxt lstmt@(L loc stmt)
+  = case ctxt of 
+      ListComp  -> check_comp
+      MonadComp -> check_comp
+      PArrComp  -> check_comp
+      ArrowExpr        -> check_do
+      DoExpr   -> check_do
+      MDoExpr   -> check_do
+      _         -> check_other
   where
-    msg = ptext (sLit "Illegal parallel list comprehension: use -XParallelListComp")
+    check_do   -- Expect ExprStmt, and change it to LastStmt
+      = case stmt of 
+          ExprStmt e _ _ _ -> return (L loc (mkLastStmt e))
+          LastStmt {}      -> return lstmt   -- "Deriving" clauses may generate a
+                                            -- LastStmt directly (unlike the parser)
+         _                -> do { addErr (hang last_error 2 (ppr stmt)); return lstmt }
+    last_error = (ptext (sLit "The last statement in") <+> pprAStmtContext ctxt
+                  <+> ptext (sLit "must be an expression"))
+
+    check_comp -- Expect LastStmt; this should be enforced by the parser!
+      = case stmt of 
+          LastStmt {} -> return lstmt
+          _           -> pprPanic "checkLastStmt" (ppr lstmt)
+
+    check_other        -- Behave just as if this wasn't the last stmt
+      = do { checkStmt ctxt lstmt; return lstmt }
 
----------
-checkTransformStmt :: HsStmtContext Name -> RnM ()
-checkTransformStmt ListComp  -- Ensure we are really within a list comprehension because otherwise the
-                            -- desugarer will break when we come to operate on a parallel array
-  = do { transform_list_comp <- xoptM Opt_TransformListComp
-       ; checkErr transform_list_comp msg }
-  where
-    msg = ptext (sLit "Illegal transform or grouping list comprehension: use -XTransformListComp")
-checkTransformStmt (ParStmtCtxt       ctxt) = checkTransformStmt ctxt  -- Ok to nest inside a parallel comprehension
-checkTransformStmt (TransformStmtCtxt ctxt) = checkTransformStmt ctxt  -- Ok to nest inside a parallel comprehension
-checkTransformStmt ctxt = addErr msg
+-- Checking when a particular Stmt is ok
+checkStmt :: HsStmtContext Name
+          -> LStmt RdrName 
+          -> RnM ()
+checkStmt ctxt (L _ stmt)
+  = do { dflags <- getDOpts
+       ; case okStmt dflags ctxt stmt of 
+           Nothing    -> return ()
+           Just extra -> addErr (msg $$ extra) }
   where
-    msg = ptext (sLit "Illegal transform or grouping in") <+> pprStmtContext ctxt
+   msg = sep [ ptext (sLit "Unexpected") <+> pprStmtCat stmt <+> ptext (sLit "statement")
+             , ptext (sLit "in") <+> pprAStmtContext ctxt ]
+
+pprStmtCat :: Stmt a -> SDoc
+pprStmtCat (TransStmt {})     = ptext (sLit "transform")
+pprStmtCat (LastStmt {})      = ptext (sLit "return expression")
+pprStmtCat (ExprStmt {})      = ptext (sLit "exprssion")
+pprStmtCat (BindStmt {})      = ptext (sLit "binding")
+pprStmtCat (LetStmt {})       = ptext (sLit "let")
+pprStmtCat (RecStmt {})       = ptext (sLit "rec")
+pprStmtCat (ParStmt {})       = ptext (sLit "parallel")
+
+------------
+isOK, notOK :: Maybe SDoc
+isOK  = Nothing
+notOK = Just empty
+
+okStmt, okDoStmt, okCompStmt, okParStmt, okPArrStmt
+   :: DynFlags -> HsStmtContext Name
+   -> Stmt RdrName -> Maybe SDoc
+-- Return Nothing if OK, (Just extra) if not ok
+-- The "extra" is an SDoc that is appended to an generic error message
+
+okStmt dflags ctxt stmt 
+  = case ctxt of
+      PatGuard {}               -> okPatGuardStmt stmt
+      ParStmtCtxt ctxt          -> okParStmt  dflags ctxt stmt
+      DoExpr                    -> okDoStmt   dflags ctxt stmt
+      MDoExpr                   -> okDoStmt   dflags ctxt stmt
+      ArrowExpr                 -> okDoStmt   dflags ctxt stmt
+      GhciStmt                  -> okDoStmt   dflags ctxt stmt
+      ListComp                  -> okCompStmt dflags ctxt stmt
+      MonadComp                 -> okCompStmt dflags ctxt stmt
+      PArrComp                  -> okPArrStmt dflags ctxt stmt
+      TransStmtCtxt ctxt -> okStmt dflags ctxt stmt
+
+-------------
+okPatGuardStmt :: Stmt RdrName -> Maybe SDoc
+okPatGuardStmt stmt
+  = case stmt of
+      ExprStmt {} -> isOK
+      BindStmt {} -> isOK
+      LetStmt {}  -> isOK
+      _           -> notOK
+
+-------------
+okParStmt dflags ctxt stmt
+  = case stmt of
+      LetStmt (HsIPBinds {}) -> notOK
+      _                      -> okStmt dflags ctxt stmt
+
+----------------
+okDoStmt dflags ctxt stmt
+  = case stmt of
+       RecStmt {}
+         | Opt_DoRec `xopt` dflags -> isOK
+         | ArrowExpr <- ctxt       -> isOK     -- Arrows allows 'rec'
+         | otherwise               -> Just (ptext (sLit "Use -XDoRec"))
+       BindStmt {} -> isOK
+       LetStmt {}  -> isOK
+       ExprStmt {} -> isOK
+       _           -> notOK
+
+----------------
+okCompStmt dflags _ stmt
+  = case stmt of
+       BindStmt {} -> isOK
+       LetStmt {}  -> isOK
+       ExprStmt {} -> isOK
+       ParStmt {} 
+         | Opt_ParallelListComp `xopt` dflags -> isOK
+         | otherwise -> Just (ptext (sLit "Use -XParallelListComp"))
+       TransStmt {} 
+         | Opt_TransformListComp `xopt` dflags -> isOK
+         | otherwise -> Just (ptext (sLit "Use -XTransformListComp"))
+       RecStmt {}  -> notOK
+       LastStmt {} -> notOK  -- Should not happen (dealt with by checkLastStmt)
+
+----------------
+okPArrStmt dflags _ stmt
+  = case stmt of
+       BindStmt {} -> isOK
+       LetStmt {}  -> isOK
+       ExprStmt {} -> isOK
+       ParStmt {} 
+         | Opt_ParallelListComp `xopt` dflags -> isOK
+         | otherwise -> Just (ptext (sLit "Use -XParallelListComp"))
+       TransStmt {} -> notOK
+       RecStmt {}   -> notOK
+       LastStmt {}  -> notOK  -- Should not happen (dealt with by checkLastStmt)
 
 ---------
 checkTupleSection :: [HsTupArg RdrName] -> RnM ()
index ae4a1e8..cfbdf35 100644 (file)
@@ -7,7 +7,7 @@ Typecheck arrow notation
 \begin{code}
 module TcArrows ( tcProc ) where
 
-import {-# SOURCE #-}  TcExpr( tcMonoExpr, tcInferRho, tcSyntaxOp )
+import {-# SOURCE #-}  TcExpr( tcMonoExpr, tcInferRho, tcSyntaxOp, tcCheckId )
 
 import HsSyn
 import TcMatches
@@ -17,7 +17,9 @@ import TcBinds
 import TcPat
 import TcUnify
 import TcRnMonad
+import TcEnv
 import Coercion
+import Id( mkLocalId )
 import Inst
 import Name
 import TysWiredIn
@@ -83,20 +85,12 @@ tcCmdTop :: CmdEnv
 
 tcCmdTop env (L loc (HsCmdTop cmd _ _ names)) cmd_stk res_ty
   = setSrcSpan loc $
-    do { cmd'   <- tcGuardedCmd env cmd cmd_stk res_ty
+    do { cmd'   <- tcCmd env cmd (cmd_stk, res_ty)
        ; names' <- mapM (tcSyntaxName ProcOrigin (cmd_arr env)) names
        ; return (L loc $ HsCmdTop cmd' cmd_stk res_ty names') }
 
 
 ----------------------------------------
-tcGuardedCmd :: CmdEnv -> LHsExpr Name -> CmdStack
-            -> TcTauType -> TcM (LHsExpr TcId)
--- A wrapper that deals with the refinement (if any)
-tcGuardedCmd env expr stk res_ty
-  = do { body <- tcCmd env expr (stk, res_ty)
-       ; return body 
-        }
-
 tcCmd :: CmdEnv -> LHsExpr Name -> (CmdStack, TcTauType) -> TcM (LHsExpr TcId)
        -- The main recursive function
 tcCmd env (L loc expr) res_ty
@@ -123,7 +117,7 @@ tc_cmd env in_cmd@(HsCase scrut matches) (stk, res_ty)
   where
     match_ctxt = MC { mc_what = CaseAlt,
                       mc_body = mc_body }
-    mc_body body res_ty' = tcGuardedCmd env body stk res_ty'
+    mc_body body res_ty' = tcCmd env body (stk, res_ty')
 
 tc_cmd env (HsIf mb_fun pred b1 b2) (stack_ty,res_ty)
   = do         { pred_ty <- newFlexiTyVarTy openTypeKind
@@ -206,22 +200,18 @@ tc_cmd env cmd@(HsLam (MatchGroup [L mtch_loc (match@(Match pats _maybe_rhs_sig
             ; return (GRHSs grhss' binds') }
 
     tc_grhs res_ty (GRHS guards body)
-       = do { (guards', rhs') <- tcStmts pg_ctxt tcGuardStmt guards res_ty $
-                                 tcGuardedCmd env body stk'
+       = do { (guards', rhs') <- tcStmtsAndThen pg_ctxt tcGuardStmt guards res_ty $
+                                 \ res_ty -> tcCmd env body (stk', res_ty)
             ; return (GRHS guards' rhs') }
 
 -------------------------------------------
 --             Do notation
 
-tc_cmd env cmd@(HsDo do_or_lc stmts body _ty) (cmd_stk, res_ty)
+tc_cmd env cmd@(HsDo do_or_lc stmts _) (cmd_stk, res_ty)
   = do         { checkTc (null cmd_stk) (nonEmptyCmdStkErr cmd)
-       ; (stmts', body') <- tcStmts do_or_lc (tcMDoStmt tc_rhs) stmts res_ty $
-                            tcGuardedCmd env body []
-       ; return (HsDo do_or_lc stmts' body' res_ty) }
+       ; stmts' <- tcStmts do_or_lc (tcArrDoStmt env) stmts res_ty 
+       ; return (HsDo do_or_lc stmts' res_ty) }
   where
-    tc_rhs rhs = do { ty <- newFlexiTyVarTy liftedTypeKind
-                   ; rhs' <- tcCmd env rhs ([], ty)
-                   ; return (rhs', ty) }
 
 
 -----------------------------------------------------------------
@@ -307,6 +297,69 @@ tc_cmd _ cmd _
 
 %************************************************************************
 %*                                                                     *
+               Stmts
+%*                                                                     *
+%************************************************************************
+
+\begin{code}
+--------------------------------
+--     Mdo-notation
+-- The distinctive features here are
+--     (a) RecStmts, and
+--     (b) no rebindable syntax
+
+tcArrDoStmt :: CmdEnv -> TcStmtChecker
+tcArrDoStmt env _ (LastStmt rhs _) res_ty thing_inside
+  = do { rhs' <- tcCmd env rhs ([], res_ty)
+       ; thing <- thing_inside (panic "tcArrDoStmt")
+       ; return (LastStmt rhs' noSyntaxExpr, thing) }
+
+tcArrDoStmt env _ (ExprStmt rhs _ _ _) res_ty thing_inside
+  = do { (rhs', elt_ty) <- tc_arr_rhs env rhs
+       ; thing          <- thing_inside res_ty
+       ; return (ExprStmt rhs' noSyntaxExpr noSyntaxExpr elt_ty, thing) }
+
+tcArrDoStmt env ctxt (BindStmt pat rhs _ _) res_ty thing_inside
+  = do { (rhs', pat_ty) <- tc_arr_rhs env rhs
+       ; (pat', thing)  <- tcPat (StmtCtxt ctxt) pat pat_ty $
+                            thing_inside res_ty
+       ; return (BindStmt pat' rhs' noSyntaxExpr noSyntaxExpr, thing) }
+
+tcArrDoStmt env ctxt (RecStmt { recS_stmts = stmts, recS_later_ids = laterNames
+                            , recS_rec_ids = recNames }) res_ty thing_inside
+  = do { rec_tys <- newFlexiTyVarTys (length recNames) liftedTypeKind
+       ; let rec_ids = zipWith mkLocalId recNames rec_tys
+       ; tcExtendIdEnv rec_ids $ do
+       { (stmts', (later_ids, rec_rets))
+               <- tcStmtsAndThen ctxt (tcArrDoStmt env) stmts res_ty   $ \ _res_ty' ->
+                       -- ToDo: res_ty not really right
+                  do { rec_rets <- zipWithM tcCheckId recNames rec_tys
+                     ; later_ids <- tcLookupLocalIds laterNames
+                     ; return (later_ids, rec_rets) }
+
+       ; thing <- tcExtendIdEnv later_ids (thing_inside res_ty)
+               -- NB:  The rec_ids for the recursive things 
+               --      already scope over this part. This binding may shadow
+               --      some of them with polymorphic things with the same Name
+               --      (see note [RecStmt] in HsExpr)
+
+        ; return (emptyRecStmt { recS_stmts = stmts', recS_later_ids = later_ids
+                               , recS_rec_ids = rec_ids, recS_rec_rets = rec_rets
+                               , recS_ret_ty = res_ty }, thing)
+       }}
+
+tcArrDoStmt _ _ stmt _ _
+  = pprPanic "tcArrDoStmt: unexpected Stmt" (ppr stmt)
+
+tc_arr_rhs :: CmdEnv -> LHsExpr Name -> TcM (LHsExpr TcId, TcType)
+tc_arr_rhs env rhs = do { ty <- newFlexiTyVarTy liftedTypeKind
+                       ; rhs' <- tcCmd env rhs ([], ty)
+                       ; return (rhs', ty) }
+\end{code}
+
+
+%************************************************************************
+%*                                                                     *
                Helpers
 %*                                                                     *
 %************************************************************************
index 6bb0820..79b097e 100644 (file)
@@ -45,6 +45,7 @@ import Type
 import Coercion
 import Var
 import VarSet
+import VarEnv
 import TysWiredIn
 import TysPrim( intPrimTy )
 import PrimOp( tagToEnumKey )
@@ -55,6 +56,7 @@ import SrcLoc
 import Util
 import ListSetOps
 import Maybes
+import ErrUtils
 import Outputable
 import FastString
 import Control.Monad
@@ -415,8 +417,8 @@ tcExpr (HsIf (Just fun) pred b1 b2) res_ty   -- Note [Rebindable syntax for if]
        -- and it maintains uniformity with other rebindable syntax
        ; return (HsIf (Just fun') pred' b1' b2') }
 
-tcExpr (HsDo do_or_lc stmts body _) res_ty
-  = tcDoStmts do_or_lc stmts body res_ty
+tcExpr (HsDo do_or_lc stmts _) res_ty
+  = tcDoStmts do_or_lc stmts res_ty
 
 tcExpr (HsProc pat cmd) res_ty
   = do { (pat', cmd', coi) <- tcProc pat cmd res_ty
@@ -820,7 +822,7 @@ tcApp fun args res_ty
        -- Typecheck the result, thereby propagating 
         -- info (if any) from result into the argument types
         -- Both actual_res_ty and res_ty are deeply skolemised
-        ; co_res <- addErrCtxt (funResCtxt fun) $
+        ; co_res <- addErrCtxtM (funResCtxt fun actual_res_ty res_ty) $
                     unifyType actual_res_ty res_ty
 
        -- Typecheck the arguments
@@ -1386,9 +1388,23 @@ funAppCtxt fun arg arg_no
                    quotes (ppr fun) <> text ", namely"])
        2 (quotes (ppr arg))
 
-funResCtxt :: LHsExpr Name -> SDoc
-funResCtxt fun
-  = ptext (sLit "In the return type of a call of") <+> quotes (ppr fun)
+funResCtxt :: LHsExpr Name -> TcType -> TcType 
+           -> TidyEnv -> TcM (TidyEnv, Message)
+-- When we have a mis-match in the return type of a function
+-- try to give a helpful message about too many/few arguments
+funResCtxt fun fun_res_ty res_ty env0
+  = do { fun_res' <- zonkTcType fun_res_ty
+       ; res'     <- zonkTcType res_ty
+       ; let n_fun = length (fst (tcSplitFunTys fun_res'))
+             n_res = length (fst (tcSplitFunTys res'))
+             what  | n_fun > n_res = ptext (sLit "few")
+                   | otherwise     = ptext (sLit "many")
+             extra | n_fun == n_res = empty
+                   | otherwise = ptext (sLit "Probable cause:") <+> quotes (ppr fun)
+                                 <+> ptext (sLit "is applied to too") <+> what 
+                                 <+> ptext (sLit "arguments") 
+             msg = ptext (sLit "In the return type of a call of") <+> quotes (ppr fun)
+       ; return (env0, msg $$ extra) }
 
 badFieldTypes :: [(Name,TcType)] -> SDoc
 badFieldTypes prs
index efacac2..dba87d2 100644 (file)
@@ -779,7 +779,7 @@ gen_Ix_binds loc tycon
     single_con_range
       = mk_easy_FunBind loc range_RDR 
          [nlTuplePat [con_pat as_needed, con_pat bs_needed] Boxed] $
-       nlHsDo ListComp stmts con_expr
+       noLoc (mkHsComp ListComp stmts con_expr)
       where
        stmts = zipWith3Equal "single_con_range" mk_qual as_needed bs_needed cs_needed
 
@@ -893,7 +893,7 @@ gen_Read_binds get_fixity loc tycon
     read_nullary_cons 
       = case nullary_cons of
            []    -> []
-           [con] -> [nlHsDo DoExpr (match_con con) (result_expr con [])]
+           [con] -> [nlHsDo DoExpr (match_con con ++ [noLoc $ mkLastStmt (result_expr con [])])]
             _     -> [nlHsApp (nlHsVar choose_RDR) 
                              (nlList (map mk_pair nullary_cons))]
         -- NB For operators the parens around (:=:) are matched by the
@@ -965,11 +965,12 @@ gen_Read_binds get_fixity loc tycon
     ------------------------------------------------------------------------
     --         Helpers
     ------------------------------------------------------------------------
-    mk_alt e1 e2       = genOpApp e1 alt_RDR e2                                        -- e1 +++ e2
-    mk_parser p ss b   = nlHsApps prec_RDR [nlHsIntLit p, nlHsDo DoExpr ss b]  -- prec p (do { ss ; b })
-    bindLex pat               = noLoc (mkBindStmt pat (nlHsVar lexP_RDR))              -- pat <- lexP
-    con_app con as     = nlHsVarApps (getRdrName con) as                       -- con as
-    result_expr con as = nlHsApp (nlHsVar returnM_RDR) (con_app con as)                -- return (con as)
+    mk_alt e1 e2       = genOpApp e1 alt_RDR e2                                -- e1 +++ e2
+    mk_parser p ss b   = nlHsApps prec_RDR [nlHsIntLit p               -- prec p (do { ss ; b })
+                                           , nlHsDo DoExpr (ss ++ [noLoc $ mkLastStmt b])]
+    bindLex pat               = noLoc (mkBindStmt pat (nlHsVar lexP_RDR))      -- pat <- lexP
+    con_app con as     = nlHsVarApps (getRdrName con) as               -- con as
+    result_expr con as = nlHsApp (nlHsVar returnM_RDR) (con_app con as) -- return (con as)
     
     punc_pat s   = nlConPat punc_RDR   [nlLitPat (mkHsString s)]  -- Punc 'c'
 
index 122b743..d179a0e 100644 (file)
@@ -578,11 +578,10 @@ zonkExpr env (HsLet binds expr)
     zonkLExpr new_env expr     `thenM` \ new_expr ->
     returnM (HsLet new_binds new_expr)
 
-zonkExpr env (HsDo do_or_lc stmts body ty)
-  = zonkStmts env stmts        `thenM` \ (new_env, new_stmts) ->
-    zonkLExpr new_env body     `thenM` \ new_body ->
+zonkExpr env (HsDo do_or_lc stmts ty)
+  = zonkStmts env stmts        `thenM` \ (_, new_stmts) ->
     zonkTcTypeToType env ty    `thenM` \ new_ty   ->
-    returnM (HsDo do_or_lc new_stmts new_body new_ty)
+    returnM (HsDo do_or_lc new_stmts new_ty)
 
 zonkExpr env (ExplicitList ty exprs)
   = zonkTcTypeToType env ty    `thenM` \ new_ty ->
@@ -728,22 +727,26 @@ zonkStmts env (s:ss) = do { (env1, s')  <- wrapLocSndM (zonkStmt env) s
                          ; return (env2, s' : ss') }
 
 zonkStmt :: ZonkEnv -> Stmt TcId -> TcM (ZonkEnv, Stmt Id)
-zonkStmt env (ParStmt stmts_w_bndrs)
+zonkStmt env (ParStmt stmts_w_bndrs mzip_op bind_op return_op)
   = mappM zonk_branch stmts_w_bndrs    `thenM` \ new_stmts_w_bndrs ->
     let 
        new_binders = concat (map snd new_stmts_w_bndrs)
        env1 = extendZonkEnv env new_binders
     in
-    return (env1, ParStmt new_stmts_w_bndrs)
+    zonkExpr env1 mzip_op   `thenM` \ new_mzip ->
+    zonkExpr env1 bind_op   `thenM` \ new_bind ->
+    zonkExpr env1 return_op `thenM` \ new_return ->
+    return (env1, ParStmt new_stmts_w_bndrs new_mzip new_bind new_return)
   where
     zonk_branch (stmts, bndrs) = zonkStmts env stmts   `thenM` \ (env1, new_stmts) ->
                                 returnM (new_stmts, zonkIdOccs env1 bndrs)
 
 zonkStmt env (RecStmt { recS_stmts = segStmts, recS_later_ids = lvs, recS_rec_ids = rvs
                       , recS_ret_fn = ret_id, recS_mfix_fn = mfix_id, recS_bind_fn = bind_id
-                      , recS_rec_rets = rets })
+                      , recS_rec_rets = rets, recS_ret_ty = ret_ty })
   = do { new_rvs <- zonkIdBndrs env rvs
        ; new_lvs <- zonkIdBndrs env lvs
+       ; new_ret_ty  <- zonkTcTypeToType env ret_ty
        ; new_ret_id  <- zonkExpr env ret_id
        ; new_mfix_id <- zonkExpr env mfix_id
        ; new_bind_id <- zonkExpr env bind_id
@@ -756,28 +759,34 @@ zonkStmt env (RecStmt { recS_stmts = segStmts, recS_later_ids = lvs, recS_rec_id
                  RecStmt { recS_stmts = new_segStmts, recS_later_ids = new_lvs
                          , recS_rec_ids = new_rvs, recS_ret_fn = new_ret_id
                          , recS_mfix_fn = new_mfix_id, recS_bind_fn = new_bind_id
-                         , recS_rec_rets = new_rets }) }
+                         , recS_rec_rets = new_rets, recS_ret_ty = new_ret_ty }) }
 
-zonkStmt env (ExprStmt expr then_op ty)
+zonkStmt env (ExprStmt expr then_op guard_op ty)
   = zonkLExpr env expr         `thenM` \ new_expr ->
     zonkExpr env then_op       `thenM` \ new_then ->
+    zonkExpr env guard_op      `thenM` \ new_guard ->
     zonkTcTypeToType env ty    `thenM` \ new_ty ->
-    returnM (env, ExprStmt new_expr new_then new_ty)
+    returnM (env, ExprStmt new_expr new_then new_guard new_ty)
 
-zonkStmt env (TransformStmt stmts binders usingExpr maybeByExpr)
-  = do { (env', stmts') <- zonkStmts env stmts 
-    ; let binders' = zonkIdOccs env' binders
-    ; usingExpr' <- zonkLExpr env' usingExpr
-    ; maybeByExpr' <- zonkMaybeLExpr env' maybeByExpr
-    ; return (env', TransformStmt stmts' binders' usingExpr' maybeByExpr') }
-    
-zonkStmt env (GroupStmt stmts binderMap by using)
+zonkStmt env (LastStmt expr ret_op)
+  = zonkLExpr env expr         `thenM` \ new_expr ->
+    zonkExpr env ret_op                `thenM` \ new_ret ->
+    returnM (env, LastStmt new_expr new_ret)
+
+zonkStmt env (TransStmt { trS_stmts = stmts, trS_bndrs = binderMap
+                        , trS_by = by, trS_form = form, trS_using = using
+                        , trS_ret = return_op, trS_bind = bind_op, trS_fmap = liftM_op })
   = do { (env', stmts') <- zonkStmts env stmts 
     ; binderMap' <- mappM (zonkBinderMapEntry env') binderMap
-    ; by' <- fmapMaybeM (zonkLExpr env') by
-    ; using' <- fmapEitherM (zonkLExpr env) (zonkExpr env) using
+    ; by'        <- fmapMaybeM (zonkLExpr env') by
+    ; using'     <- zonkLExpr env using
+    ; return_op' <- zonkExpr env' return_op
+    ; bind_op'   <- zonkExpr env' bind_op
+    ; liftM_op'  <- zonkExpr env' liftM_op
     ; let env'' = extendZonkEnv env' (map snd binderMap')
-    ; return (env'', GroupStmt stmts' binderMap' by' using') }
+    ; return (env'', TransStmt { trS_stmts = stmts', trS_bndrs = binderMap'
+                               , trS_by = by', trS_form = form, trS_using = using'
+                               , trS_ret = return_op', trS_bind = bind_op', trS_fmap = liftM_op' }) }
   where
     zonkBinderMapEntry env (oldBinder, newBinder) = do 
         let oldBinder' = zonkIdOcc env oldBinder
@@ -795,11 +804,6 @@ zonkStmt env (BindStmt pat expr bind_op fail_op)
        ; new_fail <- zonkExpr env fail_op
        ; return (env1, BindStmt new_pat new_expr new_bind new_fail) }
 
-zonkMaybeLExpr :: ZonkEnv -> Maybe (LHsExpr TcId) -> TcM (Maybe (LHsExpr Id))
-zonkMaybeLExpr _   Nothing  = return Nothing
-zonkMaybeLExpr env (Just e) = (zonkLExpr env e) >>= (return . Just)
-
-
 -------------------------------------------------------------------------
 zonkRecFields :: ZonkEnv -> HsRecordBinds TcId -> TcM (HsRecordBinds TcId)
 zonkRecFields env (HsRecFields flds dd)
@@ -1112,4 +1116,4 @@ zonkTypeZapping ty
     zonk_unbound_tyvar tv = do { let ty = anyTypeOfKind (tyVarKind tv)
                               ; writeMetaTyVar tv ty
                               ; return ty }
-\end{code}
\ No newline at end of file
+\end{code}
index 860a6db..48fdf77 100644 (file)
@@ -6,16 +6,18 @@
 TcMatches: Typecheck some @Matches@
 
 \begin{code}
+{-# OPTIONS_GHC -w #-}   -- debugging
 module TcMatches ( tcMatchesFun, tcGRHSsPat, tcMatchesCase, tcMatchLambda,
-                  TcMatchCtxt(..), 
-                  tcStmts, tcDoStmts, tcBody,
-                  tcDoStmt, tcMDoStmt, tcGuardStmt
+                  TcMatchCtxt(..), TcStmtChecker,
+                  tcStmts, tcStmtsAndThen, tcDoStmts, tcBody,
+                  tcDoStmt, tcGuardStmt
        ) where
 
-import {-# SOURCE #-}  TcExpr( tcSyntaxOp, tcInferRhoNC, tcCheckId,
+import {-# SOURCE #-}  TcExpr( tcSyntaxOp, tcInferRhoNC, tcInferRho, tcCheckId,
                                 tcMonoExpr, tcMonoExprNC, tcPolyExpr )
 
 import HsSyn
+import BasicTypes
 import TcRnMonad
 import TcEnv
 import TcPat
@@ -28,13 +30,15 @@ import TysWiredIn
 import Id
 import TyCon
 import TysPrim
-import Coercion                ( mkSymCoI )
+import Coercion                ( isIdentityCoI, mkSymCoI )
 import Outputable
-import BasicTypes      ( Arity )
 import Util
 import SrcLoc
 import FastString
 
+-- Create chunkified tuple tybes for monad comprehensions
+import MkCore
+
 import Control.Monad
 
 #include "HsVersions.h"
@@ -221,7 +225,7 @@ tcGRHSs ctxt (GRHSs grhss binds) res_ty
 tcGRHS :: TcMatchCtxt -> TcRhoType -> GRHS Name -> TcM (GRHS TcId)
 
 tcGRHS ctxt res_ty (GRHS guards rhs)
-  = do  { (guards', rhs') <- tcStmts stmt_ctxt tcGuardStmt guards res_ty $
+  = do  { (guards', rhs') <- tcStmtsAndThen stmt_ctxt tcGuardStmt guards res_ty $
                             mc_body ctxt rhs
        ; return (GRHS guards' rhs') }
   where
@@ -238,36 +242,33 @@ tcGRHS ctxt res_ty (GRHS guards rhs)
 \begin{code}
 tcDoStmts :: HsStmtContext Name 
          -> [LStmt Name]
-         -> LHsExpr Name
          -> TcRhoType
          -> TcM (HsExpr TcId)          -- Returns a HsDo
-tcDoStmts ListComp stmts body res_ty
+tcDoStmts ListComp stmts res_ty
   = do { (coi, elt_ty) <- matchExpectedListTy res_ty
-       ; (stmts', body') <- tcStmts ListComp (tcLcStmt listTyCon) stmts 
-                                    elt_ty $
-                            tcBody body
-       ; return $ mkHsWrapCoI coi 
-                     (HsDo ListComp stmts' body' (mkListTy elt_ty)) }
+        ; let list_ty = mkListTy elt_ty
+       ; stmts' <- tcStmts ListComp (tcLcStmt listTyCon) stmts elt_ty
+       ; return $ mkHsWrapCoI coi (HsDo ListComp stmts' list_ty) }
 
-tcDoStmts PArrComp stmts body res_ty
+tcDoStmts PArrComp stmts res_ty
   = do { (coi, elt_ty) <- matchExpectedPArrTy res_ty
-       ; (stmts', body') <- tcStmts PArrComp (tcLcStmt parrTyCon) stmts 
-                                    elt_ty $
-                            tcBody body
-       ; return $ mkHsWrapCoI coi 
-                     (HsDo PArrComp stmts' body' (mkPArrTy elt_ty)) }
+        ; let parr_ty = mkPArrTy elt_ty
+       ; stmts' <- tcStmts PArrComp (tcLcStmt parrTyCon) stmts elt_ty
+       ; return $ mkHsWrapCoI coi (HsDo PArrComp stmts' parr_ty) }
+
+tcDoStmts DoExpr stmts res_ty
+  = do { stmts' <- tcStmts DoExpr tcDoStmt stmts res_ty
+       ; return (HsDo DoExpr stmts' res_ty) }
 
-tcDoStmts DoExpr stmts body res_ty
-  = do { (stmts', body') <- tcStmts DoExpr tcDoStmt stmts res_ty $
-                            tcBody body
-       ; return (HsDo DoExpr stmts' body' res_ty) }
+tcDoStmts MDoExpr stmts res_ty
+  = do  { stmts' <- tcStmts MDoExpr tcDoStmt stmts res_ty
+        ; return (HsDo MDoExpr stmts' res_ty) }
 
-tcDoStmts MDoExpr stmts body res_ty
-  = do  { (stmts', body') <- tcStmts MDoExpr tcDoStmt stmts res_ty $
-                            tcBody body
-        ; return (HsDo MDoExpr stmts' body' res_ty) }
+tcDoStmts MonadComp stmts res_ty
+  = do  { stmts' <- tcStmts MonadComp tcMcStmt stmts res_ty 
+        ; return (HsDo MonadComp stmts' res_ty) }
 
-tcDoStmts ctxt _ _ _ = pprPanic "tcDoStmts" (pprStmtContext ctxt)
+tcDoStmts ctxt _ _ = pprPanic "tcDoStmts" (pprStmtContext ctxt)
 
 tcBody :: LHsExpr Name -> TcRhoType -> TcM (LHsExpr TcId)
 tcBody body res_ty
@@ -296,40 +297,52 @@ tcStmts :: HsStmtContext Name
        -> TcStmtChecker        -- NB: higher-rank type
         -> [LStmt Name]
        -> TcRhoType
-       -> (TcRhoType -> TcM thing)
-        -> TcM ([LStmt TcId], thing)
+        -> TcM [LStmt TcId]
+tcStmts ctxt stmt_chk stmts res_ty
+  = do { (stmts', _) <- tcStmtsAndThen ctxt stmt_chk stmts res_ty $
+                        const (return ())
+       ; return stmts' }
+
+tcStmtsAndThen :: HsStmtContext Name
+              -> TcStmtChecker -- NB: higher-rank type
+               -> [LStmt Name]
+              -> TcRhoType
+              -> (TcRhoType -> TcM thing)
+               -> TcM ([LStmt TcId], thing)
 
 -- Note the higher-rank type.  stmt_chk is applied at different
 -- types in the equations for tcStmts
 
-tcStmts _ _ [] res_ty thing_inside
+tcStmtsAndThen _ _ [] res_ty thing_inside
   = do { thing <- thing_inside res_ty
        ; return ([], thing) }
 
 -- LetStmts are handled uniformly, regardless of context
-tcStmts ctxt stmt_chk (L loc (LetStmt binds) : stmts) res_ty thing_inside
+tcStmtsAndThen ctxt stmt_chk (L loc (LetStmt binds) : stmts) res_ty thing_inside
   = do { (binds', (stmts',thing)) <- tcLocalBinds binds $
-                                     tcStmts ctxt stmt_chk stmts res_ty thing_inside
+                                     tcStmtsAndThen ctxt stmt_chk stmts res_ty thing_inside
        ; return (L loc (LetStmt binds') : stmts', thing) }
 
 -- For the vanilla case, handle the location-setting part
-tcStmts ctxt stmt_chk (L loc stmt : stmts) res_ty thing_inside
+tcStmtsAndThen ctxt stmt_chk (L loc stmt : stmts) res_ty thing_inside
   = do         { (stmt', (stmts', thing)) <- 
-               setSrcSpan loc                          $
-               addErrCtxt (pprStmtInCtxt ctxt stmt)    $
-               stmt_chk ctxt stmt res_ty               $ \ res_ty' ->
-               popErrCtxt                              $
-               tcStmts ctxt stmt_chk stmts res_ty'     $
+               setSrcSpan loc                              $
+               addErrCtxt (pprStmtInCtxt ctxt stmt)        $
+               stmt_chk ctxt stmt res_ty                   $ \ res_ty' ->
+               popErrCtxt                                  $
+               tcStmtsAndThen ctxt stmt_chk stmts res_ty'  $
                thing_inside
        ; return (L loc stmt' : stmts', thing) }
 
---------------------------------
---     Pattern guards
+---------------------------------------------------
+--             Pattern guards
+---------------------------------------------------
+
 tcGuardStmt :: TcStmtChecker
-tcGuardStmt _ (ExprStmt guard _ _) res_ty thing_inside
+tcGuardStmt _ (ExprStmt guard _ _ _) res_ty thing_inside
   = do { guard' <- tcMonoExpr guard boolTy
        ; thing  <- thing_inside res_ty
-       ; return (ExprStmt guard' noSyntaxExpr boolTy, thing) }
+       ; return (ExprStmt guard' noSyntaxExpr noSyntaxExpr boolTy, thing) }
 
 tcGuardStmt ctxt (BindStmt pat rhs _ _) res_ty thing_inside
   = do { (rhs', rhs_ty) <- tcInferRhoNC rhs    -- Stmt has a context already
@@ -341,25 +354,292 @@ tcGuardStmt _ stmt _ _
   = pprPanic "tcGuardStmt: unexpected Stmt" (ppr stmt)
 
 
---------------------------------
---     List comprehensions and PArrays
+---------------------------------------------------
+--          List comprehensions and PArrays
+--              (no rebindable syntax)
+---------------------------------------------------
+
+-- Dealt with separately, rather than by tcMcStmt, because
+--   a) PArr isn't (yet) an instance of Monad, so the generality seems overkill
+--   b) We have special desugaring rules for list comprehensions,
+--      which avoid creating intermediate lists.  They in turn 
+--      assume that the bind/return operations are the regular
+--      polymorphic ones, and in particular don't have any
+--      coercion matching stuff in them.  It's hard to avoid the
+--      potential for non-trivial coercions in tcMcStmt
 
 tcLcStmt :: TyCon      -- The list/Parray type constructor ([] or PArray)
         -> TcStmtChecker
 
+tcLcStmt _ _ (LastStmt body _) elt_ty thing_inside
+  = do { body' <- tcMonoExprNC body elt_ty
+       ; thing <- thing_inside (panic "tcLcStmt: thing_inside")
+       ; return (LastStmt body' noSyntaxExpr, thing) }
+
 -- A generator, pat <- rhs
-tcLcStmt m_tc ctxt (BindStmt pat rhs _ _) res_ty thing_inside
+tcLcStmt m_tc ctxt (BindStmt pat rhs _ _) elt_ty thing_inside
  = do  { pat_ty <- newFlexiTyVarTy liftedTypeKind
         ; rhs'   <- tcMonoExpr rhs (mkTyConApp m_tc [pat_ty])
        ; (pat', thing)  <- tcPat (StmtCtxt ctxt) pat pat_ty $
-                            thing_inside res_ty
+                            thing_inside elt_ty
        ; return (BindStmt pat' rhs' noSyntaxExpr noSyntaxExpr, thing) }
 
 -- A boolean guard
-tcLcStmt _ _ (ExprStmt rhs _ _) res_ty thing_inside
+tcLcStmt _ _ (ExprStmt rhs _ _ _) elt_ty thing_inside
   = do { rhs'  <- tcMonoExpr rhs boolTy
-       ; thing <- thing_inside res_ty
-       ; return (ExprStmt rhs' noSyntaxExpr boolTy, thing) }
+       ; thing <- thing_inside elt_ty
+       ; return (ExprStmt rhs' noSyntaxExpr noSyntaxExpr boolTy, thing) }
+
+-- ParStmt: See notes with tcMcStmt
+tcLcStmt m_tc ctxt (ParStmt bndr_stmts_s _ _ _) elt_ty thing_inside
+  = do { (pairs', thing) <- loop bndr_stmts_s
+       ; return (ParStmt pairs' noSyntaxExpr noSyntaxExpr noSyntaxExpr, thing) }
+  where
+    -- loop :: [([LStmt Name], [Name])] -> TcM ([([LStmt TcId], [TcId])], thing)
+    loop [] = do { thing <- thing_inside elt_ty
+                ; return ([], thing) }         -- matching in the branches
+
+    loop ((stmts, names) : pairs)
+      = do { (stmts', (ids, pairs', thing))
+               <- tcStmtsAndThen ctxt (tcLcStmt m_tc) stmts elt_ty $ \ _elt_ty' ->
+                  do { ids <- tcLookupLocalIds names
+                     ; (pairs', thing) <- loop pairs
+                     ; return (ids, pairs', thing) }
+          ; return ( (stmts', ids) : pairs', thing ) }
+
+tcLcStmt m_tc ctxt (TransStmt { trS_form = form, trS_stmts = stmts
+                              , trS_bndrs =  bindersMap
+                              , trS_by = by, trS_using = using }) elt_ty thing_inside
+  = do { let (bndr_names, n_bndr_names) = unzip bindersMap
+             unused_ty = pprPanic "tcLcStmt: inner ty" (ppr bindersMap)
+                    -- The inner 'stmts' lack a LastStmt, so the element type
+            --  passed in to tcStmtsAndThen is never looked at
+       ; (stmts', (bndr_ids, by'))
+            <- tcStmtsAndThen (TransStmtCtxt ctxt) (tcLcStmt m_tc) stmts unused_ty $ \_ -> do
+              { by' <- case by of
+                           Nothing -> return Nothing
+                           Just e  -> do { e_ty <- tcInferRho e; return (Just e_ty) }
+               ; bndr_ids <- tcLookupLocalIds bndr_names
+               ; return (bndr_ids, by') }
+
+       ; let m_app ty = mkTyConApp m_tc [ty]
+
+       --------------- Typecheck the 'using' function -------------
+       -- using :: ((a,b,c)->t) -> m (a,b,c) -> m (a,b,c)m      (ThenForm)
+       --       :: ((a,b,c)->t) -> m (a,b,c) -> m (m (a,b,c)))  (GroupForm)
+
+         -- n_app :: Type -> Type   -- Wraps a 'ty' into '[ty]' for GroupForm
+       ; let n_app = case form of
+                       ThenForm -> (\ty -> ty)
+                      _        -> m_app
+
+             by_arrow :: Type -> Type     -- Wraps 'ty' to '(a->t) -> ty' if the By is present
+             by_arrow = case by' of
+                          Nothing       -> \ty -> ty
+                          Just (_,e_ty) -> \ty -> (alphaTy `mkFunTy` e_ty) `mkFunTy` ty
+
+             tup_ty        = mkBigCoreVarTupTy bndr_ids
+             poly_arg_ty   = m_app alphaTy
+            poly_res_ty   = m_app (n_app alphaTy)
+            using_poly_ty = mkForAllTy alphaTyVar $ by_arrow $ 
+                             poly_arg_ty `mkFunTy` poly_res_ty
+
+       ; using' <- tcPolyExpr using using_poly_ty
+       ; let final_using = fmap (HsWrap (WpTyApp tup_ty)) using' 
+
+            -- 'stmts' returns a result of type (m1_ty tuple_ty),
+            -- typically something like [(Int,Bool,Int)]
+            -- We don't know what tuple_ty is yet, so we use a variable
+       ; let mk_n_bndr :: Name -> TcId -> TcId
+             mk_n_bndr n_bndr_name bndr_id = mkLocalId n_bndr_name (n_app (idType bndr_id))
+
+             -- Ensure that every old binder of type `b` is linked up with its
+             -- new binder which should have type `n b`
+            -- See Note [GroupStmt binder map] in HsExpr
+             n_bndr_ids  = zipWith mk_n_bndr n_bndr_names bndr_ids
+             bindersMap' = bndr_ids `zip` n_bndr_ids
+
+       -- Type check the thing in the environment with 
+       -- these new binders and return the result
+       ; thing <- tcExtendIdEnv n_bndr_ids (thing_inside elt_ty)
+
+       ; return (emptyTransStmt { trS_stmts = stmts', trS_bndrs = bindersMap' 
+                                , trS_by = fmap fst by', trS_using = final_using 
+                                , trS_form = form }, thing) }
+    
+tcLcStmt _ _ stmt _ _
+  = pprPanic "tcLcStmt: unexpected Stmt" (ppr stmt)
+
+
+---------------------------------------------------
+--          Monad comprehensions 
+--       (supports rebindable syntax)
+---------------------------------------------------
+
+tcMcStmt :: TcStmtChecker
+
+tcMcStmt _ (LastStmt body return_op) res_ty thing_inside
+  = do  { a_ty       <- newFlexiTyVarTy liftedTypeKind
+        ; return_op' <- tcSyntaxOp MCompOrigin return_op
+                                   (a_ty `mkFunTy` res_ty)
+        ; body'      <- tcMonoExprNC body a_ty
+        ; thing      <- thing_inside (panic "tcMcStmt: thing_inside")
+        ; return (LastStmt body' return_op', thing) } 
+
+-- Generators for monad comprehensions ( pat <- rhs )
+--
+--   [ body | q <- gen ]  ->  gen :: m a
+--                            q   ::   a
+--
+
+tcMcStmt ctxt (BindStmt pat rhs bind_op fail_op) res_ty thing_inside
+ = do   { rhs_ty     <- newFlexiTyVarTy liftedTypeKind
+        ; pat_ty     <- newFlexiTyVarTy liftedTypeKind
+        ; new_res_ty <- newFlexiTyVarTy liftedTypeKind
+
+          -- (>>=) :: rhs_ty -> (pat_ty -> new_res_ty) -> res_ty
+        ; bind_op'   <- tcSyntaxOp MCompOrigin bind_op 
+                             (mkFunTys [rhs_ty, mkFunTy pat_ty new_res_ty] res_ty)
+
+           -- If (but only if) the pattern can fail, typecheck the 'fail' operator
+        ; fail_op' <- if isIrrefutableHsPat pat 
+                      then return noSyntaxExpr
+                      else tcSyntaxOp MCompOrigin fail_op (mkFunTy stringTy new_res_ty)
+
+        ; rhs' <- tcMonoExprNC rhs rhs_ty
+        ; (pat', thing) <- tcPat (StmtCtxt ctxt) pat pat_ty $
+                           thing_inside new_res_ty
+
+        ; return (BindStmt pat' rhs' bind_op' fail_op', thing) }
+
+-- Boolean expressions.
+--
+--   [ body | stmts, expr ]  ->  expr :: m Bool
+--
+tcMcStmt _ (ExprStmt rhs then_op guard_op _) res_ty thing_inside
+  = do { -- Deal with rebindable syntax:
+          --    guard_op :: test_ty -> rhs_ty
+          --    then_op  :: rhs_ty -> new_res_ty -> res_ty
+          -- Where test_ty is, for example, Bool
+          test_ty    <- newFlexiTyVarTy liftedTypeKind
+        ; rhs_ty     <- newFlexiTyVarTy liftedTypeKind
+        ; new_res_ty <- newFlexiTyVarTy liftedTypeKind
+        ; rhs'       <- tcMonoExpr rhs test_ty
+        ; guard_op'  <- tcSyntaxOp MCompOrigin guard_op
+                                   (mkFunTy test_ty rhs_ty)
+        ; then_op'   <- tcSyntaxOp MCompOrigin then_op
+                                  (mkFunTys [rhs_ty, new_res_ty] res_ty)
+       ; thing      <- thing_inside new_res_ty
+       ; return (ExprStmt rhs' then_op' guard_op' rhs_ty, thing) }
+
+-- Grouping statements
+--
+--   [ body | stmts, then group by e ]
+--     ->  e :: t
+--   [ body | stmts, then group by e using f ]
+--     ->  e :: t
+--         f :: forall a. (a -> t) -> m a -> m (m a)
+--   [ body | stmts, then group using f ]
+--     ->  f :: forall a. m a -> m (m a)
+
+-- We type [ body | (stmts, group by e using f), ... ]
+--     f <optional by> [ (a,b,c) | stmts ] >>= \(a,b,c) -> ...body....
+--
+-- We type the functions as follows:
+--     f <optional by> :: m1 (a,b,c) -> m2 (a,b,c)             (ThenForm)
+--                            :: m1 (a,b,c) -> m2 (n (a,b,c))          (GroupForm)
+--     (>>=) :: m2 (a,b,c)     -> ((a,b,c)   -> res) -> res    (ThenForm)
+--           :: m2 (n (a,b,c)) -> (n (a,b,c) -> res) -> res    (GroupForm)
+-- 
+tcMcStmt ctxt (TransStmt { trS_stmts = stmts, trS_bndrs = bindersMap
+                         , trS_by = by, trS_using = using, trS_form = form
+                         , trS_ret = return_op, trS_bind = bind_op 
+                         , trS_fmap = fmap_op }) res_ty thing_inside
+  = do { let star_star_kind = liftedTypeKind `mkArrowKind` liftedTypeKind
+       ; m1_ty   <- newFlexiTyVarTy star_star_kind
+       ; m2_ty   <- newFlexiTyVarTy star_star_kind
+       ; tup_ty  <- newFlexiTyVarTy liftedTypeKind
+       ; by_e_ty <- newFlexiTyVarTy liftedTypeKind  -- The type of the 'by' expression (if any)
+
+         -- n_app :: Type -> Type   -- Wraps a 'ty' into '(n ty)' for GroupForm
+       ; n_app <- case form of
+                    ThenForm -> return (\ty -> ty)
+                   _        -> do { n_ty <- newFlexiTyVarTy star_star_kind
+                                  ; return (n_ty `mkAppTy`) }
+       ; let by_arrow :: Type -> Type     
+             -- (by_arrow res) produces ((alpha->e_ty) -> res)     ('by' present)
+             --                          or res                    ('by' absent) 
+             by_arrow = case by of
+                          Nothing -> \res -> res
+                          Just {} -> \res -> (alphaTy `mkFunTy` by_e_ty) `mkFunTy` res
+
+             poly_arg_ty  = m1_ty `mkAppTy` alphaTy
+             using_arg_ty = m1_ty `mkAppTy` tup_ty
+            poly_res_ty  = m2_ty `mkAppTy` n_app alphaTy
+            using_res_ty = m2_ty `mkAppTy` n_app tup_ty
+            using_poly_ty = mkForAllTy alphaTyVar $ by_arrow $ 
+                             poly_arg_ty `mkFunTy` poly_res_ty
+
+            -- 'stmts' returns a result of type (m1_ty tuple_ty),
+            -- typically something like [(Int,Bool,Int)]
+            -- We don't know what tuple_ty is yet, so we use a variable
+       ; let (bndr_names, n_bndr_names) = unzip bindersMap
+       ; (stmts', (bndr_ids, by', return_op')) <-
+            tcStmtsAndThen (TransStmtCtxt ctxt) tcMcStmt stmts using_arg_ty $ \res_ty' -> do
+               { by' <- case by of
+                           Nothing -> return Nothing
+                           Just e  -> do { e' <- tcMonoExpr e by_e_ty; return (Just e') }
+
+                -- Find the Ids (and hence types) of all old binders
+                ; bndr_ids <- tcLookupLocalIds bndr_names
+
+                -- 'return' is only used for the binders, so we know its type.
+                --   return :: (a,b,c,..) -> m (a,b,c,..)
+                ; return_op' <- tcSyntaxOp MCompOrigin return_op $ 
+                                (mkBigCoreVarTupTy bndr_ids) `mkFunTy` res_ty'
+
+                ; return (bndr_ids, by', return_op') }
+
+       --------------- Typecheck the 'bind' function -------------
+       -- (>>=) :: m2 (n (a,b,c)) -> ( n (a,b,c) -> new_res_ty ) -> res_ty
+       ; new_res_ty <- newFlexiTyVarTy liftedTypeKind
+       ; bind_op' <- tcSyntaxOp MCompOrigin bind_op $
+                                using_res_ty `mkFunTy` (n_app tup_ty `mkFunTy` new_res_ty)
+                                             `mkFunTy` res_ty
+
+       --------------- Typecheck the 'fmap' function -------------
+       ; fmap_op' <- case form of
+                       ThenForm -> return noSyntaxExpr
+                       _ -> fmap unLoc . tcPolyExpr (noLoc fmap_op) $
+                            mkForAllTy alphaTyVar $ mkForAllTy betaTyVar $
+                            (alphaTy `mkFunTy` betaTy)
+                            `mkFunTy` (n_app alphaTy)
+                            `mkFunTy` (n_app betaTy)
+
+       --------------- Typecheck the 'using' function -------------
+       -- using :: ((a,b,c)->t) -> m1 (a,b,c) -> m2 (n (a,b,c))
+
+       ; using' <- tcPolyExpr using using_poly_ty
+       ; let final_using = fmap (HsWrap (WpTyApp tup_ty)) using' 
+
+       --------------- Bulding the bindersMap ----------------
+       ; let mk_n_bndr :: Name -> TcId -> TcId
+             mk_n_bndr n_bndr_name bndr_id = mkLocalId n_bndr_name (n_app (idType bndr_id))
+
+             -- Ensure that every old binder of type `b` is linked up with its
+             -- new binder which should have type `n b`
+            -- See Note [GroupStmt binder map] in HsExpr
+             n_bndr_ids = zipWith mk_n_bndr n_bndr_names bndr_ids
+             bindersMap' = bndr_ids `zip` n_bndr_ids
+
+       -- Type check the thing in the environment with 
+       -- these new binders and return the result
+       ; thing <- tcExtendIdEnv n_bndr_ids (thing_inside new_res_ty)
+
+       ; return (TransStmt { trS_stmts = stmts', trS_bndrs = bindersMap' 
+                           , trS_by = by', trS_using = final_using 
+                           , trS_ret = return_op', trS_bind = bind_op'
+                           , trS_fmap = fmap_op', trS_form = form }, thing) }
 
 -- A parallel set of comprehensions
 --     [ (g x, h x) | ... ; let g v = ...
@@ -381,106 +661,95 @@ tcLcStmt _ _ (ExprStmt rhs _ _) res_ty thing_inside
 -- ensure that g,h and x,y don't duplicate, and simply grow the environment.
 -- So the binders of the first parallel group will be in scope in the second
 -- group.  But that's fine; there's no shadowing to worry about.
-
-tcLcStmt m_tc ctxt (ParStmt bndr_stmts_s) elt_ty thing_inside
-  = do { (pairs', thing) <- loop bndr_stmts_s
-       ; return (ParStmt pairs', thing) }
-  where
-    -- loop :: [([LStmt Name], [Name])] -> TcM ([([LStmt TcId], [TcId])], thing)
-    loop [] = do { thing <- thing_inside elt_ty
-                ; return ([], thing) }         -- matching in the branches
-
-    loop ((stmts, names) : pairs)
-      = do { (stmts', (ids, pairs', thing))
-               <- tcStmts ctxt (tcLcStmt m_tc) stmts elt_ty $ \ _elt_ty' ->
-                  do { ids <- tcLookupLocalIds names
-                     ; (pairs', thing) <- loop pairs
-                     ; return (ids, pairs', thing) }
-          ; return ( (stmts', ids) : pairs', thing ) }
-
-tcLcStmt m_tc ctxt (TransformStmt stmts binders usingExpr maybeByExpr) elt_ty thing_inside = do
-    (stmts', (binders', usingExpr', maybeByExpr', thing)) <- 
-        tcStmts (TransformStmtCtxt ctxt) (tcLcStmt m_tc) stmts elt_ty $ \elt_ty' -> do
-            let alphaListTy = mkTyConApp m_tc [alphaTy]
-                    
-            (usingExpr', maybeByExpr') <- 
-                case maybeByExpr of
-                    Nothing -> do
-                        -- We must validate that usingExpr :: forall a. [a] -> [a]
-                        let using_ty = mkForAllTy alphaTyVar (alphaListTy `mkFunTy` alphaListTy)
-                        usingExpr' <- tcPolyExpr usingExpr using_ty
-                        return (usingExpr', Nothing)
-                    Just byExpr -> do
-                        -- We must infer a type such that e :: t and then check that 
-                       -- usingExpr :: forall a. (a -> t) -> [a] -> [a]
-                        (byExpr', tTy) <- tcInferRhoNC byExpr
-                        let using_ty = mkForAllTy alphaTyVar $ 
-                                       (alphaTy `mkFunTy` tTy)
-                                       `mkFunTy` alphaListTy `mkFunTy` alphaListTy
-                        usingExpr' <- tcPolyExpr usingExpr using_ty
-                        return (usingExpr', Just byExpr')
-            
-            binders' <- tcLookupLocalIds binders
-            thing <- thing_inside elt_ty'
-            
-            return (binders', usingExpr', maybeByExpr', thing)
-
-    return (TransformStmt stmts' binders' usingExpr' maybeByExpr', thing)
-
-tcLcStmt m_tc ctxt (GroupStmt stmts bindersMap by using) elt_ty thing_inside
-  = do { let (bndr_names, list_bndr_names) = unzip bindersMap
-
-       ; (stmts', (bndr_ids, by', using_ty, elt_ty')) <-
-            tcStmts (TransformStmtCtxt ctxt) (tcLcStmt m_tc) stmts elt_ty $ \elt_ty' -> do
-               (by', using_ty) <- 
-                   case by of
-                     Nothing   -> -- check that using :: forall a. [a] -> [[a]]
-                                  return (Nothing, mkForAllTy alphaTyVar $
-                                                   alphaListTy `mkFunTy` alphaListListTy)
-                                       
-                    Just by_e -> -- check that using :: forall a. (a -> t) -> [a] -> [[a]]
-                                 -- where by :: t
-                                  do { (by_e', t_ty) <- tcInferRhoNC by_e
-                                     ; return (Just by_e', mkForAllTy alphaTyVar $
-                                                           (alphaTy `mkFunTy` t_ty) 
-                                                           `mkFunTy` alphaListTy 
-                                                           `mkFunTy` alphaListListTy) }
-                -- Find the Ids (and hence types) of all old binders
-                bndr_ids <- tcLookupLocalIds bndr_names
-                
-                return (bndr_ids, by', using_ty, elt_ty')
-        
-                -- Ensure that every old binder of type b is linked up with
-               -- its new binder which should have type [b]
-       ; let list_bndr_ids = zipWith mk_list_bndr list_bndr_names bndr_ids
-             bindersMap' = bndr_ids `zip` list_bndr_ids
-            -- See Note [GroupStmt binder map] in HsExpr
-            
-       ; using' <- case using of
-                     Left  e -> do { e' <- tcPolyExpr e         using_ty; return (Left  e') }
-                     Right e -> do { e' <- tcPolyExpr (noLoc e) using_ty; return (Right (unLoc e')) }
-
-             -- Type check the thing in the environment with 
-            -- these new binders and return the result
-       ; thing <- tcExtendIdEnv list_bndr_ids (thing_inside elt_ty')
-       ; return (GroupStmt stmts' bindersMap' by' using', thing) }
-  where
-    alphaListTy = mkTyConApp m_tc [alphaTy]
-    alphaListListTy = mkTyConApp m_tc [alphaListTy]
-            
-    mk_list_bndr :: Name -> TcId -> TcId
-    mk_list_bndr list_bndr_name bndr_id 
-      = mkLocalId list_bndr_name (mkTyConApp m_tc [idType bndr_id])
-    
-tcLcStmt _ _ stmt _ _
-  = pprPanic "tcLcStmt: unexpected Stmt" (ppr stmt)
-        
---------------------------------
---     Do-notation
--- The main excitement here is dealing with rebindable syntax
+--
+-- Note: The `mzip` function will get typechecked via:
+--
+--   ParStmt [st1::t1, st2::t2, st3::t3]
+--   
+--   mzip :: m st1
+--        -> (m st2 -> m st3 -> m (st2, st3))   -- recursive call
+--        -> m (st1, (st2, st3))
+--
+tcMcStmt ctxt (ParStmt bndr_stmts_s mzip_op bind_op return_op) res_ty thing_inside
+  = do { let star_star_kind = liftedTypeKind `mkArrowKind` liftedTypeKind
+       ; m_ty   <- newFlexiTyVarTy star_star_kind
+
+       ; let mzip_ty  = mkForAllTys [alphaTyVar, betaTyVar] $
+                        (m_ty `mkAppTy` alphaTy)
+                        `mkFunTy`
+                        (m_ty `mkAppTy` betaTy)
+                        `mkFunTy`
+                        (m_ty `mkAppTy` mkBoxedTupleTy [alphaTy, betaTy])
+       ; mzip_op' <- unLoc `fmap` tcPolyExpr (noLoc mzip_op) mzip_ty
+
+       ; return_op' <- fmap unLoc . tcPolyExpr (noLoc return_op) $
+                       mkForAllTy alphaTyVar $
+                       alphaTy `mkFunTy` (m_ty `mkAppTy` alphaTy)
+
+       ; (pairs', thing) <- loop m_ty bndr_stmts_s
+
+       -- Typecheck bind:
+       ; let tys      = map (mkBigCoreVarTupTy . snd) pairs'
+             tuple_ty = mk_tuple_ty tys
+
+       ; bind_op' <- tcSyntaxOp MCompOrigin bind_op $
+                        (m_ty `mkAppTy` tuple_ty)
+                        `mkFunTy` (tuple_ty `mkFunTy` res_ty)
+                        `mkFunTy` res_ty
+
+       ; return (ParStmt pairs' mzip_op' bind_op' return_op', thing) }
+
+  where 
+    mk_tuple_ty tys = foldr1 (\tn tm -> mkBoxedTupleTy [tn, tm]) tys
+
+       -- loop :: Type                                  -- m_ty
+       --      -> [([LStmt Name], [Name])]
+       --      -> TcM ([([LStmt TcId], [TcId])], thing)
+    loop _ [] = do { thing <- thing_inside res_ty
+                   ; return ([], thing) }           -- matching in the branches
+
+    loop m_ty ((stmts, names) : pairs)
+      = do { -- type dummy since we don't know all binder types yet
+             ty_dummy <- newFlexiTyVarTy liftedTypeKind
+           ; (stmts', (ids, pairs', thing))
+                <- tcStmtsAndThen ctxt tcMcStmt stmts ty_dummy $ \res_ty' ->
+                   do { ids <- tcLookupLocalIds names
+                     ; let m_tup_ty = m_ty `mkAppTy` mkBigCoreVarTupTy ids
+
+                     ; check_same m_tup_ty res_ty'
+                     ; check_same m_tup_ty ty_dummy
+                                                        
+                      ; (pairs', thing) <- loop m_ty pairs
+                      ; return (ids, pairs', thing) }
+           ; return ( (stmts', ids) : pairs', thing ) }
+
+       -- Check that the types match up.
+       -- This is a grevious hack.  They always *will* match 
+       -- If (>>=) and (>>) are polymorpic in the return type,
+       -- but we don't have any good way to incorporate the coercion
+       -- so for now we just check that it's the identity
+    check_same actual expected
+      = do { coi <- unifyType actual expected
+          ; unless (isIdentityCoI coi) $
+             failWithMisMatch [UnifyOrigin { uo_expected = expected
+                                           , uo_actual = actual }] }
+
+tcMcStmt _ stmt _ _
+  = pprPanic "tcMcStmt: unexpected Stmt" (ppr stmt)
+
+
+---------------------------------------------------
+--          Do-notation
+--       (supports rebindable syntax)
+---------------------------------------------------
 
 tcDoStmt :: TcStmtChecker
 
+tcDoStmt _ (LastStmt body _) res_ty thing_inside
+  = do { body' <- tcMonoExprNC body res_ty
+       ; thing <- thing_inside (panic "tcDoStmt: thing_inside")
+       ; return (LastStmt body' noSyntaxExpr, thing) }
+
 tcDoStmt ctxt (BindStmt pat rhs bind_op fail_op) res_ty thing_inside
   = do {       -- Deal with rebindable syntax:
                --       (>>=) :: rhs_ty -> (pat_ty -> new_res_ty) -> res_ty
@@ -510,7 +779,7 @@ tcDoStmt ctxt (BindStmt pat rhs bind_op fail_op) res_ty thing_inside
        ; return (BindStmt pat' rhs' bind_op' fail_op', thing) }
 
 
-tcDoStmt _ (ExprStmt rhs then_op _) res_ty thing_inside
+tcDoStmt _ (ExprStmt rhs then_op _ _) res_ty thing_inside
   = do {       -- Deal with rebindable syntax; 
                 --   (>>) :: rhs_ty -> new_res_ty -> res_ty
                -- See also Note [Treat rebindable syntax first]
@@ -521,7 +790,7 @@ tcDoStmt _ (ExprStmt rhs then_op _) res_ty thing_inside
 
         ; rhs' <- tcMonoExprNC rhs rhs_ty
        ; thing <- thing_inside new_res_ty
-       ; return (ExprStmt rhs' then_op' rhs_ty, thing) }
+       ; return (ExprStmt rhs' then_op' noSyntaxExpr rhs_ty, thing) }
 
 tcDoStmt ctxt (RecStmt { recS_stmts = stmts, recS_later_ids = later_names
                        , recS_rec_ids = rec_names, recS_ret_fn = ret_op
@@ -535,7 +804,7 @@ tcDoStmt ctxt (RecStmt { recS_stmts = stmts, recS_later_ids = later_names
         ; tcExtendIdEnv tup_ids $ do
         { stmts_ty <- newFlexiTyVarTy liftedTypeKind
         ; (stmts', (ret_op', tup_rets))
-                <- tcStmts ctxt tcDoStmt stmts stmts_ty   $ \ inner_res_ty ->
+                <- tcStmtsAndThen ctxt tcDoStmt stmts stmts_ty   $ \ inner_res_ty ->
                    do { tup_rets <- zipWithM tcCheckId tup_names tup_elt_tys
                              -- Unify the types of the "final" Ids (which may 
                              -- be polymorphic) with those of "knot-tied" Ids
@@ -551,7 +820,6 @@ tcDoStmt ctxt (RecStmt { recS_stmts = stmts, recS_later_ids = later_names
                                 (mkFunTys [mfix_res_ty, mkFunTy tup_ty new_res_ty] res_ty)
 
         ; thing <- thing_inside new_res_ty
---         ; lie_binds <- bindLocalMethods lie tup_ids
   
         ; let rec_ids = takeList rec_names tup_ids
        ; later_ids <- tcLookupLocalIds later_names
@@ -560,7 +828,7 @@ tcDoStmt ctxt (RecStmt { recS_stmts = stmts, recS_later_ids = later_names
         ; return (RecStmt { recS_stmts = stmts', recS_later_ids = later_ids
                           , recS_rec_ids = rec_ids, recS_ret_fn = ret_op' 
                           , recS_mfix_fn = mfix_op', recS_bind_fn = bind_op'
-                          , recS_rec_rets = tup_rets }, thing)
+                          , recS_rec_rets = tup_rets, recS_ret_ty = stmts_ty }, thing)
         }}
 
 tcDoStmt _ stmt _ _
@@ -577,51 +845,6 @@ rebindable syntax first, and push that information into (tcMonoExprNC rhs).
 Otherwise the error shows up when cheking the rebindable syntax, and
 the expected/inferred stuff is back to front (see Trac #3613).
 
-\begin{code}
---------------------------------
---     Mdo-notation
--- The distinctive features here are
---     (a) RecStmts, and
---     (b) no rebindable syntax
-
-tcMDoStmt :: (LHsExpr Name -> TcM (LHsExpr TcId, TcType))      -- RHS inference
-         -> TcStmtChecker
-tcMDoStmt tc_rhs ctxt (BindStmt pat rhs _ _) res_ty thing_inside
-  = do { (rhs', pat_ty) <- tc_rhs rhs
-       ; (pat', thing)  <- tcPat (StmtCtxt ctxt) pat pat_ty $
-                            thing_inside res_ty
-       ; return (BindStmt pat' rhs' noSyntaxExpr noSyntaxExpr, thing) }
-
-tcMDoStmt tc_rhs _ (ExprStmt rhs _ _) res_ty thing_inside
-  = do { (rhs', elt_ty) <- tc_rhs rhs
-       ; thing          <- thing_inside res_ty
-       ; return (ExprStmt rhs' noSyntaxExpr elt_ty, thing) }
-
-tcMDoStmt tc_rhs ctxt (RecStmt { recS_stmts = stmts, recS_later_ids = laterNames
-                               , recS_rec_ids = recNames }) res_ty thing_inside
-  = do { rec_tys <- newFlexiTyVarTys (length recNames) liftedTypeKind
-       ; let rec_ids = zipWith mkLocalId recNames rec_tys
-       ; tcExtendIdEnv rec_ids                 $ do
-       { (stmts', (later_ids, rec_rets))
-               <- tcStmts ctxt (tcMDoStmt tc_rhs) stmts res_ty $ \ _res_ty' ->
-                       -- ToDo: res_ty not really right
-                  do { rec_rets <- zipWithM tcCheckId recNames rec_tys
-                     ; later_ids <- tcLookupLocalIds laterNames
-                     ; return (later_ids, rec_rets) }
-
-       ; thing <- tcExtendIdEnv later_ids (thing_inside res_ty)
-               -- NB:  The rec_ids for the recursive things 
-               --      already scope over this part. This binding may shadow
-               --      some of them with polymorphic things with the same Name
-               --      (see note [RecStmt] in HsExpr)
-
-        ; return (RecStmt stmts' later_ids rec_ids noSyntaxExpr noSyntaxExpr noSyntaxExpr rec_rets, thing)
-       }}
-
-tcMDoStmt _ _ stmt _ _
-  = pprPanic "tcMDoStmt: unexpected Stmt" (ppr stmt)
-\end{code}
-
 
 %************************************************************************
 %*                                                                     *
index d28e901..39594f0 100644 (file)
@@ -36,7 +36,6 @@ import PrelNames
 import BasicTypes hiding (SuccessFlag(..))
 import DynFlags
 import SrcLoc
-import ErrUtils
 import Util
 import Outputable
 import FastString
@@ -348,9 +347,9 @@ tc_lpat :: LPat Name
        -> TcM a
        -> TcM (LPat TcId, a)
 tc_lpat (L span pat) pat_ty penv thing_inside
-  = setSrcSpan span              $
-    maybeAddErrCtxt (patCtxt pat) $
-    do { (pat', res) <- tc_pat penv pat pat_ty thing_inside
+  = setSrcSpan span $
+    do { (pat', res) <- maybeWrapPatCtxt pat (tc_pat penv pat pat_ty)
+                                          thing_inside
        ; return (L span pat', res) }
 
 tc_lpats :: PatEnv
@@ -774,7 +773,6 @@ matchExpectedConTy data_tc pat_ty
                     -- coi : T tys ~ pat_ty
 \end{code}
 
-Noate [
 Note [Matching constructor patterns]
 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 Suppose (coi, tys) = matchExpectedConType data_tc pat_ty
@@ -1006,12 +1004,18 @@ sigPatCtxt pats bound_tvs pat_tys body_ty tidy_env
 -}
 
 \begin{code}
-patCtxt :: Pat Name -> Maybe Message   -- Not all patterns are worth pushing a context
-patCtxt (VarPat _)  = Nothing
-patCtxt (ParPat _)  = Nothing
-patCtxt (AsPat _ _) = Nothing
-patCtxt pat        = Just (hang (ptext (sLit "In the pattern:")) 
-                         2 (ppr pat))
+maybeWrapPatCtxt :: Pat Name -> (TcM a -> TcM b) -> TcM a -> TcM b
+-- Not all patterns are worth pushing a context
+maybeWrapPatCtxt pat tcm thing_inside 
+  | not (worth_wrapping pat) = tcm thing_inside
+  | otherwise                = addErrCtxt msg $ tcm $ popErrCtxt thing_inside
+                              -- Remember to pop before doing thing_inside
+  where
+   worth_wrapping (VarPat {}) = False
+   worth_wrapping (ParPat {}) = False
+   worth_wrapping (AsPat {})  = False
+   worth_wrapping _          = True
+   msg = hang (ptext (sLit "In the pattern:")) 2 (ppr pat)
 
 -----------------------------------------------
 checkExistentials :: [TyVar] -> PatEnv -> TcM ()
index 23c2e67..7b1d5a6 100644 (file)
@@ -1205,7 +1205,7 @@ runPlans (p:ps) = tryTcLIE_ (runPlans ps) p
 
 --------------------
 mkPlan :: LStmt Name -> TcM PlanResult
-mkPlan (L loc (ExprStmt expr _ _))     -- An expression typed at the prompt 
+mkPlan (L loc (ExprStmt expr _ _ _))   -- An expression typed at the prompt 
   = do { uniq <- newUnique             -- is treated very specially
        ; let fresh_it  = itName uniq
              the_bind  = L loc $ mkFunBind (L loc fresh_it) matches
@@ -1214,7 +1214,7 @@ mkPlan (L loc (ExprStmt expr _ _))        -- An expression typed at the prompt
              bind_stmt = L loc $ BindStmt (nlVarPat fresh_it) expr
                                           (HsVar bindIOName) noSyntaxExpr 
              print_it  = L loc $ ExprStmt (nlHsApp (nlHsVar printName) (nlHsVar fresh_it))
-                                          (HsVar thenIOName) placeHolderType
+                                          (HsVar thenIOName) noSyntaxExpr placeHolderType
 
        -- The plans are:
        --      [it <- e; print it]     but not if it::()
@@ -1242,7 +1242,7 @@ mkPlan (L loc (ExprStmt expr _ _))        -- An expression typed at the prompt
 mkPlan stmt@(L loc (BindStmt {}))
   | [v] <- collectLStmtBinders stmt            -- One binder, for a bind stmt 
   = do { let print_v  = L loc $ ExprStmt (nlHsApp (nlHsVar printName) (nlHsVar v))
-                                          (HsVar thenIOName) placeHolderType
+                                         (HsVar thenIOName) noSyntaxExpr placeHolderType
 
        ; print_bind_result <- doptM Opt_PrintBindResult
        ; let print_plan = do
@@ -1269,11 +1269,25 @@ tcGhciStmts stmts
        let {
            ret_ty    = mkListTy unitTy ;
            io_ret_ty = mkTyConApp ioTyCon [ret_ty] ;
-           tc_io_stmts stmts = tcStmts GhciStmt tcDoStmt stmts io_ret_ty ;
-
+           tc_io_stmts stmts = tcStmtsAndThen GhciStmt tcDoStmt stmts io_ret_ty ;
            names = collectLStmtsBinders stmts ;
+        } ;
+
+       -- OK, we're ready to typecheck the stmts
+       traceTc "TcRnDriver.tcGhciStmts: tc stmts" empty ;
+       ((tc_stmts, ids), lie) <- captureConstraints $ 
+                                  tc_io_stmts stmts  $ \ _ ->
+                                 mapM tcLookupId names  ;
+                       -- Look up the names right in the middle,
+                       -- where they will all be in scope
 
-               -- mk_return builds the expression
+       -- Simplify the context
+       traceTc "TcRnDriver.tcGhciStmts: simplify ctxt" empty ;
+       const_binds <- checkNoErrs (simplifyInteractive lie) ;
+               -- checkNoErrs ensures that the plan fails if context redn fails
+
+       traceTc "TcRnDriver.tcGhciStmts: done" empty ;
+        let {   -- mk_return builds the expression
                --      returnIO @ [()] [coerce () x, ..,  coerce () z]
                --
                -- Despite the inconvenience of building the type applications etc,
@@ -1284,27 +1298,14 @@ tcGhciStmts stmts
                -- then the type checker would instantiate x..z, and we wouldn't
                -- get their *polymorphic* values.  (And we'd get ambiguity errs
                -- if they were overloaded, since they aren't applied to anything.)
-           mk_return ids = nlHsApp (nlHsTyApp ret_id [ret_ty]) 
-                                   (noLoc $ ExplicitList unitTy (map mk_item ids)) ;
+           ret_expr = nlHsApp (nlHsTyApp ret_id [ret_ty]) 
+                      (noLoc $ ExplicitList unitTy (map mk_item ids)) ;
            mk_item id = nlHsApp (nlHsTyApp unsafeCoerceId [idType id, unitTy])
-                                (nlHsVar id) 
-        } ;
-
-       -- OK, we're ready to typecheck the stmts
-       traceTc "TcRnDriver.tcGhciStmts: tc stmts" empty ;
-       ((tc_stmts, ids), lie) <- captureConstraints $ tc_io_stmts stmts $ \ _ ->
-                                          mapM tcLookupId names ;
-                                       -- Look up the names right in the middle,
-                                       -- where they will all be in scope
-
-       -- Simplify the context
-       traceTc "TcRnDriver.tcGhciStmts: simplify ctxt" empty ;
-       const_binds <- checkNoErrs (simplifyInteractive lie) ;
-               -- checkNoErrs ensures that the plan fails if context redn fails
-
-       traceTc "TcRnDriver.tcGhciStmts: done" empty ;
+                                (nlHsVar id) ;
+           stmts = tc_stmts ++ [noLoc (mkLastStmt ret_expr)]
+        } ;
        return (ids, mkHsDictLet (EvBinds const_binds) $
-                    noLoc (HsDo GhciStmt tc_stmts (mk_return ids) io_ret_ty))
+                    noLoc (HsDo GhciStmt stmts io_ret_ty))
     }
 \end{code}
 
index ad2405b..826c09b 100644 (file)
@@ -781,11 +781,6 @@ updCtxt :: ([ErrCtxt] -> [ErrCtxt]) -> TcM a -> TcM a
 updCtxt upd = updLclEnv (\ env@(TcLclEnv { tcl_ctxt = ctxt }) -> 
                           env { tcl_ctxt = upd ctxt })
 
--- Conditionally add an error context
-maybeAddErrCtxt :: Maybe Message -> TcM a -> TcM a
-maybeAddErrCtxt (Just msg) thing_inside = addErrCtxt msg thing_inside
-maybeAddErrCtxt Nothing    thing_inside = thing_inside
-
 popErrCtxt :: TcM a -> TcM a
 popErrCtxt = updCtxt (\ msgs -> case msgs of { [] -> []; (_ : ms) -> ms })
 
index fc82729..e511532 100644 (file)
@@ -1112,6 +1112,7 @@ data CtOrigin
   | StandAloneDerivOrigin -- Typechecking stand-alone deriving
   | DefaultOrigin      -- Typechecking a default decl
   | DoOrigin           -- Arising from a do expression
+  | MCompOrigin         -- Arising from a monad comprehension
   | IfOrigin            -- Arising from an if statement
   | ProcOrigin         -- Arising from a proc expression
   | AnnOrigin           -- An annotation
@@ -1147,6 +1148,7 @@ pprO DerivOrigin     = ptext (sLit "the 'deriving' clause of a data type declarat
 pprO StandAloneDerivOrigin = ptext (sLit "a 'deriving' declaration")
 pprO DefaultOrigin        = ptext (sLit "a 'default' declaration")
 pprO DoOrigin             = ptext (sLit "a do statement")
+pprO MCompOrigin           = ptext (sLit "a statement in a monad comprehension")
 pprO ProcOrigin                   = ptext (sLit "a proc expression")
 pprO (TypeEqOrigin eq)     = ptext (sLit "an equality") <+> ppr eq
 pprO AnnOrigin             = ptext (sLit "an annotation")
index 647f22f..414c63a 100644 (file)
@@ -102,6 +102,7 @@ import FastString
 import HsBinds               -- for TcEvBinds stuff 
 import Id 
 
+import StaticFlags( opt_PprStyle_Debug )
 import TcRnTypes
 #ifdef DEBUG
 import Control.Monad( when )
@@ -527,7 +528,7 @@ runTcS context untouch tcs
 
 #ifdef DEBUG
        ; count <- TcM.readTcRef step_count
-       ; when (count > 0) $
+       ; when (opt_PprStyle_Debug && count > 0) $
          TcM.debugDumpTcRn (ptext (sLit "Constraint solver steps =") 
                             <+> int count <+> ppr context)
 #endif
index 31352e1..e229b8b 100644 (file)
@@ -20,7 +20,7 @@ module TcUnify (
   matchExpectedListTy, matchExpectedPArrTy, 
   matchExpectedTyConApp, matchExpectedAppTy, 
   matchExpectedFunTys, matchExpectedFunKind,
-  wrapFunResCoercion
+  wrapFunResCoercion, failWithMisMatch
   ) where
 
 #include "HsVersions.h"
index 73faae7..4a502b4 100644 (file)
              <entry>dynamic</entry>
              <entry><option>-XNoTransformListComp</option></entry>
            </row>
+        <row>
+             <entry><option>-XMonadComprehensions</option></entry>
+             <entry>Enable <link linkend="monad-comprehensions">monad comprehensions</link>.</entry>
+             <entry>dynamic</entry>
+             <entry><option>-XNoMonadComprehensions</option></entry>
+           </row>
            <row>
              <entry><option>-XUnliftedFFITypes</option></entry>
              <entry>Enable unlifted FFI types.</entry>
index 9ea3332..89198c4 100644 (file)
@@ -1201,6 +1201,234 @@ output = [ x
 </para>
   </sect2>
 
+   <!-- ===================== MONAD COMPREHENSIONS ===================== -->
+
+<sect2 id="monad-comprehensions">
+    <title>Monad comprehensions</title>
+    <indexterm><primary>monad comprehensions</primary></indexterm>
+
+    <para>
+        Monad comprehesions generalise the list comprehension notation,
+        including parallel comprehensions 
+        (<xref linkend="parallel-list-comprehensions"/>) and 
+        transform comprenensions (<xref linkend="generalised-list-comprehensions"/>) 
+        to work for any monad.
+    </para>
+
+    <para>Monad comprehensions support:</para>
+
+    <itemizedlist>
+        <listitem>
+            <para>
+                Bindings:
+            </para>
+
+<programlisting>
+[ x + y | x &lt;- Just 1, y &lt;- Just 2 ]
+</programlisting>
+
+            <para>
+                Bindings are translated with the <literal>(&gt;&gt;=)</literal> and
+                <literal>return</literal> functions to the usual do-notation:
+            </para>
+
+<programlisting>
+do x &lt;- Just 1
+   y &lt;- Just 2
+   return (x+y)
+</programlisting>
+
+        </listitem>
+        <listitem>
+            <para>
+                Guards:
+            </para>
+
+<programlisting>
+[ x | x &lt;- [1..10], x &lt;= 5 ]
+</programlisting>
+
+            <para>
+                Guards are translated with the <literal>guard</literal> function,
+                which requires a <literal>MonadPlus</literal> instance:
+            </para>
+
+<programlisting>
+do x &lt;- [1..10]
+   guard (x &lt;= 5)
+   return x
+</programlisting>
+
+        </listitem>
+        <listitem>
+            <para>
+                Transform statements (as with <literal>-XTransformListComp</literal>):
+            </para>
+
+<programlisting>
+[ x+y | x &lt;- [1..10], y &lt;- [1..x], then take 2 ]
+</programlisting>
+
+            <para>
+                This translates to:
+            </para>
+
+<programlisting>
+do (x,y) &lt;- take 2 (do x &lt;- [1..10]
+                       y &lt;- [1..x]
+                       return (x,y))
+   return (x+y)
+</programlisting>
+
+        </listitem>
+        <listitem>
+            <para>
+                Group statements (as with <literal>-XTransformListComp</literal>):
+            </para>
+
+<programlisting>
+[ x | x &lt;- [1,1,2,2,3], then group by x ]
+[ x | x &lt;- [1,1,2,2,3], then group by x using GHC.Exts.groupWith ]
+[ x | x &lt;- [1,1,2,2,3], then group using myGroup ]
+</programlisting>
+
+            <para>
+                The basic <literal>then group by e</literal> statement is
+                translated using the <literal>mgroupWith</literal> function, which
+                requires a <literal>MonadGroup</literal> instance, defined in
+                <ulink url="&libraryBaseLocation;/Control-Monad-Group.html"><literal>Control.Monad.Group</literal></ulink>:
+            </para>
+
+<programlisting>
+do x &lt;- mgroupWith (do x &lt;- [1,1,2,2,3]
+                       return x)
+   return x
+</programlisting>
+
+            <para>
+                Note that the type of <literal>x</literal> is changed by the
+                grouping statement.
+            </para>
+
+            <para>
+                The grouping function can also be defined with the
+                <literal>using</literal> keyword.
+            </para>
+
+        </listitem>
+        <listitem>
+            <para>
+                Parallel statements (as with <literal>-XParallelListComp</literal>):
+            </para>
+
+<programlisting>
+[ (x+y) | x &lt;- [1..10]
+        | y &lt;- [11..20]
+        ]
+</programlisting>
+
+            <para>
+                Parallel statements are translated using the
+                <literal>mzip</literal> function, which requires a
+                <literal>MonadZip</literal> instance defined in
+                <ulink url="&libraryBaseLocation;/Control-Monad-Zip.html"><literal>Control.Monad.Zip</literal></ulink>:
+            </para>
+
+<programlisting>
+do (x,y) &lt;- mzip (do x &lt;- [1..10]
+                     return x)
+                 (do y &lt;- [11..20]
+                     return y)
+   return (x+y)
+</programlisting>
+
+        </listitem>
+    </itemizedlist>
+
+    <para>
+        All these features are enabled by default if the
+        <literal>MonadComprehensions</literal> extension is enabled. The types
+        and more detailed examples on how to use comprehensions are explained
+        in the previous chapters <xref
+            linkend="generalised-list-comprehensions"/> and <xref
+            linkend="parallel-list-comprehensions"/>. In general you just have
+        to replace the type <literal>[a]</literal> with the type
+        <literal>Monad m => m a</literal> for monad comprehensions.
+    </para>
+
+    <para>
+        Note: Even though most of these examples are using the list monad,
+        monad comprehensions work for any monad.
+        The <literal>base</literal> package offers all necessary instances for
+        lists, which make <literal>MonadComprehensions</literal> backward
+        compatible to built-in, transform and parallel list comprehensions.
+    </para>
+<para> More formally, the desugaring is as follows.  We write <literal>D[ e | Q]</literal>
+to mean the desugaring of the monad comprehension <literal>[ e | Q]</literal>: 
+<programlisting>
+Expressions: e
+Declarations: d
+Lists of qualifiers: Q,R,S  
+
+-- Basic forms
+D[ e | ]               = return e
+D[ e | p &lt;- e, Q ]     = e &gt;&gt;= \p -&gt; D[ e | Q ]
+D[ e | e, Q ]          = guard e &gt;&gt; \p -&gt; D[ e | Q ]
+D[ e | let d, Q ]      = let d in D[ e | Q ]
+
+-- Parallel comprehensions (iterate for multiple parallel branches)
+D[ e | (Q | R), S ]    = mzip D[ Qv | Q ] D[ Rv | R ] &gt;&gt;= \(Qv,Rv) -&gt; D[ e | S ]
+
+-- Transform comprehensions
+D[ e | Q then f, R ]                  = f D[ Qv | Q ] &gt;&gt;= \Qv -&gt; D[ e | R ]
+
+D[ e | Q then f by b, R ]             = f b D[ Qv | Q ] &gt;&gt;= \Qv -&gt; D[ e | R ]
+
+D[ e | Q then group using f, R ]      = f D[ Qv | Q ] &gt;&gt;= \ys -&gt; 
+                                        case (fmap selQv1 ys, ..., fmap selQvn ys) of
+                                            Qv -&gt; D[ e | R ]
+
+D[ e | Q then group by b using f, R ] = f b D[ Qv | Q ] &gt;&gt;= \ys -&gt; 
+                                        case (fmap selQv1 ys, ..., fmap selQvn ys) of
+                                           Qv -&gt; D[ e | R ]
+
+where  Qv is the tuple of variables bound by Q (and used subsequently)
+       selQvi is a selector mapping Qv to the ith component of Qv
+
+Operator     Standard binding       Expected type
+--------------------------------------------------------------------
+return       GHC.Base               t1 -&gt; m t2
+(&gt;&gt;=)        GHC.Base               m1 t1 -&gt; (t2 -&gt; m2 t3) -&gt; m3 t3
+(&gt;&gt;)         GHC.Base               m1 t1 -&gt; m2 t2         -&gt; m3 t3
+guard        Control.Monad          t1 -&gt; m t2
+fmap         GHC.Base               forall a b. (a-&gt;b) -&gt; n a -&gt; n b
+mgroupWith   Control.Monad.Group    forall a. (a -&gt; t) -&gt; m1 a -&gt; m2 (n a)
+mzip         Control.Monad.Zip      forall a b. m a -&gt; m b -&gt; m (a,b)
+</programlisting>                                          
+The comprehension should typecheck when its desugaring would typecheck. 
+</para>
+<para>
+Monad comprehensions support rebindable syntax (<xref linkend="rebindable-syntax"/>).  
+Without rebindable
+syntax, the operators from the "standard binding" module are used; with
+rebindable syntax, the operators are looked up in the current lexical scope.
+For example, parallel comprehensions will be typechecked and desugared
+using whatever "<literal>mzip</literal>" is in scope.
+</para>
+<para>
+The rebindable operators must have the "Expected type" given in the 
+table above.  These types are surprisingly general.  For example, you can
+use a bind operator with the type
+<programlisting>
+(>>=) :: T x y a -> (a -> T y z b) -> T x z b
+</programlisting>
+In the case of transform comprehensions, notice that the groups are
+parameterised over some arbitrary type <literal>n</literal> (provided it
+has an <literal>fmap</literal>, as well as
+the comprehension being over an arbitrary monad.
+</para>
+</sect2>
+
    <!-- ===================== REBINDABLE SYNTAX ===================  -->
 
 <sect2 id="rebindable-syntax">