Added a VECTORISE pragma
[ghc-hetmet.git] / compiler / vectorise / Vectorise / Builtins / Prelude.hs
1
2 -- WARNING: This module is a temporary kludge.  It will soon go away entirely (once 
3 --   VECTORISE SCALAR pragmas are fully implemented.)
4
5 -- | Mapping of prelude functions to vectorised versions.
6 --     Functions like filterP currently have a working but naive version in GHC.PArr
7 --     During vectorisation we replace these by calls to filterPA, which are
8 --     defined in dph-common Data.Array.Parallel.Lifted.Combinators
9 --
10 --     As renamer only sees the GHC.PArr functions, if you want to add a new function
11 --     to the vectoriser there has to be a definition for it in GHC.PArr, even though
12 --     it will never be used at runtime.
13 --
14 module Vectorise.Builtins.Prelude
15         ( preludeVars
16         , preludeScalars)
17 where
18 import Vectorise.Builtins.Modules
19 import PrelNames
20 import Module
21 import FastString
22
23
24 preludeVars :: Modules
25         -> [( Module, FastString        --   Maps the original variable to the one in the DPH 
26             , Module, FastString)]      --   packages that it should be rewritten to.
27 preludeVars (Modules { dph_Combinators    = _dph_Combinators
28                      , dph_PArray         = _dph_PArray
29                      , dph_Prelude_Int    = dph_Prelude_Int
30                      , dph_Prelude_Word8  = dph_Prelude_Word8
31                      , dph_Prelude_Double = dph_Prelude_Double
32                      , dph_Prelude_Bool   = dph_Prelude_Bool 
33                      , dph_Prelude_PArr   = _dph_Prelude_PArr
34                      })
35
36     -- Functions that work on whole PArrays, defined in GHC.PArr
37   = [ {- mk gHC_PARR' (fsLit "mapP")       dph_Combinators (fsLit "mapPA")
38     , mk gHC_PARR' (fsLit "zipWithP")   dph_Combinators (fsLit "zipWithPA")
39     , mk gHC_PARR' (fsLit "zipP")       dph_Combinators (fsLit "zipPA")
40     , mk gHC_PARR' (fsLit "unzipP")     dph_Combinators (fsLit "unzipPA")
41     , mk gHC_PARR' (fsLit "filterP")    dph_Combinators (fsLit "filterPA")
42     , mk gHC_PARR' (fsLit "lengthP")    dph_Combinators (fsLit "lengthPA")
43     , mk gHC_PARR' (fsLit "replicateP") dph_Combinators (fsLit "replicatePA")
44     , mk gHC_PARR' (fsLit "!:")         dph_Combinators (fsLit "indexPA")
45     , mk gHC_PARR' (fsLit "sliceP")     dph_Combinators (fsLit "slicePA")
46     , mk gHC_PARR' (fsLit "crossMapP")  dph_Combinators (fsLit "crossMapPA")
47     , mk gHC_PARR' (fsLit "singletonP") dph_Combinators (fsLit "singletonPA")
48     , mk gHC_PARR' (fsLit "concatP")    dph_Combinators (fsLit "concatPA")
49     , mk gHC_PARR' (fsLit "+:+")        dph_Combinators (fsLit "appPA")
50     , mk gHC_PARR' (fsLit "emptyP")     dph_PArray      (fsLit "emptyPA")
51
52     -- Map scalar functions to versions using closures. 
53     , -} mk' dph_Prelude_Int "div"         "divV"
54     , mk' dph_Prelude_Int "mod"         "modV"
55     , mk' dph_Prelude_Int "sqrt"        "sqrtV"
56     , mk' dph_Prelude_Int "enumFromToP" "enumFromToPA"
57     -- , mk' dph_Prelude_Int "upToP" "upToPA"
58     ]
59     ++ vars_Ord dph_Prelude_Int
60     ++ vars_Num dph_Prelude_Int
61
62     ++ vars_Ord dph_Prelude_Word8
63     ++ vars_Num dph_Prelude_Word8
64     ++
65     [ mk' dph_Prelude_Word8 "div"     "divV"
66     , mk' dph_Prelude_Word8 "mod"     "modV"
67     , mk' dph_Prelude_Word8 "fromInt" "fromIntV"
68     , mk' dph_Prelude_Word8 "toInt"   "toIntV"
69     ]
70
71     ++ vars_Ord        dph_Prelude_Double
72     ++ vars_Num        dph_Prelude_Double
73     ++ vars_Fractional dph_Prelude_Double
74     ++ vars_Floating   dph_Prelude_Double
75     ++ vars_RealFrac   dph_Prelude_Double
76     ++
77     [ mk dph_Prelude_Bool  (fsLit "andP")  dph_Prelude_Bool (fsLit "andPA")
78     , mk dph_Prelude_Bool  (fsLit "orP")   dph_Prelude_Bool (fsLit "orPA")
79
80     , mk gHC_CLASSES (fsLit "not")         dph_Prelude_Bool (fsLit "notV")
81     , mk gHC_CLASSES (fsLit "&&")          dph_Prelude_Bool (fsLit "andV")
82     , mk gHC_CLASSES (fsLit "||")          dph_Prelude_Bool (fsLit "orV")
83
84 {-
85     -- FIXME: temporary
86     , mk dph_Prelude_PArr (fsLit "fromPArrayP")       dph_Prelude_PArr (fsLit "fromPArrayPA")
87     , mk dph_Prelude_PArr (fsLit "toPArrayP")         dph_Prelude_PArr (fsLit "toPArrayPA")
88     , mk dph_Prelude_PArr (fsLit "fromNestedPArrayP") dph_Prelude_PArr (fsLit "fromNestedPArrayPA")
89     , mk dph_Prelude_PArr (fsLit "combineP")          dph_Combinators  (fsLit "combine2PA")
90     , mk dph_Prelude_PArr (fsLit "updateP")           dph_Combinators  (fsLit "updatePA")
91     , mk dph_Prelude_PArr (fsLit "bpermuteP")         dph_Combinators  (fsLit "bpermutePA")
92     , mk dph_Prelude_PArr (fsLit "indexedP")          dph_Combinators  (fsLit "indexedPA")
93 -}    ]
94   where
95     mk  = (,,,)
96     mk' mod v v' = mk mod (fsLit v) mod (fsLit v')
97
98     vars_Ord mod 
99      = [ mk' mod "=="        "eqV"
100        , mk' mod "/="        "neqV"
101        , mk' mod "<="        "leV"
102        , mk' mod "<"         "ltV"
103        , mk' mod ">="        "geV"
104        , mk' mod ">"         "gtV"
105        , mk' mod "min"       "minV"
106        , mk' mod "max"       "maxV"
107        , mk' mod "minimumP"  "minimumPA"
108        , mk' mod "maximumP"  "maximumPA"
109        , mk' mod "minIndexP" "minIndexPA"
110        , mk' mod "maxIndexP" "maxIndexPA"
111        ]
112
113     vars_Num mod 
114      = [ mk' mod "+"        "plusV"
115        , mk' mod "-"        "minusV"
116        , mk' mod "*"        "multV"
117        , mk' mod "negate"   "negateV"
118        , mk' mod "abs"      "absV"
119        , mk' mod "sumP"     "sumPA"
120        , mk' mod "productP" "productPA"
121        ]
122
123     vars_Fractional mod 
124      = [ mk' mod "/"     "divideV"
125        , mk' mod "recip" "recipV"
126        ]
127
128     vars_Floating mod 
129      = [ mk' mod "pi"      "pi"
130        , mk' mod "exp"     "expV"
131        , mk' mod "sqrt"    "sqrtV"
132        , mk' mod "log"     "logV"
133        , mk' mod "sin"     "sinV"
134        , mk' mod "tan"     "tanV"
135        , mk' mod "cos"     "cosV"
136        , mk' mod "asin"    "asinV"
137        , mk' mod "atan"    "atanV"
138        , mk' mod "acos"    "acosV"
139        , mk' mod "sinh"    "sinhV"
140        , mk' mod "tanh"    "tanhV"
141        , mk' mod "cosh"    "coshV"
142        , mk' mod "asinh"   "asinhV"
143        , mk' mod "atanh"   "atanhV"
144        , mk' mod "acosh"   "acoshV"
145        , mk' mod "**"      "powV"
146        , mk' mod "logBase" "logBaseV"
147        ]
148
149     vars_RealFrac mod
150      = [ mk' mod "fromInt"  "fromIntV"
151        , mk' mod "truncate" "truncateV"
152        , mk' mod "round"    "roundV"
153        , mk' mod "ceiling"  "ceilingV"
154        , mk' mod "floor"    "floorV"
155        ]
156
157 preludeScalars :: Modules -> [(Module, FastString)]
158 preludeScalars (Modules { dph_Prelude_Int    = dph_Prelude_Int
159                         , dph_Prelude_Word8  = dph_Prelude_Word8
160                         , dph_Prelude_Double = dph_Prelude_Double
161                         })
162   = [ mk dph_Prelude_Int "div"
163     , mk dph_Prelude_Int "mod"
164     , mk dph_Prelude_Int "sqrt"
165     ]
166     ++ scalars_Ord dph_Prelude_Int
167     ++ scalars_Num dph_Prelude_Int
168
169     ++ scalars_Ord dph_Prelude_Word8
170     ++ scalars_Num dph_Prelude_Word8
171     ++
172     [ mk dph_Prelude_Word8 "div"
173     , mk dph_Prelude_Word8 "mod"
174     , mk dph_Prelude_Word8 "fromInt"
175     , mk dph_Prelude_Word8 "toInt"
176     ]
177
178     ++ scalars_Ord dph_Prelude_Double
179     ++ scalars_Num dph_Prelude_Double
180     ++ scalars_Fractional dph_Prelude_Double
181     ++ scalars_Floating dph_Prelude_Double
182     ++ scalars_RealFrac dph_Prelude_Double
183   where
184     mk mod s = (mod, fsLit s)
185
186     scalars_Ord mod 
187      = [ mk mod "=="
188        , mk mod "/="
189        , mk mod "<="
190        , mk mod "<"
191        , mk mod ">="
192        , mk mod ">"
193        , mk mod "min"
194        , mk mod "max"
195        ]
196
197     scalars_Num mod 
198      = [ mk mod "+"
199        , mk mod "-"
200        , mk mod "*"
201        , mk mod "negate"
202        , mk mod "abs"
203        ]
204
205     scalars_Fractional mod 
206      = [ mk mod "/"
207        , mk mod "recip"
208        ]
209
210     scalars_Floating mod 
211      = [ mk mod "pi"
212        , mk mod "exp"
213        , mk mod "sqrt"
214        , mk mod "log"
215        , mk mod "sin"
216        , mk mod "tan"
217        , mk mod "cos"
218        , mk mod "asin"
219        , mk mod "atan"
220        , mk mod "acos"
221        , mk mod "sinh"
222        , mk mod "tanh"
223        , mk mod "cosh"
224        , mk mod "asinh"
225        , mk mod "atanh"
226        , mk mod "acosh"
227        , mk mod "**"
228        , mk mod "logBase"
229        ]
230
231     scalars_RealFrac mod 
232      = [ mk mod "fromInt"
233        , mk mod "truncate"
234        , mk mod "round"
235        , mk mod "ceiling"
236        , mk mod "floor"
237        ]