1 | using System; |
---|
2 | using System.Collections.Generic; |
---|
3 | using System.Linq; |
---|
4 | using System.Text; |
---|
5 | |
---|
6 | namespace BCMToolbox |
---|
7 | { |
---|
8 | [Serializable] |
---|
9 | public class WeightUpdateMessage : LocalMessage |
---|
10 | { |
---|
11 | public int[] Index = new int[] { 0, 0 }; |
---|
12 | |
---|
13 | public WeightUpdateMessage() |
---|
14 | : base(WEIGHT_UPDATE) |
---|
15 | { |
---|
16 | } |
---|
17 | |
---|
18 | public const string WEIGHT_UPDATE = "W Update"; |
---|
19 | } |
---|
20 | |
---|
21 | [Serializable] |
---|
22 | public class OutputUpdateMessage : LocalMessage |
---|
23 | { |
---|
24 | public int Index = 0; |
---|
25 | public const string OUTPUT_UPDATE = "OUT Update"; |
---|
26 | |
---|
27 | public OutputUpdateMessage() |
---|
28 | : base(OUTPUT_UPDATE) |
---|
29 | { |
---|
30 | } |
---|
31 | } |
---|
32 | |
---|
33 | public abstract class SynchronousHebbian3Network : FixedTraceableNetwork |
---|
34 | { |
---|
35 | double[] __intermediaryOutputs = null; |
---|
36 | int[] __liveOutputs = null; |
---|
37 | List<LocalPipe> __outputPipes = new List<LocalPipe>(); |
---|
38 | |
---|
39 | protected Func<double, double> __activation = MathToolbox.Sigmoid; |
---|
40 | protected Func<double, double> __activation_inverse = MathToolbox.InverseSigmoid; |
---|
41 | |
---|
42 | internal SynchronousHebbian3Network(int neuronCount, int newID, int[] liveOutputs) |
---|
43 | : base(neuronCount, newID) |
---|
44 | { |
---|
45 | __intermediaryOutputs = new double[neuronCount]; |
---|
46 | __liveOutputs = liveOutputs; |
---|
47 | foreach (int lo in __liveOutputs) |
---|
48 | if (lo >= neuronCount) |
---|
49 | throw new IndexOutOfRangeException(); |
---|
50 | } |
---|
51 | |
---|
52 | internal void LinkPipe(LocalPipe inbound) |
---|
53 | { |
---|
54 | inbound.MessageReceive += new EventHandler<LocalMessageEventArgs>(inbound_MessageReceive); |
---|
55 | } |
---|
56 | |
---|
57 | internal LocalPipe RegisterOtherNetwork(SynchronousHebbian3Network newNetwork) |
---|
58 | { |
---|
59 | if (newNetwork == this) |
---|
60 | return null; |
---|
61 | |
---|
62 | LocalPipe pipe = new LocalPipe(); |
---|
63 | __outputPipes.Add(pipe); |
---|
64 | return pipe; |
---|
65 | } |
---|
66 | |
---|
67 | void inbound_MessageReceive(object sender, LocalMessageEventArgs e) |
---|
68 | { |
---|
69 | if (e.Message.MessageType != WeightUpdateMessage.WEIGHT_UPDATE) |
---|
70 | return; |
---|
71 | |
---|
72 | // ref message |
---|
73 | } |
---|
74 | |
---|
75 | public override string Identification |
---|
76 | { |
---|
77 | get |
---|
78 | { |
---|
79 | return String.Format("HebbM Layer #{0}", this.UID); |
---|
80 | } |
---|
81 | } |
---|
82 | |
---|
83 | public virtual double Inertia |
---|
84 | { |
---|
85 | get |
---|
86 | { |
---|
87 | return 0.02; |
---|
88 | } |
---|
89 | } |
---|
90 | |
---|
91 | public virtual double OutputDecay |
---|
92 | { |
---|
93 | get |
---|
94 | { |
---|
95 | return 0.85; |
---|
96 | } |
---|
97 | } |
---|
98 | |
---|
99 | protected abstract double WeightModificationKernel(double weight, double input, double output); |
---|
100 | |
---|
101 | |
---|
102 | public virtual bool HasNormalizedWeights |
---|
103 | { |
---|
104 | get |
---|
105 | { |
---|
106 | return false; |
---|
107 | } |
---|
108 | } |
---|
109 | |
---|
110 | public virtual bool HasNormalizedInputs |
---|
111 | { |
---|
112 | get |
---|
113 | { |
---|
114 | return false; |
---|
115 | } |
---|
116 | } |
---|
117 | |
---|
118 | |
---|
119 | |
---|
120 | public override void ApplyInputs(double[] inputs) |
---|
121 | { |
---|
122 | lock (SyncRoot) |
---|
123 | { |
---|
124 | if (inputs == null) |
---|
125 | throw new ArgumentNullException(); |
---|
126 | |
---|
127 | if (inputs.Length != this.Count) |
---|
128 | throw new ArgumentException("Invalid input size. Expected: ", this.Count.ToString()); |
---|
129 | |
---|
130 | for (int i = 0; i < this.Count; i++) |
---|
131 | __potentials[i] += inputs[i]; |
---|
132 | } |
---|
133 | |
---|
134 | if (HasNormalizedInputs) |
---|
135 | NormalizePotentials(2.0); |
---|
136 | |
---|
137 | base.ApplyInputs(inputs); |
---|
138 | } |
---|
139 | |
---|
140 | |
---|
141 | public override void Propagate() |
---|
142 | { |
---|
143 | // calculate outputs |
---|
144 | for (int i = 0; i < this.Count; i++) |
---|
145 | { |
---|
146 | __intermediaryOutputs[i] = 0.0; |
---|
147 | for (int j = 0; j < this.Count; j++) |
---|
148 | if ((i != j) && !double.IsNaN(__weights[j, i])) |
---|
149 | __intermediaryOutputs[i] += __weights[j, i] * __potentials[j]; |
---|
150 | __intermediaryOutputs[i] = __activation(__intermediaryOutputs[i]); |
---|
151 | } |
---|
152 | |
---|
153 | if (HasNormalizedInputs) |
---|
154 | { |
---|
155 | double[] temp = __potentials; |
---|
156 | __potentials = __intermediaryOutputs; |
---|
157 | NormalizePotentials(2.0); |
---|
158 | __potentials = temp; |
---|
159 | } |
---|
160 | |
---|
161 | // update weights and lock outputs |
---|
162 | lock (SyncRoot) |
---|
163 | { |
---|
164 | for (int i = 0; i < this.Count; i++) |
---|
165 | for (int j = 0; j < this.Count; j++) |
---|
166 | { |
---|
167 | if ((i == j) || double.IsNaN(__weights[i, j])) |
---|
168 | continue; |
---|
169 | |
---|
170 | __weights[i, j] = WeightModificationKernel(__weights[i, j], __potentials[i], __intermediaryOutputs[j]); |
---|
171 | |
---|
172 | |
---|
173 | System.Diagnostics.Debug.Assert(!double.IsNaN(__weights[i, j])); |
---|
174 | System.Diagnostics.Debug.Assert(!double.IsInfinity(__weights[i, j])); |
---|
175 | |
---|
176 | |
---|
177 | |
---|
178 | } |
---|
179 | |
---|
180 | if (HasNormalizedWeights) |
---|
181 | NormalizeWeights(2.0); |
---|
182 | |
---|
183 | for (int i = 0; i < this.Count; i++) |
---|
184 | __potentials[i] = __intermediaryOutputs[i] * OutputDecay; |
---|
185 | } |
---|
186 | |
---|
187 | base.Propagate(); |
---|
188 | } |
---|
189 | |
---|
190 | public override void Initialize(Func<int, int, double> initializerWeights, Func<int, double> initializerPotentials) |
---|
191 | { |
---|
192 | base.Initialize(initializerWeights, initializerPotentials); |
---|
193 | if (HasNormalizedInputs) |
---|
194 | NormalizePotentials(2.0); |
---|
195 | |
---|
196 | if (HasNormalizedWeights) |
---|
197 | NormalizeWeights(2.0); |
---|
198 | } |
---|
199 | } |
---|
200 | } |
---|