Move code
[ghc-hetmet.git] / compiler / vectorise / VectUtils.hs
1 module VectUtils (
2   collectAnnTypeBinders, collectAnnTypeArgs, isAnnTypeArg,
3   splitClosureTy,
4   mkPADictType, mkPArrayType,
5   paDictArgType, paDictOfType,
6   paMethod, lengthPA, replicatePA, emptyPA,
7   abstractOverTyVars, applyToTypes,
8   lookupPArrayFamInst,
9   hoistExpr, takeHoisted
10 ) where
11
12 #include "HsVersions.h"
13
14 import VectMonad
15
16 import CoreSyn
17 import CoreUtils
18 import Type
19 import TypeRep
20 import TyCon
21 import Var
22 import PrelNames
23
24 import Outputable
25 import FastString
26
27 import Control.Monad         ( liftM, zipWithM_ )
28
29 collectAnnTypeArgs :: AnnExpr b ann -> (AnnExpr b ann, [Type])
30 collectAnnTypeArgs expr = go expr []
31   where
32     go (_, AnnApp f (_, AnnType ty)) tys = go f (ty : tys)
33     go e                             tys = (e, tys)
34
35 collectAnnTypeBinders :: AnnExpr Var ann -> ([Var], AnnExpr Var ann)
36 collectAnnTypeBinders expr = go [] expr
37   where
38     go bs (_, AnnLam b e) | isTyVar b = go (b:bs) e
39     go bs e                           = (reverse bs, e)
40
41 isAnnTypeArg :: AnnExpr b ann -> Bool
42 isAnnTypeArg (_, AnnType t) = True
43 isAnnTypeArg _              = False
44
45 isClosureTyCon :: TyCon -> Bool
46 isClosureTyCon tc = tyConUnique tc == closureTyConKey
47
48 splitClosureTy :: Type -> (Type, Type)
49 splitClosureTy ty
50   | Just (tc, [arg_ty, res_ty]) <- splitTyConApp_maybe ty
51   , isClosureTyCon tc
52   = (arg_ty, res_ty)
53
54   | otherwise = pprPanic "splitClosureTy" (ppr ty)
55
56 mkPADictType :: Type -> VM Type
57 mkPADictType ty
58   = do
59       tc <- builtin paDictTyCon
60       return $ TyConApp tc [ty]
61
62 mkPArrayType :: Type -> VM Type
63 mkPArrayType ty
64   = do
65       tc <- builtin parrayTyCon
66       return $ TyConApp tc [ty]
67
68 paDictArgType :: TyVar -> VM (Maybe Type)
69 paDictArgType tv = go (TyVarTy tv) (tyVarKind tv)
70   where
71     go ty k | Just k' <- kindView k = go ty k'
72     go ty (FunTy k1 k2)
73       = do
74           tv   <- newTyVar FSLIT("a") k1
75           mty1 <- go (TyVarTy tv) k1
76           case mty1 of
77             Just ty1 -> do
78                           mty2 <- go (AppTy ty (TyVarTy tv)) k2
79                           return $ fmap (ForAllTy tv . FunTy ty1) mty2
80             Nothing  -> go ty k2
81
82     go ty k
83       | isLiftedTypeKind k
84       = liftM Just (mkPADictType ty)
85
86     go ty k = return Nothing
87
88 paDictOfType :: Type -> VM CoreExpr
89 paDictOfType ty = paDictOfTyApp ty_fn ty_args
90   where
91     (ty_fn, ty_args) = splitAppTys ty
92
93 paDictOfTyApp :: Type -> [Type] -> VM CoreExpr
94 paDictOfTyApp ty_fn ty_args
95   | Just ty_fn' <- coreView ty_fn = paDictOfTyApp ty_fn' ty_args
96 paDictOfTyApp (TyVarTy tv) ty_args
97   = do
98       dfun <- maybeV (lookupTyVarPA tv)
99       paDFunApply dfun ty_args
100 paDictOfTyApp (TyConApp tc _) ty_args
101   = do
102       pa_class <- builtin paClass
103       (dfun, ty_args') <- lookupInst pa_class [TyConApp tc ty_args]
104       paDFunApply (Var dfun) ty_args'
105 paDictOfTyApp ty ty_args = pprPanic "paDictOfTyApp" (ppr ty)
106
107 paDFunApply :: CoreExpr -> [Type] -> VM CoreExpr
108 paDFunApply dfun tys
109   = do
110       dicts <- mapM paDictOfType tys
111       return $ mkApps (mkTyApps dfun tys) dicts
112
113 paMethod :: (Builtins -> Var) -> Type -> VM CoreExpr
114 paMethod method ty
115   = do
116       fn   <- builtin method
117       dict <- paDictOfType ty
118       return $ mkApps (Var fn) [Type ty, dict]
119
120 lengthPA :: CoreExpr -> VM CoreExpr
121 lengthPA x = liftM (`App` x) (paMethod lengthPAVar (exprType x))
122
123 replicatePA :: CoreExpr -> CoreExpr -> VM CoreExpr
124 replicatePA len x = liftM (`mkApps` [len,x])
125                           (paMethod replicatePAVar (exprType x))
126
127 emptyPA :: Type -> VM CoreExpr
128 emptyPA = paMethod emptyPAVar
129
130 abstractOverTyVars :: [TyVar] -> ((CoreExpr -> CoreExpr) -> VM a) -> VM a
131 abstractOverTyVars tvs p
132   = do
133       mdicts <- mapM mk_dict_var tvs
134       zipWithM_ (\tv -> maybe (defLocalTyVar tv) (defLocalTyVarWithPA tv . Var)) tvs mdicts
135       p (mk_lams mdicts)
136   where
137     mk_dict_var tv = do
138                        r <- paDictArgType tv
139                        case r of
140                          Just ty -> liftM Just (newLocalVar FSLIT("dPA") ty)
141                          Nothing -> return Nothing
142
143     mk_lams mdicts = mkLams (tvs ++ [dict | Just dict <- mdicts])
144
145 applyToTypes :: CoreExpr -> [Type] -> VM CoreExpr
146 applyToTypes expr tys
147   = do
148       dicts <- mapM paDictOfType tys
149       return $ expr `mkTyApps` tys `mkApps` dicts
150
151 lookupPArrayFamInst :: Type -> VM (TyCon, [Type])
152 lookupPArrayFamInst ty = builtin parrayTyCon >>= (`lookupFamInst` [ty])
153
154 hoistExpr :: FastString -> CoreExpr -> VM Var
155 hoistExpr fs expr
156   = do
157       var <- newLocalVar fs (exprType expr)
158       updGEnv $ \env ->
159         env { global_bindings = (var, expr) : global_bindings env }
160       return var
161
162 takeHoisted :: VM [(Var, CoreExpr)]
163 takeHoisted
164   = do
165       env <- readGEnv id
166       setGEnv $ env { global_bindings = [] }
167       return $ global_bindings env
168