Fix vectorisation of sum type constructors
[ghc-hetmet.git] / compiler / vectorise / VectBuiltIn.hs
1 module VectBuiltIn (
2   Builtins(..), sumTyCon, prodTyCon,
3   initBuiltins, initBuiltinTyCons, initBuiltinPAs, initBuiltinPRs,
4
5   primMethod, primPArray
6 ) where
7
8 #include "HsVersions.h"
9
10 import DsMonad
11 import IfaceEnv        ( lookupOrig )
12
13 import Module          ( Module )
14 import DataCon         ( DataCon )
15 import TyCon           ( TyCon, tyConName, tyConDataCons )
16 import Var             ( Var )
17 import Id              ( mkSysLocal )
18 import Name            ( Name, getOccString )
19 import NameEnv
20 import OccName
21
22 import TypeRep         ( funTyCon )
23 import Type            ( Type )
24 import TysPrim
25 import TysWiredIn      ( unitTyCon, tupleTyCon, intTyConName )
26 import PrelNames
27 import BasicTypes      ( Boxity(..) )
28
29 import FastString
30 import Outputable
31
32 import Data.Array
33 import Control.Monad   ( liftM, zipWithM )
34
35 mAX_NDP_PROD :: Int
36 mAX_NDP_PROD = 3
37
38 mAX_NDP_SUM :: Int
39 mAX_NDP_SUM = 3
40
41 data Builtins = Builtins {
42                   parrayTyCon      :: TyCon
43                 , paTyCon          :: TyCon
44                 , paDataCon        :: DataCon
45                 , preprTyCon       :: TyCon
46                 , prTyCon          :: TyCon
47                 , prDataCon        :: DataCon
48                 , parrayIntPrimTyCon :: TyCon
49                 , sumTyCons        :: Array Int TyCon
50                 , closureTyCon     :: TyCon
51                 , mkPRVar          :: Var
52                 , mkClosureVar     :: Var
53                 , applyClosureVar  :: Var
54                 , mkClosurePVar    :: Var
55                 , applyClosurePVar :: Var
56                 , replicatePAIntPrimVar :: Var
57                 , upToPAIntPrimVar :: Var
58                 , lengthPAVar      :: Var
59                 , replicatePAVar   :: Var
60                 , emptyPAVar       :: Var
61                 -- , packPAVar        :: Var
62                 -- , combinePAVar     :: Var
63                 , liftingContext   :: Var
64                 }
65
66 sumTyCon :: Int -> Builtins -> TyCon
67 sumTyCon n bi
68   | n >= 2 && n <= mAX_NDP_SUM = sumTyCons bi ! n
69   | otherwise = pprPanic "sumTyCon" (ppr n)
70
71 prodTyCon :: Int -> Builtins -> TyCon
72 prodTyCon n bi
73   | n >= 2 && n <= mAX_NDP_PROD = tupleTyCon Boxed n
74   | otherwise = pprPanic "prodTyCon" (ppr n)
75
76 initBuiltins :: DsM Builtins
77 initBuiltins
78   = do
79       parrayTyCon  <- dsLookupTyCon parrayTyConName
80       paTyCon      <- dsLookupTyCon paTyConName
81       let [paDataCon] = tyConDataCons paTyCon
82       preprTyCon   <- dsLookupTyCon preprTyConName
83       prTyCon      <- dsLookupTyCon prTyConName
84       let [prDataCon] = tyConDataCons prTyCon
85       parrayIntPrimTyCon <- dsLookupTyCon parrayIntPrimTyConName
86       closureTyCon <- dsLookupTyCon closureTyConName
87
88       sum_tcs <- mapM (lookupExternalTyCon nDP_REPR)
89                       [mkFastString ("Sum" ++ show i) | i <- [2..mAX_NDP_SUM]]
90
91       let sumTyCons = listArray (2, mAX_NDP_SUM) sum_tcs
92
93       mkPRVar          <- dsLookupGlobalId mkPRName
94       mkClosureVar     <- dsLookupGlobalId mkClosureName
95       applyClosureVar  <- dsLookupGlobalId applyClosureName
96       mkClosurePVar    <- dsLookupGlobalId mkClosurePName
97       applyClosurePVar <- dsLookupGlobalId applyClosurePName
98       replicatePAIntPrimVar <- dsLookupGlobalId replicatePAIntPrimName
99       upToPAIntPrimVar <- dsLookupGlobalId upToPAIntPrimName
100       lengthPAVar      <- dsLookupGlobalId lengthPAName
101       replicatePAVar   <- dsLookupGlobalId replicatePAName
102       emptyPAVar       <- dsLookupGlobalId emptyPAName
103       -- packPAVar        <- dsLookupGlobalId packPAName
104       -- combinePAVar     <- dsLookupGlobalId combinePAName
105
106       liftingContext <- liftM (\u -> mkSysLocal FSLIT("lc") u intPrimTy)
107                               newUnique
108
109       return $ Builtins {
110                  parrayTyCon      = parrayTyCon
111                , paTyCon          = paTyCon
112                , paDataCon        = paDataCon
113                , preprTyCon       = preprTyCon
114                , prTyCon          = prTyCon
115                , prDataCon        = prDataCon
116                , parrayIntPrimTyCon = parrayIntPrimTyCon
117                , sumTyCons        = sumTyCons
118                , closureTyCon     = closureTyCon
119                , mkPRVar          = mkPRVar
120                , mkClosureVar     = mkClosureVar
121                , applyClosureVar  = applyClosureVar
122                , mkClosurePVar    = mkClosurePVar
123                , applyClosurePVar = applyClosurePVar
124                , replicatePAIntPrimVar = replicatePAIntPrimVar
125                , upToPAIntPrimVar = upToPAIntPrimVar
126                , lengthPAVar      = lengthPAVar
127                , replicatePAVar   = replicatePAVar
128                , emptyPAVar       = emptyPAVar
129                -- , packPAVar        = packPAVar
130                -- , combinePAVar     = combinePAVar
131                , liftingContext   = liftingContext
132                }
133
134 initBuiltinTyCons :: DsM [(Name, TyCon)]
135 initBuiltinTyCons
136   = do
137       vects <- sequence vs
138       return (zip origs vects)
139   where
140     (origs, vs) = unzip builtinTyCons
141
142 builtinTyCons :: [(Name, DsM TyCon)]
143 builtinTyCons = [(tyConName funTyCon, dsLookupTyCon closureTyConName)]
144
145 initBuiltinDicts :: [(Name, Module, FastString)] -> DsM [(Name, Var)]
146 initBuiltinDicts ps
147   = do
148       dicts <- zipWithM lookupExternalVar mods fss
149       return $ zip tcs dicts
150   where
151     (tcs, mods, fss) = unzip3 ps
152
153 initBuiltinPAs = initBuiltinDicts builtinPAs
154
155 builtinPAs :: [(Name, Module, FastString)]
156 builtinPAs = [
157                mk closureTyConName  nDP_CLOSURE   FSLIT("dPA_Clo")
158              , mk unitTyConName     nDP_INSTANCES FSLIT("dPA_Unit")
159
160              , mk intTyConName      nDP_INSTANCES FSLIT("dPA_Int")
161              ]
162              ++ tups
163   where
164     mk name mod fs = (name, mod, fs)
165
166     tups = map mk_tup [2..3]
167     mk_tup n = mk (tyConName $ tupleTyCon Boxed n)
168                   nDP_INSTANCES
169                   (mkFastString $ "dPA_" ++ show n)
170
171 initBuiltinPRs = initBuiltinDicts . builtinPRs
172
173 builtinPRs :: Builtins -> [(Name, Module, FastString)]
174 builtinPRs bi =
175   [
176     mk (tyConName unitTyCon) nDP_REPR      FSLIT("dPR_Unit")
177   , mk closureTyConName      nDP_CLOSURE   FSLIT("dPR_Clo")
178
179     -- temporary
180   , mk intTyConName          nDP_INSTANCES FSLIT("dPR_Int")
181   ]
182
183   ++ map mk_sum  [2..mAX_NDP_SUM]
184   ++ map mk_prod [2..mAX_NDP_PROD]
185   where
186     mk name mod fs = (name, mod, fs)
187
188     mk_sum n = (tyConName $ sumTyCon n bi, nDP_REPR,
189                 mkFastString ("dPR_Sum" ++ show n))
190
191     mk_prod n = (tyConName $ prodTyCon n bi, nDP_REPR,
192                  mkFastString ("dPR_" ++ show n))
193
194 lookupExternalVar :: Module -> FastString -> DsM Var
195 lookupExternalVar mod fs
196   = dsLookupGlobalId =<< lookupOrig mod (mkVarOccFS fs)
197
198 lookupExternalTyCon :: Module -> FastString -> DsM TyCon
199 lookupExternalTyCon mod fs
200   = dsLookupTyCon =<< lookupOrig mod (mkOccNameFS tcName fs)
201
202 unitTyConName = tyConName unitTyCon
203
204
205 primMethod :: TyCon -> String -> DsM (Maybe Var)
206 primMethod tycon method
207   | Just suffix <- lookupNameEnv prim_ty_cons (tyConName tycon)
208   = liftM Just
209   $ dsLookupGlobalId =<< lookupOrig nDP_PRIM (mkVarOcc $ method ++ suffix)
210
211   | otherwise = return Nothing
212
213 primPArray :: TyCon -> DsM (Maybe TyCon)
214 primPArray tycon
215   | Just suffix <- lookupNameEnv prim_ty_cons (tyConName tycon)
216   = liftM Just
217   $ dsLookupTyCon =<< lookupOrig nDP_PRIM (mkOccName tcName $ "PArray" ++ suffix)
218
219   | otherwise = return Nothing
220
221 prim_ty_cons = mkNameEnv [mk_prim intPrimTyCon]
222   where
223     mk_prim tycon = (tyConName tycon, '_' : getOccString tycon)