Mostly-fix Trac #2595: updates for existentials
authorsimonpj@microsoft.com <unknown>
Tue, 28 Oct 2008 11:54:27 +0000 (11:54 +0000)
committersimonpj@microsoft.com <unknown>
Tue, 28 Oct 2008 11:54:27 +0000 (11:54 +0000)
Ganesh wanted to update records that involve existentials.  That
seems reasonable to me, and this patch covers existentials, GADTs,
and data type families.

The restriction is that
  The types of the updated fields may mention only the
  universally-quantified type variables of the data constructor

This doesn't allow everything in #2595 (it allows 'g' but not 'f' in
the ticket), but it gets a lot closer.

Lots of the new lines are comments!

compiler/deSugar/DsExpr.lhs
compiler/deSugar/DsMonad.lhs
compiler/typecheck/TcExpr.lhs
compiler/typecheck/TcHsSyn.lhs
docs/users_guide/glasgow_exts.xml

index 6cbd538..37129d8 100644 (file)
@@ -26,6 +26,7 @@ import DsUtils
 import DsArrows
 import DsMonad
 import Name
+import NameEnv
 
 #ifdef GHCI
 import PrelNames
@@ -40,6 +41,7 @@ import TcHsSyn
 --     needs to see source types
 import TcType
 import Type
+import Coercion
 import CoreSyn
 import CoreUtils
 import MkCore
@@ -52,6 +54,7 @@ import DataCon
 import TysWiredIn
 import BasicTypes
 import PrelNames
+import Maybes
 import SrcLoc
 import Util
 import Bag
@@ -426,52 +429,96 @@ RHSs, and do not generate a Core constructor application directly, because the c
 might do some argument-evaluation first; and may have to throw away some
 dictionaries.
 
+Note [Update for GADTs]
+~~~~~~~~~~~~~~~~~~~~~~~
+Consider 
+   data T a b where
+     T1 { f1 :: a } :: T a Int
+
+Then the wrapper function for T1 has type 
+   $WT1 :: a -> T a Int
+But if x::T a b, then
+   x { f1 = v } :: T a b   (not T a Int!)
+So we need to cast (T a Int) to (T a b).  Sigh.
+
 \begin{code}
 dsExpr expr@(RecordUpd record_expr (HsRecFields { rec_flds = fields })
                       cons_to_upd in_inst_tys out_inst_tys)
   | null fields
   = dsLExpr record_expr
   | otherwise
-  =    -- Record stuff doesn't work for existentials
-       -- The type checker checks for this, but we need 
-       -- worry only about the constructors that are to be updated
-    ASSERT2( notNull cons_to_upd && all isVanillaDataCon cons_to_upd, ppr expr )
+  = ASSERT2( notNull cons_to_upd, ppr expr )
 
     do { record_expr' <- dsLExpr record_expr
-       ; let   -- Awkwardly, for families, the match goes 
-               -- from instance type to family type
-               tycon     = dataConTyCon (head cons_to_upd)
-               in_ty     = mkTyConApp tycon in_inst_tys
-               in_out_ty = mkFunTy in_ty
-                                   (mkFamilyTyConApp tycon out_inst_tys)
-
-               mk_val_arg field old_arg_id 
-                 = case findField fields field  of
-                     (rhs:rest) -> ASSERT(null rest) rhs
-                     []         -> nlHsVar old_arg_id
-
-               mk_alt con
-                 = ASSERT( isVanillaDataCon con )
-                   do  { arg_ids <- newSysLocalsDs (dataConInstOrigArgTys con in_inst_tys)
-                       -- This call to dataConInstOrigArgTys won't work for existentials
-                       -- but existentials don't have record types anyway
-                       ; let val_args = zipWithEqual "dsExpr:RecordUpd" mk_val_arg
-                                               (dataConFieldLabels con) arg_ids
-                             rhs = foldl (\a b -> nlHsApp a b)
-                                         (nlHsTyApp (dataConWrapId con) out_inst_tys)
-                                         val_args
-                             pat = mkPrefixConPat con (map nlVarPat arg_ids) in_ty
-
-                       ; return (mkSimpleMatch [pat] rhs) }
+       ; field_binds' <- mapM ds_field fields
 
        -- It's important to generate the match with matchWrapper,
        -- and the right hand sides with applications of the wrapper Id
        -- so that everything works when we are doing fancy unboxing on the
        -- constructor aguments.
        ; alts <- mapM mk_alt cons_to_upd
-       ; ([discrim_var], matching_code) <- matchWrapper RecUpd (MatchGroup alts in_out_ty)
+       ; ([discrim_var], matching_code) 
+               <- matchWrapper RecUpd (MatchGroup alts in_out_ty)
 
-       ; return (bindNonRec discrim_var record_expr' matching_code) }
+       ; return (add_field_binds field_binds' $
+                 bindNonRec discrim_var record_expr' matching_code) }
+  where
+    ds_field :: HsRecField Id (LHsExpr Id) -> DsM (Id, CoreExpr)
+    ds_field rec_field = do { rhs <- dsLExpr (hsRecFieldArg rec_field)
+                           ; return (unLoc (hsRecFieldId rec_field), rhs) }
+
+    add_field_binds [] expr = expr
+    add_field_binds ((b,r):bs) expr = bindNonRec b r (add_field_binds bs expr)
+
+       -- Awkwardly, for families, the match goes 
+       -- from instance type to family type
+    tycon     = dataConTyCon (head cons_to_upd)
+    in_ty     = mkTyConApp tycon in_inst_tys
+    in_out_ty = mkFunTy in_ty (mkFamilyTyConApp tycon out_inst_tys)
+
+    mk_alt con
+      = do { let (univ_tvs, ex_tvs, eq_spec, 
+                 eq_theta, dict_theta, arg_tys, _) = dataConFullSig con
+                subst = mkTopTvSubst (univ_tvs `zip` in_inst_tys)
+
+               -- I'm not bothering to clone the ex_tvs
+          ; eqs_vars   <- mapM newPredVarDs (substTheta subst (eqSpecPreds eq_spec))
+          ; theta_vars <- mapM newPredVarDs (substTheta subst (eq_theta ++ dict_theta))
+          ; arg_ids    <- newSysLocalsDs (substTys subst arg_tys)
+          ; let val_args = zipWithEqual "dsExpr:RecordUpd" mk_val_arg
+                                        (dataConFieldLabels con) arg_ids
+                inst_con = noLoc $ HsWrap wrap (HsVar (dataConWrapId con))
+                       -- Reconstruct with the WrapId so that unpacking happens
+                wrap = mkWpApps theta_vars `WpCompose` 
+                       mkWpTyApps (mkTyVarTys ex_tvs) `WpCompose`
+                       mkWpTyApps [ty | (tv, ty) <- univ_tvs `zip` out_inst_tys
+                                      , isNothing (lookupTyVar wrap_subst tv) ]
+                rhs = foldl (\a b -> nlHsApp a b) inst_con val_args
+
+                       -- Tediously wrap the application in a cast
+                       -- Note [Update for GADTs]
+                wrapped_rhs | null eq_spec = rhs
+                            | otherwise    = mkLHsWrap (WpCast wrap_co) rhs
+                wrap_co = mkTyConApp tycon [ lookup tv ty 
+                                           | (tv,ty) <- univ_tvs `zip` out_inst_tys]
+                lookup univ_tv ty = case lookupTyVar wrap_subst univ_tv of
+                                       Just ty' -> ty'
+                                       Nothing  -> ty
+                wrap_subst = mkTopTvSubst [ (tv,mkSymCoercion (mkTyVarTy co_var))
+                                          | ((tv,_),co_var) <- eq_spec `zip` eqs_vars ]
+                
+                pat = noLoc $ ConPatOut { pat_con = noLoc con, pat_tvs = ex_tvs
+                                        , pat_dicts = eqs_vars ++ theta_vars
+                                        , pat_binds = emptyLHsBinds 
+                                        , pat_args = PrefixCon $ map nlVarPat arg_ids
+                                        , pat_ty = in_ty }
+          ; return (mkSimpleMatch [pat] wrapped_rhs) }
+
+    upd_field_ids :: NameEnv Id        -- Maps field name to the LocalId of the field binding
+    upd_field_ids = mkNameEnv [ (idName field_id, field_id) 
+                             | rec_fld <- fields, let field_id = unLoc (hsRecFieldId rec_fld) ]
+    mk_val_arg field_name pat_arg_id 
+      = nlHsVar (lookupNameEnv upd_field_ids field_name `orElse` pat_arg_id)
 \end{code}
 
 Here is where we desugar the Template Haskell brackets and escapes
index 145ba9e..83a5d21 100644 (file)
@@ -14,7 +14,7 @@ module DsMonad (
 
        newLocalName,
        duplicateLocalDs, newSysLocalDs, newSysLocalsDs, newUniqueId,
-       newFailLocalDs,
+       newFailLocalDs, newPredVarDs,
        getSrcSpanDs, putSrcSpanDs,
        getModuleDs,
        newUnique, 
@@ -224,12 +224,22 @@ newUniqueId :: Name -> Type -> DsM Id
 newUniqueId id = mkSysLocalM (occNameFS (nameOccName id))
 
 duplicateLocalDs :: Id -> DsM Id
-duplicateLocalDs old_local = do
-    uniq <- newUnique
-    return (setIdUnique old_local uniq)
-
+duplicateLocalDs old_local 
+  = do { uniq <- newUnique
+       ; return (setIdUnique old_local uniq) }
+
+newPredVarDs :: PredType -> DsM Var
+newPredVarDs pred
+ | isEqPred pred
+ = do { uniq <- newUnique; 
+      ; let name = mkSystemName uniq (mkOccNameFS tcName (fsLit "co"))
+           kind = mkPredTy pred
+      ; return (mkCoVar name kind) }
+ | otherwise
+ = newSysLocalDs (mkPredTy pred)
 newSysLocalDs, newFailLocalDs :: Type -> DsM Id
-newSysLocalDs = mkSysLocalM (fsLit "ds")
+newSysLocalDs  = mkSysLocalM (fsLit "ds")
 newFailLocalDs = mkSysLocalM (fsLit "fail")
 
 newSysLocalsDs :: [Type] -> DsM [Id]
index 721c57c..2eb10ef 100644 (file)
@@ -58,6 +58,7 @@ import Maybes
 import Outputable
 import FastString
 
+import Data.List( partition )
 import Control.Monad
 \end{code}
 
@@ -400,149 +401,203 @@ tcExpr expr@(RecordCon (L loc con_name) _ rbinds) res_ty
        ; (con_expr, rbinds') <- tcIdApp con_name arity check_fields res_ty
 
        ; return (RecordCon (L loc (dataConWrapId data_con)) con_expr rbinds') }
+\end{code}
 
--- The main complication with RecordUpd is that we need to explicitly
--- handle the *non-updated* fields.  Consider:
---
---     data T a b = MkT1 { fa :: a, fb :: b }
---                | MkT2 { fa :: a, fc :: Int -> Int }
---                | MkT3 { fd :: a }
---     
---     upd :: T a b -> c -> T a c
---     upd t x = t { fb = x}
---
--- The type signature on upd is correct (i.e. the result should not be (T a b))
--- because upd should be equivalent to:
---
---     upd t x = case t of 
---                     MkT1 p q -> MkT1 p x
---                     MkT2 a b -> MkT2 p b
---                     MkT3 d   -> error ...
---
--- So we need to give a completely fresh type to the result record,
--- and then constrain it by the fields that are *not* updated ("p" above).
---
--- Note that because MkT3 doesn't contain all the fields being updated,
--- its RHS is simply an error, so it doesn't impose any type constraints
---
--- All this is done in STEP 4 below.
---
--- Note about GADTs
--- ~~~~~~~~~~~~~~~~
--- For record update we require that every constructor involved in the
--- update (i.e. that has all the specified fields) is "vanilla".  I
--- don't know how to do the update otherwise.
-
+Note [Type of a record update]
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+The main complication with RecordUpd is that we need to explicitly
+handle the *non-updated* fields.  Consider:
 
-tcExpr expr@(RecordUpd record_expr rbinds _ _ _) res_ty = do
+       data T a b = MkT1 { fa :: a, fb :: b }
+                  | MkT2 { fa :: a, fc :: Int -> Int }
+                  | MkT3 { fd :: a }
+       
+       upd :: T a b -> c -> T a c
+       upd t x = t { fb = x}
+
+The type signature on upd is correct (i.e. the result should not be (T a b))
+because upd should be equivalent to:
+
+       upd t x = case t of 
+                       MkT1 p q -> MkT1 p x
+                       MkT2 a b -> MkT2 p b
+                       MkT3 d   -> error ...
+
+So we need to give a completely fresh type to the result record,
+and then constrain it by the fields that are *not* updated ("p" above).
+
+Note that because MkT3 doesn't contain all the fields being updated,
+its RHS is simply an error, so it doesn't impose any type constraints
+
+Note [Implict type sharing]
+~~~~~~~~~~~~~~~~~~~~~~~~~~~
+We also take into account any "implicit" non-update fields.  For example
+       data T a b where { MkT { f::a } :: T a a; ... }
+So the "real" type of MkT is: forall ab. (a~b) => a -> T a b
+
+Then consider
+       upd t x = t { f=x }
+We infer the type
+       upd :: T a b -> a -> T a b
+       upd (t::T a b) (x::a)
+          = case t of { MkT (co:a~b) (_:a) -> MkT co x }
+We can't give it the more general type
+       upd :: T a b -> c -> T c b
+
+Note [Criteria for update]
+~~~~~~~~~~~~~~~~~~~~~~~~~~
+We want to allow update for existentials etc, provided the updated
+field isn't part of the existential. For example, this should be ok.
+  data T a where { MkT { f1::a, f2::b->b } :: T a }
+  f :: T a -> b -> T b
+  f t b = t { f1=b }
+The criterion we use is this:
+
+  The types of the updated fields
+  mention only the universally-quantified type variables
+  of the data constructor
+
+In principle one could go further, and allow
+  g :: T a -> T a
+  g t = t { f2 = \x -> x }
+because the expression is polymorphic...but that seems a bridge too far.
+
+Note [Data family example]
+~~~~~~~~~~~~~~~~~~~~~~~~~~
+    data instance T (a,b) = MkT { x::a, y::b }
+  --->
+    data :TP a b = MkT { a::a, y::b }
+    coTP a b :: T (a,b) ~ :TP a b
+
+Suppose r :: T (t1,t2), e :: t3
+Then  r { x=e } :: T (t3,t1)
+  --->
+      case r |> co1 of
+       MkT x y -> MkT e y |> co2
+      where co1 :: T (t1,t2) ~ :TP t1 t2
+           co2 :: :TP t3 t2 ~ T (t3,t2)
+The wrapping with co2 is done by the constructor wrapper for MkT
+
+Outgoing invariants
+~~~~~~~~~~~~~~~~~~~
+In the outgoing (HsRecordUpd scrut binds cons in_inst_tys out_inst_tys):
+
+  * cons are the data constructors to be updated
+
+  * in_inst_tys, out_inst_tys have same length, and instantiate the
+       *representation* tycon of the data cons.  In Note [Data 
+       family example], in_inst_tys = [t1,t2], out_inst_tys = [t3,t2]
+       
+\begin{code}
+tcExpr expr@(RecordUpd record_expr rbinds _ _ _) res_ty
+  = do {
        -- STEP 0
        -- Check that the field names are really field names
-    let 
-       field_names = hsRecFields rbinds
-
-    MASSERT( notNull field_names )
-    sel_ids <- mapM tcLookupField field_names
-       -- The renamer has already checked that they
-       -- are all in scope
-    let
-       bad_guys = [ setSrcSpan loc $ addErrTc (notSelector field_name) 
-                  | (fld, sel_id) <- rec_flds rbinds `zip` sel_ids,
-                    not (isRecordSelector sel_id),     -- Excludes class ops
-                    let L loc field_name = hsRecFieldId fld
-                  ]
-
-    unless (null bad_guys) (sequence bad_guys >> failM)
+         let upd_fld_names = hsRecFields rbinds
+       ; MASSERT( notNull upd_fld_names )
+       ; sel_ids <- mapM tcLookupField upd_fld_names
+                       -- The renamer has already checked that
+                       -- selectors are all in scope
+       ; let bad_guys = [ setSrcSpan loc $ addErrTc (notSelector fld_name) 
+                        | (fld, sel_id) <- rec_flds rbinds `zip` sel_ids,
+                          not (isRecordSelector sel_id),       -- Excludes class ops
+                          let L loc fld_name = hsRecFieldId fld ]
+       ; unless (null bad_guys) (sequence bad_guys >> failM)
     
        -- STEP 1
        -- Figure out the tycon and data cons from the first field name
-    let
-               -- It's OK to use the non-tc splitters here (for a selector)
-       sel_id : _      = sel_ids
-       (tycon, _)      = recordSelectorFieldLabel sel_id       -- We've failed already if
-       data_cons       = tyConDataCons tycon                   -- it's not a field label
-               -- NB: for a data type family, the tycon is the instance tycon
-
-       relevant_cons   = filter is_relevant data_cons
-       is_relevant con = all (`elem` dataConFieldLabels con) field_names
-
+       ; let   -- It's OK to use the non-tc splitters here (for a selector)
+             sel_id : _  = sel_ids
+             (tycon, _)  = recordSelectorFieldLabel sel_id     -- We've failed already if
+             data_cons   = tyConDataCons tycon                 -- it's not a field label
+               -- NB: for a data type family, the tycon is the instance tycon
+
+             relevant_cons   = filter is_relevant data_cons
+             is_relevant con = all (`elem` dataConFieldLabels con) upd_fld_names
+               -- A constructor is only relevant to this process if
+               -- it contains *all* the fields that are being updated
+               -- Other ones will cause a runtime error if they occur
+
+               -- Take apart a representative constructor
+             con1 = ASSERT( not (null relevant_cons) ) head relevant_cons
+             (con1_tvs, _, _, _, _, con1_arg_tys, _) = dataConFullSig con1
+             con1_flds = dataConFieldLabels con1
+             con1_res_ty = mkFamilyTyConApp tycon (mkTyVarTys con1_tvs)
+             
        -- STEP 2
        -- Check that at least one constructor has all the named fields
        -- i.e. has an empty set of bad fields returned by badFields
-    checkTc (not (null relevant_cons))
-           (badFieldsUpd rbinds)
-
-       -- Check that all relevant data cons are vanilla.  Doing record updates on 
-       -- GADTs and/or existentials is more than my tiny brain can cope with today
-    checkTc (all isVanillaDataCon relevant_cons)
-           (nonVanillaUpd tycon)
-
-       -- STEP 4
-       -- Use the un-updated fields to find a vector of booleans saying
-       -- which type arguments must be the same in updatee and result.
-       --
-       -- WARNING: this code assumes that all data_cons in a common tycon
-       -- have FieldLabels abstracted over the same tyvars.
-    let
-               -- A constructor is only relevant to this process if
-               -- it contains *all* the fields that are being updated
-       con1 = ASSERT( not (null relevant_cons) ) head relevant_cons    -- A representative constructor
-       (con1_tyvars, theta, con1_arg_tys, con1_res_ty) = dataConSig con1
-       con1_flds     = dataConFieldLabels con1
-       common_tyvars = exactTyVarsOfTypes [ty | (fld,ty) <- con1_flds `zip` con1_arg_tys
-                                              , not (fld `elem` field_names) ]
-
-       is_common_tv tv = tv `elemVarSet` common_tyvars
-
-       mk_inst_ty tv result_inst_ty 
-         | is_common_tv tv = return result_inst_ty             -- Same as result type
-         | otherwise       = newFlexiTyVarTy (tyVarKind tv)    -- Fresh type, of correct kind
-
-    MASSERT( null theta )      -- Vanilla datacon
-    (_, result_inst_tys, result_inst_env) <- tcInstTyVars con1_tyvars
-    scrut_inst_tys <- zipWithM mk_inst_ty con1_tyvars result_inst_tys
-
-       -- STEP 3: Typecheck the update bindings.
-       -- Do this after checking for bad fields in case 
-       -- there's a field that doesn't match the constructor.
-    let
-       result_ty     = substTy result_inst_env con1_res_ty
-       con1_arg_tys' = map (substTy result_inst_env) con1_arg_tys
-       origin        = RecordUpdOrigin
-
-    co_fn   <- tcSubExp origin result_ty res_ty
-    rbinds' <- tcRecordBinds con1 con1_arg_tys' rbinds
-
-       -- STEP 5: Typecheck the expression to be updated
-    let
-       scrut_inst_env = zipTopTvSubst con1_tyvars scrut_inst_tys
-       scrut_ty = substTy scrut_inst_env con1_res_ty
-       -- This is one place where the isVanilla check is important
-       -- So that inst_tys matches the con1_tyvars
-
-    record_expr' <- tcMonoExpr record_expr scrut_ty
-
-       -- STEP 6: Figure out the LIE we need.  
-       -- We have to generate some dictionaries for the data type context, 
-       -- since we are going to do pattern matching over the data cons.
-       --
-       -- What dictionaries do we need?  The dataConStupidTheta tells us.
-    let
-       theta' = substTheta scrut_inst_env (dataConStupidTheta con1)
-
-    instStupidTheta origin theta'
+       ; checkTc (not (null relevant_cons)) (badFieldsUpd rbinds)
+
+       -- STEP 3    Note [Criteria for update]
+       -- Check that each updated field is polymorphic; that is, its type
+       -- mentions only the universally-quantified variables of the data con
+       ; let flds_w_tys = zipEqual "tcExpr:RecConUpd" con1_flds con1_arg_tys
+             (upd_flds_w_tys, fixed_flds_w_tys) = partition is_updated flds_w_tys
+             is_updated (fld,ty) = fld `elem` upd_fld_names
+
+             bad_upd_flds = filter bad_fld upd_flds_w_tys
+             con1_tv_set = mkVarSet con1_tvs
+             bad_fld (fld, ty) = fld `elem` upd_fld_names &&
+                                     not (tyVarsOfType ty `subVarSet` con1_tv_set)
+       ; checkTc (null bad_upd_flds) (badFieldTypes bad_upd_flds)
+
+       -- STEP 4  Note [Type of a record update]
+       -- Figure out types for the scrutinee and result
+       -- Both are of form (T a b c), with fresh type variables, but with
+       -- common variables where the scrutinee and result must have the same type
+       -- These are variables that appear anywhere *except* in the updated fields
+       ; let common_tvs = exactTyVarsOfTypes (map snd fixed_flds_w_tys)
+                          `unionVarSet` constrainedTyVars con1_tvs relevant_cons
+             is_common_tv tv = tv `elemVarSet` common_tvs
+
+             mk_inst_ty tv result_inst_ty 
+               | is_common_tv tv = return result_inst_ty           -- Same as result type
+               | otherwise       = newFlexiTyVarTy (tyVarKind tv)  -- Fresh type, of correct kind
+
+       ; (_, result_inst_tys, result_inst_env) <- tcInstTyVars con1_tvs
+       ; scrut_inst_tys <- zipWithM mk_inst_ty con1_tvs result_inst_tys
+
+       ; let result_ty     = substTy result_inst_env con1_res_ty
+             con1_arg_tys' = map (substTy result_inst_env) con1_arg_tys
+             scrut_subst   = zipTopTvSubst con1_tvs scrut_inst_tys
+             scrut_ty      = substTy scrut_subst con1_res_ty
+
+       -- STEP 5
+       -- Typecheck the thing to be updated, and the bindings
+       ; record_expr' <- tcMonoExpr record_expr scrut_ty
+       ; rbinds'      <- tcRecordBinds con1 con1_arg_tys' rbinds
+       
+       ; let origin = RecordUpdOrigin
+       ; co_fn <- tcSubExp origin result_ty res_ty
+
+       -- STEP 6: Deal with the stupid theta
+       ; let theta' = substTheta scrut_subst (dataConStupidTheta con1)
+       ; instStupidTheta origin theta'
 
        -- Step 7: make a cast for the scrutinee, in the case that it's from a type family
-    let scrut_co | Just co_con <- tyConFamilyCoercion_maybe tycon 
-                = WpCast $ mkTyConApp co_con scrut_inst_tys
-                | otherwise
-                = idHsWrapper
+       ; let scrut_co | Just co_con <- tyConFamilyCoercion_maybe tycon 
+                      = WpCast $ mkTyConApp co_con scrut_inst_tys
+                      | otherwise
+                      = idHsWrapper
 
        -- Phew!
-    return (mkHsWrap co_fn (RecordUpd (mkLHsWrap scrut_co record_expr') rbinds'
-                                      relevant_cons scrut_inst_tys result_inst_tys))
+       ; return (mkHsWrap co_fn (RecordUpd (mkLHsWrap scrut_co record_expr') rbinds'
+                                       relevant_cons scrut_inst_tys result_inst_tys)) }
+  where
+    constrainedTyVars :: [TyVar] -> [DataCon] -> TyVarSet
+    -- Universally-quantified tyvars that appear in any of the 
+    -- *implicit* arguments to the constructor
+    -- These tyvars must not change across the updates
+    -- See Note [Implict type sharing]
+    constrainedTyVars tvs1 cons
+      = mkVarSet [tv1 | con <- cons
+                     , let (tvs, theta, _, _) = dataConSig con
+                           bad_tvs = tyVarsOfTheta theta
+                      , (tv1,tv) <- tvs1 `zip` tvs     -- Discards existentials in tvs
+                     , tv `elemVarSet` bad_tvs ]
 \end{code}
 
-
 %************************************************************************
 %*                                                                     *
        Arithmetic sequences                    e.g. [a,b..]
@@ -1131,10 +1186,13 @@ tcRecordBinds data_con arg_tys (HsRecFields rbinds dd)
     do_bind fld@(HsRecField { hsRecFieldId = L loc field_lbl, hsRecFieldArg = rhs })
       | Just field_ty <- assocMaybe flds_w_tys field_lbl
       = addErrCtxt (fieldCtxt field_lbl)       $
-       do { rhs'   <- tcPolyExprNC rhs field_ty
-          ; sel_id <- tcLookupField field_lbl
-          ; ASSERT( isRecordSelector sel_id )
-            return (Just (fld { hsRecFieldId = L loc sel_id, hsRecFieldArg = rhs' })) }
+       do { rhs' <- tcPolyExprNC rhs field_ty
+          ; let field_id = mkUserLocal (nameOccName field_lbl)
+                                       (nameUnique field_lbl)
+                                       field_ty loc
+               -- The field_id has the *unique* of the selector Id
+               -- but is a LocalId with the appropriate type of the RHS
+          ; return (Just (fld { hsRecFieldId = L loc field_id, hsRecFieldArg = rhs' })) }
       | otherwise
       = do { addErrTc (badFieldCon data_con field_lbl)
           ; return Nothing }
@@ -1198,11 +1256,11 @@ funAppCtxt fun arg arg_no
                    quotes (ppr fun) <> text ", namely"])
         4 (quotes (ppr arg))
 
-nonVanillaUpd tycon
-  = vcat [ptext (sLit "Record update for the non-Haskell-98 data type") 
-               <+> quotes (pprSourceTyCon tycon)
-               <+> ptext (sLit "is not (yet) supported"),
-         ptext (sLit "Use pattern-matching instead")]
+badFieldTypes prs
+  = hang (ptext (sLit "Record update for insufficiently polymorphic field")
+                        <> plural prs <> colon)
+       2 (vcat [ ppr f <+> dcolon <+> ppr ty | (f,ty) <- prs ])
+
 badFieldsUpd rbinds
   = hang (ptext (sLit "No constructor has all these fields:"))
         4 (pprQuotedList (hsRecFields rbinds))
index 7e15770..491ca27 100644 (file)
@@ -766,9 +766,9 @@ zonkRecFields env (HsRecFields flds dd)
        ; return (HsRecFields flds' dd) }
   where
     zonk_rbind fld
-      = do { new_expr <- zonkLExpr env (hsRecFieldArg fld)
-          ; return (fld { hsRecFieldArg = new_expr }) }
-       -- Field selectors have declared types; hence no zonking
+      = do { new_id   <- wrapLocM (zonkIdBndr env) (hsRecFieldId fld)
+          ; new_expr <- zonkLExpr env (hsRecFieldArg fld)
+          ; return (fld { hsRecFieldId = new_id, hsRecFieldArg = new_expr }) }
 
 -------------------------------------------------------------------------
 mapIPNameTc :: (a -> TcM b) -> IPName a -> TcM (IPName b)
index dee62f4..b02b27e 100644 (file)
@@ -1990,15 +1990,28 @@ main = do
     display (inc (inc counterB))   -- prints "##"
 </programlisting>
 
-At the moment, record update syntax is only supported for Haskell 98 data types,
-so the following function does <emphasis>not</emphasis> work:
-
+Record update syntax is supported for existentials (and GADTs):
 <programlisting>
--- This is invalid; use explicit NewCounter instead for now
 setTag :: Counter a -> a -> Counter a
 setTag obj t = obj{ tag = t }
 </programlisting>
+The rule for record update is this: <emphasis>
+the types of the updated fields may
+mention only the universally-quantified type variables
+of the data constructor.  For GADTs, the field may mention only types
+that appear as a simple type-variable argument in the constructor's result type.
+</emphasis>.  For exmaple:
+<programlisting>
+data T a where { T1 { f1::a, f2::(a,b) } :: T a }    -- b is existential
+upd1 t x = t { f1=x }   -- OK:   upd1 :: T a -> b -> T b
+upd2 t x = t { f2=x }   -- BAD   (f2's type mentions b, which is
+                                  existentially quantified)
 
+data G a b where { G1 { f1::a, f2::c } :: G a [c] }
+upd3 g x = g { f1=x }   -- OK:   upd3 :: G a b -> c -> G c b
+upd4 g x = g { f2=x }   -- BAD (f2's type mentions c, which is not a simple
+                        --      type-varialbe argument in G1's result type)
+</programlisting>
 </para>
 
 </sect3>