Update vectoriser now that PData instances have moved.
[ghc-hetmet.git] / compiler / vectorise / Vectorise / Builtins / Initialise.hs
1
2
3 module Vectorise.Builtins.Initialise (
4         -- * Initialisation
5         initBuiltins, initBuiltinVars, initBuiltinTyCons, initBuiltinDataCons,
6         initBuiltinPAs, initBuiltinPRs,
7         initBuiltinBoxedTyCons, initBuiltinScalars,
8 ) where
9 import Vectorise.Builtins.Base
10 import Vectorise.Builtins.Modules
11 import Vectorise.Builtins.Prelude
12
13 import BasicTypes
14 import PrelNames
15 import TysPrim
16 import DsMonad
17 import IfaceEnv
18 import InstEnv
19 import TysWiredIn
20 import DataCon
21 import TyCon
22 import Class
23 import CoreSyn
24 import Type
25 import Name
26 import Module
27 import Id
28 import FastString
29 import Outputable
30
31 import Control.Monad
32 import Data.Array
33 import Data.List
34
35 -- | Create the initial map of builtin types and functions.
36 initBuiltins 
37         :: PackageId    -- ^ package id the builtins are in, eg dph-common
38         -> DsM Builtins
39
40 initBuiltins pkg
41  = do mapM_ load dph_Orphans
42
43       -- From dph-common:Data.Array.Parallel.PArray.PData
44       --     PData is a type family that maps an element type onto the type
45       --     we use to hold an array of those elements.
46       pdataTyCon        <- externalTyCon        dph_PArray_PData  (fsLit "PData")
47
48       --     PR is a type class that holds the primitive operators we can 
49       --     apply to array data. Its functions take arrays in terms of PData types.
50       prClass           <- externalClass        dph_PArray_PData  (fsLit "PR")
51       let prTyCon     = classTyCon prClass
52           [prDataCon] = tyConDataCons prTyCon
53
54
55       -- From dph-common:Data.Array.Parallel.PArray.PRepr
56       preprTyCon        <- externalTyCon        dph_PArray_PRepr  (fsLit "PRepr")
57       paClass           <- externalClass        dph_PArray_PRepr  (fsLit "PA")
58       let paTyCon     = classTyCon paClass
59           [paDataCon] = tyConDataCons paTyCon
60           paPRSel     = classSCSelId paClass 0
61
62       replicatePDVar    <- externalVar          dph_PArray_PRepr  (fsLit "replicatePD")
63       emptyPDVar        <- externalVar          dph_PArray_PRepr  (fsLit "emptyPD")
64       packByTagPDVar    <- externalVar          dph_PArray_PRepr  (fsLit "packByTagPD")
65       combines          <- mapM (externalVar dph_PArray_PRepr)
66                                 [mkFastString ("combine" ++ show i ++ "PD")
67                                         | i <- [2..mAX_DPH_COMBINE]]
68
69       let combinePDVars = listArray (2, mAX_DPH_COMBINE) combines
70
71
72       -- From dph-common:Data.Array.Parallel.PArray.Scalar
73       --     Scalar is the class of scalar values. 
74       --     The dictionary contains functions to coerce U.Arrays of scalars
75       --     to and from the PData representation.
76       scalarClass       <- externalClass        dph_PArray_Scalar (fsLit "Scalar")
77
78
79       -- From dph-common:Data.Array.Parallel.Lifted.PArray
80       --   A PArray (Parallel Array) holds the array length and some array elements
81       --   represented by the PData type family.
82       parrayTyCon       <- externalTyCon        dph_PArray_Base   (fsLit "PArray")
83       let [parrayDataCon] = tyConDataCons parrayTyCon
84
85       -- From dph-common:Data.Array.Parallel.PArray.Types
86       voidTyCon         <- externalTyCon        dph_PArray_Types  (fsLit "Void")
87       voidVar           <- externalVar          dph_PArray_Types  (fsLit "void")
88       fromVoidVar       <- externalVar          dph_PArray_Types  (fsLit "fromVoid")
89       wrapTyCon         <- externalTyCon        dph_PArray_Types  (fsLit "Wrap")
90       sum_tcs           <- mapM (externalTyCon  dph_PArray_Types) (numbered "Sum" 2 mAX_DPH_SUM)
91
92       -- from dph-common:Data.Array.Parallel.PArray.PDataInstances
93       pvoidVar          <- externalVar dph_PArray_PDataInstances  (fsLit "pvoid")
94       punitVar          <- externalVar dph_PArray_PDataInstances  (fsLit "punit")
95
96
97       closureTyCon      <- externalTyCon dph_Closure             (fsLit ":->")
98
99
100       -- From dph-common:Data.Array.Parallel.Lifted.Unboxed
101       sel_tys           <- mapM (externalType dph_Unboxed)
102                                 (numbered "Sel" 2 mAX_DPH_SUM)
103
104       sel_replicates    <- mapM (externalFun dph_Unboxed)
105                                 (numbered_hash "replicateSel" 2 mAX_DPH_SUM)
106
107       sel_picks         <- mapM (externalFun dph_Unboxed)
108                                 (numbered_hash "pickSel" 2 mAX_DPH_SUM)
109
110       sel_tags          <- mapM (externalFun dph_Unboxed)
111                                 (numbered "tagsSel" 2 mAX_DPH_SUM)
112
113       sel_els           <- mapM mk_elements
114                                 [(i,j) | i <- [2..mAX_DPH_SUM], j <- [0..i-1]]
115
116
117       let selTys        = listArray (2, mAX_DPH_SUM) sel_tys
118           selReplicates = listArray (2, mAX_DPH_SUM) sel_replicates
119           selPicks      = listArray (2, mAX_DPH_SUM) sel_picks
120           selTagss      = listArray (2, mAX_DPH_SUM) sel_tags
121           selEls        = array     ((2,0), (mAX_DPH_SUM, mAX_DPH_SUM)) sel_els
122           sumTyCons     = listArray (2, mAX_DPH_SUM) sum_tcs
123
124
125
126       closureVar       <- externalVar dph_Closure       (fsLit "closure")
127       applyVar         <- externalVar dph_Closure       (fsLit "$:")
128       liftedClosureVar <- externalVar dph_Closure       (fsLit "liftedClosure")
129       liftedApplyVar   <- externalVar dph_Closure       (fsLit "liftedApply")
130
131       scalar_map        <- externalVar  dph_Scalar      (fsLit "scalar_map")
132       scalar_zip2   <- externalVar      dph_Scalar      (fsLit "scalar_zipWith")
133       scalar_zips       <- mapM (externalVar dph_Scalar)
134                                 (numbered "scalar_zipWith" 3 mAX_DPH_SCALAR_ARGS)
135
136       let scalarZips    = listArray (1, mAX_DPH_SCALAR_ARGS)
137                                  (scalar_map : scalar_zip2 : scalar_zips)
138
139       closures          <- mapM (externalVar dph_Closure)
140                                 (numbered "closure" 1 mAX_DPH_SCALAR_ARGS)
141
142       let closureCtrFuns = listArray (1, mAX_DPH_COMBINE) closures
143
144       liftingContext    <- liftM (\u -> mkSysLocal (fsLit "lc") u intPrimTy)
145                                 newUnique
146
147       return   $ Builtins 
148                { dphModules       = mods
149                , parrayTyCon      = parrayTyCon
150                , parrayDataCon    = parrayDataCon
151                , pdataTyCon       = pdataTyCon
152                , paClass          = paClass
153                , paTyCon          = paTyCon
154                , paDataCon        = paDataCon
155                , paPRSel          = paPRSel
156                , preprTyCon       = preprTyCon
157                , prClass          = prClass
158                , prTyCon          = prTyCon
159                , prDataCon        = prDataCon
160                , voidTyCon        = voidTyCon
161                , wrapTyCon        = wrapTyCon
162                , selTys           = selTys
163                , selReplicates    = selReplicates
164                , selPicks         = selPicks
165                , selTagss         = selTagss
166                , selEls           = selEls
167                , sumTyCons        = sumTyCons
168                , closureTyCon     = closureTyCon
169                , voidVar          = voidVar
170                , pvoidVar         = pvoidVar
171                , fromVoidVar      = fromVoidVar
172                , punitVar         = punitVar
173                , closureVar       = closureVar
174                , applyVar         = applyVar
175                , liftedClosureVar = liftedClosureVar
176                , liftedApplyVar   = liftedApplyVar
177                , replicatePDVar   = replicatePDVar
178                , emptyPDVar       = emptyPDVar
179                , packByTagPDVar   = packByTagPDVar
180                , combinePDVars    = combinePDVars
181                , scalarClass      = scalarClass
182                , scalarZips       = scalarZips
183                , closureCtrFuns   = closureCtrFuns
184                , liftingContext   = liftingContext
185                }
186   where
187     -- Extract out all the modules we'll use.
188     -- These are the modules from the DPH base library that contain
189     --  the primitive array types and functions that vectorised code uses.
190     mods@(Modules 
191                 { dph_PArray_Base               = dph_PArray_Base
192                 , dph_PArray_Scalar             = dph_PArray_Scalar
193                 , dph_PArray_PRepr              = dph_PArray_PRepr
194                 , dph_PArray_PData              = dph_PArray_PData
195                 , dph_PArray_PDataInstances     = dph_PArray_PDataInstances
196                 , dph_PArray_Types              = dph_PArray_Types
197                 , dph_Repr                      = dph_Repr
198                 , dph_Closure                   = dph_Closure
199                 , dph_Scalar                    = dph_Scalar
200                 , dph_Unboxed                   = dph_Unboxed
201                 })
202       = dph_Modules pkg
203
204     load get_mod = dsLoadModule doc mod
205       where
206         mod = get_mod mods 
207         doc = ppr mod <+> ptext (sLit "is a DPH module")
208
209     -- Make a list of numbered strings in some range, eg foo3, foo4, foo5
210     numbered :: String -> Int -> Int -> [FastString]
211     numbered pfx m n = [mkFastString (pfx ++ show i) | i <- [m..n]]
212
213     numbered_hash :: String -> Int -> Int -> [FastString]
214     numbered_hash pfx m n = [mkFastString (pfx ++ show i ++ "#") | i <- [m..n]]
215
216     mk_elements :: (Int, Int) -> DsM ((Int, Int), CoreExpr)
217     mk_elements (i,j)
218       = do
219           v <- externalVar dph_Unboxed
220              $ mkFastString ("elementsSel" ++ show i ++ "_" ++ show j ++ "#")
221           return ((i,j), Var v)
222
223 -- | Get the mapping of names in the Prelude to names in the DPH library.
224 --
225 initBuiltinVars :: Bool   -- FIXME
226                 -> Builtins -> DsM [(Var, Var)]
227 initBuiltinVars compilingDPH (Builtins { dphModules = mods })
228   = do
229       uvars <- zipWithM externalVar umods ufs
230       vvars <- zipWithM externalVar vmods vfs
231       cvars <- zipWithM externalVar cmods cfs
232       return $ [(v,v) | v <- map dataConWorkId defaultDataConWorkers]
233                ++ zip (map dataConWorkId cons) cvars
234                ++ zip uvars vvars
235   where
236     (umods, ufs, vmods, vfs) = if compilingDPH then ([], [], [], []) else unzip4 (preludeVars mods)
237     (cons, cmods, cfs)       = unzip3 (preludeDataCons mods)
238
239     defaultDataConWorkers :: [DataCon]
240     defaultDataConWorkers = [trueDataCon, falseDataCon, unitDataCon]
241
242
243 preludeDataCons :: Modules -> [(DataCon, Module, FastString)]
244 preludeDataCons (Modules { dph_Prelude_Tuple = dph_Prelude_Tuple })
245   = [mk_tup n dph_Prelude_Tuple (mkFastString $ "tup" ++ show n) | n <- [2..3]]
246   where
247     mk_tup n mod name = (tupleCon Boxed n, mod, name)
248
249
250 -- | Get a list of names to `TyCon`s in the mock prelude.
251 initBuiltinTyCons :: Builtins -> DsM [(Name, TyCon)]
252 initBuiltinTyCons bi
253   = do
254       -- parr <- externalTyCon dph_Prelude_PArr (fsLit "PArr")
255       dft_tcs <- defaultTyCons
256       return $ (tyConName funTyCon, closureTyCon bi)
257              : (parrTyConName,      parrayTyCon bi)
258
259              -- FIXME: temporary
260              : (tyConName $ parrayTyCon bi, parrayTyCon bi)
261
262              : [(tyConName tc, tc) | tc <- dft_tcs]
263
264   where defaultTyCons :: DsM [TyCon]
265         defaultTyCons
266          = do   word8 <- dsLookupTyCon word8TyConName
267                 return [intTyCon, boolTyCon, doubleTyCon, word8]
268
269
270 -- | Get a list of names to `DataCon`s in the mock prelude.
271 initBuiltinDataCons :: Builtins -> [(Name, DataCon)]
272 initBuiltinDataCons _
273   = [(dataConName dc, dc)| dc <- defaultDataCons]
274   where defaultDataCons :: [DataCon]
275         defaultDataCons = [trueDataCon, falseDataCon, unitDataCon]
276
277
278 -- | Get the names of all buildin instance functions for the PA class.
279 initBuiltinPAs :: Builtins -> (InstEnv, InstEnv) -> DsM [(Name, Var)]
280 initBuiltinPAs (Builtins { dphModules = mods }) insts
281   = liftM (initBuiltinDicts insts) (externalClass (dph_PArray_PRepr mods) (fsLit "PA"))
282
283
284 -- | Get the names of all builtin instance functions for the PR class.
285 initBuiltinPRs :: Builtins -> (InstEnv, InstEnv) -> DsM [(Name, Var)]
286 initBuiltinPRs (Builtins { dphModules = mods }) insts
287   = liftM (initBuiltinDicts insts) (externalClass (dph_PArray_PData mods) (fsLit "PR"))
288
289
290 -- | Get the names of all DPH instance functions for this class.
291 initBuiltinDicts :: (InstEnv, InstEnv) -> Class -> [(Name, Var)]
292 initBuiltinDicts insts cls = map find $ classInstances insts cls
293   where
294     find i | [Just tc] <- instanceRoughTcs i    = (tc, instanceDFunId i)
295            | otherwise                          = pprPanic "Invalid DPH instance" (ppr i)
296
297
298 -- | Get a list of boxed `TyCons` in the mock prelude. This is Int only.
299 initBuiltinBoxedTyCons :: Builtins -> DsM [(Name, TyCon)]
300 initBuiltinBoxedTyCons 
301   = return . builtinBoxedTyCons
302   where builtinBoxedTyCons :: Builtins -> [(Name, TyCon)]
303         builtinBoxedTyCons _ 
304                 = [(tyConName intPrimTyCon, intTyCon)]
305
306 -- | Get a list of all scalar functions in the mock prelude.
307 --
308 initBuiltinScalars :: Bool 
309                    -> Builtins -> DsM [Var]
310 initBuiltinScalars True  _bi = return []
311 initBuiltinScalars False bi  = mapM (uncurry externalVar) (preludeScalars $ dphModules bi)
312
313 -- | Lookup some variable given its name and the module that contains it.
314 externalVar :: Module -> FastString -> DsM Var
315 externalVar mod fs
316   = dsLookupGlobalId =<< lookupOrig mod (mkVarOccFS fs)
317
318
319 -- | Like `externalVar` but wrap the `Var` in a `CoreExpr`
320 externalFun :: Module -> FastString -> DsM CoreExpr
321 externalFun mod fs
322  = do var <- externalVar mod fs
323       return $ Var var
324
325
326 -- | Lookup some `TyCon` given its name and the module that contains it.
327 externalTyCon :: Module -> FastString -> DsM TyCon
328 externalTyCon mod fs
329   = dsLookupTyCon =<< lookupOrig mod (mkTcOccFS fs)
330
331
332 -- | Lookup some `Type` given its name and the module that contains it.
333 externalType :: Module -> FastString -> DsM Type
334 externalType mod fs
335  = do  tycon <- externalTyCon mod fs
336        return $ mkTyConApp tycon []
337
338
339 -- | Lookup some `Class` given its name and the module that contains it.
340 externalClass :: Module -> FastString -> DsM Class
341 externalClass mod fs
342   = dsLookupClass =<< lookupOrig mod (mkClsOccFS fs)
343