Refactoring
[ghc-hetmet.git] / compiler / vectorise / VectUtils.hs
1 module VectUtils (
2   collectAnnTypeBinders, collectAnnTypeArgs, isAnnTypeArg,
3   splitClosureTy,
4   mkPADictType, mkPArrayType,
5   paDictArgType, paDictOfType, paMethod,
6   lookupPArrayFamInst,
7   hoistExpr, takeHoisted
8 ) where
9
10 #include "HsVersions.h"
11
12 import VectMonad
13
14 import CoreSyn
15 import CoreUtils
16 import Type
17 import TypeRep
18 import TyCon
19 import Var
20 import PrelNames
21
22 import Outputable
23 import FastString
24
25 import Control.Monad         ( liftM )
26
27 collectAnnTypeArgs :: AnnExpr b ann -> (AnnExpr b ann, [Type])
28 collectAnnTypeArgs expr = go expr []
29   where
30     go (_, AnnApp f (_, AnnType ty)) tys = go f (ty : tys)
31     go e                             tys = (e, tys)
32
33 collectAnnTypeBinders :: AnnExpr Var ann -> ([Var], AnnExpr Var ann)
34 collectAnnTypeBinders expr = go [] expr
35   where
36     go bs (_, AnnLam b e) | isTyVar b = go (b:bs) e
37     go bs e                           = (reverse bs, e)
38
39 isAnnTypeArg :: AnnExpr b ann -> Bool
40 isAnnTypeArg (_, AnnType t) = True
41 isAnnTypeArg _              = False
42
43 isClosureTyCon :: TyCon -> Bool
44 isClosureTyCon tc = tyConUnique tc == closureTyConKey
45
46 splitClosureTy :: Type -> (Type, Type)
47 splitClosureTy ty
48   | Just (tc, [arg_ty, res_ty]) <- splitTyConApp_maybe ty
49   , isClosureTyCon tc
50   = (arg_ty, res_ty)
51
52   | otherwise = pprPanic "splitClosureTy" (ppr ty)
53
54 mkPADictType :: Type -> VM Type
55 mkPADictType ty
56   = do
57       tc <- builtin paDictTyCon
58       return $ TyConApp tc [ty]
59
60 mkPArrayType :: Type -> VM Type
61 mkPArrayType ty
62   = do
63       tc <- builtin parrayTyCon
64       return $ TyConApp tc [ty]
65
66 paDictArgType :: TyVar -> VM (Maybe Type)
67 paDictArgType tv = go (TyVarTy tv) (tyVarKind tv)
68   where
69     go ty k | Just k' <- kindView k = go ty k'
70     go ty (FunTy k1 k2)
71       = do
72           tv   <- newTyVar FSLIT("a") k1
73           mty1 <- go (TyVarTy tv) k1
74           case mty1 of
75             Just ty1 -> do
76                           mty2 <- go (AppTy ty (TyVarTy tv)) k2
77                           return $ fmap (ForAllTy tv . FunTy ty1) mty2
78             Nothing  -> go ty k2
79
80     go ty k
81       | isLiftedTypeKind k
82       = liftM Just (mkPADictType ty)
83
84     go ty k = return Nothing
85
86 paDictOfType :: Type -> VM CoreExpr
87 paDictOfType ty = paDictOfTyApp ty_fn ty_args
88   where
89     (ty_fn, ty_args) = splitAppTys ty
90
91 paDictOfTyApp :: Type -> [Type] -> VM CoreExpr
92 paDictOfTyApp ty_fn ty_args
93   | Just ty_fn' <- coreView ty_fn = paDictOfTyApp ty_fn' ty_args
94 paDictOfTyApp (TyVarTy tv) ty_args
95   = do
96       dfun <- maybeV (lookupTyVarPA tv)
97       paDFunApply dfun ty_args
98 paDictOfTyApp (TyConApp tc _) ty_args
99   = do
100       pa_class <- builtin paClass
101       (dfun, ty_args') <- lookupInst pa_class [TyConApp tc ty_args]
102       paDFunApply (Var dfun) ty_args'
103 paDictOfTyApp ty ty_args = pprPanic "paDictOfTyApp" (ppr ty)
104
105 paDFunApply :: CoreExpr -> [Type] -> VM CoreExpr
106 paDFunApply dfun tys
107   = do
108       dicts <- mapM paDictOfType tys
109       return $ mkApps (mkTyApps dfun tys) dicts
110
111 paMethod :: (Builtins -> Var) -> Type -> VM CoreExpr
112 paMethod method ty
113   = do
114       fn   <- builtin method
115       dict <- paDictOfType ty
116       return $ mkApps (Var fn) [Type ty, dict]
117
118 lookupPArrayFamInst :: Type -> VM (TyCon, [Type])
119 lookupPArrayFamInst ty = builtin parrayTyCon >>= (`lookupFamInst` [ty])
120
121 hoistExpr :: FastString -> CoreExpr -> VM Var
122 hoistExpr fs expr
123   = do
124       var <- newLocalVar fs (exprType expr)
125       updGEnv $ \env ->
126         env { global_bindings = (var, expr) : global_bindings env }
127       return var
128
129 takeHoisted :: VM [(Var, CoreExpr)]
130 takeHoisted
131   = do
132       env <- readGEnv id
133       setGEnv $ env { global_bindings = [] }
134       return $ global_bindings env
135