import Vectorise.Builtins
import Vectorise.Type.Repr
import Vectorise.Type.PRepr
-import Vectorise.Type.PRDict
import Vectorise.Utils
import BasicTypes
import CoreUnfold
import TyCon
import Type
+import TypeRep
import Id
import Var
import Name
+import FastString
+-- import Outputable
+-- debug = False
+-- dtrace s x = if debug then pprTrace "Vectoris.Type.PADict" s x else x
-- | Build the PA dictionary for some type and hoist it to top level.
-- The PA dictionary holds fns that convert values to and from their vectorised representations.
buildPADict
:: TyCon -- ^ tycon of the type being vectorised.
-> TyCon -- ^ tycon of the type used for the vectorised representation.
- -> TyCon --
+ -> TyCon -- ^ PRepr instance tycon
-> SumRepr -- ^ representation used for the type being vectorised.
-> VM Var -- ^ name of the top-level dictionary function.
buildPADict vect_tc prepr_tc arr_tc repr
- = polyAbstract tvs $ \args ->
- do
- method_ids <- mapM (method args) paMethods
+ = polyAbstract tvs $ \args ->
+ do
+ -- The superclass dictionary is an argument if the tycon is polymorphic
+ let mk_super_ty = do
+ r <- mkPReprType inst_ty
+ pr_cls <- builtin prClass
+ return $ PredTy $ ClassP pr_cls [r]
+ super_tys <- sequence [mk_super_ty | not (null tvs)]
+ super_args <- mapM (newLocalVar (fsLit "pr")) super_tys
+ let args' = super_args ++ args
- pa_tc <- builtin paTyCon
+ -- it is constant otherwise
+ super_consts <- sequence [prDictOfPReprInstTyCon inst_ty prepr_tc []
+ | null tvs]
+
+ -- Get ids for each of the methods in the dictionary.
+ method_ids <- mapM (method args') paMethods
+
+ -- Expression to build the dictionary.
pa_dc <- builtin paDataCon
- let dict = mkLams (tvs ++ args)
+ let dict = mkLams (tvs ++ args')
$ mkConApp pa_dc
- $ Type inst_ty : map (method_call args) method_ids
+ $ Type inst_ty
+ : map Var super_args ++ super_consts
+ -- the superclass dictionary is
+ -- either lambda-bound or
+ -- constant
+ ++ map (method_call args') method_ids
- dfun_ty = mkForAllTys tvs
- $ mkFunTys (map varType args) (mkTyConApp pa_tc [inst_ty])
+ -- Build the type of the dictionary function.
+ pa_cls <- builtin paClass
+ let dfun_ty = mkForAllTys tvs
+ $ mkFunTys (map varType args')
+ (PredTy $ ClassP pa_cls [inst_ty])
-- Set the unfolding for the inliner.
raw_dfun <- newExportedVar dfun_name dfun_ty
- let dfun = raw_dfun `setIdUnfolding` mkDFunUnfolding dfun_ty (map Var method_ids)
+ let dfun_unf = mkDFunUnfolding dfun_ty
+ $ map (const $ DFunLamArg 0) super_args
+ ++ map DFunConstArg super_consts
+ ++ map (DFunPolyArg . Var) method_ids
+ dfun = raw_dfun `setIdUnfolding` dfun_unf
`setInlinePragma` dfunInlinePragma
-- Add the new binding to the top-level environment.
$ do
expr <- build vect_tc prepr_tc arr_tc repr
let body = mkLams (tvs ++ args) expr
- raw_var <- newExportedVar (method_name name) (exprType body)
- let var = raw_var
- `setIdUnfolding` mkInlineRule body (Just (length args))
+ raw_var <- newExportedVar (method_name name) (exprType body)
+ let var = raw_var
+ `setIdUnfolding` mkInlineUnfolding (Just (length args)) body
`setInlinePragma` alwaysInlinePragma
hoistBinding var body
return var
paMethods :: [(String, TyCon -> TyCon -> TyCon -> SumRepr -> VM CoreExpr)]
-paMethods = [("dictPRepr", buildPRDict),
- ("toPRepr", buildToPRepr),
+paMethods = [("toPRepr", buildToPRepr),
("fromPRepr", buildFromPRepr),
("toArrPRepr", buildToArrPRepr),
("fromArrPRepr", buildFromArrPRepr)]