From 28bb3c3c8c1467ca31db59f0b3d1a21df6607742 Mon Sep 17 00:00:00 2001 From: Roman Leshchinskiy Date: Fri, 6 Mar 2009 11:55:08 +0000 Subject: [PATCH] Try not to avoid vectorising purely scalar functions --- compiler/vectorise/VectBuiltIn.hs | 129 ++++++++++++++++++++++++++++++++++++- compiler/vectorise/VectMonad.hs | 43 ++++++++++--- compiler/vectorise/VectUtils.hs | 19 ++++++ compiler/vectorise/Vectorise.hs | 34 +++++++++- 4 files changed, 213 insertions(+), 12 deletions(-) diff --git a/compiler/vectorise/VectBuiltIn.hs b/compiler/vectorise/VectBuiltIn.hs index cbcb47d..4fe7e9e 100644 --- a/compiler/vectorise/VectBuiltIn.hs +++ b/compiler/vectorise/VectBuiltIn.hs @@ -1,9 +1,9 @@ module VectBuiltIn ( Builtins(..), sumTyCon, prodTyCon, - combinePAVar, + combinePAVar, scalarZip, closureCtrFun, initBuiltins, initBuiltinVars, initBuiltinTyCons, initBuiltinDataCons, initBuiltinPAs, initBuiltinPRs, - initBuiltinBoxedTyCons, + initBuiltinBoxedTyCons, initBuiltinScalars, primMethod, primPArray ) where @@ -14,6 +14,7 @@ import IfaceEnv ( lookupOrig ) import Module import DataCon ( DataCon, dataConName, dataConWorkId ) import TyCon ( TyCon, tyConName, tyConDataCons ) +import Class ( Class ) import Var ( Var ) import Id ( mkSysLocal ) import Name ( Name, getOccString ) @@ -48,6 +49,9 @@ mAX_DPH_SUM = 3 mAX_DPH_COMBINE :: Int mAX_DPH_COMBINE = 2 +mAX_DPH_SCALAR_ARGS :: Int +mAX_DPH_SCALAR_ARGS = 3 + data Modules = Modules { dph_PArray :: Module , dph_Repr :: Module @@ -55,6 +59,7 @@ data Modules = Modules { , dph_Unboxed :: Module , dph_Instances :: Module , dph_Combinators :: Module + , dph_Scalar :: Module , dph_Prelude_PArr :: Module , dph_Prelude_Int :: Module , dph_Prelude_Word8 :: Module @@ -71,6 +76,7 @@ dph_Modules pkg = Modules { , dph_Unboxed = mk (fsLit "Data.Array.Parallel.Lifted.Unboxed") , dph_Instances = mk (fsLit "Data.Array.Parallel.Lifted.Instances") , dph_Combinators = mk (fsLit "Data.Array.Parallel.Lifted.Combinators") + , dph_Scalar = mk (fsLit "Data.Array.Parallel.Lifted.Scalar") , dph_Prelude_PArr = mk (fsLit "Data.Array.Parallel.Prelude.Base.PArr") , dph_Prelude_Int = mk (fsLit "Data.Array.Parallel.Prelude.Base.Int") @@ -112,6 +118,9 @@ data Builtins = Builtins { , emptyPAVar :: Var , packPAVar :: Var , combinePAVars :: Array Int Var + , scalarClass :: Class + , scalarZips :: Array Int Var + , closureCtrFuns :: Array Int Var , liftingContext :: Var } @@ -131,6 +140,16 @@ combinePAVar n bi | n >= 2 && n <= mAX_DPH_COMBINE = combinePAVars bi ! n | otherwise = pprPanic "combinePAVar" (ppr n) +scalarZip :: Int -> Builtins -> Var +scalarZip n bi + | n >= 1 && n <= mAX_DPH_SCALAR_ARGS = scalarZips bi ! n + | otherwise = pprPanic "scalarZip" (ppr n) + +closureCtrFun :: Int -> Builtins -> Var +closureCtrFun n bi + | n >= 1 && n <= mAX_DPH_SCALAR_ARGS = closureCtrFuns bi ! n + | otherwise = pprPanic "closureCtrFun" (ppr n) + initBuiltins :: PackageId -> DsM Builtins initBuiltins pkg = do @@ -171,6 +190,19 @@ initBuiltins pkg | i <- [2..mAX_DPH_COMBINE]] let combinePAVars = listArray (2, mAX_DPH_COMBINE) combines + scalarClass <- externalClass dph_Scalar (fsLit "Scalar") + scalar_map <- externalVar dph_Scalar (fsLit "scalar_map") + scalar_zip2 <- externalVar dph_Scalar (fsLit "scalar_zipWith") + scalar_zips <- mapM (externalVar dph_Scalar) + [mkFastString ("scalar_zipWith" ++ show i) + | i <- [3 .. mAX_DPH_SCALAR_ARGS]] + let scalarZips = listArray (1, mAX_DPH_SCALAR_ARGS) + (scalar_map : scalar_zip2 : scalar_zips) + closures <- mapM (externalVar dph_Closure) + [mkFastString ("closure" ++ show i) + | i <- [1 .. mAX_DPH_SCALAR_ARGS]] + let closureCtrFuns = listArray (1, mAX_DPH_COMBINE) closures + liftingContext <- liftM (\u -> mkSysLocal (fsLit "lc") u intPrimTy) newUnique @@ -203,6 +235,9 @@ initBuiltins pkg , emptyPAVar = emptyPAVar , packPAVar = packPAVar , combinePAVars = combinePAVars + , scalarClass = scalarClass + , scalarZips = scalarZips + , closureCtrFuns = closureCtrFuns , liftingContext = liftingContext } where @@ -211,6 +246,7 @@ initBuiltins pkg , dph_Repr = dph_Repr , dph_Closure = dph_Closure , dph_Unboxed = dph_Unboxed + , dph_Scalar = dph_Scalar }) = dph_Modules pkg @@ -447,6 +483,91 @@ builtinBoxedTyCons :: Builtins -> [(Name, TyCon)] builtinBoxedTyCons _ = [(tyConName intPrimTyCon, intTyCon)] + +initBuiltinScalars :: Builtins -> DsM [Var] +initBuiltinScalars bi + = mapM (uncurry externalVar) (preludeScalars $ dphModules bi) + + +preludeScalars :: Modules -> [(Module, FastString)] +preludeScalars (Modules { dph_Prelude_Int = dph_Prelude_Int + , dph_Prelude_Word8 = dph_Prelude_Word8 + , dph_Prelude_Double = dph_Prelude_Double + }) + = [ + mk dph_Prelude_Int "div" + , mk dph_Prelude_Int "mod" + , mk dph_Prelude_Int "sqrt" + ] + ++ scalars_Ord dph_Prelude_Int + ++ scalars_Num dph_Prelude_Int + + ++ scalars_Ord dph_Prelude_Word8 + ++ scalars_Num dph_Prelude_Word8 + ++ + [ mk dph_Prelude_Word8 "div" + , mk dph_Prelude_Word8 "mod" + , mk dph_Prelude_Word8 "fromInt" + , mk dph_Prelude_Word8 "toInt" + ] + + ++ scalars_Ord dph_Prelude_Double + ++ scalars_Num dph_Prelude_Double + ++ scalars_Fractional dph_Prelude_Double + ++ scalars_Floating dph_Prelude_Double + ++ scalars_RealFrac dph_Prelude_Double + where + mk mod s = (mod, fsLit s) + + scalars_Ord mod = [mk mod "==" + ,mk mod "/=" + ,mk mod "<=" + ,mk mod "<" + ,mk mod ">=" + ,mk mod ">" + ,mk mod "min" + ,mk mod "max" + ] + + scalars_Num mod = [mk mod "+" + ,mk mod "-" + ,mk mod "*" + ,mk mod "negate" + ,mk mod "abs" + ] + + scalars_Fractional mod = [mk mod "/" + ,mk mod "recip" + ] + + scalars_Floating mod = [mk mod "pi" + ,mk mod "exp" + ,mk mod "sqrt" + ,mk mod "log" + ,mk mod "sin" + ,mk mod "tan" + ,mk mod "cos" + ,mk mod "asin" + ,mk mod "atan" + ,mk mod "acos" + ,mk mod "sinh" + ,mk mod "tanh" + ,mk mod "cosh" + ,mk mod "asinh" + ,mk mod "atanh" + ,mk mod "acosh" + ,mk mod "**" + ,mk mod "logBase" + ] + + scalars_RealFrac mod = [mk mod "fromInt" + ,mk mod "truncate" + ,mk mod "round" + ,mk mod "ceiling" + ,mk mod "floor" + ] + + externalVar :: Module -> FastString -> DsM Var externalVar mod fs = dsLookupGlobalId =<< lookupOrig mod (mkVarOccFS fs) @@ -461,6 +582,10 @@ externalType mod fs tycon <- externalTyCon mod fs return $ mkTyConApp tycon [] +externalClass :: Module -> FastString -> DsM Class +externalClass mod fs + = dsLookupClass =<< lookupOrig mod (mkTcOccFS fs) + unitTyConName :: Name unitTyConName = tyConName unitTyCon diff --git a/compiler/vectorise/VectMonad.hs b/compiler/vectorise/VectMonad.hs index 56f5b8f..bc120cd 100644 --- a/compiler/vectorise/VectMonad.hs +++ b/compiler/vectorise/VectMonad.hs @@ -2,14 +2,15 @@ module VectMonad ( Scope(..), VM, - noV, traceNoV, tryV, maybeV, traceMaybeV, orElseV, fixV, localV, closedV, + noV, traceNoV, ensureV, traceEnsureV, tryV, maybeV, traceMaybeV, orElseV, + onlyIfV, fixV, localV, closedV, initV, cantVectorise, maybeCantVectorise, maybeCantVectoriseM, liftDs, cloneName, cloneId, cloneVar, newExportedVar, newLocalVar, newDummyVar, newTyVar, Builtins(..), sumTyCon, prodTyCon, - combinePAVar, + combinePAVar, scalarZip, closureCtrFun, builtin, builtins, GlobalEnv(..), @@ -21,7 +22,7 @@ module VectMonad ( getBindName, inBind, - lookupVar, defGlobalVar, + lookupVar, defGlobalVar, globalScalars, lookupTyCon, defTyCon, lookupDataCon, defDataCon, lookupTyConPA, defTyConPA, defTyConPAs, @@ -30,7 +31,7 @@ module VectMonad ( lookupPrimMethod, lookupPrimPArray, lookupTyVarPA, defLocalTyVar, defLocalTyVarWithPA, localTyVars, - {-lookupInst,-} lookupFamInst + lookupInst, lookupFamInst ) where #include "HsVersions.h" @@ -40,10 +41,12 @@ import VectBuiltIn import HscTypes hiding ( MonadThings(..) ) import Module ( PackageId ) import CoreSyn +import Class import TyCon import DataCon import Type import Var +import VarSet import VarEnv import Id import Name @@ -71,6 +74,10 @@ data GlobalEnv = GlobalEnv { -- global_vars :: VarEnv Var + -- Purely scalar variables. Code which mentions only these + -- variables doesn't have to be lifted. + , global_scalars :: VarSet + -- Exported variables which have a vectorised version -- , global_exported_vars :: VarEnv (Var, Var) @@ -130,6 +137,7 @@ initGlobalEnv :: VectInfo -> (InstEnv, InstEnv) -> FamInstEnvs -> GlobalEnv initGlobalEnv info instEnvs famInstEnvs = GlobalEnv { global_vars = mapVarEnv snd $ vectInfoVar info + , global_scalars = emptyVarSet , global_exported_vars = emptyVarEnv , global_tycons = mapNameEnv snd $ vectInfoTyCon info , global_datacons = mapNameEnv snd $ vectInfoDataCon info @@ -145,6 +153,10 @@ extendImportedVarsEnv :: [(Var, Var)] -> GlobalEnv -> GlobalEnv extendImportedVarsEnv ps genv = genv { global_vars = extendVarEnvList (global_vars genv) ps } +extendScalars :: [Var] -> GlobalEnv -> GlobalEnv +extendScalars vs genv + = genv { global_scalars = extendVarSetList (global_scalars genv) vs } + setFamInstEnv :: FamInstEnv -> GlobalEnv -> GlobalEnv setFamInstEnv l_fam_inst genv = genv { global_fam_inst_env = (g_fam_inst, l_fam_inst) } @@ -231,6 +243,17 @@ noV = VM $ \_ _ _ -> return No traceNoV :: String -> SDoc -> VM a traceNoV s d = pprTrace s d noV +ensureV :: Bool -> VM () +ensureV False = noV +ensureV True = return () + +onlyIfV :: Bool -> VM a -> VM a +onlyIfV b p = ensureV b >> p + +traceEnsureV :: String -> SDoc -> Bool -> VM () +traceEnsureV s d False = traceNoV s d +traceEnsureV s d True = return () + tryV :: VM a -> VM (Maybe a) tryV (VM p) = VM $ \bi genv lenv -> do @@ -301,10 +324,8 @@ setLEnv lenv = VM $ \_ genv _ -> return (Yes genv lenv ()) updLEnv :: (LocalEnv -> LocalEnv) -> VM () updLEnv f = VM $ \_ genv lenv -> return (Yes genv (f lenv) ()) -{- getInstEnv :: VM (InstEnv, InstEnv) getInstEnv = readGEnv global_inst_env --} getFamInstEnv :: VM FamInstEnvs getFamInstEnv = readGEnv global_fam_inst_env @@ -382,6 +403,9 @@ lookupVar v . maybeCantVectoriseM "Variable not vectorised:" (ppr v) . readGEnv $ \env -> lookupVarEnv (global_vars env) v +globalScalars :: VM VarSet +globalScalars = readGEnv global_scalars + lookupTyCon :: TyCon -> VM (Maybe TyCon) lookupTyCon tc | isUnLiftedTyCon tc || isTupleTyCon tc = return (Just tc) @@ -453,7 +477,6 @@ localTyVars = readLEnv (reverse . local_tyvars) -- instances head (i.e., no flexi vars); for details for what this means, -- see the docs at InstEnv.lookupInstEnv. -- -{- lookupInst :: Class -> [Type] -> VM (DFunId, [Type]) lookupInst cls tys = do { instEnv <- getInstEnv @@ -465,12 +488,12 @@ lookupInst cls tys where inst_tys' = [ty | Right ty <- inst_tys] noFlexiVar = all isRight inst_tys - _other -> traceNoV "lookupInst" (ppr cls <+> ppr tys) + _other -> + pprPanic "VectMonad.lookupInst: not found " (ppr cls <+> ppr tys) } where isRight (Left _) = False isRight (Right _) = True --} -- Look up the representation tycon of a family instance. -- @@ -520,12 +543,14 @@ initV pkg hsc_env guts info p builtin_pas <- initBuiltinPAs builtins builtin_prs <- initBuiltinPRs builtins builtin_boxed <- initBuiltinBoxedTyCons builtins + builtin_scalars <- initBuiltinScalars builtins eps <- liftIO $ hscEPS hsc_env let famInstEnvs = (eps_fam_inst_env eps, mg_fam_inst_env guts) instEnvs = (eps_inst_env eps, mg_inst_env guts) let genv = extendImportedVarsEnv builtin_vars + . extendScalars builtin_scalars . extendTyConsEnv builtin_tycons . extendDataConsEnv builtin_datacons . extendPAFunsEnv builtin_pas diff --git a/compiler/vectorise/VectUtils.hs b/compiler/vectorise/VectUtils.hs index 7aef39b..5c01461 100644 --- a/compiler/vectorise/VectUtils.hs +++ b/compiler/vectorise/VectUtils.hs @@ -12,6 +12,7 @@ module VectUtils ( prDFunOfTyCon, paDictArgType, paDictOfType, paDFunType, paMethod, mkPR, lengthPA, replicatePA, emptyPA, packPA, combinePA, liftPA, + zipScalars, scalarClosure, polyAbstract, polyApply, polyVApply, hoistBinding, hoistExpr, hoistPolyVExpr, takeHoisted, buildClosure, buildClosures, @@ -270,6 +271,24 @@ liftPA x lc <- builtin liftingContext replicatePA (Var lc) x +zipScalars :: [Type] -> Type -> VM CoreExpr +zipScalars arg_tys res_ty + = do + scalar <- builtin scalarClass + (dfuns, _) <- mapAndUnzipM (\ty -> lookupInst scalar [ty]) ty_args + zipf <- builtin (scalarZip $ length arg_tys) + return $ Var zipf `mkTyApps` ty_args `mkApps` map Var dfuns + where + ty_args = arg_tys ++ [res_ty] + +scalarClosure :: [Type] -> Type -> CoreExpr -> CoreExpr -> VM CoreExpr +scalarClosure arg_tys res_ty scalar_fun array_fun + = do + ctr <- builtin (closureCtrFun $ length arg_tys) + pas <- mapM paDictOfType (init arg_tys) + return $ Var ctr `mkTyApps` (arg_tys ++ [res_ty]) + `mkApps` (pas ++ [scalar_fun, array_fun]) + newLocalVVar :: FastString -> Type -> VM VVar newLocalVVar fs vty = do diff --git a/compiler/vectorise/Vectorise.hs b/compiler/vectorise/Vectorise.hs index cd1f429..bee160c 100644 --- a/compiler/vectorise/Vectorise.hs +++ b/compiler/vectorise/Vectorise.hs @@ -264,12 +264,44 @@ vectExpr (_, AnnLet (AnnRec bs) body) $ vectExpr rhs vectExpr e@(fvs, AnnLam bndr _) - | isId bndr = vectLam fvs bs body + | isId bndr = onlyIfV (isEmptyVarSet fvs) (vectScalarLam bs $ deAnnotate body) + `orElseV` vectLam fvs bs body where (bs,body) = collectAnnValBinders e vectExpr e = cantVectorise "Can't vectorise expression" (ppr $ deAnnotate e) +vectScalarLam :: [Var] -> CoreExpr -> VM VExpr +vectScalarLam args body + = do + scalars <- globalScalars + onlyIfV (all is_scalar_ty arg_tys + && is_scalar_ty res_ty + && is_scalar (extendVarSetList scalars args) body) + $ do + fn_var <- hoistExpr (fsLit "fn") (mkLams args body) + zipf <- zipScalars arg_tys res_ty + clo <- scalarClosure arg_tys res_ty (Var fn_var) + (zipf `App` Var fn_var) + clo_var <- hoistExpr (fsLit "clo") clo + lclo <- liftPA (Var clo_var) + return (Var clo_var, lclo) + where + arg_tys = map idType args + res_ty = exprType body + + is_scalar_ty ty | Just (tycon, []) <- splitTyConApp_maybe ty + = tycon == intTyCon + || tycon == floatTyCon + || tycon == doubleTyCon + + | otherwise = False + + is_scalar vs (Var v) = v `elemVarSet` vs + is_scalar _ e@(Lit l) = is_scalar_ty $ exprType e + is_scalar vs (App e1 e2) = is_scalar vs e1 && is_scalar vs e2 + is_scalar _ _ = False + vectLam :: VarSet -> [Var] -> CoreExprWithFVs -> VM VExpr vectLam fvs bs body = do -- 1.7.10.4