2 module Vectorise.Type.PADict
6 import Vectorise.Builtins
7 import Vectorise.Type.Repr
8 import Vectorise.Type.PRepr
9 import Vectorise.Type.PRDict
10 import Vectorise.Utils
23 -- | Build the PA dictionary for some type and hoist it to top level.
24 -- The PA dictionary holds fns that convert values to and from their vectorised representations.
26 :: TyCon -- ^ tycon of the type being vectorised.
27 -> TyCon -- ^ tycon of the type used for the vectorised representation.
29 -> SumRepr -- ^ representation used for the type being vectorised.
30 -> VM Var -- ^ name of the top-level dictionary function.
32 buildPADict vect_tc prepr_tc arr_tc repr
33 = polyAbstract tvs $ \args ->
35 method_ids <- mapM (method args) paMethods
37 pa_tc <- builtin paTyCon
38 pa_dc <- builtin paDataCon
39 let dict = mkLams (tvs ++ args)
41 $ Type inst_ty : map (method_call args) method_ids
43 dfun_ty = mkForAllTys tvs
44 $ mkFunTys (map varType args) (mkTyConApp pa_tc [inst_ty])
46 -- Set the unfolding for the inliner.
47 raw_dfun <- newExportedVar dfun_name dfun_ty
48 let dfun = raw_dfun `setIdUnfolding` mkDFunUnfolding dfun_ty (map Var method_ids)
49 `setInlinePragma` dfunInlinePragma
51 -- Add the new binding to the top-level environment.
52 hoistBinding dfun dict
55 tvs = tyConTyVars vect_tc
56 arg_tys = mkTyVarTys tvs
57 inst_ty = mkTyConApp vect_tc arg_tys
59 dfun_name = mkPADFunOcc (getOccName vect_tc)
61 method args (name, build)
64 expr <- build vect_tc prepr_tc arr_tc repr
65 let body = mkLams (tvs ++ args) expr
66 raw_var <- newExportedVar (method_name name) (exprType body)
68 `setIdUnfolding` mkInlineUnfolding (Just (length args)) body
69 `setInlinePragma` alwaysInlinePragma
73 method_call args id = mkApps (Var id) (map Type arg_tys ++ map Var args)
74 method_name name = mkVarOcc $ occNameString dfun_name ++ ('$' : name)
77 paMethods :: [(String, TyCon -> TyCon -> TyCon -> SumRepr -> VM CoreExpr)]
78 paMethods = [("dictPRepr", buildPRDict),
79 ("toPRepr", buildToPRepr),
80 ("fromPRepr", buildFromPRepr),
81 ("toArrPRepr", buildToArrPRepr),
82 ("fromArrPRepr", buildFromArrPRepr)]