Add array fusion versions of map, filter and foldl
[haskell-directory.git] / Data / ByteString.hs
index 76e84d5..86ec26a 100644 (file)
@@ -6,6 +6,11 @@
 --               (c) Simon Marlow 2005
 --               (c) Don Stewart 2005-2006
 --               (c) Bjorn Bringert 2006
+--
+-- Array fusion code:
+--               (c) 2001,2002 Manuel M T Chakravarty & Gabriele Keller
+--                      (c) 2006      Manuel M T Chakravarty & Roman Leshchinskiy
+--
 -- License     : BSD-style
 --
 -- Maintainer  : dons@cse.unsw.edu.au
@@ -18,8 +23,8 @@
 -- | A time and space-efficient implementation of byte vectors using
 -- packed Word8 arrays, suitable for high performance use, both in terms
 -- of large data quantities, or high speed requirements. Byte vectors
--- are encoded as Word8 arrays of bytes, held in a ForeignPtr, and can
--- be passed between C and Haskell with little effort.
+-- are encoded as strict Word8 arrays of bytes, held in a ForeignPtr,
+-- and can be passed between C and Haskell with little effort.
 --
 -- This module is intended to be imported @qualified@, to avoid name
 -- clashes with Prelude functions.  eg.
@@ -216,11 +221,15 @@ module Data.ByteString (
         hGet,                   -- :: Handle -> Int -> IO ByteString
         hPut,                   -- :: Handle -> ByteString -> IO ()
 
+        -- * Fusion utilities
 #if defined(__GLASGOW_HASKELL__)
-        -- * Miscellaneous
         unpackList, -- eek, otherwise it gets thrown away by the simplifier
 #endif
 
+        noAL, NoAL, loopArr, loopAcc, loopSndAcc,
+        loopU, mapEFL, filterEFL, foldEFL,
+        filterF, mapF
+
   ) where
 
 import qualified Prelude as P
@@ -330,19 +339,24 @@ instance Arbitrary PackedString where
 
 -- | /O(n)/ Equality on the 'ByteString' type.
 eq :: ByteString -> ByteString -> Bool
-eq a b = (compareBytes a b) == EQ
+eq a@(PS p s l) b@(PS p' s' l')
+    | l /= l'            = False    -- short cut on length
+    | p == p' && s == s' = True     -- short cut for the same string
+    | otherwise          = compareBytes a b == EQ
 {-# INLINE eq #-}
 
 -- | /O(n)/ 'compareBytes' provides an 'Ordering' for 'ByteStrings' supporting slices. 
 compareBytes :: ByteString -> ByteString -> Ordering
-compareBytes (PS _ _ 0) (PS _ _ 0)       = EQ    -- short cut for empty strings
-compareBytes (PS x1 s1 l1) (PS x2 s2 l2) = inlinePerformIO $
-    withForeignPtr x1 $ \p1 ->
-    withForeignPtr x2 $ \p2 -> do
-        i <- memcmp (p1 `plusPtr` s1) (p2 `plusPtr` s2) (min l1 l2)
-        return $ case i `compare` 0 of
-                    EQ  -> l1 `compare` l2
-                    x   -> x
+compareBytes (PS x1 s1 l1) (PS x2 s2 l2)
+    | l1 == 0  && l2 == 0               = EQ  -- short cut for empty strings
+    | x1 == x2 && s1 == s2 && l1 == l2  = EQ  -- short cut for the same string
+    | otherwise                         = inlinePerformIO $
+        withForeignPtr x1 $ \p1 ->
+        withForeignPtr x2 $ \p2 -> do
+            i <- memcmp (p1 `plusPtr` s1) (p2 `plusPtr` s2) (min l1 l2)
+            return $ case i `compare` 0 of
+                        EQ  -> l1 `compare` l2
+                        x   -> x
 {-# INLINE compareBytes #-}
 
 {-
@@ -501,6 +515,8 @@ cons c (PS x s l) = create (l+1) $ \p -> withForeignPtr x $ \f -> do
         poke p c
 {-# INLINE cons #-}
 
+-- todo fuse
+
 -- | /O(n)/ Append a byte to the end of a 'ByteString'
 snoc :: ByteString -> Word8 -> ByteString
 snoc (PS x s l) c = create (l+1) $ \p -> withForeignPtr x $ \f -> do
@@ -508,6 +524,8 @@ snoc (PS x s l) c = create (l+1) $ \p -> withForeignPtr x $ \f -> do
         poke (p `plusPtr` l) c
 {-# INLINE snoc #-}
 
+-- todo fuse
+
 -- | /O(1)/ Extract the first element of a ByteString, which must be non-empty.
 head :: ByteString -> Word8
 head ps@(PS x s _)
@@ -563,25 +581,30 @@ append xs@(PS ffp s l) ys@(PS fgp t m)
 -- Transformations
 
 -- | /O(n)/ 'map' @f xs@ is the ByteString obtained by applying @f@ to each
--- element of @xs@
---
+-- element of @xs@. This function is subject to array fusion.
 map :: (Word8 -> Word8) -> ByteString -> ByteString
-map f (PS fp start len) = inlinePerformIO $ withForeignPtr fp $ \p -> do
-    new_fp <- mallocByteString len
-    withForeignPtr new_fp $ \new_p -> do
-        map_ f (len-1) (p `plusPtr` start) new_p
-        return (PS new_fp 0 len)
+map f = loopArr . loopU (mapEFL f) noAL
 {-# INLINE map #-}
 
-map_ :: (Word8 -> Word8) -> Int -> Ptr Word8 -> Ptr Word8 -> IO ()
-STRICT4(map_)
-map_ f n p1 p2
-   | n < 0 = return ()
-   | otherwise = do
-        x <- peekByteOff p1 n
-        pokeByteOff p2 n (f x)
-        map_ f (n-1) p1 p2
-{-# INLINE map_ #-}
+-- | /O(n)/ Like 'map', but not fuseable. The benefit is that it is
+-- slightly faster for one-shot cases.
+mapF :: (Word8 -> Word8) -> ByteString -> ByteString
+STRICT2(mapF)
+mapF f (PS fp s len) = inlinePerformIO $ withForeignPtr fp $ \a -> do
+    np <- mallocByteString (len+1)
+    withForeignPtr np $ \p -> do
+        map_ 0 (a `plusPtr` s) p
+        return (PS np 0 len)
+  where
+    map_ :: Int -> Ptr Word8 -> Ptr Word8 -> IO ()
+    STRICT3(map_)
+    map_ n p1 p2
+       | n >= len = return ()
+       | otherwise = do
+            x <- peekByteOff p1 n
+            pokeByteOff p2 n (f x)
+            map_ (n+1) p1 p2
+{-# INLINE mapF #-}
 
 -- | /O(n)/ 'reverse' @xs@ efficiently returns the elements of @xs@ in reverse order.
 reverse :: ByteString -> ByteString
@@ -617,7 +640,16 @@ transpose ps = P.map pack (List.transpose (P.map unpack ps))
 -- | 'foldl', applied to a binary operator, a starting value (typically
 -- the left-identity of the operator), and a ByteString, reduces the
 -- ByteString using the binary operator, from left to right.
+-- This function is subject to array fusion.
 foldl :: (a -> Word8 -> a) -> a -> ByteString -> a
+foldl f z = loopAcc . loopU (foldEFL f) z
+{-# INLINE foldl #-}
+
+{-
+--
+-- About twice as fast with 6.4.1, but not fuseable
+-- A simple fold . map is enough to make it worth while.
+--
 foldl f v (PS x s l) = inlinePerformIO $ withForeignPtr x $ \ptr ->
         lgo v (ptr `plusPtr` s) (ptr `plusPtr` (s+l))
     where
@@ -625,6 +657,7 @@ foldl f v (PS x s l) = inlinePerformIO $ withForeignPtr x $ \ptr ->
         lgo z p q | p == q    = return z
                   | otherwise = do c <- peek p
                                    lgo (f z c) (p `plusPtr` 1) q
+-}
 
 -- | 'foldr', applied to a binary operator, a starting value
 -- (typically the right-identity of the operator), and a ByteString,
@@ -641,6 +674,7 @@ foldr k z (PS x s l) = inlinePerformIO $ withForeignPtr x $ \ptr ->
 
 -- | 'foldl1' is a variant of 'foldl' that has no starting value
 -- argument, and thus must be applied to non-empty 'ByteStrings'.
+-- This function is subject to array fusion.
 foldl1 :: (Word8 -> Word8 -> Word8) -> ByteString -> Word8
 foldl1 f ps
     | null ps   = errorEmptyList "foldl1"
@@ -698,6 +732,8 @@ any f (PS x s l) = inlinePerformIO $ withForeignPtr x $ \ptr ->
                                 if f c then return True
                                        else go (p `plusPtr` 1) q
 
+-- todo fuse
+
 -- | /O(n)/ Applied to a predicate and a 'ByteString', 'all' determines
 -- if all elements of the 'ByteString' satisfy the predicate.
 all :: (Word8 -> Bool) -> ByteString -> Bool
@@ -711,6 +747,7 @@ all f (PS x s l) = inlinePerformIO $ withForeignPtr x $ \ptr ->
                                  if f c
                                     then go (p `plusPtr` 1) q
                                     else return False
+-- todo fuse
 
 -- | /O(n)/ 'maximum' returns the maximum value from a 'ByteString'
 maximum :: ByteString -> Word8
@@ -728,6 +765,8 @@ minimum xs@(PS x s l)
                     return $ c_minimum (p `plusPtr` s) l
 {-# INLINE minimum #-}
 
+-- fusion is too slow here (10x)
+
 {-
 maximum xs@(PS x s l)
     | null xs   = errorEmptyList "maximum"
@@ -1232,7 +1271,7 @@ elem c ps = case elemIndex c ps of Nothing -> False ; _ -> True
 
 -- | /O(n)/ 'notElem' is the inverse of 'elem'
 notElem :: Word8 -> ByteString -> Bool
-notElem c ps = case elemIndex c ps of Nothing -> True ; _ -> False
+notElem c ps = not (elem c ps)
 {-# INLINE notElem #-}
 
 --
@@ -1247,23 +1286,6 @@ notElem c ps = case elemIndex c ps of Nothing -> True ; _ -> False
 filterByte :: Word8 -> ByteString -> ByteString
 filterByte w ps = replicate (count w ps) w
 
-{-
--- slower than the replicate version
-
-filterByte ch ps@(PS x s l)
-    | null ps   = ps
-    | otherwise = inlinePerformIO $ generate l $ \p -> withForeignPtr x $ \f -> do
-        t <- go (f `plusPtr` s) p l
-        return (t `minusPtr` p) -- actual length
-    where
-        STRICT3(go)
-        go _ t 0 = return t
-        go f t e = do w <- peek f
-                      if w == ch
-                        then poke t w >> go (f `plusPtr` 1) (t `plusPtr` 1) (e-1)
-                        else             go (f `plusPtr` 1) t               (e-1)
--}
-
 --
 -- | /O(n)/ A first order equivalent of /filter . (\/=)/, for the common
 -- case of filtering a single byte out of a list. It is more efficient
@@ -1289,9 +1311,13 @@ filterNotByte ch ps@(PS x s l)
 
 -- | /O(n)/ 'filter', applied to a predicate and a ByteString,
 -- returns a ByteString containing those characters that satisfy the
--- predicate.
+-- predicate. This function is subject to array fusion.
 filter :: (Word8 -> Bool) -> ByteString -> ByteString
-filter k ps@(PS x s l)
+filter p  = loopArr . loopU (filterEFL p) noAL
+{-# INLINE filter #-}
+
+filterF :: (Word8 -> Bool) -> ByteString -> ByteString
+filterF k ps@(PS x s l)
     | null ps   = ps
     | otherwise = inlinePerformIO $ generate l $ \p -> withForeignPtr x $ \f -> do
         t <- go (f `plusPtr` s) p l
@@ -1303,6 +1329,7 @@ filter k ps@(PS x s l)
                       if k w
                         then poke t w >> go (f `plusPtr` 1) (t `plusPtr` 1) (e - 1)
                         else             go (f `plusPtr` 1) t               (e - 1)
+{-# INLINE filterF #-}
 
 -- Almost as good: pack $ foldl (\xs c -> if f c then c : xs else xs) [] ps
 
@@ -1673,7 +1700,7 @@ unsafeUseAsCStringLen (PS ps s l) ac = withForeignPtr ps $ \p -> ac (castPtr p `
 --
 generate :: Int -> (Ptr Word8 -> IO Int) -> IO ByteString
 generate i f = do
-    p <- mallocArray i
+    p <- mallocArray (i+1)
     i' <- f p
     p' <- reallocArray p (i'+1)
     poke (p' `plusPtr` i') (0::Word8)    -- XXX so CStrings work
@@ -2081,3 +2108,118 @@ foreign import ccall unsafe "RtsAPI.h getProgArgv"
 foreign import ccall unsafe "__hscore_memcpy_src_off"
    memcpy_ptr_baoff :: Ptr a -> RawBuffer -> Int -> CSize -> IO (Ptr ())
 #endif
+
+-- ---------------------------------------------------------------------
+--
+-- Functional array fusion for ByteStrings. 
+--
+-- From the Data Parallel Haskell project, 
+--      http://www.cse.unsw.edu.au/~chak/project/dph/
+--
+
+-- |Data type for accumulators which can be ignored. The rewrite rules rely on
+-- the fact that no bottoms of this type are ever constructed; hence, we can
+-- assume @(_ :: NoAL) `seq` x = x@.
+--
+data NoAL = NoAL
+
+-- | Special forms of loop arguments
+--
+-- * These are common special cases for the three function arguments of gen
+--   and loop; we give them special names to make it easier to trigger RULES
+--   applying in the special cases represented by these arguments.  The
+--   "INLINE [1]" makes sure that these functions are only inlined in the last
+--   two simplifier phases.
+--
+-- * In the case where the accumulator is not needed, it is better to always
+--   explicitly return a value `()', rather than just copy the input to the
+--   output, as the former gives GHC better local information.
+-- 
+
+-- | Element function expressing a mapping only
+mapEFL :: (Word8 -> Word8) -> (NoAL -> Word8 -> (NoAL, Maybe Word8))
+mapEFL f = \_ e -> (noAL, (Just $ f e))
+{-# INLINE [1] mapEFL #-}
+
+-- | Element function implementing a filter function only
+filterEFL :: (Word8 -> Bool) -> (NoAL -> Word8 -> (NoAL, Maybe Word8))
+filterEFL p = \_ e -> if p e then (noAL, Just e) else (noAL, Nothing)
+{-# INLINE [1] filterEFL #-}
+
+-- |Element function expressing a reduction only
+foldEFL :: (acc -> Word8 -> acc) -> (acc -> Word8 -> (acc, Maybe Word8))
+foldEFL f = \a e -> (f a e, Nothing)
+{-# INLINE [1] foldEFL #-}
+
+-- | No accumulator
+noAL :: NoAL
+noAL = NoAL
+{-# INLINE [1] noAL #-}
+
+-- | Projection functions that are fusion friendly (as in, we determine when
+-- they are inlined)
+loopArr :: (ByteString, acc) -> ByteString
+loopArr (arr, _) = arr
+{-# INLINE [1] loopArr #-}
+
+loopAcc :: (ByteString, acc) -> acc
+loopAcc (_, acc) = acc
+{-# INLINE [1] loopAcc #-}
+
+loopSndAcc :: (ByteString, (acc1, acc2)) -> (ByteString, acc2)
+loopSndAcc (arr, (_, acc)) = (arr, acc)
+{-# INLINE [1] loopSndAcc #-}
+
+------------------------------------------------------------------------
+
+-- | Iteration over over ByteStrings
+loopU :: (acc -> Word8 -> (acc, Maybe Word8))  -- ^ mapping & folding, once per elem
+      -> acc                                   -- ^ initial acc value
+      -> ByteString                            -- ^ input ByteString
+      -> (ByteString, acc)
+
+loopU f start (PS fp s i) = inlinePerformIO $ withForeignPtr fp $ \a -> do
+    p <- mallocArray (i+1)
+    (acc, i') <- go (a `plusPtr` s) p start
+    p' <- if i == i' then return p else reallocArray p (i'+1) -- avoid realloc for maps
+    poke (p' `plusPtr` i') (0::Word8)
+    fp' <- newForeignFreePtr p'
+    return (PS fp' 0 i', acc)
+  where
+    go p ma = trans 0 0
+        where
+            STRICT3(trans)
+            trans a_off ma_off acc
+                | a_off >= i = return (acc, ma_off)
+                | otherwise  = do
+                    x <- peekByteOff p a_off
+                    let (acc', oe) = f acc x
+                    ma_off' <- case oe of
+                        Nothing  -> return ma_off
+                        Just e   -> do pokeByteOff ma ma_off e
+                                       return $ ma_off + 1
+                    trans (a_off+1) ma_off' acc'
+
+{-# INLINE [1] loopU #-}
+
+{-# RULES
+
+"array fusion!" forall em1 em2 start1 start2 arr.
+  loopU em2 start2 (loopArr (loopU em1 start1 arr)) =
+    let em (acc1, acc2) e =
+            case em1 acc1 e of
+                (acc1', Nothing) -> ((acc1', acc2), Nothing)
+                (acc1', Just e') ->
+                    case em2 acc2 e' of
+                        (acc2', res) -> ((acc1', acc2'), res)
+    in loopSndAcc (loopU em (start1, start2) arr)
+
+"loopArr/loopSndAcc" forall x.
+  loopArr (loopSndAcc x) = loopArr x
+
+-- orphan?
+-- "seq/NoAL" forall (u::NoAL) e.
+--   u `seq` e = e
+
+ #-}
+