Make TH capable of quoting GADT declarations (Trac #5217)
authorSimon Peyton Jones <simonpj@microsoft.com>
Sat, 11 Jun 2011 15:44:07 +0000 (16:44 +0100)
committerSimon Peyton Jones <simonpj@microsoft.com>
Sat, 11 Jun 2011 15:44:07 +0000 (16:44 +0100)
Template Haskell doesn't support GADTs directly but
we can use equality constraints to do the job. Here's
an example of the dump from splicing such a declaration:

    [d| data T a b
            where
              T1 :: Int -> T Int Char
              T2 :: a -> T a a
              T3 :: a -> T [a] a
              T4 :: a -> b -> T b [a] |]
  ======>
    T5217.hs:(6,3)-(9,53)
    data T a[aQW] b[aQX]
        = (b[aQX] ~ Char, a[aQW] ~ Int) => T1 Int |
          b[aQX] ~ a[aQW] => T2 a[aQW] |
          a[aQW] ~ [b[aQX]] => T3 b[aQX] |
          forall a[aQY]. b[aQX] ~ [a[aQY]] => T4 a[aQY] a[aQW]

compiler/deSugar/DsMeta.hs
compiler/hsSyn/HsTypes.lhs
compiler/typecheck/TcHsType.lhs
compiler/typecheck/TcMType.lhs

index a5cbdd3..ffcd0d4 100644 (file)
@@ -57,6 +57,7 @@ import Bag
 import FastString
 import ForeignCall
 import MonadUtils
+import Util( equalLength )
 
 import Data.Maybe
 import Control.Monad
@@ -173,7 +174,7 @@ repTyClD (L loc (TyData { tcdND = DataType, tcdCtxt = cxt,
            do { cxt1     <- repLContext cxt
               ; opt_tys1 <- maybeMapM repLTys opt_tys   -- only for family insts
               ; opt_tys2 <- maybeMapM (coreList typeQTyConName) opt_tys1
-              ; cons1    <- mapM repC cons
+              ; cons1    <- mapM (repC (hsLTyVarNames tvs)) cons
              ; cons2    <- coreList conQTyConName cons1
              ; derivs1  <- repDerivs mb_derivs
              ; bndrs1   <- coreList tyVarBndrTyConName bndrs
@@ -190,7 +191,7 @@ repTyClD (L loc (TyData { tcdND = NewType, tcdCtxt = cxt,
            do { cxt1     <- repLContext cxt
               ; opt_tys1 <- maybeMapM repLTys opt_tys   -- only for family insts
               ; opt_tys2 <- maybeMapM (coreList typeQTyConName) opt_tys1
-              ; con1     <- repC con
+              ; con1     <- repC (hsLTyVarNames tvs) con
              ; derivs1  <- repDerivs mb_derivs
              ; bndrs1   <- coreList tyVarBndrTyConName bndrs
              ; repNewtype cxt1 tc1 bndrs1 opt_tys2 con1 derivs1
@@ -360,23 +361,73 @@ ds_msg = ptext (sLit "Cannot desugar this Template Haskell declaration:")
 --                     Constructors
 -------------------------------------------------------
 
-repC :: LConDecl Name -> DsM (Core TH.ConQ)
-repC (L _ (ConDecl { con_name = con, con_qvars = [], con_cxt = L _ []
-                  , con_details = details, con_res = ResTyH98 }))
+repC :: [Name] -> LConDecl Name -> DsM (Core TH.ConQ)
+repC _ (L _ (ConDecl { con_name = con, con_qvars = [], con_cxt = L _ []
+                       , con_details = details, con_res = ResTyH98 }))
   = do { con1 <- lookupLOcc con        -- See note [Binders and occurrences] 
-       ; repConstr con1 details 
-       }
-repC (L loc con_decl@(ConDecl { con_qvars = tvs, con_cxt = L cloc ctxt, con_res = ResTyH98 }))
-  = addTyVarBinds tvs $ \bndrs -> 
-      do { c' <- repC (L loc (con_decl { con_qvars = [], con_cxt = L cloc [] }))
-         ; ctxt' <- repContext ctxt
-         ; bndrs' <- coreList tyVarBndrTyConName bndrs
-         ; rep2 forallCName [unC bndrs', unC ctxt', unC c']
-         }
-repC (L loc con_decl)          -- GADTs
-  = putSrcSpanDs loc $
-    notHandled "GADT declaration" (ppr con_decl) 
-
+       ; repConstr con1 details  }
+repC tvs (L _ (ConDecl { con_name = con
+                       , con_qvars = con_tvs, con_cxt = L _ ctxt
+                       , con_details = details
+                       , con_res = res_ty }))
+  = do { (eq_ctxt, con_tv_subst) <- mkGadtCtxt tvs res_ty
+       ; let ex_tvs = [ tv | tv <- con_tvs, not (hsLTyVarName tv `in_subst` con_tv_subst)]
+       ; binds <- mapM dupBinder con_tv_subst 
+       ; dsExtendMetaEnv (mkNameEnv binds) $     -- Binds some of the con_tvs
+         addTyVarBinds ex_tvs $ \ ex_bndrs ->   -- Binds the remaining con_tvs
+    do { con1      <- lookupLOcc con   -- See note [Binders and occurrences] 
+       ; c'        <- repConstr con1 details
+       ; ctxt'     <- repContext (eq_ctxt ++ ctxt)
+       ; ex_bndrs' <- coreList tyVarBndrTyConName ex_bndrs
+       ; rep2 forallCName [unC ex_bndrs', unC ctxt', unC c'] } }
+
+in_subst :: Name -> [(Name,Name)] -> Bool
+in_subst _ []          = False
+in_subst n ((n',_):ns) = n==n' || in_subst n ns
+
+mkGadtCtxt :: [Name]           -- Tyvars of the data type
+           -> ResType Name
+          -> DsM (HsContext Name, [(Name,Name)])
+-- Given a data type in GADT syntax, figure out the equality 
+-- context, so that we can represent it with an explicit 
+-- equality context, because that is the only way to express
+-- the GADT in TH syntax
+--
+-- Example:   
+-- data T a b c where { MkT :: forall d e. d -> e -> T d [e] e
+--     mkGadtCtxt [a,b,c] [d,e] (T d [e] e)
+--   returns 
+--     (b~[e], c~e), [d->a] 
+-- 
+-- This function is fiddly, but not really hard
+mkGadtCtxt _ ResTyH98
+  = return ([], [])
+mkGadtCtxt data_tvs (ResTyGADT res_ty)
+  | let (head_ty, tys) = splitHsAppTys res_ty []
+  , Just _ <- is_hs_tyvar head_ty
+  , data_tvs `equalLength` tys
+  = return (go [] [] (data_tvs `zip` tys))
+
+  | otherwise 
+  = failWithDs (ptext (sLit "Malformed constructor result type") <+> ppr res_ty)
+  where
+    go cxt subst [] = (cxt, subst)
+    go cxt subst ((data_tv, ty) : rest)
+       | Just con_tv <- is_hs_tyvar ty
+       , isTyVarName con_tv
+       , not (in_subst con_tv subst)
+       = go cxt ((con_tv, data_tv) : subst) rest
+       | otherwise
+       = go (eq_pred : cxt) subst rest
+       where
+         loc = getLoc ty
+         eq_pred = L loc (HsEqualP (L loc (HsTyVar data_tv)) ty)
+
+    is_hs_tyvar (L _ (HsTyVar n))  = Just n   -- Type variables *and* tycons
+    is_hs_tyvar (L _ (HsParTy ty)) = is_hs_tyvar ty
+    is_hs_tyvar _                  = Nothing
+
+    
 repBangTy :: LBangType Name -> DsM (Core (TH.StrictTypeQ))
 repBangTy ty= do 
   MkC s <- rep2 str []
@@ -506,16 +557,14 @@ type ProcessTyVarBinds a =
 -- meta environment and gets the *new* names on Core-level as an argument
 --
 addTyVarBinds :: ProcessTyVarBinds a
-addTyVarBinds tvs m =
-  do
-    let names       = hsLTyVarNames tvs
-        mkWithKinds = map repTyVarBndrWithKind tvs
-    freshNames <- mkGenSyms names
-    term       <- addBinds freshNames $ do
-                   bndrs       <- mapM lookupBinder names 
-                    kindedBndrs <- zipWithM ($) mkWithKinds bndrs
-                   m kindedBndrs
-    wrapGenSyms freshNames term
+addTyVarBinds tvs m
+  = do { freshNames <- mkGenSyms (hsLTyVarNames tvs)
+       ; term <- addBinds freshNames $ 
+                do { kindedBndrs <- mapM mk_tv_bndr (tvs `zip` freshNames)
+                   ; m kindedBndrs }
+       ; wrapGenSyms freshNames term }
+  where
+    mk_tv_bndr (tv, (_,v)) = repTyVarBndrWithKind tv (coreVar v)
 
 -- Look up a list of type variables; the computations passed as the second 
 -- argument gets the *new* names on Core-level as an argument
@@ -1112,6 +1161,13 @@ lookupBinder n
   where
     msg = ptext (sLit "DsMeta: failed binder lookup when desugaring a TH bracket:") <+> ppr n
 
+dupBinder :: (Name, Name) -> DsM (Name, DsMetaVal)
+dupBinder (new, old) 
+  = do { mb_val <- dsLookupMetaEnv old
+       ; case mb_val of
+           Just val -> return (new, val)
+           Nothing  -> pprPanic "dupBinder" (ppr old) }
+
 -- Look up a name that is either locally bound or a global name
 --
 --  * If it is a global name, generate the "original name" representation (ie,
index 7dbb16d..d565c96 100644 (file)
@@ -26,6 +26,7 @@ module HsTypes (
        hsTyVarKind, hsTyVarNameKind,
        hsLTyVarName, hsLTyVarNames, hsLTyVarLocName, hsLTyVarLocNames,
        splitHsInstDeclTy, splitHsFunType,
+       splitHsAppTys, mkHsAppTys,
        
        -- Type place holder
        PostTcType, placeHolderType, PostTcKind, placeHolderKind,
@@ -292,6 +293,19 @@ replaceTyVarName (KindedTyVar _ k) n' = KindedTyVar n' k
 
 
 \begin{code}
+splitHsAppTys :: LHsType n -> [LHsType n] -> (LHsType n, [LHsType n])
+splitHsAppTys (L _ (HsAppTy f a)) as = splitHsAppTys f (a:as)
+splitHsAppTys f                  as = (f,as)
+
+mkHsAppTys :: OutputableBndr n => LHsType n -> [LHsType n] -> HsType n
+mkHsAppTys fun_ty [] = pprPanic "mkHsAppTys" (ppr fun_ty)
+mkHsAppTys fun_ty (arg_ty:arg_tys)
+  = foldl mk_app (HsAppTy fun_ty arg_ty) arg_tys
+  where
+    mk_app fun arg = HsAppTy (noLoc fun) arg   
+       -- Add noLocs for inner nodes of the application; 
+       -- they are never used 
+
 splitHsInstDeclTy 
     :: OutputableBndr name
     => HsType name 
index 65f16c5..7d9f93c 100644 (file)
@@ -299,7 +299,7 @@ kc_check_hs_type (HsParTy ty) exp_kind
   = do { ty' <- kc_check_lhs_type ty exp_kind; return (HsParTy ty') }
 
 kc_check_hs_type ty@(HsAppTy ty1 ty2) exp_kind
-  = do { let (fun_ty, arg_tys) = splitHsAppTys ty1 ty2
+  = do { let (fun_ty, arg_tys) = splitHsAppTys ty1 [ty2]
        ; (fun_ty', fun_kind) <- kc_lhs_type fun_ty
        ; arg_tys' <- kcCheckApps fun_ty fun_kind arg_tys ty exp_kind
        ; return (mkHsAppTys fun_ty' arg_tys') }
@@ -387,11 +387,10 @@ kc_hs_type (HsOpTy ty1 op ty2) = do
     return (HsOpTy ty1' op ty2', res_kind)
 
 kc_hs_type (HsAppTy ty1 ty2) = do
+    let (fun_ty, arg_tys) = splitHsAppTys ty1 [ty2]
     (fun_ty', fun_kind) <- kc_lhs_type fun_ty
     (arg_tys', res_kind) <- kcApps fun_ty fun_kind arg_tys
     return (mkHsAppTys fun_ty' arg_tys', res_kind)
-  where
-    (fun_ty, arg_tys) = splitHsAppTys ty1 ty2
 
 kc_hs_type (HsPredTy pred)
   = wrongPredErr pred
@@ -458,20 +457,6 @@ kcCheckApps the_fun fun_kind args ty exp_kind
             -- This improves error message; Trac #2994
        ; kc_check_lhs_types args_w_kinds }
 
-splitHsAppTys :: LHsType Name -> LHsType Name -> (LHsType Name, [LHsType Name])
-splitHsAppTys fun_ty arg_ty = split fun_ty [arg_ty]
-  where
-    split (L _ (HsAppTy f a)) as = split f (a:as)
-    split f                  as = (f,as)
-
-mkHsAppTys :: LHsType Name -> [LHsType Name] -> HsType Name
-mkHsAppTys fun_ty [] = pprPanic "mkHsAppTys" (ppr fun_ty)
-mkHsAppTys fun_ty (arg_ty:arg_tys)
-  = foldl mk_app (HsAppTy fun_ty arg_ty) arg_tys
-  where
-    mk_app fun arg = HsAppTy (noLoc fun) arg   -- Add noLocs for inner nodes of
-                                               -- the application; they are
-                                               -- never used 
 
 ---------------------------
 splitFunKind :: SDoc -> Int -> TcKind -> [b] -> TcM ([(b,ExpKind)], TcKind)
index 2c01d23..6423a83 100644 (file)
@@ -1162,7 +1162,8 @@ check_pred_ty dflags ctxt pred@(ClassP cls tys)
 check_pred_ty dflags ctxt pred@(EqPred ty1 ty2)
   = do {       -- Equational constraints are valid in all contexts if type
                -- families are permitted
-       ; checkTc (xopt Opt_TypeFamilies dflags) (eqPredTyErr pred)
+       ; checkTc (xopt Opt_TypeFamilies dflags || xopt Opt_GADTs dflags) 
+                 (eqPredTyErr pred)
        ; checkTc (case ctxt of ClassSCCtxt {} -> False; _ -> True)
                  (eqSuperClassErr pred)
 
@@ -1330,7 +1331,7 @@ badPredTyErr, eqPredTyErr, predTyVarErr :: PredType -> SDoc
 badPredTyErr pred = ptext (sLit "Illegal constraint") <+> pprPredTy pred
 eqPredTyErr  pred = ptext (sLit "Illegal equational constraint") <+> pprPredTy pred
                    $$
-                   parens (ptext (sLit "Use -XTypeFamilies to permit this"))
+                   parens (ptext (sLit "Use -XGADTs or -XTypeFamilies to permit this"))
 predTyVarErr pred  = sep [ptext (sLit "Non type-variable argument"),
                          nest 2 (ptext (sLit "in the constraint:") <+> pprPredTy pred)]
 dupPredWarn :: [[PredType]] -> SDoc