[project @ 2006-01-06 15:51:23 by simonpj]
[haskell-directory.git] / Data / Map.hs
index 37a9695..d88ceb5 100644 (file)
 --    * J. Nievergelt and E.M. Reingold,
 --     \"/Binary search trees of bounded balance/\",
 --     SIAM journal of computing 2(1), March 1973.
+--
+-- Note that the implementation is /left-biased/ -- the elements of a
+-- first argument are always preferred to the second, for example in
+-- 'union' or 'insert'.
 -----------------------------------------------------------------------------
 
 module Data.Map  ( 
@@ -150,7 +154,11 @@ module Data.Map  (
 import Prelude hiding (lookup,map,filter,foldr,foldl,null)
 import qualified Data.Set as Set
 import qualified Data.List as List
+import Data.Monoid (Monoid(..))
 import Data.Typeable
+import Control.Applicative (Applicative(..))
+import Data.Traversable (Traversable(traverse))
+import Data.Foldable (Foldable(foldMap))
 
 {-
 -- for quick check
@@ -189,6 +197,11 @@ data Map k a  = Tip
 
 type Size     = Int
 
+instance (Ord k) => Monoid (Map k v) where
+    mempty  = empty
+    mappend = union
+    mconcat = unions
+
 #if __GLASGOW_HASKELL__
 
 {--------------------------------------------------------------------
@@ -203,6 +216,7 @@ instance (Data k, Data a, Ord k) => Data (Map k a) where
   toConstr _     = error "toConstr"
   gunfold _ _    = error "gunfold"
   dataTypeOf _   = mkNorepType "Data.Map.Map"
+  dataCast2 f    = gcast2 f
 
 #endif
 
@@ -239,6 +253,16 @@ lookup' k t
                GT -> lookup' k r
                EQ -> Just x       
 
+lookupAssoc :: Ord k => k -> Map k a -> Maybe (k,a)
+lookupAssoc  k t
+  = case t of
+      Tip -> Nothing
+      Bin sz kx x l r
+          -> case compare k kx of
+               LT -> lookupAssoc k l
+               GT -> lookupAssoc k r
+               EQ -> Just (kx,x)
+
 -- | /O(log n)/. Is the key a member of the map?
 member :: Ord k => k -> Map k a -> Bool
 member k m
@@ -295,11 +319,20 @@ insert kx x t
                EQ -> Bin sz kx x l r
 
 -- | /O(log n)/. Insert with a combining function.
+-- @'insertWith' f key value mp@ 
+-- will insert the pair (key, value) into @mp@ if key does
+-- not exist in the map. If the key does exist, the function will
+-- insert the pair @(key, f new_value old_value)@.
 insertWith :: Ord k => (a -> a -> a) -> k -> a -> Map k a -> Map k a
 insertWith f k x m          
   = insertWithKey (\k x y -> f x y) k x m
 
 -- | /O(log n)/. Insert with a combining function.
+-- @'insertWithKey' f key value mp@ 
+-- will insert the pair (key, value) into @mp@ if key does
+-- not exist in the map. If the key does exist, the function will
+-- insert the pair @(key,f key new_value old_value)@.
+-- Note that the key passed to f is the same key passed to 'insertWithKey'.
 insertWithKey :: Ord k => (k -> a -> a -> a) -> k -> a -> Map k a -> Map k a
 insertWithKey f kx x t
   = case t of
@@ -308,7 +341,7 @@ insertWithKey f kx x t
           -> case compare kx ky of
                LT -> balance ky y (insertWithKey f kx x l) r
                GT -> balance ky y l (insertWithKey f kx x r)
-               EQ -> Bin sy ky (f ky x y) l r
+               EQ -> Bin sy kx (f kx x y) l r
 
 -- | /O(log n)/. The expression (@'insertLookupWithKey' f k x map@)
 -- is a pair where the first element is equal to (@'lookup' k map@)
@@ -321,7 +354,7 @@ insertLookupWithKey f kx x t
           -> case compare kx ky of
                LT -> let (found,l') = insertLookupWithKey f kx x l in (found,balance ky y l' r)
                GT -> let (found,r') = insertLookupWithKey f kx x r in (found,balance ky y l r')
-               EQ -> (Just y, Bin sy ky (f ky x y) l r)
+               EQ -> (Just y, Bin sy kx (f kx x y) l r)
 
 {--------------------------------------------------------------------
   Deletion
@@ -525,13 +558,11 @@ unionsWith f ts
 -- It prefers @t1@ when duplicate keys are encountered,
 -- i.e. (@'union' == 'unionWith' 'const'@).
 -- The implementation uses the efficient /hedge-union/ algorithm.
--- Hedge-union is more efficient on (bigset `union` smallset)?
+-- Hedge-union is more efficient on (bigset `union` smallset)
 union :: Ord k => Map k a -> Map k a -> Map k a
 union Tip t2  = t2
 union t1 Tip  = t1
-union t1 t2
-   | size t1 >= size t2  = hedgeUnionL (const LT) (const GT) t1 t2
-   | otherwise           = hedgeUnionR (const LT) (const GT) t2 t1
+union t1 t2 = hedgeUnionL (const LT) (const GT) t1 t2
 
 -- left-biased hedge union
 hedgeUnionL cmplo cmphi t1 Tip 
@@ -558,7 +589,7 @@ hedgeUnionR cmplo cmphi (Bin _ kx x l r) t2
     (found,gt)  = trimLookupLo kx cmphi t2
     newx        = case found of
                     Nothing -> x
-                    Just y  -> y
+                    Just (_,y) -> y
 
 {--------------------------------------------------------------------
   Union with a combining function
@@ -574,11 +605,7 @@ unionWith f m1 m2
 unionWithKey :: Ord k => (k -> a -> a -> a) -> Map k a -> Map k a -> Map k a
 unionWithKey f Tip t2  = t2
 unionWithKey f t1 Tip  = t1
-unionWithKey f t1 t2
-  | size t1 >= size t2  = hedgeUnionWithKey f (const LT) (const GT) t1 t2
-  | otherwise           = hedgeUnionWithKey flipf (const LT) (const GT) t2 t1
-  where
-    flipf k x y   = f k y x
+unionWithKey f t1 t2 = hedgeUnionWithKey f (const LT) (const GT) t1 t2
 
 hedgeUnionWithKey f cmplo cmphi t1 Tip 
   = t1
@@ -593,7 +620,7 @@ hedgeUnionWithKey f cmplo cmphi (Bin _ kx x l r) t2
     (found,gt)  = trimLookupLo kx cmphi t2
     newx        = case found of
                     Nothing -> x
-                    Just y  -> f kx x y
+                    Just (_,y) -> f kx x y
 
 {--------------------------------------------------------------------
   Difference
@@ -638,9 +665,10 @@ hedgeDiffWithKey f cmplo cmphi (Bin _ kx x l r) Tip
 hedgeDiffWithKey f cmplo cmphi t (Bin _ kx x l r) 
   = case found of
       Nothing -> merge tl tr
-      Just y  -> case f kx y x of
-                   Nothing -> merge tl tr
-                   Just z  -> join kx z tl tr
+      Just (ky,y) -> 
+          case f ky y x of
+            Nothing -> merge tl tr
+            Just z  -> join ky z tl tr
   where
     cmpkx k     = compare kx k   
     lt          = trim cmplo cmpkx t
@@ -666,25 +694,40 @@ intersectionWith f m1 m2
 
 -- | /O(n+m)/. Intersection with a combining function.
 -- Intersection is more efficient on (bigset `intersection` smallset)
+--intersectionWithKey :: Ord k => (k -> a -> b -> c) -> Map k a -> Map k b -> Map k c
+--intersectionWithKey f Tip t = Tip
+--intersectionWithKey f t Tip = Tip
+--intersectionWithKey f t1 t2 = intersectWithKey f t1 t2
+--
+--intersectWithKey f Tip t = Tip
+--intersectWithKey f t Tip = Tip
+--intersectWithKey f t (Bin _ kx x l r)
+--  = case found of
+--      Nothing -> merge tl tr
+--      Just y  -> join kx (f kx y x) tl tr
+--  where
+--    (lt,found,gt) = splitLookup kx t
+--    tl            = intersectWithKey f lt l
+--    tr            = intersectWithKey f gt r
+
+
 intersectionWithKey :: Ord k => (k -> a -> b -> c) -> Map k a -> Map k b -> Map k c
 intersectionWithKey f Tip t = Tip
 intersectionWithKey f t Tip = Tip
-intersectionWithKey f t1 t2
-  | size t1 >= size t2  = intersectWithKey f t1 t2
-  | otherwise           = intersectWithKey flipf t2 t1
-  where
-    flipf k x y   = f k y x
-
-intersectWithKey f Tip t = Tip
-intersectWithKey f t Tip = Tip
-intersectWithKey f t (Bin _ kx x l r)
-  = case found of
+intersectionWithKey f t1@(Bin s1 k1 x1 l1 r1) t2@(Bin s2 k2 x2 l2 r2) =
+   if s1 >= s2 then
+      let (lt,found,gt) = splitLookupWithKey k2 t1
+          tl            = intersectionWithKey f lt l2
+          tr            = intersectionWithKey f gt r2
+      in case found of
+      Just (k,x) -> join k (f k x x2) tl tr
+      Nothing -> merge tl tr
+   else let (lt,found,gt) = splitLookup k1 t2
+            tl            = intersectionWithKey f l1 lt
+            tr            = intersectionWithKey f r1 gt
+      in case found of
+      Just x -> join k1 (f k1 x1 x) tl tr
       Nothing -> merge tl tr
-      Just y  -> join kx (f kx y x) tl tr
-  where
-    (lt,found,gt) = splitLookup kx t
-    tl            = intersectWithKey f lt l
-    tr            = intersectWithKey f gt r
 
 
 
@@ -1065,15 +1108,15 @@ trim cmplo cmphi t@(Bin sx kx x l r)
               le -> trim cmplo cmphi l
       ge -> trim cmplo cmphi r
               
-trimLookupLo :: Ord k => k -> (k -> Ordering) -> Map k a -> (Maybe a, Map k a)
+trimLookupLo :: Ord k => k -> (k -> Ordering) -> Map k a -> (Maybe (k,a), Map k a)
 trimLookupLo lo cmphi Tip = (Nothing,Tip)
 trimLookupLo lo cmphi t@(Bin sx kx x l r)
   = case compare lo kx of
       LT -> case cmphi kx of
-              GT -> (lookup lo t, t)
+              GT -> (lookupAssoc lo t, t)
               le -> trimLookupLo lo cmphi l
       GT -> trimLookupLo lo cmphi r
-      EQ -> (Just x,trim (compare lo) cmphi r)
+      EQ -> (Just (kx,x),trim (compare lo) cmphi r)
 
 
 {--------------------------------------------------------------------
@@ -1119,6 +1162,22 @@ splitLookup k (Bin sx kx x l r)
       GT -> let (lt,z,gt) = splitLookup k r in (join kx x l lt,z,gt)
       EQ -> (l,Just x,r)
 
+-- | /O(log n)/.
+splitLookupWithKey :: Ord k => k -> Map k a -> (Map k a,Maybe (k,a),Map k a)
+splitLookupWithKey k Tip = (Tip,Nothing,Tip)
+splitLookupWithKey k (Bin sx kx x l r)
+  = case compare k kx of
+      LT -> let (lt,z,gt) = splitLookupWithKey k l in (lt,z,join kx x gt r)
+      GT -> let (lt,z,gt) = splitLookupWithKey k r in (join kx x l lt,z,gt)
+      EQ -> (l,Just (kx, x),r)
+
+-- | /O(log n)/. Performs a 'split' but also returns whether the pivot
+-- element was found in the original set.
+splitMember :: Ord k => k -> Map k a -> (Map k a,Bool,Map k a)
+splitMember x t = let (l,m,r) = splitLookup x t in
+     (l,maybe False (const True) m,r)
+
+
 {--------------------------------------------------------------------
   Utility functions that maintain the balance properties of the tree.
   All constructors assume that all values in [l] < [k] and all values
@@ -1305,6 +1364,16 @@ instance (Ord k, Ord v) => Ord (Map k v) where
 instance Functor (Map k) where
   fmap f m  = map f m
 
+instance Traversable (Map k) where
+  traverse f Tip = pure Tip
+  traverse f (Bin s k v l r)
+    = flip (Bin s k) <$> traverse f l <*> f v <*> traverse f r
+
+instance Foldable (Map k) where
+  foldMap _f Tip = mempty
+  foldMap f (Bin _s _k v l r)
+    = foldMap f l `mappend` f v `mappend` foldMap f r
+
 {--------------------------------------------------------------------
   Read
 --------------------------------------------------------------------}