3 Date started : 9th November 1992
5 This module implements backprop using pattern presentation style,
6 allowing for a general number of layers. No sigmoid on last layer.
7 + 0.1 to sigmoid derivative. It does not implement momentum.
9 Need to use modules for matrix and vector operations.
11 > module BpGen {-partain:(Dimensions(..),
12 > Layer(..), Layers(..),
14 > Weight(..), Weights(..),
15 > maxplace, classeg, calcerror, selectegs,
16 > trainweights, randweights)-} where
18 > import {-fool mkdependHS-}
20 > import List(transpose)
23 > randomInts :: a -> Int -> [Int]
24 > randomInts _ l = randoms (mkStdGen l)
25 > randomDoubles :: a -> Int -> [Double]
26 > randomDoubles _ l = randoms (mkStdGen l)
28 -------------------------------------------------------------------------------
30 -------------------------------------------------------------------------------
32 > type Dimensions = [Int] -- for network topology
33 > type Layer = [Double] -- vector for layers (incl. input and output)
34 > type Layers = [Layer]
35 > type Weight = [[Double]] -- connections between layers
36 > type Weights = [Weight]
37 > type Eg = (Layer,Layer) -- attributes and classes
41 -------------------------------------------------------------------------------
43 -------------------------------------------------------------------------------
45 Maxplace finds the position of the maximum element in a list.
46 sublist subtracts two vectors, $$ performs across vector multiplication
47 weivecmult multiplies a matrix and a vector
48 classeg takes the weights of a network and an input vector, and produces
49 a list of the Layers of the network after classification
50 calcerror calculates the root mean squared error of the data set
51 Also implemented sqr and sig (Sigmoid function).
53 > maxplace :: (Ord a) => [a] -> Int
54 > maxplace xs = length (takeWhile (/=(maximum xs)) xs)
56 > sqr :: (Num a) => a -> a
59 > sig :: (Floating a) => a -> a
60 > sig x = 1.0 / (1.0 + exp (negate x))
62 > sublist, ($$) :: (Num a) => [a] -> [a] -> [a]
63 > sublist = zipWith (-)
66 > weivecmult :: Weight -> Layer -> Layer
67 > weivecmult w v = [sum (wi $$ v) | wi <- w]
70 > classeg :: Weights -> Layer -> Layers
73 > = let l' = if null ws then weivecmult w templ
74 > else map sig (weivecmult w templ)
75 > templ = if null ws then l
77 > in templ : (classeg ws l')
81 > calcerror :: Weights -> Egs -> Double
82 > calcerror ws egs = sqrt (calcerror1 ws egs)
84 > calcerror1 :: Weights -> Egs -> Double
85 > calcerror1 _ [] = 0.0
86 > calcerror1 ws ((x,t):egs)
87 > = (sum.(map sqr).(sublist t).last) (classeg ws x)
91 -------------------------------------------------------------------------------
92 | Network Training Functions |
93 -------------------------------------------------------------------------------
95 selectegs produces a list of random numbers corresponding to the examples
96 to be selected during training. (It takes the range of the examples)
98 > selectegs :: Int -> [Int]
99 > selectegs n = map (`rem` n) (randomInts n n)
102 trainweights calls trainepoch to iteratively train the network. It
103 also checks the error at the end of each call to see if it has fallen to
106 > trainweights :: Egs -> Weights -> Int -> Double -> Double
107 > -> [Int] -> (Weights, [Double])
108 > trainweights _ ws 0 _ _ _ = (ws, [])
109 > --should be:trainweights egs ws (eps+1) err eta rs
110 > trainweights egs ws eps err eta rs
111 > | eps < 0 = error "BpGen.trainweights"
113 > = let (ws',rs') = trainepoch egs ws (length egs) eta rs
114 > newerr = calcerror ws' egs
115 > (ws'', errs) = trainweights egs ws' (eps-1) err eta rs'
116 > in if newerr < err then (ws', [newerr])
117 > else (ws'', newerr:errs)
120 trainepoch iteratively calls classeg and backprop to train the network,
121 as well as selecting an example.
123 > trainepoch :: Egs -> Weights -> Int -> Double -> [Int] -> (Weights, [Int])
124 > trainepoch _ ws 0 _ rs = (ws,rs)
125 > --should be: trainepoch egs ws (egno+1) eta (r:rs)
126 > trainepoch egs ws egno eta (r:rs)
127 > | egno < 0 = error "BpGen.trainepoch"
129 > = let (x,t) = egs !! r
130 > ws' = backprop eta (classeg ws x) ws t
131 > in trainepoch egs ws' (egno-1) eta rs
134 backprop causes weight changes after calculating the change
136 > backprop :: Double -> Layers -> Weights -> Layer -> Weights
137 > backprop eta (o:os) (w:ws) t
138 > = changeweights eta (o:os) (calcchange os ws t) (w:ws)
141 calcchange calculates the changes to the weights
143 > calcchange :: Layers -> Weights -> Layer -> Layers
144 > calcchange [o] [] t = [sublist t o]
145 > calcchange (o:os) (w:ws) t
146 > = (sigop o (weivecmult (transpose w) (head ds))) : ds
147 > where ds = calcchange os ws t
150 sigop performs the calculations involving the derivative of the sigmoid.
151 This uses a constant to eliminate flat spots [Fahlman, 1988]
153 > sigop :: Layer -> Layer -> Layer
155 > = let sig' x = x * (1.0 - x) + 0.1
156 > in (map sig' out) $$ change
159 changeweights makes the actual changes to weights
161 > changeweights :: Double -> Layers -> Layers -> Weights -> Weights
162 > changeweights eta os ds ws
163 > = [[[wji + eta * dj * oi | (oi,wji) <- zip o wj]
164 > | (dj,wj) <- zip d w]
165 > | (w,d,o) <- zip3 ws ds os]
168 -------------------------------------------------------------------------------
169 | Weight Manipulation |
170 -------------------------------------------------------------------------------
172 randweights generates random weights in the range -1.0 to +1.0
174 > randweights :: Dimensions -> Weights
175 > randweights dimensions
176 > = genweights dimensions (map (\x -> 2.0 * x - 1.0) (randomDoubles 1 1))
179 Generates weights, taking the values from the list of Doubles.
180 The weight sizes are taken from the list of dimensions.
182 > genweights :: Dimensions -> [Double] -> Weights
183 > genweights [x] _ = []
184 > genweights (x:y:dimensions) rs
185 > = let (w, rs') = if null dimensions then multSplitAt x y rs
186 > else multSplitAt (x+1) y rs
187 > in w : (genweights (y:dimensions) rs')
190 > multSplitAt :: Int -> Int -> [a] -> ([[a]],[a])
191 > multSplitAt inner 0 xs = ([], xs)
192 > --should be:multSplitAt inner (outer + 1) xs
193 > multSplitAt inner outer xs
194 > | outer < 0 = error "BpGen.multSplitAt"
196 > = let (l, xs') = splitAt inner xs
197 > (ls, xs'') = multSplitAt inner (outer-1) xs'