source: proiecte/ParallelANN/BCMToolbox/MultiHebbian.cs @ 171

Last change on this file since 171 was 171, checked in by (none), 14 years ago
File size: 5.7 KB
Line 
1using System;
2using System.Collections.Generic;
3using System.Linq;
4using System.Text;
5
6namespace BCMToolbox
7{
8    [Serializable]
9    public class WeightUpdateMessage : LocalMessage
10    {
11        public int[] Index = new int[] { 0, 0 };
12
13        public WeightUpdateMessage()
14            : base(WEIGHT_UPDATE)
15        {
16        }
17
18        public const string WEIGHT_UPDATE = "W Update";
19    }
20
21    [Serializable]
22    public class OutputUpdateMessage : LocalMessage
23    {
24        public int Index = 0;
25        public const string OUTPUT_UPDATE = "OUT Update";
26
27        public OutputUpdateMessage()
28            : base(OUTPUT_UPDATE)
29        {
30        }
31    }
32
33    public abstract class SynchronousHebbian3Network : FixedTraceableNetwork
34    {
35        double[] __intermediaryOutputs = null;
36        int[] __liveOutputs = null;
37        List<LocalPipe> __outputPipes = new List<LocalPipe>();
38
39        protected Func<double, double> __activation = MathToolbox.Sigmoid;
40        protected Func<double, double> __activation_inverse = MathToolbox.InverseSigmoid;
41
42        internal SynchronousHebbian3Network(int neuronCount, int newID, int[] liveOutputs)
43            : base(neuronCount, newID)
44        {
45            __intermediaryOutputs = new double[neuronCount];
46            __liveOutputs = liveOutputs;
47            foreach (int lo in __liveOutputs)
48                if (lo >= neuronCount)
49                    throw new IndexOutOfRangeException();
50        }
51
52        internal void LinkPipe(LocalPipe inbound)
53        {
54            inbound.MessageReceive += new EventHandler<LocalMessageEventArgs>(inbound_MessageReceive);
55        }       
56
57        internal LocalPipe RegisterOtherNetwork(SynchronousHebbian3Network newNetwork)
58        {
59            if (newNetwork == this)
60                return null;
61
62            LocalPipe pipe = new LocalPipe();
63            __outputPipes.Add(pipe);
64            return pipe;
65        }
66
67        void inbound_MessageReceive(object sender, LocalMessageEventArgs e)
68        {
69            if (e.Message.MessageType != WeightUpdateMessage.WEIGHT_UPDATE)
70                return;
71
72            // ref message
73        }
74
75        public override string Identification
76        {
77            get
78            {
79                return String.Format("HebbM Layer #{0}", this.UID);
80            }
81        }
82
83        public virtual double Inertia
84        {
85            get
86            {
87                return 0.02;
88            }
89        }
90
91        public virtual double OutputDecay
92        {
93            get
94            {
95                return 0.85;
96            }
97        }
98
99        protected abstract double WeightModificationKernel(double weight, double input, double output);
100
101
102        public virtual bool HasNormalizedWeights
103        {
104            get
105            {
106                return false;
107            }
108        }
109
110        public virtual bool HasNormalizedInputs
111        {
112            get
113            {
114                return false;
115            }
116        }
117
118
119
120        public override void ApplyInputs(double[] inputs)
121        {
122            lock (SyncRoot)
123            {
124                if (inputs == null)
125                    throw new ArgumentNullException();
126
127                if (inputs.Length != this.Count)
128                    throw new ArgumentException("Invalid input size. Expected: ", this.Count.ToString());
129
130                for (int i = 0; i < this.Count; i++)
131                    __potentials[i] += inputs[i];
132            }
133
134            if (HasNormalizedInputs)
135                NormalizePotentials(2.0);
136
137            base.ApplyInputs(inputs);
138        }
139
140
141        public override void Propagate()
142        {
143            // calculate outputs
144            for (int i = 0; i < this.Count; i++)
145            {
146                __intermediaryOutputs[i] = 0.0;
147                for (int j = 0; j < this.Count; j++)
148                    if ((i != j) && !double.IsNaN(__weights[j, i]))
149                        __intermediaryOutputs[i] += __weights[j, i] * __potentials[j];
150                __intermediaryOutputs[i] = __activation(__intermediaryOutputs[i]);
151            }
152
153            if (HasNormalizedInputs)
154            {
155                double[] temp = __potentials;
156                __potentials = __intermediaryOutputs;
157                NormalizePotentials(2.0);
158                __potentials = temp;
159            }
160
161            // update weights and lock outputs
162            lock (SyncRoot)
163            {
164                for (int i = 0; i < this.Count; i++)
165                    for (int j = 0; j < this.Count; j++)
166                    {
167                        if ((i == j) || double.IsNaN(__weights[i, j]))
168                            continue;
169
170                        __weights[i, j] = WeightModificationKernel(__weights[i, j], __potentials[i], __intermediaryOutputs[j]);
171
172
173                        System.Diagnostics.Debug.Assert(!double.IsNaN(__weights[i, j]));
174                        System.Diagnostics.Debug.Assert(!double.IsInfinity(__weights[i, j]));
175
176
177
178                    }
179
180                if (HasNormalizedWeights)
181                    NormalizeWeights(2.0);
182
183                for (int i = 0; i < this.Count; i++)
184                    __potentials[i] = __intermediaryOutputs[i] * OutputDecay;
185            }
186
187            base.Propagate();
188        }
189
190        public override void Initialize(Func<int, int, double> initializerWeights, Func<int, double> initializerPotentials)
191        {
192            base.Initialize(initializerWeights, initializerPotentials);
193            if (HasNormalizedInputs)
194                NormalizePotentials(2.0);
195
196            if (HasNormalizedWeights)
197                NormalizeWeights(2.0);
198        }
199    }
200}
Note: See TracBrowser for help on using the repository browser.