Use the right dictionary when calling lengthPA
[ghc-hetmet.git] / compiler / vectorise / VectUtils.hs
1 module VectUtils (
2   collectAnnTypeBinders, collectAnnTypeArgs, isAnnTypeArg,
3   splitClosureTy,
4   mkPADictType, mkPArrayType,
5   paDictArgType, paDictOfType,
6   paMethod, lengthPA, replicatePA, emptyPA,
7   abstractOverTyVars, applyToTypes,
8   lookupPArrayFamInst,
9   hoistExpr, takeHoisted
10 ) where
11
12 #include "HsVersions.h"
13
14 import VectMonad
15
16 import CoreSyn
17 import CoreUtils
18 import Type
19 import TypeRep
20 import TyCon
21 import Var
22 import PrelNames
23
24 import Outputable
25 import FastString
26
27 import Control.Monad         ( liftM, zipWithM_ )
28
29 collectAnnTypeArgs :: AnnExpr b ann -> (AnnExpr b ann, [Type])
30 collectAnnTypeArgs expr = go expr []
31   where
32     go (_, AnnApp f (_, AnnType ty)) tys = go f (ty : tys)
33     go e                             tys = (e, tys)
34
35 collectAnnTypeBinders :: AnnExpr Var ann -> ([Var], AnnExpr Var ann)
36 collectAnnTypeBinders expr = go [] expr
37   where
38     go bs (_, AnnLam b e) | isTyVar b = go (b:bs) e
39     go bs e                           = (reverse bs, e)
40
41 isAnnTypeArg :: AnnExpr b ann -> Bool
42 isAnnTypeArg (_, AnnType t) = True
43 isAnnTypeArg _              = False
44
45 isClosureTyCon :: TyCon -> Bool
46 isClosureTyCon tc = tyConName tc == closureTyConName
47
48 splitClosureTy :: Type -> (Type, Type)
49 splitClosureTy ty
50   | Just (tc, [arg_ty, res_ty]) <- splitTyConApp_maybe ty
51   , isClosureTyCon tc
52   = (arg_ty, res_ty)
53
54   | otherwise = pprPanic "splitClosureTy" (ppr ty)
55
56 isPArrayTyCon :: TyCon -> Bool
57 isPArrayTyCon tc = tyConName tc == parrayTyConName
58
59 splitPArrayTy :: Type -> Type
60 splitPArrayTy ty
61   | Just (tc, [arg_ty]) <- splitTyConApp_maybe ty
62   , isPArrayTyCon tc
63   = arg_ty
64
65   | otherwise = pprPanic "splitPArrayTy" (ppr ty)
66
67 mkPADictType :: Type -> VM Type
68 mkPADictType ty
69   = do
70       tc <- builtin paDictTyCon
71       return $ TyConApp tc [ty]
72
73 mkPArrayType :: Type -> VM Type
74 mkPArrayType ty
75   = do
76       tc <- builtin parrayTyCon
77       return $ TyConApp tc [ty]
78
79 paDictArgType :: TyVar -> VM (Maybe Type)
80 paDictArgType tv = go (TyVarTy tv) (tyVarKind tv)
81   where
82     go ty k | Just k' <- kindView k = go ty k'
83     go ty (FunTy k1 k2)
84       = do
85           tv   <- newTyVar FSLIT("a") k1
86           mty1 <- go (TyVarTy tv) k1
87           case mty1 of
88             Just ty1 -> do
89                           mty2 <- go (AppTy ty (TyVarTy tv)) k2
90                           return $ fmap (ForAllTy tv . FunTy ty1) mty2
91             Nothing  -> go ty k2
92
93     go ty k
94       | isLiftedTypeKind k
95       = liftM Just (mkPADictType ty)
96
97     go ty k = return Nothing
98
99 paDictOfType :: Type -> VM CoreExpr
100 paDictOfType ty = paDictOfTyApp ty_fn ty_args
101   where
102     (ty_fn, ty_args) = splitAppTys ty
103
104 paDictOfTyApp :: Type -> [Type] -> VM CoreExpr
105 paDictOfTyApp ty_fn ty_args
106   | Just ty_fn' <- coreView ty_fn = paDictOfTyApp ty_fn' ty_args
107 paDictOfTyApp (TyVarTy tv) ty_args
108   = do
109       dfun <- maybeV (lookupTyVarPA tv)
110       paDFunApply dfun ty_args
111 paDictOfTyApp (TyConApp tc _) ty_args
112   = do
113       pa_class <- builtin paClass
114       (dfun, ty_args') <- lookupInst pa_class [TyConApp tc ty_args]
115       paDFunApply (Var dfun) ty_args'
116 paDictOfTyApp ty ty_args = pprPanic "paDictOfTyApp" (ppr ty)
117
118 paDFunApply :: CoreExpr -> [Type] -> VM CoreExpr
119 paDFunApply dfun tys
120   = do
121       dicts <- mapM paDictOfType tys
122       return $ mkApps (mkTyApps dfun tys) dicts
123
124 paMethod :: (Builtins -> Var) -> Type -> VM CoreExpr
125 paMethod method ty
126   = do
127       fn   <- builtin method
128       dict <- paDictOfType ty
129       return $ mkApps (Var fn) [Type ty, dict]
130
131 lengthPA :: CoreExpr -> VM CoreExpr
132 lengthPA x = liftM (`App` x) (paMethod lengthPAVar ty)
133   where
134     ty = splitPArrayTy (exprType x)
135
136 replicatePA :: CoreExpr -> CoreExpr -> VM CoreExpr
137 replicatePA len x = liftM (`mkApps` [len,x])
138                           (paMethod replicatePAVar (exprType x))
139
140 emptyPA :: Type -> VM CoreExpr
141 emptyPA = paMethod emptyPAVar
142
143 abstractOverTyVars :: [TyVar] -> ((CoreExpr -> CoreExpr) -> VM a) -> VM a
144 abstractOverTyVars tvs p
145   = do
146       mdicts <- mapM mk_dict_var tvs
147       zipWithM_ (\tv -> maybe (defLocalTyVar tv) (defLocalTyVarWithPA tv . Var)) tvs mdicts
148       p (mk_lams mdicts)
149   where
150     mk_dict_var tv = do
151                        r <- paDictArgType tv
152                        case r of
153                          Just ty -> liftM Just (newLocalVar FSLIT("dPA") ty)
154                          Nothing -> return Nothing
155
156     mk_lams mdicts = mkLams (tvs ++ [dict | Just dict <- mdicts])
157
158 applyToTypes :: CoreExpr -> [Type] -> VM CoreExpr
159 applyToTypes expr tys
160   = do
161       dicts <- mapM paDictOfType tys
162       return $ expr `mkTyApps` tys `mkApps` dicts
163
164 lookupPArrayFamInst :: Type -> VM (TyCon, [Type])
165 lookupPArrayFamInst ty = builtin parrayTyCon >>= (`lookupFamInst` [ty])
166
167 hoistExpr :: FastString -> CoreExpr -> VM Var
168 hoistExpr fs expr
169   = do
170       var <- newLocalVar fs (exprType expr)
171       updGEnv $ \env ->
172         env { global_bindings = (var, expr) : global_bindings env }
173       return var
174
175 takeHoisted :: VM [(Var, CoreExpr)]
176 takeHoisted
177   = do
178       env <- readGEnv id
179       setGEnv $ env { global_bindings = [] }
180       return $ global_bindings env
181