-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathMain.java
More file actions
112 lines (96 loc) · 4.48 KB
/
Main.java
File metadata and controls
112 lines (96 loc) · 4.48 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import java.io.IOException;
import java.util.Arrays;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
public class Main {
public static String path = "./Poke_CEL_v7.ser";
public static float learnRate = 0.5f, dataSetCoverage = 1f;
public static int epochNumbers = 1000, batchSize = 50, threadNumber = 4;
public static void main(String[] args) throws IOException, InterruptedException {
Activation.computeValues(-10, 10, 10000);
DataLoader.loadTrainDataPoints("./Data/Pokemon/AugmentedCardImages/CEL", "./Data/Pokemon/CardNames/CEL.txt", dataSetCoverage);
DataLoader.loadTestDataPoints("./Data/Pokemon/ResizedCardImages/CEL", "./Data/Pokemon/CardNames/CEL.txt");
NeuralNetwork network = getNetwork(path);
loopEpochs(network, epochNumbers, batchSize, "./Data/Pokemon/ResizedCardImages/CEL", "./Data/Pokemon/CardNames/CEL.txt");
}
public static class BatchTask implements Runnable {
NeuralNetwork network;
DataPoint[] batch;
public BatchTask(NeuralNetwork network, DataPoint[] batch) {
this.network = network;
this.batch = batch;
}
@Override
public void run() {
network.learn(batch);
network.applyGradients(learnRate/batch.length);
}
}
public static void loopEpochs(NeuralNetwork network, int epochNumbers, int batchSize, String testFolder, String testLabels) throws InterruptedException {
Console.startProgressBar(3);
for (int i = 1; i < epochNumbers+1; i++) {
epoch(network, batchSize);
Console.progress(i*100/(float)epochNumbers, 3);
ReadWrite.serializeObject(network, path);
test(network, testFolder, testLabels);
}
}
public static void test(NeuralNetwork network, String testFolder, String testLabels) {
int count = 0;
float sum = 0;
for (int i = 0; i < DataLoader.testData.length; i++) {
DataPoint dataPoint = DataLoader.testData[i];
Result res = network.result(dataPoint.inputs);
sum += res.getCost(dataPoint.expectedOutputs);
Result preRes = new Result(dataPoint.expectedOutputs);
if (preRes.predicted == res.predicted) {
count++;
} else {
//System.out.println(format(preRes.outputs)+" - "+format(res.outputs));
//System.out.println(preRes.outputs[preRes.outputs.length-10]+" - "+res.outputs[res.outputs.length-10]);
//System.out.println("Expected : "+preRes.predicted+" Predicted : "+res.predicted);
}
}
System.out.println("Cost : "+(sum/DataLoader.testData.length));
System.out.println("Accuracy : "+count+"/"+DataLoader.testData.length+" -> "+count/(float) DataLoader.testData.length);
}
public static NeuralNetwork getNetwork(String path) {
if (ReadWrite.exists(path)) {
return (NeuralNetwork)ReadWrite.deserializeObject(path);
} else {
NeuralNetwork network = new NeuralNetwork(1104, 786, 400, 50);
network.randomInit(-1, 1);
return network;
}
}
public static void epoch(NeuralNetwork network, int batchSize) throws InterruptedException {
DataLoader.setupBatches(batchSize);
ExecutorService executorService = Executors.newFixedThreadPool(threadNumber);
while (DataLoader.hasNextBatch()) {
BatchTask task = new BatchTask(network, DataLoader.getBatch());
executorService.submit(task);
}
executorService.shutdown();
executorService.awaitTermination(Long.MAX_VALUE, TimeUnit.NANOSECONDS);
}
public static String format(float[] outputs) {
String ans = "[";
for (float output : outputs) {
ans += output+",";
}
return ans.substring(0, ans.length()-1)+"]";
}
public static double calculateStandardDeviation(int[] arr) {
if (arr.length <= 1) {
return 0; // Standard deviation is undefined for arrays with only one element
}
double mean = Arrays.stream(arr).average().orElse(0);
double[] squaredDeviations = Arrays.stream(arr)
.mapToDouble(x -> Math.pow(x - mean, 2))
.toArray();
double variance = Arrays.stream(squaredDeviations).sum() / (arr.length - 1);
double stdDev = Math.sqrt(variance);
return stdDev;
}
}