Adapt vectoriser to new inlining mechanism
authorRoman Leshchinskiy <rl@cse.unsw.edu.au>
Fri, 30 Oct 2009 00:41:37 +0000 (00:41 +0000)
committerRoman Leshchinskiy <rl@cse.unsw.edu.au>
Fri, 30 Oct 2009 00:41:37 +0000 (00:41 +0000)
compiler/vectorise/VectCore.hs
compiler/vectorise/VectType.hs
compiler/vectorise/VectUtils.hs
compiler/vectorise/Vectorise.hs

index d651526..cdae4dd 100644 (file)
@@ -10,7 +10,7 @@ module VectCore (
 
   vVar, vType, vNote, vLet,
   vLams, vLamsWithoutLC, vVarApps,
-  vCaseDEFAULT, vInlineMe
+  vCaseDEFAULT
 ) where
 
 #include "HsVersions.h"
@@ -18,7 +18,6 @@ module VectCore (
 import CoreSyn
 import Type           ( Type )
 import Var
-import Outputable
 
 type Vect a = (a,a)
 type VVar   = Vect Var
@@ -83,8 +82,3 @@ vCaseDEFAULT (vscrut, lscrut) (vbndr, lbndr) vty lty (vbody, lbody)
   where
     mkDEFAULT e = [(DEFAULT, [], e)]
 
-vInlineMe :: VExpr -> VExpr
-vInlineMe (vexpr, lexpr) = (mkInlineMe vexpr, mkInlineMe lexpr)
-
-mkInlineMe :: CoreExpr -> CoreExpr
-mkInlineMe = pprTrace "VectCore.mkInlineMe" (text "Roman: need to replace mkInlineMe with an InlineRule somehow")
index 7b9ec50..6e7557e 100644 (file)
@@ -11,6 +11,7 @@ import VectCore
 import HscTypes          ( TypeEnv, extendTypeEnvList, typeEnvTyCons )
 import CoreSyn
 import CoreUtils
+import CoreUnfold
 import MkCore           ( mkWildCase )
 import BuildTyCl
 import DataCon
@@ -20,9 +21,11 @@ import TypeRep
 import Coercion
 import FamInstEnv        ( FamInst, mkLocalFamInst )
 import OccName
+import Id
 import MkId
-import BasicTypes        ( StrictnessMark(..), boolToRecFlag )
-import Var               ( Var, TyVar )
+import BasicTypes        ( StrictnessMark(..), boolToRecFlag,
+                           dfunInlinePragma )
+import Var               ( Var, TyVar, varType )
 import Name              ( Name, getOccName )
 import NameEnv
 
@@ -37,7 +40,7 @@ import FastString
 
 import MonadUtils     ( zipWith3M, foldrM, concatMapM )
 import Control.Monad  ( liftM, liftM2, zipWithM, zipWithM_, mapAndUnzipM )
-import Data.List      ( inits, tails, zipWith4, zipWith6 )
+import Data.List      ( inits, tails, zipWith4, zipWith5 )
 
 -- ----------------------------------------------------------------------------
 -- Types
@@ -119,26 +122,28 @@ vectTypeEnv env
       let orig_tcs = keep_tcs ++ conv_tcs
           vect_tcs = keep_tcs ++ new_tcs
 
-      dfuns <- mapM mkPADFun vect_tcs
-      defTyConPAs (zip vect_tcs dfuns)
-      reprs <- mapM tyConRepr vect_tcs
-      repr_tcs  <- zipWith3M buildPReprTyCon orig_tcs vect_tcs reprs
-      pdata_tcs <- zipWith3M buildPDataTyCon orig_tcs vect_tcs reprs
-      binds    <- sequence (zipWith6 buildTyConBindings orig_tcs
-                                                        vect_tcs
-                                                        repr_tcs
-                                                        pdata_tcs
-                                                        dfuns
-                                                        reprs)
-
-      let all_new_tcs = new_tcs ++ repr_tcs ++ pdata_tcs
+      (_, binds, inst_tcs) <- fixV $ \ ~(dfuns', _, _) ->
+        do
+          defTyConPAs (zipLazy vect_tcs dfuns')
+          reprs <- mapM tyConRepr vect_tcs
+          repr_tcs  <- zipWith3M buildPReprTyCon orig_tcs vect_tcs reprs
+          pdata_tcs <- zipWith3M buildPDataTyCon orig_tcs vect_tcs reprs
+          dfuns <- sequence $ zipWith5 buildTyConBindings orig_tcs
+                                                          vect_tcs
+                                                          repr_tcs
+                                                          pdata_tcs
+                                                          reprs
+          binds <- takeHoisted
+          return (dfuns, binds, repr_tcs ++ pdata_tcs)
+
+      let all_new_tcs = new_tcs ++ inst_tcs
 
       let new_env = extendTypeEnvList env
                        (map ATyCon all_new_tcs
                         ++ [ADataCon dc | tc <- all_new_tcs
                                         , dc <- tyConDataCons tc])
 
-      return (new_env, map mkLocalFamInst (repr_tcs ++ pdata_tcs), concat binds)
+      return (new_env, map mkLocalFamInst inst_tcs, binds)
   where
     tycons = typeEnvTyCons env
     groups = tyConGroups tycons
@@ -715,18 +720,12 @@ buildPDataDataCon orig_name vect_tc repr_tc repr
     comp_ty r = mkPDataType (compOrigType r)
 
 
-mkPADFun :: TyCon -> VM Var
-mkPADFun vect_tc
-  = newExportedVar (mkPADFunOcc $ getOccName vect_tc) =<< paDFunType vect_tc
-
-buildTyConBindings :: TyCon -> TyCon -> TyCon -> TyCon -> Var -> SumRepr 
-                   -> VM [(Var, CoreExpr)]
-buildTyConBindings orig_tc vect_tc prepr_tc pdata_tc dfun repr
+buildTyConBindings :: TyCon -> TyCon -> TyCon -> TyCon -> SumRepr 
+                   -> VM Var
+buildTyConBindings orig_tc vect_tc prepr_tc pdata_tc repr
   = do
       vectDataConWorkers orig_tc vect_tc pdata_tc
-      dict <- buildPADict vect_tc prepr_tc pdata_tc repr
-      binds <- takeHoisted
-      return $ (dfun, dict) : binds
+      buildPADict vect_tc prepr_tc pdata_tc repr
 
 vectDataConWorkers :: TyCon -> TyCon -> TyCon -> VM ()
 vectDataConWorkers orig_tc vect_tc arr_tc
@@ -781,53 +780,71 @@ vectDataConWorkers orig_tc vect_tc arr_tc
 
     def_worker data_con arg_tys mk_body
       = do
+          arity <- polyArity tyvars
           body <- closedV
                 . inBind orig_worker
-                . polyAbstract tyvars $ \abstract ->
-                  liftM (abstract . vectorised)
+                . polyAbstract tyvars $ \args ->
+                  liftM (mkLams (tyvars ++ args) . vectorised)
                 $ buildClosures tyvars [] arg_tys res_ty mk_body
 
-          vect_worker <- cloneId mkVectOcc orig_worker (exprType body)
+          raw_worker <- cloneId mkVectOcc orig_worker (exprType body)
+          let vect_worker = raw_worker `setIdUnfolding`
+                              mkInlineRule InlSat body arity
           defGlobalVar orig_worker vect_worker
           return (vect_worker, body)
       where
         orig_worker = dataConWorkId data_con
 
-buildPADict :: TyCon -> TyCon -> TyCon -> SumRepr -> VM CoreExpr
+buildPADict :: TyCon -> TyCon -> TyCon -> SumRepr -> VM Var
 buildPADict vect_tc prepr_tc arr_tc repr
-  = polyAbstract tvs $ \abstract ->
+  = polyAbstract tvs $ \args ->
     do
-      meth_binds <- mapM mk_method paMethods
-      let meth_exprs = map (Var . fst) meth_binds
+      method_ids <- mapM (method args) paMethods
+
+      pa_tc  <- builtin paTyCon
+      pa_con <- builtin paDataCon
+      let dict = mkLams (tvs ++ args)
+               $ mkConApp pa_con
+               $ Type inst_ty : map (method_call args) method_ids
+
+          dfun_ty = mkForAllTys tvs
+                  $ mkFunTys (map varType args) (mkTyConApp pa_tc [inst_ty])
+
+      raw_dfun <- newExportedVar dfun_name dfun_ty
+      let dfun = raw_dfun `setIdUnfolding` mkDFunUnfolding pa_con method_ids
+                          `setInlinePragma` dfunInlinePragma
 
-      pa_dc <- builtin paDataCon
-      let dict = mkConApp pa_dc (Type (mkTyConApp vect_tc arg_tys) : meth_exprs)
-          body = Let (Rec meth_binds) dict
-      return . mkInlineMe $ abstract body
+      hoistBinding dfun dict
+      return dfun
   where
-    tvs = tyConTyVars arr_tc
+    tvs = tyConTyVars vect_tc
     arg_tys = mkTyVarTys tvs
+    inst_ty = mkTyConApp vect_tc arg_tys
 
-    mk_method (name, build)
+    dfun_name = mkPADFunOcc (getOccName vect_tc)
+
+    method args (name, build)
       = localV
       $ do
-          body <- build vect_tc prepr_tc arr_tc repr
-          var  <- newLocalVar name (exprType body)
-          return (var, mkInlineMe body)
-
--- The InlineMe note has gone away.  Instead, you need to use
--- CoreUnfold.mkInlineRule to make an InlineRule for the thing, and
--- attach *that* as the unfolding for the dictionary binder
-mkInlineMe :: CoreExpr -> CoreExpr
-mkInlineMe expr = pprTrace "VectType: Roman, you need to use the new InlineRule story" 
-                          (ppr expr) expr
-
-paMethods :: [(FastString, TyCon -> TyCon -> TyCon -> SumRepr -> VM CoreExpr)]
-paMethods = [(fsLit "dictPRepr",    buildPRDict),
-             (fsLit "toPRepr",      buildToPRepr),
-             (fsLit "fromPRepr",    buildFromPRepr),
-             (fsLit "toArrPRepr",   buildToArrPRepr),
-             (fsLit "fromArrPRepr", buildFromArrPRepr)]
+          expr <- build vect_tc prepr_tc arr_tc repr
+          let body = mkLams (tvs ++ args) expr
+          raw_var <- newExportedVar (method_name name) (exprType body)
+          let var = raw_var
+                      `setIdUnfolding` mkInlineRule InlSat body (length args)
+          hoistBinding var body
+          return var
+
+    method_call args id = mkApps (Var id) (map Type arg_tys ++ map Var args)
+
+    method_name name = mkVarOcc $ occNameString dfun_name ++ ('$' : name)
+
+
+paMethods :: [(String, TyCon -> TyCon -> TyCon -> SumRepr -> VM CoreExpr)]
+paMethods = [("dictPRepr",    buildPRDict),
+             ("toPRepr",      buildToPRepr),
+             ("fromPRepr",    buildFromPRepr),
+             ("toArrPRepr",   buildToArrPRepr),
+             ("fromArrPRepr", buildFromArrPRepr)]
 
 -- | Split the given tycons into two sets depending on whether they have to be
 -- converted (first list) or not (second list). The first argument contains
index e508424..9faa0ed 100644 (file)
@@ -15,7 +15,8 @@ module VectUtils (
   combinePD,
   liftPD,
   zipScalars, scalarClosure,
-  polyAbstract, polyApply, polyVApply,
+  polyAbstract, polyApply, polyVApply, polyArity,
+  Inline(..), addInlineArity, inlineMe,
   hoistBinding, hoistExpr, hoistPolyVExpr, takeHoisted,
   buildClosure, buildClosures,
   mkClosureApp
@@ -27,6 +28,7 @@ import VectMonad
 import MkCore ( mkCoreTup, mkCoreTupTy, mkWildCase )
 import CoreSyn
 import CoreUtils
+import CoreUnfold         ( mkInlineRule )
 import Coercion
 import Type
 import TypeRep
@@ -34,6 +36,7 @@ import TyCon
 import DataCon
 import Var
 import MkId               ( unwrapFamInstScrut )
+import Id                 ( setIdUnfolding )
 import TysWiredIn
 import BasicTypes         ( Boxity(..) )
 import Literal            ( Literal, mkMachInt )
@@ -43,7 +46,6 @@ import FastString
 
 import Control.Monad
 
-
 collectAnnTypeArgs :: AnnExpr b ann -> (AnnExpr b ann, [Type])
 collectAnnTypeArgs expr = go expr []
   where
@@ -315,13 +317,14 @@ newLocalVVar fs vty
       lv  <- newLocalVar fs lty
       return (vv,lv)
 
-polyAbstract :: [TyVar] -> ((CoreExpr -> CoreExpr) -> VM a) -> VM a
+polyAbstract :: [TyVar] -> ([Var] -> VM a) -> VM a
 polyAbstract tvs p
   = localV
   $ do
       mdicts <- mapM mk_dict_var tvs
-      zipWithM_ (\tv -> maybe (defLocalTyVar tv) (defLocalTyVarWithPA tv . Var)) tvs mdicts
-      p (mk_lams mdicts)
+      zipWithM_ (\tv -> maybe (defLocalTyVar tv)
+                              (defLocalTyVarWithPA tv . Var)) tvs mdicts
+      p (mk_args mdicts)
   where
     mk_dict_var tv = do
                        r <- paDictArgType tv
@@ -329,7 +332,12 @@ polyAbstract tvs p
                          Just ty -> liftM Just (newLocalVar (fsLit "dPA") ty)
                          Nothing -> return Nothing
 
-    mk_lams mdicts = mkLams (tvs ++ [dict | Just dict <- mdicts])
+    mk_args mdicts = [dict | Just dict <- mdicts]
+
+polyArity :: [TyVar] -> VM Int
+polyArity tvs = do
+                  tys <- mapM paDictArgType tvs
+                  return $ length [() | Just _ <- tys]
 
 polyApply :: CoreExpr -> [Type] -> VM CoreExpr
 polyApply expr tys
@@ -343,31 +351,48 @@ polyVApply expr tys
       dicts <- mapM paDictOfType tys
       return $ mapVect (\e -> e `mkTyApps` tys `mkApps` dicts) expr
 
+
+data Inline = Inline Int -- arity
+            | DontInline
+
+addInlineArity :: Inline -> Int -> Inline
+addInlineArity (Inline m) n = Inline (m+n)
+addInlineArity DontInline _ = DontInline
+
+inlineMe :: Inline
+inlineMe = Inline 0
+
 hoistBinding :: Var -> CoreExpr -> VM ()
 hoistBinding v e = updGEnv $ \env ->
   env { global_bindings = (v,e) : global_bindings env }
 
-hoistExpr :: FastString -> CoreExpr -> VM Var
-hoistExpr fs expr
+hoistExpr :: FastString -> CoreExpr -> Inline -> VM Var
+hoistExpr fs expr inl
   = do
-      var <- newLocalVar fs (exprType expr)
+      var <- mk_inline `liftM` newLocalVar fs (exprType expr)
       hoistBinding var expr
       return var
+  where
+    mk_inline var = case inl of
+                      Inline arity -> var `setIdUnfolding`
+                                      mkInlineRule InlSat expr arity
+                      DontInline   -> var
 
-hoistVExpr :: VExpr -> VM VVar
-hoistVExpr (ve, le)
+hoistVExpr :: VExpr -> Inline -> VM VVar
+hoistVExpr (ve, le) inl
   = do
       fs <- getBindName
-      vv <- hoistExpr ('v' `consFS` fs) ve
-      lv <- hoistExpr ('l' `consFS` fs) le
+      vv <- hoistExpr ('v' `consFS` fs) ve inl
+      lv <- hoistExpr ('l' `consFS` fs) le (addInlineArity inl 1)
       return (vv, lv)
 
-hoistPolyVExpr :: [TyVar] -> VM VExpr -> VM VExpr
-hoistPolyVExpr tvs p
+hoistPolyVExpr :: [TyVar] -> Inline -> VM VExpr -> VM VExpr
+hoistPolyVExpr tvs inline p
   = do
-      expr <- closedV . polyAbstract tvs $ \abstract ->
-              liftM (mapVect abstract) p
-      fn   <- hoistVExpr expr
+      inline' <- liftM (addInlineArity inline) (polyArity tvs)
+      expr <- closedV . polyAbstract tvs $ \args ->
+              liftM (mapVect (mkLams $ tvs ++ args)) p
+      fn   <- hoistVExpr expr inline'
       polyVApply (vVar fn) (mkTyVarTys tvs)
 
 takeHoisted :: VM [(Var, CoreExpr)]
@@ -413,14 +438,15 @@ buildClosures :: [TyVar] -> [VVar] -> [Type] -> Type -> VM VExpr -> VM VExpr
 buildClosures _   _    [] _ mk_body
   = mk_body
 buildClosures tvs vars [arg_ty] res_ty mk_body
-  = liftM vInlineMe (buildClosure tvs vars arg_ty res_ty mk_body)
+  = -- liftM vInlineMe $
+      buildClosure tvs vars arg_ty res_ty mk_body
 buildClosures tvs vars (arg_ty : arg_tys) res_ty mk_body
   = do
       res_ty' <- mkClosureTypes arg_tys res_ty
       arg <- newLocalVVar (fsLit "x") arg_ty
-      liftM vInlineMe
-        . buildClosure tvs vars arg_ty res_ty'
-        . hoistPolyVExpr tvs
+      -- liftM vInlineMe
+      buildClosure tvs vars arg_ty res_ty'
+        . hoistPolyVExpr tvs (Inline (length vars + 1))
         $ do
             lc <- builtin liftingContext
             clo <- buildClosures tvs (vars ++ [arg]) arg_tys res_ty mk_body
@@ -438,11 +464,11 @@ buildClosure tvs vars arg_ty res_ty mk_body
       env_bndr <- newLocalVVar (fsLit "env") env_ty
       arg_bndr <- newLocalVVar (fsLit "arg") arg_ty
 
-      fn <- hoistPolyVExpr tvs
+      fn <- hoistPolyVExpr tvs (Inline 2)
           $ do
               lc    <- builtin liftingContext
               body  <- mk_body
-              return . vInlineMe
+              return -- . vInlineMe
                      . vLams lc [env_bndr, arg_bndr]
                      $ bind (vVar env_bndr)
                             (vVarApps lc body (vars ++ [arg_bndr]))
index 2bce391..59fded3 100644 (file)
@@ -12,6 +12,7 @@ import HscTypes hiding      ( MonadThings(..) )
 import Module               ( PackageId )
 import CoreSyn
 import CoreUtils
+import CoreUnfold           ( mkInlineRule )
 import MkCore               ( mkWildCase )
 import CoreFVs
 import CoreMonad            ( CoreM, getHscEnv )
@@ -24,6 +25,7 @@ import VarEnv
 import VarSet
 import Id
 import OccName
+import BasicTypes           ( isLoopBreaker )
 
 import Literal              ( Literal, mkMachInt )
 import TysWiredIn
@@ -31,7 +33,8 @@ import TysPrim              ( intPrimTy )
 
 import Outputable
 import FastString
-import Control.Monad        ( liftM, liftM2, zipWithM )
+import Util                 ( zipLazy )
+import Control.Monad
 import Data.List            ( sortBy, unzip4 )
 
 vectorise :: PackageId -> ModGuts -> CoreM ModGuts
@@ -67,8 +70,8 @@ vectModule guts
 vectTopBind :: CoreBind -> VM CoreBind
 vectTopBind b@(NonRec var expr)
   = do
-      var'  <- vectTopBinder var
-      expr' <- vectTopRhs var expr
+      (inline, expr') <- vectTopRhs var expr
+      var' <- vectTopBinder var inline expr'
       hs    <- takeHoisted
       cexpr <- tryConvert var var' expr
       return . Rec $ (var, cexpr) : (var', expr') : hs
@@ -77,8 +80,13 @@ vectTopBind b@(NonRec var expr)
 
 vectTopBind b@(Rec bs)
   = do
-      vars'  <- mapM vectTopBinder vars
-      exprs' <- zipWithM vectTopRhs vars exprs
+      (vars', _, exprs') <- fixV $ \ ~(_, inlines, rhss) ->
+        do
+          vars' <- sequence [vectTopBinder var inline rhs
+                               | (var, ~(inline, rhs))
+                                 <- zipLazy vars (zip inlines rhss)]
+          (inlines', exprs') <- mapAndUnzipM (uncurry vectTopRhs) bs
+          return (vars', inlines', exprs')
       hs     <- takeHoisted
       cexprs <- sequence $ zipWith3 tryConvert vars vars' exprs
       return . Rec $ zip vars cexprs ++ zip vars' exprs' ++ hs
@@ -87,20 +95,28 @@ vectTopBind b@(Rec bs)
   where
     (vars, exprs) = unzip bs
 
-vectTopBinder :: Var -> VM Var
-vectTopBinder var
+-- NOTE: vectTopBinder *MUST* be lazy in inline and expr because of how it is
+-- used inside of fixV in vectTopBind
+vectTopBinder :: Var -> Inline -> CoreExpr -> VM Var
+vectTopBinder var inline expr
   = do
       vty  <- vectType (idType var)
-      var' <- cloneId mkVectOcc var vty
+      var' <- liftM (`setIdUnfolding` unfolding) $ cloneId mkVectOcc var vty
       defGlobalVar var var'
       return var'
+  where
+    unfolding = case inline of
+                  Inline arity -> mkInlineRule InlSat expr arity
+                  DontInline   -> noUnfolding
 
-vectTopRhs :: Var -> CoreExpr -> VM CoreExpr
+vectTopRhs :: Var -> CoreExpr -> VM (Inline, CoreExpr)
 vectTopRhs var expr
-  = do
-      closedV . liftM vectorised
-              . inBind var
-              $ vectPolyExpr (freeVars expr)
+  = closedV
+  $ do
+      (inline, vexpr) <- inBind var
+                       $ vectPolyExpr (isLoopBreaker $ idOccInfo var)
+                                      (freeVars expr)
+      return (inline, vectorised vexpr)
 
 tryConvert :: Var -> Var -> CoreExpr -> VM CoreExpr
 tryConvert var vect_var rhs
@@ -187,14 +203,19 @@ vectLiteral lit
       lexpr <- liftPD (Lit lit)
       return (Lit lit, lexpr)
 
-vectPolyExpr :: CoreExprWithFVs -> VM VExpr
-vectPolyExpr (_, AnnNote note expr)
-  = liftM (vNote note) $ vectPolyExpr expr
-vectPolyExpr expr
-  = polyAbstract tvs $ \abstract ->
-    do
-      mono' <- vectFnExpr False mono
-      return $ mapVect abstract mono'
+vectPolyExpr :: Bool -> CoreExprWithFVs -> VM (Inline, VExpr)
+vectPolyExpr loop_breaker (_, AnnNote note expr)
+  = do
+      (inline, expr') <- vectPolyExpr loop_breaker expr
+      return (inline, vNote note expr')
+vectPolyExpr loop_breaker expr
+  = do
+      arity <- polyArity tvs
+      polyAbstract tvs $ \args ->
+        do
+          (inline, mono') <- vectFnExpr False loop_breaker mono
+          return (addInlineArity inline arity,
+                  mapVect (mkLams $ tvs ++ args) mono')
   where
     (tvs, mono) = collectAnnTypeBinders expr
 
@@ -245,7 +266,7 @@ vectExpr (_, AnnCase scrut bndr ty alts)
 
 vectExpr (_, AnnLet (AnnNonRec bndr rhs) body)
   = do
-      vrhs <- localV . inBind bndr $ vectPolyExpr rhs
+      vrhs <- localV . inBind bndr . liftM snd $ vectPolyExpr False rhs
       (vbndr, vbody) <- vectBndrIn bndr (vectExpr body)
       return $ vLet (vNonRec vbndr vrhs) vbody
 
@@ -254,17 +275,18 @@ vectExpr (_, AnnLet (AnnRec bs) body)
       (vbndrs, (vrhss, vbody)) <- vectBndrsIn bndrs
                                 $ liftM2 (,)
                                   (zipWithM vect_rhs bndrs rhss)
-                                  (vectPolyExpr body)
+                                  (vectExpr body)
       return $ vLet (vRec vbndrs vrhss) vbody
   where
     (bndrs, rhss) = unzip bs
 
     vect_rhs bndr rhs = localV
                       . inBind bndr
-                      $ vectExpr rhs
+                      . liftM snd
+                      $ vectPolyExpr (isLoopBreaker $ idOccInfo bndr) rhs
 
 vectExpr e@(_, AnnLam bndr _)
-  | isId bndr = vectFnExpr True e
+  | isId bndr = liftM snd $ vectFnExpr True False e
 {-
 onlyIfV (isEmptyVarSet fvs) (vectScalarLam bs $ deAnnotate body)
                 `orElseV` vectLam True fvs bs body
@@ -274,14 +296,17 @@ onlyIfV (isEmptyVarSet fvs) (vectScalarLam bs $ deAnnotate body)
 
 vectExpr e = cantVectorise "Can't vectorise expression" (ppr $ deAnnotate e)
 
-vectFnExpr :: Bool -> CoreExprWithFVs -> VM VExpr
-vectFnExpr inline e@(fvs, AnnLam bndr _)
-  | isId bndr = onlyIfV (isEmptyVarSet fvs) (vectScalarLam bs $ deAnnotate body)
-                `orElseV` vectLam inline fvs bs body
+vectFnExpr :: Bool -> Bool -> CoreExprWithFVs -> VM (Inline, VExpr)
+vectFnExpr inline loop_breaker e@(fvs, AnnLam bndr _)
+  | isId bndr = onlyIfV (isEmptyVarSet fvs)
+                        (mark DontInline . vectScalarLam bs $ deAnnotate body)
+                `orElseV` mark inlineMe (vectLam inline loop_breaker fvs bs body)
   where
     (bs,body) = collectAnnValBinders e
-vectFnExpr _ e = vectExpr e
+vectFnExpr _ _ e = mark DontInline $ vectExpr e
 
+mark :: Inline -> VM a -> VM (Inline, a)
+mark b p = do { x <- p; return (b,x) }
 
 vectScalarLam :: [Var] -> CoreExpr -> VM VExpr
 vectScalarLam args body
@@ -291,11 +316,11 @@ vectScalarLam args body
                && is_scalar_ty res_ty
                && is_scalar (extendVarSetList scalars args) body)
         $ do
-            fn_var <- hoistExpr (fsLit "fn") (mkLams args body)
+            fn_var <- hoistExpr (fsLit "fn") (mkLams args body) DontInline
             zipf <- zipScalars arg_tys res_ty
             clo <- scalarClosure arg_tys res_ty (Var fn_var)
                                                 (zipf `App` Var fn_var)
-            clo_var <- hoistExpr (fsLit "clo") clo
+            clo_var <- hoistExpr (fsLit "clo") clo DontInline
             lclo <- liftPD (Var clo_var)
             return (Var clo_var, lclo)
   where
@@ -314,8 +339,8 @@ vectScalarLam args body
     is_scalar vs (App e1 e2) = is_scalar vs e1 && is_scalar vs e2
     is_scalar _ _            = False
 
-vectLam :: Bool -> VarSet -> [Var] -> CoreExprWithFVs -> VM VExpr
-vectLam inline fvs bs body
+vectLam :: Bool -> Bool -> VarSet -> [Var] -> CoreExprWithFVs -> VM VExpr
+vectLam inline loop_breaker fvs bs body
   = do
       tyvars <- localTyVars
       (vs, vvs) <- readLEnv $ \env ->
@@ -326,14 +351,28 @@ vectLam inline fvs bs body
       res_ty  <- vectType (exprType $ deAnnotate body)
 
       buildClosures tyvars vvs arg_tys res_ty
-        . hoistPolyVExpr tyvars
+        . hoistPolyVExpr tyvars (maybe_inline (length vs + length bs))
         $ do
             lc <- builtin liftingContext
             (vbndrs, vbody) <- vectBndrsIn (vs ++ bs)
                                            (vectExpr body)
-            return . maybe_inline $ vLams lc vbndrs vbody
+            vbody' <- break_loop lc res_ty vbody
+            return $ vLams lc vbndrs vbody'
   where
-    maybe_inline = if inline then vInlineMe else id
+    maybe_inline n | inline    = Inline n
+                   | otherwise = DontInline
+
+    break_loop lc ty (ve, le)
+      | loop_breaker
+      = do
+          empty <- emptyPD ty
+          lty <- mkPDataType ty
+          return (ve, mkWildCase (Var lc) intPrimTy lty
+                        [(DEFAULT, [], le),
+                         (LitAlt (mkMachInt 0), [], empty)])
+
+      | otherwise = return (ve, le)
 
 vectTyAppExpr :: CoreExprWithFVs -> [Type] -> VM VExpr
 vectTyAppExpr (_, AnnVar v) tys = vectPolyVar v tys
@@ -441,7 +480,7 @@ vectAlgCase tycon _ty_args scrut bndr ty alts
     cmp _             DEFAULT       = GT
     cmp _             _             = panic "vectAlgCase/cmp"
 
-    proc_alt arity sel vty lty (DataAlt dc, bndrs, body)
+    proc_alt arity sel _ lty (DataAlt dc, bndrs, body)
       = do
           vect_dc <- maybeV (lookupDataCon dc)
           let ntag = dataConTagZ vect_dc