05b1289ee83703677dae1f45a1afd55d9a3eefc6
[ghc-hetmet.git] / compiler / vectorise / VectBuiltIn.hs
1 module VectBuiltIn (
2   Builtins(..), sumTyCon, prodTyCon,
3   initBuiltins, initBuiltinTyCons, initBuiltinPAs, initBuiltinPRs,
4
5   primMethod, primPArray
6 ) where
7
8 #include "HsVersions.h"
9
10 import DsMonad
11 import IfaceEnv        ( lookupOrig )
12
13 import Module          ( Module )
14 import DataCon         ( DataCon )
15 import TyCon           ( TyCon, tyConName, tyConDataCons )
16 import Var             ( Var )
17 import Id              ( mkSysLocal )
18 import Name            ( Name, getOccString )
19 import NameEnv
20 import OccName
21
22 import TypeRep         ( funTyCon )
23 import Type            ( Type )
24 import TysPrim
25 import TysWiredIn      ( unitTyCon, tupleTyCon, intTyConName )
26 import Module          ( Module, mkModule, mkModuleNameFS )
27 import PackageConfig   ( ndpPackageId )
28 import BasicTypes      ( Boxity(..) )
29
30 import FastString
31 import Outputable
32
33 import Data.Array
34 import Control.Monad   ( liftM, zipWithM )
35
36 mAX_NDP_PROD :: Int
37 mAX_NDP_PROD = 3
38
39 mAX_NDP_SUM :: Int
40 mAX_NDP_SUM = 3
41
42 mkNDPModule :: FastString -> Module
43 mkNDPModule m = mkModule ndpPackageId (mkModuleNameFS m)
44
45 nDP_PARRAY      = mkNDPModule FSLIT("Data.Array.Parallel.Lifted.PArray")
46 nDP_REPR        = mkNDPModule FSLIT("Data.Array.Parallel.Lifted.Repr")
47 nDP_CLOSURE     = mkNDPModule FSLIT("Data.Array.Parallel.Lifted.Closure")
48 nDP_PRIM        = mkNDPModule FSLIT("Data.Array.Parallel.Lifted.Prim")
49 nDP_INSTANCES   = mkNDPModule FSLIT("Data.Array.Parallel.Lifted.Instances")
50
51 data Builtins = Builtins {
52                   parrayTyCon      :: TyCon
53                 , paTyCon          :: TyCon
54                 , paDataCon        :: DataCon
55                 , preprTyCon       :: TyCon
56                 , prTyCon          :: TyCon
57                 , prDataCon        :: DataCon
58                 , parrayIntPrimTyCon :: TyCon
59                 , voidTyCon        :: TyCon
60                 , wrapTyCon        :: TyCon
61                 , sumTyCons        :: Array Int TyCon
62                 , closureTyCon     :: TyCon
63                 , voidVar          :: Var
64                 , mkPRVar          :: Var
65                 , mkClosureVar     :: Var
66                 , applyClosureVar  :: Var
67                 , mkClosurePVar    :: Var
68                 , applyClosurePVar :: Var
69                 , replicatePAIntPrimVar :: Var
70                 , upToPAIntPrimVar :: Var
71                 , lengthPAVar      :: Var
72                 , replicatePAVar   :: Var
73                 , emptyPAVar       :: Var
74                 -- , packPAVar        :: Var
75                 -- , combinePAVar     :: Var
76                 , liftingContext   :: Var
77                 }
78
79 sumTyCon :: Int -> Builtins -> TyCon
80 sumTyCon n bi
81   | n >= 2 && n <= mAX_NDP_SUM = sumTyCons bi ! n
82   | otherwise = pprPanic "sumTyCon" (ppr n)
83
84 prodTyCon :: Int -> Builtins -> TyCon
85 prodTyCon n bi
86   | n == 1                      = wrapTyCon bi
87   | n >= 0 && n <= mAX_NDP_PROD = tupleTyCon Boxed n
88   | otherwise = pprPanic "prodTyCon" (ppr n)
89
90 initBuiltins :: DsM Builtins
91 initBuiltins
92   = do
93       parrayTyCon  <- externalTyCon nDP_PARRAY FSLIT("PArray")
94       paTyCon      <- externalTyCon nDP_PARRAY FSLIT("PA")
95       let [paDataCon] = tyConDataCons paTyCon
96       preprTyCon   <- externalTyCon nDP_PARRAY FSLIT("PRepr")
97       prTyCon      <- externalTyCon nDP_PARRAY FSLIT("PR")
98       let [prDataCon] = tyConDataCons prTyCon
99       parrayIntPrimTyCon <- externalTyCon nDP_PRIM FSLIT("PArray_Int#")
100       closureTyCon <- externalTyCon nDP_CLOSURE FSLIT(":->")
101
102       voidTyCon    <- externalTyCon nDP_REPR FSLIT("Void")
103       wrapTyCon    <- externalTyCon nDP_REPR FSLIT("Wrap")
104       sum_tcs <- mapM (externalTyCon nDP_REPR)
105                       [mkFastString ("Sum" ++ show i) | i <- [2..mAX_NDP_SUM]]
106
107       let sumTyCons = listArray (2, mAX_NDP_SUM) sum_tcs
108
109       voidVar          <- externalVar nDP_REPR FSLIT("void")
110       mkPRVar          <- externalVar nDP_PARRAY FSLIT("mkPR")
111       mkClosureVar     <- externalVar nDP_CLOSURE FSLIT("mkClosure")
112       applyClosureVar  <- externalVar nDP_CLOSURE FSLIT("$:")
113       mkClosurePVar    <- externalVar nDP_CLOSURE FSLIT("mkClosureP")
114       applyClosurePVar <- externalVar nDP_CLOSURE FSLIT("$:^")
115       replicatePAIntPrimVar <- externalVar nDP_PRIM FSLIT("replicatePA_Int#")
116       upToPAIntPrimVar <- externalVar nDP_PRIM FSLIT("upToPA_Int#")
117       lengthPAVar      <- externalVar nDP_PARRAY FSLIT("lengthPA")
118       replicatePAVar   <- externalVar nDP_PARRAY FSLIT("replicatePA")
119       emptyPAVar       <- externalVar nDP_PARRAY FSLIT("emptyPA")
120       -- packPAVar        <- dsLookupGlobalId packPAName
121       -- combinePAVar     <- dsLookupGlobalId combinePAName
122
123       liftingContext <- liftM (\u -> mkSysLocal FSLIT("lc") u intPrimTy)
124                               newUnique
125
126       return $ Builtins {
127                  parrayTyCon      = parrayTyCon
128                , paTyCon          = paTyCon
129                , paDataCon        = paDataCon
130                , preprTyCon       = preprTyCon
131                , prTyCon          = prTyCon
132                , prDataCon        = prDataCon
133                , parrayIntPrimTyCon = parrayIntPrimTyCon
134                , voidTyCon        = voidTyCon
135                , wrapTyCon        = wrapTyCon
136                , sumTyCons        = sumTyCons
137                , closureTyCon     = closureTyCon
138                , voidVar          = voidVar
139                , mkPRVar          = mkPRVar
140                , mkClosureVar     = mkClosureVar
141                , applyClosureVar  = applyClosureVar
142                , mkClosurePVar    = mkClosurePVar
143                , applyClosurePVar = applyClosurePVar
144                , replicatePAIntPrimVar = replicatePAIntPrimVar
145                , upToPAIntPrimVar = upToPAIntPrimVar
146                , lengthPAVar      = lengthPAVar
147                , replicatePAVar   = replicatePAVar
148                , emptyPAVar       = emptyPAVar
149                -- , packPAVar        = packPAVar
150                -- , combinePAVar     = combinePAVar
151                , liftingContext   = liftingContext
152                }
153
154 initBuiltinTyCons :: Builtins -> [(Name, TyCon)]
155 initBuiltinTyCons bi = [(tyConName funTyCon, closureTyCon bi)]
156
157 initBuiltinDicts :: [(Name, Module, FastString)] -> DsM [(Name, Var)]
158 initBuiltinDicts ps
159   = do
160       dicts <- zipWithM externalVar mods fss
161       return $ zip tcs dicts
162   where
163     (tcs, mods, fss) = unzip3 ps
164
165 initBuiltinPAs = initBuiltinDicts . builtinPAs
166
167 builtinPAs :: Builtins -> [(Name, Module, FastString)]
168 builtinPAs bi
169   = [
170       mk (tyConName $ closureTyCon bi)  nDP_CLOSURE     FSLIT("dPA_Clo")
171     , mk (tyConName $ voidTyCon bi)     nDP_REPR        FSLIT("dPA_Void")
172     , mk unitTyConName                  nDP_INSTANCES   FSLIT("dPA_Unit")
173
174     , mk intTyConName                   nDP_INSTANCES   FSLIT("dPA_Int")
175     ]
176     ++ tups
177   where
178     mk name mod fs = (name, mod, fs)
179
180     tups = map mk_tup [2..3]
181     mk_tup n = mk (tyConName $ tupleTyCon Boxed n)
182                   nDP_INSTANCES
183                   (mkFastString $ "dPA_" ++ show n)
184
185 initBuiltinPRs = initBuiltinDicts . builtinPRs
186
187 builtinPRs :: Builtins -> [(Name, Module, FastString)]
188 builtinPRs bi =
189   [
190     mk (tyConName unitTyCon)          nDP_REPR      FSLIT("dPR_Unit")
191   , mk (tyConName $ voidTyCon bi)     nDP_REPR      FSLIT("dPR_Void")
192   , mk (tyConName $ wrapTyCon bi)     nDP_REPR      FSLIT("dPR_Wrap")
193   , mk (tyConName $ closureTyCon bi)  nDP_CLOSURE   FSLIT("dPR_Clo")
194
195     -- temporary
196   , mk intTyConName          nDP_INSTANCES FSLIT("dPR_Int")
197   ]
198
199   ++ map mk_sum  [2..mAX_NDP_SUM]
200   ++ map mk_prod [2..mAX_NDP_PROD]
201   where
202     mk name mod fs = (name, mod, fs)
203
204     mk_sum n = (tyConName $ sumTyCon n bi, nDP_REPR,
205                 mkFastString ("dPR_Sum" ++ show n))
206
207     mk_prod n = (tyConName $ prodTyCon n bi, nDP_REPR,
208                  mkFastString ("dPR_" ++ show n))
209
210 externalVar :: Module -> FastString -> DsM Var
211 externalVar mod fs
212   = dsLookupGlobalId =<< lookupOrig mod (mkVarOccFS fs)
213
214 externalTyCon :: Module -> FastString -> DsM TyCon
215 externalTyCon mod fs
216   = dsLookupTyCon =<< lookupOrig mod (mkOccNameFS tcName fs)
217
218 unitTyConName = tyConName unitTyCon
219
220
221 primMethod :: TyCon -> String -> DsM (Maybe Var)
222 primMethod tycon method
223   | Just suffix <- lookupNameEnv prim_ty_cons (tyConName tycon)
224   = liftM Just
225   $ dsLookupGlobalId =<< lookupOrig nDP_PRIM (mkVarOcc $ method ++ suffix)
226
227   | otherwise = return Nothing
228
229 primPArray :: TyCon -> DsM (Maybe TyCon)
230 primPArray tycon
231   | Just suffix <- lookupNameEnv prim_ty_cons (tyConName tycon)
232   = liftM Just
233   $ dsLookupTyCon =<< lookupOrig nDP_PRIM (mkOccName tcName $ "PArray" ++ suffix)
234
235   | otherwise = return Nothing
236
237 prim_ty_cons = mkNameEnv [mk_prim intPrimTyCon]
238   where
239     mk_prim tycon = (tyConName tycon, '_' : getOccString tycon)