Simon's hacking on monad-comp; incomplete
[ghc-hetmet.git] / compiler / rename / RnExpr.lhs
index 425cb40..e3e92bc 100644 (file)
@@ -224,16 +224,9 @@ rnExpr (HsLet binds expr)
     rnLExpr expr                        `thenM` \ (expr',fvExpr) ->
     return (HsLet binds' expr', fvExpr)
 
-rnExpr (HsDo do_or_lc stmts body _ _)
-  = do         { ((stmts', body'), fvs1) <- rnStmts do_or_lc stmts $ \ _ ->
-                                    rnLExpr body
-        ; (return_op, fvs2) <-
-              if isMonadCompExpr do_or_lc
-                 then lookupSyntaxName returnMName
-                 else return (noSyntaxExpr, emptyFVs)
-
-       ; return ( HsDo do_or_lc stmts' body' return_op placeHolderType
-                 , fvs1 `plusFV` fvs2 ) }
+rnExpr (HsDo do_or_lc stmts _)
+  = do         { ((stmts', _), fvs) <- rnStmts do_or_lc stmts (\ _ -> return ())
+       ; return ( HsDo do_or_lc stmts' placeHolderType, fvs ) }
 
 rnExpr (ExplicitList _ exps)
   = rnExprs exps                       `thenM` \ (exps', fvs) ->
@@ -653,25 +646,53 @@ rnStmts :: HsStmtContext Name -> [LStmt RdrName]
 --
 -- Renaming a single RecStmt can give a sequence of smaller Stmts
 
-rnStmts _ [] thing_inside
-  = do { (res, fvs) <- thing_inside []
-       ; return (([], res), fvs) }
+rnStmts ctxt [] thing_inside
+  = do { addErr (ptext (sLit "Empty") <+> pprStmtContext ctxt)
+       ; (thing, fvs) <- thing_inside []
+       ; return (([], thing), fvs) }
+
+rnStmts MDoExpr stmts thing_inside    -- Deal with mdo
+  = -- Behave like do { rec { ...all but last... }; last }
+    do { ((stmts1, (stmts2, thing)), fvs) 
+          <- rnStmt MDoExpr (mkRecStmt all_but_last) $ \ bndrs ->
+             do { checkStmt MDoExpr True last_stmt
+                ; rnStmt MDoExpr last_stmt thing_inside }
+       ; return (((stmts1 ++ stmts2), thing), fvs) }
+  where
+    Just (all_but_last, last_stmt) = snocView stmts
 
 rnStmts ctxt (stmt@(L loc _) : stmts) thing_inside
+  | null stmts
+  = setSrcSpan loc $
+    do { let last_stmt = case stmt of 
+                           ExprStmt e _ _ _ -> LastStmt e noSyntaxExpr
+       ; checkStmt ctxt True {- last stmt -} stmt
+       ; rnStmt ctxt stmt thing_inside }
+
+  | otherwise
   = do { ((stmts1, (stmts2, thing)), fvs) 
-            <- setSrcSpan loc           $
-               rnStmt ctxt stmt         $ \ bndrs1 ->
-               rnStmts ctxt stmts $ \ bndrs2 ->
-               thing_inside (bndrs1 ++ bndrs2)
+            <- setSrcSpan loc                         $
+               do { checkStmt ctxt False {- Not last -} stmt
+                  ; rnStmt ctxt stmt    $ \ bndrs1 ->
+                    rnStmts ctxt stmts  $ \ bndrs2 ->
+                    thing_inside (bndrs1 ++ bndrs2) }
        ; return (((stmts1 ++ stmts2), thing), fvs) }
 
-
-rnStmt :: HsStmtContext Name -> LStmt RdrName
+----------------------
+rnStmt :: HsStmtContext Name 
+       -> LStmt RdrName
        -> ([Name] -> RnM (thing, FreeVars))
        -> RnM (([LStmt Name], thing), FreeVars)
 -- Variables bound by the Stmt, and mentioned in thing_inside,
 -- do not appear in the result FreeVars
 
+rnStmt ctxt (L loc (LastStmt expr _)) thing_inside
+  = do { (expr', fv_expr) <- rnLExpr expr
+       ; (ret_op, fvs1)   <- lookupSyntaxName returnMName
+       ; (thing, fvs3)    <- thing_inside []
+       ; return (([L loc (LastStmt expr' ret_op)], thing),
+                 fv_expr `plusFV` fvs1 `plusFV` fvs3) }
+
 rnStmt ctxt (L loc (ExprStmt expr _ _ _)) thing_inside
   = do { (expr', fv_expr) <- rnLExpr expr
        ; (then_op, fvs1)  <- lookupSyntaxName thenMName
@@ -683,7 +704,8 @@ rnStmt ctxt (L loc (ExprStmt expr _ _ _)) thing_inside
                  fv_expr `plusFV` fvs1 `plusFV` fvs2 `plusFV` fvs3) }
 
 rnStmt ctxt (L loc (BindStmt pat expr _ _)) thing_inside
-  = do { (expr', fv_expr) <- rnLExpr expr
+  = do { checkBindStmt ctxt is_last
+        ; (expr', fv_expr) <- rnLExpr expr
                -- The binders do not scope over the expression
        ; (bind_op, fvs1) <- lookupSyntaxName bindMName
        ; (fail_op, fvs2) <- lookupSyntaxName failMName
@@ -701,8 +723,7 @@ rnStmt ctxt (L loc (LetStmt binds)) thing_inside
         ; return (([L loc (LetStmt binds')], thing), fvs) }  }
 
 rnStmt ctxt (L _ (RecStmt { recS_stmts = rec_stmts })) thing_inside
-  = do { checkRecStmt ctxt
-
+  = do { 
        -- Step1: Bring all the binders of the mdo into scope
        -- (Remember that this also removes the binders from the
        -- finally-returned free-vars.)
@@ -745,8 +766,7 @@ rnStmt ctxt (L _ (RecStmt { recS_stmts = rec_stmts })) thing_inside
        ; return ((rec_stmts', thing), fvs `plusFV` fvs1 `plusFV` fvs2 `plusFV` fvs3) } }
 
 rnStmt ctxt (L loc (ParStmt segs _ _ _)) thing_inside
-  = do { checkParStmt ctxt
-        ; ((mzip_op, fvs1), (bind_op, fvs2), (return_op, fvs3)) <- if isMonadCompExpr ctxt
+  = do { ((mzip_op, fvs1), (bind_op, fvs2), (return_op, fvs3)) <- if isMonadCompExpr ctxt
               then (,,) <$> lookupSyntaxName mzipName
                         <*> lookupSyntaxName bindMName
                         <*> lookupSyntaxName returnMName
@@ -758,9 +778,7 @@ rnStmt ctxt (L loc (ParStmt segs _ _ _)) thing_inside
                  , fvs1 `plusFV` fvs2 `plusFV` fvs3 `plusFV` fvs4) }
 
 rnStmt ctxt (L loc (TransformStmt stmts _ using by _ _)) thing_inside
-  = do { checkTransformStmt ctxt
-    
-       ; (using', fvs1) <- rnLExpr using
+  = do { (using', fvs1) <- rnLExpr using
 
        ; ((stmts', (by', used_bndrs, thing)), fvs2)
              <- rnStmts (TransformStmtCtxt ctxt) stmts $ \ bndrs ->
@@ -786,9 +804,7 @@ rnStmt ctxt (L loc (TransformStmt stmts _ using by _ _)) thing_inside
                  fvs1 `plusFV` fvs2 `plusFV` fvs3 `plusFV` fvs4) }
         
 rnStmt ctxt (L loc (GroupStmt stmts _ by using _ _ _)) thing_inside
-  = do { checkTransformStmt ctxt
-    
-         -- Rename the 'using' expression in the context before the transform is begun
+  = do { -- Rename the 'using' expression in the context before the transform is begun
        ; (using', fvs1) <- case using of
                              Left e  -> do { (e', fvs) <- rnLExpr e; return (Left e', fvs) }
                             Right _
@@ -810,11 +826,11 @@ rnStmt ctxt (L loc (GroupStmt stmts _ by using _ _ _)) thing_inside
                    ; return ((by', used_bndrs, thing), fvs) }
 
        -- Lookup `return`, `(>>=)` and `liftM` for monad comprehensions
-       ; ((return_op, fvs3), (bind_op, fvs4), (liftM_op, fvs5)) <-
+       ; ((return_op, fvs3), (bind_op, fvs4), (fmap_op, fvs5)) <-
              if isMonadCompExpr ctxt
                 then (,,) <$> lookupSyntaxName returnMName
                           <*> lookupSyntaxName bindMName
-                          <*> lookupSyntaxName liftMName
+                          <*> lookupSyntaxName fmapName
                 else return ( (noSyntaxExpr, emptyFVs)
                             , (noSyntaxExpr, emptyFVs)
                             , (noSyntaxExpr, emptyFVs) )
@@ -825,7 +841,7 @@ rnStmt ctxt (L loc (GroupStmt stmts _ by using _ _ _)) thing_inside
             -- See Note [GroupStmt binder map] in HsExpr
 
        ; traceRn (text "rnStmt: implicitly rebound these used binders:" <+> ppr bndr_map)
-       ; return (([L loc (GroupStmt stmts' bndr_map by' using' return_op bind_op liftM_op)], thing), all_fvs) }
+       ; return (([L loc (GroupStmt stmts' bndr_map by' using' return_op bind_op fmap_op)], thing), all_fvs) }
 
 type ParSeg id = ([LStmt id], [id])       -- The Names are bound by the Stmts
 
@@ -1182,22 +1198,124 @@ program.
 %************************************************************************
 
 \begin{code}
-
 ---------------------- 
 -- Checking when a particular Stmt is ok
-checkLetStmt :: HsStmtContext Name -> HsLocalBinds RdrName -> RnM ()
-checkLetStmt (ParStmtCtxt _) (HsIPBinds binds) = addErr (badIpBinds (ptext (sLit "a parallel list comprehension:")) binds)
-checkLetStmt _ctxt          _binds            = return ()
+checkStmt :: HsStmtContext Name
+          -> Bool                      -- True <=> this is the last Stmt in the sequence
+          -> LStmt RdrName 
+          -> RnM ()
+checkStmt ctxt is_last (L _ stmt)
+  = do { dflags <- getDOpts
+       ; case okStmt dflags ctxt is_last stmt of 
+           Nothing   -> return ()
+           Just extr -> addErr (msg $$ extra) }
+  where
+   msg = ptext (sLit "Unexpected") <+> pprStmtCat stmt 
+         <+> ptext (sLit "statement in") <+> pprStmtContext ctxt
+
+pprStmtCat :: Stmt a -> SDoc
+pprStmtCat (TransformStmt {}) = ptext (sLit "transform")
+pprStmtCat (GroupStmt {})     = ptext (sLit "group")
+pprStmtCat (LastStmt {})      = ptext (sLit "return expression")
+pprStmtCat (ExprStmt {})      = ptext (sLit "exprssion")
+pprStmtCat (BindStmt {})      = ptext (sLit "binding")
+pprStmtCat (LetStmt {})       = ptext (sLit "let")
+pprStmtCat (RecStmt {})       = ptext (sLit "rec")
+pprStmtCat (ParStmt {})       = ptext (sLit "parallel")
+
+------------
+isOK, notOK :: Maybe SDoc
+isOK  = Nothing
+notOK = Just empty
+
+okStmt, okDoStmt, okCompStmt :: DynFlags -> HsStmtContext Name -> Bool 
+                             -> Stmt RdrName -> Maybe SDoc
+-- Return Nothing if OK, (Just extra) if not ok
+-- The "extra" is an SDoc that is appended to an generic error message
+okStmt dflags GhciStmt is_last stmt 
+  = case stmt of
+      ExprStmt {} -> isOK
+      BindStmt {} -> isOK
+      LetStmt {}  -> isOK
+      _           -> notOK
+
+okStmt dflags (PatGuard {}) is_last stmt
+  = case stmt of
+      ExprStmt {} -> isOK
+      BindStmt {} -> isOK
+      LetStmt {}  -> isOK
+      _           -> notOK
+
+okStmt dflags (ParStmtCtxt ctxt) is_last stmt
+  = case stmt of
+      LetStmt (HsIPBinds {}) -> notOK
+      _                      -> okStmt dflags ctxt is_last stmt
+
+okStmt dflags (TransformStmtCtxt ctxt) is_last stmt 
+  = okStmt dflags ctxt is_last stmt
+
+okStmt ctxt is_last stmt 
+  | isDoExpr   ctxt = okDoStmt   ctxt is_last stmt
+  | isCompExpr ctxt = okCompStmt ctxt is_last stmt
+  | otherwise       = pprPanic "okStmt" (pprStmtContext ctxt)
+
+----------------
+okDoStmt dflags ctxt is_last stmt
+  | is_last
+  = case stmt of 
+      LastStmt {} -> isOK
+      _ -> Just (ptext (sLit "The last statement in") <+> what <+> 
+                 ptext (sLIt "construct must be an expression"))
+        where
+          what = case ctxt of 
+                   DoExpr  -> ptext (sLit "a 'do'")
+                   MDoExpr -> ptext (sLit "an 'mdo'")
+                  _       -> panic "checkStmt"
+
+  | otherwise
+  = case stmt of
+       RecStmt {}  -> isOK     -- Shouldn't we test a flag?
+       BindStmt {} -> isOK
+       LetStmt {}  -> isOK
+       ExprStmt {} -> isOK
+       _           -> notOK
+
+
+----------------
+okCompStmt dflags ctxt is_last stmt
+  | is_last
+  = case stmt of
+      LastStmt {} -> Nothing
+      -> pprPanic "Unexpected stmt" (ppr stmt) -- Not a user error
+
+  | otherwise
+  = case stmt of
+       BindStmt {} -> isOK
+       LetStmt {}  -> isOK
+       ExprStmt {} -> isOK
+       RecStmt {}  -> notOK
+       ParStmt {} 
+         | dopt dflags Opt_ParallelListComp -> isOK
+         | otherwise -> Just (ptext (sLit "Use -XParallelListComp"))
+       TransformStmt {} 
+         | dopt dflags Opt_transformListComp -> isOK
+         | otherwise -> Just (ptext (sLit "Use -XTransformListComp"))
+       GroupStmt {} 
+         | dopt dflags Opt_transformListComp -> isOK
+         | otherwise -> Just (ptext (sLit "Use -XTransformListComp"))
+      
+
+checkStmt :: HsStmtContext Name -> Stmt RdrName -> Maybe SDoc
+-- Non-last stmt
+
+checkStmt (ParStmtCtxt _) (HsIPBinds binds) 
+  = Just (badIpBinds (ptext (sLit "a parallel list comprehension:")) binds)
        -- We do not allow implicit-parameter bindings in a parallel
        -- list comprehension.  I'm not sure what it might mean.
 
----------
-checkRecStmt :: HsStmtContext Name -> RnM ()
-checkRecStmt MDoExpr = return ()      -- Recursive stmt ok in 'mdo'
-checkRecStmt DoExpr  = return ()      -- and in 'do'
-checkRecStmt ctxt    = addErr msg
-  where
-    msg = ptext (sLit "Illegal 'rec' stmt in") <+> pprStmtContext ctxt
+checkStmt ctxt (RecStmt {})
+  | not (isDoExpr ctxt) 
+  = addErr (ptext (sLit "Illegal 'rec' stmt in") <+> pprStmtContext ctxt)
 
 ---------
 checkParStmt :: HsStmtContext Name -> RnM ()