2 module Vectorise.Utils.PADict (
10 import Vectorise.Monad
11 import Vectorise.Builtins
12 import Vectorise.Utils.Base
26 mkPADictType :: Type -> VM Type
27 mkPADictType ty = mkBuiltinTyConApp paTyCon [ty]
30 paDictArgType :: TyVar -> VM (Maybe Type)
31 paDictArgType tv = go (TyVarTy tv) (tyVarKind tv)
33 go ty k | Just k' <- kindView k = go ty k'
36 tv <- newTyVar (fsLit "a") k1
37 mty1 <- go (TyVarTy tv) k1
40 mty2 <- go (AppTy ty (TyVarTy tv)) k2
41 return $ fmap (ForAllTy tv . FunTy ty1) mty2
46 = liftM Just (mkPADictType ty)
48 go _ _ = return Nothing
51 -- | Get the PA dictionary for some type
52 paDictOfType :: Type -> VM CoreExpr
54 = paDictOfTyApp ty_fn ty_args
56 (ty_fn, ty_args) = splitAppTys ty
58 paDictOfTyApp :: Type -> [Type] -> VM CoreExpr
59 paDictOfTyApp ty_fn ty_args
60 | Just ty_fn' <- coreView ty_fn
61 = paDictOfTyApp ty_fn' ty_args
63 -- for type variables, look up the dfun and apply to the PA dictionaries
64 -- of the type arguments
65 paDictOfTyApp (TyVarTy tv) ty_args
66 = do dfun <- maybeV (lookupTyVarPA tv)
67 dicts <- mapM paDictOfType ty_args
68 return $ dfun `mkTyApps` ty_args `mkApps` dicts
70 -- for tycons, we also need to apply the dfun to the PR dictionary of
71 -- the representation type
72 paDictOfTyApp (TyConApp tc []) ty_args
74 dfun <- maybeV $ lookupTyConPA tc
75 pr <- prDictOfPRepr tc ty_args
76 dicts <- mapM paDictOfType ty_args
77 return $ Var dfun `mkTyApps` ty_args `mkApps` (pr:dicts)
79 paDictOfTyApp _ _ = failure
81 failure = cantVectorise "Can't construct PA dictionary for type" (ppr ty)
85 paDFunType :: TyCon -> VM Type
88 margs <- mapM paDictArgType tvs
89 res <- mkPADictType (mkTyConApp tc arg_tys)
90 return . mkForAllTys tvs
91 $ mkFunTys [arg | Just arg <- margs] res
94 arg_tys = mkTyVarTys tvs
96 paMethod :: (Builtins -> Var) -> String -> Type -> VM CoreExpr
98 | Just tycon <- splitPrimTyCon ty
100 . maybeCantVectoriseM "No PA method" (text name <+> text "for" <+> ppr tycon)
101 $ lookupPrimMethod tycon name
106 dict <- paDictOfType ty
107 return $ mkApps (Var fn) [Type ty, dict]
109 -- | Get the PR (PRepr t) dictionary, where t is the tycon applied to the type
111 prDictOfPRepr :: TyCon -> [Type] -> VM CoreExpr
112 prDictOfPRepr tycon tys
114 (prepr_tc, prepr_args) <- preprSynTyCon (mkTyConApp tycon tys)
115 case coreView (mkTyConApp prepr_tc prepr_args) of
117 dict <- prDictOfReprType rhs
118 pr_co <- mkBuiltinCo prTyCon
119 let Just arg_co = tyConFamilyCoercion_maybe prepr_tc
120 let co = mkAppCoercion pr_co
122 $ mkTyConApp arg_co prepr_args
123 return $ mkCoerce co dict
124 Nothing -> cantVectorise "Invalid PRepr type instance"
125 $ ppr $ mkTyConApp prepr_tc prepr_args
127 -- | Get the PR dictionary for a type. The argument must be a representation
129 prDictOfReprType :: Type -> VM CoreExpr
131 | Just (tycon, tyargs) <- splitTyConApp_maybe ty
133 -- a representation tycon must have a PR instance
134 dfun <- maybeV $ lookupTyConPR tycon
135 prDFunApply dfun tyargs
139 -- it is a tyvar or an application of a tyvar
140 -- determine the PR dictionary from its PA dictionary
142 -- NOTE: This assumes that PRepr t ~ t is for all representation types
145 -- FIXME: This doesn't work for kinds other than * at the moment. We'd
146 -- have to simply abstract the term over the missing type arguments.
147 pa <- paDictOfType ty
148 prsel <- builtin paPRSel
149 return $ Var prsel `mkApps` [Type ty, pa]
151 -- | Apply a tycon's PR dfun to dictionary arguments (PR or PA) corresponding
152 -- to the argument types.
153 prDFunApply :: Var -> [Type] -> VM CoreExpr
155 | Just [] <- ctxs -- PR (a :-> b) doesn't have a context
156 = return $ Var dfun `mkTyApps` tys
158 | Just tycons <- ctxs
159 , length tycons == length tys
161 pa <- builtin paTyCon
162 pr <- builtin prTyCon
163 args <- zipWithM (dictionary pa pr) tys tycons
164 return $ Var dfun `mkTyApps` tys `mkApps` args
166 | otherwise = invalid
168 -- the dfun's contexts - if its type is (PA a, PR b) => PR (C a b) then
169 -- ctxs is Just [PA, PR]
170 ctxs = fmap (map fst)
172 $ map splitTyConApp_maybe
179 dictionary pa pr ty tycon
180 | tycon == pa = paDictOfType ty
181 | tycon == pr = prDictOfReprType ty
182 | otherwise = invalid
184 invalid = cantVectorise "Invalid PR dfun type" (ppr (varType dfun) <+> ppr tys)