66ee2b1249594f738f1cfe7bc2fe77191fcea3bf
[ghc-base.git] / Data / HashTable.hs
1 {-# OPTIONS -fno-implicit-prelude #-}
2
3 -----------------------------------------------------------------------------
4 -- |
5 -- Module      :  Data.HashTable
6 -- Copyright   :  (c) The University of Glasgow 2003
7 -- License     :  BSD-style (see the file libraries/base/LICENSE)
8 --
9 -- Maintainer  :  libraries@haskell.org
10 -- Stability   :  provisional
11 -- Portability :  portable
12 --
13 -- An implementation of extensible hash tables, as described in
14 -- Per-Ake Larson, /Dynamic Hash Tables/, CACM 31(4), April 1988,
15 -- pp. 446--457.  The implementation is also derived from the one
16 -- in GHC's runtime system (@ghc\/rts\/Hash.{c,h}@).
17 --
18 -----------------------------------------------------------------------------
19
20 module Data.HashTable (
21         -- * Basic hash table operations
22         HashTable, new, insert, delete, lookup,
23         -- * Converting to and from lists
24         fromList, toList,
25         -- * Hash functions
26         -- $hash_functions
27         hashInt, hashString,
28         prime,
29         -- * Diagnostics
30         longestChain
31  ) where
32
33 -- This module is imported by Data.Dynamic, which is pretty low down in the
34 -- module hierarchy, so don't import "high-level" modules
35
36 #ifdef __GLASGOW_HASKELL__
37 import GHC.Base
38 #else
39 import Prelude  hiding  ( lookup )
40 #endif
41 import Data.Tuple       ( fst )
42 import Data.Bits
43 import Data.Maybe
44 import Data.List        ( maximumBy, filter, length, concat )
45 import Data.Int         ( Int32 )
46
47 #if defined(__GLASGOW_HASKELL__)
48 import GHC.Num
49 import GHC.Real         ( Integral(..), fromIntegral )
50
51 import GHC.IOBase       ( IO, IOArray, newIOArray, readIOArray, writeIOArray,
52                           unsafeReadIOArray, unsafeWriteIOArray,
53                           IORef, newIORef, readIORef, writeIORef )
54 import GHC.Err          ( undefined )
55 #else
56 import Data.Char        ( ord )
57 import Data.IORef       ( IORef, newIORef, readIORef, writeIORef )
58 #  if defined(__HUGS__)
59 import Hugs.IOArray     ( IOArray, newIOArray, readIOArray, writeIOArray,
60                           unsafeReadIOArray, unsafeWriteIOArray )
61 #  elif defined(__NHC__)
62 import NHC.IOExtras     ( IOArray, newIOArray, readIOArray, writeIOArray)
63 #  endif
64 #endif
65 import Control.Monad    ( when, mapM, sequence_ )
66
67
68 -----------------------------------------------------------------------
69 myReadArray  :: IOArray Int32 a -> Int32 -> IO a
70 myWriteArray :: IOArray Int32 a -> Int32 -> a -> IO ()
71 #if defined(DEBUG) || defined(__NHC__)
72 myReadArray  = readIOArray
73 myWriteArray = writeIOArray
74 #else
75 myReadArray arr i = unsafeReadIOArray arr (fromIntegral i)
76 myWriteArray arr i x = unsafeWriteIOArray arr (fromIntegral i) x
77 #endif
78
79 -- | A hash table mapping keys of type @key@ to values of type @val@.
80 --
81 -- The implementation will grow the hash table as necessary, trying to
82 -- maintain a reasonable average load per bucket in the table.
83 --
84 newtype HashTable key val = HashTable (IORef (HT key val))
85 -- TODO: the IORef should really be an MVar.
86
87 data HT key val
88   = HT {
89         split  :: !Int32, -- Next bucket to split when expanding
90         max_bucket :: !Int32, -- Max bucket of smaller table
91         mask1  :: !Int32, -- Mask for doing the mod of h_1 (smaller table)
92         mask2  :: !Int32, -- Mask for doing the mod of h_2 (larger table)
93         kcount :: !Int32, -- Number of keys
94         bcount :: !Int32, -- Number of buckets
95         dir    :: !(IOArray Int32 (IOArray Int32 [(key,val)])),
96         hash_fn :: key -> Int32,
97         cmp    :: key -> key -> Bool
98    }
99
100 {-
101 ALTERNATIVE IMPLEMENTATION:
102
103 This works out slightly slower, because there's a tradeoff between
104 allocating a complete new HT structure each time a modification is
105 made (in the version above), and allocating new Int32s each time one
106 of them is modified, as below.  Using FastMutInt instead of IORef
107 Int32 helps, but yields an implementation which has about the same
108 performance as the version above (and is more complex).
109
110 data HashTable key val
111   = HashTable {
112         split  :: !(IORef Int32), -- Next bucket to split when expanding
113         max_bucket :: !(IORef Int32), -- Max bucket of smaller table
114         mask1  :: !(IORef Int32), -- Mask for doing the mod of h_1 (smaller table)
115         mask2  :: !(IORef Int32), -- Mask for doing the mod of h_2 (larger table)
116         kcount :: !(IORef Int32), -- Number of keys
117         bcount :: !(IORef Int32), -- Number of buckets
118         dir    :: !(IOArray Int32 (IOArray Int32 [(key,val)])),
119         hash_fn :: key -> Int32,
120         cmp    :: key -> key -> Bool
121    }
122 -}
123
124
125 -- -----------------------------------------------------------------------------
126 -- Sample hash functions
127
128 -- $hash_functions
129 --
130 -- This implementation of hash tables uses the low-order /n/ bits of the hash
131 -- value for a key, where /n/ varies as the hash table grows.  A good hash
132 -- function therefore will give an even distribution regardless of /n/.
133 --
134 -- If your keyspace is integrals such that the low-order bits between
135 -- keys are highly variable, then you could get away with using 'id'
136 -- as the hash function.
137 --
138 -- We provide some sample hash functions for 'Int' and 'String' below.
139
140 -- | A sample hash function for 'Int', implemented as simply @(x `mod` P)@
141 -- where P is a suitable prime (currently 1500007).  Should give
142 -- reasonable results for most distributions of 'Int' values, except
143 -- when the keys are all multiples of the prime!
144 --
145 hashInt :: Int -> Int32
146 hashInt = (`rem` prime) . fromIntegral
147
148 -- | A sample hash function for 'String's.  The implementation is:
149 --
150 -- >    hashString = fromIntegral . foldr f 0
151 -- >      where f c m = ord c + (m * 128) `rem` 1500007
152 --
153 -- which seems to give reasonable results.
154 --
155 hashString :: String -> Int32
156 hashString = fromIntegral . foldr f 0
157   where f c m = ord c + (m * 128) `rem` fromIntegral prime
158
159 -- | A prime larger than the maximum hash table size
160 prime :: Int32
161 prime = 1500007
162
163 -- -----------------------------------------------------------------------------
164 -- Parameters
165
166 sEGMENT_SIZE  = 1024  :: Int32  -- Size of a single hash table segment
167 sEGMENT_SHIFT = 10    :: Int  -- derived
168 sEGMENT_MASK  = 0x3ff :: Int32  -- derived
169
170 dIR_SIZE = 1024  :: Int32  -- Size of the segment directory
171         -- Maximum hash table size is sEGMENT_SIZE * dIR_SIZE
172
173 hLOAD = 4 :: Int32 -- Maximum average load of a single hash bucket
174
175 -- -----------------------------------------------------------------------------
176 -- Creating a new hash table
177
178 -- | Creates a new hash table
179 new
180   :: (key -> key -> Bool)    -- ^ An equality comparison on keys
181   -> (key -> Int32)          -- ^ A hash function on keys
182   -> IO (HashTable key val)  -- ^ Returns: an empty hash table
183
184 new cmp hash_fn = do
185   -- make a new hash table with a single, empty, segment
186   dir     <- newIOArray (0,dIR_SIZE) undefined
187   segment <- newIOArray (0,sEGMENT_SIZE-1) []
188   myWriteArray dir 0 segment
189
190   let
191     split  = 0
192     max    = sEGMENT_SIZE
193     mask1  = (sEGMENT_SIZE - 1)
194     mask2  = (2 * sEGMENT_SIZE - 1)
195     kcount = 0
196     bcount = sEGMENT_SIZE
197
198     ht = HT {  dir=dir, split=split, max_bucket=max, mask1=mask1, mask2=mask2,
199                kcount=kcount, bcount=bcount, hash_fn=hash_fn, cmp=cmp
200           }
201   
202   table <- newIORef ht
203   return (HashTable table)
204
205 -- -----------------------------------------------------------------------------
206 -- Inserting a key\/value pair into the hash table
207
208 -- | Inserts an key\/value mapping into the hash table.
209 insert :: HashTable key val -> key -> val -> IO ()
210
211 insert (HashTable ref) key val = do
212   table@HT{ kcount=k, bcount=b, dir=dir } <- readIORef ref
213   let table1 = table{ kcount = k+1 }
214   table2 <-
215         if (k > hLOAD * b)
216            then expandHashTable table1
217            else return table1
218   writeIORef ref table2
219   (segment_index,segment_offset) <- tableLocation table key
220   segment <- myReadArray dir segment_index
221   bucket <- myReadArray segment segment_offset
222   myWriteArray segment segment_offset ((key,val):bucket)
223   return ()
224
225 bucketIndex :: HT key val -> key -> IO Int32
226 bucketIndex HT{ hash_fn=hash_fn,
227                 split=split,
228                 mask1=mask1,
229                 mask2=mask2 } key = do
230   let
231     h = fromIntegral (hash_fn key)
232     small_bucket = h .&. mask1
233     large_bucket = h .&. mask2
234   --
235   if small_bucket < split
236         then return large_bucket
237         else return small_bucket
238
239 tableLocation :: HT key val -> key -> IO (Int32,Int32)
240 tableLocation table key = do
241   bucket_index <- bucketIndex table key
242   let
243     segment_index  = bucket_index `shiftR` sEGMENT_SHIFT
244     segment_offset = bucket_index .&. sEGMENT_MASK
245   --
246   return (segment_index,segment_offset)
247
248 expandHashTable :: HT key val -> IO (HT key val)
249 expandHashTable
250       table@HT{ dir=dir,
251                 split=split,
252                 max_bucket=max,
253                 mask2=mask2 } = do
254   let
255       oldsegment = split `shiftR` sEGMENT_SHIFT
256       oldindex   = split .&. sEGMENT_MASK
257
258       newbucket  = max + split
259       newsegment = newbucket `shiftR` sEGMENT_SHIFT
260       newindex   = newbucket .&. sEGMENT_MASK
261   --
262   when (newindex == 0) $
263         do segment <- newIOArray (0,sEGMENT_SIZE-1) []
264            myWriteArray dir newsegment segment
265   --
266   let table' =
267         if (split+1) < max
268             then table{ split = split+1 }
269                 -- we've expanded all the buckets in this table, so start from
270                 -- the beginning again.
271             else table{ split = 0,
272                         max_bucket = max * 2,
273                         mask1 = mask2,
274                         mask2 = mask2 `shiftL` 1 .|. 1 }
275   let
276     split_bucket old new [] = do
277         segment <- myReadArray dir oldsegment
278         myWriteArray segment oldindex old
279         segment <- myReadArray dir newsegment
280         myWriteArray segment newindex new
281     split_bucket old new ((k,v):xs) = do
282         h <- bucketIndex table' k
283         if h == newbucket
284                 then split_bucket old ((k,v):new) xs
285                 else split_bucket ((k,v):old) new xs
286   --
287   segment <- myReadArray dir oldsegment
288   bucket <- myReadArray segment oldindex
289   split_bucket [] [] bucket
290   return table'
291
292 -- -----------------------------------------------------------------------------
293 -- Deleting a mapping from the hash table
294
295 -- | Remove an entry from the hash table.
296 delete :: HashTable key val -> key -> IO ()
297
298 delete (HashTable ref) key = do
299   table@HT{ dir=dir, cmp=cmp } <- readIORef ref
300   (segment_index,segment_offset) <- tableLocation table key
301   segment <- myReadArray dir segment_index
302   bucket <- myReadArray segment segment_offset
303   myWriteArray segment segment_offset (filter (not.(key `cmp`).fst) bucket)
304   return ()
305
306 -- -----------------------------------------------------------------------------
307 -- Looking up an entry in the hash table
308
309 -- | Looks up the value of a key in the hash table.
310 lookup :: HashTable key val -> key -> IO (Maybe val)
311
312 lookup (HashTable ref) key = do
313   table@HT{ dir=dir, cmp=cmp } <- readIORef ref
314   (segment_index,segment_offset) <- tableLocation table key
315   segment <- myReadArray dir segment_index
316   bucket <- myReadArray segment segment_offset
317   case [ val | (key',val) <- bucket, cmp key key' ] of
318         [] -> return Nothing
319         (v:_) -> return (Just v)
320
321 -- -----------------------------------------------------------------------------
322 -- Converting to/from lists
323
324 -- | Convert a list of key\/value pairs into a hash table.  Equality on keys
325 -- is taken from the Eq instance for the key type.
326 --
327 fromList :: Eq key => (key -> Int32) -> [(key,val)] -> IO (HashTable key val)
328 fromList hash_fn list = do
329   table <- new (==) hash_fn
330   sequence_ [ insert table k v | (k,v) <- list ]
331   return table
332
333 -- | Converts a hash table to a list of key\/value pairs.
334 --
335 toList :: HashTable key val -> IO [(key,val)]
336 toList (HashTable ref) = do
337   HT{ dir=dir, max_bucket=max, split=split } <- readIORef ref
338   --
339   let
340     max_segment = (max + split - 1) `quot` sEGMENT_SIZE
341   --
342   segments <- mapM (segmentContents dir) [0 .. max_segment]
343   return (concat segments)
344  where
345    segmentContents dir seg_index = do
346      segment <- myReadArray dir seg_index
347      bs <- mapM (myReadArray segment) [0 .. sEGMENT_SIZE-1]
348      return (concat bs)
349
350 -- -----------------------------------------------------------------------------
351 -- Diagnostics
352
353 -- | This function is useful for determining whether your hash function
354 -- is working well for your data set.  It returns the longest chain
355 -- of key\/value pairs in the hash table for which all the keys hash to
356 -- the same bucket.  If this chain is particularly long (say, longer
357 -- than 10 elements), then it might be a good idea to try a different
358 -- hash function.
359 --
360 longestChain :: HashTable key val -> IO [(key,val)]
361 longestChain (HashTable ref) = do
362   HT{ dir=dir, max_bucket=max, split=split } <- readIORef ref
363   --
364   let
365     max_segment = (max + split - 1) `quot` sEGMENT_SIZE
366   --
367   --trace ("maxChainLength: max = " ++ show max ++ ", split = " ++ show split ++ ", max_segment = " ++ show max_segment) $ do
368   segments <- mapM (segmentMaxChainLength dir) [0 .. max_segment]
369   return (maximumBy lengthCmp segments)
370  where
371    segmentMaxChainLength dir seg_index = do
372      segment <- myReadArray dir seg_index
373      bs <- mapM (myReadArray segment) [0 .. sEGMENT_SIZE-1]
374      return (maximumBy lengthCmp bs)
375
376    lengthCmp x y = length x `compare` length y