06e283164790123c829eae17d9f2528afb366eb2
[ghc-hetmet.git] / compiler / cmm / CmmCommonBlockElimZ.hs
1 module CmmCommonBlockElimZ
2   ( elimCommonBlocks
3   )
4 where
5
6
7 import Cmm hiding (blockId)
8 import CmmExpr
9 import Prelude hiding (iterate, zip, unzip)
10 import ZipCfg
11 import ZipCfgCmmRep
12
13 import FastString
14 import FiniteMap
15 import List hiding (iterate)
16 import Monad
17 import Outputable
18 import UniqFM
19 import Unique
20
21 my_trace :: String -> SDoc -> a -> a
22 my_trace = if True then pprTrace else \_ _ a -> a
23
24 -- Eliminate common blocks:
25 -- If two blocks are identical except for the label on the first node,
26 -- then we can eliminate one of the blocks. To ensure that the semantics
27 -- of the program are preserved, we have to rewrite each predecessor of the
28 -- eliminated block to proceed with the block we keep.
29
30 -- The algorithm iterates over the blocks in the graph,
31 -- checking whether it has seen another block that is equal modulo labels.
32 -- If so, then it adds an entry in a map indicating that the new block
33 -- is made redundant by the old block.
34 -- Otherwise, it is added to the useful blocks.
35
36 -- TODO: Use optimization fuel
37 elimCommonBlocks :: CmmGraph -> CmmGraph
38 elimCommonBlocks g =
39     upd_graph g . snd $ iterate common_block reset hashed_blocks (emptyUFM, emptyFM)
40       where hashed_blocks    = map (\b -> (hash_block b, b)) (reverse (postorder_dfs g))
41             reset (_, subst) = (emptyUFM, subst)
42
43 -- Iterate over the blocks until convergence
44 iterate :: (t -> a -> (Bool, t)) -> (t -> t) -> [a] -> t -> t
45 iterate upd reset blocks state =
46   case foldl upd' (False, state) blocks of
47     (True,  state') -> iterate upd reset blocks (reset state')
48     (False, state') -> state'
49   where upd' (b, s) a = let (b', s') = upd s a in (b || b', s') -- lift to track changes
50
51 -- Try to find a block that is equal (or ``common'') to b.
52 type BidMap = FiniteMap BlockId BlockId
53 type State  = (UniqFM [CmmBlock], BidMap)
54 common_block :: (Outputable h, Uniquable h) =>  State -> (h, CmmBlock) -> (Bool, State)
55 common_block (bmap, subst) (hash, b) =
56   case lookupUFM bmap $ my_trace "common_block" (ppr bid <+> ppr subst <+> ppr hash) $ hash of
57     Just bs -> case (find (eqBlockBodyWith (eqBid subst) b) bs, lookupFM subst bid) of
58                  (Just b', Nothing)                      -> addSubst b'
59                  (Just b', Just b'') | blockId b' /= b'' -> addSubst b'
60                  _ -> (False, (addToUFM bmap hash (b : bs), subst))
61     Nothing -> (False, (addToUFM bmap hash [b], subst))
62   where bid = blockId b
63         addSubst b' = my_trace "found new common block" (ppr (blockId b')) $
64                       (True, (bmap, addToFM subst bid (blockId b')))
65
66 -- Given the map ``subst'' from BlockId -> BlockId, we rewrite the graph.
67 upd_graph :: CmmGraph -> BidMap -> CmmGraph
68 upd_graph g subst = map_nodes id middle last g
69   where middle m = m
70         last (LastBranch bid)       = LastBranch $ sub bid
71         last (LastCondBranch p t f) = cond p (sub t) (sub f)
72         last (LastCall t bid)       = LastCall   t $ liftM sub bid
73         last (LastSwitch e bs)      = LastSwitch e $ map (liftM sub) bs
74         last l = l
75         cond p t f = if t == f then LastBranch t else LastCondBranch p t f
76         sub = lookupBid subst
77
78 -- To speed up comparisons, we hash each basic block modulo labels.
79 -- The hashing is a bit arbitrary (the numbers are completely arbitrary),
80 -- but it should be fast and good enough.
81 hash_block :: CmmBlock -> Int
82 hash_block (Block _ t) = hash_tail t 0
83   where hash_mid   (MidComment (FastString u _ _ _ _)) = u
84         hash_mid   (MidAssign r e) = hash_reg r + hash_e e
85         hash_mid   (MidStore e e') = hash_e e + hash_e e'
86         hash_mid   (MidUnsafeCall t _ as) = hash_tgt t + hash_as as
87         hash_mid   (MidAddToContext e es) = hash_e e + hash_lst hash_e es
88         hash_mid   (CopyIn _ fs _) = hash_fs fs
89         hash_mid   (CopyOut _ as) = hash_as as
90         hash_reg   (CmmLocal l) = hash_local l
91         hash_reg   (CmmGlobal _)    = 19
92         hash_reg   (CmmStack _)    = 13
93         hash_local (LocalReg _ _ _) = 117
94         hash_e (CmmLit l) = hash_lit l
95         hash_e (CmmLoad e _) = 67 + hash_e e
96         hash_e (CmmReg r) = hash_reg r
97         hash_e (CmmMachOp _ es) = hash_lst hash_e es -- pessimal - no operator check
98         hash_e (CmmRegOff r i) = hash_reg r + i
99         hash_lit (CmmInt i _) = fromInteger i
100         hash_lit (CmmFloat r _) = truncate r
101         hash_lit (CmmLabel _) = 119 -- ugh
102         hash_lit (CmmLabelOff _ i) = 199 + i
103         hash_lit (CmmLabelDiffOff _ _ i) = 299 + i
104         hash_tgt (CmmCallee e _) = hash_e e
105         hash_tgt (CmmPrim _) = 31 -- lots of these
106         hash_as = hash_lst $ hash_kinded hash_e
107         hash_fs = hash_lst $ hash_kinded hash_local
108         hash_kinded f (CmmKinded x _) = f x
109         hash_lst f = foldl (\z x -> f x + z) 0
110         hash_last (LastBranch _) = 23 -- would be great to hash these properly
111         hash_last (LastCondBranch p _ _) = hash_e p 
112         hash_last LastReturn = 17 -- better ideas?
113         hash_last (LastJump e) = hash_e e
114         hash_last (LastCall e _) = hash_e e
115         hash_last (LastSwitch e _) = hash_e e
116         hash_tail (ZLast LastExit) v = 29 + v * 2
117         hash_tail (ZLast (LastOther l)) v = hash_last l + (v * 2)
118         hash_tail (ZTail m t) v = hash_tail t (hash_mid m + (v * 2))
119
120 -- Utilities: equality and substitution on the graph.
121
122 -- Given a map ``subst'' from BlockID -> BlockID, we define equality.
123 eqBid :: BidMap -> BlockId -> BlockId -> Bool
124 eqBid subst bid bid' = lookupBid subst bid == lookupBid subst bid'
125 lookupBid :: BidMap -> BlockId -> BlockId
126 lookupBid subst bid = case lookupFM subst bid of
127                         Just bid  -> lookupBid subst bid
128                         Nothing -> bid
129
130 -- Equality on the body of a block, modulo a function mapping block IDs to block IDs.
131 eqBlockBodyWith :: (BlockId -> BlockId -> Bool) -> CmmBlock -> CmmBlock -> Bool
132 eqBlockBodyWith eqBid (Block _ t) (Block _ t') = eqTailWith eqBid t t'
133
134 type CmmTail = ZTail Middle Last
135 eqTailWith :: (BlockId -> BlockId -> Bool) -> CmmTail -> CmmTail -> Bool
136 eqTailWith eqBid (ZTail m t) (ZTail m' t') = m == m' && eqTailWith eqBid t t'
137 eqTailWith _ (ZLast LastExit) (ZLast LastExit) = True
138 eqTailWith eqBid (ZLast (LastOther l)) (ZLast (LastOther l')) = eqLastWith eqBid l l'
139 eqTailWith _ _ _ = False
140
141 eqLastWith :: (BlockId -> BlockId -> Bool) -> Last -> Last -> Bool
142 eqLastWith eqBid (LastBranch bid) (LastBranch bid') = eqBid bid bid'
143 eqLastWith eqBid c@(LastCondBranch _ _ _) c'@(LastCondBranch _ _ _) =
144   eqBid (cml_true c) (cml_true c')  && eqBid (cml_false c) (cml_false c') 
145 eqLastWith _ LastReturn LastReturn = True
146 eqLastWith _ (LastJump e) (LastJump e') = e == e'
147 eqLastWith eqBid c@(LastCall _ _) c'@(LastCall _ _) =
148   cml_target c == cml_target c' && eqMaybeWith eqBid (cml_cont c) (cml_cont c')
149 eqLastWith eqBid (LastSwitch e bs) (LastSwitch e' bs') =
150   e == e' && eqLstWith (eqMaybeWith eqBid) bs bs'
151 eqLastWith _ _ _ = False
152
153 eqLstWith :: (a -> b -> Bool) -> [a] -> [b] -> Bool
154 eqLstWith eltEq es es' = all (uncurry eltEq) (List.zip es es')
155
156 eqMaybeWith :: (a -> b -> Bool) -> Maybe a -> Maybe b -> Bool
157 eqMaybeWith eltEq (Just e) (Just e') = eltEq e e'
158 eqMaybeWith _ Nothing Nothing = True
159 eqMaybeWith _ _ _ = False