b578f3087c5c8b6495c69e2b51f360efb43e17a9
[ghc-hetmet.git] / compiler / vectorise / Vectorise / Builtins / Prelude.hs
1
2 -- | Mapping of prelude functions to vectorised versions.
3 --     Functions like filterP currently have a working but naive version in GHC.PArr
4 --     During vectorisation we replace these by calls to filterPA, which are
5 --     defined in dph-common Data.Array.Parallel.Lifted.Combinators
6 --
7 --     As renamer only sees the GHC.PArr functions, if you want to add a new function
8 --     to the vectoriser there has to be a definition for it in GHC.PArr, even though
9 --     it will never be used at runtime.
10 --
11 module Vectorise.Builtins.Prelude
12         ( preludeVars
13         , preludeScalars)
14 where
15 import Vectorise.Builtins.Modules
16 import PrelNames
17 import Module
18 import FastString
19
20
21 preludeVars
22         :: Modules                      -- ^ Modules containing the DPH backens
23         -> [( Module, FastString        --   Maps the original variable to the one in the DPH 
24             , Module, FastString)]      --   packages that it should be rewritten to.
25
26 preludeVars (Modules { dph_Combinators    = dph_Combinators
27                      , dph_PArray         = dph_PArray
28                      , dph_Prelude_Int    = dph_Prelude_Int
29                      , dph_Prelude_Word8  = dph_Prelude_Word8
30                      , dph_Prelude_Double = dph_Prelude_Double
31                      , dph_Prelude_Bool   = dph_Prelude_Bool 
32                      , dph_Prelude_PArr   = dph_Prelude_PArr
33                      })
34
35     -- Functions that work on whole PArrays, defined in GHC.PArr
36   = [ mk gHC_PARR (fsLit "mapP")       dph_Combinators (fsLit "mapPA")
37     , mk gHC_PARR (fsLit "zipWithP")   dph_Combinators (fsLit "zipWithPA")
38     , mk gHC_PARR (fsLit "zipP")       dph_Combinators (fsLit "zipPA")
39     , mk gHC_PARR (fsLit "unzipP")     dph_Combinators (fsLit "unzipPA")
40     , mk gHC_PARR (fsLit "filterP")    dph_Combinators (fsLit "filterPA")
41     , mk gHC_PARR (fsLit "lengthP")    dph_Combinators (fsLit "lengthPA")
42     , mk gHC_PARR (fsLit "replicateP") dph_Combinators (fsLit "replicatePA")
43     , mk gHC_PARR (fsLit "!:")         dph_Combinators (fsLit "indexPA")
44     , mk gHC_PARR (fsLit "sliceP")     dph_Combinators (fsLit "slicePA")
45     , mk gHC_PARR (fsLit "crossMapP")  dph_Combinators (fsLit "crossMapPA")
46     , mk gHC_PARR (fsLit "singletonP") dph_Combinators (fsLit "singletonPA")
47     , mk gHC_PARR (fsLit "concatP")    dph_Combinators (fsLit "concatPA")
48     , mk gHC_PARR (fsLit "+:+")        dph_Combinators (fsLit "appPA")
49     , mk gHC_PARR (fsLit "emptyP")     dph_PArray      (fsLit "emptyPA")
50
51     -- Map scalar functions to versions using closures. 
52     , mk' dph_Prelude_Int "div"         "divV"
53     , mk' dph_Prelude_Int "mod"         "modV"
54     , mk' dph_Prelude_Int "sqrt"        "sqrtV"
55     , mk' dph_Prelude_Int "enumFromToP" "enumFromToPA"
56     -- , mk' dph_Prelude_Int "upToP" "upToPA"
57     ]
58     ++ vars_Ord dph_Prelude_Int
59     ++ vars_Num dph_Prelude_Int
60
61     ++ vars_Ord dph_Prelude_Word8
62     ++ vars_Num dph_Prelude_Word8
63     ++
64     [ mk' dph_Prelude_Word8 "div"     "divV"
65     , mk' dph_Prelude_Word8 "mod"     "modV"
66     , mk' dph_Prelude_Word8 "fromInt" "fromIntV"
67     , mk' dph_Prelude_Word8 "toInt"   "toIntV"
68     ]
69
70     ++ vars_Ord        dph_Prelude_Double
71     ++ vars_Num        dph_Prelude_Double
72     ++ vars_Fractional dph_Prelude_Double
73     ++ vars_Floating   dph_Prelude_Double
74     ++ vars_RealFrac   dph_Prelude_Double
75     ++
76     [ mk dph_Prelude_Bool  (fsLit "andP")  dph_Prelude_Bool (fsLit "andPA")
77     , mk dph_Prelude_Bool  (fsLit "orP")   dph_Prelude_Bool (fsLit "orPA")
78
79     , mk gHC_CLASSES (fsLit "not")         dph_Prelude_Bool (fsLit "notV")
80     , mk gHC_CLASSES (fsLit "&&")          dph_Prelude_Bool (fsLit "andV")
81     , mk gHC_CLASSES (fsLit "||")          dph_Prelude_Bool (fsLit "orV")
82
83     -- FIXME: temporary
84     , mk dph_Prelude_PArr (fsLit "fromPArrayP")       dph_Prelude_PArr (fsLit "fromPArrayPA")
85     , mk dph_Prelude_PArr (fsLit "toPArrayP")         dph_Prelude_PArr (fsLit "toPArrayPA")
86     , mk dph_Prelude_PArr (fsLit "fromNestedPArrayP") dph_Prelude_PArr (fsLit "fromNestedPArrayPA")
87     , mk dph_Prelude_PArr (fsLit "combineP")          dph_Combinators  (fsLit "combine2PA")
88     , mk dph_Prelude_PArr (fsLit "updateP")           dph_Combinators  (fsLit "updatePA")
89     , mk dph_Prelude_PArr (fsLit "bpermuteP")         dph_Combinators  (fsLit "bpermutePA")
90     , mk dph_Prelude_PArr (fsLit "indexedP")          dph_Combinators  (fsLit "indexedPA")
91     ]
92   where
93     mk  = (,,,)
94     mk' mod v v' = mk mod (fsLit v) mod (fsLit v')
95
96     vars_Ord mod 
97      = [ mk' mod "=="        "eqV"
98        , mk' mod "/="        "neqV"
99        , mk' mod "<="        "leV"
100        , mk' mod "<"         "ltV"
101        , mk' mod ">="        "geV"
102        , mk' mod ">"         "gtV"
103        , mk' mod "min"       "minV"
104        , mk' mod "max"       "maxV"
105        , mk' mod "minimumP"  "minimumPA"
106        , mk' mod "maximumP"  "maximumPA"
107        , mk' mod "minIndexP" "minIndexPA"
108        , mk' mod "maxIndexP" "maxIndexPA"
109        ]
110
111     vars_Num mod 
112      = [ mk' mod "+"        "plusV"
113        , mk' mod "-"        "minusV"
114        , mk' mod "*"        "multV"
115        , mk' mod "negate"   "negateV"
116        , mk' mod "abs"      "absV"
117        , mk' mod "sumP"     "sumPA"
118        , mk' mod "productP" "productPA"
119        ]
120
121     vars_Fractional mod 
122      = [ mk' mod "/"     "divideV"
123        , mk' mod "recip" "recipV"
124        ]
125
126     vars_Floating mod 
127      = [ mk' mod "pi"      "pi"
128        , mk' mod "exp"     "expV"
129        , mk' mod "sqrt"    "sqrtV"
130        , mk' mod "log"     "logV"
131        , mk' mod "sin"     "sinV"
132        , mk' mod "tan"     "tanV"
133        , mk' mod "cos"     "cosV"
134        , mk' mod "asin"    "asinV"
135        , mk' mod "atan"    "atanV"
136        , mk' mod "acos"    "acosV"
137        , mk' mod "sinh"    "sinhV"
138        , mk' mod "tanh"    "tanhV"
139        , mk' mod "cosh"    "coshV"
140        , mk' mod "asinh"   "asinhV"
141        , mk' mod "atanh"   "atanhV"
142        , mk' mod "acosh"   "acoshV"
143        , mk' mod "**"      "powV"
144        , mk' mod "logBase" "logBaseV"
145        ]
146
147     vars_RealFrac mod
148      = [ mk' mod "fromInt"  "fromIntV"
149        , mk' mod "truncate" "truncateV"
150        , mk' mod "round"    "roundV"
151        , mk' mod "ceiling"  "ceilingV"
152        , mk' mod "floor"    "floorV"
153        ]
154
155
156 preludeScalars :: Modules -> [(Module, FastString)]
157 preludeScalars (Modules { dph_Prelude_Int    = dph_Prelude_Int
158                         , dph_Prelude_Word8  = dph_Prelude_Word8
159                         , dph_Prelude_Double = dph_Prelude_Double
160                         })
161   = [ mk dph_Prelude_Int "div"
162     , mk dph_Prelude_Int "mod"
163     , mk dph_Prelude_Int "sqrt"
164     ]
165     ++ scalars_Ord dph_Prelude_Int
166     ++ scalars_Num dph_Prelude_Int
167
168     ++ scalars_Ord dph_Prelude_Word8
169     ++ scalars_Num dph_Prelude_Word8
170     ++
171     [ mk dph_Prelude_Word8 "div"
172     , mk dph_Prelude_Word8 "mod"
173     , mk dph_Prelude_Word8 "fromInt"
174     , mk dph_Prelude_Word8 "toInt"
175     ]
176
177     ++ scalars_Ord dph_Prelude_Double
178     ++ scalars_Num dph_Prelude_Double
179     ++ scalars_Fractional dph_Prelude_Double
180     ++ scalars_Floating dph_Prelude_Double
181     ++ scalars_RealFrac dph_Prelude_Double
182   where
183     mk mod s = (mod, fsLit s)
184
185     scalars_Ord mod 
186      = [ mk mod "=="
187        , mk mod "/="
188        , mk mod "<="
189        , mk mod "<"
190        , mk mod ">="
191        , mk mod ">"
192        , mk mod "min"
193        , mk mod "max"
194        ]
195
196     scalars_Num mod 
197      = [ mk mod "+"
198        , mk mod "-"
199        , mk mod "*"
200        , mk mod "negate"
201        , mk mod "abs"
202        ]
203
204     scalars_Fractional mod 
205      = [ mk mod "/"
206        , mk mod "recip"
207        ]
208
209     scalars_Floating mod 
210      = [ mk mod "pi"
211        , mk mod "exp"
212        , mk mod "sqrt"
213        , mk mod "log"
214        , mk mod "sin"
215        , mk mod "tan"
216        , mk mod "cos"
217        , mk mod "asin"
218        , mk mod "atan"
219        , mk mod "acos"
220        , mk mod "sinh"
221        , mk mod "tanh"
222        , mk mod "cosh"
223        , mk mod "asinh"
224        , mk mod "atanh"
225        , mk mod "acosh"
226        , mk mod "**"
227        , mk mod "logBase"
228        ]
229
230     scalars_RealFrac mod 
231      = [ mk mod "fromInt"
232        , mk mod "truncate"
233        , mk mod "round"
234        , mk mod "ceiling"
235        , mk mod "floor"
236        ]