Add array fusion versions of map, filter and foldl
authordons@cse.unsw.edu.au <unknown>
Fri, 5 May 2006 06:08:58 +0000 (06:08 +0000)
committerdons@cse.unsw.edu.au <unknown>
Fri, 5 May 2006 06:08:58 +0000 (06:08 +0000)
This patch adds fusable map, filter and foldl, using the array fusion
code for unlifted, flat arrays, from the Data Parallel Haskell branch,
after kind help from Roman Leshchinskiy,

Pipelines of maps, filters and folds should now need to walk the
bytestring once only, and intermediate bytestrings won't be constructed.

Data/ByteString.hs
Data/ByteString/Char8.hs

index 76e84d5..86ec26a 100644 (file)
@@ -6,6 +6,11 @@
 --               (c) Simon Marlow 2005
 --               (c) Don Stewart 2005-2006
 --               (c) Bjorn Bringert 2006
+--
+-- Array fusion code:
+--               (c) 2001,2002 Manuel M T Chakravarty & Gabriele Keller
+--                      (c) 2006      Manuel M T Chakravarty & Roman Leshchinskiy
+--
 -- License     : BSD-style
 --
 -- Maintainer  : dons@cse.unsw.edu.au
@@ -18,8 +23,8 @@
 -- | A time and space-efficient implementation of byte vectors using
 -- packed Word8 arrays, suitable for high performance use, both in terms
 -- of large data quantities, or high speed requirements. Byte vectors
--- are encoded as Word8 arrays of bytes, held in a ForeignPtr, and can
--- be passed between C and Haskell with little effort.
+-- are encoded as strict Word8 arrays of bytes, held in a ForeignPtr,
+-- and can be passed between C and Haskell with little effort.
 --
 -- This module is intended to be imported @qualified@, to avoid name
 -- clashes with Prelude functions.  eg.
@@ -216,11 +221,15 @@ module Data.ByteString (
         hGet,                   -- :: Handle -> Int -> IO ByteString
         hPut,                   -- :: Handle -> ByteString -> IO ()
 
+        -- * Fusion utilities
 #if defined(__GLASGOW_HASKELL__)
-        -- * Miscellaneous
         unpackList, -- eek, otherwise it gets thrown away by the simplifier
 #endif
 
+        noAL, NoAL, loopArr, loopAcc, loopSndAcc,
+        loopU, mapEFL, filterEFL, foldEFL,
+        filterF, mapF
+
   ) where
 
 import qualified Prelude as P
@@ -330,19 +339,24 @@ instance Arbitrary PackedString where
 
 -- | /O(n)/ Equality on the 'ByteString' type.
 eq :: ByteString -> ByteString -> Bool
-eq a b = (compareBytes a b) == EQ
+eq a@(PS p s l) b@(PS p' s' l')
+    | l /= l'            = False    -- short cut on length
+    | p == p' && s == s' = True     -- short cut for the same string
+    | otherwise          = compareBytes a b == EQ
 {-# INLINE eq #-}
 
 -- | /O(n)/ 'compareBytes' provides an 'Ordering' for 'ByteStrings' supporting slices. 
 compareBytes :: ByteString -> ByteString -> Ordering
-compareBytes (PS _ _ 0) (PS _ _ 0)       = EQ    -- short cut for empty strings
-compareBytes (PS x1 s1 l1) (PS x2 s2 l2) = inlinePerformIO $
-    withForeignPtr x1 $ \p1 ->
-    withForeignPtr x2 $ \p2 -> do
-        i <- memcmp (p1 `plusPtr` s1) (p2 `plusPtr` s2) (min l1 l2)
-        return $ case i `compare` 0 of
-                    EQ  -> l1 `compare` l2
-                    x   -> x
+compareBytes (PS x1 s1 l1) (PS x2 s2 l2)
+    | l1 == 0  && l2 == 0               = EQ  -- short cut for empty strings
+    | x1 == x2 && s1 == s2 && l1 == l2  = EQ  -- short cut for the same string
+    | otherwise                         = inlinePerformIO $
+        withForeignPtr x1 $ \p1 ->
+        withForeignPtr x2 $ \p2 -> do
+            i <- memcmp (p1 `plusPtr` s1) (p2 `plusPtr` s2) (min l1 l2)
+            return $ case i `compare` 0 of
+                        EQ  -> l1 `compare` l2
+                        x   -> x
 {-# INLINE compareBytes #-}
 
 {-
@@ -501,6 +515,8 @@ cons c (PS x s l) = create (l+1) $ \p -> withForeignPtr x $ \f -> do
         poke p c
 {-# INLINE cons #-}
 
+-- todo fuse
+
 -- | /O(n)/ Append a byte to the end of a 'ByteString'
 snoc :: ByteString -> Word8 -> ByteString
 snoc (PS x s l) c = create (l+1) $ \p -> withForeignPtr x $ \f -> do
@@ -508,6 +524,8 @@ snoc (PS x s l) c = create (l+1) $ \p -> withForeignPtr x $ \f -> do
         poke (p `plusPtr` l) c
 {-# INLINE snoc #-}
 
+-- todo fuse
+
 -- | /O(1)/ Extract the first element of a ByteString, which must be non-empty.
 head :: ByteString -> Word8
 head ps@(PS x s _)
@@ -563,25 +581,30 @@ append xs@(PS ffp s l) ys@(PS fgp t m)
 -- Transformations
 
 -- | /O(n)/ 'map' @f xs@ is the ByteString obtained by applying @f@ to each
--- element of @xs@
---
+-- element of @xs@. This function is subject to array fusion.
 map :: (Word8 -> Word8) -> ByteString -> ByteString
-map f (PS fp start len) = inlinePerformIO $ withForeignPtr fp $ \p -> do
-    new_fp <- mallocByteString len
-    withForeignPtr new_fp $ \new_p -> do
-        map_ f (len-1) (p `plusPtr` start) new_p
-        return (PS new_fp 0 len)
+map f = loopArr . loopU (mapEFL f) noAL
 {-# INLINE map #-}
 
-map_ :: (Word8 -> Word8) -> Int -> Ptr Word8 -> Ptr Word8 -> IO ()
-STRICT4(map_)
-map_ f n p1 p2
-   | n < 0 = return ()
-   | otherwise = do
-        x <- peekByteOff p1 n
-        pokeByteOff p2 n (f x)
-        map_ f (n-1) p1 p2
-{-# INLINE map_ #-}
+-- | /O(n)/ Like 'map', but not fuseable. The benefit is that it is
+-- slightly faster for one-shot cases.
+mapF :: (Word8 -> Word8) -> ByteString -> ByteString
+STRICT2(mapF)
+mapF f (PS fp s len) = inlinePerformIO $ withForeignPtr fp $ \a -> do
+    np <- mallocByteString (len+1)
+    withForeignPtr np $ \p -> do
+        map_ 0 (a `plusPtr` s) p
+        return (PS np 0 len)
+  where
+    map_ :: Int -> Ptr Word8 -> Ptr Word8 -> IO ()
+    STRICT3(map_)
+    map_ n p1 p2
+       | n >= len = return ()
+       | otherwise = do
+            x <- peekByteOff p1 n
+            pokeByteOff p2 n (f x)
+            map_ (n+1) p1 p2
+{-# INLINE mapF #-}
 
 -- | /O(n)/ 'reverse' @xs@ efficiently returns the elements of @xs@ in reverse order.
 reverse :: ByteString -> ByteString
@@ -617,7 +640,16 @@ transpose ps = P.map pack (List.transpose (P.map unpack ps))
 -- | 'foldl', applied to a binary operator, a starting value (typically
 -- the left-identity of the operator), and a ByteString, reduces the
 -- ByteString using the binary operator, from left to right.
+-- This function is subject to array fusion.
 foldl :: (a -> Word8 -> a) -> a -> ByteString -> a
+foldl f z = loopAcc . loopU (foldEFL f) z
+{-# INLINE foldl #-}
+
+{-
+--
+-- About twice as fast with 6.4.1, but not fuseable
+-- A simple fold . map is enough to make it worth while.
+--
 foldl f v (PS x s l) = inlinePerformIO $ withForeignPtr x $ \ptr ->
         lgo v (ptr `plusPtr` s) (ptr `plusPtr` (s+l))
     where
@@ -625,6 +657,7 @@ foldl f v (PS x s l) = inlinePerformIO $ withForeignPtr x $ \ptr ->
         lgo z p q | p == q    = return z
                   | otherwise = do c <- peek p
                                    lgo (f z c) (p `plusPtr` 1) q
+-}
 
 -- | 'foldr', applied to a binary operator, a starting value
 -- (typically the right-identity of the operator), and a ByteString,
@@ -641,6 +674,7 @@ foldr k z (PS x s l) = inlinePerformIO $ withForeignPtr x $ \ptr ->
 
 -- | 'foldl1' is a variant of 'foldl' that has no starting value
 -- argument, and thus must be applied to non-empty 'ByteStrings'.
+-- This function is subject to array fusion.
 foldl1 :: (Word8 -> Word8 -> Word8) -> ByteString -> Word8
 foldl1 f ps
     | null ps   = errorEmptyList "foldl1"
@@ -698,6 +732,8 @@ any f (PS x s l) = inlinePerformIO $ withForeignPtr x $ \ptr ->
                                 if f c then return True
                                        else go (p `plusPtr` 1) q
 
+-- todo fuse
+
 -- | /O(n)/ Applied to a predicate and a 'ByteString', 'all' determines
 -- if all elements of the 'ByteString' satisfy the predicate.
 all :: (Word8 -> Bool) -> ByteString -> Bool
@@ -711,6 +747,7 @@ all f (PS x s l) = inlinePerformIO $ withForeignPtr x $ \ptr ->
                                  if f c
                                     then go (p `plusPtr` 1) q
                                     else return False
+-- todo fuse
 
 -- | /O(n)/ 'maximum' returns the maximum value from a 'ByteString'
 maximum :: ByteString -> Word8
@@ -728,6 +765,8 @@ minimum xs@(PS x s l)
                     return $ c_minimum (p `plusPtr` s) l
 {-# INLINE minimum #-}
 
+-- fusion is too slow here (10x)
+
 {-
 maximum xs@(PS x s l)
     | null xs   = errorEmptyList "maximum"
@@ -1232,7 +1271,7 @@ elem c ps = case elemIndex c ps of Nothing -> False ; _ -> True
 
 -- | /O(n)/ 'notElem' is the inverse of 'elem'
 notElem :: Word8 -> ByteString -> Bool
-notElem c ps = case elemIndex c ps of Nothing -> True ; _ -> False
+notElem c ps = not (elem c ps)
 {-# INLINE notElem #-}
 
 --
@@ -1247,23 +1286,6 @@ notElem c ps = case elemIndex c ps of Nothing -> True ; _ -> False
 filterByte :: Word8 -> ByteString -> ByteString
 filterByte w ps = replicate (count w ps) w
 
-{-
--- slower than the replicate version
-
-filterByte ch ps@(PS x s l)
-    | null ps   = ps
-    | otherwise = inlinePerformIO $ generate l $ \p -> withForeignPtr x $ \f -> do
-        t <- go (f `plusPtr` s) p l
-        return (t `minusPtr` p) -- actual length
-    where
-        STRICT3(go)
-        go _ t 0 = return t
-        go f t e = do w <- peek f
-                      if w == ch
-                        then poke t w >> go (f `plusPtr` 1) (t `plusPtr` 1) (e-1)
-                        else             go (f `plusPtr` 1) t               (e-1)
--}
-
 --
 -- | /O(n)/ A first order equivalent of /filter . (\/=)/, for the common
 -- case of filtering a single byte out of a list. It is more efficient
@@ -1289,9 +1311,13 @@ filterNotByte ch ps@(PS x s l)
 
 -- | /O(n)/ 'filter', applied to a predicate and a ByteString,
 -- returns a ByteString containing those characters that satisfy the
--- predicate.
+-- predicate. This function is subject to array fusion.
 filter :: (Word8 -> Bool) -> ByteString -> ByteString
-filter k ps@(PS x s l)
+filter p  = loopArr . loopU (filterEFL p) noAL
+{-# INLINE filter #-}
+
+filterF :: (Word8 -> Bool) -> ByteString -> ByteString
+filterF k ps@(PS x s l)
     | null ps   = ps
     | otherwise = inlinePerformIO $ generate l $ \p -> withForeignPtr x $ \f -> do
         t <- go (f `plusPtr` s) p l
@@ -1303,6 +1329,7 @@ filter k ps@(PS x s l)
                       if k w
                         then poke t w >> go (f `plusPtr` 1) (t `plusPtr` 1) (e - 1)
                         else             go (f `plusPtr` 1) t               (e - 1)
+{-# INLINE filterF #-}
 
 -- Almost as good: pack $ foldl (\xs c -> if f c then c : xs else xs) [] ps
 
@@ -1673,7 +1700,7 @@ unsafeUseAsCStringLen (PS ps s l) ac = withForeignPtr ps $ \p -> ac (castPtr p `
 --
 generate :: Int -> (Ptr Word8 -> IO Int) -> IO ByteString
 generate i f = do
-    p <- mallocArray i
+    p <- mallocArray (i+1)
     i' <- f p
     p' <- reallocArray p (i'+1)
     poke (p' `plusPtr` i') (0::Word8)    -- XXX so CStrings work
@@ -2081,3 +2108,118 @@ foreign import ccall unsafe "RtsAPI.h getProgArgv"
 foreign import ccall unsafe "__hscore_memcpy_src_off"
    memcpy_ptr_baoff :: Ptr a -> RawBuffer -> Int -> CSize -> IO (Ptr ())
 #endif
+
+-- ---------------------------------------------------------------------
+--
+-- Functional array fusion for ByteStrings. 
+--
+-- From the Data Parallel Haskell project, 
+--      http://www.cse.unsw.edu.au/~chak/project/dph/
+--
+
+-- |Data type for accumulators which can be ignored. The rewrite rules rely on
+-- the fact that no bottoms of this type are ever constructed; hence, we can
+-- assume @(_ :: NoAL) `seq` x = x@.
+--
+data NoAL = NoAL
+
+-- | Special forms of loop arguments
+--
+-- * These are common special cases for the three function arguments of gen
+--   and loop; we give them special names to make it easier to trigger RULES
+--   applying in the special cases represented by these arguments.  The
+--   "INLINE [1]" makes sure that these functions are only inlined in the last
+--   two simplifier phases.
+--
+-- * In the case where the accumulator is not needed, it is better to always
+--   explicitly return a value `()', rather than just copy the input to the
+--   output, as the former gives GHC better local information.
+-- 
+
+-- | Element function expressing a mapping only
+mapEFL :: (Word8 -> Word8) -> (NoAL -> Word8 -> (NoAL, Maybe Word8))
+mapEFL f = \_ e -> (noAL, (Just $ f e))
+{-# INLINE [1] mapEFL #-}
+
+-- | Element function implementing a filter function only
+filterEFL :: (Word8 -> Bool) -> (NoAL -> Word8 -> (NoAL, Maybe Word8))
+filterEFL p = \_ e -> if p e then (noAL, Just e) else (noAL, Nothing)
+{-# INLINE [1] filterEFL #-}
+
+-- |Element function expressing a reduction only
+foldEFL :: (acc -> Word8 -> acc) -> (acc -> Word8 -> (acc, Maybe Word8))
+foldEFL f = \a e -> (f a e, Nothing)
+{-# INLINE [1] foldEFL #-}
+
+-- | No accumulator
+noAL :: NoAL
+noAL = NoAL
+{-# INLINE [1] noAL #-}
+
+-- | Projection functions that are fusion friendly (as in, we determine when
+-- they are inlined)
+loopArr :: (ByteString, acc) -> ByteString
+loopArr (arr, _) = arr
+{-# INLINE [1] loopArr #-}
+
+loopAcc :: (ByteString, acc) -> acc
+loopAcc (_, acc) = acc
+{-# INLINE [1] loopAcc #-}
+
+loopSndAcc :: (ByteString, (acc1, acc2)) -> (ByteString, acc2)
+loopSndAcc (arr, (_, acc)) = (arr, acc)
+{-# INLINE [1] loopSndAcc #-}
+
+------------------------------------------------------------------------
+
+-- | Iteration over over ByteStrings
+loopU :: (acc -> Word8 -> (acc, Maybe Word8))  -- ^ mapping & folding, once per elem
+      -> acc                                   -- ^ initial acc value
+      -> ByteString                            -- ^ input ByteString
+      -> (ByteString, acc)
+
+loopU f start (PS fp s i) = inlinePerformIO $ withForeignPtr fp $ \a -> do
+    p <- mallocArray (i+1)
+    (acc, i') <- go (a `plusPtr` s) p start
+    p' <- if i == i' then return p else reallocArray p (i'+1) -- avoid realloc for maps
+    poke (p' `plusPtr` i') (0::Word8)
+    fp' <- newForeignFreePtr p'
+    return (PS fp' 0 i', acc)
+  where
+    go p ma = trans 0 0
+        where
+            STRICT3(trans)
+            trans a_off ma_off acc
+                | a_off >= i = return (acc, ma_off)
+                | otherwise  = do
+                    x <- peekByteOff p a_off
+                    let (acc', oe) = f acc x
+                    ma_off' <- case oe of
+                        Nothing  -> return ma_off
+                        Just e   -> do pokeByteOff ma ma_off e
+                                       return $ ma_off + 1
+                    trans (a_off+1) ma_off' acc'
+
+{-# INLINE [1] loopU #-}
+
+{-# RULES
+
+"array fusion!" forall em1 em2 start1 start2 arr.
+  loopU em2 start2 (loopArr (loopU em1 start1 arr)) =
+    let em (acc1, acc2) e =
+            case em1 acc1 e of
+                (acc1', Nothing) -> ((acc1', acc2), Nothing)
+                (acc1', Just e') ->
+                    case em2 acc2 e' of
+                        (acc2', res) -> ((acc1', acc2'), res)
+    in loopSndAcc (loopU em (start1, start2) arr)
+
+"loopArr/loopSndAcc" forall x.
+  loopArr (loopSndAcc x) = loopArr x
+
+-- orphan?
+-- "seq/NoAL" forall (u::NoAL) e.
+--   u `seq` e = e
+
+ #-}
+
index 530dda8..24190db 100644 (file)
@@ -215,6 +215,14 @@ module Data.ByteString.Char8 (
         unsafePackAddress,      -- :: Int -> Addr# -> ByteString
 #endif
 
+        -- * Utilities (needed for array fusion)
+#if defined(__GLASGOW_HASKELL__)
+        unpackList,
+#endif
+        noAL, NoAL, loopArr, loopAcc, loopSndAcc,
+        loopU, mapEFL, filterEFL,
+        filterF, mapF
+
     ) where
 
 import qualified Prelude as P
@@ -243,7 +251,10 @@ import Data.ByteString (ByteString(..)
 #if defined(__GLASGOW_HASKELL__)
                        ,getLine, getArgs, hGetLine, hGetNonBlocking
                        ,packAddress, unsafePackAddress
+                       ,unpackList
 #endif
+                       ,noAL, NoAL, loopArr, loopAcc, loopSndAcc
+                       ,loopU, mapEFL, filterEFL, filterF, mapF
                        ,useAsCString, unsafeUseAsCString
                        )