[project @ 1998-12-02 13:17:09 by simonm]
[ghc-hetmet.git] / ghc / compiler / typecheck / TcGRHSs.lhs
index 9dd435a..ce685fa 100644 (file)
@@ -1,30 +1,34 @@
 %
-% (c) The GRASP/AQUA Project, Glasgow University, 1992-1995
+% (c) The GRASP/AQUA Project, Glasgow University, 1992-1998
 %
 \section[TcGRHSs]{Typecheck guarded right-hand-sides}
 
 \begin{code}
-module TcGRHSs ( tcGRHSsAndBinds, tcStmt ) where
+module TcGRHSs ( tcGRHSsAndBinds, tcStmts ) where
 
 #include "HsVersions.h"
 
 import {-# SOURCE #-}  TcExpr( tcExpr )
 
-import HsSyn           ( HsBinds(..), GRHSsAndBinds(..), GRHS(..), DoOrListComp(..), 
-                         Stmt(..),
-                         collectPatBinders
+import HsSyn           ( HsBinds(..), GRHSsAndBinds(..), GRHS(..), StmtCtxt(..), 
+                         Stmt(..)
                        )
 import RnHsSyn         ( RenamedGRHSsAndBinds, RenamedGRHS, RenamedStmt )
 import TcHsSyn         ( TcGRHSsAndBinds, TcGRHS, TcStmt )
 
+import TcEnv           ( tcExtendGlobalTyVars, tcExtendEnvWithPat )
 import TcMonad
-import Inst            ( Inst, LIE, plusLIE )
+import Inst            ( LIE, plusLIE )
 import TcBinds         ( tcBindsAndThen )
+import TcSimplify      ( tcSimplifyAndCheck )
 import TcPat           ( tcPat )
+import TcMonoType      ( checkSigTyVars, noSigs, existentialPatCtxt )
 import TcType          ( TcType, newTyVarTy ) 
-import TcEnv           ( newMonoIds )
 import TysWiredIn      ( boolTy )
-import Kind            ( mkTypeKind, mkBoxedTypeKind )
+import Type            ( tyVarsOfType, openTypeKind, boxedTypeKind )
+import BasicTypes      ( RecFlag(..) )
+import VarSet
+import Bag
 import Outputable
 \end{code}
 
@@ -36,28 +40,21 @@ import Outputable
 %************************************************************************
 
 \begin{code}
-tcGRHSs :: TcType s -> [RenamedGRHS] -> TcM s ([TcGRHS s], LIE s)
+tcGRHSs :: [RenamedGRHS] -> TcType s -> StmtCtxt -> TcM s ([TcGRHS s], LIE s)
 
-tcGRHSs expected_ty [grhs]
-  = tcGRHS expected_ty grhs            `thenTc` \ (grhs', lie) ->
+tcGRHSs [grhs] expected_ty ctxt
+  = tcGRHS grhs expected_ty ctxt       `thenTc` \ (grhs', lie) ->
     returnTc ([grhs'], lie)
 
-tcGRHSs expected_ty (grhs:grhss)
-  = tcGRHS  expected_ty grhs   `thenTc` \ (grhs',  lie1) ->
-    tcGRHSs expected_ty grhss  `thenTc` \ (grhss', lie2) ->
+tcGRHSs (grhs:grhss) expected_ty ctxt
+  = tcGRHS  grhs  expected_ty ctxt     `thenTc` \ (grhs',  lie1) ->
+    tcGRHSs grhss expected_ty ctxt     `thenTc` \ (grhss', lie2) ->
     returnTc (grhs' : grhss', lie1 `plusLIE` lie2)
 
-tcGRHS expected_ty (GRHS guard expr locn)
-  = tcAddSrcLoc locn           $
-    tcStmts guard              `thenTc` \ ((guard', expr'), lie) ->
-    returnTc (GRHS guard' expr' locn, lie)
-  where
-    tcStmts []          = tcExpr expr expected_ty        `thenTc`    \ (expr2, expr_lie) ->
-                          returnTc (([], expr2), expr_lie)
-    tcStmts (stmt:stmts) = tcStmt Guard (\x->x) combine stmt $
-                          tcStmts stmts
-
-    combine stmt _ (stmts, expr) = (stmt:stmts, expr)
+tcGRHS (GRHS guarded locn) expected_ty ctxt
+  = tcAddSrcLoc locn                                   $
+    tcStmts ctxt (\ty -> ty) guarded expected_ty       `thenTc` \ (guarded', lie) ->
+    returnTc (GRHS guarded' locn, lie)
 \end{code}
 
 
@@ -71,22 +68,19 @@ tcGRHS expected_ty (GRHS guard expr locn)
 pieces.
 
 \begin{code}
-tcGRHSsAndBinds :: TcType s                    -- Expected type of RHSs
-               -> RenamedGRHSsAndBinds
+tcGRHSsAndBinds :: RenamedGRHSsAndBinds
+               -> TcType s                     -- Expected type of RHSs
+               -> StmtCtxt 
                -> TcM s (TcGRHSsAndBinds s, LIE s)
 
--- Shortcut for common case
-tcGRHSsAndBinds expected_ty (GRHSsAndBindsIn grhss EmptyBinds) 
-  = tcGRHSs expected_ty grhss         `thenTc` \ (grhss', lie) ->
-    returnTc (GRHSsAndBindsOut grhss' EmptyBinds expected_ty, lie)
-
-tcGRHSsAndBinds expected_ty (GRHSsAndBindsIn grhss binds)
+tcGRHSsAndBinds (GRHSsAndBindsIn grhss binds) expected_ty ctxt
   = tcBindsAndThen
         combiner binds
-        (tcGRHSs expected_ty grhss)
+        (tcGRHSs grhss expected_ty ctxt        `thenTc` \ (grhss, lie) ->
+         returnTc (GRHSsAndBindsOut grhss EmptyBinds expected_ty, lie))
   where
-    combiner is_rec binds grhss
-       = GRHSsAndBindsOut grhss (MonoBind binds [] is_rec) expected_ty
+    combiner is_rec mbinds (GRHSsAndBindsOut grhss binds expected_ty)
+       = GRHSsAndBindsOut grhss (MonoBind mbinds [] is_rec `ThenBinds` binds) expected_ty
 \end{code}
 
 
@@ -98,87 +92,107 @@ tcGRHSsAndBinds expected_ty (GRHSsAndBindsIn grhss binds)
 
 
 \begin{code}
-tcStmt :: DoOrListComp
-       -> (TcType s -> TcType s)               -- Relationship type of pat and rhs in pat <- rhs
-       -> (TcStmt s -> Maybe (TcType s) -> thing -> thing)
-       -> RenamedStmt
-       -> TcM s (thing, LIE s)
-       -> TcM s (thing, LIE s)
-
-tcStmt do_or_lc m combine stmt@(ReturnStmt exp) do_next
-  = ASSERT( case do_or_lc of { DoStmt -> False; ListComp -> True; Guard -> True } )
-    tcSetErrCtxt (stmtCtxt do_or_lc stmt) (
-        newTyVarTy mkTypeKind                `thenNF_Tc` \ exp_ty ->
-       tcExpr exp exp_ty                    `thenTc`    \ (exp', exp_lie) ->
-       returnTc (ReturnStmt exp', exp_lie, m exp_ty)
-    )                                  `thenTc` \ (stmt', stmt_lie, stmt_ty) ->
-    do_next                            `thenTc` \ (thing', thing_lie) ->
-    returnTc (combine stmt' (Just stmt_ty) thing',
-             stmt_lie `plusLIE` thing_lie)
-
-tcStmt do_or_lc m combine stmt@(GuardStmt exp src_loc) do_next
-  = ASSERT( case do_or_lc of { DoStmt -> False; ListComp -> True; Guard -> True } )
-    newTyVarTy mkTypeKind                    `thenNF_Tc` \ exp_ty ->
+tcStmts :: StmtCtxt
+        -> (TcType s -> TcType s)      -- m, the relationship type of pat and rhs in pat <- rhs
+        -> [RenamedStmt]
+       -> TcType s                     -- elt_ty, where type of the comprehension is (m elt_ty)
+        -> TcM s ([TcStmt s], LIE s)
+
+tcStmts do_or_lc m (stmt@(ReturnStmt exp) : stmts) elt_ty
+  = ASSERT( null stmts )
+    tcSetErrCtxt (stmtCtxt do_or_lc stmt)      $
+    tcExpr exp elt_ty                          `thenTc`    \ (exp', exp_lie) ->
+    returnTc ([ReturnStmt exp'], exp_lie)
+
+       -- ExprStmt at the end
+tcStmts do_or_lc m [stmt@(ExprStmt exp src_loc)] elt_ty
+  = tcSetErrCtxt (stmtCtxt do_or_lc stmt)      $
+    tcExpr exp (m elt_ty)                      `thenTc`    \ (exp', exp_lie) ->
+    returnTc ([ExprStmt exp' src_loc], exp_lie)
+
+       -- ExprStmt not at the end
+tcStmts do_or_lc m (stmt@(ExprStmt exp src_loc) : stmts) elt_ty
+  = ASSERT( isDoStmt do_or_lc )
     tcAddSrcLoc src_loc                (
-    tcSetErrCtxt (stmtCtxt do_or_lc stmt) (
-       tcExpr exp boolTy               `thenTc`    \ (exp', exp_lie) ->
-       returnTc (GuardStmt exp' src_loc, exp_lie)
-    ))                                 `thenTc` \ (stmt', stmt_lie) ->
-    do_next                            `thenTc` \ (thing', thing_lie) ->
-    returnTc (combine stmt' Nothing thing',
-             stmt_lie `plusLIE` thing_lie)
-
-tcStmt do_or_lc m combine stmt@(ExprStmt exp src_loc) do_next
-  = ASSERT( case do_or_lc of { DoStmt -> True; ListComp -> False; Guard -> False } )
-    newTyVarTy mkTypeKind                    `thenNF_Tc` \ exp_ty ->
-    tcAddSrcLoc src_loc                (
-    tcSetErrCtxt (stmtCtxt do_or_lc stmt)      (
-       newTyVarTy mkTypeKind           `thenNF_Tc` \ tau ->
-       let
+       tcSetErrCtxt (stmtCtxt do_or_lc stmt)   $
            -- exp has type (m tau) for some tau (doesn't matter what)
-           exp_ty = m tau
-       in
-       tcExpr exp exp_ty               `thenTc`    \ (exp', exp_lie) ->
-       returnTc (ExprStmt exp' src_loc, exp_lie, exp_ty)
-    ))                                 `thenTc` \ (stmt',  stmt_lie, stmt_ty) ->
-    do_next                            `thenTc` \ (thing', thing_lie) ->
-    returnTc (combine stmt' (Just stmt_ty) thing',
-             stmt_lie `plusLIE` thing_lie)
-
-tcStmt do_or_lc m combine stmt@(BindStmt pat exp src_loc) do_next
-  = newMonoIds (collectPatBinders pat) mkBoxedTypeKind $ \ _ ->
-    tcAddSrcLoc src_loc                (
-    tcSetErrCtxt (stmtCtxt do_or_lc stmt)      (
-       tcPat pat                       `thenTc`    \ (pat', pat_lie, pat_ty) ->  
-       tcExpr exp (m pat_ty)           `thenTc`    \ (exp', exp_lie) ->
-
-       -- NB: the environment has been extended with the new binders
-       -- which the rhs can't "see", but the renamer should have made
-       -- sure that everything is distinct by now, so there's no problem.
-       -- Putting the tcExpr before the newMonoIds messes up the nesting
-       -- of error contexts, so I didn't  bother
-
-       returnTc (BindStmt pat' exp' src_loc, pat_lie `plusLIE` exp_lie)
-    ))                                 `thenTc` \ (stmt', stmt_lie) ->
-    do_next                            `thenTc` \ (thing', thing_lie) ->
-    returnTc (combine stmt' Nothing thing',
-             stmt_lie `plusLIE` thing_lie)
-
-tcStmt do_or_lc m combine (LetStmt binds) do_next
+       newTyVarTy openTypeKind                 `thenNF_Tc` \ any_ty ->
+       tcExpr exp (m any_ty)
+    )                                  `thenTc` \ (exp', exp_lie) ->
+    tcStmts do_or_lc m stmts elt_ty    `thenTc` \ (stmts', stmts_lie) ->
+    returnTc (ExprStmt exp' src_loc : stmts',
+             exp_lie `plusLIE` stmts_lie)
+
+tcStmts do_or_lc m (stmt@(GuardStmt exp src_loc) : stmts) elt_ty
+  = ASSERT( not (isDoStmt do_or_lc) )
+    tcSetErrCtxt (stmtCtxt do_or_lc stmt) (
+       tcAddSrcLoc src_loc             $
+       tcExpr exp boolTy
+    )                                  `thenTc` \ (exp', exp_lie) ->
+    tcStmts do_or_lc m stmts elt_ty    `thenTc` \ (stmts', stmts_lie) ->
+    returnTc (GuardStmt exp' src_loc : stmts',
+             exp_lie `plusLIE` stmts_lie)
+
+tcStmts do_or_lc m (stmt@(BindStmt pat exp src_loc) : stmts) elt_ty
+  = tcAddSrcLoc src_loc                (
+       tcSetErrCtxt (stmtCtxt do_or_lc stmt)   $
+       newTyVarTy boxedTypeKind                `thenNF_Tc` \ pat_ty ->
+       tcPat noSigs pat pat_ty                 `thenTc` \ (pat', pat_lie, pat_tvs, pat_ids, avail) ->  
+       tcExpr exp (m pat_ty)                   `thenTc` \ (exp', exp_lie) ->
+       returnTc (pat', exp',
+                 pat_lie `plusLIE` exp_lie,
+                 pat_tvs, pat_ids, avail)
+    )                                  `thenTc` \ (pat', exp', lie_req, pat_tvs, pat_ids, lie_avail) ->
+
+       -- Do the rest; we don't need to add the pat_tvs to the envt
+       -- because they all appear in the pat_ids's types
+    tcExtendEnvWithPat pat_ids (
+       tcStmts do_or_lc m stmts elt_ty
+    )                                          `thenTc` \ (stmts', stmts_lie) ->
+
+
+       -- Reinstate context for existential checks
+    tcSetErrCtxt (stmtCtxt do_or_lc stmt)              $
+    tcExtendGlobalTyVars (tyVarsOfType (m elt_ty))     $
+    tcAddErrCtxtM (existentialPatCtxt pat_tvs pat_ids) $
+
+    checkSigTyVars (bagToList pat_tvs)                 `thenTc` \ zonked_pat_tvs ->
+
+    tcSimplifyAndCheck 
+       (text ("the existential context of a data constructor"))
+       (mkVarSet zonked_pat_tvs)
+       lie_avail stmts_lie                     `thenTc` \ (final_lie, dict_binds) ->
+
+    returnTc (BindStmt pat' exp' src_loc : 
+               LetStmt (MonoBind dict_binds [] Recursive) :
+                 stmts',
+             lie_req `plusLIE` final_lie)
+
+tcStmts do_or_lc m (LetStmt binds : stmts) elt_ty
      = tcBindsAndThen          -- No error context, but a binding group is
-       combine'                -- rather a large thing for an error context anyway
+       combine                 -- rather a large thing for an error context anyway
        binds
-       do_next
+       (tcStmts do_or_lc m stmts elt_ty)
      where
-       combine' is_rec binds' thing' = combine (LetStmt (MonoBind binds' [] is_rec)) Nothing thing'
+       combine is_rec binds' stmts' = LetStmt (MonoBind binds' [] is_rec) : stmts'
+
 
+isDoStmt DoStmt = True
+isDoStmt other  = False
 
 stmtCtxt do_or_lc stmt
-  = hang (ptext SLIT("In a") <+> whatever <> colon)
+  = hang (ptext SLIT("In") <+> what <> colon)
          4 (ppr stmt)
   where
-    whatever = case do_or_lc of
-                ListComp -> ptext SLIT("list-comprehension qualifier")
-                DoStmt   -> ptext SLIT("do statement")
-                Guard    -> ptext SLIT("guard")
+    what = case do_or_lc of
+               ListComp -> ptext SLIT("a list-comprehension qualifier")
+               DoStmt   -> ptext SLIT("a do statement:")
+               PatBindRhs -> thing <+> ptext SLIT("a pattern binding")
+               FunRhs f   -> thing <+> ptext SLIT("an equation for") <+> quotes (ppr f)
+               CaseAlt    -> thing <+> ptext SLIT("a case alternative")
+               LambdaBody -> thing <+> ptext SLIT("a lambda abstraction")
+    thing = case stmt of
+               BindStmt _ _ _ -> ptext SLIT("a pattern guard for")
+               GuardStmt _ _  -> ptext SLIT("a guard for")
+               ExprStmt _ _   -> ptext SLIT("the right-hand side of")
 \end{code}