Only vectorise rank-1 expressions
[ghc-hetmet.git] / compiler / vectorise / VectUtils.hs
1 module VectUtils (
2   collectAnnTypeBinders, collectAnnTypeArgs, isAnnTypeArg,
3   splitClosureTy,
4   mkPADictType, mkPArrayType,
5   paDictArgType, paDictOfType
6 ) where
7
8 #include "HsVersions.h"
9
10 import VectMonad
11
12 import CoreSyn
13 import Type
14 import TypeRep
15 import TyCon
16 import Var
17 import PrelNames
18
19 import Outputable
20
21 import Control.Monad         ( liftM )
22
23 collectAnnTypeArgs :: AnnExpr b ann -> (AnnExpr b ann, [Type])
24 collectAnnTypeArgs expr = go expr []
25   where
26     go (_, AnnApp f (_, AnnType ty)) tys = go f (ty : tys)
27     go e                             tys = (e, tys)
28
29 collectAnnTypeBinders :: AnnExpr Var ann -> ([Var], AnnExpr Var ann)
30 collectAnnTypeBinders expr = go [] expr
31   where
32     go bs (_, AnnLam b e) | isTyVar b = go (b:bs) e
33     go bs e                           = (reverse bs, e)
34
35 isAnnTypeArg :: AnnExpr b ann -> Bool
36 isAnnTypeArg (_, AnnType t) = True
37 isAnnTypeArg _              = False
38
39 isClosureTyCon :: TyCon -> Bool
40 isClosureTyCon tc = tyConUnique tc == closureTyConKey
41
42 splitClosureTy :: Type -> (Type, Type)
43 splitClosureTy ty
44   | Just (tc, [arg_ty, res_ty]) <- splitTyConApp_maybe ty
45   , isClosureTyCon tc
46   = (arg_ty, res_ty)
47
48   | otherwise = pprPanic "splitClosureTy" (ppr ty)
49
50 mkPADictType :: Type -> VM Type
51 mkPADictType ty
52   = do
53       tc <- builtin paDictTyCon
54       return $ TyConApp tc [ty]
55
56 mkPArrayType :: Type -> VM Type
57 mkPArrayType ty
58   = do
59       tc <- builtin parrayTyCon
60       return $ TyConApp tc [ty]
61
62 paDictArgType :: TyVar -> VM (Maybe Type)
63 paDictArgType tv = go (TyVarTy tv) (tyVarKind tv)
64   where
65     go ty k | Just k' <- kindView k = go ty k'
66     go ty (FunTy k1 k2)
67       = do
68           tv   <- newTyVar FSLIT("a") k1
69           mty1 <- go (TyVarTy tv) k1
70           case mty1 of
71             Just ty1 -> do
72                           mty2 <- go (AppTy ty (TyVarTy tv)) k2
73                           return $ fmap (ForAllTy tv . FunTy ty1) mty2
74             Nothing  -> go ty k2
75
76     go ty k
77       | isLiftedTypeKind k
78       = liftM Just (mkPADictType ty)
79
80     go ty k = return Nothing
81
82 paDictOfType :: Type -> VM CoreExpr
83 paDictOfType ty = paDictOfTyApp ty_fn ty_args
84   where
85     (ty_fn, ty_args) = splitAppTys ty
86
87 paDictOfTyApp :: Type -> [Type] -> VM CoreExpr
88 paDictOfTyApp ty_fn ty_args
89   | Just ty_fn' <- coreView ty_fn = paDictOfTyApp ty_fn' ty_args
90 paDictOfTyApp (TyVarTy tv) ty_args
91   = do
92       dfun <- maybeV (lookupTyVarPA tv)
93       paDFunApply dfun ty_args
94 paDictOfTyApp (TyConApp tc _) ty_args
95   = do
96       pa_class <- builtin paClass
97       (dfun, ty_args') <- lookupInst pa_class [TyConApp tc ty_args]
98       paDFunApply (Var dfun) ty_args'
99 paDictOfTyApp ty ty_args = pprPanic "paDictOfTyApp" (ppr ty)
100
101 paDFunApply :: CoreExpr -> [Type] -> VM CoreExpr
102 paDFunApply dfun tys
103   = do
104       dicts <- mapM paDictOfType tys
105       return $ mkApps (mkTyApps dfun tys) dicts
106