More hacking on monad-comp; now works
[ghc-hetmet.git] / compiler / typecheck / TcMatches.lhs
index 60bf7e2..820e517 100644 (file)
@@ -8,7 +8,7 @@ TcMatches: Typecheck some @Matches@
 \begin{code}
 module TcMatches ( tcMatchesFun, tcGRHSsPat, tcMatchesCase, tcMatchLambda,
                   TcMatchCtxt(..), 
-                  tcStmts, tcDoStmts, tcBody,
+                  tcStmts, tcStmtsAndThen, tcDoStmts, tcBody,
                   tcDoStmt, tcMDoStmt, tcGuardStmt
        ) where
 
@@ -224,7 +224,7 @@ tcGRHSs ctxt (GRHSs grhss binds) res_ty
 tcGRHS :: TcMatchCtxt -> TcRhoType -> GRHS Name -> TcM (GRHS TcId)
 
 tcGRHS ctxt res_ty (GRHS guards rhs)
-  = do  { (guards', rhs') <- tcStmts stmt_ctxt tcGuardStmt guards res_ty $
+  = do  { (guards', rhs') <- tcStmtsAndThen stmt_ctxt tcGuardStmt guards res_ty $
                             mc_body ctxt rhs
        ; return (GRHS guards' rhs') }
   where
@@ -245,7 +245,7 @@ tcDoStmts :: HsStmtContext Name
          -> TcM (HsExpr TcId)          -- Returns a HsDo
 tcDoStmts ListComp stmts res_ty
   = do { (coi, elt_ty) <- matchExpectedListTy res_ty
-       ; stmts' <- tcStmts ListComp (tcLcStmt listTyCon) stmts res_ty
+       ; stmts' <- tcStmts ListComp (tcLcStmt listTyCon) stmts elt_ty
        ; return $ mkHsWrapCoI coi 
                      (HsDo ListComp stmts' (mkListTy elt_ty)) }
 
@@ -267,7 +267,7 @@ tcDoStmts MonadComp stmts res_ty
   = do  { stmts' <- tcStmts MonadComp tcMcStmt stmts res_ty 
         ; return (HsDo MonadComp stmts' res_ty) }
 
-tcDoStmts ctxt _ _ _ _ = pprPanic "tcDoStmts" (pprStmtContext ctxt)
+tcDoStmts ctxt _ _ = pprPanic "tcDoStmts" (pprStmtContext ctxt)
 
 tcBody :: LHsExpr Name -> TcRhoType -> TcM (LHsExpr TcId)
 tcBody body res_ty
@@ -298,7 +298,7 @@ tcStmts :: HsStmtContext Name
        -> TcRhoType
         -> TcM [LStmt TcId]
 tcStmts ctxt stmt_chk stmts res_ty
-  = do { (stmts', _) <- tcStmtsAndThen ctxt stmt_check stmts res_ty $
+  = do { (stmts', _) <- tcStmtsAndThen ctxt stmt_chk stmts res_ty $
                         const (return ())
        ; return stmts' }
 
@@ -357,9 +357,9 @@ tcGuardStmt _ stmt _ _
 tcLcStmt :: TyCon      -- The list/Parray type constructor ([] or PArray)
         -> TcStmtChecker
 
-tcLcStmt m_tc ctxt (LastStmt body _) elt_ty thing_inside
+tcLcStmt _ _ (LastStmt body _) elt_ty thing_inside
   = do { body' <- tcMonoExpr body elt_ty
-       ; thing <- thing_inside elt_ty
+       ; thing <- thing_inside (panic "tcLcStmt: thing_inside")
        ; return (LastStmt body' noSyntaxExpr, thing) }
 
 -- A generator, pat <- rhs
@@ -407,7 +407,7 @@ tcLcStmt m_tc ctxt (ParStmt bndr_stmts_s _ _ _) elt_ty thing_inside
 
     loop ((stmts, names) : pairs)
       = do { (stmts', (ids, pairs', thing))
-               <- tcStmts ctxt (tcLcStmt m_tc) stmts elt_ty $ \ _elt_ty' ->
+               <- tcStmtsAndThen ctxt (tcLcStmt m_tc) stmts elt_ty $ \ _elt_ty' ->
                   do { ids <- tcLookupLocalIds names
                      ; (pairs', thing) <- loop pairs
                      ; return (ids, pairs', thing) }
@@ -415,7 +415,7 @@ tcLcStmt m_tc ctxt (ParStmt bndr_stmts_s _ _ _) elt_ty thing_inside
 
 tcLcStmt m_tc ctxt (TransformStmt stmts binders usingExpr maybeByExpr _ _) elt_ty thing_inside = do
     (stmts', (binders', usingExpr', maybeByExpr', thing)) <- 
-        tcStmts (TransformStmtCtxt ctxt) (tcLcStmt m_tc) stmts elt_ty $ \elt_ty' -> do
+        tcStmtsAndThen (TransformStmtCtxt ctxt) (tcLcStmt m_tc) stmts elt_ty $ \elt_ty' -> do
             let alphaListTy = mkTyConApp m_tc [alphaTy]
                     
             (usingExpr', maybeByExpr') <- 
@@ -442,11 +442,13 @@ tcLcStmt m_tc ctxt (TransformStmt stmts binders usingExpr maybeByExpr _ _) elt_t
 
     return (TransformStmt stmts' binders' usingExpr' maybeByExpr' noSyntaxExpr noSyntaxExpr, thing)
 
-tcLcStmt m_tc ctxt (GroupStmt stmts bindersMap by using _ _ _) elt_ty thing_inside
+tcLcStmt m_tc ctxt (GroupStmt { grpS_stmts = stmts, grpS_bndrs =  bindersMap
+                              , grpS_by = by, grpS_using = using
+                              , grpS_explicit = explicit }) elt_ty thing_inside
   = do { let (bndr_names, list_bndr_names) = unzip bindersMap
 
        ; (stmts', (bndr_ids, by', using_ty, elt_ty')) <-
-            tcStmts (TransformStmtCtxt ctxt) (tcLcStmt m_tc) stmts elt_ty $ \elt_ty' -> do
+            tcStmtsAndThen (TransformStmtCtxt ctxt) (tcLcStmt m_tc) stmts elt_ty $ \elt_ty' -> do
                (by', using_ty) <- 
                    case by of
                      Nothing   -> -- check that using :: forall a. [a] -> [[a]]
@@ -471,14 +473,14 @@ tcLcStmt m_tc ctxt (GroupStmt stmts bindersMap by using _ _ _) elt_ty thing_insi
              bindersMap' = bndr_ids `zip` list_bndr_ids
             -- See Note [GroupStmt binder map] in HsExpr
             
-       ; using' <- case using of
-                     Left  e -> do { e' <- tcPolyExpr e         using_ty; return (Left  e') }
-                     Right e -> do { e' <- tcPolyExpr (noLoc e) using_ty; return (Right (unLoc e')) }
+       ; using' <- tcPolyExpr using using_ty
 
              -- Type check the thing in the environment with 
             -- these new binders and return the result
        ; thing <- tcExtendIdEnv list_bndr_ids (thing_inside elt_ty')
-       ; return (GroupStmt stmts' bindersMap' by' using' noSyntaxExpr noSyntaxExpr noSyntaxExpr, thing) }
+       ; return (emptyGroupStmt { grpS_stmts = stmts', grpS_bndrs = bindersMap'
+                                , grpS_by = by', grpS_using = using'
+                                , grpS_explicit = explicit }, thing) }
   where
     alphaListTy = mkTyConApp m_tc [alphaTy]
     alphaListListTy = mkTyConApp m_tc [alphaListTy]
@@ -496,12 +498,13 @@ tcLcStmt _ _ stmt _ _
 
 tcMcStmt :: TcStmtChecker
 
-tcMcStmt ctxt (LastStmt body return_op) res_ty thing_inside
+tcMcStmt _ (LastStmt body return_op) res_ty thing_inside
   = do  { a_ty       <- newFlexiTyVarTy liftedTypeKind
         ; return_op' <- tcSyntaxOp MCompOrigin return_op
                                    (a_ty `mkFunTy` res_ty)
         ; body'      <- tcMonoExpr body a_ty
-        ; return (body', return_op') } 
+        ; thing      <- thing_inside (panic "tcMcStmt: thing_inside")
+        ; return (LastStmt body' return_op', thing) } 
 
 -- Generators for monad comprehensions ( pat <- rhs )
 --
@@ -561,7 +564,7 @@ tcMcStmt ctxt (TransformStmt stmts binders usingExpr maybeByExpr return_op bind_
           ty_dummy <- newFlexiTyVarTy liftedTypeKind
 
         ; (stmts', (binders', usingExpr', maybeByExpr', return_op', bind_op', thing)) <- 
-              tcStmts (TransformStmtCtxt ctxt) tcMcStmt stmts ty_dummy $ \res_ty' -> do
+              tcStmtsAndThen (TransformStmtCtxt ctxt) tcMcStmt stmts ty_dummy $ \res_ty' -> do
                   { (_, (m_ty, _)) <- matchExpectedAppTy res_ty'
                   ; (usingExpr', maybeByExpr') <- 
                         case maybeByExpr of
@@ -627,10 +630,14 @@ tcMcStmt ctxt (TransformStmt stmts binders usingExpr maybeByExpr return_op bind_
 --   [ body | stmts, then group using f ]
 --     ->  f :: forall a. m a -> m (m a)
 --
-tcMcStmt ctxt (GroupStmt stmts bindersMap by using return_op bind_op fmap_op) res_ty thing_inside
-  = do { m1_ty      <- newFlexiTyVarTy liftedTypeKind
-       ; m2_ty      <- newFlexiTyVarTy liftedTypeKind
-       ; n_ty       <- newFlexiTyVarTy liftedTypeKind
+tcMcStmt ctxt (GroupStmt { grpS_stmts = stmts, grpS_bndrs = bindersMap
+                         , grpS_by = by, grpS_using = using, grpS_explicit = explicit
+                         , grpS_ret = return_op, grpS_bind = bind_op 
+                         , grpS_fmap = fmap_op }) res_ty thing_inside
+  = do { let star_star_kind = liftedTypeKind `mkArrowKind` liftedTypeKind
+       ; m1_ty      <- newFlexiTyVarTy star_star_kind
+       ; m2_ty      <- newFlexiTyVarTy star_star_kind
+       ; n_ty       <- newFlexiTyVarTy star_star_kind
        ; tup_ty_var <- newFlexiTyVarTy liftedTypeKind
        ; new_res_ty <- newFlexiTyVarTy liftedTypeKind
        ; let (bndr_names, n_bndr_names) = unzip bindersMap
@@ -640,8 +647,10 @@ tcMcStmt ctxt (GroupStmt stmts bindersMap by using return_op bind_op fmap_op) re
             -- typically something like [(Int,Bool,Int)]
             -- We don't know what tuple_ty is yet, so we use a variable
        ; (stmts', (bndr_ids, by_e_ty, return_op')) <-
-            tcStmts (TransformStmtCtxt ctxt) tcMcStmt stmts m1_tup_ty $ \res_ty' -> do
-               { by_e_ty <- mapM tcInferRhoNC by_e
+            tcStmtsAndThen (TransformStmtCtxt ctxt) tcMcStmt stmts m1_tup_ty $ \res_ty' -> do
+               { by_e_ty <- case by of
+                               Nothing -> return Nothing
+                               Just e  -> do { e_ty <- tcInferRhoNC e; return (Just e_ty) }
 
                 -- Find the Ids (and hence types) of all old binders
                 ; bndr_ids <- tcLookupLocalIds bndr_names
@@ -671,40 +680,34 @@ tcMcStmt ctxt (GroupStmt stmts bindersMap by using return_op bind_op fmap_op) re
                                              `mkFunTy` res_ty
 
        --------------- Typecheck the 'using' function -------------
-       ; let using_fun_ty = (m1_ty `mkAppTy` alphaTy) `mkFunTy` 
+       ; let poly_fun_ty = (m1_ty `mkAppTy` alphaTy) `mkFunTy` 
                                      (m2_ty `mkAppTy` (n_ty `mkAppTy` alphaTy))
              using_poly_ty = case by_e_ty of
-               Nothing       -> mkForAllTy alphaTyVar using_fun_ty
+               Nothing       -> mkForAllTy alphaTyVar poly_fun_ty
                                 -- using :: forall a. m1 a -> m2 (n a)
 
               Just (_,t_ty) -> mkForAllTy alphaTyVar $
-                                (alphaTy `mkFunTy` t_ty) `mkFunTy` using_fun_ty
+                                (alphaTy `mkFunTy` t_ty) `mkFunTy` poly_fun_ty
                                 -- using :: forall a. (a->t) -> m1 a -> m2 (n a)
                                -- where by :: t
 
-       ; using' <- case using of
-                     Left  e -> do { e' <- tcPolyExpr e         using_poly_ty
-                                   ; return (Left  e') }
-                     Right e -> do { e' <- tcPolyExpr (noLoc e) using_poly_ty
-                                   ; return (Right (unLoc e')) }
+       ; using' <- tcPolyExpr using using_poly_ty
        ; coi <- unifyType (applyTy using_poly_ty tup_ty)
                           (case by_e_ty of
                              Nothing       -> using_fun_ty
                             Just (_,t_ty) -> (tup_ty `mkFunTy` t_ty) `mkFunTy` using_fun_ty)
-       ; let final_using = mkHsWrapCoI coi (HsWrap (WpTyApp tup_ty) using') 
+       ; let final_using = fmap (mkHsWrapCoI coi . HsWrap (WpTyApp tup_ty)) using' 
 
        --------------- Typecheck the 'fmap' function -------------
        ; fmap_op' <- fmap unLoc . tcPolyExpr (noLoc fmap_op) $
                          mkForAllTy alphaTyVar $ mkForAllTy betaTyVar $
-                             (alphaTy `mkFunTy` betaTy)
-                             `mkFunTy`
-                             (m_ty `mkAppTy` alphaTy)
-                             `mkFunTy`
-                             (m_ty `mkAppTy` betaTy)
+                         (alphaTy `mkFunTy` betaTy)
+                         `mkFunTy` (n_ty `mkAppTy` alphaTy)
+                         `mkFunTy` (n_ty `mkAppTy` betaTy)
 
        ; let mk_n_bndr :: Name -> TcId -> TcId
              mk_n_bndr n_bndr_name bndr_id 
-                = mkLocalId bndr_name (n_ty `mkAppTy` idType bndr_id)
+                = mkLocalId n_bndr_name (n_ty `mkAppTy` idType bndr_id)
 
              -- Ensure that every old binder of type `b` is linked up with its
              -- new binder which should have type `n b`
@@ -716,9 +719,10 @@ tcMcStmt ctxt (GroupStmt stmts bindersMap by using return_op bind_op fmap_op) re
        -- return the result
        ; thing <- tcExtendIdEnv n_bndr_ids (thing_inside res_ty)
 
-       ; return (GroupStmt stmts' bindersMap' 
-                           (fmap fst by_e_ty) final_using 
-                           return_op' bind_op' fmap_op', thing) }
+       ; return (GroupStmt { grpS_stmts = stmts', grpS_bndrs = bindersMap' 
+                           , grpS_by = fmap fst by_e_ty, grpS_using = final_using 
+                           , grpS_ret = return_op', grpS_bind = bind_op'
+                           , grpS_fmap = fmap_op', grpS_explicit = explicit }, thing) }
 
 -- Typecheck `ParStmt`. See `tcLcStmt` for more informations about typechecking
 -- of `ParStmt`s.
@@ -733,6 +737,8 @@ tcMcStmt ctxt (GroupStmt stmts bindersMap by using return_op bind_op fmap_op) re
 --
 tcMcStmt ctxt (ParStmt bndr_stmts_s mzip_op bind_op return_op) res_ty thing_inside
   = do { (_,(m_ty,_)) <- matchExpectedAppTy res_ty
+       -- ToDo: what if the coercion isn't the identity?
+
         ; (pairs', thing) <- loop m_ty bndr_stmts_s
 
         ; let mzip_ty  = mkForAllTys [alphaTyVar, betaTyVar] $
@@ -757,12 +763,10 @@ tcMcStmt ctxt (ParStmt bndr_stmts_s mzip_op bind_op return_op) res_ty thing_insi
         ; return_op' <- fmap unLoc . tcPolyExpr (noLoc return_op) $
                             mkForAllTy alphaTyVar $
                             alphaTy `mkFunTy` (m_ty `mkAppTy` alphaTy)
-                  ; return_op' <- tcSyntaxOp MCompOrigin return_op
-                                      (bndr_ty `mkFunTy` m_bndr_ty)
 
         ; return (ParStmt pairs' mzip_op' bind_op' return_op', thing) }
 
- where mk_tuple_ty tys = foldr (\tn tm -> mkBoxedTupleTy [tn, tm]) (last tys) (init tys)
+ where mk_tuple_ty tys = foldr1 (\tn tm -> mkBoxedTupleTy [tn, tm]) tys
 
        -- loop :: Type                                  -- m_ty
        --      -> [([LStmt Name], [Name])]
@@ -774,7 +778,7 @@ tcMcStmt ctxt (ParStmt bndr_stmts_s mzip_op bind_op return_op) res_ty thing_insi
          = do { -- type dummy since we don't know all binder types yet
                 ty_dummy <- newFlexiTyVarTy liftedTypeKind
               ; (stmts', (ids, pairs', thing))
-                   <- tcStmts ctxt tcMcStmt stmts ty_dummy $ \res_ty' ->
+                   <- tcStmtsAndThen ctxt tcMcStmt stmts ty_dummy $ \res_ty' ->
                       do { ids <- tcLookupLocalIds names
                          ; _ <- unifyType res_ty' (m_ty `mkAppTy` mkBigCoreVarTupTy ids)
                          ; (pairs', thing) <- loop m_ty pairs
@@ -790,9 +794,9 @@ tcMcStmt _ stmt _ _
 
 tcDoStmt :: TcStmtChecker
 
-tcDoStmt ctxt (LastStmt body _) res_ty thing_inside
-  = do { body' <- tcMonoExpr body res_ty
-       ; thing <- thing_inside body_ty
+tcDoStmt _ (LastStmt body _) res_ty thing_inside
+  = do { body' <- tcMonoExprNC body res_ty
+       ; thing <- thing_inside (panic "tcDoStmt: thing_inside")
        ; return (LastStmt body' noSyntaxExpr, thing) }
 
 tcDoStmt ctxt (BindStmt pat rhs bind_op fail_op) res_ty thing_inside
@@ -849,7 +853,7 @@ tcDoStmt ctxt (RecStmt { recS_stmts = stmts, recS_later_ids = later_names
         ; tcExtendIdEnv tup_ids $ do
         { stmts_ty <- newFlexiTyVarTy liftedTypeKind
         ; (stmts', (ret_op', tup_rets))
-                <- tcStmts ctxt tcDoStmt stmts stmts_ty   $ \ inner_res_ty ->
+                <- tcStmtsAndThen ctxt tcDoStmt stmts stmts_ty   $ \ inner_res_ty ->
                    do { tup_rets <- zipWithM tcCheckId tup_names tup_elt_tys
                              -- Unify the types of the "final" Ids (which may 
                              -- be polymorphic) with those of "knot-tied" Ids
@@ -916,9 +920,9 @@ tcMDoStmt tc_rhs ctxt (RecStmt { recS_stmts = stmts, recS_later_ids = laterNames
                                , recS_rec_ids = recNames }) res_ty thing_inside
   = do { rec_tys <- newFlexiTyVarTys (length recNames) liftedTypeKind
        ; let rec_ids = zipWith mkLocalId recNames rec_tys
-       ; tcExtendIdEnv rec_ids                 $ do
+       ; tcExtendIdEnv rec_ids $ do
        { (stmts', (later_ids, rec_rets))
-               <- tcStmts ctxt (tcMDoStmt tc_rhs) stmts res_ty $ \ _res_ty' ->
+               <- tcStmtsAndThen ctxt (tcMDoStmt tc_rhs) stmts res_ty  $ \ _res_ty' ->
                        -- ToDo: res_ty not really right
                   do { rec_rets <- zipWithM tcCheckId recNames rec_tys
                      ; later_ids <- tcLookupLocalIds laterNames
@@ -930,12 +934,13 @@ tcMDoStmt tc_rhs ctxt (RecStmt { recS_stmts = stmts, recS_later_ids = laterNames
                --      some of them with polymorphic things with the same Name
                --      (see note [RecStmt] in HsExpr)
 
-        ; return (RecStmt stmts' later_ids rec_ids noSyntaxExpr noSyntaxExpr noSyntaxExpr rec_rets, thing)
+        ; return (emptyRecStmt { recS_stmts = stmts', recS_later_ids = later_ids
+                               , recS_rec_ids = rec_ids, recS_rec_rets = rec_rets
+                               , recS_ret_ty = res_ty }, thing)
        }}
 
 tcMDoStmt _ _ stmt _ _
   = pprPanic "tcMDoStmt: unexpected Stmt" (ppr stmt)
-
 \end{code}