2 module Vectorise.Type.PADict
6 import Vectorise.Builtins
7 import Vectorise.Type.Repr
8 import Vectorise.Type.PRepr
25 -- dtrace s x = if debug then pprTrace "Vectoris.Type.PADict" s x else x
27 -- | Build the PA dictionary for some type and hoist it to top level.
28 -- The PA dictionary holds fns that convert values to and from their vectorised representations.
30 :: TyCon -- ^ tycon of the type being vectorised.
31 -> TyCon -- ^ tycon of the type used for the vectorised representation.
32 -> TyCon -- ^ PRepr instance tycon
33 -> SumRepr -- ^ representation used for the type being vectorised.
34 -> VM Var -- ^ name of the top-level dictionary function.
36 buildPADict vect_tc prepr_tc arr_tc repr
37 = polyAbstract tvs $ \args ->
39 -- The superclass dictionary is an argument if the tycon is polymorphic
41 r <- mkPReprType inst_ty
42 pr_cls <- builtin prClass
43 return $ PredTy $ ClassP pr_cls [r]
44 super_tys <- sequence [mk_super_ty | not (null tvs)]
45 super_args <- mapM (newLocalVar (fsLit "pr")) super_tys
46 let args' = super_args ++ args
48 -- it is constant otherwise
49 super_consts <- sequence [prDictOfPReprInstTyCon inst_ty prepr_tc []
52 -- Get ids for each of the methods in the dictionary.
53 method_ids <- mapM (method args') paMethods
55 -- Expression to build the dictionary.
56 pa_dc <- builtin paDataCon
57 let dict = mkLams (tvs ++ args')
60 : map Var super_args ++ super_consts
61 -- the superclass dictionary is
62 -- either lambda-bound or
64 ++ map (method_call args') method_ids
66 -- Build the type of the dictionary function.
67 pa_cls <- builtin paClass
68 let dfun_ty = mkForAllTys tvs
69 $ mkFunTys (map varType args')
70 (PredTy $ ClassP pa_cls [inst_ty])
72 -- Set the unfolding for the inliner.
73 raw_dfun <- newExportedVar dfun_name dfun_ty
74 let dfun_unf = mkDFunUnfolding dfun_ty
75 $ map (const $ DFunLamArg 0) super_args
76 ++ map DFunConstArg super_consts
77 ++ map (DFunPolyArg . Var) method_ids
78 dfun = raw_dfun `setIdUnfolding` dfun_unf
79 `setInlinePragma` dfunInlinePragma
81 -- Add the new binding to the top-level environment.
82 hoistBinding dfun dict
85 tvs = tyConTyVars vect_tc
86 arg_tys = mkTyVarTys tvs
87 inst_ty = mkTyConApp vect_tc arg_tys
89 dfun_name = mkPADFunOcc (getOccName vect_tc)
91 method args (name, build)
94 expr <- build vect_tc prepr_tc arr_tc repr
95 let body = mkLams (tvs ++ args) expr
96 raw_var <- newExportedVar (method_name name) (exprType body)
98 `setIdUnfolding` mkInlineUnfolding (Just (length args)) body
99 `setInlinePragma` alwaysInlinePragma
100 hoistBinding var body
103 method_call args id = mkApps (Var id) (map Type arg_tys ++ map Var args)
104 method_name name = mkVarOcc $ occNameString dfun_name ++ ('$' : name)
107 paMethods :: [(String, TyCon -> TyCon -> TyCon -> SumRepr -> VM CoreExpr)]
108 paMethods = [("toPRepr", buildToPRepr),
109 ("fromPRepr", buildFromPRepr),
110 ("toArrPRepr", buildToArrPRepr),
111 ("fromArrPRepr", buildFromArrPRepr)]