[project @ 2003-04-21 16:32:39 by ross]
[ghc-base.git] / Data / HashTable.hs
1 {-# OPTIONS -fno-implicit-prelude #-}
2
3 -----------------------------------------------------------------------------
4 -- |
5 -- Module      :  Data.HashTable
6 -- Copyright   :  (c) The University of Glasgow 2003
7 -- License     :  BSD-style (see the file libraries/base/LICENSE)
8 --
9 -- Maintainer  :  libraries@haskell.org
10 -- Stability   :  provisional
11 -- Portability :  portable
12 --
13 -- An implementation of extensible hash tables, as described in
14 -- Per-Ake Larson, /Dynamic Hash Tables/, CACM 31(4), April 1988,
15 -- pp. 446--457.  The implementation is also derived from the one
16 -- in GHC's runtime system (@ghc\/rts\/Hash.{c,h}@).
17 --
18 -----------------------------------------------------------------------------
19
20 module Data.HashTable (
21         -- * Basic hash table operations
22         HashTable, new, insert, delete, lookup,
23         -- * Converting to and from lists
24         fromList, toList,
25         -- * Hash functions
26         -- $hash_functions
27         hashInt, hashString,
28         prime,
29         -- * Diagnostics
30         longestChain
31  ) where
32
33 -- This module is imported by Data.Dynamic, which is pretty low down in the
34 -- module hierarchy, so don't import "high-level" modules
35
36 #ifdef __GLASGOW_HASKELL__
37 import GHC.Base
38 #else
39 import Prelude  hiding  ( lookup )
40 #endif
41 import Data.Tuple       ( fst )
42 import Data.Bits
43 import Data.Maybe
44 import Data.List        ( maximumBy, filter, length, concat )
45 import Data.Int         ( Int32 )
46
47 #if defined(__GLASGOW_HASKELL__)
48 import GHC.Num
49 import GHC.Real         ( Integral(..), fromIntegral )
50
51 import GHC.IOBase       ( IO, IOArray, newIOArray, readIOArray, writeIOArray,
52                           unsafeReadIOArray, unsafeWriteIOArray,
53                           IORef, newIORef, readIORef, writeIORef )
54 import GHC.Err          ( undefined )
55 #elif defined(__HUGS__)
56 import Data.Char        ( ord )
57 import Hugs.IOArray     ( IOArray, newIOArray, readIOArray, writeIOArray,
58                           unsafeReadIOArray, unsafeWriteIOArray )
59 import Data.IORef       ( IORef, newIORef, readIORef, writeIORef )
60 #endif
61 import Control.Monad    ( when, mapM, sequence_ )
62
63 -----------------------------------------------------------------------
64 myReadArray  :: IOArray Int32 a -> Int32 -> IO a
65 myWriteArray :: IOArray Int32 a -> Int32 -> a -> IO ()
66 #ifdef DEBUG
67 myReadArray  = readIOArray
68 myWriteArray = writeIOArray
69 #else
70 myReadArray arr i = unsafeReadIOArray arr (fromIntegral i)
71 myWriteArray arr i x = unsafeWriteIOArray arr (fromIntegral i) x
72 #endif
73
74 -- | A hash table mapping keys of type @key@ to values of type @val@.
75 --
76 -- The implementation will grow the hash table as necessary, trying to
77 -- maintain a reasonable average load per bucket in the table.
78 --
79 newtype HashTable key val = HashTable (IORef (HT key val))
80 -- TODO: the IORef should really be an MVar.
81
82 data HT key val
83   = HT {
84         split  :: !Int32, -- Next bucket to split when expanding
85         max_bucket :: !Int32, -- Max bucket of smaller table
86         mask1  :: !Int32, -- Mask for doing the mod of h_1 (smaller table)
87         mask2  :: !Int32, -- Mask for doing the mod of h_2 (larger table)
88         kcount :: !Int32, -- Number of keys
89         bcount :: !Int32, -- Number of buckets
90         dir    :: !(IOArray Int32 (IOArray Int32 [(key,val)])),
91         hash_fn :: key -> Int32,
92         cmp    :: key -> key -> Bool
93    }
94
95 {-
96 ALTERNATIVE IMPLEMENTATION:
97
98 This works out slightly slower, because there's a tradeoff between
99 allocating a complete new HT structure each time a modification is
100 made (in the version above), and allocating new Int32s each time one
101 of them is modified, as below.  Using FastMutInt instead of IORef
102 Int32 helps, but yields an implementation which has about the same
103 performance as the version above (and is more complex).
104
105 data HashTable key val
106   = HashTable {
107         split  :: !(IORef Int32), -- Next bucket to split when expanding
108         max_bucket :: !(IORef Int32), -- Max bucket of smaller table
109         mask1  :: !(IORef Int32), -- Mask for doing the mod of h_1 (smaller table)
110         mask2  :: !(IORef Int32), -- Mask for doing the mod of h_2 (larger table)
111         kcount :: !(IORef Int32), -- Number of keys
112         bcount :: !(IORef Int32), -- Number of buckets
113         dir    :: !(IOArray Int32 (IOArray Int32 [(key,val)])),
114         hash_fn :: key -> Int32,
115         cmp    :: key -> key -> Bool
116    }
117 -}
118
119
120 -- -----------------------------------------------------------------------------
121 -- Sample hash functions
122
123 -- $hash_functions
124 --
125 -- This implementation of hash tables uses the low-order /n/ bits of the hash
126 -- value for a key, where /n/ varies as the hash table grows.  A good hash
127 -- function therefore will give an even distribution regardless of /n/.
128 --
129 -- If your keyspace is integrals such that the low-order bits between
130 -- keys are highly variable, then you could get away with using 'id'
131 -- as the hash function.
132 --
133 -- We provide some sample hash functions for 'Int' and 'String' below.
134
135 -- | A sample hash function for 'Int', implemented as simply @(x `mod` P)@
136 -- where P is a suitable prime (currently 1500007).  Should give
137 -- reasonable results for most distributions of 'Int' values, except
138 -- when the keys are all multiples of the prime!
139 --
140 hashInt :: Int -> Int32
141 hashInt = (`rem` prime) . fromIntegral
142
143 -- | A sample hash function for 'String's.  The implementation is:
144 --
145 -- >    hashString = fromIntegral . foldr f 0
146 -- >      where f c m = ord c + (m * 128) `rem` 1500007
147 --
148 -- which seems to give reasonable results.
149 --
150 hashString :: String -> Int32
151 hashString = fromIntegral . foldr f 0
152   where f c m = ord c + (m * 128) `rem` fromIntegral prime
153
154 -- | A prime larger than the maximum hash table size
155 prime = 1500007 :: Int32
156
157 -- -----------------------------------------------------------------------------
158 -- Parameters
159
160 sEGMENT_SIZE  = 1024  :: Int32  -- Size of a single hash table segment
161 sEGMENT_SHIFT = 10    :: Int  -- derived
162 sEGMENT_MASK  = 0x3ff :: Int32  -- derived
163
164 dIR_SIZE = 1024  :: Int32  -- Size of the segment directory
165         -- Maximum hash table size is sEGMENT_SIZE * dIR_SIZE
166
167 hLOAD = 4 :: Int32 -- Maximum average load of a single hash bucket
168
169 -- -----------------------------------------------------------------------------
170 -- Creating a new hash table
171
172 -- | Creates a new hash table
173 new
174   :: (key -> key -> Bool)    -- ^ An equality comparison on keys
175   -> (key -> Int32)          -- ^ A hash function on keys
176   -> IO (HashTable key val)  -- ^ Returns: an empty hash table
177
178 new cmp hash_fn = do
179   -- make a new hash table with a single, empty, segment
180   dir     <- newIOArray (0,dIR_SIZE) undefined
181   segment <- newIOArray (0,sEGMENT_SIZE-1) []
182   myWriteArray dir 0 segment
183
184   let
185     split  = 0
186     max    = sEGMENT_SIZE
187     mask1  = (sEGMENT_SIZE - 1)
188     mask2  = (2 * sEGMENT_SIZE - 1)
189     kcount = 0
190     bcount = sEGMENT_SIZE
191
192     ht = HT {  dir=dir, split=split, max_bucket=max, mask1=mask1, mask2=mask2,
193                kcount=kcount, bcount=bcount, hash_fn=hash_fn, cmp=cmp
194           }
195   
196   table <- newIORef ht
197   return (HashTable table)
198
199 -- -----------------------------------------------------------------------------
200 -- Inserting a key\/value pair into the hash table
201
202 -- | Inserts an key\/value mapping into the hash table.
203 insert :: HashTable key val -> key -> val -> IO ()
204
205 insert (HashTable ref) key val = do
206   table@HT{ kcount=k, bcount=b, dir=dir } <- readIORef ref
207   let table1 = table{ kcount = k+1 }
208   table2 <-
209         if (k > hLOAD * b)
210            then expandHashTable table1
211            else return table1
212   writeIORef ref table2
213   (segment_index,segment_offset) <- tableLocation table key
214   segment <- myReadArray dir segment_index
215   bucket <- myReadArray segment segment_offset
216   myWriteArray segment segment_offset ((key,val):bucket)
217   return ()
218
219 bucketIndex :: HT key val -> key -> IO Int32
220 bucketIndex HT{ hash_fn=hash_fn,
221                 split=split,
222                 mask1=mask1,
223                 mask2=mask2 } key = do
224   let
225     h = fromIntegral (hash_fn key)
226     small_bucket = h .&. mask1
227     large_bucket = h .&. mask2
228   --
229   if small_bucket < split
230         then return large_bucket
231         else return small_bucket
232
233 tableLocation :: HT key val -> key -> IO (Int32,Int32)
234 tableLocation table key = do
235   bucket_index <- bucketIndex table key
236   let
237     segment_index  = bucket_index `shiftR` sEGMENT_SHIFT
238     segment_offset = bucket_index .&. sEGMENT_MASK
239   --
240   return (segment_index,segment_offset)
241
242 expandHashTable :: HT key val -> IO (HT key val)
243 expandHashTable
244       table@HT{ dir=dir,
245                 split=split,
246                 max_bucket=max,
247                 mask2=mask2 } = do
248   let
249       oldsegment = split `shiftR` sEGMENT_SHIFT
250       oldindex   = split .&. sEGMENT_MASK
251
252       newbucket  = max + split
253       newsegment = newbucket `shiftR` sEGMENT_SHIFT
254       newindex   = newbucket .&. sEGMENT_MASK
255   --
256   when (newindex == 0) $
257         do segment <- newIOArray (0,sEGMENT_SIZE-1) []
258            myWriteArray dir newsegment segment
259   --
260   let table' =
261         if (split+1) < max
262             then table{ split = split+1 }
263                 -- we've expanded all the buckets in this table, so start from
264                 -- the beginning again.
265             else table{ split = 0,
266                         max_bucket = max * 2,
267                         mask1 = mask2,
268                         mask2 = mask2 `shiftL` 1 .|. 1 }
269   let
270     split_bucket old new [] = do
271         segment <- myReadArray dir oldsegment
272         myWriteArray segment oldindex old
273         segment <- myReadArray dir newsegment
274         myWriteArray segment newindex new
275     split_bucket old new ((k,v):xs) = do
276         h <- bucketIndex table' k
277         if h == newbucket
278                 then split_bucket old ((k,v):new) xs
279                 else split_bucket ((k,v):old) new xs
280   --
281   segment <- myReadArray dir oldsegment
282   bucket <- myReadArray segment oldindex
283   split_bucket [] [] bucket
284   return table'
285
286 -- -----------------------------------------------------------------------------
287 -- Deleting a mapping from the hash table
288
289 -- | Remove an entry from the hash table.
290 delete :: HashTable key val -> key -> IO ()
291
292 delete (HashTable ref) key = do
293   table@HT{ dir=dir, cmp=cmp } <- readIORef ref
294   (segment_index,segment_offset) <- tableLocation table key
295   segment <- myReadArray dir segment_index
296   bucket <- myReadArray segment segment_offset
297   myWriteArray segment segment_offset (filter (not.(key `cmp`).fst) bucket)
298   return ()
299
300 -- -----------------------------------------------------------------------------
301 -- Looking up an entry in the hash table
302
303 -- | Looks up the value of a key in the hash table.
304 lookup :: HashTable key val -> key -> IO (Maybe val)
305
306 lookup (HashTable ref) key = do
307   table@HT{ dir=dir, cmp=cmp } <- readIORef ref
308   (segment_index,segment_offset) <- tableLocation table key
309   segment <- myReadArray dir segment_index
310   bucket <- myReadArray segment segment_offset
311   case [ val | (key',val) <- bucket, cmp key key' ] of
312         [] -> return Nothing
313         (v:_) -> return (Just v)
314
315 -- -----------------------------------------------------------------------------
316 -- Converting to/from lists
317
318 -- | Convert a list of key\/value pairs into a hash table.  Equality on keys
319 -- is taken from the Eq instance for the key type.
320 --
321 fromList :: Eq key => (key -> Int32) -> [(key,val)] -> IO (HashTable key val)
322 fromList hash_fn list = do
323   table <- new (==) hash_fn
324   sequence_ [ insert table k v | (k,v) <- list ]
325   return table
326
327 -- | Converts a hash table to a list of key\/value pairs.
328 --
329 toList :: HashTable key val -> IO [(key,val)]
330 toList (HashTable ref) = do
331   HT{ dir=dir, max_bucket=max, split=split } <- readIORef ref
332   --
333   let
334     max_segment = (max + split - 1) `quot` sEGMENT_SIZE
335   --
336   segments <- mapM (segmentContents dir) [0 .. max_segment]
337   return (concat segments)
338  where
339    segmentContents dir seg_index = do
340      segment <- myReadArray dir seg_index
341      bs <- mapM (myReadArray segment) [0 .. sEGMENT_SIZE-1]
342      return (concat bs)
343
344 -- -----------------------------------------------------------------------------
345 -- Diagnostics
346
347 -- | This function is useful for determining whether your hash function
348 -- is working well for your data set.  It returns the longest chain
349 -- of key\/value pairs in the hash table for which all the keys hash to
350 -- the same bucket.  If this chain is particularly long (say, longer
351 -- than 10 elements), then it might be a good idea to try a different
352 -- hash function.
353 --
354 longestChain :: HashTable key val -> IO [(key,val)]
355 longestChain (HashTable ref) = do
356   HT{ dir=dir, max_bucket=max, split=split } <- readIORef ref
357   --
358   let
359     max_segment = (max + split - 1) `quot` sEGMENT_SIZE
360   --
361   --trace ("maxChainLength: max = " ++ show max ++ ", split = " ++ show split ++ ", max_segment = " ++ show max_segment) $ do
362   segments <- mapM (segmentMaxChainLength dir) [0 .. max_segment]
363   return (maximumBy lengthCmp segments)
364  where
365    segmentMaxChainLength dir seg_index = do
366      segment <- myReadArray dir seg_index
367      bs <- mapM (myReadArray segment) [0 .. sEGMENT_SIZE-1]
368      return (maximumBy lengthCmp bs)
369
370    lengthCmp x y = length x `compare` length y