X-Git-Url: http://git.megacz.com/?a=blobdiff_plain;f=compiler%2Fvectorise%2FVectorise%2FExp.hs;h=569057e5e88deaa982ff0ad4ec69695e605f6ea3;hb=6cec61d14a324285dbb8ce73d4c7215f1f8d6766;hp=d35c9473416925c96daeef7825bc9fd07677b6f8;hpb=6766a6827970b340233a35faa9557455a4e11c1e;p=ghc-hetmet.git diff --git a/compiler/vectorise/Vectorise/Exp.hs b/compiler/vectorise/Vectorise/Exp.hs index d35c947..dbdf6e1 100644 --- a/compiler/vectorise/Vectorise/Exp.hs +++ b/compiler/vectorise/Vectorise/Exp.hs @@ -1,15 +1,23 @@ -- | Vectorisation of expressions. -module Vectorise.Exp - (vectPolyExpr) -where -import VectUtils -import VectVar -import VectType +module Vectorise.Exp ( + + -- Vectorise a polymorphic expression + vectPolyExpr, + + -- Vectorise a scalar expression of functional type + vectScalarFun +) where + +#include "HsVersions.h" + +import Vectorise.Type.Type +import Vectorise.Var import Vectorise.Vect import Vectorise.Env import Vectorise.Monad import Vectorise.Builtins +import Vectorise.Utils import CoreSyn import CoreUtils @@ -22,7 +30,7 @@ import Var import VarEnv import VarSet import Id -import BasicTypes +import BasicTypes( isLoopBreaker ) import Literal import TysWiredIn import TysPrim @@ -33,23 +41,22 @@ import Data.List -- | Vectorise a polymorphic expression. -vectPolyExpr - :: Bool -- ^ When vectorising the RHS of a binding, whether that - -- binding is a loop breaker. - -> CoreExprWithFVs - -> VM (Inline, VExpr) - -vectPolyExpr loop_breaker (_, AnnNote note expr) - = do (inline, expr') <- vectPolyExpr loop_breaker expr - return (inline, vNote note expr') - -vectPolyExpr loop_breaker expr +-- +vectPolyExpr :: Bool -- ^ When vectorising the RHS of a binding, whether that + -- binding is a loop breaker. + -> [Var] + -> CoreExprWithFVs + -> VM (Inline, Bool, VExpr) +vectPolyExpr loop_breaker recFns (_, AnnNote note expr) + = do (inline, isScalarFn, expr') <- vectPolyExpr loop_breaker recFns expr + return (inline, isScalarFn, vNote note expr') +vectPolyExpr loop_breaker recFns expr = do arity <- polyArity tvs polyAbstract tvs $ \args -> do - (inline, mono') <- vectFnExpr False loop_breaker mono - return (addInlineArity inline arity, + (inline, isScalarFn, mono') <- vectFnExpr False loop_breaker recFns mono + return (addInlineArity inline arity, isScalarFn, mapVect (mkLams $ tvs ++ args) mono') where (tvs, mono) = collectAnnTypeBinders expr @@ -111,12 +118,13 @@ vectExpr (_, AnnCase scrut bndr ty alts) | Just (tycon, ty_args) <- splitTyConApp_maybe scrut_ty , isAlgTyCon tycon = vectAlgCase tycon ty_args scrut bndr ty alts + | otherwise = cantVectorise "Can't vectorise expression" (ppr scrut_ty) where scrut_ty = exprType (deAnnotate scrut) vectExpr (_, AnnLet (AnnNonRec bndr rhs) body) = do - vrhs <- localV . inBind bndr . liftM snd $ vectPolyExpr False rhs + vrhs <- localV . inBind bndr . liftM (\(_,_,z)->z) $ vectPolyExpr False [] rhs (vbndr, vbody) <- vectBndrIn bndr (vectExpr body) return $ vLet (vNonRec vbndr vrhs) vbody @@ -132,11 +140,11 @@ vectExpr (_, AnnLet (AnnRec bs) body) vect_rhs bndr rhs = localV . inBind bndr - . liftM snd - $ vectPolyExpr (isLoopBreaker $ idOccInfo bndr) rhs + . liftM (\(_,_,z)->z) + $ vectPolyExpr (isLoopBreaker $ idOccInfo bndr) [] rhs vectExpr e@(_, AnnLam bndr _) - | isId bndr = liftM snd $ vectFnExpr True False e + | isId bndr = liftM (\(_,_,z) ->z) $ vectFnExpr True False [] e {- onlyIfV (isEmptyVarSet fvs) (vectScalarLam bs $ deAnnotate body) `orElseV` vectLam True fvs bs body @@ -144,87 +152,145 @@ onlyIfV (isEmptyVarSet fvs) (vectScalarLam bs $ deAnnotate body) (bs,body) = collectAnnValBinders e -} -vectExpr e = cantVectorise "Can't vectorise expression" (ppr $ deAnnotate e) - +vectExpr e = cantVectorise "Can't vectorise expression (vectExpr)" (ppr $ deAnnotate e) -- | Vectorise an expression with an outer lambda abstraction. -vectFnExpr - :: Bool -- ^ When the RHS of a binding, whether that binding should be inlined. - -> Bool -- ^ Whether the binding is a loop breaker. - -> CoreExprWithFVs -- ^ Expression to vectorise. Must have an outer `AnnLam`. - -> VM (Inline, VExpr) - -vectFnExpr inline loop_breaker e@(fvs, AnnLam bndr _) - | isId bndr = onlyIfV (isEmptyVarSet fvs) - (mark DontInline . vectScalarLam bs $ deAnnotate body) - `orElseV` mark inlineMe (vectLam inline loop_breaker fvs bs body) - where - (bs,body) = collectAnnValBinders e - -vectFnExpr _ _ e = mark DontInline $ vectExpr e - -mark :: Inline -> VM a -> VM (Inline, a) -mark b p = do { x <- p; return (b,x) } - - --- | Vectorise a function where are the args have scalar type, --- that is Int, Float, Double etc. -vectScalarLam - :: [Var] -- ^ Bound variables of function. - -> CoreExpr -- ^ Function body. - -> VM VExpr - -vectScalarLam args body - = do scalars <- globalScalars - onlyIfV (all is_scalar_ty arg_tys - && is_scalar_ty res_ty - && is_scalar (extendVarSetList scalars args) body - && uses scalars body) - $ do - fn_var <- hoistExpr (fsLit "fn") (mkLams args body) DontInline - zipf <- zipScalars arg_tys res_ty - clo <- scalarClosure arg_tys res_ty (Var fn_var) - (zipf `App` Var fn_var) - clo_var <- hoistExpr (fsLit "clo") clo DontInline - lclo <- liftPD (Var clo_var) - return (Var clo_var, lclo) +-- +vectFnExpr :: Bool -- ^ If we process the RHS of a binding, whether that binding should + -- be inlined + -> Bool -- ^ Whether the binding is a loop breaker + -> [Var] -- ^ Names of function in same recursive binding group + -> CoreExprWithFVs -- ^ Expression to vectorise; must have an outer `AnnLam` + -> VM (Inline, Bool, VExpr) +vectFnExpr inline loop_breaker recFns expr@(_fvs, AnnLam bndr _) + | isId bndr = mark DontInline True (vectScalarFun False recFns (deAnnotate expr)) + `orElseV` + mark inlineMe False (vectLam inline loop_breaker expr) +vectFnExpr _ _ _ e = mark DontInline False $ vectExpr e + +mark :: Inline -> Bool -> VM a -> VM (Inline, Bool, a) +mark b isScalarFn p = do { x <- p; return (b, isScalarFn, x) } + +-- |Vectorise an expression of functional type, where all arguments and the result are of scalar +-- type (i.e., 'Int', 'Float', 'Double' etc.) and which does not contain any subcomputations that +-- involve parallel arrays. Such functionals do not requires the full blown vectorisation +-- transformation; instead, they can be lifted by application of a member of the zipWith family +-- (i.e., 'map', 'zipWith', zipWith3', etc.) +-- +vectScalarFun :: Bool -- ^ Was the function marked as scalar by the user? + -> [Var] -- ^ Functions names in same recursive binding group + -> CoreExpr -- ^ Expression to be vectorised + -> VM VExpr +vectScalarFun forceScalar recFns expr + = do { gscalars <- globalScalars + ; let scalars = gscalars `extendVarSetList` recFns + (arg_tys, res_ty) = splitFunTys (exprType expr) + ; MASSERT( not $ null arg_tys ) + ; onlyIfV (forceScalar -- user asserts the functions is scalar + || + all is_prim_ty arg_tys -- check whether the function is scalar + && is_prim_ty res_ty + && is_scalar scalars expr + && uses scalars expr) + $ mkScalarFun arg_tys res_ty expr + } where - arg_tys = map idType args - res_ty = exprType body - - is_scalar_ty ty + -- FIXME: This is woefully insufficient!!! We need a scalar pragma for types!!! + is_prim_ty ty | Just (tycon, []) <- splitTyConApp_maybe ty = tycon == intTyCon || tycon == floatTyCon || tycon == doubleTyCon - | otherwise = False - is_scalar vs (Var v) = v `elemVarSet` vs - is_scalar _ e@(Lit _) = is_scalar_ty $ exprType e - is_scalar vs (App e1 e2) = is_scalar vs e1 && is_scalar vs e2 - is_scalar _ _ = False - + -- Checks whether an expression contain a non-scalar subexpression. + -- + -- Precodition: The variables in the first argument are scalar. + -- + -- In case of a recursive binding group, we /assume/ that all bindings are scalar (by adding + -- them to the list of scalar variables) and then check them. If one of them turns out not to + -- be scalar, the entire group is regarded as not being scalar. + -- + -- FIXME: Currently, doesn't regard external (non-data constructor) variable and anonymous + -- data constructor as scalar. Should be changed once scalar types are passed + -- through VectInfo. + -- + is_scalar :: VarSet -> CoreExpr -> Bool + is_scalar scalars (Var v) = v `elemVarSet` scalars + is_scalar _scalars (Lit _) = True + is_scalar scalars e@(App e1 e2) + | maybe_parr_ty (exprType e) = False + | otherwise = is_scalar scalars e1 && is_scalar scalars e2 + is_scalar scalars (Lam var body) + | maybe_parr_ty (varType var) = False + | otherwise = is_scalar (scalars `extendVarSet` var) body + is_scalar scalars (Let bind body) = bindsAreScalar && is_scalar scalars' body + where + (bindsAreScalar, scalars') = is_scalar_bind scalars bind + is_scalar scalars (Case e var ty alts) + | is_prim_ty ty = is_scalar scalars' e && all (is_scalar_alt scalars') alts + | otherwise = False + where + scalars' = scalars `extendVarSet` var + is_scalar scalars (Cast e _coe) = is_scalar scalars e + is_scalar scalars (Note _ e ) = is_scalar scalars e + is_scalar _scalars (Type _) = True + + -- Result: (, scalars ++ variables bound in this group) + is_scalar_bind scalars (NonRec var e) = (is_scalar scalars e, scalars `extendVarSet` var) + is_scalar_bind scalars (Rec bnds) = (all (is_scalar scalars') es, scalars') + where + (vars, es) = unzip bnds + scalars' = scalars `extendVarSetList` vars + + is_scalar_alt scalars (_, vars, e) = is_scalar (scalars `extendVarSetList ` vars) e + + -- Checks whether the type might be a parallel array type. In particular, if the outermost + -- constructor is a type family, we conservatively assume that it may be a parallel array type. + maybe_parr_ty :: Type -> Bool + maybe_parr_ty ty + | Just ty' <- coreView ty = maybe_parr_ty ty' + | Just (tyCon, _) <- splitTyConApp_maybe ty = isPArrTyCon tyCon || isSynFamilyTyCon tyCon + maybe_parr_ty _ = False + + -- FIXME: I'm not convinced that this reasoning is (always) sound. If the identify functions + -- is called by some other function that is otherwise scalar, it would be very bad + -- that just this call to the identity makes it not be scalar. -- A scalar function has to actually compute something. Without the check, -- we would treat (\(x :: Int) -> x) as a scalar function and lift it to -- (map (\x -> x)) which is very bad. Normal lifting transforms it to -- (\n# x -> x) which is what we want. - uses funs (Var v) = v `elemVarSet` funs - uses funs (App e1 e2) = uses funs e1 || uses funs e2 - uses _ _ = False - + uses funs (Var v) = v `elemVarSet` funs + uses funs (App e1 e2) = uses funs e1 || uses funs e2 + uses funs (Lam b body) = uses (funs `extendVarSet` b) body + uses funs (Let (NonRec _b letExpr) body) + = uses funs letExpr || uses funs body + uses funs (Case e _eId _ty alts) + = uses funs e || any (uses_alt funs) alts + uses _ _ = False + + uses_alt funs (_, _bs, e) = uses funs e + +mkScalarFun :: [Type] -> Type -> CoreExpr -> VM VExpr +mkScalarFun arg_tys res_ty expr + = do { fn_var <- hoistExpr (fsLit "fn") expr DontInline + ; zipf <- zipScalars arg_tys res_ty + ; clo <- scalarClosure arg_tys res_ty (Var fn_var) (zipf `App` Var fn_var) + ; clo_var <- hoistExpr (fsLit "clo") clo DontInline + ; lclo <- liftPD (Var clo_var) + ; return (Var clo_var, lclo) + } -- | Vectorise a lambda abstraction. -vectLam - :: Bool -- ^ When the RHS of a binding, whether that binding should be inlined. - -> Bool -- ^ Whether the binding is a loop breaker. - -> VarSet -- ^ The free variables in the body. - -> [Var] -- ^ Binding variables. - -> CoreExprWithFVs -- ^ Body of abstraction. - -> VM VExpr - -vectLam inline loop_breaker fvs bs body - = do tyvars <- localTyVars +-- +vectLam :: Bool -- ^ When the RHS of a binding, whether that binding should be inlined. + -> Bool -- ^ Whether the binding is a loop breaker. + -> CoreExprWithFVs -- ^ Body of abstraction. + -> VM VExpr +vectLam inline loop_breaker expr@(fvs, AnnLam _ _) + = do let (bs, body) = collectAnnValBinders expr + + tyvars <- localTyVars (vs, vvs) <- readLEnv $ \env -> unzip [(var, vv) | var <- varSetElems fvs , Just vv <- [lookupVarEnv (local_vars env) var]] @@ -254,11 +320,12 @@ vectLam inline loop_breaker fvs bs body (LitAlt (mkMachInt 0), [], empty)]) | otherwise = return (ve, le) +vectLam _ _ _ = panic "vectLam" vectTyAppExpr :: CoreExprWithFVs -> [Type] -> VM VExpr vectTyAppExpr (_, AnnVar v) tys = vectPolyVar v tys -vectTyAppExpr e tys = cantVectorise "Can't vectorise expression" +vectTyAppExpr e tys = cantVectorise "Can't vectorise expression (vectTyExpr)" (ppr $ deAnnotate e `mkTyApps` tys)