[project @ 2006-01-02 19:38:01 by jpbernardy]
[haskell-directory.git] / Data / Map.hs
index 96dc045..beddb7b 100644 (file)
 --    * J. Nievergelt and E.M. Reingold,
 --     \"/Binary search trees of bounded balance/\",
 --     SIAM journal of computing 2(1), March 1973.
+--
+-- Note that the implementation is /left-biased/ -- the elements of a
+-- first argument are always preferred to the second, for example in
+-- 'union' or 'insert'.
 -----------------------------------------------------------------------------
 
 module Data.Map  ( 
@@ -249,6 +253,16 @@ lookup' k t
                GT -> lookup' k r
                EQ -> Just x       
 
+lookupAssoc :: Ord k => k -> Map k a -> Maybe (k,a)
+lookupAssoc  k t
+  = case t of
+      Tip -> Nothing
+      Bin sz kx x l r
+          -> case compare k kx of
+               LT -> lookupAssoc k l
+               GT -> lookupAssoc k r
+               EQ -> Just (kx,x)
+
 -- | /O(log n)/. Is the key a member of the map?
 member :: Ord k => k -> Map k a -> Bool
 member k m
@@ -308,7 +322,7 @@ insert kx x t
 -- @'insertWith' f key value mp@ 
 -- will insert the pair (key, value) into @mp@ if key does
 -- not exist in the map. If the key does exist, the function will
--- insert @f new_value old_value@.
+-- insert the pair @(key, f new_value old_value)@.
 insertWith :: Ord k => (a -> a -> a) -> k -> a -> Map k a -> Map k a
 insertWith f k x m          
   = insertWithKey (\k x y -> f x y) k x m
@@ -317,7 +331,8 @@ insertWith f k x m
 -- @'insertWithKey' f key value mp@ 
 -- will insert the pair (key, value) into @mp@ if key does
 -- not exist in the map. If the key does exist, the function will
--- insert @f key new_value old_value@.
+-- insert the pair @(key,f key new_value old_value)@.
+-- Note that the key passed to f is the same key passed to 'insertWithKey'.
 insertWithKey :: Ord k => (k -> a -> a -> a) -> k -> a -> Map k a -> Map k a
 insertWithKey f kx x t
   = case t of
@@ -326,7 +341,7 @@ insertWithKey f kx x t
           -> case compare kx ky of
                LT -> balance ky y (insertWithKey f kx x l) r
                GT -> balance ky y l (insertWithKey f kx x r)
-               EQ -> Bin sy ky (f ky x y) l r
+               EQ -> Bin sy kx (f kx x y) l r
 
 -- | /O(log n)/. The expression (@'insertLookupWithKey' f k x map@)
 -- is a pair where the first element is equal to (@'lookup' k map@)
@@ -339,7 +354,7 @@ insertLookupWithKey f kx x t
           -> case compare kx ky of
                LT -> let (found,l') = insertLookupWithKey f kx x l in (found,balance ky y l' r)
                GT -> let (found,r') = insertLookupWithKey f kx x r in (found,balance ky y l r')
-               EQ -> (Just y, Bin sy ky (f ky x y) l r)
+               EQ -> (Just y, Bin sy kx (f kx x y) l r)
 
 {--------------------------------------------------------------------
   Deletion
@@ -543,13 +558,11 @@ unionsWith f ts
 -- It prefers @t1@ when duplicate keys are encountered,
 -- i.e. (@'union' == 'unionWith' 'const'@).
 -- The implementation uses the efficient /hedge-union/ algorithm.
--- Hedge-union is more efficient on (bigset `union` smallset)?
+-- Hedge-union is more efficient on (bigset `union` smallset)
 union :: Ord k => Map k a -> Map k a -> Map k a
 union Tip t2  = t2
 union t1 Tip  = t1
-union t1 t2
-   | size t1 >= size t2  = hedgeUnionL (const LT) (const GT) t1 t2
-   | otherwise           = hedgeUnionR (const LT) (const GT) t2 t1
+union t1 t2 = hedgeUnionL (const LT) (const GT) t1 t2
 
 -- left-biased hedge union
 hedgeUnionL cmplo cmphi t1 Tip 
@@ -576,7 +589,7 @@ hedgeUnionR cmplo cmphi (Bin _ kx x l r) t2
     (found,gt)  = trimLookupLo kx cmphi t2
     newx        = case found of
                     Nothing -> x
-                    Just y  -> y
+                    Just (_,y) -> y
 
 {--------------------------------------------------------------------
   Union with a combining function
@@ -592,11 +605,7 @@ unionWith f m1 m2
 unionWithKey :: Ord k => (k -> a -> a -> a) -> Map k a -> Map k a -> Map k a
 unionWithKey f Tip t2  = t2
 unionWithKey f t1 Tip  = t1
-unionWithKey f t1 t2
-  | size t1 >= size t2  = hedgeUnionWithKey f (const LT) (const GT) t1 t2
-  | otherwise           = hedgeUnionWithKey flipf (const LT) (const GT) t2 t1
-  where
-    flipf k x y   = f k y x
+unionWithKey f t1 t2 = hedgeUnionWithKey f (const LT) (const GT) t1 t2
 
 hedgeUnionWithKey f cmplo cmphi t1 Tip 
   = t1
@@ -611,7 +620,7 @@ hedgeUnionWithKey f cmplo cmphi (Bin _ kx x l r) t2
     (found,gt)  = trimLookupLo kx cmphi t2
     newx        = case found of
                     Nothing -> x
-                    Just y  -> f kx x y
+                    Just (_,y) -> f kx x y
 
 {--------------------------------------------------------------------
   Difference
@@ -656,9 +665,10 @@ hedgeDiffWithKey f cmplo cmphi (Bin _ kx x l r) Tip
 hedgeDiffWithKey f cmplo cmphi t (Bin _ kx x l r) 
   = case found of
       Nothing -> merge tl tr
-      Just y  -> case f kx y x of
-                   Nothing -> merge tl tr
-                   Just z  -> join kx z tl tr
+      Just (ky,y) -> 
+          case f ky y x of
+            Nothing -> merge tl tr
+            Just z  -> join ky z tl tr
   where
     cmpkx k     = compare kx k   
     lt          = trim cmplo cmpkx t
@@ -684,25 +694,40 @@ intersectionWith f m1 m2
 
 -- | /O(n+m)/. Intersection with a combining function.
 -- Intersection is more efficient on (bigset `intersection` smallset)
+--intersectionWithKey :: Ord k => (k -> a -> b -> c) -> Map k a -> Map k b -> Map k c
+--intersectionWithKey f Tip t = Tip
+--intersectionWithKey f t Tip = Tip
+--intersectionWithKey f t1 t2 = intersectWithKey f t1 t2
+--
+--intersectWithKey f Tip t = Tip
+--intersectWithKey f t Tip = Tip
+--intersectWithKey f t (Bin _ kx x l r)
+--  = case found of
+--      Nothing -> merge tl tr
+--      Just y  -> join kx (f kx y x) tl tr
+--  where
+--    (lt,found,gt) = splitLookup kx t
+--    tl            = intersectWithKey f lt l
+--    tr            = intersectWithKey f gt r
+
+
 intersectionWithKey :: Ord k => (k -> a -> b -> c) -> Map k a -> Map k b -> Map k c
 intersectionWithKey f Tip t = Tip
 intersectionWithKey f t Tip = Tip
-intersectionWithKey f t1 t2
-  | size t1 >= size t2  = intersectWithKey f t1 t2
-  | otherwise           = intersectWithKey flipf t2 t1
-  where
-    flipf k x y   = f k y x
-
-intersectWithKey f Tip t = Tip
-intersectWithKey f t Tip = Tip
-intersectWithKey f t (Bin _ kx x l r)
-  = case found of
+intersectionWithKey f t1@(Bin s1 k1 x1 l1 r1) t2@(Bin s2 k2 x2 l2 r2) =
+   if s1 >= s2 then
+      let (lt,found,gt) = splitLookupWithKey k2 t1
+          tl            = intersectionWithKey f lt l2
+          tr            = intersectionWithKey f gt r2
+      in case found of
+      Just (k,x) -> join k (f k x x2) tl tr
+      Nothing -> merge tl tr
+   else let (lt,found,gt) = splitLookup k1 t2
+            tl            = intersectionWithKey f l1 lt
+            tr            = intersectionWithKey f r1 gt
+      in case found of
+      Just x -> join k1 (f k1 x1 x) tl tr
       Nothing -> merge tl tr
-      Just y  -> join kx (f kx y x) tl tr
-  where
-    (lt,found,gt) = splitLookup kx t
-    tl            = intersectWithKey f lt l
-    tr            = intersectWithKey f gt r
 
 
 
@@ -1083,15 +1108,15 @@ trim cmplo cmphi t@(Bin sx kx x l r)
               le -> trim cmplo cmphi l
       ge -> trim cmplo cmphi r
               
-trimLookupLo :: Ord k => k -> (k -> Ordering) -> Map k a -> (Maybe a, Map k a)
+trimLookupLo :: Ord k => k -> (k -> Ordering) -> Map k a -> (Maybe (k,a), Map k a)
 trimLookupLo lo cmphi Tip = (Nothing,Tip)
 trimLookupLo lo cmphi t@(Bin sx kx x l r)
   = case compare lo kx of
       LT -> case cmphi kx of
-              GT -> (lookup lo t, t)
+              GT -> (lookupAssoc lo t, t)
               le -> trimLookupLo lo cmphi l
       GT -> trimLookupLo lo cmphi r
-      EQ -> (Just x,trim (compare lo) cmphi r)
+      EQ -> (Just (kx,x),trim (compare lo) cmphi r)
 
 
 {--------------------------------------------------------------------
@@ -1137,6 +1162,22 @@ splitLookup k (Bin sx kx x l r)
       GT -> let (lt,z,gt) = splitLookup k r in (join kx x l lt,z,gt)
       EQ -> (l,Just x,r)
 
+-- | /O(log n)/.
+splitLookupWithKey :: Ord k => k -> Map k a -> (Map k a,Maybe (k,a),Map k a)
+splitLookupWithKey k Tip = (Tip,Nothing,Tip)
+splitLookupWithKey k (Bin sx kx x l r)
+  = case compare k kx of
+      LT -> let (lt,z,gt) = splitLookupWithKey k l in (lt,z,join kx x gt r)
+      GT -> let (lt,z,gt) = splitLookupWithKey k r in (join kx x l lt,z,gt)
+      EQ -> (l,Just (kx, x),r)
+
+-- | /O(log n)/. Performs a 'split' but also returns whether the pivot
+-- element was found in the original set.
+splitMember :: Ord k => k -> Map k a -> (Map k a,Bool,Map k a)
+splitMember x t = let (l,m,r) = splitLookup x t in
+     (l,maybe False (const True) m,r)
+
+
 {--------------------------------------------------------------------
   Utility functions that maintain the balance properties of the tree.
   All constructors assume that all values in [l] < [k] and all values