vectoriser: fix conflicts
[ghc-hetmet.git] / compiler / vectorise / Vectorise / Type / PADict.hs
1
2 module Vectorise.Type.PADict
3         (buildPADict)
4 where
5 import Vectorise.Monad
6 import Vectorise.Builtins
7 import Vectorise.Type.Repr
8 import Vectorise.Type.PRepr
9 import Vectorise.Type.PRDict
10 import Vectorise.Utils
11
12 import BasicTypes
13 import CoreSyn
14 import CoreUtils
15 import CoreUnfold
16 import TyCon
17 import Type
18 import Id
19 import Var
20 import Name
21
22
23 -- | Build the PA dictionary for some type and hoist it to top level.
24 --   The PA dictionary holds fns that convert values to and from their vectorised representations.
25 buildPADict
26         :: TyCon        -- ^ tycon of the type being vectorised.
27         -> TyCon        -- ^ tycon of the type used for the vectorised representation.
28         -> TyCon        -- 
29         -> SumRepr      -- ^ representation used for the type being vectorised.
30         -> VM Var       -- ^ name of the top-level dictionary function.
31
32 buildPADict vect_tc prepr_tc arr_tc repr
33   = polyAbstract tvs $ \args ->
34     do
35       method_ids <- mapM (method args) paMethods
36
37       pa_tc  <- builtin paTyCon
38       pa_dc  <- builtin paDataCon
39       let dict = mkLams (tvs ++ args)
40                $ mkConApp pa_dc
41                $ Type inst_ty : map (method_call args) method_ids
42
43           dfun_ty = mkForAllTys tvs
44                   $ mkFunTys (map varType args) (mkTyConApp pa_tc [inst_ty])
45
46       -- Set the unfolding for the inliner.
47       raw_dfun <- newExportedVar dfun_name dfun_ty
48       let dfun = raw_dfun `setIdUnfolding`  mkDFunUnfolding dfun_ty (map Var method_ids)
49                           `setInlinePragma` dfunInlinePragma
50
51       -- Add the new binding to the top-level environment.
52       hoistBinding dfun dict
53       return dfun
54   where
55     tvs       = tyConTyVars vect_tc
56     arg_tys   = mkTyVarTys tvs
57     inst_ty   = mkTyConApp vect_tc arg_tys
58
59     dfun_name = mkPADFunOcc (getOccName vect_tc)
60
61     method args (name, build)
62       = localV
63       $ do
64           expr     <- build vect_tc prepr_tc arr_tc repr
65           let body = mkLams (tvs ++ args) expr
66           raw_var  <- newExportedVar (method_name name) (exprType body)
67           let var  = raw_var
68                       `setIdUnfolding` mkInlineUnfolding (Just (length args)) body
69                       `setInlinePragma` alwaysInlinePragma
70           hoistBinding var body
71           return var
72
73     method_call args id = mkApps (Var id) (map Type arg_tys ++ map Var args)
74     method_name name    = mkVarOcc $ occNameString dfun_name ++ ('$' : name)
75
76
77 paMethods :: [(String, TyCon -> TyCon -> TyCon -> SumRepr -> VM CoreExpr)]
78 paMethods = [("dictPRepr",    buildPRDict),
79              ("toPRepr",      buildToPRepr),
80              ("fromPRepr",    buildFromPRepr),
81              ("toArrPRepr",   buildToArrPRepr),
82              ("fromArrPRepr", buildFromArrPRepr)]
83