Fix Trac #2856: make deriving work for type families
authorsimonpj@microsoft.com <unknown>
Wed, 31 Dec 2008 14:41:51 +0000 (14:41 +0000)
committersimonpj@microsoft.com <unknown>
Wed, 31 Dec 2008 14:41:51 +0000 (14:41 +0000)
Darn, but TcDeriv is complicated, when type families get in on
the act!  This patch makes GeneralisedNewtypeDeriving work
properly for type families.  I think.

In order to do so, I found that GeneralisedNewtypeDeriving can
work for recursive newtypes too -- and since families are conservatively
marked recursive, that's a crucial part of the fix, and useful too.
See Note [Recursive newtypes] in TcDeriv.

compiler/typecheck/TcDeriv.lhs
compiler/typecheck/TcEnv.lhs
compiler/typecheck/TcInstDcls.lhs

index 419ec94..1a21240 100644 (file)
@@ -30,6 +30,7 @@ import HscTypes
 
 import Class
 import Type
+import Coercion
 import ErrUtils
 import MkId
 import DataCon
@@ -75,6 +76,7 @@ data DerivSpec  = DS { ds_loc     :: SrcSpan
                     , ds_cls     :: Class
                     , ds_tys     :: [Type]
                     , ds_tc      :: TyCon
+                    , ds_tc_args :: [Type]
                     , ds_newtype :: Bool }
        -- This spec implies a dfun declaration of the form
        --       df :: forall tvs. theta => C tys
@@ -82,7 +84,7 @@ data DerivSpec  = DS { ds_loc     :: SrcSpan
        -- The tyvars bind all the variables in the theta
        -- For family indexes, the tycon in 
        --       in ds_tys is the *family* tycon
-       --       in ds_tc  is the *representation* tycon
+       --       in ds_tc, ds_tc_args is the *representation* tycon
        -- For non-family tycons, both are the same
 
        -- ds_newtype = True  <=> Newtype deriving
@@ -339,8 +341,8 @@ renameDeriv is_boot gen_binds insts
                       | otherwise            = rm_dups (b:acc) bs
 
 
-    rn_inst_info (InstInfo { iSpec = inst, iBinds = NewTypeDerived })
-       = return (InstInfo { iSpec = inst, iBinds = NewTypeDerived })
+    rn_inst_info (InstInfo { iSpec = inst, iBinds = NewTypeDerived co })
+       = return (InstInfo { iSpec = inst, iBinds = NewTypeDerived co })
 
     rn_inst_info (InstInfo { iSpec = inst, iBinds = VanillaInst binds sigs })
        =       -- Bring the right type variables into 
@@ -674,14 +676,15 @@ mk_data_eqn orig tvs cls tycon tc_args rep_tc rep_tc_args mtheta
 
              spec = DS { ds_loc = loc, ds_orig = orig
                        , ds_name = dfun_name, ds_tvs = tvs 
-                       , ds_cls = cls, ds_tys = inst_tys, ds_tc = rep_tc
+                       , ds_cls = cls, ds_tys = inst_tys
+                       , ds_tc = rep_tc, ds_tc_args = rep_tc_args
                        , ds_theta =  mtheta `orElse` all_constraints
                        , ds_newtype = False }
 
        ; return (if isJust mtheta then Right spec      -- Specified context
                                   else Left spec) }    -- Infer context
 
-mk_typeable_eqn orig tvs cls tycon tc_args rep_tc _rep_tc_args mtheta
+mk_typeable_eqn orig tvs cls tycon tc_args rep_tc rep_tc_args mtheta
        -- The Typeable class is special in several ways
        --        data T a b = ... deriving( Typeable )
        -- gives
@@ -705,7 +708,8 @@ mk_typeable_eqn orig tvs cls tycon tc_args rep_tc _rep_tc_args mtheta
        ; loc <- getSrcSpanM
        ; return (Right $
                  DS { ds_loc = loc, ds_orig = orig, ds_name = dfun_name, ds_tvs = []
-                    , ds_cls = cls, ds_tys = [mkTyConApp tycon []], ds_tc = rep_tc
+                    , ds_cls = cls, ds_tys = [mkTyConApp tycon []]
+                    , ds_tc = rep_tc, ds_tc_args = rep_tc_args
                     , ds_theta = mtheta `orElse` [], ds_newtype = False })  }
 
 ------------------------------------------------------------------
@@ -899,7 +903,8 @@ mkNewTypeEqn orig mayDeriveDataTypeable newtype_deriving tvs
        ; loc <- getSrcSpanM
        ; let spec = DS { ds_loc = loc, ds_orig = orig
                        , ds_name = dfun_name, ds_tvs = varSetElems dfun_tvs 
-                       , ds_cls = cls, ds_tys = inst_tys, ds_tc = rep_tycon
+                       , ds_cls = cls, ds_tys = inst_tys
+                       , ds_tc = rep_tycon, ds_tc_args = rep_tc_args
                        , ds_theta =  mtheta `orElse` all_preds
                        , ds_newtype = True }
        ; return (if isJust mtheta then Right spec
@@ -952,7 +957,7 @@ mkNewTypeEqn orig mayDeriveDataTypeable newtype_deriving tvs
 
        nt_eta_arity = length (fst (newTyConEtadRhs rep_tycon))
                -- For newtype T a b = MkT (S a a b), the TyCon machinery already
-               -- eta-reduces the represenation type, so we know that
+               -- eta-reduces the representation type, so we know that
                --      T a ~ S a a
                -- That's convenient here, because we may have to apply
                -- it to fewer than its original complement of arguments
@@ -1006,18 +1011,7 @@ mkNewTypeEqn orig mayDeriveDataTypeable newtype_deriving tvs
                                                -- eg not: newtype T ... deriving( ST )
                                                --      because ST needs *2* type params
           && eta_ok                            -- Eta reduction works
-          && not (isRecursiveTyCon tycon)      -- Does not work for recursive tycons:
-                                               --      newtype A = MkA [A]
-                                               -- Don't want
-                                               --      instance Eq [A] => Eq A !!
-                       -- Here's a recursive newtype that's actually OK
-                       --      newtype S1 = S1 [T1 ()]
-                       --      newtype T1 a = T1 (StateT S1 IO a ) deriving( Monad )
-                       -- It's currently rejected.  Oh well.
-                       -- In fact we generate an instance decl that has method of form
-                       --      meth @ instTy = meth @ repTy
-                       -- (no coerce's).  We'd need a coerce if we wanted to handle
-                       -- recursive newtypes too
+--        && not (isRecursiveTyCon tycon)      -- Note [Recursive newtypes]
 
        -- Check that eta reduction is OK
        eta_ok = nt_eta_arity <= length rep_tc_args
@@ -1041,6 +1035,21 @@ mkNewTypeEqn orig mayDeriveDataTypeable newtype_deriving tvs
                                ]
 \end{code}
 
+Note [Recursive newtypes]
+~~~~~~~~~~~~~~~~~~~~~~~~~
+Newtype deriving works fine, even if the newtype is recursive.
+e.g.   newtype S1 = S1 [T1 ()]
+       newtype T1 a = T1 (StateT S1 IO a ) deriving( Monad )
+Remember, too, that type families are curretly (conservatively) given
+a recursive flag, so this also allows newtype deriving to work
+for type famillies.
+
+We used to exclude recursive types, because we had a rather simple
+minded way of generating the instance decl:
+   newtype A = MkA [A]
+   instance Eq [A] => Eq A     -- Makes typechecker loop!
+But now we require a simple context, so it's ok.
+
 
 %************************************************************************
 %*                                                                     *
@@ -1093,7 +1102,7 @@ inferInstanceContexts oflag infer_specs
       | otherwise
       =        do {      -- Extend the inst info from the explicit instance decls
                  -- with the current set of solutions, and simplify each RHS
-            let inst_specs = zipWithEqual "add_solns" (mkInstance2 oflag)
+            let inst_specs = zipWithEqual "add_solns" (mkInstance oflag)
                                           current_solns infer_specs
           ; new_solns <- checkNoErrs $
                          extendLocalInstEnv inst_specs $
@@ -1131,11 +1140,8 @@ inferInstanceContexts oflag infer_specs
           ; return (sortLe (<=) theta) }       -- Canonicalise before returning the solution
 
 ------------------------------------------------------------------
-mkInstance1 :: OverlapFlag -> DerivSpec -> Instance
-mkInstance1 overlap_flag spec = mkInstance2 overlap_flag (ds_theta spec) spec
-
-mkInstance2 :: OverlapFlag -> ThetaType -> DerivSpec -> Instance
-mkInstance2 overlap_flag theta
+mkInstance :: OverlapFlag -> ThetaType -> DerivSpec -> Instance
+mkInstance overlap_flag theta
            (DS { ds_name = dfun_name
                , ds_tvs = tyvars, ds_cls = clas, ds_tys = tys })
   = mkLocalInstance dfun overlap_flag
@@ -1227,14 +1233,13 @@ the renamer.  What a great hack!
 genInst :: OverlapFlag -> DerivSpec -> TcM (InstInfo RdrName, DerivAuxBinds)
 genInst oflag spec
   | ds_newtype spec
-  = return (InstInfo { iSpec  = mkInstance1 oflag spec 
-                    , iBinds = NewTypeDerived }, [])
+  = return (InstInfo { iSpec  = mkInstance oflag (ds_theta spec) spec
+                    , iBinds = NewTypeDerived co }, [])
 
   | otherwise
   = do { let loc        = getSrcSpan (ds_name spec)
-             inst       = mkInstance1 oflag spec
+             inst       = mkInstance oflag (ds_theta spec) spec
              clas       = ds_cls spec
-             rep_tycon  = ds_tc spec
 
           -- In case of a family instance, we need to use the representation
           -- tycon (after all, it has the data constructors)
@@ -1246,6 +1251,23 @@ genInst oflag spec
                             iBinds = VanillaInst meth_binds [] },
                  aux_binds)
         }
+  where
+    rep_tycon   = ds_tc spec
+    rep_tc_args = ds_tc_args spec
+    co1 = case tyConFamilyCoercion_maybe rep_tycon of
+             Nothing     -> IdCo
+             Just co_con -> ACo (mkTyConApp co_con rep_tc_args)
+    co2 = case newTyConCo_maybe rep_tycon of
+              Nothing     -> IdCo      -- The newtype is transparent; no need for a cast
+             Just co_con -> ACo (mkTyConApp co_con rep_tc_args)
+    co = co1 `mkTransCoI` co2
+
+-- Example: newtype instance N [a] = N1 (Tree a) 
+--          deriving instance Eq b => Eq (N [(b,b)])
+-- From the instance, we get an implicit newtype R1:N a = N1 (Tree a)
+-- When dealing with the deriving clause
+--    co1 : N [(b,b)] ~ R1:N (b,b)
+--    co2 : R1:N (b,b) ~ Tree (b,b)
 
 genDerivBinds :: SrcSpan -> FixityEnv -> Class -> TyCon -> (LHsBinds RdrName, DerivAuxBinds)
 genDerivBinds loc fix_env clas tycon
index 9afe28f..8c37e08 100644 (file)
@@ -57,6 +57,7 @@ import TcType
 -- import TcSuspension
 import qualified Type
 import Id
+import Coercion
 import Var
 import VarSet
 import VarEnv
@@ -642,8 +643,12 @@ data InstBindings a
                                -- specialised instances
 
   | NewTypeDerived              -- Used for deriving instances of newtypes, where the
-                               -- witness dictionary is identical to the argument 
+       CoercionI               -- witness dictionary is identical to the argument 
                                -- dictionary.  Hence no bindings, no pragmas.
+               -- The coercion maps from newtype to the representation type
+               -- (mentioning type variables bound by the forall'd iSpec variables)
+               -- E.g.   newtype instance N [a] = N1 (Tree a)
+               --        co : N [a] ~ Tree a
 
 pprInstInfo :: InstInfo a -> SDoc
 pprInstInfo info = vcat [ptext (sLit "InstInfo:") <+> ppr (idType (iDFunId info))]
@@ -651,8 +656,8 @@ pprInstInfo info = vcat [ptext (sLit "InstInfo:") <+> ppr (idType (iDFunId info)
 pprInstInfoDetails :: OutputableBndr a => InstInfo a -> SDoc
 pprInstInfoDetails info = pprInstInfo info $$ nest 2 (details (iBinds info))
   where
-    details (VanillaInst b _) = pprLHsBinds b
-    details NewTypeDerived    = text "Derived from the representation type"
+    details (VanillaInst b _)  = pprLHsBinds b
+    details (NewTypeDerived _) = text "Derived from the representation type"
 
 simpleInstInfoClsTy :: InstInfo a -> (Class, Type)
 simpleInstInfoClsTy info = case instanceHead (iSpec info) of
index 177a16f..3048174 100644 (file)
@@ -611,15 +611,28 @@ tc_inst_decl2 dfun_id (NewTypeDerived coi)
               (class_tyvars, sc_theta, _, _) = classBigSig cls
               cls_tycon = classTyCon cls
               sc_theta' = substTheta (zipOpenTvSubst class_tyvars cls_inst_tys) sc_theta
-
               Just (initial_cls_inst_tys, last_ty) = snocView cls_inst_tys
-              (nt_tycon, tc_args) = tcSplitTyConApp last_ty     -- Can't fail
-              rep_ty              = newTyConInstRhs nt_tycon tc_args
 
-              rep_pred     = mkClassPred cls (initial_cls_inst_tys ++ [rep_ty])
-                                -- In our example, rep_pred is (Foo Int (Tree [a]))
-              the_coercion = make_coercion cls_tycon initial_cls_inst_tys nt_tycon tc_args
-                                -- Coercion of kind (Foo Int (Tree [a]) ~ Foo Int (N a)
+              (rep_ty, wrapper) 
+                = case coi of
+                    IdCo   -> (last_ty, idHsWrapper)
+                    ACo co -> (snd (coercionKind co), WpCast (mk_full_coercion co))
+
+                -----------------------
+                --        mk_full_coercion
+                -- The inst_head looks like (C s1 .. sm (T a1 .. ak))
+                -- But we want the coercion (C s1 .. sm (sym (CoT a1 .. ak)))
+                --        with kind (C s1 .. sm (T a1 .. ak)  ~  C s1 .. sm <rep_ty>)
+                --        where rep_ty is the (eta-reduced) type rep of T
+                -- So we just replace T with CoT, and insert a 'sym'
+                -- NB: we know that k will be >= arity of CoT, because the latter fully eta-reduced
+
+             mk_full_coercion co = mkTyConApp cls_tycon 
+                                        (initial_cls_inst_tys ++ [mkSymCoercion co])
+                 -- Full coercion : (Foo Int (Tree [a]) ~ Foo Int (N a)
+
+              rep_pred = mkClassPred cls (initial_cls_inst_tys ++ [rep_ty])
+                 -- In our example, rep_pred is (Foo Int (Tree [a]))
 
         ; sc_loc     <- getInstLoc InstScOrigin
         ; sc_dicts   <- newDictBndrs sc_loc sc_theta'
@@ -639,7 +652,7 @@ tc_inst_decl2 dfun_id (NewTypeDerived coi)
        -- in the envt with one of the clas_tyvars
        ; checkSigTyVars inst_tvs'
 
-        ; let coerced_rep_dict = wrapId the_coercion (instToId rep_dict)
+        ; let coerced_rep_dict = wrapId wrapper (instToId rep_dict)
 
         ; body <- make_body cls_tycon cls_inst_tys sc_dicts coerced_rep_dict
         ; let dict_bind = noLoc $ VarBind (instToId this_dict) (noLoc body)
@@ -650,22 +663,6 @@ tc_inst_decl2 dfun_id (NewTypeDerived coi)
                             (dict_bind `consBag` sc_binds)) }
   where
       -----------------------
-      --        make_coercion
-      -- The inst_head looks like (C s1 .. sm (T a1 .. ak))
-      -- But we want the coercion (C s1 .. sm (sym (CoT a1 .. ak)))
-      --        with kind (C s1 .. sm (T a1 .. ak)  ~  C s1 .. sm <rep_ty>)
-      --        where rep_ty is the (eta-reduced) type rep of T
-      -- So we just replace T with CoT, and insert a 'sym'
-      -- NB: we know that k will be >= arity of CoT, because the latter fully eta-reduced
-
-    make_coercion cls_tycon initial_cls_inst_tys nt_tycon tc_args
-        | Just co_con <- newTyConCo_maybe nt_tycon
-        , let co = mkSymCoercion (mkTyConApp co_con tc_args)
-        = WpCast (mkTyConApp cls_tycon (initial_cls_inst_tys ++ [co]))
-        | otherwise     -- The newtype is transparent; no need for a cast
-        = idHsWrapper
-
-      -----------------------
       --     (make_body C tys scs coreced_rep_dict)
       --                returns
       --     (case coerced_rep_dict of { C _ ops -> C scs ops })