From 40e469ff0b451103465fbcdd6897d5359aab865f Mon Sep 17 00:00:00 2001 From: Maksuel Boni Date: Wed, 28 Mar 2018 11:25:58 +0200 Subject: [PATCH] handle errors and prepare for the implementation of hidden multi-layers Added "Rest parameters" for the future implementation of hidden multi-layers, where you will pass arguments like this: new NeuralNetwork(3,6,8,6,4) new NeuralNetwork(input, hidden_1, hidden_2, hidden_3..., output) Added "handle errors" in constructor method. --- lib/nn.js | 50 ++++++++++++++++++++++++++++++++------------------ 1 file changed, 32 insertions(+), 18 deletions(-) diff --git a/lib/nn.js b/lib/nn.js index 0260280..f864d7b 100644 --- a/lib/nn.js +++ b/lib/nn.js @@ -19,22 +19,36 @@ let tanh = new ActivationFunction( class NeuralNetwork { - // TODO: document what a, b, c are - constructor(a, b, c) { - if (a instanceof NeuralNetwork) { - this.input_nodes = a.input_nodes; - this.hidden_nodes = a.hidden_nodes; - this.output_nodes = a.output_nodes; - - this.weights_ih = a.weights_ih.copy(); - this.weights_ho = a.weights_ho.copy(); - - this.bias_h = a.bias_h.copy(); - this.bias_o = a.bias_o.copy(); - } else { - this.input_nodes = a; - this.hidden_nodes = b; - this.output_nodes = c; + /** + * Construct method. + * + * The user can enter 3 parameters (integer) to create a new NeuralNetwork, + * where: the first parameter represents the number of inputs, the second + * parameter represents the number of hidden nodes and the third parameter + * represents the number of outputs of the network. + * The user can copy an instance of NeuralNetwork by passing the same as an + * argument to the constructor. + * + * @param {NeuralNetwork|integer} args (Rest parameters) + */ + constructor(...args) { + if (args.length === 1 && args[0] instanceof NeuralNetwork) { + this.input_nodes = args[0].input_nodes; + this.hidden_nodes = args[0].hidden_nodes; + this.output_nodes = args[0].output_nodes; + + this.weights_ih = args[0].weights_ih.copy(); + this.weights_ho = args[0].weights_ho.copy(); + + this.bias_h = args[0].bias_h.copy(); + this.bias_o = args[0].bias_o.copy(); + } else if(args.length === 3) { + if(!args.every( v => Number.isInteger(v) )) { + throw new Error('All arguments must be integer.'); + } + this.input_nodes = args[0]; + this.hidden_nodes = args[1]; + this.output_nodes = args[2]; this.weights_ih = new Matrix(this.hidden_nodes, this.input_nodes); this.weights_ho = new Matrix(this.output_nodes, this.hidden_nodes); @@ -45,13 +59,13 @@ class NeuralNetwork { this.bias_o = new Matrix(this.output_nodes, 1); this.bias_h.randomize(); this.bias_o.randomize(); + } else { + throw new Error('Invalid arguments. Read the documentation!'); } // TODO: copy these as well this.setLearningRate(); this.setActivationFunction(); - - } predict(input_array) {