Fix bug in vectorised DataCon worker generation
[ghc-hetmet.git] / compiler / vectorise / VectUtils.hs
1 module VectUtils (
2   collectAnnTypeBinders, collectAnnTypeArgs, isAnnTypeArg,
3   collectAnnValBinders,
4   mkDataConTag,
5   splitClosureTy,
6   mkPADictType, mkPArrayType,
7   paDictArgType, paDictOfType, paDFunType,
8   paMethod, lengthPA, replicatePA, emptyPA, liftPA,
9   polyAbstract, polyApply, polyVApply,
10   lookupPArrayFamInst,
11   hoistBinding, hoistExpr, hoistPolyVExpr, takeHoisted,
12   buildClosure, buildClosures,
13   mkClosureApp
14 ) where
15
16 #include "HsVersions.h"
17
18 import VectCore
19 import VectMonad
20
21 import DsUtils
22 import CoreSyn
23 import CoreUtils
24 import Type
25 import TypeRep
26 import TyCon
27 import DataCon            ( DataCon, dataConWrapId, dataConTag )
28 import Var
29 import Id                 ( mkWildId )
30 import MkId               ( unwrapFamInstScrut )
31 import PrelNames
32 import TysWiredIn
33 import BasicTypes         ( Boxity(..) )
34
35 import Outputable
36 import FastString
37
38 import Control.Monad         ( liftM, zipWithM_ )
39
40 collectAnnTypeArgs :: AnnExpr b ann -> (AnnExpr b ann, [Type])
41 collectAnnTypeArgs expr = go expr []
42   where
43     go (_, AnnApp f (_, AnnType ty)) tys = go f (ty : tys)
44     go e                             tys = (e, tys)
45
46 collectAnnTypeBinders :: AnnExpr Var ann -> ([Var], AnnExpr Var ann)
47 collectAnnTypeBinders expr = go [] expr
48   where
49     go bs (_, AnnLam b e) | isTyVar b = go (b:bs) e
50     go bs e                           = (reverse bs, e)
51
52 collectAnnValBinders :: AnnExpr Var ann -> ([Var], AnnExpr Var ann)
53 collectAnnValBinders expr = go [] expr
54   where
55     go bs (_, AnnLam b e) | isId b = go (b:bs) e
56     go bs e                        = (reverse bs, e)
57
58 isAnnTypeArg :: AnnExpr b ann -> Bool
59 isAnnTypeArg (_, AnnType t) = True
60 isAnnTypeArg _              = False
61
62 mkDataConTag :: DataCon -> CoreExpr
63 mkDataConTag dc = mkConApp intDataCon [mkIntLitInt $ dataConTag dc]
64
65 isClosureTyCon :: TyCon -> Bool
66 isClosureTyCon tc = tyConName tc == closureTyConName
67
68 splitClosureTy :: Type -> (Type, Type)
69 splitClosureTy ty
70   | Just (tc, [arg_ty, res_ty]) <- splitTyConApp_maybe ty
71   , isClosureTyCon tc
72   = (arg_ty, res_ty)
73
74   | otherwise = pprPanic "splitClosureTy" (ppr ty)
75
76 isPArrayTyCon :: TyCon -> Bool
77 isPArrayTyCon tc = tyConName tc == parrayTyConName
78
79 splitPArrayTy :: Type -> Type
80 splitPArrayTy ty
81   | Just (tc, [arg_ty]) <- splitTyConApp_maybe ty
82   , isPArrayTyCon tc
83   = arg_ty
84
85   | otherwise = pprPanic "splitPArrayTy" (ppr ty)
86
87 mkClosureType :: Type -> Type -> VM Type
88 mkClosureType arg_ty res_ty
89   = do
90       tc <- builtin closureTyCon
91       return $ mkTyConApp tc [arg_ty, res_ty]
92
93 mkClosureTypes :: [Type] -> Type -> VM Type
94 mkClosureTypes arg_tys res_ty
95   = do
96       tc <- builtin closureTyCon
97       return $ foldr (mk tc) res_ty arg_tys
98   where
99     mk tc arg_ty res_ty = mkTyConApp tc [arg_ty, res_ty]
100
101 mkPADictType :: Type -> VM Type
102 mkPADictType ty
103   = do
104       tc <- builtin paTyCon
105       return $ TyConApp tc [ty]
106
107 mkPArrayType :: Type -> VM Type
108 mkPArrayType ty
109   = do
110       tc <- builtin parrayTyCon
111       return $ TyConApp tc [ty]
112
113 paDictArgType :: TyVar -> VM (Maybe Type)
114 paDictArgType tv = go (TyVarTy tv) (tyVarKind tv)
115   where
116     go ty k | Just k' <- kindView k = go ty k'
117     go ty (FunTy k1 k2)
118       = do
119           tv   <- newTyVar FSLIT("a") k1
120           mty1 <- go (TyVarTy tv) k1
121           case mty1 of
122             Just ty1 -> do
123                           mty2 <- go (AppTy ty (TyVarTy tv)) k2
124                           return $ fmap (ForAllTy tv . FunTy ty1) mty2
125             Nothing  -> go ty k2
126
127     go ty k
128       | isLiftedTypeKind k
129       = liftM Just (mkPADictType ty)
130
131     go ty k = return Nothing
132
133 paDictOfType :: Type -> VM CoreExpr
134 paDictOfType ty = paDictOfTyApp ty_fn ty_args
135   where
136     (ty_fn, ty_args) = splitAppTys ty
137
138 paDictOfTyApp :: Type -> [Type] -> VM CoreExpr
139 paDictOfTyApp ty_fn ty_args
140   | Just ty_fn' <- coreView ty_fn = paDictOfTyApp ty_fn' ty_args
141 paDictOfTyApp (TyVarTy tv) ty_args
142   = do
143       dfun <- maybeV (lookupTyVarPA tv)
144       paDFunApply dfun ty_args
145 paDictOfTyApp (TyConApp tc _) ty_args
146   = do
147       dfun <- traceMaybeV "paDictOfTyApp" (ppr tc) (lookupTyConPA tc)
148       paDFunApply (Var dfun) ty_args
149 paDictOfTyApp ty ty_args = pprPanic "paDictOfTyApp" (ppr ty)
150
151 paDFunType :: TyCon -> VM Type
152 paDFunType tc
153   = do
154       margs <- mapM paDictArgType tvs
155       res   <- mkPADictType (mkTyConApp tc arg_tys)
156       return . mkForAllTys tvs
157              $ mkFunTys [arg | Just arg <- margs] res
158   where
159     tvs = tyConTyVars tc
160     arg_tys = mkTyVarTys tvs
161
162 paDFunApply :: CoreExpr -> [Type] -> VM CoreExpr
163 paDFunApply dfun tys
164   = do
165       dicts <- mapM paDictOfType tys
166       return $ mkApps (mkTyApps dfun tys) dicts
167
168 paMethod :: (Builtins -> Var) -> Type -> VM CoreExpr
169 paMethod method ty
170   = do
171       fn   <- builtin method
172       dict <- paDictOfType ty
173       return $ mkApps (Var fn) [Type ty, dict]
174
175 lengthPA :: CoreExpr -> VM CoreExpr
176 lengthPA x = liftM (`App` x) (paMethod lengthPAVar ty)
177   where
178     ty = splitPArrayTy (exprType x)
179
180 replicatePA :: CoreExpr -> CoreExpr -> VM CoreExpr
181 replicatePA len x = liftM (`mkApps` [len,x])
182                           (paMethod replicatePAVar (exprType x))
183
184 emptyPA :: Type -> VM CoreExpr
185 emptyPA = paMethod emptyPAVar
186
187 liftPA :: CoreExpr -> VM CoreExpr
188 liftPA x
189   = do
190       lc <- builtin liftingContext
191       replicatePA (Var lc) x
192
193 newLocalVVar :: FastString -> Type -> VM VVar
194 newLocalVVar fs vty
195   = do
196       lty <- mkPArrayType vty
197       vv  <- newLocalVar fs vty
198       lv  <- newLocalVar fs lty
199       return (vv,lv)
200
201 polyAbstract :: [TyVar] -> ((CoreExpr -> CoreExpr) -> VM a) -> VM a
202 polyAbstract tvs p
203   = localV
204   $ do
205       mdicts <- mapM mk_dict_var tvs
206       zipWithM_ (\tv -> maybe (defLocalTyVar tv) (defLocalTyVarWithPA tv . Var)) tvs mdicts
207       p (mk_lams mdicts)
208   where
209     mk_dict_var tv = do
210                        r <- paDictArgType tv
211                        case r of
212                          Just ty -> liftM Just (newLocalVar FSLIT("dPA") ty)
213                          Nothing -> return Nothing
214
215     mk_lams mdicts = mkLams (tvs ++ [dict | Just dict <- mdicts])
216
217 polyApply :: CoreExpr -> [Type] -> VM CoreExpr
218 polyApply expr tys
219   = do
220       dicts <- mapM paDictOfType tys
221       return $ expr `mkTyApps` tys `mkApps` dicts
222
223 polyVApply :: VExpr -> [Type] -> VM VExpr
224 polyVApply expr tys
225   = do
226       dicts <- mapM paDictOfType tys
227       return $ mapVect (\e -> e `mkTyApps` tys `mkApps` dicts) expr
228
229 lookupPArrayFamInst :: Type -> VM (TyCon, [Type])
230 lookupPArrayFamInst ty = builtin parrayTyCon >>= (`lookupFamInst` [ty])
231
232 hoistBinding :: Var -> CoreExpr -> VM ()
233 hoistBinding v e = updGEnv $ \env ->
234   env { global_bindings = (v,e) : global_bindings env }
235
236 hoistExpr :: FastString -> CoreExpr -> VM Var
237 hoistExpr fs expr
238   = do
239       var <- newLocalVar fs (exprType expr)
240       hoistBinding var expr
241       return var
242
243 hoistVExpr :: VExpr -> VM VVar
244 hoistVExpr (ve, le)
245   = do
246       fs <- getBindName
247       vv <- hoistExpr ('v' `consFS` fs) ve
248       lv <- hoistExpr ('l' `consFS` fs) le
249       return (vv, lv)
250
251 hoistPolyVExpr :: [TyVar] -> VM VExpr -> VM VExpr
252 hoistPolyVExpr tvs p
253   = do
254       expr <- closedV . polyAbstract tvs $ \abstract ->
255               liftM (mapVect abstract) p
256       fn   <- hoistVExpr expr
257       polyVApply (vVar fn) (mkTyVarTys tvs)
258
259 takeHoisted :: VM [(Var, CoreExpr)]
260 takeHoisted
261   = do
262       env <- readGEnv id
263       setGEnv $ env { global_bindings = [] }
264       return $ global_bindings env
265
266 mkClosure :: Type -> Type -> Type -> VExpr -> VExpr -> VM VExpr
267 mkClosure arg_ty res_ty env_ty (vfn,lfn) (venv,lenv)
268   = do
269       dict <- paDictOfType env_ty
270       mkv  <- builtin mkClosureVar
271       mkl  <- builtin mkClosurePVar
272       return (Var mkv `mkTyApps` [arg_ty, res_ty, env_ty] `mkApps` [dict, vfn, lfn, venv],
273               Var mkl `mkTyApps` [arg_ty, res_ty, env_ty] `mkApps` [dict, vfn, lfn, lenv])
274
275 mkClosureApp :: VExpr -> VExpr -> VM VExpr
276 mkClosureApp (vclo, lclo) (varg, larg)
277   = do
278       vapply <- builtin applyClosureVar
279       lapply <- builtin applyClosurePVar
280       return (Var vapply `mkTyApps` [arg_ty, res_ty] `mkApps` [vclo, varg],
281               Var lapply `mkTyApps` [arg_ty, res_ty] `mkApps` [lclo, larg])
282   where
283     (arg_ty, res_ty) = splitClosureTy (exprType vclo)
284
285 buildClosures :: [TyVar] -> [VVar] -> [Type] -> Type -> VM VExpr -> VM VExpr
286 buildClosures tvs vars [arg_ty] res_ty mk_body
287   = buildClosure tvs vars arg_ty res_ty mk_body
288 buildClosures tvs vars (arg_ty : arg_tys) res_ty mk_body
289   = do
290       res_ty' <- mkClosureTypes arg_tys res_ty
291       arg <- newLocalVVar FSLIT("x") arg_ty
292       buildClosure tvs vars arg_ty res_ty'
293         . hoistPolyVExpr tvs
294         $ do
295             lc <- builtin liftingContext
296             clo <- buildClosures tvs (vars ++ [arg]) arg_tys res_ty mk_body
297             return $ vLams lc (vars ++ [arg]) clo
298
299 -- (clo <x1,...,xn> <f,f^>, aclo (Arr lc xs1 ... xsn) <f,f^>)
300 --   where
301 --     f  = \env v -> case env of <x1,...,xn> -> e x1 ... xn v
302 --     f^ = \env v -> case env of Arr l xs1 ... xsn -> e^ l x1 ... xn v
303 --
304 buildClosure :: [TyVar] -> [VVar] -> Type -> Type -> VM VExpr -> VM VExpr
305 buildClosure tvs vars arg_ty res_ty mk_body
306   = do
307       (env_ty, env, bind) <- buildEnv vars
308       env_bndr <- newLocalVVar FSLIT("env") env_ty
309       arg_bndr <- newLocalVVar FSLIT("arg") arg_ty
310
311       fn <- hoistPolyVExpr tvs
312           $ do
313               lc    <- builtin liftingContext
314               body  <- mk_body
315               body' <- bind (vVar env_bndr)
316                             (vVarApps lc body (vars ++ [arg_bndr]))
317               return (vLamsWithoutLC [env_bndr, arg_bndr] body')
318
319       mkClosure arg_ty res_ty env_ty fn env
320
321 buildEnv :: [VVar] -> VM (Type, VExpr, VExpr -> VExpr -> VM VExpr)
322 buildEnv vvs
323   = do
324       lc <- builtin liftingContext
325       let (ty, venv, vbind) = mkVectEnv tys vs
326       (lenv, lbind) <- mkLiftEnv lc tys ls
327       return (ty, (venv, lenv),
328               \(venv,lenv) (vbody,lbody) ->
329               do
330                 let vbody' = vbind venv vbody
331                 lbody' <- lbind lenv lbody
332                 return (vbody', lbody'))
333   where
334     (vs,ls) = unzip vvs
335     tys     = map idType vs
336
337 mkVectEnv :: [Type] -> [Var] -> (Type, CoreExpr, CoreExpr -> CoreExpr -> CoreExpr)
338 mkVectEnv []   []  = (unitTy, Var unitDataConId, \env body -> body)
339 mkVectEnv [ty] [v] = (ty, Var v, \env body -> Let (NonRec v env) body)
340 mkVectEnv tys  vs  = (ty, mkCoreTup (map Var vs),
341                         \env body -> Case env (mkWildId ty) (exprType body)
342                                        [(DataAlt (tupleCon Boxed (length vs)), vs, body)])
343   where
344     ty = mkCoreTupTy tys
345
346 mkLiftEnv :: Var -> [Type] -> [Var] -> VM (CoreExpr, CoreExpr -> CoreExpr -> VM CoreExpr)
347 mkLiftEnv lc [ty] [v]
348   = return (Var v, \env body ->
349                    do
350                      len <- lengthPA (Var v)
351                      return . Let (NonRec v env)
352                             $ Case len lc (exprType body) [(DEFAULT, [], body)])
353
354 -- NOTE: this transparently deals with empty environments
355 mkLiftEnv lc tys vs
356   = do
357       (env_tc, env_tyargs) <- lookupPArrayFamInst vty
358       let [env_con] = tyConDataCons env_tc
359           
360           env = Var (dataConWrapId env_con)
361                 `mkTyApps`  env_tyargs
362                 `mkVarApps` (lc : vs)
363
364           bind env body = let scrut = unwrapFamInstScrut env_tc env_tyargs env
365                           in
366                           return $ Case scrut (mkWildId (exprType scrut))
367                                         (exprType body)
368                                         [(DataAlt env_con, lc : bndrs, body)]
369       return (env, bind)
370   where
371     vty = mkCoreTupTy tys
372
373     bndrs | null vs   = [mkWildId unitTy]
374           | otherwise = vs
375