From: Roman Leshchinskiy Date: Fri, 30 Oct 2009 00:41:37 +0000 (+0000) Subject: Adapt vectoriser to new inlining mechanism X-Git-Url: http://git.megacz.com/?p=ghc-hetmet.git;a=commitdiff_plain;h=222415a5b658e737a0a1f2c980c6f80635289f75 Adapt vectoriser to new inlining mechanism --- diff --git a/compiler/vectorise/VectCore.hs b/compiler/vectorise/VectCore.hs index d651526..cdae4dd 100644 --- a/compiler/vectorise/VectCore.hs +++ b/compiler/vectorise/VectCore.hs @@ -10,7 +10,7 @@ module VectCore ( vVar, vType, vNote, vLet, vLams, vLamsWithoutLC, vVarApps, - vCaseDEFAULT, vInlineMe + vCaseDEFAULT ) where #include "HsVersions.h" @@ -18,7 +18,6 @@ module VectCore ( import CoreSyn import Type ( Type ) import Var -import Outputable type Vect a = (a,a) type VVar = Vect Var @@ -83,8 +82,3 @@ vCaseDEFAULT (vscrut, lscrut) (vbndr, lbndr) vty lty (vbody, lbody) where mkDEFAULT e = [(DEFAULT, [], e)] -vInlineMe :: VExpr -> VExpr -vInlineMe (vexpr, lexpr) = (mkInlineMe vexpr, mkInlineMe lexpr) - -mkInlineMe :: CoreExpr -> CoreExpr -mkInlineMe = pprTrace "VectCore.mkInlineMe" (text "Roman: need to replace mkInlineMe with an InlineRule somehow") diff --git a/compiler/vectorise/VectType.hs b/compiler/vectorise/VectType.hs index 7b9ec50..6e7557e 100644 --- a/compiler/vectorise/VectType.hs +++ b/compiler/vectorise/VectType.hs @@ -11,6 +11,7 @@ import VectCore import HscTypes ( TypeEnv, extendTypeEnvList, typeEnvTyCons ) import CoreSyn import CoreUtils +import CoreUnfold import MkCore ( mkWildCase ) import BuildTyCl import DataCon @@ -20,9 +21,11 @@ import TypeRep import Coercion import FamInstEnv ( FamInst, mkLocalFamInst ) import OccName +import Id import MkId -import BasicTypes ( StrictnessMark(..), boolToRecFlag ) -import Var ( Var, TyVar ) +import BasicTypes ( StrictnessMark(..), boolToRecFlag, + dfunInlinePragma ) +import Var ( Var, TyVar, varType ) import Name ( Name, getOccName ) import NameEnv @@ -37,7 +40,7 @@ import FastString import MonadUtils ( zipWith3M, foldrM, concatMapM ) import Control.Monad ( liftM, liftM2, zipWithM, zipWithM_, mapAndUnzipM ) -import Data.List ( inits, tails, zipWith4, zipWith6 ) +import Data.List ( inits, tails, zipWith4, zipWith5 ) -- ---------------------------------------------------------------------------- -- Types @@ -119,26 +122,28 @@ vectTypeEnv env let orig_tcs = keep_tcs ++ conv_tcs vect_tcs = keep_tcs ++ new_tcs - dfuns <- mapM mkPADFun vect_tcs - defTyConPAs (zip vect_tcs dfuns) - reprs <- mapM tyConRepr vect_tcs - repr_tcs <- zipWith3M buildPReprTyCon orig_tcs vect_tcs reprs - pdata_tcs <- zipWith3M buildPDataTyCon orig_tcs vect_tcs reprs - binds <- sequence (zipWith6 buildTyConBindings orig_tcs - vect_tcs - repr_tcs - pdata_tcs - dfuns - reprs) - - let all_new_tcs = new_tcs ++ repr_tcs ++ pdata_tcs + (_, binds, inst_tcs) <- fixV $ \ ~(dfuns', _, _) -> + do + defTyConPAs (zipLazy vect_tcs dfuns') + reprs <- mapM tyConRepr vect_tcs + repr_tcs <- zipWith3M buildPReprTyCon orig_tcs vect_tcs reprs + pdata_tcs <- zipWith3M buildPDataTyCon orig_tcs vect_tcs reprs + dfuns <- sequence $ zipWith5 buildTyConBindings orig_tcs + vect_tcs + repr_tcs + pdata_tcs + reprs + binds <- takeHoisted + return (dfuns, binds, repr_tcs ++ pdata_tcs) + + let all_new_tcs = new_tcs ++ inst_tcs let new_env = extendTypeEnvList env (map ATyCon all_new_tcs ++ [ADataCon dc | tc <- all_new_tcs , dc <- tyConDataCons tc]) - return (new_env, map mkLocalFamInst (repr_tcs ++ pdata_tcs), concat binds) + return (new_env, map mkLocalFamInst inst_tcs, binds) where tycons = typeEnvTyCons env groups = tyConGroups tycons @@ -715,18 +720,12 @@ buildPDataDataCon orig_name vect_tc repr_tc repr comp_ty r = mkPDataType (compOrigType r) -mkPADFun :: TyCon -> VM Var -mkPADFun vect_tc - = newExportedVar (mkPADFunOcc $ getOccName vect_tc) =<< paDFunType vect_tc - -buildTyConBindings :: TyCon -> TyCon -> TyCon -> TyCon -> Var -> SumRepr - -> VM [(Var, CoreExpr)] -buildTyConBindings orig_tc vect_tc prepr_tc pdata_tc dfun repr +buildTyConBindings :: TyCon -> TyCon -> TyCon -> TyCon -> SumRepr + -> VM Var +buildTyConBindings orig_tc vect_tc prepr_tc pdata_tc repr = do vectDataConWorkers orig_tc vect_tc pdata_tc - dict <- buildPADict vect_tc prepr_tc pdata_tc repr - binds <- takeHoisted - return $ (dfun, dict) : binds + buildPADict vect_tc prepr_tc pdata_tc repr vectDataConWorkers :: TyCon -> TyCon -> TyCon -> VM () vectDataConWorkers orig_tc vect_tc arr_tc @@ -781,53 +780,71 @@ vectDataConWorkers orig_tc vect_tc arr_tc def_worker data_con arg_tys mk_body = do + arity <- polyArity tyvars body <- closedV . inBind orig_worker - . polyAbstract tyvars $ \abstract -> - liftM (abstract . vectorised) + . polyAbstract tyvars $ \args -> + liftM (mkLams (tyvars ++ args) . vectorised) $ buildClosures tyvars [] arg_tys res_ty mk_body - vect_worker <- cloneId mkVectOcc orig_worker (exprType body) + raw_worker <- cloneId mkVectOcc orig_worker (exprType body) + let vect_worker = raw_worker `setIdUnfolding` + mkInlineRule InlSat body arity defGlobalVar orig_worker vect_worker return (vect_worker, body) where orig_worker = dataConWorkId data_con -buildPADict :: TyCon -> TyCon -> TyCon -> SumRepr -> VM CoreExpr +buildPADict :: TyCon -> TyCon -> TyCon -> SumRepr -> VM Var buildPADict vect_tc prepr_tc arr_tc repr - = polyAbstract tvs $ \abstract -> + = polyAbstract tvs $ \args -> do - meth_binds <- mapM mk_method paMethods - let meth_exprs = map (Var . fst) meth_binds + method_ids <- mapM (method args) paMethods + + pa_tc <- builtin paTyCon + pa_con <- builtin paDataCon + let dict = mkLams (tvs ++ args) + $ mkConApp pa_con + $ Type inst_ty : map (method_call args) method_ids + + dfun_ty = mkForAllTys tvs + $ mkFunTys (map varType args) (mkTyConApp pa_tc [inst_ty]) + + raw_dfun <- newExportedVar dfun_name dfun_ty + let dfun = raw_dfun `setIdUnfolding` mkDFunUnfolding pa_con method_ids + `setInlinePragma` dfunInlinePragma - pa_dc <- builtin paDataCon - let dict = mkConApp pa_dc (Type (mkTyConApp vect_tc arg_tys) : meth_exprs) - body = Let (Rec meth_binds) dict - return . mkInlineMe $ abstract body + hoistBinding dfun dict + return dfun where - tvs = tyConTyVars arr_tc + tvs = tyConTyVars vect_tc arg_tys = mkTyVarTys tvs + inst_ty = mkTyConApp vect_tc arg_tys - mk_method (name, build) + dfun_name = mkPADFunOcc (getOccName vect_tc) + + method args (name, build) = localV $ do - body <- build vect_tc prepr_tc arr_tc repr - var <- newLocalVar name (exprType body) - return (var, mkInlineMe body) - --- The InlineMe note has gone away. Instead, you need to use --- CoreUnfold.mkInlineRule to make an InlineRule for the thing, and --- attach *that* as the unfolding for the dictionary binder -mkInlineMe :: CoreExpr -> CoreExpr -mkInlineMe expr = pprTrace "VectType: Roman, you need to use the new InlineRule story" - (ppr expr) expr - -paMethods :: [(FastString, TyCon -> TyCon -> TyCon -> SumRepr -> VM CoreExpr)] -paMethods = [(fsLit "dictPRepr", buildPRDict), - (fsLit "toPRepr", buildToPRepr), - (fsLit "fromPRepr", buildFromPRepr), - (fsLit "toArrPRepr", buildToArrPRepr), - (fsLit "fromArrPRepr", buildFromArrPRepr)] + expr <- build vect_tc prepr_tc arr_tc repr + let body = mkLams (tvs ++ args) expr + raw_var <- newExportedVar (method_name name) (exprType body) + let var = raw_var + `setIdUnfolding` mkInlineRule InlSat body (length args) + hoistBinding var body + return var + + method_call args id = mkApps (Var id) (map Type arg_tys ++ map Var args) + + method_name name = mkVarOcc $ occNameString dfun_name ++ ('$' : name) + + +paMethods :: [(String, TyCon -> TyCon -> TyCon -> SumRepr -> VM CoreExpr)] +paMethods = [("dictPRepr", buildPRDict), + ("toPRepr", buildToPRepr), + ("fromPRepr", buildFromPRepr), + ("toArrPRepr", buildToArrPRepr), + ("fromArrPRepr", buildFromArrPRepr)] -- | Split the given tycons into two sets depending on whether they have to be -- converted (first list) or not (second list). The first argument contains diff --git a/compiler/vectorise/VectUtils.hs b/compiler/vectorise/VectUtils.hs index e508424..9faa0ed 100644 --- a/compiler/vectorise/VectUtils.hs +++ b/compiler/vectorise/VectUtils.hs @@ -15,7 +15,8 @@ module VectUtils ( combinePD, liftPD, zipScalars, scalarClosure, - polyAbstract, polyApply, polyVApply, + polyAbstract, polyApply, polyVApply, polyArity, + Inline(..), addInlineArity, inlineMe, hoistBinding, hoistExpr, hoistPolyVExpr, takeHoisted, buildClosure, buildClosures, mkClosureApp @@ -27,6 +28,7 @@ import VectMonad import MkCore ( mkCoreTup, mkCoreTupTy, mkWildCase ) import CoreSyn import CoreUtils +import CoreUnfold ( mkInlineRule ) import Coercion import Type import TypeRep @@ -34,6 +36,7 @@ import TyCon import DataCon import Var import MkId ( unwrapFamInstScrut ) +import Id ( setIdUnfolding ) import TysWiredIn import BasicTypes ( Boxity(..) ) import Literal ( Literal, mkMachInt ) @@ -43,7 +46,6 @@ import FastString import Control.Monad - collectAnnTypeArgs :: AnnExpr b ann -> (AnnExpr b ann, [Type]) collectAnnTypeArgs expr = go expr [] where @@ -315,13 +317,14 @@ newLocalVVar fs vty lv <- newLocalVar fs lty return (vv,lv) -polyAbstract :: [TyVar] -> ((CoreExpr -> CoreExpr) -> VM a) -> VM a +polyAbstract :: [TyVar] -> ([Var] -> VM a) -> VM a polyAbstract tvs p = localV $ do mdicts <- mapM mk_dict_var tvs - zipWithM_ (\tv -> maybe (defLocalTyVar tv) (defLocalTyVarWithPA tv . Var)) tvs mdicts - p (mk_lams mdicts) + zipWithM_ (\tv -> maybe (defLocalTyVar tv) + (defLocalTyVarWithPA tv . Var)) tvs mdicts + p (mk_args mdicts) where mk_dict_var tv = do r <- paDictArgType tv @@ -329,7 +332,12 @@ polyAbstract tvs p Just ty -> liftM Just (newLocalVar (fsLit "dPA") ty) Nothing -> return Nothing - mk_lams mdicts = mkLams (tvs ++ [dict | Just dict <- mdicts]) + mk_args mdicts = [dict | Just dict <- mdicts] + +polyArity :: [TyVar] -> VM Int +polyArity tvs = do + tys <- mapM paDictArgType tvs + return $ length [() | Just _ <- tys] polyApply :: CoreExpr -> [Type] -> VM CoreExpr polyApply expr tys @@ -343,31 +351,48 @@ polyVApply expr tys dicts <- mapM paDictOfType tys return $ mapVect (\e -> e `mkTyApps` tys `mkApps` dicts) expr + +data Inline = Inline Int -- arity + | DontInline + +addInlineArity :: Inline -> Int -> Inline +addInlineArity (Inline m) n = Inline (m+n) +addInlineArity DontInline _ = DontInline + +inlineMe :: Inline +inlineMe = Inline 0 + hoistBinding :: Var -> CoreExpr -> VM () hoistBinding v e = updGEnv $ \env -> env { global_bindings = (v,e) : global_bindings env } -hoistExpr :: FastString -> CoreExpr -> VM Var -hoistExpr fs expr +hoistExpr :: FastString -> CoreExpr -> Inline -> VM Var +hoistExpr fs expr inl = do - var <- newLocalVar fs (exprType expr) + var <- mk_inline `liftM` newLocalVar fs (exprType expr) hoistBinding var expr return var + where + mk_inline var = case inl of + Inline arity -> var `setIdUnfolding` + mkInlineRule InlSat expr arity + DontInline -> var -hoistVExpr :: VExpr -> VM VVar -hoistVExpr (ve, le) +hoistVExpr :: VExpr -> Inline -> VM VVar +hoistVExpr (ve, le) inl = do fs <- getBindName - vv <- hoistExpr ('v' `consFS` fs) ve - lv <- hoistExpr ('l' `consFS` fs) le + vv <- hoistExpr ('v' `consFS` fs) ve inl + lv <- hoistExpr ('l' `consFS` fs) le (addInlineArity inl 1) return (vv, lv) -hoistPolyVExpr :: [TyVar] -> VM VExpr -> VM VExpr -hoistPolyVExpr tvs p +hoistPolyVExpr :: [TyVar] -> Inline -> VM VExpr -> VM VExpr +hoistPolyVExpr tvs inline p = do - expr <- closedV . polyAbstract tvs $ \abstract -> - liftM (mapVect abstract) p - fn <- hoistVExpr expr + inline' <- liftM (addInlineArity inline) (polyArity tvs) + expr <- closedV . polyAbstract tvs $ \args -> + liftM (mapVect (mkLams $ tvs ++ args)) p + fn <- hoistVExpr expr inline' polyVApply (vVar fn) (mkTyVarTys tvs) takeHoisted :: VM [(Var, CoreExpr)] @@ -413,14 +438,15 @@ buildClosures :: [TyVar] -> [VVar] -> [Type] -> Type -> VM VExpr -> VM VExpr buildClosures _ _ [] _ mk_body = mk_body buildClosures tvs vars [arg_ty] res_ty mk_body - = liftM vInlineMe (buildClosure tvs vars arg_ty res_ty mk_body) + = -- liftM vInlineMe $ + buildClosure tvs vars arg_ty res_ty mk_body buildClosures tvs vars (arg_ty : arg_tys) res_ty mk_body = do res_ty' <- mkClosureTypes arg_tys res_ty arg <- newLocalVVar (fsLit "x") arg_ty - liftM vInlineMe - . buildClosure tvs vars arg_ty res_ty' - . hoistPolyVExpr tvs + -- liftM vInlineMe + buildClosure tvs vars arg_ty res_ty' + . hoistPolyVExpr tvs (Inline (length vars + 1)) $ do lc <- builtin liftingContext clo <- buildClosures tvs (vars ++ [arg]) arg_tys res_ty mk_body @@ -438,11 +464,11 @@ buildClosure tvs vars arg_ty res_ty mk_body env_bndr <- newLocalVVar (fsLit "env") env_ty arg_bndr <- newLocalVVar (fsLit "arg") arg_ty - fn <- hoistPolyVExpr tvs + fn <- hoistPolyVExpr tvs (Inline 2) $ do lc <- builtin liftingContext body <- mk_body - return . vInlineMe + return -- . vInlineMe . vLams lc [env_bndr, arg_bndr] $ bind (vVar env_bndr) (vVarApps lc body (vars ++ [arg_bndr])) diff --git a/compiler/vectorise/Vectorise.hs b/compiler/vectorise/Vectorise.hs index 2bce391..59fded3 100644 --- a/compiler/vectorise/Vectorise.hs +++ b/compiler/vectorise/Vectorise.hs @@ -12,6 +12,7 @@ import HscTypes hiding ( MonadThings(..) ) import Module ( PackageId ) import CoreSyn import CoreUtils +import CoreUnfold ( mkInlineRule ) import MkCore ( mkWildCase ) import CoreFVs import CoreMonad ( CoreM, getHscEnv ) @@ -24,6 +25,7 @@ import VarEnv import VarSet import Id import OccName +import BasicTypes ( isLoopBreaker ) import Literal ( Literal, mkMachInt ) import TysWiredIn @@ -31,7 +33,8 @@ import TysPrim ( intPrimTy ) import Outputable import FastString -import Control.Monad ( liftM, liftM2, zipWithM ) +import Util ( zipLazy ) +import Control.Monad import Data.List ( sortBy, unzip4 ) vectorise :: PackageId -> ModGuts -> CoreM ModGuts @@ -67,8 +70,8 @@ vectModule guts vectTopBind :: CoreBind -> VM CoreBind vectTopBind b@(NonRec var expr) = do - var' <- vectTopBinder var - expr' <- vectTopRhs var expr + (inline, expr') <- vectTopRhs var expr + var' <- vectTopBinder var inline expr' hs <- takeHoisted cexpr <- tryConvert var var' expr return . Rec $ (var, cexpr) : (var', expr') : hs @@ -77,8 +80,13 @@ vectTopBind b@(NonRec var expr) vectTopBind b@(Rec bs) = do - vars' <- mapM vectTopBinder vars - exprs' <- zipWithM vectTopRhs vars exprs + (vars', _, exprs') <- fixV $ \ ~(_, inlines, rhss) -> + do + vars' <- sequence [vectTopBinder var inline rhs + | (var, ~(inline, rhs)) + <- zipLazy vars (zip inlines rhss)] + (inlines', exprs') <- mapAndUnzipM (uncurry vectTopRhs) bs + return (vars', inlines', exprs') hs <- takeHoisted cexprs <- sequence $ zipWith3 tryConvert vars vars' exprs return . Rec $ zip vars cexprs ++ zip vars' exprs' ++ hs @@ -87,20 +95,28 @@ vectTopBind b@(Rec bs) where (vars, exprs) = unzip bs -vectTopBinder :: Var -> VM Var -vectTopBinder var +-- NOTE: vectTopBinder *MUST* be lazy in inline and expr because of how it is +-- used inside of fixV in vectTopBind +vectTopBinder :: Var -> Inline -> CoreExpr -> VM Var +vectTopBinder var inline expr = do vty <- vectType (idType var) - var' <- cloneId mkVectOcc var vty + var' <- liftM (`setIdUnfolding` unfolding) $ cloneId mkVectOcc var vty defGlobalVar var var' return var' + where + unfolding = case inline of + Inline arity -> mkInlineRule InlSat expr arity + DontInline -> noUnfolding -vectTopRhs :: Var -> CoreExpr -> VM CoreExpr +vectTopRhs :: Var -> CoreExpr -> VM (Inline, CoreExpr) vectTopRhs var expr - = do - closedV . liftM vectorised - . inBind var - $ vectPolyExpr (freeVars expr) + = closedV + $ do + (inline, vexpr) <- inBind var + $ vectPolyExpr (isLoopBreaker $ idOccInfo var) + (freeVars expr) + return (inline, vectorised vexpr) tryConvert :: Var -> Var -> CoreExpr -> VM CoreExpr tryConvert var vect_var rhs @@ -187,14 +203,19 @@ vectLiteral lit lexpr <- liftPD (Lit lit) return (Lit lit, lexpr) -vectPolyExpr :: CoreExprWithFVs -> VM VExpr -vectPolyExpr (_, AnnNote note expr) - = liftM (vNote note) $ vectPolyExpr expr -vectPolyExpr expr - = polyAbstract tvs $ \abstract -> - do - mono' <- vectFnExpr False mono - return $ mapVect abstract mono' +vectPolyExpr :: Bool -> CoreExprWithFVs -> VM (Inline, VExpr) +vectPolyExpr loop_breaker (_, AnnNote note expr) + = do + (inline, expr') <- vectPolyExpr loop_breaker expr + return (inline, vNote note expr') +vectPolyExpr loop_breaker expr + = do + arity <- polyArity tvs + polyAbstract tvs $ \args -> + do + (inline, mono') <- vectFnExpr False loop_breaker mono + return (addInlineArity inline arity, + mapVect (mkLams $ tvs ++ args) mono') where (tvs, mono) = collectAnnTypeBinders expr @@ -245,7 +266,7 @@ vectExpr (_, AnnCase scrut bndr ty alts) vectExpr (_, AnnLet (AnnNonRec bndr rhs) body) = do - vrhs <- localV . inBind bndr $ vectPolyExpr rhs + vrhs <- localV . inBind bndr . liftM snd $ vectPolyExpr False rhs (vbndr, vbody) <- vectBndrIn bndr (vectExpr body) return $ vLet (vNonRec vbndr vrhs) vbody @@ -254,17 +275,18 @@ vectExpr (_, AnnLet (AnnRec bs) body) (vbndrs, (vrhss, vbody)) <- vectBndrsIn bndrs $ liftM2 (,) (zipWithM vect_rhs bndrs rhss) - (vectPolyExpr body) + (vectExpr body) return $ vLet (vRec vbndrs vrhss) vbody where (bndrs, rhss) = unzip bs vect_rhs bndr rhs = localV . inBind bndr - $ vectExpr rhs + . liftM snd + $ vectPolyExpr (isLoopBreaker $ idOccInfo bndr) rhs vectExpr e@(_, AnnLam bndr _) - | isId bndr = vectFnExpr True e + | isId bndr = liftM snd $ vectFnExpr True False e {- onlyIfV (isEmptyVarSet fvs) (vectScalarLam bs $ deAnnotate body) `orElseV` vectLam True fvs bs body @@ -274,14 +296,17 @@ onlyIfV (isEmptyVarSet fvs) (vectScalarLam bs $ deAnnotate body) vectExpr e = cantVectorise "Can't vectorise expression" (ppr $ deAnnotate e) -vectFnExpr :: Bool -> CoreExprWithFVs -> VM VExpr -vectFnExpr inline e@(fvs, AnnLam bndr _) - | isId bndr = onlyIfV (isEmptyVarSet fvs) (vectScalarLam bs $ deAnnotate body) - `orElseV` vectLam inline fvs bs body +vectFnExpr :: Bool -> Bool -> CoreExprWithFVs -> VM (Inline, VExpr) +vectFnExpr inline loop_breaker e@(fvs, AnnLam bndr _) + | isId bndr = onlyIfV (isEmptyVarSet fvs) + (mark DontInline . vectScalarLam bs $ deAnnotate body) + `orElseV` mark inlineMe (vectLam inline loop_breaker fvs bs body) where (bs,body) = collectAnnValBinders e -vectFnExpr _ e = vectExpr e +vectFnExpr _ _ e = mark DontInline $ vectExpr e +mark :: Inline -> VM a -> VM (Inline, a) +mark b p = do { x <- p; return (b,x) } vectScalarLam :: [Var] -> CoreExpr -> VM VExpr vectScalarLam args body @@ -291,11 +316,11 @@ vectScalarLam args body && is_scalar_ty res_ty && is_scalar (extendVarSetList scalars args) body) $ do - fn_var <- hoistExpr (fsLit "fn") (mkLams args body) + fn_var <- hoistExpr (fsLit "fn") (mkLams args body) DontInline zipf <- zipScalars arg_tys res_ty clo <- scalarClosure arg_tys res_ty (Var fn_var) (zipf `App` Var fn_var) - clo_var <- hoistExpr (fsLit "clo") clo + clo_var <- hoistExpr (fsLit "clo") clo DontInline lclo <- liftPD (Var clo_var) return (Var clo_var, lclo) where @@ -314,8 +339,8 @@ vectScalarLam args body is_scalar vs (App e1 e2) = is_scalar vs e1 && is_scalar vs e2 is_scalar _ _ = False -vectLam :: Bool -> VarSet -> [Var] -> CoreExprWithFVs -> VM VExpr -vectLam inline fvs bs body +vectLam :: Bool -> Bool -> VarSet -> [Var] -> CoreExprWithFVs -> VM VExpr +vectLam inline loop_breaker fvs bs body = do tyvars <- localTyVars (vs, vvs) <- readLEnv $ \env -> @@ -326,14 +351,28 @@ vectLam inline fvs bs body res_ty <- vectType (exprType $ deAnnotate body) buildClosures tyvars vvs arg_tys res_ty - . hoistPolyVExpr tyvars + . hoistPolyVExpr tyvars (maybe_inline (length vs + length bs)) $ do lc <- builtin liftingContext (vbndrs, vbody) <- vectBndrsIn (vs ++ bs) (vectExpr body) - return . maybe_inline $ vLams lc vbndrs vbody + vbody' <- break_loop lc res_ty vbody + return $ vLams lc vbndrs vbody' where - maybe_inline = if inline then vInlineMe else id + maybe_inline n | inline = Inline n + | otherwise = DontInline + + break_loop lc ty (ve, le) + | loop_breaker + = do + empty <- emptyPD ty + lty <- mkPDataType ty + return (ve, mkWildCase (Var lc) intPrimTy lty + [(DEFAULT, [], le), + (LitAlt (mkMachInt 0), [], empty)]) + + | otherwise = return (ve, le) + vectTyAppExpr :: CoreExprWithFVs -> [Type] -> VM VExpr vectTyAppExpr (_, AnnVar v) tys = vectPolyVar v tys @@ -441,7 +480,7 @@ vectAlgCase tycon _ty_args scrut bndr ty alts cmp _ DEFAULT = GT cmp _ _ = panic "vectAlgCase/cmp" - proc_alt arity sel vty lty (DataAlt dc, bndrs, body) + proc_alt arity sel _ lty (DataAlt dc, bndrs, body) = do vect_dc <- maybeV (lookupDataCon dc) let ntag = dataConTagZ vect_dc