Implement INLINABLE pragma
[ghc-hetmet.git] / compiler / vectorise / Vectorise / Type / PADict.hs
1
2 module Vectorise.Type.PADict
3         (buildPADict)
4 where
5 import Vectorise.Monad
6 import Vectorise.Builtins
7 import Vectorise.Type.Repr
8 import Vectorise.Type.PRepr
9 import Vectorise.Type.PRDict
10 import Vectorise.Utils
11
12 import BasicTypes
13 import CoreSyn
14 import CoreUtils
15 import CoreUnfold
16 import TyCon
17 import Type
18 import Id
19 import Var
20 import Name
21
22
23
24 buildPADict :: TyCon -> TyCon -> TyCon -> SumRepr -> VM Var
25 buildPADict vect_tc prepr_tc arr_tc repr
26   = polyAbstract tvs $ \args ->
27     do
28       method_ids <- mapM (method args) paMethods
29
30       pa_tc  <- builtin paTyCon
31       pa_dc  <- builtin paDataCon
32       let dict = mkLams (tvs ++ args)
33                $ mkConApp pa_dc
34                $ Type inst_ty : map (method_call args) method_ids
35
36           dfun_ty = mkForAllTys tvs
37                   $ mkFunTys (map varType args) (mkTyConApp pa_tc [inst_ty])
38
39       raw_dfun <- newExportedVar dfun_name dfun_ty
40       let dfun = raw_dfun `setIdUnfolding` mkDFunUnfolding dfun_ty (map Var method_ids)
41                           `setInlinePragma` dfunInlinePragma
42
43       hoistBinding dfun dict
44       return dfun
45   where
46     tvs = tyConTyVars vect_tc
47     arg_tys = mkTyVarTys tvs
48     inst_ty = mkTyConApp vect_tc arg_tys
49
50     dfun_name = mkPADFunOcc (getOccName vect_tc)
51
52     method args (name, build)
53       = localV
54       $ do
55           expr <- build vect_tc prepr_tc arr_tc repr
56           let body = mkLams (tvs ++ args) expr
57           raw_var <- newExportedVar (method_name name) (exprType body)
58           let var = raw_var
59                       `setIdUnfolding` mkInlineUnfolding (Just (length args)) body
60                       `setInlinePragma` alwaysInlinePragma
61           hoistBinding var body
62           return var
63
64     method_call args id = mkApps (Var id) (map Type arg_tys ++ map Var args)
65
66     method_name name = mkVarOcc $ occNameString dfun_name ++ ('$' : name)
67
68
69 paMethods :: [(String, TyCon -> TyCon -> TyCon -> SumRepr -> VM CoreExpr)]
70 paMethods = [("dictPRepr",    buildPRDict),
71              ("toPRepr",      buildToPRepr),
72              ("fromPRepr",    buildFromPRepr),
73              ("toArrPRepr",   buildToArrPRepr),
74              ("fromArrPRepr", buildFromArrPRepr)]
75