Add vectorisation built-ins
[ghc-hetmet.git] / compiler / vectorise / VectUtils.hs
1 {-# OPTIONS -w #-}
2 -- The above warning supression flag is a temporary kludge.
3 -- While working on this module you are encouraged to remove it and fix
4 -- any warnings in the module. See
5 --     http://hackage.haskell.org/trac/ghc/wiki/Commentary/CodingStyle#Warnings
6 -- for details
7
8 module VectUtils (
9   collectAnnTypeBinders, collectAnnTypeArgs, isAnnTypeArg,
10   collectAnnValBinders,
11   mkDataConTag, mkDataConTagLit,
12
13   mkBuiltinCo,
14   mkPADictType, mkPArrayType, mkPReprType,
15
16   parrayReprTyCon, parrayReprDataCon, mkVScrut,
17   prDFunOfTyCon,
18   paDictArgType, paDictOfType, paDFunType,
19   paMethod, mkPR, lengthPA, replicatePA, emptyPA, packPA, liftPA,
20   polyAbstract, polyApply, polyVApply,
21   hoistBinding, hoistExpr, hoistPolyVExpr, takeHoisted,
22   buildClosure, buildClosures,
23   mkClosureApp
24 ) where
25
26 #include "HsVersions.h"
27
28 import VectCore
29 import VectMonad
30
31 import DsUtils
32 import CoreSyn
33 import CoreUtils
34 import Coercion
35 import Type
36 import TypeRep
37 import TyCon
38 import DataCon
39 import Var
40 import Id                 ( mkWildId )
41 import MkId               ( unwrapFamInstScrut )
42 import Name               ( Name )
43 import PrelNames
44 import TysWiredIn
45 import TysPrim            ( intPrimTy )
46 import BasicTypes         ( Boxity(..) )
47 import Literal            ( Literal, mkMachInt )
48
49 import Outputable
50 import FastString
51
52 import Data.List             ( zipWith4 )
53 import Control.Monad         ( liftM, liftM2, zipWithM_ )
54
55 collectAnnTypeArgs :: AnnExpr b ann -> (AnnExpr b ann, [Type])
56 collectAnnTypeArgs expr = go expr []
57   where
58     go (_, AnnApp f (_, AnnType ty)) tys = go f (ty : tys)
59     go e                             tys = (e, tys)
60
61 collectAnnTypeBinders :: AnnExpr Var ann -> ([Var], AnnExpr Var ann)
62 collectAnnTypeBinders expr = go [] expr
63   where
64     go bs (_, AnnLam b e) | isTyVar b = go (b:bs) e
65     go bs e                           = (reverse bs, e)
66
67 collectAnnValBinders :: AnnExpr Var ann -> ([Var], AnnExpr Var ann)
68 collectAnnValBinders expr = go [] expr
69   where
70     go bs (_, AnnLam b e) | isId b = go (b:bs) e
71     go bs e                        = (reverse bs, e)
72
73 isAnnTypeArg :: AnnExpr b ann -> Bool
74 isAnnTypeArg (_, AnnType t) = True
75 isAnnTypeArg _              = False
76
77 mkDataConTagLit :: DataCon -> Literal
78 mkDataConTagLit con
79   = mkMachInt . toInteger $ dataConTag con - fIRST_TAG
80
81 mkDataConTag :: DataCon -> CoreExpr
82 mkDataConTag con = mkIntLitInt (dataConTag con - fIRST_TAG)
83
84 splitPrimTyCon :: Type -> Maybe TyCon
85 splitPrimTyCon ty
86   | Just (tycon, []) <- splitTyConApp_maybe ty
87   , isPrimTyCon tycon
88   = Just tycon
89
90   | otherwise = Nothing
91
92 mkBuiltinTyConApp :: (Builtins -> TyCon) -> [Type] -> VM Type
93 mkBuiltinTyConApp get_tc tys
94   = do
95       tc <- builtin get_tc
96       return $ mkTyConApp tc tys
97
98 mkBuiltinTyConApps :: (Builtins -> TyCon) -> [Type] -> Type -> VM Type
99 mkBuiltinTyConApps get_tc tys ty
100   = do
101       tc <- builtin get_tc
102       return $ foldr (mk tc) ty tys
103   where
104     mk tc ty1 ty2 = mkTyConApp tc [ty1,ty2]
105
106 mkBuiltinTyConApps1 :: (Builtins -> TyCon) -> Type -> [Type] -> VM Type
107 mkBuiltinTyConApps1 get_tc dft [] = return dft
108 mkBuiltinTyConApps1 get_tc dft tys
109   = do
110       tc <- builtin get_tc
111       case tys of
112         [] -> pprPanic "mkBuiltinTyConApps1" (ppr tc)
113         _  -> return $ foldr1 (mk tc) tys
114   where
115     mk tc ty1 ty2 = mkTyConApp tc [ty1,ty2]
116
117 mkClosureType :: Type -> Type -> VM Type
118 mkClosureType arg_ty res_ty = mkBuiltinTyConApp closureTyCon [arg_ty, res_ty]
119
120 mkClosureTypes :: [Type] -> Type -> VM Type
121 mkClosureTypes = mkBuiltinTyConApps closureTyCon
122
123 mkPReprType :: Type -> VM Type
124 mkPReprType ty = mkBuiltinTyConApp preprTyCon [ty]
125
126 mkPADictType :: Type -> VM Type
127 mkPADictType ty = mkBuiltinTyConApp paTyCon [ty]
128
129 mkPArrayType :: Type -> VM Type
130 mkPArrayType ty
131   | Just tycon <- splitPrimTyCon ty
132   = do
133       arr <- traceMaybeV "mkPArrayType" (ppr tycon)
134            $ lookupPrimPArray tycon
135       return $ mkTyConApp arr []
136 mkPArrayType ty = mkBuiltinTyConApp parrayTyCon [ty]
137
138 mkBuiltinCo :: (Builtins -> TyCon) -> VM Coercion
139 mkBuiltinCo get_tc
140   = do
141       tc <- builtin get_tc
142       return $ mkTyConApp tc []
143
144 parrayReprTyCon :: Type -> VM (TyCon, [Type])
145 parrayReprTyCon ty = builtin parrayTyCon >>= (`lookupFamInst` [ty])
146
147 parrayReprDataCon :: Type -> VM (DataCon, [Type])
148 parrayReprDataCon ty
149   = do
150       (tc, arg_tys) <- parrayReprTyCon ty
151       let [dc] = tyConDataCons tc
152       return (dc, arg_tys)
153
154 mkVScrut :: VExpr -> VM (VExpr, TyCon, [Type])
155 mkVScrut (ve, le)
156   = do
157       (tc, arg_tys) <- parrayReprTyCon (exprType ve)
158       return ((ve, unwrapFamInstScrut tc arg_tys le), tc, arg_tys)
159
160 prDFunOfTyCon :: TyCon -> VM CoreExpr
161 prDFunOfTyCon tycon
162   = liftM Var (traceMaybeV "prDictOfTyCon" (ppr tycon) (lookupTyConPR tycon))
163
164 paDictArgType :: TyVar -> VM (Maybe Type)
165 paDictArgType tv = go (TyVarTy tv) (tyVarKind tv)
166   where
167     go ty k | Just k' <- kindView k = go ty k'
168     go ty (FunTy k1 k2)
169       = do
170           tv   <- newTyVar FSLIT("a") k1
171           mty1 <- go (TyVarTy tv) k1
172           case mty1 of
173             Just ty1 -> do
174                           mty2 <- go (AppTy ty (TyVarTy tv)) k2
175                           return $ fmap (ForAllTy tv . FunTy ty1) mty2
176             Nothing  -> go ty k2
177
178     go ty k
179       | isLiftedTypeKind k
180       = liftM Just (mkPADictType ty)
181
182     go ty k = return Nothing
183
184 paDictOfType :: Type -> VM CoreExpr
185 paDictOfType ty = paDictOfTyApp ty_fn ty_args
186   where
187     (ty_fn, ty_args) = splitAppTys ty
188
189 paDictOfTyApp :: Type -> [Type] -> VM CoreExpr
190 paDictOfTyApp ty_fn ty_args
191   | Just ty_fn' <- coreView ty_fn = paDictOfTyApp ty_fn' ty_args
192 paDictOfTyApp (TyVarTy tv) ty_args
193   = do
194       dfun <- maybeV (lookupTyVarPA tv)
195       paDFunApply dfun ty_args
196 paDictOfTyApp (TyConApp tc _) ty_args
197   = do
198       dfun <- traceMaybeV "paDictOfTyApp" (ppr tc) (lookupTyConPA tc)
199       paDFunApply (Var dfun) ty_args
200 paDictOfTyApp ty ty_args = pprPanic "paDictOfTyApp" (ppr ty)
201
202 paDFunType :: TyCon -> VM Type
203 paDFunType tc
204   = do
205       margs <- mapM paDictArgType tvs
206       res   <- mkPADictType (mkTyConApp tc arg_tys)
207       return . mkForAllTys tvs
208              $ mkFunTys [arg | Just arg <- margs] res
209   where
210     tvs = tyConTyVars tc
211     arg_tys = mkTyVarTys tvs
212
213 paDFunApply :: CoreExpr -> [Type] -> VM CoreExpr
214 paDFunApply dfun tys
215   = do
216       dicts <- mapM paDictOfType tys
217       return $ mkApps (mkTyApps dfun tys) dicts
218
219 type PAMethod = (Builtins -> Var, String)
220
221 pa_length    = (lengthPAVar,    "lengthPA")
222 pa_replicate = (replicatePAVar, "replicatePA")
223 pa_empty     = (emptyPAVar,     "emptyPA")
224 pa_pack      = (packPAVar,      "packPA")
225
226 paMethod :: PAMethod -> Type -> VM CoreExpr
227 paMethod (method, name) ty
228   | Just tycon <- splitPrimTyCon ty
229   = do
230       fn <- traceMaybeV "paMethod" (ppr tycon <+> text name)
231           $ lookupPrimMethod tycon name
232       return (Var fn)
233
234 paMethod (method, name) ty
235   = do
236       fn   <- builtin method
237       dict <- paDictOfType ty
238       return $ mkApps (Var fn) [Type ty, dict]
239
240 mkPR :: Type -> VM CoreExpr
241 mkPR ty
242   = do
243       fn   <- builtin mkPRVar
244       dict <- paDictOfType ty
245       return $ mkApps (Var fn) [Type ty, dict]
246
247 lengthPA :: Type -> CoreExpr -> VM CoreExpr
248 lengthPA ty x = liftM (`App` x) (paMethod pa_length ty)
249
250 replicatePA :: CoreExpr -> CoreExpr -> VM CoreExpr
251 replicatePA len x = liftM (`mkApps` [len,x])
252                           (paMethod pa_replicate (exprType x))
253
254 emptyPA :: Type -> VM CoreExpr
255 emptyPA = paMethod pa_empty
256
257 packPA :: Type -> CoreExpr -> CoreExpr -> CoreExpr -> VM CoreExpr
258 packPA ty xs len sel = liftM (`mkApps` [len, sel])
259                              (paMethod pa_pack ty)
260
261 liftPA :: CoreExpr -> VM CoreExpr
262 liftPA x
263   = do
264       lc <- builtin liftingContext
265       replicatePA (Var lc) x
266
267 newLocalVVar :: FastString -> Type -> VM VVar
268 newLocalVVar fs vty
269   = do
270       lty <- mkPArrayType vty
271       vv  <- newLocalVar fs vty
272       lv  <- newLocalVar fs lty
273       return (vv,lv)
274
275 polyAbstract :: [TyVar] -> ((CoreExpr -> CoreExpr) -> VM a) -> VM a
276 polyAbstract tvs p
277   = localV
278   $ do
279       mdicts <- mapM mk_dict_var tvs
280       zipWithM_ (\tv -> maybe (defLocalTyVar tv) (defLocalTyVarWithPA tv . Var)) tvs mdicts
281       p (mk_lams mdicts)
282   where
283     mk_dict_var tv = do
284                        r <- paDictArgType tv
285                        case r of
286                          Just ty -> liftM Just (newLocalVar FSLIT("dPA") ty)
287                          Nothing -> return Nothing
288
289     mk_lams mdicts = mkLams (tvs ++ [dict | Just dict <- mdicts])
290
291 polyApply :: CoreExpr -> [Type] -> VM CoreExpr
292 polyApply expr tys
293   = do
294       dicts <- mapM paDictOfType tys
295       return $ expr `mkTyApps` tys `mkApps` dicts
296
297 polyVApply :: VExpr -> [Type] -> VM VExpr
298 polyVApply expr tys
299   = do
300       dicts <- mapM paDictOfType tys
301       return $ mapVect (\e -> e `mkTyApps` tys `mkApps` dicts) expr
302
303 hoistBinding :: Var -> CoreExpr -> VM ()
304 hoistBinding v e = updGEnv $ \env ->
305   env { global_bindings = (v,e) : global_bindings env }
306
307 hoistExpr :: FastString -> CoreExpr -> VM Var
308 hoistExpr fs expr
309   = do
310       var <- newLocalVar fs (exprType expr)
311       hoistBinding var expr
312       return var
313
314 hoistVExpr :: VExpr -> VM VVar
315 hoistVExpr (ve, le)
316   = do
317       fs <- getBindName
318       vv <- hoistExpr ('v' `consFS` fs) ve
319       lv <- hoistExpr ('l' `consFS` fs) le
320       return (vv, lv)
321
322 hoistPolyVExpr :: [TyVar] -> VM VExpr -> VM VExpr
323 hoistPolyVExpr tvs p
324   = do
325       expr <- closedV . polyAbstract tvs $ \abstract ->
326               liftM (mapVect abstract) p
327       fn   <- hoistVExpr expr
328       polyVApply (vVar fn) (mkTyVarTys tvs)
329
330 takeHoisted :: VM [(Var, CoreExpr)]
331 takeHoisted
332   = do
333       env <- readGEnv id
334       setGEnv $ env { global_bindings = [] }
335       return $ global_bindings env
336
337 mkClosure :: Type -> Type -> Type -> VExpr -> VExpr -> VM VExpr
338 mkClosure arg_ty res_ty env_ty (vfn,lfn) (venv,lenv)
339   = do
340       dict <- paDictOfType env_ty
341       mkv  <- builtin mkClosureVar
342       mkl  <- builtin mkClosurePVar
343       return (Var mkv `mkTyApps` [arg_ty, res_ty, env_ty] `mkApps` [dict, vfn, lfn, venv],
344               Var mkl `mkTyApps` [arg_ty, res_ty, env_ty] `mkApps` [dict, vfn, lfn, lenv])
345
346 mkClosureApp :: Type -> Type -> VExpr -> VExpr -> VM VExpr
347 mkClosureApp arg_ty res_ty (vclo, lclo) (varg, larg)
348   = do
349       vapply <- builtin applyClosureVar
350       lapply <- builtin applyClosurePVar
351       return (Var vapply `mkTyApps` [arg_ty, res_ty] `mkApps` [vclo, varg],
352               Var lapply `mkTyApps` [arg_ty, res_ty] `mkApps` [lclo, larg])
353
354 buildClosures :: [TyVar] -> [VVar] -> [Type] -> Type -> VM VExpr -> VM VExpr
355 buildClosures tvs vars [] res_ty mk_body
356   = mk_body
357 buildClosures tvs vars [arg_ty] res_ty mk_body
358   = buildClosure tvs vars arg_ty res_ty mk_body
359 buildClosures tvs vars (arg_ty : arg_tys) res_ty mk_body
360   = do
361       res_ty' <- mkClosureTypes arg_tys res_ty
362       arg <- newLocalVVar FSLIT("x") arg_ty
363       buildClosure tvs vars arg_ty res_ty'
364         . hoistPolyVExpr tvs
365         $ do
366             lc <- builtin liftingContext
367             clo <- buildClosures tvs (vars ++ [arg]) arg_tys res_ty mk_body
368             return $ vLams lc (vars ++ [arg]) clo
369
370 -- (clo <x1,...,xn> <f,f^>, aclo (Arr lc xs1 ... xsn) <f,f^>)
371 --   where
372 --     f  = \env v -> case env of <x1,...,xn> -> e x1 ... xn v
373 --     f^ = \env v -> case env of Arr l xs1 ... xsn -> e^ l x1 ... xn v
374 --
375 buildClosure :: [TyVar] -> [VVar] -> Type -> Type -> VM VExpr -> VM VExpr
376 buildClosure tvs vars arg_ty res_ty mk_body
377   = do
378       (env_ty, env, bind) <- buildEnv vars
379       env_bndr <- newLocalVVar FSLIT("env") env_ty
380       arg_bndr <- newLocalVVar FSLIT("arg") arg_ty
381
382       fn <- hoistPolyVExpr tvs
383           $ do
384               lc    <- builtin liftingContext
385               body  <- mk_body
386               body' <- bind (vVar env_bndr)
387                             (vVarApps lc body (vars ++ [arg_bndr]))
388               return (vLamsWithoutLC [env_bndr, arg_bndr] body')
389
390       mkClosure arg_ty res_ty env_ty fn env
391
392 buildEnv :: [VVar] -> VM (Type, VExpr, VExpr -> VExpr -> VM VExpr)
393 buildEnv vvs
394   = do
395       lc <- builtin liftingContext
396       let (ty, venv, vbind) = mkVectEnv tys vs
397       (lenv, lbind) <- mkLiftEnv lc tys ls
398       return (ty, (venv, lenv),
399               \(venv,lenv) (vbody,lbody) ->
400               do
401                 let vbody' = vbind venv vbody
402                 lbody' <- lbind lenv lbody
403                 return (vbody', lbody'))
404   where
405     (vs,ls) = unzip vvs
406     tys     = map idType vs
407
408 mkVectEnv :: [Type] -> [Var] -> (Type, CoreExpr, CoreExpr -> CoreExpr -> CoreExpr)
409 mkVectEnv []   []  = (unitTy, Var unitDataConId, \env body -> body)
410 mkVectEnv [ty] [v] = (ty, Var v, \env body -> Let (NonRec v env) body)
411 mkVectEnv tys  vs  = (ty, mkCoreTup (map Var vs),
412                         \env body -> Case env (mkWildId ty) (exprType body)
413                                        [(DataAlt (tupleCon Boxed (length vs)), vs, body)])
414   where
415     ty = mkCoreTupTy tys
416
417 mkLiftEnv :: Var -> [Type] -> [Var] -> VM (CoreExpr, CoreExpr -> CoreExpr -> VM CoreExpr)
418 mkLiftEnv lc [ty] [v]
419   = return (Var v, \env body ->
420                    do
421                      len <- lengthPA ty (Var v)
422                      return . Let (NonRec v env)
423                             $ Case len lc (exprType body) [(DEFAULT, [], body)])
424
425 -- NOTE: this transparently deals with empty environments
426 mkLiftEnv lc tys vs
427   = do
428       (env_tc, env_tyargs) <- parrayReprTyCon vty
429       let [env_con] = tyConDataCons env_tc
430           
431           env = Var (dataConWrapId env_con)
432                 `mkTyApps`  env_tyargs
433                 `mkVarApps` (lc : vs)
434
435           bind env body = let scrut = unwrapFamInstScrut env_tc env_tyargs env
436                           in
437                           return $ Case scrut (mkWildId (exprType scrut))
438                                         (exprType body)
439                                         [(DataAlt env_con, lc : bndrs, body)]
440       return (env, bind)
441   where
442     vty = mkCoreTupTy tys
443
444     bndrs | null vs   = [mkWildId unitTy]
445           | otherwise = vs
446