Fix Trac #2321: bug in SAT
authorsimonpj@microsoft.com <unknown>
Mon, 16 Jun 2008 20:17:00 +0000 (20:17 +0000)
committersimonpj@microsoft.com <unknown>
Mon, 16 Jun 2008 20:17:00 +0000 (20:17 +0000)
  This is a fairly substantial rewrite of the Static Argument Transformatoin,
  done by Max Bolingbroke and reviewed and modified by Simon PJ.

  * Fix a subtle scoping problem; see Note [Binder type capture]
  * Redo the analysis to use environments
  * Run gentle simlification just before the transformation

compiler/main/DynFlags.hs
compiler/simplCore/SAT.lhs

index ed2fdc0..931d384 100644 (file)
@@ -849,7 +849,7 @@ getCoreToDo dflags
     liberate_case = dopt Opt_LiberateCase dflags
     rule_check    = ruleCheck dflags
     vectorisation = dopt Opt_Vectorise dflags
-    -- static_args   = dopt Opt_StaticArgumentTransformation dflags
+    static_args   = dopt Opt_StaticArgumentTransformation dflags
 
     maybe_rule_check phase = runMaybe rule_check (CoreDoRuleCheck phase)
 
@@ -903,8 +903,7 @@ getCoreToDo dflags
     -- may expose extra opportunities to float things outwards. However, to fix
     -- up the output of the transformation we need at do at least one simplify
     -- after this before anything else
-            -- runWhen static_args CoreDoStaticArgs,
-            -- XXX disabled, see #2321
+        runWhen static_args (CoreDoPasses [ simpl_gently, CoreDoStaticArgs ]),
 
         -- initial simplify: mk specialiser happy: minimum effort please
         simpl_gently,
index 3022f3c..e6e5ff1 100644 (file)
@@ -4,7 +4,7 @@
 
 %************************************************************************
 
-               Static Argument Transformation pass
+               Static Argument Transformation pass
 
 %************************************************************************
 
@@ -35,7 +35,7 @@ therefore there is no penalty in keeping them.
 
 We only apply the SAT when the number of static args is > 2. This
 produces few bad cases.  See
-       should_transform 
+                should_transform
 in saTransform.
 
 Here are the headline nofib results:
@@ -53,19 +53,25 @@ essential to make this work well!
 module SAT ( doStaticArgs ) where
 
 import DynFlags
-import Var
-import VarEnv
+import Var hiding (mkLocalId)
 import CoreSyn
 import CoreLint
+import CoreUtils
 import Type
 import TcType
 import Id
+import Name
+import OccName
+import VarEnv
 import UniqSupply
-import Unique
 import Util
+import UniqFM
+import VarSet
+import Unique
+import UniqSet
+import Outputable
 
 import Data.List
-import Panic
 import FastString
 
 #include "HsVersions.h"
@@ -78,350 +84,348 @@ doStaticArgs dflags us binds = do
     let binds' = snd $ mapAccumL sat_bind_threaded_us us binds
     endPass dflags "Static argument" Opt_D_verbose_core2core binds'
   where
-    sat_bind_threaded_us us bind = 
-        let (us1, us2) = splitUniqSupply us 
-        in (us1, runSAT (satBind bind) us2)
+    sat_bind_threaded_us us bind =
+        let (us1, us2) = splitUniqSupply us
+        in (us1, fst $ runSAT us2 (satBind bind emptyUniqSet))
 \end{code}
 \begin{code}
 -- We don't bother to SAT recursive groups since it can lead
 -- to massive code expansion: see Andre Santos' thesis for details.
 -- This means we only apply the actual SAT to Rec groups of one element,
 -- but we want to recurse into the others anyway to discover other binds
-satBind :: CoreBind -> SatM CoreBind
-satBind (NonRec binder expr) = do
-    expr' <- satExpr expr
-    return (NonRec binder expr')
-satBind (Rec [(binder, rhs)]) = do
-    insSAEnvFromBinding binder rhs
-    rhs' <- satExpr rhs
-    saTransform binder rhs'
-satBind (Rec pairs) = do
+satBind :: CoreBind -> IdSet -> SatM (CoreBind, IdSATInfo)
+satBind (NonRec binder expr) interesting_ids = do
+    (expr', sat_info_expr, expr_app) <- satExpr expr interesting_ids
+    return (NonRec binder expr', finalizeApp expr_app sat_info_expr)
+satBind (Rec [(binder, rhs)]) interesting_ids = do
+    let interesting_ids' = interesting_ids `addOneToUniqSet` binder
+        (rhs_binders, rhs_body) = collectBinders rhs
+    (rhs_body', sat_info_rhs_body) <- satTopLevelExpr rhs_body interesting_ids'
+    let sat_info_rhs_from_args = unitVarEnv binder (bindersToSATInfo rhs_binders)
+        sat_info_rhs' = mergeIdSATInfo sat_info_rhs_from_args sat_info_rhs_body
+        
+        shadowing = binder `elementOfUniqSet` interesting_ids
+        sat_info_rhs'' = if shadowing
+                        then sat_info_rhs' `delFromUFM` binder -- For safety
+                        else sat_info_rhs'
+    
+    bind' <- saTransformMaybe binder (lookupUFM sat_info_rhs' binder) 
+                             rhs_binders rhs_body'
+    return (bind', sat_info_rhs'')
+satBind (Rec pairs) interesting_ids = do
     let (binders, rhss) = unzip pairs
-    rhss' <- mapM satExpr rhss
-    return (Rec (zipEqual "satBind" binders rhss'))
+    rhss_SATed <- mapM (\e -> satTopLevelExpr e interesting_ids) rhss
+    let (rhss', sat_info_rhss') = unzip rhss_SATed
+    return (Rec (zipEqual "satBind" binders rhss'), mergeIdSATInfos sat_info_rhss')
 \end{code}
 \begin{code}
-emptySATInfo :: Id -> Maybe (Id, SATInfo)
-emptySATInfo v = Just (v, ([], []))
-
-satExpr :: CoreExpr -> SatM CoreExpr
-satExpr var@(Var v) = do
-    updSAEnv (emptySATInfo v)
-    return var
-
-satExpr lit@(Lit _) = do
-    return lit
-
-satExpr (Lam binders body) = do
-    body' <- satExpr body
-    return (Lam binders body')
-
-satExpr app@(App _ _) = do
-    getAppArgs app
+data App = VarApp Id | TypeApp Type
+data Staticness a = Static a | NotStatic
 
-satExpr (Case expr bndr ty alts) = do
-    expr' <- satExpr expr
-    alts' <- mapM satAlt alts
-    return (Case expr' bndr ty alts')
+type IdAppInfo = (Id, SATInfo)
+
+type SATInfo = [Staticness App]
+type IdSATInfo = IdEnv SATInfo
+emptyIdSATInfo :: IdSATInfo
+emptyIdSATInfo = emptyUFM
+
+{-
+pprIdSATInfo id_sat_info = vcat (map pprIdAndSATInfo (fmToList id_sat_info))
+  where pprIdAndSATInfo (v, sat_info) = hang (ppr v <> colon) 4 (pprSATInfo sat_info)
+-}
+
+pprSATInfo :: SATInfo -> SDoc
+pprSATInfo staticness = hcat $ map pprStaticness staticness
+
+pprStaticness :: Staticness App -> SDoc
+pprStaticness (Static (VarApp _))  = ptext (sLit "SV") 
+pprStaticness (Static (TypeApp _)) = ptext (sLit "ST") 
+pprStaticness NotStatic            = ptext (sLit "NS")
+
+
+mergeSATInfo :: SATInfo -> SATInfo -> SATInfo
+mergeSATInfo [] _  = []
+mergeSATInfo _  [] = []
+mergeSATInfo (NotStatic:statics) (_:apps) = NotStatic : mergeSATInfo statics apps
+mergeSATInfo (_:statics) (NotStatic:apps) = NotStatic : mergeSATInfo statics apps
+mergeSATInfo ((Static (VarApp v)):statics)  ((Static (VarApp v')):apps)  = (if v == v' then Static (VarApp v) else NotStatic) : mergeSATInfo statics apps
+mergeSATInfo ((Static (TypeApp t)):statics) ((Static (TypeApp t')):apps) = (if t `coreEqType` t' then Static (TypeApp t) else NotStatic) : mergeSATInfo statics apps
+mergeSATInfo l  r  = pprPanic "mergeSATInfo" $ ptext (sLit "Left:") <> pprSATInfo l <> ptext (sLit ", ")
+                                            <> ptext (sLit "Right:") <> pprSATInfo r
+
+mergeIdSATInfo :: IdSATInfo -> IdSATInfo -> IdSATInfo
+mergeIdSATInfo = plusUFM_C mergeSATInfo
+
+mergeIdSATInfos :: [IdSATInfo] -> IdSATInfo
+mergeIdSATInfos = foldl' mergeIdSATInfo emptyIdSATInfo
+
+bindersToSATInfo :: [Id] -> SATInfo
+bindersToSATInfo vs = map (Static . binderToApp) vs
+    where binderToApp v = if isId v
+                          then VarApp v
+                          else TypeApp $ mkTyVarTy v
+
+finalizeApp :: Maybe IdAppInfo -> IdSATInfo -> IdSATInfo
+finalizeApp Nothing id_sat_info = id_sat_info
+finalizeApp (Just (v, sat_info')) id_sat_info = 
+    let sat_info'' = case lookupUFM id_sat_info v of
+                        Nothing -> sat_info'
+                        Just sat_info -> mergeSATInfo sat_info sat_info'
+    in extendVarEnv id_sat_info v sat_info''
+\end{code}
+\begin{code}
+satTopLevelExpr :: CoreExpr -> IdSet -> SatM (CoreExpr, IdSATInfo)
+satTopLevelExpr expr interesting_ids = do
+    (expr', sat_info_expr, expr_app) <- satExpr expr interesting_ids
+    return (expr', finalizeApp expr_app sat_info_expr)
+
+satExpr :: CoreExpr -> IdSet -> SatM (CoreExpr, IdSATInfo, Maybe IdAppInfo)
+satExpr var@(Var v) interesting_ids = do
+    let app_info = if v `elementOfUniqSet` interesting_ids
+                   then Just (v, [])
+                   else Nothing
+    return (var, emptyIdSATInfo, app_info)
+
+satExpr lit@(Lit _) _ = do
+    return (lit, emptyIdSATInfo, Nothing)
+
+satExpr (Lam binders body) interesting_ids = do
+    (body', sat_info, this_app) <- satExpr body interesting_ids
+    return (Lam binders body', finalizeApp this_app sat_info, Nothing)
+
+satExpr (App fn arg) interesting_ids = do
+    (fn', sat_info_fn, fn_app) <- satExpr fn interesting_ids
+    let satRemainder = boring fn' sat_info_fn
+    case fn_app of
+        Nothing -> satRemainder Nothing
+        Just (fn_id, fn_app_info) ->
+            -- TODO: remove this use of append somehow (use a data structure with O(1) append but a left-to-right kind of interface)
+            let satRemainderWithStaticness arg_staticness = satRemainder $ Just (fn_id, fn_app_info ++ [arg_staticness])
+            in case arg of
+                Type t -> satRemainderWithStaticness $ Static (TypeApp t)
+                Var v  -> satRemainderWithStaticness $ Static (VarApp v)
+                _      -> satRemainderWithStaticness $ NotStatic
+  where
+    boring :: CoreExpr -> IdSATInfo -> Maybe IdAppInfo -> SatM (CoreExpr, IdSATInfo, Maybe IdAppInfo)
+    boring fn' sat_info_fn app_info = 
+        do (arg', sat_info_arg, arg_app) <- satExpr arg interesting_ids
+           let sat_info_arg' = finalizeApp arg_app sat_info_arg
+               sat_info = mergeIdSATInfo sat_info_fn sat_info_arg'
+           return (App fn' arg', sat_info, app_info)
+
+satExpr (Case expr bndr ty alts) interesting_ids = do
+    (expr', sat_info_expr, expr_app) <- satExpr expr interesting_ids
+    let sat_info_expr' = finalizeApp expr_app sat_info_expr
+    
+    zipped_alts' <- mapM satAlt alts
+    let (alts', sat_infos_alts) = unzip zipped_alts'
+    return (Case expr' bndr ty alts', mergeIdSATInfo sat_info_expr' (mergeIdSATInfos sat_infos_alts), Nothing)
   where
     satAlt (con, bndrs, expr) = do
-        expr' <- satExpr expr
-        return (con, bndrs, expr')
+        (expr', sat_info_expr) <- satTopLevelExpr expr interesting_ids
+        return ((con, bndrs, expr'), sat_info_expr)
 
-satExpr (Let bind body) = do
-    body' <- satExpr body
-    bind' <- satBind bind
-    return (Let bind' body')
+satExpr (Let bind body) interesting_ids = do
+    (body', sat_info_body, body_app) <- satExpr body interesting_ids
+    (bind', sat_info_bind) <- satBind bind interesting_ids
+    return (Let bind' body', mergeIdSATInfo sat_info_body sat_info_bind, body_app)
 
-satExpr (Note note expr) = do
-    expr' <- satExpr expr
-    return (Note note expr')
+satExpr (Note note expr) interesting_ids = do
+    (expr', sat_info_expr, expr_app) <- satExpr expr interesting_ids
+    return (Note note expr', sat_info_expr, expr_app)
 
-satExpr ty@(Type _) = do
-    return ty
+satExpr ty@(Type _) _ = do
+    return (ty, emptyIdSATInfo, Nothing)
 
-satExpr (Cast expr coercion) = do
-    expr' <- satExpr expr
-    return (Cast expr' coercion)
-\end{code}
-
-\begin{code}
-getAppArgs :: CoreExpr -> SatM CoreExpr
-getAppArgs app = do
-    (app', result) <- get app
-    updSAEnv result
-    return app'
-  where
-    get :: CoreExpr -> SatM (CoreExpr, Maybe (Id, SATInfo))
-    get (App e (Type ty)) = do
-        (e', result) <- get e
-        return
-            (App e' (Type ty),
-            case result of
-                Nothing            -> Nothing
-                Just (v, (tv, lv)) -> Just (v, (tv ++ [Static ty], lv)))
-
-    get (App e a) = do
-        (e', result) <- get e
-        a' <- satExpr a
-        
-        let si = case a' of
-                    Var v -> Static v
-                    _     -> NotStatic
-        return
-            (App e' a',
-            case result of
-                Just (v, (tv, lv))  -> Just (v, (tv, lv ++ [si]))
-                Nothing             -> Nothing)
-
-    get var@(Var v) = do
-        return (var, emptySATInfo v)
-
-    get e = do
-        e' <- satExpr e
-        return (e', Nothing)
+satExpr (Cast expr coercion) interesting_ids = do
+    (expr', sat_info_expr, expr_app) <- satExpr expr interesting_ids
+    return (Cast expr' coercion, sat_info_expr, expr_app)
 \end{code}
 
 %************************************************************************
 
-       Environment
+                Static Argument Transformation Monad
 
 %************************************************************************
 
 \begin{code}
-data SATEnv = SatEnv { idSATInfo :: IdEnv SATInfo }
-
-emptyEnv :: SATEnv
-emptyEnv = SatEnv { idSATInfo = emptyVarEnv }
+type SatM result = UniqSM result
 
-type SATInfo = ([Staticness Type], [Staticness Id])
+runSAT :: UniqSupply -> SatM a -> a
+runSAT = initUs_
 
-data Staticness a = Static a | NotStatic
-
-delOneFromSAEnv :: Id -> SatM ()
-delOneFromSAEnv v = modifyEnv $ \env -> env { idSATInfo = delVarEnv (idSATInfo env) v }
-
-updSAEnv :: Maybe (Id, SATInfo) -> SatM ()
-updSAEnv Nothing = do
-    return ()
-updSAEnv (Just (b, (tyargs, args))) = do
-    r <- getSATInfo b
-    case r of
-      Nothing               -> return ()
-      Just (tyargs', args') -> do
-          delOneFromSAEnv b
-          insSAEnv b (checkArgs (eqWith coreEqType) tyargs tyargs',
-                      checkArgs (eqWith (==)) args args')
-  where eqWith _  NotStatic  NotStatic  = True
-        eqWith eq (Static x) (Static y) = x `eq` y
-        eqWith _  _          _          = False
-
-checkArgs :: (Staticness a -> Staticness a -> Bool) -> [Staticness a] -> [Staticness a] -> [Staticness a]
-checkArgs _  as [] = notStatics (length as)
-checkArgs _  [] as = notStatics (length as)
-checkArgs eq (a:as) (a':as') | a `eq` a' = a:checkArgs eq as as'
-checkArgs eq (_:as) (_:as') = NotStatic:checkArgs eq as as'
-
-notStatics :: Int -> [Staticness a]
-notStatics n = nOfThem n NotStatic
-
-insSAEnv :: Id -> SATInfo -> SatM ()
-insSAEnv b info = modifyEnv $ \env -> env { idSATInfo = extendVarEnv (idSATInfo env) b info }
-
-insSAEnvFromBinding :: Id -> CoreExpr -> SatM ()
-insSAEnvFromBinding bndr e = insSAEnv bndr (getArgLists e)
+newUnique :: SatM Unique
+newUnique = getUniqueUs
 \end{code}
 
-%************************************************************************
-
-       Static Argument Transformation Monad
 
 %************************************************************************
 
-Two items of state to thread around: a UniqueSupply and a SATEnv.
-
-\begin{code}
-newtype SatM result
-  = SatM (UniqSupply -> SATEnv -> (result, SATEnv))
-
-instance Monad SatM where
-    (>>=) = thenSAT
-    (>>) = thenSAT_
-    return = returnSAT
-
-runSAT :: SatM a -> UniqSupply -> a
-runSAT (SatM f) us = fst $ f us emptyEnv
-
-thenSAT :: SatM a -> (a -> SatM b) -> SatM b
-thenSAT (SatM m) k
-  = SatM $ \us env -> 
-    case splitUniqSupply us    of { (s1, s2) ->
-    case m s1 env              of { (m_result, menv) ->
-    case k m_result            of { (SatM k') ->
-    k' s2 menv }}}
-
-thenSAT_ :: SatM a -> SatM b -> SatM b
-thenSAT_ (SatM m) (SatM k)
-  = SatM $ \us env ->
-    case splitUniqSupply us    of { (s1, s2) ->
-    case m s1 env               of { (_, menv) ->
-    k s2 menv }}
-
-returnSAT :: a -> SatM a
-returnSAT v = withEnv $ \env -> (v, env)
-
-modifyEnv :: (SATEnv -> SATEnv) -> SatM ()
-modifyEnv f = SatM $ \_ env -> ((), f env)
-
-withEnv :: (SATEnv -> (b, SATEnv)) -> SatM b
-withEnv f = SatM $ \_ env -> f env
-
-projectFromEnv :: (SATEnv -> a) -> SatM a
-projectFromEnv f = withEnv (\env -> (f env, env))
-\end{code}
+                Static Argument Transformation Monad
 
 %************************************************************************
 
-               Utility Functions
+To do the transformation, the game plan is to:
 
-%************************************************************************
+1. Create a small nonrecursive RHS that takes the
+   original arguments to the function but discards
+   the ones that are static and makes a call to the
+   SATed version with the remainder. We intend that
+   this will be inlined later, removing the overhead
 
-\begin{code}
-getSATInfo :: Id -> SatM (Maybe SATInfo)
-getSATInfo var = projectFromEnv $ \env -> lookupVarEnv (idSATInfo env) var
+2. Bind this nonrecursive RHS over the original body
+   WITH THE SAME UNIQUE as the original body so that
+   any recursive calls to the original now go via
+   the small wrapper
 
-newSATName :: Id -> Type -> SatM Id
-newSATName _ ty
-  = SatM $ \us env -> (mkSysLocal (fsLit "$sat") (uniqFromSupply us) ty, env)
+3. Rebind the original function to a new one which contains
+   our SATed function and just makes a call to it:
+   we call the thing making this call the local body
 
-getArgLists :: CoreExpr -> ([Staticness Type], [Staticness Id])
-getArgLists expr
-  = let
-    (tvs, lambda_bounds, _) = collectTyAndValBinders expr
-    in
-    ([ Static (mkTyVarTy tv) | tv <- tvs ],
-     [ Static v              | v <- lambda_bounds ])
-
-\end{code}
+Example: transform this
 
-We implement saTransform using shadowing of binders, that is
-we transform
-map = \f as -> case as of
-         [] -> []
-         (a':as') -> let x = f a'
-                 y = map f as'
-                 in x:y
+    map :: forall a b. (a->b) -> [a] -> [b]
+    map = /\ab. \(f:a->b) (as:[a]) -> body[map]
 to
-map = \f as -> let map = \f as -> map' as
-           in let rec map' = \as -> case as of
-                      [] -> []
-                      (a':as') -> let x = f a'
-                              y = map f as'
-                              in x:y
-          in map' as
-
-the inner map should get inlined and eliminated.
+    map :: forall a b. (a->b) -> [a] -> [b]
+    map = /\ab. \(f:a->b) (as:[a]) ->
+         letrec map' :: [a] -> [b]
+                   -- The "worker function
+                map' = \(as:[a]) -> 
+                        let map :: forall a' b'. (a -> b) -> [a] -> [b]
+                               -- The "shadow function
+                            map = /\a'b'. \(f':(a->b) (as:[a]).
+                                  map' as
+                        in body[map]
+         in map' as
+
+Note [Shadow binding]
+~~~~~~~~~~~~~~~~~~~~~
+The calls to the inner map inside body[map] should get inlined
+by the local re-binding of 'map'.  We call this the "shadow binding".
+
+But we can't use the original binder 'map' unchanged, because
+it might be exported, in which case the shadow binding won't be
+discarded as dead code after it is inlined.
+
+So we use a hack: we make a new SysLocal binder with the *same* unique
+as binder.  (Another alternative would be to reset the export flag.)
+
+Note [Binder type capture]
+~~~~~~~~~~~~~~~~~~~~~~~~~~
+Notice that in the inner map (the "shadow function"), the static arguments
+are discarded -- it's as if they were underscores.  Instead, mentions
+of these arguments (notably in the types of dynamic arguments) are bound
+by the *outer* lambdas of the main function.  So we must make up fresh
+names for the static arguments so that they do not capture variables 
+mentioned in the types of dynamic args.  
+
+In the map example, the shadow function must clone the static type
+argument a,b, giving a',b', to ensure that in the \(as:[a]), the 'a'
+is bound by the outer forall.  We clone f' too for consistency, but
+that doesn't matter either way because static Id arguments aren't 
+mentioned in the shadow binding at all.
+
+If we don't we get something like this:
+
+[Exported]
+[Arity 3]
+GHC.Base.until =
+  \ (@ a_aiK)
+    (p_a6T :: a_aiK -> GHC.Bool.Bool)
+    (f_a6V :: a_aiK -> a_aiK)
+    (x_a6X :: a_aiK) ->
+    letrec {
+      sat_worker_s1aU :: a_aiK -> a_aiK
+      []
+      sat_worker_s1aU =
+        \ (x_a6X :: a_aiK) ->
+          let {
+            sat_shadow_r17 :: forall a_a3O.
+                              (a_a3O -> GHC.Bool.Bool) -> (a_a3O -> a_a3O) -> a_a3O -> a_a3O
+            []
+            sat_shadow_r17 =
+              \ (@ a_aiK)
+                (p_a6T :: a_aiK -> GHC.Bool.Bool)
+                (f_a6V :: a_aiK -> a_aiK)
+                (x_a6X :: a_aiK) ->
+                sat_worker_s1aU x_a6X } in
+          case p_a6T x_a6X of wild_X3y [ALWAYS Dead Nothing] {
+            GHC.Bool.False -> GHC.Base.until @ a_aiK p_a6T f_a6V (f_a6V x_a6X);
+            GHC.Bool.True -> x_a6X
+          }; } in
+    sat_worker_s1aU x_a6X
+    
+Where sat_shadow has captured the type variables of x_a6X etc as it has a a_aiK 
+type argument. This is bad because it means the application sat_worker_s1aU x_a6X
+is not well typed.
 
 \begin{code}
-saTransform :: Id -> CoreExpr -> SatM CoreBind
-saTransform binder rhs = do
-    r <- getSATInfo binder
-    case r of
-      Just (tyargs, args) | should_transform args
-        -> do
-            -- In order to get strictness information on this new binder
-            -- we need to make sure this stage happens >before< the analysis
-            binder' <- newSATName binder (mkSATLamTy tyargs args)
-            new_rhs <- mkNewRhs binder binder' args rhs
-            return (NonRec binder new_rhs)
-      _ -> return (Rec [(binder, rhs)])
+saTransformMaybe :: Id -> Maybe SATInfo -> [Id] -> CoreExpr -> SatM CoreBind
+saTransformMaybe binder maybe_arg_staticness rhs_binders rhs_body
+  | Just arg_staticness <- maybe_arg_staticness
+  , should_transform arg_staticness
+  = saTransform binder arg_staticness rhs_binders rhs_body
+  | otherwise
+  = return (Rec [(binder, mkLams rhs_binders rhs_body)])
+  where 
+    should_transform staticness = n_static_args > 1 -- THIS IS THE DECISION POINT
+      where
+       n_static_args = length (filter isStaticValue staticness)
+
+saTransform :: Id -> SATInfo -> [Id] -> CoreExpr -> SatM CoreBind
+saTransform binder arg_staticness rhs_binders rhs_body
+  = do { shadow_lam_bndrs <- mapM clone binders_w_staticness
+       ; uniq             <- newUnique
+       ; return (NonRec binder (mk_new_rhs uniq shadow_lam_bndrs)) }
   where
-    should_transform args
-      = staticArgsLength > 1           -- THIS IS THE DECISION POINT
-      where staticArgsLength = length (filter isStatic args)
+    -- Running example: foldr
+    -- foldr \alpha \beta c n xs = e, for some e
+    -- arg_staticness = [Static TypeApp, Static TypeApp, Static VarApp, Static VarApp, NonStatic]
+    -- rhs_binders = [\alpha, \beta, c, n, xs]
+    -- rhs_body = e
     
-    mkNewRhs binder binder' args rhs = let
-        non_static_args :: [Id]
-        non_static_args = get_nsa args rhs_val_binders
-          where
-            get_nsa :: [Staticness a] -> [a] -> [a]
-            get_nsa [] _ = []
-            get_nsa _ [] = []
-            get_nsa (NotStatic:args) (v:as) = v:get_nsa args as
-            get_nsa (_:args)         (_:as) =   get_nsa args as
-
-        -- To do the transformation, the game plan is to:
-        -- 1. Create a small nonrecursive RHS that takes the
-        --    original arguments to the function but discards
-        --    the ones that are static and makes a call to the
-        --    SATed version with the remainder. We intend that
-        --    this will be inlined later, removing the overhead
-        -- 2. Bind this nonrecursive RHS over the original body
-        --    WITH THE SAME UNIQUE as the original body so that
-        --    any recursive calls to the original now go via
-        --    the small wrapper
-        -- 3. Rebind the original function to a new one which contains
-        --    our SATed function and just makes a call to it:
-        --    we call the thing making this call the local body
-
-        local_body = mkApps (Var binder') [Var a | a <- non_static_args]
-
-        nonrec_rhs = mkOrigLam local_body
-
-        -- HACK! The following is a fake SysLocal binder with
-        --  *the same* unique as binder.
-        -- the reason for this is the following:
-        -- this binder *will* get inlined but if it happen to be
-        -- a top level binder it is never removed as dead code,
-        -- therefore we have to remove that information (of it being
-        -- top-level or exported somehow.)
-        -- A better fix is to use binder directly but with the TopLevel
-        -- tag (or Exported tag) modified.
-        fake_binder = mkSysLocal (fsLit "sat")
-                (getUnique binder)
-                (idType binder)
-        rec_body = mkLams non_static_args
-                   (Let (NonRec fake_binder nonrec_rhs) {-in-} rhs_body)
-        in return (mkOrigLam (Let (Rec [(binder', rec_body)]) {-in-} local_body))
-      where
-        (rhs_binders, rhs_body) = collectBinders rhs
-        rhs_val_binders = filter isId rhs_binders
-        
-        mkOrigLam = mkLams rhs_binders
+    binders_w_staticness = rhs_binders `zip` (arg_staticness ++ repeat NotStatic)
+                                       -- Any extra args are assumed NotStatic
+
+    non_static_args :: [Var]
+           -- non_static_args = [xs]
+           -- rhs_binders_without_type_capture = [\alpha', \beta', c, n, xs]
+    non_static_args = [v | (v, NotStatic) <- binders_w_staticness]
+
+    clone (bndr, NotStatic) = return bndr
+    clone (bndr, _        ) = do { uniq <- newUnique
+                                ; return (setVarUnique bndr uniq) }
+
+    -- new_rhs = \alpha beta c n xs -> 
+    --           let sat_worker = \xs -> let sat_shadow = \alpha' beta' c n xs -> 
+    --                                              sat_worker xs 
+    --                                   in e
+    --           in sat_worker xs
+    mk_new_rhs uniq shadow_lam_bndrs 
+       = mkLams rhs_binders $ 
+         Let (Rec [(rec_body_bndr, rec_body)]) 
+         local_body
+       where
+         local_body = mkVarApps (Var rec_body_bndr) non_static_args
+
+         rec_body = mkLams non_static_args $
+                     Let (NonRec shadow_bndr shadow_rhs) rhs_body
+
+           -- See Note [Binder type capture]
+         shadow_rhs = mkLams shadow_lam_bndrs local_body
+           -- nonrec_rhs = \alpha' beta' c n xs -> sat_worker xs
+
+         rec_body_bndr = mkSysLocal (fsLit "sat_worker") uniq (exprType rec_body)
+           -- rec_body_bndr = sat_worker
+    
+           -- See Note [Shadow binding]; make a SysLocal
+         shadow_bndr = mkSysLocal (occNameFS (getOccName binder)) 
+                                  (idUnique binder)
+                                  (exprType shadow_rhs)
 
-    mkSATLamTy tyargs args
-      = substTy (mk_inst_tyenv tyargs tv_tmpl)
-                (mkSigmaTy tv_tmpl' theta_tys' tau_ty')
-      where
-          -- get type info for the local function:
-          (tv_tmpl, theta_tys, tau_ty) = (tcSplitSigmaTy . idType) binder
-          (reg_arg_tys, res_type)      = splitFunTys tau_ty
-
-          -- now, we drop the ones that are
-          -- static, that is, the ones we will not pass to the local function
-          tv_tmpl'     = dropStatics tyargs tv_tmpl
-
-          -- Extract the args that correspond to the theta tys (e.g. dictionaries) and argument tys (normal values)
-          (args1, args2) = splitAtList theta_tys args
-          theta_tys'     = dropStatics args1 theta_tys
-          reg_arg_tys'   = dropStatics args2 reg_arg_tys
-
-          -- Piece the function type back together from our static-filtered components
-          tau_ty'        = mkFunTys reg_arg_tys' res_type
-
-          mk_inst_tyenv :: [Staticness Type] -> [TyVar] -> TvSubst
-          mk_inst_tyenv []              _      = emptyTvSubst
-          mk_inst_tyenv (Static s:args) (t:ts) = extendTvSubst (mk_inst_tyenv args ts) t s
-          mk_inst_tyenv (_:args)        (_:ts) = mk_inst_tyenv args ts
-          mk_inst_tyenv _               _      = panic "mk_inst_tyenv"
-
-dropStatics :: [Staticness a] -> [b] -> [b]
-dropStatics [] t = t
-dropStatics (Static _:args) (_:ts) = dropStatics args ts
-dropStatics (_:args)        (t:ts) = t:dropStatics args ts
-dropStatics _               _      = panic "dropStatics"
-
-isStatic :: Staticness a -> Bool
-isStatic NotStatic = False
-isStatic _         = True
-\end{code}
+isStaticValue :: Staticness App -> Bool
+isStaticValue (Static (VarApp _)) = True
+isStaticValue _                   = False
+
+\end{code}
\ No newline at end of file