[project @ 2003-04-23 14:29:51 by malcolm]
[ghc-base.git] / Data / HashTable.hs
1 {-# OPTIONS -fno-implicit-prelude #-}
2
3 -----------------------------------------------------------------------------
4 -- |
5 -- Module      :  Data.HashTable
6 -- Copyright   :  (c) The University of Glasgow 2003
7 -- License     :  BSD-style (see the file libraries/base/LICENSE)
8 --
9 -- Maintainer  :  libraries@haskell.org
10 -- Stability   :  provisional
11 -- Portability :  portable
12 --
13 -- An implementation of extensible hash tables, as described in
14 -- Per-Ake Larson, /Dynamic Hash Tables/, CACM 31(4), April 1988,
15 -- pp. 446--457.  The implementation is also derived from the one
16 -- in GHC's runtime system (@ghc\/rts\/Hash.{c,h}@).
17 --
18 -----------------------------------------------------------------------------
19
20 module Data.HashTable (
21         -- * Basic hash table operations
22         HashTable, new, insert, delete, lookup,
23         -- * Converting to and from lists
24         fromList, toList,
25         -- * Hash functions
26         -- $hash_functions
27         hashInt, hashString,
28         prime,
29         -- * Diagnostics
30         longestChain
31  ) where
32
33 -- This module is imported by Data.Dynamic, which is pretty low down in the
34 -- module hierarchy, so don't import "high-level" modules
35
36 #ifdef __GLASGOW_HASKELL__
37 import GHC.Base
38 #else
39 import Prelude  hiding  ( lookup )
40 #endif
41 import Data.Tuple       ( fst )
42 import Data.Bits
43 import Data.Maybe
44 import Data.List        ( maximumBy, filter, length, concat )
45 import Data.Int         ( Int32 )
46
47 #if defined(__GLASGOW_HASKELL__)
48 import GHC.Num
49 import GHC.Real         ( Integral(..), fromIntegral )
50
51 import GHC.IOBase       ( IO, IOArray, newIOArray, readIOArray, writeIOArray,
52                           unsafeReadIOArray, unsafeWriteIOArray,
53                           IORef, newIORef, readIORef, writeIORef )
54 import GHC.Err          ( undefined )
55 #else
56 import Data.Char        ( ord )
57 import Data.IORef       ( IORef, newIORef, readIORef, writeIORef )
58 #  if defined(__HUGS__)
59 import Hugs.IOArray     ( IOArray, newIOArray, readIOArray, writeIOArray,
60                           unsafeReadIOArray, unsafeWriteIOArray )
61 #  elif defined(__NHC__)
62 import NHC.IOExtras     ( IOArray, newIOArray, readIOArray, writeIOArray)
63 #  endif
64 #endif
65 import Control.Monad    ( when, mapM, sequence_ )
66
67
68 -----------------------------------------------------------------------
69 myReadArray  :: IOArray Int32 a -> Int32 -> IO a
70 myWriteArray :: IOArray Int32 a -> Int32 -> a -> IO ()
71 #if defined(DEBUG) || defined(__NHC__)
72 myReadArray  = readIOArray
73 myWriteArray = writeIOArray
74 #else
75 myReadArray arr i = unsafeReadIOArray arr (fromIntegral i)
76 myWriteArray arr i x = unsafeWriteIOArray arr (fromIntegral i) x
77 #endif
78
79 -- | A hash table mapping keys of type @key@ to values of type @val@.
80 --
81 -- The implementation will grow the hash table as necessary, trying to
82 -- maintain a reasonable average load per bucket in the table.
83 --
84 newtype HashTable key val = HashTable (IORef (HT key val))
85 -- TODO: the IORef should really be an MVar.
86
87 data HT key val
88   = HT {
89         split  :: !Int32, -- Next bucket to split when expanding
90         max_bucket :: !Int32, -- Max bucket of smaller table
91         mask1  :: !Int32, -- Mask for doing the mod of h_1 (smaller table)
92         mask2  :: !Int32, -- Mask for doing the mod of h_2 (larger table)
93         kcount :: !Int32, -- Number of keys
94         bcount :: !Int32, -- Number of buckets
95         dir    :: !(IOArray Int32 (IOArray Int32 [(key,val)])),
96         hash_fn :: key -> Int32,
97         cmp    :: key -> key -> Bool
98    }
99
100 {-
101 ALTERNATIVE IMPLEMENTATION:
102
103 This works out slightly slower, because there's a tradeoff between
104 allocating a complete new HT structure each time a modification is
105 made (in the version above), and allocating new Int32s each time one
106 of them is modified, as below.  Using FastMutInt instead of IORef
107 Int32 helps, but yields an implementation which has about the same
108 performance as the version above (and is more complex).
109
110 data HashTable key val
111   = HashTable {
112         split  :: !(IORef Int32), -- Next bucket to split when expanding
113         max_bucket :: !(IORef Int32), -- Max bucket of smaller table
114         mask1  :: !(IORef Int32), -- Mask for doing the mod of h_1 (smaller table)
115         mask2  :: !(IORef Int32), -- Mask for doing the mod of h_2 (larger table)
116         kcount :: !(IORef Int32), -- Number of keys
117         bcount :: !(IORef Int32), -- Number of buckets
118         dir    :: !(IOArray Int32 (IOArray Int32 [(key,val)])),
119         hash_fn :: key -> Int32,
120         cmp    :: key -> key -> Bool
121    }
122 -}
123
124
125 -- -----------------------------------------------------------------------------
126 -- Sample hash functions
127
128 -- $hash_functions
129 --
130 -- This implementation of hash tables uses the low-order /n/ bits of the hash
131 -- value for a key, where /n/ varies as the hash table grows.  A good hash
132 -- function therefore will give an even distribution regardless of /n/.
133 --
134 -- If your keyspace is integrals such that the low-order bits between
135 -- keys are highly variable, then you could get away with using 'id'
136 -- as the hash function.
137 --
138 -- We provide some sample hash functions for 'Int' and 'String' below.
139
140 -- | A sample hash function for 'Int', implemented as simply @(x `mod` P)@
141 -- where P is a suitable prime (currently 1500007).  Should give
142 -- reasonable results for most distributions of 'Int' values, except
143 -- when the keys are all multiples of the prime!
144 --
145 hashInt :: Int -> Int32
146 hashInt = (`rem` prime) . fromIntegral
147
148 -- | A sample hash function for 'String's.  The implementation is:
149 --
150 -- >    hashString = fromIntegral . foldr f 0
151 -- >      where f c m = ord c + (m * 128) `rem` 1500007
152 --
153 -- which seems to give reasonable results.
154 --
155 hashString :: String -> Int32
156 hashString = fromIntegral . foldr f 0
157   where f c m = ord c + (m * 128) `rem` fromIntegral prime
158
159 -- | A prime larger than the maximum hash table size
160 prime = 1500007 :: Int32
161
162 -- -----------------------------------------------------------------------------
163 -- Parameters
164
165 sEGMENT_SIZE  = 1024  :: Int32  -- Size of a single hash table segment
166 sEGMENT_SHIFT = 10    :: Int  -- derived
167 sEGMENT_MASK  = 0x3ff :: Int32  -- derived
168
169 dIR_SIZE = 1024  :: Int32  -- Size of the segment directory
170         -- Maximum hash table size is sEGMENT_SIZE * dIR_SIZE
171
172 hLOAD = 4 :: Int32 -- Maximum average load of a single hash bucket
173
174 -- -----------------------------------------------------------------------------
175 -- Creating a new hash table
176
177 -- | Creates a new hash table
178 new
179   :: (key -> key -> Bool)    -- ^ An equality comparison on keys
180   -> (key -> Int32)          -- ^ A hash function on keys
181   -> IO (HashTable key val)  -- ^ Returns: an empty hash table
182
183 new cmp hash_fn = do
184   -- make a new hash table with a single, empty, segment
185   dir     <- newIOArray (0,dIR_SIZE) undefined
186   segment <- newIOArray (0,sEGMENT_SIZE-1) []
187   myWriteArray dir 0 segment
188
189   let
190     split  = 0
191     max    = sEGMENT_SIZE
192     mask1  = (sEGMENT_SIZE - 1)
193     mask2  = (2 * sEGMENT_SIZE - 1)
194     kcount = 0
195     bcount = sEGMENT_SIZE
196
197     ht = HT {  dir=dir, split=split, max_bucket=max, mask1=mask1, mask2=mask2,
198                kcount=kcount, bcount=bcount, hash_fn=hash_fn, cmp=cmp
199           }
200   
201   table <- newIORef ht
202   return (HashTable table)
203
204 -- -----------------------------------------------------------------------------
205 -- Inserting a key\/value pair into the hash table
206
207 -- | Inserts an key\/value mapping into the hash table.
208 insert :: HashTable key val -> key -> val -> IO ()
209
210 insert (HashTable ref) key val = do
211   table@HT{ kcount=k, bcount=b, dir=dir } <- readIORef ref
212   let table1 = table{ kcount = k+1 }
213   table2 <-
214         if (k > hLOAD * b)
215            then expandHashTable table1
216            else return table1
217   writeIORef ref table2
218   (segment_index,segment_offset) <- tableLocation table key
219   segment <- myReadArray dir segment_index
220   bucket <- myReadArray segment segment_offset
221   myWriteArray segment segment_offset ((key,val):bucket)
222   return ()
223
224 bucketIndex :: HT key val -> key -> IO Int32
225 bucketIndex HT{ hash_fn=hash_fn,
226                 split=split,
227                 mask1=mask1,
228                 mask2=mask2 } key = do
229   let
230     h = fromIntegral (hash_fn key)
231     small_bucket = h .&. mask1
232     large_bucket = h .&. mask2
233   --
234   if small_bucket < split
235         then return large_bucket
236         else return small_bucket
237
238 tableLocation :: HT key val -> key -> IO (Int32,Int32)
239 tableLocation table key = do
240   bucket_index <- bucketIndex table key
241   let
242     segment_index  = bucket_index `shiftR` sEGMENT_SHIFT
243     segment_offset = bucket_index .&. sEGMENT_MASK
244   --
245   return (segment_index,segment_offset)
246
247 expandHashTable :: HT key val -> IO (HT key val)
248 expandHashTable
249       table@HT{ dir=dir,
250                 split=split,
251                 max_bucket=max,
252                 mask2=mask2 } = do
253   let
254       oldsegment = split `shiftR` sEGMENT_SHIFT
255       oldindex   = split .&. sEGMENT_MASK
256
257       newbucket  = max + split
258       newsegment = newbucket `shiftR` sEGMENT_SHIFT
259       newindex   = newbucket .&. sEGMENT_MASK
260   --
261   when (newindex == 0) $
262         do segment <- newIOArray (0,sEGMENT_SIZE-1) []
263            myWriteArray dir newsegment segment
264   --
265   let table' =
266         if (split+1) < max
267             then table{ split = split+1 }
268                 -- we've expanded all the buckets in this table, so start from
269                 -- the beginning again.
270             else table{ split = 0,
271                         max_bucket = max * 2,
272                         mask1 = mask2,
273                         mask2 = mask2 `shiftL` 1 .|. 1 }
274   let
275     split_bucket old new [] = do
276         segment <- myReadArray dir oldsegment
277         myWriteArray segment oldindex old
278         segment <- myReadArray dir newsegment
279         myWriteArray segment newindex new
280     split_bucket old new ((k,v):xs) = do
281         h <- bucketIndex table' k
282         if h == newbucket
283                 then split_bucket old ((k,v):new) xs
284                 else split_bucket ((k,v):old) new xs
285   --
286   segment <- myReadArray dir oldsegment
287   bucket <- myReadArray segment oldindex
288   split_bucket [] [] bucket
289   return table'
290
291 -- -----------------------------------------------------------------------------
292 -- Deleting a mapping from the hash table
293
294 -- | Remove an entry from the hash table.
295 delete :: HashTable key val -> key -> IO ()
296
297 delete (HashTable ref) key = do
298   table@HT{ dir=dir, cmp=cmp } <- readIORef ref
299   (segment_index,segment_offset) <- tableLocation table key
300   segment <- myReadArray dir segment_index
301   bucket <- myReadArray segment segment_offset
302   myWriteArray segment segment_offset (filter (not.(key `cmp`).fst) bucket)
303   return ()
304
305 -- -----------------------------------------------------------------------------
306 -- Looking up an entry in the hash table
307
308 -- | Looks up the value of a key in the hash table.
309 lookup :: HashTable key val -> key -> IO (Maybe val)
310
311 lookup (HashTable ref) key = do
312   table@HT{ dir=dir, cmp=cmp } <- readIORef ref
313   (segment_index,segment_offset) <- tableLocation table key
314   segment <- myReadArray dir segment_index
315   bucket <- myReadArray segment segment_offset
316   case [ val | (key',val) <- bucket, cmp key key' ] of
317         [] -> return Nothing
318         (v:_) -> return (Just v)
319
320 -- -----------------------------------------------------------------------------
321 -- Converting to/from lists
322
323 -- | Convert a list of key\/value pairs into a hash table.  Equality on keys
324 -- is taken from the Eq instance for the key type.
325 --
326 fromList :: Eq key => (key -> Int32) -> [(key,val)] -> IO (HashTable key val)
327 fromList hash_fn list = do
328   table <- new (==) hash_fn
329   sequence_ [ insert table k v | (k,v) <- list ]
330   return table
331
332 -- | Converts a hash table to a list of key\/value pairs.
333 --
334 toList :: HashTable key val -> IO [(key,val)]
335 toList (HashTable ref) = do
336   HT{ dir=dir, max_bucket=max, split=split } <- readIORef ref
337   --
338   let
339     max_segment = (max + split - 1) `quot` sEGMENT_SIZE
340   --
341   segments <- mapM (segmentContents dir) [0 .. max_segment]
342   return (concat segments)
343  where
344    segmentContents dir seg_index = do
345      segment <- myReadArray dir seg_index
346      bs <- mapM (myReadArray segment) [0 .. sEGMENT_SIZE-1]
347      return (concat bs)
348
349 -- -----------------------------------------------------------------------------
350 -- Diagnostics
351
352 -- | This function is useful for determining whether your hash function
353 -- is working well for your data set.  It returns the longest chain
354 -- of key\/value pairs in the hash table for which all the keys hash to
355 -- the same bucket.  If this chain is particularly long (say, longer
356 -- than 10 elements), then it might be a good idea to try a different
357 -- hash function.
358 --
359 longestChain :: HashTable key val -> IO [(key,val)]
360 longestChain (HashTable ref) = do
361   HT{ dir=dir, max_bucket=max, split=split } <- readIORef ref
362   --
363   let
364     max_segment = (max + split - 1) `quot` sEGMENT_SIZE
365   --
366   --trace ("maxChainLength: max = " ++ show max ++ ", split = " ++ show split ++ ", max_segment = " ++ show max_segment) $ do
367   segments <- mapM (segmentMaxChainLength dir) [0 .. max_segment]
368   return (maximumBy lengthCmp segments)
369  where
370    segmentMaxChainLength dir seg_index = do
371      segment <- myReadArray dir seg_index
372      bs <- mapM (myReadArray segment) [0 .. sEGMENT_SIZE-1]
373      return (maximumBy lengthCmp bs)
374
375    lengthCmp x y = length x `compare` length y