Add the ability to derive instances of Functor, Foldable, Traversable
[ghc-hetmet.git] / compiler / typecheck / TcDeriv.lhs
index eac2209..a507197 100644 (file)
@@ -49,6 +49,8 @@ import ListSetOps
 import Outputable
 import FastString
 import Bag
+
+import Control.Monad
 \end{code}
 
 %************************************************************************
@@ -566,15 +568,12 @@ mkEqnHelp orig tvs cls cls_tys tc_app mtheta
                   className cls `elem` typeableClassNames) 
                  (derivingHiddenErr tycon)
 
-       ; mayDeriveDataTypeable <- doptM Opt_DeriveDataTypeable
-       ; newtype_deriving <- doptM Opt_GeneralizedNewtypeDeriving
-
+       ; dflags <- getDOpts
        ; if isDataTyCon rep_tc then
-               mkDataTypeEqn orig mayDeriveDataTypeable tvs cls cls_tys 
+               mkDataTypeEqn orig dflags tvs cls cls_tys
                              tycon tc_args rep_tc rep_tc_args mtheta
          else
-               mkNewTypeEqn orig mayDeriveDataTypeable newtype_deriving
-                            tvs cls cls_tys 
+               mkNewTypeEqn orig dflags tvs cls cls_tys 
                             tycon tc_args rep_tc rep_tc_args mtheta }
   | otherwise
   = failWithTc (derivingThingErr cls cls_tys tc_app
@@ -631,13 +630,21 @@ famInstNotFound tycon tys
 %************************************************************************
 
 \begin{code}
-mkDataTypeEqn :: InstOrigin -> Bool -> [Var] -> Class -> [Type]
-              -> TyCon -> [Type] -> TyCon -> [Type] -> Maybe ThetaType
-              -> TcRn EarlyDerivSpec   -- Return 'Nothing' if error
-               
-mkDataTypeEqn orig mayDeriveDataTypeable tvs cls cls_tys
+mkDataTypeEqn :: InstOrigin
+              -> DynFlags
+              -> [Var]                  -- Universally quantified type variables in the instance
+              -> Class                  -- Class for which we need to derive an instance
+              -> [Type]                 -- Other parameters to the class except the last
+              -> TyCon                  -- Type constructor for which the instance is requested (last parameter to the type class)
+              -> [Type]                 -- Parameters to the type constructor
+              -> TyCon                  -- rep of the above (for type families)
+              -> [Type]                 -- rep of the above
+              -> Maybe ThetaType        -- Context of the instance, for standalone deriving
+              -> TcRn EarlyDerivSpec    -- Return 'Nothing' if error
+
+mkDataTypeEqn orig dflags tvs cls cls_tys
               tycon tc_args rep_tc rep_tc_args mtheta
-  = case checkSideConditions mayDeriveDataTypeable cls cls_tys rep_tc of
+  = case checkSideConditions dflags cls cls_tys rep_tc of
        -- NB: pass the *representation* tycon to checkSideConditions
        CanDerive -> mk_data_eqn orig tvs cls tycon tc_args rep_tc rep_tc_args mtheta
        NonDerivableClass       -> bale_out (nonStdErr cls)
@@ -656,7 +663,7 @@ mk_data_eqn orig tvs cls tycon tc_args rep_tc rep_tc_args mtheta
   | otherwise
   = do { dfun_name <- new_dfun_name cls tycon
        ; loc <- getSrcSpanM
-       ; let ordinary_constraints
+       ; let ordinary_constraints_simple
                = [ mkClassPred cls [arg_ty] 
                  | data_con <- tyConDataCons rep_tc,
                    arg_ty   <- ASSERT( isVanillaDataCon data_con )
@@ -665,13 +672,31 @@ mk_data_eqn orig tvs cls tycon tc_args rep_tc rep_tc_args mtheta
                        -- No constraints for unlifted types
                        -- Where they are legal we generate specilised function calls
 
+              -- constraints on all subtypes for classes like Functor
+              ordinary_constraints_deep
+                = [ mkClassPred cls [deept_ty]
+                  | data_con <- tyConDataCons rep_tc,
+                    arg_ty   <- ASSERT( isVanillaDataCon data_con )
+                                dataConInstOrigArgTys data_con (rep_tc_args++[mkTyVarTy dummy_ty]),
+                    deept_ty <- deepSubtypesContaining dummy_ty arg_ty,
+                    not (isUnLiftedType deept_ty) ]
+               where dummy_ty = last (tyConTyVars tycon) -- don't substitute the last var, this might not be a good idea
+
+              ordinary_constraints
+               | getUnique cls == functorClassKey     = ordinary_constraints_deep
+               | getUnique cls == foldableClassKey    = ordinary_constraints_deep
+               | getUnique cls == traversableClassKey = ordinary_constraints_deep
+               | otherwise                            = ordinary_constraints_simple
+
                        -- See Note [Superclasses of derived instance]
              sc_constraints = substTheta (zipOpenTvSubst (classTyVars cls) inst_tys)
                                          (classSCTheta cls)
              inst_tys = [mkTyConApp tycon tc_args]
 
-             stupid_subst = zipTopTvSubst (tyConTyVars rep_tc) rep_tc_args
+             nonfree_tycon_vars = dropTail (classArity cls) (tyConTyVars rep_tc)
+             stupid_subst = zipTopTvSubst nonfree_tycon_vars rep_tc_args
              stupid_constraints = substTheta stupid_subst (tyConStupidTheta rep_tc)
+
              all_constraints = stupid_constraints ++ sc_constraints ++ ordinary_constraints
 
              spec = DS { ds_loc = loc, ds_orig = orig
@@ -712,6 +737,7 @@ mk_typeable_eqn orig tvs cls tycon tc_args rep_tc rep_tc_args mtheta
                     , ds_tc = rep_tc, ds_tc_args = rep_tc_args
                     , ds_theta = mtheta `orElse` [], ds_newtype = False })  }
 
+
 ------------------------------------------------------------------
 -- Check side conditions that dis-allow derivability for particular classes
 -- This is *apart* from the newtype-deriving mechanism
@@ -724,10 +750,10 @@ data DerivStatus = CanDerive
                 | DerivableClassError SDoc     -- Standard class, but can't do it
                 | NonDerivableClass            -- Non-standard class
 
-checkSideConditions :: Bool -> Class -> [TcType] -> TyCon -> DerivStatus
-checkSideConditions mayDeriveDataTypeable cls cls_tys rep_tc
+checkSideConditions :: DynFlags -> Class -> [TcType] -> TyCon -> DerivStatus
+checkSideConditions dflags cls cls_tys rep_tc
   | Just cond <- sideConditions cls
-  = case (cond (mayDeriveDataTypeable, rep_tc)) of
+  = case (cond (dflags, rep_tc)) of
        Just err -> DerivableClassError err     -- Class-specific error
        Nothing  | null cls_tys -> CanDerive
                 | otherwise    -> DerivableClassError ty_args_why      -- e.g. deriving( Eq s )
@@ -748,13 +774,17 @@ sideConditions cls
   | cls_key == ixClassKey      = Just (cond_std `andCond` cond_enumOrProduct)
   | cls_key == boundedClassKey = Just (cond_std `andCond` cond_enumOrProduct)
   | cls_key == dataClassKey    = Just (cond_mayDeriveDataTypeable `andCond` cond_std `andCond` cond_noUnliftedArgs)
+  | cls_key == functorClassKey = Just (cond_mayDeriveFunctor `andCond` cond_std `andCond` cond_functorOK True)
+  | cls_key == foldableClassKey = Just (cond_mayDeriveFunctor `andCond` cond_std `andCond` cond_functorOK False)
+  | cls_key == traversableClassKey = Just (cond_mayDeriveFunctor `andCond` cond_std `andCond` cond_functorOK False)
   | getName cls `elem` typeableClassNames = Just (cond_mayDeriveDataTypeable `andCond` cond_typeableOK)
   | otherwise = Nothing
   where
     cls_key = getUnique cls
 
-type Condition = (Bool, TyCon) -> Maybe SDoc
-       -- Bool is whether or not we are allowed to derive Data and Typeable
+type Condition = (DynFlags, TyCon) -> Maybe SDoc
+       -- first Bool is whether or not we are allowed to derive Data and Typeable
+       -- second Bool is whether or not we are allowed to derive Functor
        -- TyCon is the *representation* tycon if the 
        --      data type is an indexed one
        -- Nothing => OK
@@ -835,13 +865,47 @@ cond_typeableOK (_, rep_tc)
     fam_inst = quotes (pprSourceTyCon rep_tc) <+> 
               ptext (sLit "is a type family")
 
+cond_functorOK :: Bool -> Condition
+-- OK for Functor class
+-- Currently: (a) at least one argument
+--            (b) don't use argument contravariantly
+--            (c) don't use argument in the wrong place, e.g. data T a = T (X a a)
+--            (d) optionally: don't use function types
+cond_functorOK allowFunctions (_, rep_tc) = msum (map check con_types)
+  where
+    data_cons = tyConDataCons rep_tc
+    con_types = concatMap dataConOrigArgTys data_cons
+    check = functorLikeTraverse
+                    Nothing
+                    Nothing
+                    (Just covariant)
+                    (\x y   -> if allowFunctions then x `mplus` y else Just functions)
+                    (\_ xs  -> msum xs)
+                    (\_ x   -> x)
+                    (Just wrong_arg)
+                    (\_ x   -> x)
+                    (last (tyConTyVars rep_tc))
+    covariant = quotes (pprSourceTyCon rep_tc) <+> 
+                ptext (sLit "uses the type variable in a function argument")
+    functions = quotes (pprSourceTyCon rep_tc) <+> 
+                ptext (sLit "contains function types")
+    wrong_arg = quotes (pprSourceTyCon rep_tc) <+> 
+                ptext (sLit "uses the type variable in an argument other than the last")
+
 cond_mayDeriveDataTypeable :: Condition
-cond_mayDeriveDataTypeable (mayDeriveDataTypeable, _)
- | mayDeriveDataTypeable = Nothing
+cond_mayDeriveDataTypeable (dflags, _)
+ | dopt Opt_DeriveDataTypeable dflags = Nothing
  | otherwise = Just why
   where
     why  = ptext (sLit "You need -XDeriveDataTypeable to derive an instance for this class")
 
+cond_mayDeriveFunctor :: Condition
+cond_mayDeriveFunctor (dflags, _)
+ | dopt Opt_DeriveFunctor dflags = Nothing
+ | otherwise = Just why
+  where
+    why  = ptext (sLit "You need -XDeriveFunctor to derive an instance for this class")
+
 std_class_via_iso :: Class -> Bool
 std_class_via_iso clas -- These standard classes can be derived for a newtype
                        -- using the isomorphism trick *even if no -fglasgow-exts*
@@ -890,11 +954,11 @@ a context for the Data instances:
 %************************************************************************
 
 \begin{code}
-mkNewTypeEqn :: InstOrigin -> Bool -> Bool -> [Var] -> Class
+mkNewTypeEqn :: InstOrigin -> DynFlags -> [Var] -> Class
              -> [Type] -> TyCon -> [Type] -> TyCon -> [Type]
              -> Maybe ThetaType
              -> TcRn EarlyDerivSpec
-mkNewTypeEqn orig mayDeriveDataTypeable newtype_deriving tvs
+mkNewTypeEqn orig dflags tvs
              cls cls_tys tycon tc_args rep_tycon rep_tc_args mtheta
 -- Want: instance (...) => cls (cls_tys ++ [tycon tc_args]) where ...
   | can_derive_via_isomorphism && (newtype_deriving || std_class_via_iso cls)
@@ -919,7 +983,8 @@ mkNewTypeEqn orig mayDeriveDataTypeable newtype_deriving tvs
        | newtype_deriving    -> bale_out cant_derive_err  -- Too hard, even with newtype deriving
        | otherwise           -> bale_out non_std_err      -- Try newtype deriving!
   where
-       check_conditions = checkSideConditions mayDeriveDataTypeable cls cls_tys rep_tycon
+        newtype_deriving = dopt Opt_GeneralizedNewtypeDeriving dflags
+       check_conditions = checkSideConditions dflags cls cls_tys rep_tycon
        bale_out msg = failWithTc (derivingThingErr cls cls_tys inst_ty msg)
 
        non_std_err = nonStdErr cls $$
@@ -1292,6 +1357,9 @@ genDerivBinds loc fix_env clas tycon
               ,(showClassKey,     gen_Show_binds fix_env)
               ,(readClassKey,     gen_Read_binds fix_env)
               ,(dataClassKey,     gen_Data_binds)
+              ,(functorClassKey,  gen_Functor_binds)
+              ,(foldableClassKey, gen_Foldable_binds)
+              ,(traversableClassKey, gen_Traversable_binds)
               ]
 \end{code}