2 module Vectorise.Utils (
3 module Vectorise.Utils.Base,
4 module Vectorise.Utils.Closure,
5 module Vectorise.Utils.Hoisting,
6 module Vectorise.Utils.PADict,
7 module Vectorise.Utils.Poly,
11 collectAnnTypeBinders,
16 replicatePD, emptyPD, packByTagPD,
20 zipScalars, scalarClosure,
26 import Vectorise.Utils.Base
27 import Vectorise.Utils.Closure
28 import Vectorise.Utils.Hoisting
29 import Vectorise.Utils.PADict
30 import Vectorise.Utils.Poly
31 import Vectorise.Monad
32 import Vectorise.Builtins
40 -- Annotated Exprs ------------------------------------------------------------
41 collectAnnTypeArgs :: AnnExpr b ann -> (AnnExpr b ann, [Type])
42 collectAnnTypeArgs expr = go expr []
44 go (_, AnnApp f (_, AnnType ty)) tys = go f (ty : tys)
47 collectAnnTypeBinders :: AnnExpr Var ann -> ([Var], AnnExpr Var ann)
48 collectAnnTypeBinders expr = go [] expr
50 go bs (_, AnnLam b e) | isTyCoVar b = go (b:bs) e
51 go bs e = (reverse bs, e)
53 collectAnnValBinders :: AnnExpr Var ann -> ([Var], AnnExpr Var ann)
54 collectAnnValBinders expr = go [] expr
56 go bs (_, AnnLam b e) | isId b = go (b:bs) e
57 go bs e = (reverse bs, e)
59 isAnnTypeArg :: AnnExpr b ann -> Bool
60 isAnnTypeArg (_, AnnType _) = True
61 isAnnTypeArg _ = False
64 -- PD "Parallel Data" Functions -----------------------------------------------
66 -- Given some data that has a PA dictionary, we can convert it to its
67 -- representation type, perform some operation on the data, then convert it back.
69 -- In the DPH backend, the types of these functions are defined
70 -- in dph-common/D.A.P.Lifted/PArray.hs
73 -- | An empty array of the given type.
74 emptyPD :: Type -> VM CoreExpr
75 emptyPD = paMethod emptyPDVar "emptyPD"
78 -- | Produce an array containing copies of a given element.
80 :: CoreExpr -- ^ Number of copies in the resulting array.
81 -> CoreExpr -- ^ Value to replicate.
85 = liftM (`mkApps` [len,x])
86 $ paMethod replicatePDVar "replicatePD" (exprType x)
89 -- | Select some elements from an array that correspond to a particular tag value
90 --- and pack them into a new array.
91 -- eg packByTagPD Int# [:23, 42, 95, 50, 27, 49:] 3 [:1, 2, 1, 2, 3, 2:] 2
95 :: Type -- ^ Element type.
96 -> CoreExpr -- ^ Source array.
97 -> CoreExpr -- ^ Length of resulting array.
98 -> CoreExpr -- ^ Tag values of elements in source array.
99 -> CoreExpr -- ^ The tag value for the elements to select.
102 packByTagPD ty xs len tags t
103 = liftM (`mkApps` [xs, len, tags, t])
104 (paMethod packByTagPDVar "packByTagPD" ty)
107 -- | Combine some arrays based on a selector.
108 -- The selector says which source array to choose for each element of the
111 :: Type -- ^ Element type
112 -> CoreExpr -- ^ Length of resulting array
113 -> CoreExpr -- ^ Selector.
114 -> [CoreExpr] -- ^ Arrays to combine.
117 combinePD ty len sel xs
118 = liftM (`mkApps` (len : sel : xs))
119 (paMethod (combinePDVar n) ("combine" ++ show n ++ "PD") ty)
124 -- | Like `replicatePD` but use the lifting context in the vectoriser state.
125 liftPD :: CoreExpr -> VM CoreExpr
128 lc <- builtin liftingContext
129 replicatePD (Var lc) x
132 -- Scalars --------------------------------------------------------------------
133 zipScalars :: [Type] -> Type -> VM CoreExpr
134 zipScalars arg_tys res_ty
136 scalar <- builtin scalarClass
137 (dfuns, _) <- mapAndUnzipM (\ty -> lookupInst scalar [ty]) ty_args
138 zipf <- builtin (scalarZip $ length arg_tys)
139 return $ Var zipf `mkTyApps` ty_args `mkApps` map Var dfuns
141 ty_args = arg_tys ++ [res_ty]
144 scalarClosure :: [Type] -> Type -> CoreExpr -> CoreExpr -> VM CoreExpr
145 scalarClosure arg_tys res_ty scalar_fun array_fun
147 ctr <- builtin (closureCtrFun $ length arg_tys)
148 pas <- mapM paDictOfType (init arg_tys)
149 return $ Var ctr `mkTyApps` (arg_tys ++ [res_ty])
150 `mkApps` (pas ++ [scalar_fun, array_fun])
155 boxExpr :: Type -> VExpr -> VM VExpr
156 boxExpr ty (vexpr, lexpr)
157 | Just (tycon, []) <- splitTyConApp_maybe ty
158 , isUnLiftedTyCon tycon
160 r <- lookupBoxedTyCon tycon
162 Just tycon' -> let [dc] = tyConDataCons tycon'
164 return (mkConApp dc [vexpr], lexpr)
165 Nothing -> return (vexpr, lexpr)