10 mapindexed, mapindexed2, mapindexed3,
11 -- zipArr, sumArr, scaleArr,
17 Vector (Vector), toVector, fromVector, listVector, vectorList, vector,
18 zipVector, scaleVector, sumVector, vectorNorm2, vectorSize,
20 Matrix (Matrix), toMatrix, fromMatrix, listMatrix, matrixList, matrix,
21 zipMatrix, scaleMatrix, sumMatrix,
28 -- showsVecList, showsMatList
34 --import IOExtensions(unsafePerformIO)
37 @Vector@ and @Matrix@ are 1-based arrays with read/show in form of Lists.
40 data Vector a = Vector (Array Int a) deriving (Eq) --, Show)
42 toVector :: Array Int a -> Vector a
45 fromVector :: Vector a -> Array Int a
46 fromVector (Vector x) = x
48 instance Functor (Vector) where
49 fmap fn x = toVector (fmap fn (fromVector x))
51 {-instance Eq a => Eq (Vector a) where
52 -- (Vector x) == (Vector y) = x == y
55 instance Show a => Show (Vector a) where
56 showsPrec p x = showsPrec p (elems (fromVector x))
58 instance Read a => Read (Vector a) where
59 readsPrec p = readParen False
60 (\r -> [(listVector s, t) | (s, t) <- reads r])
62 instance Num b => Num (Vector b) where
63 (+) = zipVector "+" (+)
64 (-) = zipVector "-" (-)
68 -- (*) = matMult -- works only for matrices!
69 -- fromInteger = fmap fromInteger
73 Convert a list to 1-based vector.
76 listVector :: [a] -> Vector a
77 listVector x = toVector (listArray (1,length x) x)
79 vectorList :: Vector a -> [a]
80 vectorList = elems . fromVector
82 vector (l,u) x | l == 1 = toVector (array (l,u) x)
83 | otherwise = error "vector: l != 1"
85 zipVector :: String -> (b -> c -> d) -> Vector b -> Vector c -> Vector d
86 zipVector s f (Vector a) (Vector b)
87 | bounds a == bounds b = vector (bounds a) [(i, f (a!i) (b!i)) | i <- indices a]
88 | otherwise = error ("zipVector: " ++ s ++ ": unconformable arrays")
90 scaleVector :: Num a => a -> Vector a -> Vector a
91 scaleVector a = fmap (* a)
93 sumVector :: Num a => Vector a -> a
94 sumVector = sum . elems . fromVector
96 vectorNorm2 :: Num a => Vector a -> a
97 vectorNorm2 x = inner x x
99 vectorSize :: Vector a -> Int
100 vectorSize (Vector x) = rangeSize (bounds x)
107 data Matrix a = Matrix (Array (Int, Int) a) deriving Eq
109 toMatrix :: Array (Int, Int) a -> Matrix a
110 toMatrix x = Matrix x
112 fromMatrix :: Matrix a -> Array (Int, Int) a
113 fromMatrix (Matrix x) = x
115 instance Functor (Matrix) where
116 fmap fn x = toMatrix (fmap fn (fromMatrix x))
118 --instance Eq a => Eq (Matrix a) where
119 -- (Matrix x) == (Matrix y) = x == y
121 instance Show a => Show (Matrix a) where
122 showsPrec p x = vertl (matrixList x)
124 vertl [] = showString "[]"
125 vertl (x:xs) = showChar '[' . shows x . showl xs
126 where showl [] = showChar ']'
127 showl (x:xs) = showString ",\n" . shows x . showl xs
129 instance Read a => Read (Matrix a) where
130 readsPrec p = readParen False
131 (\r -> [(listMatrix s, t) | (s, t) <- reads r])
133 instance Num b => Num (Matrix b) where
134 (+) = zipMatrix "+" (+)
135 (-) = zipMatrix "-" (-)
139 x * y = toMatrix (matMult (fromMatrix x) (fromMatrix y)) -- works only for matrices!
140 -- fromInteger = fmap fromInteger
143 Convert a nested list to a matrix.
146 listMatrix :: [[a]] -> Matrix a
147 listMatrix x = Matrix (listArray ((1, 1), (length x, length (x!!0))) (concat x))
149 matrixList :: Matrix a -> [[a]]
150 matrixList (Matrix x) = [ [x!(i,j) | j <- range (l',u')] | i <- range (l,u)]
151 where ((l,l'),(u,u')) = bounds x
153 matrix ((l,l'),(u,u')) x | l == 1 && l' == 1 = toMatrix (array ((l,l'),(u,u')) x)
154 | otherwise = error "matrix: l != 1"
156 zipMatrix :: String -> (b -> c -> d) -> Matrix b -> Matrix c -> Matrix d
157 zipMatrix s f (Matrix a) (Matrix b)
158 | bounds a == bounds b = matrix (bounds a) [(i, f (a!i) (b!i)) | i <- indices a]
159 | otherwise = error ("zipMatrix: " ++ s ++ ": unconformable arrays")
161 scaleMatrix :: Num a => a -> Matrix a -> Matrix a
162 scaleMatrix a = fmap (* a)
164 sumMatrix :: Num a => Matrix a -> a
165 sumMatrix = sum . elems . fromMatrix
173 safezipWith :: String -> (a -> b -> c) -> [a] -> [b] -> [c]
174 safezipWith _ _ [] [] = []
175 safezipWith s f (x:xs) (y:ys) = f x y : safezipWith s f xs ys
176 safezipWith s _ _ _ = error ("safezipWith: " ++ s ++ ": unconformable vectors")
178 safezip :: [a] -> [b] -> [(a,b)]
179 safezip = safezipWith "(,)" (,)
181 trMatrix :: Matrix a -> Matrix a
182 trMatrix (Matrix x) = matrix ((l,l'),(u',u)) [((j,i), x!(i,j)) | j <- range (l',u'), i <- range (l,u)]
183 where ((l,l'),(u,u')) = bounds x
185 row :: (Ix a, Ix b) => a -> Array (a,b) c -> Array b c
186 row i x = ixmap (l',u') (\j->(i,j)) x where ((l,l'),(u,u')) = bounds x
188 zipArr :: (Ix a) => String -> (b -> c -> d) -> Array a b -> Array a c -> Array a d
189 zipArr s f a b | bounds a == bounds b = array (bounds a) [(i, f (a!i) (b!i)) | i <- indices a]
190 | otherwise = error ("zipArr: " ++ s ++ ": unconformable arrays")
193 Valid only for b -> c -> b functions.
196 zipArr' :: (Ix a) => String -> (b -> c -> b) -> Array a b -> Array a c -> Array a b
197 zipArr' s f a b | bounds a == bounds b = accum f a (assocs b)
198 | otherwise = error ("zipArr': " ++ s ++ ": unconformable arrays")
201 Overload arithmetical operators to work on lists.
204 instance Num a => Num [a] where
205 (+) = safezipWith "+" (+)
206 (-) = safezipWith "-" (-)
211 -- fromInteger = undefined
215 sum1 :: (Num a) => [a] -> a
218 --main = print (sum1 [[4,1,1], [5,1,2], [6,1,3,4]])
222 map2 f = fmap (fmap f)
223 map3 f = fmap (map2 f)
226 Map function f at position n only. Out of range indices are silently
230 mapat n f x = mapat1 0 f x where
232 mapat1 i f (x:xs) = (if i == n then f x else x) : mapat1 (i + 1) f xs
234 mapat2 (i,j) = mapat i . mapat j
235 mapat3 (i,j,k) = mapat i . mapat j . mapat k
237 -- main = print (mapat 2 (10+) [1,2,3,4])
238 -- main = print (mapat2 (1,0) (1000+) ginp)
239 -- main = print (mapat3 (1,0,1) (1000+) gw)
243 mapindexed f x = mapindexed1 f 0 x where
244 mapindexed1 _ _ [] = []
245 mapindexed1 f n (x:xs) = f n x : mapindexed1 f (n + 1) xs
247 mapindexed2 f = mapindexed (\i -> mapindexed (\j -> f (i, j)))
248 mapindexed3 f = mapindexed (\i -> mapindexed (\j -> mapindexed (\k -> f (i, j, k))))
250 -- main = print (mapindexed (\x y -> mapat (10+) [1,2,3,4] y) [1,2,3,4])
251 -- main = print (mapindexed2 (\(i,j) x -> 100*i + 10*j + x) ginp)
252 -- main = print (mapindexed3 (\(i,j,k) x -> 1000*i + 100*j + 10*k + x) gw)
257 Overload arithmetical operators to work on arrays.
260 instance (Ix a, Show a, Num b) => Num (Array a b) where
266 -- (*) = matMult -- works only for matrices!
267 -- fromInteger = map fromInteger
271 scaleArr :: (Ix i, Num a) => a -> Array i a -> Array i a
272 scaleArr a = fmap (*a)
274 sumArr :: (Ix i, Num a) => Array i a -> a
279 arraySize :: (Ix i) => Array i a -> Int
280 arraySize = rangeSize . bounds
284 matMult :: (Ix a, Ix b, Ix c, Num d) =>
285 Array (a,b) d -> Array (b,c) d -> Array (a,c) d
286 matMult x y = array resultBounds
287 [((i,j), sum [x!(i,k) * y!(k,j) | k <- range (lj,uj)])
288 | i <- range (li,ui),
289 j <- range (lj',uj') ]
290 where ((li,lj),(ui,uj)) = bounds x
291 ((li',lj'),(ui',uj')) = bounds y
293 | (lj,uj)==(li',ui') = ((li,lj'),(ui,uj'))
294 | otherwise = error "matMult: incompatible bounds"
298 Inner product of two vectors.
301 inner :: Num a => Vector a -> Vector a -> a
302 inner (Vector v) (Vector w) = if b == bounds w
303 then sum [v!i * w!i | i <- range b]
304 else error "nn.inner: inconformable vectors"
308 Outer product of two vectors $v \dot w^\mathrm{T}$.
311 outerVector :: Num b => Vector b -> Vector b -> Matrix b
312 outerVector (Vector v) (Vector w) = if (l,u) == (l',u')
313 then matrix ((l,l'),(u,u')) [((i,j), v!i * w!j) | i <- range (l,u), j <- range (l',u')]
314 else error "nn.outer: inconformable vectors"
315 where ((l,u),(l',u')) = (bounds v, bounds w)
319 outerArr :: (Ix a, Num b) => Array a b -> Array a b -> Array (a,a) b
320 outerArr v w = if (l,u) == (l',u')
321 then array ((l,l'),(u,u')) [((i,j), v!i * w!j) | i <- range (l,u), j <- range (l',u')]
322 else error "nn.outer: inconformable vectors"
323 where ((l,u),(l',u')) = (bounds v, bounds w)
326 Inner product of a matrix and a vector.
329 matvec :: (Ix a, Num b) => Array (a,a) b -> Array a b -> Array a b
330 matvec w x | bounds x == (l',u') =
331 array (l,u) [(i, sum [w!(i,j) * x!j | j <- range (l',u')])
333 | otherwise = error "nn.matvec: inconformable arrays"
334 where ((l,l'),(u,u')) = bounds w
340 augment :: (Num a) => Vector a -> a -> Vector a
341 augment (Vector x) y = Vector (array (a,b') ((b',y) : assocs x))
342 where (a,b) = bounds x
346 Older approach (x!!i!!j fails in ghc-2.03).
349 toMatrix' :: [[a]] -> Matrix a
350 toMatrix' x = Matrix (array ((1,1),(u,u')) [((i,j), (x!!(i-1))!!(j-1))
351 | i <- range (1,u), j <- range (1,u')])
352 where (u,u') = (length x,length (x!!0))
358 padleft :: Int -> String -> String
359 padleft n x | n <= length x = x
360 | otherwise = replicate (n - length x) ' ' ++ x
364 padMatrix :: RealFloat a => Int -> Matrix a -> Matrix String
365 padMatrix n x = let ss = fmap (\a -> showFFloat (Just n) a "") x
366 maxw = maximum (fmap length (elems (fromMatrix ss)))
367 in fmap (padleft maxw) ss
371 showsVector :: (RealFloat a) => Int -> Vector a -> ShowS
372 showsVector n x z1 = let x' = padArr n x
374 concat (fmap (\ (i, s) -> if i == u then s ++ "\n" else s ++ " ") (assocs x')) ++ z1
378 showsMatrix :: RealFloat a => Int -> Matrix a -> ShowS
379 showsMatrix n x z1 = let x' = padMatrix n x
380 ((l,l'),(u,u')) = bounds x' in
381 concat (fmap (\ ((i,j), s) -> if j == u' then s ++ "\n" else s ++ " ") (assocs x')) ++ z1
385 showsVecList n x s = foldr (showsVector n) s x
386 showsMatList n x s = foldr (showsMatrix n) s x
391 --spy :: Show a => String -> a -> a
392 --spy msg x = trace ('<' : msg ++ ": " ++ shows x ">\n") x
393 --spy x = seq (unsafePerformIO (putStr ('<' : shows x ">\n"))) x
394 --spy x = traceShow "z" x