-- | Vectorisation of expressions.
-module Vectorise.Exp
- (vectPolyExpr)
-where
-import VectUtils
-import VectType
-import Vectorise.Utils.Closure
-import Vectorise.Utils.Hoisting
+module Vectorise.Exp (
+
+ -- Vectorise a polymorphic expression
+ vectPolyExpr,
+
+ -- Vectorise a scalar expression of functional type
+ vectScalarFun
+) where
+
+#include "HsVersions.h"
+
+import Vectorise.Type.Type
import Vectorise.Var
import Vectorise.Vect
import Vectorise.Env
import Vectorise.Monad
import Vectorise.Builtins
+import Vectorise.Utils
import CoreSyn
import CoreUtils
import VarEnv
import VarSet
import Id
-import BasicTypes
+import BasicTypes( isLoopBreaker )
import Literal
import TysWiredIn
import TysPrim
-- | Vectorise a polymorphic expression.
-vectPolyExpr
- :: Bool -- ^ When vectorising the RHS of a binding, whether that
- -- binding is a loop breaker.
- -> CoreExprWithFVs
- -> VM (Inline, VExpr)
-
-vectPolyExpr loop_breaker (_, AnnNote note expr)
- = do (inline, expr') <- vectPolyExpr loop_breaker expr
- return (inline, vNote note expr')
-
-vectPolyExpr loop_breaker expr
+--
+vectPolyExpr :: Bool -- ^ When vectorising the RHS of a binding, whether that
+ -- binding is a loop breaker.
+ -> [Var]
+ -> CoreExprWithFVs
+ -> VM (Inline, Bool, VExpr)
+vectPolyExpr loop_breaker recFns (_, AnnNote note expr)
+ = do (inline, isScalarFn, expr') <- vectPolyExpr loop_breaker recFns expr
+ return (inline, isScalarFn, vNote note expr')
+vectPolyExpr loop_breaker recFns expr
= do
arity <- polyArity tvs
polyAbstract tvs $ \args ->
do
- (inline, mono') <- vectFnExpr False loop_breaker mono
- return (addInlineArity inline arity,
+ (inline, isScalarFn, mono') <- vectFnExpr False loop_breaker recFns mono
+ return (addInlineArity inline arity, isScalarFn,
mapVect (mkLams $ tvs ++ args) mono')
where
(tvs, mono) = collectAnnTypeBinders expr
| Just (tycon, ty_args) <- splitTyConApp_maybe scrut_ty
, isAlgTyCon tycon
= vectAlgCase tycon ty_args scrut bndr ty alts
+ | otherwise = cantVectorise "Can't vectorise expression" (ppr scrut_ty)
where
scrut_ty = exprType (deAnnotate scrut)
vectExpr (_, AnnLet (AnnNonRec bndr rhs) body)
= do
- vrhs <- localV . inBind bndr . liftM snd $ vectPolyExpr False rhs
+ vrhs <- localV . inBind bndr . liftM (\(_,_,z)->z) $ vectPolyExpr False [] rhs
(vbndr, vbody) <- vectBndrIn bndr (vectExpr body)
return $ vLet (vNonRec vbndr vrhs) vbody
vect_rhs bndr rhs = localV
. inBind bndr
- . liftM snd
- $ vectPolyExpr (isLoopBreaker $ idOccInfo bndr) rhs
+ . liftM (\(_,_,z)->z)
+ $ vectPolyExpr (isLoopBreaker $ idOccInfo bndr) [] rhs
vectExpr e@(_, AnnLam bndr _)
- | isId bndr = liftM snd $ vectFnExpr True False e
+ | isId bndr = liftM (\(_,_,z) ->z) $ vectFnExpr True False [] e
{-
onlyIfV (isEmptyVarSet fvs) (vectScalarLam bs $ deAnnotate body)
`orElseV` vectLam True fvs bs body
(bs,body) = collectAnnValBinders e
-}
-vectExpr e = cantVectorise "Can't vectorise expression" (ppr $ deAnnotate e)
-
+vectExpr e = cantVectorise "Can't vectorise expression (vectExpr)" (ppr $ deAnnotate e)
-- | Vectorise an expression with an outer lambda abstraction.
-vectFnExpr
- :: Bool -- ^ When the RHS of a binding, whether that binding should be inlined.
- -> Bool -- ^ Whether the binding is a loop breaker.
- -> CoreExprWithFVs -- ^ Expression to vectorise. Must have an outer `AnnLam`.
- -> VM (Inline, VExpr)
-
-vectFnExpr inline loop_breaker e@(fvs, AnnLam bndr _)
- | isId bndr = onlyIfV (isEmptyVarSet fvs)
- (mark DontInline . vectScalarLam bs $ deAnnotate body)
- `orElseV` mark inlineMe (vectLam inline loop_breaker fvs bs body)
- where
- (bs,body) = collectAnnValBinders e
-
-vectFnExpr _ _ e = mark DontInline $ vectExpr e
-
-mark :: Inline -> VM a -> VM (Inline, a)
-mark b p = do { x <- p; return (b,x) }
-
-
--- | Vectorise a function where are the args have scalar type,
--- that is Int, Float, Double etc.
-vectScalarLam
- :: [Var] -- ^ Bound variables of function.
- -> CoreExpr -- ^ Function body.
- -> VM VExpr
-
-vectScalarLam args body
- = do scalars <- globalScalars
- onlyIfV (all is_scalar_ty arg_tys
- && is_scalar_ty res_ty
- && is_scalar (extendVarSetList scalars args) body
- && uses scalars body)
- $ do
- fn_var <- hoistExpr (fsLit "fn") (mkLams args body) DontInline
- zipf <- zipScalars arg_tys res_ty
- clo <- scalarClosure arg_tys res_ty (Var fn_var)
- (zipf `App` Var fn_var)
- clo_var <- hoistExpr (fsLit "clo") clo DontInline
- lclo <- liftPD (Var clo_var)
- return (Var clo_var, lclo)
+--
+vectFnExpr :: Bool -- ^ If we process the RHS of a binding, whether that binding should
+ -- be inlined
+ -> Bool -- ^ Whether the binding is a loop breaker
+ -> [Var] -- ^ Names of function in same recursive binding group
+ -> CoreExprWithFVs -- ^ Expression to vectorise; must have an outer `AnnLam`
+ -> VM (Inline, Bool, VExpr)
+vectFnExpr inline loop_breaker recFns expr@(_fvs, AnnLam bndr _)
+ | isId bndr = mark DontInline True (vectScalarFun False recFns (deAnnotate expr))
+ `orElseV`
+ mark inlineMe False (vectLam inline loop_breaker expr)
+vectFnExpr _ _ _ e = mark DontInline False $ vectExpr e
+
+mark :: Inline -> Bool -> VM a -> VM (Inline, Bool, a)
+mark b isScalarFn p = do { x <- p; return (b, isScalarFn, x) }
+
+-- |Vectorise an expression of functional type, where all arguments and the result are of scalar
+-- type (i.e., 'Int', 'Float', 'Double' etc.) and which does not contain any subcomputations that
+-- involve parallel arrays. Such functionals do not requires the full blown vectorisation
+-- transformation; instead, they can be lifted by application of a member of the zipWith family
+-- (i.e., 'map', 'zipWith', zipWith3', etc.)
+--
+vectScalarFun :: Bool -- ^ Was the function marked as scalar by the user?
+ -> [Var] -- ^ Functions names in same recursive binding group
+ -> CoreExpr -- ^ Expression to be vectorised
+ -> VM VExpr
+vectScalarFun forceScalar recFns expr
+ = do { gscalars <- globalScalars
+ ; let scalars = gscalars `extendVarSetList` recFns
+ (arg_tys, res_ty) = splitFunTys (exprType expr)
+ ; MASSERT( not $ null arg_tys )
+ ; onlyIfV (forceScalar -- user asserts the functions is scalar
+ ||
+ all is_prim_ty arg_tys -- check whether the function is scalar
+ && is_prim_ty res_ty
+ && is_scalar scalars expr
+ && uses scalars expr)
+ $ mkScalarFun arg_tys res_ty expr
+ }
where
- arg_tys = map idType args
- res_ty = exprType body
-
- is_scalar_ty ty
+ -- FIXME: This is woefully insufficient!!! We need a scalar pragma for types!!!
+ is_prim_ty ty
| Just (tycon, []) <- splitTyConApp_maybe ty
= tycon == intTyCon
|| tycon == floatTyCon
|| tycon == doubleTyCon
-
| otherwise = False
- is_scalar vs (Var v) = v `elemVarSet` vs
- is_scalar _ e@(Lit _) = is_scalar_ty $ exprType e
- is_scalar vs (App e1 e2) = is_scalar vs e1 && is_scalar vs e2
- is_scalar _ _ = False
-
+ -- Checks whether an expression contain a non-scalar subexpression.
+ --
+ -- Precodition: The variables in the first argument are scalar.
+ --
+ -- In case of a recursive binding group, we /assume/ that all bindings are scalar (by adding
+ -- them to the list of scalar variables) and then check them. If one of them turns out not to
+ -- be scalar, the entire group is regarded as not being scalar.
+ --
+ -- FIXME: Currently, doesn't regard external (non-data constructor) variable and anonymous
+ -- data constructor as scalar. Should be changed once scalar types are passed
+ -- through VectInfo.
+ --
+ is_scalar :: VarSet -> CoreExpr -> Bool
+ is_scalar scalars (Var v) = v `elemVarSet` scalars
+ is_scalar _scalars (Lit _) = True
+ is_scalar scalars e@(App e1 e2)
+ | maybe_parr_ty (exprType e) = False
+ | otherwise = is_scalar scalars e1 && is_scalar scalars e2
+ is_scalar scalars (Lam var body)
+ | maybe_parr_ty (varType var) = False
+ | otherwise = is_scalar (scalars `extendVarSet` var) body
+ is_scalar scalars (Let bind body) = bindsAreScalar && is_scalar scalars' body
+ where
+ (bindsAreScalar, scalars') = is_scalar_bind scalars bind
+ is_scalar scalars (Case e var ty alts)
+ | is_prim_ty ty = is_scalar scalars' e && all (is_scalar_alt scalars') alts
+ | otherwise = False
+ where
+ scalars' = scalars `extendVarSet` var
+ is_scalar scalars (Cast e _coe) = is_scalar scalars e
+ is_scalar scalars (Note _ e ) = is_scalar scalars e
+ is_scalar _scalars (Type _) = True
+
+ -- Result: (<is this binding group scalar>, scalars ++ variables bound in this group)
+ is_scalar_bind scalars (NonRec var e) = (is_scalar scalars e, scalars `extendVarSet` var)
+ is_scalar_bind scalars (Rec bnds) = (all (is_scalar scalars') es, scalars')
+ where
+ (vars, es) = unzip bnds
+ scalars' = scalars `extendVarSetList` vars
+
+ is_scalar_alt scalars (_, vars, e) = is_scalar (scalars `extendVarSetList ` vars) e
+
+ -- Checks whether the type might be a parallel array type. In particular, if the outermost
+ -- constructor is a type family, we conservatively assume that it may be a parallel array type.
+ maybe_parr_ty :: Type -> Bool
+ maybe_parr_ty ty
+ | Just ty' <- coreView ty = maybe_parr_ty ty'
+ | Just (tyCon, _) <- splitTyConApp_maybe ty = isPArrTyCon tyCon || isSynFamilyTyCon tyCon
+ maybe_parr_ty _ = False
+
+ -- FIXME: I'm not convinced that this reasoning is (always) sound. If the identify functions
+ -- is called by some other function that is otherwise scalar, it would be very bad
+ -- that just this call to the identity makes it not be scalar.
-- A scalar function has to actually compute something. Without the check,
-- we would treat (\(x :: Int) -> x) as a scalar function and lift it to
-- (map (\x -> x)) which is very bad. Normal lifting transforms it to
-- (\n# x -> x) which is what we want.
- uses funs (Var v) = v `elemVarSet` funs
- uses funs (App e1 e2) = uses funs e1 || uses funs e2
- uses _ _ = False
-
+ uses funs (Var v) = v `elemVarSet` funs
+ uses funs (App e1 e2) = uses funs e1 || uses funs e2
+ uses funs (Lam b body) = uses (funs `extendVarSet` b) body
+ uses funs (Let (NonRec _b letExpr) body)
+ = uses funs letExpr || uses funs body
+ uses funs (Case e _eId _ty alts)
+ = uses funs e || any (uses_alt funs) alts
+ uses _ _ = False
+
+ uses_alt funs (_, _bs, e) = uses funs e
+
+mkScalarFun :: [Type] -> Type -> CoreExpr -> VM VExpr
+mkScalarFun arg_tys res_ty expr
+ = do { fn_var <- hoistExpr (fsLit "fn") expr DontInline
+ ; zipf <- zipScalars arg_tys res_ty
+ ; clo <- scalarClosure arg_tys res_ty (Var fn_var) (zipf `App` Var fn_var)
+ ; clo_var <- hoistExpr (fsLit "clo") clo DontInline
+ ; lclo <- liftPD (Var clo_var)
+ ; return (Var clo_var, lclo)
+ }
-- | Vectorise a lambda abstraction.
-vectLam
- :: Bool -- ^ When the RHS of a binding, whether that binding should be inlined.
- -> Bool -- ^ Whether the binding is a loop breaker.
- -> VarSet -- ^ The free variables in the body.
- -> [Var] -- ^ Binding variables.
- -> CoreExprWithFVs -- ^ Body of abstraction.
- -> VM VExpr
-
-vectLam inline loop_breaker fvs bs body
- = do tyvars <- localTyVars
+--
+vectLam :: Bool -- ^ When the RHS of a binding, whether that binding should be inlined.
+ -> Bool -- ^ Whether the binding is a loop breaker.
+ -> CoreExprWithFVs -- ^ Body of abstraction.
+ -> VM VExpr
+vectLam inline loop_breaker expr@(fvs, AnnLam _ _)
+ = do let (bs, body) = collectAnnValBinders expr
+
+ tyvars <- localTyVars
(vs, vvs) <- readLEnv $ \env ->
unzip [(var, vv) | var <- varSetElems fvs
, Just vv <- [lookupVarEnv (local_vars env) var]]
(LitAlt (mkMachInt 0), [], empty)])
| otherwise = return (ve, le)
+vectLam _ _ _ = panic "vectLam"
vectTyAppExpr :: CoreExprWithFVs -> [Type] -> VM VExpr
vectTyAppExpr (_, AnnVar v) tys = vectPolyVar v tys
-vectTyAppExpr e tys = cantVectorise "Can't vectorise expression"
+vectTyAppExpr e tys = cantVectorise "Can't vectorise expression (vectTyExpr)"
(ppr $ deAnnotate e `mkTyApps` tys)