4c144cf63cb925747acf9f04c6c894f3757d0ad9
[ghc-hetmet.git] / compiler / cmm / CmmCommonBlockElimZ.hs
1 module CmmCommonBlockElimZ
2   ( elimCommonBlocks
3   )
4 where
5
6
7 import BlockId
8 import CmmExpr
9 import Prelude hiding (iterate, zip, unzip)
10 import ZipCfg
11 import ZipCfgCmmRep
12
13 import Data.Bits
14 import Data.Word
15 import FastString
16 import List hiding (iterate)
17 import Monad
18 import Outputable
19 import UniqFM
20 import Unique
21
22 my_trace :: String -> SDoc -> a -> a
23 my_trace = if False then pprTrace else \_ _ a -> a
24
25 -- Eliminate common blocks:
26 -- If two blocks are identical except for the label on the first node,
27 -- then we can eliminate one of the blocks. To ensure that the semantics
28 -- of the program are preserved, we have to rewrite each predecessor of the
29 -- eliminated block to proceed with the block we keep.
30
31 -- The algorithm iterates over the blocks in the graph,
32 -- checking whether it has seen another block that is equal modulo labels.
33 -- If so, then it adds an entry in a map indicating that the new block
34 -- is made redundant by the old block.
35 -- Otherwise, it is added to the useful blocks.
36
37 -- TODO: Use optimization fuel
38 elimCommonBlocks :: CmmGraph -> CmmGraph
39 elimCommonBlocks g =
40     upd_graph g . snd $ iterate common_block reset hashed_blocks
41                                 (emptyUFM, emptyBlockEnv)
42       where hashed_blocks    = map (\b -> (hash_block b, b)) (reverse (postorder_dfs g))
43             reset (_, subst) = (emptyUFM, subst)
44
45 -- Iterate over the blocks until convergence
46 iterate :: (t -> a -> (Bool, t)) -> (t -> t) -> [a] -> t -> t
47 iterate upd reset blocks state =
48   case foldl upd' (False, state) blocks of
49     (True,  state') -> iterate upd reset blocks (reset state')
50     (False, state') -> state'
51   where upd' (b, s) a = let (b', s') = upd s a in (b || b', s') -- lift to track changes
52
53 -- Try to find a block that is equal (or ``common'') to b.
54 type BidMap = BlockEnv BlockId
55 type State  = (UniqFM [CmmBlock], BidMap)
56 common_block :: (Outputable h, Uniquable h) =>  State -> (h, CmmBlock) -> (Bool, State)
57 common_block (bmap, subst) (hash, b) =
58   case lookupUFM bmap hash of
59     Just bs -> case (find (eqBlockBodyWith (eqBid subst) b) bs,
60                      lookupBlockEnv subst bid) of
61                  (Just b', Nothing)                      -> addSubst b'
62                  (Just b', Just b'') | blockId b' /= b'' -> addSubst b'
63                  _ -> (False, (addToUFM bmap hash (b : bs), subst))
64     Nothing -> (False, (addToUFM bmap hash [b], subst))
65   where bid = blockId b
66         addSubst b' = my_trace "found new common block" (ppr (blockId b')) $
67                       (True, (bmap, extendBlockEnv subst bid (blockId b')))
68
69 -- Given the map ``subst'' from BlockId -> BlockId, we rewrite the graph.
70 upd_graph :: CmmGraph -> BidMap -> CmmGraph
71 upd_graph g subst = map_nodes id middle last g
72   where middle = mapExpDeepMiddle exp
73         last l = last' (mapExpDeepLast exp l)
74         last' (LastBranch bid)            = LastBranch $ sub bid
75         last' (LastCondBranch p t f)      = cond p (sub t) (sub f)
76         last' (LastCall t (Just bid) args res u) = LastCall t (Just $ sub bid) args res u
77         last' l@(LastCall _ Nothing _ _ _)  = l
78         last' (LastSwitch e bs)           = LastSwitch e $ map (liftM sub) bs
79         cond p t f = if t == f then LastBranch t else LastCondBranch p t f
80         exp (CmmStackSlot (CallArea (Young id))       off) =
81              CmmStackSlot (CallArea (Young (sub id))) off
82         exp (CmmLit (CmmBlock id)) = CmmLit (CmmBlock (sub id))
83         exp e = e
84         sub = lookupBid subst
85
86 -- To speed up comparisons, we hash each basic block modulo labels.
87 -- The hashing is a bit arbitrary (the numbers are completely arbitrary),
88 -- but it should be fast and good enough.
89 hash_block :: CmmBlock -> Int
90 hash_block (Block _ t) =
91   fromIntegral (hash_tail t (0 :: Word32) .&. (0x7fffffff :: Word32))
92   -- UniqFM doesn't like negative Ints
93   where hash_mid   (MidComment (FastString u _ _ _ _)) = cvt u
94         hash_mid   (MidAssign r e) = hash_reg r + hash_e e
95         hash_mid   (MidStore e e') = hash_e e + hash_e e'
96         hash_mid   (MidForeignCall _ t _ as) = hash_tgt t + hash_lst hash_e as
97         hash_reg :: CmmReg -> Word32
98         hash_reg   (CmmLocal l) = hash_local l
99         hash_reg   (CmmGlobal _)    = 19
100         hash_local (LocalReg _ _) = 117
101         hash_e :: CmmExpr -> Word32
102         hash_e (CmmLit l) = hash_lit l
103         hash_e (CmmLoad e _) = 67 + hash_e e
104         hash_e (CmmReg r) = hash_reg r
105         hash_e (CmmMachOp _ es) = hash_lst hash_e es -- pessimal - no operator check
106         hash_e (CmmRegOff r i) = hash_reg r + cvt i
107         hash_e (CmmStackSlot _ _) = 13
108         hash_lit :: CmmLit -> Word32
109         hash_lit (CmmInt i _) = fromInteger i
110         hash_lit (CmmFloat r _) = truncate r
111         hash_lit (CmmLabel _) = 119 -- ugh
112         hash_lit (CmmLabelOff _ i) = cvt $ 199 + i
113         hash_lit (CmmLabelDiffOff _ _ i) = cvt $ 299 + i
114         hash_lit (CmmBlock _) = 191 -- ugh
115         hash_lit (CmmHighStackMark) = cvt 313
116         hash_tgt (ForeignTarget e _) = hash_e e
117         hash_tgt (PrimTarget _) = 31 -- lots of these
118         hash_lst f = foldl (\z x -> f x + z) (0::Word32)
119         hash_last (LastBranch _) = 23 -- would be great to hash these properly
120         hash_last (LastCondBranch p _ _) = hash_e p 
121         hash_last (LastCall e _ _ _ _) = hash_e e
122         hash_last (LastSwitch e _) = hash_e e
123         hash_tail (ZLast LastExit) v = 29 + v `shiftL` 1
124         hash_tail (ZLast (LastOther l)) v = hash_last l + (v `shiftL` 1)
125         hash_tail (ZTail m t) v = hash_tail t (hash_mid m + (v `shiftL` 1))
126         cvt = fromInteger . toInteger
127 -- Utilities: equality and substitution on the graph.
128
129 -- Given a map ``subst'' from BlockID -> BlockID, we define equality.
130 eqBid :: BidMap -> BlockId -> BlockId -> Bool
131 eqBid subst bid bid' = lookupBid subst bid == lookupBid subst bid'
132 lookupBid :: BidMap -> BlockId -> BlockId
133 lookupBid subst bid = case lookupBlockEnv subst bid of
134                         Just bid  -> lookupBid subst bid
135                         Nothing -> bid
136
137 -- Equality on the body of a block, modulo a function mapping block IDs to block IDs.
138 eqBlockBodyWith :: (BlockId -> BlockId -> Bool) -> CmmBlock -> CmmBlock -> Bool
139 eqBlockBodyWith eqBid (Block _ t) (Block _ t') = eqTailWith eqBid t t'
140
141 type CmmTail = ZTail Middle Last
142 eqTailWith :: (BlockId -> BlockId -> Bool) -> CmmTail -> CmmTail -> Bool
143 eqTailWith eqBid (ZTail m t) (ZTail m' t') = m == m' && eqTailWith eqBid t t'
144 eqTailWith _ (ZLast LastExit) (ZLast LastExit) = True
145 eqTailWith eqBid (ZLast (LastOther l)) (ZLast (LastOther l')) = eqLastWith eqBid l l'
146 eqTailWith _ _ _ = False
147
148 eqLastWith :: (BlockId -> BlockId -> Bool) -> Last -> Last -> Bool
149 eqLastWith eqBid (LastBranch bid1) (LastBranch bid2) = eqBid bid1 bid2
150 eqLastWith eqBid (LastCondBranch c1 t1 f1) (LastCondBranch c2 t2 f2) =
151   c1 == c2 && eqBid t1 t2 && eqBid f1 f2
152 eqLastWith eqBid (LastCall t1 c1 a1 r1 u1) (LastCall t2 c2 a2 r2 u2) =
153   t1 == t2 && eqMaybeWith eqBid c1 c2 && a1 == a2 && r1 == r2 && u1 == u2
154 eqLastWith eqBid (LastSwitch e1 bs1) (LastSwitch e2 bs2) =
155   e1 == e2 && eqLstWith (eqMaybeWith eqBid) bs1 bs2
156 eqLastWith _ _ _ = False
157
158 eqLstWith :: (a -> b -> Bool) -> [a] -> [b] -> Bool
159 eqLstWith eltEq es es' = all (uncurry eltEq) (List.zip es es')
160
161 eqMaybeWith :: (a -> b -> Bool) -> Maybe a -> Maybe b -> Bool
162 eqMaybeWith eltEq (Just e) (Just e') = eltEq e e'
163 eqMaybeWith _ Nothing Nothing = True
164 eqMaybeWith _ _ _ = False