Vectorisation utilities
[ghc-hetmet.git] / compiler / vectorise / VectUtils.hs
1 {-# OPTIONS -w #-}
2 -- The above warning supression flag is a temporary kludge.
3 -- While working on this module you are encouraged to remove it and fix
4 -- any warnings in the module. See
5 --     http://hackage.haskell.org/trac/ghc/wiki/Commentary/CodingStyle#Warnings
6 -- for details
7
8 module VectUtils (
9   collectAnnTypeBinders, collectAnnTypeArgs, isAnnTypeArg,
10   collectAnnValBinders,
11   dataConTagZ, mkDataConTag, mkDataConTagLit,
12
13   newLocalVVar,
14
15   mkBuiltinCo,
16   mkPADictType, mkPArrayType, mkPReprType,
17
18   parrayReprTyCon, parrayReprDataCon, mkVScrut,
19   prDFunOfTyCon,
20   paDictArgType, paDictOfType, paDFunType,
21   paMethod, mkPR, lengthPA, replicatePA, emptyPA, packPA, liftPA,
22   polyAbstract, polyApply, polyVApply,
23   hoistBinding, hoistExpr, hoistPolyVExpr, takeHoisted,
24   buildClosure, buildClosures,
25   mkClosureApp
26 ) where
27
28 #include "HsVersions.h"
29
30 import VectCore
31 import VectMonad
32
33 import DsUtils
34 import CoreSyn
35 import CoreUtils
36 import Coercion
37 import Type
38 import TypeRep
39 import TyCon
40 import DataCon
41 import Var
42 import Id                 ( mkWildId )
43 import MkId               ( unwrapFamInstScrut )
44 import Name               ( Name )
45 import PrelNames
46 import TysWiredIn
47 import TysPrim            ( intPrimTy )
48 import BasicTypes         ( Boxity(..) )
49 import Literal            ( Literal, mkMachInt )
50
51 import Outputable
52 import FastString
53
54 import Data.List             ( zipWith4 )
55 import Control.Monad         ( liftM, liftM2, zipWithM_ )
56
57 collectAnnTypeArgs :: AnnExpr b ann -> (AnnExpr b ann, [Type])
58 collectAnnTypeArgs expr = go expr []
59   where
60     go (_, AnnApp f (_, AnnType ty)) tys = go f (ty : tys)
61     go e                             tys = (e, tys)
62
63 collectAnnTypeBinders :: AnnExpr Var ann -> ([Var], AnnExpr Var ann)
64 collectAnnTypeBinders expr = go [] expr
65   where
66     go bs (_, AnnLam b e) | isTyVar b = go (b:bs) e
67     go bs e                           = (reverse bs, e)
68
69 collectAnnValBinders :: AnnExpr Var ann -> ([Var], AnnExpr Var ann)
70 collectAnnValBinders expr = go [] expr
71   where
72     go bs (_, AnnLam b e) | isId b = go (b:bs) e
73     go bs e                        = (reverse bs, e)
74
75 isAnnTypeArg :: AnnExpr b ann -> Bool
76 isAnnTypeArg (_, AnnType t) = True
77 isAnnTypeArg _              = False
78
79 dataConTagZ :: DataCon -> Int
80 dataConTagZ con = dataConTag con - fIRST_TAG
81
82 mkDataConTagLit :: DataCon -> Literal
83 mkDataConTagLit = mkMachInt . toInteger . dataConTagZ
84
85 mkDataConTag :: DataCon -> CoreExpr
86 mkDataConTag = mkIntLitInt . dataConTagZ
87
88 splitPrimTyCon :: Type -> Maybe TyCon
89 splitPrimTyCon ty
90   | Just (tycon, []) <- splitTyConApp_maybe ty
91   , isPrimTyCon tycon
92   = Just tycon
93
94   | otherwise = Nothing
95
96 mkBuiltinTyConApp :: (Builtins -> TyCon) -> [Type] -> VM Type
97 mkBuiltinTyConApp get_tc tys
98   = do
99       tc <- builtin get_tc
100       return $ mkTyConApp tc tys
101
102 mkBuiltinTyConApps :: (Builtins -> TyCon) -> [Type] -> Type -> VM Type
103 mkBuiltinTyConApps get_tc tys ty
104   = do
105       tc <- builtin get_tc
106       return $ foldr (mk tc) ty tys
107   where
108     mk tc ty1 ty2 = mkTyConApp tc [ty1,ty2]
109
110 mkBuiltinTyConApps1 :: (Builtins -> TyCon) -> Type -> [Type] -> VM Type
111 mkBuiltinTyConApps1 get_tc dft [] = return dft
112 mkBuiltinTyConApps1 get_tc dft tys
113   = do
114       tc <- builtin get_tc
115       case tys of
116         [] -> pprPanic "mkBuiltinTyConApps1" (ppr tc)
117         _  -> return $ foldr1 (mk tc) tys
118   where
119     mk tc ty1 ty2 = mkTyConApp tc [ty1,ty2]
120
121 mkClosureType :: Type -> Type -> VM Type
122 mkClosureType arg_ty res_ty = mkBuiltinTyConApp closureTyCon [arg_ty, res_ty]
123
124 mkClosureTypes :: [Type] -> Type -> VM Type
125 mkClosureTypes = mkBuiltinTyConApps closureTyCon
126
127 mkPReprType :: Type -> VM Type
128 mkPReprType ty = mkBuiltinTyConApp preprTyCon [ty]
129
130 mkPADictType :: Type -> VM Type
131 mkPADictType ty = mkBuiltinTyConApp paTyCon [ty]
132
133 mkPArrayType :: Type -> VM Type
134 mkPArrayType ty
135   | Just tycon <- splitPrimTyCon ty
136   = do
137       arr <- traceMaybeV "mkPArrayType" (ppr tycon)
138            $ lookupPrimPArray tycon
139       return $ mkTyConApp arr []
140 mkPArrayType ty = mkBuiltinTyConApp parrayTyCon [ty]
141
142 mkBuiltinCo :: (Builtins -> TyCon) -> VM Coercion
143 mkBuiltinCo get_tc
144   = do
145       tc <- builtin get_tc
146       return $ mkTyConApp tc []
147
148 parrayReprTyCon :: Type -> VM (TyCon, [Type])
149 parrayReprTyCon ty = builtin parrayTyCon >>= (`lookupFamInst` [ty])
150
151 parrayReprDataCon :: Type -> VM (DataCon, [Type])
152 parrayReprDataCon ty
153   = do
154       (tc, arg_tys) <- parrayReprTyCon ty
155       let [dc] = tyConDataCons tc
156       return (dc, arg_tys)
157
158 mkVScrut :: VExpr -> VM (VExpr, TyCon, [Type])
159 mkVScrut (ve, le)
160   = do
161       (tc, arg_tys) <- parrayReprTyCon (exprType ve)
162       return ((ve, unwrapFamInstScrut tc arg_tys le), tc, arg_tys)
163
164 prDFunOfTyCon :: TyCon -> VM CoreExpr
165 prDFunOfTyCon tycon
166   = liftM Var (traceMaybeV "prDictOfTyCon" (ppr tycon) (lookupTyConPR tycon))
167
168 paDictArgType :: TyVar -> VM (Maybe Type)
169 paDictArgType tv = go (TyVarTy tv) (tyVarKind tv)
170   where
171     go ty k | Just k' <- kindView k = go ty k'
172     go ty (FunTy k1 k2)
173       = do
174           tv   <- newTyVar FSLIT("a") k1
175           mty1 <- go (TyVarTy tv) k1
176           case mty1 of
177             Just ty1 -> do
178                           mty2 <- go (AppTy ty (TyVarTy tv)) k2
179                           return $ fmap (ForAllTy tv . FunTy ty1) mty2
180             Nothing  -> go ty k2
181
182     go ty k
183       | isLiftedTypeKind k
184       = liftM Just (mkPADictType ty)
185
186     go ty k = return Nothing
187
188 paDictOfType :: Type -> VM CoreExpr
189 paDictOfType ty = paDictOfTyApp ty_fn ty_args
190   where
191     (ty_fn, ty_args) = splitAppTys ty
192
193 paDictOfTyApp :: Type -> [Type] -> VM CoreExpr
194 paDictOfTyApp ty_fn ty_args
195   | Just ty_fn' <- coreView ty_fn = paDictOfTyApp ty_fn' ty_args
196 paDictOfTyApp (TyVarTy tv) ty_args
197   = do
198       dfun <- maybeV (lookupTyVarPA tv)
199       paDFunApply dfun ty_args
200 paDictOfTyApp (TyConApp tc _) ty_args
201   = do
202       dfun <- traceMaybeV "paDictOfTyApp" (ppr tc) (lookupTyConPA tc)
203       paDFunApply (Var dfun) ty_args
204 paDictOfTyApp ty ty_args = pprPanic "paDictOfTyApp" (ppr ty)
205
206 paDFunType :: TyCon -> VM Type
207 paDFunType tc
208   = do
209       margs <- mapM paDictArgType tvs
210       res   <- mkPADictType (mkTyConApp tc arg_tys)
211       return . mkForAllTys tvs
212              $ mkFunTys [arg | Just arg <- margs] res
213   where
214     tvs = tyConTyVars tc
215     arg_tys = mkTyVarTys tvs
216
217 paDFunApply :: CoreExpr -> [Type] -> VM CoreExpr
218 paDFunApply dfun tys
219   = do
220       dicts <- mapM paDictOfType tys
221       return $ mkApps (mkTyApps dfun tys) dicts
222
223 type PAMethod = (Builtins -> Var, String)
224
225 pa_length    = (lengthPAVar,    "lengthPA")
226 pa_replicate = (replicatePAVar, "replicatePA")
227 pa_empty     = (emptyPAVar,     "emptyPA")
228 pa_pack      = (packPAVar,      "packPA")
229
230 paMethod :: PAMethod -> Type -> VM CoreExpr
231 paMethod (method, name) ty
232   | Just tycon <- splitPrimTyCon ty
233   = do
234       fn <- traceMaybeV "paMethod" (ppr tycon <+> text name)
235           $ lookupPrimMethod tycon name
236       return (Var fn)
237
238 paMethod (method, name) ty
239   = do
240       fn   <- builtin method
241       dict <- paDictOfType ty
242       return $ mkApps (Var fn) [Type ty, dict]
243
244 mkPR :: Type -> VM CoreExpr
245 mkPR ty
246   = do
247       fn   <- builtin mkPRVar
248       dict <- paDictOfType ty
249       return $ mkApps (Var fn) [Type ty, dict]
250
251 lengthPA :: Type -> CoreExpr -> VM CoreExpr
252 lengthPA ty x = liftM (`App` x) (paMethod pa_length ty)
253
254 replicatePA :: CoreExpr -> CoreExpr -> VM CoreExpr
255 replicatePA len x = liftM (`mkApps` [len,x])
256                           (paMethod pa_replicate (exprType x))
257
258 emptyPA :: Type -> VM CoreExpr
259 emptyPA = paMethod pa_empty
260
261 packPA :: Type -> CoreExpr -> CoreExpr -> CoreExpr -> VM CoreExpr
262 packPA ty xs len sel = liftM (`mkApps` [len, sel])
263                              (paMethod pa_pack ty)
264
265 liftPA :: CoreExpr -> VM CoreExpr
266 liftPA x
267   = do
268       lc <- builtin liftingContext
269       replicatePA (Var lc) x
270
271 newLocalVVar :: FastString -> Type -> VM VVar
272 newLocalVVar fs vty
273   = do
274       lty <- mkPArrayType vty
275       vv  <- newLocalVar fs vty
276       lv  <- newLocalVar fs lty
277       return (vv,lv)
278
279 polyAbstract :: [TyVar] -> ((CoreExpr -> CoreExpr) -> VM a) -> VM a
280 polyAbstract tvs p
281   = localV
282   $ do
283       mdicts <- mapM mk_dict_var tvs
284       zipWithM_ (\tv -> maybe (defLocalTyVar tv) (defLocalTyVarWithPA tv . Var)) tvs mdicts
285       p (mk_lams mdicts)
286   where
287     mk_dict_var tv = do
288                        r <- paDictArgType tv
289                        case r of
290                          Just ty -> liftM Just (newLocalVar FSLIT("dPA") ty)
291                          Nothing -> return Nothing
292
293     mk_lams mdicts = mkLams (tvs ++ [dict | Just dict <- mdicts])
294
295 polyApply :: CoreExpr -> [Type] -> VM CoreExpr
296 polyApply expr tys
297   = do
298       dicts <- mapM paDictOfType tys
299       return $ expr `mkTyApps` tys `mkApps` dicts
300
301 polyVApply :: VExpr -> [Type] -> VM VExpr
302 polyVApply expr tys
303   = do
304       dicts <- mapM paDictOfType tys
305       return $ mapVect (\e -> e `mkTyApps` tys `mkApps` dicts) expr
306
307 hoistBinding :: Var -> CoreExpr -> VM ()
308 hoistBinding v e = updGEnv $ \env ->
309   env { global_bindings = (v,e) : global_bindings env }
310
311 hoistExpr :: FastString -> CoreExpr -> VM Var
312 hoistExpr fs expr
313   = do
314       var <- newLocalVar fs (exprType expr)
315       hoistBinding var expr
316       return var
317
318 hoistVExpr :: VExpr -> VM VVar
319 hoistVExpr (ve, le)
320   = do
321       fs <- getBindName
322       vv <- hoistExpr ('v' `consFS` fs) ve
323       lv <- hoistExpr ('l' `consFS` fs) le
324       return (vv, lv)
325
326 hoistPolyVExpr :: [TyVar] -> VM VExpr -> VM VExpr
327 hoistPolyVExpr tvs p
328   = do
329       expr <- closedV . polyAbstract tvs $ \abstract ->
330               liftM (mapVect abstract) p
331       fn   <- hoistVExpr expr
332       polyVApply (vVar fn) (mkTyVarTys tvs)
333
334 takeHoisted :: VM [(Var, CoreExpr)]
335 takeHoisted
336   = do
337       env <- readGEnv id
338       setGEnv $ env { global_bindings = [] }
339       return $ global_bindings env
340
341 mkClosure :: Type -> Type -> Type -> VExpr -> VExpr -> VM VExpr
342 mkClosure arg_ty res_ty env_ty (vfn,lfn) (venv,lenv)
343   = do
344       dict <- paDictOfType env_ty
345       mkv  <- builtin mkClosureVar
346       mkl  <- builtin mkClosurePVar
347       return (Var mkv `mkTyApps` [arg_ty, res_ty, env_ty] `mkApps` [dict, vfn, lfn, venv],
348               Var mkl `mkTyApps` [arg_ty, res_ty, env_ty] `mkApps` [dict, vfn, lfn, lenv])
349
350 mkClosureApp :: Type -> Type -> VExpr -> VExpr -> VM VExpr
351 mkClosureApp arg_ty res_ty (vclo, lclo) (varg, larg)
352   = do
353       vapply <- builtin applyClosureVar
354       lapply <- builtin applyClosurePVar
355       return (Var vapply `mkTyApps` [arg_ty, res_ty] `mkApps` [vclo, varg],
356               Var lapply `mkTyApps` [arg_ty, res_ty] `mkApps` [lclo, larg])
357
358 buildClosures :: [TyVar] -> [VVar] -> [Type] -> Type -> VM VExpr -> VM VExpr
359 buildClosures tvs vars [] res_ty mk_body
360   = mk_body
361 buildClosures tvs vars [arg_ty] res_ty mk_body
362   = buildClosure tvs vars arg_ty res_ty mk_body
363 buildClosures tvs vars (arg_ty : arg_tys) res_ty mk_body
364   = do
365       res_ty' <- mkClosureTypes arg_tys res_ty
366       arg <- newLocalVVar FSLIT("x") arg_ty
367       buildClosure tvs vars arg_ty res_ty'
368         . hoistPolyVExpr tvs
369         $ do
370             lc <- builtin liftingContext
371             clo <- buildClosures tvs (vars ++ [arg]) arg_tys res_ty mk_body
372             return $ vLams lc (vars ++ [arg]) clo
373
374 -- (clo <x1,...,xn> <f,f^>, aclo (Arr lc xs1 ... xsn) <f,f^>)
375 --   where
376 --     f  = \env v -> case env of <x1,...,xn> -> e x1 ... xn v
377 --     f^ = \env v -> case env of Arr l xs1 ... xsn -> e^ l x1 ... xn v
378 --
379 buildClosure :: [TyVar] -> [VVar] -> Type -> Type -> VM VExpr -> VM VExpr
380 buildClosure tvs vars arg_ty res_ty mk_body
381   = do
382       (env_ty, env, bind) <- buildEnv vars
383       env_bndr <- newLocalVVar FSLIT("env") env_ty
384       arg_bndr <- newLocalVVar FSLIT("arg") arg_ty
385
386       fn <- hoistPolyVExpr tvs
387           $ do
388               lc    <- builtin liftingContext
389               body  <- mk_body
390               body' <- bind (vVar env_bndr)
391                             (vVarApps lc body (vars ++ [arg_bndr]))
392               return (vLamsWithoutLC [env_bndr, arg_bndr] body')
393
394       mkClosure arg_ty res_ty env_ty fn env
395
396 buildEnv :: [VVar] -> VM (Type, VExpr, VExpr -> VExpr -> VM VExpr)
397 buildEnv vvs
398   = do
399       lc <- builtin liftingContext
400       let (ty, venv, vbind) = mkVectEnv tys vs
401       (lenv, lbind) <- mkLiftEnv lc tys ls
402       return (ty, (venv, lenv),
403               \(venv,lenv) (vbody,lbody) ->
404               do
405                 let vbody' = vbind venv vbody
406                 lbody' <- lbind lenv lbody
407                 return (vbody', lbody'))
408   where
409     (vs,ls) = unzip vvs
410     tys     = map idType vs
411
412 mkVectEnv :: [Type] -> [Var] -> (Type, CoreExpr, CoreExpr -> CoreExpr -> CoreExpr)
413 mkVectEnv []   []  = (unitTy, Var unitDataConId, \env body -> body)
414 mkVectEnv [ty] [v] = (ty, Var v, \env body -> Let (NonRec v env) body)
415 mkVectEnv tys  vs  = (ty, mkCoreTup (map Var vs),
416                         \env body -> Case env (mkWildId ty) (exprType body)
417                                        [(DataAlt (tupleCon Boxed (length vs)), vs, body)])
418   where
419     ty = mkCoreTupTy tys
420
421 mkLiftEnv :: Var -> [Type] -> [Var] -> VM (CoreExpr, CoreExpr -> CoreExpr -> VM CoreExpr)
422 mkLiftEnv lc [ty] [v]
423   = return (Var v, \env body ->
424                    do
425                      len <- lengthPA ty (Var v)
426                      return . Let (NonRec v env)
427                             $ Case len lc (exprType body) [(DEFAULT, [], body)])
428
429 -- NOTE: this transparently deals with empty environments
430 mkLiftEnv lc tys vs
431   = do
432       (env_tc, env_tyargs) <- parrayReprTyCon vty
433       let [env_con] = tyConDataCons env_tc
434           
435           env = Var (dataConWrapId env_con)
436                 `mkTyApps`  env_tyargs
437                 `mkVarApps` (lc : vs)
438
439           bind env body = let scrut = unwrapFamInstScrut env_tc env_tyargs env
440                           in
441                           return $ Case scrut (mkWildId (exprType scrut))
442                                         (exprType body)
443                                         [(DataAlt env_con, lc : bndrs, body)]
444       return (env, bind)
445   where
446     vty = mkCoreTupTy tys
447
448     bndrs | null vs   = [mkWildId unitTy]
449           | otherwise = vs
450