Try not to avoid vectorising purely scalar functions
authorRoman Leshchinskiy <rl@cse.unsw.edu.au>
Fri, 6 Mar 2009 11:55:08 +0000 (11:55 +0000)
committerRoman Leshchinskiy <rl@cse.unsw.edu.au>
Fri, 6 Mar 2009 11:55:08 +0000 (11:55 +0000)
compiler/vectorise/VectBuiltIn.hs
compiler/vectorise/VectMonad.hs
compiler/vectorise/VectUtils.hs
compiler/vectorise/Vectorise.hs

index cbcb47d..4fe7e9e 100644 (file)
@@ -1,9 +1,9 @@
 module VectBuiltIn (
   Builtins(..), sumTyCon, prodTyCon,
-  combinePAVar,
+  combinePAVar, scalarZip, closureCtrFun,
   initBuiltins, initBuiltinVars, initBuiltinTyCons, initBuiltinDataCons,
   initBuiltinPAs, initBuiltinPRs,
-  initBuiltinBoxedTyCons,
+  initBuiltinBoxedTyCons, initBuiltinScalars,
 
   primMethod, primPArray
 ) where
@@ -14,6 +14,7 @@ import IfaceEnv        ( lookupOrig )
 import Module
 import DataCon         ( DataCon, dataConName, dataConWorkId )
 import TyCon           ( TyCon, tyConName, tyConDataCons )
+import Class           ( Class )
 import Var             ( Var )
 import Id              ( mkSysLocal )
 import Name            ( Name, getOccString )
@@ -48,6 +49,9 @@ mAX_DPH_SUM = 3
 mAX_DPH_COMBINE :: Int
 mAX_DPH_COMBINE = 2
 
+mAX_DPH_SCALAR_ARGS :: Int
+mAX_DPH_SCALAR_ARGS = 3
+
 data Modules = Modules {
                    dph_PArray :: Module
                  , dph_Repr :: Module
@@ -55,6 +59,7 @@ data Modules = Modules {
                  , dph_Unboxed :: Module
                  , dph_Instances :: Module
                  , dph_Combinators :: Module
+                 , dph_Scalar :: Module
                  , dph_Prelude_PArr :: Module
                  , dph_Prelude_Int :: Module
                  , dph_Prelude_Word8 :: Module
@@ -71,6 +76,7 @@ dph_Modules pkg = Modules {
   , dph_Unboxed        = mk (fsLit "Data.Array.Parallel.Lifted.Unboxed")
   , dph_Instances      = mk (fsLit "Data.Array.Parallel.Lifted.Instances")
   , dph_Combinators    = mk (fsLit "Data.Array.Parallel.Lifted.Combinators")
+  , dph_Scalar         = mk (fsLit "Data.Array.Parallel.Lifted.Scalar")
 
   , dph_Prelude_PArr   = mk (fsLit "Data.Array.Parallel.Prelude.Base.PArr")
   , dph_Prelude_Int    = mk (fsLit "Data.Array.Parallel.Prelude.Base.Int")
@@ -112,6 +118,9 @@ data Builtins = Builtins {
                 , emptyPAVar       :: Var
                 , packPAVar        :: Var
                 , combinePAVars    :: Array Int Var
+                , scalarClass      :: Class
+                , scalarZips       :: Array Int Var
+                , closureCtrFuns   :: Array Int Var
                 , liftingContext   :: Var
                 }
 
@@ -131,6 +140,16 @@ combinePAVar n bi
   | n >= 2 && n <= mAX_DPH_COMBINE = combinePAVars bi ! n
   | otherwise = pprPanic "combinePAVar" (ppr n)
 
+scalarZip :: Int -> Builtins -> Var
+scalarZip n bi
+  | n >= 1 && n <= mAX_DPH_SCALAR_ARGS = scalarZips bi ! n
+  | otherwise = pprPanic "scalarZip" (ppr n)
+
+closureCtrFun :: Int -> Builtins -> Var
+closureCtrFun n bi
+  | n >= 1 && n <= mAX_DPH_SCALAR_ARGS = closureCtrFuns bi ! n
+  | otherwise = pprPanic "closureCtrFun" (ppr n)
+
 initBuiltins :: PackageId -> DsM Builtins
 initBuiltins pkg
   = do
@@ -171,6 +190,19 @@ initBuiltins pkg
                           | i <- [2..mAX_DPH_COMBINE]]
       let combinePAVars = listArray (2, mAX_DPH_COMBINE) combines
 
+      scalarClass <- externalClass dph_Scalar (fsLit "Scalar")
+      scalar_map <- externalVar dph_Scalar (fsLit "scalar_map")
+      scalar_zip2 <- externalVar dph_Scalar (fsLit "scalar_zipWith")
+      scalar_zips <- mapM (externalVar dph_Scalar)
+                          [mkFastString ("scalar_zipWith" ++ show i)
+                             | i <- [3 .. mAX_DPH_SCALAR_ARGS]]
+      let scalarZips = listArray (1, mAX_DPH_SCALAR_ARGS)
+                                 (scalar_map : scalar_zip2 : scalar_zips)
+      closures <- mapM (externalVar dph_Closure)
+                       [mkFastString ("closure" ++ show i)
+                          | i <- [1 .. mAX_DPH_SCALAR_ARGS]]
+      let closureCtrFuns = listArray (1, mAX_DPH_COMBINE) closures
+
       liftingContext <- liftM (\u -> mkSysLocal (fsLit "lc") u intPrimTy)
                               newUnique
 
@@ -203,6 +235,9 @@ initBuiltins pkg
                , emptyPAVar       = emptyPAVar
                , packPAVar        = packPAVar
                , combinePAVars    = combinePAVars
+               , scalarClass      = scalarClass
+               , scalarZips       = scalarZips
+               , closureCtrFuns   = closureCtrFuns
                , liftingContext   = liftingContext
                }
   where
@@ -211,6 +246,7 @@ initBuiltins pkg
              , dph_Repr           = dph_Repr
              , dph_Closure        = dph_Closure
              , dph_Unboxed        = dph_Unboxed
+             , dph_Scalar         = dph_Scalar
              })
       = dph_Modules pkg
 
@@ -447,6 +483,91 @@ builtinBoxedTyCons :: Builtins -> [(Name, TyCon)]
 builtinBoxedTyCons _ =
   [(tyConName intPrimTyCon, intTyCon)]
 
+
+initBuiltinScalars :: Builtins -> DsM [Var]
+initBuiltinScalars bi
+  = mapM (uncurry externalVar) (preludeScalars $ dphModules bi)
+
+
+preludeScalars :: Modules -> [(Module, FastString)]
+preludeScalars (Modules { dph_Prelude_Int    = dph_Prelude_Int
+                        , dph_Prelude_Word8  = dph_Prelude_Word8
+                        , dph_Prelude_Double = dph_Prelude_Double
+                        })
+  = [
+      mk dph_Prelude_Int "div"
+    , mk dph_Prelude_Int "mod"
+    , mk dph_Prelude_Int "sqrt"
+    ]
+    ++ scalars_Ord dph_Prelude_Int
+    ++ scalars_Num dph_Prelude_Int
+
+    ++ scalars_Ord dph_Prelude_Word8
+    ++ scalars_Num dph_Prelude_Word8
+    ++
+    [ mk dph_Prelude_Word8 "div"
+    , mk dph_Prelude_Word8 "mod"
+    , mk dph_Prelude_Word8 "fromInt"
+    , mk dph_Prelude_Word8 "toInt"
+    ]
+
+    ++ scalars_Ord dph_Prelude_Double
+    ++ scalars_Num dph_Prelude_Double
+    ++ scalars_Fractional dph_Prelude_Double
+    ++ scalars_Floating dph_Prelude_Double
+    ++ scalars_RealFrac dph_Prelude_Double
+  where
+    mk mod s = (mod, fsLit s)
+
+    scalars_Ord mod = [mk mod "=="
+                      ,mk mod "/="
+                      ,mk mod "<="
+                      ,mk mod "<"
+                      ,mk mod ">="
+                      ,mk mod ">"
+                      ,mk mod "min"
+                      ,mk mod "max"
+                      ]
+
+    scalars_Num mod = [mk mod "+"
+                      ,mk mod "-"
+                      ,mk mod "*"
+                      ,mk mod "negate"
+                      ,mk mod "abs"
+                      ]
+
+    scalars_Fractional mod = [mk mod "/"
+                             ,mk mod "recip"
+                             ]
+
+    scalars_Floating mod = [mk mod "pi"
+                           ,mk mod "exp"
+                           ,mk mod "sqrt"
+                           ,mk mod "log"
+                           ,mk mod "sin"
+                           ,mk mod "tan"
+                           ,mk mod "cos"
+                           ,mk mod "asin"
+                           ,mk mod "atan"
+                           ,mk mod "acos"
+                           ,mk mod "sinh"
+                           ,mk mod "tanh"
+                           ,mk mod "cosh"
+                           ,mk mod "asinh"
+                           ,mk mod "atanh"
+                           ,mk mod "acosh"
+                           ,mk mod "**"
+                           ,mk mod "logBase"
+                           ]
+
+    scalars_RealFrac mod = [mk mod "fromInt"
+                           ,mk mod "truncate"
+                           ,mk mod "round"
+                           ,mk mod "ceiling"
+                           ,mk mod "floor"
+                           ]
+
+
 externalVar :: Module -> FastString -> DsM Var
 externalVar mod fs
   = dsLookupGlobalId =<< lookupOrig mod (mkVarOccFS fs)
@@ -461,6 +582,10 @@ externalType mod fs
       tycon <- externalTyCon mod fs
       return $ mkTyConApp tycon []
 
+externalClass :: Module -> FastString -> DsM Class
+externalClass mod fs
+  = dsLookupClass =<< lookupOrig mod (mkTcOccFS fs)
+
 unitTyConName :: Name
 unitTyConName = tyConName unitTyCon
 
index 56f5b8f..bc120cd 100644 (file)
@@ -2,14 +2,15 @@ module VectMonad (
   Scope(..),
   VM,
 
-  noV, traceNoV, tryV, maybeV, traceMaybeV, orElseV, fixV, localV, closedV,
+  noV, traceNoV, ensureV, traceEnsureV, tryV, maybeV, traceMaybeV, orElseV,
+  onlyIfV, fixV, localV, closedV,
   initV, cantVectorise, maybeCantVectorise, maybeCantVectoriseM,
   liftDs,
   cloneName, cloneId, cloneVar,
   newExportedVar, newLocalVar, newDummyVar, newTyVar,
   
   Builtins(..), sumTyCon, prodTyCon,
-  combinePAVar,
+  combinePAVar, scalarZip, closureCtrFun,
   builtin, builtins,
 
   GlobalEnv(..),
@@ -21,7 +22,7 @@ module VectMonad (
 
   getBindName, inBind,
 
-  lookupVar, defGlobalVar,
+  lookupVar, defGlobalVar, globalScalars,
   lookupTyCon, defTyCon,
   lookupDataCon, defDataCon,
   lookupTyConPA, defTyConPA, defTyConPAs,
@@ -30,7 +31,7 @@ module VectMonad (
   lookupPrimMethod, lookupPrimPArray,
   lookupTyVarPA, defLocalTyVar, defLocalTyVarWithPA, localTyVars,
 
-  {-lookupInst,-} lookupFamInst
+  lookupInst, lookupFamInst
 ) where
 
 #include "HsVersions.h"
@@ -40,10 +41,12 @@ import VectBuiltIn
 import HscTypes hiding  ( MonadThings(..) )
 import Module           ( PackageId )
 import CoreSyn
+import Class
 import TyCon
 import DataCon
 import Type
 import Var
+import VarSet
 import VarEnv
 import Id
 import Name
@@ -71,6 +74,10 @@ data GlobalEnv = GlobalEnv {
                   -- 
                   global_vars :: VarEnv Var
 
+                  -- Purely scalar variables. Code which mentions only these
+                  -- variables doesn't have to be lifted.
+                , global_scalars :: VarSet
+
                   -- Exported variables which have a vectorised version
                   --
                 , global_exported_vars :: VarEnv (Var, Var)
@@ -130,6 +137,7 @@ initGlobalEnv :: VectInfo -> (InstEnv, InstEnv) -> FamInstEnvs -> GlobalEnv
 initGlobalEnv info instEnvs famInstEnvs
   = GlobalEnv {
       global_vars          = mapVarEnv snd $ vectInfoVar info
+    , global_scalars   = emptyVarSet
     , global_exported_vars = emptyVarEnv
     , global_tycons        = mapNameEnv snd $ vectInfoTyCon info
     , global_datacons      = mapNameEnv snd $ vectInfoDataCon info
@@ -145,6 +153,10 @@ extendImportedVarsEnv :: [(Var, Var)] -> GlobalEnv -> GlobalEnv
 extendImportedVarsEnv ps genv
   = genv { global_vars = extendVarEnvList (global_vars genv) ps }
 
+extendScalars :: [Var] -> GlobalEnv -> GlobalEnv
+extendScalars vs genv
+  = genv { global_scalars = extendVarSetList (global_scalars genv) vs }
+
 setFamInstEnv :: FamInstEnv -> GlobalEnv -> GlobalEnv
 setFamInstEnv l_fam_inst genv
   = genv { global_fam_inst_env = (g_fam_inst, l_fam_inst) }
@@ -231,6 +243,17 @@ noV = VM $ \_ _ _ -> return No
 traceNoV :: String -> SDoc -> VM a
 traceNoV s d = pprTrace s d noV
 
+ensureV :: Bool -> VM ()
+ensureV False = noV
+ensureV True  = return ()
+
+onlyIfV :: Bool -> VM a -> VM a
+onlyIfV b p = ensureV b >> p
+
+traceEnsureV :: String -> SDoc -> Bool -> VM ()
+traceEnsureV s d False = traceNoV s d
+traceEnsureV s d True  = return ()
+
 tryV :: VM a -> VM (Maybe a)
 tryV (VM p) = VM $ \bi genv lenv ->
   do
@@ -301,10 +324,8 @@ setLEnv lenv = VM $ \_ genv _ -> return (Yes genv lenv ())
 updLEnv :: (LocalEnv -> LocalEnv) -> VM ()
 updLEnv f = VM $ \_ genv lenv -> return (Yes genv (f lenv) ())
 
-{-
 getInstEnv :: VM (InstEnv, InstEnv)
 getInstEnv = readGEnv global_inst_env
--}
 
 getFamInstEnv :: VM FamInstEnvs
 getFamInstEnv = readGEnv global_fam_inst_env
@@ -382,6 +403,9 @@ lookupVar v
                 . maybeCantVectoriseM "Variable not vectorised:" (ppr v)
                 . readGEnv $ \env -> lookupVarEnv (global_vars env) v
 
+globalScalars :: VM VarSet
+globalScalars = readGEnv global_scalars
+
 lookupTyCon :: TyCon -> VM (Maybe TyCon)
 lookupTyCon tc
   | isUnLiftedTyCon tc || isTupleTyCon tc = return (Just tc)
@@ -453,7 +477,6 @@ localTyVars = readLEnv (reverse . local_tyvars)
 -- instances head (i.e., no flexi vars); for details for what this means,
 -- see the docs at InstEnv.lookupInstEnv.
 --
-{-
 lookupInst :: Class -> [Type] -> VM (DFunId, [Type])
 lookupInst cls tys
   = do { instEnv <- getInstEnv
@@ -465,12 +488,12 @@ lookupInst cls tys
              where
                inst_tys'  = [ty | Right ty <- inst_tys]
                noFlexiVar = all isRight inst_tys
-          _other         -> traceNoV "lookupInst" (ppr cls <+> ppr tys)
+          _other         ->
+             pprPanic "VectMonad.lookupInst: not found " (ppr cls <+> ppr tys)
        }
   where
     isRight (Left  _) = False
     isRight (Right _) = True
--}
 
 -- Look up the representation tycon of a family instance.
 --
@@ -520,12 +543,14 @@ initV pkg hsc_env guts info p
         builtin_pas    <- initBuiltinPAs builtins
         builtin_prs    <- initBuiltinPRs builtins
         builtin_boxed  <- initBuiltinBoxedTyCons builtins
+        builtin_scalars <- initBuiltinScalars builtins
 
         eps <- liftIO $ hscEPS hsc_env
         let famInstEnvs = (eps_fam_inst_env eps, mg_fam_inst_env guts)
             instEnvs    = (eps_inst_env     eps, mg_inst_env     guts)
 
         let genv = extendImportedVarsEnv builtin_vars
+                 . extendScalars builtin_scalars
                  . extendTyConsEnv builtin_tycons
                  . extendDataConsEnv builtin_datacons
                  . extendPAFunsEnv builtin_pas
index 7aef39b..5c01461 100644 (file)
@@ -12,6 +12,7 @@ module VectUtils (
   prDFunOfTyCon,
   paDictArgType, paDictOfType, paDFunType,
   paMethod, mkPR, lengthPA, replicatePA, emptyPA, packPA, combinePA, liftPA,
+  zipScalars, scalarClosure,
   polyAbstract, polyApply, polyVApply,
   hoistBinding, hoistExpr, hoistPolyVExpr, takeHoisted,
   buildClosure, buildClosures,
@@ -270,6 +271,24 @@ liftPA x
       lc <- builtin liftingContext
       replicatePA (Var lc) x
 
+zipScalars :: [Type] -> Type -> VM CoreExpr
+zipScalars arg_tys res_ty
+  = do
+      scalar <- builtin scalarClass
+      (dfuns, _) <- mapAndUnzipM (\ty -> lookupInst scalar [ty]) ty_args
+      zipf <- builtin (scalarZip $ length arg_tys)
+      return $ Var zipf `mkTyApps` ty_args `mkApps` map Var dfuns
+    where
+      ty_args = arg_tys ++ [res_ty]
+
+scalarClosure :: [Type] -> Type -> CoreExpr -> CoreExpr -> VM CoreExpr
+scalarClosure arg_tys res_ty scalar_fun array_fun
+  = do
+      ctr <- builtin (closureCtrFun $ length arg_tys)
+      pas <- mapM paDictOfType (init arg_tys)
+      return $ Var ctr `mkTyApps` (arg_tys ++ [res_ty])
+                       `mkApps`   (pas ++ [scalar_fun, array_fun])
+
 newLocalVVar :: FastString -> Type -> VM VVar
 newLocalVVar fs vty
   = do
index cd1f429..bee160c 100644 (file)
@@ -264,12 +264,44 @@ vectExpr (_, AnnLet (AnnRec bs) body)
                       $ vectExpr rhs
 
 vectExpr e@(fvs, AnnLam bndr _)
-  | isId bndr = vectLam fvs bs body
+  | isId bndr = onlyIfV (isEmptyVarSet fvs) (vectScalarLam bs $ deAnnotate body)
+                `orElseV` vectLam fvs bs body
   where
     (bs,body) = collectAnnValBinders e
 
 vectExpr e = cantVectorise "Can't vectorise expression" (ppr $ deAnnotate e)
 
+vectScalarLam :: [Var] -> CoreExpr -> VM VExpr
+vectScalarLam args body
+  = do
+      scalars <- globalScalars
+      onlyIfV (all is_scalar_ty arg_tys
+               && is_scalar_ty res_ty
+               && is_scalar (extendVarSetList scalars args) body)
+        $ do
+            fn_var <- hoistExpr (fsLit "fn") (mkLams args body)
+            zipf <- zipScalars arg_tys res_ty
+            clo <- scalarClosure arg_tys res_ty (Var fn_var)
+                                                (zipf `App` Var fn_var)
+            clo_var <- hoistExpr (fsLit "clo") clo
+            lclo <- liftPA (Var clo_var)
+            return (Var clo_var, lclo)
+  where
+    arg_tys = map idType args
+    res_ty  = exprType body
+
+    is_scalar_ty ty | Just (tycon, []) <- splitTyConApp_maybe ty
+                    = tycon == intTyCon
+                      || tycon == floatTyCon
+                      || tycon == doubleTyCon
+
+                    | otherwise = False
+
+    is_scalar vs (Var v)     = v `elemVarSet` vs
+    is_scalar _ e@(Lit l)    = is_scalar_ty $ exprType e
+    is_scalar vs (App e1 e2) = is_scalar vs e1 && is_scalar vs e2
+    is_scalar _ _            = False
+
 vectLam :: VarSet -> [Var] -> CoreExprWithFVs -> VM VExpr
 vectLam fvs bs body
   = do