vectoriser: adapt to new superclass story part I (dictionary construction)
[ghc-hetmet.git] / compiler / vectorise / Vectorise / Utils / PADict.hs
1
2 module Vectorise.Utils.PADict (
3         mkPADictType,
4         paDictArgType,
5         paDictOfType,
6         paDFunType,
7         paMethod        
8 )
9 where
10 import Vectorise.Monad
11 import Vectorise.Builtins
12 import Vectorise.Utils.Base
13
14 import CoreSyn
15 import CoreUtils
16 import Coercion
17 import Type
18 import TypeRep
19 import TyCon
20 import Var
21 import Outputable
22 import FastString
23 import Control.Monad
24
25
26 mkPADictType :: Type -> VM Type
27 mkPADictType ty = mkBuiltinTyConApp paTyCon [ty]
28
29
30 paDictArgType :: TyVar -> VM (Maybe Type)
31 paDictArgType tv = go (TyVarTy tv) (tyVarKind tv)
32   where
33     go ty k | Just k' <- kindView k = go ty k'
34     go ty (FunTy k1 k2)
35       = do
36           tv   <- newTyVar (fsLit "a") k1
37           mty1 <- go (TyVarTy tv) k1
38           case mty1 of
39             Just ty1 -> do
40                           mty2 <- go (AppTy ty (TyVarTy tv)) k2
41                           return $ fmap (ForAllTy tv . FunTy ty1) mty2
42             Nothing  -> go ty k2
43
44     go ty k
45       | isLiftedTypeKind k
46       = liftM Just (mkPADictType ty)
47
48     go _ _ = return Nothing
49
50
51 -- | Get the PA dictionary for some type
52 paDictOfType :: Type -> VM CoreExpr
53 paDictOfType ty 
54   = paDictOfTyApp ty_fn ty_args
55   where
56     (ty_fn, ty_args) = splitAppTys ty
57
58     paDictOfTyApp :: Type -> [Type] -> VM CoreExpr
59     paDictOfTyApp ty_fn ty_args
60         | Just ty_fn' <- coreView ty_fn 
61         = paDictOfTyApp ty_fn' ty_args
62
63     -- for type variables, look up the dfun and apply to the PA dictionaries
64     -- of the type arguments
65     paDictOfTyApp (TyVarTy tv) ty_args
66      = do dfun <- maybeV (lookupTyVarPA tv)
67           dicts <- mapM paDictOfType ty_args
68           return $ dfun `mkTyApps` ty_args `mkApps` dicts
69
70     -- for tycons, we also need to apply the dfun to the PR dictionary of
71     -- the representation type
72     paDictOfTyApp (TyConApp tc []) ty_args
73      = do
74          dfun <- maybeV $ lookupTyConPA tc
75          pr <- prDictOfPRepr tc ty_args
76          dicts <- mapM paDictOfType ty_args
77          return $ Var dfun `mkTyApps` ty_args `mkApps` (pr:dicts)
78
79     paDictOfTyApp _ _ = failure
80
81     failure = cantVectorise "Can't construct PA dictionary for type" (ppr ty)
82
83
84
85 paDFunType :: TyCon -> VM Type
86 paDFunType tc
87   = do
88       margs <- mapM paDictArgType tvs
89       res   <- mkPADictType (mkTyConApp tc arg_tys)
90       return . mkForAllTys tvs
91              $ mkFunTys [arg | Just arg <- margs] res
92   where
93     tvs = tyConTyVars tc
94     arg_tys = mkTyVarTys tvs
95
96 paMethod :: (Builtins -> Var) -> String -> Type -> VM CoreExpr
97 paMethod _ name ty
98   | Just tycon <- splitPrimTyCon ty
99   = liftM Var
100   . maybeCantVectoriseM "No PA method" (text name <+> text "for" <+> ppr tycon)
101   $ lookupPrimMethod tycon name
102
103 paMethod method _ ty
104   = do
105       fn   <- builtin method
106       dict <- paDictOfType ty
107       return $ mkApps (Var fn) [Type ty, dict]
108
109 -- | Get the PR (PRepr t) dictionary, where t is the tycon applied to the type
110 -- arguments
111 prDictOfPRepr :: TyCon -> [Type] -> VM CoreExpr
112 prDictOfPRepr tycon tys
113   = do
114       (prepr_tc, prepr_args) <- preprSynTyCon (mkTyConApp tycon tys)
115       case coreView (mkTyConApp prepr_tc prepr_args) of
116         Just rhs -> do
117                       dict <- prDictOfReprType rhs
118                       pr_co <- mkBuiltinCo prTyCon
119                       let Just arg_co = tyConFamilyCoercion_maybe prepr_tc
120                       let co = mkAppCoercion pr_co
121                              $ mkSymCoercion
122                              $ mkTyConApp arg_co prepr_args
123                       return $ mkCoerce co dict
124         Nothing  -> cantVectorise "Invalid PRepr type instance"
125                                   $ ppr $ mkTyConApp prepr_tc prepr_args
126
127 -- | Get the PR dictionary for a type. The argument must be a representation
128 -- type.
129 prDictOfReprType :: Type -> VM CoreExpr
130 prDictOfReprType ty
131   | Just (tycon, tyargs) <- splitTyConApp_maybe ty
132     = do
133         -- a representation tycon must have a PR instance
134         dfun <- maybeV $ lookupTyConPR tycon
135         prDFunApply dfun tyargs
136
137   | otherwise
138     = do
139         -- it is a tyvar or an application of a tyvar
140         -- determine the PR dictionary from its PA dictionary
141         --
142         -- NOTE: This assumes that PRepr t ~ t is for all representation types
143         -- t
144         --
145         -- FIXME: This doesn't work for kinds other than * at the moment. We'd
146         -- have to simply abstract the term over the missing type arguments.
147         pa    <- paDictOfType ty
148         prsel <- builtin paPRSel
149         return $ Var prsel `mkApps` [Type ty, pa]
150
151 -- | Apply a tycon's PR dfun to dictionary arguments (PR or PA) corresponding
152 -- to the argument types.
153 prDFunApply :: Var -> [Type] -> VM CoreExpr
154 prDFunApply dfun tys
155   | Just [] <- ctxs    -- PR (a :-> b) doesn't have a context
156   = return $ Var dfun `mkTyApps` tys
157
158   | Just tycons <- ctxs
159   , length tycons == length tys
160   = do
161       pa <- builtin paTyCon
162       pr <- builtin prTyCon 
163       args <- zipWithM (dictionary pa pr) tys tycons
164       return $ Var dfun `mkTyApps` tys `mkApps` args
165
166   | otherwise = invalid
167   where
168     -- the dfun's contexts - if its type is (PA a, PR b) => PR (C a b) then
169     -- ctxs is Just [PA, PR]
170     ctxs = fmap (map fst)
171          $ sequence
172          $ map splitTyConApp_maybe
173          $ fst
174          $ splitFunTys
175          $ snd
176          $ splitForAllTys
177          $ varType dfun
178
179     dictionary pa pr ty tycon
180       | tycon == pa = paDictOfType ty
181       | tycon == pr = prDictOfReprType ty
182       | otherwise   = invalid
183
184     invalid = cantVectorise "Invalid PR dfun type" (ppr (varType dfun) <+> ppr tys)
185