From ff3bfae6010625b7ffe96bc62e8e139870684600 Mon Sep 17 00:00:00 2001 From: Roman Leshchinskiy Date: Wed, 26 Jan 2011 23:18:43 +0000 Subject: [PATCH] Fix vectorisation of recursive types --- compiler/ghc.cabal.in | 2 - compiler/vectorise/Vectorise.hs | 7 +-- compiler/vectorise/Vectorise/Builtins/Base.hs | 2 + .../vectorise/Vectorise/Builtins/Initialise.hs | 19 +++--- compiler/vectorise/Vectorise/Env.hs | 12 +++- compiler/vectorise/Vectorise/Monad/Base.hs | 1 - compiler/vectorise/Vectorise/Type/Env.hs | 9 ++- compiler/vectorise/Vectorise/Type/PADict.hs | 51 +++++++++------ compiler/vectorise/Vectorise/Type/PRDict.hs | 56 ----------------- compiler/vectorise/Vectorise/Type/Repr.hs | 2 +- compiler/vectorise/Vectorise/Utils.hs | 2 - compiler/vectorise/Vectorise/Utils/PADict.hs | 65 ++++++++++++++------ compiler/vectorise/Vectorise/Utils/PRDict.hs | 43 ------------- 13 files changed, 107 insertions(+), 164 deletions(-) delete mode 100644 compiler/vectorise/Vectorise/Type/PRDict.hs delete mode 100644 compiler/vectorise/Vectorise/Utils/PRDict.hs diff --git a/compiler/ghc.cabal.in b/compiler/ghc.cabal.in index a8c2d54..cc4c562 100644 --- a/compiler/ghc.cabal.in +++ b/compiler/ghc.cabal.in @@ -479,7 +479,6 @@ Library Vectorise.Utils.Closure Vectorise.Utils.Hoisting Vectorise.Utils.PADict - Vectorise.Utils.PRDict Vectorise.Utils.Poly Vectorise.Utils Vectorise.Type.Env @@ -487,7 +486,6 @@ Library Vectorise.Type.PData Vectorise.Type.PRepr Vectorise.Type.PADict - Vectorise.Type.PRDict Vectorise.Type.Type Vectorise.Type.TyConDecl Vectorise.Type.Classify diff --git a/compiler/vectorise/Vectorise.hs b/compiler/vectorise/Vectorise.hs index b4b383e..5e45c97 100644 --- a/compiler/vectorise/Vectorise.hs +++ b/compiler/vectorise/Vectorise.hs @@ -18,7 +18,6 @@ import CoreSyn import CoreUnfold ( mkInlineUnfolding ) import CoreFVs import CoreMonad ( CoreM, getHscEnv ) -import FamInstEnv ( extendFamInstEnvList ) import Var import Id import OccName @@ -62,9 +61,7 @@ vectModule guts -- TODO: What new binds do we get back here? (types', fam_insts, tc_binds) <- vectTypeEnv (mg_types guts) - -- TODO: What is this? - let fam_inst_env' = extendFamInstEnvList (mg_fam_inst_env guts) fam_insts - updGEnv (setFamInstEnv fam_inst_env') + (_, fam_inst_env) <- readGEnv global_fam_inst_env -- dicts <- mapM buildPADict pa_insts -- workers <- mapM vectDataConWorkers pa_insts @@ -74,7 +71,7 @@ vectModule guts return $ guts { mg_types = types' , mg_binds = Rec tc_binds : binds' - , mg_fam_inst_env = fam_inst_env' + , mg_fam_inst_env = fam_inst_env , mg_fam_insts = mg_fam_insts guts ++ fam_insts } diff --git a/compiler/vectorise/Vectorise/Builtins/Base.hs b/compiler/vectorise/Vectorise/Builtins/Base.hs index 5e4d47d..69ae84f 100644 --- a/compiler/vectorise/Vectorise/Builtins/Base.hs +++ b/compiler/vectorise/Vectorise/Builtins/Base.hs @@ -61,10 +61,12 @@ data Builtins , parrayTyCon :: TyCon -- ^ PArray , parrayDataCon :: DataCon -- ^ PArray , pdataTyCon :: TyCon -- ^ PData + , paClass :: Class -- ^ PA , paTyCon :: TyCon -- ^ PA , paDataCon :: DataCon -- ^ PA , paPRSel :: Var -- ^ PA , preprTyCon :: TyCon -- ^ PRepr + , prClass :: Class -- ^ PR , prTyCon :: TyCon -- ^ PR , prDataCon :: DataCon -- ^ PR , replicatePDVar :: Var -- ^ replicatePD diff --git a/compiler/vectorise/Vectorise/Builtins/Initialise.hs b/compiler/vectorise/Vectorise/Builtins/Initialise.hs index d9a1f0d..9e78f11 100644 --- a/compiler/vectorise/Vectorise/Builtins/Initialise.hs +++ b/compiler/vectorise/Vectorise/Builtins/Initialise.hs @@ -46,14 +46,15 @@ initBuiltins pkg let [parrayDataCon] = tyConDataCons parrayTyCon pdataTyCon <- externalTyCon dph_PArray (fsLit "PData") - pa <- externalClass dph_PArray (fsLit "PA") - let paTyCon = classTyCon pa + paClass <- externalClass dph_PArray (fsLit "PA") + let paTyCon = classTyCon paClass [paDataCon] = tyConDataCons paTyCon - paPRSel = classSCSelId pa 0 + paPRSel = classSCSelId paClass 0 preprTyCon <- externalTyCon dph_PArray (fsLit "PRepr") - prTyCon <- externalClassTyCon dph_PArray (fsLit "PR") - let [prDataCon] = tyConDataCons prTyCon + prClass <- externalClass dph_PArray (fsLit "PR") + let prTyCon = classTyCon prClass + [prDataCon] = tyConDataCons prTyCon closureTyCon <- externalTyCon dph_Closure (fsLit ":->") @@ -127,10 +128,12 @@ initBuiltins pkg , parrayTyCon = parrayTyCon , parrayDataCon = parrayDataCon , pdataTyCon = pdataTyCon + , paClass = paClass , paTyCon = paTyCon , paDataCon = paDataCon , paPRSel = paPRSel , preprTyCon = preprTyCon + , prClass = prClass , prTyCon = prTyCon , prDataCon = prDataCon , voidTyCon = voidTyCon @@ -308,9 +311,3 @@ externalClass :: Module -> FastString -> DsM Class externalClass mod fs = dsLookupClass =<< lookupOrig mod (mkClsOccFS fs) - --- | Like `externalClass`, but get the TyCon of of the class. -externalClassTyCon :: Module -> FastString -> DsM TyCon -externalClassTyCon mod fs = liftM classTyCon (externalClass mod fs) - - diff --git a/compiler/vectorise/Vectorise/Env.hs b/compiler/vectorise/Vectorise/Env.hs index 30f259b..70ed8c4 100644 --- a/compiler/vectorise/Vectorise/Env.hs +++ b/compiler/vectorise/Vectorise/Env.hs @@ -11,7 +11,8 @@ module Vectorise.Env ( initGlobalEnv, extendImportedVarsEnv, extendScalars, - setFamInstEnv, + setFamEnv, + extendFamEnv, extendTyConsEnv, extendDataConsEnv, extendPAFunsEnv, @@ -142,11 +143,16 @@ extendScalars vs genv -- | Set the list of type family instances in an environment. -setFamInstEnv :: FamInstEnv -> GlobalEnv -> GlobalEnv -setFamInstEnv l_fam_inst genv +setFamEnv :: FamInstEnv -> GlobalEnv -> GlobalEnv +setFamEnv l_fam_inst genv = genv { global_fam_inst_env = (g_fam_inst, l_fam_inst) } where (g_fam_inst, _) = global_fam_inst_env genv +extendFamEnv :: [FamInst] -> GlobalEnv -> GlobalEnv +extendFamEnv new genv + = genv { global_fam_inst_env = (g_fam_inst, extendFamInstEnvList l_fam_inst new) } + where (g_fam_inst, l_fam_inst) = global_fam_inst_env genv + -- | Extend the list of type constructors in an environment. extendTyConsEnv :: [(Name, TyCon)] -> GlobalEnv -> GlobalEnv diff --git a/compiler/vectorise/Vectorise/Monad/Base.hs b/compiler/vectorise/Vectorise/Monad/Base.hs index 98da3fe..c2c314f 100644 --- a/compiler/vectorise/Vectorise/Monad/Base.hs +++ b/compiler/vectorise/Vectorise/Monad/Base.hs @@ -77,7 +77,6 @@ maybeCantVectoriseM s d p Just x -> return x Nothing -> cantVectorise s d - -- Control -------------------------------------------------------------------- -- | Return some result saying we've failed. noV :: VM a diff --git a/compiler/vectorise/Vectorise/Type/Env.hs b/compiler/vectorise/Vectorise/Type/Env.hs index 99c1746..61a52bc 100644 --- a/compiler/vectorise/Vectorise/Type/Env.hs +++ b/compiler/vectorise/Vectorise/Type/Env.hs @@ -82,6 +82,13 @@ vectTypeEnv env let vect_tcs = filter (not . isClassTyCon) $ keep_tcs ++ new_tcs + reprs <- mapM tyConRepr vect_tcs + repr_tcs <- zipWith3M buildPReprTyCon orig_tcs vect_tcs reprs + pdata_tcs <- zipWith3M buildPDataTyCon orig_tcs vect_tcs reprs + updGEnv $ extendFamEnv + $ map mkLocalFamInst + $ repr_tcs ++ pdata_tcs + -- Create PRepr and PData instances for the vectorised types. -- We get back the binds for the instance functions, -- and some new type constructors for the representation types. @@ -89,8 +96,6 @@ vectTypeEnv env do defTyConPAs (zipLazy vect_tcs dfuns') reprs <- mapM tyConRepr vect_tcs - repr_tcs <- zipWith3M buildPReprTyCon orig_tcs vect_tcs reprs - pdata_tcs <- zipWith3M buildPDataTyCon orig_tcs vect_tcs reprs dfuns <- sequence $ zipWith5 buildTyConBindings diff --git a/compiler/vectorise/Vectorise/Type/PADict.hs b/compiler/vectorise/Vectorise/Type/PADict.hs index ed6264a..4c786cf 100644 --- a/compiler/vectorise/Vectorise/Type/PADict.hs +++ b/compiler/vectorise/Vectorise/Type/PADict.hs @@ -6,7 +6,6 @@ import Vectorise.Monad import Vectorise.Builtins import Vectorise.Type.Repr import Vectorise.Type.PRepr -import Vectorise.Type.PRDict import Vectorise.Utils import BasicTypes @@ -19,7 +18,8 @@ import TypeRep import Id import Var import Name -import Outputable +import FastString +-- import Outputable -- debug = False -- dtrace s x = if debug then pprTrace "Vectoris.Type.PADict" s x else x @@ -29,38 +29,52 @@ import Outputable buildPADict :: TyCon -- ^ tycon of the type being vectorised. -> TyCon -- ^ tycon of the type used for the vectorised representation. - -> TyCon -- + -> TyCon -- ^ PRepr instance tycon -> SumRepr -- ^ representation used for the type being vectorised. -> VM Var -- ^ name of the top-level dictionary function. buildPADict vect_tc prepr_tc arr_tc repr = polyAbstract tvs $ \args -> - case args of - (_:_) -> pprPanic "Vectorise.Type.PADict.buildPADict" (text "why do we need superclass dicts?") - [] -> do - -- TODO: I'm forcing args to [] because I'm not sure why we need them. - -- class PA has superclass (PR (PRepr a)) but we're not using - -- the superclass dictionary to build the PA dictionary. + do + -- The superclass dictionary is an argument if the tycon is polymorphic + let mk_super_ty = do + r <- mkPReprType inst_ty + pr_cls <- builtin prClass + return $ PredTy $ ClassP pr_cls [r] + super_tys <- sequence [mk_super_ty | not (null tvs)] + super_args <- mapM (newLocalVar (fsLit "pr")) super_tys + let args' = super_args ++ args + + -- it is constant otherwise + super_consts <- sequence [prDictOfPReprInstTyCon inst_ty prepr_tc [] + | null tvs] -- Get ids for each of the methods in the dictionary. - method_ids <- mapM (method args) paMethods + method_ids <- mapM (method args') paMethods -- Expression to build the dictionary. pa_dc <- builtin paDataCon - let dict = mkLams (tvs ++ args) + let dict = mkLams (tvs ++ args') $ mkConApp pa_dc - $ Type inst_ty : map (method_call args) method_ids + $ Type inst_ty + : map Var super_args ++ super_consts + -- the superclass dictionary is + -- either lambda-bound or + -- constant + ++ map (method_call args') method_ids -- Build the type of the dictionary function. - pa_tc <- builtin paTyCon - let Just pa_cls = tyConClass_maybe pa_tc - + pa_cls <- builtin paClass let dfun_ty = mkForAllTys tvs - $ mkFunTys (map varType args) (PredTy $ ClassP pa_cls [inst_ty]) + $ mkFunTys (map varType args') + (PredTy $ ClassP pa_cls [inst_ty]) -- Set the unfolding for the inliner. raw_dfun <- newExportedVar dfun_name dfun_ty - let dfun_unf = mkDFunUnfolding dfun_ty (map (DFunPolyArg . Var) method_ids) + let dfun_unf = mkDFunUnfolding dfun_ty + $ map (const $ DFunLamArg 0) super_args + ++ map DFunConstArg super_consts + ++ map (DFunPolyArg . Var) method_ids dfun = raw_dfun `setIdUnfolding` dfun_unf `setInlinePragma` dfunInlinePragma @@ -91,8 +105,7 @@ buildPADict vect_tc prepr_tc arr_tc repr paMethods :: [(String, TyCon -> TyCon -> TyCon -> SumRepr -> VM CoreExpr)] -paMethods = [("dictPRepr", buildPRDict), - ("toPRepr", buildToPRepr), +paMethods = [("toPRepr", buildToPRepr), ("fromPRepr", buildFromPRepr), ("toArrPRepr", buildToArrPRepr), ("fromArrPRepr", buildFromArrPRepr)] diff --git a/compiler/vectorise/Vectorise/Type/PRDict.hs b/compiler/vectorise/Vectorise/Type/PRDict.hs deleted file mode 100644 index 1a55116..0000000 --- a/compiler/vectorise/Vectorise/Type/PRDict.hs +++ /dev/null @@ -1,56 +0,0 @@ - -module Vectorise.Type.PRDict - (buildPRDict) -where -import Vectorise.Utils -import Vectorise.Monad -import Vectorise.Builtins -import Vectorise.Type.Repr -import CoreSyn -import CoreUtils -import TyCon -import Type -import Coercion - - - -buildPRDict :: TyCon -> TyCon -> TyCon -> SumRepr -> VM CoreExpr -buildPRDict vect_tc prepr_tc _ r - = do - dict <- sum_dict r - pr_co <- mkBuiltinCo prTyCon - let co = mkAppCoercion pr_co - . mkSymCoercion - $ mkTyConApp arg_co ty_args - return (mkCoerce co dict) - where - ty_args = mkTyVarTys (tyConTyVars vect_tc) - Just arg_co = tyConFamilyCoercion_maybe prepr_tc - - sum_dict EmptySum = prDFunOfTyCon =<< builtin voidTyCon - sum_dict (UnarySum r) = con_dict r - sum_dict (Sum { repr_sum_tc = sum_tc - , repr_con_tys = tys - , repr_cons = cons - }) - = do - dicts <- mapM con_dict cons - dfun <- prDFunOfTyCon sum_tc - return $ dfun `mkTyApps` tys `mkApps` dicts - - con_dict (ConRepr _ r) = prod_dict r - - prod_dict EmptyProd = prDFunOfTyCon =<< builtin voidTyCon - prod_dict (UnaryProd r) = comp_dict r - prod_dict (Prod { repr_tup_tc = tup_tc - , repr_comp_tys = tys - , repr_comps = comps }) - = do - dicts <- mapM comp_dict comps - dfun <- prDFunOfTyCon tup_tc - return $ dfun `mkTyApps` tys `mkApps` dicts - - comp_dict (Keep _ pr) = return pr - comp_dict (Wrap ty) = wrapPR ty - - diff --git a/compiler/vectorise/Vectorise/Type/Repr.hs b/compiler/vectorise/Vectorise/Type/Repr.hs index 40242ae..bb300ca 100644 --- a/compiler/vectorise/Vectorise/Type/Repr.hs +++ b/compiler/vectorise/Vectorise/Type/Repr.hs @@ -82,7 +82,7 @@ tyConRepr tc = sum_repr (tyConDataCons tc) where arity = length tys - comp_repr ty = liftM (Keep ty) (prDictOfType ty) + comp_repr ty = liftM (Keep ty) (prDictOfReprType ty) `orElseV` return (Wrap ty) sumReprType :: SumRepr -> VM Type diff --git a/compiler/vectorise/Vectorise/Utils.hs b/compiler/vectorise/Vectorise/Utils.hs index e701383..1a099e3 100644 --- a/compiler/vectorise/Vectorise/Utils.hs +++ b/compiler/vectorise/Vectorise/Utils.hs @@ -4,7 +4,6 @@ module Vectorise.Utils ( module Vectorise.Utils.Closure, module Vectorise.Utils.Hoisting, module Vectorise.Utils.PADict, - module Vectorise.Utils.PRDict, module Vectorise.Utils.Poly, -- * Annotated Exprs @@ -28,7 +27,6 @@ import Vectorise.Utils.Base import Vectorise.Utils.Closure import Vectorise.Utils.Hoisting import Vectorise.Utils.PADict -import Vectorise.Utils.PRDict import Vectorise.Utils.Poly import Vectorise.Monad import Vectorise.Builtins diff --git a/compiler/vectorise/Vectorise/Utils/PADict.hs b/compiler/vectorise/Vectorise/Utils/PADict.hs index 93f2297..329cb63 100644 --- a/compiler/vectorise/Vectorise/Utils/PADict.hs +++ b/compiler/vectorise/Vectorise/Utils/PADict.hs @@ -2,7 +2,9 @@ module Vectorise.Utils.PADict ( paDictArgType, paDictOfType, - paMethod + paMethod, + prDictOfReprType, + prDictOfPReprInstTyCon ) where import Vectorise.Monad @@ -42,7 +44,9 @@ paDictArgType tv = go (TyVarTy tv) (tyVarKind tv) go ty k | isLiftedTypeKind k - = liftM Just (mkBuiltinTyConApp paTyCon [ty]) + = do + pa_cls <- builtin paClass + return $ Just $ PredTy $ ClassP pa_cls [ty] go _ _ = return Nothing @@ -108,17 +112,36 @@ prDictOfPReprInst :: Type -> VM CoreExpr prDictOfPReprInst ty = do (prepr_tc, prepr_args) <- preprSynTyCon ty - case coreView (mkTyConApp prepr_tc prepr_args) of - Just rhs -> do - dict <- prDictOfReprType rhs - pr_co <- mkBuiltinCo prTyCon - let Just arg_co = tyConFamilyCoercion_maybe prepr_tc - let co = mkAppCoercion pr_co - $ mkSymCoercion - $ mkTyConApp arg_co prepr_args - return $ mkCoerce co dict - Nothing -> cantVectorise "Invalid PRepr type instance" - $ ppr ty + prDictOfPReprInstTyCon ty prepr_tc prepr_args + +-- | Given a type @ty@, its PRepr synonym tycon and its type arguments, +-- return the PR @PRepr ty@. Suppose we have: +-- +-- > type instance PRepr (T a1 ... an) = t +-- +-- which is internally translated into +-- +-- > type :R:PRepr a1 ... an = t +-- +-- and the corresponding coercion. Then, +-- +-- > prDictOfPReprInstTyCon (T a1 ... an) :R:PRepr u1 ... un = PR (T u1 ... un) +-- +-- Note that @ty@ is only used for error messages +-- +prDictOfPReprInstTyCon :: Type -> TyCon -> [Type] -> VM CoreExpr +prDictOfPReprInstTyCon ty prepr_tc prepr_args + | Just rhs <- coreView (mkTyConApp prepr_tc prepr_args) + = do + dict <- prDictOfReprType' rhs + pr_co <- mkBuiltinCo prTyCon + let Just arg_co = tyConFamilyCoercion_maybe prepr_tc + let co = mkAppCoercion pr_co + $ mkSymCoercion + $ mkTyConApp arg_co prepr_args + return $ mkCoerce co dict + + | otherwise = cantVectorise "Invalid PRepr type instance" (ppr ty) -- | Get the PR dictionary for a type. The argument must be a representation -- type. @@ -129,14 +152,13 @@ prDictOfReprType ty prepr <- builtin preprTyCon if tycon == prepr then do - [ty'] <- return tyargs - prDictOfPReprInst ty' + let [ty'] = tyargs + pa <- paDictOfType ty' + sel <- builtin paPRSel + return $ Var sel `App` Type ty' `App` pa else do -- a representation tycon must have a PR instance - dfun <- maybeCantVectoriseM - "No PR dictionary for type constructor" - (ppr tycon <+> text "in" <+> ppr ty) - $ lookupTyConPR tycon + dfun <- maybeV $ lookupTyConPR tycon prDFunApply dfun tyargs | otherwise @@ -153,6 +175,11 @@ prDictOfReprType ty prsel <- builtin paPRSel return $ Var prsel `mkApps` [Type ty, pa] +prDictOfReprType' :: Type -> VM CoreExpr +prDictOfReprType' ty = prDictOfReprType ty `orElseV` + cantVectorise "No PR dictionary for representation type" + (ppr ty) + -- | Apply a tycon's PR dfun to dictionary arguments (PR or PA) corresponding -- to the argument types. prDFunApply :: Var -> [Type] -> VM CoreExpr diff --git a/compiler/vectorise/Vectorise/Utils/PRDict.hs b/compiler/vectorise/Vectorise/Utils/PRDict.hs deleted file mode 100644 index a5d09df..0000000 --- a/compiler/vectorise/Vectorise/Utils/PRDict.hs +++ /dev/null @@ -1,43 +0,0 @@ - -module Vectorise.Utils.PRDict ( - prDictOfType, - wrapPR -) -where -import Vectorise.Monad -import Vectorise.Builtins -import Vectorise.Utils.Base -import Vectorise.Utils.PADict - -import CoreSyn -import Type -import TypeRep -import Control.Monad - - -prDictOfType :: Type -> VM CoreExpr -prDictOfType ty = prDictOfTyApp ty_fn ty_args - where - (ty_fn, ty_args) = splitAppTys ty - -prDictOfTyApp :: Type -> [Type] -> VM CoreExpr -prDictOfTyApp ty_fn ty_args - | Just ty_fn' <- coreView ty_fn = prDictOfTyApp ty_fn' ty_args -prDictOfTyApp (TyConApp tc _) ty_args - = do - dfun <- liftM Var $ maybeV (lookupTyConPR tc) - prDFunApply dfun ty_args -prDictOfTyApp _ _ = noV - -prDFunApply :: CoreExpr -> [Type] -> VM CoreExpr -prDFunApply dfun tys - = do - dicts <- mapM prDictOfType tys - return $ mkApps (mkTyApps dfun tys) dicts - -wrapPR :: Type -> VM CoreExpr -wrapPR ty - = do - pa_dict <- paDictOfType ty - pr_dfun <- prDFunOfTyCon =<< builtin wrapTyCon - return $ mkApps pr_dfun [Type ty, pa_dict] -- 1.7.10.4