83f056d64fc29ad62d4111592d435d5411a582b7
[ghc-hetmet.git] / compiler / vectorise / VectBuiltIn.hs
1 module VectBuiltIn (
2   Builtins(..), sumTyCon, prodTyCon,
3   combinePAVar,
4   initBuiltins, initBuiltinVars, initBuiltinTyCons, initBuiltinDataCons,
5   initBuiltinPAs, initBuiltinPRs,
6   initBuiltinBoxedTyCons,
7
8   primMethod, primPArray
9 ) where
10
11 import DsMonad
12 import IfaceEnv        ( lookupOrig )
13
14 import Module
15 import DataCon         ( DataCon, dataConName, dataConWorkId )
16 import TyCon           ( TyCon, tyConName, tyConDataCons )
17 import Var             ( Var )
18 import Id              ( mkSysLocal )
19 import Name            ( Name, getOccString )
20 import NameEnv
21 import OccName
22
23 import TypeRep         ( funTyCon )
24 import Type            ( Type, mkTyConApp )
25 import TysPrim
26 import TysWiredIn      ( unitTyCon, unitDataCon,
27                          tupleTyCon, tupleCon,
28                          intTyCon, intTyConName,
29                          doubleTyCon, doubleTyConName,
30                          boolTyCon, boolTyConName, trueDataCon, falseDataCon,
31                          parrTyConName )
32 import PrelNames       ( gHC_PARR )
33 import BasicTypes      ( Boxity(..) )
34
35 import FastString
36 import Outputable
37
38 import Data.Array
39 import Control.Monad   ( liftM, zipWithM )
40 import Data.List       ( unzip4 )
41
42 mAX_DPH_PROD :: Int
43 mAX_DPH_PROD = 5
44
45 mAX_DPH_SUM :: Int
46 mAX_DPH_SUM = 3
47
48 mAX_DPH_COMBINE :: Int
49 mAX_DPH_COMBINE = 2
50
51 data Modules = Modules {
52                    dph_PArray :: Module
53                  , dph_Repr :: Module
54                  , dph_Closure :: Module
55                  , dph_Unboxed :: Module
56                  , dph_Instances :: Module
57                  , dph_Combinators :: Module
58                  , dph_Prelude_PArr :: Module
59                  , dph_Prelude_Int :: Module
60                  , dph_Prelude_Double :: Module
61                  , dph_Prelude_Bool :: Module
62                  , dph_Prelude_Tuple :: Module
63                }
64
65 dph_Modules :: PackageId -> Modules
66 dph_Modules pkg = Modules {
67     dph_PArray         = mk (fsLit "Data.Array.Parallel.Lifted.PArray")
68   , dph_Repr           = mk (fsLit "Data.Array.Parallel.Lifted.Repr")
69   , dph_Closure        = mk (fsLit "Data.Array.Parallel.Lifted.Closure")
70   , dph_Unboxed        = mk (fsLit "Data.Array.Parallel.Lifted.Unboxed")
71   , dph_Instances      = mk (fsLit "Data.Array.Parallel.Lifted.Instances")
72   , dph_Combinators    = mk (fsLit "Data.Array.Parallel.Lifted.Combinators")
73
74   , dph_Prelude_PArr   = mk (fsLit "Data.Array.Parallel.Prelude.Base.PArr")
75   , dph_Prelude_Int    = mk (fsLit "Data.Array.Parallel.Prelude.Base.Int")
76   , dph_Prelude_Double = mk (fsLit "Data.Array.Parallel.Prelude.Base.Double")
77   , dph_Prelude_Bool   = mk (fsLit "Data.Array.Parallel.Prelude.Base.Bool")
78   , dph_Prelude_Tuple  = mk (fsLit "Data.Array.Parallel.Prelude.Base.Tuple")
79   }
80   where
81     mk = mkModule pkg . mkModuleNameFS
82
83
84 data Builtins = Builtins {
85                   dphModules       :: Modules
86                 , parrayTyCon      :: TyCon
87                 , paTyCon          :: TyCon
88                 , paDataCon        :: DataCon
89                 , preprTyCon       :: TyCon
90                 , prTyCon          :: TyCon
91                 , prDataCon        :: DataCon
92                 , intPrimArrayTy   :: Type
93                 , voidTyCon        :: TyCon
94                 , wrapTyCon        :: TyCon
95                 , enumerationTyCon :: TyCon
96                 , sumTyCons        :: Array Int TyCon
97                 , closureTyCon     :: TyCon
98                 , voidVar          :: Var
99                 , mkPRVar          :: Var
100                 , mkClosureVar     :: Var
101                 , applyClosureVar  :: Var
102                 , mkClosurePVar    :: Var
103                 , applyClosurePVar :: Var
104                 , replicatePAIntPrimVar :: Var
105                 , upToPAIntPrimVar :: Var
106                 , selectPAIntPrimVar :: Var
107                 , truesPABoolPrimVar :: Var
108                 , lengthPAVar      :: Var
109                 , replicatePAVar   :: Var
110                 , emptyPAVar       :: Var
111                 , packPAVar        :: Var
112                 , combinePAVars    :: Array Int Var
113                 , liftingContext   :: Var
114                 }
115
116 sumTyCon :: Int -> Builtins -> TyCon
117 sumTyCon n bi
118   | n >= 2 && n <= mAX_DPH_SUM = sumTyCons bi ! n
119   | otherwise = pprPanic "sumTyCon" (ppr n)
120
121 prodTyCon :: Int -> Builtins -> TyCon
122 prodTyCon n bi
123   | n == 1                      = wrapTyCon bi
124   | n >= 0 && n <= mAX_DPH_PROD = tupleTyCon Boxed n
125   | otherwise = pprPanic "prodTyCon" (ppr n)
126
127 combinePAVar :: Int -> Builtins -> Var
128 combinePAVar n bi
129   | n >= 2 && n <= mAX_DPH_COMBINE = combinePAVars bi ! n
130   | otherwise = pprPanic "combinePAVar" (ppr n)
131
132 initBuiltins :: PackageId -> DsM Builtins
133 initBuiltins pkg
134   = do
135       parrayTyCon  <- externalTyCon dph_PArray (fsLit "PArray")
136       paTyCon      <- externalTyCon dph_PArray (fsLit "PA")
137       let [paDataCon] = tyConDataCons paTyCon
138       preprTyCon   <- externalTyCon dph_PArray (fsLit "PRepr")
139       prTyCon      <- externalTyCon dph_PArray (fsLit "PR")
140       let [prDataCon] = tyConDataCons prTyCon
141       intPrimArrayTy <- externalType dph_Unboxed (fsLit "PArray_Int#")
142       closureTyCon <- externalTyCon dph_Closure (fsLit ":->")
143
144       voidTyCon    <- externalTyCon dph_Repr (fsLit "Void")
145       wrapTyCon    <- externalTyCon dph_Repr (fsLit "Wrap")
146       enumerationTyCon <- externalTyCon dph_Repr (fsLit "Enumeration")
147       sum_tcs <- mapM (externalTyCon dph_Repr)
148                       [mkFastString ("Sum" ++ show i) | i <- [2..mAX_DPH_SUM]]
149
150       let sumTyCons = listArray (2, mAX_DPH_SUM) sum_tcs
151
152       voidVar          <- externalVar dph_Repr (fsLit "void")
153       mkPRVar          <- externalVar dph_PArray (fsLit "mkPR")
154       mkClosureVar     <- externalVar dph_Closure (fsLit "mkClosure")
155       applyClosureVar  <- externalVar dph_Closure (fsLit "$:")
156       mkClosurePVar    <- externalVar dph_Closure (fsLit "mkClosureP")
157       applyClosurePVar <- externalVar dph_Closure (fsLit "$:^")
158       replicatePAIntPrimVar <- externalVar dph_Unboxed (fsLit "replicatePA_Int#")
159       upToPAIntPrimVar <- externalVar dph_Unboxed (fsLit "upToPA_Int#")
160       selectPAIntPrimVar <- externalVar dph_Unboxed (fsLit "selectPA_Int#")
161       truesPABoolPrimVar <- externalVar dph_Unboxed (fsLit "truesPA_Bool#")
162       lengthPAVar      <- externalVar dph_PArray (fsLit "lengthPA#")
163       replicatePAVar   <- externalVar dph_PArray (fsLit "replicatePA#")
164       emptyPAVar       <- externalVar dph_PArray (fsLit "emptyPA")
165       packPAVar        <- externalVar dph_PArray (fsLit "packPA#")
166
167       combines <- mapM (externalVar dph_PArray)
168                        [mkFastString ("combine" ++ show i ++ "PA#")
169                           | i <- [2..mAX_DPH_COMBINE]]
170       let combinePAVars = listArray (2, mAX_DPH_COMBINE) combines
171
172       liftingContext <- liftM (\u -> mkSysLocal (fsLit "lc") u intPrimTy)
173                               newUnique
174
175       return $ Builtins {
176                  dphModules       = modules
177                , parrayTyCon      = parrayTyCon
178                , paTyCon          = paTyCon
179                , paDataCon        = paDataCon
180                , preprTyCon       = preprTyCon
181                , prTyCon          = prTyCon
182                , prDataCon        = prDataCon
183                , intPrimArrayTy   = intPrimArrayTy
184                , voidTyCon        = voidTyCon
185                , wrapTyCon        = wrapTyCon
186                , enumerationTyCon = enumerationTyCon
187                , sumTyCons        = sumTyCons
188                , closureTyCon     = closureTyCon
189                , voidVar          = voidVar
190                , mkPRVar          = mkPRVar
191                , mkClosureVar     = mkClosureVar
192                , applyClosureVar  = applyClosureVar
193                , mkClosurePVar    = mkClosurePVar
194                , applyClosurePVar = applyClosurePVar
195                , replicatePAIntPrimVar = replicatePAIntPrimVar
196                , upToPAIntPrimVar = upToPAIntPrimVar
197                , selectPAIntPrimVar = selectPAIntPrimVar
198                , truesPABoolPrimVar = truesPABoolPrimVar
199                , lengthPAVar      = lengthPAVar
200                , replicatePAVar   = replicatePAVar
201                , emptyPAVar       = emptyPAVar
202                , packPAVar        = packPAVar
203                , combinePAVars    = combinePAVars
204                , liftingContext   = liftingContext
205                }
206   where
207     modules@(Modules {
208                dph_PArray         = dph_PArray
209              , dph_Repr           = dph_Repr
210              , dph_Closure        = dph_Closure
211              , dph_Unboxed        = dph_Unboxed
212              })
213       = dph_Modules pkg
214
215
216 initBuiltinVars :: Builtins -> DsM [(Var, Var)]
217 initBuiltinVars (Builtins { dphModules = mods })
218   = do
219       uvars <- zipWithM externalVar umods ufs
220       vvars <- zipWithM externalVar vmods vfs
221       cvars <- zipWithM externalVar cmods cfs
222       return $ [(v,v) | v <- map dataConWorkId defaultDataConWorkers]
223                ++ zip (map dataConWorkId cons) cvars
224                ++ zip uvars vvars
225   where
226     (umods, ufs, vmods, vfs) = unzip4 (preludeVars mods)
227
228     (cons, cmods, cfs) = unzip3 (preludeDataCons mods)
229
230 defaultDataConWorkers :: [DataCon]
231 defaultDataConWorkers = [trueDataCon, falseDataCon, unitDataCon]
232
233 preludeDataCons :: Modules -> [(DataCon, Module, FastString)]
234 preludeDataCons (Modules { dph_Prelude_Tuple = dph_Prelude_Tuple })
235   = [mk_tup n dph_Prelude_Tuple (mkFastString $ "tup" ++ show n) | n <- [2..3]]
236   where
237     mk_tup n mod name = (tupleCon Boxed n, mod, name)
238
239 preludeVars :: Modules -> [(Module, FastString, Module, FastString)]
240 preludeVars (Modules { dph_Combinators    = dph_Combinators
241                      , dph_PArray         = dph_PArray
242                      , dph_Prelude_Int    = dph_Prelude_Int
243                      , dph_Prelude_Double = dph_Prelude_Double
244                      , dph_Prelude_Bool   = dph_Prelude_Bool 
245                      , dph_Prelude_PArr   = dph_Prelude_PArr
246                      })
247   = [
248       mk gHC_PARR (fsLit "mapP")       dph_Combinators (fsLit "mapPA")
249     , mk gHC_PARR (fsLit "zipWithP")   dph_Combinators (fsLit "zipWithPA")
250     , mk gHC_PARR (fsLit "zipP")       dph_Combinators (fsLit "zipPA")
251     , mk gHC_PARR (fsLit "unzipP")     dph_Combinators (fsLit "unzipPA")
252     , mk gHC_PARR (fsLit "filterP")    dph_Combinators (fsLit "filterPA")
253     , mk gHC_PARR (fsLit "lengthP")    dph_Combinators (fsLit "lengthPA")
254     , mk gHC_PARR (fsLit "replicateP") dph_Combinators (fsLit "replicatePA")
255     , mk gHC_PARR (fsLit "!:")         dph_Combinators (fsLit "indexPA")
256     , mk gHC_PARR (fsLit "crossMapP")  dph_Combinators (fsLit "crossMapPA")
257     , mk gHC_PARR (fsLit "singletonP") dph_Combinators (fsLit "singletonPA")
258     , mk gHC_PARR (fsLit "concatP")    dph_Combinators (fsLit "concatPA")
259     , mk gHC_PARR (fsLit "+:+")        dph_Combinators (fsLit "appPA")
260     , mk gHC_PARR (fsLit "emptyP")     dph_PArray (fsLit "emptyPA")
261
262     , mk' dph_Prelude_Int "div"  "divV"
263     , mk' dph_Prelude_Int "mod"  "modV"
264     , mk' dph_Prelude_Int "sqrt" "sqrtV"
265     , mk' dph_Prelude_Int "enumFromToP" "enumFromToPA"
266     , mk' dph_Prelude_Int "upToP" "upToPA"
267     ]
268     ++ vars_Ord dph_Prelude_Int
269     ++ vars_Num dph_Prelude_Int
270
271     ++ vars_Ord dph_Prelude_Double
272     ++ vars_Num dph_Prelude_Double
273     ++ vars_Fractional dph_Prelude_Double
274     ++ vars_Floating dph_Prelude_Double
275     ++ vars_RealFrac dph_Prelude_Double
276     ++
277     [ mk dph_Prelude_Bool  (fsLit "andP")  dph_Prelude_Bool (fsLit "andPA")
278     , mk dph_Prelude_Bool  (fsLit "orP")  dph_Prelude_Bool (fsLit "orPA")
279
280     -- FIXME: temporary
281     , mk dph_Prelude_PArr (fsLit "fromPArrayP") dph_Prelude_PArr (fsLit "fromPArrayPA")
282     , mk dph_Prelude_PArr (fsLit "toPArrayP") dph_Prelude_PArr (fsLit "toPArrayPA")
283     , mk dph_Prelude_PArr (fsLit "fromNestedPArrayP") dph_Prelude_PArr (fsLit "fromNestedPArrayPA")
284     , mk dph_Prelude_PArr (fsLit "combineP")    dph_Combinators (fsLit "combine2PA")
285     ]
286   where
287     mk  = (,,,)
288     mk' mod v v' = mk mod (fsLit v) mod (fsLit v')
289
290     vars_Ord mod = [mk' mod "=="  "eqV"
291                    ,mk' mod "/=" "neqV"
292                    ,mk' mod "<="  "leV"
293                    ,mk' mod "<"   "ltV"
294                    ,mk' mod ">="  "geV"
295                    ,mk' mod ">"   "gtV"
296                    ,mk' mod "min" "minV"
297                    ,mk' mod "max" "maxV"
298                    ,mk' mod "minimumP" "minimumPA"
299                    ,mk' mod "maximumP" "maximumPA"
300                    ,mk' mod "minIndexP" "minIndexPA"
301                    ,mk' mod "maxIndexP" "maxIndexPA"
302                    ]
303
304     vars_Num mod = [mk' mod "+"        "plusV"
305                    ,mk' mod "-"        "minusV"
306                    ,mk' mod "*"        "multV"
307                    ,mk' mod "negate"   "negateV"
308                    ,mk' mod "abs"      "absV"
309                    ,mk' mod "sumP"     "sumPA"
310                    ,mk' mod "productP" "productPA"
311                    ]
312
313     vars_Fractional mod = [mk' mod "/"     "divideV"
314                           ,mk' mod "recip" "recipV"
315                           ]
316
317     vars_Floating mod = [mk' mod "pi" "pi"
318                         ,mk' mod "exp" "expV"
319                         ,mk' mod "sqrt" "sqrtV"
320                         ,mk' mod "log" "logV"
321                         ,mk' mod "sin" "sinV"
322                         ,mk' mod "tan" "tanV"
323                         ,mk' mod "cos" "cosV"
324                         ,mk' mod "asin" "asinV"
325                         ,mk' mod "atan" "atanV"
326                         ,mk' mod "acos" "acosV"
327                         ,mk' mod "sinh" "sinhV"
328                         ,mk' mod "tanh" "tanhV"
329                         ,mk' mod "cosh" "coshV"
330                         ,mk' mod "asinh" "asinhV"
331                         ,mk' mod "atanh" "atanhV"
332                         ,mk' mod "acosh" "acoshV"
333                         ,mk' mod "**"    "powV"
334                         ,mk' mod "logBase" "logBaseV"
335                         ]
336
337     vars_RealFrac mod = [mk' mod "fromInt" "fromIntV"
338                         ,mk' mod "truncate" "truncateV"
339                         ,mk' mod "round" "roundV"
340                         ,mk' mod "ceiling" "ceilingV"
341                         ,mk' mod "floor" "floorV"
342                         ]
343
344 initBuiltinTyCons :: Builtins -> DsM [(Name, TyCon)]
345 initBuiltinTyCons bi
346   = do
347       -- parr <- externalTyCon dph_Prelude_PArr (fsLit "PArr")
348       return $ (tyConName funTyCon, closureTyCon bi)
349              : (parrTyConName,      parrayTyCon bi)
350
351              -- FIXME: temporary
352              : (tyConName $ parrayTyCon bi, parrayTyCon bi)
353
354              : [(tyConName tc, tc) | tc <- defaultTyCons]
355
356 defaultTyCons :: [TyCon]
357 defaultTyCons = [intTyCon, boolTyCon, doubleTyCon]
358
359 initBuiltinDataCons :: Builtins -> [(Name, DataCon)]
360 initBuiltinDataCons _ = [(dataConName dc, dc)| dc <- defaultDataCons]
361
362 defaultDataCons :: [DataCon]
363 defaultDataCons = [trueDataCon, falseDataCon, unitDataCon]
364
365 initBuiltinDicts :: [(Name, Module, FastString)] -> DsM [(Name, Var)]
366 initBuiltinDicts ps
367   = do
368       dicts <- zipWithM externalVar mods fss
369       return $ zip tcs dicts
370   where
371     (tcs, mods, fss) = unzip3 ps
372
373 initBuiltinPAs :: Builtins -> DsM [(Name, Var)]
374 initBuiltinPAs = initBuiltinDicts . builtinPAs
375
376 builtinPAs :: Builtins -> [(Name, Module, FastString)]
377 builtinPAs bi@(Builtins { dphModules = mods })
378   = [
379       mk (tyConName $ closureTyCon bi)  (dph_Closure   mods) (fsLit "dPA_Clo")
380     , mk (tyConName $ voidTyCon bi)     (dph_Repr      mods) (fsLit "dPA_Void")
381     , mk (tyConName $ parrayTyCon bi)   (dph_Instances mods) (fsLit "dPA_PArray")
382     , mk unitTyConName                  (dph_Instances mods) (fsLit "dPA_Unit")
383
384     , mk intTyConName                   (dph_Instances mods) (fsLit "dPA_Int")
385     , mk doubleTyConName                (dph_Instances mods) (fsLit "dPA_Double")
386     , mk boolTyConName                  (dph_Instances mods) (fsLit "dPA_Bool")
387     ]
388     ++ tups
389   where
390     mk name mod fs = (name, mod, fs)
391
392     tups = map mk_tup [2..mAX_DPH_PROD]
393     mk_tup n = mk (tyConName $ tupleTyCon Boxed n)
394                   (dph_Instances mods)
395                   (mkFastString $ "dPA_" ++ show n)
396
397 initBuiltinPRs :: Builtins -> DsM [(Name, Var)]
398 initBuiltinPRs = initBuiltinDicts . builtinPRs
399
400 builtinPRs :: Builtins -> [(Name, Module, FastString)]
401 builtinPRs bi@(Builtins { dphModules = mods }) =
402   [
403     mk (tyConName   unitTyCon)           (dph_Repr mods)    (fsLit "dPR_Unit")
404   , mk (tyConName $ voidTyCon        bi) (dph_Repr mods)    (fsLit "dPR_Void")
405   , mk (tyConName $ wrapTyCon        bi) (dph_Repr mods)    (fsLit "dPR_Wrap")
406   , mk (tyConName $ enumerationTyCon bi) (dph_Repr mods)    (fsLit "dPR_Enumeration")
407   , mk (tyConName $ closureTyCon     bi) (dph_Closure mods) (fsLit "dPR_Clo")
408
409     -- temporary
410   , mk intTyConName          (dph_Instances mods) (fsLit "dPR_Int")
411   , mk doubleTyConName       (dph_Instances mods) (fsLit "dPR_Double")
412   ]
413
414   ++ map mk_sum  [2..mAX_DPH_SUM]
415   ++ map mk_prod [2..mAX_DPH_PROD]
416   where
417     mk name mod fs = (name, mod, fs)
418
419     mk_sum n = (tyConName $ sumTyCon n bi, dph_Repr mods,
420                 mkFastString ("dPR_Sum" ++ show n))
421
422     mk_prod n = (tyConName $ prodTyCon n bi, dph_Repr mods,
423                  mkFastString ("dPR_" ++ show n))
424
425 initBuiltinBoxedTyCons :: Builtins -> DsM [(Name, TyCon)]
426 initBuiltinBoxedTyCons = return . builtinBoxedTyCons
427
428 builtinBoxedTyCons :: Builtins -> [(Name, TyCon)]
429 builtinBoxedTyCons _ =
430   [(tyConName intPrimTyCon, intTyCon)]
431
432 externalVar :: Module -> FastString -> DsM Var
433 externalVar mod fs
434   = dsLookupGlobalId =<< lookupOrig mod (mkVarOccFS fs)
435
436 externalTyCon :: Module -> FastString -> DsM TyCon
437 externalTyCon mod fs
438   = dsLookupTyCon =<< lookupOrig mod (mkTcOccFS fs)
439
440 externalType :: Module -> FastString -> DsM Type
441 externalType mod fs
442   = do
443       tycon <- externalTyCon mod fs
444       return $ mkTyConApp tycon []
445
446 unitTyConName :: Name
447 unitTyConName = tyConName unitTyCon
448
449
450 primMethod :: TyCon -> String -> Builtins -> DsM (Maybe Var)
451 primMethod  tycon method (Builtins { dphModules = mods })
452   | Just suffix <- lookupNameEnv prim_ty_cons (tyConName tycon)
453   = liftM Just
454   $ dsLookupGlobalId =<< lookupOrig (dph_Unboxed mods)
455                                     (mkVarOcc $ method ++ suffix)
456
457   | otherwise = return Nothing
458
459 primPArray :: TyCon -> Builtins -> DsM (Maybe TyCon)
460 primPArray tycon (Builtins { dphModules = mods })
461   | Just suffix <- lookupNameEnv prim_ty_cons (tyConName tycon)
462   = liftM Just
463   $ dsLookupTyCon =<< lookupOrig (dph_Unboxed mods)
464                                  (mkTcOcc $ "PArray" ++ suffix)
465
466   | otherwise = return Nothing
467
468 prim_ty_cons :: NameEnv String
469 prim_ty_cons = mkNameEnv [mk_prim intPrimTyCon]
470   where
471     mk_prim tycon = (tyConName tycon, '_' : getOccString tycon)
472