Update vectoriser now that PRepr has moved
[ghc-hetmet.git] / compiler / vectorise / Vectorise / Builtins / Initialise.hs
1
2
3 module Vectorise.Builtins.Initialise (
4         -- * Initialisation
5         initBuiltins, initBuiltinVars, initBuiltinTyCons, initBuiltinDataCons,
6         initBuiltinPAs, initBuiltinPRs,
7         initBuiltinBoxedTyCons, initBuiltinScalars,
8 ) where
9 import Vectorise.Builtins.Base
10 import Vectorise.Builtins.Modules
11 import Vectorise.Builtins.Prelude
12
13 import BasicTypes
14 import PrelNames
15 import TysPrim
16 import DsMonad
17 import IfaceEnv
18 import InstEnv
19 import TysWiredIn
20 import DataCon
21 import TyCon
22 import Class
23 import CoreSyn
24 import Type
25 import Name
26 import Module
27 import Id
28 import FastString
29 import Outputable
30
31 import Control.Monad
32 import Data.Array
33 import Data.List
34
35 -- | Create the initial map of builtin types and functions.
36 initBuiltins 
37         :: PackageId    -- ^ package id the builtins are in, eg dph-common
38         -> DsM Builtins
39
40 initBuiltins pkg
41  = do mapM_ load dph_Orphans
42
43       -- From dph-common:Data.Array.Parallel.PArray.PData
44       --     PData is a type family that maps an element type onto the type
45       --     we use to hold an array of those elements.
46       pdataTyCon        <- externalTyCon        dph_PArray_PData  (fsLit "PData")
47
48       --     PR is a type class that holds the primitive operators we can 
49       --     apply to array data. Its functions take arrays in terms of PData types.
50       prClass           <- externalClass        dph_PArray_PData  (fsLit "PR")
51       let prTyCon     = classTyCon prClass
52           [prDataCon] = tyConDataCons prTyCon
53
54
55       -- From dph-common:Data.Array.Parallel.PArray.PRepr
56       preprTyCon        <- externalTyCon        dph_PArray_PRepr  (fsLit "PRepr")
57       paClass           <- externalClass        dph_PArray_PRepr  (fsLit "PA")
58       let paTyCon     = classTyCon paClass
59           [paDataCon] = tyConDataCons paTyCon
60           paPRSel     = classSCSelId paClass 0
61
62       replicatePDVar    <- externalVar          dph_PArray_PRepr  (fsLit "replicatePD")
63       emptyPDVar        <- externalVar          dph_PArray_PRepr  (fsLit "emptyPD")
64       packByTagPDVar    <- externalVar          dph_PArray_PRepr  (fsLit "packByTagPD")
65       combines          <- mapM (externalVar dph_PArray_PRepr)
66                                 [mkFastString ("combine" ++ show i ++ "PD")
67                                         | i <- [2..mAX_DPH_COMBINE]]
68
69       let combinePDVars = listArray (2, mAX_DPH_COMBINE) combines
70
71
72       -- From dph-common:Data.Array.Parallel.PArray.Scalar
73       --     Scalar is the class of scalar values. 
74       --     The dictionary contains functions to coerce U.Arrays of scalars
75       --     to and from the PData representation.
76       scalarClass       <- externalClass        dph_PArray_Scalar (fsLit "Scalar")
77
78
79       -- From dph-common:Data.Array.Parallel.Lifted.PArray
80       --   A PArray (Parallel Array) holds the array length and some array elements
81       --   represented by the PData type family.
82       parrayTyCon       <- externalTyCon        dph_PArray      (fsLit "PArray")
83       let [parrayDataCon] = tyConDataCons parrayTyCon
84
85
86       closureTyCon      <- externalTyCon dph_Closure             (fsLit ":->")
87
88       -- From dph-common:Data.Array.Parallel.Lifted.Repr
89       voidTyCon         <- externalTyCon        dph_Repr        (fsLit "Void")
90       wrapTyCon         <- externalTyCon        dph_Repr        (fsLit "Wrap")
91
92       -- From dph-common:Data.Array.Parallel.Lifted.Unboxed
93       sel_tys           <- mapM (externalType dph_Unboxed)
94                                 (numbered "Sel" 2 mAX_DPH_SUM)
95
96       sel_replicates    <- mapM (externalFun dph_Unboxed)
97                                 (numbered_hash "replicateSel" 2 mAX_DPH_SUM)
98
99       sel_picks         <- mapM (externalFun dph_Unboxed)
100                                 (numbered_hash "pickSel" 2 mAX_DPH_SUM)
101
102       sel_tags          <- mapM (externalFun dph_Unboxed)
103                                 (numbered "tagsSel" 2 mAX_DPH_SUM)
104
105       sel_els           <- mapM mk_elements
106                                 [(i,j) | i <- [2..mAX_DPH_SUM], j <- [0..i-1]]
107
108       sum_tcs           <- mapM (externalTyCon dph_Repr)
109                                 (numbered "Sum" 2 mAX_DPH_SUM)
110
111       let selTys        = listArray (2, mAX_DPH_SUM) sel_tys
112           selReplicates = listArray (2, mAX_DPH_SUM) sel_replicates
113           selPicks      = listArray (2, mAX_DPH_SUM) sel_picks
114           selTagss      = listArray (2, mAX_DPH_SUM) sel_tags
115           selEls        = array     ((2,0), (mAX_DPH_SUM, mAX_DPH_SUM)) sel_els
116           sumTyCons     = listArray (2, mAX_DPH_SUM) sum_tcs
117
118
119       voidVar          <- externalVar dph_Repr          (fsLit "void")
120       pvoidVar         <- externalVar dph_Repr          (fsLit "pvoid")
121       fromVoidVar      <- externalVar dph_Repr          (fsLit "fromVoid")
122       punitVar         <- externalVar dph_Repr          (fsLit "punit")
123       closureVar       <- externalVar dph_Closure       (fsLit "closure")
124       applyVar         <- externalVar dph_Closure       (fsLit "$:")
125       liftedClosureVar <- externalVar dph_Closure       (fsLit "liftedClosure")
126       liftedApplyVar   <- externalVar dph_Closure       (fsLit "liftedApply")
127
128       scalar_map        <- externalVar  dph_Scalar      (fsLit "scalar_map")
129       scalar_zip2       <- externalVar  dph_Scalar      (fsLit "scalar_zipWith")
130       scalar_zips       <- mapM (externalVar dph_Scalar)
131                                 (numbered "scalar_zipWith" 3 mAX_DPH_SCALAR_ARGS)
132
133       let scalarZips    = listArray (1, mAX_DPH_SCALAR_ARGS)
134                                  (scalar_map : scalar_zip2 : scalar_zips)
135
136       closures          <- mapM (externalVar dph_Closure)
137                                 (numbered "closure" 1 mAX_DPH_SCALAR_ARGS)
138
139       let closureCtrFuns = listArray (1, mAX_DPH_COMBINE) closures
140
141       liftingContext    <- liftM (\u -> mkSysLocal (fsLit "lc") u intPrimTy)
142                                 newUnique
143
144       return   $ Builtins 
145                { dphModules       = mods
146                , parrayTyCon      = parrayTyCon
147                , parrayDataCon    = parrayDataCon
148                , pdataTyCon       = pdataTyCon
149                , paClass          = paClass
150                , paTyCon          = paTyCon
151                , paDataCon        = paDataCon
152                , paPRSel          = paPRSel
153                , preprTyCon       = preprTyCon
154                , prClass          = prClass
155                , prTyCon          = prTyCon
156                , prDataCon        = prDataCon
157                , voidTyCon        = voidTyCon
158                , wrapTyCon        = wrapTyCon
159                , selTys           = selTys
160                , selReplicates    = selReplicates
161                , selPicks         = selPicks
162                , selTagss         = selTagss
163                , selEls           = selEls
164                , sumTyCons        = sumTyCons
165                , closureTyCon     = closureTyCon
166                , voidVar          = voidVar
167                , pvoidVar         = pvoidVar
168                , fromVoidVar      = fromVoidVar
169                , punitVar         = punitVar
170                , closureVar       = closureVar
171                , applyVar         = applyVar
172                , liftedClosureVar = liftedClosureVar
173                , liftedApplyVar   = liftedApplyVar
174                , replicatePDVar   = replicatePDVar
175                , emptyPDVar       = emptyPDVar
176                , packByTagPDVar   = packByTagPDVar
177                , combinePDVars    = combinePDVars
178                , scalarClass      = scalarClass
179                , scalarZips       = scalarZips
180                , closureCtrFuns   = closureCtrFuns
181                , liftingContext   = liftingContext
182                }
183   where
184     -- Extract out all the modules we'll use.
185     -- These are the modules from the DPH base library that contain
186     --  the primitive array types and functions that vectorised code uses.
187     mods@(Modules 
188                 { dph_PArray            = dph_PArray
189                 , dph_PArray_Scalar     = dph_PArray_Scalar
190                 , dph_PArray_PRepr      = dph_PArray_PRepr
191                 , dph_PArray_PData      = dph_PArray_PData
192                 , dph_Repr              = dph_Repr
193                 , dph_Closure           = dph_Closure
194                 , dph_Scalar            = dph_Scalar
195                 , dph_Unboxed           = dph_Unboxed
196                 })
197       = dph_Modules pkg
198
199     load get_mod = dsLoadModule doc mod
200       where
201         mod = get_mod mods 
202         doc = ppr mod <+> ptext (sLit "is a DPH module")
203
204     -- Make a list of numbered strings in some range, eg foo3, foo4, foo5
205     numbered :: String -> Int -> Int -> [FastString]
206     numbered pfx m n = [mkFastString (pfx ++ show i) | i <- [m..n]]
207
208     numbered_hash :: String -> Int -> Int -> [FastString]
209     numbered_hash pfx m n = [mkFastString (pfx ++ show i ++ "#") | i <- [m..n]]
210
211     mk_elements :: (Int, Int) -> DsM ((Int, Int), CoreExpr)
212     mk_elements (i,j)
213       = do
214           v <- externalVar dph_Unboxed
215              $ mkFastString ("elementsSel" ++ show i ++ "_" ++ show j ++ "#")
216           return ((i,j), Var v)
217
218 -- | Get the mapping of names in the Prelude to names in the DPH library.
219 --
220 initBuiltinVars :: Bool   -- FIXME
221                 -> Builtins -> DsM [(Var, Var)]
222 initBuiltinVars compilingDPH (Builtins { dphModules = mods })
223   = do
224       uvars <- zipWithM externalVar umods ufs
225       vvars <- zipWithM externalVar vmods vfs
226       cvars <- zipWithM externalVar cmods cfs
227       return $ [(v,v) | v <- map dataConWorkId defaultDataConWorkers]
228                ++ zip (map dataConWorkId cons) cvars
229                ++ zip uvars vvars
230   where
231     (umods, ufs, vmods, vfs) = if compilingDPH then ([], [], [], []) else unzip4 (preludeVars mods)
232     (cons, cmods, cfs)       = unzip3 (preludeDataCons mods)
233
234     defaultDataConWorkers :: [DataCon]
235     defaultDataConWorkers = [trueDataCon, falseDataCon, unitDataCon]
236
237
238 preludeDataCons :: Modules -> [(DataCon, Module, FastString)]
239 preludeDataCons (Modules { dph_Prelude_Tuple = dph_Prelude_Tuple })
240   = [mk_tup n dph_Prelude_Tuple (mkFastString $ "tup" ++ show n) | n <- [2..3]]
241   where
242     mk_tup n mod name = (tupleCon Boxed n, mod, name)
243
244
245 -- | Get a list of names to `TyCon`s in the mock prelude.
246 initBuiltinTyCons :: Builtins -> DsM [(Name, TyCon)]
247 initBuiltinTyCons bi
248   = do
249       -- parr <- externalTyCon dph_Prelude_PArr (fsLit "PArr")
250       dft_tcs <- defaultTyCons
251       return $ (tyConName funTyCon, closureTyCon bi)
252              : (parrTyConName,      parrayTyCon bi)
253
254              -- FIXME: temporary
255              : (tyConName $ parrayTyCon bi, parrayTyCon bi)
256
257              : [(tyConName tc, tc) | tc <- dft_tcs]
258
259   where defaultTyCons :: DsM [TyCon]
260         defaultTyCons
261          = do   word8 <- dsLookupTyCon word8TyConName
262                 return [intTyCon, boolTyCon, doubleTyCon, word8]
263
264
265 -- | Get a list of names to `DataCon`s in the mock prelude.
266 initBuiltinDataCons :: Builtins -> [(Name, DataCon)]
267 initBuiltinDataCons _
268   = [(dataConName dc, dc)| dc <- defaultDataCons]
269   where defaultDataCons :: [DataCon]
270         defaultDataCons = [trueDataCon, falseDataCon, unitDataCon]
271
272
273 -- | Get the names of all buildin instance functions for the PA class.
274 initBuiltinPAs :: Builtins -> (InstEnv, InstEnv) -> DsM [(Name, Var)]
275 initBuiltinPAs (Builtins { dphModules = mods }) insts
276   = liftM (initBuiltinDicts insts) (externalClass (dph_PArray_PRepr mods) (fsLit "PA"))
277
278
279 -- | Get the names of all builtin instance functions for the PR class.
280 initBuiltinPRs :: Builtins -> (InstEnv, InstEnv) -> DsM [(Name, Var)]
281 initBuiltinPRs (Builtins { dphModules = mods }) insts
282   = liftM (initBuiltinDicts insts) (externalClass (dph_PArray_PData mods) (fsLit "PR"))
283
284
285 -- | Get the names of all DPH instance functions for this class.
286 initBuiltinDicts :: (InstEnv, InstEnv) -> Class -> [(Name, Var)]
287 initBuiltinDicts insts cls = map find $ classInstances insts cls
288   where
289     find i | [Just tc] <- instanceRoughTcs i    = (tc, instanceDFunId i)
290            | otherwise                          = pprPanic "Invalid DPH instance" (ppr i)
291
292
293 -- | Get a list of boxed `TyCons` in the mock prelude. This is Int only.
294 initBuiltinBoxedTyCons :: Builtins -> DsM [(Name, TyCon)]
295 initBuiltinBoxedTyCons 
296   = return . builtinBoxedTyCons
297   where builtinBoxedTyCons :: Builtins -> [(Name, TyCon)]
298         builtinBoxedTyCons _ 
299                 = [(tyConName intPrimTyCon, intTyCon)]
300
301 -- | Get a list of all scalar functions in the mock prelude.
302 --
303 initBuiltinScalars :: Bool 
304                    -> Builtins -> DsM [Var]
305 initBuiltinScalars True  _bi = return []
306 initBuiltinScalars False bi  = mapM (uncurry externalVar) (preludeScalars $ dphModules bi)
307
308 -- | Lookup some variable given its name and the module that contains it.
309 externalVar :: Module -> FastString -> DsM Var
310 externalVar mod fs
311   = dsLookupGlobalId =<< lookupOrig mod (mkVarOccFS fs)
312
313
314 -- | Like `externalVar` but wrap the `Var` in a `CoreExpr`
315 externalFun :: Module -> FastString -> DsM CoreExpr
316 externalFun mod fs
317  = do var <- externalVar mod fs
318       return $ Var var
319
320
321 -- | Lookup some `TyCon` given its name and the module that contains it.
322 externalTyCon :: Module -> FastString -> DsM TyCon
323 externalTyCon mod fs
324   = dsLookupTyCon =<< lookupOrig mod (mkTcOccFS fs)
325
326
327 -- | Lookup some `Type` given its name and the module that contains it.
328 externalType :: Module -> FastString -> DsM Type
329 externalType mod fs
330  = do  tycon <- externalTyCon mod fs
331        return $ mkTyConApp tycon []
332
333
334 -- | Lookup some `Class` given its name and the module that contains it.
335 externalClass :: Module -> FastString -> DsM Class
336 externalClass mod fs
337   = dsLookupClass =<< lookupOrig mod (mkClsOccFS fs)
338