Move vectorisation-related smart constructors into a separate module
[ghc-hetmet.git] / compiler / vectorise / VectUtils.hs
1 module VectUtils (
2   collectAnnTypeBinders, collectAnnTypeArgs, isAnnTypeArg,
3   splitClosureTy,
4   mkPADictType, mkPArrayType,
5   paDictArgType, paDictOfType,
6   paMethod, lengthPA, replicatePA, emptyPA,
7   polyAbstract, polyApply,
8   lookupPArrayFamInst,
9   hoistExpr, takeHoisted
10 ) where
11
12 #include "HsVersions.h"
13
14 import VectCore
15 import VectMonad
16
17 import DsUtils
18 import CoreSyn
19 import CoreUtils
20 import Type
21 import TypeRep
22 import TyCon
23 import DataCon            ( dataConWrapId )
24 import Var
25 import Id                 ( mkWildId )
26 import MkId               ( unwrapFamInstScrut )
27 import PrelNames
28 import TysWiredIn
29 import BasicTypes         ( Boxity(..) )
30
31 import Outputable
32 import FastString
33
34 import Control.Monad         ( liftM, zipWithM_ )
35
36 collectAnnTypeArgs :: AnnExpr b ann -> (AnnExpr b ann, [Type])
37 collectAnnTypeArgs expr = go expr []
38   where
39     go (_, AnnApp f (_, AnnType ty)) tys = go f (ty : tys)
40     go e                             tys = (e, tys)
41
42 collectAnnTypeBinders :: AnnExpr Var ann -> ([Var], AnnExpr Var ann)
43 collectAnnTypeBinders expr = go [] expr
44   where
45     go bs (_, AnnLam b e) | isTyVar b = go (b:bs) e
46     go bs e                           = (reverse bs, e)
47
48 isAnnTypeArg :: AnnExpr b ann -> Bool
49 isAnnTypeArg (_, AnnType t) = True
50 isAnnTypeArg _              = False
51
52 isClosureTyCon :: TyCon -> Bool
53 isClosureTyCon tc = tyConName tc == closureTyConName
54
55 splitClosureTy :: Type -> (Type, Type)
56 splitClosureTy ty
57   | Just (tc, [arg_ty, res_ty]) <- splitTyConApp_maybe ty
58   , isClosureTyCon tc
59   = (arg_ty, res_ty)
60
61   | otherwise = pprPanic "splitClosureTy" (ppr ty)
62
63 isPArrayTyCon :: TyCon -> Bool
64 isPArrayTyCon tc = tyConName tc == parrayTyConName
65
66 splitPArrayTy :: Type -> Type
67 splitPArrayTy ty
68   | Just (tc, [arg_ty]) <- splitTyConApp_maybe ty
69   , isPArrayTyCon tc
70   = arg_ty
71
72   | otherwise = pprPanic "splitPArrayTy" (ppr ty)
73
74 mkPADictType :: Type -> VM Type
75 mkPADictType ty
76   = do
77       tc <- builtin paDictTyCon
78       return $ TyConApp tc [ty]
79
80 mkPArrayType :: Type -> VM Type
81 mkPArrayType ty
82   = do
83       tc <- builtin parrayTyCon
84       return $ TyConApp tc [ty]
85
86 paDictArgType :: TyVar -> VM (Maybe Type)
87 paDictArgType tv = go (TyVarTy tv) (tyVarKind tv)
88   where
89     go ty k | Just k' <- kindView k = go ty k'
90     go ty (FunTy k1 k2)
91       = do
92           tv   <- newTyVar FSLIT("a") k1
93           mty1 <- go (TyVarTy tv) k1
94           case mty1 of
95             Just ty1 -> do
96                           mty2 <- go (AppTy ty (TyVarTy tv)) k2
97                           return $ fmap (ForAllTy tv . FunTy ty1) mty2
98             Nothing  -> go ty k2
99
100     go ty k
101       | isLiftedTypeKind k
102       = liftM Just (mkPADictType ty)
103
104     go ty k = return Nothing
105
106 paDictOfType :: Type -> VM CoreExpr
107 paDictOfType ty = paDictOfTyApp ty_fn ty_args
108   where
109     (ty_fn, ty_args) = splitAppTys ty
110
111 paDictOfTyApp :: Type -> [Type] -> VM CoreExpr
112 paDictOfTyApp ty_fn ty_args
113   | Just ty_fn' <- coreView ty_fn = paDictOfTyApp ty_fn' ty_args
114 paDictOfTyApp (TyVarTy tv) ty_args
115   = do
116       dfun <- maybeV (lookupTyVarPA tv)
117       paDFunApply dfun ty_args
118 paDictOfTyApp (TyConApp tc _) ty_args
119   = do
120       pa_class <- builtin paClass
121       (dfun, ty_args') <- lookupInst pa_class [TyConApp tc ty_args]
122       paDFunApply (Var dfun) ty_args'
123 paDictOfTyApp ty ty_args = pprPanic "paDictOfTyApp" (ppr ty)
124
125 paDFunApply :: CoreExpr -> [Type] -> VM CoreExpr
126 paDFunApply dfun tys
127   = do
128       dicts <- mapM paDictOfType tys
129       return $ mkApps (mkTyApps dfun tys) dicts
130
131 paMethod :: (Builtins -> Var) -> Type -> VM CoreExpr
132 paMethod method ty
133   = do
134       fn   <- builtin method
135       dict <- paDictOfType ty
136       return $ mkApps (Var fn) [Type ty, dict]
137
138 lengthPA :: CoreExpr -> VM CoreExpr
139 lengthPA x = liftM (`App` x) (paMethod lengthPAVar ty)
140   where
141     ty = splitPArrayTy (exprType x)
142
143 replicatePA :: CoreExpr -> CoreExpr -> VM CoreExpr
144 replicatePA len x = liftM (`mkApps` [len,x])
145                           (paMethod replicatePAVar (exprType x))
146
147 emptyPA :: Type -> VM CoreExpr
148 emptyPA = paMethod emptyPAVar
149
150 newLocalVVar :: FastString -> Type -> VM VVar
151 newLocalVVar fs vty
152   = do
153       lty <- mkPArrayType vty
154       vv  <- newLocalVar fs vty
155       lv  <- newLocalVar fs lty
156       return (vv,lv)
157
158 polyAbstract :: [TyVar] -> ((CoreExpr -> CoreExpr) -> VM a) -> VM a
159 polyAbstract tvs p
160   = localV
161   $ do
162       mdicts <- mapM mk_dict_var tvs
163       zipWithM_ (\tv -> maybe (defLocalTyVar tv) (defLocalTyVarWithPA tv . Var)) tvs mdicts
164       p (mk_lams mdicts)
165   where
166     mk_dict_var tv = do
167                        r <- paDictArgType tv
168                        case r of
169                          Just ty -> liftM Just (newLocalVar FSLIT("dPA") ty)
170                          Nothing -> return Nothing
171
172     mk_lams mdicts = mkLams (tvs ++ [dict | Just dict <- mdicts])
173
174 polyApply :: CoreExpr -> [Type] -> VM CoreExpr
175 polyApply expr tys
176   = do
177       dicts <- mapM paDictOfType tys
178       return $ expr `mkTyApps` tys `mkApps` dicts
179
180 lookupPArrayFamInst :: Type -> VM (TyCon, [Type])
181 lookupPArrayFamInst ty = builtin parrayTyCon >>= (`lookupFamInst` [ty])
182
183 hoistExpr :: FastString -> CoreExpr -> VM Var
184 hoistExpr fs expr
185   = do
186       var <- newLocalVar fs (exprType expr)
187       updGEnv $ \env ->
188         env { global_bindings = (var, expr) : global_bindings env }
189       return var
190
191 hoistPolyExpr :: FastString -> [TyVar] -> CoreExpr -> VM CoreExpr
192 hoistPolyExpr fs tvs expr
193   = do
194       poly_expr <- closedV . polyAbstract tvs $ \abstract -> return (abstract expr)
195       fn        <- hoistExpr fs poly_expr
196       polyApply (Var fn) (mkTyVarTys tvs)
197
198 hoistPolyVExpr :: FastString -> [TyVar] -> VExpr -> VM VExpr
199 hoistPolyVExpr fs tvs (ve, le)
200   = do
201       ve' <- hoistPolyExpr ('v' `consFS` fs) tvs ve
202       le' <- hoistPolyExpr ('l' `consFS` fs) tvs le
203       return (ve',le')
204
205 takeHoisted :: VM [(Var, CoreExpr)]
206 takeHoisted
207   = do
208       env <- readGEnv id
209       setGEnv $ env { global_bindings = [] }
210       return $ global_bindings env
211
212
213 mkClosure :: Type -> Type -> Type -> VExpr -> VExpr -> VM VExpr
214 mkClosure arg_ty res_ty env_ty (vfn,lfn) (venv,lenv)
215   = do
216       dict <- paDictOfType env_ty
217       mkv  <- builtin mkClosureVar
218       mkl  <- builtin mkClosurePVar
219       return (Var mkv `mkTyApps` [arg_ty, res_ty, env_ty] `mkApps` [dict, vfn, lfn, venv],
220               Var mkl `mkTyApps` [arg_ty, res_ty, env_ty] `mkApps` [dict, vfn, lfn, lenv])
221
222 -- (clo <x1,...,xn> <f,f^>, aclo (Arr lc xs1 ... xsn) <f,f^>)
223 --   where
224 --     f  = \env v -> case env of <x1,...,xn> -> e x1 ... xn v
225 --     f^ = \env v -> case env of Arr l xs1 ... xsn -> e^ l x1 ... xn v
226
227 buildClosure :: [TyVar] -> Var -> [VVar] -> VVar -> VExpr -> VM VExpr
228 buildClosure tvs lv vars arg body
229   = do
230       (env_ty, env, bind) <- buildEnv lv vars
231       env_bndr            <- newLocalVVar FSLIT("env") env_ty
232
233       fn <- hoistPolyVExpr FSLIT("fn") tvs
234           . mkVLams [env_bndr, arg]
235           . bind (vVar env_bndr)
236           $ mkVVarApps lv body (vars ++ [arg])
237
238       mkClosure arg_ty res_ty env_ty fn env
239
240   where
241     arg_ty = idType (vectorised arg)
242     res_ty = exprType (vectorised body)
243
244
245 buildEnv :: Var -> [VVar] -> VM (Type, VExpr, VExpr -> VExpr -> VExpr)
246 buildEnv lv vvs
247   = do
248       let (ty, venv, vbind) = mkVectEnv tys vs
249       (lenv, lbind) <- mkLiftEnv lv tys ls
250       return (ty, (venv, lenv),
251               \(venv,lenv) (vbody,lbody) -> (vbind venv vbody, lbind lenv lbody))
252   where
253     (vs,ls) = unzip vvs
254     tys     = map idType vs
255
256 mkVectEnv :: [Type] -> [Var] -> (Type, CoreExpr, CoreExpr -> CoreExpr -> CoreExpr)
257 mkVectEnv []   []  = (unitTy, Var unitDataConId, \env body -> body)
258 mkVectEnv [ty] [v] = (ty, Var v, \env body -> Let (NonRec v env) body)
259 mkVectEnv tys  vs  = (ty, mkCoreTup (map Var vs),
260                         \env body -> Case env (mkWildId ty) (exprType body)
261                                        [(DataAlt (tupleCon Boxed (length vs)), vs, body)])
262   where
263     ty = mkCoreTupTy tys
264
265 mkLiftEnv :: Var -> [Type] -> [Var] -> VM (CoreExpr, CoreExpr -> CoreExpr -> CoreExpr)
266 mkLiftEnv lv [ty] [v]
267   = do
268       len <- lengthPA (Var v)
269       return (Var v, \env body -> Let (NonRec v env)
270                                 $ Case len lv (exprType body) [(DEFAULT, [], body)])
271
272 -- NOTE: this transparently deals with empty environments
273 mkLiftEnv lv tys vs
274   = do
275       (env_tc, env_tyargs) <- lookupPArrayFamInst vty
276       let [env_con] = tyConDataCons env_tc
277           
278           env = Var (dataConWrapId env_con)
279                 `mkTyApps`  env_tyargs
280                 `mkVarApps` (lv : vs)
281
282           bind env body = let scrut = unwrapFamInstScrut env_tc env_tyargs env
283                           in
284                           Case scrut (mkWildId (exprType scrut)) (exprType body)
285                             [(DataAlt env_con, lv : vs, body)]
286       return (env, bind)
287   where
288     vty = mkCoreTupTy tys
289