2 -- | The vectoriser rewrites user code to use builtin types and functions exported by the DPH library.
3 -- We track the names of those things in the `Builtis` type, and provide selection functions
4 -- to help extract their names.
9 sumTyCon, prodTyCon, prodDataCon,
10 selTy,selReplicate, selPick, selTags, selElements,
11 combinePDVar, scalarZip, closureCtrFun,
14 initBuiltins, initBuiltinVars, initBuiltinTyCons, initBuiltinDataCons,
15 initBuiltinPAs, initBuiltinPRs,
16 initBuiltinBoxedTyCons, initBuiltinScalars,
18 primMethod, primPArray
21 import Vectorise.Builtins.Modules
22 import Vectorise.Builtins.Base
25 import IfaceEnv ( lookupOrig )
29 import DataCon ( DataCon, dataConName, dataConWorkId )
30 import TyCon ( TyCon, tyConName, tyConDataCons )
31 import Class ( Class, classTyCon )
32 import CoreSyn ( CoreExpr, Expr(..) )
34 import Id ( mkSysLocal )
35 import Name ( Name, getOccString )
39 import TypeRep ( funTyCon )
40 import Type ( Type, mkTyConApp )
42 import TysWiredIn ( unitDataCon,
46 boolTyCon, trueDataCon, falseDataCon,
48 import PrelNames ( word8TyConName, gHC_PARR, gHC_CLASSES )
49 import BasicTypes ( Boxity(..) )
55 import Control.Monad ( liftM, zipWithM )
56 import Data.List ( unzip4 )
61 -- Initialisation -------------------------------------------------------------
62 -- | Create the initial map of builtin types and functions.
64 :: PackageId -- ^ package id the builtins are in, eg dph-common
69 mapM_ load dph_Orphans
71 -- From dph-common:Data.Array.Parallel.Lifted.PArray
72 parrayTyCon <- externalTyCon dph_PArray (fsLit "PArray")
73 let [parrayDataCon] = tyConDataCons parrayTyCon
74 pdataTyCon <- externalTyCon dph_PArray (fsLit "PData")
75 paTyCon <- externalClassTyCon dph_PArray (fsLit "PA")
76 let [paDataCon] = tyConDataCons paTyCon
77 preprTyCon <- externalTyCon dph_PArray (fsLit "PRepr")
78 prTyCon <- externalClassTyCon dph_PArray (fsLit "PR")
79 let [prDataCon] = tyConDataCons prTyCon
82 closureTyCon <- externalTyCon dph_Closure (fsLit ":->")
84 -- From dph-common:Data.Array.Parallel.Lifted.Repr
85 voidTyCon <- externalTyCon dph_Repr (fsLit "Void")
86 wrapTyCon <- externalTyCon dph_Repr (fsLit "Wrap")
88 -- From dph-common:Data.Array.Parallel.Lifted.Unboxed
89 sel_tys <- mapM (externalType dph_Unboxed)
90 (numbered "Sel" 2 mAX_DPH_SUM)
92 sel_replicates <- mapM (externalFun dph_Unboxed)
93 (numbered_hash "replicateSel" 2 mAX_DPH_SUM)
95 sel_picks <- mapM (externalFun dph_Unboxed)
96 (numbered_hash "pickSel" 2 mAX_DPH_SUM)
98 sel_tags <- mapM (externalFun dph_Unboxed)
99 (numbered "tagsSel" 2 mAX_DPH_SUM)
101 sel_els <- mapM mk_elements
102 [(i,j) | i <- [2..mAX_DPH_SUM], j <- [0..i-1]]
104 sum_tcs <- mapM (externalTyCon dph_Repr)
105 (numbered "Sum" 2 mAX_DPH_SUM)
107 let selTys = listArray (2, mAX_DPH_SUM) sel_tys
108 selReplicates = listArray (2, mAX_DPH_SUM) sel_replicates
109 selPicks = listArray (2, mAX_DPH_SUM) sel_picks
110 selTagss = listArray (2, mAX_DPH_SUM) sel_tags
111 selEls = array ((2,0), (mAX_DPH_SUM, mAX_DPH_SUM)) sel_els
112 sumTyCons = listArray (2, mAX_DPH_SUM) sum_tcs
115 voidVar <- externalVar dph_Repr (fsLit "void")
116 pvoidVar <- externalVar dph_Repr (fsLit "pvoid")
117 fromVoidVar <- externalVar dph_Repr (fsLit "fromVoid")
118 punitVar <- externalVar dph_Repr (fsLit "punit")
119 closureVar <- externalVar dph_Closure (fsLit "closure")
120 applyVar <- externalVar dph_Closure (fsLit "$:")
121 liftedClosureVar <- externalVar dph_Closure (fsLit "liftedClosure")
122 liftedApplyVar <- externalVar dph_Closure (fsLit "liftedApply")
123 replicatePDVar <- externalVar dph_PArray (fsLit "replicatePD")
124 emptyPDVar <- externalVar dph_PArray (fsLit "emptyPD")
125 packByTagPDVar <- externalVar dph_PArray (fsLit "packByTagPD")
127 combines <- mapM (externalVar dph_PArray)
128 [mkFastString ("combine" ++ show i ++ "PD")
129 | i <- [2..mAX_DPH_COMBINE]]
130 let combinePDVars = listArray (2, mAX_DPH_COMBINE) combines
132 scalarClass <- externalClass dph_PArray (fsLit "Scalar")
133 scalar_map <- externalVar dph_Scalar (fsLit "scalar_map")
134 scalar_zip2 <- externalVar dph_Scalar (fsLit "scalar_zipWith")
135 scalar_zips <- mapM (externalVar dph_Scalar)
136 (numbered "scalar_zipWith" 3 mAX_DPH_SCALAR_ARGS)
137 let scalarZips = listArray (1, mAX_DPH_SCALAR_ARGS)
138 (scalar_map : scalar_zip2 : scalar_zips)
139 closures <- mapM (externalVar dph_Closure)
140 (numbered "closure" 1 mAX_DPH_SCALAR_ARGS)
141 let closureCtrFuns = listArray (1, mAX_DPH_COMBINE) closures
143 liftingContext <- liftM (\u -> mkSysLocal (fsLit "lc") u intPrimTy)
148 , parrayTyCon = parrayTyCon
149 , parrayDataCon = parrayDataCon
150 , pdataTyCon = pdataTyCon
152 , paDataCon = paDataCon
153 , preprTyCon = preprTyCon
155 , prDataCon = prDataCon
156 , voidTyCon = voidTyCon
157 , wrapTyCon = wrapTyCon
159 , selReplicates = selReplicates
160 , selPicks = selPicks
161 , selTagss = selTagss
163 , sumTyCons = sumTyCons
164 , closureTyCon = closureTyCon
166 , pvoidVar = pvoidVar
167 , fromVoidVar = fromVoidVar
168 , punitVar = punitVar
169 , closureVar = closureVar
170 , applyVar = applyVar
171 , liftedClosureVar = liftedClosureVar
172 , liftedApplyVar = liftedApplyVar
173 , replicatePDVar = replicatePDVar
174 , emptyPDVar = emptyPDVar
175 , packByTagPDVar = packByTagPDVar
176 , combinePDVars = combinePDVars
177 , scalarClass = scalarClass
178 , scalarZips = scalarZips
179 , closureCtrFuns = closureCtrFuns
180 , liftingContext = liftingContext
184 dph_PArray = dph_PArray
185 , dph_Repr = dph_Repr
186 , dph_Closure = dph_Closure
187 , dph_Scalar = dph_Scalar
188 , dph_Unboxed = dph_Unboxed
192 load get_mod = dsLoadModule doc mod
194 mod = get_mod modules
195 doc = ppr mod <+> ptext (sLit "is a DPH module")
197 -- Make a list of numbered strings in some range, eg foo3, foo4, foo5
198 numbered :: String -> Int -> Int -> [FastString]
199 numbered pfx m n = [mkFastString (pfx ++ show i) | i <- [m..n]]
201 numbered_hash :: String -> Int -> Int -> [FastString]
202 numbered_hash pfx m n = [mkFastString (pfx ++ show i ++ "#") | i <- [m..n]]
204 mk_elements :: (Int, Int) -> DsM ((Int, Int), CoreExpr)
207 v <- externalVar dph_Unboxed
208 $ mkFastString ("elementsSel" ++ show i ++ "_" ++ show j ++ "#")
209 return ((i,j), Var v)
212 -- | Get the mapping of names in the Prelude to names in the DPH library.
213 initBuiltinVars :: Builtins -> DsM [(Var, Var)]
214 initBuiltinVars (Builtins { dphModules = mods })
216 uvars <- zipWithM externalVar umods ufs
217 vvars <- zipWithM externalVar vmods vfs
218 cvars <- zipWithM externalVar cmods cfs
219 return $ [(v,v) | v <- map dataConWorkId defaultDataConWorkers]
220 ++ zip (map dataConWorkId cons) cvars
223 (umods, ufs, vmods, vfs) = unzip4 (preludeVars mods)
224 (cons, cmods, cfs) = unzip3 (preludeDataCons mods)
226 defaultDataConWorkers :: [DataCon]
227 defaultDataConWorkers = [trueDataCon, falseDataCon, unitDataCon]
229 preludeDataCons :: Modules -> [(DataCon, Module, FastString)]
230 preludeDataCons (Modules { dph_Prelude_Tuple = dph_Prelude_Tuple })
231 = [mk_tup n dph_Prelude_Tuple (mkFastString $ "tup" ++ show n) | n <- [2..3]]
233 mk_tup n mod name = (tupleCon Boxed n, mod, name)
236 -- | Mapping of prelude functions to vectorised versions.
237 -- Functions like filterP currently have a working but naive version in GHC.PArr
238 -- During vectorisation we replace these by calls to filterPA, which are
239 -- defined in dph-common Data.Array.Parallel.Lifted.Combinators
241 -- As renamer only sees the GHC.PArr functions, if you want to add a new function
242 -- to the vectoriser there has to be a definition for it in GHC.PArr, even though
243 -- it will never be used at runtime.
245 preludeVars :: Modules -> [(Module, FastString, Module, FastString)]
246 preludeVars (Modules { dph_Combinators = dph_Combinators
247 , dph_PArray = dph_PArray
248 , dph_Prelude_Int = dph_Prelude_Int
249 , dph_Prelude_Word8 = dph_Prelude_Word8
250 , dph_Prelude_Double = dph_Prelude_Double
251 , dph_Prelude_Bool = dph_Prelude_Bool
252 , dph_Prelude_PArr = dph_Prelude_PArr
255 -- Functions that work on whole PArrays, defined in GHC.PArr
256 = [ mk gHC_PARR (fsLit "mapP") dph_Combinators (fsLit "mapPA")
257 , mk gHC_PARR (fsLit "zipWithP") dph_Combinators (fsLit "zipWithPA")
258 , mk gHC_PARR (fsLit "zipP") dph_Combinators (fsLit "zipPA")
259 , mk gHC_PARR (fsLit "unzipP") dph_Combinators (fsLit "unzipPA")
260 , mk gHC_PARR (fsLit "filterP") dph_Combinators (fsLit "filterPA")
261 , mk gHC_PARR (fsLit "lengthP") dph_Combinators (fsLit "lengthPA")
262 , mk gHC_PARR (fsLit "replicateP") dph_Combinators (fsLit "replicatePA")
263 , mk gHC_PARR (fsLit "!:") dph_Combinators (fsLit "indexPA")
264 , mk gHC_PARR (fsLit "sliceP") dph_Combinators (fsLit "slicePA")
265 , mk gHC_PARR (fsLit "crossMapP") dph_Combinators (fsLit "crossMapPA")
266 , mk gHC_PARR (fsLit "singletonP") dph_Combinators (fsLit "singletonPA")
267 , mk gHC_PARR (fsLit "concatP") dph_Combinators (fsLit "concatPA")
268 , mk gHC_PARR (fsLit "+:+") dph_Combinators (fsLit "appPA")
269 , mk gHC_PARR (fsLit "emptyP") dph_PArray (fsLit "emptyPA")
271 -- Map scalar functions to versions using closures.
272 , mk' dph_Prelude_Int "div" "divV"
273 , mk' dph_Prelude_Int "mod" "modV"
274 , mk' dph_Prelude_Int "sqrt" "sqrtV"
275 , mk' dph_Prelude_Int "enumFromToP" "enumFromToPA"
276 -- , mk' dph_Prelude_Int "upToP" "upToPA"
278 ++ vars_Ord dph_Prelude_Int
279 ++ vars_Num dph_Prelude_Int
281 ++ vars_Ord dph_Prelude_Word8
282 ++ vars_Num dph_Prelude_Word8
284 [ mk' dph_Prelude_Word8 "div" "divV"
285 , mk' dph_Prelude_Word8 "mod" "modV"
286 , mk' dph_Prelude_Word8 "fromInt" "fromIntV"
287 , mk' dph_Prelude_Word8 "toInt" "toIntV"
290 ++ vars_Ord dph_Prelude_Double
291 ++ vars_Num dph_Prelude_Double
292 ++ vars_Fractional dph_Prelude_Double
293 ++ vars_Floating dph_Prelude_Double
294 ++ vars_RealFrac dph_Prelude_Double
296 [ mk dph_Prelude_Bool (fsLit "andP") dph_Prelude_Bool (fsLit "andPA")
297 , mk dph_Prelude_Bool (fsLit "orP") dph_Prelude_Bool (fsLit "orPA")
299 , mk gHC_CLASSES (fsLit "not") dph_Prelude_Bool (fsLit "notV")
300 , mk gHC_CLASSES (fsLit "&&") dph_Prelude_Bool (fsLit "andV")
301 , mk gHC_CLASSES (fsLit "||") dph_Prelude_Bool (fsLit "orV")
304 , mk dph_Prelude_PArr (fsLit "fromPArrayP") dph_Prelude_PArr (fsLit "fromPArrayPA")
305 , mk dph_Prelude_PArr (fsLit "toPArrayP") dph_Prelude_PArr (fsLit "toPArrayPA")
306 , mk dph_Prelude_PArr (fsLit "fromNestedPArrayP") dph_Prelude_PArr (fsLit "fromNestedPArrayPA")
307 , mk dph_Prelude_PArr (fsLit "combineP") dph_Combinators (fsLit "combine2PA")
308 , mk dph_Prelude_PArr (fsLit "updateP") dph_Combinators (fsLit "updatePA")
309 , mk dph_Prelude_PArr (fsLit "bpermuteP") dph_Combinators (fsLit "bpermutePA")
310 , mk dph_Prelude_PArr (fsLit "indexedP") dph_Combinators (fsLit "indexedPA")
314 mk' mod v v' = mk mod (fsLit v) mod (fsLit v')
317 = [ mk' mod "==" "eqV"
318 , mk' mod "/=" "neqV"
323 , mk' mod "min" "minV"
324 , mk' mod "max" "maxV"
325 , mk' mod "minimumP" "minimumPA"
326 , mk' mod "maximumP" "maximumPA"
327 , mk' mod "minIndexP" "minIndexPA"
328 , mk' mod "maxIndexP" "maxIndexPA"
332 = [ mk' mod "+" "plusV"
333 , mk' mod "-" "minusV"
334 , mk' mod "*" "multV"
335 , mk' mod "negate" "negateV"
336 , mk' mod "abs" "absV"
337 , mk' mod "sumP" "sumPA"
338 , mk' mod "productP" "productPA"
342 = [ mk' mod "/" "divideV"
343 , mk' mod "recip" "recipV"
347 = [ mk' mod "pi" "pi"
348 , mk' mod "exp" "expV"
349 , mk' mod "sqrt" "sqrtV"
350 , mk' mod "log" "logV"
351 , mk' mod "sin" "sinV"
352 , mk' mod "tan" "tanV"
353 , mk' mod "cos" "cosV"
354 , mk' mod "asin" "asinV"
355 , mk' mod "atan" "atanV"
356 , mk' mod "acos" "acosV"
357 , mk' mod "sinh" "sinhV"
358 , mk' mod "tanh" "tanhV"
359 , mk' mod "cosh" "coshV"
360 , mk' mod "asinh" "asinhV"
361 , mk' mod "atanh" "atanhV"
362 , mk' mod "acosh" "acoshV"
363 , mk' mod "**" "powV"
364 , mk' mod "logBase" "logBaseV"
368 = [ mk' mod "fromInt" "fromIntV"
369 , mk' mod "truncate" "truncateV"
370 , mk' mod "round" "roundV"
371 , mk' mod "ceiling" "ceilingV"
372 , mk' mod "floor" "floorV"
376 -- | Get a list of names to `TyCon`s in the mock prelude.
377 initBuiltinTyCons :: Builtins -> DsM [(Name, TyCon)]
380 -- parr <- externalTyCon dph_Prelude_PArr (fsLit "PArr")
381 dft_tcs <- defaultTyCons
382 return $ (tyConName funTyCon, closureTyCon bi)
383 : (parrTyConName, parrayTyCon bi)
386 : (tyConName $ parrayTyCon bi, parrayTyCon bi)
388 : [(tyConName tc, tc) | tc <- dft_tcs]
390 defaultTyCons :: DsM [TyCon]
393 word8 <- dsLookupTyCon word8TyConName
394 return [intTyCon, boolTyCon, doubleTyCon, word8]
397 -- | Get a list of names to `DataCon`s in the mock prelude.
398 initBuiltinDataCons :: Builtins -> [(Name, DataCon)]
399 initBuiltinDataCons _ = [(dataConName dc, dc)| dc <- defaultDataCons]
401 defaultDataCons :: [DataCon]
402 defaultDataCons = [trueDataCon, falseDataCon, unitDataCon]
405 -- | Get the names of all buildin instance functions for the PA class.
406 initBuiltinPAs :: Builtins -> (InstEnv, InstEnv) -> DsM [(Name, Var)]
407 initBuiltinPAs (Builtins { dphModules = mods }) insts
408 = liftM (initBuiltinDicts insts) (externalClass (dph_PArray mods) (fsLit "PA"))
411 -- | Get the names of all builtin instance functions for the PR class.
412 initBuiltinPRs :: Builtins -> (InstEnv, InstEnv) -> DsM [(Name, Var)]
413 initBuiltinPRs (Builtins { dphModules = mods }) insts
414 = liftM (initBuiltinDicts insts) (externalClass (dph_PArray mods) (fsLit "PR"))
417 -- | Get the names of all DPH instance functions for this class.
418 initBuiltinDicts :: (InstEnv, InstEnv) -> Class -> [(Name, Var)]
419 initBuiltinDicts insts cls = map find $ classInstances insts cls
421 find i | [Just tc] <- instanceRoughTcs i = (tc, instanceDFunId i)
422 | otherwise = pprPanic "Invalid DPH instance" (ppr i)
425 -- | Get a list of boxed `TyCons` in the mock prelude. This is Int only.
426 initBuiltinBoxedTyCons :: Builtins -> DsM [(Name, TyCon)]
427 initBuiltinBoxedTyCons = return . builtinBoxedTyCons
429 builtinBoxedTyCons :: Builtins -> [(Name, TyCon)]
431 = [(tyConName intPrimTyCon, intTyCon)]
434 -- | Get a list of all scalar functions in the mock prelude.
435 initBuiltinScalars :: Builtins -> DsM [Var]
436 initBuiltinScalars bi
437 = mapM (uncurry externalVar) (preludeScalars $ dphModules bi)
440 preludeScalars :: Modules -> [(Module, FastString)]
441 preludeScalars (Modules { dph_Prelude_Int = dph_Prelude_Int
442 , dph_Prelude_Word8 = dph_Prelude_Word8
443 , dph_Prelude_Double = dph_Prelude_Double
445 = [ mk dph_Prelude_Int "div"
446 , mk dph_Prelude_Int "mod"
447 , mk dph_Prelude_Int "sqrt"
449 ++ scalars_Ord dph_Prelude_Int
450 ++ scalars_Num dph_Prelude_Int
452 ++ scalars_Ord dph_Prelude_Word8
453 ++ scalars_Num dph_Prelude_Word8
455 [ mk dph_Prelude_Word8 "div"
456 , mk dph_Prelude_Word8 "mod"
457 , mk dph_Prelude_Word8 "fromInt"
458 , mk dph_Prelude_Word8 "toInt"
461 ++ scalars_Ord dph_Prelude_Double
462 ++ scalars_Num dph_Prelude_Double
463 ++ scalars_Fractional dph_Prelude_Double
464 ++ scalars_Floating dph_Prelude_Double
465 ++ scalars_RealFrac dph_Prelude_Double
467 mk mod s = (mod, fsLit s)
488 scalars_Fractional mod
523 -- | Lookup some variable given its name and the module that contains it.
524 externalVar :: Module -> FastString -> DsM Var
526 = dsLookupGlobalId =<< lookupOrig mod (mkVarOccFS fs)
529 -- | Like `externalVar` but wrap the `Var` in a `CoreExpr`
530 externalFun :: Module -> FastString -> DsM CoreExpr
532 = do var <- externalVar mod fs
536 -- | Lookup some `TyCon` given its name and the module that contains it.
537 externalTyCon :: Module -> FastString -> DsM TyCon
539 = dsLookupTyCon =<< lookupOrig mod (mkTcOccFS fs)
542 -- | Lookup some `Type` given its name and the module that contains it.
543 externalType :: Module -> FastString -> DsM Type
545 = do tycon <- externalTyCon mod fs
546 return $ mkTyConApp tycon []
549 -- | Lookup some `Class` given its name and the module that contains it.
550 externalClass :: Module -> FastString -> DsM Class
552 = dsLookupClass =<< lookupOrig mod (mkClsOccFS fs)
555 -- | Like `externalClass`, but get the TyCon of of the class.
556 externalClassTyCon :: Module -> FastString -> DsM TyCon
557 externalClassTyCon mod fs = liftM classTyCon (externalClass mod fs)
560 -- | Lookup a method function given its name and instance type.
561 primMethod :: TyCon -> String -> Builtins -> DsM (Maybe Var)
562 primMethod tycon method (Builtins { dphModules = mods })
563 | Just suffix <- lookupNameEnv prim_ty_cons (tyConName tycon)
565 $ dsLookupGlobalId =<< lookupOrig (dph_Unboxed mods)
566 (mkVarOcc $ method ++ suffix)
568 | otherwise = return Nothing
570 -- | Lookup the representation type we use for PArrays that contain a given element type.
571 primPArray :: TyCon -> Builtins -> DsM (Maybe TyCon)
572 primPArray tycon (Builtins { dphModules = mods })
573 | Just suffix <- lookupNameEnv prim_ty_cons (tyConName tycon)
575 $ dsLookupTyCon =<< lookupOrig (dph_Unboxed mods)
576 (mkTcOcc $ "PArray" ++ suffix)
578 | otherwise = return Nothing
580 prim_ty_cons :: NameEnv String
581 prim_ty_cons = mkNameEnv [mk_prim intPrimTyCon]
583 mk_prim tycon = (tyConName tycon, '_' : getOccString tycon)