Fix Trac #2111: improve error handling for 'rec' in do-notation
authorsimonpj@microsoft.com <unknown>
Tue, 26 Feb 2008 17:56:35 +0000 (17:56 +0000)
committersimonpj@microsoft.com <unknown>
Tue, 26 Feb 2008 17:56:35 +0000 (17:56 +0000)
We were not dealing correctly with all the combinations of
do notation
mdo notation
arrow notation
in combination with 'rec' Stmts.

I think this patch sorts it out.

compiler/rename/RnExpr.lhs
compiler/typecheck/TcMatches.lhs

index a73d1a8..8c96a5f 100644 (file)
@@ -412,7 +412,6 @@ convertOpFormsCmd (OpApp c1 op fixity c2)
 
 convertOpFormsCmd (HsPar c) = HsPar (convertOpFormsLCmd c)
 
 
 convertOpFormsCmd (HsPar c) = HsPar (convertOpFormsLCmd c)
 
--- gaw 2004
 convertOpFormsCmd (HsCase exp matches)
   = HsCase exp (convertOpFormsMatch matches)
 
 convertOpFormsCmd (HsCase exp matches)
   = HsCase exp (convertOpFormsMatch matches)
 
@@ -659,35 +658,32 @@ rnStmt ctxt (BindStmt pat expr _ _) thing_inside
        -- fv_expr shouldn't really be filtered by the rnPatsAndThen
        -- but it does not matter because the names are unique
 
        -- fv_expr shouldn't really be filtered by the rnPatsAndThen
        -- but it does not matter because the names are unique
 
-rnStmt ctxt (LetStmt binds) thing_inside = do
-    checkErr (ok ctxt binds) (badIpBinds (ptext SLIT("a parallel list comprehension:")) binds)
-    rnLocalBindsAndThen binds $ \binds' -> do
-        (thing, fvs) <- thing_inside
-        return ((LetStmt binds', thing), fvs)
-  where
-       -- We do not allow implicit-parameter bindings in a parallel
-       -- list comprehension.  I'm not sure what it might mean.
-    ok (ParStmtCtxt _) (HsIPBinds _) = False
-    ok _              _             = True
+rnStmt ctxt (LetStmt binds) thing_inside 
+  = do { checkLetStmt ctxt binds
+       ; rnLocalBindsAndThen binds $ \binds' -> do
+       { (thing, fvs) <- thing_inside
+        ; return ((LetStmt binds', thing), fvs) }  }
 
 rnStmt ctxt (RecStmt rec_stmts _ _ _ _) thing_inside
 
 rnStmt ctxt (RecStmt rec_stmts _ _ _ _) thing_inside
-  = 
-    rn_rec_stmts_and_then rec_stmts    $ \ segs ->
-    thing_inside                       `thenM` \ (thing, fvs) ->
-    let
-       segs_w_fwd_refs          = addFwdRefs segs
-       (ds, us, fs, rec_stmts') = unzip4 segs_w_fwd_refs
-       later_vars = nameSetToList (plusFVs ds `intersectNameSet` fvs)
-       fwd_vars   = nameSetToList (plusFVs fs)
-       uses       = plusFVs us
-       rec_stmt   = RecStmt rec_stmts' later_vars fwd_vars [] emptyLHsBinds
-    in 
-    returnM ((rec_stmt, thing), uses `plusFV` fvs)
-  where
-    doc = text "In a recursive do statement"
+  = do { checkRecStmt ctxt
+       ; rn_rec_stmts_and_then rec_stmts       $ \ segs -> do
+       { (thing, fvs) <- thing_inside
+       ; let
+           segs_w_fwd_refs          = addFwdRefs segs
+           (ds, us, fs, rec_stmts') = unzip4 segs_w_fwd_refs
+           later_vars = nameSetToList (plusFVs ds `intersectNameSet` fvs)
+           fwd_vars   = nameSetToList (plusFVs fs)
+           uses       = plusFVs us
+           rec_stmt   = RecStmt rec_stmts' later_vars fwd_vars [] emptyLHsBinds
+       ; return ((rec_stmt, thing), uses `plusFV` fvs) } }
+
+rnStmt ctxt (ParStmt segs) thing_inside
+  = do { checkParStmt ctxt
+       ; ((segs', thing), fvs) <- rnParallelStmts (ParStmtCtxt ctxt) segs thing_inside
+       ; return ((ParStmt segs', thing), fvs) }
 
 rnStmt ctxt (TransformStmt (stmts, _) usingExpr maybeByExpr) thing_inside = do
 
 rnStmt ctxt (TransformStmt (stmts, _) usingExpr maybeByExpr) thing_inside = do
-    checkIsTransformableListComp ctxt
+    checkTransformStmt ctxt
     
     (usingExpr', fv_usingExpr) <- rnLExpr usingExpr
     ((stmts', binders, (maybeByExpr', thing)), fvs) <- 
     
     (usingExpr', fv_usingExpr) <- rnLExpr usingExpr
     ((stmts', binders, (maybeByExpr', thing)), fvs) <- 
@@ -705,7 +701,7 @@ rnStmt ctxt (TransformStmt (stmts, _) usingExpr maybeByExpr) thing_inside = do
         return (Just expr', fv_expr)
         
 rnStmt ctxt (GroupStmt (stmts, _) groupByClause) thing_inside = do
         return (Just expr', fv_expr)
         
 rnStmt ctxt (GroupStmt (stmts, _) groupByClause) thing_inside = do
-    checkIsTransformableListComp ctxt
+    checkTransformStmt ctxt
     
     -- We must rename the using expression in the context before the transform is begun
     groupByClauseAction <- 
     
     -- We must rename the using expression in the context before the transform is begun
     groupByClauseAction <- 
@@ -763,13 +759,6 @@ rnStmt ctxt (GroupStmt (stmts, _) groupByClause) thing_inside = do
     traceRn (text "rnStmt: implicitly rebound these used binders:" <+> ppr usedBinderMap)
     return ((GroupStmt (stmts', usedBinderMap) groupByClause', thing), fvs)
   
     traceRn (text "rnStmt: implicitly rebound these used binders:" <+> ppr usedBinderMap)
     return ((GroupStmt (stmts', usedBinderMap) groupByClause', thing), fvs)
   
-rnStmt ctxt (ParStmt segs) thing_inside
-  = do { parallel_list_comp <- doptM Opt_ParallelListComp
-       ; checkM parallel_list_comp parStmtErr
-       ; ((segs', thing), fvs) <- rnParallelStmts (ParStmtCtxt ctxt) segs thing_inside
-       ; return ((ParStmt segs', thing), fvs) }
-
-
 rnNormalStmtsAndFindUsedBinders :: HsStmtContext Name 
           -> [LStmt RdrName]
           -> ([Name] -> RnM (thing, FreeVars))
 rnNormalStmtsAndFindUsedBinders :: HsStmtContext Name 
           -> [LStmt RdrName]
           -> ([Name] -> RnM (thing, FreeVars))
@@ -828,21 +817,6 @@ rnParallelStmts ctxt segs thing_inside = do
         cmpByOcc n1 n2 = nameOccName n1 `compare` nameOccName n2
         dupErr vs = addErr (ptext SLIT("Duplicate binding in parallel list comprehension for:")
                     <+> quotes (ppr (head vs)))
         cmpByOcc n1 n2 = nameOccName n1 `compare` nameOccName n2
         dupErr vs = addErr (ptext SLIT("Duplicate binding in parallel list comprehension for:")
                     <+> quotes (ppr (head vs)))
-
-
-checkIsTransformableListComp :: HsStmtContext Name -> RnM ()
-checkIsTransformableListComp ctxt = do
-    -- Ensure we are really within a list comprehension because otherwise the
-    -- desugarer will break when we come to operate on a parallel array
-    checkM (notParallelArray ctxt) transformStmtOutsideListCompErr
-    
-    -- Ensure the user has turned the correct flag on
-    transform_list_comp <- doptM Opt_TransformListComp
-    checkM transform_list_comp transformStmtErr
-  where
-    notParallelArray PArrComp = False
-    notParallelArray _        = True
-    
 \end{code}
 
 
 \end{code}
 
 
@@ -1177,19 +1151,54 @@ mkAssertErrorExpr
 %************************************************************************
 
 \begin{code}
 %************************************************************************
 
 \begin{code}
-patSynErr e = do { addErr (sep [ptext SLIT("Pattern syntax in expression context:"),
-                               nest 4 (ppr e)])
-                ; return (EWildPat, emptyFVs) }
 
 
+---------------------- 
+-- Checking when a particular Stmt is ok
+checkLetStmt :: HsStmtContext Name -> HsLocalBinds RdrName -> RnM ()
+checkLetStmt (ParStmtCtxt _) (HsIPBinds binds) = addErr (badIpBinds (ptext SLIT("a parallel list comprehension:")) binds)
+checkLetStmt _ctxt          _binds            = return ()
+       -- We do not allow implicit-parameter bindings in a parallel
+       -- list comprehension.  I'm not sure what it might mean.
 
 
-parStmtErr = addErr (ptext SLIT("Illegal parallel list comprehension: use -XParallelListComp"))
+---------
+checkRecStmt :: HsStmtContext Name -> RnM ()
+checkRecStmt (MDoExpr {}) = return ()  -- Recursive stmt ok in 'mdo'
+checkRecStmt (DoExpr {})  = return ()  -- ..and in 'do' but only because of arrows:
+                                       --   proc x -> do { ...rec... }
+                                       -- We don't have enough context to distinguish this situation here
+                                       --      so we leave it to the type checker
+checkRecStmt ctxt        = addErr msg
+  where
+    msg = ptext SLIT("Illegal 'rec' stmt in") <+> pprStmtContext ctxt
 
 
-transformStmtErr = addErr (ptext SLIT("Illegal transform or grouping list comprehension: use -XTransformListComp"))
-transformStmtOutsideListCompErr = addErr (ptext SLIT("Currently you may only use transform or grouping comprehensions within list comprehensions, not parallel array comprehensions"))
+---------
+checkParStmt :: HsStmtContext Name -> RnM ()
+checkParStmt ctxt 
+  = do { parallel_list_comp <- doptM Opt_ParallelListComp
+       ; checkErr parallel_list_comp msg }
+  where
+    msg = ptext SLIT("Illegal parallel list comprehension: use -XParallelListComp")
+
+---------
+checkTransformStmt :: HsStmtContext Name -> RnM ()
+checkTransformStmt ListComp  -- Ensure we are really within a list comprehension because otherwise the
+                            -- desugarer will break when we come to operate on a parallel array
+  = do { transform_list_comp <- doptM Opt_TransformListComp
+       ; checkErr transform_list_comp msg }
+  where
+    msg = ptext SLIT("Illegal transform or grouping list comprehension: use -XTransformListComp")
+checkTransformStmt (ParStmtCtxt       ctxt) = checkTransformStmt ctxt  -- Ok to nest inside a parallel comprehension
+checkTransformStmt (TransformStmtCtxt ctxt) = checkTransformStmt ctxt  -- Ok to nest inside a parallel comprehension
+checkTransformStmt ctxt = addErr msg
+  where
+    msg = ptext SLIT("Illegal transform or grouping in") <+> pprStmtContext ctxt
+    
+---------
+patSynErr e = do { addErr (sep [ptext SLIT("Pattern syntax in expression context:"),
+                               nest 4 (ppr e)])
+                ; return (EWildPat, emptyFVs) }
 
 badIpBinds what binds
   = hang (ptext SLIT("Implicit-parameter bindings illegal in") <+> what)
         2 (ppr binds)
 \end{code}
 
 badIpBinds what binds
   = hang (ptext SLIT("Implicit-parameter bindings illegal in") <+> what)
         2 (ppr binds)
 \end{code}
-
-
index f02b74a..452bae7 100644 (file)
@@ -509,6 +509,11 @@ tcDoStmt ctxt (ExprStmt rhs then_op _) (reft,res_ty) thing_inside
        ; thing <- thing_inside (reft, new_res_ty)
        ; return (ExprStmt rhs' then_op' rhs_ty, thing) }
 
        ; thing <- thing_inside (reft, new_res_ty)
        ; return (ExprStmt rhs' then_op' rhs_ty, thing) }
 
+tcDoStmt ctxt (RecStmt {}) res_ty thing_inside
+  = failWithTc (ptext SLIT("Illegal 'rec' stmt in") <+> pprStmtContext ctxt)
+       -- This case can't be caught in the renamer
+       -- see RnExpr.checkRecStmt
+
 tcDoStmt ctxt stmt res_ty thing_inside
   = pprPanic "tcDoStmt: unexpected Stmt" (ppr stmt)
 
 tcDoStmt ctxt stmt res_ty thing_inside
   = pprPanic "tcDoStmt: unexpected Stmt" (ppr stmt)