using System; using System.Collections.Generic; using System.Linq; using System.Text; namespace BCMToolbox { [Serializable] public class WeightUpdateMessage : LocalMessage { public int[] Index = new int[] { 0, 0 }; public WeightUpdateMessage() : base(WEIGHT_UPDATE) { } public const string WEIGHT_UPDATE = "W Update"; } [Serializable] public class OutputUpdateMessage : LocalMessage { public int Index = 0; public const string OUTPUT_UPDATE = "OUT Update"; public OutputUpdateMessage() : base(OUTPUT_UPDATE) { } } public abstract class SynchronousHebbian3Network : FixedTraceableNetwork { double[] __intermediaryOutputs = null; int[] __liveOutputs = null; List __outputPipes = new List(); protected Func __activation = MathToolbox.Sigmoid; protected Func __activation_inverse = MathToolbox.InverseSigmoid; internal SynchronousHebbian3Network(int neuronCount, int newID, int[] liveOutputs) : base(neuronCount, newID) { __intermediaryOutputs = new double[neuronCount]; __liveOutputs = liveOutputs; foreach (int lo in __liveOutputs) if (lo >= neuronCount) throw new IndexOutOfRangeException(); } internal void LinkPipe(LocalPipe inbound) { inbound.MessageReceive += new EventHandler(inbound_MessageReceive); } internal LocalPipe RegisterOtherNetwork(SynchronousHebbian3Network newNetwork) { if (newNetwork == this) return null; LocalPipe pipe = new LocalPipe(); __outputPipes.Add(pipe); return pipe; } void inbound_MessageReceive(object sender, LocalMessageEventArgs e) { if (e.Message.MessageType != WeightUpdateMessage.WEIGHT_UPDATE) return; // ref message } public override string Identification { get { return String.Format("HebbM Layer #{0}", this.UID); } } public virtual double Inertia { get { return 0.02; } } public virtual double OutputDecay { get { return 0.85; } } protected abstract double WeightModificationKernel(double weight, double input, double output); public virtual bool HasNormalizedWeights { get { return false; } } public virtual bool HasNormalizedInputs { get { return false; } } public override void ApplyInputs(double[] inputs) { lock (SyncRoot) { if (inputs == null) throw new ArgumentNullException(); if (inputs.Length != this.Count) throw new ArgumentException("Invalid input size. Expected: ", this.Count.ToString()); for (int i = 0; i < this.Count; i++) __potentials[i] += inputs[i]; } if (HasNormalizedInputs) NormalizePotentials(2.0); base.ApplyInputs(inputs); } public override void Propagate() { // calculate outputs for (int i = 0; i < this.Count; i++) { __intermediaryOutputs[i] = 0.0; for (int j = 0; j < this.Count; j++) if ((i != j) && !double.IsNaN(__weights[j, i])) __intermediaryOutputs[i] += __weights[j, i] * __potentials[j]; __intermediaryOutputs[i] = __activation(__intermediaryOutputs[i]); } if (HasNormalizedInputs) { double[] temp = __potentials; __potentials = __intermediaryOutputs; NormalizePotentials(2.0); __potentials = temp; } // update weights and lock outputs lock (SyncRoot) { for (int i = 0; i < this.Count; i++) for (int j = 0; j < this.Count; j++) { if ((i == j) || double.IsNaN(__weights[i, j])) continue; __weights[i, j] = WeightModificationKernel(__weights[i, j], __potentials[i], __intermediaryOutputs[j]); System.Diagnostics.Debug.Assert(!double.IsNaN(__weights[i, j])); System.Diagnostics.Debug.Assert(!double.IsInfinity(__weights[i, j])); } if (HasNormalizedWeights) NormalizeWeights(2.0); for (int i = 0; i < this.Count; i++) __potentials[i] = __intermediaryOutputs[i] * OutputDecay; } base.Propagate(); } public override void Initialize(Func initializerWeights, Func initializerPotentials) { base.Initialize(initializerWeights, initializerPotentials); if (HasNormalizedInputs) NormalizePotentials(2.0); if (HasNormalizedWeights) NormalizeWeights(2.0); } } }