Finish breaking up vectoriser utils
[ghc-hetmet.git] / compiler / vectorise / Vectorise / Type / PADict.hs
1
2 module Vectorise.Type.PADict
3         (buildPADict)
4 where
5 import Vectorise.Monad
6 import Vectorise.Builtins
7 import Vectorise.Type.Repr
8 import Vectorise.Type.PRepr
9 import Vectorise.Type.PRDict
10 import Vectorise.Utils
11
12 import BasicTypes
13 import CoreSyn
14 import CoreUtils
15 import CoreUnfold
16 import TyCon
17 import Type
18 import OccName
19 import Id
20 import Var
21 import Name
22
23
24
25 buildPADict :: TyCon -> TyCon -> TyCon -> SumRepr -> VM Var
26 buildPADict vect_tc prepr_tc arr_tc repr
27   = polyAbstract tvs $ \args ->
28     do
29       method_ids <- mapM (method args) paMethods
30
31       pa_tc  <- builtin paTyCon
32       pa_dc  <- builtin paDataCon
33       let dict = mkLams (tvs ++ args)
34                $ mkConApp pa_dc
35                $ Type inst_ty : map (method_call args) method_ids
36
37           dfun_ty = mkForAllTys tvs
38                   $ mkFunTys (map varType args) (mkTyConApp pa_tc [inst_ty])
39
40       raw_dfun <- newExportedVar dfun_name dfun_ty
41       let dfun = raw_dfun `setIdUnfolding` mkDFunUnfolding dfun_ty (map Var method_ids)
42                           `setInlinePragma` dfunInlinePragma
43
44       hoistBinding dfun dict
45       return dfun
46   where
47     tvs = tyConTyVars vect_tc
48     arg_tys = mkTyVarTys tvs
49     inst_ty = mkTyConApp vect_tc arg_tys
50
51     dfun_name = mkPADFunOcc (getOccName vect_tc)
52
53     method args (name, build)
54       = localV
55       $ do
56           expr <- build vect_tc prepr_tc arr_tc repr
57           let body = mkLams (tvs ++ args) expr
58           raw_var <- newExportedVar (method_name name) (exprType body)
59           let var = raw_var
60                       `setIdUnfolding` mkInlineRule body (Just (length args))
61                       `setInlinePragma` alwaysInlinePragma
62           hoistBinding var body
63           return var
64
65     method_call args id = mkApps (Var id) (map Type arg_tys ++ map Var args)
66
67     method_name name = mkVarOcc $ occNameString dfun_name ++ ('$' : name)
68
69
70 paMethods :: [(String, TyCon -> TyCon -> TyCon -> SumRepr -> VM CoreExpr)]
71 paMethods = [("dictPRepr",    buildPRDict),
72              ("toPRepr",      buildToPRepr),
73              ("fromPRepr",    buildFromPRepr),
74              ("toArrPRepr",   buildToArrPRepr),
75              ("fromArrPRepr", buildFromArrPRepr)]
76