7b0e4af4e12ba37c128e216ac784b4b8ca5f2dd3
[ghc-hetmet.git] / compiler / vectorise / VectUtils.hs
1 module VectUtils (
2   collectAnnTypeBinders, collectAnnTypeArgs, isAnnTypeArg,
3   splitClosureTy,
4   mkPADictType, mkPArrayType,
5   paDictArgType, paDictOfType,
6   paMethod, lengthPA, replicatePA, emptyPA,
7   polyAbstract, polyApply,
8   lookupPArrayFamInst,
9   hoistExpr, takeHoisted
10 ) where
11
12 #include "HsVersions.h"
13
14 import VectMonad
15
16 import CoreSyn
17 import CoreUtils
18 import Type
19 import TypeRep
20 import TyCon
21 import Var
22 import PrelNames
23
24 import Outputable
25 import FastString
26
27 import Control.Monad         ( liftM, zipWithM_ )
28
29 collectAnnTypeArgs :: AnnExpr b ann -> (AnnExpr b ann, [Type])
30 collectAnnTypeArgs expr = go expr []
31   where
32     go (_, AnnApp f (_, AnnType ty)) tys = go f (ty : tys)
33     go e                             tys = (e, tys)
34
35 collectAnnTypeBinders :: AnnExpr Var ann -> ([Var], AnnExpr Var ann)
36 collectAnnTypeBinders expr = go [] expr
37   where
38     go bs (_, AnnLam b e) | isTyVar b = go (b:bs) e
39     go bs e                           = (reverse bs, e)
40
41 isAnnTypeArg :: AnnExpr b ann -> Bool
42 isAnnTypeArg (_, AnnType t) = True
43 isAnnTypeArg _              = False
44
45 isClosureTyCon :: TyCon -> Bool
46 isClosureTyCon tc = tyConName tc == closureTyConName
47
48 splitClosureTy :: Type -> (Type, Type)
49 splitClosureTy ty
50   | Just (tc, [arg_ty, res_ty]) <- splitTyConApp_maybe ty
51   , isClosureTyCon tc
52   = (arg_ty, res_ty)
53
54   | otherwise = pprPanic "splitClosureTy" (ppr ty)
55
56 isPArrayTyCon :: TyCon -> Bool
57 isPArrayTyCon tc = tyConName tc == parrayTyConName
58
59 splitPArrayTy :: Type -> Type
60 splitPArrayTy ty
61   | Just (tc, [arg_ty]) <- splitTyConApp_maybe ty
62   , isPArrayTyCon tc
63   = arg_ty
64
65   | otherwise = pprPanic "splitPArrayTy" (ppr ty)
66
67 mkPADictType :: Type -> VM Type
68 mkPADictType ty
69   = do
70       tc <- builtin paDictTyCon
71       return $ TyConApp tc [ty]
72
73 mkPArrayType :: Type -> VM Type
74 mkPArrayType ty
75   = do
76       tc <- builtin parrayTyCon
77       return $ TyConApp tc [ty]
78
79 paDictArgType :: TyVar -> VM (Maybe Type)
80 paDictArgType tv = go (TyVarTy tv) (tyVarKind tv)
81   where
82     go ty k | Just k' <- kindView k = go ty k'
83     go ty (FunTy k1 k2)
84       = do
85           tv   <- newTyVar FSLIT("a") k1
86           mty1 <- go (TyVarTy tv) k1
87           case mty1 of
88             Just ty1 -> do
89                           mty2 <- go (AppTy ty (TyVarTy tv)) k2
90                           return $ fmap (ForAllTy tv . FunTy ty1) mty2
91             Nothing  -> go ty k2
92
93     go ty k
94       | isLiftedTypeKind k
95       = liftM Just (mkPADictType ty)
96
97     go ty k = return Nothing
98
99 paDictOfType :: Type -> VM CoreExpr
100 paDictOfType ty = paDictOfTyApp ty_fn ty_args
101   where
102     (ty_fn, ty_args) = splitAppTys ty
103
104 paDictOfTyApp :: Type -> [Type] -> VM CoreExpr
105 paDictOfTyApp ty_fn ty_args
106   | Just ty_fn' <- coreView ty_fn = paDictOfTyApp ty_fn' ty_args
107 paDictOfTyApp (TyVarTy tv) ty_args
108   = do
109       dfun <- maybeV (lookupTyVarPA tv)
110       paDFunApply dfun ty_args
111 paDictOfTyApp (TyConApp tc _) ty_args
112   = do
113       pa_class <- builtin paClass
114       (dfun, ty_args') <- lookupInst pa_class [TyConApp tc ty_args]
115       paDFunApply (Var dfun) ty_args'
116 paDictOfTyApp ty ty_args = pprPanic "paDictOfTyApp" (ppr ty)
117
118 paDFunApply :: CoreExpr -> [Type] -> VM CoreExpr
119 paDFunApply dfun tys
120   = do
121       dicts <- mapM paDictOfType tys
122       return $ mkApps (mkTyApps dfun tys) dicts
123
124 paMethod :: (Builtins -> Var) -> Type -> VM CoreExpr
125 paMethod method ty
126   = do
127       fn   <- builtin method
128       dict <- paDictOfType ty
129       return $ mkApps (Var fn) [Type ty, dict]
130
131 lengthPA :: CoreExpr -> VM CoreExpr
132 lengthPA x = liftM (`App` x) (paMethod lengthPAVar ty)
133   where
134     ty = splitPArrayTy (exprType x)
135
136 replicatePA :: CoreExpr -> CoreExpr -> VM CoreExpr
137 replicatePA len x = liftM (`mkApps` [len,x])
138                           (paMethod replicatePAVar (exprType x))
139
140 emptyPA :: Type -> VM CoreExpr
141 emptyPA = paMethod emptyPAVar
142
143 polyAbstract :: [TyVar] -> ((CoreExpr -> CoreExpr) -> VM a) -> VM a
144 polyAbstract tvs p
145   = localV
146   $ do
147       mdicts <- mapM mk_dict_var tvs
148       zipWithM_ (\tv -> maybe (defLocalTyVar tv) (defLocalTyVarWithPA tv . Var)) tvs mdicts
149       p (mk_lams mdicts)
150   where
151     mk_dict_var tv = do
152                        r <- paDictArgType tv
153                        case r of
154                          Just ty -> liftM Just (newLocalVar FSLIT("dPA") ty)
155                          Nothing -> return Nothing
156
157     mk_lams mdicts = mkLams (tvs ++ [dict | Just dict <- mdicts])
158
159 polyApply :: CoreExpr -> [Type] -> VM CoreExpr
160 polyApply expr tys
161   = do
162       dicts <- mapM paDictOfType tys
163       return $ expr `mkTyApps` tys `mkApps` dicts
164
165 lookupPArrayFamInst :: Type -> VM (TyCon, [Type])
166 lookupPArrayFamInst ty = builtin parrayTyCon >>= (`lookupFamInst` [ty])
167
168 hoistExpr :: FastString -> CoreExpr -> VM Var
169 hoistExpr fs expr
170   = do
171       var <- newLocalVar fs (exprType expr)
172       updGEnv $ \env ->
173         env { global_bindings = (var, expr) : global_bindings env }
174       return var
175
176 takeHoisted :: VM [(Var, CoreExpr)]
177 takeHoisted
178   = do
179       env <- readGEnv id
180       setGEnv $ env { global_bindings = [] }
181       return $ global_bindings env
182