Fix vectorisation of recursive types
[ghc-hetmet.git] / compiler / vectorise / Vectorise / Type / PADict.hs
1
2 module Vectorise.Type.PADict
3         (buildPADict)
4 where
5 import Vectorise.Monad
6 import Vectorise.Builtins
7 import Vectorise.Type.Repr
8 import Vectorise.Type.PRepr
9 import Vectorise.Utils
10
11 import BasicTypes
12 import CoreSyn
13 import CoreUtils
14 import CoreUnfold
15 import TyCon
16 import Type
17 import TypeRep
18 import Id
19 import Var
20 import Name
21 import FastString
22 -- import Outputable
23
24 -- debug                = False
25 -- dtrace s x   = if debug then pprTrace "Vectoris.Type.PADict" s x else x
26
27 -- | Build the PA dictionary for some type and hoist it to top level.
28 --   The PA dictionary holds fns that convert values to and from their vectorised representations.
29 buildPADict
30         :: TyCon        -- ^ tycon of the type being vectorised.
31         -> TyCon        -- ^ tycon of the type used for the vectorised representation.
32         -> TyCon        -- ^ PRepr instance tycon
33         -> SumRepr      -- ^ representation used for the type being vectorised.
34         -> VM Var       -- ^ name of the top-level dictionary function.
35
36 buildPADict vect_tc prepr_tc arr_tc repr
37  = polyAbstract tvs $ \args ->
38    do
39       -- The superclass dictionary is an argument if the tycon is polymorphic
40       let mk_super_ty = do
41                           r <- mkPReprType inst_ty
42                           pr_cls <- builtin prClass
43                           return $ PredTy $ ClassP pr_cls [r]
44       super_tys <- sequence [mk_super_ty | not (null tvs)]
45       super_args <- mapM (newLocalVar (fsLit "pr")) super_tys
46       let args' = super_args ++ args
47
48       -- it is constant otherwise
49       super_consts <- sequence [prDictOfPReprInstTyCon inst_ty prepr_tc []
50                                                 | null tvs]
51
52       -- Get ids for each of the methods in the dictionary.
53       method_ids <- mapM (method args') paMethods
54
55       -- Expression to build the dictionary.
56       pa_dc  <- builtin paDataCon
57       let dict = mkLams (tvs ++ args')
58                $ mkConApp pa_dc
59                $ Type inst_ty
60                : map Var super_args ++ super_consts
61                                    -- the superclass dictionary is
62                                    -- either lambda-bound or
63                                    -- constant
64                  ++ map (method_call args') method_ids
65
66       -- Build the type of the dictionary function.
67       pa_cls <- builtin paClass
68       let dfun_ty       = mkForAllTys tvs
69                         $ mkFunTys (map varType args')
70                                    (PredTy $ ClassP pa_cls [inst_ty])
71
72       -- Set the unfolding for the inliner.
73       raw_dfun <- newExportedVar dfun_name dfun_ty
74       let dfun_unf = mkDFunUnfolding dfun_ty
75                    $ map (const $ DFunLamArg 0) super_args
76                      ++ map DFunConstArg super_consts
77                      ++ map (DFunPolyArg . Var) method_ids
78           dfun = raw_dfun `setIdUnfolding`  dfun_unf
79                           `setInlinePragma` dfunInlinePragma
80
81       -- Add the new binding to the top-level environment.
82       hoistBinding dfun dict
83       return dfun
84   where
85     tvs       = tyConTyVars vect_tc
86     arg_tys   = mkTyVarTys tvs
87     inst_ty   = mkTyConApp vect_tc arg_tys
88
89     dfun_name = mkPADFunOcc (getOccName vect_tc)
90
91     method args (name, build)
92       = localV
93       $ do
94           expr     <- build vect_tc prepr_tc arr_tc repr
95           let body = mkLams (tvs ++ args) expr
96           raw_var  <- newExportedVar (method_name name) (exprType body)
97           let var  = raw_var
98                       `setIdUnfolding` mkInlineUnfolding (Just (length args)) body
99                       `setInlinePragma` alwaysInlinePragma
100           hoistBinding var body
101           return var
102
103     method_call args id = mkApps (Var id) (map Type arg_tys ++ map Var args)
104     method_name name    = mkVarOcc $ occNameString dfun_name ++ ('$' : name)
105
106
107 paMethods :: [(String, TyCon -> TyCon -> TyCon -> SumRepr -> VM CoreExpr)]
108 paMethods = [("toPRepr",      buildToPRepr),
109              ("fromPRepr",    buildFromPRepr),
110              ("toArrPRepr",   buildToArrPRepr),
111              ("fromArrPRepr", buildFromArrPRepr)]
112