2 module Vectorise.Utils (
3 module Vectorise.Utils.Base,
4 module Vectorise.Utils.Closure,
5 module Vectorise.Utils.Hoisting,
6 module Vectorise.Utils.PADict,
7 module Vectorise.Utils.Poly,
11 collectAnnTypeBinders,
16 replicatePD, emptyPD, packByTagPD,
20 zipScalars, scalarClosure,
26 import Vectorise.Utils.Base
27 import Vectorise.Utils.Closure
28 import Vectorise.Utils.Hoisting
29 import Vectorise.Utils.PADict
30 import Vectorise.Utils.Poly
31 import Vectorise.Monad
32 import Vectorise.Builtins
39 -- Annotated Exprs ------------------------------------------------------------
40 collectAnnTypeArgs :: AnnExpr b ann -> (AnnExpr b ann, [Type])
41 collectAnnTypeArgs expr = go expr []
43 go (_, AnnApp f (_, AnnType ty)) tys = go f (ty : tys)
46 collectAnnTypeBinders :: AnnExpr Var ann -> ([Var], AnnExpr Var ann)
47 collectAnnTypeBinders expr = go [] expr
49 go bs (_, AnnLam b e) | isTyVar b = go (b:bs) e
50 go bs e = (reverse bs, e)
52 collectAnnValBinders :: AnnExpr Var ann -> ([Var], AnnExpr Var ann)
53 collectAnnValBinders expr = go [] expr
55 go bs (_, AnnLam b e) | isId b = go (b:bs) e
56 go bs e = (reverse bs, e)
58 isAnnTypeArg :: AnnExpr b ann -> Bool
59 isAnnTypeArg (_, AnnType _) = True
60 isAnnTypeArg _ = False
63 -- PD "Parallel Data" Functions -----------------------------------------------
65 -- Given some data that has a PA dictionary, we can convert it to its
66 -- representation type, perform some operation on the data, then convert it back.
68 -- In the DPH backend, the types of these functions are defined
69 -- in dph-common/D.A.P.Lifted/PArray.hs
72 -- | An empty array of the given type.
73 emptyPD :: Type -> VM CoreExpr
74 emptyPD = paMethod emptyPDVar "emptyPD"
77 -- | Produce an array containing copies of a given element.
79 :: CoreExpr -- ^ Number of copies in the resulting array.
80 -> CoreExpr -- ^ Value to replicate.
84 = liftM (`mkApps` [len,x])
85 $ paMethod replicatePDVar "replicatePD" (exprType x)
88 -- | Select some elements from an array that correspond to a particular tag value
89 --- and pack them into a new array.
90 -- eg packByTagPD Int# [:23, 42, 95, 50, 27, 49:] 3 [:1, 2, 1, 2, 3, 2:] 2
94 :: Type -- ^ Element type.
95 -> CoreExpr -- ^ Source array.
96 -> CoreExpr -- ^ Length of resulting array.
97 -> CoreExpr -- ^ Tag values of elements in source array.
98 -> CoreExpr -- ^ The tag value for the elements to select.
101 packByTagPD ty xs len tags t
102 = liftM (`mkApps` [xs, len, tags, t])
103 (paMethod packByTagPDVar "packByTagPD" ty)
106 -- | Combine some arrays based on a selector.
107 -- The selector says which source array to choose for each element of the
110 :: Type -- ^ Element type
111 -> CoreExpr -- ^ Length of resulting array
112 -> CoreExpr -- ^ Selector.
113 -> [CoreExpr] -- ^ Arrays to combine.
116 combinePD ty len sel xs
117 = liftM (`mkApps` (len : sel : xs))
118 (paMethod (combinePDVar n) ("combine" ++ show n ++ "PD") ty)
123 -- | Like `replicatePD` but use the lifting context in the vectoriser state.
124 liftPD :: CoreExpr -> VM CoreExpr
127 lc <- builtin liftingContext
128 replicatePD (Var lc) x
131 -- Scalars --------------------------------------------------------------------
132 zipScalars :: [Type] -> Type -> VM CoreExpr
133 zipScalars arg_tys res_ty
135 scalar <- builtin scalarClass
136 (dfuns, _) <- mapAndUnzipM (\ty -> lookupInst scalar [ty]) ty_args
137 zipf <- builtin (scalarZip $ length arg_tys)
138 return $ Var zipf `mkTyApps` ty_args `mkApps` map Var dfuns
140 ty_args = arg_tys ++ [res_ty]
143 scalarClosure :: [Type] -> Type -> CoreExpr -> CoreExpr -> VM CoreExpr
144 scalarClosure arg_tys res_ty scalar_fun array_fun
146 ctr <- builtin (closureCtrFun $ length arg_tys)
147 pas <- mapM paDictOfType (init arg_tys)
148 return $ Var ctr `mkTyApps` (arg_tys ++ [res_ty])
149 `mkApps` (pas ++ [scalar_fun, array_fun])
154 boxExpr :: Type -> VExpr -> VM VExpr
155 boxExpr ty (vexpr, lexpr)
156 | Just (tycon, []) <- splitTyConApp_maybe ty
157 , isUnLiftedTyCon tycon
159 r <- lookupBoxedTyCon tycon
161 Just tycon' -> let [dc] = tyConDataCons tycon'
163 return (mkConApp dc [vexpr], lexpr)
164 Nothing -> return (vexpr, lexpr)