Preliminary monad-comprehension patch (Trac #4370)
[ghc-hetmet.git] / compiler / hsSyn / HsExpr.lhs
index 06616f1..e367af5 100644 (file)
@@ -148,6 +148,8 @@ data HsExpr id
                 [LStmt id]           -- "do":one or more stmts
                 (LHsExpr id)         -- The body; the last expression in the
                                      -- 'do' of [ body | ... ] in a list comp
+                (SyntaxExpr id)      -- The 'return' function, see Note
+                                     -- [Monad Comprehensions]
                 PostTcType           -- Type of the whole expression
 
   | ExplicitList                -- syntactic list
@@ -439,7 +441,7 @@ ppr_expr (HsLet binds expr)
   = sep [hang (ptext (sLit "let")) 2 (pprBinds binds),
          hang (ptext (sLit "in"))  2 (ppr expr)]
 
-ppr_expr (HsDo do_or_list_comp stmts body _) = pprDo do_or_list_comp stmts body
+ppr_expr (HsDo do_or_list_comp stmts body _ _) = pprDo do_or_list_comp stmts body
 
 ppr_expr (ExplicitList _ exprs)
   = brackets (pprDeeperList fsep (punctuate comma (map ppr_lexpr exprs)))
@@ -575,7 +577,7 @@ pprParendExpr expr
       HsPar {}          -> pp_as_was
       HsBracket {}      -> pp_as_was
       HsBracketOut _ [] -> pp_as_was
-      HsDo sc _ _ _
+      HsDo sc _ _ _ _
        | isListCompExpr sc -> pp_as_was
       _                    -> parens pp_as_was
 
@@ -830,8 +832,8 @@ type LStmtLR idL idR = Located (StmtLR idL idR)
 
 type Stmt id = StmtLR id id
 
--- The SyntaxExprs in here are used *only* for do-notation, which
--- has rebindable syntax.  Otherwise they are unused.
+-- The SyntaxExprs in here are used *only* for do-notation and monad
+-- comprehensions, which have rebindable syntax. Otherwise they are unused.
 data StmtLR idL idR
   = BindStmt (LPat idL)
              (LHsExpr idR)
@@ -842,17 +844,24 @@ data StmtLR idL idR
 
   | ExprStmt (LHsExpr idR)     -- See Note [ExprStmt]
              (SyntaxExpr idR) -- The (>>) operator
+             (SyntaxExpr idR) -- The `guard` operator
+                              -- See notes [Monad Comprehensions]
              PostTcType       -- Element type of the RHS (used for arrows)
 
   | LetStmt  (HsLocalBindsLR idL idR)
 
-  -- ParStmts only occur in a list comprehension
+  -- ParStmts only occur in a list/monad comprehension
   | ParStmt  [([LStmt idL], [idR])]
+             (SyntaxExpr idR)           -- polymorphic `mzip` for monad comprehensions
+             (SyntaxExpr idR)           -- The `>>=` operator
+             (SyntaxExpr idR)           -- polymorphic `return` operator
+                                        -- See notes [Monad Comprehensions]
+
   -- After renaming, the ids are the binders bound by the stmts and used
   -- after them
 
-  -- "qs, then f by e" ==> TransformStmt qs binders f (Just e)
-  -- "qs, then f"      ==> TransformStmt qs binders f Nothing
+  -- "qs, then f by e" ==> TransformStmt qs binders f (Just e) (return) (>>=)
+  -- "qs, then f"      ==> TransformStmt qs binders f Nothing  (return) (>>=)
   | TransformStmt 
          [LStmt idL]   -- Stmts are the ones to the left of the 'then'
 
@@ -863,6 +872,11 @@ data StmtLR idL idR
 
          (Maybe (LHsExpr idR)) -- "by e" (optional)
 
+         (SyntaxExpr idR)       -- The 'return' function for inner monad
+                                -- comprehensions
+         (SyntaxExpr idR)       -- The '(>>=)' operator.
+                                -- See Note [Monad Comprehensions]
+
   | GroupStmt 
          [LStmt idL]      -- Stmts to the *left* of the 'group'
                          -- which generates the tuples to be grouped
@@ -874,7 +888,14 @@ data StmtLR idL idR
          (Either               -- "using f"
              (LHsExpr idR)     --   Left f  => explicit "using f"
              (SyntaxExpr idR)) --   Right f => implicit; filled in with 'groupWith'
-                                                       
+                                --     (list comprehensions) or 'groupM' (monad
+                                --     comprehensions)
+
+         (SyntaxExpr idR)       -- The 'return' function for inner monad
+                                -- comprehensions
+         (SyntaxExpr idR)       -- The '(>>=)' operator
+         (SyntaxExpr idR)       -- The 'liftM' function from Control.Monad for desugaring
+                                -- See Note [Monad Comprehensions]
 
   -- Recursive statement (see Note [How RecStmt works] below)
   | RecStmt
@@ -952,6 +973,12 @@ depends on the context.  Consider the following contexts:
                 E :: Bool
           Translation: if E then fail else ...
 
+        A monad comprehension of type (m res_ty)
+        ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+        * ExprStmt E Bool:   [ .. | .... E ]
+                E :: Bool
+          Translation: guard E >> ...
+
 Array comprehensions are handled like list comprehensions -=chak
 
 Note [How RecStmt works]
@@ -993,6 +1020,45 @@ A (RecStmt stmts) types as if you had written
 where v1..vn are the later_ids
       r1..rm are the rec_ids
 
+Note [Monad Comprehensions]
+~~~~~~~~~~~~~~~~~~~~~~~~~~~
+Monad comprehensions require seperate functions like 'return' and '>>=' for
+desugaring. These functions are stored in the 'HsDo' expression and the
+statements used in monad comprehensions. For example, the 'return' of the
+'HsDo' expression is used to lift the body of the monad comprehension:
+
+  [ body | stmts ]
+   =>
+  stmts >>= \bndrs -> return body
+
+In transform and grouping statements ('then ..' and 'then group ..') the
+'return' function is required for nested monad comprehensions, for example:
+
+  [ body | stmts, then f, rest ]
+   =>
+  f [ env | stmts ] >>= \bndrs -> [ body | rest ]
+
+Normal expressions require the 'Control.Monad.guard' function for boolean
+expressions:
+
+  [ body | exp, stmts ]
+   =>
+  guard exp >> [ body | stmts ]
+
+Grouping/parallel statements require the 'Control.Monad.Group.groupM' and
+'Control.Monad.Zip.mzip' functions:
+
+  [ body | stmts, then group by e, rest]
+   =>
+  groupM [ body | stmts ] >>= \bndrs -> [ body | rest ]
+
+  [ body | stmts1 | stmts2 | .. ]
+   =>
+  mzip stmts1 (mzip stmts2 (..)) >>= \(bndrs1, (bndrs2, ..)) -> return body
+
+In any other context than 'MonadComp', the fields for most of these
+'SyntaxExpr's stay bottom.
+
 
 \begin{code}
 instance (OutputableBndr idL, OutputableBndr idR) => Outputable (StmtLR idL idR) where
@@ -1001,14 +1067,14 @@ instance (OutputableBndr idL, OutputableBndr idR) => Outputable (StmtLR idL idR)
 pprStmt :: (OutputableBndr idL, OutputableBndr idR) => (StmtLR idL idR) -> SDoc
 pprStmt (BindStmt pat expr _ _)   = hsep [ppr pat, ptext (sLit "<-"), ppr expr]
 pprStmt (LetStmt binds)           = hsep [ptext (sLit "let"), pprBinds binds]
-pprStmt (ExprStmt expr _ _)       = ppr expr
-pprStmt (ParStmt stmtss)          = hsep (map doStmts stmtss)
+pprStmt (ExprStmt expr _ _ _)     = ppr expr
+pprStmt (ParStmt stmtss _ _ _)    = hsep (map doStmts stmtss)
   where doStmts stmts = ptext (sLit "| ") <> ppr stmts
 
-pprStmt (TransformStmt stmts bndrs using by)
+pprStmt (TransformStmt stmts bndrs using by _ _)
   = sep (ppr_lc_stmts stmts ++ [pprTransformStmt bndrs using by])
 
-pprStmt (GroupStmt stmts _ by using) 
+pprStmt (GroupStmt stmts _ by using _ _ _) 
   = sep (ppr_lc_stmts stmts ++ [pprGroupStmt by using])
 
 pprStmt (RecStmt { recS_stmts = segment, recS_rec_ids = rec_ids
@@ -1043,6 +1109,7 @@ pprDo GhciStmt    stmts body = ptext (sLit "do")  <+> ppr_do_stmts stmts body
 pprDo MDoExpr     stmts body = ptext (sLit "mdo") <+> ppr_do_stmts stmts body
 pprDo ListComp    stmts body = brackets    $ pprComp stmts body
 pprDo PArrComp    stmts body = pa_brackets $ pprComp stmts body
+pprDo MonadComp   stmts body = brackets    $ pprComp stmts body
 pprDo _           _     _    = panic "pprDo" -- PatGuard, ParStmtCxt
 
 ppr_do_stmts :: OutputableBndr id => [LStmt id] -> LHsExpr id -> SDoc
@@ -1178,6 +1245,7 @@ data HsStmtContext id
   | DoExpr
   | GhciStmt                            -- A command-line Stmt in GHCi pat <- rhs
   | MDoExpr                              -- Recursive do-expression
+  | MonadComp
   | PArrComp                             -- Parallel array comprehension
   | PatGuard (HsMatchContext id)         -- Pattern guard for specified thing
   | ParStmtCtxt (HsStmtContext id)       -- A branch of a parallel stmt
@@ -1192,9 +1260,16 @@ isDoExpr MDoExpr = True
 isDoExpr _       = False
 
 isListCompExpr :: HsStmtContext id -> Bool
-isListCompExpr ListComp = True
-isListCompExpr PArrComp = True
-isListCompExpr _        = False
+isListCompExpr ListComp  = True
+isListCompExpr PArrComp  = True
+isListCompExpr MonadComp = True
+isListCompExpr _         = False
+
+isMonadCompExpr :: HsStmtContext id -> Bool
+isMonadCompExpr MonadComp                = True
+isMonadCompExpr (ParStmtCtxt ctxt)       = isMonadCompExpr ctxt
+isMonadCompExpr (TransformStmtCtxt ctxt) = isMonadCompExpr ctxt
+isMonadCompExpr _                        = False
 \end{code}
 
 \begin{code}
@@ -1242,6 +1317,7 @@ pprStmtContext GhciStmt        = ptext (sLit "an interactive GHCi command")
 pprStmtContext DoExpr          = ptext (sLit "a 'do' expression")
 pprStmtContext MDoExpr         = ptext (sLit "an 'mdo' expression")
 pprStmtContext ListComp        = ptext (sLit "a list comprehension")
+pprStmtContext MonadComp       = ptext (sLit "a monad comprehension")
 pprStmtContext PArrComp        = ptext (sLit "an array comprehension")
 
 {-
@@ -1275,6 +1351,7 @@ matchContextErrString (StmtCtxt GhciStmt)        = ptext (sLit "interactive GHCi
 matchContextErrString (StmtCtxt DoExpr)          = ptext (sLit "'do' expression")
 matchContextErrString (StmtCtxt MDoExpr)         = ptext (sLit "'mdo' expression")
 matchContextErrString (StmtCtxt ListComp)        = ptext (sLit "list comprehension")
+matchContextErrString (StmtCtxt MonadComp)       = ptext (sLit "monad comprehension")
 matchContextErrString (StmtCtxt PArrComp)        = ptext (sLit "array comprehension")
 \end{code}
 
@@ -1290,7 +1367,7 @@ pprStmtInCtxt ctxt stmt = hang (ptext (sLit "In a stmt of") <+> pprStmtContext c
                          4 (ppr_stmt stmt)
   where
     -- For Group and Transform Stmts, don't print the nested stmts!
-    ppr_stmt (GroupStmt _ _ by using)         = pprGroupStmt by using
-    ppr_stmt (TransformStmt _ bndrs using by) = pprTransformStmt bndrs using by
-    ppr_stmt stmt                             = pprStmt stmt
+    ppr_stmt (GroupStmt _ _ by using _ _ _)       = pprGroupStmt by using
+    ppr_stmt (TransformStmt _ bndrs using by _ _) = pprTransformStmt bndrs using by
+    ppr_stmt stmt                                 = pprStmt stmt
 \end{code}