[project @ 2002-10-11 08:47:12 by simonpj]
authorsimonpj <unknown>
Fri, 11 Oct 2002 08:47:13 +0000 (08:47 +0000)
committersimonpj <unknown>
Fri, 11 Oct 2002 08:47:13 +0000 (08:47 +0000)
Fix mdo so that it works with polymorphic functions

ghc/compiler/typecheck/TcHsSyn.lhs
ghc/compiler/typecheck/TcMatches.lhs

index 386f4eb..88b745d 100644 (file)
@@ -617,56 +617,63 @@ zonkArithSeq env (FromThenTo e1 e2 e3)
 
 
 -------------------------------------------------------------------------
-zonkStmts :: ZonkEnv -> [TcStmt] -> TcM [TypecheckedStmt]
+zonkStmts  :: ZonkEnv -> [TcStmt] -> TcM [TypecheckedStmt]
 
-zonkStmts env [] = returnM []
+zonkStmts env stmts = zonk_stmts env stmts     `thenM` \ (_, stmts) ->
+                     returnM stmts
 
-zonkStmts env (ParStmtOut bndrstmtss : stmts)
+zonk_stmts :: ZonkEnv -> [TcStmt] -> TcM (ZonkEnv, [TypecheckedStmt])
+
+zonk_stmts env [] = returnM (env, [])
+
+zonk_stmts env (ParStmtOut bndrstmtss : stmts)
   = mappM (mappM zonkId) bndrss                `thenM` \ new_bndrss ->
     mappM (zonkStmts env) stmtss       `thenM` \ new_stmtss ->
     let 
        new_binders = concat new_bndrss
        env1 = extendZonkEnv env new_binders
     in
-    zonkStmts env1 stmts               `thenM` \ new_stmts ->
-    returnM (ParStmtOut (zip new_bndrss new_stmtss) : new_stmts)
+    zonk_stmts env1 stmts              `thenM` \ (env2, new_stmts) ->
+    returnM (env2, ParStmtOut (zip new_bndrss new_stmtss) : new_stmts)
   where
     (bndrss, stmtss) = unzip bndrstmtss
 
-zonkStmts env (RecStmt vs segStmts rets : stmts)
+zonk_stmts env (RecStmt vs segStmts rets : stmts)
   = mappM zonkId vs            `thenM` \ new_vs ->
     let
        env1 = extendZonkEnv env new_vs
     in
-    zonkStmts env1 segStmts    `thenM` \ new_segStmts ->
-    zonkExprs env1 rets                `thenM` \ new_rets ->
-    zonkStmts env1 stmts       `thenM` \ new_stmts ->
-    returnM (RecStmt new_vs new_segStmts new_rets : new_stmts)
-
-zonkStmts env (ResultStmt expr locn : stmts)
-  = zonkExpr env expr  `thenM` \ new_expr ->
-    zonkStmts env stmts        `thenM` \ new_stmts ->
-    returnM (ResultStmt new_expr locn : new_stmts)
+    zonk_stmts env1 segStmts   `thenM` \ (env2, new_segStmts) ->
+       -- Zonk the ret-expressions in an envt that 
+       -- has the polymorphic bindings in the envt
+    zonkExprs env2 rets                `thenM` \ new_rets ->
+    zonk_stmts env1 stmts      `thenM` \ (env3, new_stmts) ->
+    returnM (env3, RecStmt new_vs new_segStmts new_rets : new_stmts)
+
+zonk_stmts env (ResultStmt expr locn : stmts)
+  = ASSERT( null stmts )
+    zonkExpr env expr  `thenM` \ new_expr ->
+    returnM (env, [ResultStmt new_expr locn])
 
-zonkStmts env (ExprStmt expr ty locn : stmts)
+zonk_stmts env (ExprStmt expr ty locn : stmts)
   = zonkExpr env expr          `thenM` \ new_expr ->
     zonkTcTypeToType env ty    `thenM` \ new_ty ->
-    zonkStmts env stmts                `thenM` \ new_stmts ->
-    returnM (ExprStmt new_expr new_ty locn : new_stmts)
+    zonk_stmts env stmts       `thenM` \ (env1, new_stmts) ->
+    returnM (env1, ExprStmt new_expr new_ty locn : new_stmts)
 
-zonkStmts env (LetStmt binds : stmts)
-  = zonkBinds env binds                `thenM` \ (new_env, new_binds) ->
-    zonkStmts new_env stmts    `thenM` \ new_stmts ->
-    returnM (LetStmt new_binds : new_stmts)
+zonk_stmts env (LetStmt binds : stmts)
+  = zonkBinds env binds                `thenM` \ (env1, new_binds) ->
+    zonk_stmts env1 stmts      `thenM` \ (env2, new_stmts) ->
+    returnM (env2, LetStmt new_binds : new_stmts)
 
-zonkStmts env (BindStmt pat expr locn : stmts)
+zonk_stmts env (BindStmt pat expr locn : stmts)
   = zonkExpr env expr                  `thenM` \ new_expr ->
     zonkPat env pat                    `thenM` \ (new_pat, new_ids) ->
     let
        env1 = extendZonkEnv env (bagToList new_ids)
     in
-    zonkStmts env1 stmts               `thenM` \ new_stmts ->
-    returnM (BindStmt new_pat new_expr locn : new_stmts)
+    zonk_stmts env1 stmts              `thenM` \ (env2, new_stmts) ->
+    returnM (env2, BindStmt new_pat new_expr locn : new_stmts)
 
 
 
index a1a5758..985cc46 100644 (file)
@@ -26,7 +26,7 @@ import TcHsSyn                ( TcMatch, TcGRHSs, TcStmt, TcDictBinds,
 import TcRnMonad
 import TcMonoType      ( tcAddScopedTyVars, tcHsSigType, UserTypeCtxt(..) )
 import Inst            ( tcSyntaxName )
-import TcEnv           ( TcId, tcLookupLocalIds, tcExtendLocalValEnv, tcExtendLocalValEnv2 )
+import TcEnv           ( TcId, tcLookupLocalIds, tcLookupId, tcExtendLocalValEnv, tcExtendLocalValEnv2 )
 import TcPat           ( tcPat, tcMonoPatBndr )
 import TcMType         ( newTyVarTy, newTyVarTys, zonkTcType, zapToType )
 import TcType          ( TcType, TcTyVar, tyVarsOfType, tidyOpenTypes, tidyOpenType,
@@ -460,25 +460,29 @@ tcStmtAndThen combine do_or_lc m_ty (ParStmtOut bndr_stmts_s) thing_inside
        -- RecStmt
 tcStmtAndThen combine do_or_lc m_ty (RecStmt recNames stmts _) thing_inside
   = newTyVarTys (length recNames) liftedTypeKind               `thenM` \ recTys ->
-    tcExtendLocalValEnv (zipWith mkLocalId recNames recTys)    $
+    let
+       mono_ids = zipWith mkLocalId recNames recTys
+    in
+    tcExtendLocalValEnv mono_ids                       $
     tcStmtsAndThen combine_rec do_or_lc m_ty stmts (
-       tcLookupLocalIds recNames  `thenM` \ rn ->
-       returnM ([], rn)
-    )                                                          `thenM` \ (stmts', recIds) ->
+       mappM tc_ret (recNames `zip` recTys)    `thenM` \ rets ->
+       returnM ([], rets)
+    )                                          `thenM` \ (stmts', rets) ->
 
-    -- Unify the types of the "final" Ids with those of "knot-tied" Ids
-    mappM tc_ret (recIds `zip` recTys)                 `thenM` \ rets' ->
-  
-    thing_inside                                       `thenM` \ thing ->
+       -- NB: it's the mono_ids that scope over this part
+    thing_inside                               `thenM` \ thing ->
   
-    returnM (combine (RecStmt recIds stmts' rets') thing)
+    returnM (combine (RecStmt mono_ids stmts' rets) thing)
   where 
     combine_rec stmt (stmts, thing) = (stmt:stmts, thing)
 
     -- Unify the types of the "final" Ids with those of "knot-tied" Ids
-    tc_ret (rec_id, rec_ty)
-       = tcSubExp rec_ty (idType rec_id)       `thenM` \ co_fn ->
-         returnM (co_fn <$> HsVar rec_id) 
+    tc_ret (rec_name, mono_ty)
+       = tcLookupId rec_name                   `thenM` \ poly_id ->
+               -- poly_id may have a polymorphic type
+               -- but mono_ty is just a monomorphic type variable
+         tcSubExp mono_ty (idType poly_id)     `thenM` \ co_fn ->
+         returnM (co_fn <$> HsVar poly_id) 
 
        -- ExprStmt
 tcStmtAndThen combine do_or_lc m_ty@(m, res_elt_ty) stmt@(ExprStmt exp _ locn) thing_inside