mkDFunUnfolding wants the type of the dfun to be a PredTy
[ghc-hetmet.git] / compiler / vectorise / Vectorise / Type / PADict.hs
1
2 module Vectorise.Type.PADict
3         (buildPADict)
4 where
5 import Vectorise.Monad
6 import Vectorise.Builtins
7 import Vectorise.Type.Repr
8 import Vectorise.Type.PRepr
9 import Vectorise.Type.PRDict
10 import Vectorise.Utils
11
12 import BasicTypes
13 import CoreSyn
14 import CoreUtils
15 import CoreUnfold
16 import TyCon
17 import Type
18 import TypeRep
19 import Id
20 import Var
21 import Name
22 import Outputable
23 import Class
24
25 debug           = False
26 dtrace s x      = if debug then pprTrace "Vectoris.Type.PADict" s x else x
27
28 -- | Build the PA dictionary for some type and hoist it to top level.
29 --   The PA dictionary holds fns that convert values to and from their vectorised representations.
30 buildPADict
31         :: TyCon        -- ^ tycon of the type being vectorised.
32         -> TyCon        -- ^ tycon of the type used for the vectorised representation.
33         -> TyCon        -- 
34         -> SumRepr      -- ^ representation used for the type being vectorised.
35         -> VM Var       -- ^ name of the top-level dictionary function.
36
37 buildPADict vect_tc prepr_tc arr_tc repr
38  = dtrace (text "buildPADict" <+> ppr vect_tc <+> ppr prepr_tc <+> ppr arr_tc)
39  $ polyAbstract tvs $ \args@[] ->
40  do
41       -- TODO: I'm forcing args to [] because I'm not sure why we need them.
42       --       class PA has superclass (PR (PRepr a)) but we're not using
43       --       the superclass dictionary to build the PA dictionary.
44
45       -- Get ids for each of the methods in the dictionary.
46       method_ids <- mapM (method args) paMethods
47
48       -- Expression to build the dictionary.
49       pa_dc  <- builtin paDataCon
50       let dict = mkLams (tvs ++ args)
51                $ mkConApp pa_dc
52                $ Type inst_ty : map (method_call args) method_ids
53
54       dtrace (text "dict    = " <+> ppr dict) $ return ()
55
56       -- Build the type of the dictionary function.
57       pa_tc          <- builtin paTyCon
58       let pa_opitems = [(id, NoDefMeth) | id <- method_ids]
59       let pa_cls     = mkClass 
60                         (tyConName pa_tc)
61                         tvs             -- tyvars of class
62                         []              -- fundeps
63                         []              -- superclass predicates
64                         []              -- superclass dict selectors
65                         []              -- associated type families
66                         pa_opitems      -- class op items
67                         pa_tc           -- dictionary type constructor
68                         
69       let dfun_ty = mkForAllTys tvs
70                   $ mkFunTys (map varType args) (PredTy $ ClassP pa_cls [inst_ty])
71
72       -- Set the unfolding for the inliner.
73       raw_dfun <- newExportedVar dfun_name dfun_ty
74       let dfun = raw_dfun `setIdUnfolding`  mkDFunUnfolding dfun_ty (map Var method_ids)
75                           `setInlinePragma` dfunInlinePragma
76
77       -- Add the new binding to the top-level environment.
78       hoistBinding dfun dict
79       return dfun
80   where
81     tvs       = tyConTyVars vect_tc
82     arg_tys   = mkTyVarTys tvs
83     inst_ty   = mkTyConApp vect_tc arg_tys
84
85     dfun_name = mkPADFunOcc (getOccName vect_tc)
86
87     method args (name, build)
88       = localV
89       $ do
90           expr     <- build vect_tc prepr_tc arr_tc repr
91           let body = mkLams (tvs ++ args) expr
92           raw_var  <- newExportedVar (method_name name) (exprType body)
93           let var  = raw_var
94                       `setIdUnfolding` mkInlineUnfolding (Just (length args)) body
95                       `setInlinePragma` alwaysInlinePragma
96           hoistBinding var body
97           return var
98
99     method_call args id = mkApps (Var id) (map Type arg_tys ++ map Var args)
100     method_name name    = mkVarOcc $ occNameString dfun_name ++ ('$' : name)
101
102
103 paMethods :: [(String, TyCon -> TyCon -> TyCon -> SumRepr -> VM CoreExpr)]
104 paMethods = [("dictPRepr",    buildPRDict),
105              ("toPRepr",      buildToPRepr),
106              ("fromPRepr",    buildFromPRepr),
107              ("toArrPRepr",   buildToArrPRepr),
108              ("fromArrPRepr", buildFromArrPRepr)]
109