Fix Trac #2856: make deriving work for type families
[ghc-hetmet.git] / compiler / typecheck / TcDeriv.lhs
index 44ea1fc..1a21240 100644 (file)
@@ -30,6 +30,7 @@ import HscTypes
 
 import Class
 import Type
+import Coercion
 import ErrUtils
 import MkId
 import DataCon
@@ -75,6 +76,7 @@ data DerivSpec  = DS { ds_loc     :: SrcSpan
                     , ds_cls     :: Class
                     , ds_tys     :: [Type]
                     , ds_tc      :: TyCon
+                    , ds_tc_args :: [Type]
                     , ds_newtype :: Bool }
        -- This spec implies a dfun declaration of the form
        --       df :: forall tvs. theta => C tys
@@ -82,7 +84,7 @@ data DerivSpec  = DS { ds_loc     :: SrcSpan
        -- The tyvars bind all the variables in the theta
        -- For family indexes, the tycon in 
        --       in ds_tys is the *family* tycon
-       --       in ds_tc  is the *representation* tycon
+       --       in ds_tc, ds_tc_args is the *representation* tycon
        -- For non-family tycons, both are the same
 
        -- ds_newtype = True  <=> Newtype deriving
@@ -339,8 +341,8 @@ renameDeriv is_boot gen_binds insts
                       | otherwise            = rm_dups (b:acc) bs
 
 
-    rn_inst_info (InstInfo { iSpec = inst, iBinds = NewTypeDerived })
-       = return (InstInfo { iSpec = inst, iBinds = NewTypeDerived })
+    rn_inst_info (InstInfo { iSpec = inst, iBinds = NewTypeDerived co })
+       = return (InstInfo { iSpec = inst, iBinds = NewTypeDerived co })
 
     rn_inst_info (InstInfo { iSpec = inst, iBinds = VanillaInst binds sigs })
        =       -- Bring the right type variables into 
@@ -459,19 +461,34 @@ deriveTyData (L loc deriv_pred, L _ decl@(TyData { tcdLName = L _ tycon_name,
              (arg_kinds, _) = splitKindFunTys kind
              n_args_to_drop = length arg_kinds 
              n_args_to_keep = tyConArity tc - n_args_to_drop
-             inst_ty = mkTyConApp tc (take n_args_to_keep tc_args)
-             inst_ty_kind = typeKind inst_ty
-
+             args_to_drop   = drop n_args_to_keep tc_args
+             inst_ty        = mkTyConApp tc (take n_args_to_keep tc_args)
+             inst_ty_kind   = typeKind inst_ty
+             dropped_tvs    = mkVarSet (mapCatMaybes getTyVar_maybe args_to_drop)
+             univ_tvs       = (mkVarSet tvs `extendVarSetList` deriv_tvs)
+                                       `minusVarSet` dropped_tvs
        -- Check that the result really is well-kinded
        ; checkTc (n_args_to_keep >= 0 && (inst_ty_kind `eqKind` kind))
                  (derivingKindErr tc cls cls_tys kind)
 
+       ; checkTc (sizeVarSet dropped_tvs == n_args_to_drop &&           -- (a)
+                  tyVarsOfTypes (inst_ty:cls_tys) `subVarSet` univ_tvs) -- (b)
+                 (derivingEtaErr cls cls_tys inst_ty)
+               -- Check that 
+               --  (a) The data type can be eta-reduced; eg reject:
+               --              data instance T a a = ... deriving( Monad )
+               --  (b) The type class args do not mention any of the dropped type
+               --      variables 
+               --              newtype T a s = ... deriving( ST s )
+
        -- Type families can't be partially applied
-       -- e.g.   newtype instance T Int a = ... deriving( Monad )
+       -- e.g.   newtype instance T Int a = MkT [a] deriving( Monad )
+       -- Note [Deriving, type families, and partial applications]
        ; checkTc (not (isOpenTyCon tc) || n_args_to_drop == 0)
                  (typeFamilyPapErr tc cls cls_tys inst_ty)
 
-       ; mkEqnHelp DerivOrigin (tvs++deriv_tvs) cls cls_tys inst_ty Nothing } }
+       ; mkEqnHelp DerivOrigin (varSetElems univ_tvs) cls cls_tys inst_ty Nothing } }
   where
        -- Tiresomely we must figure out the "lhs", which is awkward for type families
        -- E.g.   data T a b = .. deriving( Eq )
@@ -490,8 +507,37 @@ deriveTyData (L loc deriv_pred, L _ decl@(TyData { tcdLName = L _ tycon_name,
 
 deriveTyData _other
   = panic "derivTyData"        -- Caller ensures that only TyData can happen
+\end{code}
 
-------------------------------------------------------------------
+Note [Deriving, type families, and partial applications]
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+When there are no type families, it's quite easy:
+
+    newtype S a = MkS [a]
+    -- :CoS :: S  ~ [] -- Eta-reduced
+
+    instance Eq [a] => Eq (S a)        -- by coercion sym (Eq (coMkS a)) : Eq [a] ~ Eq (S a)
+    instance Monad [] => Monad S       -- by coercion sym (Monad coMkS)  : Monad [] ~ Monad S 
+
+When type familes are involved it's trickier:
+
+    data family T a b
+    newtype instance T Int a = MkT [a] deriving( Eq, Monad )
+    -- :RT is the representation type for (T Int a)
+    --  :CoF:R1T a :: T Int a ~ :RT a  -- Not eta reduced
+    --  :Co:R1T    :: :RT ~ []         -- Eta-reduced
+
+    instance Eq [a] => Eq (T Int a)    -- easy by coercion
+    instance Monad [] => Monad (T Int) -- only if we can eta reduce???
+
+The "???" bit is that we don't build the :CoF thing in eta-reduced form
+Henc the current typeFamilyPapErr, even though the instance makes sense.
+After all, we can write it out
+    instance Monad [] => Monad (T Int) -- only if we can eta reduce???
+      return x = MkT [x]
+      ... etc ...      
+
+\begin{code}
 mkEqnHelp :: InstOrigin -> [TyVar] -> Class -> [Type] -> Type
           -> Maybe ThetaType   -- Just    => context supplied (standalone deriving)
                                -- Nothing => context inferred (deriving on data decl)
@@ -630,14 +676,15 @@ mk_data_eqn orig tvs cls tycon tc_args rep_tc rep_tc_args mtheta
 
              spec = DS { ds_loc = loc, ds_orig = orig
                        , ds_name = dfun_name, ds_tvs = tvs 
-                       , ds_cls = cls, ds_tys = inst_tys, ds_tc = rep_tc
+                       , ds_cls = cls, ds_tys = inst_tys
+                       , ds_tc = rep_tc, ds_tc_args = rep_tc_args
                        , ds_theta =  mtheta `orElse` all_constraints
                        , ds_newtype = False }
 
        ; return (if isJust mtheta then Right spec      -- Specified context
                                   else Left spec) }    -- Infer context
 
-mk_typeable_eqn orig tvs cls tycon tc_args rep_tc _rep_tc_args mtheta
+mk_typeable_eqn orig tvs cls tycon tc_args rep_tc rep_tc_args mtheta
        -- The Typeable class is special in several ways
        --        data T a b = ... deriving( Typeable )
        -- gives
@@ -661,7 +708,8 @@ mk_typeable_eqn orig tvs cls tycon tc_args rep_tc _rep_tc_args mtheta
        ; loc <- getSrcSpanM
        ; return (Right $
                  DS { ds_loc = loc, ds_orig = orig, ds_name = dfun_name, ds_tvs = []
-                    , ds_cls = cls, ds_tys = [mkTyConApp tycon []], ds_tc = rep_tc
+                    , ds_cls = cls, ds_tys = [mkTyConApp tycon []]
+                    , ds_tc = rep_tc, ds_tc_args = rep_tc_args
                     , ds_theta = mtheta `orElse` [], ds_newtype = False })  }
 
 ------------------------------------------------------------------
@@ -673,19 +721,17 @@ mk_typeable_eqn orig tvs cls tycon tc_args rep_tc _rep_tc_args mtheta
 -- family tycon (with indexes) in error messages.
 
 data DerivStatus = CanDerive
-                | NonDerivableClass
-                | DerivableClassError SDoc
+                | DerivableClassError SDoc     -- Standard class, but can't do it
+                | NonDerivableClass            -- Non-standard class
 
 checkSideConditions :: Bool -> Class -> [TcType] -> TyCon -> DerivStatus
 checkSideConditions mayDeriveDataTypeable cls cls_tys rep_tc
-  | notNull cls_tys    
-  = DerivableClassError ty_args_why    -- e.g. deriving( Foo s )
-  | otherwise
-  = case sideConditions cls of
-       Nothing   -> NonDerivableClass
-       Just cond -> case (cond (mayDeriveDataTypeable, rep_tc)) of
-                       Nothing  -> CanDerive
-                       Just err -> DerivableClassError err
+  | Just cond <- sideConditions cls
+  = case (cond (mayDeriveDataTypeable, rep_tc)) of
+       Just err -> DerivableClassError err     -- Class-specific error
+       Nothing  | null cls_tys -> CanDerive
+                | otherwise    -> DerivableClassError ty_args_why      -- e.g. deriving( Eq s )
+  | otherwise = NonDerivableClass      -- Not a standard class
   where
     ty_args_why        = quotes (ppr (mkClassPred cls cls_tys)) <+> ptext (sLit "is not a class")
 
@@ -850,13 +896,15 @@ mkNewTypeEqn :: InstOrigin -> Bool -> Bool -> [Var] -> Class
              -> TcRn EarlyDerivSpec
 mkNewTypeEqn orig mayDeriveDataTypeable newtype_deriving tvs
              cls cls_tys tycon tc_args rep_tycon rep_tc_args mtheta
+-- Want: instance (...) => cls (cls_tys ++ [tycon tc_args]) where ...
   | can_derive_via_isomorphism && (newtype_deriving || std_class_via_iso cls)
   = do { traceTc (text "newtype deriving:" <+> ppr tycon <+> ppr rep_tys)
        ; dfun_name <- new_dfun_name cls tycon
        ; loc <- getSrcSpanM
        ; let spec = DS { ds_loc = loc, ds_orig = orig
                        , ds_name = dfun_name, ds_tvs = varSetElems dfun_tvs 
-                       , ds_cls = cls, ds_tys = inst_tys, ds_tc = rep_tycon
+                       , ds_cls = cls, ds_tys = inst_tys
+                       , ds_tc = rep_tycon, ds_tc_args = rep_tc_args
                        , ds_theta =  mtheta `orElse` all_preds
                        , ds_newtype = True }
        ; return (if isJust mtheta then Right spec
@@ -909,7 +957,7 @@ mkNewTypeEqn orig mayDeriveDataTypeable newtype_deriving tvs
 
        nt_eta_arity = length (fst (newTyConEtadRhs rep_tycon))
                -- For newtype T a b = MkT (S a a b), the TyCon machinery already
-               -- eta-reduces the represenation type, so we know that
+               -- eta-reduces the representation type, so we know that
                --      T a ~ S a a
                -- That's convenient here, because we may have to apply
                -- it to fewer than its original complement of arguments
@@ -934,7 +982,7 @@ mkNewTypeEqn orig mayDeriveDataTypeable newtype_deriving tvs
     -- See Note [Newtype deriving superclasses] above
 
        cls_tyvars = classTyVars cls
-       dfun_tvs = tyVarsOfTypes tc_args
+       dfun_tvs = tyVarsOfTypes inst_tys
        inst_ty = mkTyConApp tycon tc_args
        inst_tys = cls_tys ++ [inst_ty]
        sc_theta = substTheta (zipOpenTvSubst cls_tyvars inst_tys)
@@ -963,33 +1011,17 @@ mkNewTypeEqn orig mayDeriveDataTypeable newtype_deriving tvs
                                                -- eg not: newtype T ... deriving( ST )
                                                --      because ST needs *2* type params
           && eta_ok                            -- Eta reduction works
-          && not (isRecursiveTyCon tycon)      -- Does not work for recursive tycons:
-                                               --      newtype A = MkA [A]
-                                               -- Don't want
-                                               --      instance Eq [A] => Eq A !!
-                       -- Here's a recursive newtype that's actually OK
-                       --      newtype S1 = S1 [T1 ()]
-                       --      newtype T1 a = T1 (StateT S1 IO a ) deriving( Monad )
-                       -- It's currently rejected.  Oh well.
-                       -- In fact we generate an instance decl that has method of form
-                       --      meth @ instTy = meth @ repTy
-                       -- (no coerce's).  We'd need a coerce if we wanted to handle
-                       -- recursive newtypes too
+--        && not (isRecursiveTyCon tycon)      -- Note [Recursive newtypes]
 
        -- Check that eta reduction is OK
-       eta_ok = (nt_eta_arity <= length rep_tc_args)
-               -- (a) the newtype can be eta-reduced to match the number
+       eta_ok = nt_eta_arity <= length rep_tc_args
+               -- The newtype can be eta-reduced to match the number
                --     of type argument actually supplied
                --        newtype T a b = MkT (S [a] b) deriving( Monad )
                --     Here the 'b' must be the same in the rep type (S [a] b)
                --     And the [a] must not mention 'b'.  That's all handled
                --     by nt_eta_rity.
 
-             && (tyVarsOfTypes cls_tys `subVarSet` dfun_tvs)
-               -- (c) the type class args do not mention any of the dropped type
-               --     variables 
-               --              newtype T a b = ... deriving( Monad b )
-
        cant_derive_err = vcat [ptext (sLit "even with cunning newtype deriving:"),
                                if isRecursiveTyCon tycon then
                                  ptext (sLit "the newtype may be recursive")
@@ -1003,6 +1035,21 @@ mkNewTypeEqn orig mayDeriveDataTypeable newtype_deriving tvs
                                ]
 \end{code}
 
+Note [Recursive newtypes]
+~~~~~~~~~~~~~~~~~~~~~~~~~
+Newtype deriving works fine, even if the newtype is recursive.
+e.g.   newtype S1 = S1 [T1 ()]
+       newtype T1 a = T1 (StateT S1 IO a ) deriving( Monad )
+Remember, too, that type families are curretly (conservatively) given
+a recursive flag, so this also allows newtype deriving to work
+for type famillies.
+
+We used to exclude recursive types, because we had a rather simple
+minded way of generating the instance decl:
+   newtype A = MkA [A]
+   instance Eq [A] => Eq A     -- Makes typechecker loop!
+But now we require a simple context, so it's ok.
+
 
 %************************************************************************
 %*                                                                     *
@@ -1055,7 +1102,7 @@ inferInstanceContexts oflag infer_specs
       | otherwise
       =        do {      -- Extend the inst info from the explicit instance decls
                  -- with the current set of solutions, and simplify each RHS
-            let inst_specs = zipWithEqual "add_solns" (mkInstance2 oflag)
+            let inst_specs = zipWithEqual "add_solns" (mkInstance oflag)
                                           current_solns infer_specs
           ; new_solns <- checkNoErrs $
                          extendLocalInstEnv inst_specs $
@@ -1093,11 +1140,8 @@ inferInstanceContexts oflag infer_specs
           ; return (sortLe (<=) theta) }       -- Canonicalise before returning the solution
 
 ------------------------------------------------------------------
-mkInstance1 :: OverlapFlag -> DerivSpec -> Instance
-mkInstance1 overlap_flag spec = mkInstance2 overlap_flag (ds_theta spec) spec
-
-mkInstance2 :: OverlapFlag -> ThetaType -> DerivSpec -> Instance
-mkInstance2 overlap_flag theta
+mkInstance :: OverlapFlag -> ThetaType -> DerivSpec -> Instance
+mkInstance overlap_flag theta
            (DS { ds_name = dfun_name
                , ds_tvs = tyvars, ds_cls = clas, ds_tys = tys })
   = mkLocalInstance dfun overlap_flag
@@ -1189,14 +1233,13 @@ the renamer.  What a great hack!
 genInst :: OverlapFlag -> DerivSpec -> TcM (InstInfo RdrName, DerivAuxBinds)
 genInst oflag spec
   | ds_newtype spec
-  = return (InstInfo { iSpec  = mkInstance1 oflag spec 
-                    , iBinds = NewTypeDerived }, [])
+  = return (InstInfo { iSpec  = mkInstance oflag (ds_theta spec) spec
+                    , iBinds = NewTypeDerived co }, [])
 
   | otherwise
   = do { let loc        = getSrcSpan (ds_name spec)
-             inst       = mkInstance1 oflag spec
+             inst       = mkInstance oflag (ds_theta spec) spec
              clas       = ds_cls spec
-             rep_tycon  = ds_tc spec
 
           -- In case of a family instance, we need to use the representation
           -- tycon (after all, it has the data constructors)
@@ -1208,6 +1251,23 @@ genInst oflag spec
                             iBinds = VanillaInst meth_binds [] },
                  aux_binds)
         }
+  where
+    rep_tycon   = ds_tc spec
+    rep_tc_args = ds_tc_args spec
+    co1 = case tyConFamilyCoercion_maybe rep_tycon of
+             Nothing     -> IdCo
+             Just co_con -> ACo (mkTyConApp co_con rep_tc_args)
+    co2 = case newTyConCo_maybe rep_tycon of
+              Nothing     -> IdCo      -- The newtype is transparent; no need for a cast
+             Just co_con -> ACo (mkTyConApp co_con rep_tc_args)
+    co = co1 `mkTransCoI` co2
+
+-- Example: newtype instance N [a] = N1 (Tree a) 
+--          deriving instance Eq b => Eq (N [(b,b)])
+-- From the instance, we get an implicit newtype R1:N a = N1 (Tree a)
+-- When dealing with the deriving clause
+--    co1 : N [(b,b)] ~ R1:N (b,b)
+--    co2 : R1:N (b,b) ~ Tree (b,b)
 
 genDerivBinds :: SrcSpan -> FixityEnv -> Class -> TyCon -> (LHsBinds RdrName, DerivAuxBinds)
 genDerivBinds loc fix_env clas tycon
@@ -1246,6 +1306,12 @@ derivingKindErr tc cls cls_tys cls_kind
        2 (ptext (sLit "Class") <+> quotes (ppr cls)
            <+> ptext (sLit "expects an argument of kind") <+> quotes (pprKind cls_kind))
 
+derivingEtaErr :: Class -> [Type] -> Type -> Message
+derivingEtaErr cls cls_tys inst_ty
+  = sep [ptext (sLit "Cannot eta-reduce to an instance of form"),
+        nest 2 (ptext (sLit "instance (...) =>")
+               <+> pprClassPred cls (cls_tys ++ [inst_ty]))]
+
 typeFamilyPapErr :: TyCon -> Class -> [Type] -> Type -> Message
 typeFamilyPapErr tc cls cls_tys inst_ty
   = hang (ptext (sLit "Derived instance") <+> quotes (pprClassPred cls (cls_tys ++ [inst_ty])))