Finish breaking up VectBuiltIn and VectMonad, and add comments
[ghc-hetmet.git] / compiler / vectorise / Vectorise / Builtins / Prelude.hs
diff --git a/compiler/vectorise/Vectorise/Builtins/Prelude.hs b/compiler/vectorise/Vectorise/Builtins/Prelude.hs
new file mode 100644 (file)
index 0000000..b578f30
--- /dev/null
@@ -0,0 +1,236 @@
+
+-- | Mapping of prelude functions to vectorised versions.
+--     Functions like filterP currently have a working but naive version in GHC.PArr
+--     During vectorisation we replace these by calls to filterPA, which are
+--     defined in dph-common Data.Array.Parallel.Lifted.Combinators
+--
+--     As renamer only sees the GHC.PArr functions, if you want to add a new function
+--     to the vectoriser there has to be a definition for it in GHC.PArr, even though
+--     it will never be used at runtime.
+--
+module Vectorise.Builtins.Prelude
+       ( preludeVars
+       , preludeScalars)
+where
+import Vectorise.Builtins.Modules
+import PrelNames
+import Module
+import FastString
+
+
+preludeVars
+       :: Modules                      -- ^ Modules containing the DPH backens
+       -> [( Module, FastString        --   Maps the original variable to the one in the DPH 
+           , Module, FastString)]      --   packages that it should be rewritten to.
+
+preludeVars (Modules { dph_Combinators    = dph_Combinators
+                     , dph_PArray         = dph_PArray
+                     , dph_Prelude_Int    = dph_Prelude_Int
+                     , dph_Prelude_Word8  = dph_Prelude_Word8
+                     , dph_Prelude_Double = dph_Prelude_Double
+                     , dph_Prelude_Bool   = dph_Prelude_Bool 
+                     , dph_Prelude_PArr   = dph_Prelude_PArr
+                     })
+
+    -- Functions that work on whole PArrays, defined in GHC.PArr
+  = [ mk gHC_PARR (fsLit "mapP")       dph_Combinators (fsLit "mapPA")
+    , mk gHC_PARR (fsLit "zipWithP")   dph_Combinators (fsLit "zipWithPA")
+    , mk gHC_PARR (fsLit "zipP")       dph_Combinators (fsLit "zipPA")
+    , mk gHC_PARR (fsLit "unzipP")     dph_Combinators (fsLit "unzipPA")
+    , mk gHC_PARR (fsLit "filterP")    dph_Combinators (fsLit "filterPA")
+    , mk gHC_PARR (fsLit "lengthP")    dph_Combinators (fsLit "lengthPA")
+    , mk gHC_PARR (fsLit "replicateP") dph_Combinators (fsLit "replicatePA")
+    , mk gHC_PARR (fsLit "!:")         dph_Combinators (fsLit "indexPA")
+    , mk gHC_PARR (fsLit "sliceP")     dph_Combinators (fsLit "slicePA")
+    , mk gHC_PARR (fsLit "crossMapP")  dph_Combinators (fsLit "crossMapPA")
+    , mk gHC_PARR (fsLit "singletonP") dph_Combinators (fsLit "singletonPA")
+    , mk gHC_PARR (fsLit "concatP")    dph_Combinators (fsLit "concatPA")
+    , mk gHC_PARR (fsLit "+:+")        dph_Combinators (fsLit "appPA")
+    , mk gHC_PARR (fsLit "emptyP")     dph_PArray      (fsLit "emptyPA")
+
+    -- Map scalar functions to versions using closures. 
+    , mk' dph_Prelude_Int "div"         "divV"
+    , mk' dph_Prelude_Int "mod"         "modV"
+    , mk' dph_Prelude_Int "sqrt"        "sqrtV"
+    , mk' dph_Prelude_Int "enumFromToP" "enumFromToPA"
+    -- , mk' dph_Prelude_Int "upToP" "upToPA"
+    ]
+    ++ vars_Ord dph_Prelude_Int
+    ++ vars_Num dph_Prelude_Int
+
+    ++ vars_Ord dph_Prelude_Word8
+    ++ vars_Num dph_Prelude_Word8
+    ++
+    [ mk' dph_Prelude_Word8 "div"     "divV"
+    , mk' dph_Prelude_Word8 "mod"     "modV"
+    , mk' dph_Prelude_Word8 "fromInt" "fromIntV"
+    , mk' dph_Prelude_Word8 "toInt"   "toIntV"
+    ]
+
+    ++ vars_Ord        dph_Prelude_Double
+    ++ vars_Num        dph_Prelude_Double
+    ++ vars_Fractional dph_Prelude_Double
+    ++ vars_Floating   dph_Prelude_Double
+    ++ vars_RealFrac   dph_Prelude_Double
+    ++
+    [ mk dph_Prelude_Bool  (fsLit "andP")  dph_Prelude_Bool (fsLit "andPA")
+    , mk dph_Prelude_Bool  (fsLit "orP")   dph_Prelude_Bool (fsLit "orPA")
+
+    , mk gHC_CLASSES (fsLit "not")         dph_Prelude_Bool (fsLit "notV")
+    , mk gHC_CLASSES (fsLit "&&")          dph_Prelude_Bool (fsLit "andV")
+    , mk gHC_CLASSES (fsLit "||")          dph_Prelude_Bool (fsLit "orV")
+
+    -- FIXME: temporary
+    , mk dph_Prelude_PArr (fsLit "fromPArrayP")       dph_Prelude_PArr (fsLit "fromPArrayPA")
+    , mk dph_Prelude_PArr (fsLit "toPArrayP")         dph_Prelude_PArr (fsLit "toPArrayPA")
+    , mk dph_Prelude_PArr (fsLit "fromNestedPArrayP") dph_Prelude_PArr (fsLit "fromNestedPArrayPA")
+    , mk dph_Prelude_PArr (fsLit "combineP")          dph_Combinators  (fsLit "combine2PA")
+    , mk dph_Prelude_PArr (fsLit "updateP")           dph_Combinators  (fsLit "updatePA")
+    , mk dph_Prelude_PArr (fsLit "bpermuteP")         dph_Combinators  (fsLit "bpermutePA")
+    , mk dph_Prelude_PArr (fsLit "indexedP")          dph_Combinators  (fsLit "indexedPA")
+    ]
+  where
+    mk  = (,,,)
+    mk' mod v v' = mk mod (fsLit v) mod (fsLit v')
+
+    vars_Ord mod 
+     = [ mk' mod "=="        "eqV"
+       , mk' mod "/="        "neqV"
+       , mk' mod "<="        "leV"
+       , mk' mod "<"         "ltV"
+       , mk' mod ">="        "geV"
+       , mk' mod ">"         "gtV"
+       , mk' mod "min"       "minV"
+       , mk' mod "max"       "maxV"
+       , mk' mod "minimumP"  "minimumPA"
+       , mk' mod "maximumP"  "maximumPA"
+       , mk' mod "minIndexP" "minIndexPA"
+       , mk' mod "maxIndexP" "maxIndexPA"
+       ]
+
+    vars_Num mod 
+     = [ mk' mod "+"        "plusV"
+       , mk' mod "-"        "minusV"
+       , mk' mod "*"        "multV"
+       , mk' mod "negate"   "negateV"
+       , mk' mod "abs"      "absV"
+       , mk' mod "sumP"     "sumPA"
+       , mk' mod "productP" "productPA"
+       ]
+
+    vars_Fractional mod 
+     = [ mk' mod "/"     "divideV"
+       , mk' mod "recip" "recipV"
+       ]
+
+    vars_Floating mod 
+     = [ mk' mod "pi"      "pi"
+       , mk' mod "exp"     "expV"
+       , mk' mod "sqrt"    "sqrtV"
+       , mk' mod "log"     "logV"
+       , mk' mod "sin"     "sinV"
+       , mk' mod "tan"     "tanV"
+       , mk' mod "cos"     "cosV"
+       , mk' mod "asin"    "asinV"
+       , mk' mod "atan"    "atanV"
+       , mk' mod "acos"    "acosV"
+       , mk' mod "sinh"    "sinhV"
+       , mk' mod "tanh"    "tanhV"
+       , mk' mod "cosh"    "coshV"
+       , mk' mod "asinh"   "asinhV"
+       , mk' mod "atanh"   "atanhV"
+       , mk' mod "acosh"   "acoshV"
+       , mk' mod "**"      "powV"
+       , mk' mod "logBase" "logBaseV"
+       ]
+
+    vars_RealFrac mod
+     = [ mk' mod "fromInt"  "fromIntV"
+       , mk' mod "truncate" "truncateV"
+       , mk' mod "round"    "roundV"
+       , mk' mod "ceiling"  "ceilingV"
+       , mk' mod "floor"    "floorV"
+       ]
+
+
+preludeScalars :: Modules -> [(Module, FastString)]
+preludeScalars (Modules { dph_Prelude_Int    = dph_Prelude_Int
+                        , dph_Prelude_Word8  = dph_Prelude_Word8
+                        , dph_Prelude_Double = dph_Prelude_Double
+                        })
+  = [ mk dph_Prelude_Int "div"
+    , mk dph_Prelude_Int "mod"
+    , mk dph_Prelude_Int "sqrt"
+    ]
+    ++ scalars_Ord dph_Prelude_Int
+    ++ scalars_Num dph_Prelude_Int
+
+    ++ scalars_Ord dph_Prelude_Word8
+    ++ scalars_Num dph_Prelude_Word8
+    ++
+    [ mk dph_Prelude_Word8 "div"
+    , mk dph_Prelude_Word8 "mod"
+    , mk dph_Prelude_Word8 "fromInt"
+    , mk dph_Prelude_Word8 "toInt"
+    ]
+
+    ++ scalars_Ord dph_Prelude_Double
+    ++ scalars_Num dph_Prelude_Double
+    ++ scalars_Fractional dph_Prelude_Double
+    ++ scalars_Floating dph_Prelude_Double
+    ++ scalars_RealFrac dph_Prelude_Double
+  where
+    mk mod s = (mod, fsLit s)
+
+    scalars_Ord mod 
+     = [ mk mod "=="
+       , mk mod "/="
+       , mk mod "<="
+       , mk mod "<"
+       , mk mod ">="
+       , mk mod ">"
+       , mk mod "min"
+       , mk mod "max"
+       ]
+
+    scalars_Num mod 
+     = [ mk mod "+"
+       , mk mod "-"
+       , mk mod "*"
+       , mk mod "negate"
+       , mk mod "abs"
+       ]
+
+    scalars_Fractional mod 
+     = [ mk mod "/"
+       , mk mod "recip"
+       ]
+
+    scalars_Floating mod 
+     = [ mk mod "pi"
+       , mk mod "exp"
+       , mk mod "sqrt"
+       , mk mod "log"
+       , mk mod "sin"
+       , mk mod "tan"
+       , mk mod "cos"
+       , mk mod "asin"
+       , mk mod "atan"
+       , mk mod "acos"
+       , mk mod "sinh"
+       , mk mod "tanh"
+       , mk mod "cosh"
+       , mk mod "asinh"
+       , mk mod "atanh"
+       , mk mod "acosh"
+       , mk mod "**"
+       , mk mod "logBase"
+       ]
+
+    scalars_RealFrac mod 
+     = [ mk mod "fromInt"
+       , mk mod "truncate"
+       , mk mod "round"
+       , mk mod "ceiling"
+       , mk mod "floor"
+       ]