[project @ 2003-08-15 11:31:02 by simonmar]
[ghc-hetmet.git] / ghc / compiler / coreSyn / CorePrep.lhs
index d2515c9..18444b6 100644 (file)
@@ -109,7 +109,7 @@ corePrepPgm dflags binds types
            binds_out = initUs_ us (
                          corePrepTopBinds binds        `thenUs` \ floats1 ->
                          corePrepTopBinds implicit_binds       `thenUs` \ floats2 ->
-                         returnUs (deFloatTop (floats1 `appOL` floats2))
+                         returnUs (deFloatTop (floats1 `appendFloats` floats2))
                        )
            
         endPass dflags "CorePrep" Opt_D_dump_prep binds_out
@@ -179,45 +179,81 @@ data FloatingBind = FloatLet CoreBind
                  | FloatCase Id CoreExpr Bool
                        -- The bool indicates "ok-for-speculation"
 
+data Floats = Floats OkToSpec (OrdList FloatingBind)
+
+-- Can we float these binds out of the rhs of a let?  We cache this decision
+-- to avoid having to recompute it in a non-linear way when there are
+-- deeply nested lets.
+data OkToSpec
+   = NotOkToSpec       -- definitely not
+   | OkToSpec          -- yes
+   | IfUnboxedOk       -- only if floating an unboxed binding is ok
+
+emptyFloats :: Floats
+emptyFloats = Floats OkToSpec nilOL
+
+addFloat :: Floats -> FloatingBind -> Floats
+addFloat (Floats ok_to_spec floats) new_float
+  = Floats (combine ok_to_spec (check new_float)) (floats `snocOL` new_float)
+  where
+    check (FloatLet _)               = OkToSpec
+    check (FloatCase _ _ ok_for_spec) 
+       | ok_for_spec  =  IfUnboxedOk
+       | otherwise    =  NotOkToSpec
+       -- The ok-for-speculation flag says that it's safe to
+       -- float this Case out of a let, and thereby do it more eagerly
+       -- We need the top-level flag because it's never ok to float
+       -- an unboxed binding to the top level
+
+unitFloat :: FloatingBind -> Floats
+unitFloat = addFloat emptyFloats
+
+appendFloats :: Floats -> Floats -> Floats
+appendFloats (Floats spec1 floats1) (Floats spec2 floats2)
+  = Floats (combine spec1 spec2) (floats1 `appOL` floats2)
+
+concatFloats :: [Floats] -> Floats
+concatFloats = foldr appendFloats emptyFloats
+
+combine NotOkToSpec _ = NotOkToSpec
+combine _ NotOkToSpec = NotOkToSpec
+combine IfUnboxedOk _ = IfUnboxedOk
+combine _ IfUnboxedOk = IfUnboxedOk
+combine _ _           = OkToSpec
+    
 instance Outputable FloatingBind where
   ppr (FloatLet bind)        = text "FloatLet" <+> ppr bind
   ppr (FloatCase b rhs spec) = text "FloatCase" <+> ppr b <+> ppr spec <+> equals <+> ppr rhs
 
 type CloneEnv = IdEnv Id       -- Clone local Ids
 
-deFloatTop :: OrdList FloatingBind -> [CoreBind]
+deFloatTop :: Floats -> [CoreBind]
 -- For top level only; we don't expect any FloatCases
-deFloatTop floats
+deFloatTop (Floats _ floats)
   = foldrOL get [] floats
   where
     get (FloatLet b) bs = b:bs
     get b           bs = pprPanic "corePrepPgm" (ppr b)
 
-allLazy :: TopLevelFlag -> RecFlag -> OrdList FloatingBind -> Bool
-allLazy top_lvl is_rec floats 
-  = foldrOL check True floats
-  where
-    unboxed_ok = isNotTopLevel top_lvl && isNonRec is_rec
-
-    check (FloatLet _)               y = y
-    check (FloatCase _ _ ok_for_spec) y = unboxed_ok && ok_for_spec && y
-       -- The ok-for-speculation flag says that it's safe to
-       -- float this Case out of a let, and thereby do it more eagerly
-       -- We need the top-level flag because it's never ok to float
-       -- an unboxed binding to the top level
+allLazy :: TopLevelFlag -> RecFlag -> Floats -> Bool
+allLazy top_lvl is_rec (Floats ok_to_spec _)
+  = case ok_to_spec of
+       OkToSpec -> True
+       NotOkToSpec -> False
+       IfUnboxedOk -> isNotTopLevel top_lvl && isNonRec is_rec
 
 -- ---------------------------------------------------------------------------
 --                     Bindings
 -- ---------------------------------------------------------------------------
 
-corePrepTopBinds :: [CoreBind] -> UniqSM (OrdList FloatingBind)
+corePrepTopBinds :: [CoreBind] -> UniqSM Floats
 corePrepTopBinds binds 
   = go emptyVarEnv binds
   where
-    go env []            = returnUs nilOL
+    go env []            = returnUs emptyFloats
     go env (bind : binds) = corePrepTopBind env bind   `thenUs` \ (env', bind') ->
                            go env' binds               `thenUs` \ binds' ->
-                           returnUs (bind' `appOL` binds')
+                           returnUs (bind' `appendFloats` binds')
 
 -- NB: we do need to float out of top-level bindings
 -- Consider    x = length [True,False]
@@ -247,16 +283,16 @@ corePrepTopBinds binds
 -- it looks difficult.
 
 --------------------------------
-corePrepTopBind :: CloneEnv -> CoreBind -> UniqSM (CloneEnv, OrdList FloatingBind)
+corePrepTopBind :: CloneEnv -> CoreBind -> UniqSM (CloneEnv, Floats)
 corePrepTopBind env (NonRec bndr rhs) 
   = cloneBndr env bndr                                 `thenUs` \ (env', bndr') ->
     corePrepRhs TopLevel NonRecursive env (bndr, rhs)  `thenUs` \ (floats, rhs') -> 
-    returnUs (env', floats `snocOL` FloatLet (NonRec bndr' rhs'))
+    returnUs (env', addFloat floats (FloatLet (NonRec bndr' rhs')))
 
 corePrepTopBind env (Rec pairs) = corePrepRecPairs TopLevel env pairs
 
 --------------------------------
-corePrepBind ::  CloneEnv -> CoreBind -> UniqSM (CloneEnv, OrdList FloatingBind)
+corePrepBind ::  CloneEnv -> CoreBind -> UniqSM (CloneEnv, Floats)
        -- This one is used for *local* bindings
 corePrepBind env (NonRec bndr rhs)
   = etaExpandRhs bndr rhs                              `thenUs` \ rhs1 ->
@@ -270,16 +306,16 @@ corePrepBind env (Rec pairs) = corePrepRecPairs NotTopLevel env pairs
 --------------------------------
 corePrepRecPairs :: TopLevelFlag -> CloneEnv
                 -> [(Id,CoreExpr)]     -- Recursive bindings
-                -> UniqSM (CloneEnv, OrdList FloatingBind)
+                -> UniqSM (CloneEnv, Floats)
 -- Used for all recursive bindings, top level and otherwise
 corePrepRecPairs lvl env pairs
   = cloneBndrs env (map fst pairs)                             `thenUs` \ (env', bndrs') ->
     mapAndUnzipUs (corePrepRhs lvl Recursive env') pairs       `thenUs` \ (floats_s, rhss') ->
-    returnUs (env', unitOL (FloatLet (Rec (flatten (concatOL floats_s) bndrs' rhss'))))
+    returnUs (env', unitFloat (FloatLet (Rec (flatten (concatFloats floats_s) bndrs' rhss'))))
   where
        -- Flatten all the floats, and the currrent
        -- group into a single giant Rec
-    flatten floats bndrs rhss = foldrOL get (bndrs `zip` rhss) floats
+    flatten (Floats _ floats) bndrs rhss = foldrOL get (bndrs `zip` rhss) floats
 
     get (FloatLet (NonRec b r)) prs2 = (b,r) : prs2
     get (FloatLet (Rec prs1))   prs2 = prs1 ++ prs2
@@ -287,7 +323,7 @@ corePrepRecPairs lvl env pairs
 --------------------------------
 corePrepRhs :: TopLevelFlag -> RecFlag
            -> CloneEnv -> (Id, CoreExpr)
-           -> UniqSM (OrdList FloatingBind, CoreExpr)
+           -> UniqSM (Floats, CoreExpr)
 -- Used for top-level bindings, and local recursive bindings
 corePrepRhs top_lvl is_rec env (bndr, rhs)
   = etaExpandRhs bndr rhs      `thenUs` \ rhs' ->
@@ -301,7 +337,7 @@ corePrepRhs top_lvl is_rec env (bndr, rhs)
 
 -- This is where we arrange that a non-trivial argument is let-bound
 corePrepArg :: CloneEnv -> CoreArg -> RhsDemand
-          -> UniqSM (OrdList FloatingBind, CoreArg)
+          -> UniqSM (Floats, CoreArg)
 corePrepArg env arg dem
   = corePrepExprFloat env arg          `thenUs` \ (floats, arg') ->
     if exprIsTrivial arg'
@@ -330,7 +366,7 @@ corePrepAnExpr env expr
     mkBinds floats expr
 
 
-corePrepExprFloat :: CloneEnv -> CoreExpr -> UniqSM (OrdList FloatingBind, CoreExpr)
+corePrepExprFloat :: CloneEnv -> CoreExpr -> UniqSM (Floats, CoreExpr)
 -- If
 --     e  ===>  (bs, e')
 -- then        
@@ -343,18 +379,18 @@ corePrepExprFloat env (Var v)
   = fiddleCCall v                              `thenUs` \ v1 ->
     let v2 = lookupVarEnv env v1 `orElse` v1 in
     maybeSaturate v2 (Var v2) 0 (idType v2)    `thenUs` \ app ->
-    returnUs (nilOL, app)
+    returnUs (emptyFloats, app)
 
 corePrepExprFloat env expr@(Type _)
-  = returnUs (nilOL, expr)
+  = returnUs (emptyFloats, expr)
 
 corePrepExprFloat env expr@(Lit lit)
-  = returnUs (nilOL, expr)
+  = returnUs (emptyFloats, expr)
 
 corePrepExprFloat env (Let bind body)
   = corePrepBind env bind              `thenUs` \ (env', new_binds) ->
     corePrepExprFloat env' body                `thenUs` \ (floats, new_body) ->
-    returnUs (new_binds `appOL` floats, new_body)
+    returnUs (new_binds `appendFloats` floats, new_body)
 
 corePrepExprFloat env (Note n@(SCC _) expr)
   = corePrepAnExpr env expr            `thenUs` \ expr1 ->
@@ -368,7 +404,7 @@ corePrepExprFloat env (Note other_note expr)
 corePrepExprFloat env expr@(Lam _ _)
   = cloneBndrs env bndrs               `thenUs` \ (env', bndrs') ->
     corePrepAnExpr env' body           `thenUs` \ body' ->
-    returnUs (nilOL, mkLams bndrs' body')
+    returnUs (emptyFloats, mkLams bndrs' body')
   where
     (bndrs,body) = collectBinders expr
 
@@ -377,7 +413,7 @@ corePrepExprFloat env (Case scrut bndr alts)
     deLamFloat scrut1                  `thenUs` \ (floats2, scrut2) ->
     cloneBndr env bndr                 `thenUs` \ (env', bndr') ->
     mapUs (sat_alt env') alts          `thenUs` \ alts' ->
-    returnUs (floats1 `appOL` floats2 , Case scrut2 bndr' alts')
+    returnUs (floats1 `appendFloats` floats2 , Case scrut2 bndr' alts')
   where
     sat_alt env (con, bs, rhs)
          = cloneBndrs env bs           `thenUs` \ (env', bs') ->
@@ -411,7 +447,7 @@ corePrepExprFloat env expr@(App _ _)
                   (CoreExpr,Int),        -- the head of the application,
                                          -- and no. of args it was applied to
                   Type,                  -- type of the whole expr
-                  OrdList FloatingBind,  -- any floats we pulled out
+                  Floats,                -- any floats we pulled out
                   [Demand])              -- remaining argument demands
 
     collect_args (App fun arg@(Type arg_ty)) depth
@@ -428,12 +464,12 @@ corePrepExprFloat env expr@(App _ _)
                                  splitFunTy_maybe fun_ty
          in
          corePrepArg env arg (mkDemTy ss1 arg_ty)      `thenUs` \ (fs, arg') ->
-         returnUs (App fun' arg', hd, res_ty, fs `appOL` floats, ss_rest)
+         returnUs (App fun' arg', hd, res_ty, fs `appendFloats` floats, ss_rest)
 
     collect_args (Var v) depth
        = fiddleCCall v `thenUs` \ v1 ->
          let v2 = lookupVarEnv env v1 `orElse` v1 in
-         returnUs (Var v2, (Var v2, depth), idType v2, nilOL, stricts)
+         returnUs (Var v2, (Var v2, depth), idType v2, emptyFloats, stricts)
        where
          stricts = case idNewStrictness v of
                        StrictSig (DmdType _ demands _)
@@ -495,9 +531,9 @@ maybeSaturate fn expr n_args ty
 
 floatRhs :: TopLevelFlag -> RecFlag
         -> Id
-        -> (OrdList FloatingBind, CoreExpr)    -- Rhs: let binds in body
-        -> UniqSM (OrdList FloatingBind,       -- Floats out of this bind
-                   CoreExpr)                   -- Final Rhs
+        -> (Floats, CoreExpr)  -- Rhs: let binds in body
+        -> UniqSM (Floats,     -- Floats out of this bind
+                   CoreExpr)   -- Final Rhs
 
 floatRhs top_lvl is_rec bndr (floats, rhs)
   | isTopLevel top_lvl || exprIsValue rhs,     -- Float to expose value or 
@@ -513,12 +549,12 @@ floatRhs top_lvl is_rec bndr (floats, rhs)
   | otherwise
        -- Don't float; the RHS isn't a value
   = mkBinds floats rhs         `thenUs` \ rhs' ->
-    returnUs (nilOL, rhs')
+    returnUs (emptyFloats, rhs')
 
 -- mkLocalNonRec is used only for *nested*, *non-recursive* bindings
-mkLocalNonRec :: Id  -> RhsDemand                      -- Lhs: id with demand
-             -> OrdList FloatingBind -> CoreExpr       -- Rhs: let binds in body
-             -> UniqSM (OrdList FloatingBind)
+mkLocalNonRec :: Id  -> RhsDemand      -- Lhs: id with demand
+             -> Floats -> CoreExpr     -- Rhs: let binds in body
+             -> UniqSM Floats
 
 mkLocalNonRec bndr dem floats rhs
   | isUnLiftedType (idType bndr)
@@ -527,7 +563,7 @@ mkLocalNonRec bndr dem floats rhs
     let
        float = FloatCase bndr rhs (exprOkForSpeculation rhs)
     in
-    returnUs (floats `snocOL` float)
+    returnUs (addFloat floats float)
 
   | isStrict dem 
        -- It's a strict let so we definitely float all the bindings
@@ -537,18 +573,18 @@ mkLocalNonRec bndr dem floats rhs
        float | exprIsValue rhs = FloatLet (NonRec bndr rhs)
              | otherwise       = FloatCase bndr rhs (exprOkForSpeculation rhs)
     in
-    returnUs (floats `snocOL` float)
+    returnUs (addFloat floats float)
 
   | otherwise
   = floatRhs NotTopLevel NonRecursive bndr (floats, rhs)       `thenUs` \ (floats', rhs') ->
-    returnUs (floats' `snocOL` FloatLet (NonRec bndr rhs'))
+    returnUs (addFloat floats' (FloatLet (NonRec bndr rhs')))
 
   where
     bndr_ty     = idType bndr
 
 
-mkBinds :: OrdList FloatingBind -> CoreExpr -> UniqSM CoreExpr
-mkBinds binds body 
+mkBinds :: Floats -> CoreExpr -> UniqSM CoreExpr
+mkBinds (Floats _ binds) body 
   | isNilOL binds = returnUs body
   | otherwise    = deLam body          `thenUs` \ body' ->
                    returnUs (foldrOL mk_bind body' binds)
@@ -606,7 +642,7 @@ deLam expr =
   mkBinds floats expr
 
 
-deLamFloat :: CoreExpr -> UniqSM (OrdList FloatingBind, CoreExpr)
+deLamFloat :: CoreExpr -> UniqSM (Floats, CoreExpr)
 -- Remove top level lambdas by let-bindinig
 
 deLamFloat (Note n expr)
@@ -616,12 +652,12 @@ deLamFloat (Note n expr)
     returnUs (floats, Note n expr')
 
 deLamFloat expr 
-  | null bndrs = returnUs (nilOL, expr)
+  | null bndrs = returnUs (emptyFloats, expr)
   | otherwise 
   = case tryEta bndrs body of
-      Just no_lam_result -> returnUs (nilOL, no_lam_result)
+      Just no_lam_result -> returnUs (emptyFloats, no_lam_result)
       Nothing           -> newVar (exprType expr)      `thenUs` \ fn ->
-                           returnUs (unitOL (FloatLet (NonRec fn expr)), 
+                           returnUs (unitFloat (FloatLet (NonRec fn expr)), 
                                      Var fn)
   where
     (bndrs,body) = collectBinders expr