From 1c15bee5a8fc004c16693d7d7a2d95b442549b66 Mon Sep 17 00:00:00 2001 From: "simonpj@microsoft.com" Date: Mon, 2 Feb 2009 13:48:29 +0000 Subject: [PATCH] Add the ability to derive instances of Functor, Foldable, Traversable This patch is a straightforward extension of the 'deriving' mechanism. The ability to derive classes Functor, Foldable, Traverable is controlled by a single flag -XDeriveFunctor. (Maybe that's a poor name.) Still to come: documentation Thanks to twanvl for developing the patch --- compiler/main/DynFlags.hs | 3 + compiler/prelude/PrelNames.lhs | 54 ++++--- compiler/typecheck/TcDeriv.lhs | 116 +++++++++++--- compiler/typecheck/TcGenDeriv.lhs | 314 ++++++++++++++++++++++++++++++++++++- compiler/utils/Util.lhs | 5 + 5 files changed, 447 insertions(+), 45 deletions(-) diff --git a/compiler/main/DynFlags.hs b/compiler/main/DynFlags.hs index 35949f7..44bd124 100644 --- a/compiler/main/DynFlags.hs +++ b/compiler/main/DynFlags.hs @@ -220,6 +220,7 @@ data DynFlag | Opt_RelaxedPolyRec | Opt_StandaloneDeriving | Opt_DeriveDataTypeable + | Opt_DeriveFunctor | Opt_TypeSynonymInstances | Opt_FlexibleContexts | Opt_FlexibleInstances @@ -1771,6 +1772,7 @@ xFlags = [ ( "UnboxedTuples", Opt_UnboxedTuples, const Supported ), ( "StandaloneDeriving", Opt_StandaloneDeriving, const Supported ), ( "DeriveDataTypeable", Opt_DeriveDataTypeable, const Supported ), + ( "DeriveFunctor", Opt_DeriveFunctor, const Supported ), ( "TypeSynonymInstances", Opt_TypeSynonymInstances, const Supported ), ( "FlexibleContexts", Opt_FlexibleContexts, const Supported ), ( "FlexibleInstances", Opt_FlexibleInstances, const Supported ), @@ -1809,6 +1811,7 @@ glasgowExtsFlags = [ , Opt_TypeSynonymInstances , Opt_StandaloneDeriving , Opt_DeriveDataTypeable + , Opt_DeriveFunctor , Opt_FlexibleContexts , Opt_FlexibleInstances , Opt_ConstrainedClassMethods diff --git a/compiler/prelude/PrelNames.lhs b/compiler/prelude/PrelNames.lhs index bf46465..235bc3c 100644 --- a/compiler/prelude/PrelNames.lhs +++ b/compiler/prelude/PrelNames.lhs @@ -131,6 +131,9 @@ basicKnownKeyNames realFloatClassName, -- numeric dataClassName, isStringClassName, + applicativeClassName, + foldableClassName, + traversableClassName, -- Numeric stuff negateName, minusName, @@ -164,7 +167,7 @@ basicKnownKeyNames -- Read stuff readClassName, - + -- Stable pointers newStablePtrName, @@ -204,11 +207,8 @@ basicKnownKeyNames randomClassName, randomGenClassName, monadPlusClassName, -- Annotation type checking - toAnnotationWrapperName, + toAnnotationWrapperName - -- Booleans - andName, orName - -- The Either type , eitherTyConName, leftDataConName, rightDataConName @@ -236,10 +236,11 @@ pRELUDE = mkBaseModule_ pRELUDE_NAME gHC_PRIM, gHC_TYPES, gHC_BOOL, gHC_UNIT, gHC_ORDERING, gHC_GENERICS, gHC_CLASSES, gHC_BASE, gHC_ENUM, gHC_SHOW, gHC_READ, gHC_NUM, gHC_INTEGER, gHC_INTEGER_INTERNALS, gHC_LIST, gHC_PARR, - gHC_TUPLE, dATA_TUPLE, dATA_EITHER, dATA_STRING, gHC_PACK, gHC_CONC, gHC_IO_BASE, + gHC_TUPLE, dATA_TUPLE, dATA_EITHER, dATA_STRING, dATA_FOLDABLE, dATA_TRAVERSABLE, + gHC_PACK, gHC_CONC, gHC_IO_BASE, gHC_ST, gHC_ARR, gHC_STABLE, gHC_ADDR, gHC_PTR, gHC_ERR, gHC_REAL, gHC_FLOAT, gHC_TOP_HANDLER, sYSTEM_IO, dYNAMIC, tYPEABLE, gENERICS, - dOTNET, rEAD_PREC, lEX, gHC_INT, gHC_WORD, mONAD, mONAD_FIX, aRROW, + dOTNET, rEAD_PREC, lEX, gHC_INT, gHC_WORD, mONAD, mONAD_FIX, aRROW, cONTROL_APPLICATIVE, gHC_DESUGAR, rANDOM, gHC_EXTS, cONTROL_EXCEPTION_BASE :: Module gHC_PRIM = mkPrimModule (fsLit "GHC.Prim") -- Primitive types and values gHC_TYPES = mkPrimModule (fsLit "GHC.Types") @@ -261,6 +262,8 @@ gHC_TUPLE = mkPrimModule (fsLit "GHC.Tuple") dATA_TUPLE = mkBaseModule (fsLit "Data.Tuple") dATA_EITHER = mkBaseModule (fsLit "Data.Either") dATA_STRING = mkBaseModule (fsLit "Data.String") +dATA_FOLDABLE = mkBaseModule (fsLit "Data.Foldable") +dATA_TRAVERSABLE= mkBaseModule (fsLit "Data.Traversable") gHC_PACK = mkBaseModule (fsLit "GHC.Pack") gHC_CONC = mkBaseModule (fsLit "GHC.Conc") gHC_IO_BASE = mkBaseModule (fsLit "GHC.IOBase") @@ -285,6 +288,7 @@ gHC_WORD = mkBaseModule (fsLit "GHC.Word") mONAD = mkBaseModule (fsLit "Control.Monad") mONAD_FIX = mkBaseModule (fsLit "Control.Monad.Fix") aRROW = mkBaseModule (fsLit "Control.Arrow") +cONTROL_APPLICATIVE = mkBaseModule (fsLit "Control.Applicative") gHC_DESUGAR = mkBaseModule (fsLit "GHC.Desugar") rANDOM = mkBaseModule (fsLit "System.Random") gHC_EXTS = mkBaseModule (fsLit "GHC.Exts") @@ -389,9 +393,6 @@ returnM_RDR = nameRdrName returnMName bindM_RDR = nameRdrName bindMName failM_RDR = nameRdrName failMName -and_RDR :: RdrName -and_RDR = nameRdrName andName - left_RDR, right_RDR :: RdrName left_RDR = nameRdrName leftDataConName right_RDR = nameRdrName rightDataConName @@ -443,8 +444,9 @@ compose_RDR :: RdrName compose_RDR = varQual_RDR gHC_BASE (fsLit ".") not_RDR, getTag_RDR, succ_RDR, pred_RDR, minBound_RDR, maxBound_RDR, - range_RDR, inRange_RDR, index_RDR, + and_RDR, range_RDR, inRange_RDR, index_RDR, unsafeIndex_RDR, unsafeRangeSize_RDR :: RdrName +and_RDR = varQual_RDR gHC_CLASSES (fsLit "&&") not_RDR = varQual_RDR gHC_CLASSES (fsLit "not") getTag_RDR = varQual_RDR gHC_BASE (fsLit "getTag") succ_RDR = varQual_RDR gHC_ENUM (fsLit "succ") @@ -502,6 +504,13 @@ inlDataCon_RDR = dataQual_RDR gHC_GENERICS (fsLit "Inl") inrDataCon_RDR = dataQual_RDR gHC_GENERICS (fsLit "Inr") genUnitDataCon_RDR = dataQual_RDR gHC_GENERICS (fsLit "Unit") +fmap_RDR, pure_RDR, ap_RDR, foldable_foldr_RDR, traverse_RDR :: RdrName +fmap_RDR = varQual_RDR gHC_BASE (fsLit "fmap") +pure_RDR = varQual_RDR cONTROL_APPLICATIVE (fsLit "pure") +ap_RDR = varQual_RDR cONTROL_APPLICATIVE (fsLit "<*>") +foldable_foldr_RDR = varQual_RDR dATA_FOLDABLE (fsLit "foldr") +traverse_RDR = varQual_RDR dATA_TRAVERSABLE (fsLit "traverse") + ---------------------- varQual_RDR, tcQual_RDR, clsQual_RDR, dataQual_RDR :: Module -> FastString -> RdrName @@ -573,13 +582,19 @@ bindMName = methName gHC_BASE (fsLit ">>=") bindMClassOpKey returnMName = methName gHC_BASE (fsLit "return") returnMClassOpKey failMName = methName gHC_BASE (fsLit "fail") failMClassOpKey +-- Classes (Applicative, Foldable, Traversable) +applicativeClassName, foldableClassName, traversableClassName :: Name +applicativeClassName = clsQual cONTROL_APPLICATIVE (fsLit "Applicative") applicativeClassKey +foldableClassName = clsQual dATA_FOLDABLE (fsLit "Foldable") foldableClassKey +traversableClassName = clsQual dATA_TRAVERSABLE (fsLit "Traversable") traversableClassKey + -- Functions for GHC extensions groupWithName :: Name groupWithName = varQual gHC_EXTS (fsLit "groupWith") groupWithIdKey -- Random PrelBase functions fromStringName, otherwiseIdName, foldrName, buildName, augmentName, - mapName, appendName, andName, orName, assertName, + mapName, appendName, assertName, breakpointName, breakpointCondName, breakpointAutoName, opaqueTyConName :: Name fromStringName = methName dATA_STRING (fsLit "fromString") fromStringClassOpKey @@ -589,8 +604,6 @@ buildName = varQual gHC_BASE (fsLit "build") buildIdKey augmentName = varQual gHC_BASE (fsLit "augment") augmentIdKey mapName = varQual gHC_BASE (fsLit "map") mapIdKey appendName = varQual gHC_BASE (fsLit "++") appendIdKey -andName = varQual gHC_CLASSES (fsLit "&&") andIdKey -orName = varQual gHC_CLASSES (fsLit "||") orIdKey assertName = varQual gHC_BASE (fsLit "assert") assertIdKey breakpointName = varQual gHC_BASE (fsLit "breakpoint") breakpointIdKey breakpointCondName= varQual gHC_BASE (fsLit "breakpointCond") breakpointCondIdKey @@ -889,6 +902,11 @@ randomGenClassKey = mkPreludeClassUnique 32 isStringClassKey :: Unique isStringClassKey = mkPreludeClassUnique 33 + +applicativeClassKey, foldableClassKey, traversableClassKey :: Unique +applicativeClassKey = mkPreludeClassUnique 34 +foldableClassKey = mkPreludeClassUnique 35 +traversableClassKey = mkPreludeClassUnique 36 \end{code} %************************************************************************ @@ -1156,9 +1174,7 @@ rootMainKey, runMainKey :: Unique rootMainKey = mkPreludeMiscIdUnique 55 runMainKey = mkPreludeMiscIdUnique 56 -andIdKey, orIdKey, thenIOIdKey, lazyIdKey, assertErrorIdKey :: Unique -andIdKey = mkPreludeMiscIdUnique 57 -orIdKey = mkPreludeMiscIdUnique 58 +thenIOIdKey, lazyIdKey, assertErrorIdKey :: Unique thenIOIdKey = mkPreludeMiscIdUnique 59 lazyIdKey = mkPreludeMiscIdUnique 60 assertErrorIdKey = mkPreludeMiscIdUnique 61 @@ -1260,6 +1276,7 @@ fromStringClassOpKey = mkPreludeMiscIdUnique 125 toAnnotationWrapperIdKey :: Unique toAnnotationWrapperIdKey = mkPreludeMiscIdUnique 126 + ---------------- Template Haskell ------------------- -- USES IdUniques 200-399 ----------------------------------------------------- @@ -1325,7 +1342,8 @@ standardClassKeys = derivableClassKeys ++ numericClassKeys ++ [randomClassKey, randomGenClassKey, functorClassKey, monadClassKey, monadPlusClassKey, - isStringClassKey + isStringClassKey, + applicativeClassKey, foldableClassKey, traversableClassKey ] \end{code} diff --git a/compiler/typecheck/TcDeriv.lhs b/compiler/typecheck/TcDeriv.lhs index eac2209..a507197 100644 --- a/compiler/typecheck/TcDeriv.lhs +++ b/compiler/typecheck/TcDeriv.lhs @@ -49,6 +49,8 @@ import ListSetOps import Outputable import FastString import Bag + +import Control.Monad \end{code} %************************************************************************ @@ -566,15 +568,12 @@ mkEqnHelp orig tvs cls cls_tys tc_app mtheta className cls `elem` typeableClassNames) (derivingHiddenErr tycon) - ; mayDeriveDataTypeable <- doptM Opt_DeriveDataTypeable - ; newtype_deriving <- doptM Opt_GeneralizedNewtypeDeriving - + ; dflags <- getDOpts ; if isDataTyCon rep_tc then - mkDataTypeEqn orig mayDeriveDataTypeable tvs cls cls_tys + mkDataTypeEqn orig dflags tvs cls cls_tys tycon tc_args rep_tc rep_tc_args mtheta else - mkNewTypeEqn orig mayDeriveDataTypeable newtype_deriving - tvs cls cls_tys + mkNewTypeEqn orig dflags tvs cls cls_tys tycon tc_args rep_tc rep_tc_args mtheta } | otherwise = failWithTc (derivingThingErr cls cls_tys tc_app @@ -631,13 +630,21 @@ famInstNotFound tycon tys %************************************************************************ \begin{code} -mkDataTypeEqn :: InstOrigin -> Bool -> [Var] -> Class -> [Type] - -> TyCon -> [Type] -> TyCon -> [Type] -> Maybe ThetaType - -> TcRn EarlyDerivSpec -- Return 'Nothing' if error - -mkDataTypeEqn orig mayDeriveDataTypeable tvs cls cls_tys +mkDataTypeEqn :: InstOrigin + -> DynFlags + -> [Var] -- Universally quantified type variables in the instance + -> Class -- Class for which we need to derive an instance + -> [Type] -- Other parameters to the class except the last + -> TyCon -- Type constructor for which the instance is requested (last parameter to the type class) + -> [Type] -- Parameters to the type constructor + -> TyCon -- rep of the above (for type families) + -> [Type] -- rep of the above + -> Maybe ThetaType -- Context of the instance, for standalone deriving + -> TcRn EarlyDerivSpec -- Return 'Nothing' if error + +mkDataTypeEqn orig dflags tvs cls cls_tys tycon tc_args rep_tc rep_tc_args mtheta - = case checkSideConditions mayDeriveDataTypeable cls cls_tys rep_tc of + = case checkSideConditions dflags cls cls_tys rep_tc of -- NB: pass the *representation* tycon to checkSideConditions CanDerive -> mk_data_eqn orig tvs cls tycon tc_args rep_tc rep_tc_args mtheta NonDerivableClass -> bale_out (nonStdErr cls) @@ -656,7 +663,7 @@ mk_data_eqn orig tvs cls tycon tc_args rep_tc rep_tc_args mtheta | otherwise = do { dfun_name <- new_dfun_name cls tycon ; loc <- getSrcSpanM - ; let ordinary_constraints + ; let ordinary_constraints_simple = [ mkClassPred cls [arg_ty] | data_con <- tyConDataCons rep_tc, arg_ty <- ASSERT( isVanillaDataCon data_con ) @@ -665,13 +672,31 @@ mk_data_eqn orig tvs cls tycon tc_args rep_tc rep_tc_args mtheta -- No constraints for unlifted types -- Where they are legal we generate specilised function calls + -- constraints on all subtypes for classes like Functor + ordinary_constraints_deep + = [ mkClassPred cls [deept_ty] + | data_con <- tyConDataCons rep_tc, + arg_ty <- ASSERT( isVanillaDataCon data_con ) + dataConInstOrigArgTys data_con (rep_tc_args++[mkTyVarTy dummy_ty]), + deept_ty <- deepSubtypesContaining dummy_ty arg_ty, + not (isUnLiftedType deept_ty) ] + where dummy_ty = last (tyConTyVars tycon) -- don't substitute the last var, this might not be a good idea + + ordinary_constraints + | getUnique cls == functorClassKey = ordinary_constraints_deep + | getUnique cls == foldableClassKey = ordinary_constraints_deep + | getUnique cls == traversableClassKey = ordinary_constraints_deep + | otherwise = ordinary_constraints_simple + -- See Note [Superclasses of derived instance] sc_constraints = substTheta (zipOpenTvSubst (classTyVars cls) inst_tys) (classSCTheta cls) inst_tys = [mkTyConApp tycon tc_args] - stupid_subst = zipTopTvSubst (tyConTyVars rep_tc) rep_tc_args + nonfree_tycon_vars = dropTail (classArity cls) (tyConTyVars rep_tc) + stupid_subst = zipTopTvSubst nonfree_tycon_vars rep_tc_args stupid_constraints = substTheta stupid_subst (tyConStupidTheta rep_tc) + all_constraints = stupid_constraints ++ sc_constraints ++ ordinary_constraints spec = DS { ds_loc = loc, ds_orig = orig @@ -712,6 +737,7 @@ mk_typeable_eqn orig tvs cls tycon tc_args rep_tc rep_tc_args mtheta , ds_tc = rep_tc, ds_tc_args = rep_tc_args , ds_theta = mtheta `orElse` [], ds_newtype = False }) } + ------------------------------------------------------------------ -- Check side conditions that dis-allow derivability for particular classes -- This is *apart* from the newtype-deriving mechanism @@ -724,10 +750,10 @@ data DerivStatus = CanDerive | DerivableClassError SDoc -- Standard class, but can't do it | NonDerivableClass -- Non-standard class -checkSideConditions :: Bool -> Class -> [TcType] -> TyCon -> DerivStatus -checkSideConditions mayDeriveDataTypeable cls cls_tys rep_tc +checkSideConditions :: DynFlags -> Class -> [TcType] -> TyCon -> DerivStatus +checkSideConditions dflags cls cls_tys rep_tc | Just cond <- sideConditions cls - = case (cond (mayDeriveDataTypeable, rep_tc)) of + = case (cond (dflags, rep_tc)) of Just err -> DerivableClassError err -- Class-specific error Nothing | null cls_tys -> CanDerive | otherwise -> DerivableClassError ty_args_why -- e.g. deriving( Eq s ) @@ -748,13 +774,17 @@ sideConditions cls | cls_key == ixClassKey = Just (cond_std `andCond` cond_enumOrProduct) | cls_key == boundedClassKey = Just (cond_std `andCond` cond_enumOrProduct) | cls_key == dataClassKey = Just (cond_mayDeriveDataTypeable `andCond` cond_std `andCond` cond_noUnliftedArgs) + | cls_key == functorClassKey = Just (cond_mayDeriveFunctor `andCond` cond_std `andCond` cond_functorOK True) + | cls_key == foldableClassKey = Just (cond_mayDeriveFunctor `andCond` cond_std `andCond` cond_functorOK False) + | cls_key == traversableClassKey = Just (cond_mayDeriveFunctor `andCond` cond_std `andCond` cond_functorOK False) | getName cls `elem` typeableClassNames = Just (cond_mayDeriveDataTypeable `andCond` cond_typeableOK) | otherwise = Nothing where cls_key = getUnique cls -type Condition = (Bool, TyCon) -> Maybe SDoc - -- Bool is whether or not we are allowed to derive Data and Typeable +type Condition = (DynFlags, TyCon) -> Maybe SDoc + -- first Bool is whether or not we are allowed to derive Data and Typeable + -- second Bool is whether or not we are allowed to derive Functor -- TyCon is the *representation* tycon if the -- data type is an indexed one -- Nothing => OK @@ -835,13 +865,47 @@ cond_typeableOK (_, rep_tc) fam_inst = quotes (pprSourceTyCon rep_tc) <+> ptext (sLit "is a type family") +cond_functorOK :: Bool -> Condition +-- OK for Functor class +-- Currently: (a) at least one argument +-- (b) don't use argument contravariantly +-- (c) don't use argument in the wrong place, e.g. data T a = T (X a a) +-- (d) optionally: don't use function types +cond_functorOK allowFunctions (_, rep_tc) = msum (map check con_types) + where + data_cons = tyConDataCons rep_tc + con_types = concatMap dataConOrigArgTys data_cons + check = functorLikeTraverse + Nothing + Nothing + (Just covariant) + (\x y -> if allowFunctions then x `mplus` y else Just functions) + (\_ xs -> msum xs) + (\_ x -> x) + (Just wrong_arg) + (\_ x -> x) + (last (tyConTyVars rep_tc)) + covariant = quotes (pprSourceTyCon rep_tc) <+> + ptext (sLit "uses the type variable in a function argument") + functions = quotes (pprSourceTyCon rep_tc) <+> + ptext (sLit "contains function types") + wrong_arg = quotes (pprSourceTyCon rep_tc) <+> + ptext (sLit "uses the type variable in an argument other than the last") + cond_mayDeriveDataTypeable :: Condition -cond_mayDeriveDataTypeable (mayDeriveDataTypeable, _) - | mayDeriveDataTypeable = Nothing +cond_mayDeriveDataTypeable (dflags, _) + | dopt Opt_DeriveDataTypeable dflags = Nothing | otherwise = Just why where why = ptext (sLit "You need -XDeriveDataTypeable to derive an instance for this class") +cond_mayDeriveFunctor :: Condition +cond_mayDeriveFunctor (dflags, _) + | dopt Opt_DeriveFunctor dflags = Nothing + | otherwise = Just why + where + why = ptext (sLit "You need -XDeriveFunctor to derive an instance for this class") + std_class_via_iso :: Class -> Bool std_class_via_iso clas -- These standard classes can be derived for a newtype -- using the isomorphism trick *even if no -fglasgow-exts* @@ -890,11 +954,11 @@ a context for the Data instances: %************************************************************************ \begin{code} -mkNewTypeEqn :: InstOrigin -> Bool -> Bool -> [Var] -> Class +mkNewTypeEqn :: InstOrigin -> DynFlags -> [Var] -> Class -> [Type] -> TyCon -> [Type] -> TyCon -> [Type] -> Maybe ThetaType -> TcRn EarlyDerivSpec -mkNewTypeEqn orig mayDeriveDataTypeable newtype_deriving tvs +mkNewTypeEqn orig dflags tvs cls cls_tys tycon tc_args rep_tycon rep_tc_args mtheta -- Want: instance (...) => cls (cls_tys ++ [tycon tc_args]) where ... | can_derive_via_isomorphism && (newtype_deriving || std_class_via_iso cls) @@ -919,7 +983,8 @@ mkNewTypeEqn orig mayDeriveDataTypeable newtype_deriving tvs | newtype_deriving -> bale_out cant_derive_err -- Too hard, even with newtype deriving | otherwise -> bale_out non_std_err -- Try newtype deriving! where - check_conditions = checkSideConditions mayDeriveDataTypeable cls cls_tys rep_tycon + newtype_deriving = dopt Opt_GeneralizedNewtypeDeriving dflags + check_conditions = checkSideConditions dflags cls cls_tys rep_tycon bale_out msg = failWithTc (derivingThingErr cls cls_tys inst_ty msg) non_std_err = nonStdErr cls $$ @@ -1292,6 +1357,9 @@ genDerivBinds loc fix_env clas tycon ,(showClassKey, gen_Show_binds fix_env) ,(readClassKey, gen_Read_binds fix_env) ,(dataClassKey, gen_Data_binds) + ,(functorClassKey, gen_Functor_binds) + ,(foldableClassKey, gen_Foldable_binds) + ,(traversableClassKey, gen_Traversable_binds) ] \end{code} diff --git a/compiler/typecheck/TcGenDeriv.lhs b/compiler/typecheck/TcGenDeriv.lhs index 9826f2f..845fecc 100644 --- a/compiler/typecheck/TcGenDeriv.lhs +++ b/compiler/typecheck/TcGenDeriv.lhs @@ -23,6 +23,9 @@ module TcGenDeriv ( gen_Show_binds, gen_Data_binds, gen_Typeable_binds, + gen_Functor_binds, functorLikeTraverse, deepSubtypesContaining, + gen_Foldable_binds, + gen_Traversable_binds, genAuxBind ) where @@ -44,7 +47,12 @@ import TyCon import TcType import TysPrim import TysWiredIn +import Type +import TypeRep +import VarSet +import State import Util +import MonadUtils import Outputable import FastString import OccName @@ -1203,6 +1211,302 @@ prefix_RDR = dataQual_RDR gENERICS (fsLit "Prefix") infix_RDR = dataQual_RDR gENERICS (fsLit "Infix") \end{code} + + +%************************************************************************ +%* * + Functor instances +%* * +%************************************************************************ + +For the data type: + + data T a = T1 Int a | T2 (T a) + +We generate the instance: + + instance Functor T where + fmap f (T1 b1 a) = T1 b1 (f a) + fmap f (T2 ta) = T2 (fmap f ta) + +Notice that we don't simply apply 'fmap' to the constructor arguments. +Rather + - Do nothing to an argument whose type doesn't mention 'a' + - Apply 'f' to an argument of type 'a' + - Apply 'fmap f' to other arguments +That's why we have to recurse deeply into the constructor argument types, +rather than just one level, as we typically do. + +What about types with more than one type parameter? In general, we only +derive Functor for the last position: + + data S a b = S1 [b] | S2 a + instance Functor (S a) where + fmap f (S1 bs) = S1 (fmap f bs) + fmap f (S2 a) = S2 a + +However, we have special cases for + - tuples + - functions + +More formally, we write the derivation of fmap code over type variable +'a for type 'b as ($fmap 'a 'b). In this general notation the derived +instance for T is: + + instance Functor T where + fmap f (T1 x1 x2) = T1 ($(fmap 'a 'b1) x1) ($(fmap 'a 'a) x2) + fmap f (T2 x1) = T2 ($(fmap 'a '(T a)) x1) + + $(fmap 'a 'b) x = x -- when b does not contain a + $(fmap 'a 'a) x = f x + $(fmap 'a '(b1,b2)) x = case x of (x1,x2) -> ($(fmap 'a 'b1) x1, $(fmap 'a 'b2) x2) + $(fmap 'a '(T b1 b2)) x = fmap $(fmap 'a 'b2) x -- when a only occurs in the last parameter, b2 + $(fmap 'a '(b -> c)) x = \b -> $(fmap 'a' 'c) (x ($(cofmap 'a 'b) b)) + +For functions, the type parameter 'a can occur in a contravariant position, +which means we need to derive a function like: + + cofmap :: (a -> b) -> (f b -> f a) + +This is pretty much the same as $fmap, only without the $(cofmap 'a 'a) case: + + $(cofmap 'a 'b) x = x -- when b does not contain a + $(cofmap 'a 'a) x = error "type variable in contravariant position" + $(cofmap 'a '(b1,b2)) x = case x of (x1,x2) -> ($(cofmap 'a 'b1) x1, $(cofmap 'a 'b2) x2) + $(cofmap 'a '[b]) x = map $(cofmap 'a 'b) x + $(cofmap 'a '(T b1 b2)) x = fmap $(cofmap 'a 'b2) x -- when a only occurs in the last parameter, b2 + $(cofmap 'a '(b -> c)) x = \b -> $(cofmap 'a' 'c) (x ($(fmap 'a 'c) b)) + +\begin{code} +gen_Functor_binds :: SrcSpan -> TyCon -> (LHsBinds RdrName, DerivAuxBinds) +gen_Functor_binds loc tycon + = (listToBag [fmap_bind], []) + where + data_cons = tyConDataCons tycon + arg = last (tyConTyVars tycon) -- argument to derive for, 'a in the above description + + fmap_bind = L loc $ mkFunBind (L loc fmap_RDR) (map fmap_eqn data_cons) + fmap_eqn con = evalState (match_for_con [f_Pat] con parts) bs_RDRs + where parts = map derive_fmap_type (dataConOrigArgTys con) + + derive_fmap_type :: Type -> LHsExpr RdrName -> State [RdrName] (LHsExpr RdrName) + derive_fmap_type = functorLikeTraverse + (\ x -> return x) -- fmap f x = x + (\ x -> return (nlHsApp f_Expr x)) -- fmap f x = f x + (panic "contravariant") + (\g h x -> mkSimpleLam (\b -> h =<< (nlHsApp x `fmap` g b))) -- fmap f x = \b -> h (x (g b)) + (mkSimpleTupleCase match_for_con) -- fmap f x = case x of (a1,a2,..) -> (g1 a1,g2 a2,..) + (\_ g x -> do gg <- mkSimpleLam g + return $ nlHsApps fmap_RDR [gg,x]) -- fmap f x = fmap g x + (panic "in other argument") + (\_ g x -> g x) + arg + + match_for_con = mkSimpleConMatch $ + \con_name xsM -> do xs <- sequence xsM + return (nlHsApps con_name xs) -- Con (g1 v1) (g2 v2) .. +\end{code} + +Utility functions related to Functor deriving. + +Since several things use the same pattern of traversal, this is abstracted into functorLikeTraverse. +This function works like a fold: it makes a value of type 'a' in a bottom up way. + +\begin{code} +-- Generic traversal for Functor deriving +functorLikeTraverse :: a -- ^ Case: does not contain variable + -> a -- ^ Case: the variable itself + -> a -- ^ Case: the variable itself, contravariantly + -> (a -> a -> a) -- ^ Case: function type + -> (Boxity -> [a] -> a) -- ^ Case: tuple type + -> (Type -> a -> a) -- ^ Case: other tycon, variable only in last argument + -> a -- ^ Case: other tycon, variable only in last argument + -> (TcTyVar -> a -> a) -- ^ Case: forall type + -> TcTyVar -- ^ Variable to look for + -> Type -- ^ Type to process + -> a +functorLikeTraverse caseTrivial caseVar caseCoVar caseFun caseTuple caseTyApp caseWrongArg caseForAll var ty + = fst (go False ty) + where -- go returns (result of type a, does type contain var) + go co ty | Just ty' <- coreView ty = go co ty' + go co (TyVarTy v) | v == var = (if co then caseCoVar else caseVar,True) + go co (FunTy (PredTy _) b) = go co b + go co (FunTy x y) | xc || yc = (caseFun xr yr,True) + where (xr,xc) = go (not co) x + (yr,yc) = go co y + go co (AppTy x y) | xc = (caseWrongArg,True) + | yc = (caseTyApp x yr,True) + where (_, xc) = go co x + (yr,yc) = go co y + go co ty@(TyConApp con args) + | isTupleTyCon con = (caseTuple (tupleTyConBoxity con) xrs,True) + | null args = (caseTrivial,False) + | or (init xcs) = (caseWrongArg,True) + | (last xcs) = (caseTyApp (fst (splitAppTy ty)) (last xrs),True) + where (xrs,xcs) = unzip (map (go co) args) + go co (ForAllTy v x) | v /= var && xc = (caseForAll v xr,True) + where (xr,xc) = go co x + go _ _ = (caseTrivial,False) + +-- return all subtypes of ty that contain var somewhere +-- these are the things that should appear in instance constraints +deepSubtypesContaining :: TcTyVar -> TcType -> [TcType] +deepSubtypesContaining = functorLikeTraverse + [] + [] + (panic "contravariant") + (\x y -> x ++ y) -- function + (\_ xs -> concat xs) -- tuple + (\ty x -> ty : x) -- tyapp + (panic "in other argument") + (\v x -> filter (not . (v `elemVarSet`) . tyVarsOfType) x) -- forall v + + +-- Make a HsLam using a fresh variable from a State monad +mkSimpleLam :: (LHsExpr id -> State [id] (LHsExpr id)) -> State [id] (LHsExpr id) +mkSimpleLam lam = do + (n:names) <- get + put names + body <- lam (nlHsVar n) + return (mkHsLam [nlVarPat n] body) + +mkSimpleLam2 :: (LHsExpr id -> LHsExpr id -> State [id] (LHsExpr id)) -> State [id] (LHsExpr id) +mkSimpleLam2 lam = do + (n1:n2:names) <- get + put names + body <- lam (nlHsVar n1) (nlHsVar n2) + return (mkHsLam [nlVarPat n1,nlVarPat n2] body) + +-- "Con a1 a2 a3 -> fold [x1 a1, x2 a2, x3 a3]" +mkSimpleConMatch :: Monad m => (RdrName -> [a] -> m (LHsExpr RdrName)) -> [LPat RdrName] -> DataCon -> [LHsExpr RdrName -> a] -> m (LMatch RdrName) +mkSimpleConMatch fold extra_pats con insides = do + let con_name = getRdrName con + let vars_needed = takeList insides as_RDRs + let pat = nlConVarPat con_name vars_needed + rhs <- fold con_name (zipWith ($) insides (map nlHsVar vars_needed)) + return $ mkMatch (extra_pats ++ [pat]) rhs emptyLocalBinds + +-- "case x of (a1,a2,a3) -> fold [x1 a1, x2 a2, x3 a3]" +mkSimpleTupleCase :: Monad m => ([LPat RdrName] -> DataCon -> [LHsExpr RdrName -> a] -> m (LMatch RdrName)) + -> Boxity -> [LHsExpr RdrName -> a] -> LHsExpr RdrName -> m (LHsExpr RdrName) +mkSimpleTupleCase match_for_con boxity insides x = do + let con = tupleCon boxity (length insides) + match <- match_for_con [] con insides + return $ nlHsCase x [match] +\end{code} + + +%************************************************************************ +%* * + Foldable instances +%* * +%************************************************************************ + +Deriving Foldable instances works the same way as Functor instances, +only Foldable instances are not possible for function types at all. +Here the derived instance for the type T above is: + + instance Foldable T where + foldr f z (T1 x1 x2 x3) = $(foldr 'a 'b1) x1 ( $(foldr 'a 'a) x2 ( $(foldr 'a 'b2) x3 z ) ) + +The cases are: + + $(foldr 'a 'b) x z = z -- when b does not contain a + $(foldr 'a 'a) x z = f x z + $(foldr 'a '(b1,b2)) x z = case x of (x1,x2) -> $(foldr 'a 'b1) x1 ( $(foldr 'a 'b2) x2 z ) + $(foldr 'a '(T b1 b2)) x z = foldr $(foldr 'a 'b2) x z -- when a only occurs in the last parameter, b2 + +Note that the arguments to the real foldr function are the wrong way around, +since (f :: a -> b -> b), while (foldr f :: b -> t a -> b). + +\begin{code} +gen_Foldable_binds :: SrcSpan -> TyCon -> (LHsBinds RdrName, DerivAuxBinds) +gen_Foldable_binds loc tycon + = (listToBag [foldr_bind], []) + where + data_cons = tyConDataCons tycon + arg = last (tyConTyVars tycon) -- argument to derive for, 'a in the above description + + foldr_bind = L loc $ mkFunBind (L loc foldr_RDR) (map foldr_eqn data_cons) + foldr_eqn con = evalState (match_for_con z_Expr [f_Pat,z_Pat] con parts) bs_RDRs + where parts = map derive_foldr_type (dataConOrigArgTys con) + + derive_foldr_type :: Type -> LHsExpr RdrName -> LHsExpr RdrName -> State [RdrName] (LHsExpr RdrName) + derive_foldr_type = functorLikeTraverse + (\ _ z -> return z) -- foldr f z x = z + (\ x z -> return (nlHsApps f_RDR [x,z])) -- foldr f z x = f x z + (panic "function") + (panic "function") + (\b gs x z -> mkSimpleTupleCase (match_for_con z) b gs x) + (\_ g x z -> do gg <- mkSimpleLam2 g -- foldr f z x = foldr (\xx zz -> g xx zz) z x + return $ nlHsApps foldable_foldr_RDR [gg,z,x]) + (panic "in other argument") + (\_ g x z -> g x z) + arg + + match_for_con z = mkSimpleConMatch (\_con_name -> foldrM ($) z) -- g1 v1 (g2 v2 (.. z)) +\end{code} + + +%************************************************************************ +%* * + Traversable instances +%* * +%************************************************************************ + +Again, Traversable is much like Functor and Foldable. + +The cases are: + + $(traverse 'a 'b) x = pure x -- when b does not contain a + $(traverse 'a 'a) x = f x + $(traverse 'a '(b1,b2)) x = case x of (x1,x2) -> (,) <$> $(traverse 'a 'b1) x1 <*> $(traverse 'a 'b2) x2 + $(traverse 'a '(T b1 b2)) x = traverse $(traverse 'a 'b2) x -- when a only occurs in the last parameter, b2 + +Note that the generated code is not as efficient as it could be. For instance: + + data T a = T Int a deriving Traversable + +gives the function: traverse f (T x y) = T <$> pure x <*> f y +instead of: traverse f (T x y) = T x <$> f y + +\begin{code} +gen_Traversable_binds :: SrcSpan -> TyCon -> (LHsBinds RdrName, DerivAuxBinds) +gen_Traversable_binds loc tycon + = (listToBag [traverse_bind], []) + where + data_cons = tyConDataCons tycon + arg = last (tyConTyVars tycon) -- argument to derive for, 'a in the above description + + traverse_bind = L loc $ mkFunBind (L loc traverse_RDR) (map traverse_eqn data_cons) + traverse_eqn con = evalState (match_for_con [f_Pat] con parts) bs_RDRs + where parts = map derive_travese_type (dataConOrigArgTys con) + + derive_travese_type :: Type -> LHsExpr RdrName -> State [RdrName] (LHsExpr RdrName) + derive_travese_type = functorLikeTraverse + (\ x -> return (nlHsApps pure_RDR [x])) -- traverse f x = pure x + (\ x -> return (nlHsApps f_RDR [x])) -- travese f x = f x + (panic "function") + (panic "function") + (mkSimpleTupleCase match_for_con) -- travese f x z = case x of (a1,a2,..) -> (,,) <$> g1 a1 <*> g2 a2 <*> .. + (\_ g x -> do gg <- mkSimpleLam g -- travese f x = travese (\xx -> g xx) x + return $ nlHsApps traverse_RDR [gg,x]) + (panic "in other argument") + (\_ g x -> g x) + arg + + match_for_con = mkSimpleConMatch $ + \con_name xsM -> do xs <- sequence xsM + return (mkApCon (nlHsVar con_name) xs) + + -- ((Con <$> x1) <*> x2) <*> .. + mkApCon con [] = nlHsApps pure_RDR [con] + mkApCon con (x:xs) = foldl appAp (nlHsApps fmap_RDR [con,x]) xs + where appAp x y = nlHsApps ap_RDR [x,y] +\end{code} + + + %************************************************************************ %* * \subsection{Generating extra binds (@con2tag@ and @tag2con@)} @@ -1500,12 +1804,13 @@ genOpApp e1 op e2 = nlHsPar (nlHsOpApp e1 op e2) \end{code} \begin{code} -a_RDR, b_RDR, c_RDR, d_RDR, k_RDR, z_RDR, ah_RDR, bh_RDR, ch_RDR, dh_RDR, +a_RDR, b_RDR, c_RDR, d_RDR, f_RDR, k_RDR, z_RDR, ah_RDR, bh_RDR, ch_RDR, dh_RDR, cmp_eq_RDR :: RdrName a_RDR = mkVarUnqual (fsLit "a") b_RDR = mkVarUnqual (fsLit "b") c_RDR = mkVarUnqual (fsLit "c") d_RDR = mkVarUnqual (fsLit "d") +f_RDR = mkVarUnqual (fsLit "f") k_RDR = mkVarUnqual (fsLit "k") z_RDR = mkVarUnqual (fsLit "z") ah_RDR = mkVarUnqual (fsLit "a#") @@ -1519,22 +1824,25 @@ as_RDRs = [ mkVarUnqual (mkFastString ("a"++show i)) | i <- [(1::Int) .. ] ] bs_RDRs = [ mkVarUnqual (mkFastString ("b"++show i)) | i <- [(1::Int) .. ] ] cs_RDRs = [ mkVarUnqual (mkFastString ("c"++show i)) | i <- [(1::Int) .. ] ] -a_Expr, b_Expr, c_Expr, ltTag_Expr, eqTag_Expr, gtTag_Expr, +a_Expr, b_Expr, c_Expr, f_Expr, z_Expr, ltTag_Expr, eqTag_Expr, gtTag_Expr, false_Expr, true_Expr :: LHsExpr RdrName a_Expr = nlHsVar a_RDR b_Expr = nlHsVar b_RDR c_Expr = nlHsVar c_RDR +f_Expr = nlHsVar f_RDR +z_Expr = nlHsVar z_RDR ltTag_Expr = nlHsVar ltTag_RDR eqTag_Expr = nlHsVar eqTag_RDR gtTag_Expr = nlHsVar gtTag_RDR false_Expr = nlHsVar false_RDR true_Expr = nlHsVar true_RDR -a_Pat, b_Pat, c_Pat, d_Pat, k_Pat, z_Pat :: LPat RdrName +a_Pat, b_Pat, c_Pat, d_Pat, f_Pat, k_Pat, z_Pat :: LPat RdrName a_Pat = nlVarPat a_RDR b_Pat = nlVarPat b_RDR c_Pat = nlVarPat c_RDR d_Pat = nlVarPat d_RDR +f_Pat = nlVarPat f_RDR k_Pat = nlVarPat k_RDR z_Pat = nlVarPat z_RDR diff --git a/compiler/utils/Util.lhs b/compiler/utils/Util.lhs index db6f96a..af81110 100644 --- a/compiler/utils/Util.lhs +++ b/compiler/utils/Util.lhs @@ -32,6 +32,7 @@ module Util ( -- * List operations controlled by another list takeList, dropList, splitAtList, split, + dropTail, -- * For loop nTimes, @@ -608,6 +609,10 @@ splitAtList (_:xs) (y:ys) = (y:ys', ys'') where (ys', ys'') = splitAtList xs ys +-- drop from the end of a list +dropTail :: Int -> [a] -> [a] +dropTail n = reverse . drop n . reverse + snocView :: [a] -> Maybe ([a],a) -- Split off the last element snocView [] = Nothing -- 1.7.10.4