731371e1615a3166fbac71b2d6b05e8d0b840438
[ghc-hetmet.git] / compiler / vectorise / Vectorise / Builtins / Prelude.hs
1
2 -- WARNING: This module is a temporary kludge.  It will soon go away entirely (once 
3 --   VECTORISE SCALAR pragmas are fully implemented.)
4
5 -- | Mapping of prelude functions to vectorised versions.
6 --     Functions like filterP currently have a working but naive version in GHC.PArr
7 --     During vectorisation we replace these by calls to filterPA, which are
8 --     defined in dph-common Data.Array.Parallel.Lifted.Combinators
9 --
10 --     As renamer only sees the GHC.PArr functions, if you want to add a new function
11 --     to the vectoriser there has to be a definition for it in GHC.PArr, even though
12 --     it will never be used at runtime.
13 --
14 module Vectorise.Builtins.Prelude
15         ( preludeVars
16         , preludeScalars)
17 where
18 import Vectorise.Builtins.Modules
19 import PrelNames
20 import Module
21 import FastString
22
23
24 preludeVars :: Modules
25         -> [( Module, FastString        --   Maps the original variable to the one in the DPH 
26             , Module, FastString)]      --   packages that it should be rewritten to.
27 preludeVars (Modules { dph_Combinators    = _dph_Combinators
28                      , dph_Prelude_Int    = dph_Prelude_Int
29                      , dph_Prelude_Word8  = dph_Prelude_Word8
30                      , dph_Prelude_Double = dph_Prelude_Double
31                      , dph_Prelude_Bool   = dph_Prelude_Bool 
32                      })
33
34     -- Functions that work on whole PArrays, defined in GHC.PArr
35   = [ {- mk gHC_PARR' (fsLit "mapP")       dph_Combinators (fsLit "mapPA")
36     , mk gHC_PARR' (fsLit "zipWithP")   dph_Combinators (fsLit "zipWithPA")
37     , mk gHC_PARR' (fsLit "zipP")       dph_Combinators (fsLit "zipPA")
38     , mk gHC_PARR' (fsLit "unzipP")     dph_Combinators (fsLit "unzipPA")
39     , mk gHC_PARR' (fsLit "filterP")    dph_Combinators (fsLit "filterPA")
40     , mk gHC_PARR' (fsLit "lengthP")    dph_Combinators (fsLit "lengthPA")
41     , mk gHC_PARR' (fsLit "replicateP") dph_Combinators (fsLit "replicatePA")
42     , mk gHC_PARR' (fsLit "!:")         dph_Combinators (fsLit "indexPA")
43     , mk gHC_PARR' (fsLit "sliceP")     dph_Combinators (fsLit "slicePA")
44     , mk gHC_PARR' (fsLit "crossMapP")  dph_Combinators (fsLit "crossMapPA")
45     , mk gHC_PARR' (fsLit "singletonP") dph_Combinators (fsLit "singletonPA")
46     , mk gHC_PARR' (fsLit "concatP")    dph_Combinators (fsLit "concatPA")
47     , mk gHC_PARR' (fsLit "+:+")        dph_Combinators (fsLit "appPA")
48     , mk gHC_PARR' (fsLit "emptyP")     dph_PArray      (fsLit "emptyPA")
49
50     -- Map scalar functions to versions using closures. 
51     , -} mk' dph_Prelude_Int "div"         "divV"
52     , mk' dph_Prelude_Int "mod"         "modV"
53     , mk' dph_Prelude_Int "sqrt"        "sqrtV"
54     , mk' dph_Prelude_Int "enumFromToP" "enumFromToPA"
55     -- , mk' dph_Prelude_Int "upToP" "upToPA"
56     ]
57     ++ vars_Ord dph_Prelude_Int
58     ++ vars_Num dph_Prelude_Int
59
60     ++ vars_Ord dph_Prelude_Word8
61     ++ vars_Num dph_Prelude_Word8
62     ++
63     [ mk' dph_Prelude_Word8 "div"     "divV"
64     , mk' dph_Prelude_Word8 "mod"     "modV"
65     , mk' dph_Prelude_Word8 "fromInt" "fromIntV"
66     , mk' dph_Prelude_Word8 "toInt"   "toIntV"
67     ]
68
69     ++ vars_Ord        dph_Prelude_Double
70     ++ vars_Num        dph_Prelude_Double
71     ++ vars_Fractional dph_Prelude_Double
72     ++ vars_Floating   dph_Prelude_Double
73     ++ vars_RealFrac   dph_Prelude_Double
74     ++
75     [ mk dph_Prelude_Bool  (fsLit "andP")  dph_Prelude_Bool (fsLit "andPA")
76     , mk dph_Prelude_Bool  (fsLit "orP")   dph_Prelude_Bool (fsLit "orPA")
77
78     , mk gHC_CLASSES (fsLit "not")         dph_Prelude_Bool (fsLit "notV")
79     , mk gHC_CLASSES (fsLit "&&")          dph_Prelude_Bool (fsLit "andV")
80     , mk gHC_CLASSES (fsLit "||")          dph_Prelude_Bool (fsLit "orV")
81
82 {-
83     -- FIXME: temporary
84     , mk dph_Prelude_PArr (fsLit "fromPArrayP")       dph_Prelude_PArr (fsLit "fromPArrayPA")
85     , mk dph_Prelude_PArr (fsLit "toPArrayP")         dph_Prelude_PArr (fsLit "toPArrayPA")
86     , mk dph_Prelude_PArr (fsLit "fromNestedPArrayP") dph_Prelude_PArr (fsLit "fromNestedPArrayPA")
87     , mk dph_Prelude_PArr (fsLit "combineP")          dph_Combinators  (fsLit "combine2PA")
88     , mk dph_Prelude_PArr (fsLit "updateP")           dph_Combinators  (fsLit "updatePA")
89     , mk dph_Prelude_PArr (fsLit "bpermuteP")         dph_Combinators  (fsLit "bpermutePA")
90     , mk dph_Prelude_PArr (fsLit "indexedP")          dph_Combinators  (fsLit "indexedPA")
91 -}    ]
92   where
93     mk  = (,,,)
94     mk' mod v v' = mk mod (fsLit v) mod (fsLit v')
95
96     vars_Ord mod 
97      = [ mk' mod "=="        "eqV"
98        , mk' mod "/="        "neqV"
99        , mk' mod "<="        "leV"
100        , mk' mod "<"         "ltV"
101        , mk' mod ">="        "geV"
102        , mk' mod ">"         "gtV"
103        , mk' mod "min"       "minV"
104        , mk' mod "max"       "maxV"
105        , mk' mod "minimumP"  "minimumPA"
106        , mk' mod "maximumP"  "maximumPA"
107        , mk' mod "minIndexP" "minIndexPA"
108        , mk' mod "maxIndexP" "maxIndexPA"
109        ]
110
111     vars_Num mod 
112      = [ mk' mod "+"        "plusV"
113        , mk' mod "-"        "minusV"
114        , mk' mod "*"        "multV"
115        , mk' mod "negate"   "negateV"
116        , mk' mod "abs"      "absV"
117        , mk' mod "sumP"     "sumPA"
118        , mk' mod "productP" "productPA"
119        ]
120
121     vars_Fractional mod 
122      = [ mk' mod "/"     "divideV"
123        , mk' mod "recip" "recipV"
124        ]
125
126     vars_Floating mod 
127      = [ mk' mod "pi"      "pi"
128        , mk' mod "exp"     "expV"
129        , mk' mod "sqrt"    "sqrtV"
130        , mk' mod "log"     "logV"
131        , mk' mod "sin"     "sinV"
132        , mk' mod "tan"     "tanV"
133        , mk' mod "cos"     "cosV"
134        , mk' mod "asin"    "asinV"
135        , mk' mod "atan"    "atanV"
136        , mk' mod "acos"    "acosV"
137        , mk' mod "sinh"    "sinhV"
138        , mk' mod "tanh"    "tanhV"
139        , mk' mod "cosh"    "coshV"
140        , mk' mod "asinh"   "asinhV"
141        , mk' mod "atanh"   "atanhV"
142        , mk' mod "acosh"   "acoshV"
143        , mk' mod "**"      "powV"
144        , mk' mod "logBase" "logBaseV"
145        ]
146
147     vars_RealFrac mod
148      = [ mk' mod "fromInt"  "fromIntV"
149        , mk' mod "truncate" "truncateV"
150        , mk' mod "round"    "roundV"
151        , mk' mod "ceiling"  "ceilingV"
152        , mk' mod "floor"    "floorV"
153        ]
154
155 preludeScalars :: Modules -> [(Module, FastString)]
156 preludeScalars (Modules { dph_Prelude_Int    = dph_Prelude_Int
157                         , dph_Prelude_Word8  = dph_Prelude_Word8
158                         , dph_Prelude_Double = dph_Prelude_Double
159                         })
160   = [ mk dph_Prelude_Int "div"
161     , mk dph_Prelude_Int "mod"
162     , mk dph_Prelude_Int "sqrt"
163     ]
164     ++ scalars_Ord dph_Prelude_Int
165     ++ scalars_Num dph_Prelude_Int
166
167     ++ scalars_Ord dph_Prelude_Word8
168     ++ scalars_Num dph_Prelude_Word8
169     ++
170     [ mk dph_Prelude_Word8 "div"
171     , mk dph_Prelude_Word8 "mod"
172     , mk dph_Prelude_Word8 "fromInt"
173     , mk dph_Prelude_Word8 "toInt"
174     ]
175
176     ++ scalars_Ord dph_Prelude_Double
177     ++ scalars_Num dph_Prelude_Double
178     ++ scalars_Fractional dph_Prelude_Double
179     ++ scalars_Floating dph_Prelude_Double
180     ++ scalars_RealFrac dph_Prelude_Double
181   where
182     mk mod s = (mod, fsLit s)
183
184     scalars_Ord mod 
185      = [ mk mod "=="
186        , mk mod "/="
187        , mk mod "<="
188        , mk mod "<"
189        , mk mod ">="
190        , mk mod ">"
191        , mk mod "min"
192        , mk mod "max"
193        ]
194
195     scalars_Num mod 
196      = [ mk mod "+"
197        , mk mod "-"
198        , mk mod "*"
199        , mk mod "negate"
200        , mk mod "abs"
201        ]
202
203     scalars_Fractional mod 
204      = [ mk mod "/"
205        , mk mod "recip"
206        ]
207
208     scalars_Floating mod 
209      = [ mk mod "pi"
210        , mk mod "exp"
211        , mk mod "sqrt"
212        , mk mod "log"
213        , mk mod "sin"
214        , mk mod "tan"
215        , mk mod "cos"
216        , mk mod "asin"
217        , mk mod "atan"
218        , mk mod "acos"
219        , mk mod "sinh"
220        , mk mod "tanh"
221        , mk mod "cosh"
222        , mk mod "asinh"
223        , mk mod "atanh"
224        , mk mod "acosh"
225        , mk mod "**"
226        , mk mod "logBase"
227        ]
228
229     scalars_RealFrac mod 
230      = [ mk mod "fromInt"
231        , mk mod "truncate"
232        , mk mod "round"
233        , mk mod "ceiling"
234        , mk mod "floor"
235        ]