#include "HsVersions.h"
import CoreSyn
-import CoreUtils ( mkSCC )
+import CoreUtils ( mkSCC, exprIsValue, exprIsTrivial )
import CmdLineOpts ( DynFlags, DynFlag(..), FloatOutSwitches(..) )
import ErrUtils ( dumpIfSet_dyn )
import CostCentre ( dupifyCC, CostCentre )
-import Id ( Id )
+import Id ( Id, idType )
+import Type ( isUnLiftedType )
import CoreLint ( showPass, endPass )
-import SetLevels ( setLevels, Level(..), ltMajLvl, ltLvl, isTopLvl )
+import SetLevels ( Level(..), LevelledExpr, LevelledBind,
+ setLevels, ltMajLvl, ltLvl, isTopLvl )
import UniqSupply ( UniqSupply )
import List ( partition )
import Outputable
+import Util ( notNull )
\end{code}
-----------------
Well, maybe. We don't do this at the moment.
\begin{code}
-type LevelledExpr = TaggedExpr Level
-type LevelledBind = TaggedBind Level
-type FloatBind = (Level, CoreBind)
-type FloatBinds = [FloatBind]
+type FloatBind = (Level, CoreBind) -- INVARIANT: a FloatBind is always lifted
+type FloatBinds = [FloatBind]
\end{code}
%************************************************************************
floatTopBind bind@(Rec _)
= case (floatBind bind) of { (fs, floats, Rec pairs') ->
- WARN( not (null floats), ppr bind $$ ppr floats )
+ WARN( notNull floats, ppr bind $$ ppr floats )
(fs, [Rec (floatsToBindPairs floats ++ pairs')]) }
\end{code}
floatBind :: LevelledBind
-> (FloatStats, FloatBinds, CoreBind)
-floatBind (NonRec (name,level) rhs)
- = case (floatRhs level rhs) of { (fs, rhs_floats, rhs') ->
+floatBind (NonRec (TB name level) rhs)
+ = case (floatNonRecRhs level rhs) of { (fs, rhs_floats, rhs') ->
(fs, rhs_floats, NonRec name rhs') }
floatBind bind@(Rec pairs)
= case (unzip3 (map do_pair pairs)) of { (fss, rhss_floats, new_pairs) ->
- if not (isTopLvl bind_level) then
- -- Standard case
+ if not (isTopLvl bind_dest_level) then
+ -- Standard case; the floated bindings can't mention the
+ -- binders, because they couldn't be escaping a major level
+ -- if so.
(sum_stats fss, concat rhss_floats, Rec new_pairs)
else
-- In a recursive binding, *destined for* the top level
-- (only), the rhs floats may contain references to the
-- bound things. For example
- --
-- f = ...(let v = ...f... in b) ...
- --
-- might get floated to
- --
-- v = ...f...
-- f = ... b ...
- --
-- and hence we must (pessimistically) make all the floats recursive
-- with the top binding. Later dependency analysis will unravel it.
--
- -- Can't happen on nested bindings because floatRhs will dump
- -- the bindings in the RHS (partitionByMajorLevel treats top specially)
+ -- This can only happen for bindings destined for the top level,
+ -- because only then will partitionByMajorLevel allow through a binding
+ -- that only differs in its minor level
(sum_stats fss, [],
Rec (new_pairs ++ floatsToBindPairs (concat rhss_floats)))
}
where
- bind_level = getBindLevel bind
+ bind_dest_level = getBindLevel bind
- do_pair ((name, level), rhs)
+ do_pair (TB name level, rhs)
= case (floatRhs level rhs) of { (fs, rhs_floats, rhs') ->
(fs, rhs_floats, (name, rhs'))
}
%************************************************************************
\begin{code}
-floatExpr, floatRhs
+floatExpr, floatRhs, floatNonRecRhs
:: Level
-> LevelledExpr
-> (FloatStats, FloatBinds, CoreExpr)
-floatRhs lvl arg
+floatRhs lvl arg -- Used rec rhss, and case-alternative rhss
= case (floatExpr lvl arg) of { (fsa, floats, arg') ->
case (partitionByMajorLevel lvl floats) of { (floats', heres) ->
+ -- Dump bindings that aren't going to escape from a lambda;
+ -- in particular, we must dump the ones that are bound by
+ -- the rec or case alternative
+ (fsa, floats', install heres arg') }}
+
+floatNonRecRhs lvl arg -- Used for nested non-rec rhss, and fn args
+ = case (floatExpr lvl arg) of { (fsa, floats, arg') ->
-- Dump bindings that aren't going to escape from a lambda
- -- This is to avoid floating the x binding out of
+ -- This isn't a scoping issue (the binder isn't in scope in the RHS of a non-rec binding)
+ -- Rather, it is to avoid floating the x binding out of
-- f (let x = e in b)
- -- unnecessarily. It even causes a bug to do so if we have
- -- y = writeArr# a n (let x = e in b)
- -- because the y binding is an expr-ok-for-speculation one.
- -- [SLPJ Dec 01: I don't understand this last comment;
- -- writeArr# is not ok-for-spec because of its side effect]
+ -- unnecessarily. But we first test for values or trival rhss,
+ -- because (in particular) we don't want to insert new bindings between
+ -- the "=" and the "\". E.g.
+ -- f = \x -> let <bind> in <body>
+ -- We do not want
+ -- f = let <bind> in \x -> <body>
+ -- (a) The simplifier will immediately float it further out, so we may
+ -- as well do so right now; in general, keeping rhss as manifest
+ -- values is good
+ -- (b) If a float-in pass follows immediately, it might add yet more
+ -- bindings just after the '='. And some of them might (correctly)
+ -- be strict even though the 'let f' is lazy, because f, being a value,
+ -- gets its demand-info zapped by the simplifier.
+ if exprIsValue arg' || exprIsTrivial arg' then
+ (fsa, floats, arg')
+ else
+ case (partitionByMajorLevel lvl floats) of { (floats', heres) ->
(fsa, floats', install heres arg') }}
floatExpr _ (Var v) = (zeroStats, [], Var v)
floatExpr _ (Lit lit) = (zeroStats, [], Lit lit)
floatExpr lvl (App e a)
- = case (floatExpr lvl e) of { (fse, floats_e, e') ->
- case (floatRhs lvl a) of { (fsa, floats_a, a') ->
+ = case (floatExpr lvl e) of { (fse, floats_e, e') ->
+ case (floatNonRecRhs lvl a) of { (fsa, floats_a, a') ->
(fse `add_stats` fsa, floats_e ++ floats_a, App e' a') }}
floatExpr lvl lam@(Lam _ _)
= let
(bndrs_w_lvls, body) = collectBinders lam
- (bndrs, lvls) = unzip bndrs_w_lvls
+ bndrs = [b | TB b _ <- bndrs_w_lvls]
+ lvls = [l | TB b l <- bndrs_w_lvls]
-- For the all-tyvar case we are prepared to pull
-- the lets out, to implement the float-out-of-big-lambda
floatExpr lvl (Note InlineMe expr) -- Other than SCCs
= case floatExpr InlineCtxt expr of { (fs, floating_defns, expr') ->
- WARN( not (null floating_defns),
- ppr expr $$ ppr floating_defns ) -- We do no floating out of Inlines
- (fs, [], Note InlineMe expr') } -- See notes in SetLevels
+ -- There can be some floating_defns, arising from
+ -- ordinary lets that were there all the time. It seems
+ -- more efficient to test once here than to avoid putting
+ -- them into floating_defns (which would mean testing for
+ -- inlineCtxt at every let)
+ (fs, [], Note InlineMe (install floating_defns expr')) } -- See notes in SetLevels
floatExpr lvl (Note note expr) -- Other than SCCs
= case (floatExpr lvl expr) of { (fs, floating_defns, expr') ->
(fs, floating_defns, Note note expr') }
+floatExpr lvl (Let (NonRec (TB bndr bndr_lvl) rhs) body)
+ | isUnLiftedType (idType bndr) -- Treat unlifted lets just like a case
+ = case floatExpr lvl rhs of { (fs, rhs_floats, rhs') ->
+ case floatRhs bndr_lvl body of { (fs, body_floats, body') ->
+ (fs, rhs_floats ++ body_floats, Let (NonRec bndr rhs') body') }}
+
floatExpr lvl (Let bind body)
= case (floatBind bind) of { (fsb, rhs_floats, bind') ->
case (floatExpr lvl body) of { (fse, body_floats, body') ->
--- if isInlineCtxt lvl then -- No floating inside an InlineMe
--- ASSERT( null rhs_floats && null body_floats )
--- (add_stats fsb fse, [], Let bind' body')
--- else
- (add_stats fsb fse,
- rhs_floats ++ [(bind_lvl, bind')] ++ body_floats,
- body')
- }}
+ (add_stats fsb fse,
+ rhs_floats ++ [(bind_lvl, bind')] ++ body_floats,
+ body') }}
where
bind_lvl = getBindLevel bind
-floatExpr lvl (Case scrut (case_bndr, case_lvl) alts)
+floatExpr lvl (Case scrut (TB case_bndr case_lvl) alts)
= case floatExpr lvl scrut of { (fse, fde, scrut') ->
case floatList float_alt alts of { (fsa, fda, alts') ->
(add_stats fse fsa, fda ++ fde, Case scrut' case_bndr alts')
-- don't gratuitiously float bindings out of the RHSs
float_alt (con, bs, rhs)
= case (floatRhs case_lvl rhs) of { (fs, rhs_floats, rhs') ->
- (fs, rhs_floats, (con, map fst bs, rhs')) }
+ (fs, rhs_floats, (con, [b | TB b _ <- bs], rhs')) }
floatList :: (a -> (FloatStats, FloatBinds, b)) -> [a] -> (FloatStats, FloatBinds, [b])
%************************************************************************
\begin{code}
-getBindLevel (NonRec (_, lvl) _) = lvl
-getBindLevel (Rec (((_,lvl), _) : _)) = lvl
+getBindLevel (NonRec (TB _ lvl) _) = lvl
+getBindLevel (Rec (((TB _ lvl), _) : _)) = lvl
\end{code}
\begin{code}