Make arrays safer (e.g. trac #1046)
authorIan Lynagh <igloo@earth.li>
Fri, 10 Aug 2007 16:34:05 +0000 (16:34 +0000)
committerIan Lynagh <igloo@earth.li>
Fri, 10 Aug 2007 16:34:05 +0000 (16:34 +0000)
GHC/Arr.lhs

index 8f439cd..25505fc 100644 (file)
@@ -287,7 +287,14 @@ type IPr = (Int, Int)
 
 -- | The type of immutable non-strict (boxed) arrays
 -- with indices in @i@ and elements in @e@.
-data Ix i => Array     i e = Array   !i !i (Array# e)
+-- The Int is the number of elements in the Array.
+data Ix i => Array i e
+                 = Array !i         -- the lower bound, l
+                         !i         -- the upper bound, u
+                         !Int       -- a cache of (rangeSize (l,u))
+                                    -- used to make sure an index is
+                                    -- really in range
+                         (Array# e) -- The actual elements
 
 -- | Mutable, boxed, non-strict arrays in the 'ST' monad.  The type
 -- arguments are as follows:
@@ -298,13 +305,19 @@ data Ix i => Array     i e = Array   !i !i (Array# e)
 --
 --  * @e@: the element type of the array.
 --
-data         STArray s i e = STArray !i !i (MutableArray# s e)
+data STArray s i e
+         = STArray !i                  -- the lower bound, l
+                   !i                  -- the upper bound, u
+                   !Int                -- a cache of (rangeSize (l,u))
+                                       -- used to make sure an index is
+                                       -- really in range
+                   (MutableArray# s e) -- The actual elements
        -- No Ix context for STArray.  They are stupid,
        -- and force an Ix context on the equality instance.
 
 -- Just pointer equality on mutable arrays:
 instance Eq (STArray s i e) where
-    STArray _ _ arr1# == STArray _ _ arr2# =
+    STArray _ _ _ arr1# == STArray _ _ _ arr2# =
         sameMutableArray# arr1# arr2#
 \end{code}
 
@@ -359,14 +372,21 @@ array :: Ix i
                        -- association '(i, x)' defines the value of
                        -- the array at index 'i' to be 'x'.
        -> Array i e
-array (l,u) ies = unsafeArray (l,u) [(index (l,u) i, e) | (i, e) <- ies]
+array (l,u) ies
+    = let n = safeRangeSize (l,u)
+      in unsafeArray' (l,u) n
+                      [(safeIndex (l,u) n i, e) | (i, e) <- ies]
 
 {-# INLINE unsafeArray #-}
 unsafeArray :: Ix i => (i,i) -> [(Int, e)] -> Array i e
-unsafeArray (l,u) ies = runST (ST $ \s1# ->
-    case rangeSize (l,u)                of { I# n# ->
-    case newArray# n# arrEleBottom s1#  of { (# s2#, marr# #) ->
-    foldr (fill marr#) (done l u marr#) ies s2# }})
+unsafeArray b ies = unsafeArray' b (rangeSize b) ies
+
+{-# INLINE unsafeArray' #-}
+unsafeArray' :: Ix i => (i,i) -> Int -> [(Int, e)] -> Array i e
+unsafeArray' (l,u) n@(I# n#) ies = runST (ST $ \s1# ->
+    case newArray# n# arrEleBottom s1# of
+        (# s2#, marr# #) ->
+            foldr (fill marr#) (done l u n marr#) ies s2#)
 
 {-# INLINE fill #-}
 fill :: MutableArray# s e -> (Int, e) -> STRep s a -> STRep s a
@@ -375,10 +395,10 @@ fill marr# (I# i#, e) next s1# =
     next s2# }
 
 {-# INLINE done #-}
-done :: Ix i => i -> i -> MutableArray# s e -> STRep s (Array i e)
-done l u marr# s1# =
-    case unsafeFreezeArray# marr# s1#   of { (# s2#, arr# #) ->
-    (# s2#, Array l u arr# #) }
+done :: Ix i => i -> i -> Int -> MutableArray# s e -> STRep s (Array i e)
+done l u n marr# s1# =
+    case unsafeFreezeArray# marr# s1# of
+        (# s2#, arr# #) -> (# s2#, Array l u n arr# #)
 
 -- This is inefficient and I'm not sure why:
 -- listArray (l,u) es = unsafeArray (l,u) (zip [0 .. rangeSize (l,u) - 1] es)
@@ -391,7 +411,7 @@ done l u marr# s1# =
 {-# INLINE listArray #-}
 listArray :: Ix i => (i,i) -> [e] -> Array i e
 listArray (l,u) es = runST (ST $ \s1# ->
-    case rangeSize (l,u)                of { I# n# ->
+    case safeRangeSize (l,u)            of { n@(I# n#) ->
     case newArray# n# arrEleBottom s1#  of { (# s2#, marr# #) ->
     let fillFromList i# xs s3# | i# ==# n# = s3#
                                | otherwise = case xs of
@@ -399,39 +419,57 @@ listArray (l,u) es = runST (ST $ \s1# ->
             y:ys -> case writeArray# marr# i# y s3# of { s4# ->
                     fillFromList (i# +# 1#) ys s4# } in
     case fillFromList 0# es s2#         of { s3# ->
-    done l u marr# s3# }}})
+    done l u n marr# s3# }}})
 
 -- | The value at the given index in an array.
 {-# INLINE (!) #-}
 (!) :: Ix i => Array i e -> i -> e
-arr@(Array l u _) ! i = unsafeAt arr (index (l,u) i)
+arr@(Array l u n _) ! i = unsafeAt arr $ safeIndex (l,u) n i
+
+{-# INLINE safeRangeSize #-}
+safeRangeSize :: Ix i => (i, i) -> Int
+safeRangeSize (l,u) = let r = rangeSize (l, u)
+                      in if r < 0 then error "Negative range size"
+                                  else r
+
+{-# INLINE safeIndex #-}
+safeIndex :: Ix i => (i, i) -> Int -> i -> Int
+safeIndex (l,u) n i = let i' = unsafeIndex (l,u) i
+                      in if (0 <= i') && (i' < n)
+                         then i'
+                         else error "Error in array index"
 
 {-# INLINE unsafeAt #-}
 unsafeAt :: Ix i => Array i e -> Int -> e
-unsafeAt (Array _ _ arr#) (I# i#) =
+unsafeAt (Array _ _ _ arr#) (I# i#) =
     case indexArray# arr# i# of (# e #) -> e
 
 -- | The bounds with which an array was constructed.
 {-# INLINE bounds #-}
 bounds :: Ix i => Array i e -> (i,i)
-bounds (Array l u _) = (l,u)
+bounds (Array l u _ _) = (l,u)
+
+-- | The number of elements in the array.
+{-# INLINE numElements #-}
+numElements :: Ix i => Array i e -> Int
+numElements (Array _ _ n _) = n
 
 -- | The list of indices of an array in ascending order.
 {-# INLINE indices #-}
 indices :: Ix i => Array i e -> [i]
-indices (Array l u _) = range (l,u)
+indices (Array l u _ _) = range (l,u)
 
 -- | The list of elements of an array in index order.
 {-# INLINE elems #-}
 elems :: Ix i => Array i e -> [e]
-elems arr@(Array l u _) =
-    [unsafeAt arr i | i <- [0 .. rangeSize (l,u) - 1]]
+elems arr@(Array l u n _) =
+    [unsafeAt arr i | i <- [0 .. n - 1]]
 
 -- | The list of associations of an array in index order.
 {-# INLINE assocs #-}
 assocs :: Ix i => Array i e -> [(i, e)]
-assocs arr@(Array l u _) =
-    [(i, unsafeAt arr (unsafeIndex (l,u) i)) | i <- range (l,u)]
+assocs arr@(Array l u _ _) =
+    [(i, arr ! i) | i <- range (l,u)]
 
 -- | The 'accumArray' deals with repeated indices in the association
 -- list using an /accumulating function/ which combines the values of
@@ -455,21 +493,27 @@ accumArray :: Ix i
        -> [(i, a)]             -- ^ association list
        -> Array i e
 accumArray f init (l,u) ies =
-    unsafeAccumArray f init (l,u) [(index (l,u) i, e) | (i, e) <- ies]
+    let n = safeRangeSize (l,u)
+    in unsafeAccumArray' f init (l,u) n
+                         [(safeIndex (l,u) n i, e) | (i, e) <- ies]
 
 {-# INLINE unsafeAccumArray #-}
 unsafeAccumArray :: Ix i => (e -> a -> e) -> e -> (i,i) -> [(Int, a)] -> Array i e
-unsafeAccumArray f init (l,u) ies = runST (ST $ \s1# ->
-    case rangeSize (l,u)                of { I# n# ->
+unsafeAccumArray f init b ies = unsafeAccumArray' f init b (rangeSize b) ies
+
+{-# INLINE unsafeAccumArray' #-}
+unsafeAccumArray' :: Ix i => (e -> a -> e) -> e -> (i,i) -> Int -> [(Int, a)] -> Array i e
+unsafeAccumArray' f init (l,u) n@(I# n#) ies = runST (ST $ \s1# ->
     case newArray# n# init s1#          of { (# s2#, marr# #) ->
-    foldr (adjust f marr#) (done l u marr#) ies s2# }})
+    foldr (adjust f marr#) (done l u n marr#) ies s2# })
 
 {-# INLINE adjust #-}
 adjust :: (e -> a -> e) -> MutableArray# s e -> (Int, a) -> STRep s b -> STRep s b
 adjust f marr# (I# i#, new) next s1# =
-    case readArray# marr# i# s1#        of { (# s2#, old #) ->
-    case writeArray# marr# i# (f old new) s2# of { s3# ->
-    next s3# }}
+    case readArray# marr# i# s1# of
+        (# s2#, old #) ->
+            case writeArray# marr# i# (f old new) s2# of
+                s3# -> next s3#
 
 -- | Constructs an array identical to the first argument except that it has
 -- been updated by the associations in the right argument.
@@ -484,14 +528,14 @@ adjust f marr# (I# i#, new) next s1# =
 -- but GHC's implementation uses the last association for each index.
 {-# INLINE (//) #-}
 (//) :: Ix i => Array i e -> [(i, e)] -> Array i e
-arr@(Array l u _) // ies =
-    unsafeReplace arr [(index (l,u) i, e) | (i, e) <- ies]
+arr@(Array l u n _) // ies =
+    unsafeReplace arr [(safeIndex (l,u) n i, e) | (i, e) <- ies]
 
 {-# INLINE unsafeReplace #-}
 unsafeReplace :: Ix i => Array i e -> [(Int, e)] -> Array i e
-unsafeReplace arr@(Array l u _) ies = runST (do
-    STArray _ _ marr# <- thawSTArray arr
-    ST (foldr (fill marr#) (done l u marr#) ies))
+unsafeReplace arr ies = runST (do
+    STArray l u n marr# <- thawSTArray arr
+    ST (foldr (fill marr#) (done l u n marr#) ies))
 
 -- | @'accum' f@ takes an array and an association list and accumulates
 -- pairs from the list into the array with the accumulating function @f@.
@@ -501,19 +545,19 @@ unsafeReplace arr@(Array l u _) ies = runST (do
 --
 {-# INLINE accum #-}
 accum :: Ix i => (e -> a -> e) -> Array i e -> [(i, a)] -> Array i e
-accum f arr@(Array l u _) ies =
-    unsafeAccum f arr [(index (l,u) i, e) | (i, e) <- ies]
+accum f arr@(Array l u n _) ies =
+    unsafeAccum f arr [(safeIndex (l,u) n i, e) | (i, e) <- ies]
 
 {-# INLINE unsafeAccum #-}
 unsafeAccum :: Ix i => (e -> a -> e) -> Array i e -> [(Int, a)] -> Array i e
-unsafeAccum f arr@(Array l u _) ies = runST (do
-    STArray _ _ marr# <- thawSTArray arr
-    ST (foldr (adjust f marr#) (done l u marr#) ies))
+unsafeAccum f arr ies = runST (do
+    STArray l u n marr# <- thawSTArray arr
+    ST (foldr (adjust f marr#) (done l u n marr#) ies))
 
 {-# INLINE amap #-}
 amap :: Ix i => (a -> b) -> Array i a -> Array i b
-amap f arr@(Array l u _) =
-    unsafeArray (l,u) [(i, f (unsafeAt arr i)) | i <- [0 .. rangeSize (l,u) - 1]]
+amap f arr@(Array l u n _) =
+    unsafeArray' (l,u) n [(i, f (unsafeAt arr i)) | i <- [0 .. n - 1]]
 
 -- | 'ixmap' allows for transformations on array indices.
 -- It may be thought of as providing function composition on the right
@@ -524,14 +568,14 @@ amap f arr@(Array l u _) =
 {-# INLINE ixmap #-}
 ixmap :: (Ix i, Ix j) => (i,i) -> (i -> j) -> Array j e -> Array i e
 ixmap (l,u) f arr =
-    unsafeArray (l,u) [(unsafeIndex (l,u) i, arr ! f i) | i <- range (l,u)]
+    array (l,u) [(i, arr ! f i) | i <- range (l,u)]
 
 {-# INLINE eqArray #-}
 eqArray :: (Ix i, Eq e) => Array i e -> Array i e -> Bool
-eqArray arr1@(Array l1 u1 _) arr2@(Array l2 u2 _) =
-    if rangeSize (l1,u1) == 0 then rangeSize (l2,u2) == 0 else
+eqArray arr1@(Array l1 u1 n1 _) arr2@(Array l2 u2 n2 _) =
+    if n1 == 0 then n2 == 0 else
     l1 == l2 && u1 == u2 &&
-    and [unsafeAt arr1 i == unsafeAt arr2 i | i <- [0 .. rangeSize (l1,u1) - 1]]
+    and [unsafeAt arr1 i == unsafeAt arr2 i | i <- [0 .. n1 - 1]]
 
 {-# INLINE cmpArray #-}
 cmpArray :: (Ix i, Ord e) => Array i e -> Array i e -> Ordering
@@ -539,13 +583,14 @@ cmpArray arr1 arr2 = compare (assocs arr1) (assocs arr2)
 
 {-# INLINE cmpIntArray #-}
 cmpIntArray :: Ord e => Array Int e -> Array Int e -> Ordering
-cmpIntArray arr1@(Array l1 u1 _) arr2@(Array l2 u2 _) =
-    if rangeSize (l1,u1) == 0 then if rangeSize (l2,u2) == 0 then EQ else LT else
-    if rangeSize (l2,u2) == 0 then GT else
-    case compare l1 l2 of
-        EQ    -> foldr cmp (compare u1 u2) [0 .. rangeSize (l1, min u1 u2) - 1]
-        other -> other
-    where
+cmpIntArray arr1@(Array l1 u1 n1 _) arr2@(Array l2 u2 n2 _) =
+    if n1 == 0 then
+        if n2 == 0 then EQ else LT
+    else if n2 == 0 then GT
+    else case compare l1 l2 of
+             EQ    -> foldr cmp (compare u1 u2) [0 .. (n1 `min` n2) - 1]
+             other -> other
+  where
     cmp i rest = case compare (unsafeAt arr1 i) (unsafeAt arr2 i) of
         EQ    -> rest
         other -> other
@@ -606,34 +651,38 @@ might be different, though.
 {-# INLINE newSTArray #-}
 newSTArray :: Ix i => (i,i) -> e -> ST s (STArray s i e)
 newSTArray (l,u) init = ST $ \s1# ->
-    case rangeSize (l,u)                of { I# n# ->
+    case safeRangeSize (l,u)            of { n@(I# n#) ->
     case newArray# n# init s1#          of { (# s2#, marr# #) ->
-    (# s2#, STArray l u marr# #) }}
+    (# s2#, STArray l u n marr# #) }}
 
 {-# INLINE boundsSTArray #-}
 boundsSTArray :: STArray s i e -> (i,i)  
-boundsSTArray (STArray l u _) = (l,u)
+boundsSTArray (STArray l u _ _) = (l,u)
+
+{-# INLINE numElementsSTArray #-}
+numElementsSTArray :: STArray s i e -> Int
+numElementsSTArray (STArray _ _ n _) = n
 
 {-# INLINE readSTArray #-}
 readSTArray :: Ix i => STArray s i e -> i -> ST s e
-readSTArray marr@(STArray l u _) i =
-    unsafeReadSTArray marr (index (l,u) i)
+readSTArray marr@(STArray l u n _) i =
+    unsafeReadSTArray marr (safeIndex (l,u) n i)
 
 {-# INLINE unsafeReadSTArray #-}
 unsafeReadSTArray :: Ix i => STArray s i e -> Int -> ST s e
-unsafeReadSTArray (STArray _ _ marr#) (I# i#) = ST $ \s1# ->
-    readArray# marr# i# s1#
+unsafeReadSTArray (STArray _ _ _ marr#) (I# i#)
+    = ST $ \s1# -> readArray# marr# i# s1#
 
 {-# INLINE writeSTArray #-}
 writeSTArray :: Ix i => STArray s i e -> i -> e -> ST s () 
-writeSTArray marr@(STArray l u _) i e =
-    unsafeWriteSTArray marr (index (l,u) i) e
+writeSTArray marr@(STArray l u n _) i e =
+    unsafeWriteSTArray marr (safeIndex (l,u) n i) e
 
 {-# INLINE unsafeWriteSTArray #-}
 unsafeWriteSTArray :: Ix i => STArray s i e -> Int -> e -> ST s () 
-unsafeWriteSTArray (STArray _ _ marr#) (I# i#) e = ST $ \s1# ->
-    case writeArray# marr# i# e s1#     of { s2# ->
-    (# s2#, () #) }
+unsafeWriteSTArray (STArray _ _ _ marr#) (I# i#) e = ST $ \s1# ->
+    case writeArray# marr# i# e s1# of
+        s2# -> (# s2#, () #)
 \end{code}
 
 
@@ -645,8 +694,7 @@ unsafeWriteSTArray (STArray _ _ marr#) (I# i#) e = ST $ \s1# ->
 
 \begin{code}
 freezeSTArray :: Ix i => STArray s i e -> ST s (Array i e)
-freezeSTArray (STArray l u marr#) = ST $ \s1# ->
-    case rangeSize (l,u)                of { I# n# ->
+freezeSTArray (STArray l u n@(I# n#) marr#) = ST $ \s1# ->
     case newArray# n# arrEleBottom s1#  of { (# s2#, marr'# #) ->
     let copy i# s3# | i# ==# n# = s3#
                     | otherwise =
@@ -655,17 +703,16 @@ freezeSTArray (STArray l u marr#) = ST $ \s1# ->
             copy (i# +# 1#) s5# }} in
     case copy 0# s2#                    of { s3# ->
     case unsafeFreezeArray# marr'# s3#  of { (# s4#, arr# #) ->
-    (# s4#, Array l u arr# #) }}}}
+    (# s4#, Array l u n arr# #) }}}
 
 {-# INLINE unsafeFreezeSTArray #-}
 unsafeFreezeSTArray :: Ix i => STArray s i e -> ST s (Array i e)
-unsafeFreezeSTArray (STArray l u marr#) = ST $ \s1# ->
+unsafeFreezeSTArray (STArray l u n marr#) = ST $ \s1# ->
     case unsafeFreezeArray# marr# s1#   of { (# s2#, arr# #) ->
-    (# s2#, Array l u arr# #) }
+    (# s2#, Array l u n arr# #) }
 
 thawSTArray :: Ix i => Array i e -> ST s (STArray s i e)
-thawSTArray (Array l u arr#) = ST $ \s1# ->
-    case rangeSize (l,u)                of { I# n# ->
+thawSTArray (Array l u n@(I# n#) arr#) = ST $ \s1# ->
     case newArray# n# arrEleBottom s1#  of { (# s2#, marr# #) ->
     let copy i# s3# | i# ==# n# = s3#
                     | otherwise =
@@ -673,11 +720,11 @@ thawSTArray (Array l u arr#) = ST $ \s1# ->
             case writeArray# marr# i# e s3# of { s4# ->
             copy (i# +# 1#) s4# }} in
     case copy 0# s2#                    of { s3# ->
-    (# s3#, STArray l u marr# #) }}}
+    (# s3#, STArray l u n marr# #) }}
 
 {-# INLINE unsafeThawSTArray #-}
 unsafeThawSTArray :: Ix i => Array i e -> ST s (STArray s i e)
-unsafeThawSTArray (Array l u arr#) = ST $ \s1# ->
+unsafeThawSTArray (Array l u n arr#) = ST $ \s1# ->
     case unsafeThawArray# arr# s1#      of { (# s2#, marr# #) ->
-    (# s2#, STArray l u marr# #) }
+    (# s2#, STArray l u n marr# #) }
 \end{code}