More on monad-comp; an intermediate state, so don't pull
authorSimon Peyton Jones <simonpj@microsoft.com>
Mon, 2 May 2011 12:56:37 +0000 (13:56 +0100)
committerSimon Peyton Jones <simonpj@microsoft.com>
Mon, 2 May 2011 12:56:37 +0000 (13:56 +0100)
compiler/deSugar/DsExpr.lhs
compiler/hsSyn/HsExpr.lhs
compiler/hsSyn/HsUtils.lhs
compiler/parser/Parser.y.pp
compiler/rename/RnExpr.lhs
compiler/typecheck/TcMatches.lhs

index 418bda5..4088e44 100644 (file)
@@ -740,6 +740,7 @@ dsDo stmts
                                          noSyntaxExpr  -- Tuple cannot fail
 
         tup_ids      = rec_ids ++ filterOut (`elem` rec_ids) later_ids
+        tup_ty       = mkBoxedTupleTy (map idType tup_ids) -- Deals with singleton case
         rec_tup_pats = map nlVarPat tup_ids
         later_pats   = rec_tup_pats
         rets         = map noLoc rec_rets
@@ -748,8 +749,11 @@ dsDo stmts
                                                  (mkFunTy tup_ty body_ty))
         mfix_pat     = noLoc $ LazyPat $ mkLHsPatTup rec_tup_pats
         body         = noLoc $ HsDo DoExpr (rec_stmts ++ [ret_stmt]) body_ty
-        ret_stmt     = noLoc $ LastStmt (mkLHsTupleExpr rets) return_op
-        tup_ty       = mkBoxedTupleTy (map idType tup_ids) -- Deals with singleton case
+        ret_app      = nlHsApp (noLoc return_op) (mkLHsTupleExpr rets)
+        ret_stmt     = noLoc $ mkLastStmt ret_app
+                    -- This LastStmt will be desugared with dsDo, 
+                    -- which ignores the return_op in the LastStmt,
+                    -- so we must apply the return_op explicitly 
 
 handle_failure :: LPat Id -> MatchResult -> SyntaxExpr Id -> DsM CoreExpr
     -- In a do expression, pattern-match failure just calls
index cf9c0d7..fba270c 100644 (file)
@@ -833,7 +833,8 @@ type Stmt id = StmtLR id id
 -- The SyntaxExprs in here are used *only* for do-notation and monad
 -- comprehensions, which have rebindable syntax. Otherwise they are unused.
 data StmtLR idL idR
-  = LastStmt  -- Always the last Stmt in ListComp, MonadComp, PArrComp, DoExpr, MDoExpr
+  = LastStmt  -- Always the last Stmt in ListComp, MonadComp, PArrComp, 
+             -- and (after the renamer) DoExpr, MDoExpr
               -- Not used for GhciStmt, PatGuard, which scope over other stuff
                (LHsExpr idR)
                (SyntaxExpr idR)   -- The return operator, used only for MonadComp
@@ -1090,7 +1091,7 @@ instance (OutputableBndr idL, OutputableBndr idR) => Outputable (StmtLR idL idR)
     ppr stmt = pprStmt stmt
 
 pprStmt :: (OutputableBndr idL, OutputableBndr idR) => (StmtLR idL idR) -> SDoc
-pprStmt (LastStmt expr _)         = ppr expr
+pprStmt (LastStmt expr _)         = ifPprDebug (ptext (sLit "[last]")) <+> ppr expr
 pprStmt (BindStmt pat expr _ _)   = hsep [ppr pat, ptext (sLit "<-"), ppr expr]
 pprStmt (LetStmt binds)           = hsep [ptext (sLit "let"), pprBinds binds]
 pprStmt (ExprStmt expr _ _ _)     = ppr expr
@@ -1354,8 +1355,8 @@ pprAStmtContext ctxt = article <+> pprStmtContext ctxt
 
 -----------------
 pprStmtContext GhciStmt        = ptext (sLit "interactive GHCi command")
-pprStmtContext DoExpr          = ptext (sLit "'do' expression")
-pprStmtContext MDoExpr         = ptext (sLit "'mdo' expression")
+pprStmtContext DoExpr          = ptext (sLit "'do' block")
+pprStmtContext MDoExpr         = ptext (sLit "'mdo' block")
 pprStmtContext ListComp        = ptext (sLit "list comprehension")
 pprStmtContext MonadComp       = ptext (sLit "monad comprehension")
 pprStmtContext PArrComp        = ptext (sLit "array comprehension")
@@ -1402,8 +1403,13 @@ pprMatchInCtxt ctxt match  = hang (ptext (sLit "In") <+> pprMatchContext ctxt <>
 
 pprStmtInCtxt :: (OutputableBndr idL, OutputableBndr idR)
               => HsStmtContext idL -> StmtLR idL idR -> SDoc
-pprStmtInCtxt ctxt stmt = hang (ptext (sLit "In a stmt of") <+> pprAStmtContext ctxt <> colon)
-                         4 (ppr_stmt stmt)
+pprStmtInCtxt ctxt (LastStmt e _)
+  | isListCompExpr ctxt      -- For [ e | .. ], do not mutter about "stmts"
+  = hang (ptext (sLit "In the expression:")) 2 (ppr e)
+
+pprStmtInCtxt ctxt stmt 
+  = hang (ptext (sLit "In a stmt of") <+> pprAStmtContext ctxt <> colon)
+       2 (ppr_stmt stmt)
   where
     -- For Group and Transform Stmts, don't print the nested stmts!
     ppr_stmt (GroupStmt { grpS_by = by, grpS_using = using
index de883f2..51a0de3 100644 (file)
@@ -21,7 +21,7 @@ module HsUtils(
   mkMatchGroup, mkMatch, mkHsLam, mkHsIf,
   mkHsWrap, mkLHsWrap, mkHsWrapCoI, mkLHsWrapCoI,
   coiToHsWrapper, mkHsLams, mkHsDictLet,
-  mkHsOpApp, mkHsDo, mkHsComp, mkHsWrapPat, mkHsWrapPatCoI, mkDoStmts,
+  mkHsOpApp, mkHsDo, mkHsComp, mkHsWrapPat, mkHsWrapPatCoI, 
 
   nlHsTyApp, nlHsVar, nlHsLit, nlHsApp, nlHsApps, nlHsIntLit, nlHsVarApps, 
   nlHsDo, nlHsOpApp, nlHsLam, nlHsPar, nlHsIf, nlHsCase, nlList,
@@ -192,7 +192,6 @@ mkHsFractional :: Rational -> PostTcType -> HsOverLit id
 mkHsIsString   :: FastString -> PostTcType -> HsOverLit id
 mkHsDo         :: HsStmtContext Name -> [LStmt id] -> HsExpr id
 mkHsComp       :: HsStmtContext Name -> [LStmt id] -> LHsExpr id -> HsExpr id
-mkDoStmts      :: [LStmt id] -> [LStmt id] 
 
 mkNPat      :: HsOverLit id -> Maybe (SyntaxExpr id) -> Pat id
 mkNPlusKPat :: Located id -> HsOverLit id -> Pat id
@@ -215,11 +214,6 @@ mkHsIsString   s       = OverLit (HsIsString   s)  noRebindableInfo noSyntaxExpr
 noRebindableInfo :: Bool
 noRebindableInfo = error "noRebindableInfo"    -- Just another placeholder; 
 
--- mkDoStmts turns a trailing ExprStmt into a LastStmt
-mkDoStmts [L loc (ExprStmt e _ _ _)] = [L loc (mkLastStmt e)]
-mkDoStmts (s:ss)                    = s : mkDoStmts ss
-mkDoStmts []                        = []
-
 mkHsDo ctxt stmts = HsDo ctxt stmts placeHolderType
 mkHsComp ctxt stmts expr = mkHsDo ctxt (stmts ++ [last_stmt])
   where
index c42ea0c..aa20ea6 100644 (file)
@@ -1602,7 +1602,7 @@ apats  :: { [LPat RdrName] }
 -- Statement sequences
 
 stmtlist :: { Located [LStmt RdrName] }
-       : '{'           stmts '}'       { LL (mkDoStmts (unLoc $2)) }
+       : '{'           stmts '}'       { LL (unLoc $2) }
        |     vocurly   stmts close     { $2 }
 
 --     do { ;; s ; s ; ; s ;; }
index d1dd222..11d44e3 100644 (file)
@@ -648,32 +648,22 @@ rnStmts MDoExpr stmts thing_inside    -- Deal with mdo
   = -- Behave like do { rec { ...all but last... }; last }
     do { ((stmts1, (stmts2, thing)), fvs) 
           <- rnStmt MDoExpr (noLoc $ mkRecStmt all_but_last) $ \ _ ->
-             do { checkStmt MDoExpr True last_stmt
-                ; rnStmt MDoExpr last_stmt thing_inside }
+             do { last_stmt' <- checkLastStmt MDoExpr last_stmt
+                ; rnStmt MDoExpr last_stmt' thing_inside }
        ; return (((stmts1 ++ stmts2), thing), fvs) }
   where
     Just (all_but_last, last_stmt) = snocView stmts
 
-rnStmts ctxt (lstmt@(L loc stmt) : lstmts) thing_inside
+rnStmts ctxt (lstmt@(L loc _) : lstmts) thing_inside
   | null lstmts
   = setSrcSpan loc $
-    do { -- Turn a final ExprStmt into a LastStmt
-         -- This is the first place it's convenient to do this
-        -- (In principle the parser could do it, but it's 
-        --  just not very convenient to do so.)
-         let stmt' | okEmpty ctxt 
-                   = lstmt
-                   | otherwise    
-                   = case stmt of 
-                       ExprStmt e _ _ _ -> L loc (mkLastStmt e)
-                      _                -> lstmt
-       ; checkStmt ctxt True {- last stmt -} stmt'
-       ; rnStmt ctxt stmt' thing_inside }
+    do { lstmt' <- checkLastStmt ctxt lstmt
+       ; rnStmt ctxt lstmt' thing_inside }
 
   | otherwise
   = do { ((stmts1, (stmts2, thing)), fvs) 
             <- setSrcSpan loc                         $
-               do { checkStmt ctxt False {- Not last -} lstmt
+               do { checkStmt ctxt lstmt
                   ; rnStmt ctxt lstmt    $ \ bndrs1 ->
                     rnStmts ctxt lstmts  $ \ bndrs2 ->
                     thing_inside (bndrs1 ++ bndrs2) }
@@ -1211,7 +1201,7 @@ checkEmptyStmts :: HsStmtContext Name -> RnM ()
 checkEmptyStmts ctxt 
   = unless (okEmpty ctxt) (addErr (emptyErr ctxt))
 
-okEmpty :: HsStmtContext Name -> Bool
+okEmpty :: HsStmtContext a -> Bool
 okEmpty (PatGuard {}) = True
 okEmpty _             = False
 
@@ -1221,14 +1211,42 @@ emptyErr (TransformStmtCtxt {}) = ptext (sLit "Empty statement group preceding '
 emptyErr ctxt                   = ptext (sLit "Empty") <+> pprStmtContext ctxt
 
 ---------------------- 
+checkLastStmt :: HsStmtContext Name
+              -> LStmt RdrName 
+              -> RnM (LStmt RdrName)
+checkLastStmt ctxt lstmt@(L loc stmt)
+  = case ctxt of 
+      ListComp  -> check_comp
+      MonadComp -> check_comp
+      PArrComp  -> check_comp
+      DoExpr   -> check_do
+      MDoExpr   -> check_do
+      _         -> check_other
+  where
+    check_do   -- Expect ExprStmt, and change it to LastStmt
+      = case stmt of 
+          ExprStmt e _ _ _ -> return (L loc (mkLastStmt e))
+          LastStmt {}      -> return lstmt   -- "Deriving" clauses may generate a
+                                            -- LastStmt directly (unlike the parser)
+         _                -> do { addErr (hang last_error 2 (ppr stmt)); return lstmt }
+    last_error = (ptext (sLit "The last statement in") <+> pprAStmtContext ctxt
+                  <+> ptext (sLit "must be an expression"))
+
+    check_comp -- Expect LastStmt; this should be enforced by the parser!
+      = case stmt of 
+          LastStmt {} -> return lstmt
+          _           -> pprPanic "checkLastStmt" (ppr lstmt)
+
+    check_other        -- Behave just as if this wasn't the last stmt
+      = do { checkStmt ctxt lstmt; return lstmt }
+
 -- Checking when a particular Stmt is ok
 checkStmt :: HsStmtContext Name
-          -> Bool                      -- True <=> this is the last Stmt in the sequence
           -> LStmt RdrName 
           -> RnM ()
-checkStmt ctxt is_last (L _ stmt)
+checkStmt ctxt (L _ stmt)
   = do { dflags <- getDOpts
-       ; case okStmt dflags ctxt is_last stmt of 
+       ; case okStmt dflags ctxt stmt of 
            Nothing    -> return ()
            Just extra -> addErr (msg $$ extra) }
   where
@@ -1250,42 +1268,32 @@ isOK, notOK :: Maybe SDoc
 isOK  = Nothing
 notOK = Just empty
 
-okStmt, okDoStmt, okCompStmt :: DynFlags -> HsStmtContext Name -> Bool 
+okStmt, okDoStmt, okCompStmt :: DynFlags -> HsStmtContext Name
                              -> Stmt RdrName -> Maybe SDoc
 -- Return Nothing if OK, (Just extra) if not ok
 -- The "extra" is an SDoc that is appended to an generic error message
-okStmt _ (PatGuard {}) _ stmt
+okStmt _ (PatGuard {}) stmt
   = case stmt of
       ExprStmt {} -> isOK
       BindStmt {} -> isOK
       LetStmt {}  -> isOK
       _           -> notOK
 
-okStmt dflags (ParStmtCtxt ctxt) _ stmt
+okStmt dflags (ParStmtCtxt ctxt) stmt
   = case stmt of
       LetStmt (HsIPBinds {}) -> notOK
-      _                      -> okStmt dflags ctxt False stmt
-                               -- NB: is_last=False in recursive
-                               -- call; the branches of of a Par
-                               -- not finish with a LastStmt
+      _                      -> okStmt dflags ctxt stmt
 
-okStmt dflags (TransformStmtCtxt ctxt) _ stmt 
-  = okStmt dflags ctxt False stmt
+okStmt dflags (TransformStmtCtxt ctxt) stmt 
+  = okStmt dflags ctxt stmt
 
-okStmt dflags ctxt is_last stmt 
-  | isDoExpr       ctxt = okDoStmt   dflags ctxt is_last stmt
-  | isListCompExpr ctxt = okCompStmt dflags ctxt is_last stmt
+okStmt dflags ctxt stmt 
+  | isDoExpr       ctxt = okDoStmt   dflags ctxt stmt
+  | isListCompExpr ctxt = okCompStmt dflags ctxt stmt
   | otherwise           = pprPanic "okStmt" (pprStmtContext ctxt)
 
 ----------------
-okDoStmt dflags ctxt is_last stmt
-  | is_last
-  = case stmt of 
-      LastStmt {} -> isOK
-      _ -> Just (ptext (sLit "The last statement in") <+> pprAStmtContext ctxt
-                 <+> ptext (sLit "must be an expression"))
-
-  | otherwise
+okDoStmt dflags _ stmt
   = case stmt of
        RecStmt {} 
          | Opt_DoRec `xopt` dflags -> isOK
@@ -1297,13 +1305,7 @@ okDoStmt dflags ctxt is_last stmt
 
 
 ----------------
-okCompStmt dflags _ is_last stmt
-  | is_last
-  = case stmt of
-      LastStmt {} -> Nothing
-      _ -> pprPanic "Unexpected stmt" (ppr stmt)  -- Not a user error
-
-  | otherwise
+okCompStmt dflags _ stmt
   = case stmt of
        BindStmt {} -> isOK
        LetStmt {}  -> isOK
index 820e517..87449b6 100644 (file)
@@ -358,7 +358,7 @@ tcLcStmt :: TyCon   -- The list/Parray type constructor ([] or PArray)
         -> TcStmtChecker
 
 tcLcStmt _ _ (LastStmt body _) elt_ty thing_inside
-  = do { body' <- tcMonoExpr body elt_ty
+  = do { body' <- tcMonoExprNC body elt_ty
        ; thing <- thing_inside (panic "tcLcStmt: thing_inside")
        ; return (LastStmt body' noSyntaxExpr, thing) }
 
@@ -502,7 +502,7 @@ tcMcStmt _ (LastStmt body return_op) res_ty thing_inside
   = do  { a_ty       <- newFlexiTyVarTy liftedTypeKind
         ; return_op' <- tcSyntaxOp MCompOrigin return_op
                                    (a_ty `mkFunTy` res_ty)
-        ; body'      <- tcMonoExpr body a_ty
+        ; body'      <- tcMonoExprNC body a_ty
         ; thing      <- thing_inside (panic "tcMcStmt: thing_inside")
         ; return (LastStmt body' return_op', thing) } 
 
@@ -558,15 +558,20 @@ tcMcStmt _ (ExprStmt rhs then_op guard_op _) res_ty thing_inside
 --   [ body | stmts, then f by e ]  ->  f :: forall a. (a -> t) -> m a -> m a
 --
 tcMcStmt ctxt (TransformStmt stmts binders usingExpr maybeByExpr return_op bind_op) res_ty thing_inside
-  = do  {
-        -- We don't know the types of binders yet, so we use this dummy and
-        -- later unify this type with the `m_bndr_ty`
-          ty_dummy <- newFlexiTyVarTy liftedTypeKind
+  = do { let star_star_kind = liftedTypeKind `mkArrowKind` liftedTypeKind
+       ; m1_ty      <- newFlexiTyVarTy star_star_kind
+       ; m2_ty      <- newFlexiTyVarTy star_star_kind
+       ; n_ty       <- newFlexiTyVarTy star_star_kind
+       ; tup_ty_var <- newFlexiTyVarTy liftedTypeKind
+       ; new_res_ty <- newFlexiTyVarTy liftedTypeKind
+       ; let m1_tup_ty = m1_ty `mkAppTy` tup_ty_var
 
+            -- 'stmts' returns a result of type (m1_ty tuple_ty),
+            -- typically something like [(Int,Bool,Int)]
+            -- We don't know what tuple_ty is yet, so we use a variable
         ; (stmts', (binders', usingExpr', maybeByExpr', return_op', bind_op', thing)) <- 
-              tcStmtsAndThen (TransformStmtCtxt ctxt) tcMcStmt stmts ty_dummy $ \res_ty' -> do
-                  { (_, (m_ty, _)) <- matchExpectedAppTy res_ty'
-                  ; (usingExpr', maybeByExpr') <- 
+              tcStmtsAndThen (TransformStmtCtxt ctxt) tcMcStmt stmts m1_tup_ty $ \res_ty' -> do
+                  { (usingExpr', maybeByExpr') <- 
                         case maybeByExpr of
                             Nothing -> do
                                 -- We must validate that usingExpr :: forall a. m a -> m a
@@ -671,8 +676,8 @@ tcMcStmt ctxt (GroupStmt { grpS_stmts = stmts, grpS_bndrs = bindersMap
              using_res_ty = m2_ty `mkAppTy` n_tup_ty   -- m2 (n (a,b,c))
             using_fun_ty = using_arg_ty `mkFunTy` using_arg_ty
               
-                -- (>>=) :: m2 (n (a,b,c)) -> ( n (a,b,c) -> new_res_ty ) -> res_ty
-                -- using :: ((a,b,c)->t) -> m1 (a,b,c) -> m2 (n (a,b,c))
+          -- (>>=) :: m2 (n (a,b,c)) -> ( n (a,b,c) -> new_res_ty ) -> res_ty
+          -- using :: ((a,b,c)->t) -> m1 (a,b,c) -> m2 (n (a,b,c))
 
        --------------- Typecheck the 'bind' function -------------
        ; bind_op' <- tcSyntaxOp MCompOrigin bind_op $