Generate conversion to PRepr during vectorisation
[ghc-hetmet.git] / compiler / vectorise / VectUtils.hs
1 module VectUtils (
2   collectAnnTypeBinders, collectAnnTypeArgs, isAnnTypeArg,
3   collectAnnValBinders,
4   mkDataConTag,
5   splitClosureTy,
6   mkPReprType, mkPReprAlts,
7   mkPADictType, mkPArrayType,
8   parrayReprTyCon, parrayReprDataCon, mkVScrut,
9   paDictArgType, paDictOfType, paDFunType,
10   paMethod, lengthPA, replicatePA, emptyPA, liftPA,
11   polyAbstract, polyApply, polyVApply,
12   hoistBinding, hoistExpr, hoistPolyVExpr, takeHoisted,
13   buildClosure, buildClosures,
14   mkClosureApp
15 ) where
16
17 #include "HsVersions.h"
18
19 import VectCore
20 import VectMonad
21
22 import DsUtils
23 import CoreSyn
24 import CoreUtils
25 import Type
26 import TypeRep
27 import TyCon
28 import DataCon            ( DataCon, dataConWrapId, dataConTag )
29 import Var
30 import Id                 ( mkWildId )
31 import MkId               ( unwrapFamInstScrut )
32 import PrelNames
33 import TysWiredIn
34 import BasicTypes         ( Boxity(..) )
35
36 import Outputable
37 import FastString
38
39 import Control.Monad         ( liftM, zipWithM_ )
40
41 collectAnnTypeArgs :: AnnExpr b ann -> (AnnExpr b ann, [Type])
42 collectAnnTypeArgs expr = go expr []
43   where
44     go (_, AnnApp f (_, AnnType ty)) tys = go f (ty : tys)
45     go e                             tys = (e, tys)
46
47 collectAnnTypeBinders :: AnnExpr Var ann -> ([Var], AnnExpr Var ann)
48 collectAnnTypeBinders expr = go [] expr
49   where
50     go bs (_, AnnLam b e) | isTyVar b = go (b:bs) e
51     go bs e                           = (reverse bs, e)
52
53 collectAnnValBinders :: AnnExpr Var ann -> ([Var], AnnExpr Var ann)
54 collectAnnValBinders expr = go [] expr
55   where
56     go bs (_, AnnLam b e) | isId b = go (b:bs) e
57     go bs e                        = (reverse bs, e)
58
59 isAnnTypeArg :: AnnExpr b ann -> Bool
60 isAnnTypeArg (_, AnnType t) = True
61 isAnnTypeArg _              = False
62
63 mkDataConTag :: DataCon -> CoreExpr
64 mkDataConTag dc = mkConApp intDataCon [mkIntLitInt $ dataConTag dc]
65
66 isClosureTyCon :: TyCon -> Bool
67 isClosureTyCon tc = tyConName tc == closureTyConName
68
69 splitClosureTy :: Type -> (Type, Type)
70 splitClosureTy ty
71   | Just (tc, [arg_ty, res_ty]) <- splitTyConApp_maybe ty
72   , isClosureTyCon tc
73   = (arg_ty, res_ty)
74
75   | otherwise = pprPanic "splitClosureTy" (ppr ty)
76
77 isPArrayTyCon :: TyCon -> Bool
78 isPArrayTyCon tc = tyConName tc == parrayTyConName
79
80 splitPArrayTy :: Type -> Type
81 splitPArrayTy ty
82   | Just (tc, [arg_ty]) <- splitTyConApp_maybe ty
83   , isPArrayTyCon tc
84   = arg_ty
85
86   | otherwise = pprPanic "splitPArrayTy" (ppr ty)
87
88 mkBuiltinTyConApp :: (Builtins -> TyCon) -> [Type] -> VM Type
89 mkBuiltinTyConApp get_tc tys
90   = do
91       tc <- builtin get_tc
92       return $ mkTyConApp tc tys
93
94 mkBuiltinTyConApps :: (Builtins -> TyCon) -> [Type] -> Type -> VM Type
95 mkBuiltinTyConApps get_tc tys ty
96   = do
97       tc <- builtin get_tc
98       return $ foldr (mk tc) ty tys
99   where
100     mk tc ty1 ty2 = mkTyConApp tc [ty1,ty2]
101
102 mkBuiltinTyConApps1 :: (Builtins -> TyCon) -> Type -> [Type] -> VM Type
103 mkBuiltinTyConApps1 get_tc dft [] = return dft
104 mkBuiltinTyConApps1 get_tc dft tys
105   = do
106       tc <- builtin get_tc
107       case tys of
108         [] -> pprPanic "mkBuiltinTyConApps1" (ppr tc)
109         _  -> return $ foldr1 (mk tc) tys
110   where
111     mk tc ty1 ty2 = mkTyConApp tc [ty1,ty2]
112
113 mkPReprType :: [[Type]] -> VM Type
114 mkPReprType [] = return unitTy
115 mkPReprType tys
116   = do
117       embed <- builtin embedTyCon
118       cross <- builtin crossTyCon
119       plus  <- builtin plusTyCon
120
121       let mk_embed ty      = mkTyConApp embed [ty]
122           mk_cross ty1 ty2 = mkTyConApp cross [ty1, ty2]
123           mk_plus  ty1 ty2 = mkTyConApp plus  [ty1, ty2]
124
125           mk_tup   []      = unitTy
126           mk_tup   tys     = foldr1 mk_cross tys
127
128           mk_sum   []      = unitTy
129           mk_sum   tys     = foldr1 mk_plus  tys
130
131       return . mk_sum
132              . map (mk_tup . map mk_embed)
133              $ tys
134
135 mkPReprAlts :: [[CoreExpr]] -> VM ([CoreExpr], Type)
136 mkPReprAlts ess
137   = do
138       embed_tc <- builtin embedTyCon
139       embed_dc <- builtin embedDataCon
140       cross_tc <- builtin crossTyCon
141       cross_dc <- builtin crossDataCon
142       plus_tc  <- builtin plusTyCon
143       left_dc  <- builtin leftDataCon
144       right_dc <- builtin rightDataCon
145
146       let mk_embed (expr, ty, pa)
147             = (mkConApp   embed_dc [Type ty, pa, expr],
148                mkTyConApp embed_tc [ty])
149
150           mk_cross (expr1, ty1) (expr2, ty2)
151             = (mkConApp   cross_dc [Type ty1, Type ty2, expr1, expr2],
152                mkTyConApp cross_tc [ty1, ty2])
153
154           mk_tup [] = (Var unitDataConId, unitTy)
155           mk_tup es = foldr1 mk_cross es
156
157           mk_sum []           = ([Var unitDataConId], unitTy)
158           mk_sum [(expr, ty)] = ([expr], ty)
159           mk_sum ((expr, lty) : es)
160             = let (alts, rty) = mk_sum es
161               in
162               (mkConApp left_dc [Type lty, Type rty, expr]
163                  : [mkConApp right_dc [Type lty, Type rty, alt] | alt <- alts],
164                mkTyConApp plus_tc [lty, rty])
165       
166       liftM (mk_sum . map (mk_tup . map mk_embed))
167             (mapM (mapM init) ess)
168   where
169     init expr = let ty = exprType expr
170                 in do
171                      pa <- paDictOfType ty
172                      return (expr, ty, pa)
173
174 mkClosureType :: Type -> Type -> VM Type
175 mkClosureType arg_ty res_ty = mkBuiltinTyConApp closureTyCon [arg_ty, res_ty]
176
177 mkClosureTypes :: [Type] -> Type -> VM Type
178 mkClosureTypes = mkBuiltinTyConApps closureTyCon
179
180 mkPADictType :: Type -> VM Type
181 mkPADictType ty = mkBuiltinTyConApp paTyCon [ty]
182
183 mkPArrayType :: Type -> VM Type
184 mkPArrayType ty = mkBuiltinTyConApp parrayTyCon [ty]
185
186 parrayReprTyCon :: Type -> VM (TyCon, [Type])
187 parrayReprTyCon ty = builtin parrayTyCon >>= (`lookupFamInst` [ty])
188
189 parrayReprDataCon :: Type -> VM (DataCon, [Type])
190 parrayReprDataCon ty
191   = do
192       (tc, arg_tys) <- parrayReprTyCon ty
193       let [dc] = tyConDataCons tc
194       return (dc, arg_tys)
195
196 mkVScrut :: VExpr -> VM (VExpr, TyCon, [Type])
197 mkVScrut (ve, le)
198   = do
199       (tc, arg_tys) <- parrayReprTyCon (exprType ve)
200       return ((ve, unwrapFamInstScrut tc arg_tys le), tc, arg_tys)
201
202 paDictArgType :: TyVar -> VM (Maybe Type)
203 paDictArgType tv = go (TyVarTy tv) (tyVarKind tv)
204   where
205     go ty k | Just k' <- kindView k = go ty k'
206     go ty (FunTy k1 k2)
207       = do
208           tv   <- newTyVar FSLIT("a") k1
209           mty1 <- go (TyVarTy tv) k1
210           case mty1 of
211             Just ty1 -> do
212                           mty2 <- go (AppTy ty (TyVarTy tv)) k2
213                           return $ fmap (ForAllTy tv . FunTy ty1) mty2
214             Nothing  -> go ty k2
215
216     go ty k
217       | isLiftedTypeKind k
218       = liftM Just (mkPADictType ty)
219
220     go ty k = return Nothing
221
222 paDictOfType :: Type -> VM CoreExpr
223 paDictOfType ty = paDictOfTyApp ty_fn ty_args
224   where
225     (ty_fn, ty_args) = splitAppTys ty
226
227 paDictOfTyApp :: Type -> [Type] -> VM CoreExpr
228 paDictOfTyApp ty_fn ty_args
229   | Just ty_fn' <- coreView ty_fn = paDictOfTyApp ty_fn' ty_args
230 paDictOfTyApp (TyVarTy tv) ty_args
231   = do
232       dfun <- maybeV (lookupTyVarPA tv)
233       paDFunApply dfun ty_args
234 paDictOfTyApp (TyConApp tc _) ty_args
235   = do
236       dfun <- traceMaybeV "paDictOfTyApp" (ppr tc) (lookupTyConPA tc)
237       paDFunApply (Var dfun) ty_args
238 paDictOfTyApp ty ty_args = pprPanic "paDictOfTyApp" (ppr ty)
239
240 paDFunType :: TyCon -> VM Type
241 paDFunType tc
242   = do
243       margs <- mapM paDictArgType tvs
244       res   <- mkPADictType (mkTyConApp tc arg_tys)
245       return . mkForAllTys tvs
246              $ mkFunTys [arg | Just arg <- margs] res
247   where
248     tvs = tyConTyVars tc
249     arg_tys = mkTyVarTys tvs
250
251 paDFunApply :: CoreExpr -> [Type] -> VM CoreExpr
252 paDFunApply dfun tys
253   = do
254       dicts <- mapM paDictOfType tys
255       return $ mkApps (mkTyApps dfun tys) dicts
256
257 paMethod :: (Builtins -> Var) -> Type -> VM CoreExpr
258 paMethod method ty
259   = do
260       fn   <- builtin method
261       dict <- paDictOfType ty
262       return $ mkApps (Var fn) [Type ty, dict]
263
264 lengthPA :: CoreExpr -> VM CoreExpr
265 lengthPA x = liftM (`App` x) (paMethod lengthPAVar ty)
266   where
267     ty = splitPArrayTy (exprType x)
268
269 replicatePA :: CoreExpr -> CoreExpr -> VM CoreExpr
270 replicatePA len x = liftM (`mkApps` [len,x])
271                           (paMethod replicatePAVar (exprType x))
272
273 emptyPA :: Type -> VM CoreExpr
274 emptyPA = paMethod emptyPAVar
275
276 liftPA :: CoreExpr -> VM CoreExpr
277 liftPA x
278   = do
279       lc <- builtin liftingContext
280       replicatePA (Var lc) x
281
282 newLocalVVar :: FastString -> Type -> VM VVar
283 newLocalVVar fs vty
284   = do
285       lty <- mkPArrayType vty
286       vv  <- newLocalVar fs vty
287       lv  <- newLocalVar fs lty
288       return (vv,lv)
289
290 polyAbstract :: [TyVar] -> ((CoreExpr -> CoreExpr) -> VM a) -> VM a
291 polyAbstract tvs p
292   = localV
293   $ do
294       mdicts <- mapM mk_dict_var tvs
295       zipWithM_ (\tv -> maybe (defLocalTyVar tv) (defLocalTyVarWithPA tv . Var)) tvs mdicts
296       p (mk_lams mdicts)
297   where
298     mk_dict_var tv = do
299                        r <- paDictArgType tv
300                        case r of
301                          Just ty -> liftM Just (newLocalVar FSLIT("dPA") ty)
302                          Nothing -> return Nothing
303
304     mk_lams mdicts = mkLams (tvs ++ [dict | Just dict <- mdicts])
305
306 polyApply :: CoreExpr -> [Type] -> VM CoreExpr
307 polyApply expr tys
308   = do
309       dicts <- mapM paDictOfType tys
310       return $ expr `mkTyApps` tys `mkApps` dicts
311
312 polyVApply :: VExpr -> [Type] -> VM VExpr
313 polyVApply expr tys
314   = do
315       dicts <- mapM paDictOfType tys
316       return $ mapVect (\e -> e `mkTyApps` tys `mkApps` dicts) expr
317
318 hoistBinding :: Var -> CoreExpr -> VM ()
319 hoistBinding v e = updGEnv $ \env ->
320   env { global_bindings = (v,e) : global_bindings env }
321
322 hoistExpr :: FastString -> CoreExpr -> VM Var
323 hoistExpr fs expr
324   = do
325       var <- newLocalVar fs (exprType expr)
326       hoistBinding var expr
327       return var
328
329 hoistVExpr :: VExpr -> VM VVar
330 hoistVExpr (ve, le)
331   = do
332       fs <- getBindName
333       vv <- hoistExpr ('v' `consFS` fs) ve
334       lv <- hoistExpr ('l' `consFS` fs) le
335       return (vv, lv)
336
337 hoistPolyVExpr :: [TyVar] -> VM VExpr -> VM VExpr
338 hoistPolyVExpr tvs p
339   = do
340       expr <- closedV . polyAbstract tvs $ \abstract ->
341               liftM (mapVect abstract) p
342       fn   <- hoistVExpr expr
343       polyVApply (vVar fn) (mkTyVarTys tvs)
344
345 takeHoisted :: VM [(Var, CoreExpr)]
346 takeHoisted
347   = do
348       env <- readGEnv id
349       setGEnv $ env { global_bindings = [] }
350       return $ global_bindings env
351
352 mkClosure :: Type -> Type -> Type -> VExpr -> VExpr -> VM VExpr
353 mkClosure arg_ty res_ty env_ty (vfn,lfn) (venv,lenv)
354   = do
355       dict <- paDictOfType env_ty
356       mkv  <- builtin mkClosureVar
357       mkl  <- builtin mkClosurePVar
358       return (Var mkv `mkTyApps` [arg_ty, res_ty, env_ty] `mkApps` [dict, vfn, lfn, venv],
359               Var mkl `mkTyApps` [arg_ty, res_ty, env_ty] `mkApps` [dict, vfn, lfn, lenv])
360
361 mkClosureApp :: VExpr -> VExpr -> VM VExpr
362 mkClosureApp (vclo, lclo) (varg, larg)
363   = do
364       vapply <- builtin applyClosureVar
365       lapply <- builtin applyClosurePVar
366       return (Var vapply `mkTyApps` [arg_ty, res_ty] `mkApps` [vclo, varg],
367               Var lapply `mkTyApps` [arg_ty, res_ty] `mkApps` [lclo, larg])
368   where
369     (arg_ty, res_ty) = splitClosureTy (exprType vclo)
370
371 buildClosures :: [TyVar] -> [VVar] -> [Type] -> Type -> VM VExpr -> VM VExpr
372 buildClosures tvs vars [] res_ty mk_body
373   = mk_body
374 buildClosures tvs vars [arg_ty] res_ty mk_body
375   = buildClosure tvs vars arg_ty res_ty mk_body
376 buildClosures tvs vars (arg_ty : arg_tys) res_ty mk_body
377   = do
378       res_ty' <- mkClosureTypes arg_tys res_ty
379       arg <- newLocalVVar FSLIT("x") arg_ty
380       buildClosure tvs vars arg_ty res_ty'
381         . hoistPolyVExpr tvs
382         $ do
383             lc <- builtin liftingContext
384             clo <- buildClosures tvs (vars ++ [arg]) arg_tys res_ty mk_body
385             return $ vLams lc (vars ++ [arg]) clo
386
387 -- (clo <x1,...,xn> <f,f^>, aclo (Arr lc xs1 ... xsn) <f,f^>)
388 --   where
389 --     f  = \env v -> case env of <x1,...,xn> -> e x1 ... xn v
390 --     f^ = \env v -> case env of Arr l xs1 ... xsn -> e^ l x1 ... xn v
391 --
392 buildClosure :: [TyVar] -> [VVar] -> Type -> Type -> VM VExpr -> VM VExpr
393 buildClosure tvs vars arg_ty res_ty mk_body
394   = do
395       (env_ty, env, bind) <- buildEnv vars
396       env_bndr <- newLocalVVar FSLIT("env") env_ty
397       arg_bndr <- newLocalVVar FSLIT("arg") arg_ty
398
399       fn <- hoistPolyVExpr tvs
400           $ do
401               lc    <- builtin liftingContext
402               body  <- mk_body
403               body' <- bind (vVar env_bndr)
404                             (vVarApps lc body (vars ++ [arg_bndr]))
405               return (vLamsWithoutLC [env_bndr, arg_bndr] body')
406
407       mkClosure arg_ty res_ty env_ty fn env
408
409 buildEnv :: [VVar] -> VM (Type, VExpr, VExpr -> VExpr -> VM VExpr)
410 buildEnv vvs
411   = do
412       lc <- builtin liftingContext
413       let (ty, venv, vbind) = mkVectEnv tys vs
414       (lenv, lbind) <- mkLiftEnv lc tys ls
415       return (ty, (venv, lenv),
416               \(venv,lenv) (vbody,lbody) ->
417               do
418                 let vbody' = vbind venv vbody
419                 lbody' <- lbind lenv lbody
420                 return (vbody', lbody'))
421   where
422     (vs,ls) = unzip vvs
423     tys     = map idType vs
424
425 mkVectEnv :: [Type] -> [Var] -> (Type, CoreExpr, CoreExpr -> CoreExpr -> CoreExpr)
426 mkVectEnv []   []  = (unitTy, Var unitDataConId, \env body -> body)
427 mkVectEnv [ty] [v] = (ty, Var v, \env body -> Let (NonRec v env) body)
428 mkVectEnv tys  vs  = (ty, mkCoreTup (map Var vs),
429                         \env body -> Case env (mkWildId ty) (exprType body)
430                                        [(DataAlt (tupleCon Boxed (length vs)), vs, body)])
431   where
432     ty = mkCoreTupTy tys
433
434 mkLiftEnv :: Var -> [Type] -> [Var] -> VM (CoreExpr, CoreExpr -> CoreExpr -> VM CoreExpr)
435 mkLiftEnv lc [ty] [v]
436   = return (Var v, \env body ->
437                    do
438                      len <- lengthPA (Var v)
439                      return . Let (NonRec v env)
440                             $ Case len lc (exprType body) [(DEFAULT, [], body)])
441
442 -- NOTE: this transparently deals with empty environments
443 mkLiftEnv lc tys vs
444   = do
445       (env_tc, env_tyargs) <- parrayReprTyCon vty
446       let [env_con] = tyConDataCons env_tc
447           
448           env = Var (dataConWrapId env_con)
449                 `mkTyApps`  env_tyargs
450                 `mkVarApps` (lc : vs)
451
452           bind env body = let scrut = unwrapFamInstScrut env_tc env_tyargs env
453                           in
454                           return $ Case scrut (mkWildId (exprType scrut))
455                                         (exprType body)
456                                         [(DataAlt env_con, lc : bndrs, body)]
457       return (env, bind)
458   where
459     vty = mkCoreTupTy tys
460
461     bndrs | null vs   = [mkWildId unitTy]
462           | otherwise = vs
463