Don't dump floated bindings just outside a lambda
authorsimonpj@microsoft.com <unknown>
Fri, 2 Feb 2007 15:54:52 +0000 (15:54 +0000)
committersimonpj@microsoft.com <unknown>
Fri, 2 Feb 2007 15:54:52 +0000 (15:54 +0000)
We do not want the FloatOut pass to transform
f = \x. e
to
f = let lvl = ... in \x.e

The arity pinned on f isn't right any more; and see
Note [Floating out of RHSs].

Core Lint is now spotting the arity lossage (for a letrec), which is
how I spotted this bug.

I also re-jigged the code around floatBind; it's a bit tidier now.

compiler/simplCore/FloatOut.lhs

index 3477467..c97bbce 100644 (file)
@@ -24,7 +24,6 @@ import SetLevels      ( Level(..), LevelledExpr, LevelledBind,
 import UniqSupply       ( UniqSupply )
 import List            ( partition )
 import Outputable
-import Util             ( notNull )
 \end{code}
 
        -----------------
@@ -144,15 +143,10 @@ floatOutwards float_sws dflags us pgm
     pp_not True  = empty
     pp_not False = text "not"
 
-floatTopBind bind@(NonRec _ _)
-  = case (floatBind bind) of { (fs, floats, bind') ->
-    (fs, floatsToBinds floats ++ [bind'])
+floatTopBind bind
+  = case (floatBind bind) of { (fs, floats) ->
+    (fs, floatsToBinds floats)
     }
-
-floatTopBind bind@(Rec _)
-  = case (floatBind bind) of { (fs, floats, Rec pairs') ->
-    WARN( notNull floats, ppr bind $$ ppr floats )
-    (fs, [Rec (floatsToBindPairs floats ++ pairs')]) }
 \end{code}
 
 %************************************************************************
@@ -163,21 +157,25 @@ floatTopBind bind@(Rec _)
 
 
 \begin{code}
-floatBind :: LevelledBind
-         -> (FloatStats, FloatBinds, CoreBind)
+floatBind :: LevelledBind -> (FloatStats, FloatBinds)
 
 floatBind (NonRec (TB name level) rhs)
-  = case (floatNonRecRhs level rhs) of { (fs, rhs_floats, rhs') ->
-    (fs, rhs_floats, NonRec name rhs') }
+  = case (floatRhs level rhs) of { (fs, rhs_floats, rhs') ->
+    (fs, rhs_floats ++ [(level, NonRec name rhs')]) }
 
 floatBind bind@(Rec pairs)
   = case (unzip3 (map do_pair pairs)) of { (fss, rhss_floats, new_pairs) ->
-
-    if not (isTopLvl bind_dest_level) then
-       -- Standard case; the floated bindings can't mention the
-       -- binders, because they couldn't be escaping a major level
-       -- if so.
-       (sum_stats fss, concat rhss_floats, Rec new_pairs)
+    let rhs_floats = concat rhss_floats in
+
+    if not (isTopLvl bind_dest_lvl) then
+       -- Find which bindings float out at least one lambda beyond this one
+       -- These ones can't mention the binders, because they couldn't 
+       -- be escaping a major level if so.
+       -- The ones that are not going further can join the letrec;
+       -- they may not be mutually recursive but the occurrence analyser will
+       -- find that out.
+       case (partitionByMajorLevel bind_dest_lvl rhs_floats) of { (floats', heres) ->
+       (sum_stats fss, floats' ++ [(bind_dest_lvl, Rec (floatsToBindPairs heres ++ new_pairs))]) }
     else
        -- In a recursive binding, *destined for* the top level
        -- (only), the rhs floats may contain references to the 
@@ -192,11 +190,10 @@ floatBind bind@(Rec pairs)
        -- This can only happen for bindings destined for the top level,
        -- because only then will partitionByMajorLevel allow through a binding
        -- that only differs in its minor level
-       (sum_stats fss, [],
-        Rec (new_pairs ++ floatsToBindPairs (concat rhss_floats)))
+       (sum_stats fss, [(bind_dest_lvl, Rec (new_pairs ++ floatsToBindPairs rhs_floats))])
     }
   where
-    bind_dest_level = getBindLevel bind
+    bind_dest_lvl = getBindLevel bind
 
     do_pair (TB name level, rhs)
       = case (floatRhs level rhs) of { (fs, rhs_floats, rhs') ->
@@ -211,12 +208,12 @@ floatBind bind@(Rec pairs)
 %************************************************************************
 
 \begin{code}
-floatExpr, floatRhs, floatNonRecRhs
+floatExpr, floatRhs, floatCaseAlt
         :: Level
         -> LevelledExpr
         -> (FloatStats, FloatBinds, CoreExpr)
 
-floatRhs lvl arg       -- Used rec rhss, and case-alternative rhss
+floatCaseAlt lvl arg   -- Used rec rhss, and case-alternative rhss
   = case (floatExpr lvl arg) of { (fsa, floats, arg') ->
     case (partitionByMajorLevel lvl floats) of { (floats', heres) ->
        -- Dump bindings that aren't going to escape from a lambda;
@@ -224,38 +221,43 @@ floatRhs lvl arg  -- Used rec rhss, and case-alternative rhss
        -- the rec or case alternative
     (fsa, floats', install heres arg') }}
 
-floatNonRecRhs lvl arg -- Used for nested non-rec rhss, and fn args
+floatRhs lvl arg       -- Used for nested non-rec rhss, and fn args
+                       -- See Note [Floating out of RHS]
   = case (floatExpr lvl arg) of { (fsa, floats, arg') ->
-       -- Dump bindings that aren't going to escape from a lambda
-       -- This isn't a scoping issue (the binder isn't in scope in the RHS of a non-rec binding)
-       -- Rather, it is to avoid floating the x binding out of
-       --      f (let x = e in b)
-       -- unnecessarily.  But we first test for values or trival rhss,
-       -- because (in particular) we don't want to insert new bindings between
-       -- the "=" and the "\".  E.g.
-       --      f = \x -> let <bind> in <body>
-       -- We do not want
-       --      f = let <bind> in \x -> <body>
-       -- (a) The simplifier will immediately float it further out, so we may
-       --      as well do so right now; in general, keeping rhss as manifest 
-       --      values is good
-       -- (b) If a float-in pass follows immediately, it might add yet more
-       --      bindings just after the '='.  And some of them might (correctly)
-       --      be strict even though the 'let f' is lazy, because f, being a value,
-       --      gets its demand-info zapped by the simplifier.
     if exprIsHNF arg' || exprIsTrivial arg' then
        (fsa, floats, arg')
     else
     case (partitionByMajorLevel lvl floats) of { (floats', heres) ->
     (fsa, floats', install heres arg') }}
 
+-- Note [Floating out of RHSs]
+-- ~~~~~~~~~~~~~~~~~~~~~~~~~~~
+-- Dump bindings that aren't going to escape from a lambda
+-- This isn't a scoping issue (the binder isn't in scope in the RHS 
+--     of a non-rec binding)
+-- Rather, it is to avoid floating the x binding out of
+--     f (let x = e in b)
+-- unnecessarily.  But we first test for values or trival rhss,
+-- because (in particular) we don't want to insert new bindings between
+-- the "=" and the "\".  E.g.
+--     f = \x -> let <bind> in <body>
+-- We do not want
+--     f = let <bind> in \x -> <body>
+-- (a) The simplifier will immediately float it further out, so we may
+--     as well do so right now; in general, keeping rhss as manifest 
+--     values is good
+-- (b) If a float-in pass follows immediately, it might add yet more
+--     bindings just after the '='.  And some of them might (correctly)
+--     be strict even though the 'let f' is lazy, because f, being a value,
+--     gets its demand-info zapped by the simplifier.
+
 floatExpr _ (Var v)   = (zeroStats, [], Var v)
 floatExpr _ (Type ty) = (zeroStats, [], Type ty)
 floatExpr _ (Lit lit) = (zeroStats, [], Lit lit)
          
 floatExpr lvl (App e a)
   = case (floatExpr      lvl e) of { (fse, floats_e, e') ->
-    case (floatNonRecRhs lvl a) of { (fsa, floats_a, a') ->
+    case (floatRhs lvl a)      of { (fsa, floats_a, a') ->
     (fse `add_stats` fsa, floats_e ++ floats_a, App e' a') }}
 
 floatExpr lvl lam@(Lam _ _)
@@ -326,13 +328,11 @@ floatExpr lvl (Let (NonRec (TB bndr bndr_lvl) rhs) body)
     (fs, rhs_floats ++ body_floats, Let (NonRec bndr rhs') body') }}
 
 floatExpr lvl (Let bind body)
-  = case (floatBind bind)     of { (fsb, rhs_floats,  bind') ->
+  = case (floatBind bind)     of { (fsb, bind_floats) ->
     case (floatExpr lvl body) of { (fse, body_floats, body') ->
     (add_stats fsb fse,
-     rhs_floats ++ [(bind_lvl, bind')] ++ body_floats,
+     bind_floats ++ body_floats,
      body')  }}
-  where
-    bind_lvl = getBindLevel bind
 
 floatExpr lvl (Case scrut (TB case_bndr case_lvl) ty alts)
   = case floatExpr lvl scrut   of { (fse, fde, scrut') ->
@@ -340,10 +340,10 @@ floatExpr lvl (Case scrut (TB case_bndr case_lvl) ty alts)
     (add_stats fse fsa, fda ++ fde, Case scrut' case_bndr ty alts')
     }}
   where
-       -- Use floatRhs for the alternatives, so that we
+       -- Use floatCaseAlt for the alternatives, so that we
        -- don't gratuitiously float bindings out of the RHSs
     float_alt (con, bs, rhs)
-       = case (floatRhs case_lvl rhs)  of { (fs, rhs_floats, rhs') ->
+       = case (floatCaseAlt case_lvl rhs)      of { (fs, rhs_floats, rhs') ->
          (fs, rhs_floats, (con, [b | TB b _ <- bs], rhs')) }