diff --git a/lib/nn.js b/lib/nn.js index 5cce746..7c20490 100644 --- a/lib/nn.js +++ b/lib/nn.js @@ -1,7 +1,7 @@ // Other techniques for learning -class ActivationFunction{ - constructor(func, dfunc){ +class ActivationFunction { + constructor(func, dfunc) { this.func = func; this.dfunc = dfunc; } @@ -9,12 +9,12 @@ class ActivationFunction{ let sigmoid = new ActivationFunction( x => 1 / (1 + Math.exp(-x)), - y => y * (1- y) + y => y * (1 - y) ); let tanh = new ActivationFunction( x => Math.tanh(x), - y => 1-(y*y) + y => 1 - (y * y) ); @@ -33,10 +33,15 @@ class NeuralNetwork { this.bias_o = new Matrix(this.output_nodes, 1); this.bias_h.randomize(); this.bias_o.randomize(); - this.setLearningRate(); - this.setActivationFunction(); + this.learningRateDecay = false; + this.decayRatio = 0; + this.numberOfTrainings = 0; + this.countOfTrainings = 0; + + this.setLearningRate(); + this.setActivationFunction(); } predict(input_array) { @@ -57,12 +62,35 @@ class NeuralNetwork { return output.toArray(); } + setActivationFunction(func = sigmoid) { + this.activation_function = func; + } + setLearningRate(learning_rate = 0.1) { this.learning_rate = learning_rate; } - setActivationFunction(func = sigmoid) { - this.activation_function = func; + enableLearningRateDecay(number_of_trainings = 1000, decay_ratio = 0.01) { + this.learningRateDecay = true; + this.numberOfTrainings = number_of_trainings; + this.decayRatio = decay_ratio; + } + + disableLearningRateDecay() { + this.learningRateDecay = false; + } + + checkLearningRateDecay() { + if (this.countOfTrainings > 0 && this.numberOfTrainings > 0 && this.learningRateDecay && this.learning_rate > 0) { + if (this.countOfTrainings % this.numberOfTrainings === 0) { + let newLearningRate = this.learning_rate * (1 - this.decayRatio); + if (newLearningRate <= 0) { + disableLearningRateDecay(); + return; + } + this.setLearningRate(newLearningRate); + } + } } train(input_array, target_array) { @@ -118,6 +146,9 @@ class NeuralNetwork { // Adjust the bias by its deltas (which is just the gradients) this.bias_h.add(hidden_gradient); + this.countOfTrainings++; + this.checkLearningRateDecay(); + // outputs.print(); // targets.print(); // error.print(); @@ -128,8 +159,7 @@ class NeuralNetwork { } static deserialize(data) { - if(typeof data == 'string') - { + if (typeof data == 'string') { data = JSON.parse(data); } let nn = new NeuralNetwork(data.input_nodes, data.hidden_nodes, data.output_nodes);