Add the ability to derive instances of Functor, Foldable, Traversable
authorsimonpj@microsoft.com <unknown>
Mon, 2 Feb 2009 13:48:29 +0000 (13:48 +0000)
committersimonpj@microsoft.com <unknown>
Mon, 2 Feb 2009 13:48:29 +0000 (13:48 +0000)
This patch is a straightforward extension of the 'deriving' mechanism.
The ability to derive classes Functor, Foldable, Traverable is controlled
by a single flag  -XDeriveFunctor.  (Maybe that's a poor name.)

Still to come: documentation

Thanks to twanvl for developing the patch

compiler/main/DynFlags.hs
compiler/prelude/PrelNames.lhs
compiler/typecheck/TcDeriv.lhs
compiler/typecheck/TcGenDeriv.lhs
compiler/utils/Util.lhs

index 35949f7..44bd124 100644 (file)
@@ -220,6 +220,7 @@ data DynFlag
    | Opt_RelaxedPolyRec
    | Opt_StandaloneDeriving
    | Opt_DeriveDataTypeable
+   | Opt_DeriveFunctor
    | Opt_TypeSynonymInstances
    | Opt_FlexibleContexts
    | Opt_FlexibleInstances
@@ -1771,6 +1772,7 @@ xFlags = [
   ( "UnboxedTuples",                    Opt_UnboxedTuples, const Supported ),
   ( "StandaloneDeriving",               Opt_StandaloneDeriving, const Supported ),
   ( "DeriveDataTypeable",               Opt_DeriveDataTypeable, const Supported ),
+  ( "DeriveFunctor",                    Opt_DeriveFunctor, const Supported ),
   ( "TypeSynonymInstances",             Opt_TypeSynonymInstances, const Supported ),
   ( "FlexibleContexts",                 Opt_FlexibleContexts, const Supported ),
   ( "FlexibleInstances",                Opt_FlexibleInstances, const Supported ),
@@ -1809,6 +1811,7 @@ glasgowExtsFlags = [
            , Opt_TypeSynonymInstances
            , Opt_StandaloneDeriving
            , Opt_DeriveDataTypeable
+           , Opt_DeriveFunctor
            , Opt_FlexibleContexts
            , Opt_FlexibleInstances
            , Opt_ConstrainedClassMethods
index bf46465..235bc3c 100644 (file)
@@ -131,6 +131,9 @@ basicKnownKeyNames
        realFloatClassName,             -- numeric
        dataClassName, 
        isStringClassName,
+       applicativeClassName,
+       foldableClassName,
+       traversableClassName,
 
        -- Numeric stuff
        negateName, minusName, 
@@ -164,7 +167,7 @@ basicKnownKeyNames
 
        -- Read stuff
        readClassName, 
-       
+
        -- Stable pointers
        newStablePtrName,
 
@@ -204,11 +207,8 @@ basicKnownKeyNames
        randomClassName, randomGenClassName, monadPlusClassName,
 
         -- Annotation type checking
-        toAnnotationWrapperName,
+        toAnnotationWrapperName
 
-       -- Booleans
-       andName, orName
-       
        -- The Either type
        , eitherTyConName, leftDataConName, rightDataConName
 
@@ -236,10 +236,11 @@ pRELUDE           = mkBaseModule_ pRELUDE_NAME
 
 gHC_PRIM, gHC_TYPES, gHC_BOOL, gHC_UNIT, gHC_ORDERING, gHC_GENERICS, gHC_CLASSES, gHC_BASE, gHC_ENUM,
     gHC_SHOW, gHC_READ, gHC_NUM, gHC_INTEGER, gHC_INTEGER_INTERNALS, gHC_LIST, gHC_PARR,
-    gHC_TUPLE, dATA_TUPLE, dATA_EITHER, dATA_STRING, gHC_PACK, gHC_CONC, gHC_IO_BASE,
+    gHC_TUPLE, dATA_TUPLE, dATA_EITHER, dATA_STRING, dATA_FOLDABLE, dATA_TRAVERSABLE,
+    gHC_PACK, gHC_CONC, gHC_IO_BASE,
     gHC_ST, gHC_ARR, gHC_STABLE, gHC_ADDR, gHC_PTR, gHC_ERR, gHC_REAL,
     gHC_FLOAT, gHC_TOP_HANDLER, sYSTEM_IO, dYNAMIC, tYPEABLE, gENERICS,
-    dOTNET, rEAD_PREC, lEX, gHC_INT, gHC_WORD, mONAD, mONAD_FIX, aRROW,
+    dOTNET, rEAD_PREC, lEX, gHC_INT, gHC_WORD, mONAD, mONAD_FIX, aRROW, cONTROL_APPLICATIVE,
     gHC_DESUGAR, rANDOM, gHC_EXTS, cONTROL_EXCEPTION_BASE :: Module
 gHC_PRIM       = mkPrimModule (fsLit "GHC.Prim")   -- Primitive types and values
 gHC_TYPES       = mkPrimModule (fsLit "GHC.Types")
@@ -261,6 +262,8 @@ gHC_TUPLE   = mkPrimModule (fsLit "GHC.Tuple")
 dATA_TUPLE     = mkBaseModule (fsLit "Data.Tuple")
 dATA_EITHER    = mkBaseModule (fsLit "Data.Either")
 dATA_STRING    = mkBaseModule (fsLit "Data.String")
+dATA_FOLDABLE  = mkBaseModule (fsLit "Data.Foldable")
+dATA_TRAVERSABLE= mkBaseModule (fsLit "Data.Traversable")
 gHC_PACK       = mkBaseModule (fsLit "GHC.Pack")
 gHC_CONC       = mkBaseModule (fsLit "GHC.Conc")
 gHC_IO_BASE    = mkBaseModule (fsLit "GHC.IOBase")
@@ -285,6 +288,7 @@ gHC_WORD    = mkBaseModule (fsLit "GHC.Word")
 mONAD          = mkBaseModule (fsLit "Control.Monad")
 mONAD_FIX      = mkBaseModule (fsLit "Control.Monad.Fix")
 aRROW          = mkBaseModule (fsLit "Control.Arrow")
+cONTROL_APPLICATIVE = mkBaseModule (fsLit "Control.Applicative")
 gHC_DESUGAR = mkBaseModule (fsLit "GHC.Desugar")
 rANDOM         = mkBaseModule (fsLit "System.Random")
 gHC_EXTS       = mkBaseModule (fsLit "GHC.Exts")
@@ -389,9 +393,6 @@ returnM_RDR                 = nameRdrName returnMName
 bindM_RDR              = nameRdrName bindMName
 failM_RDR              = nameRdrName failMName
 
-and_RDR :: RdrName
-and_RDR                        = nameRdrName andName
-
 left_RDR, right_RDR :: RdrName
 left_RDR               = nameRdrName leftDataConName
 right_RDR              = nameRdrName rightDataConName
@@ -443,8 +444,9 @@ compose_RDR :: RdrName
 compose_RDR            = varQual_RDR gHC_BASE (fsLit ".")
 
 not_RDR, getTag_RDR, succ_RDR, pred_RDR, minBound_RDR, maxBound_RDR,
-    range_RDR, inRange_RDR, index_RDR,
+    and_RDR, range_RDR, inRange_RDR, index_RDR,
     unsafeIndex_RDR, unsafeRangeSize_RDR :: RdrName
+and_RDR                        = varQual_RDR gHC_CLASSES (fsLit "&&")
 not_RDR                = varQual_RDR gHC_CLASSES (fsLit "not")
 getTag_RDR             = varQual_RDR gHC_BASE (fsLit "getTag")
 succ_RDR               = varQual_RDR gHC_ENUM (fsLit "succ")
@@ -502,6 +504,13 @@ inlDataCon_RDR     = dataQual_RDR gHC_GENERICS (fsLit "Inl")
 inrDataCon_RDR     = dataQual_RDR gHC_GENERICS (fsLit "Inr")
 genUnitDataCon_RDR = dataQual_RDR gHC_GENERICS (fsLit "Unit")
 
+fmap_RDR, pure_RDR, ap_RDR, foldable_foldr_RDR, traverse_RDR :: RdrName
+fmap_RDR               = varQual_RDR gHC_BASE (fsLit "fmap")
+pure_RDR               = varQual_RDR cONTROL_APPLICATIVE (fsLit "pure")
+ap_RDR                         = varQual_RDR cONTROL_APPLICATIVE (fsLit "<*>")
+foldable_foldr_RDR     = varQual_RDR dATA_FOLDABLE       (fsLit "foldr")
+traverse_RDR           = varQual_RDR dATA_TRAVERSABLE    (fsLit "traverse")
+
 ----------------------
 varQual_RDR, tcQual_RDR, clsQual_RDR, dataQual_RDR
     :: Module -> FastString -> RdrName
@@ -573,13 +582,19 @@ bindMName    = methName gHC_BASE (fsLit ">>=")    bindMClassOpKey
 returnMName       = methName gHC_BASE (fsLit "return") returnMClassOpKey
 failMName         = methName gHC_BASE (fsLit "fail")   failMClassOpKey
 
+-- Classes (Applicative, Foldable, Traversable)
+applicativeClassName, foldableClassName, traversableClassName :: Name
+applicativeClassName  = clsQual  cONTROL_APPLICATIVE (fsLit "Applicative") applicativeClassKey
+foldableClassName     = clsQual  dATA_FOLDABLE       (fsLit "Foldable")    foldableClassKey
+traversableClassName  = clsQual  dATA_TRAVERSABLE    (fsLit "Traversable") traversableClassKey
+
 -- Functions for GHC extensions
 groupWithName :: Name
 groupWithName = varQual gHC_EXTS (fsLit "groupWith") groupWithIdKey
 
 -- Random PrelBase functions
 fromStringName, otherwiseIdName, foldrName, buildName, augmentName,
-    mapName, appendName, andName, orName, assertName,
+    mapName, appendName, assertName,
     breakpointName, breakpointCondName, breakpointAutoName,
     opaqueTyConName :: Name
 fromStringName = methName dATA_STRING (fsLit "fromString") fromStringClassOpKey
@@ -589,8 +604,6 @@ buildName     = varQual gHC_BASE (fsLit "build")      buildIdKey
 augmentName      = varQual gHC_BASE (fsLit "augment")    augmentIdKey
 mapName       = varQual gHC_BASE (fsLit "map")        mapIdKey
 appendName       = varQual gHC_BASE (fsLit "++")         appendIdKey
-andName                  = varQual gHC_CLASSES (fsLit "&&")     andIdKey
-orName           = varQual gHC_CLASSES (fsLit "||")     orIdKey
 assertName        = varQual gHC_BASE (fsLit "assert")     assertIdKey
 breakpointName    = varQual gHC_BASE (fsLit "breakpoint") breakpointIdKey
 breakpointCondName= varQual gHC_BASE (fsLit "breakpointCond") breakpointCondIdKey
@@ -889,6 +902,11 @@ randomGenClassKey  = mkPreludeClassUnique 32
 
 isStringClassKey :: Unique
 isStringClassKey       = mkPreludeClassUnique 33
+
+applicativeClassKey, foldableClassKey, traversableClassKey :: Unique
+applicativeClassKey    = mkPreludeClassUnique 34
+foldableClassKey       = mkPreludeClassUnique 35
+traversableClassKey    = mkPreludeClassUnique 36
 \end{code}
 
 %************************************************************************
@@ -1156,9 +1174,7 @@ rootMainKey, runMainKey :: Unique
 rootMainKey                  = mkPreludeMiscIdUnique 55
 runMainKey                   = mkPreludeMiscIdUnique 56
 
-andIdKey, orIdKey, thenIOIdKey, lazyIdKey, assertErrorIdKey :: Unique
-andIdKey                     = mkPreludeMiscIdUnique 57
-orIdKey                              = mkPreludeMiscIdUnique 58
+thenIOIdKey, lazyIdKey, assertErrorIdKey :: Unique
 thenIOIdKey                  = mkPreludeMiscIdUnique 59
 lazyIdKey                    = mkPreludeMiscIdUnique 60
 assertErrorIdKey             = mkPreludeMiscIdUnique 61
@@ -1260,6 +1276,7 @@ fromStringClassOpKey            = mkPreludeMiscIdUnique 125
 toAnnotationWrapperIdKey :: Unique
 toAnnotationWrapperIdKey      = mkPreludeMiscIdUnique 126
 
+
 ---------------- Template Haskell -------------------
 --     USES IdUniques 200-399
 -----------------------------------------------------
@@ -1325,7 +1342,8 @@ standardClassKeys = derivableClassKeys ++ numericClassKeys
                  ++ [randomClassKey, randomGenClassKey,
                      functorClassKey, 
                      monadClassKey, monadPlusClassKey,
-                     isStringClassKey
+                     isStringClassKey,
+                     applicativeClassKey, foldableClassKey, traversableClassKey
                     ]
 \end{code}
 
index eac2209..a507197 100644 (file)
@@ -49,6 +49,8 @@ import ListSetOps
 import Outputable
 import FastString
 import Bag
+
+import Control.Monad
 \end{code}
 
 %************************************************************************
@@ -566,15 +568,12 @@ mkEqnHelp orig tvs cls cls_tys tc_app mtheta
                   className cls `elem` typeableClassNames) 
                  (derivingHiddenErr tycon)
 
-       ; mayDeriveDataTypeable <- doptM Opt_DeriveDataTypeable
-       ; newtype_deriving <- doptM Opt_GeneralizedNewtypeDeriving
-
+       ; dflags <- getDOpts
        ; if isDataTyCon rep_tc then
-               mkDataTypeEqn orig mayDeriveDataTypeable tvs cls cls_tys 
+               mkDataTypeEqn orig dflags tvs cls cls_tys
                              tycon tc_args rep_tc rep_tc_args mtheta
          else
-               mkNewTypeEqn orig mayDeriveDataTypeable newtype_deriving
-                            tvs cls cls_tys 
+               mkNewTypeEqn orig dflags tvs cls cls_tys 
                             tycon tc_args rep_tc rep_tc_args mtheta }
   | otherwise
   = failWithTc (derivingThingErr cls cls_tys tc_app
@@ -631,13 +630,21 @@ famInstNotFound tycon tys
 %************************************************************************
 
 \begin{code}
-mkDataTypeEqn :: InstOrigin -> Bool -> [Var] -> Class -> [Type]
-              -> TyCon -> [Type] -> TyCon -> [Type] -> Maybe ThetaType
-              -> TcRn EarlyDerivSpec   -- Return 'Nothing' if error
-               
-mkDataTypeEqn orig mayDeriveDataTypeable tvs cls cls_tys
+mkDataTypeEqn :: InstOrigin
+              -> DynFlags
+              -> [Var]                  -- Universally quantified type variables in the instance
+              -> Class                  -- Class for which we need to derive an instance
+              -> [Type]                 -- Other parameters to the class except the last
+              -> TyCon                  -- Type constructor for which the instance is requested (last parameter to the type class)
+              -> [Type]                 -- Parameters to the type constructor
+              -> TyCon                  -- rep of the above (for type families)
+              -> [Type]                 -- rep of the above
+              -> Maybe ThetaType        -- Context of the instance, for standalone deriving
+              -> TcRn EarlyDerivSpec    -- Return 'Nothing' if error
+
+mkDataTypeEqn orig dflags tvs cls cls_tys
               tycon tc_args rep_tc rep_tc_args mtheta
-  = case checkSideConditions mayDeriveDataTypeable cls cls_tys rep_tc of
+  = case checkSideConditions dflags cls cls_tys rep_tc of
        -- NB: pass the *representation* tycon to checkSideConditions
        CanDerive -> mk_data_eqn orig tvs cls tycon tc_args rep_tc rep_tc_args mtheta
        NonDerivableClass       -> bale_out (nonStdErr cls)
@@ -656,7 +663,7 @@ mk_data_eqn orig tvs cls tycon tc_args rep_tc rep_tc_args mtheta
   | otherwise
   = do { dfun_name <- new_dfun_name cls tycon
        ; loc <- getSrcSpanM
-       ; let ordinary_constraints
+       ; let ordinary_constraints_simple
                = [ mkClassPred cls [arg_ty] 
                  | data_con <- tyConDataCons rep_tc,
                    arg_ty   <- ASSERT( isVanillaDataCon data_con )
@@ -665,13 +672,31 @@ mk_data_eqn orig tvs cls tycon tc_args rep_tc rep_tc_args mtheta
                        -- No constraints for unlifted types
                        -- Where they are legal we generate specilised function calls
 
+              -- constraints on all subtypes for classes like Functor
+              ordinary_constraints_deep
+                = [ mkClassPred cls [deept_ty]
+                  | data_con <- tyConDataCons rep_tc,
+                    arg_ty   <- ASSERT( isVanillaDataCon data_con )
+                                dataConInstOrigArgTys data_con (rep_tc_args++[mkTyVarTy dummy_ty]),
+                    deept_ty <- deepSubtypesContaining dummy_ty arg_ty,
+                    not (isUnLiftedType deept_ty) ]
+               where dummy_ty = last (tyConTyVars tycon) -- don't substitute the last var, this might not be a good idea
+
+              ordinary_constraints
+               | getUnique cls == functorClassKey     = ordinary_constraints_deep
+               | getUnique cls == foldableClassKey    = ordinary_constraints_deep
+               | getUnique cls == traversableClassKey = ordinary_constraints_deep
+               | otherwise                            = ordinary_constraints_simple
+
                        -- See Note [Superclasses of derived instance]
              sc_constraints = substTheta (zipOpenTvSubst (classTyVars cls) inst_tys)
                                          (classSCTheta cls)
              inst_tys = [mkTyConApp tycon tc_args]
 
-             stupid_subst = zipTopTvSubst (tyConTyVars rep_tc) rep_tc_args
+             nonfree_tycon_vars = dropTail (classArity cls) (tyConTyVars rep_tc)
+             stupid_subst = zipTopTvSubst nonfree_tycon_vars rep_tc_args
              stupid_constraints = substTheta stupid_subst (tyConStupidTheta rep_tc)
+
              all_constraints = stupid_constraints ++ sc_constraints ++ ordinary_constraints
 
              spec = DS { ds_loc = loc, ds_orig = orig
@@ -712,6 +737,7 @@ mk_typeable_eqn orig tvs cls tycon tc_args rep_tc rep_tc_args mtheta
                     , ds_tc = rep_tc, ds_tc_args = rep_tc_args
                     , ds_theta = mtheta `orElse` [], ds_newtype = False })  }
 
+
 ------------------------------------------------------------------
 -- Check side conditions that dis-allow derivability for particular classes
 -- This is *apart* from the newtype-deriving mechanism
@@ -724,10 +750,10 @@ data DerivStatus = CanDerive
                 | DerivableClassError SDoc     -- Standard class, but can't do it
                 | NonDerivableClass            -- Non-standard class
 
-checkSideConditions :: Bool -> Class -> [TcType] -> TyCon -> DerivStatus
-checkSideConditions mayDeriveDataTypeable cls cls_tys rep_tc
+checkSideConditions :: DynFlags -> Class -> [TcType] -> TyCon -> DerivStatus
+checkSideConditions dflags cls cls_tys rep_tc
   | Just cond <- sideConditions cls
-  = case (cond (mayDeriveDataTypeable, rep_tc)) of
+  = case (cond (dflags, rep_tc)) of
        Just err -> DerivableClassError err     -- Class-specific error
        Nothing  | null cls_tys -> CanDerive
                 | otherwise    -> DerivableClassError ty_args_why      -- e.g. deriving( Eq s )
@@ -748,13 +774,17 @@ sideConditions cls
   | cls_key == ixClassKey      = Just (cond_std `andCond` cond_enumOrProduct)
   | cls_key == boundedClassKey = Just (cond_std `andCond` cond_enumOrProduct)
   | cls_key == dataClassKey    = Just (cond_mayDeriveDataTypeable `andCond` cond_std `andCond` cond_noUnliftedArgs)
+  | cls_key == functorClassKey = Just (cond_mayDeriveFunctor `andCond` cond_std `andCond` cond_functorOK True)
+  | cls_key == foldableClassKey = Just (cond_mayDeriveFunctor `andCond` cond_std `andCond` cond_functorOK False)
+  | cls_key == traversableClassKey = Just (cond_mayDeriveFunctor `andCond` cond_std `andCond` cond_functorOK False)
   | getName cls `elem` typeableClassNames = Just (cond_mayDeriveDataTypeable `andCond` cond_typeableOK)
   | otherwise = Nothing
   where
     cls_key = getUnique cls
 
-type Condition = (Bool, TyCon) -> Maybe SDoc
-       -- Bool is whether or not we are allowed to derive Data and Typeable
+type Condition = (DynFlags, TyCon) -> Maybe SDoc
+       -- first Bool is whether or not we are allowed to derive Data and Typeable
+       -- second Bool is whether or not we are allowed to derive Functor
        -- TyCon is the *representation* tycon if the 
        --      data type is an indexed one
        -- Nothing => OK
@@ -835,13 +865,47 @@ cond_typeableOK (_, rep_tc)
     fam_inst = quotes (pprSourceTyCon rep_tc) <+> 
               ptext (sLit "is a type family")
 
+cond_functorOK :: Bool -> Condition
+-- OK for Functor class
+-- Currently: (a) at least one argument
+--            (b) don't use argument contravariantly
+--            (c) don't use argument in the wrong place, e.g. data T a = T (X a a)
+--            (d) optionally: don't use function types
+cond_functorOK allowFunctions (_, rep_tc) = msum (map check con_types)
+  where
+    data_cons = tyConDataCons rep_tc
+    con_types = concatMap dataConOrigArgTys data_cons
+    check = functorLikeTraverse
+                    Nothing
+                    Nothing
+                    (Just covariant)
+                    (\x y   -> if allowFunctions then x `mplus` y else Just functions)
+                    (\_ xs  -> msum xs)
+                    (\_ x   -> x)
+                    (Just wrong_arg)
+                    (\_ x   -> x)
+                    (last (tyConTyVars rep_tc))
+    covariant = quotes (pprSourceTyCon rep_tc) <+> 
+                ptext (sLit "uses the type variable in a function argument")
+    functions = quotes (pprSourceTyCon rep_tc) <+> 
+                ptext (sLit "contains function types")
+    wrong_arg = quotes (pprSourceTyCon rep_tc) <+> 
+                ptext (sLit "uses the type variable in an argument other than the last")
+
 cond_mayDeriveDataTypeable :: Condition
-cond_mayDeriveDataTypeable (mayDeriveDataTypeable, _)
- | mayDeriveDataTypeable = Nothing
+cond_mayDeriveDataTypeable (dflags, _)
+ | dopt Opt_DeriveDataTypeable dflags = Nothing
  | otherwise = Just why
   where
     why  = ptext (sLit "You need -XDeriveDataTypeable to derive an instance for this class")
 
+cond_mayDeriveFunctor :: Condition
+cond_mayDeriveFunctor (dflags, _)
+ | dopt Opt_DeriveFunctor dflags = Nothing
+ | otherwise = Just why
+  where
+    why  = ptext (sLit "You need -XDeriveFunctor to derive an instance for this class")
+
 std_class_via_iso :: Class -> Bool
 std_class_via_iso clas -- These standard classes can be derived for a newtype
                        -- using the isomorphism trick *even if no -fglasgow-exts*
@@ -890,11 +954,11 @@ a context for the Data instances:
 %************************************************************************
 
 \begin{code}
-mkNewTypeEqn :: InstOrigin -> Bool -> Bool -> [Var] -> Class
+mkNewTypeEqn :: InstOrigin -> DynFlags -> [Var] -> Class
              -> [Type] -> TyCon -> [Type] -> TyCon -> [Type]
              -> Maybe ThetaType
              -> TcRn EarlyDerivSpec
-mkNewTypeEqn orig mayDeriveDataTypeable newtype_deriving tvs
+mkNewTypeEqn orig dflags tvs
              cls cls_tys tycon tc_args rep_tycon rep_tc_args mtheta
 -- Want: instance (...) => cls (cls_tys ++ [tycon tc_args]) where ...
   | can_derive_via_isomorphism && (newtype_deriving || std_class_via_iso cls)
@@ -919,7 +983,8 @@ mkNewTypeEqn orig mayDeriveDataTypeable newtype_deriving tvs
        | newtype_deriving    -> bale_out cant_derive_err  -- Too hard, even with newtype deriving
        | otherwise           -> bale_out non_std_err      -- Try newtype deriving!
   where
-       check_conditions = checkSideConditions mayDeriveDataTypeable cls cls_tys rep_tycon
+        newtype_deriving = dopt Opt_GeneralizedNewtypeDeriving dflags
+       check_conditions = checkSideConditions dflags cls cls_tys rep_tycon
        bale_out msg = failWithTc (derivingThingErr cls cls_tys inst_ty msg)
 
        non_std_err = nonStdErr cls $$
@@ -1292,6 +1357,9 @@ genDerivBinds loc fix_env clas tycon
               ,(showClassKey,     gen_Show_binds fix_env)
               ,(readClassKey,     gen_Read_binds fix_env)
               ,(dataClassKey,     gen_Data_binds)
+              ,(functorClassKey,  gen_Functor_binds)
+              ,(foldableClassKey, gen_Foldable_binds)
+              ,(traversableClassKey, gen_Traversable_binds)
               ]
 \end{code}
 
index 9826f2f..845fecc 100644 (file)
@@ -23,6 +23,9 @@ module TcGenDeriv (
        gen_Show_binds,
        gen_Data_binds,
        gen_Typeable_binds,
+       gen_Functor_binds, functorLikeTraverse, deepSubtypesContaining,
+       gen_Foldable_binds,
+       gen_Traversable_binds,
        genAuxBind
     ) where
 
@@ -44,7 +47,12 @@ import TyCon
 import TcType
 import TysPrim
 import TysWiredIn
+import Type
+import TypeRep
+import VarSet
+import State
 import Util
+import MonadUtils
 import Outputable
 import FastString
 import OccName
@@ -1203,6 +1211,302 @@ prefix_RDR     = dataQual_RDR gENERICS (fsLit "Prefix")
 infix_RDR      = dataQual_RDR gENERICS (fsLit "Infix")
 \end{code}
 
+
+
+%************************************************************************
+%*                                                                     *
+       Functor instances
+%*                                                                     *
+%************************************************************************
+
+For the data type:
+
+  data T a = T1 Int a | T2 (T a)
+
+We generate the instance:
+
+  instance Functor T where
+      fmap f (T1 b1 a) = T1 b1 (f a)
+      fmap f (T2 ta)   = T2 (fmap f ta)
+
+Notice that we don't simply apply 'fmap' to the constructor arguments.
+Rather 
+  - Do nothing to an argument whose type doesn't mention 'a'
+  - Apply 'f' to an argument of type 'a'
+  - Apply 'fmap f' to other arguments 
+That's why we have to recurse deeply into the constructor argument types,
+rather than just one level, as we typically do.
+
+What about types with more than one type parameter?  In general, we only 
+derive Functor for the last position:
+
+  data S a b = S1 [b] | S2 a
+  instance Functor (S a) where
+    fmap f (S1 bs) = S1 (fmap f bs)
+    fmap f (S2 a)  = S2 a
+
+However, we have special cases for
+        - tuples
+        - functions
+
+More formally, we write the derivation of fmap code over type variable
+'a for type 'b as ($fmap 'a 'b).  In this general notation the derived
+instance for T is:
+
+  instance Functor T where
+      fmap f (T1 x1 x2) = T1 ($(fmap 'a 'b1) x1) ($(fmap 'a 'a) x2)
+      fmap f (T2 x1)    = T2 ($(fmap 'a '(T a)) x1)
+
+  $(fmap 'a 'b)         x  =  x     -- when b does not contain a
+  $(fmap 'a 'a)         x  =  f x
+  $(fmap 'a '(b1,b2))   x  =  case x of (x1,x2) -> ($(fmap 'a 'b1) x1, $(fmap 'a 'b2) x2)
+  $(fmap 'a '(T b1 b2)) x  =  fmap $(fmap 'a 'b2) x   -- when a only occurs in the last parameter, b2
+  $(fmap 'a '(b -> c))  x  =  \b -> $(fmap 'a' 'c) (x ($(cofmap 'a 'b) b))
+
+For functions, the type parameter 'a can occur in a contravariant position,
+which means we need to derive a function like:
+
+  cofmap :: (a -> b) -> (f b -> f a)
+
+This is pretty much the same as $fmap, only without the $(cofmap 'a 'a) case:
+
+  $(cofmap 'a 'b)         x  =  x     -- when b does not contain a
+  $(cofmap 'a 'a)         x  =  error "type variable in contravariant position"
+  $(cofmap 'a '(b1,b2))   x  =  case x of (x1,x2) -> ($(cofmap 'a 'b1) x1, $(cofmap 'a 'b2) x2)
+  $(cofmap 'a '[b])       x  =  map $(cofmap 'a 'b) x
+  $(cofmap 'a '(T b1 b2)) x  =  fmap $(cofmap 'a 'b2) x   -- when a only occurs in the last parameter, b2
+  $(cofmap 'a '(b -> c))  x  =  \b -> $(cofmap 'a' 'c) (x ($(fmap 'a 'c) b))
+
+\begin{code}
+gen_Functor_binds :: SrcSpan -> TyCon -> (LHsBinds RdrName, DerivAuxBinds)
+gen_Functor_binds loc tycon
+  = (listToBag [fmap_bind], [])
+  where
+    data_cons = tyConDataCons tycon
+    arg = last (tyConTyVars tycon) -- argument to derive for, 'a in the above description
+
+    fmap_bind = L loc $ mkFunBind (L loc fmap_RDR) (map fmap_eqn data_cons)
+    fmap_eqn con = evalState (match_for_con [f_Pat] con parts) bs_RDRs
+      where parts = map derive_fmap_type (dataConOrigArgTys con)
+
+    derive_fmap_type :: Type -> LHsExpr RdrName -> State [RdrName] (LHsExpr RdrName)
+    derive_fmap_type = functorLikeTraverse
+        (\     x -> return x)                                         -- fmap f x = x
+        (\     x -> return (nlHsApp f_Expr x))                        -- fmap f x = f x
+        (panic "contravariant")
+        (\g h  x -> mkSimpleLam (\b -> h =<< (nlHsApp x `fmap` g b))) -- fmap f x = \b -> h (x (g b))
+        (mkSimpleTupleCase match_for_con)                             -- fmap f x = case x of (a1,a2,..) -> (g1 a1,g2 a2,..)
+        (\_ g  x -> do gg <- mkSimpleLam g
+                       return $ nlHsApps fmap_RDR [gg,x])             -- fmap f x = fmap g x
+        (panic "in other argument")
+        (\_ g  x -> g x)
+        arg
+
+    match_for_con = mkSimpleConMatch $
+        \con_name xsM -> do xs <- sequence xsM
+                            return (nlHsApps con_name xs)  -- Con (g1 v1) (g2 v2) ..
+\end{code}
+
+Utility functions related to Functor deriving.
+
+Since several things use the same pattern of traversal, this is abstracted into functorLikeTraverse.
+This function works like a fold: it makes a value of type 'a' in a bottom up way.
+
+\begin{code}
+-- Generic traversal for Functor deriving
+functorLikeTraverse :: a                    -- ^ Case: does not contain variable
+                    -> a                    -- ^ Case: the variable itself
+                    -> a                    -- ^ Case: the variable itself, contravariantly
+                    -> (a -> a -> a)        -- ^ Case: function type
+                    -> (Boxity -> [a] -> a) -- ^ Case: tuple type
+                    -> (Type -> a -> a)     -- ^ Case: other tycon, variable only in last argument
+                    -> a                    -- ^ Case: other tycon, variable only in last argument
+                    -> (TcTyVar -> a -> a)  -- ^ Case: forall type
+                    -> TcTyVar              -- ^ Variable to look for
+                    -> Type                 -- ^ Type to process
+                    -> a
+functorLikeTraverse caseTrivial caseVar caseCoVar caseFun caseTuple caseTyApp caseWrongArg caseForAll var ty
+    = fst (go False ty)
+  where -- go returns (result of type a, does type contain var)
+        go co ty | Just ty' <- coreView ty = go co ty'
+        go co (TyVarTy    v) | v == var = (if co then caseCoVar else caseVar,True)
+        go co (FunTy (PredTy _) b)      = go co b
+        go co (FunTy x y)    | xc || yc = (caseFun xr yr,True)
+            where (xr,xc) = go (not co) x
+                  (yr,yc) = go co       y
+        go co (AppTy    x y) | xc       = (caseWrongArg,True)
+                             | yc       = (caseTyApp x yr,True)
+            where (_, xc) = go co x
+                  (yr,yc) = go co y
+        go co ty@(TyConApp con args)
+               | isTupleTyCon con       = (caseTuple (tupleTyConBoxity con) xrs,True)
+               | null args              = (caseTrivial,False)
+               | or (init xcs)          = (caseWrongArg,True)
+               | (last xcs)             = (caseTyApp (fst (splitAppTy ty)) (last xrs),True)
+            where (xrs,xcs) = unzip (map (go co) args)
+        go co (ForAllTy v x) | v /= var && xc = (caseForAll v xr,True)
+            where (xr,xc) = go co x
+        go _  _                         = (caseTrivial,False)
+
+-- return all subtypes of ty that contain var somewhere
+-- these are the things that should appear in instance constraints
+deepSubtypesContaining :: TcTyVar -> TcType -> [TcType]
+deepSubtypesContaining = functorLikeTraverse
+      []
+      []
+      (panic "contravariant")
+      (\x y   -> x ++ y)      -- function
+      (\_  xs -> concat xs)   -- tuple
+      (\ty x  -> ty : x)      -- tyapp
+      (panic "in other argument")
+      (\v x   -> filter (not . (v `elemVarSet`) . tyVarsOfType) x) -- forall v
+
+
+-- Make a HsLam using a fresh variable from a State monad
+mkSimpleLam :: (LHsExpr id -> State [id] (LHsExpr id)) -> State [id] (LHsExpr id)
+mkSimpleLam lam = do
+    (n:names) <- get
+    put names
+    body <- lam (nlHsVar n)
+    return (mkHsLam [nlVarPat n] body)
+
+mkSimpleLam2 :: (LHsExpr id -> LHsExpr id -> State [id] (LHsExpr id)) -> State [id] (LHsExpr id)
+mkSimpleLam2 lam = do
+    (n1:n2:names) <- get
+    put names
+    body <- lam (nlHsVar n1) (nlHsVar n2)
+    return (mkHsLam [nlVarPat n1,nlVarPat n2] body)
+
+-- "Con a1 a2 a3 -> fold [x1 a1, x2 a2, x3 a3]"
+mkSimpleConMatch :: Monad m => (RdrName -> [a] -> m (LHsExpr RdrName)) -> [LPat RdrName] -> DataCon -> [LHsExpr RdrName -> a] -> m (LMatch RdrName)
+mkSimpleConMatch fold extra_pats con insides = do
+    let con_name = getRdrName con
+    let vars_needed = takeList insides as_RDRs
+    let pat = nlConVarPat con_name vars_needed
+    rhs <- fold con_name (zipWith ($) insides (map nlHsVar vars_needed))
+    return $ mkMatch (extra_pats ++ [pat]) rhs emptyLocalBinds
+
+-- "case x of (a1,a2,a3) -> fold [x1 a1, x2 a2, x3 a3]"
+mkSimpleTupleCase :: Monad m => ([LPat RdrName] -> DataCon -> [LHsExpr RdrName -> a] -> m (LMatch RdrName))
+                  -> Boxity -> [LHsExpr RdrName -> a] -> LHsExpr RdrName -> m (LHsExpr RdrName)
+mkSimpleTupleCase match_for_con boxity insides x = do
+    let con = tupleCon boxity (length insides)
+    match <- match_for_con [] con insides
+    return $ nlHsCase x [match]
+\end{code}
+
+
+%************************************************************************
+%*                                                                     *
+       Foldable instances
+%*                                                                     *
+%************************************************************************
+
+Deriving Foldable instances works the same way as Functor instances,
+only Foldable instances are not possible for function types at all.
+Here the derived instance for the type T above is:
+
+  instance Foldable T where
+      foldr f z (T1 x1 x2 x3) = $(foldr 'a 'b1) x1 ( $(foldr 'a 'a) x2 ( $(foldr 'a 'b2) x3 z ) )
+
+The cases are:
+
+  $(foldr 'a 'b)         x z  =  z     -- when b does not contain a
+  $(foldr 'a 'a)         x z  =  f x z
+  $(foldr 'a '(b1,b2))   x z  =  case x of (x1,x2) -> $(foldr 'a 'b1) x1 ( $(foldr 'a 'b2) x2 z )
+  $(foldr 'a '(T b1 b2)) x z  =  foldr $(foldr 'a 'b2) x z  -- when a only occurs in the last parameter, b2
+
+Note that the arguments to the real foldr function are the wrong way around,
+since (f :: a -> b -> b), while (foldr f :: b -> t a -> b).
+
+\begin{code}
+gen_Foldable_binds :: SrcSpan -> TyCon -> (LHsBinds RdrName, DerivAuxBinds)
+gen_Foldable_binds loc tycon
+  = (listToBag [foldr_bind], [])
+  where
+    data_cons = tyConDataCons tycon
+    arg = last (tyConTyVars tycon) -- argument to derive for, 'a in the above description
+
+    foldr_bind = L loc $ mkFunBind (L loc foldr_RDR) (map foldr_eqn data_cons)
+    foldr_eqn con = evalState (match_for_con z_Expr [f_Pat,z_Pat] con parts) bs_RDRs
+      where parts = map derive_foldr_type (dataConOrigArgTys con)
+
+    derive_foldr_type :: Type -> LHsExpr RdrName -> LHsExpr RdrName -> State [RdrName] (LHsExpr RdrName)
+    derive_foldr_type = functorLikeTraverse
+        (\     _ z -> return z)                            -- foldr f z x = z
+        (\     x z -> return (nlHsApps f_RDR [x,z]))       -- foldr f z x = f x z
+        (panic "function")
+        (panic "function")
+        (\b gs x z -> mkSimpleTupleCase (match_for_con z) b gs x)
+        (\_ g  x z -> do gg <- mkSimpleLam2 g              -- foldr f z x = foldr (\xx zz -> g xx zz) z x
+                         return $ nlHsApps foldable_foldr_RDR [gg,z,x])
+        (panic "in other argument")
+        (\_ g  x z -> g x z)
+        arg
+
+    match_for_con z = mkSimpleConMatch (\_con_name -> foldrM ($) z) -- g1 v1 (g2 v2 (.. z))
+\end{code}
+
+
+%************************************************************************
+%*                                                                     *
+       Traversable instances
+%*                                                                     *
+%************************************************************************
+
+Again, Traversable is much like Functor and Foldable.
+
+The cases are:
+
+  $(traverse 'a 'b)         x  =  pure x     -- when b does not contain a
+  $(traverse 'a 'a)         x  =  f x
+  $(traverse 'a '(b1,b2))   x  =  case x of (x1,x2) -> (,) <$> $(traverse 'a 'b1) x1 <*> $(traverse 'a 'b2) x2
+  $(traverse 'a '(T b1 b2)) x  =  traverse $(traverse 'a 'b2) x  -- when a only occurs in the last parameter, b2
+
+Note that the generated code is not as efficient as it could be. For instance:
+
+  data T a = T Int a  deriving Traversable
+
+gives the function: traverse f (T x y) = T <$> pure x <*> f y
+instead of:         traverse f (T x y) = T x <$> f y
+
+\begin{code}
+gen_Traversable_binds :: SrcSpan -> TyCon -> (LHsBinds RdrName, DerivAuxBinds)
+gen_Traversable_binds loc tycon
+  = (listToBag [traverse_bind], [])
+  where
+    data_cons = tyConDataCons tycon
+    arg = last (tyConTyVars tycon) -- argument to derive for, 'a in the above description
+
+    traverse_bind = L loc $ mkFunBind (L loc traverse_RDR) (map traverse_eqn data_cons)
+    traverse_eqn con = evalState (match_for_con [f_Pat] con parts) bs_RDRs
+      where parts = map derive_travese_type (dataConOrigArgTys con)
+
+    derive_travese_type :: Type -> LHsExpr RdrName -> State [RdrName] (LHsExpr RdrName)
+    derive_travese_type = functorLikeTraverse
+        (\     x -> return (nlHsApps pure_RDR [x]))    -- traverse f x = pure x
+        (\     x -> return (nlHsApps f_RDR [x]))       -- travese f x = f x
+        (panic "function")
+        (panic "function")
+        (mkSimpleTupleCase match_for_con)              -- travese f x z = case x of (a1,a2,..) -> (,,) <$> g1 a1 <*> g2 a2 <*> ..
+        (\_ g  x -> do gg <- mkSimpleLam g             -- travese f x = travese (\xx -> g xx) x
+                       return $ nlHsApps traverse_RDR [gg,x])
+        (panic "in other argument")
+        (\_ g  x -> g x)
+        arg
+
+    match_for_con = mkSimpleConMatch $
+        \con_name xsM -> do xs <- sequence xsM
+                            return (mkApCon (nlHsVar con_name) xs)
+
+    -- ((Con <$> x1) <*> x2) <*> ..
+    mkApCon con []     = nlHsApps pure_RDR [con]
+    mkApCon con (x:xs) = foldl appAp (nlHsApps fmap_RDR [con,x]) xs
+       where appAp x y = nlHsApps ap_RDR [x,y]
+\end{code}
+
+
+
 %************************************************************************
 %*                                                                     *
 \subsection{Generating extra binds (@con2tag@ and @tag2con@)}
@@ -1500,12 +1804,13 @@ genOpApp e1 op e2 = nlHsPar (nlHsOpApp e1 op e2)
 \end{code}
 
 \begin{code}
-a_RDR, b_RDR, c_RDR, d_RDR, k_RDR, z_RDR, ah_RDR, bh_RDR, ch_RDR, dh_RDR,
+a_RDR, b_RDR, c_RDR, d_RDR, f_RDR, k_RDR, z_RDR, ah_RDR, bh_RDR, ch_RDR, dh_RDR,
     cmp_eq_RDR :: RdrName
 a_RDR          = mkVarUnqual (fsLit "a")
 b_RDR          = mkVarUnqual (fsLit "b")
 c_RDR          = mkVarUnqual (fsLit "c")
 d_RDR          = mkVarUnqual (fsLit "d")
+f_RDR          = mkVarUnqual (fsLit "f")
 k_RDR          = mkVarUnqual (fsLit "k")
 z_RDR          = mkVarUnqual (fsLit "z")
 ah_RDR         = mkVarUnqual (fsLit "a#")
@@ -1519,22 +1824,25 @@ as_RDRs         = [ mkVarUnqual (mkFastString ("a"++show i)) | i <- [(1::Int) .. ] ]
 bs_RDRs                = [ mkVarUnqual (mkFastString ("b"++show i)) | i <- [(1::Int) .. ] ]
 cs_RDRs                = [ mkVarUnqual (mkFastString ("c"++show i)) | i <- [(1::Int) .. ] ]
 
-a_Expr, b_Expr, c_Expr, ltTag_Expr, eqTag_Expr, gtTag_Expr,
+a_Expr, b_Expr, c_Expr, f_Expr, z_Expr, ltTag_Expr, eqTag_Expr, gtTag_Expr,
     false_Expr, true_Expr :: LHsExpr RdrName
 a_Expr         = nlHsVar a_RDR
 b_Expr         = nlHsVar b_RDR
 c_Expr         = nlHsVar c_RDR
+f_Expr         = nlHsVar f_RDR
+z_Expr         = nlHsVar z_RDR
 ltTag_Expr     = nlHsVar ltTag_RDR
 eqTag_Expr     = nlHsVar eqTag_RDR
 gtTag_Expr     = nlHsVar gtTag_RDR
 false_Expr     = nlHsVar false_RDR
 true_Expr      = nlHsVar true_RDR
 
-a_Pat, b_Pat, c_Pat, d_Pat, k_Pat, z_Pat :: LPat RdrName
+a_Pat, b_Pat, c_Pat, d_Pat, f_Pat, k_Pat, z_Pat :: LPat RdrName
 a_Pat          = nlVarPat a_RDR
 b_Pat          = nlVarPat b_RDR
 c_Pat          = nlVarPat c_RDR
 d_Pat          = nlVarPat d_RDR
+f_Pat          = nlVarPat f_RDR
 k_Pat          = nlVarPat k_RDR
 z_Pat          = nlVarPat z_RDR
 
index db6f96a..af81110 100644 (file)
@@ -32,6 +32,7 @@ module Util (
 
         -- * List operations controlled by another list
         takeList, dropList, splitAtList, split,
+        dropTail,
 
         -- * For loop
         nTimes,
@@ -608,6 +609,10 @@ splitAtList (_:xs) (y:ys) = (y:ys', ys'')
     where
       (ys', ys'') = splitAtList xs ys
 
+-- drop from the end of a list
+dropTail :: Int -> [a] -> [a]
+dropTail n = reverse . drop n . reverse
+
 snocView :: [a] -> Maybe ([a],a)
         -- Split off the last element
 snocView [] = Nothing