module VectBuiltIn (
Builtins(..), sumTyCon, prodTyCon,
- combinePAVar,
+ combinePAVar, scalarZip, closureCtrFun,
initBuiltins, initBuiltinVars, initBuiltinTyCons, initBuiltinDataCons,
initBuiltinPAs, initBuiltinPRs,
- initBuiltinBoxedTyCons,
+ initBuiltinBoxedTyCons, initBuiltinScalars,
primMethod, primPArray
) where
import Module
import DataCon ( DataCon, dataConName, dataConWorkId )
import TyCon ( TyCon, tyConName, tyConDataCons )
+import Class ( Class )
import Var ( Var )
import Id ( mkSysLocal )
import Name ( Name, getOccString )
mAX_DPH_COMBINE :: Int
mAX_DPH_COMBINE = 2
+mAX_DPH_SCALAR_ARGS :: Int
+mAX_DPH_SCALAR_ARGS = 3
+
data Modules = Modules {
dph_PArray :: Module
, dph_Repr :: Module
, dph_Unboxed :: Module
, dph_Instances :: Module
, dph_Combinators :: Module
+ , dph_Scalar :: Module
, dph_Prelude_PArr :: Module
, dph_Prelude_Int :: Module
, dph_Prelude_Word8 :: Module
, dph_Unboxed = mk (fsLit "Data.Array.Parallel.Lifted.Unboxed")
, dph_Instances = mk (fsLit "Data.Array.Parallel.Lifted.Instances")
, dph_Combinators = mk (fsLit "Data.Array.Parallel.Lifted.Combinators")
+ , dph_Scalar = mk (fsLit "Data.Array.Parallel.Lifted.Scalar")
, dph_Prelude_PArr = mk (fsLit "Data.Array.Parallel.Prelude.Base.PArr")
, dph_Prelude_Int = mk (fsLit "Data.Array.Parallel.Prelude.Base.Int")
, emptyPAVar :: Var
, packPAVar :: Var
, combinePAVars :: Array Int Var
+ , scalarClass :: Class
+ , scalarZips :: Array Int Var
+ , closureCtrFuns :: Array Int Var
, liftingContext :: Var
}
| n >= 2 && n <= mAX_DPH_COMBINE = combinePAVars bi ! n
| otherwise = pprPanic "combinePAVar" (ppr n)
+scalarZip :: Int -> Builtins -> Var
+scalarZip n bi
+ | n >= 1 && n <= mAX_DPH_SCALAR_ARGS = scalarZips bi ! n
+ | otherwise = pprPanic "scalarZip" (ppr n)
+
+closureCtrFun :: Int -> Builtins -> Var
+closureCtrFun n bi
+ | n >= 1 && n <= mAX_DPH_SCALAR_ARGS = closureCtrFuns bi ! n
+ | otherwise = pprPanic "closureCtrFun" (ppr n)
+
initBuiltins :: PackageId -> DsM Builtins
initBuiltins pkg
= do
| i <- [2..mAX_DPH_COMBINE]]
let combinePAVars = listArray (2, mAX_DPH_COMBINE) combines
+ scalarClass <- externalClass dph_Scalar (fsLit "Scalar")
+ scalar_map <- externalVar dph_Scalar (fsLit "scalar_map")
+ scalar_zip2 <- externalVar dph_Scalar (fsLit "scalar_zipWith")
+ scalar_zips <- mapM (externalVar dph_Scalar)
+ [mkFastString ("scalar_zipWith" ++ show i)
+ | i <- [3 .. mAX_DPH_SCALAR_ARGS]]
+ let scalarZips = listArray (1, mAX_DPH_SCALAR_ARGS)
+ (scalar_map : scalar_zip2 : scalar_zips)
+ closures <- mapM (externalVar dph_Closure)
+ [mkFastString ("closure" ++ show i)
+ | i <- [1 .. mAX_DPH_SCALAR_ARGS]]
+ let closureCtrFuns = listArray (1, mAX_DPH_COMBINE) closures
+
liftingContext <- liftM (\u -> mkSysLocal (fsLit "lc") u intPrimTy)
newUnique
, emptyPAVar = emptyPAVar
, packPAVar = packPAVar
, combinePAVars = combinePAVars
+ , scalarClass = scalarClass
+ , scalarZips = scalarZips
+ , closureCtrFuns = closureCtrFuns
, liftingContext = liftingContext
}
where
, dph_Repr = dph_Repr
, dph_Closure = dph_Closure
, dph_Unboxed = dph_Unboxed
+ , dph_Scalar = dph_Scalar
})
= dph_Modules pkg
builtinBoxedTyCons _ =
[(tyConName intPrimTyCon, intTyCon)]
+
+initBuiltinScalars :: Builtins -> DsM [Var]
+initBuiltinScalars bi
+ = mapM (uncurry externalVar) (preludeScalars $ dphModules bi)
+
+
+preludeScalars :: Modules -> [(Module, FastString)]
+preludeScalars (Modules { dph_Prelude_Int = dph_Prelude_Int
+ , dph_Prelude_Word8 = dph_Prelude_Word8
+ , dph_Prelude_Double = dph_Prelude_Double
+ })
+ = [
+ mk dph_Prelude_Int "div"
+ , mk dph_Prelude_Int "mod"
+ , mk dph_Prelude_Int "sqrt"
+ ]
+ ++ scalars_Ord dph_Prelude_Int
+ ++ scalars_Num dph_Prelude_Int
+
+ ++ scalars_Ord dph_Prelude_Word8
+ ++ scalars_Num dph_Prelude_Word8
+ ++
+ [ mk dph_Prelude_Word8 "div"
+ , mk dph_Prelude_Word8 "mod"
+ , mk dph_Prelude_Word8 "fromInt"
+ , mk dph_Prelude_Word8 "toInt"
+ ]
+
+ ++ scalars_Ord dph_Prelude_Double
+ ++ scalars_Num dph_Prelude_Double
+ ++ scalars_Fractional dph_Prelude_Double
+ ++ scalars_Floating dph_Prelude_Double
+ ++ scalars_RealFrac dph_Prelude_Double
+ where
+ mk mod s = (mod, fsLit s)
+
+ scalars_Ord mod = [mk mod "=="
+ ,mk mod "/="
+ ,mk mod "<="
+ ,mk mod "<"
+ ,mk mod ">="
+ ,mk mod ">"
+ ,mk mod "min"
+ ,mk mod "max"
+ ]
+
+ scalars_Num mod = [mk mod "+"
+ ,mk mod "-"
+ ,mk mod "*"
+ ,mk mod "negate"
+ ,mk mod "abs"
+ ]
+
+ scalars_Fractional mod = [mk mod "/"
+ ,mk mod "recip"
+ ]
+
+ scalars_Floating mod = [mk mod "pi"
+ ,mk mod "exp"
+ ,mk mod "sqrt"
+ ,mk mod "log"
+ ,mk mod "sin"
+ ,mk mod "tan"
+ ,mk mod "cos"
+ ,mk mod "asin"
+ ,mk mod "atan"
+ ,mk mod "acos"
+ ,mk mod "sinh"
+ ,mk mod "tanh"
+ ,mk mod "cosh"
+ ,mk mod "asinh"
+ ,mk mod "atanh"
+ ,mk mod "acosh"
+ ,mk mod "**"
+ ,mk mod "logBase"
+ ]
+
+ scalars_RealFrac mod = [mk mod "fromInt"
+ ,mk mod "truncate"
+ ,mk mod "round"
+ ,mk mod "ceiling"
+ ,mk mod "floor"
+ ]
+
+
externalVar :: Module -> FastString -> DsM Var
externalVar mod fs
= dsLookupGlobalId =<< lookupOrig mod (mkVarOccFS fs)
tycon <- externalTyCon mod fs
return $ mkTyConApp tycon []
+externalClass :: Module -> FastString -> DsM Class
+externalClass mod fs
+ = dsLookupClass =<< lookupOrig mod (mkTcOccFS fs)
+
unitTyConName :: Name
unitTyConName = tyConName unitTyCon
Scope(..),
VM,
- noV, traceNoV, tryV, maybeV, traceMaybeV, orElseV, fixV, localV, closedV,
+ noV, traceNoV, ensureV, traceEnsureV, tryV, maybeV, traceMaybeV, orElseV,
+ onlyIfV, fixV, localV, closedV,
initV, cantVectorise, maybeCantVectorise, maybeCantVectoriseM,
liftDs,
cloneName, cloneId, cloneVar,
newExportedVar, newLocalVar, newDummyVar, newTyVar,
Builtins(..), sumTyCon, prodTyCon,
- combinePAVar,
+ combinePAVar, scalarZip, closureCtrFun,
builtin, builtins,
GlobalEnv(..),
getBindName, inBind,
- lookupVar, defGlobalVar,
+ lookupVar, defGlobalVar, globalScalars,
lookupTyCon, defTyCon,
lookupDataCon, defDataCon,
lookupTyConPA, defTyConPA, defTyConPAs,
lookupPrimMethod, lookupPrimPArray,
lookupTyVarPA, defLocalTyVar, defLocalTyVarWithPA, localTyVars,
- {-lookupInst,-} lookupFamInst
+ lookupInst, lookupFamInst
) where
#include "HsVersions.h"
import HscTypes hiding ( MonadThings(..) )
import Module ( PackageId )
import CoreSyn
+import Class
import TyCon
import DataCon
import Type
import Var
+import VarSet
import VarEnv
import Id
import Name
--
global_vars :: VarEnv Var
+ -- Purely scalar variables. Code which mentions only these
+ -- variables doesn't have to be lifted.
+ , global_scalars :: VarSet
+
-- Exported variables which have a vectorised version
--
, global_exported_vars :: VarEnv (Var, Var)
initGlobalEnv info instEnvs famInstEnvs
= GlobalEnv {
global_vars = mapVarEnv snd $ vectInfoVar info
+ , global_scalars = emptyVarSet
, global_exported_vars = emptyVarEnv
, global_tycons = mapNameEnv snd $ vectInfoTyCon info
, global_datacons = mapNameEnv snd $ vectInfoDataCon info
extendImportedVarsEnv ps genv
= genv { global_vars = extendVarEnvList (global_vars genv) ps }
+extendScalars :: [Var] -> GlobalEnv -> GlobalEnv
+extendScalars vs genv
+ = genv { global_scalars = extendVarSetList (global_scalars genv) vs }
+
setFamInstEnv :: FamInstEnv -> GlobalEnv -> GlobalEnv
setFamInstEnv l_fam_inst genv
= genv { global_fam_inst_env = (g_fam_inst, l_fam_inst) }
traceNoV :: String -> SDoc -> VM a
traceNoV s d = pprTrace s d noV
+ensureV :: Bool -> VM ()
+ensureV False = noV
+ensureV True = return ()
+
+onlyIfV :: Bool -> VM a -> VM a
+onlyIfV b p = ensureV b >> p
+
+traceEnsureV :: String -> SDoc -> Bool -> VM ()
+traceEnsureV s d False = traceNoV s d
+traceEnsureV s d True = return ()
+
tryV :: VM a -> VM (Maybe a)
tryV (VM p) = VM $ \bi genv lenv ->
do
updLEnv :: (LocalEnv -> LocalEnv) -> VM ()
updLEnv f = VM $ \_ genv lenv -> return (Yes genv (f lenv) ())
-{-
getInstEnv :: VM (InstEnv, InstEnv)
getInstEnv = readGEnv global_inst_env
--}
getFamInstEnv :: VM FamInstEnvs
getFamInstEnv = readGEnv global_fam_inst_env
. maybeCantVectoriseM "Variable not vectorised:" (ppr v)
. readGEnv $ \env -> lookupVarEnv (global_vars env) v
+globalScalars :: VM VarSet
+globalScalars = readGEnv global_scalars
+
lookupTyCon :: TyCon -> VM (Maybe TyCon)
lookupTyCon tc
| isUnLiftedTyCon tc || isTupleTyCon tc = return (Just tc)
-- instances head (i.e., no flexi vars); for details for what this means,
-- see the docs at InstEnv.lookupInstEnv.
--
-{-
lookupInst :: Class -> [Type] -> VM (DFunId, [Type])
lookupInst cls tys
= do { instEnv <- getInstEnv
where
inst_tys' = [ty | Right ty <- inst_tys]
noFlexiVar = all isRight inst_tys
- _other -> traceNoV "lookupInst" (ppr cls <+> ppr tys)
+ _other ->
+ pprPanic "VectMonad.lookupInst: not found " (ppr cls <+> ppr tys)
}
where
isRight (Left _) = False
isRight (Right _) = True
--}
-- Look up the representation tycon of a family instance.
--
builtin_pas <- initBuiltinPAs builtins
builtin_prs <- initBuiltinPRs builtins
builtin_boxed <- initBuiltinBoxedTyCons builtins
+ builtin_scalars <- initBuiltinScalars builtins
eps <- liftIO $ hscEPS hsc_env
let famInstEnvs = (eps_fam_inst_env eps, mg_fam_inst_env guts)
instEnvs = (eps_inst_env eps, mg_inst_env guts)
let genv = extendImportedVarsEnv builtin_vars
+ . extendScalars builtin_scalars
. extendTyConsEnv builtin_tycons
. extendDataConsEnv builtin_datacons
. extendPAFunsEnv builtin_pas
prDFunOfTyCon,
paDictArgType, paDictOfType, paDFunType,
paMethod, mkPR, lengthPA, replicatePA, emptyPA, packPA, combinePA, liftPA,
+ zipScalars, scalarClosure,
polyAbstract, polyApply, polyVApply,
hoistBinding, hoistExpr, hoistPolyVExpr, takeHoisted,
buildClosure, buildClosures,
lc <- builtin liftingContext
replicatePA (Var lc) x
+zipScalars :: [Type] -> Type -> VM CoreExpr
+zipScalars arg_tys res_ty
+ = do
+ scalar <- builtin scalarClass
+ (dfuns, _) <- mapAndUnzipM (\ty -> lookupInst scalar [ty]) ty_args
+ zipf <- builtin (scalarZip $ length arg_tys)
+ return $ Var zipf `mkTyApps` ty_args `mkApps` map Var dfuns
+ where
+ ty_args = arg_tys ++ [res_ty]
+
+scalarClosure :: [Type] -> Type -> CoreExpr -> CoreExpr -> VM CoreExpr
+scalarClosure arg_tys res_ty scalar_fun array_fun
+ = do
+ ctr <- builtin (closureCtrFun $ length arg_tys)
+ pas <- mapM paDictOfType (init arg_tys)
+ return $ Var ctr `mkTyApps` (arg_tys ++ [res_ty])
+ `mkApps` (pas ++ [scalar_fun, array_fun])
+
newLocalVVar :: FastString -> Type -> VM VVar
newLocalVVar fs vty
= do
$ vectExpr rhs
vectExpr e@(fvs, AnnLam bndr _)
- | isId bndr = vectLam fvs bs body
+ | isId bndr = onlyIfV (isEmptyVarSet fvs) (vectScalarLam bs $ deAnnotate body)
+ `orElseV` vectLam fvs bs body
where
(bs,body) = collectAnnValBinders e
vectExpr e = cantVectorise "Can't vectorise expression" (ppr $ deAnnotate e)
+vectScalarLam :: [Var] -> CoreExpr -> VM VExpr
+vectScalarLam args body
+ = do
+ scalars <- globalScalars
+ onlyIfV (all is_scalar_ty arg_tys
+ && is_scalar_ty res_ty
+ && is_scalar (extendVarSetList scalars args) body)
+ $ do
+ fn_var <- hoistExpr (fsLit "fn") (mkLams args body)
+ zipf <- zipScalars arg_tys res_ty
+ clo <- scalarClosure arg_tys res_ty (Var fn_var)
+ (zipf `App` Var fn_var)
+ clo_var <- hoistExpr (fsLit "clo") clo
+ lclo <- liftPA (Var clo_var)
+ return (Var clo_var, lclo)
+ where
+ arg_tys = map idType args
+ res_ty = exprType body
+
+ is_scalar_ty ty | Just (tycon, []) <- splitTyConApp_maybe ty
+ = tycon == intTyCon
+ || tycon == floatTyCon
+ || tycon == doubleTyCon
+
+ | otherwise = False
+
+ is_scalar vs (Var v) = v `elemVarSet` vs
+ is_scalar _ e@(Lit l) = is_scalar_ty $ exprType e
+ is_scalar vs (App e1 e2) = is_scalar vs e1 && is_scalar vs e2
+ is_scalar _ _ = False
+
vectLam :: VarSet -> [Var] -> CoreExprWithFVs -> VM VExpr
vectLam fvs bs body
= do