Improve dead block calculation, per Simon Marlow's advice.
authorEdward Z. Yang <ezyang@mit.edu>
Thu, 5 May 2011 18:37:23 +0000 (19:37 +0100)
committerEdward Z. Yang <ezyang@mit.edu>
Thu, 5 May 2011 20:37:40 +0000 (21:37 +0100)
Signed-off-by: Edward Z. Yang <ezyang@mit.edu>

compiler/cmm/CmmOpt.hs

index 1c7e7e5..a2eecd5 100644 (file)
@@ -37,6 +37,7 @@ import Data.Bits
 import Data.Word
 import Data.Int
 import Data.Maybe
+import Data.List
 
 import Compiler.Hoopl hiding (Unique)
 
@@ -57,11 +58,9 @@ cmmEliminateDeadBlocks :: [CmmBasicBlock] -> [CmmBasicBlock]
 cmmEliminateDeadBlocks [] = []
 cmmEliminateDeadBlocks blocks@(BasicBlock base_id _:_) =
     let -- Calculate what's reachable from what block
-        -- We have to do a deep fold into CmmExpr because
-        -- there may be a BlockId in the CmmBlock literal.
-        reachableMap = foldl f emptyBlockMap blocks
-            where f m (BasicBlock block_id stmts) = mapInsert block_id (reachableFrom stmts) m
-        reachableFrom stmts = foldl stmt emptyBlockSet stmts
+        reachableMap = foldl' f emptyUFM blocks -- lazy in values
+            where f m (BasicBlock block_id stmts) = addToUFM m block_id (reachableFrom stmts)
+        reachableFrom stmts = foldl stmt [] stmts
             where
                 stmt m CmmNop = m
                 stmt m (CmmComment _) = m
@@ -70,30 +69,30 @@ cmmEliminateDeadBlocks blocks@(BasicBlock base_id _:_) =
                 stmt m (CmmCall c _ as _ _) = f (actuals m as) c
                     where f m (CmmCallee e _) = expr m e
                           f m (CmmPrim _) = m
-                stmt m (CmmBranch b) = setInsert b m
-                stmt m (CmmCondBranch e b) = setInsert b (expr m e)
-                stmt m (CmmSwitch e bs) = foldl (flip setInsert) (expr m e) (catMaybes bs)
+                stmt m (CmmBranch b) = b:m
+                stmt m (CmmCondBranch e b) = b:(expr m e)
+                stmt m (CmmSwitch e bs) = catMaybes bs ++ expr m e
                 stmt m (CmmJump e as) = expr (actuals m as) e
                 stmt m (CmmReturn as) = actuals m as
-                actuals m as = foldl (\m h -> expr m (hintlessCmm h)) m as
+                actuals m as = foldl' (\m h -> expr m (hintlessCmm h)) m as
+                -- We have to do a deep fold into CmmExpr because
+                -- there may be a BlockId in the CmmBlock literal.
                 expr m (CmmLit l) = lit m l
                 expr m (CmmLoad e _) = expr m e
                 expr m (CmmReg _) = m
-                expr m (CmmMachOp _ es) = foldl expr m es
+                expr m (CmmMachOp _ es) = foldl' expr m es
                 expr m (CmmStackSlot _ _) = m
                 expr m (CmmRegOff _ _) = m
-                lit m (CmmBlock b) = setInsert b m
+                lit m (CmmBlock b) = b:m
                 lit m _ = m
-        -- Expand reachable set until you hit fixpoint
-        initReachable = setSingleton base_id :: BlockSet
-        expandReachable old_set new_set =
-            if setSize new_set > setSize old_set
-                then expandReachable new_set $ setFold
-                        (\x s -> maybe setEmpty id (mapLookup x reachableMap) `setUnion` s)
-                        new_set
-                        (setDifference new_set old_set)
-                else new_set -- fixpoint achieved
-        reachable = expandReachable setEmpty initReachable
+        -- go todo done
+        reachable = go [base_id] (setEmpty :: BlockSet)
+          where go []     m = m
+                go (x:xs) m
+                    | setMember x m = go xs m
+                    | otherwise     = go (add ++ xs) (setInsert x m)
+                        where add = fromMaybe (panic "cmmEliminateDeadBlocks: unknown block")
+                                              (lookupUFM reachableMap x)
     in filter (\(BasicBlock block_id _) -> setMember block_id reachable) blocks
 
 -- -----------------------------------------------------------------------------