Update vectoriser now that Scalar has moved
[ghc-hetmet.git] / compiler / vectorise / Vectorise / Builtins / Initialise.hs
1
2
3 module Vectorise.Builtins.Initialise (
4         -- * Initialisation
5         initBuiltins, initBuiltinVars, initBuiltinTyCons, initBuiltinDataCons,
6         initBuiltinPAs, initBuiltinPRs,
7         initBuiltinBoxedTyCons, initBuiltinScalars,
8 ) where
9 import Vectorise.Builtins.Base
10 import Vectorise.Builtins.Modules
11 import Vectorise.Builtins.Prelude
12
13 import BasicTypes
14 import PrelNames
15 import TysPrim
16 import DsMonad
17 import IfaceEnv
18 import InstEnv
19 import TysWiredIn
20 import DataCon
21 import TyCon
22 import Class
23 import CoreSyn
24 import Type
25 import Name
26 import Module
27 import Id
28 import FastString
29 import Outputable
30
31 import Control.Monad
32 import Data.Array
33 import Data.List
34
35 -- | Create the initial map of builtin types and functions.
36 initBuiltins 
37         :: PackageId    -- ^ package id the builtins are in, eg dph-common
38         -> DsM Builtins
39
40 initBuiltins pkg
41  = do mapM_ load dph_Orphans
42
43       -- From dph-common:Data.Array.Parallel.PArray.PData
44       --   PData is a type family that maps an element type onto the type
45       --   we use to hold an array of those elements.
46       pdataTyCon        <- externalTyCon        dph_PArray_PData  (fsLit "PData")
47
48       --   PR is a type class that holds the primitive operators we can 
49       --   apply to array data. Its functions take arrays in terms of PData types.
50       prClass           <- externalClass        dph_PArray_PData  (fsLit "PR")
51       let prTyCon     = classTyCon prClass
52           [prDataCon] = tyConDataCons prTyCon
53
54
55       -- From dph-common:Data.Array.Parallel.PArray.Scalar
56       --   Scalar is the class of scalar values. 
57       --   The dictionary contains functions to coerce U.Arrays of scalars
58       --   to and from the PData representation.
59       scalarClass       <- externalClass        dph_PArray_Scalar (fsLit "Scalar")
60
61
62       -- From dph-common:Data.Array.Parallel.Lifted.PArray
63       --   A PArray (Parallel Array) holds the array length and some array elements
64       --   represented by the PData type family.
65       parrayTyCon       <- externalTyCon        dph_PArray      (fsLit "PArray")
66       let [parrayDataCon] = tyConDataCons parrayTyCon
67
68       paClass           <- externalClass        dph_PArray      (fsLit "PA")
69       let paTyCon     = classTyCon paClass
70           [paDataCon] = tyConDataCons paTyCon
71           paPRSel     = classSCSelId paClass 0
72
73       preprTyCon        <- externalTyCon        dph_PArray      (fsLit "PRepr")
74
75       closureTyCon      <- externalTyCon dph_Closure            (fsLit ":->")
76
77       -- From dph-common:Data.Array.Parallel.Lifted.Repr
78       voidTyCon         <- externalTyCon        dph_Repr        (fsLit "Void")
79       wrapTyCon         <- externalTyCon        dph_Repr        (fsLit "Wrap")
80
81       -- From dph-common:Data.Array.Parallel.Lifted.Unboxed
82       sel_tys           <- mapM (externalType dph_Unboxed)
83                                 (numbered "Sel" 2 mAX_DPH_SUM)
84
85       sel_replicates    <- mapM (externalFun dph_Unboxed)
86                                 (numbered_hash "replicateSel" 2 mAX_DPH_SUM)
87
88       sel_picks         <- mapM (externalFun dph_Unboxed)
89                                 (numbered_hash "pickSel" 2 mAX_DPH_SUM)
90
91       sel_tags          <- mapM (externalFun dph_Unboxed)
92                                 (numbered "tagsSel" 2 mAX_DPH_SUM)
93
94       sel_els           <- mapM mk_elements
95                                 [(i,j) | i <- [2..mAX_DPH_SUM], j <- [0..i-1]]
96
97       sum_tcs           <- mapM (externalTyCon dph_Repr)
98                                 (numbered "Sum" 2 mAX_DPH_SUM)
99
100       let selTys        = listArray (2, mAX_DPH_SUM) sel_tys
101           selReplicates = listArray (2, mAX_DPH_SUM) sel_replicates
102           selPicks      = listArray (2, mAX_DPH_SUM) sel_picks
103           selTagss      = listArray (2, mAX_DPH_SUM) sel_tags
104           selEls        = array     ((2,0), (mAX_DPH_SUM, mAX_DPH_SUM)) sel_els
105           sumTyCons     = listArray (2, mAX_DPH_SUM) sum_tcs
106
107
108       voidVar          <- externalVar dph_Repr          (fsLit "void")
109       pvoidVar         <- externalVar dph_Repr          (fsLit "pvoid")
110       fromVoidVar      <- externalVar dph_Repr          (fsLit "fromVoid")
111       punitVar         <- externalVar dph_Repr          (fsLit "punit")
112       closureVar       <- externalVar dph_Closure       (fsLit "closure")
113       applyVar         <- externalVar dph_Closure       (fsLit "$:")
114       liftedClosureVar <- externalVar dph_Closure       (fsLit "liftedClosure")
115       liftedApplyVar   <- externalVar dph_Closure       (fsLit "liftedApply")
116       replicatePDVar   <- externalVar dph_PArray        (fsLit "replicatePD")
117       emptyPDVar       <- externalVar dph_PArray        (fsLit "emptyPD")
118       packByTagPDVar   <- externalVar dph_PArray        (fsLit "packByTagPD")
119
120       combines          <- mapM (externalVar dph_PArray)
121                                 [mkFastString ("combine" ++ show i ++ "PD")
122                                         | i <- [2..mAX_DPH_COMBINE]]
123       let combinePDVars = listArray (2, mAX_DPH_COMBINE) combines
124
125       scalar_map        <- externalVar  dph_Scalar      (fsLit "scalar_map")
126       scalar_zip2       <- externalVar  dph_Scalar      (fsLit "scalar_zipWith")
127       scalar_zips       <- mapM (externalVar dph_Scalar)
128                                 (numbered "scalar_zipWith" 3 mAX_DPH_SCALAR_ARGS)
129
130       let scalarZips    = listArray (1, mAX_DPH_SCALAR_ARGS)
131                                  (scalar_map : scalar_zip2 : scalar_zips)
132
133       closures          <- mapM (externalVar dph_Closure)
134                                 (numbered "closure" 1 mAX_DPH_SCALAR_ARGS)
135
136       let closureCtrFuns = listArray (1, mAX_DPH_COMBINE) closures
137
138       liftingContext    <- liftM (\u -> mkSysLocal (fsLit "lc") u intPrimTy)
139                                 newUnique
140
141       return   $ Builtins 
142                { dphModules       = mods
143                , parrayTyCon      = parrayTyCon
144                , parrayDataCon    = parrayDataCon
145                , pdataTyCon       = pdataTyCon
146                , paClass          = paClass
147                , paTyCon          = paTyCon
148                , paDataCon        = paDataCon
149                , paPRSel          = paPRSel
150                , preprTyCon       = preprTyCon
151                , prClass          = prClass
152                , prTyCon          = prTyCon
153                , prDataCon        = prDataCon
154                , voidTyCon        = voidTyCon
155                , wrapTyCon        = wrapTyCon
156                , selTys           = selTys
157                , selReplicates    = selReplicates
158                , selPicks         = selPicks
159                , selTagss         = selTagss
160                , selEls           = selEls
161                , sumTyCons        = sumTyCons
162                , closureTyCon     = closureTyCon
163                , voidVar          = voidVar
164                , pvoidVar         = pvoidVar
165                , fromVoidVar      = fromVoidVar
166                , punitVar         = punitVar
167                , closureVar       = closureVar
168                , applyVar         = applyVar
169                , liftedClosureVar = liftedClosureVar
170                , liftedApplyVar   = liftedApplyVar
171                , replicatePDVar   = replicatePDVar
172                , emptyPDVar       = emptyPDVar
173                , packByTagPDVar   = packByTagPDVar
174                , combinePDVars    = combinePDVars
175                , scalarClass      = scalarClass
176                , scalarZips       = scalarZips
177                , closureCtrFuns   = closureCtrFuns
178                , liftingContext   = liftingContext
179                }
180   where
181     -- Extract out all the modules we'll use.
182     -- These are the modules from the DPH base library that contain
183     --  the primitive array types and functions that vectorised code uses.
184     mods@(Modules 
185                 { dph_PArray            = dph_PArray
186                 , dph_PArray_PData      = dph_PArray_PData
187                 , dph_PArray_Scalar     = dph_PArray_Scalar
188                 , dph_Repr              = dph_Repr
189                 , dph_Closure           = dph_Closure
190                 , dph_Scalar            = dph_Scalar
191                 , dph_Unboxed           = dph_Unboxed
192                 })
193       = dph_Modules pkg
194
195     load get_mod = dsLoadModule doc mod
196       where
197         mod = get_mod mods 
198         doc = ppr mod <+> ptext (sLit "is a DPH module")
199
200     -- Make a list of numbered strings in some range, eg foo3, foo4, foo5
201     numbered :: String -> Int -> Int -> [FastString]
202     numbered pfx m n = [mkFastString (pfx ++ show i) | i <- [m..n]]
203
204     numbered_hash :: String -> Int -> Int -> [FastString]
205     numbered_hash pfx m n = [mkFastString (pfx ++ show i ++ "#") | i <- [m..n]]
206
207     mk_elements :: (Int, Int) -> DsM ((Int, Int), CoreExpr)
208     mk_elements (i,j)
209       = do
210           v <- externalVar dph_Unboxed
211              $ mkFastString ("elementsSel" ++ show i ++ "_" ++ show j ++ "#")
212           return ((i,j), Var v)
213
214 -- | Get the mapping of names in the Prelude to names in the DPH library.
215 --
216 initBuiltinVars :: Bool   -- FIXME
217                 -> Builtins -> DsM [(Var, Var)]
218 initBuiltinVars compilingDPH (Builtins { dphModules = mods })
219   = do
220       uvars <- zipWithM externalVar umods ufs
221       vvars <- zipWithM externalVar vmods vfs
222       cvars <- zipWithM externalVar cmods cfs
223       return $ [(v,v) | v <- map dataConWorkId defaultDataConWorkers]
224                ++ zip (map dataConWorkId cons) cvars
225                ++ zip uvars vvars
226   where
227     (umods, ufs, vmods, vfs) = if compilingDPH then ([], [], [], []) else unzip4 (preludeVars mods)
228     (cons, cmods, cfs)       = unzip3 (preludeDataCons mods)
229
230     defaultDataConWorkers :: [DataCon]
231     defaultDataConWorkers = [trueDataCon, falseDataCon, unitDataCon]
232
233
234 preludeDataCons :: Modules -> [(DataCon, Module, FastString)]
235 preludeDataCons (Modules { dph_Prelude_Tuple = dph_Prelude_Tuple })
236   = [mk_tup n dph_Prelude_Tuple (mkFastString $ "tup" ++ show n) | n <- [2..3]]
237   where
238     mk_tup n mod name = (tupleCon Boxed n, mod, name)
239
240
241 -- | Get a list of names to `TyCon`s in the mock prelude.
242 initBuiltinTyCons :: Builtins -> DsM [(Name, TyCon)]
243 initBuiltinTyCons bi
244   = do
245       -- parr <- externalTyCon dph_Prelude_PArr (fsLit "PArr")
246       dft_tcs <- defaultTyCons
247       return $ (tyConName funTyCon, closureTyCon bi)
248              : (parrTyConName,      parrayTyCon bi)
249
250              -- FIXME: temporary
251              : (tyConName $ parrayTyCon bi, parrayTyCon bi)
252
253              : [(tyConName tc, tc) | tc <- dft_tcs]
254
255   where defaultTyCons :: DsM [TyCon]
256         defaultTyCons
257          = do   word8 <- dsLookupTyCon word8TyConName
258                 return [intTyCon, boolTyCon, doubleTyCon, word8]
259
260
261 -- | Get a list of names to `DataCon`s in the mock prelude.
262 initBuiltinDataCons :: Builtins -> [(Name, DataCon)]
263 initBuiltinDataCons _
264   = [(dataConName dc, dc)| dc <- defaultDataCons]
265   where defaultDataCons :: [DataCon]
266         defaultDataCons = [trueDataCon, falseDataCon, unitDataCon]
267
268
269 -- | Get the names of all buildin instance functions for the PA class.
270 initBuiltinPAs :: Builtins -> (InstEnv, InstEnv) -> DsM [(Name, Var)]
271 initBuiltinPAs (Builtins { dphModules = mods }) insts
272   = liftM (initBuiltinDicts insts) (externalClass (dph_PArray mods) (fsLit "PA"))
273
274
275 -- | Get the names of all builtin instance functions for the PR class.
276 initBuiltinPRs :: Builtins -> (InstEnv, InstEnv) -> DsM [(Name, Var)]
277 initBuiltinPRs (Builtins { dphModules = mods }) insts
278   = liftM (initBuiltinDicts insts) (externalClass (dph_PArray_PData mods) (fsLit "PR"))
279
280
281 -- | Get the names of all DPH instance functions for this class.
282 initBuiltinDicts :: (InstEnv, InstEnv) -> Class -> [(Name, Var)]
283 initBuiltinDicts insts cls = map find $ classInstances insts cls
284   where
285     find i | [Just tc] <- instanceRoughTcs i    = (tc, instanceDFunId i)
286            | otherwise                          = pprPanic "Invalid DPH instance" (ppr i)
287
288
289 -- | Get a list of boxed `TyCons` in the mock prelude. This is Int only.
290 initBuiltinBoxedTyCons :: Builtins -> DsM [(Name, TyCon)]
291 initBuiltinBoxedTyCons 
292   = return . builtinBoxedTyCons
293   where builtinBoxedTyCons :: Builtins -> [(Name, TyCon)]
294         builtinBoxedTyCons _ 
295                 = [(tyConName intPrimTyCon, intTyCon)]
296
297 -- | Get a list of all scalar functions in the mock prelude.
298 --
299 initBuiltinScalars :: Bool 
300                    -> Builtins -> DsM [Var]
301 initBuiltinScalars True  _bi = return []
302 initBuiltinScalars False bi  = mapM (uncurry externalVar) (preludeScalars $ dphModules bi)
303
304 -- | Lookup some variable given its name and the module that contains it.
305 externalVar :: Module -> FastString -> DsM Var
306 externalVar mod fs
307   = dsLookupGlobalId =<< lookupOrig mod (mkVarOccFS fs)
308
309
310 -- | Like `externalVar` but wrap the `Var` in a `CoreExpr`
311 externalFun :: Module -> FastString -> DsM CoreExpr
312 externalFun mod fs
313  = do var <- externalVar mod fs
314       return $ Var var
315
316
317 -- | Lookup some `TyCon` given its name and the module that contains it.
318 externalTyCon :: Module -> FastString -> DsM TyCon
319 externalTyCon mod fs
320   = dsLookupTyCon =<< lookupOrig mod (mkTcOccFS fs)
321
322
323 -- | Lookup some `Type` given its name and the module that contains it.
324 externalType :: Module -> FastString -> DsM Type
325 externalType mod fs
326  = do  tycon <- externalTyCon mod fs
327        return $ mkTyConApp tycon []
328
329
330 -- | Lookup some `Class` given its name and the module that contains it.
331 externalClass :: Module -> FastString -> DsM Class
332 externalClass mod fs
333   = dsLookupClass =<< lookupOrig mod (mkClsOccFS fs)
334