[project @ 2005-03-18 13:37:27 by simonmar]
[ghc-hetmet.git] / ghc / compiler / typecheck / TcPat.lhs
index f831b75..a6d9d1d 100644 (file)
@@ -4,7 +4,7 @@
 \section[TcPat]{Typechecking patterns}
 
 \begin{code}
-module TcPat ( tcPat, tcPats, PatCtxt(..), badFieldCon, polyPatSig ) where
+module TcPat ( tcPat, tcPats, PatCtxt(..), badFieldCon, polyPatSig, refineTyVars ) where
 
 #include "HsVersions.h"
 
@@ -18,29 +18,33 @@ import Inst         ( InstOrigin(..),
                          instToId, tcInstStupidTheta, tcSyntaxName
                        )
 import Id              ( Id, idType, mkLocalId )
+import Var             ( tyVarName )
 import Name            ( Name )
 import TcSimplify      ( tcSimplifyCheck, bindInstsOfLocalFuns )
-import TcEnv           ( newLocalName, tcExtendIdEnv1, tcExtendTyVarEnv,
+import TcEnv           ( newLocalName, tcExtendIdEnv1, tcExtendTyVarEnv2,
                          tcLookupClass, tcLookupDataCon, tcLookupId )
-import TcMType                 ( newTyFlexiVarTy, arityErr, tcSkolTyVars, isRigidType )
+import TcMType                 ( newTyFlexiVarTy, arityErr, tcSkolTyVars, readMetaTyVar )
 import TcType          ( TcType, TcTyVar, TcSigmaType, TcTauType, zipTopTvSubst,
-                         SkolemInfo(PatSkol), isSkolemTyVar, pprSkolemTyVar, 
+                         SkolemInfo(PatSkol), isSkolemTyVar, isMetaTyVar, pprTcTyVar, 
+                         TvSubst, mkOpenTvSubst, substTyVar, substTy, MetaDetails(..),
                          mkTyVarTys, mkClassPred, mkTyConApp, isOverloadedTy )
+import VarEnv          ( mkVarEnv )    -- ugly
 import Kind            ( argTypeKind, liftedTypeKind )
 import TcUnify         ( tcSubPat, Expected(..), zapExpectedType, 
                          zapExpectedTo, zapToListTy, zapToTyConApp )  
 import TcHsType                ( UserTypeCtxt(..), TcSigInfo( sig_tau ), TcSigFun, tcHsPatSigType )
 import TysWiredIn      ( stringTy, parrTyCon, tupleTyCon )
-import Unify           ( MaybeErr(..), tcRefineTys, tcMatchTys )
+import Unify           ( MaybeErr(..), gadtRefineTys, BindFlag(..) )
 import Type            ( substTys, substTheta )
-import CmdLineOpts     ( opt_IrrefutableTuples )
+import StaticFlags     ( opt_IrrefutableTuples )
 import TyCon           ( TyCon )
 import DataCon         ( DataCon, dataConTyCon, isVanillaDataCon, dataConInstOrigArgTys,
                          dataConFieldLabels, dataConSourceArity, dataConSig )
 import PrelNames       ( eqStringName, eqName, geName, negateName, minusName, 
                          integralClassName )
 import BasicTypes      ( isBoxed )
-import SrcLoc          ( Located(..), noLoc, unLoc )
+import SrcLoc          ( Located(..), SrcSpan, noLoc, unLoc )
+import Maybes          ( catMaybes )
 import ErrUtils                ( Message )
 import Outputable
 import FastString
@@ -114,14 +118,12 @@ tcCheckPats ctxt pats tys thing_inside    -- A trivial wrapper
 %************************************************************************
 
 \begin{code}
-data PatCtxt = LamPat Bool | LetPat TcSigFun
-       -- True <=> we are checking the case expression, 
-       --              so can do full-blown refinement
-       -- False <=> inferring, do no refinement
+data PatCtxt = LamPat          -- Used for lambda, case, do-notation etc
+            | LetPat TcSigFun  -- Used for let(rec) bindings
 
 -------------------
 tcPatBndr :: PatCtxt -> Name -> Expected TcSigmaType -> TcM TcId
-tcPatBndr (LamPat _) bndr_name pat_ty
+tcPatBndr LamPat bndr_name pat_ty
   = do { pat_ty' <- zapExpectedType pat_ty argTypeKind
                -- If pat_ty is Expected, this returns the appropriate
                -- SigmaType.  In Infer mode, we create a fresh type variable.
@@ -240,8 +242,13 @@ tc_pat ctxt (SigPatIn pat sig) pat_ty thing_inside
   = do {       -- See Note [Pattern coercions] below
          (sig_tvs, sig_ty) <- tcHsPatSigType PatSigCtxt sig
        ; tcSubPat sig_ty pat_ty
-       ; (pat', tvs, res) <- tcExtendTyVarEnv sig_tvs $
-                             tc_lpat ctxt pat (Check sig_ty) thing_inside
+       ; subst <- refineTyVars sig_tvs -- See note [Type matching]
+       ; let tv_binds = [(tyVarName tv, substTyVar subst tv) | tv <- sig_tvs]
+             sig_ty'  = substTy subst sig_ty
+       ; (pat', tvs, res) 
+             <- tcExtendTyVarEnv2 tv_binds $
+                tc_lpat ctxt pat (Check sig_ty') thing_inside
+
        ; return (SigPatOut pat' sig_ty, tvs, res) }
 
 tc_pat ctxt pat@(TypePat ty) pat_ty thing_inside
@@ -283,7 +290,7 @@ tc_pat ctxt pat_in@(ConPatIn (L con_span con_name) arg_pats) pat_ty thing_inside
   = do { data_con <- tcLookupDataCon con_name
        ; let tycon = dataConTyCon data_con
        ; ty_args <- zapToTyConApp tycon pat_ty
-       ; (pat', tvs, res) <- tcConPat ctxt data_con tycon ty_args arg_pats thing_inside
+       ; (pat', tvs, res) <- tcConPat ctxt con_span data_con tycon ty_args arg_pats thing_inside
        ; return (pat', tvs, res) }
 
 
@@ -361,16 +368,16 @@ tc_pat ctxt pat@(NPlusKPatIn (L nm_loc name) lit@(HsIntegral i _) minus_name) pa
 %************************************************************************
 
 \begin{code}
-tcConPat :: PatCtxt -> DataCon -> TyCon -> [TcTauType] 
+tcConPat :: PatCtxt -> SrcSpan -> DataCon -> TyCon -> [TcTauType] 
         -> HsConDetails Name (LPat Name) -> TcM a
         -> TcM (Pat TcId, [TcTyVar], a)
-tcConPat ctxt data_con tycon ty_args arg_pats thing_inside
+tcConPat ctxt span data_con tycon ty_args arg_pats thing_inside
   | isVanillaDataCon data_con
   = do { let arg_tys = dataConInstOrigArgTys data_con ty_args
        ; tcInstStupidTheta data_con ty_args
        ; traceTc (text "tcConPat" <+> vcat [ppr data_con, ppr ty_args, ppr arg_tys])
        ; (arg_pats', tvs, res) <- tcConArgs ctxt data_con arg_pats arg_tys thing_inside
-       ; return (ConPatOut data_con [] [] emptyLHsBinds 
+       ; return (ConPatOut (L span data_con) [] [] emptyLHsBinds 
                            arg_pats' (mkTyConApp tycon ty_args),
                  tvs, res) }
 
@@ -385,19 +392,22 @@ tcConPat ctxt data_con tycon ty_args arg_pats thing_inside
              arg_tys' = substTys tenv arg_tys
              res_tys' = substTys tenv res_tys
        ; dicts <- newDicts (SigOrigin rigid_info) theta'
-       ; tcInstStupidTheta data_con tv_tys'
 
        -- Do type refinement!
        ; traceTc (text "tcGadtPat" <+> vcat [ppr data_con, ppr tvs', ppr arg_tys', ppr res_tys', 
                                              text "ty-args:" <+> ppr ty_args ])
        ; refineAlt ctxt data_con tvs' ty_args res_tys' $ do    
 
-       { ((arg_pats', inner_tvs, res), lie_req) 
-               <- getLIE (tcConArgs ctxt data_con arg_pats arg_tys' thing_inside)
+       { ((arg_pats', inner_tvs, res), lie_req) <- getLIE $
+               do { tcInstStupidTheta data_con tv_tys'
+                       -- The stupid-theta mentions the newly-bound tyvars, so
+                       -- it must live inside the getLIE, so that the
+                       --  tcSimplifyCheck will apply the type refinement to it
+                  ; tcConArgs ctxt data_con arg_pats arg_tys' thing_inside }
 
        ; dict_binds <- tcSimplifyCheck doc tvs' dicts lie_req
 
-       ; return (ConPatOut data_con 
+       ; return (ConPatOut (L span data_con)
                            tvs' (map instToId dicts) dict_binds
                            arg_pats' (mkTyConApp tycon ty_args),
                  tvs' ++ inner_tvs, res) } }
@@ -488,17 +498,50 @@ refineAlt :: PatCtxt -> DataCon
            -> TcM a -> TcM a
 refineAlt ctxt con ex_tvs ctxt_tys pat_tys thing_inside 
   = do { old_subst <- getTypeRefinement
-       ; let refiner | can_i_refine ctxt = tcRefineTys
-                     | otherwise         = tcMatchTys
-       ; case refiner ex_tvs old_subst pat_tys ctxt_tys of
+       ; case gadtRefineTys bind_fn old_subst pat_tys ctxt_tys of
                Failed msg -> failWithTc (inaccessibleAlt msg)
                Succeeded new_subst -> do {
          traceTc (text "refineTypes:match" <+> ppr con <+> ppr new_subst)
        ; setTypeRefinement new_subst thing_inside } }
 
   where
-    can_i_refine (LamPat can_refine) = can_refine
-    can_i_refine other_ctxt         = False
+    bind_fn tv | isMetaTyVar tv = WildCard     -- Wobbly types behave as wild cards
+              | otherwise      = BindMe
+\end{code}
+
+Note [Type matching]
+~~~~~~~~~~~~~~~~~~~~
+This little function @refineTyVars@ is a little tricky.  Suppose we have a pattern type
+signature
+       f = \(x :: Term a) -> body
+Then 'a' should be bound to a wobbly type.  But if we have
+       f :: Term b -> some-type
+       f = \(x :: Term a) -> body
+then 'a' should be bound to the rigid type 'b'.  So we want to
+       * instantiate the type sig with fresh meta tyvars (e.g. \alpha)
+       * unify with the type coming from the context
+       * read out whatever information has been gleaned
+               from that unificaiton (e.g. unifying \alpha with 'b')
+       * and replace \alpha by 'b' in the type, when typechecking the
+               pattern inside the type sig (x in this case)
+It amounts to combining rigid info from the context and from the sig.
+
+Exactly the same thing happens for 'smart function application'.
+
+\begin{code}
+refineTyVars :: [TcTyVar]      -- Newly instantiated meta-tyvars of the function
+            -> TcM TvSubst     -- Substitution mapping any of the meta-tyvars that
+                               -- has been unifies to what it was instantiated to
+-- Just one level of de-wobblification though.  What a hack! 
+refineTyVars tvs
+  = do { mb_prs <- mapM mk_pr tvs
+       ; return (mkOpenTvSubst (mkVarEnv (catMaybes mb_prs))) }
+  where
+    mk_pr tv = do { details <- readMetaTyVar tv
+                 ; case details of
+                       Indirect ty -> return (Just (tv,ty))
+                       other       -> return Nothing 
+                 }
 \end{code}
 
 %************************************************************************
@@ -591,9 +634,7 @@ badTypePat pat = ptext SLIT("Illegal type pattern") <+> ppr pat
 lazyPatErr pat tvs
   = failWithTc $
     hang (ptext SLIT("A lazy (~) pattern connot bind existential type variables"))
-       2 (vcat (map get tvs))
-  where
-   get tv = ASSERT( isSkolemTyVar tv ) pprSkolemTyVar tv
+       2 (vcat (map pprTcTyVar tvs))
 
 inaccessibleAlt msg
   = hang (ptext SLIT("Inaccessible case alternative:")) 2 msg