X-Git-Url: http://git.megacz.com/?a=blobdiff_plain;f=compiler%2Fvectorise%2FVectorise%2FExp.hs;h=4676e182a9221aad77b1041842f0f713a03d20db;hb=febf1ced754a3996ac1a5877dcded87828560d1c;hp=b94224ab7b34a715b855162e2e1263cb806acbdb;hpb=37b0cb1147cadef4d68f3fc61faa3ec11ad47440;p=ghc-hetmet.git diff --git a/compiler/vectorise/Vectorise/Exp.hs b/compiler/vectorise/Vectorise/Exp.hs index b94224a..4676e18 100644 --- a/compiler/vectorise/Vectorise/Exp.hs +++ b/compiler/vectorise/Vectorise/Exp.hs @@ -1,15 +1,23 @@ -- | Vectorisation of expressions. -module Vectorise.Exp - (vectPolyExpr) -where -import Vectorise.Utils +module Vectorise.Exp ( + + -- Vectorise a polymorphic expression + vectPolyExpr, + + -- Vectorise a scalar expression of functional type + vectScalarFun +) where + +#include "HsVersions.h" + import Vectorise.Type.Type import Vectorise.Var import Vectorise.Vect import Vectorise.Env import Vectorise.Monad import Vectorise.Builtins +import Vectorise.Utils import CoreSyn import CoreUtils @@ -33,22 +41,21 @@ import Data.List -- | Vectorise a polymorphic expression. -vectPolyExpr - :: Bool -- ^ When vectorising the RHS of a binding, whether that - -- binding is a loop breaker. - -> CoreExprWithFVs - -> VM (Inline, Bool, VExpr) - -vectPolyExpr loop_breaker (_, AnnNote note expr) - = do (inline, isScalarFn, expr') <- vectPolyExpr loop_breaker expr +-- +vectPolyExpr :: Bool -- ^ When vectorising the RHS of a binding, whether that + -- binding is a loop breaker. + -> [Var] + -> CoreExprWithFVs + -> VM (Inline, Bool, VExpr) +vectPolyExpr loop_breaker recFns (_, AnnNote note expr) + = do (inline, isScalarFn, expr') <- vectPolyExpr loop_breaker recFns expr return (inline, isScalarFn, vNote note expr') - -vectPolyExpr loop_breaker expr +vectPolyExpr loop_breaker recFns expr = do arity <- polyArity tvs polyAbstract tvs $ \args -> do - (inline, isScalarFn, mono') <- vectFnExpr False loop_breaker mono + (inline, isScalarFn, mono') <- vectFnExpr False loop_breaker recFns mono return (addInlineArity inline arity, isScalarFn, mapVect (mkLams $ tvs ++ args) mono') where @@ -117,7 +124,7 @@ vectExpr (_, AnnCase scrut bndr ty alts) vectExpr (_, AnnLet (AnnNonRec bndr rhs) body) = do - vrhs <- localV . inBind bndr . liftM (\(_,_,z)->z) $ vectPolyExpr False rhs + vrhs <- localV . inBind bndr . liftM (\(_,_,z)->z) $ vectPolyExpr False [] rhs (vbndr, vbody) <- vectBndrIn bndr (vectExpr body) return $ vLet (vNonRec vbndr vrhs) vbody @@ -134,10 +141,10 @@ vectExpr (_, AnnLet (AnnRec bs) body) vect_rhs bndr rhs = localV . inBind bndr . liftM (\(_,_,z)->z) - $ vectPolyExpr (isLoopBreaker $ idOccInfo bndr) rhs + $ vectPolyExpr (isLoopBreaker $ idOccInfo bndr) [] rhs vectExpr e@(_, AnnLam bndr _) - | isId bndr = liftM (\(_,_,z) ->z) $ vectFnExpr True False e + | isId bndr = liftM (\(_,_,z) ->z) $ vectFnExpr True False [] e {- onlyIfV (isEmptyVarSet fvs) (vectScalarLam bs $ deAnnotate body) `orElseV` vectLam True fvs bs body @@ -147,146 +154,144 @@ onlyIfV (isEmptyVarSet fvs) (vectScalarLam bs $ deAnnotate body) vectExpr e = cantVectorise "Can't vectorise expression (vectExpr)" (ppr $ deAnnotate e) - -- | Vectorise an expression with an outer lambda abstraction. -vectFnExpr - :: Bool -- ^ When the RHS of a binding, whether that binding should be inlined. - -> Bool -- ^ Whether the binding is a loop breaker. - -> CoreExprWithFVs -- ^ Expression to vectorise. Must have an outer `AnnLam`. - -> VM (Inline, Bool, VExpr) - -vectFnExpr inline loop_breaker e@(fvs, AnnLam bndr _) - | isId bndr = pprTrace "vectFnExpr -- id" (ppr fvs )$ - onlyIfV True -- (isEmptyVarSet fvs) -- we check for free variables later. TODO: clean up - (mark DontInline True . vectScalarLam bs $ deAnnotate body) - `orElseV` mark inlineMe False (vectLam inline loop_breaker fvs bs body) - where - (bs,body) = collectAnnValBinders e - -vectFnExpr _ _ e = pprTrace "vectFnExpr -- otherwise" (ppr "a" )$ mark DontInline False $ vectExpr e +-- +vectFnExpr :: Bool -- ^ If we process the RHS of a binding, whether that binding should + -- be inlined + -> Bool -- ^ Whether the binding is a loop breaker + -> [Var] -- ^ Names of function in same recursive binding group + -> CoreExprWithFVs -- ^ Expression to vectorise; must have an outer `AnnLam` + -> VM (Inline, Bool, VExpr) +vectFnExpr inline loop_breaker recFns expr@(_fvs, AnnLam bndr _) + | isId bndr = mark DontInline True (vectScalarFun False recFns (deAnnotate expr)) + `orElseV` + mark inlineMe False (vectLam inline loop_breaker expr) +vectFnExpr _ _ _ e = mark DontInline False $ vectExpr e mark :: Inline -> Bool -> VM a -> VM (Inline, Bool, a) mark b isScalarFn p = do { x <- p; return (b, isScalarFn, x) } - --- | Vectorise a function where are the args have scalar type, --- that is Int, Float, Double etc. -vectScalarLam - :: [Var] -- ^ Bound variables of function. - -> CoreExpr -- ^ Function body. - -> VM VExpr - -vectScalarLam args body - = do scalars <- globalScalars - pprTrace "vectScalarLam" (ppr $ is_scalar (extendVarSetList scalars args) body) $ - onlyIfV (all is_prim_ty arg_tys - && is_prim_ty res_ty - && is_scalar (extendVarSetList scalars args) body - && uses scalars body) - $ do - fn_var <- hoistExpr (fsLit "fn") (mkLams args body) DontInline - zipf <- zipScalars arg_tys res_ty - clo <- scalarClosure arg_tys res_ty (Var fn_var) - (zipf `App` Var fn_var) - clo_var <- hoistExpr (fsLit "clo") clo DontInline - lclo <- liftPD (Var clo_var) - pprTrace " lam is scalar" (ppr "") $ - return (Var clo_var, lclo) +-- |Vectorise an expression of functional type, where all arguments and the result are of scalar +-- type (i.e., 'Int', 'Float', 'Double' etc.) and which does not contain any subcomputations that +-- involve parallel arrays. Such functionals do not requires the full blown vectorisation +-- transformation; instead, they can be lifted by application of a member of the zipWith family +-- (i.e., 'map', 'zipWith', zipWith3', etc.) +-- +vectScalarFun :: Bool -- ^ Was the function marked as scalar by the user? + -> [Var] -- ^ Functions names in same recursive binding group + -> CoreExpr -- ^ Expression to be vectorised + -> VM VExpr +vectScalarFun forceScalar recFns expr + = do { gscalars <- globalScalars + ; let scalars = gscalars `extendVarSetList` recFns + (arg_tys, res_ty) = splitFunTys (exprType expr) + ; MASSERT( not $ null arg_tys ) + ; onlyIfV (forceScalar -- user asserts the functions is scalar + || + all is_prim_ty arg_tys -- check whether the function is scalar + && is_prim_ty res_ty + && is_scalar scalars expr + && uses scalars expr) + $ mkScalarFun arg_tys res_ty expr + } where - arg_tys = map idType args - res_ty = exprType body - + -- FIXME: This is woefully insufficient!!! We need a scalar pragma for types!!! is_prim_ty ty | Just (tycon, []) <- splitTyConApp_maybe ty = tycon == intTyCon || tycon == floatTyCon || tycon == doubleTyCon - | otherwise = False - - cantbe_parr_expr expr = not $ maybe_parr_ty $ exprType expr - - maybe_parr_ty ty = maybe_parr_ty' [] ty - - maybe_parr_ty' _ ty | Nothing <- splitTyConApp_maybe ty = False -- TODO: is this really what we want to do with polym. types? - maybe_parr_ty' alreadySeen ty - | isPArrTyCon tycon = True - | isPrimTyCon tycon = False - | isAbstractTyCon tycon = True - | isFunTyCon tycon || isProductTyCon tycon || isTupleTyCon tycon = any (maybe_parr_ty' alreadySeen) args - | isDataTyCon tycon = pprTrace "isDataTyCon" (ppr tycon) $ - any (maybe_parr_ty' alreadySeen) args || - hasParrDataCon alreadySeen tycon - | otherwise = True - where - Just (tycon, args) = splitTyConApp_maybe ty - - - hasParrDataCon alreadySeen tycon - | tycon `elem` alreadySeen = False - | otherwise = - any (maybe_parr_ty' $ tycon : alreadySeen) $ concat $ map dataConOrigArgTys $ tyConDataCons tycon - - -- checks to make sure expression can't contain a non-scalar subexpression. Might err on the side of caution whenever - -- an external (non data constructor) variable is used, or anonymous data constructor - is_scalar vs e@(Var v) - | Just _ <- isDataConId_maybe v = cantbe_parr_expr e - | otherwise = cantbe_parr_expr e && (v `elemVarSet` vs) - is_scalar _ e@(Lit _) = -- pprTrace "is_scalar Lit" (ppr e) $ - cantbe_parr_expr e - - is_scalar vs e@(App e1 e2) = -- pprTrace "is_scalar App" (ppr e) $ - cantbe_parr_expr e && - is_scalar vs e1 && is_scalar vs e2 - is_scalar vs e@(Let (NonRec b letExpr) body) - = -- pprTrace "is_scalar Let" (ppr e) $ - cantbe_parr_expr e && - is_scalar vs letExpr && is_scalar (extendVarSet vs b) body - is_scalar vs e@(Let (Rec bnds) body) - = let vs' = extendVarSetList vs (map fst bnds) - in -- pprTrace "is_scalar Rec" (ppr e) $ - cantbe_parr_expr e && - all (is_scalar vs') (map snd bnds) && is_scalar vs' body - is_scalar vs e@(Case eC eId ty alts) - = let vs' = extendVarSet vs eId - in -- pprTrace "is_scalar Case" (ppr e) $ - cantbe_parr_expr e && - is_prim_ty ty && - is_scalar vs' eC && - (all (is_scalar_alt vs') alts) - - is_scalar _ e = -- pprTrace "is_scalar other" (ppr e) $ - False - - is_scalar_alt vs (_, bs, e) - = is_scalar (extendVarSetList vs bs) e + -- Checks whether an expression contain a non-scalar subexpression. + -- + -- Precodition: The variables in the first argument are scalar. + -- + -- In case of a recursive binding group, we /assume/ that all bindings are scalar (by adding + -- them to the list of scalar variables) and then check them. If one of them turns out not to + -- be scalar, the entire group is regarded as not being scalar. + -- + -- FIXME: Currently, doesn't regard external (non-data constructor) variable and anonymous + -- data constructor as scalar. Should be changed once scalar types are passed + -- through VectInfo. + -- + is_scalar :: VarSet -> CoreExpr -> Bool + is_scalar scalars (Var v) = v `elemVarSet` scalars + is_scalar _scalars (Lit _) = True + is_scalar scalars e@(App e1 e2) + | maybe_parr_ty (exprType e) = False + | otherwise = is_scalar scalars e1 && is_scalar scalars e2 + is_scalar scalars (Lam var body) + | maybe_parr_ty (varType var) = False + | otherwise = is_scalar (scalars `extendVarSet` var) body + is_scalar scalars (Let bind body) = bindsAreScalar && is_scalar scalars' body + where + (bindsAreScalar, scalars') = is_scalar_bind scalars bind + is_scalar scalars (Case e var ty alts) + | is_prim_ty ty = is_scalar scalars' e && all (is_scalar_alt scalars') alts + | otherwise = False + where + scalars' = scalars `extendVarSet` var + is_scalar scalars (Cast e _coe) = is_scalar scalars e + is_scalar scalars (Note _ e ) = is_scalar scalars e + is_scalar _scalars (Type {}) = True + is_scalar _scalars (Coercion {}) = True + + -- Result: (, scalars ++ variables bound in this group) + is_scalar_bind scalars (NonRec var e) = (is_scalar scalars e, scalars `extendVarSet` var) + is_scalar_bind scalars (Rec bnds) = (all (is_scalar scalars') es, scalars') + where + (vars, es) = unzip bnds + scalars' = scalars `extendVarSetList` vars + + is_scalar_alt scalars (_, vars, e) = is_scalar (scalars `extendVarSetList ` vars) e + + -- Checks whether the type might be a parallel array type. In particular, if the outermost + -- constructor is a type family, we conservatively assume that it may be a parallel array type. + maybe_parr_ty :: Type -> Bool + maybe_parr_ty ty + | Just ty' <- coreView ty = maybe_parr_ty ty' + | Just (tyCon, _) <- splitTyConApp_maybe ty = isPArrTyCon tyCon || isSynFamilyTyCon tyCon + maybe_parr_ty _ = False + + -- FIXME: I'm not convinced that this reasoning is (always) sound. If the identify functions + -- is called by some other function that is otherwise scalar, it would be very bad + -- that just this call to the identity makes it not be scalar. -- A scalar function has to actually compute something. Without the check, -- we would treat (\(x :: Int) -> x) as a scalar function and lift it to -- (map (\x -> x)) which is very bad. Normal lifting transforms it to -- (\n# x -> x) which is what we want. - uses funs (Var v) = v `elemVarSet` funs - uses funs (App e1 e2) = uses funs e1 || uses funs e2 + uses funs (Var v) = v `elemVarSet` funs + uses funs (App e1 e2) = uses funs e1 || uses funs e2 + uses funs (Lam b body) = uses (funs `extendVarSet` b) body uses funs (Let (NonRec _b letExpr) body) - = uses funs letExpr || uses funs body + = uses funs letExpr || uses funs body uses funs (Case e _eId _ty alts) - = uses funs e || any (uses_alt funs) alts - uses _ _ = False + = uses funs e || any (uses_alt funs) alts + uses _ _ = False + + uses_alt funs (_, _bs, e) = uses funs e - uses_alt funs (_, _bs, e) - = uses funs e +mkScalarFun :: [Type] -> Type -> CoreExpr -> VM VExpr +mkScalarFun arg_tys res_ty expr + = do { fn_var <- hoistExpr (fsLit "fn") expr DontInline + ; zipf <- zipScalars arg_tys res_ty + ; clo <- scalarClosure arg_tys res_ty (Var fn_var) (zipf `App` Var fn_var) + ; clo_var <- hoistExpr (fsLit "clo") clo DontInline + ; lclo <- liftPD (Var clo_var) + ; return (Var clo_var, lclo) + } -- | Vectorise a lambda abstraction. -vectLam - :: Bool -- ^ When the RHS of a binding, whether that binding should be inlined. - -> Bool -- ^ Whether the binding is a loop breaker. - -> VarSet -- ^ The free variables in the body. - -> [Var] -- ^ Binding variables. - -> CoreExprWithFVs -- ^ Body of abstraction. - -> VM VExpr - -vectLam inline loop_breaker fvs bs body - = do tyvars <- localTyVars +-- +vectLam :: Bool -- ^ When the RHS of a binding, whether that binding should be inlined. + -> Bool -- ^ Whether the binding is a loop breaker. + -> CoreExprWithFVs -- ^ Body of abstraction. + -> VM VExpr +vectLam inline loop_breaker expr@(fvs, AnnLam _ _) + = do let (bs, body) = collectAnnValBinders expr + + tyvars <- localTyVars (vs, vvs) <- readLEnv $ \env -> unzip [(var, vv) | var <- varSetElems fvs , Just vv <- [lookupVarEnv (local_vars env) var]] @@ -316,6 +321,7 @@ vectLam inline loop_breaker fvs bs body (LitAlt (mkMachInt 0), [], empty)]) | otherwise = return (ve, le) +vectLam _ _ _ = panic "vectLam" vectTyAppExpr :: CoreExprWithFVs -> [Type] -> VM VExpr