More on monad-comp; an intermediate state, so don't pull
[ghc-hetmet.git] / compiler / rename / RnExpr.lhs
index d1dd222..11d44e3 100644 (file)
@@ -648,32 +648,22 @@ rnStmts MDoExpr stmts thing_inside    -- Deal with mdo
   = -- Behave like do { rec { ...all but last... }; last }
     do { ((stmts1, (stmts2, thing)), fvs) 
           <- rnStmt MDoExpr (noLoc $ mkRecStmt all_but_last) $ \ _ ->
-             do { checkStmt MDoExpr True last_stmt
-                ; rnStmt MDoExpr last_stmt thing_inside }
+             do { last_stmt' <- checkLastStmt MDoExpr last_stmt
+                ; rnStmt MDoExpr last_stmt' thing_inside }
        ; return (((stmts1 ++ stmts2), thing), fvs) }
   where
     Just (all_but_last, last_stmt) = snocView stmts
 
-rnStmts ctxt (lstmt@(L loc stmt) : lstmts) thing_inside
+rnStmts ctxt (lstmt@(L loc _) : lstmts) thing_inside
   | null lstmts
   = setSrcSpan loc $
-    do { -- Turn a final ExprStmt into a LastStmt
-         -- This is the first place it's convenient to do this
-        -- (In principle the parser could do it, but it's 
-        --  just not very convenient to do so.)
-         let stmt' | okEmpty ctxt 
-                   = lstmt
-                   | otherwise    
-                   = case stmt of 
-                       ExprStmt e _ _ _ -> L loc (mkLastStmt e)
-                      _                -> lstmt
-       ; checkStmt ctxt True {- last stmt -} stmt'
-       ; rnStmt ctxt stmt' thing_inside }
+    do { lstmt' <- checkLastStmt ctxt lstmt
+       ; rnStmt ctxt lstmt' thing_inside }
 
   | otherwise
   = do { ((stmts1, (stmts2, thing)), fvs) 
             <- setSrcSpan loc                         $
-               do { checkStmt ctxt False {- Not last -} lstmt
+               do { checkStmt ctxt lstmt
                   ; rnStmt ctxt lstmt    $ \ bndrs1 ->
                     rnStmts ctxt lstmts  $ \ bndrs2 ->
                     thing_inside (bndrs1 ++ bndrs2) }
@@ -1211,7 +1201,7 @@ checkEmptyStmts :: HsStmtContext Name -> RnM ()
 checkEmptyStmts ctxt 
   = unless (okEmpty ctxt) (addErr (emptyErr ctxt))
 
-okEmpty :: HsStmtContext Name -> Bool
+okEmpty :: HsStmtContext a -> Bool
 okEmpty (PatGuard {}) = True
 okEmpty _             = False
 
@@ -1221,14 +1211,42 @@ emptyErr (TransformStmtCtxt {}) = ptext (sLit "Empty statement group preceding '
 emptyErr ctxt                   = ptext (sLit "Empty") <+> pprStmtContext ctxt
 
 ---------------------- 
+checkLastStmt :: HsStmtContext Name
+              -> LStmt RdrName 
+              -> RnM (LStmt RdrName)
+checkLastStmt ctxt lstmt@(L loc stmt)
+  = case ctxt of 
+      ListComp  -> check_comp
+      MonadComp -> check_comp
+      PArrComp  -> check_comp
+      DoExpr   -> check_do
+      MDoExpr   -> check_do
+      _         -> check_other
+  where
+    check_do   -- Expect ExprStmt, and change it to LastStmt
+      = case stmt of 
+          ExprStmt e _ _ _ -> return (L loc (mkLastStmt e))
+          LastStmt {}      -> return lstmt   -- "Deriving" clauses may generate a
+                                            -- LastStmt directly (unlike the parser)
+         _                -> do { addErr (hang last_error 2 (ppr stmt)); return lstmt }
+    last_error = (ptext (sLit "The last statement in") <+> pprAStmtContext ctxt
+                  <+> ptext (sLit "must be an expression"))
+
+    check_comp -- Expect LastStmt; this should be enforced by the parser!
+      = case stmt of 
+          LastStmt {} -> return lstmt
+          _           -> pprPanic "checkLastStmt" (ppr lstmt)
+
+    check_other        -- Behave just as if this wasn't the last stmt
+      = do { checkStmt ctxt lstmt; return lstmt }
+
 -- Checking when a particular Stmt is ok
 checkStmt :: HsStmtContext Name
-          -> Bool                      -- True <=> this is the last Stmt in the sequence
           -> LStmt RdrName 
           -> RnM ()
-checkStmt ctxt is_last (L _ stmt)
+checkStmt ctxt (L _ stmt)
   = do { dflags <- getDOpts
-       ; case okStmt dflags ctxt is_last stmt of 
+       ; case okStmt dflags ctxt stmt of 
            Nothing    -> return ()
            Just extra -> addErr (msg $$ extra) }
   where
@@ -1250,42 +1268,32 @@ isOK, notOK :: Maybe SDoc
 isOK  = Nothing
 notOK = Just empty
 
-okStmt, okDoStmt, okCompStmt :: DynFlags -> HsStmtContext Name -> Bool 
+okStmt, okDoStmt, okCompStmt :: DynFlags -> HsStmtContext Name
                              -> Stmt RdrName -> Maybe SDoc
 -- Return Nothing if OK, (Just extra) if not ok
 -- The "extra" is an SDoc that is appended to an generic error message
-okStmt _ (PatGuard {}) _ stmt
+okStmt _ (PatGuard {}) stmt
   = case stmt of
       ExprStmt {} -> isOK
       BindStmt {} -> isOK
       LetStmt {}  -> isOK
       _           -> notOK
 
-okStmt dflags (ParStmtCtxt ctxt) _ stmt
+okStmt dflags (ParStmtCtxt ctxt) stmt
   = case stmt of
       LetStmt (HsIPBinds {}) -> notOK
-      _                      -> okStmt dflags ctxt False stmt
-                               -- NB: is_last=False in recursive
-                               -- call; the branches of of a Par
-                               -- not finish with a LastStmt
+      _                      -> okStmt dflags ctxt stmt
 
-okStmt dflags (TransformStmtCtxt ctxt) _ stmt 
-  = okStmt dflags ctxt False stmt
+okStmt dflags (TransformStmtCtxt ctxt) stmt 
+  = okStmt dflags ctxt stmt
 
-okStmt dflags ctxt is_last stmt 
-  | isDoExpr       ctxt = okDoStmt   dflags ctxt is_last stmt
-  | isListCompExpr ctxt = okCompStmt dflags ctxt is_last stmt
+okStmt dflags ctxt stmt 
+  | isDoExpr       ctxt = okDoStmt   dflags ctxt stmt
+  | isListCompExpr ctxt = okCompStmt dflags ctxt stmt
   | otherwise           = pprPanic "okStmt" (pprStmtContext ctxt)
 
 ----------------
-okDoStmt dflags ctxt is_last stmt
-  | is_last
-  = case stmt of 
-      LastStmt {} -> isOK
-      _ -> Just (ptext (sLit "The last statement in") <+> pprAStmtContext ctxt
-                 <+> ptext (sLit "must be an expression"))
-
-  | otherwise
+okDoStmt dflags _ stmt
   = case stmt of
        RecStmt {} 
          | Opt_DoRec `xopt` dflags -> isOK
@@ -1297,13 +1305,7 @@ okDoStmt dflags ctxt is_last stmt
 
 
 ----------------
-okCompStmt dflags _ is_last stmt
-  | is_last
-  = case stmt of
-      LastStmt {} -> Nothing
-      _ -> pprPanic "Unexpected stmt" (ppr stmt)  -- Not a user error
-
-  | otherwise
+okCompStmt dflags _ stmt
   = case stmt of
        BindStmt {} -> isOK
        LetStmt {}  -> isOK