Fix bug in vectorisation of case expressions
[ghc-hetmet.git] / compiler / vectorise / VectBuiltIn.hs
index 9d790c6..c61343b 100644 (file)
@@ -26,7 +26,7 @@ import TypeRep         ( funTyCon )
 import Type            ( Type, mkTyConApp )
 import TysPrim
 import TysWiredIn      ( unitTyCon, unitDataCon,
-                         tupleTyCon,
+                         tupleTyCon, tupleCon,
                          intTyCon, intTyConName, intTy,
                          doubleTyCon, doubleTyConName,
                          boolTyCon, boolTyConName, trueDataCon, falseDataCon,
@@ -53,8 +53,17 @@ mAX_NDP_COMBINE = 2
 mkNDPModule :: FastString -> Module
 mkNDPModule m = mkModule ndpPackageId (mkModuleNameFS m)
 
-nDP_UARR, nDP_PARRAY, nDP_REPR, nDP_CLOSURE, nDP_UNBOXED, nDP_INSTANCES, nDP_COMBINATORS,
-    nDP_PRELUDE_PARR, nDP_PRELUDE_INT, nDP_PRELUDE_DOUBLE :: Module
+nDP_UARR,
+  nDP_PARRAY,
+  nDP_REPR,
+  nDP_CLOSURE,
+  nDP_UNBOXED,
+  nDP_INSTANCES,
+  nDP_COMBINATORS,
+  nDP_PRELUDE_PARR,
+  nDP_PRELUDE_INT,
+  nDP_PRELUDE_DOUBLE,
+  nDP_PRELUDE_TUPLE :: Module
 
 nDP_UARR        = mkNDPModule FSLIT("Data.Array.Parallel.Unlifted.Flat.UArr")
 nDP_PARRAY      = mkNDPModule FSLIT("Data.Array.Parallel.Lifted.PArray")
@@ -67,6 +76,7 @@ nDP_COMBINATORS = mkNDPModule FSLIT("Data.Array.Parallel.Lifted.Combinators")
 nDP_PRELUDE_PARR = mkNDPModule FSLIT("Data.Array.Parallel.Prelude.Base.PArr")
 nDP_PRELUDE_INT  = mkNDPModule FSLIT("Data.Array.Parallel.Prelude.Base.Int")
 nDP_PRELUDE_DOUBLE = mkNDPModule FSLIT("Data.Array.Parallel.Prelude.Base.Double")
+nDP_PRELUDE_TUPLE  = mkNDPModule FSLIT("Data.Array.Parallel.Prelude.Base.Tuple")
 
 data Builtins = Builtins {
                   parrayTyCon      :: TyCon
@@ -200,20 +210,31 @@ initBuiltinVars _
   = do
       uvars <- zipWithM externalVar umods ufs
       vvars <- zipWithM externalVar vmods vfs
+      cvars <- zipWithM externalVar cmods cfs
       return $ [(v,v) | v <- map dataConWorkId defaultDataConWorkers]
+               ++ zip (map dataConWorkId cons) cvars
                ++ zip uvars vvars
   where
     (umods, ufs, vmods, vfs) = unzip4 preludeVars
 
+    (cons, cmods, cfs) = unzip3 preludeDataCons
+
 defaultDataConWorkers :: [DataCon]
 defaultDataConWorkers = [trueDataCon, falseDataCon, unitDataCon]
 
+preludeDataCons :: [(DataCon, Module, FastString)]
+preludeDataCons
+  = [mk_tup n nDP_PRELUDE_TUPLE (mkFastString $ "tup" ++ show n) | n <- [2..3]]
+  where
+    mk_tup n mod name = (tupleCon Boxed n, mod, name)
+
 preludeVars :: [(Module, FastString, Module, FastString)]
 preludeVars
   = [
       mk gHC_PARR FSLIT("mapP")       nDP_COMBINATORS FSLIT("mapPA")
     , mk gHC_PARR FSLIT("zipWithP")   nDP_COMBINATORS FSLIT("zipWithPA")
     , mk gHC_PARR FSLIT("zipP")       nDP_COMBINATORS FSLIT("zipPA")
+    , mk gHC_PARR FSLIT("unzipP")     nDP_COMBINATORS FSLIT("unzipPA")
     , mk gHC_PARR FSLIT("filterP")    nDP_COMBINATORS FSLIT("filterPA")
     , mk gHC_PARR FSLIT("lengthP")    nDP_COMBINATORS FSLIT("lengthPA")
     , mk gHC_PARR FSLIT("replicateP") nDP_COMBINATORS FSLIT("replicatePA")
@@ -240,6 +261,7 @@ preludeVars
     , mk nDP_PRELUDE_DOUBLE  FSLIT("plus") nDP_PRELUDE_DOUBLE FSLIT("plusV")
     , mk nDP_PRELUDE_DOUBLE  FSLIT("minus") nDP_PRELUDE_DOUBLE FSLIT("minusV")
     , mk nDP_PRELUDE_DOUBLE  FSLIT("mult")  nDP_PRELUDE_DOUBLE FSLIT("multV")
+    , mk nDP_PRELUDE_DOUBLE  FSLIT("divide")  nDP_PRELUDE_DOUBLE FSLIT("divideV")
     , mk nDP_PRELUDE_DOUBLE  FSLIT("sumP")  nDP_PRELUDE_DOUBLE FSLIT("sumPA")
     , mk nDP_PRELUDE_DOUBLE  FSLIT("minIndexP") 
          nDP_PRELUDE_DOUBLE  FSLIT("minIndexPA")