X-Git-Url: http://git.megacz.com/?p=ghc-hetmet.git;a=blobdiff_plain;f=compiler%2Fvectorise%2FVectBuiltIn.hs;h=4fe7e9e6e2bf96b775f0dde0a13fecf8a5df34d4;hp=cbcb47db76720104d258b7d992d1ce4768693fb4;hb=28bb3c3c8c1467ca31db59f0b3d1a21df6607742;hpb=7106cd1bb3633ee274673cd0d1ea82315ca8b56d diff --git a/compiler/vectorise/VectBuiltIn.hs b/compiler/vectorise/VectBuiltIn.hs index cbcb47d..4fe7e9e 100644 --- a/compiler/vectorise/VectBuiltIn.hs +++ b/compiler/vectorise/VectBuiltIn.hs @@ -1,9 +1,9 @@ module VectBuiltIn ( Builtins(..), sumTyCon, prodTyCon, - combinePAVar, + combinePAVar, scalarZip, closureCtrFun, initBuiltins, initBuiltinVars, initBuiltinTyCons, initBuiltinDataCons, initBuiltinPAs, initBuiltinPRs, - initBuiltinBoxedTyCons, + initBuiltinBoxedTyCons, initBuiltinScalars, primMethod, primPArray ) where @@ -14,6 +14,7 @@ import IfaceEnv ( lookupOrig ) import Module import DataCon ( DataCon, dataConName, dataConWorkId ) import TyCon ( TyCon, tyConName, tyConDataCons ) +import Class ( Class ) import Var ( Var ) import Id ( mkSysLocal ) import Name ( Name, getOccString ) @@ -48,6 +49,9 @@ mAX_DPH_SUM = 3 mAX_DPH_COMBINE :: Int mAX_DPH_COMBINE = 2 +mAX_DPH_SCALAR_ARGS :: Int +mAX_DPH_SCALAR_ARGS = 3 + data Modules = Modules { dph_PArray :: Module , dph_Repr :: Module @@ -55,6 +59,7 @@ data Modules = Modules { , dph_Unboxed :: Module , dph_Instances :: Module , dph_Combinators :: Module + , dph_Scalar :: Module , dph_Prelude_PArr :: Module , dph_Prelude_Int :: Module , dph_Prelude_Word8 :: Module @@ -71,6 +76,7 @@ dph_Modules pkg = Modules { , dph_Unboxed = mk (fsLit "Data.Array.Parallel.Lifted.Unboxed") , dph_Instances = mk (fsLit "Data.Array.Parallel.Lifted.Instances") , dph_Combinators = mk (fsLit "Data.Array.Parallel.Lifted.Combinators") + , dph_Scalar = mk (fsLit "Data.Array.Parallel.Lifted.Scalar") , dph_Prelude_PArr = mk (fsLit "Data.Array.Parallel.Prelude.Base.PArr") , dph_Prelude_Int = mk (fsLit "Data.Array.Parallel.Prelude.Base.Int") @@ -112,6 +118,9 @@ data Builtins = Builtins { , emptyPAVar :: Var , packPAVar :: Var , combinePAVars :: Array Int Var + , scalarClass :: Class + , scalarZips :: Array Int Var + , closureCtrFuns :: Array Int Var , liftingContext :: Var } @@ -131,6 +140,16 @@ combinePAVar n bi | n >= 2 && n <= mAX_DPH_COMBINE = combinePAVars bi ! n | otherwise = pprPanic "combinePAVar" (ppr n) +scalarZip :: Int -> Builtins -> Var +scalarZip n bi + | n >= 1 && n <= mAX_DPH_SCALAR_ARGS = scalarZips bi ! n + | otherwise = pprPanic "scalarZip" (ppr n) + +closureCtrFun :: Int -> Builtins -> Var +closureCtrFun n bi + | n >= 1 && n <= mAX_DPH_SCALAR_ARGS = closureCtrFuns bi ! n + | otherwise = pprPanic "closureCtrFun" (ppr n) + initBuiltins :: PackageId -> DsM Builtins initBuiltins pkg = do @@ -171,6 +190,19 @@ initBuiltins pkg | i <- [2..mAX_DPH_COMBINE]] let combinePAVars = listArray (2, mAX_DPH_COMBINE) combines + scalarClass <- externalClass dph_Scalar (fsLit "Scalar") + scalar_map <- externalVar dph_Scalar (fsLit "scalar_map") + scalar_zip2 <- externalVar dph_Scalar (fsLit "scalar_zipWith") + scalar_zips <- mapM (externalVar dph_Scalar) + [mkFastString ("scalar_zipWith" ++ show i) + | i <- [3 .. mAX_DPH_SCALAR_ARGS]] + let scalarZips = listArray (1, mAX_DPH_SCALAR_ARGS) + (scalar_map : scalar_zip2 : scalar_zips) + closures <- mapM (externalVar dph_Closure) + [mkFastString ("closure" ++ show i) + | i <- [1 .. mAX_DPH_SCALAR_ARGS]] + let closureCtrFuns = listArray (1, mAX_DPH_COMBINE) closures + liftingContext <- liftM (\u -> mkSysLocal (fsLit "lc") u intPrimTy) newUnique @@ -203,6 +235,9 @@ initBuiltins pkg , emptyPAVar = emptyPAVar , packPAVar = packPAVar , combinePAVars = combinePAVars + , scalarClass = scalarClass + , scalarZips = scalarZips + , closureCtrFuns = closureCtrFuns , liftingContext = liftingContext } where @@ -211,6 +246,7 @@ initBuiltins pkg , dph_Repr = dph_Repr , dph_Closure = dph_Closure , dph_Unboxed = dph_Unboxed + , dph_Scalar = dph_Scalar }) = dph_Modules pkg @@ -447,6 +483,91 @@ builtinBoxedTyCons :: Builtins -> [(Name, TyCon)] builtinBoxedTyCons _ = [(tyConName intPrimTyCon, intTyCon)] + +initBuiltinScalars :: Builtins -> DsM [Var] +initBuiltinScalars bi + = mapM (uncurry externalVar) (preludeScalars $ dphModules bi) + + +preludeScalars :: Modules -> [(Module, FastString)] +preludeScalars (Modules { dph_Prelude_Int = dph_Prelude_Int + , dph_Prelude_Word8 = dph_Prelude_Word8 + , dph_Prelude_Double = dph_Prelude_Double + }) + = [ + mk dph_Prelude_Int "div" + , mk dph_Prelude_Int "mod" + , mk dph_Prelude_Int "sqrt" + ] + ++ scalars_Ord dph_Prelude_Int + ++ scalars_Num dph_Prelude_Int + + ++ scalars_Ord dph_Prelude_Word8 + ++ scalars_Num dph_Prelude_Word8 + ++ + [ mk dph_Prelude_Word8 "div" + , mk dph_Prelude_Word8 "mod" + , mk dph_Prelude_Word8 "fromInt" + , mk dph_Prelude_Word8 "toInt" + ] + + ++ scalars_Ord dph_Prelude_Double + ++ scalars_Num dph_Prelude_Double + ++ scalars_Fractional dph_Prelude_Double + ++ scalars_Floating dph_Prelude_Double + ++ scalars_RealFrac dph_Prelude_Double + where + mk mod s = (mod, fsLit s) + + scalars_Ord mod = [mk mod "==" + ,mk mod "/=" + ,mk mod "<=" + ,mk mod "<" + ,mk mod ">=" + ,mk mod ">" + ,mk mod "min" + ,mk mod "max" + ] + + scalars_Num mod = [mk mod "+" + ,mk mod "-" + ,mk mod "*" + ,mk mod "negate" + ,mk mod "abs" + ] + + scalars_Fractional mod = [mk mod "/" + ,mk mod "recip" + ] + + scalars_Floating mod = [mk mod "pi" + ,mk mod "exp" + ,mk mod "sqrt" + ,mk mod "log" + ,mk mod "sin" + ,mk mod "tan" + ,mk mod "cos" + ,mk mod "asin" + ,mk mod "atan" + ,mk mod "acos" + ,mk mod "sinh" + ,mk mod "tanh" + ,mk mod "cosh" + ,mk mod "asinh" + ,mk mod "atanh" + ,mk mod "acosh" + ,mk mod "**" + ,mk mod "logBase" + ] + + scalars_RealFrac mod = [mk mod "fromInt" + ,mk mod "truncate" + ,mk mod "round" + ,mk mod "ceiling" + ,mk mod "floor" + ] + + externalVar :: Module -> FastString -> DsM Var externalVar mod fs = dsLookupGlobalId =<< lookupOrig mod (mkVarOccFS fs) @@ -461,6 +582,10 @@ externalType mod fs tycon <- externalTyCon mod fs return $ mkTyConApp tycon [] +externalClass :: Module -> FastString -> DsM Class +externalClass mod fs + = dsLookupClass =<< lookupOrig mod (mkTcOccFS fs) + unitTyConName :: Name unitTyConName = tyConName unitTyCon