module Vectorise.Exp
(vectPolyExpr)
where
-import VectUtils
+import Vectorise.Utils
import Vectorise.Type.Type
-import Vectorise.Utils.Closure
-import Vectorise.Utils.Hoisting
import Vectorise.Var
import Vectorise.Vect
import Vectorise.Env
import VarEnv
import VarSet
import Id
-import BasicTypes
+import BasicTypes( isLoopBreaker )
import Literal
import TysWiredIn
import TysPrim
-- | Vectorise a polymorphic expression.
-vectPolyExpr
- :: Bool -- ^ When vectorising the RHS of a binding, whether that
- -- binding is a loop breaker.
- -> CoreExprWithFVs
- -> VM (Inline, VExpr)
-
-vectPolyExpr loop_breaker (_, AnnNote note expr)
- = do (inline, expr') <- vectPolyExpr loop_breaker expr
- return (inline, vNote note expr')
-
-vectPolyExpr loop_breaker expr
+--
+vectPolyExpr :: Bool -- ^ When vectorising the RHS of a binding, whether that
+ -- binding is a loop breaker.
+ -> [Var]
+ -> CoreExprWithFVs
+ -> VM (Inline, Bool, VExpr)
+vectPolyExpr loop_breaker recFns (_, AnnNote note expr)
+ = do (inline, isScalarFn, expr') <- vectPolyExpr loop_breaker recFns expr
+ return (inline, isScalarFn, vNote note expr')
+vectPolyExpr loop_breaker recFns expr
= do
arity <- polyArity tvs
polyAbstract tvs $ \args ->
do
- (inline, mono') <- vectFnExpr False loop_breaker mono
- return (addInlineArity inline arity,
+ (inline, isScalarFn, mono') <- vectFnExpr False loop_breaker recFns mono
+ return (addInlineArity inline arity, isScalarFn,
mapVect (mkLams $ tvs ++ args) mono')
where
(tvs, mono) = collectAnnTypeBinders expr
| Just (tycon, ty_args) <- splitTyConApp_maybe scrut_ty
, isAlgTyCon tycon
= vectAlgCase tycon ty_args scrut bndr ty alts
+ | otherwise = cantVectorise "Can't vectorise expression" (ppr scrut_ty)
where
scrut_ty = exprType (deAnnotate scrut)
vectExpr (_, AnnLet (AnnNonRec bndr rhs) body)
= do
- vrhs <- localV . inBind bndr . liftM snd $ vectPolyExpr False rhs
+ vrhs <- localV . inBind bndr . liftM (\(_,_,z)->z) $ vectPolyExpr False [] rhs
(vbndr, vbody) <- vectBndrIn bndr (vectExpr body)
return $ vLet (vNonRec vbndr vrhs) vbody
vect_rhs bndr rhs = localV
. inBind bndr
- . liftM snd
- $ vectPolyExpr (isLoopBreaker $ idOccInfo bndr) rhs
+ . liftM (\(_,_,z)->z)
+ $ vectPolyExpr (isLoopBreaker $ idOccInfo bndr) [] rhs
vectExpr e@(_, AnnLam bndr _)
- | isId bndr = liftM snd $ vectFnExpr True False e
+ | isId bndr = liftM (\(_,_,z) ->z) $ vectFnExpr True False [] e
{-
onlyIfV (isEmptyVarSet fvs) (vectScalarLam bs $ deAnnotate body)
`orElseV` vectLam True fvs bs body
(bs,body) = collectAnnValBinders e
-}
-vectExpr e = cantVectorise "Can't vectorise expression" (ppr $ deAnnotate e)
-
+vectExpr e = cantVectorise "Can't vectorise expression (vectExpr)" (ppr $ deAnnotate e)
-- | Vectorise an expression with an outer lambda abstraction.
-vectFnExpr
- :: Bool -- ^ When the RHS of a binding, whether that binding should be inlined.
- -> Bool -- ^ Whether the binding is a loop breaker.
- -> CoreExprWithFVs -- ^ Expression to vectorise. Must have an outer `AnnLam`.
- -> VM (Inline, VExpr)
-
-vectFnExpr inline loop_breaker e@(fvs, AnnLam bndr _)
- | isId bndr = onlyIfV (isEmptyVarSet fvs)
- (mark DontInline . vectScalarLam bs $ deAnnotate body)
- `orElseV` mark inlineMe (vectLam inline loop_breaker fvs bs body)
+--
+vectFnExpr :: Bool -- ^ When the RHS of a binding, whether that binding should be inlined.
+ -> Bool -- ^ Whether the binding is a loop breaker.
+ -> [Var]
+ -> CoreExprWithFVs -- ^ Expression to vectorise. Must have an outer `AnnLam`.
+ -> VM (Inline, Bool, VExpr)
+vectFnExpr inline loop_breaker recFns e@(fvs, AnnLam bndr _)
+ | isId bndr = onlyIfV True -- (isEmptyVarSet fvs) -- we check for free variables later. TODO: clean up
+ (mark DontInline True . vectScalarLam bs recFns $ deAnnotate body)
+ `orElseV` mark inlineMe False (vectLam inline loop_breaker fvs bs body)
where
(bs,body) = collectAnnValBinders e
+vectFnExpr _ _ _ e = mark DontInline False $ vectExpr e
-vectFnExpr _ _ e = mark DontInline $ vectExpr e
-
-mark :: Inline -> VM a -> VM (Inline, a)
-mark b p = do { x <- p; return (b,x) }
+mark :: Inline -> Bool -> VM a -> VM (Inline, Bool, a)
+mark b isScalarFn p = do { x <- p; return (b, isScalarFn, x) }
-- | Vectorise a function where are the args have scalar type,
-- that is Int, Float, Double etc.
vectScalarLam
- :: [Var] -- ^ Bound variables of function.
+ :: [Var] -- ^ Bound variables of function
+ -> [Var]
-> CoreExpr -- ^ Function body.
-> VM VExpr
-vectScalarLam args body
- = do scalars <- globalScalars
- onlyIfV (all is_scalar_ty arg_tys
- && is_scalar_ty res_ty
+vectScalarLam args recFns body
+ = do scalars' <- globalScalars
+ let scalars = unionVarSet (mkVarSet recFns) scalars'
+ onlyIfV (all is_prim_ty arg_tys
+ && is_prim_ty res_ty
&& is_scalar (extendVarSetList scalars args) body
&& uses scalars body)
$ do
arg_tys = map idType args
res_ty = exprType body
- is_scalar_ty ty
+ is_prim_ty ty
| Just (tycon, []) <- splitTyConApp_maybe ty
= tycon == intTyCon
|| tycon == floatTyCon
|| tycon == doubleTyCon
| otherwise = False
-
- is_scalar vs (Var v) = v `elemVarSet` vs
- is_scalar _ e@(Lit _) = is_scalar_ty $ exprType e
- is_scalar vs (App e1 e2) = is_scalar vs e1 && is_scalar vs e2
- is_scalar _ _ = False
+
+ cantbe_parr_expr expr = not $ maybe_parr_ty $ exprType expr
+
+ maybe_parr_ty ty = maybe_parr_ty' [] ty
+
+ maybe_parr_ty' _ ty | Nothing <- splitTyConApp_maybe ty = False -- TODO: is this really what we want to do with polym. types?
+ maybe_parr_ty' alreadySeen ty
+ | isPArrTyCon tycon = True
+ | isPrimTyCon tycon = False
+ | isAbstractTyCon tycon = True
+ | isFunTyCon tycon || isProductTyCon tycon || isTupleTyCon tycon = any (maybe_parr_ty' alreadySeen) args
+ | isDataTyCon tycon = any (maybe_parr_ty' alreadySeen) args ||
+ hasParrDataCon alreadySeen tycon
+ | otherwise = True
+ where
+ Just (tycon, args) = splitTyConApp_maybe ty
+
+
+ hasParrDataCon alreadySeen tycon
+ | tycon `elem` alreadySeen = False
+ | otherwise =
+ any (maybe_parr_ty' $ tycon : alreadySeen) $ concat $ map dataConOrigArgTys $ tyConDataCons tycon
+
+ -- checks to make sure expression can't contain a non-scalar subexpression. Might err on the side of caution whenever
+ -- an external (non data constructor) variable is used, or anonymous data constructor
+ is_scalar vs e@(Var v)
+ | Just _ <- isDataConId_maybe v = cantbe_parr_expr e
+ | otherwise = cantbe_parr_expr e && (v `elemVarSet` vs)
+ is_scalar _ e@(Lit _) = cantbe_parr_expr e
+
+ is_scalar vs e@(App e1 e2) = cantbe_parr_expr e &&
+ is_scalar vs e1 && is_scalar vs e2
+ is_scalar vs e@(Let (NonRec b letExpr) body)
+ = cantbe_parr_expr e &&
+ is_scalar vs letExpr && is_scalar (extendVarSet vs b) body
+ is_scalar vs e@(Let (Rec bnds) body)
+ = let vs' = extendVarSetList vs (map fst bnds)
+ in cantbe_parr_expr e &&
+ all (is_scalar vs') (map snd bnds) && is_scalar vs' body
+ is_scalar vs e@(Case eC eId ty alts)
+ = let vs' = extendVarSet vs eId
+ in cantbe_parr_expr e &&
+ is_prim_ty ty &&
+ is_scalar vs' eC &&
+ (all (is_scalar_alt vs') alts)
+
+ is_scalar _ _ = False
+
+ is_scalar_alt vs (_, bs, e)
+ = is_scalar (extendVarSetList vs bs) e
-- A scalar function has to actually compute something. Without the check,
-- we would treat (\(x :: Int) -> x) as a scalar function and lift it to
-- (\n# x -> x) which is what we want.
uses funs (Var v) = v `elemVarSet` funs
uses funs (App e1 e2) = uses funs e1 || uses funs e2
+ uses funs (Let (NonRec _b letExpr) body)
+ = uses funs letExpr || uses funs body
+ uses funs (Case e _eId _ty alts)
+ = uses funs e || any (uses_alt funs) alts
uses _ _ = False
+ uses_alt funs (_, _bs, e)
+ = uses funs e
-- | Vectorise a lambda abstraction.
vectLam
vectTyAppExpr :: CoreExprWithFVs -> [Type] -> VM VExpr
vectTyAppExpr (_, AnnVar v) tys = vectPolyVar v tys
-vectTyAppExpr e tys = cantVectorise "Can't vectorise expression"
+vectTyAppExpr e tys = cantVectorise "Can't vectorise expression (vectTyExpr)"
(ppr $ deAnnotate e `mkTyApps` tys)