2 module Vectorise.Type.PADict
6 import Vectorise.Builtins
7 import Vectorise.Type.Repr
8 import Vectorise.Type.PRepr
9 import Vectorise.Type.PRDict
10 import Vectorise.Utils
26 -- dtrace s x = if debug then pprTrace "Vectoris.Type.PADict" s x else x
28 -- | Build the PA dictionary for some type and hoist it to top level.
29 -- The PA dictionary holds fns that convert values to and from their vectorised representations.
31 :: TyCon -- ^ tycon of the type being vectorised.
32 -> TyCon -- ^ tycon of the type used for the vectorised representation.
34 -> SumRepr -- ^ representation used for the type being vectorised.
35 -> VM Var -- ^ name of the top-level dictionary function.
37 buildPADict vect_tc prepr_tc arr_tc repr
38 = polyAbstract tvs $ \args ->
40 (_:_) -> pprPanic "Vectorise.Type.PADict.buildPADict" (text "why do we need superclass dicts?")
42 -- TODO: I'm forcing args to [] because I'm not sure why we need them.
43 -- class PA has superclass (PR (PRepr a)) but we're not using
44 -- the superclass dictionary to build the PA dictionary.
46 -- Get ids for each of the methods in the dictionary.
47 method_ids <- mapM (method args) paMethods
49 -- Expression to build the dictionary.
50 pa_dc <- builtin paDataCon
51 let dict = mkLams (tvs ++ args)
53 $ Type inst_ty : map (method_call args) method_ids
55 -- Build the type of the dictionary function.
56 pa_tc <- builtin paTyCon
57 let pa_opitems = [(id, NoDefMeth) | id <- method_ids]
60 tvs -- tyvars of class
62 [] -- superclass predicates
63 0 -- number of equalities
64 [] -- superclass dict selectors
65 [] -- associated type families
66 pa_opitems -- class op items
67 pa_tc -- dictionary type constructor
69 let dfun_ty = mkForAllTys tvs
70 $ mkFunTys (map varType args) (PredTy $ ClassP pa_cls [inst_ty])
72 -- Set the unfolding for the inliner.
73 raw_dfun <- newExportedVar dfun_name dfun_ty
74 let dfun = raw_dfun `setIdUnfolding` mkDFunUnfolding dfun_ty (map Var method_ids)
75 `setInlinePragma` dfunInlinePragma
77 -- Add the new binding to the top-level environment.
78 hoistBinding dfun dict
81 tvs = tyConTyVars vect_tc
82 arg_tys = mkTyVarTys tvs
83 inst_ty = mkTyConApp vect_tc arg_tys
85 dfun_name = mkPADFunOcc (getOccName vect_tc)
87 method args (name, build)
90 expr <- build vect_tc prepr_tc arr_tc repr
91 let body = mkLams (tvs ++ args) expr
92 raw_var <- newExportedVar (method_name name) (exprType body)
94 `setIdUnfolding` mkInlineUnfolding (Just (length args)) body
95 `setInlinePragma` alwaysInlinePragma
99 method_call args id = mkApps (Var id) (map Type arg_tys ++ map Var args)
100 method_name name = mkVarOcc $ occNameString dfun_name ++ ('$' : name)
103 paMethods :: [(String, TyCon -> TyCon -> TyCon -> SumRepr -> VM CoreExpr)]
104 paMethods = [("dictPRepr", buildPRDict),
105 ("toPRepr", buildToPRepr),
106 ("fromPRepr", buildFromPRepr),
107 ("toArrPRepr", buildToArrPRepr),
108 ("fromArrPRepr", buildFromArrPRepr)]