Fix scoped type variables for expression type signatures
[ghc-hetmet.git] / compiler / deSugar / MatchCon.lhs
index 6ff502a..fd840e6 100644 (file)
@@ -8,24 +8,21 @@ module MatchCon ( matchConFamily ) where
 
 #include "HsVersions.h"
 
-import Id( idType )
-
 import {-# SOURCE #-} Match    ( match )
 
-import HsSyn           ( Pat(..), HsConDetails(..) )
+import HsSyn           ( Pat(..), LPat, HsConDetails(..) )
 import DsBinds         ( dsLHsBinds )
-import DataCon         ( isVanillaDataCon, dataConInstOrigArgTys )
+import DataCon         ( DataCon, dataConInstOrigArgTys, dataConEqSpec,
+                         dataConFieldLabels, dataConSourceArity )
 import TcType          ( tcTyConAppArgs )
 import Type            ( mkTyVarTys )
 import CoreSyn
 import DsMonad
 import DsUtils
 
-import Id              ( Id )
+import Id              ( Id, idName )
 import Type             ( Type )
-import ListSetOps      ( equivClassesByUniq )
 import SrcLoc          ( unLoc, Located(..) )
-import Unique          ( Uniquable(..) )
 import Outputable
 \end{code}
 
@@ -82,63 +79,62 @@ have-we-used-all-the-constructors? question; the local function
 \begin{code}
 matchConFamily :: [Id]
                -> Type
-              -> [EquationInfo]
+              -> [[EquationInfo]]
               -> DsM MatchResult
-matchConFamily (var:vars) ty eqns_info
-  = let
-       -- Sort into equivalence classes by the unique on the constructor
-       -- All the EqnInfos should start with a ConPat
-       groups = equivClassesByUniq get_uniq eqns_info
-       get_uniq (EqnInfo { eqn_pats = ConPatOut (L _ data_con) _ _ _ _ _ : _}) = getUnique data_con
-
-       -- Get the wrapper from the head of each group.  We're going to
-       -- use it as the pattern in this case expression, so we need to 
-       -- ensure that any type variables it mentions in the pattern are
-       -- in scope.  So we put its wrappers outside the case, and
-       -- zap the wrapper for it. 
-       wraps :: [CoreExpr -> CoreExpr]
-       wraps = map (eqn_wrap . head) groups
-
-       groups' = [ eqn { eqn_wrap = idWrapper } : eqns | eqn:eqns <- groups ]
-    in
-       -- Now make a case alternative out of each group
-    mappM (match_con vars ty) groups'  `thenDs` \ alts ->
-    returnDs (adjustMatchResult (foldr (.) idWrapper wraps) $
-             mkCoAlgCaseMatchResult var ty alts)
-\end{code}
-
-And here is the local function that does all the work.  It is
-more-or-less the @matchCon@/@matchClause@ functions on page~94 in
-Wadler's chapter in SLPJ.  The function @shift_con_pats@ does what the
-list comprehension in @matchClause@ (SLPJ, p.~94) does, except things
-are trickier in real life.  Works for @ConPats@, and we want it to
-fail catastrophically for anything else (which a list comprehension
-wouldn't).  Cf.~@shift_lit_pats@ in @MatchLits@.
-
-\begin{code}
-match_con vars ty eqns
-  = do { -- Make new vars for the con arguments; avoid new locals where possible
-         arg_vars     <- selectMatchVars (map unLoc arg_pats1) arg_tys
-       ; eqns'        <- mapM shift eqns 
+-- Each group of eqns is for a single constructor
+matchConFamily (var:vars) ty groups
+  = do { alts <- mapM (matchOneCon vars ty) groups
+       ; return (mkCoAlgCaseMatchResult var ty alts) }
+
+matchOneCon vars ty (eqn1 : eqns)      -- All eqns for a single constructor
+  = do { (wraps, eqns') <- mapAndUnzipM shift (eqn1:eqns)
+       ; arg_vars <- selectMatchVars (take (dataConSourceArity con) 
+                                           (eqn_pats (head eqns')))
+               -- Use the new arugment patterns as a source of 
+               -- suggestions for the new variables
        ; match_result <- match (arg_vars ++ vars) ty eqns'
-       ; return (con, tvs1 ++ dicts1 ++ arg_vars, match_result) }
+       ; return (con, tvs1 ++ dicts1 ++ arg_vars, 
+                 adjustMatchResult (foldr1 (.) wraps) match_result) }
   where
-    ConPatOut (L _ con) tvs1 dicts1 _ (PrefixCon arg_pats1) pat_ty = firstPat (head eqns)
-
-    shift eqn@(EqnInfo { eqn_wrap = wrap, 
-                        eqn_pats = ConPatOut _ tvs ds bind (PrefixCon arg_pats) _ : pats })
+    ConPatOut { pat_con = L _ con, pat_ty = pat_ty1,
+               pat_tvs = tvs1, pat_dicts = dicts1 } = firstPat eqn1
+       
+    arg_tys  = dataConInstOrigArgTys con inst_tys
+    n_co_args = length (dataConEqSpec con)
+    inst_tys = tcTyConAppArgs pat_ty1 ++ (drop n_co_args $ mkTyVarTys tvs1)
+       -- Newtypes opaque, hence tcTyConAppArgs
+
+    shift eqn@(EqnInfo { eqn_pats = ConPatOut{ pat_tvs = tvs, pat_dicts = ds, 
+                                              pat_binds = bind, pat_args = args
+                                             } : pats })
        = do { prs <- dsLHsBinds bind
-            ; return (eqn { eqn_wrap = wrap . wrapBinds (tvs `zip` tvs1) 
-                                            . wrapBinds (ds  `zip` dicts1)
-                                            . mkDsLet (Rec prs),
-                            eqn_pats = map unLoc arg_pats ++ pats }) }
-
-       -- Get the arg types, which we use to type the new vars
-       -- to match on, from the "outside"; the types of pats1 may 
-       -- be more refined, and hence won't do
-    arg_tys = dataConInstOrigArgTys con inst_tys
-    inst_tys | isVanillaDataCon con = tcTyConAppArgs pat_ty    -- Newtypes opaque!
-            | otherwise            = mkTyVarTys tvs1
+            ; return (wrapBinds (tvs `zip` tvs1) 
+                      . wrapBinds (ds  `zip` dicts1)
+                      . mkDsLet (Rec prs),
+                      eqn { eqn_pats = conArgPats con arg_tys args ++ pats }) }
+
+conArgPats :: DataCon 
+          -> [Type]    -- Instantiated argument types 
+          -> HsConDetails Id (LPat Id)
+          -> [Pat Id]
+conArgPats data_con arg_tys (PrefixCon ps)   = map unLoc ps
+conArgPats data_con arg_tys (InfixCon p1 p2) = [unLoc p1, unLoc p2]
+conArgPats data_con arg_tys (RecCon rpats)
+  | null rpats
+  =    -- Special case for C {}, which can be used for 
+       -- a constructor that isn't declared to have
+       -- fields at all
+    map WildPat arg_tys
+
+  | otherwise
+  = zipWith mk_pat (dataConFieldLabels data_con) arg_tys
+  where
+       -- mk_pat picks a WildPat of the appropriate type for absent fields,
+       -- and the specified pattern for present fields
+    mk_pat lbl arg_ty
+       = case [ pat | (sel_id,pat) <- rpats, idName (unLoc sel_id) == lbl] of
+           (pat:pats) -> ASSERT( null pats ) unLoc pat
+           []         -> WildPat arg_ty
 \end{code}
 
 Note [Existentials in shift_con_pat]