Bottom extraction: float out bottoming expressions to top level
[ghc-hetmet.git] / compiler / simplCore / FloatOut.lhs
index d65f7bd..f5f8946 100644 (file)
@@ -10,11 +10,12 @@ module FloatOut ( floatOutwards ) where
 
 import CoreSyn
 import CoreUtils
+import CoreArity       ( etaExpand )
 
 import DynFlags        ( DynFlags, DynFlag(..), FloatOutSwitches(..) )
 import ErrUtils                ( dumpIfSet_dyn )
 import CostCentre      ( dupifyCC, CostCentre )
-import Id              ( Id, idType )
+import Id              ( Id, idType, idArity, isBottomingId )
 import Type            ( isUnLiftedType )
 import SetLevels       ( Level(..), LevelledExpr, LevelledBind,
                          setLevels, isTopLvl, tOP_LEVEL )
@@ -144,13 +145,18 @@ floatTopBind bind
 %*                                                                     *
 %************************************************************************
 
-
 \begin{code}
 floatBind :: LevelledBind -> (FloatStats, FloatBinds)
 
-floatBind (NonRec (TB name level) rhs)
+floatBind (NonRec (TB var level) rhs)
   = case (floatRhs level rhs) of { (fs, rhs_floats, rhs') ->
-    (fs, rhs_floats `plusFloats` unitFloat level (NonRec name rhs')) }
+
+       -- A tiresome hack: 
+       -- see Note [Bottoming floats: eta expansion] in SetLevels
+    let rhs'' | isBottomingId var = etaExpand (idArity var) rhs'
+             | otherwise         = rhs'
+
+    in (fs, rhs_floats `plusFloats` unitFloat level (NonRec var rhs'')) }
 
 floatBind bind@(Rec pairs)
   = case (unzip3 (map do_pair pairs)) of { (fss, rhss_floats, new_pairs) ->
@@ -297,8 +303,8 @@ floatExpr lvl (Cast expr co)
     (fs, floating_defns, Cast expr' co) }
 
 floatExpr lvl (Let (NonRec (TB bndr bndr_lvl) rhs) body)
-  | isUnLiftedType (idType bndr)       -- Treat unlifted lets just like a case
-                               -- I.e. floatExpr for rhs, floatCaseAlt for body
+  | isUnLiftedType (idType bndr)  -- Treat unlifted lets just like a case
+                                 -- I.e. floatExpr for rhs, floatCaseAlt for body
   = case floatExpr lvl rhs         of { (_, rhs_floats, rhs') ->
     case floatCaseAlt bndr_lvl body of { (fs, body_floats, body') ->
     (fs, rhs_floats `plusFloats` body_floats, Let (NonRec bndr rhs') body') }}