Preliminary monad-comprehension patch (Trac #4370)
[ghc-hetmet.git] / compiler / parser / RdrHsSyn.lhs
index 47abf23..0e22c69 100644 (file)
@@ -42,6 +42,7 @@ module RdrHsSyn (
        checkPatterns,        -- SrcLoc -> [HsExp] -> P [HsPat]
        checkDo,              -- [Stmt] -> P [Stmt]
        checkMDo,             -- [Stmt] -> P [Stmt]
+       checkMonadComp,       -- P (HsStmtContext RdrName)
        checkValDef,          -- (SrcLoc, HsExp, HsRhs, [HsDecl]) -> P HsDecl
        checkValSig,          -- (SrcLoc, HsExp, HsRhs, [HsDecl]) -> P HsDecl
        checkDoAndIfThenElse,
@@ -54,6 +55,7 @@ import Class            ( FunDep )
 import TypeRep          ( Kind )
 import RdrName         ( RdrName, isRdrTyVar, isRdrTc, mkUnqual, rdrNameOcc, 
                          isRdrDataCon, isUnqual, getRdrName, setRdrNameSpace )
+import Name             ( Name )
 import BasicTypes      ( maxPrecedence, Activation(..), RuleMatchInfo,
                           InlinePragma(..), InlineSpec(..) )
 import Lexer
@@ -629,8 +631,8 @@ checkDoMDo _   nm loc []   = parseErrorSDoc loc (text ("Empty " ++ nm ++ " const
 checkDoMDo pre nm _   ss   = do
   check ss
   where 
-       check  []                     = panic "RdrHsSyn:checkDoMDo"
-       check  [L _ (ExprStmt e _ _)] = return ([], e)
+       check  []                       = panic "RdrHsSyn:checkDoMDo"
+       check  [L _ (ExprStmt e _ _ _)] = return ([], e)
        check  [L l e] = parseErrorSDoc l
                          (text ("The last statement in " ++ pre ++ nm ++
                                                    " construct must be an expression:")
@@ -912,6 +914,20 @@ isFunLhs e = go e []
                 _ -> return Nothing }
    go _ _ = return Nothing
 
+
+---------------------------------------------------------------------------
+-- Check for monad comprehensions
+--
+-- If the flag MonadComprehensions is set, return a `MonadComp' context,
+-- otherwise use the usual `ListComp' context
+
+checkMonadComp :: P (HsStmtContext Name)
+checkMonadComp = do
+    pState <- getPState
+    return $ if xopt Opt_MonadComprehensions (dflags pState)
+                then MonadComp
+                else ListComp
+
 ---------------------------------------------------------------------------
 -- Miscellaneous utilities