Utility functions for accessing parallel array representations
[ghc-hetmet.git] / compiler / vectorise / VectUtils.hs
1 module VectUtils (
2   collectAnnTypeBinders, collectAnnTypeArgs, isAnnTypeArg,
3   collectAnnValBinders,
4   mkDataConTag,
5   splitClosureTy,
6   mkPADictType, mkPArrayType,
7   parrayReprTyCon, parrayReprDataCon,
8   paDictArgType, paDictOfType, paDFunType,
9   paMethod, lengthPA, replicatePA, emptyPA, liftPA,
10   polyAbstract, polyApply, polyVApply,
11   hoistBinding, hoistExpr, hoistPolyVExpr, takeHoisted,
12   buildClosure, buildClosures,
13   mkClosureApp
14 ) where
15
16 #include "HsVersions.h"
17
18 import VectCore
19 import VectMonad
20
21 import DsUtils
22 import CoreSyn
23 import CoreUtils
24 import Type
25 import TypeRep
26 import TyCon
27 import DataCon            ( DataCon, dataConWrapId, dataConTag )
28 import Var
29 import Id                 ( mkWildId )
30 import MkId               ( unwrapFamInstScrut )
31 import PrelNames
32 import TysWiredIn
33 import BasicTypes         ( Boxity(..) )
34
35 import Outputable
36 import FastString
37
38 import Control.Monad         ( liftM, zipWithM_ )
39
40 collectAnnTypeArgs :: AnnExpr b ann -> (AnnExpr b ann, [Type])
41 collectAnnTypeArgs expr = go expr []
42   where
43     go (_, AnnApp f (_, AnnType ty)) tys = go f (ty : tys)
44     go e                             tys = (e, tys)
45
46 collectAnnTypeBinders :: AnnExpr Var ann -> ([Var], AnnExpr Var ann)
47 collectAnnTypeBinders expr = go [] expr
48   where
49     go bs (_, AnnLam b e) | isTyVar b = go (b:bs) e
50     go bs e                           = (reverse bs, e)
51
52 collectAnnValBinders :: AnnExpr Var ann -> ([Var], AnnExpr Var ann)
53 collectAnnValBinders expr = go [] expr
54   where
55     go bs (_, AnnLam b e) | isId b = go (b:bs) e
56     go bs e                        = (reverse bs, e)
57
58 isAnnTypeArg :: AnnExpr b ann -> Bool
59 isAnnTypeArg (_, AnnType t) = True
60 isAnnTypeArg _              = False
61
62 mkDataConTag :: DataCon -> CoreExpr
63 mkDataConTag dc = mkConApp intDataCon [mkIntLitInt $ dataConTag dc]
64
65 isClosureTyCon :: TyCon -> Bool
66 isClosureTyCon tc = tyConName tc == closureTyConName
67
68 splitClosureTy :: Type -> (Type, Type)
69 splitClosureTy ty
70   | Just (tc, [arg_ty, res_ty]) <- splitTyConApp_maybe ty
71   , isClosureTyCon tc
72   = (arg_ty, res_ty)
73
74   | otherwise = pprPanic "splitClosureTy" (ppr ty)
75
76 isPArrayTyCon :: TyCon -> Bool
77 isPArrayTyCon tc = tyConName tc == parrayTyConName
78
79 splitPArrayTy :: Type -> Type
80 splitPArrayTy ty
81   | Just (tc, [arg_ty]) <- splitTyConApp_maybe ty
82   , isPArrayTyCon tc
83   = arg_ty
84
85   | otherwise = pprPanic "splitPArrayTy" (ppr ty)
86
87 mkClosureType :: Type -> Type -> VM Type
88 mkClosureType arg_ty res_ty
89   = do
90       tc <- builtin closureTyCon
91       return $ mkTyConApp tc [arg_ty, res_ty]
92
93 mkClosureTypes :: [Type] -> Type -> VM Type
94 mkClosureTypes arg_tys res_ty
95   = do
96       tc <- builtin closureTyCon
97       return $ foldr (mk tc) res_ty arg_tys
98   where
99     mk tc arg_ty res_ty = mkTyConApp tc [arg_ty, res_ty]
100
101 mkPADictType :: Type -> VM Type
102 mkPADictType ty
103   = do
104       tc <- builtin paTyCon
105       return $ TyConApp tc [ty]
106
107 mkPArrayType :: Type -> VM Type
108 mkPArrayType ty
109   = do
110       tc <- builtin parrayTyCon
111       return $ TyConApp tc [ty]
112
113 parrayReprTyCon :: Type -> VM (TyCon, [Type])
114 parrayReprTyCon ty = builtin parrayTyCon >>= (`lookupFamInst` [ty])
115
116 parrayReprDataCon :: Type -> VM (DataCon, [Type])
117 parrayReprDataCon ty
118   = do
119       (tc, arg_tys) <- parrayReprTyCon ty
120       let [dc] = tyConDataCons tc
121       return (dc, arg_tys)
122
123 paDictArgType :: TyVar -> VM (Maybe Type)
124 paDictArgType tv = go (TyVarTy tv) (tyVarKind tv)
125   where
126     go ty k | Just k' <- kindView k = go ty k'
127     go ty (FunTy k1 k2)
128       = do
129           tv   <- newTyVar FSLIT("a") k1
130           mty1 <- go (TyVarTy tv) k1
131           case mty1 of
132             Just ty1 -> do
133                           mty2 <- go (AppTy ty (TyVarTy tv)) k2
134                           return $ fmap (ForAllTy tv . FunTy ty1) mty2
135             Nothing  -> go ty k2
136
137     go ty k
138       | isLiftedTypeKind k
139       = liftM Just (mkPADictType ty)
140
141     go ty k = return Nothing
142
143 paDictOfType :: Type -> VM CoreExpr
144 paDictOfType ty = paDictOfTyApp ty_fn ty_args
145   where
146     (ty_fn, ty_args) = splitAppTys ty
147
148 paDictOfTyApp :: Type -> [Type] -> VM CoreExpr
149 paDictOfTyApp ty_fn ty_args
150   | Just ty_fn' <- coreView ty_fn = paDictOfTyApp ty_fn' ty_args
151 paDictOfTyApp (TyVarTy tv) ty_args
152   = do
153       dfun <- maybeV (lookupTyVarPA tv)
154       paDFunApply dfun ty_args
155 paDictOfTyApp (TyConApp tc _) ty_args
156   = do
157       dfun <- traceMaybeV "paDictOfTyApp" (ppr tc) (lookupTyConPA tc)
158       paDFunApply (Var dfun) ty_args
159 paDictOfTyApp ty ty_args = pprPanic "paDictOfTyApp" (ppr ty)
160
161 paDFunType :: TyCon -> VM Type
162 paDFunType tc
163   = do
164       margs <- mapM paDictArgType tvs
165       res   <- mkPADictType (mkTyConApp tc arg_tys)
166       return . mkForAllTys tvs
167              $ mkFunTys [arg | Just arg <- margs] res
168   where
169     tvs = tyConTyVars tc
170     arg_tys = mkTyVarTys tvs
171
172 paDFunApply :: CoreExpr -> [Type] -> VM CoreExpr
173 paDFunApply dfun tys
174   = do
175       dicts <- mapM paDictOfType tys
176       return $ mkApps (mkTyApps dfun tys) dicts
177
178 paMethod :: (Builtins -> Var) -> Type -> VM CoreExpr
179 paMethod method ty
180   = do
181       fn   <- builtin method
182       dict <- paDictOfType ty
183       return $ mkApps (Var fn) [Type ty, dict]
184
185 lengthPA :: CoreExpr -> VM CoreExpr
186 lengthPA x = liftM (`App` x) (paMethod lengthPAVar ty)
187   where
188     ty = splitPArrayTy (exprType x)
189
190 replicatePA :: CoreExpr -> CoreExpr -> VM CoreExpr
191 replicatePA len x = liftM (`mkApps` [len,x])
192                           (paMethod replicatePAVar (exprType x))
193
194 emptyPA :: Type -> VM CoreExpr
195 emptyPA = paMethod emptyPAVar
196
197 liftPA :: CoreExpr -> VM CoreExpr
198 liftPA x
199   = do
200       lc <- builtin liftingContext
201       replicatePA (Var lc) x
202
203 newLocalVVar :: FastString -> Type -> VM VVar
204 newLocalVVar fs vty
205   = do
206       lty <- mkPArrayType vty
207       vv  <- newLocalVar fs vty
208       lv  <- newLocalVar fs lty
209       return (vv,lv)
210
211 polyAbstract :: [TyVar] -> ((CoreExpr -> CoreExpr) -> VM a) -> VM a
212 polyAbstract tvs p
213   = localV
214   $ do
215       mdicts <- mapM mk_dict_var tvs
216       zipWithM_ (\tv -> maybe (defLocalTyVar tv) (defLocalTyVarWithPA tv . Var)) tvs mdicts
217       p (mk_lams mdicts)
218   where
219     mk_dict_var tv = do
220                        r <- paDictArgType tv
221                        case r of
222                          Just ty -> liftM Just (newLocalVar FSLIT("dPA") ty)
223                          Nothing -> return Nothing
224
225     mk_lams mdicts = mkLams (tvs ++ [dict | Just dict <- mdicts])
226
227 polyApply :: CoreExpr -> [Type] -> VM CoreExpr
228 polyApply expr tys
229   = do
230       dicts <- mapM paDictOfType tys
231       return $ expr `mkTyApps` tys `mkApps` dicts
232
233 polyVApply :: VExpr -> [Type] -> VM VExpr
234 polyVApply expr tys
235   = do
236       dicts <- mapM paDictOfType tys
237       return $ mapVect (\e -> e `mkTyApps` tys `mkApps` dicts) expr
238
239 hoistBinding :: Var -> CoreExpr -> VM ()
240 hoistBinding v e = updGEnv $ \env ->
241   env { global_bindings = (v,e) : global_bindings env }
242
243 hoistExpr :: FastString -> CoreExpr -> VM Var
244 hoistExpr fs expr
245   = do
246       var <- newLocalVar fs (exprType expr)
247       hoistBinding var expr
248       return var
249
250 hoistVExpr :: VExpr -> VM VVar
251 hoistVExpr (ve, le)
252   = do
253       fs <- getBindName
254       vv <- hoistExpr ('v' `consFS` fs) ve
255       lv <- hoistExpr ('l' `consFS` fs) le
256       return (vv, lv)
257
258 hoistPolyVExpr :: [TyVar] -> VM VExpr -> VM VExpr
259 hoistPolyVExpr tvs p
260   = do
261       expr <- closedV . polyAbstract tvs $ \abstract ->
262               liftM (mapVect abstract) p
263       fn   <- hoistVExpr expr
264       polyVApply (vVar fn) (mkTyVarTys tvs)
265
266 takeHoisted :: VM [(Var, CoreExpr)]
267 takeHoisted
268   = do
269       env <- readGEnv id
270       setGEnv $ env { global_bindings = [] }
271       return $ global_bindings env
272
273 mkClosure :: Type -> Type -> Type -> VExpr -> VExpr -> VM VExpr
274 mkClosure arg_ty res_ty env_ty (vfn,lfn) (venv,lenv)
275   = do
276       dict <- paDictOfType env_ty
277       mkv  <- builtin mkClosureVar
278       mkl  <- builtin mkClosurePVar
279       return (Var mkv `mkTyApps` [arg_ty, res_ty, env_ty] `mkApps` [dict, vfn, lfn, venv],
280               Var mkl `mkTyApps` [arg_ty, res_ty, env_ty] `mkApps` [dict, vfn, lfn, lenv])
281
282 mkClosureApp :: VExpr -> VExpr -> VM VExpr
283 mkClosureApp (vclo, lclo) (varg, larg)
284   = do
285       vapply <- builtin applyClosureVar
286       lapply <- builtin applyClosurePVar
287       return (Var vapply `mkTyApps` [arg_ty, res_ty] `mkApps` [vclo, varg],
288               Var lapply `mkTyApps` [arg_ty, res_ty] `mkApps` [lclo, larg])
289   where
290     (arg_ty, res_ty) = splitClosureTy (exprType vclo)
291
292 buildClosures :: [TyVar] -> [VVar] -> [Type] -> Type -> VM VExpr -> VM VExpr
293 buildClosures tvs vars [arg_ty] res_ty mk_body
294   = buildClosure tvs vars arg_ty res_ty mk_body
295 buildClosures tvs vars (arg_ty : arg_tys) res_ty mk_body
296   = do
297       res_ty' <- mkClosureTypes arg_tys res_ty
298       arg <- newLocalVVar FSLIT("x") arg_ty
299       buildClosure tvs vars arg_ty res_ty'
300         . hoistPolyVExpr tvs
301         $ do
302             lc <- builtin liftingContext
303             clo <- buildClosures tvs (vars ++ [arg]) arg_tys res_ty mk_body
304             return $ vLams lc (vars ++ [arg]) clo
305
306 -- (clo <x1,...,xn> <f,f^>, aclo (Arr lc xs1 ... xsn) <f,f^>)
307 --   where
308 --     f  = \env v -> case env of <x1,...,xn> -> e x1 ... xn v
309 --     f^ = \env v -> case env of Arr l xs1 ... xsn -> e^ l x1 ... xn v
310 --
311 buildClosure :: [TyVar] -> [VVar] -> Type -> Type -> VM VExpr -> VM VExpr
312 buildClosure tvs vars arg_ty res_ty mk_body
313   = do
314       (env_ty, env, bind) <- buildEnv vars
315       env_bndr <- newLocalVVar FSLIT("env") env_ty
316       arg_bndr <- newLocalVVar FSLIT("arg") arg_ty
317
318       fn <- hoistPolyVExpr tvs
319           $ do
320               lc    <- builtin liftingContext
321               body  <- mk_body
322               body' <- bind (vVar env_bndr)
323                             (vVarApps lc body (vars ++ [arg_bndr]))
324               return (vLamsWithoutLC [env_bndr, arg_bndr] body')
325
326       mkClosure arg_ty res_ty env_ty fn env
327
328 buildEnv :: [VVar] -> VM (Type, VExpr, VExpr -> VExpr -> VM VExpr)
329 buildEnv vvs
330   = do
331       lc <- builtin liftingContext
332       let (ty, venv, vbind) = mkVectEnv tys vs
333       (lenv, lbind) <- mkLiftEnv lc tys ls
334       return (ty, (venv, lenv),
335               \(venv,lenv) (vbody,lbody) ->
336               do
337                 let vbody' = vbind venv vbody
338                 lbody' <- lbind lenv lbody
339                 return (vbody', lbody'))
340   where
341     (vs,ls) = unzip vvs
342     tys     = map idType vs
343
344 mkVectEnv :: [Type] -> [Var] -> (Type, CoreExpr, CoreExpr -> CoreExpr -> CoreExpr)
345 mkVectEnv []   []  = (unitTy, Var unitDataConId, \env body -> body)
346 mkVectEnv [ty] [v] = (ty, Var v, \env body -> Let (NonRec v env) body)
347 mkVectEnv tys  vs  = (ty, mkCoreTup (map Var vs),
348                         \env body -> Case env (mkWildId ty) (exprType body)
349                                        [(DataAlt (tupleCon Boxed (length vs)), vs, body)])
350   where
351     ty = mkCoreTupTy tys
352
353 mkLiftEnv :: Var -> [Type] -> [Var] -> VM (CoreExpr, CoreExpr -> CoreExpr -> VM CoreExpr)
354 mkLiftEnv lc [ty] [v]
355   = return (Var v, \env body ->
356                    do
357                      len <- lengthPA (Var v)
358                      return . Let (NonRec v env)
359                             $ Case len lc (exprType body) [(DEFAULT, [], body)])
360
361 -- NOTE: this transparently deals with empty environments
362 mkLiftEnv lc tys vs
363   = do
364       (env_tc, env_tyargs) <- parrayReprTyCon vty
365       let [env_con] = tyConDataCons env_tc
366           
367           env = Var (dataConWrapId env_con)
368                 `mkTyApps`  env_tyargs
369                 `mkVarApps` (lc : vs)
370
371           bind env body = let scrut = unwrapFamInstScrut env_tc env_tyargs env
372                           in
373                           return $ Case scrut (mkWildId (exprType scrut))
374                                         (exprType body)
375                                         [(DataAlt env_con, lc : bndrs, body)]
376       return (env, bind)
377   where
378     vty = mkCoreTupTy tys
379
380     bndrs | null vs   = [mkWildId unitTy]
381           | otherwise = vs
382