[project @ 1997-07-27 00:43:10 by sof]
[ghc-hetmet.git] / ghc / tests / programs / cholewo-eval / Arr.lhs
1
2 \begin{code}
3 module Arr (
4   module Array,
5
6   safezipWith, safezip,
7   row,
8   sum1, map2, map3,
9   mapat, mapat2, mapat3,
10   mapindexed, mapindexed2, mapindexed3,
11 --  zipArr, sumArr, scaleArr,
12   arraySize,
13
14   matvec, inner, 
15   outerVector,
16   
17   Vector (Vector), toVector, fromVector, listVector, vectorList, vector, 
18   zipVector, scaleVector, sumVector, vectorNorm2, vectorSize,
19   
20   Matrix (Matrix), toMatrix, fromMatrix, listMatrix, matrixList, matrix, 
21   zipMatrix, scaleMatrix, sumMatrix,
22
23   augment,
24   trMatrix,
25
26 --   showsVector,
27 --   showsMatrix,
28 -- showsVecList, showsMatList
29 --  spy,
30 ) where
31 import Array
32 import Numeric
33 --import Trace
34 --import IOExtensions(unsafePerformIO)
35 \end{code}
36
37 @Vector@ and @Matrix@ are 1-based arrays with read/show in form of Lists.
38
39 \begin{code}
40 data Vector a = Vector (Array Int a) deriving (Eq) --, Show)
41
42 toVector :: Array Int a -> Vector a
43 toVector x = Vector x
44
45 fromVector :: Vector a -> Array Int a
46 fromVector (Vector x) = x
47
48 instance Functor (Vector) where
49   map fn x = toVector (map fn (fromVector x))    
50
51 {-instance Eq a => Eq (Vector a) where
52 --  (Vector x) == (Vector y) = x == y
53 -}
54
55 instance Show a => Show (Vector a) where
56   showsPrec p x = showsPrec p (elems (fromVector x))
57
58 instance Read a => Read (Vector a) where
59   readsPrec p = readParen False 
60                   (\r -> [(listVector s, t) | (s, t) <- reads r])
61
62 instance Num b => Num (Vector b) where
63   (+) = zipVector "+" (+)
64   (-) = zipVector "-" (-)
65   negate = map negate
66   abs = map abs
67   signum = map signum
68 --   (*) = matMult -- works only for matrices!
69 --  fromInteger = map fromInteger
70 \end{code}
71
72
73 Convert a list to 1-based vector.
74
75 \begin{code}
76 listVector :: [a] -> Vector a
77 listVector x = toVector (listArray (1,length x) x)
78
79 vectorList :: Vector a -> [a]
80 vectorList = elems . fromVector
81
82 vector (l,u) x | l == 1 = toVector (array (l,u) x)
83                | otherwise = error "vector: l != 1"
84                
85 zipVector :: String -> (b -> c -> d) -> Vector b -> Vector c -> Vector d
86 zipVector s f (Vector a) (Vector b) 
87   | bounds a == bounds b = vector (bounds a) [(i, f (a!i) (b!i)) | i <- indices a]
88   | otherwise            = error ("zipVector: " ++ s ++ ": unconformable arrays")
89
90 scaleVector :: Num a => a -> Vector a -> Vector a
91 scaleVector a = map (* a)
92
93 sumVector :: Num a => Vector a -> a
94 sumVector = sum . elems . fromVector
95
96 vectorNorm2 :: Num a => Vector a -> a
97 vectorNorm2 x = inner x x
98
99 vectorSize :: Vector a -> Int
100 vectorSize (Vector x) = rangeSize (bounds x)
101
102 \end{code}
103
104 ==============
105
106 \begin{code}
107 data Matrix a = Matrix (Array (Int, Int) a) deriving Eq
108
109 toMatrix :: Array (Int, Int) a -> Matrix a
110 toMatrix x = Matrix x
111
112 fromMatrix :: Matrix a -> Array (Int, Int) a
113 fromMatrix (Matrix x) = x
114
115 instance Functor (Matrix) where
116   map fn x = toMatrix (map fn (fromMatrix x))    
117
118 --instance Eq a => Eq (Matrix a) where
119 --  (Matrix x) == (Matrix y) = x == y
120
121 instance Show a => Show (Matrix a) where
122   showsPrec p x = vertl (matrixList x)
123   
124 vertl [] = showString "[]"
125 vertl (x:xs) = showChar '[' . shows x . showl xs 
126     where showl [] = showChar ']'
127           showl (x:xs) = showString ",\n" . shows x . showl xs
128
129 instance Read a => Read (Matrix a) where
130     readsPrec p = readParen False
131                   (\r -> [(listMatrix s, t) | (s, t) <- reads r])
132
133 instance Num b => Num (Matrix b) where
134   (+) = zipMatrix "+" (+)
135   (-) = zipMatrix "-" (-)
136   negate = map negate
137   abs = map abs
138   signum = map signum
139   x * y = toMatrix (matMult (fromMatrix x) (fromMatrix y)) -- works only for matrices!
140 --  fromInteger = map fromInteger
141 \end{code}
142
143 Convert a nested list to a matrix.
144
145 \begin{code}
146 listMatrix :: [[a]] -> Matrix a
147 listMatrix x = Matrix (listArray ((1, 1), (length x, length (x!!0))) (concat x))
148
149 matrixList :: Matrix a -> [[a]]
150 matrixList (Matrix x) = [ [x!(i,j) | j <- range (l',u')] | i <- range (l,u)]
151          where ((l,l'),(u,u')) = bounds x
152
153 matrix ((l,l'),(u,u')) x | l == 1 && l' == 1 = toMatrix (array ((l,l'),(u,u')) x)
154                          | otherwise = error "matrix: l != 1"
155
156 zipMatrix :: String -> (b -> c -> d) -> Matrix b -> Matrix c -> Matrix d
157 zipMatrix s f (Matrix a) (Matrix b) 
158   | bounds a == bounds b = matrix (bounds a) [(i, f (a!i) (b!i)) | i <- indices a]
159   | otherwise            = error ("zipMatrix: " ++ s ++ ": unconformable arrays")
160
161 scaleMatrix :: Num a => a -> Matrix a -> Matrix a
162 scaleMatrix a = map (* a)
163
164 sumMatrix :: Num a => Matrix a -> a
165 sumMatrix = sum . elems . fromMatrix
166
167 \end{code}
168
169
170 ============
171
172 \begin{code}
173 safezipWith :: String -> (a -> b -> c) -> [a] -> [b] -> [c]
174 safezipWith _ _ [] [] = []
175 safezipWith s f (x:xs) (y:ys) = f x y : safezipWith s f xs ys
176 safezipWith s _ _ _ = error ("safezipWith: " ++ s ++ ": unconformable vectors")
177
178 safezip :: [a] -> [b] -> [(a,b)]
179 safezip = safezipWith "(,)" (,)
180
181 trMatrix :: Matrix a -> Matrix a
182 trMatrix (Matrix x) = matrix ((l,l'),(u',u)) [((j,i), x!(i,j)) | j <- range (l',u'), i <- range (l,u)]
183          where ((l,l'),(u,u')) = bounds x
184
185 row :: (Ix a, Ix b) => a -> Array (a,b) c -> Array b c
186 row i x = ixmap (l',u') (\j->(i,j)) x where ((l,l'),(u,u')) = bounds x
187
188 zipArr :: (Ix a) => String -> (b -> c -> d) -> Array a b -> Array a c -> Array a d
189 zipArr s f a b | bounds a == bounds b = array (bounds a) [(i, f (a!i) (b!i)) | i <- indices a]
190                | otherwise            = error ("zipArr: " ++ s ++ ": unconformable arrays")
191 \end{code}
192
193 Valid only for b -> c -> b functions.
194
195 \begin{code}
196 zipArr' :: (Ix a) => String -> (b -> c -> b) -> Array a b -> Array a c -> Array a b
197 zipArr' s f a b | bounds a == bounds b = accum f a (assocs b)
198                 | otherwise            = error ("zipArr': " ++ s ++ ": unconformable arrays")
199 \end{code}
200
201 Overload arithmetical operators to work on lists.
202
203 \begin{code}
204 instance Num a => Num [a] where
205   (+) = safezipWith "+" (+)
206   (-) = safezipWith "-" (-)
207   negate = map negate
208   abs = map abs
209   signum = map signum
210 --   (*) = undefined
211 --   fromInteger = undefined
212 \end{code}
213
214 \begin{code}
215 sum1 :: (Num a) => [a] -> a
216 sum1 = foldl1 (+)
217
218 --main = print (sum1 [[4,1,1], [5,1,2], [6,1,3,4]])
219 \end{code}
220
221 \begin{code}
222 map2 f = map (map f) 
223 map3 f = map (map2 f) 
224 \end{code}
225
226 Map function f at position n only.  Out of range indices are silently
227 ignored.
228
229 \begin{code}
230 mapat n f x = mapat1 0 f x where
231     mapat1 _ _ [] = []
232     mapat1 i f (x:xs) = (if i == n then f x else x) : mapat1 (i + 1) f xs
233
234 mapat2 (i,j) = mapat i . mapat j
235 mapat3 (i,j,k) = mapat i . mapat j . mapat k
236
237 -- main = print (mapat 2 (10+) [1,2,3,4])
238 -- main = print (mapat2 (1,0) (1000+) ginp)
239 -- main = print (mapat3 (1,0,1) (1000+) gw)
240 \end{code}
241
242 \begin{code}
243 mapindexed f x = mapindexed1 f 0 x where
244     mapindexed1 _ _ [] = []
245     mapindexed1 f n (x:xs) = f n x : mapindexed1 f (n + 1) xs
246
247 mapindexed2 f = mapindexed (\i -> mapindexed (\j -> f (i, j))) 
248 mapindexed3 f = mapindexed (\i -> mapindexed (\j -> mapindexed (\k -> f (i, j, k))))
249
250 -- main = print (mapindexed (\x y -> mapat (10+) [1,2,3,4] y) [1,2,3,4])
251 -- main = print (mapindexed2 (\(i,j) x -> 100*i + 10*j + x) ginp)
252 -- main = print (mapindexed3 (\(i,j,k) x -> 1000*i + 100*j + 10*k + x) gw)
253 \end{code}
254
255
256
257 Overload arithmetical operators to work on arrays.
258
259 \begin{code}
260 instance (Ix a, Show a, Num b) => Num (Array a b) where
261   (+) = zipArr "+" (+)
262   (-) = zipArr "-" (-)
263   negate = map negate
264   abs = map abs
265   signum = map signum
266 --   (*) = matMult -- works only for matrices!
267 --   fromInteger = map fromInteger
268 \end{code}
269
270 \begin{xcode}
271 scaleArr :: (Ix i, Num a) => a -> Array i a -> Array i a
272 scaleArr a = map (*a)
273
274 sumArr :: (Ix i, Num a) => Array i a -> a
275 sumArr = sum . elems
276 \end{xcode}
277
278 \begin{code}
279 arraySize :: (Ix i) => Array i a -> Int
280 arraySize = rangeSize . bounds
281 \end{code}
282
283 \begin{code}
284 matMult         :: (Ix a, Ix b, Ix c, Num d) =>
285                    Array (a,b) d -> Array (b,c) d -> Array (a,c) d
286 matMult x y     =  array resultBounds
287                          [((i,j), sum [x!(i,k) * y!(k,j) | k <- range (lj,uj)])
288                                        | i <- range (li,ui),
289                                          j <- range (lj',uj') ]
290         where ((li,lj),(ui,uj))         =  bounds x
291               ((li',lj'),(ui',uj'))     =  bounds y
292               resultBounds
293                 | (lj,uj)==(li',ui')    =  ((li,lj'),(ui,uj'))
294                 | otherwise             = error "matMult: incompatible bounds"
295 \end{code}
296
297
298 Inner product of two vectors.
299
300 \begin{code}
301 inner :: Num a => Vector a -> Vector a -> a
302 inner (Vector v) (Vector w) = if b == bounds w
303                then sum [v!i * w!i | i <- range b]
304                else error "nn.inner: inconformable vectors"
305             where b = bounds v
306 \end{code}
307
308 Outer product of two vectors $v \dot w^\mathrm{T}$.
309
310 \begin{code}
311 outerVector :: Num b => Vector b -> Vector b -> Matrix b
312 outerVector (Vector v) (Vector w) = if (l,u) == (l',u')
313                then matrix ((l,l'),(u,u')) [((i,j), v!i * w!j) | i <- range (l,u), j <- range (l',u')]
314                else error "nn.outer: inconformable vectors"
315             where ((l,u),(l',u')) = (bounds v, bounds w)
316 \end{code}
317
318 \begin{code}
319 outerArr :: (Ix a, Num b) => Array a b -> Array a b -> Array (a,a) b
320 outerArr v w = if (l,u) == (l',u')
321                then array ((l,l'),(u,u')) [((i,j), v!i * w!j) | i <- range (l,u), j <- range (l',u')]
322                else error "nn.outer: inconformable vectors"
323             where ((l,u),(l',u')) = (bounds v, bounds w)
324 \end{code}
325
326 Inner product of a matrix and a vector.
327
328 \begin{code}
329 matvec :: (Ix a, Num b) => Array (a,a) b -> Array a b -> Array a b
330 matvec w x | bounds x == (l',u') =
331                 array (l,u) [(i, sum [w!(i,j) * x!j | j <- range (l',u')]) 
332                                 | i <- range (l,u)]
333            | otherwise           = error "nn.matvec: inconformable arrays"
334          where ((l,l'),(u,u')) = bounds w
335 \end{code}
336
337 Append to a vector.
338
339 \begin{code}
340 augment :: (Num a) => Vector a -> a -> Vector a
341 augment (Vector x) y = Vector (array (a,b') ((b',y) : assocs x))
342             where (a,b) = bounds x
343                   b' = b + 1
344 \end{code}
345
346 Older approach (x!!i!!j fails in ghc-2.03).
347
348 \begin{code}
349 toMatrix' :: [[a]] -> Matrix a
350 toMatrix' x = Matrix (array ((1,1),(u,u')) [((i,j), (x!!(i-1))!!(j-1)) 
351                              | i <- range (1,u), j <- range (1,u')])
352           where (u,u') = (length x,length (x!!0))
353 \end{code}
354
355 Matrix 2D printout.
356
357 \begin{code}
358 padleft :: Int -> String -> String
359 padleft n x | n <= length x = x
360             | otherwise = replicate (n - length x) ' ' ++ x
361 \end{code}
362
363 \begin{code}
364 padMatrix :: RealFloat a => Int -> Matrix a -> Matrix String
365 padMatrix n x = let ss = map (\a -> showFFloat (Just n) a "") x 
366                     maxw = maximum (map length (elems (fromMatrix ss)))
367               in map (padleft maxw) ss
368 \end{code}
369
370 \begin{xcode}
371 showsVector :: (RealFloat a) => Int -> Vector a -> ShowS
372 showsVector n x z1 = let x' = padArr n x
373                          (l,u) = bounds x' in
374                   concat (map (\ (i, s) -> if i == u then s ++ "\n" else s ++ " ") (assocs x')) ++ z1
375 \end{xcode}
376
377 \begin{xcode}
378 showsMatrix :: RealFloat a => Int -> Matrix a -> ShowS
379 showsMatrix n x z1 = let x' = padMatrix n x
380                          ((l,l'),(u,u')) = bounds x' in
381                    concat (map (\ ((i,j), s) -> if j == u' then s ++ "\n" else s ++ " ") (assocs x')) ++ z1
382 \end{xcode}
383
384 {-
385 showsVecList n x s = foldr (showsVector n) s x
386 showsMatList n x s = foldr (showsMatrix n) s x
387 -}
388
389
390 \begin{code}
391 --spy :: Show a => String -> a -> a
392 --spy msg x = trace ('<' : msg ++ ": " ++ shows x ">\n") x
393 --spy x  = seq (unsafePerformIO (putStr ('<' : shows x ">\n"))) x
394 --spy x  = traceShow "z" x
395 \end{code}