Preliminary monad-comprehension patch (Trac #4370)
[ghc-hetmet.git] / compiler / rename / RnExpr.lhs
index 310d075..425cb40 100644 (file)
@@ -131,8 +131,8 @@ rnExpr (HsApp fun arg)
     rnLExpr arg                `thenM` \ (arg',fvArg) ->
     return (HsApp fun' arg', fvFun `plusFV` fvArg)
 
-rnExpr (OpApp e1 (L op_loc (HsVar op_rdr)) _ e2) 
-  = do { (e1', fv_e1) <- rnLExpr e1
+rnExpr (OpApp e1 (L op_loc (HsVar op_rdr)) _ e2)
+  = do  { (e1', fv_e1) <- rnLExpr e1
        ; (e2', fv_e2) <- rnLExpr e2
        ; op_name <- setSrcSpan op_loc (lookupOccRn op_rdr)
        ; (op', fv_op) <- finishHsVar op_name
@@ -146,6 +146,10 @@ rnExpr (OpApp e1 (L op_loc (HsVar op_rdr)) _ e2)
        ; fixity <- lookupFixityRn op_name
        ; final_e <- mkOpAppRn e1' (L op_loc op') fixity e2'
        ; return (final_e, fv_e1 `plusFV` fv_op `plusFV` fv_e2) }
+rnExpr (OpApp _ other_op _ _)
+  = failWith (vcat [ hang (ptext (sLit "Operator application with a non-variable operator:"))
+                        2 (ppr other_op)
+                   , ptext (sLit "(Probably resulting from a Template Haskell splice)") ])
 
 rnExpr (NegApp e _)
   = rnLExpr e                  `thenM` \ (e', fv_e) ->
@@ -220,10 +224,16 @@ rnExpr (HsLet binds expr)
     rnLExpr expr                        `thenM` \ (expr',fvExpr) ->
     return (HsLet binds' expr', fvExpr)
 
-rnExpr (HsDo do_or_lc stmts body _)
-  = do         { ((stmts', body'), fvs) <- rnStmts do_or_lc stmts $
-                                   rnLExpr body
-       ; return (HsDo do_or_lc stmts' body' placeHolderType, fvs) }
+rnExpr (HsDo do_or_lc stmts body _ _)
+  = do         { ((stmts', body'), fvs1) <- rnStmts do_or_lc stmts $ \ _ ->
+                                    rnLExpr body
+        ; (return_op, fvs2) <-
+              if isMonadCompExpr do_or_lc
+                 then lookupSyntaxName returnMName
+                 else return (noSyntaxExpr, emptyFVs)
+
+       ; return ( HsDo do_or_lc stmts' body' return_op placeHolderType
+                 , fvs1 `plusFV` fvs2 ) }
 
 rnExpr (ExplicitList _ exps)
   = rnExprs exps                       `thenM` \ (exps', fvs) ->
@@ -264,13 +274,10 @@ rnExpr (ExprWithTySig expr pty)
 
 rnExpr (HsIf _ p b1 b2)
   = do { (p', fvP) <- rnLExpr p
-    ; (b1', fvB1) <- rnLExpr b1
-    ; (b2', fvB2) <- rnLExpr b2
-    ; rebind <- xoptM Opt_RebindableSyntax
-    ; if not rebind
-       then return (HsIf Nothing p' b1' b2', plusFVs [fvP, fvB1, fvB2])
-       else do { c <- liftM HsVar (lookupOccRn (mkVarUnqual (fsLit "ifThenElse")))
-               ; return (HsIf (Just c) p' b1' b2', plusFVs [fvP, fvB1, fvB2]) }}
+       ; (b1', fvB1) <- rnLExpr b1
+       ; (b2', fvB2) <- rnLExpr b2
+       ; (mb_ite, fvITE) <- lookupIfThenElse
+       ; return (HsIf mb_ite p' b1' b2', plusFVs [fvITE, fvP, fvB1, fvB2]) }
 
 rnExpr (HsType a)
   = rnHsTypeFVs doc a  `thenM` \ (t, fvT) -> 
@@ -440,9 +447,10 @@ convertOpFormsCmd (HsIf f exp c1 c2)
 convertOpFormsCmd (HsLet binds cmd)
   = HsLet binds (convertOpFormsLCmd cmd)
 
-convertOpFormsCmd (HsDo ctxt stmts body ty)
+convertOpFormsCmd (HsDo ctxt stmts body return_op ty)
   = HsDo ctxt (map (fmap convertOpFormsStmt) stmts)
-             (convertOpFormsLCmd body) ty
+             (convertOpFormsLCmd body)
+              (convertOpFormsCmd  return_op) ty
 
 -- Anything else is unchanged.  This includes HsArrForm (already done),
 -- things with no sub-commands, and illegal commands (which will be
@@ -452,8 +460,8 @@ convertOpFormsCmd c = c
 convertOpFormsStmt :: StmtLR id id -> StmtLR id id
 convertOpFormsStmt (BindStmt pat cmd _ _)
   = BindStmt pat (convertOpFormsLCmd cmd) noSyntaxExpr noSyntaxExpr
-convertOpFormsStmt (ExprStmt cmd _ _)
-  = ExprStmt (convertOpFormsLCmd cmd) noSyntaxExpr placeHolderType
+convertOpFormsStmt (ExprStmt cmd _ _ _)
+  = ExprStmt (convertOpFormsLCmd cmd) noSyntaxExpr noSyntaxExpr placeHolderType
 convertOpFormsStmt stmt@(RecStmt { recS_stmts = stmts })
   = stmt { recS_stmts = map (fmap convertOpFormsStmt) stmts }
 convertOpFormsStmt stmt = stmt
@@ -496,7 +504,7 @@ methodNamesCmd (HsIf _ _ c1 c2)
 
 methodNamesCmd (HsLet _ c) = methodNamesLCmd c
 
-methodNamesCmd (HsDo _ stmts body _) 
+methodNamesCmd (HsDo _ stmts body _ _) 
   = methodNamesStmts stmts `plusFV` methodNamesLCmd body
 
 methodNamesCmd (HsApp c _) = methodNamesLCmd c
@@ -537,11 +545,11 @@ methodNamesLStmt :: Located (StmtLR Name Name) -> FreeVars
 methodNamesLStmt = methodNamesStmt . unLoc
 
 methodNamesStmt :: StmtLR Name Name -> FreeVars
-methodNamesStmt (ExprStmt cmd _ _)               = methodNamesLCmd cmd
+methodNamesStmt (ExprStmt cmd _ _ _)             = methodNamesLCmd cmd
 methodNamesStmt (BindStmt _ cmd _ _)             = methodNamesLCmd cmd
 methodNamesStmt (RecStmt { recS_stmts = stmts }) = methodNamesStmts stmts `addOneFV` loopAName
 methodNamesStmt (LetStmt _)                      = emptyFVs
-methodNamesStmt (ParStmt _)                      = emptyFVs
+methodNamesStmt (ParStmt _ _ _ _)                = emptyFVs
 methodNamesStmt (TransformStmt {})               = emptyFVs
 methodNamesStmt (GroupStmt {})                   = emptyFVs
    -- ParStmt, TransformStmt and GroupStmt can't occur in commands, but it's not convenient to error 
@@ -637,16 +645,7 @@ rnBracket (DecBrG _) = panic "rnBracket: unexpected DecBrG"
 %************************************************************************
 
 \begin{code}
-rnStmts :: HsStmtContext Name -> [LStmt RdrName] 
-       -> RnM (thing, FreeVars)
-       -> RnM (([LStmt Name], thing), FreeVars)
--- Variables bound by the Stmts, and mentioned in thing_inside,
--- do not appear in the result FreeVars
-
-rnStmts (MDoExpr _) stmts thing_inside = rnMDoStmts    stmts thing_inside
-rnStmts ctxt        stmts thing_inside = rnNormalStmts ctxt stmts (\ _ -> thing_inside)
-
-rnNormalStmts :: HsStmtContext Name -> [LStmt RdrName]
+rnStmts :: HsStmtContext Name -> [LStmt RdrName]
              -> ([Name] -> RnM (thing, FreeVars))
              -> RnM (([LStmt Name], thing), FreeVars)  
 -- Variables bound by the Stmts, and mentioned in thing_inside,
@@ -654,15 +653,15 @@ rnNormalStmts :: HsStmtContext Name -> [LStmt RdrName]
 --
 -- Renaming a single RecStmt can give a sequence of smaller Stmts
 
-rnNormalStmts _ [] thing_inside 
+rnStmts _ [] thing_inside
   = do { (res, fvs) <- thing_inside []
        ; return (([], res), fvs) }
 
-rnNormalStmts ctxt (stmt@(L loc _) : stmts) thing_inside
+rnStmts ctxt (stmt@(L loc _) : stmts) thing_inside
   = do { ((stmts1, (stmts2, thing)), fvs) 
             <- setSrcSpan loc           $
                rnStmt ctxt stmt         $ \ bndrs1 ->
-               rnNormalStmts ctxt stmts $ \ bndrs2 ->
+               rnStmts ctxt stmts $ \ bndrs2 ->
                thing_inside (bndrs1 ++ bndrs2)
        ; return (((stmts1 ++ stmts2), thing), fvs) }
 
@@ -673,12 +672,15 @@ rnStmt :: HsStmtContext Name -> LStmt RdrName
 -- Variables bound by the Stmt, and mentioned in thing_inside,
 -- do not appear in the result FreeVars
 
-rnStmt _ (L loc (ExprStmt expr _ _)) thing_inside
+rnStmt ctxt (L loc (ExprStmt expr _ _ _)) thing_inside
   = do { (expr', fv_expr) <- rnLExpr expr
        ; (then_op, fvs1)  <- lookupSyntaxName thenMName
-       ; (thing, fvs2)    <- thing_inside []
-       ; return (([L loc (ExprStmt expr' then_op placeHolderType)], thing),
-                 fv_expr `plusFV` fvs1 `plusFV` fvs2) }
+       ; (guard_op, fvs2) <- if isMonadCompExpr ctxt
+                                 then lookupSyntaxName guardMName
+                                 else return (noSyntaxExpr, emptyFVs)
+       ; (thing, fvs3)    <- thing_inside []
+       ; return (([L loc (ExprStmt expr' then_op guard_op placeHolderType)], thing),
+                 fv_expr `plusFV` fvs1 `plusFV` fvs2 `plusFV` fvs3) }
 
 rnStmt ctxt (L loc (BindStmt pat expr _ _)) thing_inside
   = do { (expr', fv_expr) <- rnLExpr expr
@@ -710,7 +712,7 @@ rnStmt ctxt (L _ (RecStmt { recS_stmts = rec_stmts })) thing_inside
        -- for which it's the fwd refs within the bind itself
        -- (This set may not be empty, because we're in a recursive 
        -- context.)
-        ; rn_rec_stmts_and_then rec_stmts      $ \ segs -> do
+        ; rnRecStmtsAndThen rec_stmts   $ \ segs -> do
 
        { let bndrs = nameSetToList $ foldr (unionNameSets . (\(ds,_,_,_) -> ds)) 
                                             emptyNameSet segs
@@ -742,18 +744,26 @@ rnStmt ctxt (L _ (RecStmt { recS_stmts = rec_stmts })) thing_inside
 
        ; return ((rec_stmts', thing), fvs `plusFV` fvs1 `plusFV` fvs2 `plusFV` fvs3) } }
 
-rnStmt ctxt (L loc (ParStmt segs)) thing_inside
+rnStmt ctxt (L loc (ParStmt segs _ _ _)) thing_inside
   = do { checkParStmt ctxt
-       ; ((segs', thing), fvs) <- rnParallelStmts (ParStmtCtxt ctxt) segs thing_inside
-       ; return (([L loc (ParStmt segs')], thing), fvs) }
-
-rnStmt ctxt (L loc (TransformStmt stmts _ using by)) thing_inside
+        ; ((mzip_op, fvs1), (bind_op, fvs2), (return_op, fvs3)) <- if isMonadCompExpr ctxt
+              then (,,) <$> lookupSyntaxName mzipName
+                        <*> lookupSyntaxName bindMName
+                        <*> lookupSyntaxName returnMName
+              else return ( (noSyntaxExpr, emptyFVs)
+                          , (noSyntaxExpr, emptyFVs)
+                          , (noSyntaxExpr, emptyFVs) )
+       ; ((segs', thing), fvs4) <- rnParallelStmts (ParStmtCtxt ctxt) segs thing_inside
+       ; return ( ([L loc (ParStmt segs' mzip_op bind_op return_op)], thing)
+                 , fvs1 `plusFV` fvs2 `plusFV` fvs3 `plusFV` fvs4) }
+
+rnStmt ctxt (L loc (TransformStmt stmts _ using by _ _)) thing_inside
   = do { checkTransformStmt ctxt
     
        ; (using', fvs1) <- rnLExpr using
 
        ; ((stmts', (by', used_bndrs, thing)), fvs2)
-             <- rnNormalStmts (TransformStmtCtxt ctxt) stmts $ \ bndrs ->
+             <- rnStmts (TransformStmtCtxt ctxt) stmts $ \ bndrs ->
                 do { (by', fvs_by) <- case by of
                                         Nothing -> return (Nothing, emptyFVs)
                                         Just e  -> do { (e', fvs) <- rnLExpr e; return (Just e', fvs) }
@@ -764,35 +774,58 @@ rnStmt ctxt (L loc (TransformStmt stmts _ using by)) thing_inside
                          -- the "thing inside", **or of the by-expression**, as used
                    ; return ((by', used_bndrs, thing), fvs) }
 
-       ; return (([L loc (TransformStmt stmts' used_bndrs using' by')], thing), 
-                 fvs1 `plusFV` fvs2) }
+       -- Lookup `(>>=)` and `fail` for monad comprehensions
+       ; ((return_op, fvs3), (bind_op, fvs4)) <-
+             if isMonadCompExpr ctxt
+                then (,) <$> lookupSyntaxName returnMName
+                         <*> lookupSyntaxName bindMName
+                else return ( (noSyntaxExpr, emptyFVs)
+                            , (noSyntaxExpr, emptyFVs) )
+
+       ; return (([L loc (TransformStmt stmts' used_bndrs using' by' return_op bind_op)], thing), 
+                 fvs1 `plusFV` fvs2 `plusFV` fvs3 `plusFV` fvs4) }
         
-rnStmt ctxt (L loc (GroupStmt stmts _ by using)) thing_inside
+rnStmt ctxt (L loc (GroupStmt stmts _ by using _ _ _)) thing_inside
   = do { checkTransformStmt ctxt
     
          -- Rename the 'using' expression in the context before the transform is begun
        ; (using', fvs1) <- case using of
                              Left e  -> do { (e', fvs) <- rnLExpr e; return (Left e', fvs) }
-                            Right _ -> do { (e', fvs) <- lookupSyntaxName groupWithName
-                                           ; return (Right e', fvs) }
+                            Right _
+                                | isMonadCompExpr ctxt ->
+                                  do { (e', fvs) <- lookupSyntaxName groupMName
+                                     ; return (Right e', fvs) }
+                                | otherwise ->
+                                  do { (e', fvs) <- lookupSyntaxName groupWithName
+                                     ; return (Right e', fvs) }
 
          -- Rename the stmts and the 'by' expression
         -- Keep track of the variables mentioned in the 'by' expression
        ; ((stmts', (by', used_bndrs, thing)), fvs2) 
-             <- rnNormalStmts (TransformStmtCtxt ctxt) stmts $ \ bndrs ->
+             <- rnStmts (TransformStmtCtxt ctxt) stmts $ \ bndrs ->
                 do { (by',   fvs_by) <- mapMaybeFvRn rnLExpr by
                    ; (thing, fvs_thing) <- thing_inside bndrs
                    ; let fvs = fvs_by `plusFV` fvs_thing
                          used_bndrs = filter (`elemNameSet` fvs) bndrs
                    ; return ((by', used_bndrs, thing), fvs) }
 
-       ; let all_fvs  = fvs1 `plusFV` fvs2 
+       -- Lookup `return`, `(>>=)` and `liftM` for monad comprehensions
+       ; ((return_op, fvs3), (bind_op, fvs4), (liftM_op, fvs5)) <-
+             if isMonadCompExpr ctxt
+                then (,,) <$> lookupSyntaxName returnMName
+                          <*> lookupSyntaxName bindMName
+                          <*> lookupSyntaxName liftMName
+                else return ( (noSyntaxExpr, emptyFVs)
+                            , (noSyntaxExpr, emptyFVs)
+                            , (noSyntaxExpr, emptyFVs) )
+
+       ; let all_fvs  = fvs1 `plusFV` fvs2 `plusFV` fvs3 `plusFV` fvs4
+                             `plusFV` fvs5
              bndr_map = used_bndrs `zip` used_bndrs
             -- See Note [GroupStmt binder map] in HsExpr
 
        ; traceRn (text "rnStmt: implicitly rebound these used binders:" <+> ppr bndr_map)
-       ; return (([L loc (GroupStmt stmts' bndr_map by' using')], thing), all_fvs) }
-
+       ; return (([L loc (GroupStmt stmts' bndr_map by' using' return_op bind_op liftM_op)], thing), all_fvs) }
 
 type ParSeg id = ([LStmt id], [id])       -- The Names are bound by the Stmts
 
@@ -816,7 +849,7 @@ rnParallelStmts ctxt segs thing_inside
 
     rn_segs env bndrs_so_far ((stmts,_) : segs) 
       = do { ((stmts', (used_bndrs, segs', thing)), fvs)
-                    <- rnNormalStmts ctxt stmts $ \ bndrs ->
+                    <- rnStmts ctxt stmts $ \ bndrs ->
                        setLocalRdrEnv env       $ do
                        { ((segs', thing), fvs) <- rn_segs env (bndrs ++ bndrs_so_far) segs
                       ; let used_bndrs = filter (`elemNameSet` fvs) bndrs
@@ -864,28 +897,13 @@ type Segment stmts = (Defs,
                      stmts)    -- Either Stmt or [Stmt]
 
 
-----------------------------------------------------
-
-rnMDoStmts :: [LStmt RdrName]
-          -> RnM (thing, FreeVars)
-          -> RnM (([LStmt Name], thing), FreeVars)     
-rnMDoStmts stmts thing_inside
-  = rn_rec_stmts_and_then stmts $ \ segs -> do
-    { (thing, fvs_later) <- thing_inside
-    ; let   segs_w_fwd_refs = addFwdRefs segs
-           grouped_segs = glomSegments segs_w_fwd_refs
-           (stmts', fvs) = segsToStmts emptyRecStmt grouped_segs fvs_later
-    ; return ((stmts', thing), fvs) }
-
----------------------------------------------
-
 -- wrapper that does both the left- and right-hand sides
-rn_rec_stmts_and_then :: [LStmt RdrName]
+rnRecStmtsAndThen :: [LStmt RdrName]
                          -- assumes that the FreeVars returned includes
                          -- the FreeVars of the Segments
                       -> ([Segment (LStmt Name)] -> RnM (a, FreeVars))
                       -> RnM (a, FreeVars)
-rn_rec_stmts_and_then s cont
+rnRecStmtsAndThen s cont
   = do { -- (A) Make the mini fixity env for all of the stmts
          fix_env <- makeMiniFixityEnv (collectRecStmtsFixities s)
 
@@ -894,13 +912,15 @@ rn_rec_stmts_and_then s cont
 
          --    ...bring them and their fixities into scope
        ; let bound_names = collectLStmtsBinders (map fst new_lhs_and_fv)
+             -- Fake uses of variables introduced implicitly (warning suppression, see #4404)
+             implicit_uses = lStmtsImplicits (map fst new_lhs_and_fv)
        ; bindLocalNamesFV bound_names $
           addLocalFixities fix_env bound_names $ do
 
          -- (C) do the right-hand-sides and thing-inside
        { segs <- rn_rec_stmts bound_names new_lhs_and_fv
        ; (res, fvs) <- cont segs 
-       ; warnUnusedLocalBinds bound_names fvs
+       ; warnUnusedLocalBinds bound_names (fvs `unionNameSets` implicit_uses)
        ; return (res, fvs) }}
 
 -- get all the fixity decls in any Let stmt
@@ -922,9 +942,9 @@ rn_rec_stmt_lhs :: MiniFixityEnv
                    -- so we don't bother to compute it accurately in the other cases
                 -> RnM [(LStmtLR Name RdrName, FreeVars)]
 
-rn_rec_stmt_lhs _ (L loc (ExprStmt expr a b)) = return [(L loc (ExprStmt expr a b), 
-                                                       -- this is actually correct
-                                                       emptyFVs)]
+rn_rec_stmt_lhs _ (L loc (ExprStmt expr a b c)) = return [(L loc (ExprStmt expr a b c), 
+                                                         -- this is actually correct
+                                                         emptyFVs)]
 
 rn_rec_stmt_lhs fix_env (L loc (BindStmt pat expr a b)) 
   = do 
@@ -947,7 +967,7 @@ rn_rec_stmt_lhs fix_env (L loc (LetStmt (HsValBinds binds)))
 rn_rec_stmt_lhs fix_env (L _ (RecStmt { recS_stmts = stmts })) -- Flatten Rec inside Rec
     = rn_rec_stmts_lhs fix_env stmts
 
-rn_rec_stmt_lhs _ stmt@(L _ (ParStmt _))       -- Syntactically illegal in mdo
+rn_rec_stmt_lhs _ stmt@(L _ (ParStmt _ _ _ _)) -- Syntactically illegal in mdo
   = pprPanic "rn_rec_stmt" (ppr stmt)
   
 rn_rec_stmt_lhs _ stmt@(L _ (TransformStmt {}))        -- Syntactically illegal in mdo
@@ -978,11 +998,11 @@ rn_rec_stmt :: [Name] -> LStmtLR Name RdrName -> FreeVars -> RnM [Segment (LStmt
        -- Rename a Stmt that is inside a RecStmt (or mdo)
        -- Assumes all binders are already in scope
        -- Turns each stmt into a singleton Stmt
-rn_rec_stmt _ (L loc (ExprStmt expr _ _)) _
+rn_rec_stmt _ (L loc (ExprStmt expr _ _ _)) _
   = rnLExpr expr `thenM` \ (expr', fvs) ->
     lookupSyntaxName thenMName `thenM` \ (then_op, fvs1) ->
     return [(emptyNameSet, fvs `plusFV` fvs1, emptyNameSet,
-             L loc (ExprStmt expr' then_op placeHolderType))]
+             L loc (ExprStmt expr' then_op noSyntaxExpr placeHolderType))]
 
 rn_rec_stmt _ (L loc (BindStmt pat' expr _ _)) fv_pat
   = rnLExpr expr               `thenM` \ (expr', fv_expr) ->
@@ -1000,7 +1020,7 @@ rn_rec_stmt _ (L _ (LetStmt binds@(HsIPBinds _))) _
 
 rn_rec_stmt all_bndrs (L loc (LetStmt (HsValBinds binds'))) _ = do 
   (binds', du_binds) <- 
-      -- fixities and unused are handled above in rn_rec_stmts_and_then
+      -- fixities and unused are handled above in rnRecStmtsAndThen
       rnLocalValBindsRHS (mkNameSet all_bndrs) binds'
   return [(duDefs du_binds, allUses du_binds, 
           emptyNameSet, L loc (LetStmt (HsValBinds binds')))]
@@ -1173,19 +1193,22 @@ checkLetStmt _ctxt           _binds            = return ()
 
 ---------
 checkRecStmt :: HsStmtContext Name -> RnM ()
-checkRecStmt (MDoExpr {}) = return ()  -- Recursive stmt ok in 'mdo'
-checkRecStmt (DoExpr {})  = return ()  -- and in 'do'
-checkRecStmt ctxt        = addErr msg
+checkRecStmt MDoExpr = return ()      -- Recursive stmt ok in 'mdo'
+checkRecStmt DoExpr  = return ()      -- and in 'do'
+checkRecStmt ctxt    = addErr msg
   where
     msg = ptext (sLit "Illegal 'rec' stmt in") <+> pprStmtContext ctxt
 
 ---------
 checkParStmt :: HsStmtContext Name -> RnM ()
 checkParStmt _
-  = do { parallel_list_comp <- xoptM Opt_ParallelListComp
-       ; checkErr parallel_list_comp msg }
+  = do { monad_comp <- xoptM Opt_MonadComprehensions
+        ; unless monad_comp $ do
+          { parallel_list_comp <- xoptM Opt_ParallelListComp
+         ; checkErr parallel_list_comp msg }
+        }
   where
-    msg = ptext (sLit "Illegal parallel list comprehension: use -XParallelListComp")
+    msg = ptext (sLit "Illegal parallel list comprehension: use -XParallelListComp or -XMonadComprehensions")
 
 ---------
 checkTransformStmt :: HsStmtContext Name -> RnM ()
@@ -1194,7 +1217,10 @@ checkTransformStmt ListComp  -- Ensure we are really within a list comprehension
   = do { transform_list_comp <- xoptM Opt_TransformListComp
        ; checkErr transform_list_comp msg }
   where
-    msg = ptext (sLit "Illegal transform or grouping list comprehension: use -XTransformListComp")
+    msg = ptext (sLit "Illegal transform or grouping list comprehension: use -XTransformListComp or -XMonadComprehensions")
+checkTransformStmt MonadComp  -- Monad comprehensions are always fine, since the
+                              -- MonadComprehensions flag will already be turned on
+  = do  { return () }
 checkTransformStmt (ParStmtCtxt       ctxt) = checkTransformStmt ctxt  -- Ok to nest inside a parallel comprehension
 checkTransformStmt (TransformStmtCtxt ctxt) = checkTransformStmt ctxt  -- Ok to nest inside a parallel comprehension
 checkTransformStmt ctxt = addErr msg