3dd54256318259d2a3586e623b589253d7659876
[ghc-hetmet.git] / compiler / vectorise / Vectorise / Utils.hs
1
2 module Vectorise.Utils (
3   module Vectorise.Utils.Base,
4   module Vectorise.Utils.Closure,
5   module Vectorise.Utils.Hoisting,
6   module Vectorise.Utils.PADict,
7   module Vectorise.Utils.PRDict,
8   module Vectorise.Utils.Poly,
9
10   -- * Annotated Exprs
11   collectAnnTypeArgs,
12   collectAnnTypeBinders,
13   collectAnnValBinders,
14   isAnnTypeArg,
15
16   -- * PD Functions
17   replicatePD, emptyPD, packByTagPD,
18   combinePD, liftPD,
19
20   -- * Scalars
21   zipScalars, scalarClosure,
22
23   -- * Naming
24   newLocalVar
25
26 where
27 import Vectorise.Utils.Base
28 import Vectorise.Utils.Closure
29 import Vectorise.Utils.Hoisting
30 import Vectorise.Utils.PADict
31 import Vectorise.Utils.PRDict
32 import Vectorise.Utils.Poly
33 import Vectorise.Monad
34 import Vectorise.Builtins
35 import CoreSyn
36 import CoreUtils
37 import Type
38 import Var
39 import Control.Monad
40
41
42 -- Annotated Exprs ------------------------------------------------------------
43 collectAnnTypeArgs :: AnnExpr b ann -> (AnnExpr b ann, [Type])
44 collectAnnTypeArgs expr = go expr []
45   where
46     go (_, AnnApp f (_, AnnType ty)) tys = go f (ty : tys)
47     go e                             tys = (e, tys)
48
49 collectAnnTypeBinders :: AnnExpr Var ann -> ([Var], AnnExpr Var ann)
50 collectAnnTypeBinders expr = go [] expr
51   where
52     go bs (_, AnnLam b e) | isTyCoVar b = go (b:bs) e
53     go bs e                           = (reverse bs, e)
54
55 collectAnnValBinders :: AnnExpr Var ann -> ([Var], AnnExpr Var ann)
56 collectAnnValBinders expr = go [] expr
57   where
58     go bs (_, AnnLam b e) | isId b = go (b:bs) e
59     go bs e                        = (reverse bs, e)
60
61 isAnnTypeArg :: AnnExpr b ann -> Bool
62 isAnnTypeArg (_, AnnType _) = True
63 isAnnTypeArg _              = False
64
65
66 -- PD Functions ---------------------------------------------------------------
67 replicatePD :: CoreExpr -> CoreExpr -> VM CoreExpr
68 replicatePD len x = liftM (`mkApps` [len,x])
69                           (paMethod replicatePDVar "replicatePD" (exprType x))
70
71 emptyPD :: Type -> VM CoreExpr
72 emptyPD = paMethod emptyPDVar "emptyPD"
73
74
75 packByTagPD :: Type -> CoreExpr -> CoreExpr -> CoreExpr -> CoreExpr -> VM CoreExpr
76 packByTagPD ty xs len tags t
77   = liftM (`mkApps` [xs, len, tags, t])
78           (paMethod packByTagPDVar "packByTagPD" ty)
79
80
81 combinePD :: Type -> CoreExpr -> CoreExpr -> [CoreExpr] -> VM CoreExpr
82 combinePD ty len sel xs
83   = liftM (`mkApps` (len : sel : xs))
84           (paMethod (combinePDVar n) ("combine" ++ show n ++ "PD") ty)
85   where
86     n = length xs
87
88
89 -- | Like `replicatePD` but use the lifting context in the vectoriser state.
90 liftPD :: CoreExpr -> VM CoreExpr
91 liftPD x
92   = do
93       lc <- builtin liftingContext
94       replicatePD (Var lc) x
95
96
97 -- Scalars --------------------------------------------------------------------
98 zipScalars :: [Type] -> Type -> VM CoreExpr
99 zipScalars arg_tys res_ty
100   = do
101       scalar <- builtin scalarClass
102       (dfuns, _) <- mapAndUnzipM (\ty -> lookupInst scalar [ty]) ty_args
103       zipf <- builtin (scalarZip $ length arg_tys)
104       return $ Var zipf `mkTyApps` ty_args `mkApps` map Var dfuns
105     where
106       ty_args = arg_tys ++ [res_ty]
107
108
109 scalarClosure :: [Type] -> Type -> CoreExpr -> CoreExpr -> VM CoreExpr
110 scalarClosure arg_tys res_ty scalar_fun array_fun
111   = do
112       ctr      <- builtin (closureCtrFun $ length arg_tys)
113       Just pas <- liftM sequence $ mapM paDictOfType (init arg_tys)
114       return $ Var ctr `mkTyApps` (arg_tys ++ [res_ty])
115                        `mkApps`   (pas ++ [scalar_fun, array_fun])
116
117
118
119 {-
120 boxExpr :: Type -> VExpr -> VM VExpr
121 boxExpr ty (vexpr, lexpr)
122   | Just (tycon, []) <- splitTyConApp_maybe ty
123   , isUnLiftedTyCon tycon
124   = do
125       r <- lookupBoxedTyCon tycon
126       case r of
127         Just tycon' -> let [dc] = tyConDataCons tycon'
128                        in
129                        return (mkConApp dc [vexpr], lexpr)
130         Nothing     -> return (vexpr, lexpr)
131 -}