d1c1dab9748d437083db89ac8382ab5d72794eeb
[ghc-hetmet.git] / compiler / vectorise / VectUtils.hs
1 module VectUtils (
2   collectAnnTypeBinders, collectAnnTypeArgs, isAnnTypeArg,
3   collectAnnValBinders,
4   mkDataConTag,
5   splitClosureTy,
6   mkPRepr, mkToPRepr, mkFromPRepr,
7   mkPADictType, mkPArrayType, mkPReprType,
8   parrayReprTyCon, parrayReprDataCon, mkVScrut,
9   prDictOfType, prCoerce,
10   paDictArgType, paDictOfType, paDFunType,
11   paMethod, lengthPA, replicatePA, emptyPA, liftPA,
12   polyAbstract, polyApply, polyVApply,
13   hoistBinding, hoistExpr, hoistPolyVExpr, takeHoisted,
14   buildClosure, buildClosures,
15   mkClosureApp
16 ) where
17
18 #include "HsVersions.h"
19
20 import VectCore
21 import VectMonad
22
23 import DsUtils
24 import CoreSyn
25 import CoreUtils
26 import Coercion
27 import Type
28 import TypeRep
29 import TyCon
30 import DataCon            ( DataCon, dataConWrapId, dataConTag )
31 import Var
32 import Id                 ( mkWildId )
33 import MkId               ( unwrapFamInstScrut )
34 import Name               ( Name )
35 import PrelNames
36 import TysWiredIn
37 import BasicTypes         ( Boxity(..) )
38
39 import Outputable
40 import FastString
41
42 import Control.Monad         ( liftM, liftM2, zipWithM_ )
43
44 collectAnnTypeArgs :: AnnExpr b ann -> (AnnExpr b ann, [Type])
45 collectAnnTypeArgs expr = go expr []
46   where
47     go (_, AnnApp f (_, AnnType ty)) tys = go f (ty : tys)
48     go e                             tys = (e, tys)
49
50 collectAnnTypeBinders :: AnnExpr Var ann -> ([Var], AnnExpr Var ann)
51 collectAnnTypeBinders expr = go [] expr
52   where
53     go bs (_, AnnLam b e) | isTyVar b = go (b:bs) e
54     go bs e                           = (reverse bs, e)
55
56 collectAnnValBinders :: AnnExpr Var ann -> ([Var], AnnExpr Var ann)
57 collectAnnValBinders expr = go [] expr
58   where
59     go bs (_, AnnLam b e) | isId b = go (b:bs) e
60     go bs e                        = (reverse bs, e)
61
62 isAnnTypeArg :: AnnExpr b ann -> Bool
63 isAnnTypeArg (_, AnnType t) = True
64 isAnnTypeArg _              = False
65
66 mkDataConTag :: DataCon -> CoreExpr
67 mkDataConTag dc = mkConApp intDataCon [mkIntLitInt $ dataConTag dc]
68
69 splitUnTy :: String -> Name -> Type -> Type
70 splitUnTy s name ty
71   | Just (tc, [ty']) <- splitTyConApp_maybe ty
72   , tyConName tc == name
73   = ty'
74
75   | otherwise = pprPanic s (ppr ty)
76
77 splitBinTy :: String -> Name -> Type -> (Type, Type)
78 splitBinTy s name ty
79   | Just (tc, [ty1, ty2]) <- splitTyConApp_maybe ty
80   , tyConName tc == name
81   = (ty1, ty2)
82
83   | otherwise = pprPanic s (ppr ty)
84
85 splitCrossTy :: Type -> (Type, Type)
86 splitCrossTy = splitBinTy "splitCrossTy" ndpCrossTyConName
87
88 splitPlusTy :: Type -> (Type, Type)
89 splitPlusTy = splitBinTy "splitSumTy" ndpPlusTyConName
90
91 splitEmbedTy :: Type -> Type
92 splitEmbedTy = splitUnTy "splitEmbedTy" embedTyConName
93
94 splitClosureTy :: Type -> (Type, Type)
95 splitClosureTy = splitBinTy "splitClosureTy" closureTyConName
96
97 splitPArrayTy :: Type -> Type
98 splitPArrayTy = splitUnTy "splitPArrayTy" parrayTyConName
99
100 mkBuiltinTyConApp :: (Builtins -> TyCon) -> [Type] -> VM Type
101 mkBuiltinTyConApp get_tc tys
102   = do
103       tc <- builtin get_tc
104       return $ mkTyConApp tc tys
105
106 mkBuiltinTyConApps :: (Builtins -> TyCon) -> [Type] -> Type -> VM Type
107 mkBuiltinTyConApps get_tc tys ty
108   = do
109       tc <- builtin get_tc
110       return $ foldr (mk tc) ty tys
111   where
112     mk tc ty1 ty2 = mkTyConApp tc [ty1,ty2]
113
114 mkBuiltinTyConApps1 :: (Builtins -> TyCon) -> Type -> [Type] -> VM Type
115 mkBuiltinTyConApps1 get_tc dft [] = return dft
116 mkBuiltinTyConApps1 get_tc dft tys
117   = do
118       tc <- builtin get_tc
119       case tys of
120         [] -> pprPanic "mkBuiltinTyConApps1" (ppr tc)
121         _  -> return $ foldr1 (mk tc) tys
122   where
123     mk tc ty1 ty2 = mkTyConApp tc [ty1,ty2]
124
125 mkPRepr :: [[Type]] -> VM Type
126 mkPRepr [] = return unitTy
127 mkPRepr tys
128   = do
129       embed <- builtin embedTyCon
130       cross <- builtin crossTyCon
131       plus  <- builtin plusTyCon
132
133       let mk_embed ty      = mkTyConApp embed [ty]
134           mk_cross ty1 ty2 = mkTyConApp cross [ty1, ty2]
135           mk_plus  ty1 ty2 = mkTyConApp plus  [ty1, ty2]
136
137           mk_tup   []      = unitTy
138           mk_tup   tys     = foldr1 mk_cross tys
139
140           mk_sum   []      = unitTy
141           mk_sum   tys     = foldr1 mk_plus  tys
142
143       return . mk_sum
144              . map (mk_tup . map mk_embed)
145              $ tys
146
147 mkToPRepr :: [[CoreExpr]] -> VM ([CoreExpr], Type)
148 mkToPRepr ess
149   = do
150       embed_tc <- builtin embedTyCon
151       embed_dc <- builtin embedDataCon
152       cross_tc <- builtin crossTyCon
153       cross_dc <- builtin crossDataCon
154       plus_tc  <- builtin plusTyCon
155       left_dc  <- builtin leftDataCon
156       right_dc <- builtin rightDataCon
157
158       let mk_embed expr
159             = (mkConApp   embed_dc [Type ty, expr],
160                mkTyConApp embed_tc [ty])
161             where ty = exprType expr
162
163           mk_cross (expr1, ty1) (expr2, ty2)
164             = (mkConApp   cross_dc [Type ty1, Type ty2, expr1, expr2],
165                mkTyConApp cross_tc [ty1, ty2])
166
167           mk_tup [] = (Var unitDataConId, unitTy)
168           mk_tup es = foldr1 mk_cross es
169
170           mk_sum []           = ([Var unitDataConId], unitTy)
171           mk_sum [(expr, ty)] = ([expr], ty)
172           mk_sum ((expr, lty) : es)
173             = let (alts, rty) = mk_sum es
174               in
175               (mkConApp left_dc [Type lty, Type rty, expr]
176                  : [mkConApp right_dc [Type lty, Type rty, alt] | alt <- alts],
177                mkTyConApp plus_tc [lty, rty])
178
179       return . mk_sum $ map (mk_tup . map mk_embed) ess
180
181 mkFromPRepr :: CoreExpr -> Type -> [([Var], CoreExpr)] -> VM CoreExpr
182 mkFromPRepr scrut res_ty alts
183   = do
184       embed_dc <- builtin embedDataCon
185       cross_dc <- builtin crossDataCon
186       left_dc  <- builtin leftDataCon
187       right_dc <- builtin rightDataCon
188       pa_tc    <- builtin paTyCon
189
190       let un_embed expr ty var res
191             = Case expr (mkWildId ty) res_ty
192                    [(DataAlt embed_dc, [var], res)]
193
194           un_cross expr ty var1 var2 res
195             = Case expr (mkWildId ty) res_ty
196                 [(DataAlt cross_dc, [var1, var2], res)]
197
198           un_tup expr ty []    res = return res
199           un_tup expr ty [var] res = return $ un_embed expr ty var res
200           un_tup expr ty (var : vars) res
201             = do
202                 lv <- newLocalVar FSLIT("x") lty
203                 rv <- newLocalVar FSLIT("y") rty
204                 liftM (un_cross expr ty lv rv
205                       . un_embed (Var lv) lty var)
206                       (un_tup (Var rv) rty vars res)
207             where
208               (lty, rty) = splitCrossTy ty
209
210           un_plus expr ty var1 var2 res1 res2
211             = Case expr (mkWildId ty) res_ty
212                 [(DataAlt left_dc,  [var1], res1),
213                  (DataAlt right_dc, [var2], res2)]
214
215           un_sum expr ty [(vars, res)] = un_tup expr ty vars res
216           un_sum expr ty ((vars, res) : alts)
217             = do
218                 lv <- newLocalVar FSLIT("l") lty
219                 rv <- newLocalVar FSLIT("r") rty
220                 liftM2 (un_plus expr ty lv rv)
221                          (un_tup (Var lv) lty vars res)
222                          (un_sum (Var rv) rty alts)
223             where
224               (lty, rty) = splitPlusTy ty
225
226       un_sum scrut (exprType scrut) alts
227
228 mkClosureType :: Type -> Type -> VM Type
229 mkClosureType arg_ty res_ty = mkBuiltinTyConApp closureTyCon [arg_ty, res_ty]
230
231 mkClosureTypes :: [Type] -> Type -> VM Type
232 mkClosureTypes = mkBuiltinTyConApps closureTyCon
233
234 mkPReprType :: Type -> VM Type
235 mkPReprType ty = mkBuiltinTyConApp preprTyCon [ty]
236
237 mkPADictType :: Type -> VM Type
238 mkPADictType ty = mkBuiltinTyConApp paTyCon [ty]
239
240 mkPArrayType :: Type -> VM Type
241 mkPArrayType ty = mkBuiltinTyConApp parrayTyCon [ty]
242
243 parrayReprTyCon :: Type -> VM (TyCon, [Type])
244 parrayReprTyCon ty = builtin parrayTyCon >>= (`lookupFamInst` [ty])
245
246 parrayReprDataCon :: Type -> VM (DataCon, [Type])
247 parrayReprDataCon ty
248   = do
249       (tc, arg_tys) <- parrayReprTyCon ty
250       let [dc] = tyConDataCons tc
251       return (dc, arg_tys)
252
253 mkVScrut :: VExpr -> VM (VExpr, TyCon, [Type])
254 mkVScrut (ve, le)
255   = do
256       (tc, arg_tys) <- parrayReprTyCon (exprType ve)
257       return ((ve, unwrapFamInstScrut tc arg_tys le), tc, arg_tys)
258
259 prDictOfType :: Type -> VM CoreExpr
260 prDictOfType orig_ty
261   | Just (tycon, ty_args) <- splitTyConApp_maybe orig_ty
262   = do
263       dfun <- traceMaybeV "prDictOfType" (ppr tycon) (lookupTyConPR tycon)
264       prDFunApply (Var dfun) ty_args
265
266 prDFunApply :: CoreExpr -> [Type] -> VM CoreExpr
267 prDFunApply dfun tys
268   = do
269       args <- mapM mkDFunArg arg_tys
270       return $ mkApps mono_dfun args
271   where
272     mono_dfun    = mkTyApps dfun tys
273     (arg_tys, _) = splitFunTys (exprType mono_dfun)
274
275 mkDFunArg :: Type -> VM CoreExpr
276 mkDFunArg ty
277   | Just (tycon, [arg]) <- splitTyConApp_maybe ty
278
279   = let name = tyConName tycon
280
281         get_dict | name == paTyConName = paDictOfType
282                  | name == prTyConName = prDictOfType
283                  | otherwise           = pprPanic "mkDFunArg" (ppr ty)
284
285     in get_dict arg
286
287 mkDFunArg ty = pprPanic "mkDFunArg" (ppr ty)
288
289 prCoerce :: TyCon -> [Type] -> CoreExpr -> VM CoreExpr
290 prCoerce repr_tc args expr
291   | Just arg_co <- tyConFamilyCoercion_maybe repr_tc
292   = do
293       pr_tc <- builtin prTyCon
294
295       let co = mkAppCoercion (mkTyConApp pr_tc [])
296                              (mkSymCoercion (mkTyConApp arg_co args))
297
298       return $ mkCoerce co expr
299
300 paDictArgType :: TyVar -> VM (Maybe Type)
301 paDictArgType tv = go (TyVarTy tv) (tyVarKind tv)
302   where
303     go ty k | Just k' <- kindView k = go ty k'
304     go ty (FunTy k1 k2)
305       = do
306           tv   <- newTyVar FSLIT("a") k1
307           mty1 <- go (TyVarTy tv) k1
308           case mty1 of
309             Just ty1 -> do
310                           mty2 <- go (AppTy ty (TyVarTy tv)) k2
311                           return $ fmap (ForAllTy tv . FunTy ty1) mty2
312             Nothing  -> go ty k2
313
314     go ty k
315       | isLiftedTypeKind k
316       = liftM Just (mkPADictType ty)
317
318     go ty k = return Nothing
319
320 paDictOfType :: Type -> VM CoreExpr
321 paDictOfType ty = paDictOfTyApp ty_fn ty_args
322   where
323     (ty_fn, ty_args) = splitAppTys ty
324
325 paDictOfTyApp :: Type -> [Type] -> VM CoreExpr
326 paDictOfTyApp ty_fn ty_args
327   | Just ty_fn' <- coreView ty_fn = paDictOfTyApp ty_fn' ty_args
328 paDictOfTyApp (TyVarTy tv) ty_args
329   = do
330       dfun <- maybeV (lookupTyVarPA tv)
331       paDFunApply dfun ty_args
332 paDictOfTyApp (TyConApp tc _) ty_args
333   = do
334       dfun <- traceMaybeV "paDictOfTyApp" (ppr tc) (lookupTyConPA tc)
335       paDFunApply (Var dfun) ty_args
336 paDictOfTyApp ty ty_args = pprPanic "paDictOfTyApp" (ppr ty)
337
338 paDFunType :: TyCon -> VM Type
339 paDFunType tc
340   = do
341       margs <- mapM paDictArgType tvs
342       res   <- mkPADictType (mkTyConApp tc arg_tys)
343       return . mkForAllTys tvs
344              $ mkFunTys [arg | Just arg <- margs] res
345   where
346     tvs = tyConTyVars tc
347     arg_tys = mkTyVarTys tvs
348
349 paDFunApply :: CoreExpr -> [Type] -> VM CoreExpr
350 paDFunApply dfun tys
351   = do
352       dicts <- mapM paDictOfType tys
353       return $ mkApps (mkTyApps dfun tys) dicts
354
355 paMethod :: (Builtins -> Var) -> Type -> VM CoreExpr
356 paMethod method ty
357   = do
358       fn   <- builtin method
359       dict <- paDictOfType ty
360       return $ mkApps (Var fn) [Type ty, dict]
361
362 lengthPA :: CoreExpr -> VM CoreExpr
363 lengthPA x = liftM (`App` x) (paMethod lengthPAVar ty)
364   where
365     ty = splitPArrayTy (exprType x)
366
367 replicatePA :: CoreExpr -> CoreExpr -> VM CoreExpr
368 replicatePA len x = liftM (`mkApps` [len,x])
369                           (paMethod replicatePAVar (exprType x))
370
371 emptyPA :: Type -> VM CoreExpr
372 emptyPA = paMethod emptyPAVar
373
374 liftPA :: CoreExpr -> VM CoreExpr
375 liftPA x
376   = do
377       lc <- builtin liftingContext
378       replicatePA (Var lc) x
379
380 newLocalVVar :: FastString -> Type -> VM VVar
381 newLocalVVar fs vty
382   = do
383       lty <- mkPArrayType vty
384       vv  <- newLocalVar fs vty
385       lv  <- newLocalVar fs lty
386       return (vv,lv)
387
388 polyAbstract :: [TyVar] -> ((CoreExpr -> CoreExpr) -> VM a) -> VM a
389 polyAbstract tvs p
390   = localV
391   $ do
392       mdicts <- mapM mk_dict_var tvs
393       zipWithM_ (\tv -> maybe (defLocalTyVar tv) (defLocalTyVarWithPA tv . Var)) tvs mdicts
394       p (mk_lams mdicts)
395   where
396     mk_dict_var tv = do
397                        r <- paDictArgType tv
398                        case r of
399                          Just ty -> liftM Just (newLocalVar FSLIT("dPA") ty)
400                          Nothing -> return Nothing
401
402     mk_lams mdicts = mkLams (tvs ++ [dict | Just dict <- mdicts])
403
404 polyApply :: CoreExpr -> [Type] -> VM CoreExpr
405 polyApply expr tys
406   = do
407       dicts <- mapM paDictOfType tys
408       return $ expr `mkTyApps` tys `mkApps` dicts
409
410 polyVApply :: VExpr -> [Type] -> VM VExpr
411 polyVApply expr tys
412   = do
413       dicts <- mapM paDictOfType tys
414       return $ mapVect (\e -> e `mkTyApps` tys `mkApps` dicts) expr
415
416 hoistBinding :: Var -> CoreExpr -> VM ()
417 hoistBinding v e = updGEnv $ \env ->
418   env { global_bindings = (v,e) : global_bindings env }
419
420 hoistExpr :: FastString -> CoreExpr -> VM Var
421 hoistExpr fs expr
422   = do
423       var <- newLocalVar fs (exprType expr)
424       hoistBinding var expr
425       return var
426
427 hoistVExpr :: VExpr -> VM VVar
428 hoistVExpr (ve, le)
429   = do
430       fs <- getBindName
431       vv <- hoistExpr ('v' `consFS` fs) ve
432       lv <- hoistExpr ('l' `consFS` fs) le
433       return (vv, lv)
434
435 hoistPolyVExpr :: [TyVar] -> VM VExpr -> VM VExpr
436 hoistPolyVExpr tvs p
437   = do
438       expr <- closedV . polyAbstract tvs $ \abstract ->
439               liftM (mapVect abstract) p
440       fn   <- hoistVExpr expr
441       polyVApply (vVar fn) (mkTyVarTys tvs)
442
443 takeHoisted :: VM [(Var, CoreExpr)]
444 takeHoisted
445   = do
446       env <- readGEnv id
447       setGEnv $ env { global_bindings = [] }
448       return $ global_bindings env
449
450 mkClosure :: Type -> Type -> Type -> VExpr -> VExpr -> VM VExpr
451 mkClosure arg_ty res_ty env_ty (vfn,lfn) (venv,lenv)
452   = do
453       dict <- paDictOfType env_ty
454       mkv  <- builtin mkClosureVar
455       mkl  <- builtin mkClosurePVar
456       return (Var mkv `mkTyApps` [arg_ty, res_ty, env_ty] `mkApps` [dict, vfn, lfn, venv],
457               Var mkl `mkTyApps` [arg_ty, res_ty, env_ty] `mkApps` [dict, vfn, lfn, lenv])
458
459 mkClosureApp :: VExpr -> VExpr -> VM VExpr
460 mkClosureApp (vclo, lclo) (varg, larg)
461   = do
462       vapply <- builtin applyClosureVar
463       lapply <- builtin applyClosurePVar
464       return (Var vapply `mkTyApps` [arg_ty, res_ty] `mkApps` [vclo, varg],
465               Var lapply `mkTyApps` [arg_ty, res_ty] `mkApps` [lclo, larg])
466   where
467     (arg_ty, res_ty) = splitClosureTy (exprType vclo)
468
469 buildClosures :: [TyVar] -> [VVar] -> [Type] -> Type -> VM VExpr -> VM VExpr
470 buildClosures tvs vars [] res_ty mk_body
471   = mk_body
472 buildClosures tvs vars [arg_ty] res_ty mk_body
473   = buildClosure tvs vars arg_ty res_ty mk_body
474 buildClosures tvs vars (arg_ty : arg_tys) res_ty mk_body
475   = do
476       res_ty' <- mkClosureTypes arg_tys res_ty
477       arg <- newLocalVVar FSLIT("x") arg_ty
478       buildClosure tvs vars arg_ty res_ty'
479         . hoistPolyVExpr tvs
480         $ do
481             lc <- builtin liftingContext
482             clo <- buildClosures tvs (vars ++ [arg]) arg_tys res_ty mk_body
483             return $ vLams lc (vars ++ [arg]) clo
484
485 -- (clo <x1,...,xn> <f,f^>, aclo (Arr lc xs1 ... xsn) <f,f^>)
486 --   where
487 --     f  = \env v -> case env of <x1,...,xn> -> e x1 ... xn v
488 --     f^ = \env v -> case env of Arr l xs1 ... xsn -> e^ l x1 ... xn v
489 --
490 buildClosure :: [TyVar] -> [VVar] -> Type -> Type -> VM VExpr -> VM VExpr
491 buildClosure tvs vars arg_ty res_ty mk_body
492   = do
493       (env_ty, env, bind) <- buildEnv vars
494       env_bndr <- newLocalVVar FSLIT("env") env_ty
495       arg_bndr <- newLocalVVar FSLIT("arg") arg_ty
496
497       fn <- hoistPolyVExpr tvs
498           $ do
499               lc    <- builtin liftingContext
500               body  <- mk_body
501               body' <- bind (vVar env_bndr)
502                             (vVarApps lc body (vars ++ [arg_bndr]))
503               return (vLamsWithoutLC [env_bndr, arg_bndr] body')
504
505       mkClosure arg_ty res_ty env_ty fn env
506
507 buildEnv :: [VVar] -> VM (Type, VExpr, VExpr -> VExpr -> VM VExpr)
508 buildEnv vvs
509   = do
510       lc <- builtin liftingContext
511       let (ty, venv, vbind) = mkVectEnv tys vs
512       (lenv, lbind) <- mkLiftEnv lc tys ls
513       return (ty, (venv, lenv),
514               \(venv,lenv) (vbody,lbody) ->
515               do
516                 let vbody' = vbind venv vbody
517                 lbody' <- lbind lenv lbody
518                 return (vbody', lbody'))
519   where
520     (vs,ls) = unzip vvs
521     tys     = map idType vs
522
523 mkVectEnv :: [Type] -> [Var] -> (Type, CoreExpr, CoreExpr -> CoreExpr -> CoreExpr)
524 mkVectEnv []   []  = (unitTy, Var unitDataConId, \env body -> body)
525 mkVectEnv [ty] [v] = (ty, Var v, \env body -> Let (NonRec v env) body)
526 mkVectEnv tys  vs  = (ty, mkCoreTup (map Var vs),
527                         \env body -> Case env (mkWildId ty) (exprType body)
528                                        [(DataAlt (tupleCon Boxed (length vs)), vs, body)])
529   where
530     ty = mkCoreTupTy tys
531
532 mkLiftEnv :: Var -> [Type] -> [Var] -> VM (CoreExpr, CoreExpr -> CoreExpr -> VM CoreExpr)
533 mkLiftEnv lc [ty] [v]
534   = return (Var v, \env body ->
535                    do
536                      len <- lengthPA (Var v)
537                      return . Let (NonRec v env)
538                             $ Case len lc (exprType body) [(DEFAULT, [], body)])
539
540 -- NOTE: this transparently deals with empty environments
541 mkLiftEnv lc tys vs
542   = do
543       (env_tc, env_tyargs) <- parrayReprTyCon vty
544       let [env_con] = tyConDataCons env_tc
545           
546           env = Var (dataConWrapId env_con)
547                 `mkTyApps`  env_tyargs
548                 `mkVarApps` (lc : vs)
549
550           bind env body = let scrut = unwrapFamInstScrut env_tc env_tyargs env
551                           in
552                           return $ Case scrut (mkWildId (exprType scrut))
553                                         (exprType body)
554                                         [(DataAlt env_con, lc : bndrs, body)]
555       return (env, bind)
556   where
557     vty = mkCoreTupTy tys
558
559     bndrs | null vs   = [mkWildId unitTy]
560           | otherwise = vs
561