Add 'rec' to stmts in a 'do', and deprecate 'mdo'
authorsimonpj@microsoft.com <unknown>
Wed, 28 Oct 2009 13:35:54 +0000 (13:35 +0000)
committersimonpj@microsoft.com <unknown>
Wed, 28 Oct 2009 13:35:54 +0000 (13:35 +0000)
The change is this (see Trac #2798).  Instead of writing

  mdo { a <- getChar
      ; b <- f c
      ; c <- g b
      ; putChar c
      ; return b }

you would write

  do { a <- getChar
     ; rec { b <- f c
           ; c <- g b }
     ; putChar c
     ; return b }

That is,
  * 'mdo' is eliminated
  * 'rec' is added, which groups a bunch of statements
    into a single recursive statement

This 'rec' thing is already present for the arrow notation, so it
makes the two more uniform.  Moreover, 'rec' lets you say more
precisely where the recursion is (if you want to), whereas 'mdo' just
says "there's recursion here somewhere".  Lastly, all this works with
rebindable syntax (which mdo does not).

Currently 'mdo' is enabled by -XRecursiveDo.  So we now deprecate this
flag, with another flag -XDoRec to enable the 'rec' keyword.

Implementation notes:
  * Some changes in Lexer.x
  * All uses of RecStmt now use record syntax

I'm still not really happy with the "rec_ids" and "later_ids" in the
RecStmt constructor, but I don't dare change it without consulting Ross
about the consequences for arrow syntax.

12 files changed:
compiler/deSugar/Coverage.lhs
compiler/deSugar/DsArrows.lhs
compiler/deSugar/DsExpr.lhs
compiler/hsSyn/HsExpr.lhs
compiler/hsSyn/HsUtils.lhs
compiler/main/DynFlags.hs
compiler/parser/Lexer.x
compiler/rename/RnExpr.lhs
compiler/typecheck/TcExpr.lhs-boot
compiler/typecheck/TcHsSyn.lhs
compiler/typecheck/TcMatches.lhs
docs/users_guide/glasgow_exts.xml

index dce7962..2136d01 100644 (file)
@@ -461,13 +461,15 @@ addTickStmt isGuard (GroupStmt (stmts, binderMap) groupByClause) = do
           case x of
             Left a -> f a >>= (return . Left)
             Right b -> g b >>= (return . Right)
-addTickStmt isGuard (RecStmt stmts ids1 ids2 tys dictbinds) = do
-       liftM5 RecStmt 
-               (addTickLStmts isGuard stmts)
-               (return ids1)
-               (return ids2)
-               (return tys)
-               (addTickDictBinds dictbinds)
+addTickStmt isGuard stmt@(RecStmt {})
+  = do { stmts' <- addTickLStmts isGuard (recS_stmts stmt)
+       ; ret'   <- addTickSyntaxExpr hpcSrcSpan (recS_ret_fn stmt)
+       ; mfix'  <- addTickSyntaxExpr hpcSrcSpan (recS_mfix_fn stmt)
+       ; bind'  <- addTickSyntaxExpr hpcSrcSpan (recS_bind_fn stmt)
+       ; dicts' <- addTickDictBinds (recS_dicts stmt)
+       ; return (stmt { recS_stmts = stmts', recS_ret_fn = ret'
+                      , recS_mfix_fn = mfix', recS_bind_fn = bind'
+                      , recS_dicts = dicts' }) }
 
 addTick :: Maybe (Bool -> BoxLabel) -> LHsExpr Id -> TM (LHsExpr Id)
 addTick isGuard e | Just fn <- isGuard = addBinTickLHsExpr fn e
index cead3dd..3ffda53 100644 (file)
@@ -779,7 +779,9 @@ dsCmdStmt ids local_vars env_ids out_ids (LetStmt binds) = do
 --                     first (loop (arr (\((ys1),~(ys2)) -> (ys)) >>> ss)) >>>
 --                     arr (\((xs1),(xs2)) -> (xs')) >>> ss'
 
-dsCmdStmt ids local_vars env_ids out_ids (RecStmt stmts later_ids rec_ids rhss _binds) = do
+dsCmdStmt ids local_vars env_ids out_ids 
+          (RecStmt { recS_stmts = stmts, recS_later_ids = later_ids, recS_rec_ids = rec_ids
+                   , recS_rec_rets = rhss, recS_dicts = _binds }) = do
     let         -- ToDo: ****** binds not desugared; ROSS PLEASE FIX ********
         env2_id_set = mkVarSet out_ids `minusVarSet` mkVarSet later_ids
         env2_ids = varSetElems env2_id_set
index 820bd9a..e89270c 100644 (file)
@@ -49,6 +49,7 @@ import DynFlags
 import StaticFlags
 import CostCentre
 import Id
+import Var
 import PrelInfo
 import DataCon
 import TysWiredIn
@@ -676,13 +677,16 @@ dsDo      :: [LStmt Id]
        -> Type                 -- Type of the whole expression
        -> DsM CoreExpr
 
-dsDo stmts body _result_ty
+dsDo stmts body result_ty
   = goL stmts
   where
+    -- result_ty must be of the form (m b)
+    (m_ty, _b_ty) = tcSplitAppTy result_ty
+
     goL [] = dsLExpr body
-    goL ((L loc stmt):lstmts) = putSrcSpanDs loc (go stmt lstmts)
+    goL ((L loc stmt):lstmts) = putSrcSpanDs loc (go loc stmt lstmts)
   
-    go (ExprStmt rhs then_expr _) stmts
+    go _ (ExprStmt rhs then_expr _) stmts
       = do { rhs2 <- dsLExpr rhs
            ; case tcSplitAppTy_maybe (exprType rhs2) of
                 Just (container_ty, returning_ty) -> warnDiscardedDoBindings rhs container_ty returning_ty
@@ -691,23 +695,52 @@ dsDo stmts body _result_ty
           ; rest <- goL stmts
           ; return (mkApps then_expr2 [rhs2, rest]) }
     
-    go (LetStmt binds) stmts
+    go _ (LetStmt binds) stmts
       = do { rest <- goL stmts
           ; dsLocalBinds binds rest }
 
-    go (BindStmt pat rhs bind_op fail_op) stmts
-      = 
-       do  { body     <- goL stmts
-           ; rhs'     <- dsLExpr rhs
-          ; bind_op' <- dsExpr bind_op
-          ; var   <- selectSimpleMatchVarL pat
-          ; let bind_ty = exprType bind_op'    -- rhs -> (pat -> res1) -> res2
-                res1_ty = funResultTy (funArgTy (funResultTy bind_ty))
-          ; match <- matchSinglePat (Var var) (StmtCtxt DoExpr) pat
-                                    res1_ty (cantFailMatchResult body)
-          ; match_code <- handle_failure pat match fail_op
-          ; return (mkApps bind_op' [rhs', Lam var match_code]) }
+    go _ (BindStmt pat rhs bind_op fail_op) stmts
+      = do  { body     <- goL stmts
+            ; rhs'     <- dsLExpr rhs
+           ; bind_op' <- dsExpr bind_op
+           ; var   <- selectSimpleMatchVarL pat
+           ; let bind_ty = exprType bind_op'   -- rhs -> (pat -> res1) -> res2
+                 res1_ty = funResultTy (funArgTy (funResultTy bind_ty))
+           ; match <- matchSinglePat (Var var) (StmtCtxt DoExpr) pat
+                                     res1_ty (cantFailMatchResult body)
+           ; match_code <- handle_failure pat match fail_op
+           ; return (mkApps bind_op' [rhs', Lam var match_code]) }
     
+    go loc (RecStmt { recS_stmts = rec_stmts, recS_later_ids = later_ids
+                    , recS_rec_ids = rec_ids, recS_ret_fn = return_op
+                    , recS_mfix_fn = mfix_op, recS_bind_fn = bind_op
+                    , recS_rec_rets = rec_rets, recS_dicts = binds }) stmts 
+      = ASSERT( length rec_ids > 0 )
+        goL (new_bind_stmt : let_stmt : stmts)
+      where
+        -- returnE <- dsExpr return_id
+        -- mfixE <- dsExpr mfix_id
+        new_bind_stmt = L loc $ BindStmt (mkLHsPatTup later_pats) mfix_app
+                                         bind_op 
+                                            noSyntaxExpr  -- Tuple cannot fail
+
+        let_stmt = L loc $ LetStmt (HsValBinds (ValBindsOut [(Recursive, binds)] []))
+
+        tup_ids      = rec_ids ++ filterOut (`elem` rec_ids) later_ids
+        rec_tup_pats = map nlVarPat tup_ids
+        later_pats   = rec_tup_pats
+        rets         = map noLoc rec_rets
+
+        mfix_app   = nlHsApp (noLoc mfix_op) mfix_arg
+        mfix_arg   = noLoc $ HsLam (MatchGroup [mkSimpleMatch [mfix_pat] body]
+                                             (mkFunTy tup_ty body_ty))
+        mfix_pat   = noLoc $ LazyPat $ mkLHsPatTup rec_tup_pats
+        body       = noLoc $ HsDo DoExpr rec_stmts return_app body_ty
+        return_app = nlHsApp (noLoc return_op) (mkLHsTupleExpr rets)
+       body_ty    = mkAppTy m_ty tup_ty
+        tup_ty     = mkCoreTupTy (map idType tup_ids)
+                  -- mkCoreTupTy deals with singleton case
+
     -- In a do expression, pattern-match failure just calls
     -- the monadic 'fail' rather than throwing an exception
     handle_failure pat match fail_op
@@ -774,10 +807,11 @@ dsMDo tbl stmts body result_ty
           ; return (mkApps (Var bind_id) [Type (hsLPatType pat), Type b_ty, 
                                             rhs', Lam var match_code]) }
     
-    go loc (RecStmt rec_stmts later_ids rec_ids rec_rets binds) stmts
+    go loc (RecStmt rec_stmts later_ids rec_ids _ _ _ rec_rets binds) stmts
       = ASSERT( length rec_ids > 0 )
         ASSERT( length rec_ids == length rec_rets )
-       goL (new_bind_stmt : let_stmt : stmts)
+        pprTrace "dsMDo" (ppr later_ids) $
+        goL (new_bind_stmt : let_stmt : stmts)
       where
         new_bind_stmt = L loc $ mkBindStmt (mk_tup_pat later_pats) mfix_app
        let_stmt = L loc $ LetStmt (HsValBinds (ValBindsOut [(Recursive, binds)] []))
index cdf7322..c3f38ca 100644 (file)
@@ -847,26 +847,41 @@ data StmtLR idL idR
   -- the names which they group over in statements
 
   -- Recursive statement (see Note [RecStmt] below)
-  | RecStmt  [LStmtLR idL idR]
-             --- The next two fields are only valid after renaming
-             [idR] -- The ids are a subset of the variables bound by the
-                   -- stmts that are used in stmts that follow the RecStmt
-
-             [idR] -- Ditto, but these variables are the "recursive" ones,
-                   -- that are used before they are bound in the stmts of
-                   -- the RecStmt. From a type-checking point of view,
-                   -- these ones have to be monomorphic
-
-             --- These fields are only valid after typechecking
-             [PostTcExpr]       -- These expressions correspond 1-to-1 with
-                                -- the "recursive" [id], and are the
-                                -- expressions that should be returned by
-                                -- the recursion.
-                                -- They may not quite be the Ids themselves,
-                                -- because the Id may be *polymorphic*, but
-                                -- the returned thing has to be *monomorphic*.
-             (DictBinds idR)    -- Method bindings of Ids bound by the
-                                -- RecStmt, and used afterwards
+  | RecStmt
+     { recS_stmts :: [LStmtLR idL idR]
+
+        -- The next two fields are only valid after renaming
+     , recS_later_ids :: [idR] -- The ids are a subset of the variables bound by the
+                              -- stmts that are used in stmts that follow the RecStmt
+
+     , recS_rec_ids :: [idR]   -- Ditto, but these variables are the "recursive" ones,
+                              -- that are used before they are bound in the stmts of
+                              -- the RecStmt. 
+
+       -- An Id can be in both groups
+       -- Both sets of Ids are (now) treated monomorphically
+       -- The only reason they are separate is becuase the DsArrows 
+       -- code uses them separately, and I don't understand it well
+       -- enough to change it
+
+       -- Rebindable syntax
+     , recS_bind_fn :: SyntaxExpr idR -- The bind function
+     , recS_ret_fn  :: SyntaxExpr idR -- The return function
+     , recS_mfix_fn :: SyntaxExpr idR -- The mfix function
+
+        -- These fields are only valid after typechecking
+     , recS_rec_rets :: [PostTcExpr] -- These expressions correspond 1-to-1 with
+                                     -- recS_rec_ids, and are the
+                                     -- expressions that should be returned by
+                                     -- the recursion.
+                                     -- They may not quite be the Ids themselves,
+                                     -- because the Id may be *polymorphic*, but
+                                     -- the returned thing has to be *monomorphic*, 
+                                    -- so they may be type applications
+
+      , recS_dicts :: DictBinds idR  -- Method bindings of Ids bound by the
+                                     -- RecStmt, and used afterwards
+      }
 \end{code}
 
 ExprStmts are a bit tricky, because what they mean
@@ -894,8 +909,8 @@ depends on the context.  Consider the following contexts:
 
 Array comprehensions are handled like list comprehensions -=chak
 
-Note [RecStmt]
-~~~~~~~~~~~~~~
+Note [How RecStmt works]
+~~~~~~~~~~~~~~~~~~~~~~~~
 Example:
         HsDo [ BindStmt x ex
 
@@ -917,6 +932,17 @@ Here, the RecStmt binds a,b,c; but
 Nota Bene: the two a's have different types, even though they
 have the same Name.
 
+Note [Typing a RecStmt]
+~~~~~~~~~~~~~~~~~~~~~~~
+A (RecStmt stmts) types as if you had written
+
+  (v1,..,vn, _, ..., _) <- mfix (\~(_, ..., _, r1, ..., rm) ->
+                                do { stmts 
+                                   ; return (v1,..vn, r1, ..., rm) })
+
+where v1..vn are the later_ids
+      r1..rm are the rec_ids
+
 
 \begin{code}
 instance (OutputableBndr idL, OutputableBndr idR) => Outputable (StmtLR idL idR) where
@@ -934,7 +960,11 @@ pprStmt (TransformStmt (stmts, _) usingExpr maybeByExpr)
         byExprDoc = maybe empty (\byExpr -> hsep [ptext (sLit "by"), ppr byExpr]) maybeByExpr
 pprStmt (GroupStmt (stmts, _) groupByClause) = (hsep [stmtsDoc, ptext (sLit "then group"), pprGroupByClause groupByClause])
   where stmtsDoc = interpp'SP stmts
-pprStmt (RecStmt segment _ _ _ _) = ptext (sLit "rec") <+> braces (vcat (map ppr segment))
+pprStmt (RecStmt { recS_stmts = segment, recS_rec_ids = rec_ids, recS_later_ids = later_ids })
+  = ptext (sLit "rec") <+> 
+    vcat [ braces (vcat (map ppr segment))
+         , ifPprDebug (vcat [ ptext (sLit "rec_ids=") <> ppr rec_ids
+                            , ptext (sLit "later_ids=") <> ppr later_ids])]
 
 pprGroupByClause :: (OutputableBndr id) => GroupByClause id -> SDoc
 pprGroupByClause (GroupByNothing usingExpr) = hsep [ptext (sLit "using"), ppr usingExpr]
index d793a3b..66d9ed3 100644 (file)
@@ -139,7 +139,9 @@ mkGroupByUsingStmt :: [LStmt idL] -> LHsExpr idR -> LHsExpr idR -> StmtLR idL id
 
 mkExprStmt :: LHsExpr idR -> StmtLR idL idR
 mkBindStmt :: LPat idL -> LHsExpr idR -> StmtLR idL idR
-mkRecStmt  :: [LStmtLR idL idR] -> StmtLR idL idR
+
+emptyRecStmt :: StmtLR idL idR
+mkRecStmt    :: [LStmtLR idL idR] -> StmtLR idL idR
 
 
 mkHsIntegral   i       = OverLit (HsIntegral   i)  noRebindableInfo noSyntaxExpr
@@ -163,7 +165,13 @@ mkGroupByUsingStmt stmts byExpr usingExpr = GroupStmt (stmts, []) (GroupBySometh
 
 mkExprStmt expr            = ExprStmt expr noSyntaxExpr placeHolderType
 mkBindStmt pat expr = BindStmt pat expr noSyntaxExpr noSyntaxExpr
-mkRecStmt stmts            = RecStmt stmts [] [] [] emptyLHsBinds
+
+emptyRecStmt = RecStmt { recS_stmts = [], recS_later_ids = [], recS_rec_ids = []
+                       , recS_ret_fn = noSyntaxExpr, recS_mfix_fn = noSyntaxExpr
+                      , recS_bind_fn = noSyntaxExpr
+                       , recS_rec_rets = [], recS_dicts = emptyLHsBinds }
+
+mkRecStmt stmts = emptyRecStmt { recS_stmts = stmts }
 
 -------------------------------
 --- A useful function for building @OpApps@.  The operator is always a
@@ -414,8 +422,8 @@ collectStmtBinders (ExprStmt _ _ _)     = []
 collectStmtBinders (ParStmt xs)         = collectLStmtsBinders
                                         $ concatMap fst xs
 collectStmtBinders (TransformStmt (stmts, _) _ _) = collectLStmtsBinders stmts
-collectStmtBinders (GroupStmt (stmts, _) _)     = collectLStmtsBinders stmts
-collectStmtBinders (RecStmt ss _ _ _ _) = collectLStmtsBinders ss
+collectStmtBinders (GroupStmt (stmts, _) _)       = collectLStmtsBinders stmts
+collectStmtBinders (RecStmt { recS_stmts = ss })  = collectLStmtsBinders ss
 \end{code}
 
 
index adee723..f7a5d4a 100644 (file)
@@ -246,6 +246,7 @@ data DynFlag
    | Opt_TransformListComp
    | Opt_GeneralizedNewtypeDeriving
    | Opt_RecursiveDo
+   | Opt_DoRec
    | Opt_PostfixOperators
    | Opt_TupleSections
    | Opt_PatternGuards
@@ -1650,7 +1651,7 @@ mkFlag turnOn flagPrefix f (name, dynflag, deprecated)
 
 deprecatedForLanguage :: String -> Bool -> Deprecated
 deprecatedForLanguage lang turn_on
-    = Deprecated ("use -X"  ++ flag ++ " or pragma {-# LANGUAGE " ++ flag ++ "#-} instead")
+    = Deprecated ("use -X"  ++ flag ++ " or pragma {-# LANGUAGE " ++ flag ++ " #-} instead")
     where 
       flag | turn_on    = lang
            | otherwise = "No"++lang
@@ -1801,7 +1802,9 @@ xFlags = [
   ( "RankNTypes",                       Opt_RankNTypes, const Supported ),
   ( "ImpredicativeTypes",               Opt_ImpredicativeTypes, const Supported ),
   ( "TypeOperators",                    Opt_TypeOperators, const Supported ),
-  ( "RecursiveDo",                      Opt_RecursiveDo, const Supported ),
+  ( "RecursiveDo",                      Opt_RecursiveDo,
+    deprecatedForLanguage "DoRec"),
+  ( "DoRec",                            Opt_DoRec, const Supported ),
   ( "Arrows",                           Opt_Arrows, const Supported ),
   ( "PArr",                             Opt_PArr, const Supported ),
   ( "TemplateHaskell",                  Opt_TemplateHaskell, const Supported ),
@@ -1911,7 +1914,7 @@ glasgowExtsFlags = [
            , Opt_LiberalTypeSynonyms
            , Opt_RankNTypes
            , Opt_TypeOperators
-           , Opt_RecursiveDo
+           , Opt_DoRec
            , Opt_ParallelListComp
            , Opt_EmptyDataDecls
            , Opt_KindSignatures
index fe5c693..3a93ba1 100644 (file)
@@ -662,7 +662,7 @@ reservedWordsFM = listToUFM $
        ( "ccall",      ITccallconv,     bit ffiBit),
        ( "prim",       ITprimcallconv,  bit ffiBit),
 
-       ( "rec",        ITrec,           bit arrowsBit),
+       ( "rec",        ITrec,           bit recBit),
        ( "proc",       ITproc,          bit arrowsBit)
      ]
 
@@ -1672,6 +1672,8 @@ rawTokenStreamBit :: Int
 rawTokenStreamBit = 20 -- producing a token stream with all comments included
 newQualOpsBit :: Int
 newQualOpsBit = 21 -- Haskell' qualified operator syntax, e.g. Prelude.(+)
+recBit :: Int
+recBit = 22 -- rec
 
 always :: Int -> Bool
 always           _     = True
@@ -1766,6 +1768,8 @@ mkPState buf loc flags  =
               .|. magicHashBit      `setBitIf` dopt Opt_MagicHash    flags
               .|. kindSigsBit       `setBitIf` dopt Opt_KindSignatures flags
               .|. recursiveDoBit    `setBitIf` dopt Opt_RecursiveDo flags
+              .|. recBit            `setBitIf` dopt Opt_DoRec  flags
+              .|. recBit            `setBitIf` dopt Opt_Arrows flags
               .|. unicodeSyntaxBit  `setBitIf` dopt Opt_UnicodeSyntax flags
               .|. unboxedTuplesBit  `setBitIf` dopt Opt_UnboxedTuples flags
               .|. standaloneDerivingBit `setBitIf` dopt Opt_StandaloneDeriving flags
index 4b263e2..4ce7182 100644 (file)
@@ -32,9 +32,7 @@ import RnTypes                ( rnHsTypeFVs, rnSplice, checkTH,
 import RnPat
 import DynFlags                ( DynFlag(..) )
 import BasicTypes      ( FixityDirection(..) )
-import PrelNames       ( hasKey, assertIdKey, assertErrorName,
-                         loopAName, choiceAName, appAName, arrAName, composeAName, firstAName,
-                         negateName, thenMName, bindMName, failMName, groupWithName )
+import PrelNames
 
 import Name
 import NameSet
@@ -454,8 +452,8 @@ convertOpFormsStmt (BindStmt pat cmd _ _)
   = BindStmt pat (convertOpFormsLCmd cmd) noSyntaxExpr noSyntaxExpr
 convertOpFormsStmt (ExprStmt cmd _ _)
   = ExprStmt (convertOpFormsLCmd cmd) noSyntaxExpr placeHolderType
-convertOpFormsStmt (RecStmt stmts lvs rvs es binds)
-  = RecStmt (map (fmap convertOpFormsStmt) stmts) lvs rvs es binds
+convertOpFormsStmt stmt@(RecStmt { recS_stmts = stmts })
+  = stmt { recS_stmts = map (fmap convertOpFormsStmt) stmts }
 convertOpFormsStmt stmt = stmt
 
 convertOpFormsMatch :: MatchGroup id -> MatchGroup id
@@ -537,14 +535,13 @@ methodNamesLStmt :: Located (StmtLR Name Name) -> FreeVars
 methodNamesLStmt = methodNamesStmt . unLoc
 
 methodNamesStmt :: StmtLR Name Name -> FreeVars
-methodNamesStmt (ExprStmt cmd _ _)     = methodNamesLCmd cmd
-methodNamesStmt (BindStmt _ cmd _ _) = methodNamesLCmd cmd
-methodNamesStmt (RecStmt stmts _ _ _ _)
-  = methodNamesStmts stmts `addOneFV` loopAName
-methodNamesStmt (LetStmt _)  = emptyFVs
-methodNamesStmt (ParStmt _) = emptyFVs
-methodNamesStmt (TransformStmt _ _ _) = emptyFVs
-methodNamesStmt (GroupStmt _ _) = emptyFVs
+methodNamesStmt (ExprStmt cmd _ _)               = methodNamesLCmd cmd
+methodNamesStmt (BindStmt _ cmd _ _)             = methodNamesLCmd cmd
+methodNamesStmt (RecStmt { recS_stmts = stmts }) = methodNamesStmts stmts `addOneFV` loopAName
+methodNamesStmt (LetStmt _)                      = emptyFVs
+methodNamesStmt (ParStmt _)                      = emptyFVs
+methodNamesStmt (TransformStmt _ _ _)            = emptyFVs
+methodNamesStmt (GroupStmt _ _)                  = emptyFVs
    -- ParStmt, TransformStmt and GroupStmt can't occur in commands, but it's not convenient to error 
    -- here so we just do what's convenient
 \end{code}
@@ -636,67 +633,95 @@ rnStmts ctxt        = rnNormalStmts ctxt
 rnNormalStmts :: HsStmtContext Name -> [LStmt RdrName]
              -> RnM (thing, FreeVars)
              -> RnM (([LStmt Name], thing), FreeVars)  
--- Used for cases *other* than recursive mdo
--- Implements nested scopes
-
 rnNormalStmts _ [] thing_inside 
   = do { (thing, fvs) <- thing_inside
        ; return (([],thing), fvs) } 
 
-rnNormalStmts ctxt (L loc stmt : stmts) thing_inside
-  = do { ((stmt', (stmts', thing)), fvs) <- rnStmt ctxt stmt $
-            rnNormalStmts ctxt stmts thing_inside
-       ; return (((L loc stmt' : stmts'), thing), fvs) }
+rnNormalStmts ctxt (stmt@(L loc _) : stmts) thing_inside
+  = do { ((stmts1, (stmts2, thing)), fvs) 
+            <- setSrcSpan loc $
+               rnStmt ctxt stmt $
+               rnNormalStmts ctxt stmts thing_inside
+       ; return (((stmts1 ++ stmts2), thing), fvs) }
 
 
-rnStmt :: HsStmtContext Name -> Stmt RdrName
+rnStmt :: HsStmtContext Name -> LStmt RdrName
        -> RnM (thing, FreeVars)
-       -> RnM ((Stmt Name, thing), FreeVars)
+       -> RnM (([LStmt Name], thing), FreeVars)
 
-rnStmt _ (ExprStmt expr _ _) thing_inside
+rnStmt _ (L loc (ExprStmt expr _ _)) thing_inside
   = do { (expr', fv_expr) <- rnLExpr expr
        ; (then_op, fvs1)  <- lookupSyntaxName thenMName
        ; (thing, fvs2)    <- thing_inside
-       ; return ((ExprStmt expr' then_op placeHolderType, thing),
+       ; return (([L loc (ExprStmt expr' then_op placeHolderType)], thing),
                  fv_expr `plusFV` fvs1 `plusFV` fvs2) }
 
-rnStmt ctxt (BindStmt pat expr _ _) thing_inside
+rnStmt ctxt (L loc (BindStmt pat expr _ _)) thing_inside
   = do { (expr', fv_expr) <- rnLExpr expr
                -- The binders do not scope over the expression
        ; (bind_op, fvs1) <- lookupSyntaxName bindMName
        ; (fail_op, fvs2) <- lookupSyntaxName failMName
        ; rnPats (StmtCtxt ctxt) [pat] $ \ [pat'] -> do
        { (thing, fvs3) <- thing_inside
-       ; return ((BindStmt pat' expr' bind_op fail_op, thing),
+       ; return (([L loc (BindStmt pat' expr' bind_op fail_op)], thing),
                  fv_expr `plusFV` fvs1 `plusFV` fvs2 `plusFV` fvs3) }}
        -- fv_expr shouldn't really be filtered by the rnPatsAndThen
        -- but it does not matter because the names are unique
 
-rnStmt ctxt (LetStmt binds) thing_inside 
+rnStmt ctxt (L loc (LetStmt binds)) thing_inside 
   = do { checkLetStmt ctxt binds
        ; rnLocalBindsAndThen binds $ \binds' -> do
        { (thing, fvs) <- thing_inside
-        ; return ((LetStmt binds', thing), fvs) }  }
+        ; return (([L loc (LetStmt binds')], thing), fvs) }  }
 
-rnStmt ctxt (RecStmt rec_stmts _ _ _ _) thing_inside
+rnStmt ctxt (L _ (RecStmt { recS_stmts = rec_stmts })) thing_inside
   = do { checkRecStmt ctxt
-       ; rn_rec_stmts_and_then rec_stmts       $ \ segs -> do
-       { (thing, fvs) <- thing_inside
+
+       -- Step1: Bring all the binders of the mdo into scope
+       -- (Remember that this also removes the binders from the
+       -- finally-returned free-vars.)
+       -- And rename each individual stmt, making a
+       -- singleton segment.  At this stage the FwdRefs field
+       -- isn't finished: it's empty for all except a BindStmt
+       -- for which it's the fwd refs within the bind itself
+       -- (This set may not be empty, because we're in a recursive 
+       -- context.)
+        ; rn_rec_stmts_and_then rec_stmts      $ \ segs -> do
+
+       { (thing, fvs_later) <- thing_inside
+       ; (return_op, fvs1)  <- lookupSyntaxName returnMName
+       ; (mfix_op,   fvs2)  <- lookupSyntaxName mfixName
+       ; (bind_op,   fvs3)  <- lookupSyntaxName bindMName
        ; let
+               -- Step 2: Fill in the fwd refs.
+               --         The segments are all singletons, but their fwd-ref
+               --         field mentions all the things used by the segment
+               --         that are bound after their use
            segs_w_fwd_refs          = addFwdRefs segs
-           (ds, us, fs, rec_stmts') = unzip4 segs_w_fwd_refs
-           later_vars = nameSetToList (plusFVs ds `intersectNameSet` fvs)
-           fwd_vars   = nameSetToList (plusFVs fs)
-           uses       = plusFVs us
-           rec_stmt   = RecStmt rec_stmts' later_vars fwd_vars [] emptyLHsBinds
-       ; return ((rec_stmt, thing), uses `plusFV` fvs) } }
-
-rnStmt ctxt (ParStmt segs) thing_inside
+
+               -- Step 3: Group together the segments to make bigger segments
+               --         Invariant: in the result, no segment uses a variable
+               --                    bound in a later segment
+           grouped_segs = glomSegments segs_w_fwd_refs
+
+               -- Step 4: Turn the segments into Stmts
+               --         Use RecStmt when and only when there are fwd refs
+               --         Also gather up the uses from the end towards the
+               --         start, so we can tell the RecStmt which things are
+               --         used 'after' the RecStmt
+           empty_rec_stmt = emptyRecStmt { recS_ret_fn  = return_op
+                                          , recS_mfix_fn = mfix_op
+                                          , recS_bind_fn = bind_op }
+           (rec_stmts', fvs) = segsToStmts empty_rec_stmt grouped_segs fvs_later
+
+       ; return ((rec_stmts', thing), fvs `plusFV` fvs1 `plusFV` fvs2 `plusFV` fvs3) } }
+
+rnStmt ctxt (L loc (ParStmt segs)) thing_inside
   = do { checkParStmt ctxt
        ; ((segs', thing), fvs) <- rnParallelStmts (ParStmtCtxt ctxt) segs thing_inside
-       ; return ((ParStmt segs', thing), fvs) }
+       ; return (([L loc (ParStmt segs')], thing), fvs) }
 
-rnStmt ctxt (TransformStmt (stmts, _) usingExpr maybeByExpr) thing_inside = do
+rnStmt ctxt (L loc (TransformStmt (stmts, _) usingExpr maybeByExpr)) thing_inside = do
     checkTransformStmt ctxt
     
     (usingExpr', fv_usingExpr) <- rnLExpr usingExpr
@@ -707,14 +732,15 @@ rnStmt ctxt (TransformStmt (stmts, _) usingExpr maybeByExpr) thing_inside = do
             
             return ((maybeByExpr', thing), fv_maybeByExpr `plusFV` fv_thing)
     
-    return ((TransformStmt (stmts', binders) usingExpr' maybeByExpr', thing), fv_usingExpr `plusFV` fvs)
+    return (([L loc (TransformStmt (stmts', binders) usingExpr' maybeByExpr')], thing), 
+             fv_usingExpr `plusFV` fvs)
   where
     rnMaybeLExpr Nothing = return (Nothing, emptyFVs)
     rnMaybeLExpr (Just expr) = do
         (expr', fv_expr) <- rnLExpr expr
         return (Just expr', fv_expr)
         
-rnStmt ctxt (GroupStmt (stmts, _) groupByClause) thing_inside = do
+rnStmt ctxt (L loc (GroupStmt (stmts, _) groupByClause)) thing_inside = do
     checkTransformStmt ctxt
     
     -- We must rename the using expression in the context before the transform is begun
@@ -771,7 +797,7 @@ rnStmt ctxt (GroupStmt (stmts, _) groupByClause) thing_inside = do
             return ((groupByClause', usedBinderMap, thing), fv_groupByClause `plusFV` real_fv_thing)
     
     traceRn (text "rnStmt: implicitly rebound these used binders:" <+> ppr usedBinderMap)
-    return ((GroupStmt (stmts', usedBinderMap) groupByClause', thing), fvs)
+    return (([L loc (GroupStmt (stmts', usedBinderMap) groupByClause')], thing), fvs)
   
 rnNormalStmtsAndFindUsedBinders :: HsStmtContext Name 
           -> [LStmt RdrName]
@@ -858,39 +884,12 @@ rnMDoStmts :: [LStmt RdrName]
           -> RnM (thing, FreeVars)
           -> RnM (([LStmt Name], thing), FreeVars)     
 rnMDoStmts stmts thing_inside
-  =    -- Step1: Bring all the binders of the mdo into scope
-       -- (Remember that this also removes the binders from the
-       -- finally-returned free-vars.)
-       -- And rename each individual stmt, making a
-       -- singleton segment.  At this stage the FwdRefs field
-       -- isn't finished: it's empty for all except a BindStmt
-       -- for which it's the fwd refs within the bind itself
-       -- (This set may not be empty, because we're in a recursive 
-       -- context.)
-     rn_rec_stmts_and_then stmts $ \ segs -> do {
-
-       ; (thing, fvs_later) <- thing_inside
-
-       ; let
-       -- Step 2: Fill in the fwd refs.
-       --         The segments are all singletons, but their fwd-ref
-       --         field mentions all the things used by the segment
-       --         that are bound after their use
-           segs_w_fwd_refs = addFwdRefs segs
-
-       -- Step 3: Group together the segments to make bigger segments
-       --         Invariant: in the result, no segment uses a variable
-       --                    bound in a later segment
+  = rn_rec_stmts_and_then stmts $ \ segs -> do
+    { (thing, fvs_later) <- thing_inside
+    ; let   segs_w_fwd_refs = addFwdRefs segs
            grouped_segs = glomSegments segs_w_fwd_refs
-
-       -- Step 4: Turn the segments into Stmts
-       --         Use RecStmt when and only when there are fwd refs
-       --         Also gather up the uses from the end towards the
-       --         start, so we can tell the RecStmt which things are
-       --         used 'after' the RecStmt
-           (stmts', fvs) = segsToStmts grouped_segs fvs_later
-
-       ; return ((stmts', thing), fvs) }
+           (stmts', fvs) = segsToStmts emptyRecStmt grouped_segs fvs_later
+    ; return ((stmts', thing), fvs) }
 
 ---------------------------------------------
 
@@ -957,7 +956,8 @@ rn_rec_stmt_lhs fix_env (L loc (LetStmt (HsValBinds binds)))
                  emptyFVs
                  )]
 
-rn_rec_stmt_lhs fix_env (L _ (RecStmt stmts _ _ _ _))  -- Flatten Rec inside Rec
+-- XXX Do we need to do something with the return and mfix names?
+rn_rec_stmt_lhs fix_env (L _ (RecStmt { recS_stmts = stmts })) -- Flatten Rec inside Rec
     = rn_rec_stmts_lhs fix_env stmts
 
 rn_rec_stmt_lhs _ stmt@(L _ (ParStmt _))       -- Syntactically illegal in mdo
@@ -1020,16 +1020,16 @@ rn_rec_stmt all_bndrs (L loc (LetStmt (HsValBinds binds'))) _ = do
            emptyNameSet, L loc (LetStmt (HsValBinds binds')))]
 
 -- no RecStmt case becuase they get flattened above when doing the LHSes
-rn_rec_stmt _ stmt@(L _ (RecStmt _ _ _ _ _)) _ 
+rn_rec_stmt _ stmt@(L _ (RecStmt {})) _
   = pprPanic "rn_rec_stmt: RecStmt" (ppr stmt)
 
-rn_rec_stmt _ stmt@(L _ (ParStmt _)) _ -- Syntactically illegal in mdo
+rn_rec_stmt _ stmt@(L _ (ParStmt {})) _        -- Syntactically illegal in mdo
   = pprPanic "rn_rec_stmt: ParStmt" (ppr stmt)
 
-rn_rec_stmt _ stmt@(L _ (TransformStmt _ _ _)) _       -- Syntactically illegal in mdo
+rn_rec_stmt _ stmt@(L _ (TransformStmt {})) _  -- Syntactically illegal in mdo
   = pprPanic "rn_rec_stmt: TransformStmt" (ppr stmt)
 
-rn_rec_stmt _ stmt@(L _ (GroupStmt _ _)) _     -- Syntactically illegal in mdo
+rn_rec_stmt _ stmt@(L _ (GroupStmt {})) _      -- Syntactically illegal in mdo
   = pprPanic "rn_rec_stmt: GroupStmt" (ppr stmt)
 
 rn_rec_stmt _ (L _ (LetStmt EmptyLocalBinds)) _
@@ -1120,23 +1120,24 @@ glomSegments ((defs,uses,fwds,stmt) : segs)
 
 
 ----------------------------------------------------
-segsToStmts :: [Segment [LStmt Name]] 
+segsToStmts :: Stmt Name               -- A RecStmt with the SyntaxOps filled in
+            -> [Segment [LStmt Name]] 
            -> FreeVars                 -- Free vars used 'later'
            -> ([LStmt Name], FreeVars)
 
-segsToStmts [] fvs_later = ([], fvs_later)
-segsToStmts ((defs, uses, fwds, ss) : segs) fvs_later
+segsToStmts _ [] fvs_later = ([], fvs_later)
+segsToStmts empty_rec_stmt ((defs, uses, fwds, ss) : segs) fvs_later
   = ASSERT( not (null ss) )
     (new_stmt : later_stmts, later_uses `plusFV` uses)
   where
-    (later_stmts, later_uses) = segsToStmts segs fvs_later
+    (later_stmts, later_uses) = segsToStmts empty_rec_stmt segs fvs_later
     new_stmt | non_rec  = head ss
-            | otherwise = L (getLoc (head ss)) $ 
-                          RecStmt ss (nameSetToList used_later) (nameSetToList fwds) 
-                                     [] emptyLHsBinds
-            where
-              non_rec    = isSingleton ss && isEmptyNameSet fwds
-              used_later = defs `intersectNameSet` later_uses
+            | otherwise = L (getLoc (head ss)) rec_stmt 
+    rec_stmt = empty_rec_stmt { recS_stmts     = ss
+                              , recS_later_ids = nameSetToList used_later
+                              , recS_rec_ids   = nameSetToList fwds }
+    non_rec    = isSingleton ss && isEmptyNameSet fwds
+    used_later = defs `intersectNameSet` later_uses
                                -- The ones needed after the RecStmt
 \end{code}
 
@@ -1187,10 +1188,7 @@ checkLetStmt _ctxt            _binds            = return ()
 ---------
 checkRecStmt :: HsStmtContext Name -> RnM ()
 checkRecStmt (MDoExpr {}) = return ()  -- Recursive stmt ok in 'mdo'
-checkRecStmt (DoExpr {})  = return ()  -- ..and in 'do' but only because of arrows:
-                                       --   proc x -> do { ...rec... }
-                                       -- We don't have enough context to distinguish this situation here
-                                       --      so we leave it to the type checker
+checkRecStmt (DoExpr {})  = return ()  -- and in 'do'
 checkRecStmt ctxt        = addErr msg
   where
     msg = ptext (sLit "Illegal 'rec' stmt in") <+> pprStmtContext ctxt
index ec36034..6a75a10 100644 (file)
@@ -10,7 +10,7 @@ tcPolyExpr ::
        -> BoxySigmaType
        -> TcM (LHsExpr TcId)
 
-tcMonoExpr :: 
+tcMonoExpr, tcMonoExprNC :: 
          LHsExpr Name
        -> BoxyRhoType
        -> TcM (LHsExpr TcId)
index de572ba..fbe3c9f 100644 (file)
@@ -682,21 +682,26 @@ zonkStmt env (ParStmt stmts_w_bndrs)
     zonk_branch (stmts, bndrs) = zonkStmts env stmts   `thenM` \ (env1, new_stmts) ->
                                 returnM (new_stmts, zonkIdOccs env1 bndrs)
 
-zonkStmt env (RecStmt segStmts lvs rvs rets binds)
-  = zonkIdBndrs env rvs                `thenM` \ new_rvs ->
-    let
-       env1 = extendZonkEnv env new_rvs
-    in
-    zonkStmts env1 segStmts    `thenM` \ (env2, new_segStmts) ->
+zonkStmt env (RecStmt { recS_stmts = segStmts, recS_later_ids = lvs, recS_rec_ids = rvs
+                      , recS_ret_fn = ret_id, recS_mfix_fn = mfix_id, recS_bind_fn = bind_id
+                      , recS_rec_rets = rets, recS_dicts = binds })
+  = do { new_rvs <- zonkIdBndrs env rvs
+       ; new_lvs <- zonkIdBndrs env lvs
+       ; new_ret_id  <- zonkExpr env ret_id
+       ; new_mfix_id <- zonkExpr env mfix_id
+       ; new_bind_id <- zonkExpr env bind_id
+       ; let env1 = extendZonkEnv env new_rvs
+       ; (env2, new_segStmts) <- zonkStmts env1 segStmts
        -- Zonk the ret-expressions in an envt that 
        -- has the polymorphic bindings in the envt
-    mapM (zonkExpr env2) rets  `thenM` \ new_rets ->
-    let
-       new_lvs = zonkIdOccs env2 lvs
-       env3 = extendZonkEnv env new_lvs        -- Only the lvs are needed
-    in
-    zonkRecMonoBinds env3 binds        `thenM` \ (env4, new_binds) ->
-    returnM (env4, RecStmt new_segStmts new_lvs new_rvs new_rets new_binds)
+       ; new_rets <- mapM (zonkExpr env2) rets
+       ; let env3 = extendZonkEnv env new_lvs  -- Only the lvs are needed
+       ; (env4, new_binds) <- zonkRecMonoBinds env3 binds
+       ; return (env4,
+                 RecStmt { recS_stmts = new_segStmts, recS_later_ids = new_lvs
+                         , recS_rec_ids = new_rvs, recS_ret_fn = new_ret_id
+                         , recS_mfix_fn = new_mfix_id, recS_bind_fn = new_bind_id
+                         , recS_rec_rets = new_rets, recS_dicts = new_binds }) }
 
 zonkStmt env (ExprStmt expr then_op ty)
   = zonkLExpr env expr         `thenM` \ new_expr ->
index 3e0e8c0..37b8cbe 100644 (file)
@@ -12,7 +12,8 @@ module TcMatches ( tcMatchesFun, tcGRHSsPat, tcMatchesCase, tcMatchLambda,
                   tcDoStmt, tcMDoStmt, tcGuardStmt
        ) where
 
-import {-# SOURCE #-}  TcExpr( tcSyntaxOp, tcInferRhoNC, tcMonoExpr, tcPolyExpr )
+import {-# SOURCE #-}  TcExpr( tcSyntaxOp, tcInferRhoNC, 
+                                tcMonoExpr, tcMonoExprNC, tcPolyExpr )
 
 import HsSyn
 import TcRnMonad
@@ -24,6 +25,7 @@ import TcType
 import TcBinds
 import TcUnify
 import TcSimplify
+import MkCore
 import Name
 import TysWiredIn
 import PrelNames
@@ -465,24 +467,22 @@ tcLcStmt _ _ stmt _ _
 tcDoStmt :: TcStmtChecker
 
 tcDoStmt ctxt (BindStmt pat rhs bind_op fail_op) res_ty thing_inside
-  = do { (rhs', rhs_ty) <- tcInferRhoNC rhs
-               -- We should use type *inference* for the RHS computations, 
-                -- becuase of GADTs. 
-               --      do { pat <- rhs; <rest> }
-               -- is rather like
-               --      case rhs of { pat -> <rest> }
-               -- We do inference on rhs, so that information about its type 
-                -- can be refined when type-checking the pattern. 
+  = do {       -- Deal with rebindable syntax:
+               --       (>>=) :: rhs_ty -> (pat_ty -> new_res_ty) -> res_ty
+               -- This level of generality is needed for using do-notation
+               -- in full generality; see Trac #1537
+
+               -- I'd like to put this *after* the tcSyntaxOp 
+                -- (see Note [Treat rebindable syntax first], but that breaks 
+               -- the rigidity info for GADTs.  When we move to the new story
+                -- for GADTs, we can move this after tcSyntaxOp
+          (rhs', rhs_ty) <- tcInferRhoNC rhs
 
-       -- Deal with rebindable syntax:
-       --       (>>=) :: rhs_ty -> (pat_ty -> new_res_ty) -> res_ty
-       -- This level of generality is needed for using do-notation
-       -- in full generality; see Trac #1537
        ; ((bind_op', new_res_ty), pat_ty) <- 
             withBox liftedTypeKind $ \ pat_ty ->
             withBox liftedTypeKind $ \ new_res_ty ->
             tcSyntaxOp DoOrigin bind_op 
-                       (mkFunTys [rhs_ty, mkFunTy pat_ty new_res_ty] res_ty)
+                            (mkFunTys [rhs_ty, mkFunTy pat_ty new_res_ty] res_ty)
 
                -- If (but only if) the pattern can fail, 
                -- typecheck the 'fail' operator
@@ -490,31 +490,94 @@ tcDoStmt ctxt (BindStmt pat rhs bind_op fail_op) res_ty thing_inside
                      then return noSyntaxExpr
                      else tcSyntaxOp DoOrigin fail_op (mkFunTy stringTy new_res_ty)
 
+               -- We should typecheck the RHS *before* the pattern,
+                -- because of GADTs. 
+               --      do { pat <- rhs; <rest> }
+               -- is rather like
+               --      case rhs of { pat -> <rest> }
+               -- We do inference on rhs, so that information about its type 
+                -- can be refined when type-checking the pattern. 
+
        ; (pat', thing) <- tcPat (StmtCtxt ctxt) pat pat_ty new_res_ty thing_inside
 
        ; return (BindStmt pat' rhs' bind_op' fail_op', thing) }
 
 
 tcDoStmt _ (ExprStmt rhs then_op _) res_ty thing_inside
-  = do { (rhs', rhs_ty) <- tcInferRhoNC rhs
-
-       -- Deal with rebindable syntax; (>>) :: rhs_ty -> new_res_ty -> res_ty
-       ; (then_op', new_res_ty) <-
+  = do {       -- Deal with rebindable syntax; 
+                --   (>>) :: rhs_ty -> new_res_ty -> res_ty
+               -- See also Note [Treat rebindable syntax first]
+         ((then_op', rhs_ty), new_res_ty) <-
                withBox liftedTypeKind $ \ new_res_ty ->
+               withBox liftedTypeKind $ \ rhs_ty ->
                tcSyntaxOp DoOrigin then_op 
                           (mkFunTys [rhs_ty, new_res_ty] res_ty)
 
+        ; rhs' <- tcMonoExprNC rhs rhs_ty
        ; thing <- thing_inside new_res_ty
        ; return (ExprStmt rhs' then_op' rhs_ty, thing) }
 
-tcDoStmt ctxt (RecStmt {}) _ _
-  = failWithTc (ptext (sLit "Illegal 'rec' stmt in") <+> pprStmtContext ctxt)
-       -- This case can't be caught in the renamer
-       -- see RnExpr.checkRecStmt
+tcDoStmt ctxt (RecStmt { recS_stmts = stmts, recS_later_ids = later_names
+                       , recS_rec_ids = rec_names, recS_ret_fn = ret_op
+                       , recS_mfix_fn = mfix_op, recS_bind_fn = bind_op }) 
+         res_ty thing_inside
+  = do  { let tup_names = rec_names ++ filterOut (`elem` rec_names) later_names
+        ; tup_elt_tys <- newFlexiTyVarTys (length tup_names) liftedTypeKind
+        ; let tup_ids = zipWith mkLocalId tup_names tup_elt_tys
+             tup_ty  = mkCoreTupTy tup_elt_tys
+
+        ; tcExtendIdEnv tup_ids $ do
+        { ((stmts', (ret_op', tup_rets)), stmts_ty)
+                <- withBox liftedTypeKind $ \ stmts_ty ->
+                   tcStmts ctxt tcDoStmt stmts stmts_ty   $ \ inner_res_ty ->
+                   do { tup_rets <- zipWithM tc_ret tup_names tup_elt_tys
+                     ; ret_op' <- tcSyntaxOp DoOrigin ret_op (mkFunTy tup_ty inner_res_ty)
+                      ; return (ret_op', tup_rets) }
+
+       ; (mfix_op', mfix_res_ty) <- withBox liftedTypeKind $ \ mfix_res_ty ->
+                                     tcSyntaxOp DoOrigin mfix_op
+                                        (mkFunTy (mkFunTy tup_ty stmts_ty) mfix_res_ty)
+
+       ; (bind_op', new_res_ty) <- withBox liftedTypeKind $ \ new_res_ty ->
+                                   tcSyntaxOp DoOrigin bind_op 
+                                       (mkFunTys [mfix_res_ty, mkFunTy tup_ty new_res_ty] res_ty)
+
+        ; (thing,lie) <- getLIE (thing_inside new_res_ty)
+        ; lie_binds <- bindInstsOfLocalFuns lie tup_ids
+  
+        ; let rec_ids = takeList rec_names tup_ids
+       ; later_ids <- tcLookupLocalIds later_names
+       ; traceTc (text "tcdo" <+> vcat [ppr rec_ids <+> ppr (map idType rec_ids),
+                                         ppr later_ids <+> ppr (map idType later_ids)])
+        ; return (RecStmt { recS_stmts = stmts', recS_later_ids = later_ids
+                          , recS_rec_ids = rec_ids, recS_ret_fn = ret_op' 
+                          , recS_mfix_fn = mfix_op', recS_bind_fn = bind_op'
+                          , recS_rec_rets = tup_rets, recS_dicts = lie_binds }, thing)
+        }}
+  where 
+    -- Unify the types of the "final" Ids with those of "knot-tied" Ids
+    tc_ret rec_name mono_ty
+        = do { poly_id <- tcLookupId rec_name
+                -- poly_id may have a polymorphic type
+                -- but mono_ty is just a monomorphic type variable
+             ; co_fn <- tcSubExp DoOrigin (idType poly_id) mono_ty
+             ; return (mkHsWrap co_fn (HsVar poly_id)) }
 
 tcDoStmt _ stmt _ _
   = pprPanic "tcDoStmt: unexpected Stmt" (ppr stmt)
+\end{code}
 
+Note [Treat rebindable syntax first]
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+When typechecking
+       do { bar; ... } :: IO ()
+we want to typecheck 'bar' in the knowledge that it should be an IO thing,
+pushing info from the context into the RHS.  To do this, we check the
+rebindable syntax first, and push that information into (tcMonoExprNC rhs).
+Otherwise the error shows up when cheking the rebindable syntax, and
+the expected/inferred stuff is back to front (see Trac #3613).
+
+\begin{code}
 --------------------------------
 --     Mdo-notation
 -- The distinctive features here are
@@ -533,7 +596,7 @@ tcMDoStmt tc_rhs _ (ExprStmt rhs _ _) res_ty thing_inside
        ; thing          <- thing_inside res_ty
        ; return (ExprStmt rhs' noSyntaxExpr elt_ty, thing) }
 
-tcMDoStmt tc_rhs ctxt (RecStmt stmts laterNames recNames _ _) res_ty thing_inside
+tcMDoStmt tc_rhs ctxt (RecStmt stmts laterNames recNames _ _ _ _ _) res_ty thing_inside
   = do { rec_tys <- newFlexiTyVarTys (length recNames) liftedTypeKind
        ; let rec_ids = zipWith mkLocalId recNames rec_tys
        ; tcExtendIdEnv rec_ids                 $ do
@@ -551,7 +614,7 @@ tcMDoStmt tc_rhs ctxt (RecStmt stmts laterNames recNames _ _) res_ty thing_insid
                --      (see note [RecStmt] in HsExpr)
        ; lie_binds <- bindInstsOfLocalFuns lie later_ids
   
-       ; return (RecStmt stmts' later_ids rec_ids rec_rets lie_binds, thing)
+       ; return (RecStmt stmts' later_ids rec_ids noSyntaxExpr noSyntaxExpr noSyntaxExpr rec_rets lie_binds, thing)
        }}
   where 
     -- Unify the types of the "final" Ids with those of "knot-tied" Ids
index 71a0752..6046691 100644 (file)
@@ -82,7 +82,7 @@ documentation</ulink> describes all the libraries that come with GHC.
           <option>-XRankNTypes</option>,
           <option>-XImpredicativeTypes</option>,
           <option>-XTypeOperators</option>,
-          <option>-XRecursiveDo</option>,
+          <option>-XDoRec</option>,
           <option>-XParallelListComp</option>,
           <option>-XEmptyDataDecls</option>,
           <option>-XKindSignatures</option>,
@@ -860,33 +860,45 @@ it, you can use the <option>-XNoNPlusKPatterns</option> flag.
 <title>The recursive do-notation
 </title>
 
-<para> The recursive do-notation (also known as mdo-notation) is implemented as described in
-<ulink url="http://citeseer.ist.psu.edu/erk02recursive.html">A recursive do for Haskell</ulink>,
-by Levent Erkok, John Launchbury,
-Haskell Workshop 2002, pages: 29-37. Pittsburgh, Pennsylvania. 
-This paper is essential reading for anyone making non-trivial use of mdo-notation,
-and we do not repeat it here.
-</para>
 <para>
-The do-notation of Haskell does not allow <emphasis>recursive bindings</emphasis>,
+The do-notation of Haskell 98 does not allow <emphasis>recursive bindings</emphasis>,
 that is, the variables bound in a do-expression are visible only in the textually following 
 code block. Compare this to a let-expression, where bound variables are visible in the entire binding
 group. It turns out that several applications can benefit from recursive bindings in
-the do-notation, and this extension provides the necessary syntactic support.
+the do-notation.  The <option>-XDoRec</option> flag provides the necessary syntactic support.
 </para>
 <para>
-Here is a simple (yet contrived) example:
-</para>
+Here is a simple (albeit contrived) example:
 <programlisting>
+{-# LANGUAGE DoRec #-}
 import Control.Monad.Fix
 
-justOnes = mdo xs &lt;- Just (1:xs)
-               return xs
+justOnes = do { rec { xs &lt;- Just (1:xs) }
+              ; return (map negate xs) }
 </programlisting>
+The <literal>rec</literal>
+As you can guess <literal>justOnes</literal> will evaluate to <literal>Just [-1,-1,-1,...</literal>.
+</para>
 <para>
-As you can guess <literal>justOnes</literal> will evaluate to <literal>Just [1,1,1,...</literal>.
+The background and motivation for recusrive do-notation is described in
+<ulink url="http://citeseer.ist.psu.edu/erk02recursive.html">A recursive do for Haskell</ulink>,
+by Levent Erkok, John Launchbury,
+Haskell Workshop 2002, pages: 29-37. Pittsburgh, Pennsylvania. 
+This paper is essential reading for anyone making non-trivial use of mdo-notation,
+and we do not repeat it here.  However, note that GHC uses a different syntax than the one
+in the paper.
 </para>
 
+<sect3>
+<title>Details of recursive do-notation</title>
+<para>
+The recursive do-notation is enabled with the flag <option>-XDoRec</option> or, equivalently,
+the LANGUAGE pragma <option>DoRec</option>.  It introduces the single new keyword "<literal>rec</literal>",
+which wraps a mutually-recusrive group of monadic statements,
+producing a single statement.  Similar to a <literal>let</literal>
+statement, the variables bound in the <literal>rec</literal> are 
+visible throughout the <literal>rec</literal> group, and below it.
+</para>
 <para>
 The Control.Monad.Fix library introduces the <literal>MonadFix</literal> class.  Its definition is:
 </para>
@@ -899,30 +911,35 @@ The function <literal>mfix</literal>
 dictates how the required recursion operation should be performed.  For example, 
 <literal>justOnes</literal> desugars as follows:
 <programlisting>
-justOnes = mfix (\xs' -&gt; do { xs &lt;- Just (1:xs'); return xs }
+justOnes = do { xs &lt;- mfix (\xs' -&gt; do { xs &lt;- Just (1:xs'); return xs })
+              ; return (map negate xs) }
 </programlisting>
-For full details of the way in which mdo is typechecked and desugared, see 
-the paper <ulink url="http://citeseer.ist.psu.edu/erk02recursive.html">A recursive do for Haskell</ulink>.
-In particular, GHC implements the segmentation technique described in Section 3.2 of the paper.
-</para>
-<para>
-If recursive bindings are required for a monad,
-then that monad must be declared an instance of the <literal>MonadFix</literal> class.
-The following instances of <literal>MonadFix</literal> are automatically provided: List, Maybe, IO. 
-Furthermore, the Control.Monad.ST and Control.Monad.ST.Lazy modules provide the instances of the MonadFix class 
-for Haskell's internal state monad (strict and lazy, respectively).
+In general, a <literal>rec</literal> statment <literal>rec <replaceable>ss</replaceable></literal>
+is desugared to the statement
+<programlisting>
+  <replaceable>vs</replaceable> &lt;- mfix (\~<replaceable>vs</replaceable> -&gt; do { <replaceable>ss</replaceable>
+                                                                   ; return <replaceable>vs</replaceable> })
+</programlisting>
+where <replaceable>vs</replaceable> is a tuple of the varaibles bound by <replaceable>ss</replaceable>.
+Moreover, the original <literal>rec</literal> typechecks exactly 
+when the above desugared version would do so.  (For example, this means that 
+the variables <replaceable>vs</replaceable> are all monomorphic in the statements
+following the <literal>rec</literal>, because they are bound by a lambda.)
 </para>
 <para>
-Here are some important points in using the recursive-do notation:
+Here are some other important points in using the recursive-do notation:
 <itemizedlist>
 <listitem><para>
-The recursive version of the do-notation uses the keyword <literal>mdo</literal> (rather
-than <literal>do</literal>).
+It is enabled with the flag <literal>-XDoRec</literal>, which is in turn implied by
+<literal>-fglasgow-exts</literal>.
 </para></listitem>
 
 <listitem><para>
-It is enabled with the flag <literal>-XRecursiveDo</literal>, which is in turn implied by
-<literal>-fglasgow-exts</literal>.
+If recursive bindings are required for a monad,
+then that monad must be declared an instance of the <literal>MonadFix</literal> class.
+The following instances of <literal>MonadFix</literal> are automatically provided: List, Maybe, IO. 
+Furthermore, the Control.Monad.ST and Control.Monad.ST.Lazy modules provide the instances of the MonadFix class 
+for Haskell's internal state monad (strict and lazy, respectively).
 </para></listitem>
 
 <listitem><para>
@@ -932,20 +949,31 @@ be distinct (Section 3.3 of the paper).
 </para></listitem>
 
 <listitem><para>
-Variables bound by a <literal>let</literal> statement in an <literal>mdo</literal>
-are monomorphic in the <literal>mdo</literal> (Section 3.1 of the paper).  However
-GHC breaks the <literal>mdo</literal> into segments to enhance polymorphism,
-and improve termination (Section 3.2 of the paper).
+Similar to let-bindings, GHC implements the segmentation technique described in Section 3.2 of
+<ulink url="http://citeseer.ist.psu.edu/erk02recursive.html">A recursive do for Haskell</ulink>,
+to break up a single <literal>rec</literal> statement into a sequenc e of statements with
+<literal>rec</literal> groups of minimal size.  This 
+improves polymorphism, and reduces the size of the recursive "knot".
 </para></listitem>
 </itemizedlist>
 </para>
+</sect3>
 
+<sect3> <title Mdo-notation (deprecated) </title>
+
+<para> GHC used to support the flag <option>-XREecursiveDo</option>,
+which enabled the keyword <literal>mdo</literal>, precisely as described in
+<ulink url="http://citeseer.ist.psu.edu/erk02recursive.html">A recursive do for Haskell</ulink>,
+but this is now deprecated.  Instead of <literal>mdo { Q; e }</literal>, write
+<literal>do { rec Q; e }</literal>.
+</para>
 <para>
 Historical note: The old implementation of the mdo-notation (and most
 of the existing documents) used the name
 <literal>MonadRec</literal> for the class and the corresponding library.
 This name is not supported by GHC.
 </para>
+</sect3>
 
 </sect2>