X-Git-Url: http://git.megacz.com/?a=blobdiff_plain;f=compiler%2Fvectorise%2FVectorise%2FExp.hs;h=862a760a43a5cd7322acac10c364da2f5a2dbacf;hb=21703cf93de9e93f6b278b4d46f8511a813cbeda;hp=da783a923065b8810ff5e70bafa677badf2c1876;hpb=1158cc3254c5f14db28223966d8b666890f8beaa;p=ghc-hetmet.git diff --git a/compiler/vectorise/Vectorise/Exp.hs b/compiler/vectorise/Vectorise/Exp.hs index da783a9..862a760 100644 --- a/compiler/vectorise/Vectorise/Exp.hs +++ b/compiler/vectorise/Vectorise/Exp.hs @@ -3,8 +3,8 @@ module Vectorise.Exp (vectPolyExpr) where -import VectUtils -import VectType +import Vectorise.Utils +import Vectorise.Type.Type import Vectorise.Var import Vectorise.Vect import Vectorise.Env @@ -22,7 +22,7 @@ import Var import VarEnv import VarSet import Id -import BasicTypes +import BasicTypes( isLoopBreaker ) import Literal import TysWiredIn import TysPrim @@ -176,8 +176,8 @@ vectScalarLam vectScalarLam args body = do scalars <- globalScalars - onlyIfV (all is_scalar_ty arg_tys - && is_scalar_ty res_ty + onlyIfV (all is_prim_ty arg_tys + && is_prim_ty res_ty && is_scalar (extendVarSetList scalars args) body && uses scalars body) $ do @@ -192,18 +192,68 @@ vectScalarLam args body arg_tys = map idType args res_ty = exprType body - is_scalar_ty ty + is_prim_ty ty | Just (tycon, []) <- splitTyConApp_maybe ty = tycon == intTyCon || tycon == floatTyCon || tycon == doubleTyCon | otherwise = False - - is_scalar vs (Var v) = v `elemVarSet` vs - is_scalar _ e@(Lit _) = is_scalar_ty $ exprType e - is_scalar vs (App e1 e2) = is_scalar vs e1 && is_scalar vs e2 - is_scalar _ _ = False + + cantbe_parr_expr expr = not $ maybe_parr_ty $ exprType expr + + maybe_parr_ty ty = maybe_parr_ty' [] ty + maybe_parr_ty' alreadySeen ty + | isPArrTyCon tycon = True + | isPrimTyCon tycon = False + | isAbstractTyCon tycon = True + | isFunTyCon tycon || isProductTyCon tycon || isTupleTyCon tycon = any (maybe_parr_ty' alreadySeen) args + | isDataTyCon tycon = pprTrace "isDataTyCon" (ppr tycon) $ + any (maybe_parr_ty' alreadySeen) args || + hasParrDataCon alreadySeen tycon + | otherwise = True + where + Just (tycon, args) = splitTyConApp_maybe ty + + + hasParrDataCon alreadySeen tycon + | tycon `elem` alreadySeen = False + | otherwise = + any (maybe_parr_ty' $ tycon : alreadySeen) $ concat $ map dataConOrigArgTys $ tyConDataCons tycon + + -- checks to make sure expression can't contain a non-scalar subexpression. Might err on the side of caution whenever + -- an external (non data constructor) variable is used, or anonymous data constructor + is_scalar vs e@(Var v) + | Just _ <- isDataConId_maybe v = cantbe_parr_expr e + | otherwise = cantbe_parr_expr e && (v `elemVarSet` vs) + is_scalar _ e@(Lit _) = -- pprTrace "is_scalar Lit" (ppr e) $ + cantbe_parr_expr e + + is_scalar vs e@(App e1 e2) = -- pprTrace "is_scalar App" (ppr e) $ + cantbe_parr_expr e && + is_scalar vs e1 && is_scalar vs e2 + is_scalar vs e@(Let (NonRec b letExpr) body) + = -- pprTrace "is_scalar Let" (ppr e) $ + cantbe_parr_expr e && + is_scalar vs letExpr && is_scalar (extendVarSet vs b) body + is_scalar vs e@(Let (Rec bnds) body) + = let vs' = extendVarSetList vs (map fst bnds) + in -- pprTrace "is_scalar Rec" (ppr e) $ + cantbe_parr_expr e && + all (is_scalar vs') (map snd bnds) && is_scalar vs' body + is_scalar vs e@(Case eC eId ty alts) + = let vs' = extendVarSet vs eId + in -- pprTrace "is_scalar Case" (ppr e) $ + cantbe_parr_expr e && + is_prim_ty ty && + is_scalar vs' eC && + (all (is_scalar_alt vs') alts) + + is_scalar _ e = -- pprTrace "is_scalar other" (ppr e) $ + False + + is_scalar_alt vs (_, bs, e) + = is_scalar (extendVarSetList vs bs) e -- A scalar function has to actually compute something. Without the check, -- we would treat (\(x :: Int) -> x) as a scalar function and lift it to @@ -211,8 +261,14 @@ vectScalarLam args body -- (\n# x -> x) which is what we want. uses funs (Var v) = v `elemVarSet` funs uses funs (App e1 e2) = uses funs e1 || uses funs e2 + uses funs (Let (NonRec _b letExpr) body) + = uses funs letExpr || uses funs body + uses funs (Case e _eId _ty alts) + = uses funs e || any (uses_alt funs) alts uses _ _ = False + uses_alt funs (_, _bs, e) + = uses funs e -- | Vectorise a lambda abstraction. vectLam