Replacing copyins and copyouts with data-movement instructions
[ghc-hetmet.git] / compiler / cmm / CmmCommonBlockElimZ.hs
1 module CmmCommonBlockElimZ
2   ( elimCommonBlocks
3   )
4 where
5
6
7 import BlockId
8 import Cmm hiding (blockId)
9 import CmmExpr
10 import Prelude hiding (iterate, zip, unzip)
11 import ZipCfg
12 import ZipCfgCmmRep
13
14 import FastString
15 import FiniteMap
16 import List hiding (iterate)
17 import Monad
18 import Outputable
19 import UniqFM
20 import Unique
21
22 my_trace :: String -> SDoc -> a -> a
23 my_trace = if True then pprTrace else \_ _ a -> a
24
25 -- Eliminate common blocks:
26 -- If two blocks are identical except for the label on the first node,
27 -- then we can eliminate one of the blocks. To ensure that the semantics
28 -- of the program are preserved, we have to rewrite each predecessor of the
29 -- eliminated block to proceed with the block we keep.
30
31 -- The algorithm iterates over the blocks in the graph,
32 -- checking whether it has seen another block that is equal modulo labels.
33 -- If so, then it adds an entry in a map indicating that the new block
34 -- is made redundant by the old block.
35 -- Otherwise, it is added to the useful blocks.
36
37 -- TODO: Use optimization fuel
38 elimCommonBlocks :: CmmGraph -> CmmGraph
39 elimCommonBlocks g =
40     upd_graph g . snd $ iterate common_block reset hashed_blocks (emptyUFM, emptyFM)
41       where hashed_blocks    = map (\b -> (hash_block b, b)) (reverse (postorder_dfs g))
42             reset (_, subst) = (emptyUFM, subst)
43
44 -- Iterate over the blocks until convergence
45 iterate :: (t -> a -> (Bool, t)) -> (t -> t) -> [a] -> t -> t
46 iterate upd reset blocks state =
47   case foldl upd' (False, state) blocks of
48     (True,  state') -> iterate upd reset blocks (reset state')
49     (False, state') -> state'
50   where upd' (b, s) a = let (b', s') = upd s a in (b || b', s') -- lift to track changes
51
52 -- Try to find a block that is equal (or ``common'') to b.
53 type BidMap = FiniteMap BlockId BlockId
54 type State  = (UniqFM [CmmBlock], BidMap)
55 common_block :: (Outputable h, Uniquable h) =>  State -> (h, CmmBlock) -> (Bool, State)
56 common_block (bmap, subst) (hash, b) =
57   case lookupUFM bmap $ my_trace "common_block" (ppr bid <+> ppr subst <+> ppr hash) $ hash of
58     Just bs -> case (find (eqBlockBodyWith (eqBid subst) b) bs, lookupFM subst bid) of
59                  (Just b', Nothing)                      -> addSubst b'
60                  (Just b', Just b'') | blockId b' /= b'' -> addSubst b'
61                  _ -> (False, (addToUFM bmap hash (b : bs), subst))
62     Nothing -> (False, (addToUFM bmap hash [b], subst))
63   where bid = blockId b
64         addSubst b' = my_trace "found new common block" (ppr (blockId b')) $
65                       (True, (bmap, addToFM subst bid (blockId b')))
66
67 -- Given the map ``subst'' from BlockId -> BlockId, we rewrite the graph.
68 upd_graph :: CmmGraph -> BidMap -> CmmGraph
69 upd_graph g subst = map_nodes id middle last g
70   where middle m = m
71         last (LastBranch bid)       = LastBranch $ sub bid
72         last (LastCondBranch p t f) = cond p (sub t) (sub f)
73         last (LastCall t bid)       = LastCall   t $ liftM sub bid
74         last (LastSwitch e bs)      = LastSwitch e $ map (liftM sub) bs
75         last l = l
76         cond p t f = if t == f then LastBranch t else LastCondBranch p t f
77         sub = lookupBid subst
78
79 -- To speed up comparisons, we hash each basic block modulo labels.
80 -- The hashing is a bit arbitrary (the numbers are completely arbitrary),
81 -- but it should be fast and good enough.
82 hash_block :: CmmBlock -> Int
83 hash_block (Block _ t) = hash_tail t 0
84   where hash_mid   (MidComment (FastString u _ _ _ _)) = u
85         hash_mid   (MidAssign r e) = hash_reg r + hash_e e
86         hash_mid   (MidStore e e') = hash_e e + hash_e e'
87         hash_mid   (MidUnsafeCall t _ as) = hash_tgt t + hash_as as
88         hash_mid   (MidAddToContext e es) = hash_e e + hash_lst hash_e es
89         hash_mid   (CopyIn _ fs _) = hash_fs fs
90         hash_mid   (CopyOut _ as) = hash_as as
91         hash_reg   (CmmLocal l) = hash_local l
92         hash_reg   (CmmGlobal _)    = 19
93         hash_local (LocalReg _ _ _) = 117
94         hash_e (CmmLit l) = hash_lit l
95         hash_e (CmmLoad e _) = 67 + hash_e e
96         hash_e (CmmReg r) = hash_reg r
97         hash_e (CmmMachOp _ es) = hash_lst hash_e es -- pessimal - no operator check
98         hash_e (CmmRegOff r i) = hash_reg r + i
99         hash_e (CmmStackSlot _ _) = 13
100         hash_lit (CmmInt i _) = fromInteger i
101         hash_lit (CmmFloat r _) = truncate r
102         hash_lit (CmmLabel _) = 119 -- ugh
103         hash_lit (CmmLabelOff _ i) = 199 + i
104         hash_lit (CmmLabelDiffOff _ _ i) = 299 + i
105         hash_tgt (CmmCallee e _) = hash_e e
106         hash_tgt (CmmPrim _) = 31 -- lots of these
107         hash_as = hash_lst $ hash_kinded hash_e
108         hash_fs = hash_lst $ hash_kinded hash_local
109         hash_kinded f (CmmKinded x _) = f x
110         hash_lst f = foldl (\z x -> f x + z) 0
111         hash_last (LastBranch _) = 23 -- would be great to hash these properly
112         hash_last (LastCondBranch p _ _) = hash_e p 
113         hash_last LastReturn = 17 -- better ideas?
114         hash_last (LastJump e) = hash_e e
115         hash_last (LastCall e _) = hash_e e
116         hash_last (LastSwitch e _) = hash_e e
117         hash_tail (ZLast LastExit) v = 29 + v * 2
118         hash_tail (ZLast (LastOther l)) v = hash_last l + (v * 2)
119         hash_tail (ZTail m t) v = hash_tail t (hash_mid m + (v * 2))
120
121 -- Utilities: equality and substitution on the graph.
122
123 -- Given a map ``subst'' from BlockID -> BlockID, we define equality.
124 eqBid :: BidMap -> BlockId -> BlockId -> Bool
125 eqBid subst bid bid' = lookupBid subst bid == lookupBid subst bid'
126 lookupBid :: BidMap -> BlockId -> BlockId
127 lookupBid subst bid = case lookupFM subst bid of
128                         Just bid  -> lookupBid subst bid
129                         Nothing -> bid
130
131 -- Equality on the body of a block, modulo a function mapping block IDs to block IDs.
132 eqBlockBodyWith :: (BlockId -> BlockId -> Bool) -> CmmBlock -> CmmBlock -> Bool
133 eqBlockBodyWith eqBid (Block _ t) (Block _ t') = eqTailWith eqBid t t'
134
135 type CmmTail = ZTail Middle Last
136 eqTailWith :: (BlockId -> BlockId -> Bool) -> CmmTail -> CmmTail -> Bool
137 eqTailWith eqBid (ZTail m t) (ZTail m' t') = m == m' && eqTailWith eqBid t t'
138 eqTailWith _ (ZLast LastExit) (ZLast LastExit) = True
139 eqTailWith eqBid (ZLast (LastOther l)) (ZLast (LastOther l')) = eqLastWith eqBid l l'
140 eqTailWith _ _ _ = False
141
142 eqLastWith :: (BlockId -> BlockId -> Bool) -> Last -> Last -> Bool
143 eqLastWith eqBid (LastBranch bid) (LastBranch bid') = eqBid bid bid'
144 eqLastWith eqBid c@(LastCondBranch _ _ _) c'@(LastCondBranch _ _ _) =
145   eqBid (cml_true c) (cml_true c')  && eqBid (cml_false c) (cml_false c') 
146 eqLastWith _ LastReturn LastReturn = True
147 eqLastWith _ (LastJump e) (LastJump e') = e == e'
148 eqLastWith eqBid c@(LastCall _ _) c'@(LastCall _ _) =
149   cml_target c == cml_target c' && eqMaybeWith eqBid (cml_cont c) (cml_cont c')
150 eqLastWith eqBid (LastSwitch e bs) (LastSwitch e' bs') =
151   e == e' && eqLstWith (eqMaybeWith eqBid) bs bs'
152 eqLastWith _ _ _ = False
153
154 eqLstWith :: (a -> b -> Bool) -> [a] -> [b] -> Bool
155 eqLstWith eltEq es es' = all (uncurry eltEq) (List.zip es es')
156
157 eqMaybeWith :: (a -> b -> Bool) -> Maybe a -> Maybe b -> Bool
158 eqMaybeWith eltEq (Just e) (Just e') = eltEq e e'
159 eqMaybeWith _ Nothing Nothing = True
160 eqMaybeWith _ _ _ = False