Merging in the new codegen branch
[ghc-hetmet.git] / compiler / cmm / CmmCommonBlockElimZ.hs
1 module CmmCommonBlockElimZ
2   ( elimCommonBlocks
3   )
4 where
5
6
7 import BlockId
8 import CmmExpr
9 import Prelude hiding (iterate, zip, unzip)
10 import ZipCfg
11 import ZipCfgCmmRep
12
13 import FastString
14 import FiniteMap
15 import List hiding (iterate)
16 import Monad
17 import Outputable
18 import UniqFM
19 import Unique
20
21 my_trace :: String -> SDoc -> a -> a
22 my_trace = if True then pprTrace else \_ _ a -> a
23
24 -- Eliminate common blocks:
25 -- If two blocks are identical except for the label on the first node,
26 -- then we can eliminate one of the blocks. To ensure that the semantics
27 -- of the program are preserved, we have to rewrite each predecessor of the
28 -- eliminated block to proceed with the block we keep.
29
30 -- The algorithm iterates over the blocks in the graph,
31 -- checking whether it has seen another block that is equal modulo labels.
32 -- If so, then it adds an entry in a map indicating that the new block
33 -- is made redundant by the old block.
34 -- Otherwise, it is added to the useful blocks.
35
36 -- TODO: Use optimization fuel
37 elimCommonBlocks :: CmmGraph -> CmmGraph
38 elimCommonBlocks g =
39     upd_graph g . snd $ iterate common_block reset hashed_blocks (emptyUFM, emptyFM)
40       where hashed_blocks    = map (\b -> (hash_block b, b)) (reverse (postorder_dfs g))
41             reset (_, subst) = (emptyUFM, subst)
42
43 -- Iterate over the blocks until convergence
44 iterate :: (t -> a -> (Bool, t)) -> (t -> t) -> [a] -> t -> t
45 iterate upd reset blocks state =
46   case foldl upd' (False, state) blocks of
47     (True,  state') -> iterate upd reset blocks (reset state')
48     (False, state') -> state'
49   where upd' (b, s) a = let (b', s') = upd s a in (b || b', s') -- lift to track changes
50
51 -- Try to find a block that is equal (or ``common'') to b.
52 type BidMap = FiniteMap BlockId BlockId
53 type State  = (UniqFM [CmmBlock], BidMap)
54 common_block :: (Outputable h, Uniquable h) =>  State -> (h, CmmBlock) -> (Bool, State)
55 common_block (bmap, subst) (hash, b) =
56   case lookupUFM bmap $ my_trace "common_block" (ppr bid <+> ppr subst <+> ppr hash) $ hash of
57     Just bs -> case (find (eqBlockBodyWith (eqBid subst) b) bs, lookupFM subst bid) of
58                  (Just b', Nothing)                      -> addSubst b'
59                  (Just b', Just b'') | blockId b' /= b'' -> addSubst b'
60                  _ -> (False, (addToUFM bmap hash (b : bs), subst))
61     Nothing -> (False, (addToUFM bmap hash [b], subst))
62   where bid = blockId b
63         addSubst b' = my_trace "found new common block" (ppr (blockId b')) $
64                       (True, (bmap, addToFM subst bid (blockId b')))
65
66 -- Given the map ``subst'' from BlockId -> BlockId, we rewrite the graph.
67 upd_graph :: CmmGraph -> BidMap -> CmmGraph
68 upd_graph g subst = map_nodes id middle last g
69   where middle m = m
70         last (LastBranch bid)       = LastBranch $ sub bid
71         last (LastCondBranch p t f) = cond p (sub t) (sub f)
72         last (LastCall t bid s)     = LastCall   t (liftM sub bid) s
73         last (LastSwitch e bs)      = LastSwitch e $ map (liftM sub) bs
74         last l = l
75         cond p t f = if t == f then LastBranch t else LastCondBranch p t f
76         sub = lookupBid subst
77
78 -- To speed up comparisons, we hash each basic block modulo labels.
79 -- The hashing is a bit arbitrary (the numbers are completely arbitrary),
80 -- but it should be fast and good enough.
81 hash_block :: CmmBlock -> Int
82 hash_block (Block _ _ t) = hash_tail t 0
83   where hash_mid   (MidComment (FastString u _ _ _ _)) = u
84         hash_mid   (MidAssign r e) = hash_reg r + hash_e e
85         hash_mid   (MidStore e e') = hash_e e + hash_e e'
86         hash_mid   (MidUnsafeCall t _ as) = hash_tgt t + hash_lst hash_e as
87         hash_mid   (MidAddToContext e es) = hash_e e + hash_lst hash_e es
88         hash_reg   (CmmLocal l) = hash_local l
89         hash_reg   (CmmGlobal _)    = 19
90         hash_local (LocalReg _ _) = 117
91         hash_e (CmmLit l) = hash_lit l
92         hash_e (CmmLoad e _) = 67 + hash_e e
93         hash_e (CmmReg r) = hash_reg r
94         hash_e (CmmMachOp _ es) = hash_lst hash_e es -- pessimal - no operator check
95         hash_e (CmmRegOff r i) = hash_reg r + i
96         hash_e (CmmStackSlot _ _) = 13
97         hash_lit (CmmInt i _) = fromInteger i
98         hash_lit (CmmFloat r _) = truncate r
99         hash_lit (CmmLabel _) = 119 -- ugh
100         hash_lit (CmmLabelOff _ i) = 199 + i
101         hash_lit (CmmLabelDiffOff _ _ i) = 299 + i
102         hash_tgt (ForeignTarget e _) = hash_e e
103         hash_tgt (PrimTarget _) = 31 -- lots of these
104         hash_lst f = foldl (\z x -> f x + z) (0::Int)
105         hash_last (LastBranch _) = 23 -- would be great to hash these properly
106         hash_last (LastCondBranch p _ _) = hash_e p 
107         hash_last (LastReturn _) = 17 -- better ideas?
108         hash_last (LastJump e _) = hash_e e
109         hash_last (LastCall e _ _) = hash_e e
110         hash_last (LastSwitch e _) = hash_e e
111         hash_tail (ZLast LastExit) v = 29 + v * 2
112         hash_tail (ZLast (LastOther l)) v = hash_last l + (v * 2)
113         hash_tail (ZTail m t) v = hash_tail t (hash_mid m + (v * 2))
114
115 -- Utilities: equality and substitution on the graph.
116
117 -- Given a map ``subst'' from BlockID -> BlockID, we define equality.
118 eqBid :: BidMap -> BlockId -> BlockId -> Bool
119 eqBid subst bid bid' = lookupBid subst bid == lookupBid subst bid'
120 lookupBid :: BidMap -> BlockId -> BlockId
121 lookupBid subst bid = case lookupFM subst bid of
122                         Just bid  -> lookupBid subst bid
123                         Nothing -> bid
124
125 -- Equality on the body of a block, modulo a function mapping block IDs to block IDs.
126 eqBlockBodyWith :: (BlockId -> BlockId -> Bool) -> CmmBlock -> CmmBlock -> Bool
127 eqBlockBodyWith eqBid (Block _ Nothing t) (Block _ Nothing t') = eqTailWith eqBid t t'
128 eqBlockBodyWith _ _ _ = False
129
130 type CmmTail = ZTail Middle Last
131 eqTailWith :: (BlockId -> BlockId -> Bool) -> CmmTail -> CmmTail -> Bool
132 eqTailWith eqBid (ZTail m t) (ZTail m' t') = m == m' && eqTailWith eqBid t t'
133 eqTailWith _ (ZLast LastExit) (ZLast LastExit) = True
134 eqTailWith eqBid (ZLast (LastOther l)) (ZLast (LastOther l')) = eqLastWith eqBid l l'
135 eqTailWith _ _ _ = False
136
137 eqLastWith :: (BlockId -> BlockId -> Bool) -> Last -> Last -> Bool
138 eqLastWith eqBid (LastBranch bid) (LastBranch bid') = eqBid bid bid'
139 eqLastWith eqBid c@(LastCondBranch _ _ _) c'@(LastCondBranch _ _ _) =
140   eqBid (cml_true c) (cml_true c')  && eqBid (cml_false c) (cml_false c') 
141 eqLastWith _ (LastReturn s) (LastReturn s') = s == s'
142 eqLastWith _ (LastJump e s) (LastJump e' s') = e == e' && s == s'
143 eqLastWith eqBid c@(LastCall _ _ s) c'@(LastCall _ _ s') =
144   cml_target c == cml_target c' && eqMaybeWith eqBid (cml_cont c) (cml_cont c') &&
145   s == s'
146 eqLastWith eqBid (LastSwitch e bs) (LastSwitch e' bs') =
147   e == e' && eqLstWith (eqMaybeWith eqBid) bs bs'
148 eqLastWith _ _ _ = False
149
150 eqLstWith :: (a -> b -> Bool) -> [a] -> [b] -> Bool
151 eqLstWith eltEq es es' = all (uncurry eltEq) (List.zip es es')
152
153 eqMaybeWith :: (a -> b -> Bool) -> Maybe a -> Maybe b -> Bool
154 eqMaybeWith eltEq (Just e) (Just e') = eltEq e e'
155 eqMaybeWith _ Nothing Nothing = True
156 eqMaybeWith _ _ _ = False