Adapt vectoriser to new inlining mechanism
[ghc-hetmet.git] / compiler / vectorise / Vectorise.hs
index 2bce391..59fded3 100644 (file)
@@ -12,6 +12,7 @@ import HscTypes hiding      ( MonadThings(..) )
 import Module               ( PackageId )
 import CoreSyn
 import CoreUtils
+import CoreUnfold           ( mkInlineRule )
 import MkCore               ( mkWildCase )
 import CoreFVs
 import CoreMonad            ( CoreM, getHscEnv )
@@ -24,6 +25,7 @@ import VarEnv
 import VarSet
 import Id
 import OccName
+import BasicTypes           ( isLoopBreaker )
 
 import Literal              ( Literal, mkMachInt )
 import TysWiredIn
@@ -31,7 +33,8 @@ import TysPrim              ( intPrimTy )
 
 import Outputable
 import FastString
-import Control.Monad        ( liftM, liftM2, zipWithM )
+import Util                 ( zipLazy )
+import Control.Monad
 import Data.List            ( sortBy, unzip4 )
 
 vectorise :: PackageId -> ModGuts -> CoreM ModGuts
@@ -67,8 +70,8 @@ vectModule guts
 vectTopBind :: CoreBind -> VM CoreBind
 vectTopBind b@(NonRec var expr)
   = do
-      var'  <- vectTopBinder var
-      expr' <- vectTopRhs var expr
+      (inline, expr') <- vectTopRhs var expr
+      var' <- vectTopBinder var inline expr'
       hs    <- takeHoisted
       cexpr <- tryConvert var var' expr
       return . Rec $ (var, cexpr) : (var', expr') : hs
@@ -77,8 +80,13 @@ vectTopBind b@(NonRec var expr)
 
 vectTopBind b@(Rec bs)
   = do
-      vars'  <- mapM vectTopBinder vars
-      exprs' <- zipWithM vectTopRhs vars exprs
+      (vars', _, exprs') <- fixV $ \ ~(_, inlines, rhss) ->
+        do
+          vars' <- sequence [vectTopBinder var inline rhs
+                               | (var, ~(inline, rhs))
+                                 <- zipLazy vars (zip inlines rhss)]
+          (inlines', exprs') <- mapAndUnzipM (uncurry vectTopRhs) bs
+          return (vars', inlines', exprs')
       hs     <- takeHoisted
       cexprs <- sequence $ zipWith3 tryConvert vars vars' exprs
       return . Rec $ zip vars cexprs ++ zip vars' exprs' ++ hs
@@ -87,20 +95,28 @@ vectTopBind b@(Rec bs)
   where
     (vars, exprs) = unzip bs
 
-vectTopBinder :: Var -> VM Var
-vectTopBinder var
+-- NOTE: vectTopBinder *MUST* be lazy in inline and expr because of how it is
+-- used inside of fixV in vectTopBind
+vectTopBinder :: Var -> Inline -> CoreExpr -> VM Var
+vectTopBinder var inline expr
   = do
       vty  <- vectType (idType var)
-      var' <- cloneId mkVectOcc var vty
+      var' <- liftM (`setIdUnfolding` unfolding) $ cloneId mkVectOcc var vty
       defGlobalVar var var'
       return var'
+  where
+    unfolding = case inline of
+                  Inline arity -> mkInlineRule InlSat expr arity
+                  DontInline   -> noUnfolding
 
-vectTopRhs :: Var -> CoreExpr -> VM CoreExpr
+vectTopRhs :: Var -> CoreExpr -> VM (Inline, CoreExpr)
 vectTopRhs var expr
-  = do
-      closedV . liftM vectorised
-              . inBind var
-              $ vectPolyExpr (freeVars expr)
+  = closedV
+  $ do
+      (inline, vexpr) <- inBind var
+                       $ vectPolyExpr (isLoopBreaker $ idOccInfo var)
+                                      (freeVars expr)
+      return (inline, vectorised vexpr)
 
 tryConvert :: Var -> Var -> CoreExpr -> VM CoreExpr
 tryConvert var vect_var rhs
@@ -187,14 +203,19 @@ vectLiteral lit
       lexpr <- liftPD (Lit lit)
       return (Lit lit, lexpr)
 
-vectPolyExpr :: CoreExprWithFVs -> VM VExpr
-vectPolyExpr (_, AnnNote note expr)
-  = liftM (vNote note) $ vectPolyExpr expr
-vectPolyExpr expr
-  = polyAbstract tvs $ \abstract ->
-    do
-      mono' <- vectFnExpr False mono
-      return $ mapVect abstract mono'
+vectPolyExpr :: Bool -> CoreExprWithFVs -> VM (Inline, VExpr)
+vectPolyExpr loop_breaker (_, AnnNote note expr)
+  = do
+      (inline, expr') <- vectPolyExpr loop_breaker expr
+      return (inline, vNote note expr')
+vectPolyExpr loop_breaker expr
+  = do
+      arity <- polyArity tvs
+      polyAbstract tvs $ \args ->
+        do
+          (inline, mono') <- vectFnExpr False loop_breaker mono
+          return (addInlineArity inline arity,
+                  mapVect (mkLams $ tvs ++ args) mono')
   where
     (tvs, mono) = collectAnnTypeBinders expr
 
@@ -245,7 +266,7 @@ vectExpr (_, AnnCase scrut bndr ty alts)
 
 vectExpr (_, AnnLet (AnnNonRec bndr rhs) body)
   = do
-      vrhs <- localV . inBind bndr $ vectPolyExpr rhs
+      vrhs <- localV . inBind bndr . liftM snd $ vectPolyExpr False rhs
       (vbndr, vbody) <- vectBndrIn bndr (vectExpr body)
       return $ vLet (vNonRec vbndr vrhs) vbody
 
@@ -254,17 +275,18 @@ vectExpr (_, AnnLet (AnnRec bs) body)
       (vbndrs, (vrhss, vbody)) <- vectBndrsIn bndrs
                                 $ liftM2 (,)
                                   (zipWithM vect_rhs bndrs rhss)
-                                  (vectPolyExpr body)
+                                  (vectExpr body)
       return $ vLet (vRec vbndrs vrhss) vbody
   where
     (bndrs, rhss) = unzip bs
 
     vect_rhs bndr rhs = localV
                       . inBind bndr
-                      $ vectExpr rhs
+                      . liftM snd
+                      $ vectPolyExpr (isLoopBreaker $ idOccInfo bndr) rhs
 
 vectExpr e@(_, AnnLam bndr _)
-  | isId bndr = vectFnExpr True e
+  | isId bndr = liftM snd $ vectFnExpr True False e
 {-
 onlyIfV (isEmptyVarSet fvs) (vectScalarLam bs $ deAnnotate body)
                 `orElseV` vectLam True fvs bs body
@@ -274,14 +296,17 @@ onlyIfV (isEmptyVarSet fvs) (vectScalarLam bs $ deAnnotate body)
 
 vectExpr e = cantVectorise "Can't vectorise expression" (ppr $ deAnnotate e)
 
-vectFnExpr :: Bool -> CoreExprWithFVs -> VM VExpr
-vectFnExpr inline e@(fvs, AnnLam bndr _)
-  | isId bndr = onlyIfV (isEmptyVarSet fvs) (vectScalarLam bs $ deAnnotate body)
-                `orElseV` vectLam inline fvs bs body
+vectFnExpr :: Bool -> Bool -> CoreExprWithFVs -> VM (Inline, VExpr)
+vectFnExpr inline loop_breaker e@(fvs, AnnLam bndr _)
+  | isId bndr = onlyIfV (isEmptyVarSet fvs)
+                        (mark DontInline . vectScalarLam bs $ deAnnotate body)
+                `orElseV` mark inlineMe (vectLam inline loop_breaker fvs bs body)
   where
     (bs,body) = collectAnnValBinders e
-vectFnExpr _ e = vectExpr e
+vectFnExpr _ _ e = mark DontInline $ vectExpr e
 
+mark :: Inline -> VM a -> VM (Inline, a)
+mark b p = do { x <- p; return (b,x) }
 
 vectScalarLam :: [Var] -> CoreExpr -> VM VExpr
 vectScalarLam args body
@@ -291,11 +316,11 @@ vectScalarLam args body
                && is_scalar_ty res_ty
                && is_scalar (extendVarSetList scalars args) body)
         $ do
-            fn_var <- hoistExpr (fsLit "fn") (mkLams args body)
+            fn_var <- hoistExpr (fsLit "fn") (mkLams args body) DontInline
             zipf <- zipScalars arg_tys res_ty
             clo <- scalarClosure arg_tys res_ty (Var fn_var)
                                                 (zipf `App` Var fn_var)
-            clo_var <- hoistExpr (fsLit "clo") clo
+            clo_var <- hoistExpr (fsLit "clo") clo DontInline
             lclo <- liftPD (Var clo_var)
             return (Var clo_var, lclo)
   where
@@ -314,8 +339,8 @@ vectScalarLam args body
     is_scalar vs (App e1 e2) = is_scalar vs e1 && is_scalar vs e2
     is_scalar _ _            = False
 
-vectLam :: Bool -> VarSet -> [Var] -> CoreExprWithFVs -> VM VExpr
-vectLam inline fvs bs body
+vectLam :: Bool -> Bool -> VarSet -> [Var] -> CoreExprWithFVs -> VM VExpr
+vectLam inline loop_breaker fvs bs body
   = do
       tyvars <- localTyVars
       (vs, vvs) <- readLEnv $ \env ->
@@ -326,14 +351,28 @@ vectLam inline fvs bs body
       res_ty  <- vectType (exprType $ deAnnotate body)
 
       buildClosures tyvars vvs arg_tys res_ty
-        . hoistPolyVExpr tyvars
+        . hoistPolyVExpr tyvars (maybe_inline (length vs + length bs))
         $ do
             lc <- builtin liftingContext
             (vbndrs, vbody) <- vectBndrsIn (vs ++ bs)
                                            (vectExpr body)
-            return . maybe_inline $ vLams lc vbndrs vbody
+            vbody' <- break_loop lc res_ty vbody
+            return $ vLams lc vbndrs vbody'
   where
-    maybe_inline = if inline then vInlineMe else id
+    maybe_inline n | inline    = Inline n
+                   | otherwise = DontInline
+
+    break_loop lc ty (ve, le)
+      | loop_breaker
+      = do
+          empty <- emptyPD ty
+          lty <- mkPDataType ty
+          return (ve, mkWildCase (Var lc) intPrimTy lty
+                        [(DEFAULT, [], le),
+                         (LitAlt (mkMachInt 0), [], empty)])
+
+      | otherwise = return (ve, le)
 
 vectTyAppExpr :: CoreExprWithFVs -> [Type] -> VM VExpr
 vectTyAppExpr (_, AnnVar v) tys = vectPolyVar v tys
@@ -441,7 +480,7 @@ vectAlgCase tycon _ty_args scrut bndr ty alts
     cmp _             DEFAULT       = GT
     cmp _             _             = panic "vectAlgCase/cmp"
 
-    proc_alt arity sel vty lty (DataAlt dc, bndrs, body)
+    proc_alt arity sel _ lty (DataAlt dc, bndrs, body)
       = do
           vect_dc <- maybeV (lookupDataCon dc)
           let ntag = dataConTagZ vect_dc