-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathNeuralNetwork.java
More file actions
57 lines (47 loc) · 1.75 KB
/
NeuralNetwork.java
File metadata and controls
57 lines (47 loc) · 1.75 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import java.io.Serializable;
public class NeuralNetwork implements Serializable {
public int inputSize;
public Layer[] layers;
public NeuralNetwork(int... layersNodes) {
this.inputSize = layersNodes[0];
this.layers = new Layer[layersNodes.length-1];
for (int index = 0; index < layersNodes.length-1; index++) {
this.layers[index] = new Layer(layersNodes[index], layersNodes[index+1]);
}
}
public Result result(float[] inputs) {
return new Result(forwardPropagate(inputs));
}
public void learn(DataPoint[] trainingBatch) {
for (DataPoint dataPoint : trainingBatch) {
backPropagate(dataPoint.inputs, dataPoint.expectedOutputs);
}
}
public void backPropagate(float[] inputs, float[] expectedOuputs) {
forwardPropagate(inputs);
Layer outputLayer = layers[layers.length-1];
float[] nodeValues = outputLayer.computeLLNodeValues(expectedOuputs);
outputLayer.updateGradients(nodeValues);
for (int hlIndex = layers.length-2; hlIndex >= 0; hlIndex--) {
Layer hiddenLayer = layers[hlIndex];
nodeValues = hiddenLayer.computeHLNodeValues(layers[hlIndex+1], nodeValues);
hiddenLayer.updateGradients(nodeValues);
}
}
public void applyGradients(float learnRate) {
for (Layer layer : layers) {
layer.applyGradients(learnRate);
}
}
public float[] forwardPropagate(float[] inputs) {
for (Layer layer : layers) {
inputs = layer.computeOutput(inputs);
}
return inputs;
}
public void randomInit(float a, float b) {
for (Layer layer : this.layers) {
layer.randomInit(a, b);
}
}
}