From 5c4a4c4bfe2a007f41f42ebab689bcd7219bed0d Mon Sep 17 00:00:00 2001 From: Simon Peyton Jones Date: Sat, 11 Jun 2011 16:44:07 +0100 Subject: [PATCH] Make TH capable of quoting GADT declarations (Trac #5217) Template Haskell doesn't support GADTs directly but we can use equality constraints to do the job. Here's an example of the dump from splicing such a declaration: [d| data T a b where T1 :: Int -> T Int Char T2 :: a -> T a a T3 :: a -> T [a] a T4 :: a -> b -> T b [a] |] ======> T5217.hs:(6,3)-(9,53) data T a[aQW] b[aQX] = (b[aQX] ~ Char, a[aQW] ~ Int) => T1 Int | b[aQX] ~ a[aQW] => T2 a[aQW] | a[aQW] ~ [b[aQX]] => T3 b[aQX] | forall a[aQY]. b[aQX] ~ [a[aQY]] => T4 a[aQY] a[aQW] --- compiler/deSugar/DsMeta.hs | 112 +++++++++++++++++++++++++++++---------- compiler/hsSyn/HsTypes.lhs | 14 +++++ compiler/typecheck/TcHsType.lhs | 19 +------ compiler/typecheck/TcMType.lhs | 5 +- 4 files changed, 103 insertions(+), 47 deletions(-) diff --git a/compiler/deSugar/DsMeta.hs b/compiler/deSugar/DsMeta.hs index a5cbdd3..ffcd0d4 100644 --- a/compiler/deSugar/DsMeta.hs +++ b/compiler/deSugar/DsMeta.hs @@ -57,6 +57,7 @@ import Bag import FastString import ForeignCall import MonadUtils +import Util( equalLength ) import Data.Maybe import Control.Monad @@ -173,7 +174,7 @@ repTyClD (L loc (TyData { tcdND = DataType, tcdCtxt = cxt, do { cxt1 <- repLContext cxt ; opt_tys1 <- maybeMapM repLTys opt_tys -- only for family insts ; opt_tys2 <- maybeMapM (coreList typeQTyConName) opt_tys1 - ; cons1 <- mapM repC cons + ; cons1 <- mapM (repC (hsLTyVarNames tvs)) cons ; cons2 <- coreList conQTyConName cons1 ; derivs1 <- repDerivs mb_derivs ; bndrs1 <- coreList tyVarBndrTyConName bndrs @@ -190,7 +191,7 @@ repTyClD (L loc (TyData { tcdND = NewType, tcdCtxt = cxt, do { cxt1 <- repLContext cxt ; opt_tys1 <- maybeMapM repLTys opt_tys -- only for family insts ; opt_tys2 <- maybeMapM (coreList typeQTyConName) opt_tys1 - ; con1 <- repC con + ; con1 <- repC (hsLTyVarNames tvs) con ; derivs1 <- repDerivs mb_derivs ; bndrs1 <- coreList tyVarBndrTyConName bndrs ; repNewtype cxt1 tc1 bndrs1 opt_tys2 con1 derivs1 @@ -360,23 +361,73 @@ ds_msg = ptext (sLit "Cannot desugar this Template Haskell declaration:") -- Constructors ------------------------------------------------------- -repC :: LConDecl Name -> DsM (Core TH.ConQ) -repC (L _ (ConDecl { con_name = con, con_qvars = [], con_cxt = L _ [] - , con_details = details, con_res = ResTyH98 })) +repC :: [Name] -> LConDecl Name -> DsM (Core TH.ConQ) +repC _ (L _ (ConDecl { con_name = con, con_qvars = [], con_cxt = L _ [] + , con_details = details, con_res = ResTyH98 })) = do { con1 <- lookupLOcc con -- See note [Binders and occurrences] - ; repConstr con1 details - } -repC (L loc con_decl@(ConDecl { con_qvars = tvs, con_cxt = L cloc ctxt, con_res = ResTyH98 })) - = addTyVarBinds tvs $ \bndrs -> - do { c' <- repC (L loc (con_decl { con_qvars = [], con_cxt = L cloc [] })) - ; ctxt' <- repContext ctxt - ; bndrs' <- coreList tyVarBndrTyConName bndrs - ; rep2 forallCName [unC bndrs', unC ctxt', unC c'] - } -repC (L loc con_decl) -- GADTs - = putSrcSpanDs loc $ - notHandled "GADT declaration" (ppr con_decl) - + ; repConstr con1 details } +repC tvs (L _ (ConDecl { con_name = con + , con_qvars = con_tvs, con_cxt = L _ ctxt + , con_details = details + , con_res = res_ty })) + = do { (eq_ctxt, con_tv_subst) <- mkGadtCtxt tvs res_ty + ; let ex_tvs = [ tv | tv <- con_tvs, not (hsLTyVarName tv `in_subst` con_tv_subst)] + ; binds <- mapM dupBinder con_tv_subst + ; dsExtendMetaEnv (mkNameEnv binds) $ -- Binds some of the con_tvs + addTyVarBinds ex_tvs $ \ ex_bndrs -> -- Binds the remaining con_tvs + do { con1 <- lookupLOcc con -- See note [Binders and occurrences] + ; c' <- repConstr con1 details + ; ctxt' <- repContext (eq_ctxt ++ ctxt) + ; ex_bndrs' <- coreList tyVarBndrTyConName ex_bndrs + ; rep2 forallCName [unC ex_bndrs', unC ctxt', unC c'] } } + +in_subst :: Name -> [(Name,Name)] -> Bool +in_subst _ [] = False +in_subst n ((n',_):ns) = n==n' || in_subst n ns + +mkGadtCtxt :: [Name] -- Tyvars of the data type + -> ResType Name + -> DsM (HsContext Name, [(Name,Name)]) +-- Given a data type in GADT syntax, figure out the equality +-- context, so that we can represent it with an explicit +-- equality context, because that is the only way to express +-- the GADT in TH syntax +-- +-- Example: +-- data T a b c where { MkT :: forall d e. d -> e -> T d [e] e +-- mkGadtCtxt [a,b,c] [d,e] (T d [e] e) +-- returns +-- (b~[e], c~e), [d->a] +-- +-- This function is fiddly, but not really hard +mkGadtCtxt _ ResTyH98 + = return ([], []) +mkGadtCtxt data_tvs (ResTyGADT res_ty) + | let (head_ty, tys) = splitHsAppTys res_ty [] + , Just _ <- is_hs_tyvar head_ty + , data_tvs `equalLength` tys + = return (go [] [] (data_tvs `zip` tys)) + + | otherwise + = failWithDs (ptext (sLit "Malformed constructor result type") <+> ppr res_ty) + where + go cxt subst [] = (cxt, subst) + go cxt subst ((data_tv, ty) : rest) + | Just con_tv <- is_hs_tyvar ty + , isTyVarName con_tv + , not (in_subst con_tv subst) + = go cxt ((con_tv, data_tv) : subst) rest + | otherwise + = go (eq_pred : cxt) subst rest + where + loc = getLoc ty + eq_pred = L loc (HsEqualP (L loc (HsTyVar data_tv)) ty) + + is_hs_tyvar (L _ (HsTyVar n)) = Just n -- Type variables *and* tycons + is_hs_tyvar (L _ (HsParTy ty)) = is_hs_tyvar ty + is_hs_tyvar _ = Nothing + + repBangTy :: LBangType Name -> DsM (Core (TH.StrictTypeQ)) repBangTy ty= do MkC s <- rep2 str [] @@ -506,16 +557,14 @@ type ProcessTyVarBinds a = -- meta environment and gets the *new* names on Core-level as an argument -- addTyVarBinds :: ProcessTyVarBinds a -addTyVarBinds tvs m = - do - let names = hsLTyVarNames tvs - mkWithKinds = map repTyVarBndrWithKind tvs - freshNames <- mkGenSyms names - term <- addBinds freshNames $ do - bndrs <- mapM lookupBinder names - kindedBndrs <- zipWithM ($) mkWithKinds bndrs - m kindedBndrs - wrapGenSyms freshNames term +addTyVarBinds tvs m + = do { freshNames <- mkGenSyms (hsLTyVarNames tvs) + ; term <- addBinds freshNames $ + do { kindedBndrs <- mapM mk_tv_bndr (tvs `zip` freshNames) + ; m kindedBndrs } + ; wrapGenSyms freshNames term } + where + mk_tv_bndr (tv, (_,v)) = repTyVarBndrWithKind tv (coreVar v) -- Look up a list of type variables; the computations passed as the second -- argument gets the *new* names on Core-level as an argument @@ -1112,6 +1161,13 @@ lookupBinder n where msg = ptext (sLit "DsMeta: failed binder lookup when desugaring a TH bracket:") <+> ppr n +dupBinder :: (Name, Name) -> DsM (Name, DsMetaVal) +dupBinder (new, old) + = do { mb_val <- dsLookupMetaEnv old + ; case mb_val of + Just val -> return (new, val) + Nothing -> pprPanic "dupBinder" (ppr old) } + -- Look up a name that is either locally bound or a global name -- -- * If it is a global name, generate the "original name" representation (ie, diff --git a/compiler/hsSyn/HsTypes.lhs b/compiler/hsSyn/HsTypes.lhs index 7dbb16d..d565c96 100644 --- a/compiler/hsSyn/HsTypes.lhs +++ b/compiler/hsSyn/HsTypes.lhs @@ -26,6 +26,7 @@ module HsTypes ( hsTyVarKind, hsTyVarNameKind, hsLTyVarName, hsLTyVarNames, hsLTyVarLocName, hsLTyVarLocNames, splitHsInstDeclTy, splitHsFunType, + splitHsAppTys, mkHsAppTys, -- Type place holder PostTcType, placeHolderType, PostTcKind, placeHolderKind, @@ -292,6 +293,19 @@ replaceTyVarName (KindedTyVar _ k) n' = KindedTyVar n' k \begin{code} +splitHsAppTys :: LHsType n -> [LHsType n] -> (LHsType n, [LHsType n]) +splitHsAppTys (L _ (HsAppTy f a)) as = splitHsAppTys f (a:as) +splitHsAppTys f as = (f,as) + +mkHsAppTys :: OutputableBndr n => LHsType n -> [LHsType n] -> HsType n +mkHsAppTys fun_ty [] = pprPanic "mkHsAppTys" (ppr fun_ty) +mkHsAppTys fun_ty (arg_ty:arg_tys) + = foldl mk_app (HsAppTy fun_ty arg_ty) arg_tys + where + mk_app fun arg = HsAppTy (noLoc fun) arg + -- Add noLocs for inner nodes of the application; + -- they are never used + splitHsInstDeclTy :: OutputableBndr name => HsType name diff --git a/compiler/typecheck/TcHsType.lhs b/compiler/typecheck/TcHsType.lhs index 65f16c5..7d9f93c 100644 --- a/compiler/typecheck/TcHsType.lhs +++ b/compiler/typecheck/TcHsType.lhs @@ -299,7 +299,7 @@ kc_check_hs_type (HsParTy ty) exp_kind = do { ty' <- kc_check_lhs_type ty exp_kind; return (HsParTy ty') } kc_check_hs_type ty@(HsAppTy ty1 ty2) exp_kind - = do { let (fun_ty, arg_tys) = splitHsAppTys ty1 ty2 + = do { let (fun_ty, arg_tys) = splitHsAppTys ty1 [ty2] ; (fun_ty', fun_kind) <- kc_lhs_type fun_ty ; arg_tys' <- kcCheckApps fun_ty fun_kind arg_tys ty exp_kind ; return (mkHsAppTys fun_ty' arg_tys') } @@ -387,11 +387,10 @@ kc_hs_type (HsOpTy ty1 op ty2) = do return (HsOpTy ty1' op ty2', res_kind) kc_hs_type (HsAppTy ty1 ty2) = do + let (fun_ty, arg_tys) = splitHsAppTys ty1 [ty2] (fun_ty', fun_kind) <- kc_lhs_type fun_ty (arg_tys', res_kind) <- kcApps fun_ty fun_kind arg_tys return (mkHsAppTys fun_ty' arg_tys', res_kind) - where - (fun_ty, arg_tys) = splitHsAppTys ty1 ty2 kc_hs_type (HsPredTy pred) = wrongPredErr pred @@ -458,20 +457,6 @@ kcCheckApps the_fun fun_kind args ty exp_kind -- This improves error message; Trac #2994 ; kc_check_lhs_types args_w_kinds } -splitHsAppTys :: LHsType Name -> LHsType Name -> (LHsType Name, [LHsType Name]) -splitHsAppTys fun_ty arg_ty = split fun_ty [arg_ty] - where - split (L _ (HsAppTy f a)) as = split f (a:as) - split f as = (f,as) - -mkHsAppTys :: LHsType Name -> [LHsType Name] -> HsType Name -mkHsAppTys fun_ty [] = pprPanic "mkHsAppTys" (ppr fun_ty) -mkHsAppTys fun_ty (arg_ty:arg_tys) - = foldl mk_app (HsAppTy fun_ty arg_ty) arg_tys - where - mk_app fun arg = HsAppTy (noLoc fun) arg -- Add noLocs for inner nodes of - -- the application; they are - -- never used --------------------------- splitFunKind :: SDoc -> Int -> TcKind -> [b] -> TcM ([(b,ExpKind)], TcKind) diff --git a/compiler/typecheck/TcMType.lhs b/compiler/typecheck/TcMType.lhs index 2c01d23..6423a83 100644 --- a/compiler/typecheck/TcMType.lhs +++ b/compiler/typecheck/TcMType.lhs @@ -1162,7 +1162,8 @@ check_pred_ty dflags ctxt pred@(ClassP cls tys) check_pred_ty dflags ctxt pred@(EqPred ty1 ty2) = do { -- Equational constraints are valid in all contexts if type -- families are permitted - ; checkTc (xopt Opt_TypeFamilies dflags) (eqPredTyErr pred) + ; checkTc (xopt Opt_TypeFamilies dflags || xopt Opt_GADTs dflags) + (eqPredTyErr pred) ; checkTc (case ctxt of ClassSCCtxt {} -> False; _ -> True) (eqSuperClassErr pred) @@ -1330,7 +1331,7 @@ badPredTyErr, eqPredTyErr, predTyVarErr :: PredType -> SDoc badPredTyErr pred = ptext (sLit "Illegal constraint") <+> pprPredTy pred eqPredTyErr pred = ptext (sLit "Illegal equational constraint") <+> pprPredTy pred $$ - parens (ptext (sLit "Use -XTypeFamilies to permit this")) + parens (ptext (sLit "Use -XGADTs or -XTypeFamilies to permit this")) predTyVarErr pred = sep [ptext (sLit "Non type-variable argument"), nest 2 (ptext (sLit "in the constraint:") <+> pprPredTy pred)] dupPredWarn :: [[PredType]] -> SDoc -- 1.7.10.4