X-Git-Url: http://git.megacz.com/?a=blobdiff_plain;f=compiler%2Fvectorise%2FVectorise%2FExp.hs;h=4676e182a9221aad77b1041842f0f713a03d20db;hb=18691d440f90a3dff4ef538091c886af505e5cf5;hp=569057e5e88deaa982ff0ad4ec69695e605f6ea3;hpb=f2aaae9757e7532485c97f6c9a9ed5437542d1dd;p=ghc-hetmet.git diff --git a/compiler/vectorise/Vectorise/Exp.hs b/compiler/vectorise/Vectorise/Exp.hs index 569057e..4676e18 100644 --- a/compiler/vectorise/Vectorise/Exp.hs +++ b/compiler/vectorise/Vectorise/Exp.hs @@ -1,15 +1,23 @@ -- | Vectorisation of expressions. -module Vectorise.Exp - (vectPolyExpr) -where -import Vectorise.Utils +module Vectorise.Exp ( + + -- Vectorise a polymorphic expression + vectPolyExpr, + + -- Vectorise a scalar expression of functional type + vectScalarFun +) where + +#include "HsVersions.h" + import Vectorise.Type.Type import Vectorise.Var import Vectorise.Vect import Vectorise.Env import Vectorise.Monad import Vectorise.Builtins +import Vectorise.Utils import CoreSyn import CoreUtils @@ -148,134 +156,142 @@ vectExpr e = cantVectorise "Can't vectorise expression (vectExpr)" (ppr $ deAnno -- | Vectorise an expression with an outer lambda abstraction. -- -vectFnExpr :: Bool -- ^ When the RHS of a binding, whether that binding should be inlined. - -> Bool -- ^ Whether the binding is a loop breaker. - -> [Var] - -> CoreExprWithFVs -- ^ Expression to vectorise. Must have an outer `AnnLam`. +vectFnExpr :: Bool -- ^ If we process the RHS of a binding, whether that binding should + -- be inlined + -> Bool -- ^ Whether the binding is a loop breaker + -> [Var] -- ^ Names of function in same recursive binding group + -> CoreExprWithFVs -- ^ Expression to vectorise; must have an outer `AnnLam` -> VM (Inline, Bool, VExpr) -vectFnExpr inline loop_breaker recFns e@(fvs, AnnLam bndr _) - | isId bndr = onlyIfV True -- (isEmptyVarSet fvs) -- we check for free variables later. TODO: clean up - (mark DontInline True . vectScalarLam bs recFns $ deAnnotate body) - `orElseV` mark inlineMe False (vectLam inline loop_breaker fvs bs body) - where - (bs,body) = collectAnnValBinders e +vectFnExpr inline loop_breaker recFns expr@(_fvs, AnnLam bndr _) + | isId bndr = mark DontInline True (vectScalarFun False recFns (deAnnotate expr)) + `orElseV` + mark inlineMe False (vectLam inline loop_breaker expr) vectFnExpr _ _ _ e = mark DontInline False $ vectExpr e mark :: Inline -> Bool -> VM a -> VM (Inline, Bool, a) mark b isScalarFn p = do { x <- p; return (b, isScalarFn, x) } - --- | Vectorise a function where are the args have scalar type, --- that is Int, Float, Double etc. -vectScalarLam - :: [Var] -- ^ Bound variables of function - -> [Var] - -> CoreExpr -- ^ Function body. - -> VM VExpr - -vectScalarLam args recFns body - = do scalars' <- globalScalars - let scalars = unionVarSet (mkVarSet recFns) scalars' - onlyIfV (all is_prim_ty arg_tys - && is_prim_ty res_ty - && is_scalar (extendVarSetList scalars args) body - && uses scalars body) - $ do - fn_var <- hoistExpr (fsLit "fn") (mkLams args body) DontInline - zipf <- zipScalars arg_tys res_ty - clo <- scalarClosure arg_tys res_ty (Var fn_var) - (zipf `App` Var fn_var) - clo_var <- hoistExpr (fsLit "clo") clo DontInline - lclo <- liftPD (Var clo_var) - return (Var clo_var, lclo) +-- |Vectorise an expression of functional type, where all arguments and the result are of scalar +-- type (i.e., 'Int', 'Float', 'Double' etc.) and which does not contain any subcomputations that +-- involve parallel arrays. Such functionals do not requires the full blown vectorisation +-- transformation; instead, they can be lifted by application of a member of the zipWith family +-- (i.e., 'map', 'zipWith', zipWith3', etc.) +-- +vectScalarFun :: Bool -- ^ Was the function marked as scalar by the user? + -> [Var] -- ^ Functions names in same recursive binding group + -> CoreExpr -- ^ Expression to be vectorised + -> VM VExpr +vectScalarFun forceScalar recFns expr + = do { gscalars <- globalScalars + ; let scalars = gscalars `extendVarSetList` recFns + (arg_tys, res_ty) = splitFunTys (exprType expr) + ; MASSERT( not $ null arg_tys ) + ; onlyIfV (forceScalar -- user asserts the functions is scalar + || + all is_prim_ty arg_tys -- check whether the function is scalar + && is_prim_ty res_ty + && is_scalar scalars expr + && uses scalars expr) + $ mkScalarFun arg_tys res_ty expr + } where - arg_tys = map idType args - res_ty = exprType body - + -- FIXME: This is woefully insufficient!!! We need a scalar pragma for types!!! is_prim_ty ty | Just (tycon, []) <- splitTyConApp_maybe ty = tycon == intTyCon || tycon == floatTyCon || tycon == doubleTyCon - | otherwise = False - - cantbe_parr_expr expr = not $ maybe_parr_ty $ exprType expr - - maybe_parr_ty ty = maybe_parr_ty' [] ty - - maybe_parr_ty' _ ty | Nothing <- splitTyConApp_maybe ty = False -- TODO: is this really what we want to do with polym. types? - maybe_parr_ty' alreadySeen ty - | isPArrTyCon tycon = True - | isPrimTyCon tycon = False - | isAbstractTyCon tycon = True - | isFunTyCon tycon || isProductTyCon tycon || isTupleTyCon tycon = any (maybe_parr_ty' alreadySeen) args - | isDataTyCon tycon = any (maybe_parr_ty' alreadySeen) args || - hasParrDataCon alreadySeen tycon - | otherwise = True - where - Just (tycon, args) = splitTyConApp_maybe ty - - - hasParrDataCon alreadySeen tycon - | tycon `elem` alreadySeen = False - | otherwise = - any (maybe_parr_ty' $ tycon : alreadySeen) $ concat $ map dataConOrigArgTys $ tyConDataCons tycon - - -- checks to make sure expression can't contain a non-scalar subexpression. Might err on the side of caution whenever - -- an external (non data constructor) variable is used, or anonymous data constructor - is_scalar vs e@(Var v) - | Just _ <- isDataConId_maybe v = cantbe_parr_expr e - | otherwise = cantbe_parr_expr e && (v `elemVarSet` vs) - is_scalar _ e@(Lit _) = cantbe_parr_expr e - - is_scalar vs e@(App e1 e2) = cantbe_parr_expr e && - is_scalar vs e1 && is_scalar vs e2 - is_scalar vs e@(Let (NonRec b letExpr) body) - = cantbe_parr_expr e && - is_scalar vs letExpr && is_scalar (extendVarSet vs b) body - is_scalar vs e@(Let (Rec bnds) body) - = let vs' = extendVarSetList vs (map fst bnds) - in cantbe_parr_expr e && - all (is_scalar vs') (map snd bnds) && is_scalar vs' body - is_scalar vs e@(Case eC eId ty alts) - = let vs' = extendVarSet vs eId - in cantbe_parr_expr e && - is_prim_ty ty && - is_scalar vs' eC && - (all (is_scalar_alt vs') alts) - - is_scalar _ _ = False - - is_scalar_alt vs (_, bs, e) - = is_scalar (extendVarSetList vs bs) e + -- Checks whether an expression contain a non-scalar subexpression. + -- + -- Precodition: The variables in the first argument are scalar. + -- + -- In case of a recursive binding group, we /assume/ that all bindings are scalar (by adding + -- them to the list of scalar variables) and then check them. If one of them turns out not to + -- be scalar, the entire group is regarded as not being scalar. + -- + -- FIXME: Currently, doesn't regard external (non-data constructor) variable and anonymous + -- data constructor as scalar. Should be changed once scalar types are passed + -- through VectInfo. + -- + is_scalar :: VarSet -> CoreExpr -> Bool + is_scalar scalars (Var v) = v `elemVarSet` scalars + is_scalar _scalars (Lit _) = True + is_scalar scalars e@(App e1 e2) + | maybe_parr_ty (exprType e) = False + | otherwise = is_scalar scalars e1 && is_scalar scalars e2 + is_scalar scalars (Lam var body) + | maybe_parr_ty (varType var) = False + | otherwise = is_scalar (scalars `extendVarSet` var) body + is_scalar scalars (Let bind body) = bindsAreScalar && is_scalar scalars' body + where + (bindsAreScalar, scalars') = is_scalar_bind scalars bind + is_scalar scalars (Case e var ty alts) + | is_prim_ty ty = is_scalar scalars' e && all (is_scalar_alt scalars') alts + | otherwise = False + where + scalars' = scalars `extendVarSet` var + is_scalar scalars (Cast e _coe) = is_scalar scalars e + is_scalar scalars (Note _ e ) = is_scalar scalars e + is_scalar _scalars (Type {}) = True + is_scalar _scalars (Coercion {}) = True + + -- Result: (, scalars ++ variables bound in this group) + is_scalar_bind scalars (NonRec var e) = (is_scalar scalars e, scalars `extendVarSet` var) + is_scalar_bind scalars (Rec bnds) = (all (is_scalar scalars') es, scalars') + where + (vars, es) = unzip bnds + scalars' = scalars `extendVarSetList` vars + + is_scalar_alt scalars (_, vars, e) = is_scalar (scalars `extendVarSetList ` vars) e + + -- Checks whether the type might be a parallel array type. In particular, if the outermost + -- constructor is a type family, we conservatively assume that it may be a parallel array type. + maybe_parr_ty :: Type -> Bool + maybe_parr_ty ty + | Just ty' <- coreView ty = maybe_parr_ty ty' + | Just (tyCon, _) <- splitTyConApp_maybe ty = isPArrTyCon tyCon || isSynFamilyTyCon tyCon + maybe_parr_ty _ = False + + -- FIXME: I'm not convinced that this reasoning is (always) sound. If the identify functions + -- is called by some other function that is otherwise scalar, it would be very bad + -- that just this call to the identity makes it not be scalar. -- A scalar function has to actually compute something. Without the check, -- we would treat (\(x :: Int) -> x) as a scalar function and lift it to -- (map (\x -> x)) which is very bad. Normal lifting transforms it to -- (\n# x -> x) which is what we want. - uses funs (Var v) = v `elemVarSet` funs - uses funs (App e1 e2) = uses funs e1 || uses funs e2 + uses funs (Var v) = v `elemVarSet` funs + uses funs (App e1 e2) = uses funs e1 || uses funs e2 + uses funs (Lam b body) = uses (funs `extendVarSet` b) body uses funs (Let (NonRec _b letExpr) body) - = uses funs letExpr || uses funs body + = uses funs letExpr || uses funs body uses funs (Case e _eId _ty alts) - = uses funs e || any (uses_alt funs) alts - uses _ _ = False + = uses funs e || any (uses_alt funs) alts + uses _ _ = False - uses_alt funs (_, _bs, e) - = uses funs e + uses_alt funs (_, _bs, e) = uses funs e + +mkScalarFun :: [Type] -> Type -> CoreExpr -> VM VExpr +mkScalarFun arg_tys res_ty expr + = do { fn_var <- hoistExpr (fsLit "fn") expr DontInline + ; zipf <- zipScalars arg_tys res_ty + ; clo <- scalarClosure arg_tys res_ty (Var fn_var) (zipf `App` Var fn_var) + ; clo_var <- hoistExpr (fsLit "clo") clo DontInline + ; lclo <- liftPD (Var clo_var) + ; return (Var clo_var, lclo) + } -- | Vectorise a lambda abstraction. -vectLam - :: Bool -- ^ When the RHS of a binding, whether that binding should be inlined. - -> Bool -- ^ Whether the binding is a loop breaker. - -> VarSet -- ^ The free variables in the body. - -> [Var] -- ^ Binding variables. - -> CoreExprWithFVs -- ^ Body of abstraction. - -> VM VExpr - -vectLam inline loop_breaker fvs bs body - = do tyvars <- localTyVars +-- +vectLam :: Bool -- ^ When the RHS of a binding, whether that binding should be inlined. + -> Bool -- ^ Whether the binding is a loop breaker. + -> CoreExprWithFVs -- ^ Body of abstraction. + -> VM VExpr +vectLam inline loop_breaker expr@(fvs, AnnLam _ _) + = do let (bs, body) = collectAnnValBinders expr + + tyvars <- localTyVars (vs, vvs) <- readLEnv $ \env -> unzip [(var, vv) | var <- varSetElems fvs , Just vv <- [lookupVarEnv (local_vars env) var]] @@ -305,6 +321,7 @@ vectLam inline loop_breaker fvs bs body (LitAlt (mkMachInt 0), [], empty)]) | otherwise = return (ve, le) +vectLam _ _ _ = panic "vectLam" vectTyAppExpr :: CoreExprWithFVs -> [Type] -> VM VExpr