Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 32 additions & 18 deletions lib/nn.js
Original file line number Diff line number Diff line change
Expand Up @@ -19,22 +19,36 @@ let tanh = new ActivationFunction(


class NeuralNetwork {
// TODO: document what a, b, c are
constructor(a, b, c) {
if (a instanceof NeuralNetwork) {
this.input_nodes = a.input_nodes;
this.hidden_nodes = a.hidden_nodes;
this.output_nodes = a.output_nodes;

this.weights_ih = a.weights_ih.copy();
this.weights_ho = a.weights_ho.copy();

this.bias_h = a.bias_h.copy();
this.bias_o = a.bias_o.copy();
} else {
this.input_nodes = a;
this.hidden_nodes = b;
this.output_nodes = c;
/**
* Construct method.
*
* The user can enter 3 parameters (integer) to create a new NeuralNetwork,
* where: the first parameter represents the number of inputs, the second
* parameter represents the number of hidden nodes and the third parameter
* represents the number of outputs of the network.
* The user can copy an instance of NeuralNetwork by passing the same as an
* argument to the constructor.
*
* @param {NeuralNetwork|integer} args (Rest parameters)
*/
constructor(...args) {
if (args.length === 1 && args[0] instanceof NeuralNetwork) {
this.input_nodes = args[0].input_nodes;
this.hidden_nodes = args[0].hidden_nodes;
this.output_nodes = args[0].output_nodes;

this.weights_ih = args[0].weights_ih.copy();
this.weights_ho = args[0].weights_ho.copy();

this.bias_h = args[0].bias_h.copy();
this.bias_o = args[0].bias_o.copy();
} else if(args.length === 3) {
if(!args.every( v => Number.isInteger(v) )) {
throw new Error('All arguments must be integer.');
}
this.input_nodes = args[0];
this.hidden_nodes = args[1];
this.output_nodes = args[2];

this.weights_ih = new Matrix(this.hidden_nodes, this.input_nodes);
this.weights_ho = new Matrix(this.output_nodes, this.hidden_nodes);
Expand All @@ -45,13 +59,13 @@ class NeuralNetwork {
this.bias_o = new Matrix(this.output_nodes, 1);
this.bias_h.randomize();
this.bias_o.randomize();
} else {
throw new Error('Invalid arguments. Read the documentation!');
}

// TODO: copy these as well
this.setLearningRate();
this.setActivationFunction();


}

predict(input_array) {
Expand Down