Preliminary monad-comprehension patch (Trac #4370)
[ghc-hetmet.git] / compiler / rename / RnExpr.lhs
index d11249a..425cb40 100644 (file)
@@ -224,10 +224,16 @@ rnExpr (HsLet binds expr)
     rnLExpr expr                        `thenM` \ (expr',fvExpr) ->
     return (HsLet binds' expr', fvExpr)
 
-rnExpr (HsDo do_or_lc stmts body _)
-  = do  { ((stmts', body'), fvs) <- rnStmts do_or_lc stmts $ \ _ ->
-                                   rnLExpr body
-       ; return (HsDo do_or_lc stmts' body' placeHolderType, fvs) }
+rnExpr (HsDo do_or_lc stmts body _ _)
+  = do         { ((stmts', body'), fvs1) <- rnStmts do_or_lc stmts $ \ _ ->
+                                    rnLExpr body
+        ; (return_op, fvs2) <-
+              if isMonadCompExpr do_or_lc
+                 then lookupSyntaxName returnMName
+                 else return (noSyntaxExpr, emptyFVs)
+
+       ; return ( HsDo do_or_lc stmts' body' return_op placeHolderType
+                 , fvs1 `plusFV` fvs2 ) }
 
 rnExpr (ExplicitList _ exps)
   = rnExprs exps                       `thenM` \ (exps', fvs) ->
@@ -441,9 +447,10 @@ convertOpFormsCmd (HsIf f exp c1 c2)
 convertOpFormsCmd (HsLet binds cmd)
   = HsLet binds (convertOpFormsLCmd cmd)
 
-convertOpFormsCmd (HsDo ctxt stmts body ty)
+convertOpFormsCmd (HsDo ctxt stmts body return_op ty)
   = HsDo ctxt (map (fmap convertOpFormsStmt) stmts)
-             (convertOpFormsLCmd body) ty
+             (convertOpFormsLCmd body)
+              (convertOpFormsCmd  return_op) ty
 
 -- Anything else is unchanged.  This includes HsArrForm (already done),
 -- things with no sub-commands, and illegal commands (which will be
@@ -453,8 +460,8 @@ convertOpFormsCmd c = c
 convertOpFormsStmt :: StmtLR id id -> StmtLR id id
 convertOpFormsStmt (BindStmt pat cmd _ _)
   = BindStmt pat (convertOpFormsLCmd cmd) noSyntaxExpr noSyntaxExpr
-convertOpFormsStmt (ExprStmt cmd _ _)
-  = ExprStmt (convertOpFormsLCmd cmd) noSyntaxExpr placeHolderType
+convertOpFormsStmt (ExprStmt cmd _ _ _)
+  = ExprStmt (convertOpFormsLCmd cmd) noSyntaxExpr noSyntaxExpr placeHolderType
 convertOpFormsStmt stmt@(RecStmt { recS_stmts = stmts })
   = stmt { recS_stmts = map (fmap convertOpFormsStmt) stmts }
 convertOpFormsStmt stmt = stmt
@@ -497,7 +504,7 @@ methodNamesCmd (HsIf _ _ c1 c2)
 
 methodNamesCmd (HsLet _ c) = methodNamesLCmd c
 
-methodNamesCmd (HsDo _ stmts body _) 
+methodNamesCmd (HsDo _ stmts body _ _) 
   = methodNamesStmts stmts `plusFV` methodNamesLCmd body
 
 methodNamesCmd (HsApp c _) = methodNamesLCmd c
@@ -538,11 +545,11 @@ methodNamesLStmt :: Located (StmtLR Name Name) -> FreeVars
 methodNamesLStmt = methodNamesStmt . unLoc
 
 methodNamesStmt :: StmtLR Name Name -> FreeVars
-methodNamesStmt (ExprStmt cmd _ _)               = methodNamesLCmd cmd
+methodNamesStmt (ExprStmt cmd _ _ _)             = methodNamesLCmd cmd
 methodNamesStmt (BindStmt _ cmd _ _)             = methodNamesLCmd cmd
 methodNamesStmt (RecStmt { recS_stmts = stmts }) = methodNamesStmts stmts `addOneFV` loopAName
 methodNamesStmt (LetStmt _)                      = emptyFVs
-methodNamesStmt (ParStmt _)                      = emptyFVs
+methodNamesStmt (ParStmt _ _ _ _)                = emptyFVs
 methodNamesStmt (TransformStmt {})               = emptyFVs
 methodNamesStmt (GroupStmt {})                   = emptyFVs
    -- ParStmt, TransformStmt and GroupStmt can't occur in commands, but it's not convenient to error 
@@ -665,12 +672,15 @@ rnStmt :: HsStmtContext Name -> LStmt RdrName
 -- Variables bound by the Stmt, and mentioned in thing_inside,
 -- do not appear in the result FreeVars
 
-rnStmt _ (L loc (ExprStmt expr _ _)) thing_inside
+rnStmt ctxt (L loc (ExprStmt expr _ _ _)) thing_inside
   = do { (expr', fv_expr) <- rnLExpr expr
        ; (then_op, fvs1)  <- lookupSyntaxName thenMName
-       ; (thing, fvs2)    <- thing_inside []
-       ; return (([L loc (ExprStmt expr' then_op placeHolderType)], thing),
-                 fv_expr `plusFV` fvs1 `plusFV` fvs2) }
+       ; (guard_op, fvs2) <- if isMonadCompExpr ctxt
+                                 then lookupSyntaxName guardMName
+                                 else return (noSyntaxExpr, emptyFVs)
+       ; (thing, fvs3)    <- thing_inside []
+       ; return (([L loc (ExprStmt expr' then_op guard_op placeHolderType)], thing),
+                 fv_expr `plusFV` fvs1 `plusFV` fvs2 `plusFV` fvs3) }
 
 rnStmt ctxt (L loc (BindStmt pat expr _ _)) thing_inside
   = do { (expr', fv_expr) <- rnLExpr expr
@@ -734,12 +744,20 @@ rnStmt ctxt (L _ (RecStmt { recS_stmts = rec_stmts })) thing_inside
 
        ; return ((rec_stmts', thing), fvs `plusFV` fvs1 `plusFV` fvs2 `plusFV` fvs3) } }
 
-rnStmt ctxt (L loc (ParStmt segs)) thing_inside
+rnStmt ctxt (L loc (ParStmt segs _ _ _)) thing_inside
   = do { checkParStmt ctxt
-       ; ((segs', thing), fvs) <- rnParallelStmts (ParStmtCtxt ctxt) segs thing_inside
-       ; return (([L loc (ParStmt segs')], thing), fvs) }
-
-rnStmt ctxt (L loc (TransformStmt stmts _ using by)) thing_inside
+        ; ((mzip_op, fvs1), (bind_op, fvs2), (return_op, fvs3)) <- if isMonadCompExpr ctxt
+              then (,,) <$> lookupSyntaxName mzipName
+                        <*> lookupSyntaxName bindMName
+                        <*> lookupSyntaxName returnMName
+              else return ( (noSyntaxExpr, emptyFVs)
+                          , (noSyntaxExpr, emptyFVs)
+                          , (noSyntaxExpr, emptyFVs) )
+       ; ((segs', thing), fvs4) <- rnParallelStmts (ParStmtCtxt ctxt) segs thing_inside
+       ; return ( ([L loc (ParStmt segs' mzip_op bind_op return_op)], thing)
+                 , fvs1 `plusFV` fvs2 `plusFV` fvs3 `plusFV` fvs4) }
+
+rnStmt ctxt (L loc (TransformStmt stmts _ using by _ _)) thing_inside
   = do { checkTransformStmt ctxt
     
        ; (using', fvs1) <- rnLExpr using
@@ -756,17 +774,30 @@ rnStmt ctxt (L loc (TransformStmt stmts _ using by)) thing_inside
                          -- the "thing inside", **or of the by-expression**, as used
                    ; return ((by', used_bndrs, thing), fvs) }
 
-       ; return (([L loc (TransformStmt stmts' used_bndrs using' by')], thing), 
-                 fvs1 `plusFV` fvs2) }
+       -- Lookup `(>>=)` and `fail` for monad comprehensions
+       ; ((return_op, fvs3), (bind_op, fvs4)) <-
+             if isMonadCompExpr ctxt
+                then (,) <$> lookupSyntaxName returnMName
+                         <*> lookupSyntaxName bindMName
+                else return ( (noSyntaxExpr, emptyFVs)
+                            , (noSyntaxExpr, emptyFVs) )
+
+       ; return (([L loc (TransformStmt stmts' used_bndrs using' by' return_op bind_op)], thing), 
+                 fvs1 `plusFV` fvs2 `plusFV` fvs3 `plusFV` fvs4) }
         
-rnStmt ctxt (L loc (GroupStmt stmts _ by using)) thing_inside
+rnStmt ctxt (L loc (GroupStmt stmts _ by using _ _ _)) thing_inside
   = do { checkTransformStmt ctxt
     
          -- Rename the 'using' expression in the context before the transform is begun
        ; (using', fvs1) <- case using of
                              Left e  -> do { (e', fvs) <- rnLExpr e; return (Left e', fvs) }
-                            Right _ -> do { (e', fvs) <- lookupSyntaxName groupWithName
-                                           ; return (Right e', fvs) }
+                            Right _
+                                | isMonadCompExpr ctxt ->
+                                  do { (e', fvs) <- lookupSyntaxName groupMName
+                                     ; return (Right e', fvs) }
+                                | otherwise ->
+                                  do { (e', fvs) <- lookupSyntaxName groupWithName
+                                     ; return (Right e', fvs) }
 
          -- Rename the stmts and the 'by' expression
         -- Keep track of the variables mentioned in the 'by' expression
@@ -778,13 +809,23 @@ rnStmt ctxt (L loc (GroupStmt stmts _ by using)) thing_inside
                          used_bndrs = filter (`elemNameSet` fvs) bndrs
                    ; return ((by', used_bndrs, thing), fvs) }
 
-       ; let all_fvs  = fvs1 `plusFV` fvs2 
+       -- Lookup `return`, `(>>=)` and `liftM` for monad comprehensions
+       ; ((return_op, fvs3), (bind_op, fvs4), (liftM_op, fvs5)) <-
+             if isMonadCompExpr ctxt
+                then (,,) <$> lookupSyntaxName returnMName
+                          <*> lookupSyntaxName bindMName
+                          <*> lookupSyntaxName liftMName
+                else return ( (noSyntaxExpr, emptyFVs)
+                            , (noSyntaxExpr, emptyFVs)
+                            , (noSyntaxExpr, emptyFVs) )
+
+       ; let all_fvs  = fvs1 `plusFV` fvs2 `plusFV` fvs3 `plusFV` fvs4
+                             `plusFV` fvs5
              bndr_map = used_bndrs `zip` used_bndrs
             -- See Note [GroupStmt binder map] in HsExpr
 
        ; traceRn (text "rnStmt: implicitly rebound these used binders:" <+> ppr bndr_map)
-       ; return (([L loc (GroupStmt stmts' bndr_map by' using')], thing), all_fvs) }
-
+       ; return (([L loc (GroupStmt stmts' bndr_map by' using' return_op bind_op liftM_op)], thing), all_fvs) }
 
 type ParSeg id = ([LStmt id], [id])       -- The Names are bound by the Stmts
 
@@ -901,9 +942,9 @@ rn_rec_stmt_lhs :: MiniFixityEnv
                    -- so we don't bother to compute it accurately in the other cases
                 -> RnM [(LStmtLR Name RdrName, FreeVars)]
 
-rn_rec_stmt_lhs _ (L loc (ExprStmt expr a b)) = return [(L loc (ExprStmt expr a b), 
-                                                       -- this is actually correct
-                                                       emptyFVs)]
+rn_rec_stmt_lhs _ (L loc (ExprStmt expr a b c)) = return [(L loc (ExprStmt expr a b c), 
+                                                         -- this is actually correct
+                                                         emptyFVs)]
 
 rn_rec_stmt_lhs fix_env (L loc (BindStmt pat expr a b)) 
   = do 
@@ -926,7 +967,7 @@ rn_rec_stmt_lhs fix_env (L loc (LetStmt (HsValBinds binds)))
 rn_rec_stmt_lhs fix_env (L _ (RecStmt { recS_stmts = stmts })) -- Flatten Rec inside Rec
     = rn_rec_stmts_lhs fix_env stmts
 
-rn_rec_stmt_lhs _ stmt@(L _ (ParStmt _))       -- Syntactically illegal in mdo
+rn_rec_stmt_lhs _ stmt@(L _ (ParStmt _ _ _ _)) -- Syntactically illegal in mdo
   = pprPanic "rn_rec_stmt" (ppr stmt)
   
 rn_rec_stmt_lhs _ stmt@(L _ (TransformStmt {}))        -- Syntactically illegal in mdo
@@ -957,11 +998,11 @@ rn_rec_stmt :: [Name] -> LStmtLR Name RdrName -> FreeVars -> RnM [Segment (LStmt
        -- Rename a Stmt that is inside a RecStmt (or mdo)
        -- Assumes all binders are already in scope
        -- Turns each stmt into a singleton Stmt
-rn_rec_stmt _ (L loc (ExprStmt expr _ _)) _
+rn_rec_stmt _ (L loc (ExprStmt expr _ _ _)) _
   = rnLExpr expr `thenM` \ (expr', fvs) ->
     lookupSyntaxName thenMName `thenM` \ (then_op, fvs1) ->
     return [(emptyNameSet, fvs `plusFV` fvs1, emptyNameSet,
-             L loc (ExprStmt expr' then_op placeHolderType))]
+             L loc (ExprStmt expr' then_op noSyntaxExpr placeHolderType))]
 
 rn_rec_stmt _ (L loc (BindStmt pat' expr _ _)) fv_pat
   = rnLExpr expr               `thenM` \ (expr', fv_expr) ->
@@ -1161,10 +1202,13 @@ checkRecStmt ctxt    = addErr msg
 ---------
 checkParStmt :: HsStmtContext Name -> RnM ()
 checkParStmt _
-  = do { parallel_list_comp <- xoptM Opt_ParallelListComp
-       ; checkErr parallel_list_comp msg }
+  = do { monad_comp <- xoptM Opt_MonadComprehensions
+        ; unless monad_comp $ do
+          { parallel_list_comp <- xoptM Opt_ParallelListComp
+         ; checkErr parallel_list_comp msg }
+        }
   where
-    msg = ptext (sLit "Illegal parallel list comprehension: use -XParallelListComp")
+    msg = ptext (sLit "Illegal parallel list comprehension: use -XParallelListComp or -XMonadComprehensions")
 
 ---------
 checkTransformStmt :: HsStmtContext Name -> RnM ()
@@ -1173,7 +1217,10 @@ checkTransformStmt ListComp  -- Ensure we are really within a list comprehension
   = do { transform_list_comp <- xoptM Opt_TransformListComp
        ; checkErr transform_list_comp msg }
   where
-    msg = ptext (sLit "Illegal transform or grouping list comprehension: use -XTransformListComp")
+    msg = ptext (sLit "Illegal transform or grouping list comprehension: use -XTransformListComp or -XMonadComprehensions")
+checkTransformStmt MonadComp  -- Monad comprehensions are always fine, since the
+                              -- MonadComprehensions flag will already be turned on
+  = do  { return () }
 checkTransformStmt (ParStmtCtxt       ctxt) = checkTransformStmt ctxt  -- Ok to nest inside a parallel comprehension
 checkTransformStmt (TransformStmtCtxt ctxt) = checkTransformStmt ctxt  -- Ok to nest inside a parallel comprehension
 checkTransformStmt ctxt = addErr msg