Add Word8 support to vectoriser
[ghc-hetmet.git] / compiler / vectorise / VectBuiltIn.hs
1 module VectBuiltIn (
2   Builtins(..), sumTyCon, prodTyCon,
3   combinePAVar,
4   initBuiltins, initBuiltinVars, initBuiltinTyCons, initBuiltinDataCons,
5   initBuiltinPAs, initBuiltinPRs,
6   initBuiltinBoxedTyCons,
7
8   primMethod, primPArray
9 ) where
10
11 import DsMonad
12 import IfaceEnv        ( lookupOrig )
13
14 import Module
15 import DataCon         ( DataCon, dataConName, dataConWorkId )
16 import TyCon           ( TyCon, tyConName, tyConDataCons )
17 import Var             ( Var )
18 import Id              ( mkSysLocal )
19 import Name            ( Name, getOccString )
20 import NameEnv
21 import OccName
22
23 import TypeRep         ( funTyCon )
24 import Type            ( Type, mkTyConApp )
25 import TysPrim
26 import TysWiredIn      ( unitTyCon, unitDataCon,
27                          tupleTyCon, tupleCon,
28                          intTyCon, intTyConName,
29                          doubleTyCon, doubleTyConName,
30                          boolTyCon, boolTyConName, trueDataCon, falseDataCon,
31                          parrTyConName )
32 import PrelNames       ( word8TyConName, gHC_PARR )
33 import BasicTypes      ( Boxity(..) )
34
35 import FastString
36 import Outputable
37
38 import Data.Array
39 import Control.Monad   ( liftM, zipWithM )
40 import Data.List       ( unzip4 )
41
42 mAX_DPH_PROD :: Int
43 mAX_DPH_PROD = 5
44
45 mAX_DPH_SUM :: Int
46 mAX_DPH_SUM = 3
47
48 mAX_DPH_COMBINE :: Int
49 mAX_DPH_COMBINE = 2
50
51 data Modules = Modules {
52                    dph_PArray :: Module
53                  , dph_Repr :: Module
54                  , dph_Closure :: Module
55                  , dph_Unboxed :: Module
56                  , dph_Instances :: Module
57                  , dph_Combinators :: Module
58                  , dph_Prelude_PArr :: Module
59                  , dph_Prelude_Int :: Module
60                  , dph_Prelude_Word8 :: Module
61                  , dph_Prelude_Double :: Module
62                  , dph_Prelude_Bool :: Module
63                  , dph_Prelude_Tuple :: Module
64                }
65
66 dph_Modules :: PackageId -> Modules
67 dph_Modules pkg = Modules {
68     dph_PArray         = mk (fsLit "Data.Array.Parallel.Lifted.PArray")
69   , dph_Repr           = mk (fsLit "Data.Array.Parallel.Lifted.Repr")
70   , dph_Closure        = mk (fsLit "Data.Array.Parallel.Lifted.Closure")
71   , dph_Unboxed        = mk (fsLit "Data.Array.Parallel.Lifted.Unboxed")
72   , dph_Instances      = mk (fsLit "Data.Array.Parallel.Lifted.Instances")
73   , dph_Combinators    = mk (fsLit "Data.Array.Parallel.Lifted.Combinators")
74
75   , dph_Prelude_PArr   = mk (fsLit "Data.Array.Parallel.Prelude.Base.PArr")
76   , dph_Prelude_Int    = mk (fsLit "Data.Array.Parallel.Prelude.Base.Int")
77   , dph_Prelude_Word8  = mk (fsLit "Data.Array.Parallel.Prelude.Base.Word8")
78   , dph_Prelude_Double = mk (fsLit "Data.Array.Parallel.Prelude.Base.Double")
79   , dph_Prelude_Bool   = mk (fsLit "Data.Array.Parallel.Prelude.Base.Bool")
80   , dph_Prelude_Tuple  = mk (fsLit "Data.Array.Parallel.Prelude.Base.Tuple")
81   }
82   where
83     mk = mkModule pkg . mkModuleNameFS
84
85
86 data Builtins = Builtins {
87                   dphModules       :: Modules
88                 , parrayTyCon      :: TyCon
89                 , paTyCon          :: TyCon
90                 , paDataCon        :: DataCon
91                 , preprTyCon       :: TyCon
92                 , prTyCon          :: TyCon
93                 , prDataCon        :: DataCon
94                 , intPrimArrayTy   :: Type
95                 , voidTyCon        :: TyCon
96                 , wrapTyCon        :: TyCon
97                 , enumerationTyCon :: TyCon
98                 , sumTyCons        :: Array Int TyCon
99                 , closureTyCon     :: TyCon
100                 , voidVar          :: Var
101                 , mkPRVar          :: Var
102                 , mkClosureVar     :: Var
103                 , applyClosureVar  :: Var
104                 , mkClosurePVar    :: Var
105                 , applyClosurePVar :: Var
106                 , replicatePAIntPrimVar :: Var
107                 , upToPAIntPrimVar :: Var
108                 , selectPAIntPrimVar :: Var
109                 , truesPABoolPrimVar :: Var
110                 , lengthPAVar      :: Var
111                 , replicatePAVar   :: Var
112                 , emptyPAVar       :: Var
113                 , packPAVar        :: Var
114                 , combinePAVars    :: Array Int Var
115                 , liftingContext   :: Var
116                 }
117
118 sumTyCon :: Int -> Builtins -> TyCon
119 sumTyCon n bi
120   | n >= 2 && n <= mAX_DPH_SUM = sumTyCons bi ! n
121   | otherwise = pprPanic "sumTyCon" (ppr n)
122
123 prodTyCon :: Int -> Builtins -> TyCon
124 prodTyCon n bi
125   | n == 1                      = wrapTyCon bi
126   | n >= 0 && n <= mAX_DPH_PROD = tupleTyCon Boxed n
127   | otherwise = pprPanic "prodTyCon" (ppr n)
128
129 combinePAVar :: Int -> Builtins -> Var
130 combinePAVar n bi
131   | n >= 2 && n <= mAX_DPH_COMBINE = combinePAVars bi ! n
132   | otherwise = pprPanic "combinePAVar" (ppr n)
133
134 initBuiltins :: PackageId -> DsM Builtins
135 initBuiltins pkg
136   = do
137       parrayTyCon  <- externalTyCon dph_PArray (fsLit "PArray")
138       paTyCon      <- externalTyCon dph_PArray (fsLit "PA")
139       let [paDataCon] = tyConDataCons paTyCon
140       preprTyCon   <- externalTyCon dph_PArray (fsLit "PRepr")
141       prTyCon      <- externalTyCon dph_PArray (fsLit "PR")
142       let [prDataCon] = tyConDataCons prTyCon
143       intPrimArrayTy <- externalType dph_Unboxed (fsLit "PArray_Int#")
144       closureTyCon <- externalTyCon dph_Closure (fsLit ":->")
145
146       voidTyCon    <- externalTyCon dph_Repr (fsLit "Void")
147       wrapTyCon    <- externalTyCon dph_Repr (fsLit "Wrap")
148       enumerationTyCon <- externalTyCon dph_Repr (fsLit "Enumeration")
149       sum_tcs <- mapM (externalTyCon dph_Repr)
150                       [mkFastString ("Sum" ++ show i) | i <- [2..mAX_DPH_SUM]]
151
152       let sumTyCons = listArray (2, mAX_DPH_SUM) sum_tcs
153
154       voidVar          <- externalVar dph_Repr (fsLit "void")
155       mkPRVar          <- externalVar dph_PArray (fsLit "mkPR")
156       mkClosureVar     <- externalVar dph_Closure (fsLit "mkClosure")
157       applyClosureVar  <- externalVar dph_Closure (fsLit "$:")
158       mkClosurePVar    <- externalVar dph_Closure (fsLit "mkClosureP")
159       applyClosurePVar <- externalVar dph_Closure (fsLit "$:^")
160       replicatePAIntPrimVar <- externalVar dph_Unboxed (fsLit "replicatePA_Int#")
161       upToPAIntPrimVar <- externalVar dph_Unboxed (fsLit "upToPA_Int#")
162       selectPAIntPrimVar <- externalVar dph_Unboxed (fsLit "selectPA_Int#")
163       truesPABoolPrimVar <- externalVar dph_Unboxed (fsLit "truesPA_Bool#")
164       lengthPAVar      <- externalVar dph_PArray (fsLit "lengthPA#")
165       replicatePAVar   <- externalVar dph_PArray (fsLit "replicatePA#")
166       emptyPAVar       <- externalVar dph_PArray (fsLit "emptyPA")
167       packPAVar        <- externalVar dph_PArray (fsLit "packPA#")
168
169       combines <- mapM (externalVar dph_PArray)
170                        [mkFastString ("combine" ++ show i ++ "PA#")
171                           | i <- [2..mAX_DPH_COMBINE]]
172       let combinePAVars = listArray (2, mAX_DPH_COMBINE) combines
173
174       liftingContext <- liftM (\u -> mkSysLocal (fsLit "lc") u intPrimTy)
175                               newUnique
176
177       return $ Builtins {
178                  dphModules       = modules
179                , parrayTyCon      = parrayTyCon
180                , paTyCon          = paTyCon
181                , paDataCon        = paDataCon
182                , preprTyCon       = preprTyCon
183                , prTyCon          = prTyCon
184                , prDataCon        = prDataCon
185                , intPrimArrayTy   = intPrimArrayTy
186                , voidTyCon        = voidTyCon
187                , wrapTyCon        = wrapTyCon
188                , enumerationTyCon = enumerationTyCon
189                , sumTyCons        = sumTyCons
190                , closureTyCon     = closureTyCon
191                , voidVar          = voidVar
192                , mkPRVar          = mkPRVar
193                , mkClosureVar     = mkClosureVar
194                , applyClosureVar  = applyClosureVar
195                , mkClosurePVar    = mkClosurePVar
196                , applyClosurePVar = applyClosurePVar
197                , replicatePAIntPrimVar = replicatePAIntPrimVar
198                , upToPAIntPrimVar = upToPAIntPrimVar
199                , selectPAIntPrimVar = selectPAIntPrimVar
200                , truesPABoolPrimVar = truesPABoolPrimVar
201                , lengthPAVar      = lengthPAVar
202                , replicatePAVar   = replicatePAVar
203                , emptyPAVar       = emptyPAVar
204                , packPAVar        = packPAVar
205                , combinePAVars    = combinePAVars
206                , liftingContext   = liftingContext
207                }
208   where
209     modules@(Modules {
210                dph_PArray         = dph_PArray
211              , dph_Repr           = dph_Repr
212              , dph_Closure        = dph_Closure
213              , dph_Unboxed        = dph_Unboxed
214              })
215       = dph_Modules pkg
216
217
218 initBuiltinVars :: Builtins -> DsM [(Var, Var)]
219 initBuiltinVars (Builtins { dphModules = mods })
220   = do
221       uvars <- zipWithM externalVar umods ufs
222       vvars <- zipWithM externalVar vmods vfs
223       cvars <- zipWithM externalVar cmods cfs
224       return $ [(v,v) | v <- map dataConWorkId defaultDataConWorkers]
225                ++ zip (map dataConWorkId cons) cvars
226                ++ zip uvars vvars
227   where
228     (umods, ufs, vmods, vfs) = unzip4 (preludeVars mods)
229
230     (cons, cmods, cfs) = unzip3 (preludeDataCons mods)
231
232 defaultDataConWorkers :: [DataCon]
233 defaultDataConWorkers = [trueDataCon, falseDataCon, unitDataCon]
234
235 preludeDataCons :: Modules -> [(DataCon, Module, FastString)]
236 preludeDataCons (Modules { dph_Prelude_Tuple = dph_Prelude_Tuple })
237   = [mk_tup n dph_Prelude_Tuple (mkFastString $ "tup" ++ show n) | n <- [2..3]]
238   where
239     mk_tup n mod name = (tupleCon Boxed n, mod, name)
240
241 preludeVars :: Modules -> [(Module, FastString, Module, FastString)]
242 preludeVars (Modules { dph_Combinators    = dph_Combinators
243                      , dph_PArray         = dph_PArray
244                      , dph_Prelude_Int    = dph_Prelude_Int
245                      , dph_Prelude_Word8  = dph_Prelude_Word8
246                      , dph_Prelude_Double = dph_Prelude_Double
247                      , dph_Prelude_Bool   = dph_Prelude_Bool 
248                      , dph_Prelude_PArr   = dph_Prelude_PArr
249                      })
250   = [
251       mk gHC_PARR (fsLit "mapP")       dph_Combinators (fsLit "mapPA")
252     , mk gHC_PARR (fsLit "zipWithP")   dph_Combinators (fsLit "zipWithPA")
253     , mk gHC_PARR (fsLit "zipP")       dph_Combinators (fsLit "zipPA")
254     , mk gHC_PARR (fsLit "unzipP")     dph_Combinators (fsLit "unzipPA")
255     , mk gHC_PARR (fsLit "filterP")    dph_Combinators (fsLit "filterPA")
256     , mk gHC_PARR (fsLit "lengthP")    dph_Combinators (fsLit "lengthPA")
257     , mk gHC_PARR (fsLit "replicateP") dph_Combinators (fsLit "replicatePA")
258     , mk gHC_PARR (fsLit "!:")         dph_Combinators (fsLit "indexPA")
259     , mk gHC_PARR (fsLit "crossMapP")  dph_Combinators (fsLit "crossMapPA")
260     , mk gHC_PARR (fsLit "singletonP") dph_Combinators (fsLit "singletonPA")
261     , mk gHC_PARR (fsLit "concatP")    dph_Combinators (fsLit "concatPA")
262     , mk gHC_PARR (fsLit "+:+")        dph_Combinators (fsLit "appPA")
263     , mk gHC_PARR (fsLit "emptyP")     dph_PArray (fsLit "emptyPA")
264
265     , mk' dph_Prelude_Int "div"  "divV"
266     , mk' dph_Prelude_Int "mod"  "modV"
267     , mk' dph_Prelude_Int "sqrt" "sqrtV"
268     , mk' dph_Prelude_Int "enumFromToP" "enumFromToPA"
269     , mk' dph_Prelude_Int "upToP" "upToPA"
270     ]
271     ++ vars_Ord dph_Prelude_Int
272     ++ vars_Num dph_Prelude_Int
273
274     ++ vars_Ord dph_Prelude_Word8
275     ++ vars_Num dph_Prelude_Word8
276     ++
277     [ mk' dph_Prelude_Word8 "div" "divV"
278     , mk' dph_Prelude_Word8 "mod" "modV"
279     ]
280
281     ++ vars_Ord dph_Prelude_Double
282     ++ vars_Num dph_Prelude_Double
283     ++ vars_Fractional dph_Prelude_Double
284     ++ vars_Floating dph_Prelude_Double
285     ++ vars_RealFrac dph_Prelude_Double
286     ++
287     [ mk dph_Prelude_Bool  (fsLit "andP")  dph_Prelude_Bool (fsLit "andPA")
288     , mk dph_Prelude_Bool  (fsLit "orP")  dph_Prelude_Bool (fsLit "orPA")
289
290     -- FIXME: temporary
291     , mk dph_Prelude_PArr (fsLit "fromPArrayP") dph_Prelude_PArr (fsLit "fromPArrayPA")
292     , mk dph_Prelude_PArr (fsLit "toPArrayP") dph_Prelude_PArr (fsLit "toPArrayPA")
293     , mk dph_Prelude_PArr (fsLit "fromNestedPArrayP") dph_Prelude_PArr (fsLit "fromNestedPArrayPA")
294     , mk dph_Prelude_PArr (fsLit "combineP")    dph_Combinators (fsLit "combine2PA")
295     ]
296   where
297     mk  = (,,,)
298     mk' mod v v' = mk mod (fsLit v) mod (fsLit v')
299
300     vars_Ord mod = [mk' mod "=="  "eqV"
301                    ,mk' mod "/=" "neqV"
302                    ,mk' mod "<="  "leV"
303                    ,mk' mod "<"   "ltV"
304                    ,mk' mod ">="  "geV"
305                    ,mk' mod ">"   "gtV"
306                    ,mk' mod "min" "minV"
307                    ,mk' mod "max" "maxV"
308                    ,mk' mod "minimumP" "minimumPA"
309                    ,mk' mod "maximumP" "maximumPA"
310                    ,mk' mod "minIndexP" "minIndexPA"
311                    ,mk' mod "maxIndexP" "maxIndexPA"
312                    ]
313
314     vars_Num mod = [mk' mod "+"        "plusV"
315                    ,mk' mod "-"        "minusV"
316                    ,mk' mod "*"        "multV"
317                    ,mk' mod "negate"   "negateV"
318                    ,mk' mod "abs"      "absV"
319                    ,mk' mod "sumP"     "sumPA"
320                    ,mk' mod "productP" "productPA"
321                    ]
322
323     vars_Fractional mod = [mk' mod "/"     "divideV"
324                           ,mk' mod "recip" "recipV"
325                           ]
326
327     vars_Floating mod = [mk' mod "pi" "pi"
328                         ,mk' mod "exp" "expV"
329                         ,mk' mod "sqrt" "sqrtV"
330                         ,mk' mod "log" "logV"
331                         ,mk' mod "sin" "sinV"
332                         ,mk' mod "tan" "tanV"
333                         ,mk' mod "cos" "cosV"
334                         ,mk' mod "asin" "asinV"
335                         ,mk' mod "atan" "atanV"
336                         ,mk' mod "acos" "acosV"
337                         ,mk' mod "sinh" "sinhV"
338                         ,mk' mod "tanh" "tanhV"
339                         ,mk' mod "cosh" "coshV"
340                         ,mk' mod "asinh" "asinhV"
341                         ,mk' mod "atanh" "atanhV"
342                         ,mk' mod "acosh" "acoshV"
343                         ,mk' mod "**"    "powV"
344                         ,mk' mod "logBase" "logBaseV"
345                         ]
346
347     vars_RealFrac mod = [mk' mod "fromInt" "fromIntV"
348                         ,mk' mod "truncate" "truncateV"
349                         ,mk' mod "round" "roundV"
350                         ,mk' mod "ceiling" "ceilingV"
351                         ,mk' mod "floor" "floorV"
352                         ]
353
354 initBuiltinTyCons :: Builtins -> DsM [(Name, TyCon)]
355 initBuiltinTyCons bi
356   = do
357       -- parr <- externalTyCon dph_Prelude_PArr (fsLit "PArr")
358       dft_tcs <- defaultTyCons
359       return $ (tyConName funTyCon, closureTyCon bi)
360              : (parrTyConName,      parrayTyCon bi)
361
362              -- FIXME: temporary
363              : (tyConName $ parrayTyCon bi, parrayTyCon bi)
364
365              : [(tyConName tc, tc) | tc <- dft_tcs]
366
367 defaultTyCons :: DsM [TyCon]
368 defaultTyCons
369   = do
370       word8 <- dsLookupTyCon word8TyConName
371       return [intTyCon, boolTyCon, doubleTyCon, word8]
372
373 initBuiltinDataCons :: Builtins -> [(Name, DataCon)]
374 initBuiltinDataCons _ = [(dataConName dc, dc)| dc <- defaultDataCons]
375
376 defaultDataCons :: [DataCon]
377 defaultDataCons = [trueDataCon, falseDataCon, unitDataCon]
378
379 initBuiltinDicts :: [(Name, Module, FastString)] -> DsM [(Name, Var)]
380 initBuiltinDicts ps
381   = do
382       dicts <- zipWithM externalVar mods fss
383       return $ zip tcs dicts
384   where
385     (tcs, mods, fss) = unzip3 ps
386
387 initBuiltinPAs :: Builtins -> DsM [(Name, Var)]
388 initBuiltinPAs = initBuiltinDicts . builtinPAs
389
390 builtinPAs :: Builtins -> [(Name, Module, FastString)]
391 builtinPAs bi@(Builtins { dphModules = mods })
392   = [
393       mk (tyConName $ closureTyCon bi)  (dph_Closure   mods) (fsLit "dPA_Clo")
394     , mk (tyConName $ voidTyCon bi)     (dph_Repr      mods) (fsLit "dPA_Void")
395     , mk (tyConName $ parrayTyCon bi)   (dph_Instances mods) (fsLit "dPA_PArray")
396     , mk unitTyConName                  (dph_Instances mods) (fsLit "dPA_Unit")
397
398     , mk intTyConName                   (dph_Instances mods) (fsLit "dPA_Int")
399     , mk word8TyConName                 (dph_Instances mods) (fsLit "dPA_Word8")
400     , mk doubleTyConName                (dph_Instances mods) (fsLit "dPA_Double")
401     , mk boolTyConName                  (dph_Instances mods) (fsLit "dPA_Bool")
402     ]
403     ++ tups
404   where
405     mk name mod fs = (name, mod, fs)
406
407     tups = map mk_tup [2..mAX_DPH_PROD]
408     mk_tup n = mk (tyConName $ tupleTyCon Boxed n)
409                   (dph_Instances mods)
410                   (mkFastString $ "dPA_" ++ show n)
411
412 initBuiltinPRs :: Builtins -> DsM [(Name, Var)]
413 initBuiltinPRs = initBuiltinDicts . builtinPRs
414
415 builtinPRs :: Builtins -> [(Name, Module, FastString)]
416 builtinPRs bi@(Builtins { dphModules = mods }) =
417   [
418     mk (tyConName   unitTyCon)           (dph_Repr mods)    (fsLit "dPR_Unit")
419   , mk (tyConName $ voidTyCon        bi) (dph_Repr mods)    (fsLit "dPR_Void")
420   , mk (tyConName $ wrapTyCon        bi) (dph_Repr mods)    (fsLit "dPR_Wrap")
421   , mk (tyConName $ enumerationTyCon bi) (dph_Repr mods)    (fsLit "dPR_Enumeration")
422   , mk (tyConName $ closureTyCon     bi) (dph_Closure mods) (fsLit "dPR_Clo")
423
424     -- temporary
425   , mk intTyConName          (dph_Instances mods) (fsLit "dPR_Int")
426   , mk word8TyConName        (dph_Instances mods) (fsLit "dPR_Word8")
427   , mk doubleTyConName       (dph_Instances mods) (fsLit "dPR_Double")
428   ]
429
430   ++ map mk_sum  [2..mAX_DPH_SUM]
431   ++ map mk_prod [2..mAX_DPH_PROD]
432   where
433     mk name mod fs = (name, mod, fs)
434
435     mk_sum n = (tyConName $ sumTyCon n bi, dph_Repr mods,
436                 mkFastString ("dPR_Sum" ++ show n))
437
438     mk_prod n = (tyConName $ prodTyCon n bi, dph_Repr mods,
439                  mkFastString ("dPR_" ++ show n))
440
441 initBuiltinBoxedTyCons :: Builtins -> DsM [(Name, TyCon)]
442 initBuiltinBoxedTyCons = return . builtinBoxedTyCons
443
444 builtinBoxedTyCons :: Builtins -> [(Name, TyCon)]
445 builtinBoxedTyCons _ =
446   [(tyConName intPrimTyCon, intTyCon)]
447
448 externalVar :: Module -> FastString -> DsM Var
449 externalVar mod fs
450   = dsLookupGlobalId =<< lookupOrig mod (mkVarOccFS fs)
451
452 externalTyCon :: Module -> FastString -> DsM TyCon
453 externalTyCon mod fs
454   = dsLookupTyCon =<< lookupOrig mod (mkTcOccFS fs)
455
456 externalType :: Module -> FastString -> DsM Type
457 externalType mod fs
458   = do
459       tycon <- externalTyCon mod fs
460       return $ mkTyConApp tycon []
461
462 unitTyConName :: Name
463 unitTyConName = tyConName unitTyCon
464
465
466 primMethod :: TyCon -> String -> Builtins -> DsM (Maybe Var)
467 primMethod  tycon method (Builtins { dphModules = mods })
468   | Just suffix <- lookupNameEnv prim_ty_cons (tyConName tycon)
469   = liftM Just
470   $ dsLookupGlobalId =<< lookupOrig (dph_Unboxed mods)
471                                     (mkVarOcc $ method ++ suffix)
472
473   | otherwise = return Nothing
474
475 primPArray :: TyCon -> Builtins -> DsM (Maybe TyCon)
476 primPArray tycon (Builtins { dphModules = mods })
477   | Just suffix <- lookupNameEnv prim_ty_cons (tyConName tycon)
478   = liftM Just
479   $ dsLookupTyCon =<< lookupOrig (dph_Unboxed mods)
480                                  (mkTcOcc $ "PArray" ++ suffix)
481
482   | otherwise = return Nothing
483
484 prim_ty_cons :: NameEnv String
485 prim_ty_cons = mkNameEnv [mk_prim intPrimTyCon]
486   where
487     mk_prim tycon = (tyConName tycon, '_' : getOccString tycon)
488