Merge remote branch 'origin/master'
[ghc-hetmet.git] / compiler / vectorise / Vectorise / Utils.hs
1
2 module Vectorise.Utils (
3   module Vectorise.Utils.Base,
4   module Vectorise.Utils.Closure,
5   module Vectorise.Utils.Hoisting,
6   module Vectorise.Utils.PADict,
7   module Vectorise.Utils.Poly,
8
9   -- * Annotated Exprs
10   collectAnnTypeArgs,
11   collectAnnTypeBinders,
12   collectAnnValBinders,
13   isAnnTypeArg,
14
15   -- * PD Functions
16   replicatePD, emptyPD, packByTagPD,
17   combinePD, liftPD,
18
19   -- * Scalars
20   zipScalars, scalarClosure,
21
22   -- * Naming
23   newLocalVar
24
25 where
26 import Vectorise.Utils.Base
27 import Vectorise.Utils.Closure
28 import Vectorise.Utils.Hoisting
29 import Vectorise.Utils.PADict
30 import Vectorise.Utils.Poly
31 import Vectorise.Monad
32 import Vectorise.Builtins
33 import CoreSyn
34 import CoreUtils
35 import Type
36 import Control.Monad
37
38
39 -- Annotated Exprs ------------------------------------------------------------
40 collectAnnTypeArgs :: AnnExpr b ann -> (AnnExpr b ann, [Type])
41 collectAnnTypeArgs expr = go expr []
42   where
43     go (_, AnnApp f (_, AnnType ty)) tys = go f (ty : tys)
44     go e                             tys = (e, tys)
45
46 collectAnnTypeBinders :: AnnExpr Var ann -> ([Var], AnnExpr Var ann)
47 collectAnnTypeBinders expr = go [] expr
48   where
49     go bs (_, AnnLam b e) | isTyVar b = go (b:bs) e
50     go bs e                           = (reverse bs, e)
51
52 collectAnnValBinders :: AnnExpr Var ann -> ([Var], AnnExpr Var ann)
53 collectAnnValBinders expr = go [] expr
54   where
55     go bs (_, AnnLam b e) | isId b = go (b:bs) e
56     go bs e                        = (reverse bs, e)
57
58 isAnnTypeArg :: AnnExpr b ann -> Bool
59 isAnnTypeArg (_, AnnType _) = True
60 isAnnTypeArg _              = False
61
62
63 -- PD "Parallel Data" Functions -----------------------------------------------
64 --
65 --   Given some data that has a PA dictionary, we can convert it to its 
66 --   representation type, perform some operation on the data, then convert it back.
67 --
68 --   In the DPH backend, the types of these functions are defined
69 --   in dph-common/D.A.P.Lifted/PArray.hs
70 --
71
72 -- | An empty array of the given type.
73 emptyPD :: Type -> VM CoreExpr
74 emptyPD = paMethod emptyPDVar "emptyPD"
75
76
77 -- | Produce an array containing copies of a given element.
78 replicatePD
79         :: CoreExpr     -- ^ Number of copies in the resulting array.
80         -> CoreExpr     -- ^ Value to replicate.
81         -> VM CoreExpr
82
83 replicatePD len x 
84         = liftM (`mkApps` [len,x])
85         $ paMethod replicatePDVar "replicatePD" (exprType x)
86
87
88 -- | Select some elements from an array that correspond to a particular tag value
89 ---  and pack them into a new array.
90 --   eg  packByTagPD Int# [:23, 42, 95, 50, 27, 49:]  3 [:1, 2, 1, 2, 3, 2:] 2 
91 --          ==> [:42, 50, 49:]
92 --
93 packByTagPD 
94         :: Type         -- ^ Element type.
95         -> CoreExpr     -- ^ Source array.
96         -> CoreExpr     -- ^ Length of resulting array.
97         -> CoreExpr     -- ^ Tag values of elements in source array.
98         -> CoreExpr     -- ^ The tag value for the elements to select.
99         -> VM CoreExpr
100
101 packByTagPD ty xs len tags t
102   = liftM (`mkApps` [xs, len, tags, t])
103           (paMethod packByTagPDVar "packByTagPD" ty)
104
105
106 -- | Combine some arrays based on a selector.
107 --     The selector says which source array to choose for each element of the
108 --     resulting array.
109 combinePD 
110         :: Type         -- ^ Element type
111         -> CoreExpr     -- ^ Length of resulting array
112         -> CoreExpr     -- ^ Selector.
113         -> [CoreExpr]   -- ^ Arrays to combine.
114         -> VM CoreExpr
115
116 combinePD ty len sel xs
117   = liftM (`mkApps` (len : sel : xs))
118           (paMethod (combinePDVar n) ("combine" ++ show n ++ "PD") ty)
119   where
120     n = length xs
121
122
123 -- | Like `replicatePD` but use the lifting context in the vectoriser state.
124 liftPD :: CoreExpr -> VM CoreExpr
125 liftPD x
126   = do
127       lc <- builtin liftingContext
128       replicatePD (Var lc) x
129
130
131 -- Scalars --------------------------------------------------------------------
132 zipScalars :: [Type] -> Type -> VM CoreExpr
133 zipScalars arg_tys res_ty
134   = do
135       scalar <- builtin scalarClass
136       (dfuns, _) <- mapAndUnzipM (\ty -> lookupInst scalar [ty]) ty_args
137       zipf <- builtin (scalarZip $ length arg_tys)
138       return $ Var zipf `mkTyApps` ty_args `mkApps` map Var dfuns
139     where
140       ty_args = arg_tys ++ [res_ty]
141
142
143 scalarClosure :: [Type] -> Type -> CoreExpr -> CoreExpr -> VM CoreExpr
144 scalarClosure arg_tys res_ty scalar_fun array_fun
145   = do
146       ctr <- builtin (closureCtrFun $ length arg_tys)
147       pas <- mapM paDictOfType (init arg_tys)
148       return $ Var ctr `mkTyApps` (arg_tys ++ [res_ty])
149                        `mkApps`   (pas ++ [scalar_fun, array_fun])
150
151
152
153 {-
154 boxExpr :: Type -> VExpr -> VM VExpr
155 boxExpr ty (vexpr, lexpr)
156   | Just (tycon, []) <- splitTyConApp_maybe ty
157   , isUnLiftedTyCon tycon
158   = do
159       r <- lookupBoxedTyCon tycon
160       case r of
161         Just tycon' -> let [dc] = tyConDataCons tycon'
162                        in
163                        return (mkConApp dc [vexpr], lexpr)
164         Nothing     -> return (vexpr, lexpr)
165 -}