Make let-floating work even if there are big lambdas in the way
authorsimonpj@microsoft.com <unknown>
Mon, 7 May 2007 16:24:22 +0000 (16:24 +0000)
committersimonpj@microsoft.com <unknown>
Mon, 7 May 2007 16:24:22 +0000 (16:24 +0000)
This patch generalises the let-floating transformation in a way
suggested by Roman and Manuel when doing closure conversion.

There are extensive comments in Note [Floating and type abstraction],
which begins thus.  Consider this:
x = /\a. C e1 e2
We'd like to float this to
y1 = /\a. e1
y2 = /\a. e2
x = /\a. C (y1 a) (y2 a)
for the usual reasons: we want to inline x rather vigorously.

(Further commennts follow in SimplUtils.)

The implementation is not hard; indeed it used to be in GHC years ago.
I removed it thinking that full laziness would achieve the same
effect, but I'm not sure it does; and in any case it seems more direct
to do it here.

The transformation should not make anything worse, so yell if
you see anything unexpected happening.

compiler/simplCore/SimplEnv.lhs
compiler/simplCore/SimplUtils.lhs
compiler/simplCore/Simplify.lhs

index 3832f54..2fedf87 100644 (file)
@@ -32,9 +32,9 @@ module SimplEnv (
        substExpr, substTy, 
 
        -- Floats
-       Floats, emptyFloats, isEmptyFloats, addNonRec, addFloats, 
-       wrapFloats, floatBinds, setFloats, canFloat, zapFloats, addRecFloats,
-       getFloats
+       Floats, emptyFloats, isEmptyFloats, addNonRec, addFloats, extendFloats,
+       wrapFloats, floatBinds, setFloats, zapFloats, addRecFloats,
+       doFloatFromRhs, getFloats
     ) where
 
 #include "HsVersions.h"
@@ -44,14 +44,12 @@ import IdInfo
 import CoreSyn
 import Rules
 import CoreUtils
-import CoreFVs
 import CostCentre
 import Var
 import VarEnv
 import VarSet
 import OrdList
 import Id
-import NewDemand
 import qualified CoreSubst     ( Subst, mkSubst, substExpr, substSpec, substWorker )
 import qualified Type          ( substTy, substTyVarBndr )
 import Type hiding             ( substTy, substTyVarBndr )
@@ -59,7 +57,6 @@ import Coercion
 import BasicTypes      
 import DynFlags
 import Util
-import UniqFM
 import Outputable
 \end{code}
 
@@ -312,11 +309,13 @@ The Floats is a bunch of bindings, classified by a FloatFlag.
 
   NonRec x (y:ys)      FltLifted
   Rec [(x,rhs)]                FltLifted
-  NonRec x# (y +# 3)   FltOkSpec
+
+  NonRec x# (y +# 3)   FltOkSpec       -- Unboxed, but ok-for-spec'n
+
   NonRec x# (a /# b)   FltCareful
-  NonRec x* (f y)      FltCareful      -- Might fail or diverge
-  NonRec x# (f y)      FltCareful      -- Might fail or diverge
-                         (where f :: Int -> Int#)
+  NonRec x* (f y)      FltCareful      -- Strict binding; might fail or diverge
+  NonRec x# (f y)      FltCareful      -- Unboxed binding: might fail or diverge
+                                       --        (where f :: Int -> Int#)
 
 \begin{code}
 data Floats = Floats (OrdList OutBind) FloatFlag
@@ -359,14 +358,15 @@ classifyFF (NonRec bndr rhs)
   | exprOkForSpeculation rhs = FltOkSpec
   | otherwise               = FltCareful
 
-canFloat :: TopLevelFlag -> RecFlag -> Bool -> SimplEnv -> Bool
-canFloat lvl rec str (SimplEnv {seFloats = Floats _ ff}) 
-  = canFloatFlt lvl rec str ff
-
-canFloatFlt :: TopLevelFlag -> RecFlag -> Bool -> FloatFlag -> Bool
-canFloatFlt lvl rec str FltLifted  = True
-canFloatFlt lvl rec str FltOkSpec  = isNotTopLevel lvl && isNonRec rec
-canFloatFlt lvl rec str FltCareful = str && isNotTopLevel lvl && isNonRec rec
+doFloatFromRhs :: TopLevelFlag -> RecFlag -> Bool -> OutExpr -> SimplEnv -> Bool
+doFloatFromRhs lvl rec str rhs (SimplEnv {seFloats = Floats fs ff}) 
+  =  not (isNilOL fs) && want_to_float && can_float
+  where
+     want_to_float = isTopLevel lvl || exprIsCheap rhs
+     can_float = case ff of
+                  FltLifted  -> True
+                  FltOkSpec  -> isNotTopLevel lvl && isNonRec rec
+                  FltCareful -> isNotTopLevel lvl && isNonRec rec && str
 \end{code}
 
 
@@ -387,6 +387,16 @@ addNonRec env id rhs
   = env { seFloats = seFloats env `addFlts` unitFloat (NonRec id rhs),
          seInScope = extendInScopeSet (seInScope env) id }
 
+extendFloats :: SimplEnv -> [OutBind] -> SimplEnv
+-- Add these bindings to the floats, and extend the in-scope env too
+extendFloats env binds
+  = env { seFloats  = seFloats env `addFlts` new_floats,
+         seInScope = extendInScopeSetList (seInScope env) bndrs }
+  where
+    bndrs = bindersOfBinds binds
+    new_floats = Floats (toOL binds) 
+                       (foldr (andFF . classifyFF) FltLifted binds)
+
 addFloats :: SimplEnv -> SimplEnv -> SimplEnv
 -- Add the floats for env2 to env1; 
 -- *plus* the in-scope set for env2, which is bigger 
index 3b304c6..95aa89e 100644 (file)
@@ -19,7 +19,9 @@ module SimplUtils (
        mkBoringStop, mkLazyArgStop, mkRhsStop, contIsRhsOrArg,
        interestingCallContext, interestingArgContext,
 
-       interestingArg, mkArgInfo
+       interestingArg, mkArgInfo,
+       
+       abstractFloats
     ) where
 
 #include "HsVersions.h"
@@ -28,12 +30,14 @@ import SimplEnv
 import DynFlags
 import StaticFlags
 import CoreSyn
+import qualified CoreSubst
 import PprCore
 import CoreFVs
 import CoreUtils
 import Literal 
 import CoreUnfold
 import MkId
+import Name
 import Id
 import NewDemand
 import SimplMonad
@@ -149,14 +153,14 @@ mkLazyArgStop ty has_rules = Stop ty AnArg (canUpdateInPlace ty || has_rules)
 mkRhsStop :: OutType -> SimplCont
 mkRhsStop ty = Stop ty AnRhs (canUpdateInPlace ty)
 
-contIsRhsOrArg (Stop _ _ _)    = True
+contIsRhsOrArg (Stop {})       = True
 contIsRhsOrArg (StrictBind {}) = True
 contIsRhsOrArg (StrictArg {})  = True
 contIsRhsOrArg other          = False
 
 -------------------
 contIsDupable :: SimplCont -> Bool
-contIsDupable (Stop _ _ _)                      = True
+contIsDupable (Stop {})                 = True
 contIsDupable (ApplyTo  OkToDup _ _ _)   = True
 contIsDupable (Select   OkToDup _ _ _ _) = True
 contIsDupable (CoerceIt _ cont)          = contIsDupable cont
@@ -164,7 +168,7 @@ contIsDupable other                  = False
 
 -------------------
 contIsTrivial :: SimplCont -> Bool
-contIsTrivial (Stop _ _ _)               = True
+contIsTrivial (Stop {})                          = True
 contIsTrivial (ApplyTo _ (Type _) _ cont) = contIsTrivial cont
 contIsTrivial (CoerceIt _ cont)          = contIsTrivial cont
 contIsTrivial other                      = False
@@ -803,6 +807,8 @@ mkLam :: [OutBndr] -> OutExpr -> SimplM OutExpr
 --     a) eta reduction, if that gives a trivial expression
 --     b) eta expansion [only if there are some value lambdas]
 
+mkLam [] body 
+  = return body
 mkLam bndrs body
   = do { dflags <- getDOptsSmpl
        ; mkLam' dflags bndrs body }
@@ -941,8 +947,35 @@ tryEtaExpansion dflags body
 %*                                                                     *
 %************************************************************************
 
-tryRhsTyLam tries this transformation, when the big lambda appears as
-the RHS of a let(rec) binding:
+Note [Floating and type abstraction]
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+Consider this:
+       x = /\a. C e1 e2
+We'd like to float this to 
+       y1 = /\a. e1
+       y2 = /\a. e2
+       x = /\a. C (y1 a) (y2 a)
+for the usual reasons: we want to inline x rather vigorously.
+
+You may think that this kind of thing is rare.  But in some programs it is
+common.  For example, if you do closure conversion you might get:
+
+       data a :-> b = forall e. (e -> a -> b) :$ e
+
+       f_cc :: forall a. a :-> a
+       f_cc = /\a. (\e. id a) :$ ()
+
+Now we really want to inline that f_cc thing so that the
+construction of the closure goes away. 
+
+So I have elaborated simplLazyBind to understand right-hand sides that look
+like
+       /\ a1..an. body
+
+and treat them specially. The real work is done in SimplUtils.abstractFloats,
+but there is quite a bit of plumbing in simplLazyBind as well.
+
+The same transformation is good when there are lets in the body:
 
        /\abc -> let(rec) x = e in b
    ==>
@@ -964,25 +997,6 @@ let-floating.
 This optimisation is CRUCIAL in eliminating the junk introduced by
 desugaring mutually recursive definitions.  Don't eliminate it lightly!
 
-So far as the implementation is concerned:
-
-       Invariant: go F e = /\tvs -> F e
-       
-       Equalities:
-               go F (Let x=e in b)
-               = Let x' = /\tvs -> F e 
-                 in 
-                 go G b
-               where
-                   G = F . Let x = x' tvs
-       
-               go F (Letrec xi=ei in b)
-               = Letrec {xi' = /\tvs -> G ei} 
-                 in
-                 go G b
-               where
-                 G = F . Let {xi = xi' tvs}
-
 [May 1999]  If we do this transformation *regardless* then we can
 end up with some pretty silly stuff.  For example, 
 
@@ -1004,43 +1018,30 @@ and is of the form
 If we abstract this wrt the tyvar we then can't do the case inline
 as we would normally do.
 
+That's why the whole transformation is part of the same process that
+floats let-bindings and constructor arguments out of RHSs.  In particular,
+it is guarded by the doFloatFromRhs call in simplLazyBind.
 
-\begin{code}
-{-     Trying to do this in full laziness
-
-tryRhsTyLam :: SimplEnv -> [OutTyVar] -> OutExpr -> SimplM FloatsWithExpr
--- Call ensures that all the binders are type variables
-
-tryRhsTyLam env tyvars body            -- Only does something if there's a let
-  |  not (all isTyVar tyvars)
-  || not (worth_it body)               -- inside a type lambda, 
-  = returnSmpl (emptyFloats env, body) -- and a WHNF inside that
-
-  | otherwise
-  = go env (\x -> x) body
 
+\begin{code}
+abstractFloats :: [OutTyVar] -> SimplEnv -> OutExpr -> SimplM ([OutBind], OutExpr)
+abstractFloats tvs body_env body
+  = ASSERT( notNull body_floats )
+    do { (subst, float_binds) <- mapAccumLSmpl abstract empty_subst body_floats
+       ; return (float_binds, CoreSubst.substExpr subst body) }
   where
-    worth_it e@(Let _ _) = whnf_in_middle e
-    worth_it e          = False
-
-    whnf_in_middle (Let (NonRec x rhs) e) | isUnLiftedType (idType x) = False
-    whnf_in_middle (Let _ e) = whnf_in_middle e
-    whnf_in_middle e        = exprIsCheap e
-
-    main_tyvar_set = mkVarSet tyvars
-
-    go env fn (Let bind@(NonRec var rhs) body)
-      | exprIsTrivial rhs
-      = go env (fn . Let bind) body
-
-    go env fn (Let (NonRec var rhs) body)
-      = mk_poly tyvars_here var                                                        `thenSmpl` \ (var', rhs') ->
-       addAuxiliaryBind env (NonRec var' (mkLams tyvars_here (fn rhs)))        $ \ env -> 
-       go env (fn . Let (mk_silly_bind var rhs')) body
-
+    main_tv_set = mkVarSet tvs
+    body_floats = getFloats body_env
+    empty_subst = CoreSubst.mkEmptySubst (seInScope body_env)
+
+    abstract :: CoreSubst.Subst -> OutBind -> SimplM (CoreSubst.Subst, OutBind)
+    abstract subst (NonRec id rhs)
+      = do { (poly_id, poly_app) <- mk_poly tvs_here id
+          ; let poly_rhs = mkLams tvs_here (CoreSubst.substExpr subst rhs)
+                subst'   = CoreSubst.extendIdSubst subst id poly_app
+          ; return (subst', (NonRec poly_id poly_rhs)) }
       where
-
-       tyvars_here = varSetElems (main_tyvar_set `intersectVarSet` exprSomeFreeVars isTyVar rhs)
+       tvs_here = varSetElems (main_tv_set `intersectVarSet` exprSomeFreeVars isTyVar rhs)
                -- Abstract only over the type variables free in the rhs
                -- wrt which the new binding is abstracted.  But the naive
                -- approach of abstract wrt the tyvars free in the Id's type
@@ -1057,28 +1058,26 @@ tryRhsTyLam env tyvars body             -- Only does something if there's a let
                -- abstracting wrt *all* the tyvars.  We'll see if that
                -- gives rise to problems.   SLPJ June 98
 
-    go env fn (Let (Rec prs) body)
-       = mapAndUnzipSmpl (mk_poly tyvars_here) vars    `thenSmpl` \ (vars', rhss') ->
-        let
-           gn body = fn (foldr Let body (zipWith mk_silly_bind vars rhss'))
-           pairs   = vars' `zip` [mkLams tyvars_here (gn rhs) | rhs <- rhss]
-        in
-        addAuxiliaryBind env (Rec pairs)               $ \ env ->
-        go env gn body 
+    abstract subst (Rec prs)
+       = do { (poly_ids, poly_apps) <- mapAndUnzipSmpl (mk_poly tvs_here) ids
+           ; let subst' = CoreSubst.extendSubstList subst (ids `zip` poly_apps)
+                 poly_rhss = [mkLams tvs_here (CoreSubst.substExpr subst' rhs) | rhs <- rhss]
+           ; return (subst', Rec (poly_ids `zip` poly_rhss)) }
        where
-        (vars,rhss) = unzip prs
-        tyvars_here = varSetElems (main_tyvar_set `intersectVarSet` exprsSomeFreeVars isTyVar (map snd prs))
-               -- See notes with tyvars_here above
-
-    go env fn body = returnSmpl (emptyFloats env, fn body)
-
-    mk_poly tyvars_here var
-      = getUniqueSmpl          `thenSmpl` \ uniq ->
-       let
-           poly_name = setNameUnique (idName var) uniq         -- Keep same name
-           poly_ty   = mkForAllTys tyvars_here (idType var)    -- But new type of course
-           poly_id   = mkLocalId poly_name poly_ty 
-
+        (ids,rhss) = unzip prs
+        
+        tvs_here = varSetElems (main_tv_set `intersectVarSet` bind_ftvs)
+        bind_ftvs = exprsSomeFreeVars isTyVar rhss `unionVarSet` tyVarsOfTypes (map idType ids)
+               -- Also nb that we must take the tyvars of the Id's type too:
+               --      x::a = x
+               -- Bizarre, I know
+
+    mk_poly tvs_here var
+      = do { uniq <- getUniqueSmpl
+          ; let  poly_name = setNameUnique (idName var) uniq           -- Keep same name
+                 poly_ty   = mkForAllTys tvs_here (idType var) -- But new type of course
+                 poly_id   = mkLocalId poly_name poly_ty 
+          ; return (poly_id, mkTyApps (Var poly_id) (mkTyVarTys tvs_here)) }
                -- In the olden days, it was crucial to copy the occInfo of the original var, 
                -- because we were looking at occurrence-analysed but as yet unsimplified code!
                -- In particular, we mustn't lose the loop breakers.  BUT NOW we are looking
@@ -1091,10 +1090,10 @@ tryRhsTyLam env tyvars body             -- Only does something if there's a let
                -- where x* has an INLINE prag on it.  Now, once x* is inlined,
                -- the occurrences of x' will be just the occurrences originally
                -- pinned on x.
-       in
-       returnSmpl (poly_id, mkTyApps (Var poly_id) (mkTyVarTys tyvars_here))
+\end{code}
+
+Historical note: if you use let-bindings instead of a substitution, beware of this:
 
-    mk_silly_bind var rhs = NonRec var (Note InlineMe rhs)
                -- Suppose we start with:
                --
                --      x = /\ a -> let g = G in E
@@ -1114,8 +1113,6 @@ tryRhsTyLam env tyvars body               -- Only does something if there's a let
                -- Solution: put an INLINE note on g's RHS, so that poly_g seems
                --           to appear many times.  (NB: mkInlineMe eliminates
                --           such notes on trivial RHSs, so do it manually.)
--}
-\end{code}
 
 %************************************************************************
 %*                                                                     *
index aab8925..5b8f304 100644 (file)
@@ -305,46 +305,38 @@ simplLazyBind :: SimplEnv
              -> SimplM SimplEnv
 
 simplLazyBind env top_lvl is_rec bndr bndr1 rhs rhs_se
-  = do { let   rhs_env  = rhs_se `setInScope` env
-               rhs_cont = mkRhsStop (idType bndr1)
+  = do { let   rhs_env     = rhs_se `setInScope` env
+               (tvs, body) = collectTyBinders rhs
+       ; (body_env, tvs') <- simplBinders rhs_env tvs
+               -- See Note [Floating and type abstraction]
+               -- in SimplUtils
 
        -- Simplify the RHS; note the mkRhsStop, which tells 
        -- the simplifier that this is the RHS of a let.
-       ; (rhs_env1, rhs1) <- simplExprF rhs_env rhs rhs_cont
-
-       -- If any of the floats can't be floated, give up now
-       -- (The canFloat predicate says True for empty floats.)
-       ; if (not (canFloat top_lvl is_rec False rhs_env1))
-         then  completeBind env top_lvl bndr bndr1
-                                (wrapFloats rhs_env1 rhs1)
-         else do
+       ; let rhs_cont = mkRhsStop (applyTys (idType bndr1) (mkTyVarTys tvs'))
+       ; (body_env1, body1) <- simplExprF body_env body rhs_cont
+
        -- ANF-ise a constructor or PAP rhs
-       { (rhs_env2, rhs2) <- prepareRhs rhs_env1 rhs1
-       ; (env', rhs3) <- chooseRhsFloats top_lvl is_rec False env rhs_env2 rhs2
-       ; completeBind env' top_lvl bndr bndr1 rhs3 } }
-
-chooseRhsFloats :: TopLevelFlag -> RecFlag -> Bool
-               -> SimplEnv     -- Env for the let
-               -> SimplEnv     -- Env for the RHS, with RHS floats in it
-               -> OutExpr              -- ..and the RHS itself
-               -> SimplM (SimplEnv, OutExpr)   -- New env for let, and RHS
-
-chooseRhsFloats top_lvl is_rec is_strict env rhs_env rhs
-  | not (isEmptyFloats rhs_env)                -- Something to float
-  , canFloat top_lvl is_rec is_strict rhs_env  -- ...that can float
-  , (isTopLevel top_lvl  || exprIsCheap rhs)   -- ...and we want to float      
-  = do { tick LetFloatFromLet  -- Float
-       ; return (addFloats env rhs_env, rhs) } -- Add the floats to the main env
-  | otherwise                  -- Don't float
-  = return (env, wrapFloats rhs_env rhs)       -- Wrap the floats around the RHS
-\end{code}
+       ; (body_env2, body2) <- prepareRhs body_env1 body1
 
+       ; (env', rhs')
+           <-  if not (doFloatFromRhs top_lvl is_rec False body2 body_env2)
+               then                            -- No floating, just wrap up!
+                    do { rhs' <- mkLam tvs' (wrapFloats body_env2 body2)
+                       ; return (env, rhs') }
 
-%************************************************************************
-%*                                                                     *
-\subsection{simplNonRec}
-%*                                                                     *
-%************************************************************************
+               else if null tvs then           -- Simple floating
+                    do { tick LetFloatFromLet
+                       ; return (addFloats env body_env2, body2) }
+
+               else                            -- Do type-abstraction first
+                    do { tick LetFloatFromLet
+                       ; (poly_binds, body3) <- abstractFloats tvs body_env2 body2
+                       ; rhs' <- mkLam tvs' body3
+                       ; return (extendFloats env poly_binds, rhs') }
+
+       ; completeBind env' top_lvl bndr bndr1 rhs' }
+\end{code}
 
 A specialised variant of simplNonRec used when the RHS is already simplified, 
 notably in knownCon.  It uses case-binding where necessary.
@@ -369,7 +361,11 @@ completeNonRecX :: SimplEnv
 
 completeNonRecX env top_lvl is_rec is_strict old_bndr new_bndr new_rhs
   = do         { (env1, rhs1) <- prepareRhs (zapFloats env) new_rhs
-       ; (env2, rhs2) <- chooseRhsFloats top_lvl is_rec is_strict env env1 rhs1
+       ; (env2, rhs2) <- 
+               if doFloatFromRhs top_lvl is_rec is_strict rhs1 env1
+               then do { tick LetFloatFromLet
+                       ; return (addFloats env env1, rhs1) }   -- Add the floats to the main env
+               else return (env, wrapFloats env1 rhs1)         -- Wrap the floats around the RHS
        ; completeBind env2 NotTopLevel old_bndr new_bndr rhs2 }
 \end{code}
 
@@ -447,6 +443,7 @@ prepareRhs env rhs
        = return (False, env, other)
 \end{code}
 
+
 Note [Float coercions]
 ~~~~~~~~~~~~~~~~~~~~~~
 When we find the binding