Merge remote branch 'origin/master'
[ghc-hetmet.git] / compiler / simplCore / FloatOut.lhs
index fba88e7..e5db7d9 100644 (file)
@@ -24,9 +24,11 @@ import UniqSupply       ( UniqSupply )
 import Bag
 import Util
 import Maybes
-import UniqFM
 import Outputable
 import FastString
+import qualified Data.IntMap as M
+
+#include "HsVersions.h"
 \end{code}
 
        -----------------
@@ -223,6 +225,7 @@ floatRhs lvl arg    -- Used for nested non-rec rhss, and fn args
 -----------------
 floatExpr _ (Var v)   = (zeroStats, emptyFloats, Var v)
 floatExpr _ (Type ty) = (zeroStats, emptyFloats, Type ty)
+floatExpr _ (Coercion co) = (zeroStats, emptyFloats, Coercion co)
 floatExpr _ (Lit lit) = (zeroStats, emptyFloats, Lit lit)
          
 floatExpr lvl (App e a)
@@ -230,27 +233,20 @@ floatExpr lvl (App e a)
     case (floatRhs lvl a)      of { (fsa, floats_a, a') ->
     (fse `add_stats` fsa, floats_e `plusFloats` floats_a, App e' a') }}
 
-floatExpr _ lam@(Lam _ _)
-  = let
-       (bndrs_w_lvls, body) = collectBinders lam
+floatExpr _ lam@(Lam (TB _ lam_lvl) _)
+  = let (bndrs_w_lvls, body) = collectBinders lam
        bndrs                = [b | TB b _ <- bndrs_w_lvls]
-       lvls                 = [l | TB _ l <- bndrs_w_lvls]
-
-       -- For the all-tyvar case we are prepared to pull 
-       -- the lets out, to implement the float-out-of-big-lambda
-       -- transform; but otherwise we only float bindings that are
-       -- going to escape a value lambda.
-       -- In particular, for one-shot lambdas we don't float things
-       -- out; we get no saving by so doing.
-       partition_fn | all isTyCoVar bndrs = partitionByLevel
-                    | otherwise         = partitionByMajorLevel
+       -- All the binders have the same level
+       -- See SetLevels.lvlLamBndrs
     in
-    case (floatExpr (last lvls) body) of { (fs, floats, body') ->
-
-       -- Dump any bindings which absolutely cannot go any further
-    case (partition_fn (head lvls) floats)     of { (floats', heres) ->
-
-    (add_to_stats fs floats', floats', mkLams bndrs (install heres body'))
+    case (floatExpr lam_lvl body) of { (fs, floats, body1) ->
+
+        -- Dump anything that is captured by this lambda
+       -- Eg  \x -> ...(\y -> let v = <blah> in ...)...
+       -- We'll have the binding (v = <blah>) in the floats,
+       -- but must dump it at the lambda-x
+    case (partitionByLevel lam_lvl floats)     of { (floats1, heres) ->
+    (add_to_stats fs floats1, floats1, mkLams bndrs (install heres body1))
     }}
 
 floatExpr lvl (Note note@(SCC cc) expr)
@@ -401,34 +397,39 @@ data FloatBinds  = FB !(Bag FloatBind)            -- Destined for top level
                      !MajorEnv                 -- Levels other than top
      -- See Note [Representation of FloatBinds]
 
-type MajorEnv = UniqFM MinorEnv                        -- Keyed by major level
-type MinorEnv = UniqFM (Bag FloatBind)         -- Keyed by minor level
+instance Outputable FloatBinds where
+  ppr (FB fbs env) = ptext (sLit "FB") <+> (braces $ vcat
+                       [ ptext (sLit "binds =") <+> ppr fbs
+                       , ptext (sLit "env =") <+> ppr env ])
+
+type MajorEnv = M.IntMap MinorEnv                      -- Keyed by major level
+type MinorEnv = M.IntMap (Bag FloatBind)               -- Keyed by minor level
 
 flattenFloats :: FloatBinds -> Bag FloatBind
 flattenFloats (FB tops others) = tops `unionBags` flattenMajor others
 
 flattenMajor :: MajorEnv -> Bag FloatBind
-flattenMajor = foldUFM (unionBags . flattenMinor) emptyBag
+flattenMajor = M.fold (unionBags . flattenMinor) emptyBag
 
 flattenMinor :: MinorEnv -> Bag FloatBind
-flattenMinor = foldUFM unionBags emptyBag
+flattenMinor = M.fold unionBags emptyBag
 
 emptyFloats :: FloatBinds
-emptyFloats = FB emptyBag emptyUFM
+emptyFloats = FB emptyBag M.empty
 
 unitFloat :: Level -> FloatBind -> FloatBinds
 unitFloat lvl@(Level major minor) b 
-  | isTopLvl lvl = FB (unitBag b) emptyUFM
-  | otherwise    = FB emptyBag (unitUFM major (unitUFM minor (unitBag b)))
+  | isTopLvl lvl = FB (unitBag b) M.empty
+  | otherwise    = FB emptyBag (M.singleton major (M.singleton minor (unitBag b)))
 
 plusFloats :: FloatBinds -> FloatBinds -> FloatBinds
 plusFloats (FB t1 b1) (FB t2 b2) = FB (t1 `unionBags` t2) (b1 `plusMajor` b2)
 
 plusMajor :: MajorEnv -> MajorEnv -> MajorEnv
-plusMajor = plusUFM_C plusMinor
+plusMajor = M.unionWith plusMinor
 
 plusMinor :: MinorEnv -> MinorEnv -> MinorEnv
-plusMinor = plusUFM_C unionBags
+plusMinor = M.unionWith unionBags
 
 floatsToBindPairs :: Bag FloatBind -> [(Id,CoreExpr)] -> [(Id,CoreExpr)]
 floatsToBindPairs floats binds = foldrBag add binds floats
@@ -460,26 +461,26 @@ partitionByMajorLevel, partitionByLevel
 partitionByMajorLevel (Level major _) (FB tops defns)
   = (FB tops outer, heres `unionBags` flattenMajor inner)
   where
-    (outer, mb_heres, inner) = splitUFM defns major
+    (outer, mb_heres, inner) = M.splitLookup major defns
     heres = case mb_heres of 
                Nothing -> emptyBag
                Just h  -> flattenMinor h
 
 partitionByLevel (Level major minor) (FB tops defns)
-  = (FB tops (outer_maj `plusMajor` unitUFM major outer_min),
+  = (FB tops (outer_maj `plusMajor` M.singleton major outer_min),
      here_min `unionBags` flattenMinor inner_min 
               `unionBags` flattenMajor inner_maj)
 
   where
-    (outer_maj, mb_here_maj, inner_maj) = splitUFM defns major
+    (outer_maj, mb_here_maj, inner_maj) = M.splitLookup major defns
     (outer_min, mb_here_min, inner_min) = case mb_here_maj of
-                                            Nothing -> (emptyUFM, Nothing, emptyUFM)
-                                            Just min_defns -> splitUFM min_defns minor
+                                            Nothing -> (M.empty, Nothing, M.empty)
+                                            Just min_defns -> M.splitLookup minor min_defns
     here_min = mb_here_min `orElse` emptyBag
 
 wrapCostCentre :: CostCentre -> FloatBinds -> FloatBinds
 wrapCostCentre cc (FB tops defns)
-  = FB (wrap_defns tops) (mapUFM (mapUFM wrap_defns) defns)
+  = FB (wrap_defns tops) (M.map (M.map wrap_defns) defns)
   where
     wrap_defns = mapBag wrap_one 
     wrap_one (NonRec binder rhs) = NonRec binder (mkSCC cc rhs)