ndpNames :: [Name]
ndpNames = [ parrayTyConName, paTyConName, preprTyConName, prTyConName
- , ndpCrossTyConName, ndpPlusTyConName, embedTyConName
+ , embedTyConName
, closureTyConName
, mkClosureName, applyClosureName
, mkClosurePName, applyClosurePName
gLA_EXTS = mkBaseModule FSLIT("GHC.Exts")
nDP_PARRAY = mkNDPModule FSLIT("Data.Array.Parallel.Lifted.PArray")
+nDP_REPR = mkNDPModule FSLIT("Data.Array.Parallel.Lifted.Repr")
nDP_UTILS = mkNDPModule FSLIT("Data.Array.Parallel.Lifted.Utils")
nDP_CLOSURE = mkNDPModule FSLIT("Data.Array.Parallel.Lifted.Closure")
nDP_INSTANCES = mkNDPModule FSLIT("Data.Array.Parallel.Lifted.Instances")
paTyConName = tcQual nDP_PARRAY FSLIT("PA") paTyConKey
preprTyConName = tcQual nDP_PARRAY FSLIT("PRepr") preprTyConKey
prTyConName = clsQual nDP_PARRAY FSLIT("PR") prTyConKey
-ndpCrossTyConName = tcQual nDP_PARRAY FSLIT(":*:") ndpCrossTyConKey
-ndpPlusTyConName = tcQual nDP_PARRAY FSLIT(":+:") ndpPlusTyConKey
-embedTyConName = tcQual nDP_PARRAY FSLIT("Embed") embedTyConKey
+embedTyConName = tcQual nDP_REPR FSLIT("Embed") embedTyConKey
lengthPAName = varQual nDP_PARRAY FSLIT("lengthPA") lengthPAIdKey
replicatePAName = varQual nDP_PARRAY FSLIT("replicatePA") replicatePAIdKey
emptyPAName = varQual nDP_PARRAY FSLIT("emptyPA") emptyPAIdKey
paTyConKey = mkPreludeTyConUnique 137
preprTyConKey = mkPreludeTyConUnique 138
embedTyConKey = mkPreludeTyConUnique 139
-ndpCrossTyConKey = mkPreludeTyConUnique 140
-ndpPlusTyConKey = mkPreludeTyConUnique 141
-prTyConKey = mkPreludeTyConUnique 142
+prTyConKey = mkPreludeTyConUnique 140
---------------- Template Haskell -------------------
module VectBuiltIn (
- Builtins(..),
+ Builtins(..), sumTyCon, prodTyCon,
initBuiltins, initBuiltinTyCons, initBuiltinPAs, initBuiltinPRs
) where
import Var ( Var )
import Id ( mkSysLocal )
import Name ( Name )
-import OccName ( mkVarOccFS )
+import OccName ( mkVarOccFS, mkOccNameFS, tcName )
import TypeRep ( funTyCon )
import TysPrim ( intPrimTy )
import BasicTypes ( Boxity(..) )
import FastString
+import Outputable
+import Data.Array
import Control.Monad ( liftM, zipWithM )
+mAX_NDP_PROD :: Int
+mAX_NDP_PROD = 3
+
+mAX_NDP_SUM :: Int
+mAX_NDP_SUM = 3
+
data Builtins = Builtins {
parrayTyCon :: TyCon
, paTyCon :: TyCon
, prDataCon :: DataCon
, embedTyCon :: TyCon
, embedDataCon :: DataCon
- , crossTyCon :: TyCon
- , crossDataCon :: DataCon
- , plusTyCon :: TyCon
- , leftDataCon :: DataCon
- , rightDataCon :: DataCon
+ , sumTyCons :: Array Int TyCon
, closureTyCon :: TyCon
, mkClosureVar :: Var
, applyClosureVar :: Var
, liftingContext :: Var
}
+sumTyCon :: Int -> Builtins -> TyCon
+sumTyCon n bi
+ | n >= 2 && n <= mAX_NDP_SUM = sumTyCons bi ! n
+ | otherwise = pprPanic "sumTyCon" (ppr n)
+
+prodTyCon :: Int -> Builtins -> TyCon
+prodTyCon n bi
+ | n >= 2 && n <= mAX_NDP_PROD = tupleTyCon Boxed n
+ | otherwise = pprPanic "prodTyCon" (ppr n)
+
+
initBuiltins :: DsM Builtins
initBuiltins
= do
let [prDataCon] = tyConDataCons prTyCon
embedTyCon <- dsLookupTyCon embedTyConName
let [embedDataCon] = tyConDataCons embedTyCon
- crossTyCon <- dsLookupTyCon ndpCrossTyConName
- let [crossDataCon] = tyConDataCons crossTyCon
- plusTyCon <- dsLookupTyCon ndpPlusTyConName
- let [leftDataCon, rightDataCon] = tyConDataCons plusTyCon
closureTyCon <- dsLookupTyCon closureTyConName
+ sum_tcs <- mapM (lookupExternalTyCon nDP_REPR)
+ [mkFastString ("Sum" ++ show i) | i <- [2..mAX_NDP_SUM]]
+
+ let sumTyCons = listArray (2, mAX_NDP_SUM) sum_tcs
+
mkClosureVar <- dsLookupGlobalId mkClosureName
applyClosureVar <- dsLookupGlobalId applyClosureName
mkClosurePVar <- dsLookupGlobalId mkClosurePName
, prDataCon = prDataCon
, embedTyCon = embedTyCon
, embedDataCon = embedDataCon
- , crossTyCon = crossTyCon
- , crossDataCon = crossDataCon
- , plusTyCon = plusTyCon
- , leftDataCon = leftDataCon
- , rightDataCon = rightDataCon
+ , sumTyCons = sumTyCons
, closureTyCon = closureTyCon
, mkClosureVar = mkClosureVar
, applyClosureVar = applyClosureVar
builtinPAs :: [(Name, Module, FastString)]
builtinPAs = [
mk closureTyConName nDP_CLOSURE FSLIT("dPA_Clo")
- , mk unitTyConName nDP_PARRAY FSLIT("dPA_Unit")
+ , mk unitTyConName nDP_INSTANCES FSLIT("dPA_Unit")
, mk intTyConName nDP_INSTANCES FSLIT("dPA_Int")
]
nDP_INSTANCES
(mkFastString $ "dPA_" ++ show n)
-initBuiltinPRs = initBuiltinDicts builtinPRs
+initBuiltinPRs = initBuiltinDicts . builtinPRs
-builtinPRs :: [(Name, Module, FastString)]
-builtinPRs = [
- mk (tyConName unitTyCon) nDP_PARRAY FSLIT("dPR_Unit")
- , mk ndpCrossTyConName nDP_PARRAY FSLIT("dPR_Cross")
- , mk ndpPlusTyConName nDP_PARRAY FSLIT("dPR_Plus")
- , mk embedTyConName nDP_PARRAY FSLIT("dPR_Embed")
- , mk closureTyConName nDP_CLOSURE FSLIT("dPR_Clo")
+builtinPRs :: Builtins -> [(Name, Module, FastString)]
+builtinPRs bi =
+ [
+ mk (tyConName unitTyCon) nDP_REPR FSLIT("dPR_Unit")
+ , mk embedTyConName nDP_REPR FSLIT("dPR_Embed")
+ , mk closureTyConName nDP_CLOSURE FSLIT("dPR_Clo")
- -- temporary
- , mk intTyConName nDP_INSTANCES FSLIT("dPR_Int")
- ]
+ -- temporary
+ , mk intTyConName nDP_INSTANCES FSLIT("dPR_Int")
+ ]
+
+ ++ map mk_sum [2..mAX_NDP_SUM]
+ ++ map mk_prod [2..mAX_NDP_PROD]
where
mk name mod fs = (name, mod, fs)
+ mk_sum n = (tyConName $ sumTyCon n bi, nDP_REPR,
+ mkFastString ("dPR_Sum" ++ show n))
+
+ mk_prod n = (tyConName $ prodTyCon n bi, nDP_REPR,
+ mkFastString ("dPR_" ++ show n))
+
lookupExternalVar :: Module -> FastString -> DsM Var
lookupExternalVar mod fs
= dsLookupGlobalId =<< lookupOrig mod (mkVarOccFS fs)
+lookupExternalTyCon :: Module -> FastString -> DsM TyCon
+lookupExternalTyCon mod fs
+ = dsLookupTyCon =<< lookupOrig mod (mkOccNameFS tcName fs)
+
unitTyConName = tyConName unitTyCon
import Outputable
import FastString
+import Data.List ( zipWith4 )
import Control.Monad ( liftM, liftM2, zipWithM_ )
collectAnnTypeArgs :: AnnExpr b ann -> (AnnExpr b ann, [Type])
| otherwise = pprPanic s (ppr ty)
-splitCrossTy :: Type -> (Type, Type)
-splitCrossTy = splitBinTy "splitCrossTy" ndpCrossTyConName
+splitFixedTyConApp :: TyCon -> Type -> [Type]
+splitFixedTyConApp tc ty
+ | Just (tc', tys) <- splitTyConApp_maybe ty
+ , tc == tc'
+ = tys
-splitPlusTy :: Type -> (Type, Type)
-splitPlusTy = splitBinTy "splitSumTy" ndpPlusTyConName
+ | otherwise = pprPanic "splitFixedTyConApp" (ppr tc <+> ppr ty)
splitEmbedTy :: Type -> Type
splitEmbedTy = splitUnTy "splitEmbedTy" embedTyConName
mk tc ty1 ty2 = mkTyConApp tc [ty1,ty2]
mkPRepr :: [[Type]] -> VM Type
-mkPRepr [] = return unitTy
mkPRepr tys
= do
- embed <- builtin embedTyCon
- cross <- builtin crossTyCon
- plus <- builtin plusTyCon
+ embed_tc <- builtin embedTyCon
+ sum_tcs <- builtins sumTyCon
+ prod_tcs <- builtins prodTyCon
- let mk_embed ty = mkTyConApp embed [ty]
- mk_cross ty1 ty2 = mkTyConApp cross [ty1, ty2]
- mk_plus ty1 ty2 = mkTyConApp plus [ty1, ty2]
+ let mk_sum [] = unitTy
+ mk_sum [ty] = ty
+ mk_sum tys = mkTyConApp (sum_tcs $ length tys) tys
- mk_tup [] = unitTy
- mk_tup tys = foldr1 mk_cross tys
+ mk_prod [] = unitTy
+ mk_prod [ty] = ty
+ mk_prod tys = mkTyConApp (prod_tcs $ length tys) tys
- mk_sum [] = unitTy
- mk_sum tys = foldr1 mk_plus tys
+ mk_embed ty = mkTyConApp embed_tc [ty]
return . mk_sum
- . map (mk_tup . map mk_embed)
+ . map (mk_prod . map mk_embed)
$ tys
mkToPRepr :: [[CoreExpr]] -> VM ([CoreExpr], Type)
= do
embed_tc <- builtin embedTyCon
embed_dc <- builtin embedDataCon
- cross_tc <- builtin crossTyCon
- cross_dc <- builtin crossDataCon
- plus_tc <- builtin plusTyCon
- left_dc <- builtin leftDataCon
- right_dc <- builtin rightDataCon
-
- let mk_embed expr
- = (mkConApp embed_dc [Type ty, expr],
- mkTyConApp embed_tc [ty])
- where ty = exprType expr
-
- mk_cross (expr1, ty1) (expr2, ty2)
- = (mkConApp cross_dc [Type ty1, Type ty2, expr1, expr2],
- mkTyConApp cross_tc [ty1, ty2])
-
- mk_tup [] = (Var unitDataConId, unitTy)
- mk_tup es = foldr1 mk_cross es
+ sum_tcs <- builtins sumTyCon
+ prod_tcs <- builtins prodTyCon
- mk_sum [] = ([Var unitDataConId], unitTy)
+ let mk_sum [] = ([Var unitDataConId], unitTy)
mk_sum [(expr, ty)] = ([expr], ty)
- mk_sum ((expr, lty) : es)
- = let (alts, rty) = mk_sum es
- in
- (mkConApp left_dc [Type lty, Type rty, expr]
- : [mkConApp right_dc [Type lty, Type rty, alt] | alt <- alts],
- mkTyConApp plus_tc [lty, rty])
+ mk_sum es = (zipWith mk_alt (tyConDataCons sum_tc) exprs,
+ mkTyConApp sum_tc tys)
+ where
+ (exprs, tys) = unzip es
+ sum_tc = sum_tcs (length es)
+ mk_alt dc expr = mkConApp dc (map Type tys ++ [expr])
+
+ mk_prod [] = (Var unitDataConId, unitTy)
+ mk_prod [(expr, ty)] = (expr, ty)
+ mk_prod es = (mkConApp prod_dc (map Type tys ++ exprs),
+ mkTyConApp prod_tc tys)
+ where
+ (exprs, tys) = unzip es
+ prod_tc = prod_tcs (length es)
+ [prod_dc] = tyConDataCons prod_tc
+
+ mk_embed expr = (mkConApp embed_dc [Type ty, expr],
+ mkTyConApp embed_tc [ty])
+ where ty = exprType expr
- return . mk_sum $ map (mk_tup . map mk_embed) ess
+ return . mk_sum $ map (mk_prod . map mk_embed) ess
mkFromPRepr :: CoreExpr -> Type -> [([Var], CoreExpr)] -> VM CoreExpr
mkFromPRepr scrut res_ty alts
= do
embed_dc <- builtin embedDataCon
- cross_dc <- builtin crossDataCon
- left_dc <- builtin leftDataCon
- right_dc <- builtin rightDataCon
- pa_tc <- builtin paTyCon
+ sum_tcs <- builtins sumTyCon
+ prod_tcs <- builtins prodTyCon
- let un_embed expr ty var res
- = Case expr (mkWildId ty) res_ty
- [(DataAlt embed_dc, [var], res)]
-
- un_cross expr ty var1 var2 res
- = Case expr (mkWildId ty) res_ty
- [(DataAlt cross_dc, [var1, var2], res)]
-
- un_tup expr ty [] res = return res
- un_tup expr ty [var] res = return $ un_embed expr ty var res
- un_tup expr ty (var : vars) res
+ let un_sum expr ty [(vars, res)] = un_prod expr ty vars res
+ un_sum expr ty bs
= do
- lv <- newLocalVar FSLIT("x") lty
- rv <- newLocalVar FSLIT("y") rty
- liftM (un_cross expr ty lv rv
- . un_embed (Var lv) lty var)
- (un_tup (Var rv) rty vars res)
+ ps <- mapM (newLocalVar FSLIT("p")) tys
+ bodies <- sequence
+ $ zipWith4 un_prod (map Var ps) tys vars rs
+ return . Case expr (mkWildId ty) res_ty
+ $ zipWith3 mk_alt sum_dcs ps bodies
where
- (lty, rty) = splitCrossTy ty
+ (vars, rs) = unzip bs
+ tys = splitFixedTyConApp sum_tc ty
+ sum_tc = sum_tcs $ length bs
+ sum_dcs = tyConDataCons sum_tc
- un_plus expr ty var1 var2 res1 res2
- = Case expr (mkWildId ty) res_ty
- [(DataAlt left_dc, [var1], res1),
- (DataAlt right_dc, [var2], res2)]
+ mk_alt dc p body = (DataAlt dc, [p], body)
- un_sum expr ty [(vars, res)] = un_tup expr ty vars res
- un_sum expr ty ((vars, res) : alts)
+ un_prod expr ty [] r = return r
+ un_prod expr ty [var] r = return $ un_embed expr ty var r
+ un_prod expr ty vars r
= do
- lv <- newLocalVar FSLIT("l") lty
- rv <- newLocalVar FSLIT("r") rty
- liftM2 (un_plus expr ty lv rv)
- (un_tup (Var lv) lty vars res)
- (un_sum (Var rv) rty alts)
+ xs <- mapM (newLocalVar FSLIT("x")) tys
+ let body = foldr (\(e,t,v) r -> un_embed e t v r) r
+ $ zip3 (map Var xs) tys vars
+ return $ Case expr (mkWildId ty) res_ty
+ [(DataAlt prod_dc, xs, body)]
where
- (lty, rty) = splitPlusTy ty
+ tys = splitFixedTyConApp prod_tc ty
+ prod_tc = prod_tcs $ length vars
+ [prod_dc] = tyConDataCons prod_tc
+
+ un_embed expr ty var r
+ = Case expr (mkWildId ty) res_ty
+ [(DataAlt embed_dc, [var], r)]
un_sum scrut (exprType scrut) alts