5f4735c0db86cd1d5969b2eea1d439e3cfa048fd
[ghc-hetmet.git] / compiler / vectorise / Vectorise / Builtins / Initialise.hs
1
2
3 module Vectorise.Builtins.Initialise (
4         -- * Initialisation
5         initBuiltins, initBuiltinVars, initBuiltinTyCons, initBuiltinDataCons,
6         initBuiltinPAs, initBuiltinPRs,
7         initBuiltinBoxedTyCons, initBuiltinScalars,
8 ) where
9 import Vectorise.Builtins.Base
10 import Vectorise.Builtins.Modules
11 import Vectorise.Builtins.Prelude
12
13 import BasicTypes
14 import PrelNames
15 import TysPrim
16 import DsMonad
17 import IfaceEnv
18 import InstEnv
19 import TysWiredIn
20 import DataCon
21 import TyCon
22 import Class
23 import CoreSyn
24 import Type
25 import Name
26 import Module
27 import Id
28 import FastString
29 import Outputable
30
31 import Control.Monad
32 import Data.Array
33 import Data.List
34
35 -- | Create the initial map of builtin types and functions.
36 initBuiltins 
37         :: PackageId    -- ^ package id the builtins are in, eg dph-common
38         -> DsM Builtins
39
40 initBuiltins pkg
41  = do mapM_ load dph_Orphans
42
43       -- From dph-common:Data.Array.Parallel.PArray.PData
44       --   PData is a type family that maps an element type onto the type
45       --   we use to hold an array of those elements.
46       pdataTyCon        <- externalTyCon        dph_PData       (fsLit "PData")
47
48       --   PR is a type class that holds the primitive operators we can 
49       --   apply to array data. Its functions take arrays in terms of PData types.
50       prClass           <- externalClass        dph_PData      (fsLit "PR")
51       let prTyCon     = classTyCon prClass
52           [prDataCon] = tyConDataCons prTyCon
53
54       -- From dph-common:Data.Array.Parallel.Lifted.PArray
55       --   A PArray (Parallel Array) holds the array length and some array elements
56       --   represented by the PData type family.
57       parrayTyCon       <- externalTyCon        dph_PArray      (fsLit "PArray")
58       let [parrayDataCon] = tyConDataCons parrayTyCon
59
60       paClass           <- externalClass        dph_PArray      (fsLit "PA")
61       let paTyCon     = classTyCon paClass
62           [paDataCon] = tyConDataCons paTyCon
63           paPRSel     = classSCSelId paClass 0
64
65       preprTyCon        <- externalTyCon        dph_PArray      (fsLit "PRepr")
66
67       closureTyCon      <- externalTyCon dph_Closure            (fsLit ":->")
68
69       -- From dph-common:Data.Array.Parallel.Lifted.Repr
70       voidTyCon         <- externalTyCon        dph_Repr        (fsLit "Void")
71       wrapTyCon         <- externalTyCon        dph_Repr        (fsLit "Wrap")
72
73       -- From dph-common:Data.Array.Parallel.Lifted.Unboxed
74       sel_tys           <- mapM (externalType dph_Unboxed)
75                                 (numbered "Sel" 2 mAX_DPH_SUM)
76
77       sel_replicates    <- mapM (externalFun dph_Unboxed)
78                                 (numbered_hash "replicateSel" 2 mAX_DPH_SUM)
79
80       sel_picks         <- mapM (externalFun dph_Unboxed)
81                                 (numbered_hash "pickSel" 2 mAX_DPH_SUM)
82
83       sel_tags          <- mapM (externalFun dph_Unboxed)
84                                 (numbered "tagsSel" 2 mAX_DPH_SUM)
85
86       sel_els           <- mapM mk_elements
87                                 [(i,j) | i <- [2..mAX_DPH_SUM], j <- [0..i-1]]
88
89       sum_tcs           <- mapM (externalTyCon dph_Repr)
90                                 (numbered "Sum" 2 mAX_DPH_SUM)
91
92       let selTys        = listArray (2, mAX_DPH_SUM) sel_tys
93           selReplicates = listArray (2, mAX_DPH_SUM) sel_replicates
94           selPicks      = listArray (2, mAX_DPH_SUM) sel_picks
95           selTagss      = listArray (2, mAX_DPH_SUM) sel_tags
96           selEls        = array     ((2,0), (mAX_DPH_SUM, mAX_DPH_SUM)) sel_els
97           sumTyCons     = listArray (2, mAX_DPH_SUM) sum_tcs
98
99
100       voidVar          <- externalVar dph_Repr          (fsLit "void")
101       pvoidVar         <- externalVar dph_Repr          (fsLit "pvoid")
102       fromVoidVar      <- externalVar dph_Repr          (fsLit "fromVoid")
103       punitVar         <- externalVar dph_Repr          (fsLit "punit")
104       closureVar       <- externalVar dph_Closure       (fsLit "closure")
105       applyVar         <- externalVar dph_Closure       (fsLit "$:")
106       liftedClosureVar <- externalVar dph_Closure       (fsLit "liftedClosure")
107       liftedApplyVar   <- externalVar dph_Closure       (fsLit "liftedApply")
108       replicatePDVar   <- externalVar dph_PArray        (fsLit "replicatePD")
109       emptyPDVar       <- externalVar dph_PArray        (fsLit "emptyPD")
110       packByTagPDVar   <- externalVar dph_PArray        (fsLit "packByTagPD")
111
112       combines          <- mapM (externalVar dph_PArray)
113                                 [mkFastString ("combine" ++ show i ++ "PD")
114                                         | i <- [2..mAX_DPH_COMBINE]]
115       let combinePDVars = listArray (2, mAX_DPH_COMBINE) combines
116
117       scalarClass       <- externalClass dph_PArray     (fsLit "Scalar")
118       scalar_map        <- externalVar  dph_Scalar      (fsLit "scalar_map")
119       scalar_zip2       <- externalVar  dph_Scalar      (fsLit "scalar_zipWith")
120       scalar_zips       <- mapM (externalVar dph_Scalar)
121                                 (numbered "scalar_zipWith" 3 mAX_DPH_SCALAR_ARGS)
122
123       let scalarZips    = listArray (1, mAX_DPH_SCALAR_ARGS)
124                                  (scalar_map : scalar_zip2 : scalar_zips)
125
126       closures          <- mapM (externalVar dph_Closure)
127                                 (numbered "closure" 1 mAX_DPH_SCALAR_ARGS)
128
129       let closureCtrFuns = listArray (1, mAX_DPH_COMBINE) closures
130
131       liftingContext    <- liftM (\u -> mkSysLocal (fsLit "lc") u intPrimTy)
132                                 newUnique
133
134       return   $ Builtins 
135                { dphModules       = mods
136                , parrayTyCon      = parrayTyCon
137                , parrayDataCon    = parrayDataCon
138                , pdataTyCon       = pdataTyCon
139                , paClass          = paClass
140                , paTyCon          = paTyCon
141                , paDataCon        = paDataCon
142                , paPRSel          = paPRSel
143                , preprTyCon       = preprTyCon
144                , prClass          = prClass
145                , prTyCon          = prTyCon
146                , prDataCon        = prDataCon
147                , voidTyCon        = voidTyCon
148                , wrapTyCon        = wrapTyCon
149                , selTys           = selTys
150                , selReplicates    = selReplicates
151                , selPicks         = selPicks
152                , selTagss         = selTagss
153                , selEls           = selEls
154                , sumTyCons        = sumTyCons
155                , closureTyCon     = closureTyCon
156                , voidVar          = voidVar
157                , pvoidVar         = pvoidVar
158                , fromVoidVar      = fromVoidVar
159                , punitVar         = punitVar
160                , closureVar       = closureVar
161                , applyVar         = applyVar
162                , liftedClosureVar = liftedClosureVar
163                , liftedApplyVar   = liftedApplyVar
164                , replicatePDVar   = replicatePDVar
165                , emptyPDVar       = emptyPDVar
166                , packByTagPDVar   = packByTagPDVar
167                , combinePDVars    = combinePDVars
168                , scalarClass      = scalarClass
169                , scalarZips       = scalarZips
170                , closureCtrFuns   = closureCtrFuns
171                , liftingContext   = liftingContext
172                }
173   where
174     -- Extract out all the modules we'll use.
175     -- These are the modules from the DPH base library that contain
176     --  the primitive array types and functions that vectorised code uses.
177     mods@(Modules 
178                 { dph_PArray    = dph_PArray
179                 , dph_PData     = dph_PData
180                 , dph_Repr      = dph_Repr
181                 , dph_Closure   = dph_Closure
182                 , dph_Scalar    = dph_Scalar
183                 , dph_Unboxed   = dph_Unboxed
184                 })
185       = dph_Modules pkg
186
187     load get_mod = dsLoadModule doc mod
188       where
189         mod = get_mod mods 
190         doc = ppr mod <+> ptext (sLit "is a DPH module")
191
192     -- Make a list of numbered strings in some range, eg foo3, foo4, foo5
193     numbered :: String -> Int -> Int -> [FastString]
194     numbered pfx m n = [mkFastString (pfx ++ show i) | i <- [m..n]]
195
196     numbered_hash :: String -> Int -> Int -> [FastString]
197     numbered_hash pfx m n = [mkFastString (pfx ++ show i ++ "#") | i <- [m..n]]
198
199     mk_elements :: (Int, Int) -> DsM ((Int, Int), CoreExpr)
200     mk_elements (i,j)
201       = do
202           v <- externalVar dph_Unboxed
203              $ mkFastString ("elementsSel" ++ show i ++ "_" ++ show j ++ "#")
204           return ((i,j), Var v)
205
206 -- | Get the mapping of names in the Prelude to names in the DPH library.
207 --
208 initBuiltinVars :: Bool   -- FIXME
209                 -> Builtins -> DsM [(Var, Var)]
210 initBuiltinVars compilingDPH (Builtins { dphModules = mods })
211   = do
212       uvars <- zipWithM externalVar umods ufs
213       vvars <- zipWithM externalVar vmods vfs
214       cvars <- zipWithM externalVar cmods cfs
215       return $ [(v,v) | v <- map dataConWorkId defaultDataConWorkers]
216                ++ zip (map dataConWorkId cons) cvars
217                ++ zip uvars vvars
218   where
219     (umods, ufs, vmods, vfs) = if compilingDPH then ([], [], [], []) else unzip4 (preludeVars mods)
220     (cons, cmods, cfs)       = unzip3 (preludeDataCons mods)
221
222     defaultDataConWorkers :: [DataCon]
223     defaultDataConWorkers = [trueDataCon, falseDataCon, unitDataCon]
224
225
226 preludeDataCons :: Modules -> [(DataCon, Module, FastString)]
227 preludeDataCons (Modules { dph_Prelude_Tuple = dph_Prelude_Tuple })
228   = [mk_tup n dph_Prelude_Tuple (mkFastString $ "tup" ++ show n) | n <- [2..3]]
229   where
230     mk_tup n mod name = (tupleCon Boxed n, mod, name)
231
232
233 -- | Get a list of names to `TyCon`s in the mock prelude.
234 initBuiltinTyCons :: Builtins -> DsM [(Name, TyCon)]
235 initBuiltinTyCons bi
236   = do
237       -- parr <- externalTyCon dph_Prelude_PArr (fsLit "PArr")
238       dft_tcs <- defaultTyCons
239       return $ (tyConName funTyCon, closureTyCon bi)
240              : (parrTyConName,      parrayTyCon bi)
241
242              -- FIXME: temporary
243              : (tyConName $ parrayTyCon bi, parrayTyCon bi)
244
245              : [(tyConName tc, tc) | tc <- dft_tcs]
246
247   where defaultTyCons :: DsM [TyCon]
248         defaultTyCons
249          = do   word8 <- dsLookupTyCon word8TyConName
250                 return [intTyCon, boolTyCon, doubleTyCon, word8]
251
252
253 -- | Get a list of names to `DataCon`s in the mock prelude.
254 initBuiltinDataCons :: Builtins -> [(Name, DataCon)]
255 initBuiltinDataCons _
256   = [(dataConName dc, dc)| dc <- defaultDataCons]
257   where defaultDataCons :: [DataCon]
258         defaultDataCons = [trueDataCon, falseDataCon, unitDataCon]
259
260
261 -- | Get the names of all buildin instance functions for the PA class.
262 initBuiltinPAs :: Builtins -> (InstEnv, InstEnv) -> DsM [(Name, Var)]
263 initBuiltinPAs (Builtins { dphModules = mods }) insts
264   = liftM (initBuiltinDicts insts) (externalClass (dph_PArray mods) (fsLit "PA"))
265
266
267 -- | Get the names of all builtin instance functions for the PR class.
268 initBuiltinPRs :: Builtins -> (InstEnv, InstEnv) -> DsM [(Name, Var)]
269 initBuiltinPRs (Builtins { dphModules = mods }) insts
270   = liftM (initBuiltinDicts insts) (externalClass (dph_PData mods) (fsLit "PR"))
271
272
273 -- | Get the names of all DPH instance functions for this class.
274 initBuiltinDicts :: (InstEnv, InstEnv) -> Class -> [(Name, Var)]
275 initBuiltinDicts insts cls = map find $ classInstances insts cls
276   where
277     find i | [Just tc] <- instanceRoughTcs i    = (tc, instanceDFunId i)
278            | otherwise                          = pprPanic "Invalid DPH instance" (ppr i)
279
280
281 -- | Get a list of boxed `TyCons` in the mock prelude. This is Int only.
282 initBuiltinBoxedTyCons :: Builtins -> DsM [(Name, TyCon)]
283 initBuiltinBoxedTyCons 
284   = return . builtinBoxedTyCons
285   where builtinBoxedTyCons :: Builtins -> [(Name, TyCon)]
286         builtinBoxedTyCons _ 
287                 = [(tyConName intPrimTyCon, intTyCon)]
288
289 -- | Get a list of all scalar functions in the mock prelude.
290 --
291 initBuiltinScalars :: Bool 
292                    -> Builtins -> DsM [Var]
293 initBuiltinScalars True  _bi = return []
294 initBuiltinScalars False bi  = mapM (uncurry externalVar) (preludeScalars $ dphModules bi)
295
296 -- | Lookup some variable given its name and the module that contains it.
297 externalVar :: Module -> FastString -> DsM Var
298 externalVar mod fs
299   = dsLookupGlobalId =<< lookupOrig mod (mkVarOccFS fs)
300
301
302 -- | Like `externalVar` but wrap the `Var` in a `CoreExpr`
303 externalFun :: Module -> FastString -> DsM CoreExpr
304 externalFun mod fs
305  = do var <- externalVar mod fs
306       return $ Var var
307
308
309 -- | Lookup some `TyCon` given its name and the module that contains it.
310 externalTyCon :: Module -> FastString -> DsM TyCon
311 externalTyCon mod fs
312   = dsLookupTyCon =<< lookupOrig mod (mkTcOccFS fs)
313
314
315 -- | Lookup some `Type` given its name and the module that contains it.
316 externalType :: Module -> FastString -> DsM Type
317 externalType mod fs
318  = do  tycon <- externalTyCon mod fs
319        return $ mkTyConApp tycon []
320
321
322 -- | Lookup some `Class` given its name and the module that contains it.
323 externalClass :: Module -> FastString -> DsM Class
324 externalClass mod fs
325   = dsLookupClass =<< lookupOrig mod (mkClsOccFS fs)
326