From a52f14894e48d47e62b5b33f7d7f4b3f2cc88a79 Mon Sep 17 00:00:00 2001 From: Roman Leshchinskiy Date: Thu, 23 Aug 2007 06:09:45 +0000 Subject: [PATCH] Use n-ary sums and products for NDP's generic representation Originally, we wanted to only use binary ones, at least initially. But this would a lot of fiddling with selectors when converting to/from generic array representations. This is both inefficient and hard to implement. Instead, we will limit the arity of our sums/product representation to, say, 16 (it's 3 at the moment) and initially refuse to vectorise programs for which this is not sufficient. This allows us to implement everything in the library. Later, we can implement the necessary splitting. --- compiler/prelude/PrelNames.lhs | 11 ++- compiler/vectorise/VectBuiltIn.hs | 80 +++++++++++++-------- compiler/vectorise/VectMonad.hs | 9 ++- compiler/vectorise/VectUtils.hs | 144 ++++++++++++++++++------------------- 4 files changed, 132 insertions(+), 112 deletions(-) diff --git a/compiler/prelude/PrelNames.lhs b/compiler/prelude/PrelNames.lhs index 8de554d..9839290 100644 --- a/compiler/prelude/PrelNames.lhs +++ b/compiler/prelude/PrelNames.lhs @@ -218,7 +218,7 @@ genericTyConNames = [crossTyConName, plusTyConName, genUnitTyConName] ndpNames :: [Name] ndpNames = [ parrayTyConName, paTyConName, preprTyConName, prTyConName - , ndpCrossTyConName, ndpPlusTyConName, embedTyConName + , embedTyConName , closureTyConName , mkClosureName, applyClosureName , mkClosurePName, applyClosurePName @@ -276,6 +276,7 @@ rANDOM = mkBaseModule FSLIT("System.Random") gLA_EXTS = mkBaseModule FSLIT("GHC.Exts") nDP_PARRAY = mkNDPModule FSLIT("Data.Array.Parallel.Lifted.PArray") +nDP_REPR = mkNDPModule FSLIT("Data.Array.Parallel.Lifted.Repr") nDP_UTILS = mkNDPModule FSLIT("Data.Array.Parallel.Lifted.Utils") nDP_CLOSURE = mkNDPModule FSLIT("Data.Array.Parallel.Lifted.Closure") nDP_INSTANCES = mkNDPModule FSLIT("Data.Array.Parallel.Lifted.Instances") @@ -697,9 +698,7 @@ parrayTyConName = tcQual nDP_PARRAY FSLIT("PArray") parrayTyConKey paTyConName = tcQual nDP_PARRAY FSLIT("PA") paTyConKey preprTyConName = tcQual nDP_PARRAY FSLIT("PRepr") preprTyConKey prTyConName = clsQual nDP_PARRAY FSLIT("PR") prTyConKey -ndpCrossTyConName = tcQual nDP_PARRAY FSLIT(":*:") ndpCrossTyConKey -ndpPlusTyConName = tcQual nDP_PARRAY FSLIT(":+:") ndpPlusTyConKey -embedTyConName = tcQual nDP_PARRAY FSLIT("Embed") embedTyConKey +embedTyConName = tcQual nDP_REPR FSLIT("Embed") embedTyConKey lengthPAName = varQual nDP_PARRAY FSLIT("lengthPA") lengthPAIdKey replicatePAName = varQual nDP_PARRAY FSLIT("replicatePA") replicatePAIdKey emptyPAName = varQual nDP_PARRAY FSLIT("emptyPA") emptyPAIdKey @@ -895,9 +894,7 @@ closureTyConKey = mkPreludeTyConUnique 136 paTyConKey = mkPreludeTyConUnique 137 preprTyConKey = mkPreludeTyConUnique 138 embedTyConKey = mkPreludeTyConUnique 139 -ndpCrossTyConKey = mkPreludeTyConUnique 140 -ndpPlusTyConKey = mkPreludeTyConUnique 141 -prTyConKey = mkPreludeTyConUnique 142 +prTyConKey = mkPreludeTyConUnique 140 ---------------- Template Haskell ------------------- diff --git a/compiler/vectorise/VectBuiltIn.hs b/compiler/vectorise/VectBuiltIn.hs index e6c65ac..0afef5b 100644 --- a/compiler/vectorise/VectBuiltIn.hs +++ b/compiler/vectorise/VectBuiltIn.hs @@ -1,5 +1,5 @@ module VectBuiltIn ( - Builtins(..), + Builtins(..), sumTyCon, prodTyCon, initBuiltins, initBuiltinTyCons, initBuiltinPAs, initBuiltinPRs ) where @@ -14,7 +14,7 @@ import TyCon ( TyCon, tyConName, tyConDataCons ) import Var ( Var ) import Id ( mkSysLocal ) import Name ( Name ) -import OccName ( mkVarOccFS ) +import OccName ( mkVarOccFS, mkOccNameFS, tcName ) import TypeRep ( funTyCon ) import TysPrim ( intPrimTy ) @@ -23,9 +23,17 @@ import PrelNames import BasicTypes ( Boxity(..) ) import FastString +import Outputable +import Data.Array import Control.Monad ( liftM, zipWithM ) +mAX_NDP_PROD :: Int +mAX_NDP_PROD = 3 + +mAX_NDP_SUM :: Int +mAX_NDP_SUM = 3 + data Builtins = Builtins { parrayTyCon :: TyCon , paTyCon :: TyCon @@ -35,11 +43,7 @@ data Builtins = Builtins { , prDataCon :: DataCon , embedTyCon :: TyCon , embedDataCon :: DataCon - , crossTyCon :: TyCon - , crossDataCon :: DataCon - , plusTyCon :: TyCon - , leftDataCon :: DataCon - , rightDataCon :: DataCon + , sumTyCons :: Array Int TyCon , closureTyCon :: TyCon , mkClosureVar :: Var , applyClosureVar :: Var @@ -54,6 +58,17 @@ data Builtins = Builtins { , liftingContext :: Var } +sumTyCon :: Int -> Builtins -> TyCon +sumTyCon n bi + | n >= 2 && n <= mAX_NDP_SUM = sumTyCons bi ! n + | otherwise = pprPanic "sumTyCon" (ppr n) + +prodTyCon :: Int -> Builtins -> TyCon +prodTyCon n bi + | n >= 2 && n <= mAX_NDP_PROD = tupleTyCon Boxed n + | otherwise = pprPanic "prodTyCon" (ppr n) + + initBuiltins :: DsM Builtins initBuiltins = do @@ -65,12 +80,13 @@ initBuiltins let [prDataCon] = tyConDataCons prTyCon embedTyCon <- dsLookupTyCon embedTyConName let [embedDataCon] = tyConDataCons embedTyCon - crossTyCon <- dsLookupTyCon ndpCrossTyConName - let [crossDataCon] = tyConDataCons crossTyCon - plusTyCon <- dsLookupTyCon ndpPlusTyConName - let [leftDataCon, rightDataCon] = tyConDataCons plusTyCon closureTyCon <- dsLookupTyCon closureTyConName + sum_tcs <- mapM (lookupExternalTyCon nDP_REPR) + [mkFastString ("Sum" ++ show i) | i <- [2..mAX_NDP_SUM]] + + let sumTyCons = listArray (2, mAX_NDP_SUM) sum_tcs + mkClosureVar <- dsLookupGlobalId mkClosureName applyClosureVar <- dsLookupGlobalId applyClosureName mkClosurePVar <- dsLookupGlobalId mkClosurePName @@ -94,11 +110,7 @@ initBuiltins , prDataCon = prDataCon , embedTyCon = embedTyCon , embedDataCon = embedDataCon - , crossTyCon = crossTyCon - , crossDataCon = crossDataCon - , plusTyCon = plusTyCon - , leftDataCon = leftDataCon - , rightDataCon = rightDataCon + , sumTyCons = sumTyCons , closureTyCon = closureTyCon , mkClosureVar = mkClosureVar , applyClosureVar = applyClosureVar @@ -137,7 +149,7 @@ initBuiltinPAs = initBuiltinDicts builtinPAs builtinPAs :: [(Name, Module, FastString)] builtinPAs = [ mk closureTyConName nDP_CLOSURE FSLIT("dPA_Clo") - , mk unitTyConName nDP_PARRAY FSLIT("dPA_Unit") + , mk unitTyConName nDP_INSTANCES FSLIT("dPA_Unit") , mk intTyConName nDP_INSTANCES FSLIT("dPA_Int") ] @@ -150,25 +162,37 @@ builtinPAs = [ nDP_INSTANCES (mkFastString $ "dPA_" ++ show n) -initBuiltinPRs = initBuiltinDicts builtinPRs +initBuiltinPRs = initBuiltinDicts . builtinPRs -builtinPRs :: [(Name, Module, FastString)] -builtinPRs = [ - mk (tyConName unitTyCon) nDP_PARRAY FSLIT("dPR_Unit") - , mk ndpCrossTyConName nDP_PARRAY FSLIT("dPR_Cross") - , mk ndpPlusTyConName nDP_PARRAY FSLIT("dPR_Plus") - , mk embedTyConName nDP_PARRAY FSLIT("dPR_Embed") - , mk closureTyConName nDP_CLOSURE FSLIT("dPR_Clo") +builtinPRs :: Builtins -> [(Name, Module, FastString)] +builtinPRs bi = + [ + mk (tyConName unitTyCon) nDP_REPR FSLIT("dPR_Unit") + , mk embedTyConName nDP_REPR FSLIT("dPR_Embed") + , mk closureTyConName nDP_CLOSURE FSLIT("dPR_Clo") - -- temporary - , mk intTyConName nDP_INSTANCES FSLIT("dPR_Int") - ] + -- temporary + , mk intTyConName nDP_INSTANCES FSLIT("dPR_Int") + ] + + ++ map mk_sum [2..mAX_NDP_SUM] + ++ map mk_prod [2..mAX_NDP_PROD] where mk name mod fs = (name, mod, fs) + mk_sum n = (tyConName $ sumTyCon n bi, nDP_REPR, + mkFastString ("dPR_Sum" ++ show n)) + + mk_prod n = (tyConName $ prodTyCon n bi, nDP_REPR, + mkFastString ("dPR_" ++ show n)) + lookupExternalVar :: Module -> FastString -> DsM Var lookupExternalVar mod fs = dsLookupGlobalId =<< lookupOrig mod (mkVarOccFS fs) +lookupExternalTyCon :: Module -> FastString -> DsM TyCon +lookupExternalTyCon mod fs + = dsLookupTyCon =<< lookupOrig mod (mkOccNameFS tcName fs) + unitTyConName = tyConName unitTyCon diff --git a/compiler/vectorise/VectMonad.hs b/compiler/vectorise/VectMonad.hs index 6bc2f4d..320d192 100644 --- a/compiler/vectorise/VectMonad.hs +++ b/compiler/vectorise/VectMonad.hs @@ -7,8 +7,8 @@ module VectMonad ( cloneName, cloneId, newExportedVar, newLocalVar, newDummyVar, newTyVar, - Builtins(..), - builtin, + Builtins(..), sumTyCon, prodTyCon, + builtin, builtins, GlobalEnv(..), setFamInstEnv, @@ -240,6 +240,9 @@ liftDs p = VM $ \bi genv lenv -> do { x <- p; return (Yes genv lenv x) } builtin :: (Builtins -> a) -> VM a builtin f = VM $ \bi genv lenv -> return (Yes genv lenv (f bi)) +builtins :: (a -> Builtins -> b) -> VM (a -> b) +builtins f = VM $ \bi genv lenv -> return (Yes genv lenv (`f` bi)) + readGEnv :: (GlobalEnv -> a) -> VM a readGEnv f = VM $ \bi genv lenv -> return (Yes genv lenv (f genv)) @@ -454,7 +457,7 @@ initV hsc_env guts info p builtins <- initBuiltins builtin_tycons <- initBuiltinTyCons builtin_pas <- initBuiltinPAs - builtin_prs <- initBuiltinPRs + builtin_prs <- initBuiltinPRs builtins eps <- ioToIOEnv $ hscEPS hsc_env let famInstEnvs = (eps_fam_inst_env eps, mg_fam_inst_env guts) diff --git a/compiler/vectorise/VectUtils.hs b/compiler/vectorise/VectUtils.hs index d1c1dab..0f101bd 100644 --- a/compiler/vectorise/VectUtils.hs +++ b/compiler/vectorise/VectUtils.hs @@ -39,6 +39,7 @@ import BasicTypes ( Boxity(..) ) import Outputable import FastString +import Data.List ( zipWith4 ) import Control.Monad ( liftM, liftM2, zipWithM_ ) collectAnnTypeArgs :: AnnExpr b ann -> (AnnExpr b ann, [Type]) @@ -82,11 +83,13 @@ splitBinTy s name ty | otherwise = pprPanic s (ppr ty) -splitCrossTy :: Type -> (Type, Type) -splitCrossTy = splitBinTy "splitCrossTy" ndpCrossTyConName +splitFixedTyConApp :: TyCon -> Type -> [Type] +splitFixedTyConApp tc ty + | Just (tc', tys) <- splitTyConApp_maybe ty + , tc == tc' + = tys -splitPlusTy :: Type -> (Type, Type) -splitPlusTy = splitBinTy "splitSumTy" ndpPlusTyConName + | otherwise = pprPanic "splitFixedTyConApp" (ppr tc <+> ppr ty) splitEmbedTy :: Type -> Type splitEmbedTy = splitUnTy "splitEmbedTy" embedTyConName @@ -123,25 +126,24 @@ mkBuiltinTyConApps1 get_tc dft tys mk tc ty1 ty2 = mkTyConApp tc [ty1,ty2] mkPRepr :: [[Type]] -> VM Type -mkPRepr [] = return unitTy mkPRepr tys = do - embed <- builtin embedTyCon - cross <- builtin crossTyCon - plus <- builtin plusTyCon + embed_tc <- builtin embedTyCon + sum_tcs <- builtins sumTyCon + prod_tcs <- builtins prodTyCon - let mk_embed ty = mkTyConApp embed [ty] - mk_cross ty1 ty2 = mkTyConApp cross [ty1, ty2] - mk_plus ty1 ty2 = mkTyConApp plus [ty1, ty2] + let mk_sum [] = unitTy + mk_sum [ty] = ty + mk_sum tys = mkTyConApp (sum_tcs $ length tys) tys - mk_tup [] = unitTy - mk_tup tys = foldr1 mk_cross tys + mk_prod [] = unitTy + mk_prod [ty] = ty + mk_prod tys = mkTyConApp (prod_tcs $ length tys) tys - mk_sum [] = unitTy - mk_sum tys = foldr1 mk_plus tys + mk_embed ty = mkTyConApp embed_tc [ty] return . mk_sum - . map (mk_tup . map mk_embed) + . map (mk_prod . map mk_embed) $ tys mkToPRepr :: [[CoreExpr]] -> VM ([CoreExpr], Type) @@ -149,79 +151,73 @@ mkToPRepr ess = do embed_tc <- builtin embedTyCon embed_dc <- builtin embedDataCon - cross_tc <- builtin crossTyCon - cross_dc <- builtin crossDataCon - plus_tc <- builtin plusTyCon - left_dc <- builtin leftDataCon - right_dc <- builtin rightDataCon - - let mk_embed expr - = (mkConApp embed_dc [Type ty, expr], - mkTyConApp embed_tc [ty]) - where ty = exprType expr - - mk_cross (expr1, ty1) (expr2, ty2) - = (mkConApp cross_dc [Type ty1, Type ty2, expr1, expr2], - mkTyConApp cross_tc [ty1, ty2]) - - mk_tup [] = (Var unitDataConId, unitTy) - mk_tup es = foldr1 mk_cross es + sum_tcs <- builtins sumTyCon + prod_tcs <- builtins prodTyCon - mk_sum [] = ([Var unitDataConId], unitTy) + let mk_sum [] = ([Var unitDataConId], unitTy) mk_sum [(expr, ty)] = ([expr], ty) - mk_sum ((expr, lty) : es) - = let (alts, rty) = mk_sum es - in - (mkConApp left_dc [Type lty, Type rty, expr] - : [mkConApp right_dc [Type lty, Type rty, alt] | alt <- alts], - mkTyConApp plus_tc [lty, rty]) + mk_sum es = (zipWith mk_alt (tyConDataCons sum_tc) exprs, + mkTyConApp sum_tc tys) + where + (exprs, tys) = unzip es + sum_tc = sum_tcs (length es) + mk_alt dc expr = mkConApp dc (map Type tys ++ [expr]) + + mk_prod [] = (Var unitDataConId, unitTy) + mk_prod [(expr, ty)] = (expr, ty) + mk_prod es = (mkConApp prod_dc (map Type tys ++ exprs), + mkTyConApp prod_tc tys) + where + (exprs, tys) = unzip es + prod_tc = prod_tcs (length es) + [prod_dc] = tyConDataCons prod_tc + + mk_embed expr = (mkConApp embed_dc [Type ty, expr], + mkTyConApp embed_tc [ty]) + where ty = exprType expr - return . mk_sum $ map (mk_tup . map mk_embed) ess + return . mk_sum $ map (mk_prod . map mk_embed) ess mkFromPRepr :: CoreExpr -> Type -> [([Var], CoreExpr)] -> VM CoreExpr mkFromPRepr scrut res_ty alts = do embed_dc <- builtin embedDataCon - cross_dc <- builtin crossDataCon - left_dc <- builtin leftDataCon - right_dc <- builtin rightDataCon - pa_tc <- builtin paTyCon + sum_tcs <- builtins sumTyCon + prod_tcs <- builtins prodTyCon - let un_embed expr ty var res - = Case expr (mkWildId ty) res_ty - [(DataAlt embed_dc, [var], res)] - - un_cross expr ty var1 var2 res - = Case expr (mkWildId ty) res_ty - [(DataAlt cross_dc, [var1, var2], res)] - - un_tup expr ty [] res = return res - un_tup expr ty [var] res = return $ un_embed expr ty var res - un_tup expr ty (var : vars) res + let un_sum expr ty [(vars, res)] = un_prod expr ty vars res + un_sum expr ty bs = do - lv <- newLocalVar FSLIT("x") lty - rv <- newLocalVar FSLIT("y") rty - liftM (un_cross expr ty lv rv - . un_embed (Var lv) lty var) - (un_tup (Var rv) rty vars res) + ps <- mapM (newLocalVar FSLIT("p")) tys + bodies <- sequence + $ zipWith4 un_prod (map Var ps) tys vars rs + return . Case expr (mkWildId ty) res_ty + $ zipWith3 mk_alt sum_dcs ps bodies where - (lty, rty) = splitCrossTy ty + (vars, rs) = unzip bs + tys = splitFixedTyConApp sum_tc ty + sum_tc = sum_tcs $ length bs + sum_dcs = tyConDataCons sum_tc - un_plus expr ty var1 var2 res1 res2 - = Case expr (mkWildId ty) res_ty - [(DataAlt left_dc, [var1], res1), - (DataAlt right_dc, [var2], res2)] + mk_alt dc p body = (DataAlt dc, [p], body) - un_sum expr ty [(vars, res)] = un_tup expr ty vars res - un_sum expr ty ((vars, res) : alts) + un_prod expr ty [] r = return r + un_prod expr ty [var] r = return $ un_embed expr ty var r + un_prod expr ty vars r = do - lv <- newLocalVar FSLIT("l") lty - rv <- newLocalVar FSLIT("r") rty - liftM2 (un_plus expr ty lv rv) - (un_tup (Var lv) lty vars res) - (un_sum (Var rv) rty alts) + xs <- mapM (newLocalVar FSLIT("x")) tys + let body = foldr (\(e,t,v) r -> un_embed e t v r) r + $ zip3 (map Var xs) tys vars + return $ Case expr (mkWildId ty) res_ty + [(DataAlt prod_dc, xs, body)] where - (lty, rty) = splitPlusTy ty + tys = splitFixedTyConApp prod_tc ty + prod_tc = prod_tcs $ length vars + [prod_dc] = tyConDataCons prod_tc + + un_embed expr ty var r + = Case expr (mkWildId ty) res_ty + [(DataAlt embed_dc, [var], r)] un_sum scrut (exprType scrut) alts -- 1.7.10.4