Fix Trac #3057 in deriving Functor
authorsimonpj@microsoft.com <unknown>
Tue, 3 Mar 2009 17:06:12 +0000 (17:06 +0000)
committersimonpj@microsoft.com <unknown>
Tue, 3 Mar 2009 17:06:12 +0000 (17:06 +0000)
The universal type variables of a data constructor are not necessarily
identical to those of its parent type constructor, especially if the
data type is imported.

While I was at it, I did a significant refactoring to make all this
traversal of types more comprehensible, by adding the data type
FFoldType.

compiler/typecheck/TcDeriv.lhs
compiler/typecheck/TcGenDeriv.lhs

index 8352f58..c1025d4 100644 (file)
@@ -881,20 +881,19 @@ cond_functorOK allowFunctions (dflags, rep_tc)
   | not (dopt Opt_DeriveFunctor dflags)
   = Just (ptext (sLit "You need -XDeriveFunctor to derive an instance for this class"))
   | otherwise
-  = msum (map check con_types)
+  = msum (map check_con data_cons)     -- msum picks the first 'Just', if any
   where
     data_cons = tyConDataCons rep_tc
-    con_types = concatMap dataConOrigArgTys data_cons
-    check = functorLikeTraverse
-                    Nothing
-                    Nothing
-                    (Just covariant)
-                    (\x y   -> if allowFunctions then x `mplus` y else Just functions)
-                    (\_ xs  -> msum xs)
-                    (\_ x   -> x)
-                    (Just wrong_arg)
-                    (\_ x   -> x)
-                    (last (tyConTyVars rep_tc))
+    check_con con = msum (foldDataConArgs ft_check con)
+
+    ft_check :: FFoldType (Maybe SDoc)
+    ft_check = FT { ft_triv = Nothing, ft_var = Nothing, ft_co_var = Just covariant
+                 , ft_fun = \x y -> if allowFunctions then x `mplus` y else Just functions
+                  , ft_tup = \_ xs  -> msum xs
+                  , ft_ty_app = \_ x   -> x
+                  , ft_bad_app = Just wrong_arg
+                  , ft_forall = \_ x   -> x }
+                    
     covariant = quotes (pprSourceTyCon rep_tc) <+> 
                 ptext (sLit "uses the type variable in a function argument")
     functions = quotes (pprSourceTyCon rep_tc) <+> 
index 92a39d9..ba1c001 100644 (file)
@@ -23,7 +23,9 @@ module TcGenDeriv (
        gen_Show_binds,
        gen_Data_binds,
        gen_Typeable_binds,
-       gen_Functor_binds, functorLikeTraverse, deepSubtypesContaining,
+       gen_Functor_binds, 
+       FFoldType(..), functorLikeTraverse, 
+       deepSubtypesContaining, foldDataConArgs,
        gen_Foldable_binds,
        gen_Traversable_binds,
        genAuxBind
@@ -48,6 +50,7 @@ import TcType
 import TysPrim
 import TysWiredIn
 import Type
+import Var( TyVar )
 import TypeRep
 import VarSet
 import State
@@ -1280,27 +1283,27 @@ This is pretty much the same as $fmap, only without the $(cofmap 'a 'a) case:
 \begin{code}
 gen_Functor_binds :: SrcSpan -> TyCon -> (LHsBinds RdrName, DerivAuxBinds)
 gen_Functor_binds loc tycon
-  = (listToBag [fmap_bind], [])
+  = (unitBag fmap_bind, [])
   where
     data_cons = tyConDataCons tycon
-    arg = last (tyConTyVars tycon) -- argument to derive for, 'a in the above description
 
     fmap_bind = L loc $ mkFunBind (L loc fmap_RDR) (map fmap_eqn data_cons)
     fmap_eqn con = evalState (match_for_con [f_Pat] con parts) bs_RDRs
-      where parts = map derive_fmap_type (dataConOrigArgTys con)
-
-    derive_fmap_type :: Type -> LHsExpr RdrName -> State [RdrName] (LHsExpr RdrName)
-    derive_fmap_type = functorLikeTraverse
-        (\     x -> return x)                                         -- fmap f x = x
-        (\     x -> return (nlHsApp f_Expr x))                        -- fmap f x = f x
-        (panic "contravariant")
-        (\g h  x -> mkSimpleLam (\b -> h =<< (nlHsApp x `fmap` g b))) -- fmap f x = \b -> h (x (g b))
-        (mkSimpleTupleCase match_for_con)                             -- fmap f x = case x of (a1,a2,..) -> (g1 a1,g2 a2,..)
-        (\_ g  x -> do gg <- mkSimpleLam g
-                       return $ nlHsApps fmap_RDR [gg,x])             -- fmap f x = fmap g x
-        (panic "in other argument")
-        (\_ g  x -> g x)
-        arg
+      where 
+        parts = foldDataConArgs ft_fmap con
+
+    ft_fmap :: FFoldType (LHsExpr RdrName -> State [RdrName] (LHsExpr RdrName))
+    -- Tricky higher order type; I can't say I fully understand this code :-(
+    ft_fmap = FT { ft_triv = \x -> return x                    -- fmap f x = x
+                , ft_var  = \x -> return (nlHsApp f_Expr x)   -- fmap f x = f x
+                , ft_fun = \g h x -> mkSimpleLam (\b -> h =<< (nlHsApp x `fmap` g b)) 
+                                                              -- fmap f x = \b -> h (x (g b))
+                , ft_tup = mkSimpleTupleCase match_for_con    -- fmap f x = case x of (a1,a2,..) -> (g1 a1,g2 a2,..)
+                , ft_ty_app = \_ g  x -> do gg <- mkSimpleLam g      -- fmap f x = fmap g x
+                                            return $ nlHsApps fmap_RDR [gg,x]        
+                , ft_forall = \_ g  x -> g x
+                , ft_bad_app = panic "in other argument"
+                , ft_co_var = panic "contravariant" }
 
     match_for_con = mkSimpleConMatch $
         \con_name xsM -> do xs <- sequence xsM
@@ -1314,19 +1317,27 @@ This function works like a fold: it makes a value of type 'a' in a bottom up way
 
 \begin{code}
 -- Generic traversal for Functor deriving
-functorLikeTraverse :: a                    -- ^ Case: does not contain variable
-                    -> a                    -- ^ Case: the variable itself
-                    -> a                    -- ^ Case: the variable itself, contravariantly
-                    -> (a -> a -> a)        -- ^ Case: function type
-                    -> (Boxity -> [a] -> a) -- ^ Case: tuple type
-                    -> (Type -> a -> a)     -- ^ Case: type app, variable only in last argument
-                    -> a                    -- ^ Case: type app, variable other than in last argument
-                    -> (TcTyVar -> a -> a)  -- ^ Case: forall type
-                    -> TcTyVar              -- ^ Variable to look for
-                    -> Type                 -- ^ Type to process
-                    -> a
-functorLikeTraverse caseTrivial caseVar caseCoVar caseFun caseTuple caseTyApp caseWrongArg caseForAll var ty
-    = fst (go False ty)
+data FFoldType a      -- Describes how to fold over a Type in a functor like way
+   = FT { ft_triv    :: a                  -- Does not contain variable
+       , ft_var     :: a                   -- The variable itself                             
+       , ft_co_var  :: a                   -- The variable itself, contravariantly            
+       , ft_fun     :: a -> a -> a         -- Function type
+       , ft_tup     :: Boxity -> [a] -> a  -- Tuple type 
+       , ft_ty_app  :: Type -> a -> a      -- Type app, variable only in last argument        
+       , ft_bad_app :: a                   -- Type app, variable other than in last argument  
+       , ft_forall  :: TcTyVar -> a -> a   -- Forall type                                     
+     }
+
+functorLikeTraverse :: TyVar        -- ^ Variable to look for
+                   -> FFoldType a   -- ^ How to fold
+                   -> Type          -- ^ Type to process
+                   -> a
+functorLikeTraverse var (FT { ft_triv = caseTrivial,     ft_var = caseVar
+                            , ft_co_var = caseCoVar,     ft_fun = caseFun
+                            , ft_tup = caseTuple,        ft_ty_app = caseTyApp 
+                           , ft_bad_app = caseWrongArg, ft_forall = caseForAll })
+                   ty
+  = fst (go False ty)
   where -- go returns (result of type a, does type contain var)
         go co ty | Just ty' <- coreView ty = go co ty'
         go co (TyVarTy    v) | v == var = (if co then caseCoVar else caseVar,True)
@@ -1351,20 +1362,31 @@ functorLikeTraverse caseTrivial caseVar caseCoVar caseFun caseTuple caseTyApp ca
 
 -- Return all syntactic subterms of ty that contain var somewhere
 -- These are the things that should appear in instance constraints
-deepSubtypesContaining :: TcTyVar -> TcType -> [TcType]
-deepSubtypesContaining = functorLikeTraverse
-      []
-      []
-      (panic "contravariant")
-      (\x y   -> x ++ y)      -- function
-      (\_  xs -> concat xs)   -- tuple
-      (\ty x  -> ty : x)      -- tyapp
-      (panic "in other argument")
-      (\v x   -> filter (not . (v `elemVarSet`) . tyVarsOfType) x) -- forall v
-
+deepSubtypesContaining :: TyVar -> Type -> [TcType]
+deepSubtypesContaining tv
+  = functorLikeTraverse tv 
+       (FT { ft_triv = []
+           , ft_var = []
+           , ft_fun = (++), ft_tup = \_ xs -> concat xs
+           , ft_ty_app = (:)
+           , ft_bad_app = panic "in other argument"
+           , ft_co_var = panic "contravariant"
+           , ft_forall = \v xs -> filterOut ((v `elemVarSet`) . tyVarsOfType) xs })
+
+
+foldDataConArgs :: FFoldType a -> DataCon -> [a]
+-- Fold over the arguments of the datacon
+foldDataConArgs ft con
+  = map (functorLikeTraverse tv ft) (dataConOrigArgTys con)
+  where
+    tv = last (dataConUnivTyVars con) 
+                   -- Argument to derive for, 'a in the above description
+                   -- The validity checks have ensured that con is
+                   -- a vanilla data constructor
 
 -- Make a HsLam using a fresh variable from a State monad
 mkSimpleLam :: (LHsExpr id -> State [id] (LHsExpr id)) -> State [id] (LHsExpr id)
+-- (mkSimpleLam fn) returns (\x. fn(x))
 mkSimpleLam lam = do
     (n:names) <- get
     put names
@@ -1423,27 +1445,25 @@ since (f :: a -> b -> b), while (foldr f :: b -> t a -> b).
 \begin{code}
 gen_Foldable_binds :: SrcSpan -> TyCon -> (LHsBinds RdrName, DerivAuxBinds)
 gen_Foldable_binds loc tycon
-  = (listToBag [foldr_bind], [])
+  = (unitBag foldr_bind, [])
   where
     data_cons = tyConDataCons tycon
-    arg = last (tyConTyVars tycon) -- argument to derive for, 'a in the above description
 
     foldr_bind = L loc $ mkFunBind (L loc foldr_RDR) (map foldr_eqn data_cons)
     foldr_eqn con = evalState (match_for_con z_Expr [f_Pat,z_Pat] con parts) bs_RDRs
-      where parts = map derive_foldr_type (dataConOrigArgTys con)
-
-    derive_foldr_type :: Type -> LHsExpr RdrName -> LHsExpr RdrName -> State [RdrName] (LHsExpr RdrName)
-    derive_foldr_type = functorLikeTraverse
-        (\     _ z -> return z)                            -- foldr f z x = z
-        (\     x z -> return (nlHsApps f_RDR [x,z]))       -- foldr f z x = f x z
-        (panic "function")
-        (panic "function")
-        (\b gs x z -> mkSimpleTupleCase (match_for_con z) b gs x)
-        (\_ g  x z -> do gg <- mkSimpleLam2 g              -- foldr f z x = foldr (\xx zz -> g xx zz) z x
-                         return $ nlHsApps foldable_foldr_RDR [gg,z,x])
-        (panic "in other argument")
-        (\_ g  x z -> g x z)
-        arg
+      where 
+        parts = foldDataConArgs ft_foldr con
+
+    ft_foldr :: FFoldType (LHsExpr RdrName -> LHsExpr RdrName -> State [RdrName] (LHsExpr RdrName))
+    ft_foldr = FT { ft_triv = \_ z -> return z                        -- foldr f z x = z
+                 , ft_var  = \x z -> return (nlHsApps f_RDR [x,z])   -- foldr f z x = f x z
+                 , ft_tup = \b gs x z -> mkSimpleTupleCase (match_for_con z) b gs x
+                 , ft_ty_app = \_ g  x z -> do gg <- mkSimpleLam2 g   -- foldr f z x = foldr (\xx zz -> g xx zz) z x
+                                               return $ nlHsApps foldable_foldr_RDR [gg,z,x]
+                 , ft_forall = \_ g  x z -> g x z
+                 , ft_co_var = panic "covariant"
+                 , ft_fun = panic "function"
+                 , ft_bad_app = panic "in other argument" }
 
     match_for_con z = mkSimpleConMatch (\_con_name -> foldrM ($) z) -- g1 v1 (g2 v2 (.. z))
 \end{code}
@@ -1474,27 +1494,27 @@ instead of:         traverse f (T x y) = T x <$> f y
 \begin{code}
 gen_Traversable_binds :: SrcSpan -> TyCon -> (LHsBinds RdrName, DerivAuxBinds)
 gen_Traversable_binds loc tycon
-  = (listToBag [traverse_bind], [])
+  = (unitBag traverse_bind, [])
   where
     data_cons = tyConDataCons tycon
-    arg = last (tyConTyVars tycon) -- argument to derive for, 'a in the above description
 
     traverse_bind = L loc $ mkFunBind (L loc traverse_RDR) (map traverse_eqn data_cons)
     traverse_eqn con = evalState (match_for_con [f_Pat] con parts) bs_RDRs
-      where parts = map derive_travese_type (dataConOrigArgTys con)
-
-    derive_travese_type :: Type -> LHsExpr RdrName -> State [RdrName] (LHsExpr RdrName)
-    derive_travese_type = functorLikeTraverse
-        (\     x -> return (nlHsApps pure_RDR [x]))    -- traverse f x = pure x
-        (\     x -> return (nlHsApps f_RDR [x]))       -- travese f x = f x
-        (panic "function")
-        (panic "function")
-        (mkSimpleTupleCase match_for_con)              -- travese f x z = case x of (a1,a2,..) -> (,,) <$> g1 a1 <*> g2 a2 <*> ..
-        (\_ g  x -> do gg <- mkSimpleLam g             -- travese f x = travese (\xx -> g xx) x
-                       return $ nlHsApps traverse_RDR [gg,x])
-        (panic "in other argument")
-        (\_ g  x -> g x)
-        arg
+      where 
+        parts = foldDataConArgs ft_trav con
+
+
+    ft_trav :: FFoldType (LHsExpr RdrName -> State [RdrName] (LHsExpr RdrName))
+    ft_trav = FT { ft_triv = \x -> return (nlHsApps pure_RDR [x])   -- traverse f x = pure x
+                , ft_var = \x -> return (nlHsApps f_RDR [x])       -- travese f x = f x
+                , ft_tup = mkSimpleTupleCase match_for_con         -- travese f x z = case x of (a1,a2,..) -> 
+                                                                   --                   (,,) <$> g1 a1 <*> g2 a2 <*> ..
+                , ft_ty_app = \_ g  x -> do gg <- mkSimpleLam g    -- travese f x = travese (\xx -> g xx) x
+                                            return $ nlHsApps traverse_RDR [gg,x]
+                , ft_forall = \_ g  x -> g x
+                , ft_co_var = panic "covariant"
+                , ft_fun = panic "function"
+                , ft_bad_app = panic "in other argument" }
 
     match_for_con = mkSimpleConMatch $
         \con_name xsM -> do xs <- sequence xsM