[project @ 2001-02-20 13:15:11 by simonpj]
authorsimonpj <unknown>
Tue, 20 Feb 2001 13:15:11 +0000 (13:15 +0000)
committersimonpj <unknown>
Tue, 20 Feb 2001 13:15:11 +0000 (13:15 +0000)
Fix top level float

ghc/compiler/coreSyn/CoreSat.lhs

index b26f3a8..62eda2e 100644 (file)
@@ -10,7 +10,7 @@ module CoreSat (
 
 #include "HsVersions.h"
 
-import CoreUtils( exprIsTrivial, exprIsAtom, exprType, exprIsValue, etaExpand )
+import CoreUtils( exprIsTrivial, exprIsAtom, exprType, exprIsValue, etaExpand, exprArity )
 import CoreFVs ( exprFreeVars )
 import CoreLint        ( endPass )
 import CoreSyn
@@ -100,20 +100,64 @@ coreSatExpr dflags expr
 -- Dealing with bindings
 -- ---------------------------------------------------------------------------
 
-data FloatingBind = FloatBind CoreBind
+data FloatingBind = FloatLet CoreBind
                  | FloatCase Id CoreExpr
 
+allLazy :: OrdList FloatingBind -> Bool
+allLazy floats = foldOL check True floats
+              where
+                check (FloatLet _)    y = y
+                check (FloatCase _ _) y = False
+
 coreSatTopBinds :: [CoreBind] -> UniqSM [CoreBind]
 -- Very careful to preserve the arity of top-level functions
-coreSatTopBinds bs
-  = mapUs do_bind bs
+coreSatTopBinds [] = returnUs []
+
+coreSatTopBinds (NonRec b r : binds)
+  = coreSatTopRhs b r          `thenUs` \ (floats, r') ->
+    coreSatTopBinds binds      `thenUs` \ binds' ->
+    returnUs (floats ++ NonRec b r' : binds')
+
+coreSatTopBinds (Rec prs : binds)
+  = mapAndUnzipUs do_pair prs  `thenUs` \ (floats_s, prs') ->
+    coreSatTopBinds binds      `thenUs` \ binds' ->
+    returnUs (Rec (flattenBinds (concat floats_s) ++ prs') : binds')
   where
-    do_bind (NonRec b r) = coreSatAnExpr r     `thenUs` \ r' ->
-                          returnUs (NonRec b r')
-    do_bind (Rec prs)   = mapUs do_pair prs    `thenUs` \ prs' ->
-                          returnUs (Rec prs')
-    do_pair (b,r)       = coreSatAnExpr r      `thenUs` \ r' ->
-                          returnUs (b, r')
+    do_pair (b,r) = coreSatTopRhs b r  `thenUs` \ (floats, r') ->
+                   returnUs (floats, (b, r'))
+
+coreSatTopRhs :: Id -> CoreExpr -> UniqSM ([CoreBind], CoreExpr)
+-- The trick here is that if we see
+--     x = $wC p $wJust q
+-- we want to transform to
+--     sat = \a -> $wJust a
+--     x = $wC p sat q
+-- and NOT to
+--     x = let sat = \a -> $wJust a in $wC p sat q
+--
+-- The latter is bad because the thing was a value before, but
+-- is a thunk now, and that's wrong because now x may need to
+-- be in other bindings' SRTs.
+-- This has to be right for recursive as well as non-recursive bindings
+--
+-- Notice that it's right to give sat vanilla IdInfo; in particular NoCafRefs
+--
+-- You might worry that arity might increase, thus
+--     x = $wC a  ==>  x = \ b c -> $wC a b c
+-- but the simpifier does eta expansion vigorously, so I don't think this 
+-- can occur.  If it did, it would be a problem, because x's arity changes,
+-- so we have an ASSERT to check.  (I use WARN so we can see the output.)
+
+coreSatTopRhs b rhs
+  = coreSatExprFloat rhs       `thenUs` \ (floats, rhs1) ->
+    if exprIsValue rhs then
+       ASSERT( allLazy floats )
+        WARN( idArity b /= exprArity rhs1, ptext SLIT("Disaster!") <+> ppr b )
+       returnUs ([bind | FloatLet bind <- fromOL floats], rhs1)
+    else
+       mkBinds floats rhs1     `thenUs` \ rhs2 ->
+        WARN( idArity b /= exprArity rhs2, ptext SLIT("Disaster!") <+> ppr b )
+       returnUs ([], rhs2)
 
 
 coreSatBind :: CoreBind -> UniqSM (OrdList FloatingBind)
@@ -127,13 +171,15 @@ coreSatBind :: CoreBind -> UniqSM (OrdList FloatingBind)
 
 coreSatBind (NonRec binder rhs)
   = coreSatExprFloat rhs       `thenUs` \ (floats, new_rhs) ->
-    mkNonRec binder new_rhs (bdrDem binder) floats
+    mkNonRec binder (bdrDem binder) floats new_rhs
        -- NB: if there are any lambdas at the top of the RHS,
        -- the floats will be empty, so the arity won't be affected
 
 coreSatBind (Rec pairs)
+       -- Don't bother to try to float bindings out of RHSs
+       -- (compare mkNonRec, which does try)
   = mapUs do_rhs pairs                         `thenUs` \ new_pairs ->
-    returnUs (unitOL (FloatBind (Rec new_pairs)))
+    returnUs (unitOL (FloatLet (Rec new_pairs)))
   where
     do_rhs (bndr,rhs) =        coreSatAnExpr rhs       `thenUs` \ new_rhs' ->
                        returnUs (bndr,new_rhs')
@@ -150,7 +196,7 @@ coreSatArg arg dem
     if needs_binding arg'
        then returnUs (floats, arg')
        else newVar (exprType arg')     `thenUs` \ v ->
-            mkNonRec v arg' dem floats `thenUs` \ floats' -> 
+            mkNonRec v dem floats arg' `thenUs` \ floats' -> 
             returnUs (floats', Var v)
 
 needs_binding | opt_KeepStgTypes = exprIsAtom
@@ -287,7 +333,7 @@ coreSatExprFloat expr@(App _ _)
     collect_args fun depth
        = coreSatExprFloat fun                  `thenUs` \ (fun_floats, fun) ->
          newVar ty                             `thenUs` \ fn_id ->
-          mkNonRec fn_id fun onceDem fun_floats        `thenUs` \ floats ->
+          mkNonRec fn_id onceDem fun_floats fun        `thenUs` \ floats ->
          returnUs (Var fn_id, (Var fn_id, depth), ty, floats, [])
         where
          ty = exprType fun
@@ -334,18 +380,27 @@ maybeSaturate fn expr n_args ty
 -- Precipitating the floating bindings
 -- ---------------------------------------------------------------------------
 
--- mkNonrec is used for local bindings only, not top level
-mkNonRec bndr rhs dem floats
-  |  isUnLiftedType bndr_rep_ty
-  || isStrictDem dem && not (exprIsValue rhs)
+-- mkNonRec is used for local bindings only, not top level
+mkNonRec :: Id  -> RhsDemand                   -- Lhs: id with demand
+        -> OrdList FloatingBind -> CoreExpr    -- Rhs: let binds in body
+        -> UniqSM (OrdList FloatingBind)
+mkNonRec bndr dem floats rhs
+  | exprIsValue rhs            -- Notably constructor applications
+  = ASSERT( allLazy floats )   -- The only floats we can get out of a value are eta expansions 
+                               -- e.g.  C $wJust ==> let s = \x -> $wJust x in C s
+                               -- Here we want to float the s binding.
+    returnUs (floats `snocOL` FloatLet (NonRec bndr rhs))
+    
+  |  isUnLiftedType bndr_rep_ty        || isStrictDem dem 
   = ASSERT( not (isUnboxedTupleType bndr_rep_ty) )
     returnUs (floats `snocOL` FloatCase bndr rhs)
-  where
-    bndr_rep_ty = repType (idType bndr)
 
-mkNonRec bndr rhs dem floats
+  | otherwise
   = mkBinds floats rhs `thenUs` \ rhs' ->
-    returnUs (unitOL (FloatBind (NonRec bndr rhs')))
+    returnUs (unitOL (FloatLet (NonRec bndr rhs')))
+
+  where
+    bndr_rep_ty  = repType (idType bndr)
 
 mkBinds :: OrdList FloatingBind -> CoreExpr -> UniqSM CoreExpr
 mkBinds binds body 
@@ -354,7 +409,7 @@ mkBinds binds body
                    returnUs (foldOL mk_bind body' binds)
   where
     mk_bind (FloatCase bndr rhs) body = Case rhs bndr [(DEFAULT, [], body)]
-    mk_bind (FloatBind bind)     body = Let bind body
+    mk_bind (FloatLet bind)      body = Let bind body
 
 -- ---------------------------------------------------------------------------
 -- Eliminate Lam as a non-rhs (STG doesn't have such a thing)