More vectorisation-related built-ins
[ghc-hetmet.git] / compiler / vectorise / VectUtils.hs
1 {-# OPTIONS -w #-}
2 -- The above warning supression flag is a temporary kludge.
3 -- While working on this module you are encouraged to remove it and fix
4 -- any warnings in the module. See
5 --     http://hackage.haskell.org/trac/ghc/wiki/Commentary/CodingStyle#Warnings
6 -- for details
7
8 module VectUtils (
9   collectAnnTypeBinders, collectAnnTypeArgs, isAnnTypeArg,
10   collectAnnValBinders,
11   dataConTagZ, mkDataConTag, mkDataConTagLit,
12
13   newLocalVVar,
14
15   mkBuiltinCo,
16   mkPADictType, mkPArrayType, mkPReprType,
17
18   parrayReprTyCon, parrayReprDataCon, mkVScrut,
19   prDFunOfTyCon,
20   paDictArgType, paDictOfType, paDFunType,
21   paMethod, mkPR, lengthPA, replicatePA, emptyPA, packPA, combinePA, liftPA,
22   polyAbstract, polyApply, polyVApply,
23   hoistBinding, hoistExpr, hoistPolyVExpr, takeHoisted,
24   buildClosure, buildClosures,
25   mkClosureApp
26 ) where
27
28 #include "HsVersions.h"
29
30 import VectCore
31 import VectMonad
32
33 import DsUtils
34 import CoreSyn
35 import CoreUtils
36 import Coercion
37 import Type
38 import TypeRep
39 import TyCon
40 import DataCon
41 import Var
42 import Id                 ( mkWildId )
43 import MkId               ( unwrapFamInstScrut )
44 import Name               ( Name )
45 import PrelNames
46 import TysWiredIn
47 import TysPrim            ( intPrimTy )
48 import BasicTypes         ( Boxity(..) )
49 import Literal            ( Literal, mkMachInt )
50
51 import Outputable
52 import FastString
53
54 import Data.List             ( zipWith4 )
55 import Control.Monad         ( liftM, liftM2, zipWithM_ )
56
57 collectAnnTypeArgs :: AnnExpr b ann -> (AnnExpr b ann, [Type])
58 collectAnnTypeArgs expr = go expr []
59   where
60     go (_, AnnApp f (_, AnnType ty)) tys = go f (ty : tys)
61     go e                             tys = (e, tys)
62
63 collectAnnTypeBinders :: AnnExpr Var ann -> ([Var], AnnExpr Var ann)
64 collectAnnTypeBinders expr = go [] expr
65   where
66     go bs (_, AnnLam b e) | isTyVar b = go (b:bs) e
67     go bs e                           = (reverse bs, e)
68
69 collectAnnValBinders :: AnnExpr Var ann -> ([Var], AnnExpr Var ann)
70 collectAnnValBinders expr = go [] expr
71   where
72     go bs (_, AnnLam b e) | isId b = go (b:bs) e
73     go bs e                        = (reverse bs, e)
74
75 isAnnTypeArg :: AnnExpr b ann -> Bool
76 isAnnTypeArg (_, AnnType t) = True
77 isAnnTypeArg _              = False
78
79 dataConTagZ :: DataCon -> Int
80 dataConTagZ con = dataConTag con - fIRST_TAG
81
82 mkDataConTagLit :: DataCon -> Literal
83 mkDataConTagLit = mkMachInt . toInteger . dataConTagZ
84
85 mkDataConTag :: DataCon -> CoreExpr
86 mkDataConTag = mkIntLitInt . dataConTagZ
87
88 splitPrimTyCon :: Type -> Maybe TyCon
89 splitPrimTyCon ty
90   | Just (tycon, []) <- splitTyConApp_maybe ty
91   , isPrimTyCon tycon
92   = Just tycon
93
94   | otherwise = Nothing
95
96 mkBuiltinTyConApp :: (Builtins -> TyCon) -> [Type] -> VM Type
97 mkBuiltinTyConApp get_tc tys
98   = do
99       tc <- builtin get_tc
100       return $ mkTyConApp tc tys
101
102 mkBuiltinTyConApps :: (Builtins -> TyCon) -> [Type] -> Type -> VM Type
103 mkBuiltinTyConApps get_tc tys ty
104   = do
105       tc <- builtin get_tc
106       return $ foldr (mk tc) ty tys
107   where
108     mk tc ty1 ty2 = mkTyConApp tc [ty1,ty2]
109
110 mkBuiltinTyConApps1 :: (Builtins -> TyCon) -> Type -> [Type] -> VM Type
111 mkBuiltinTyConApps1 get_tc dft [] = return dft
112 mkBuiltinTyConApps1 get_tc dft tys
113   = do
114       tc <- builtin get_tc
115       case tys of
116         [] -> pprPanic "mkBuiltinTyConApps1" (ppr tc)
117         _  -> return $ foldr1 (mk tc) tys
118   where
119     mk tc ty1 ty2 = mkTyConApp tc [ty1,ty2]
120
121 mkClosureType :: Type -> Type -> VM Type
122 mkClosureType arg_ty res_ty = mkBuiltinTyConApp closureTyCon [arg_ty, res_ty]
123
124 mkClosureTypes :: [Type] -> Type -> VM Type
125 mkClosureTypes = mkBuiltinTyConApps closureTyCon
126
127 mkPReprType :: Type -> VM Type
128 mkPReprType ty = mkBuiltinTyConApp preprTyCon [ty]
129
130 mkPADictType :: Type -> VM Type
131 mkPADictType ty = mkBuiltinTyConApp paTyCon [ty]
132
133 mkPArrayType :: Type -> VM Type
134 mkPArrayType ty
135   | Just tycon <- splitPrimTyCon ty
136   = do
137       arr <- traceMaybeV "mkPArrayType" (ppr tycon)
138            $ lookupPrimPArray tycon
139       return $ mkTyConApp arr []
140 mkPArrayType ty = mkBuiltinTyConApp parrayTyCon [ty]
141
142 mkBuiltinCo :: (Builtins -> TyCon) -> VM Coercion
143 mkBuiltinCo get_tc
144   = do
145       tc <- builtin get_tc
146       return $ mkTyConApp tc []
147
148 parrayReprTyCon :: Type -> VM (TyCon, [Type])
149 parrayReprTyCon ty = builtin parrayTyCon >>= (`lookupFamInst` [ty])
150
151 parrayReprDataCon :: Type -> VM (DataCon, [Type])
152 parrayReprDataCon ty
153   = do
154       (tc, arg_tys) <- parrayReprTyCon ty
155       let [dc] = tyConDataCons tc
156       return (dc, arg_tys)
157
158 mkVScrut :: VExpr -> VM (VExpr, TyCon, [Type])
159 mkVScrut (ve, le)
160   = do
161       (tc, arg_tys) <- parrayReprTyCon (exprType ve)
162       return ((ve, unwrapFamInstScrut tc arg_tys le), tc, arg_tys)
163
164 prDFunOfTyCon :: TyCon -> VM CoreExpr
165 prDFunOfTyCon tycon
166   = liftM Var (traceMaybeV "prDictOfTyCon" (ppr tycon) (lookupTyConPR tycon))
167
168 paDictArgType :: TyVar -> VM (Maybe Type)
169 paDictArgType tv = go (TyVarTy tv) (tyVarKind tv)
170   where
171     go ty k | Just k' <- kindView k = go ty k'
172     go ty (FunTy k1 k2)
173       = do
174           tv   <- newTyVar FSLIT("a") k1
175           mty1 <- go (TyVarTy tv) k1
176           case mty1 of
177             Just ty1 -> do
178                           mty2 <- go (AppTy ty (TyVarTy tv)) k2
179                           return $ fmap (ForAllTy tv . FunTy ty1) mty2
180             Nothing  -> go ty k2
181
182     go ty k
183       | isLiftedTypeKind k
184       = liftM Just (mkPADictType ty)
185
186     go ty k = return Nothing
187
188 paDictOfType :: Type -> VM CoreExpr
189 paDictOfType ty = paDictOfTyApp ty_fn ty_args
190   where
191     (ty_fn, ty_args) = splitAppTys ty
192
193 paDictOfTyApp :: Type -> [Type] -> VM CoreExpr
194 paDictOfTyApp ty_fn ty_args
195   | Just ty_fn' <- coreView ty_fn = paDictOfTyApp ty_fn' ty_args
196 paDictOfTyApp (TyVarTy tv) ty_args
197   = do
198       dfun <- maybeV (lookupTyVarPA tv)
199       paDFunApply dfun ty_args
200 paDictOfTyApp (TyConApp tc _) ty_args
201   = do
202       dfun <- traceMaybeV "paDictOfTyApp" (ppr tc) (lookupTyConPA tc)
203       paDFunApply (Var dfun) ty_args
204 paDictOfTyApp ty ty_args = pprPanic "paDictOfTyApp" (ppr ty)
205
206 paDFunType :: TyCon -> VM Type
207 paDFunType tc
208   = do
209       margs <- mapM paDictArgType tvs
210       res   <- mkPADictType (mkTyConApp tc arg_tys)
211       return . mkForAllTys tvs
212              $ mkFunTys [arg | Just arg <- margs] res
213   where
214     tvs = tyConTyVars tc
215     arg_tys = mkTyVarTys tvs
216
217 paDFunApply :: CoreExpr -> [Type] -> VM CoreExpr
218 paDFunApply dfun tys
219   = do
220       dicts <- mapM paDictOfType tys
221       return $ mkApps (mkTyApps dfun tys) dicts
222
223 type PAMethod = (Builtins -> Var, String)
224
225 pa_length    = (lengthPAVar,    "lengthPA")
226 pa_replicate = (replicatePAVar, "replicatePA")
227 pa_empty     = (emptyPAVar,     "emptyPA")
228 pa_pack      = (packPAVar,      "packPA")
229
230 paMethod :: PAMethod -> Type -> VM CoreExpr
231 paMethod (method, name) ty
232   | Just tycon <- splitPrimTyCon ty
233   = do
234       fn <- traceMaybeV "paMethod" (ppr tycon <+> text name)
235           $ lookupPrimMethod tycon name
236       return (Var fn)
237
238 paMethod (method, name) ty
239   = do
240       fn   <- builtin method
241       dict <- paDictOfType ty
242       return $ mkApps (Var fn) [Type ty, dict]
243
244 mkPR :: Type -> VM CoreExpr
245 mkPR ty
246   = do
247       fn   <- builtin mkPRVar
248       dict <- paDictOfType ty
249       return $ mkApps (Var fn) [Type ty, dict]
250
251 lengthPA :: Type -> CoreExpr -> VM CoreExpr
252 lengthPA ty x = liftM (`App` x) (paMethod pa_length ty)
253
254 replicatePA :: CoreExpr -> CoreExpr -> VM CoreExpr
255 replicatePA len x = liftM (`mkApps` [len,x])
256                           (paMethod pa_replicate (exprType x))
257
258 emptyPA :: Type -> VM CoreExpr
259 emptyPA = paMethod pa_empty
260
261 packPA :: Type -> CoreExpr -> CoreExpr -> CoreExpr -> VM CoreExpr
262 packPA ty xs len sel = liftM (`mkApps` [len, sel])
263                              (paMethod pa_pack ty)
264
265 combinePA :: Type -> CoreExpr -> CoreExpr -> CoreExpr -> [CoreExpr]
266           -> VM CoreExpr
267 combinePA ty len sel is xs
268   = liftM (`mkApps` (len : sel : is : xs))
269           (paMethod (combinePAVar n, "combine" ++ show n ++ "PA") ty)
270   where
271     n = length xs
272
273 liftPA :: CoreExpr -> VM CoreExpr
274 liftPA x
275   = do
276       lc <- builtin liftingContext
277       replicatePA (Var lc) x
278
279 newLocalVVar :: FastString -> Type -> VM VVar
280 newLocalVVar fs vty
281   = do
282       lty <- mkPArrayType vty
283       vv  <- newLocalVar fs vty
284       lv  <- newLocalVar fs lty
285       return (vv,lv)
286
287 polyAbstract :: [TyVar] -> ((CoreExpr -> CoreExpr) -> VM a) -> VM a
288 polyAbstract tvs p
289   = localV
290   $ do
291       mdicts <- mapM mk_dict_var tvs
292       zipWithM_ (\tv -> maybe (defLocalTyVar tv) (defLocalTyVarWithPA tv . Var)) tvs mdicts
293       p (mk_lams mdicts)
294   where
295     mk_dict_var tv = do
296                        r <- paDictArgType tv
297                        case r of
298                          Just ty -> liftM Just (newLocalVar FSLIT("dPA") ty)
299                          Nothing -> return Nothing
300
301     mk_lams mdicts = mkLams (tvs ++ [dict | Just dict <- mdicts])
302
303 polyApply :: CoreExpr -> [Type] -> VM CoreExpr
304 polyApply expr tys
305   = do
306       dicts <- mapM paDictOfType tys
307       return $ expr `mkTyApps` tys `mkApps` dicts
308
309 polyVApply :: VExpr -> [Type] -> VM VExpr
310 polyVApply expr tys
311   = do
312       dicts <- mapM paDictOfType tys
313       return $ mapVect (\e -> e `mkTyApps` tys `mkApps` dicts) expr
314
315 hoistBinding :: Var -> CoreExpr -> VM ()
316 hoistBinding v e = updGEnv $ \env ->
317   env { global_bindings = (v,e) : global_bindings env }
318
319 hoistExpr :: FastString -> CoreExpr -> VM Var
320 hoistExpr fs expr
321   = do
322       var <- newLocalVar fs (exprType expr)
323       hoistBinding var expr
324       return var
325
326 hoistVExpr :: VExpr -> VM VVar
327 hoistVExpr (ve, le)
328   = do
329       fs <- getBindName
330       vv <- hoistExpr ('v' `consFS` fs) ve
331       lv <- hoistExpr ('l' `consFS` fs) le
332       return (vv, lv)
333
334 hoistPolyVExpr :: [TyVar] -> VM VExpr -> VM VExpr
335 hoistPolyVExpr tvs p
336   = do
337       expr <- closedV . polyAbstract tvs $ \abstract ->
338               liftM (mapVect abstract) p
339       fn   <- hoistVExpr expr
340       polyVApply (vVar fn) (mkTyVarTys tvs)
341
342 takeHoisted :: VM [(Var, CoreExpr)]
343 takeHoisted
344   = do
345       env <- readGEnv id
346       setGEnv $ env { global_bindings = [] }
347       return $ global_bindings env
348
349 mkClosure :: Type -> Type -> Type -> VExpr -> VExpr -> VM VExpr
350 mkClosure arg_ty res_ty env_ty (vfn,lfn) (venv,lenv)
351   = do
352       dict <- paDictOfType env_ty
353       mkv  <- builtin mkClosureVar
354       mkl  <- builtin mkClosurePVar
355       return (Var mkv `mkTyApps` [arg_ty, res_ty, env_ty] `mkApps` [dict, vfn, lfn, venv],
356               Var mkl `mkTyApps` [arg_ty, res_ty, env_ty] `mkApps` [dict, vfn, lfn, lenv])
357
358 mkClosureApp :: Type -> Type -> VExpr -> VExpr -> VM VExpr
359 mkClosureApp arg_ty res_ty (vclo, lclo) (varg, larg)
360   = do
361       vapply <- builtin applyClosureVar
362       lapply <- builtin applyClosurePVar
363       return (Var vapply `mkTyApps` [arg_ty, res_ty] `mkApps` [vclo, varg],
364               Var lapply `mkTyApps` [arg_ty, res_ty] `mkApps` [lclo, larg])
365
366 buildClosures :: [TyVar] -> [VVar] -> [Type] -> Type -> VM VExpr -> VM VExpr
367 buildClosures tvs vars [] res_ty mk_body
368   = mk_body
369 buildClosures tvs vars [arg_ty] res_ty mk_body
370   = buildClosure tvs vars arg_ty res_ty mk_body
371 buildClosures tvs vars (arg_ty : arg_tys) res_ty mk_body
372   = do
373       res_ty' <- mkClosureTypes arg_tys res_ty
374       arg <- newLocalVVar FSLIT("x") arg_ty
375       buildClosure tvs vars arg_ty res_ty'
376         . hoistPolyVExpr tvs
377         $ do
378             lc <- builtin liftingContext
379             clo <- buildClosures tvs (vars ++ [arg]) arg_tys res_ty mk_body
380             return $ vLams lc (vars ++ [arg]) clo
381
382 -- (clo <x1,...,xn> <f,f^>, aclo (Arr lc xs1 ... xsn) <f,f^>)
383 --   where
384 --     f  = \env v -> case env of <x1,...,xn> -> e x1 ... xn v
385 --     f^ = \env v -> case env of Arr l xs1 ... xsn -> e^ l x1 ... xn v
386 --
387 buildClosure :: [TyVar] -> [VVar] -> Type -> Type -> VM VExpr -> VM VExpr
388 buildClosure tvs vars arg_ty res_ty mk_body
389   = do
390       (env_ty, env, bind) <- buildEnv vars
391       env_bndr <- newLocalVVar FSLIT("env") env_ty
392       arg_bndr <- newLocalVVar FSLIT("arg") arg_ty
393
394       fn <- hoistPolyVExpr tvs
395           $ do
396               lc    <- builtin liftingContext
397               body  <- mk_body
398               body' <- bind (vVar env_bndr)
399                             (vVarApps lc body (vars ++ [arg_bndr]))
400               return (vLamsWithoutLC [env_bndr, arg_bndr] body')
401
402       mkClosure arg_ty res_ty env_ty fn env
403
404 buildEnv :: [VVar] -> VM (Type, VExpr, VExpr -> VExpr -> VM VExpr)
405 buildEnv vvs
406   = do
407       lc <- builtin liftingContext
408       let (ty, venv, vbind) = mkVectEnv tys vs
409       (lenv, lbind) <- mkLiftEnv lc tys ls
410       return (ty, (venv, lenv),
411               \(venv,lenv) (vbody,lbody) ->
412               do
413                 let vbody' = vbind venv vbody
414                 lbody' <- lbind lenv lbody
415                 return (vbody', lbody'))
416   where
417     (vs,ls) = unzip vvs
418     tys     = map idType vs
419
420 mkVectEnv :: [Type] -> [Var] -> (Type, CoreExpr, CoreExpr -> CoreExpr -> CoreExpr)
421 mkVectEnv []   []  = (unitTy, Var unitDataConId, \env body -> body)
422 mkVectEnv [ty] [v] = (ty, Var v, \env body -> Let (NonRec v env) body)
423 mkVectEnv tys  vs  = (ty, mkCoreTup (map Var vs),
424                         \env body -> Case env (mkWildId ty) (exprType body)
425                                        [(DataAlt (tupleCon Boxed (length vs)), vs, body)])
426   where
427     ty = mkCoreTupTy tys
428
429 mkLiftEnv :: Var -> [Type] -> [Var] -> VM (CoreExpr, CoreExpr -> CoreExpr -> VM CoreExpr)
430 mkLiftEnv lc [ty] [v]
431   = return (Var v, \env body ->
432                    do
433                      len <- lengthPA ty (Var v)
434                      return . Let (NonRec v env)
435                             $ Case len lc (exprType body) [(DEFAULT, [], body)])
436
437 -- NOTE: this transparently deals with empty environments
438 mkLiftEnv lc tys vs
439   = do
440       (env_tc, env_tyargs) <- parrayReprTyCon vty
441       let [env_con] = tyConDataCons env_tc
442           
443           env = Var (dataConWrapId env_con)
444                 `mkTyApps`  env_tyargs
445                 `mkVarApps` (lc : vs)
446
447           bind env body = let scrut = unwrapFamInstScrut env_tc env_tyargs env
448                           in
449                           return $ Case scrut (mkWildId (exprType scrut))
450                                         (exprType body)
451                                         [(DataAlt env_con, lc : bndrs, body)]
452       return (env, bind)
453   where
454     vty = mkCoreTupTy tys
455
456     bndrs | null vs   = [mkWildId unitTy]
457           | otherwise = vs
458