Move VectType module to Vectorise tree
[ghc-hetmet.git] / compiler / vectorise / VectUtils.hs
1 module VectUtils (
2   collectAnnTypeBinders, collectAnnTypeArgs, isAnnTypeArg,
3   collectAnnValBinders,
4   dataConTagZ, mkDataConTag, mkDataConTagLit,
5
6   newLocalVVar,
7
8   mkBuiltinCo, voidType, mkWrapType,
9   mkPADictType, mkPArrayType, mkPDataType, mkPReprType, mkPArray,
10   mkBuiltinTyConApps, mkClosureTypes,
11
12   pdataReprTyCon, pdataReprDataCon, mkVScrut,
13   prDictOfType, prDFunOfTyCon,
14   paDictArgType, paDictOfType, paDFunType,
15   paMethod, wrapPR, replicatePD, emptyPD, packByTagPD,
16   combinePD,
17   liftPD,
18   zipScalars, scalarClosure,
19   polyAbstract, polyApply, polyVApply, polyArity
20 ) where
21 import Vectorise.Monad
22 import Vectorise.Vect
23 import Vectorise.Builtins
24
25 import CoreSyn
26 import CoreUtils
27 import Coercion
28 import Type
29 import TypeRep
30 import TyCon
31 import DataCon
32 import Var
33 import MkId
34 import Literal
35 import Outputable
36 import FastString
37 import Control.Monad
38
39
40 collectAnnTypeArgs :: AnnExpr b ann -> (AnnExpr b ann, [Type])
41 collectAnnTypeArgs expr = go expr []
42   where
43     go (_, AnnApp f (_, AnnType ty)) tys = go f (ty : tys)
44     go e                             tys = (e, tys)
45
46 collectAnnTypeBinders :: AnnExpr Var ann -> ([Var], AnnExpr Var ann)
47 collectAnnTypeBinders expr = go [] expr
48   where
49     go bs (_, AnnLam b e) | isTyVar b = go (b:bs) e
50     go bs e                           = (reverse bs, e)
51
52 collectAnnValBinders :: AnnExpr Var ann -> ([Var], AnnExpr Var ann)
53 collectAnnValBinders expr = go [] expr
54   where
55     go bs (_, AnnLam b e) | isId b = go (b:bs) e
56     go bs e                        = (reverse bs, e)
57
58 isAnnTypeArg :: AnnExpr b ann -> Bool
59 isAnnTypeArg (_, AnnType _) = True
60 isAnnTypeArg _              = False
61
62 dataConTagZ :: DataCon -> Int
63 dataConTagZ con = dataConTag con - fIRST_TAG
64
65 mkDataConTagLit :: DataCon -> Literal
66 mkDataConTagLit = mkMachInt . toInteger . dataConTagZ
67
68 mkDataConTag :: DataCon -> CoreExpr
69 mkDataConTag = mkIntLitInt . dataConTagZ
70
71 splitPrimTyCon :: Type -> Maybe TyCon
72 splitPrimTyCon ty
73   | Just (tycon, []) <- splitTyConApp_maybe ty
74   , isPrimTyCon tycon
75   = Just tycon
76
77   | otherwise = Nothing
78
79 mkBuiltinTyConApp :: (Builtins -> TyCon) -> [Type] -> VM Type
80 mkBuiltinTyConApp get_tc tys
81   = do
82       tc <- builtin get_tc
83       return $ mkTyConApp tc tys
84
85 mkBuiltinTyConApps :: (Builtins -> TyCon) -> [Type] -> Type -> VM Type
86 mkBuiltinTyConApps get_tc tys ty
87   = do
88       tc <- builtin get_tc
89       return $ foldr (mk tc) ty tys
90   where
91     mk tc ty1 ty2 = mkTyConApp tc [ty1,ty2]
92
93 voidType :: VM Type
94 voidType = mkBuiltinTyConApp voidTyCon []
95
96 mkWrapType :: Type -> VM Type
97 mkWrapType ty = mkBuiltinTyConApp wrapTyCon [ty]
98
99
100 mkClosureTypes :: [Type] -> Type -> VM Type
101 mkClosureTypes = mkBuiltinTyConApps closureTyCon
102
103 mkPReprType :: Type -> VM Type
104 mkPReprType ty = mkBuiltinTyConApp preprTyCon [ty]
105
106 mkPADictType :: Type -> VM Type
107 mkPADictType ty = mkBuiltinTyConApp paTyCon [ty]
108
109 mkPArrayType :: Type -> VM Type
110 mkPArrayType ty
111   | Just tycon <- splitPrimTyCon ty
112   = do
113       r <- lookupPrimPArray tycon
114       case r of
115         Just arr -> return $ mkTyConApp arr []
116         Nothing  -> cantVectorise "Primitive tycon not vectorised" (ppr tycon)
117 mkPArrayType ty = mkBuiltinTyConApp parrayTyCon [ty]
118
119 mkPDataType :: Type -> VM Type
120 mkPDataType ty = mkBuiltinTyConApp pdataTyCon [ty]
121
122 mkPArray :: Type -> CoreExpr -> CoreExpr -> VM CoreExpr
123 mkPArray ty len dat = do
124                         tc <- builtin parrayTyCon
125                         let [dc] = tyConDataCons tc
126                         return $ mkConApp dc [Type ty, len, dat]
127
128 mkBuiltinCo :: (Builtins -> TyCon) -> VM Coercion
129 mkBuiltinCo get_tc
130   = do
131       tc <- builtin get_tc
132       return $ mkTyConApp tc []
133
134 pdataReprTyCon :: Type -> VM (TyCon, [Type])
135 pdataReprTyCon ty = builtin pdataTyCon >>= (`lookupFamInst` [ty])
136
137 pdataReprDataCon :: Type -> VM (DataCon, [Type])
138 pdataReprDataCon ty
139   = do
140       (tc, arg_tys) <- pdataReprTyCon ty
141       let [dc] = tyConDataCons tc
142       return (dc, arg_tys)
143
144 mkVScrut :: VExpr -> VM (CoreExpr, CoreExpr, TyCon, [Type])
145 mkVScrut (ve, le)
146   = do
147       (tc, arg_tys) <- pdataReprTyCon ty
148       return (ve, unwrapFamInstScrut tc arg_tys le, tc, arg_tys)
149   where
150     ty = exprType ve
151
152 prDFunOfTyCon :: TyCon -> VM CoreExpr
153 prDFunOfTyCon tycon
154   = liftM Var
155   . maybeCantVectoriseM "No PR dictionary for tycon" (ppr tycon)
156   $ lookupTyConPR tycon
157
158
159 paDictArgType :: TyVar -> VM (Maybe Type)
160 paDictArgType tv = go (TyVarTy tv) (tyVarKind tv)
161   where
162     go ty k | Just k' <- kindView k = go ty k'
163     go ty (FunTy k1 k2)
164       = do
165           tv   <- newTyVar (fsLit "a") k1
166           mty1 <- go (TyVarTy tv) k1
167           case mty1 of
168             Just ty1 -> do
169                           mty2 <- go (AppTy ty (TyVarTy tv)) k2
170                           return $ fmap (ForAllTy tv . FunTy ty1) mty2
171             Nothing  -> go ty k2
172
173     go ty k
174       | isLiftedTypeKind k
175       = liftM Just (mkPADictType ty)
176
177     go _ _ = return Nothing
178
179
180 -- | Get the PA dictionary for some type, or `Nothing` if there isn't one.
181 paDictOfType :: Type -> VM (Maybe CoreExpr)
182 paDictOfType ty 
183   = paDictOfTyApp ty_fn ty_args
184   where
185     (ty_fn, ty_args) = splitAppTys ty
186
187     paDictOfTyApp :: Type -> [Type] -> VM (Maybe CoreExpr)
188     paDictOfTyApp ty_fn ty_args
189         | Just ty_fn' <- coreView ty_fn 
190         = paDictOfTyApp ty_fn' ty_args
191
192     paDictOfTyApp (TyVarTy tv) ty_args
193      = do dfun <- maybeV (lookupTyVarPA tv)
194           liftM Just $ paDFunApply dfun ty_args
195
196     paDictOfTyApp (TyConApp tc _) ty_args
197      = do mdfun <- lookupTyConPA tc
198           case mdfun of
199             Nothing     
200              -> pprTrace "VectUtils.paDictOfType"
201                          (vcat [ text "No PA dictionary"
202                                , text "for tycon: " <> ppr tc
203                                , text "in type:   " <> ppr ty])
204              $ return Nothing
205
206             Just dfun   -> liftM Just $ paDFunApply (Var dfun) ty_args
207
208     paDictOfTyApp ty _
209      = cantVectorise "Can't construct PA dictionary for type" (ppr ty)
210
211
212
213 paDFunType :: TyCon -> VM Type
214 paDFunType tc
215   = do
216       margs <- mapM paDictArgType tvs
217       res   <- mkPADictType (mkTyConApp tc arg_tys)
218       return . mkForAllTys tvs
219              $ mkFunTys [arg | Just arg <- margs] res
220   where
221     tvs = tyConTyVars tc
222     arg_tys = mkTyVarTys tvs
223
224 paDFunApply :: CoreExpr -> [Type] -> VM CoreExpr
225 paDFunApply dfun tys
226  = do Just dicts <- liftM sequence $ mapM paDictOfType tys
227       return $ mkApps (mkTyApps dfun tys) dicts
228
229
230 paMethod :: (Builtins -> Var) -> String -> Type -> VM CoreExpr
231 paMethod _ name ty
232   | Just tycon <- splitPrimTyCon ty
233   = liftM Var
234   . maybeCantVectoriseM "No PA method" (text name <+> text "for" <+> ppr tycon)
235   $ lookupPrimMethod tycon name
236
237 paMethod method _ ty
238   = do
239       fn        <- builtin method
240       Just dict <- paDictOfType ty
241       return $ mkApps (Var fn) [Type ty, dict]
242
243 prDictOfType :: Type -> VM CoreExpr
244 prDictOfType ty = prDictOfTyApp ty_fn ty_args
245   where
246     (ty_fn, ty_args) = splitAppTys ty
247
248 prDictOfTyApp :: Type -> [Type] -> VM CoreExpr
249 prDictOfTyApp ty_fn ty_args
250   | Just ty_fn' <- coreView ty_fn = prDictOfTyApp ty_fn' ty_args
251 prDictOfTyApp (TyConApp tc _) ty_args
252   = do
253       dfun <- liftM Var $ maybeV (lookupTyConPR tc)
254       prDFunApply dfun ty_args
255 prDictOfTyApp _ _ = noV
256
257 prDFunApply :: CoreExpr -> [Type] -> VM CoreExpr
258 prDFunApply dfun tys
259   = do
260       dicts <- mapM prDictOfType tys
261       return $ mkApps (mkTyApps dfun tys) dicts
262
263 wrapPR :: Type -> VM CoreExpr
264 wrapPR ty
265   = do
266       Just  pa_dict <- paDictOfType ty
267       pr_dfun       <- prDFunOfTyCon =<< builtin wrapTyCon
268       return $ mkApps pr_dfun [Type ty, pa_dict]
269
270 replicatePD :: CoreExpr -> CoreExpr -> VM CoreExpr
271 replicatePD len x = liftM (`mkApps` [len,x])
272                           (paMethod replicatePDVar "replicatePD" (exprType x))
273
274 emptyPD :: Type -> VM CoreExpr
275 emptyPD = paMethod emptyPDVar "emptyPD"
276
277 packByTagPD :: Type -> CoreExpr -> CoreExpr -> CoreExpr -> CoreExpr
278                  -> VM CoreExpr
279 packByTagPD ty xs len tags t
280   = liftM (`mkApps` [xs, len, tags, t])
281           (paMethod packByTagPDVar "packByTagPD" ty)
282
283 combinePD :: Type -> CoreExpr -> CoreExpr -> [CoreExpr]
284           -> VM CoreExpr
285 combinePD ty len sel xs
286   = liftM (`mkApps` (len : sel : xs))
287           (paMethod (combinePDVar n) ("combine" ++ show n ++ "PD") ty)
288   where
289     n = length xs
290
291 -- | Like `replicatePD` but use the lifting context in the vectoriser state.
292 liftPD :: CoreExpr -> VM CoreExpr
293 liftPD x
294   = do
295       lc <- builtin liftingContext
296       replicatePD (Var lc) x
297
298 zipScalars :: [Type] -> Type -> VM CoreExpr
299 zipScalars arg_tys res_ty
300   = do
301       scalar <- builtin scalarClass
302       (dfuns, _) <- mapAndUnzipM (\ty -> lookupInst scalar [ty]) ty_args
303       zipf <- builtin (scalarZip $ length arg_tys)
304       return $ Var zipf `mkTyApps` ty_args `mkApps` map Var dfuns
305     where
306       ty_args = arg_tys ++ [res_ty]
307
308 scalarClosure :: [Type] -> Type -> CoreExpr -> CoreExpr -> VM CoreExpr
309 scalarClosure arg_tys res_ty scalar_fun array_fun
310   = do
311       ctr      <- builtin (closureCtrFun $ length arg_tys)
312       Just pas <- liftM sequence $ mapM paDictOfType (init arg_tys)
313       return $ Var ctr `mkTyApps` (arg_tys ++ [res_ty])
314                        `mkApps`   (pas ++ [scalar_fun, array_fun])
315
316 newLocalVVar :: FastString -> Type -> VM VVar
317 newLocalVVar fs vty
318   = do
319       lty <- mkPDataType vty
320       vv  <- newLocalVar fs vty
321       lv  <- newLocalVar fs lty
322       return (vv,lv)
323
324 polyAbstract :: [TyVar] -> ([Var] -> VM a) -> VM a
325 polyAbstract tvs p
326   = localV
327   $ do
328       mdicts <- mapM mk_dict_var tvs
329       zipWithM_ (\tv -> maybe (defLocalTyVar tv)
330                               (defLocalTyVarWithPA tv . Var)) tvs mdicts
331       p (mk_args mdicts)
332   where
333     mk_dict_var tv = do
334                        r <- paDictArgType tv
335                        case r of
336                          Just ty -> liftM Just (newLocalVar (fsLit "dPA") ty)
337                          Nothing -> return Nothing
338
339     mk_args mdicts = [dict | Just dict <- mdicts]
340
341 polyArity :: [TyVar] -> VM Int
342 polyArity tvs = do
343                   tys <- mapM paDictArgType tvs
344                   return $ length [() | Just _ <- tys]
345
346 polyApply :: CoreExpr -> [Type] -> VM CoreExpr
347 polyApply expr tys
348  = do Just dicts <- liftM sequence $ mapM paDictOfType tys
349       return $ expr `mkTyApps` tys `mkApps` dicts
350
351 polyVApply :: VExpr -> [Type] -> VM VExpr
352 polyVApply expr tys
353  = do Just dicts <- liftM sequence $ mapM paDictOfType tys
354       return     $ mapVect (\e -> e `mkTyApps` tys `mkApps` dicts) expr
355
356
357 {-
358 boxExpr :: Type -> VExpr -> VM VExpr
359 boxExpr ty (vexpr, lexpr)
360   | Just (tycon, []) <- splitTyConApp_maybe ty
361   , isUnLiftedTyCon tycon
362   = do
363       r <- lookupBoxedTyCon tycon
364       case r of
365         Just tycon' -> let [dc] = tyConDataCons tycon'
366                        in
367                        return (mkConApp dc [vexpr], lexpr)
368         Nothing     -> return (vexpr, lexpr)
369 -}
370
371