Deriving for indexed data types
[ghc-hetmet.git] / compiler / typecheck / TcDeriv.lhs
index 90ff3a7..60a7499 100644 (file)
@@ -135,10 +135,14 @@ So, here are the synonyms for the ``equation'' structures:
 type DerivEqn = (SrcSpan, InstOrigin, Name, Class, TyCon, [TyVar], DerivRhs)
                -- The Name is the name for the DFun we'll build
                -- The tyvars bind all the variables in the RHS
+               -- For family indexes, the tycon is the representation tycon
 
 pprDerivEqn :: DerivEqn -> SDoc
-pprDerivEqn (l,_,n,c,tc,tvs,rhs)
-  = parens (hsep [ppr l, ppr n, ppr c, ppr tc, ppr tvs] <+> equals <+> ppr rhs)
+pprDerivEqn (l, _, n, c, tc, tvs, rhs)
+  = parens (hsep [ppr l, ppr n, ppr c, ppr origTc, ppr tys] <+> equals <+>
+           ppr rhs)
+  where
+    (origTc, tys) = tyConOrigHead tc
 
 type DerivRhs  = ThetaType
 type DerivSoln = DerivRhs
@@ -270,7 +274,8 @@ deriveOrdinaryStuff overlap_flag eqns
        ; extra_binds <- genTaggeryBinds inst_infos
 
        -- Done
-       ; returnM (inst_infos, unionManyBags (extra_binds : aux_binds_s))
+       ; returnM (map fst inst_infos, 
+                  unionManyBags (extra_binds : aux_binds_s))
    }
 
 -----------------------------------------
@@ -328,6 +333,13 @@ when the dict is constructed in TcInstDcls.tcInstDecl2
 
 
 \begin{code}
+type DerivSpec = (SrcSpan,             -- location of the deriving clause
+                 InstOrigin,           -- deriving at data decl or standalone?
+                 NewOrData,            -- newtype or data type
+                 Name,                 -- Type constructor for which we derive
+                 Maybe [LHsType Name], -- Type indexes if indexed type
+                 LHsType Name)         -- Class instance to be generated
+
 makeDerivEqns :: OverlapFlag
              -> [LTyClDecl Name] 
              -> [LDerivDecl Name] 
@@ -335,44 +347,60 @@ makeDerivEqns :: OverlapFlag
                      [InstInfo])       -- Special newtype derivings
 
 makeDerivEqns overlap_flag tycl_decls deriv_decls
-  = do derive_these_top_level <- mapM top_level_deriv deriv_decls >>= return . catMaybes
+  = do derive_top_level <- mapM top_level_deriv deriv_decls
        (maybe_ordinaries, maybe_newtypes) 
-           <- mapAndUnzipM mk_eqn (derive_these ++ derive_these_top_level)
+           <- mapAndUnzipM mk_eqn (derive_data ++ catMaybes derive_top_level)
        return (catMaybes maybe_ordinaries, catMaybes maybe_newtypes)
   where
     ------------------------------------------------------------------
-    derive_these :: [(SrcSpan, InstOrigin, NewOrData, Name, LHsType Name)]
-       -- Find the (nd, TyCon, Pred) pairs that must be `derived'
-    derive_these = [ (srcLocSpan (getSrcLoc tycon), DerivOrigin, nd, tycon, pred) 
-                  | L _ (TyData { tcdND = nd, tcdLName = L _ tycon, 
-                                 tcdDerivs = Just preds }) <- tycl_decls,
+    -- Deriving clauses at data declarations
+    derive_data :: [DerivSpec]
+    derive_data = [ (loc, DerivOrigin, nd, tycon, tyPats, pred) 
+                  | L loc (TyData { tcdND = nd, tcdLName = L _ tycon, 
+                                    tcdTyPats = tyPats,
+                                    tcdDerivs = Just preds }) <- tycl_decls,
                     pred <- preds ]
 
-    top_level_deriv :: LDerivDecl Name -> TcM (Maybe (SrcSpan, InstOrigin, NewOrData, Name, LHsType Name))
-    top_level_deriv d@(L l (DerivDecl inst ty_name)) = recoverM (returnM Nothing) $ setSrcSpan l $ 
+    -- Standalone deriving declarations
+    top_level_deriv :: LDerivDecl Name -> TcM (Maybe DerivSpec)
+    top_level_deriv d@(L loc (DerivDecl inst ty_name)) = 
+      recoverM (returnM Nothing) $ setSrcSpan loc $ 
         do tycon <- tcLookupLocatedTyCon ty_name
            let new_or_data = if isNewTyCon tycon then NewType else DataType
-           traceTc (text "Stand-alone deriving:" <+> ppr (new_or_data, unLoc ty_name, inst))
-           return $ Just (l, StandAloneDerivOrigin, new_or_data, unLoc ty_name, inst)
+           traceTc (text "Stand-alone deriving:" <+> 
+                   ppr (new_or_data, unLoc ty_name, inst))
+           return $ Just (loc, StandAloneDerivOrigin, new_or_data, 
+                         unLoc ty_name, Nothing, inst)
 
     ------------------------------------------------------------------
-    -- takes (whether newtype or data, name of data type, partially applied type class)
-    mk_eqn :: (SrcSpan, InstOrigin, NewOrData, Name, LHsType Name) -> TcM (Maybe DerivEqn, Maybe InstInfo)
+    -- Derive equation/inst info for one deriving clause (data or standalone)
+    mk_eqn :: DerivSpec -> TcM (Maybe DerivEqn, Maybe InstInfo)
        -- We swizzle the tyvars and datacons out of the tycon
        -- to make the rest of the equation
        --
-       -- The "deriv_ty" is a LHsType to take account of the fact that for newtype derivign
-       -- we allow deriving (forall a. C [a]).
-
-    mk_eqn (loc, orig, new_or_data, tycon_name, hs_deriv_ty)
-      = tcLookupTyCon tycon_name               `thenM` \ tycon ->
-       setSrcSpan loc          $
-        addErrCtxt (derivCtxt tycon)           $
-       tcExtendTyVarEnv (tyConTyVars tycon)    $       -- Deriving preds may (now) mention
-                                                       -- the type variables for the type constructor
-       tcHsDeriv hs_deriv_ty                   `thenM` \ (deriv_tvs, clas, tys) ->
-       doptM Opt_GlasgowExts                   `thenM` \ gla_exts ->
-        mk_eqn_help loc orig gla_exts new_or_data tycon deriv_tvs clas tys
+       -- The "deriv_ty" is a LHsType to take account of the fact that for
+       -- newtype deriving we allow deriving (forall a. C [a]).
+
+    mk_eqn (loc, orig, new_or_data, tycon_name, mb_tys, hs_deriv_ty)
+      = setSrcSpan loc                            $
+        addErrCtxt (derivCtxt tycon_name mb_tys)  $
+        do { named_tycon <- tcLookupTyCon tycon_name
+
+             -- Lookup representation tycon in case of a family instance
+          ; tycon <- case mb_tys of
+                       Nothing    -> return named_tycon
+                       Just hsTys -> do
+                                       tys <- mapM dsHsType hsTys
+                                       tcLookupFamInst named_tycon tys
+
+            -- Enable deriving preds to mention the type variables in the
+            -- instance type
+          ; tcExtendTyVarEnv (tyConTyVars tycon) $ do
+               -- 
+          { (deriv_tvs, clas, tys) <- tcHsDeriv hs_deriv_ty
+          ; gla_exts <- doptM Opt_GlasgowExts
+           ; mk_eqn_help loc orig gla_exts new_or_data tycon deriv_tvs clas tys
+          }}
 
     ------------------------------------------------------------------
     -- data/newtype T a = ... deriving( C t1 t2 )
@@ -381,10 +409,12 @@ makeDerivEqns overlap_flag tycl_decls deriv_decls
 
     mk_eqn_help loc orig gla_exts DataType tycon deriv_tvs clas tys
       | Just err <- checkSideConditions gla_exts tycon deriv_tvs clas tys
-      = bale_out (derivingThingErr clas tys tycon (tyConTyVars tycon) err)
+      = bale_out (derivingThingErr clas tys origTyCon ttys err)
       | otherwise 
       = do { eqn <- mkDataTypeEqn loc orig tycon clas
           ; returnM (Just eqn, Nothing) }
+      where
+        (origTyCon, ttys) = tyConOrigHead tycon
 
     mk_eqn_help loc orig gla_exts NewType tycon deriv_tvs clas tys
       | can_derive_via_isomorphism && (gla_exts || std_class_via_iso clas)
@@ -528,7 +558,7 @@ makeDerivEqns overlap_flag tycl_decls deriv_decls
              && (tyVarsOfType rep_fn' `disjointVarSet` dropped_tvs)
              && (tyVarsOfTypes tys    `disjointVarSet` dropped_tvs)
 
-       cant_derive_err = derivingThingErr clas tys tycon tyvars_to_keep
+       cant_derive_err = derivingThingErr clas tys tycon (mkTyVarTys tyvars_to_keep)
                                (vcat [ptext SLIT("even with cunning newtype deriving:"),
                                        if isRecursiveTyCon tycon then
                                          ptext SLIT("the newtype is recursive")
@@ -545,7 +575,7 @@ makeDerivEqns overlap_flag tycl_decls deriv_decls
                                        else empty
                                      ])
 
-       non_std_err = derivingThingErr clas tys tycon tyvars_to_keep
+       non_std_err = derivingThingErr clas tys tycon (mkTyVarTys tyvars_to_keep)
                                (vcat [non_std_why clas,
                                       ptext SLIT("Try -fglasgow-exts for GHC's newtype-deriving extension")])
 
@@ -588,7 +618,8 @@ mkDataTypeEqn loc orig tycon clas
 
   | otherwise
   = do { dfun_name <- new_dfun_name clas tycon
-       ; return (loc, orig, dfun_name, clas, tycon, tyvars, constraints) }
+       ; return (loc, orig, dfun_name, clas, tycon, tyvars, constraints)
+       }
   where
     tyvars            = tyConTyVars tycon
     constraints       = extra_constraints ++ ordinary_constraints
@@ -598,7 +629,7 @@ mkDataTypeEqn loc orig tycon clas
     ordinary_constraints
       = [ mkClassPred clas [arg_ty] 
         | data_con <- tyConDataCons tycon,
-          arg_ty <- dataConInstOrigArgTys data_con (map mkTyVarTy (tyConTyVars tycon)),
+          arg_ty <- dataConInstOrigArgTys data_con (mkTyVarTys tyvars),
           not (isUnLiftedType arg_ty)  -- No constraints for unlifted types?
         ]
 
@@ -678,12 +709,16 @@ cond_typeableOK :: Condition
 -- Currently: (a) args all of kind *
 --           (b) 7 or fewer args
 cond_typeableOK (gla_exts, tycon)
-  | tyConArity tycon > 7                                     = Just too_many
-  | not (all (isSubArgTypeKind . tyVarKind) (tyConTyVars tycon)) = Just bad_kind
-  | otherwise                                                = Nothing
+  | tyConArity tycon > 7       = Just too_many
+  | not (all (isSubArgTypeKind . tyVarKind) (tyConTyVars tycon)) 
+                                = Just bad_kind
+  | isFamInstTyCon tycon       = Just fam_inst  -- no Typable for family insts
+  | otherwise                  = Nothing
   where
     too_many = quotes (ppr tycon) <+> ptext SLIT("has too many arguments")
-    bad_kind = quotes (ppr tycon) <+> ptext SLIT("has arguments of kind other than `*'")
+    bad_kind = quotes (ppr tycon) <+> 
+              ptext SLIT("has arguments of kind other than `*'")
+    fam_inst = quotes (ppr tycon) <+> ptext SLIT("is a type family")
 
 cond_glaExts :: Condition
 cond_glaExts (gla_exts, tycon) | gla_exts  = Nothing
@@ -757,9 +792,9 @@ solveDerivEqns overlap_flag orig_eqns
 
     ------------------------------------------------------------------
     gen_soln :: DerivEqn -> TcM [PredType]
-    gen_soln (loc, orig, _, clas, tc,tyvars,deriv_rhs)
+    gen_soln (loc, orig, _, clas, tc, tyvars, deriv_rhs)
       = setSrcSpan loc $
-       do { let inst_tys = [mkTyConApp tc (mkTyVarTys tyvars)]
+       do { let inst_tys = [origHead]
           ; theta <- addErrCtxt (derivInstCtxt1 clas inst_tys) $
                      tcSimplifyDeriv orig tc tyvars deriv_rhs
                -- Claim: the result instance declaration is guaranteed valid
@@ -767,15 +802,15 @@ solveDerivEqns overlap_flag orig_eqns
                --   checkValidInstance tyvars theta clas inst_tys
           ; return (sortLe (<=) theta) }       -- Canonicalise before returning the solution
       where
-       
+         origHead = uncurry mkTyConApp (tyConOrigHead tc)      
 
     ------------------------------------------------------------------
     mk_inst_spec :: DerivEqn -> DerivSoln -> Instance
     mk_inst_spec (loc, orig, dfun_name, clas, tycon, tyvars, _) theta
        = mkLocalInstance dfun overlap_flag
        where
-         dfun = mkDictFunId dfun_name tyvars theta clas
-                            [mkTyConApp tycon (mkTyVarTys tyvars)]
+         dfun     = mkDictFunId dfun_name tyvars theta clas [origHead]
+         origHead = uncurry mkTyConApp (tyConOrigHead tycon)
 
 extendLocalInstEnv :: [Instance] -> TcM a -> TcM a
 -- Add new locally-defined instances; don't bother to check
@@ -850,16 +885,27 @@ the renamer.  What a great hack!
 \end{itemize}
 
 \begin{code}
--- Generate the InstInfo for the required instance,
+-- Generate the InstInfo for the required instance paired with the
+--   *representation* tycon for that instance,
 -- plus any auxiliary bindings required
-genInst :: Instance -> TcM (InstInfo, LHsBinds RdrName)
+--
+-- Representation tycons differ from the tycon in the instance signature in
+-- case of instances for indexed families.
+--
+genInst :: Instance -> TcM ((InstInfo, TyCon), LHsBinds RdrName)
 genInst spec
   = do { fix_env <- getFixityEnv
        ; let
            (tyvars,_,clas,[ty])    = instanceHead spec
            clas_nm                 = className clas
-           tycon                   = tcTyConAppTyCon ty 
-           (meth_binds, aux_binds) = genDerivBinds clas fix_env tycon
+           (visible_tycon, tyArgs) = tcSplitTyConApp ty 
+
+          -- In case of a family instance, we need to use the representation
+          -- tycon (after all it has the data constructors)
+        ; tycon <- if isOpenTyCon visible_tycon
+                  then tcLookupFamInst visible_tycon tyArgs
+                  else return visible_tycon
+       ; let (meth_binds, aux_binds) = genDerivBinds clas fix_env tycon
 
        -- Bring the right type variables into 
        -- scope, and rename the method binds
@@ -870,10 +916,10 @@ genInst spec
                                   rnMethodBinds clas_nm (\n -> []) [] meth_binds
 
        -- Build the InstInfo
-       ; return (InstInfo { iSpec = spec, 
-                            iBinds = VanillaInst rn_meth_binds [] }, 
+       ; return ((InstInfo { iSpec = spec, 
+                             iBinds = VanillaInst rn_meth_binds [] }, tycon),
                  aux_binds)
-       }
+        }
 
 genDerivBinds clas fix_env tycon
   | className clas `elem` typeableClassNames
@@ -936,15 +982,14 @@ We're deriving @Enum@, or @Ix@ (enum type only???)
 If we have a @tag2con@ function, we also generate a @maxtag@ constant.
 
 \begin{code}
-genTaggeryBinds :: [InstInfo] -> TcM (LHsBinds RdrName)
+genTaggeryBinds :: [(InstInfo, TyCon)] -> TcM (LHsBinds RdrName)
 genTaggeryBinds infos
   = do { names_so_far <- foldlM do_con2tag []           tycons_of_interest
        ; nm_alist_etc <- foldlM do_tag2con names_so_far tycons_of_interest
        ; return (listToBag (map gen_tag_n_con_monobind nm_alist_etc)) }
   where
-    all_CTs = [ (cls, tcTyConAppTyCon ty)
-             | info <- infos, 
-               let (cls,ty) = simpleInstInfoClsTy info ]
+    all_CTs                 = [ (fst (simpleInstInfoClsTy info), tc) 
+                             | (info, tc) <- infos]
     all_tycons             = map snd all_CTs
     (tycons_of_interest, _) = removeDups compare all_tycons
     
@@ -983,17 +1028,24 @@ genTaggeryBinds infos
 \end{code}
 
 \begin{code}
-derivingThingErr clas tys tycon tyvars why
-  = sep [hsep [ptext SLIT("Can't make a derived instance of"), quotes (ppr pred)],
+derivingThingErr clas tys tycon ttys why
+  = sep [hsep [ptext SLIT("Can't make a derived instance of"), 
+              quotes (ppr pred)],
         nest 2 (parens why)]
   where
-    pred = mkClassPred clas (tys ++ [mkTyConApp tycon (mkTyVarTys tyvars)])
+    pred = mkClassPred clas (tys ++ [mkTyConApp tycon ttys])
 
-derivCtxt :: TyCon -> SDoc
-derivCtxt tycon
-  = ptext SLIT("When deriving instances for") <+> quotes (ppr tycon)
+derivCtxt :: Name -> Maybe [LHsType Name] -> SDoc
+derivCtxt tycon mb_tys
+  = ptext SLIT("When deriving instances for") <+> quotes typeInst
+  where
+    typeInst = case mb_tys of
+                Nothing  -> ppr tycon
+                Just tys -> ppr tycon <+> 
+                            hsep (map (pprParendHsType . unLoc) tys)
 
 derivInstCtxt1 clas inst_tys
-  = ptext SLIT("When deriving the instance for") <+> quotes (pprClassPred clas inst_tys)
+  = ptext SLIT("When deriving the instance for") <+> 
+    quotes (pprClassPred clas inst_tys)
 \end{code}