From 80cb2c397aec9751586c3a2a753f848e143dbd67 Mon Sep 17 00:00:00 2001 From: "keller@cse.unsw.edu.au" Date: Mon, 14 Feb 2011 00:29:45 +0000 Subject: [PATCH] Handling of recursive scalar functions in isScalarLam --- compiler/vectorise/Vectorise.hs | 39 +++++++++++++++++++----------- compiler/vectorise/Vectorise/Exp.hs | 43 +++++++++++++++++++-------------- compiler/vectorise/Vectorise/Monad.hs | 8 +++++- 3 files changed, 57 insertions(+), 33 deletions(-) diff --git a/compiler/vectorise/Vectorise.hs b/compiler/vectorise/Vectorise.hs index 8c9579e..999e8ef 100644 --- a/compiler/vectorise/Vectorise.hs +++ b/compiler/vectorise/Vectorise.hs @@ -115,7 +115,7 @@ vectModule guts vectTopBind :: CoreBind -> VM CoreBind vectTopBind b@(NonRec var expr) = do - (inline, expr') <- vectTopRhs var expr + (inline, _, expr') <- vectTopRhs [] var expr var' <- vectTopBinder var inline expr' -- Vectorising the body may create other top-level bindings. @@ -131,15 +131,23 @@ vectTopBind b@(NonRec var expr) vectTopBind b@(Rec bs) = do + -- pprTrace "in Rec" (ppr vars) $ return () (vars', _, exprs') <- fixV $ \ ~(_, inlines, rhss) -> do vars' <- sequence [vectTopBinder var inline rhs | (var, ~(inline, rhs)) <- zipLazy vars (zip inlines rhss)] - (inlines', exprs') - <- mapAndUnzipM (uncurry vectTopRhs) bs - - return (vars', inlines', exprs') - + (inlines', areScalars', exprs') + <- mapAndUnzip3M (uncurry $ vectTopRhs vars) bs + if (and areScalars') || (length bs <= 1) + then do + -- pprTrace "in Rec - all scalars??" (ppr areScalars') $ return () + return (vars', inlines', exprs') + else do + -- pprTrace "in Rec - not all scalars" (ppr areScalars') $ return () + mapM deleteGlobalScalar vars + (inlines'', _, exprs'') <- mapAndUnzip3M (uncurry $ vectTopRhs []) bs + return (vars', inlines'', exprs'') + hs <- takeHoisted cexprs <- sequence $ zipWith3 tryConvert vars vars' exprs return . Rec $ zip vars cexprs ++ zip vars' exprs' ++ hs @@ -147,7 +155,9 @@ vectTopBind b@(Rec bs) return b where (vars, exprs) = unzip bs - + mapAndUnzip3M f xs = do + ys <- mapM f xs + return $ unzip3 ys -- | Make the vectorised version of this top level binder, and add the mapping -- between it and the original to the state. For some binder @foo@ the vectorised @@ -182,21 +192,22 @@ vectTopBinder var inline expr -- | Vectorise the RHS of a top-level binding, in an empty local environment. vectTopRhs - :: Var -- ^ Name of the binding. + :: [Var] -- ^ Names of all functions in the rec block + -> Var -- ^ Name of the binding. -> CoreExpr -- ^ Body of the binding. - -> VM (Inline, CoreExpr) + -> VM (Inline, Bool, CoreExpr) -vectTopRhs var expr +vectTopRhs recFs var expr = dtrace (vcat [text "vectTopRhs", ppr expr]) $ closedV $ do (inline, isScalar, vexpr) <- inBind var - $ pprTrace "vectTopRhs" (ppr var) - $ vectPolyExpr (isLoopBreaker $ idOccInfo var) + -- $ pprTrace "vectTopRhs" (ppr var) + $ vectPolyExpr (isLoopBreaker $ idOccInfo var) recFs (freeVars expr) if isScalar then addGlobalScalar var - else return () - return (inline, vectorised vexpr) + else deleteGlobalScalar var + return (inline, isScalar, vectorised vexpr) -- | Project out the vectorised version of a binding from some closure, diff --git a/compiler/vectorise/Vectorise/Exp.hs b/compiler/vectorise/Vectorise/Exp.hs index b94224a..091a760 100644 --- a/compiler/vectorise/Vectorise/Exp.hs +++ b/compiler/vectorise/Vectorise/Exp.hs @@ -35,20 +35,21 @@ import Data.List -- | Vectorise a polymorphic expression. vectPolyExpr :: Bool -- ^ When vectorising the RHS of a binding, whether that - -- binding is a loop breaker. + -- binding is a loop breaker. + -> [Var] -> CoreExprWithFVs -> VM (Inline, Bool, VExpr) -vectPolyExpr loop_breaker (_, AnnNote note expr) - = do (inline, isScalarFn, expr') <- vectPolyExpr loop_breaker expr +vectPolyExpr loop_breaker recFns (_, AnnNote note expr) + = do (inline, isScalarFn, expr') <- vectPolyExpr loop_breaker recFns expr return (inline, isScalarFn, vNote note expr') -vectPolyExpr loop_breaker expr +vectPolyExpr loop_breaker recFns expr = do arity <- polyArity tvs polyAbstract tvs $ \args -> do - (inline, isScalarFn, mono') <- vectFnExpr False loop_breaker mono + (inline, isScalarFn, mono') <- vectFnExpr False loop_breaker recFns mono return (addInlineArity inline arity, isScalarFn, mapVect (mkLams $ tvs ++ args) mono') where @@ -117,7 +118,7 @@ vectExpr (_, AnnCase scrut bndr ty alts) vectExpr (_, AnnLet (AnnNonRec bndr rhs) body) = do - vrhs <- localV . inBind bndr . liftM (\(_,_,z)->z) $ vectPolyExpr False rhs + vrhs <- localV . inBind bndr . liftM (\(_,_,z)->z) $ vectPolyExpr False [] rhs (vbndr, vbody) <- vectBndrIn bndr (vectExpr body) return $ vLet (vNonRec vbndr vrhs) vbody @@ -134,10 +135,10 @@ vectExpr (_, AnnLet (AnnRec bs) body) vect_rhs bndr rhs = localV . inBind bndr . liftM (\(_,_,z)->z) - $ vectPolyExpr (isLoopBreaker $ idOccInfo bndr) rhs + $ vectPolyExpr (isLoopBreaker $ idOccInfo bndr) [] rhs vectExpr e@(_, AnnLam bndr _) - | isId bndr = liftM (\(_,_,z) ->z) $ vectFnExpr True False e + | isId bndr = liftM (\(_,_,z) ->z) $ vectFnExpr True False [] e {- onlyIfV (isEmptyVarSet fvs) (vectScalarLam bs $ deAnnotate body) `orElseV` vectLam True fvs bs body @@ -152,18 +153,19 @@ vectExpr e = cantVectorise "Can't vectorise expression (vectExpr)" (ppr $ deAnno vectFnExpr :: Bool -- ^ When the RHS of a binding, whether that binding should be inlined. -> Bool -- ^ Whether the binding is a loop breaker. + -> [Var] -> CoreExprWithFVs -- ^ Expression to vectorise. Must have an outer `AnnLam`. -> VM (Inline, Bool, VExpr) -vectFnExpr inline loop_breaker e@(fvs, AnnLam bndr _) - | isId bndr = pprTrace "vectFnExpr -- id" (ppr fvs )$ +vectFnExpr inline loop_breaker recFns e@(fvs, AnnLam bndr _) + | isId bndr = -- pprTrace "vectFnExpr -- id" (ppr fvs )$ onlyIfV True -- (isEmptyVarSet fvs) -- we check for free variables later. TODO: clean up - (mark DontInline True . vectScalarLam bs $ deAnnotate body) + (mark DontInline True . vectScalarLam bs recFns $ deAnnotate body) `orElseV` mark inlineMe False (vectLam inline loop_breaker fvs bs body) where (bs,body) = collectAnnValBinders e -vectFnExpr _ _ e = pprTrace "vectFnExpr -- otherwise" (ppr "a" )$ mark DontInline False $ vectExpr e +vectFnExpr _ _ _ e = pprTrace "vectFnExpr -- otherwise" (ppr "a" )$ mark DontInline False $ vectExpr e mark :: Inline -> Bool -> VM a -> VM (Inline, Bool, a) mark b isScalarFn p = do { x <- p; return (b, isScalarFn, x) } @@ -172,13 +174,18 @@ mark b isScalarFn p = do { x <- p; return (b, isScalarFn, x) } -- | Vectorise a function where are the args have scalar type, -- that is Int, Float, Double etc. vectScalarLam - :: [Var] -- ^ Bound variables of function. + :: [Var] -- ^ Bound variables of function + -> [Var] -> CoreExpr -- ^ Function body. -> VM VExpr -vectScalarLam args body - = do scalars <- globalScalars - pprTrace "vectScalarLam" (ppr $ is_scalar (extendVarSetList scalars args) body) $ +vectScalarLam args recFns body + = do scalars' <- globalScalars + let scalars = unionVarSet (mkVarSet recFns) scalars' +{- pprTrace "vectScalarLam uses" (ppr $ uses scalars body) $ + pprTrace "vectScalarLam is prim res" (ppr $ is_prim_ty res_ty) $ + pprTrace "vectScalarLam is scalar body" (ppr $ is_scalar (extendVarSetList scalars args) body) $ + pprTrace "vectScalarLam arg tys" (ppr $ arg_tys) $ -} onlyIfV (all is_prim_ty arg_tys && is_prim_ty res_ty && is_scalar (extendVarSetList scalars args) body @@ -190,7 +197,7 @@ vectScalarLam args body (zipf `App` Var fn_var) clo_var <- hoistExpr (fsLit "clo") clo DontInline lclo <- liftPD (Var clo_var) - pprTrace " lam is scalar" (ppr "") $ + {- pprTrace " lam is scalar" (ppr "") $ -} return (Var clo_var, lclo) where arg_tys = map idType args @@ -214,7 +221,7 @@ vectScalarLam args body | isPrimTyCon tycon = False | isAbstractTyCon tycon = True | isFunTyCon tycon || isProductTyCon tycon || isTupleTyCon tycon = any (maybe_parr_ty' alreadySeen) args - | isDataTyCon tycon = pprTrace "isDataTyCon" (ppr tycon) $ + | isDataTyCon tycon = -- pprTrace "isDataTyCon" (ppr tycon) $ any (maybe_parr_ty' alreadySeen) args || hasParrDataCon alreadySeen tycon | otherwise = True diff --git a/compiler/vectorise/Vectorise/Monad.hs b/compiler/vectorise/Vectorise/Monad.hs index 77b9b7f..2597430 100644 --- a/compiler/vectorise/Vectorise/Monad.hs +++ b/compiler/vectorise/Vectorise/Monad.hs @@ -17,7 +17,8 @@ module Vectorise.Monad ( maybeCantVectoriseVarM, dumpVar, addGlobalScalar, - + deleteGlobalScalar, + -- * Primitives lookupPrimPArray, lookupPrimMethod @@ -146,6 +147,11 @@ addGlobalScalar :: Var -> VM () addGlobalScalar var = updGEnv $ \env -> pprTrace "addGLobalScalar" (ppr var) env{global_scalars = extendVarSet (global_scalars env) var} +deleteGlobalScalar :: Var -> VM () +deleteGlobalScalar var + = updGEnv $ \env -> pprTrace "deleteGLobalScalar" (ppr var) env{global_scalars = delVarSet (global_scalars env) var} + + -- Primitives ----------------------------------------------------------------- lookupPrimPArray :: TyCon -> VM (Maybe TyCon) lookupPrimPArray = liftBuiltinDs . primPArray -- 1.7.10.4