gen_Show_binds,
gen_Data_binds,
gen_Typeable_binds,
+ gen_Functor_binds,
+ FFoldType(..), functorLikeTraverse,
+ deepSubtypesContaining, foldDataConArgs,
+ gen_Foldable_binds,
+ gen_Traversable_binds,
genAuxBind
) where
import TcType
import TysPrim
import TysWiredIn
+import Type
+import Var( TyVar )
+import TypeRep
+import VarSet
+import State
import Util
+import MonadUtils
import Outputable
import FastString
import OccName
data_con
= case tyConSingleDataCon_maybe tycon of -- just checking...
Nothing -> panic "get_Ix_binds"
- Just dc | any isUnLiftedType (dataConOrigArgTys dc)
- -> pprPanic "Can't derive Ix for a single-constructor type with primitive argument types:" (ppr tycon)
- | otherwise -> dc
+ Just dc -> dc
con_arity = dataConSourceArity data_con
data_con_RDR = getRdrName data_con
con_str = data_con_str data_con
prefix_parser = mk_parser prefix_prec prefix_stmts body
- prefix_stmts -- T a b c
- = (if not (isSym con_str) then
- [bindLex (ident_pat con_str)]
- else [read_punc "(", bindLex (symbol_pat con_str), read_punc ")"])
- ++ read_args
+
+ read_prefix_con
+ | isSym con_str = [read_punc "(", bindLex (symbol_pat con_str), read_punc ")"]
+ | otherwise = [bindLex (ident_pat con_str)]
+ read_infix_con
+ | isSym con_str = [bindLex (symbol_pat con_str)]
+ | otherwise = [read_punc "`", bindLex (ident_pat con_str), read_punc "`"]
+
+ prefix_stmts -- T a b c
+ = read_prefix_con ++ read_args
+
infix_stmts -- a %% b, or a `T` b
= [read_a1]
- ++ (if isSym con_str
- then [bindLex (symbol_pat con_str)]
- else [read_punc "`", bindLex (ident_pat con_str), read_punc "`"])
+ ++ read_infix_con
++ [read_a2]
record_stmts -- T { f1 = a, f2 = b }
- = [bindLex (ident_pat (wrapOpParens con_str)),
- read_punc "{"]
+ = read_prefix_con
+ ++ [read_punc "{"]
++ concat (intersperse [read_punc ","] field_stmts)
++ [read_punc "}"]
data_con_str con = occNameString (getOccName con)
read_punc c = bindLex (punc_pat c)
- read_arg a ty
- | isUnLiftedType ty = pprPanic "Error in deriving:" (text "Can't read unlifted types yet:" <+> ppr ty)
- | otherwise = noLoc (mkBindStmt (nlVarPat a) (nlHsVarApps step_RDR [readPrec_RDR]))
+ read_arg a ty = ASSERT( not (isUnLiftedType ty) )
+ noLoc (mkBindStmt (nlVarPat a) (nlHsVarApps step_RDR [readPrec_RDR]))
read_field lbl a = read_lbl lbl ++
[read_punc "=",
dataTypeOf _ = $dT
+ dataCast1 = gcast1 -- If T :: * -> *
+ dataCast2 = gcast2 -- if T :: * -> * -> *
+
+
\begin{code}
gen_Data_binds :: SrcSpan
-> TyCon
-> (LHsBinds RdrName, -- The method bindings
DerivAuxBinds) -- Auxiliary bindings
gen_Data_binds loc tycon
- = (listToBag [gfoldl_bind, gunfold_bind, toCon_bind, dataTypeOf_bind],
+ = (listToBag [gfoldl_bind, gunfold_bind, toCon_bind, dataTypeOf_bind]
+ `unionBags` gcast_binds,
-- Auxiliary definitions: the data type and constructors
MkTyCon tycon : map MkDataCon data_cons)
where
[nlWildPat]
(nlHsVar (mk_data_type_name tycon))
+ ------------ gcast1/2
+ tycon_kind = tyConKind tycon
+ gcast_binds | tycon_kind `eqKind` kind1 = mk_gcast dataCast1_RDR gcast1_RDR
+ | tycon_kind `eqKind` kind2 = mk_gcast dataCast2_RDR gcast2_RDR
+ | otherwise = emptyBag
+ mk_gcast dataCast_RDR gcast_RDR
+ = unitBag (mk_easy_FunBind loc dataCast_RDR [nlVarPat f_RDR]
+ (nlHsVar gcast_RDR `nlHsApp` nlHsVar f_RDR))
+
+
+kind1, kind2 :: Kind
+kind1 = liftedTypeKind `mkArrowKind` liftedTypeKind
+kind2 = liftedTypeKind `mkArrowKind` kind1
gfoldl_RDR, gunfold_RDR, toConstr_RDR, dataTypeOf_RDR, mkConstr_RDR,
- mkDataType_RDR, conIndex_RDR, prefix_RDR, infix_RDR :: RdrName
+ mkDataType_RDR, conIndex_RDR, prefix_RDR, infix_RDR,
+ dataCast1_RDR, dataCast2_RDR, gcast1_RDR, gcast2_RDR :: RdrName
gfoldl_RDR = varQual_RDR gENERICS (fsLit "gfoldl")
gunfold_RDR = varQual_RDR gENERICS (fsLit "gunfold")
toConstr_RDR = varQual_RDR gENERICS (fsLit "toConstr")
dataTypeOf_RDR = varQual_RDR gENERICS (fsLit "dataTypeOf")
+dataCast1_RDR = varQual_RDR gENERICS (fsLit "dataCast1")
+dataCast2_RDR = varQual_RDR gENERICS (fsLit "dataCast2")
+gcast1_RDR = varQual_RDR tYPEABLE (fsLit "gcast1")
+gcast2_RDR = varQual_RDR tYPEABLE (fsLit "gcast2")
mkConstr_RDR = varQual_RDR gENERICS (fsLit "mkConstr")
mkDataType_RDR = varQual_RDR gENERICS (fsLit "mkDataType")
conIndex_RDR = varQual_RDR gENERICS (fsLit "constrIndex")
infix_RDR = dataQual_RDR gENERICS (fsLit "Infix")
\end{code}
+
+
+%************************************************************************
+%* *
+ Functor instances
+%* *
+%************************************************************************
+
+For the data type:
+
+ data T a = T1 Int a | T2 (T a)
+
+We generate the instance:
+
+ instance Functor T where
+ fmap f (T1 b1 a) = T1 b1 (f a)
+ fmap f (T2 ta) = T2 (fmap f ta)
+
+Notice that we don't simply apply 'fmap' to the constructor arguments.
+Rather
+ - Do nothing to an argument whose type doesn't mention 'a'
+ - Apply 'f' to an argument of type 'a'
+ - Apply 'fmap f' to other arguments
+That's why we have to recurse deeply into the constructor argument types,
+rather than just one level, as we typically do.
+
+What about types with more than one type parameter? In general, we only
+derive Functor for the last position:
+
+ data S a b = S1 [b] | S2 (a, T a b)
+ instance Functor (S a) where
+ fmap f (S1 bs) = S1 (fmap f bs)
+ fmap f (S2 (p,q)) = S2 (a, fmap f q)
+
+However, we have special cases for
+ - tuples
+ - functions
+
+More formally, we write the derivation of fmap code over type variable
+'a for type 'b as ($fmap 'a 'b). In this general notation the derived
+instance for T is:
+
+ instance Functor T where
+ fmap f (T1 x1 x2) = T1 ($(fmap 'a 'b1) x1) ($(fmap 'a 'a) x2)
+ fmap f (T2 x1) = T2 ($(fmap 'a '(T a)) x1)
+
+ $(fmap 'a 'b) x = x -- when b does not contain a
+ $(fmap 'a 'a) x = f x
+ $(fmap 'a '(b1,b2)) x = case x of (x1,x2) -> ($(fmap 'a 'b1) x1, $(fmap 'a 'b2) x2)
+ $(fmap 'a '(T b1 b2)) x = fmap $(fmap 'a 'b2) x -- when a only occurs in the last parameter, b2
+ $(fmap 'a '(b -> c)) x = \b -> $(fmap 'a' 'c) (x ($(cofmap 'a 'b) b))
+
+For functions, the type parameter 'a can occur in a contravariant position,
+which means we need to derive a function like:
+
+ cofmap :: (a -> b) -> (f b -> f a)
+
+This is pretty much the same as $fmap, only without the $(cofmap 'a 'a) case:
+
+ $(cofmap 'a 'b) x = x -- when b does not contain a
+ $(cofmap 'a 'a) x = error "type variable in contravariant position"
+ $(cofmap 'a '(b1,b2)) x = case x of (x1,x2) -> ($(cofmap 'a 'b1) x1, $(cofmap 'a 'b2) x2)
+ $(cofmap 'a '[b]) x = map $(cofmap 'a 'b) x
+ $(cofmap 'a '(T b1 b2)) x = fmap $(cofmap 'a 'b2) x -- when a only occurs in the last parameter, b2
+ $(cofmap 'a '(b -> c)) x = \b -> $(cofmap 'a' 'c) (x ($(fmap 'a 'c) b))
+
+\begin{code}
+gen_Functor_binds :: SrcSpan -> TyCon -> (LHsBinds RdrName, DerivAuxBinds)
+gen_Functor_binds loc tycon
+ = (unitBag fmap_bind, [])
+ where
+ data_cons = tyConDataCons tycon
+
+ fmap_bind = L loc $ mkFunBind (L loc fmap_RDR) (map fmap_eqn data_cons)
+ fmap_eqn con = evalState (match_for_con [f_Pat] con parts) bs_RDRs
+ where
+ parts = foldDataConArgs ft_fmap con
+
+ ft_fmap :: FFoldType (LHsExpr RdrName -> State [RdrName] (LHsExpr RdrName))
+ -- Tricky higher order type; I can't say I fully understand this code :-(
+ ft_fmap = FT { ft_triv = \x -> return x -- fmap f x = x
+ , ft_var = \x -> return (nlHsApp f_Expr x) -- fmap f x = f x
+ , ft_fun = \g h x -> mkSimpleLam (\b -> h =<< (nlHsApp x `fmap` g b))
+ -- fmap f x = \b -> h (x (g b))
+ , ft_tup = mkSimpleTupleCase match_for_con -- fmap f x = case x of (a1,a2,..) -> (g1 a1,g2 a2,..)
+ , ft_ty_app = \_ g x -> do gg <- mkSimpleLam g -- fmap f x = fmap g x
+ return $ nlHsApps fmap_RDR [gg,x]
+ , ft_forall = \_ g x -> g x
+ , ft_bad_app = panic "in other argument"
+ , ft_co_var = panic "contravariant" }
+
+ match_for_con = mkSimpleConMatch $
+ \con_name xsM -> do xs <- sequence xsM
+ return (nlHsApps con_name xs) -- Con (g1 v1) (g2 v2) ..
+\end{code}
+
+Utility functions related to Functor deriving.
+
+Since several things use the same pattern of traversal, this is abstracted into functorLikeTraverse.
+This function works like a fold: it makes a value of type 'a' in a bottom up way.
+
+\begin{code}
+-- Generic traversal for Functor deriving
+data FFoldType a -- Describes how to fold over a Type in a functor like way
+ = FT { ft_triv :: a -- Does not contain variable
+ , ft_var :: a -- The variable itself
+ , ft_co_var :: a -- The variable itself, contravariantly
+ , ft_fun :: a -> a -> a -- Function type
+ , ft_tup :: Boxity -> [a] -> a -- Tuple type
+ , ft_ty_app :: Type -> a -> a -- Type app, variable only in last argument
+ , ft_bad_app :: a -- Type app, variable other than in last argument
+ , ft_forall :: TcTyVar -> a -> a -- Forall type
+ }
+
+functorLikeTraverse :: TyVar -- ^ Variable to look for
+ -> FFoldType a -- ^ How to fold
+ -> Type -- ^ Type to process
+ -> a
+functorLikeTraverse var (FT { ft_triv = caseTrivial, ft_var = caseVar
+ , ft_co_var = caseCoVar, ft_fun = caseFun
+ , ft_tup = caseTuple, ft_ty_app = caseTyApp
+ , ft_bad_app = caseWrongArg, ft_forall = caseForAll })
+ ty
+ = fst (go False ty)
+ where -- go returns (result of type a, does type contain var)
+ go co ty | Just ty' <- coreView ty = go co ty'
+ go co (TyVarTy v) | v == var = (if co then caseCoVar else caseVar,True)
+ go co (FunTy (PredTy _) b) = go co b
+ go co (FunTy x y) | xc || yc = (caseFun xr yr,True)
+ where (xr,xc) = go (not co) x
+ (yr,yc) = go co y
+ go co (AppTy x y) | xc = (caseWrongArg, True)
+ | yc = (caseTyApp x yr, True)
+ where (_, xc) = go co x
+ (yr,yc) = go co y
+ go co ty@(TyConApp con args)
+ | isTupleTyCon con = (caseTuple (tupleTyConBoxity con) xrs,True)
+ | null args = (caseTrivial,False) -- T
+ | or (init xcs) = (caseWrongArg,True) -- T (..var..) ty
+ | last xcs = -- T (..no var..) ty
+ (caseTyApp (fst (splitAppTy ty)) (last xrs),True)
+ where (xrs,xcs) = unzip (map (go co) args)
+ go co (ForAllTy v x) | v /= var && xc = (caseForAll v xr,True)
+ where (xr,xc) = go co x
+ go _ _ = (caseTrivial,False)
+
+-- Return all syntactic subterms of ty that contain var somewhere
+-- These are the things that should appear in instance constraints
+deepSubtypesContaining :: TyVar -> Type -> [TcType]
+deepSubtypesContaining tv
+ = functorLikeTraverse tv
+ (FT { ft_triv = []
+ , ft_var = []
+ , ft_fun = (++), ft_tup = \_ xs -> concat xs
+ , ft_ty_app = (:)
+ , ft_bad_app = panic "in other argument"
+ , ft_co_var = panic "contravariant"
+ , ft_forall = \v xs -> filterOut ((v `elemVarSet`) . tyVarsOfType) xs })
+
+
+foldDataConArgs :: FFoldType a -> DataCon -> [a]
+-- Fold over the arguments of the datacon
+foldDataConArgs ft con
+ = map (functorLikeTraverse tv ft) (dataConOrigArgTys con)
+ where
+ tv = last (dataConUnivTyVars con)
+ -- Argument to derive for, 'a in the above description
+ -- The validity checks have ensured that con is
+ -- a vanilla data constructor
+
+-- Make a HsLam using a fresh variable from a State monad
+mkSimpleLam :: (LHsExpr id -> State [id] (LHsExpr id)) -> State [id] (LHsExpr id)
+-- (mkSimpleLam fn) returns (\x. fn(x))
+mkSimpleLam lam = do
+ (n:names) <- get
+ put names
+ body <- lam (nlHsVar n)
+ return (mkHsLam [nlVarPat n] body)
+
+mkSimpleLam2 :: (LHsExpr id -> LHsExpr id -> State [id] (LHsExpr id)) -> State [id] (LHsExpr id)
+mkSimpleLam2 lam = do
+ (n1:n2:names) <- get
+ put names
+ body <- lam (nlHsVar n1) (nlHsVar n2)
+ return (mkHsLam [nlVarPat n1,nlVarPat n2] body)
+
+-- "Con a1 a2 a3 -> fold [x1 a1, x2 a2, x3 a3]"
+mkSimpleConMatch :: Monad m => (RdrName -> [a] -> m (LHsExpr RdrName)) -> [LPat RdrName] -> DataCon -> [LHsExpr RdrName -> a] -> m (LMatch RdrName)
+mkSimpleConMatch fold extra_pats con insides = do
+ let con_name = getRdrName con
+ let vars_needed = takeList insides as_RDRs
+ let pat = nlConVarPat con_name vars_needed
+ rhs <- fold con_name (zipWith ($) insides (map nlHsVar vars_needed))
+ return $ mkMatch (extra_pats ++ [pat]) rhs emptyLocalBinds
+
+-- "case x of (a1,a2,a3) -> fold [x1 a1, x2 a2, x3 a3]"
+mkSimpleTupleCase :: Monad m => ([LPat RdrName] -> DataCon -> [LHsExpr RdrName -> a] -> m (LMatch RdrName))
+ -> Boxity -> [LHsExpr RdrName -> a] -> LHsExpr RdrName -> m (LHsExpr RdrName)
+mkSimpleTupleCase match_for_con boxity insides x = do
+ let con = tupleCon boxity (length insides)
+ match <- match_for_con [] con insides
+ return $ nlHsCase x [match]
+\end{code}
+
+
+%************************************************************************
+%* *
+ Foldable instances
+%* *
+%************************************************************************
+
+Deriving Foldable instances works the same way as Functor instances,
+only Foldable instances are not possible for function types at all.
+Here the derived instance for the type T above is:
+
+ instance Foldable T where
+ foldr f z (T1 x1 x2 x3) = $(foldr 'a 'b1) x1 ( $(foldr 'a 'a) x2 ( $(foldr 'a 'b2) x3 z ) )
+
+The cases are:
+
+ $(foldr 'a 'b) x z = z -- when b does not contain a
+ $(foldr 'a 'a) x z = f x z
+ $(foldr 'a '(b1,b2)) x z = case x of (x1,x2) -> $(foldr 'a 'b1) x1 ( $(foldr 'a 'b2) x2 z )
+ $(foldr 'a '(T b1 b2)) x z = foldr $(foldr 'a 'b2) x z -- when a only occurs in the last parameter, b2
+
+Note that the arguments to the real foldr function are the wrong way around,
+since (f :: a -> b -> b), while (foldr f :: b -> t a -> b).
+
+\begin{code}
+gen_Foldable_binds :: SrcSpan -> TyCon -> (LHsBinds RdrName, DerivAuxBinds)
+gen_Foldable_binds loc tycon
+ = (unitBag foldr_bind, [])
+ where
+ data_cons = tyConDataCons tycon
+
+ foldr_bind = L loc $ mkFunBind (L loc foldr_RDR) (map foldr_eqn data_cons)
+ foldr_eqn con = evalState (match_for_con z_Expr [f_Pat,z_Pat] con parts) bs_RDRs
+ where
+ parts = foldDataConArgs ft_foldr con
+
+ ft_foldr :: FFoldType (LHsExpr RdrName -> LHsExpr RdrName -> State [RdrName] (LHsExpr RdrName))
+ ft_foldr = FT { ft_triv = \_ z -> return z -- foldr f z x = z
+ , ft_var = \x z -> return (nlHsApps f_RDR [x,z]) -- foldr f z x = f x z
+ , ft_tup = \b gs x z -> mkSimpleTupleCase (match_for_con z) b gs x
+ , ft_ty_app = \_ g x z -> do gg <- mkSimpleLam2 g -- foldr f z x = foldr (\xx zz -> g xx zz) z x
+ return $ nlHsApps foldable_foldr_RDR [gg,z,x]
+ , ft_forall = \_ g x z -> g x z
+ , ft_co_var = panic "covariant"
+ , ft_fun = panic "function"
+ , ft_bad_app = panic "in other argument" }
+
+ match_for_con z = mkSimpleConMatch (\_con_name -> foldrM ($) z) -- g1 v1 (g2 v2 (.. z))
+\end{code}
+
+
+%************************************************************************
+%* *
+ Traversable instances
+%* *
+%************************************************************************
+
+Again, Traversable is much like Functor and Foldable.
+
+The cases are:
+
+ $(traverse 'a 'b) x = pure x -- when b does not contain a
+ $(traverse 'a 'a) x = f x
+ $(traverse 'a '(b1,b2)) x = case x of (x1,x2) -> (,) <$> $(traverse 'a 'b1) x1 <*> $(traverse 'a 'b2) x2
+ $(traverse 'a '(T b1 b2)) x = traverse $(traverse 'a 'b2) x -- when a only occurs in the last parameter, b2
+
+Note that the generated code is not as efficient as it could be. For instance:
+
+ data T a = T Int a deriving Traversable
+
+gives the function: traverse f (T x y) = T <$> pure x <*> f y
+instead of: traverse f (T x y) = T x <$> f y
+
+\begin{code}
+gen_Traversable_binds :: SrcSpan -> TyCon -> (LHsBinds RdrName, DerivAuxBinds)
+gen_Traversable_binds loc tycon
+ = (unitBag traverse_bind, [])
+ where
+ data_cons = tyConDataCons tycon
+
+ traverse_bind = L loc $ mkFunBind (L loc traverse_RDR) (map traverse_eqn data_cons)
+ traverse_eqn con = evalState (match_for_con [f_Pat] con parts) bs_RDRs
+ where
+ parts = foldDataConArgs ft_trav con
+
+
+ ft_trav :: FFoldType (LHsExpr RdrName -> State [RdrName] (LHsExpr RdrName))
+ ft_trav = FT { ft_triv = \x -> return (nlHsApps pure_RDR [x]) -- traverse f x = pure x
+ , ft_var = \x -> return (nlHsApps f_RDR [x]) -- travese f x = f x
+ , ft_tup = mkSimpleTupleCase match_for_con -- travese f x z = case x of (a1,a2,..) ->
+ -- (,,) <$> g1 a1 <*> g2 a2 <*> ..
+ , ft_ty_app = \_ g x -> do gg <- mkSimpleLam g -- travese f x = travese (\xx -> g xx) x
+ return $ nlHsApps traverse_RDR [gg,x]
+ , ft_forall = \_ g x -> g x
+ , ft_co_var = panic "covariant"
+ , ft_fun = panic "function"
+ , ft_bad_app = panic "in other argument" }
+
+ match_for_con = mkSimpleConMatch $
+ \con_name xsM -> do xs <- sequence xsM
+ return (mkApCon (nlHsVar con_name) xs)
+
+ -- ((Con <$> x1) <*> x2) <*> ..
+ mkApCon con [] = nlHsApps pure_RDR [con]
+ mkApCon con (x:xs) = foldl appAp (nlHsApps fmap_RDR [con,x]) xs
+ where appAp x y = nlHsApps ap_RDR [x,y]
+\end{code}
+
+
+
%************************************************************************
%* *
\subsection{Generating extra binds (@con2tag@ and @tag2con@)}
\end{code}
\begin{code}
-a_RDR, b_RDR, c_RDR, d_RDR, k_RDR, z_RDR, ah_RDR, bh_RDR, ch_RDR, dh_RDR,
+a_RDR, b_RDR, c_RDR, d_RDR, f_RDR, k_RDR, z_RDR, ah_RDR, bh_RDR, ch_RDR, dh_RDR,
cmp_eq_RDR :: RdrName
a_RDR = mkVarUnqual (fsLit "a")
b_RDR = mkVarUnqual (fsLit "b")
c_RDR = mkVarUnqual (fsLit "c")
d_RDR = mkVarUnqual (fsLit "d")
+f_RDR = mkVarUnqual (fsLit "f")
k_RDR = mkVarUnqual (fsLit "k")
z_RDR = mkVarUnqual (fsLit "z")
ah_RDR = mkVarUnqual (fsLit "a#")
bs_RDRs = [ mkVarUnqual (mkFastString ("b"++show i)) | i <- [(1::Int) .. ] ]
cs_RDRs = [ mkVarUnqual (mkFastString ("c"++show i)) | i <- [(1::Int) .. ] ]
-a_Expr, b_Expr, c_Expr, ltTag_Expr, eqTag_Expr, gtTag_Expr,
+a_Expr, b_Expr, c_Expr, f_Expr, z_Expr, ltTag_Expr, eqTag_Expr, gtTag_Expr,
false_Expr, true_Expr :: LHsExpr RdrName
a_Expr = nlHsVar a_RDR
b_Expr = nlHsVar b_RDR
c_Expr = nlHsVar c_RDR
+f_Expr = nlHsVar f_RDR
+z_Expr = nlHsVar z_RDR
ltTag_Expr = nlHsVar ltTag_RDR
eqTag_Expr = nlHsVar eqTag_RDR
gtTag_Expr = nlHsVar gtTag_RDR
false_Expr = nlHsVar false_RDR
true_Expr = nlHsVar true_RDR
-a_Pat, b_Pat, c_Pat, d_Pat, k_Pat, z_Pat :: LPat RdrName
+a_Pat, b_Pat, c_Pat, d_Pat, f_Pat, k_Pat, z_Pat :: LPat RdrName
a_Pat = nlVarPat a_RDR
b_Pat = nlVarPat b_RDR
c_Pat = nlVarPat c_RDR
d_Pat = nlVarPat d_RDR
+f_Pat = nlVarPat f_RDR
k_Pat = nlVarPat k_RDR
z_Pat = nlVarPat z_RDR