X-Git-Url: http://git.megacz.com/?p=ghc-hetmet.git;a=blobdiff_plain;f=compiler%2Fvectorise%2FVectorise.hs;h=ea69c4ff6e8845a48a940e7845ab1c21e37fb41a;hp=2bce391a8f72788ee8d8ccc36116232cad754943;hb=c5af8d12b7a6e67f31b4792c1ebe3dc53326c1f3;hpb=cfccfa67393fcf8cb43aaa465d421b67c7117580 diff --git a/compiler/vectorise/Vectorise.hs b/compiler/vectorise/Vectorise.hs index 2bce391..ea69c4f 100644 --- a/compiler/vectorise/Vectorise.hs +++ b/compiler/vectorise/Vectorise.hs @@ -12,6 +12,7 @@ import HscTypes hiding ( MonadThings(..) ) import Module ( PackageId ) import CoreSyn import CoreUtils +import CoreUnfold ( mkInlineRule ) import MkCore ( mkWildCase ) import CoreFVs import CoreMonad ( CoreM, getHscEnv ) @@ -24,6 +25,7 @@ import VarEnv import VarSet import Id import OccName +import BasicTypes ( isLoopBreaker ) import Literal ( Literal, mkMachInt ) import TysWiredIn @@ -31,7 +33,8 @@ import TysPrim ( intPrimTy ) import Outputable import FastString -import Control.Monad ( liftM, liftM2, zipWithM ) +import Util ( zipLazy ) +import Control.Monad import Data.List ( sortBy, unzip4 ) vectorise :: PackageId -> ModGuts -> CoreM ModGuts @@ -67,8 +70,8 @@ vectModule guts vectTopBind :: CoreBind -> VM CoreBind vectTopBind b@(NonRec var expr) = do - var' <- vectTopBinder var - expr' <- vectTopRhs var expr + (inline, expr') <- vectTopRhs var expr + var' <- vectTopBinder var inline expr' hs <- takeHoisted cexpr <- tryConvert var var' expr return . Rec $ (var, cexpr) : (var', expr') : hs @@ -77,8 +80,13 @@ vectTopBind b@(NonRec var expr) vectTopBind b@(Rec bs) = do - vars' <- mapM vectTopBinder vars - exprs' <- zipWithM vectTopRhs vars exprs + (vars', _, exprs') <- fixV $ \ ~(_, inlines, rhss) -> + do + vars' <- sequence [vectTopBinder var inline rhs + | (var, ~(inline, rhs)) + <- zipLazy vars (zip inlines rhss)] + (inlines', exprs') <- mapAndUnzipM (uncurry vectTopRhs) bs + return (vars', inlines', exprs') hs <- takeHoisted cexprs <- sequence $ zipWith3 tryConvert vars vars' exprs return . Rec $ zip vars cexprs ++ zip vars' exprs' ++ hs @@ -87,20 +95,28 @@ vectTopBind b@(Rec bs) where (vars, exprs) = unzip bs -vectTopBinder :: Var -> VM Var -vectTopBinder var +-- NOTE: vectTopBinder *MUST* be lazy in inline and expr because of how it is +-- used inside of fixV in vectTopBind +vectTopBinder :: Var -> Inline -> CoreExpr -> VM Var +vectTopBinder var inline expr = do vty <- vectType (idType var) - var' <- cloneId mkVectOcc var vty + var' <- liftM (`setIdUnfolding` unfolding) $ cloneId mkVectOcc var vty defGlobalVar var var' return var' + where + unfolding = case inline of + Inline arity -> mkInlineRule expr (Just arity) + DontInline -> noUnfolding -vectTopRhs :: Var -> CoreExpr -> VM CoreExpr +vectTopRhs :: Var -> CoreExpr -> VM (Inline, CoreExpr) vectTopRhs var expr - = do - closedV . liftM vectorised - . inBind var - $ vectPolyExpr (freeVars expr) + = closedV + $ do + (inline, vexpr) <- inBind var + $ vectPolyExpr (isLoopBreaker $ idOccInfo var) + (freeVars expr) + return (inline, vectorised vexpr) tryConvert :: Var -> Var -> CoreExpr -> VM CoreExpr tryConvert var vect_var rhs @@ -187,14 +203,19 @@ vectLiteral lit lexpr <- liftPD (Lit lit) return (Lit lit, lexpr) -vectPolyExpr :: CoreExprWithFVs -> VM VExpr -vectPolyExpr (_, AnnNote note expr) - = liftM (vNote note) $ vectPolyExpr expr -vectPolyExpr expr - = polyAbstract tvs $ \abstract -> - do - mono' <- vectFnExpr False mono - return $ mapVect abstract mono' +vectPolyExpr :: Bool -> CoreExprWithFVs -> VM (Inline, VExpr) +vectPolyExpr loop_breaker (_, AnnNote note expr) + = do + (inline, expr') <- vectPolyExpr loop_breaker expr + return (inline, vNote note expr') +vectPolyExpr loop_breaker expr + = do + arity <- polyArity tvs + polyAbstract tvs $ \args -> + do + (inline, mono') <- vectFnExpr False loop_breaker mono + return (addInlineArity inline arity, + mapVect (mkLams $ tvs ++ args) mono') where (tvs, mono) = collectAnnTypeBinders expr @@ -245,7 +266,7 @@ vectExpr (_, AnnCase scrut bndr ty alts) vectExpr (_, AnnLet (AnnNonRec bndr rhs) body) = do - vrhs <- localV . inBind bndr $ vectPolyExpr rhs + vrhs <- localV . inBind bndr . liftM snd $ vectPolyExpr False rhs (vbndr, vbody) <- vectBndrIn bndr (vectExpr body) return $ vLet (vNonRec vbndr vrhs) vbody @@ -254,17 +275,18 @@ vectExpr (_, AnnLet (AnnRec bs) body) (vbndrs, (vrhss, vbody)) <- vectBndrsIn bndrs $ liftM2 (,) (zipWithM vect_rhs bndrs rhss) - (vectPolyExpr body) + (vectExpr body) return $ vLet (vRec vbndrs vrhss) vbody where (bndrs, rhss) = unzip bs vect_rhs bndr rhs = localV . inBind bndr - $ vectExpr rhs + . liftM snd + $ vectPolyExpr (isLoopBreaker $ idOccInfo bndr) rhs vectExpr e@(_, AnnLam bndr _) - | isId bndr = vectFnExpr True e + | isId bndr = liftM snd $ vectFnExpr True False e {- onlyIfV (isEmptyVarSet fvs) (vectScalarLam bs $ deAnnotate body) `orElseV` vectLam True fvs bs body @@ -274,14 +296,17 @@ onlyIfV (isEmptyVarSet fvs) (vectScalarLam bs $ deAnnotate body) vectExpr e = cantVectorise "Can't vectorise expression" (ppr $ deAnnotate e) -vectFnExpr :: Bool -> CoreExprWithFVs -> VM VExpr -vectFnExpr inline e@(fvs, AnnLam bndr _) - | isId bndr = onlyIfV (isEmptyVarSet fvs) (vectScalarLam bs $ deAnnotate body) - `orElseV` vectLam inline fvs bs body +vectFnExpr :: Bool -> Bool -> CoreExprWithFVs -> VM (Inline, VExpr) +vectFnExpr inline loop_breaker e@(fvs, AnnLam bndr _) + | isId bndr = onlyIfV (isEmptyVarSet fvs) + (mark DontInline . vectScalarLam bs $ deAnnotate body) + `orElseV` mark inlineMe (vectLam inline loop_breaker fvs bs body) where (bs,body) = collectAnnValBinders e -vectFnExpr _ e = vectExpr e +vectFnExpr _ _ e = mark DontInline $ vectExpr e +mark :: Inline -> VM a -> VM (Inline, a) +mark b p = do { x <- p; return (b,x) } vectScalarLam :: [Var] -> CoreExpr -> VM VExpr vectScalarLam args body @@ -289,13 +314,14 @@ vectScalarLam args body scalars <- globalScalars onlyIfV (all is_scalar_ty arg_tys && is_scalar_ty res_ty - && is_scalar (extendVarSetList scalars args) body) + && is_scalar (extendVarSetList scalars args) body + && uses scalars body) $ do - fn_var <- hoistExpr (fsLit "fn") (mkLams args body) + fn_var <- hoistExpr (fsLit "fn") (mkLams args body) DontInline zipf <- zipScalars arg_tys res_ty clo <- scalarClosure arg_tys res_ty (Var fn_var) (zipf `App` Var fn_var) - clo_var <- hoistExpr (fsLit "clo") clo + clo_var <- hoistExpr (fsLit "clo") clo DontInline lclo <- liftPD (Var clo_var) return (Var clo_var, lclo) where @@ -314,8 +340,16 @@ vectScalarLam args body is_scalar vs (App e1 e2) = is_scalar vs e1 && is_scalar vs e2 is_scalar _ _ = False -vectLam :: Bool -> VarSet -> [Var] -> CoreExprWithFVs -> VM VExpr -vectLam inline fvs bs body + -- A scalar function has to actually compute something. Without the check, + -- we would treat (\(x :: Int) -> x) as a scalar function and lift it to + -- (map (\x -> x)) which is very bad. Normal lifting transforms it to + -- (\n# x -> x) which is what we want. + uses funs (Var v) = v `elemVarSet` funs + uses funs (App e1 e2) = uses funs e1 || uses funs e2 + uses _ _ = False + +vectLam :: Bool -> Bool -> VarSet -> [Var] -> CoreExprWithFVs -> VM VExpr +vectLam inline loop_breaker fvs bs body = do tyvars <- localTyVars (vs, vvs) <- readLEnv $ \env -> @@ -326,14 +360,28 @@ vectLam inline fvs bs body res_ty <- vectType (exprType $ deAnnotate body) buildClosures tyvars vvs arg_tys res_ty - . hoistPolyVExpr tyvars + . hoistPolyVExpr tyvars (maybe_inline (length vs + length bs)) $ do lc <- builtin liftingContext (vbndrs, vbody) <- vectBndrsIn (vs ++ bs) (vectExpr body) - return . maybe_inline $ vLams lc vbndrs vbody + vbody' <- break_loop lc res_ty vbody + return $ vLams lc vbndrs vbody' where - maybe_inline = if inline then vInlineMe else id + maybe_inline n | inline = Inline n + | otherwise = DontInline + + break_loop lc ty (ve, le) + | loop_breaker + = do + empty <- emptyPD ty + lty <- mkPDataType ty + return (ve, mkWildCase (Var lc) intPrimTy lty + [(DEFAULT, [], le), + (LitAlt (mkMachInt 0), [], empty)]) + + | otherwise = return (ve, le) + vectTyAppExpr :: CoreExprWithFVs -> [Type] -> VM VExpr vectTyAppExpr (_, AnnVar v) tys = vectPolyVar v tys @@ -441,7 +489,7 @@ vectAlgCase tycon _ty_args scrut bndr ty alts cmp _ DEFAULT = GT cmp _ _ = panic "vectAlgCase/cmp" - proc_alt arity sel vty lty (DataAlt dc, bndrs, body) + proc_alt arity sel _ lty (DataAlt dc, bndrs, body) = do vect_dc <- maybeV (lookupDataCon dc) let ntag = dataConTagZ vect_dc