vectTopBind :: CoreBind -> VM CoreBind
vectTopBind b@(NonRec var expr)
= do
- (inline, expr') <- vectTopRhs var expr
+ (inline, _, expr') <- vectTopRhs [] var expr
var' <- vectTopBinder var inline expr'
-- Vectorising the body may create other top-level bindings.
vectTopBind b@(Rec bs)
= do
+ -- pprTrace "in Rec" (ppr vars) $ return ()
(vars', _, exprs')
<- fixV $ \ ~(_, inlines, rhss) ->
do vars' <- sequence [vectTopBinder var inline rhs
| (var, ~(inline, rhs)) <- zipLazy vars (zip inlines rhss)]
- (inlines', exprs')
- <- mapAndUnzipM (uncurry vectTopRhs) bs
-
- return (vars', inlines', exprs')
-
+ (inlines', areScalars', exprs')
+ <- mapAndUnzip3M (uncurry $ vectTopRhs vars) bs
+ if (and areScalars') || (length bs <= 1)
+ then do
+ -- pprTrace "in Rec - all scalars??" (ppr areScalars') $ return ()
+ return (vars', inlines', exprs')
+ else do
+ -- pprTrace "in Rec - not all scalars" (ppr areScalars') $ return ()
+ mapM deleteGlobalScalar vars
+ (inlines'', _, exprs'') <- mapAndUnzip3M (uncurry $ vectTopRhs []) bs
+ return (vars', inlines'', exprs'')
+
hs <- takeHoisted
cexprs <- sequence $ zipWith3 tryConvert vars vars' exprs
return . Rec $ zip vars cexprs ++ zip vars' exprs' ++ hs
return b
where
(vars, exprs) = unzip bs
-
+ mapAndUnzip3M f xs = do
+ ys <- mapM f xs
+ return $ unzip3 ys
-- | Make the vectorised version of this top level binder, and add the mapping
-- between it and the original to the state. For some binder @foo@ the vectorised
-- | Vectorise the RHS of a top-level binding, in an empty local environment.
vectTopRhs
- :: Var -- ^ Name of the binding.
+ :: [Var] -- ^ Names of all functions in the rec block
+ -> Var -- ^ Name of the binding.
-> CoreExpr -- ^ Body of the binding.
- -> VM (Inline, CoreExpr)
+ -> VM (Inline, Bool, CoreExpr)
-vectTopRhs var expr
+vectTopRhs recFs var expr
= dtrace (vcat [text "vectTopRhs", ppr expr])
$ closedV
$ do (inline, isScalar, vexpr) <- inBind var
- $ pprTrace "vectTopRhs" (ppr var)
- $ vectPolyExpr (isLoopBreaker $ idOccInfo var)
+ -- $ pprTrace "vectTopRhs" (ppr var)
+ $ vectPolyExpr (isLoopBreaker $ idOccInfo var) recFs
(freeVars expr)
if isScalar
then addGlobalScalar var
- else return ()
- return (inline, vectorised vexpr)
+ else deleteGlobalScalar var
+ return (inline, isScalar, vectorised vexpr)
-- | Project out the vectorised version of a binding from some closure,
-- | Vectorise a polymorphic expression.
vectPolyExpr
:: Bool -- ^ When vectorising the RHS of a binding, whether that
- -- binding is a loop breaker.
+ -- binding is a loop breaker.
+ -> [Var]
-> CoreExprWithFVs
-> VM (Inline, Bool, VExpr)
-vectPolyExpr loop_breaker (_, AnnNote note expr)
- = do (inline, isScalarFn, expr') <- vectPolyExpr loop_breaker expr
+vectPolyExpr loop_breaker recFns (_, AnnNote note expr)
+ = do (inline, isScalarFn, expr') <- vectPolyExpr loop_breaker recFns expr
return (inline, isScalarFn, vNote note expr')
-vectPolyExpr loop_breaker expr
+vectPolyExpr loop_breaker recFns expr
= do
arity <- polyArity tvs
polyAbstract tvs $ \args ->
do
- (inline, isScalarFn, mono') <- vectFnExpr False loop_breaker mono
+ (inline, isScalarFn, mono') <- vectFnExpr False loop_breaker recFns mono
return (addInlineArity inline arity, isScalarFn,
mapVect (mkLams $ tvs ++ args) mono')
where
vectExpr (_, AnnLet (AnnNonRec bndr rhs) body)
= do
- vrhs <- localV . inBind bndr . liftM (\(_,_,z)->z) $ vectPolyExpr False rhs
+ vrhs <- localV . inBind bndr . liftM (\(_,_,z)->z) $ vectPolyExpr False [] rhs
(vbndr, vbody) <- vectBndrIn bndr (vectExpr body)
return $ vLet (vNonRec vbndr vrhs) vbody
vect_rhs bndr rhs = localV
. inBind bndr
. liftM (\(_,_,z)->z)
- $ vectPolyExpr (isLoopBreaker $ idOccInfo bndr) rhs
+ $ vectPolyExpr (isLoopBreaker $ idOccInfo bndr) [] rhs
vectExpr e@(_, AnnLam bndr _)
- | isId bndr = liftM (\(_,_,z) ->z) $ vectFnExpr True False e
+ | isId bndr = liftM (\(_,_,z) ->z) $ vectFnExpr True False [] e
{-
onlyIfV (isEmptyVarSet fvs) (vectScalarLam bs $ deAnnotate body)
`orElseV` vectLam True fvs bs body
vectFnExpr
:: Bool -- ^ When the RHS of a binding, whether that binding should be inlined.
-> Bool -- ^ Whether the binding is a loop breaker.
+ -> [Var]
-> CoreExprWithFVs -- ^ Expression to vectorise. Must have an outer `AnnLam`.
-> VM (Inline, Bool, VExpr)
-vectFnExpr inline loop_breaker e@(fvs, AnnLam bndr _)
- | isId bndr = pprTrace "vectFnExpr -- id" (ppr fvs )$
+vectFnExpr inline loop_breaker recFns e@(fvs, AnnLam bndr _)
+ | isId bndr = -- pprTrace "vectFnExpr -- id" (ppr fvs )$
onlyIfV True -- (isEmptyVarSet fvs) -- we check for free variables later. TODO: clean up
- (mark DontInline True . vectScalarLam bs $ deAnnotate body)
+ (mark DontInline True . vectScalarLam bs recFns $ deAnnotate body)
`orElseV` mark inlineMe False (vectLam inline loop_breaker fvs bs body)
where
(bs,body) = collectAnnValBinders e
-vectFnExpr _ _ e = pprTrace "vectFnExpr -- otherwise" (ppr "a" )$ mark DontInline False $ vectExpr e
+vectFnExpr _ _ _ e = pprTrace "vectFnExpr -- otherwise" (ppr "a" )$ mark DontInline False $ vectExpr e
mark :: Inline -> Bool -> VM a -> VM (Inline, Bool, a)
mark b isScalarFn p = do { x <- p; return (b, isScalarFn, x) }
-- | Vectorise a function where are the args have scalar type,
-- that is Int, Float, Double etc.
vectScalarLam
- :: [Var] -- ^ Bound variables of function.
+ :: [Var] -- ^ Bound variables of function
+ -> [Var]
-> CoreExpr -- ^ Function body.
-> VM VExpr
-vectScalarLam args body
- = do scalars <- globalScalars
- pprTrace "vectScalarLam" (ppr $ is_scalar (extendVarSetList scalars args) body) $
+vectScalarLam args recFns body
+ = do scalars' <- globalScalars
+ let scalars = unionVarSet (mkVarSet recFns) scalars'
+{- pprTrace "vectScalarLam uses" (ppr $ uses scalars body) $
+ pprTrace "vectScalarLam is prim res" (ppr $ is_prim_ty res_ty) $
+ pprTrace "vectScalarLam is scalar body" (ppr $ is_scalar (extendVarSetList scalars args) body) $
+ pprTrace "vectScalarLam arg tys" (ppr $ arg_tys) $ -}
onlyIfV (all is_prim_ty arg_tys
&& is_prim_ty res_ty
&& is_scalar (extendVarSetList scalars args) body
(zipf `App` Var fn_var)
clo_var <- hoistExpr (fsLit "clo") clo DontInline
lclo <- liftPD (Var clo_var)
- pprTrace " lam is scalar" (ppr "") $
+ {- pprTrace " lam is scalar" (ppr "") $ -}
return (Var clo_var, lclo)
where
arg_tys = map idType args
| isPrimTyCon tycon = False
| isAbstractTyCon tycon = True
| isFunTyCon tycon || isProductTyCon tycon || isTupleTyCon tycon = any (maybe_parr_ty' alreadySeen) args
- | isDataTyCon tycon = pprTrace "isDataTyCon" (ppr tycon) $
+ | isDataTyCon tycon = -- pprTrace "isDataTyCon" (ppr tycon) $
any (maybe_parr_ty' alreadySeen) args ||
hasParrDataCon alreadySeen tycon
| otherwise = True