Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -80,11 +80,6 @@ jacocoTestReport {
xml.enabled = true
csv.enabled = false
}
afterEvaluate {
classDirectories.setFrom(files(classDirectories.files.collect {
fileTree(dir: it, exclude: 'tech/donau/behaiv/proto**')
}))
}
}

task proto(type: Exec) {
Expand Down
97 changes: 0 additions & 97 deletions src/main/java/de/dmi3y/behaiv/kernel/BaseKernel.java

This file was deleted.

90 changes: 72 additions & 18 deletions src/main/java/de/dmi3y/behaiv/kernel/Kernel.java
Original file line number Diff line number Diff line change
@@ -1,41 +1,95 @@
package de.dmi3y.behaiv.kernel;

import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import de.dmi3y.behaiv.storage.BehaivStorage;
import de.dmi3y.behaiv.tools.Pair;
import tech.donau.behaiv.proto.PredictionSet;

import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;

public interface Kernel {
void setId(String id);
public abstract class Kernel {

boolean isEmpty();
protected String id;
protected Long treshold = 10L;
protected ObjectMapper objectMapper;
protected boolean partialFitAllowed = false;
protected boolean alwaysKeepData = true;

@Deprecated
void fit(List<Pair<List<Double>, String>> data);
public Kernel(String id) {
this.id = id;
objectMapper = new ObjectMapper();
}

void fit(PredictionSet data);

void fit();
public void setId(String id) {
this.id = id;
}

void setTreshold(Long treshold);

boolean readyToPredict();
//list<features>, label
protected List<Pair<List<Double>, String>> data = new ArrayList<>();

void update(List<Pair<List<Double>, String>> data);

boolean isPartialFitAllowed();
public abstract boolean isEmpty();

void updateSingle(List<Double> features, String label);
public abstract void fit(List<Pair<List<Double>, String>> data);

String predictOne(List<Double> features);
public void fit() {
fit(this.data);
}

boolean isAlwaysKeepData();
public void setTreshold(Long treshold) {
this.treshold = treshold;
}

void setAlwaysKeepData(boolean alwaysKeepData);
public boolean readyToPredict() {
return data.size() > treshold;
}

void save(BehaivStorage storage) throws IOException;
public void update(List<Pair<List<Double>, String>> data) {
}

void restore(BehaivStorage storage) throws IOException;
public boolean isPartialFitAllowed() {
return partialFitAllowed;
}

public void updateSingle(List<Double> features, String label) {
data.add(new Pair<>(features, label));
}

public abstract String predictOne(List<Double> features);

public boolean isAlwaysKeepData() {
return alwaysKeepData;
}

public void setAlwaysKeepData(boolean alwaysKeepData) {
this.alwaysKeepData = alwaysKeepData;
}

public void save(BehaivStorage storage) throws IOException {

try (final BufferedWriter writer = new BufferedWriter(new FileWriter(storage.getDataFile(id)))) {
writer.write(objectMapper.writeValueAsString(data));
}
}

public void restore(BehaivStorage storage) throws IOException {
final TypeReference<List<Pair<List<Double>, String>>> typeReference = new TypeReference<List<Pair<List<Double>, String>>>() {
};
try (final BufferedReader reader = new BufferedReader(new FileReader(storage.getDataFile(id)))) {
final String content = reader.readLine();
if (content == null || content.isEmpty()) {
data = new ArrayList<>();
} else {
data = objectMapper.readValue(content, typeReference);
}
}
}
}
87 changes: 21 additions & 66 deletions src/main/java/de/dmi3y/behaiv/kernel/LogisticRegressionKernel.java
Original file line number Diff line number Diff line change
Expand Up @@ -3,27 +3,23 @@
import com.fasterxml.jackson.core.type.TypeReference;
import de.dmi3y.behaiv.kernel.logistic.LogisticUtils;
import de.dmi3y.behaiv.storage.BehaivStorage;
import de.dmi3y.behaiv.tools.DataMappingUtils;
import de.dmi3y.behaiv.tools.Pair;
import org.apache.commons.lang3.ArrayUtils;
import org.ejml.simple.SimpleMatrix;
import tech.donau.behaiv.proto.PredictionSet;

import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import java.util.*;

import static de.dmi3y.behaiv.tools.DataMappingUtils.toDistinctListOfPairValues;
import static de.dmi3y.behaiv.tools.DataMappingUtils.toInput2dArray;

public class LogisticRegressionKernel extends BaseKernel {
public class LogisticRegressionKernel extends Kernel {

protected List<String> cachedLables = new ArrayList<>();
protected List<String> labels = new ArrayList<>();
private Random rand;
protected SimpleMatrix theta;

Expand All @@ -39,64 +35,24 @@ public LogisticRegressionKernel(String id) {

@Override
public boolean isEmpty() {
return theta == null && data.getPredictionCount() == 0;
return theta == null && data.size() == 0;
}


@Override
public void fit(PredictionSet data) {
public void fit(List<Pair<List<Double>, String>> data) {
this.data = data;
if (data.getDynamicColumns()) {
throw new UnsupportedOperationException("LogisticRegressionKernel doesn't support dynamic fields");
}
for (int i = 0; i < data.getPredictionList().size(); i++) {
cachedLables.add(data.getPredictionList().get(i).getLabel());
}
labels = toDistinctListOfPairValues(data);
if (readyToPredict()) {
//features
double[][] inputs = toInput2dArray(data);

//labels
double[][] labelArray = new double[data.getPredictionCount()][cachedLables.size()];
for (int i = 0; i < data.getPredictionCount(); i++) {
int dummyPos = cachedLables.indexOf(data.getPrediction(i).getLabel());
labelArray[i][dummyPos] = 1.0;
}

//output layer
final SimpleMatrix inputMatrix = new SimpleMatrix(inputs);
final SimpleMatrix outputMatrix = new SimpleMatrix(labelArray);
//3x4?

//TODO dilemma on if we need to re-do theta or keep it as-is, if new features arrising we'll have a problem
if (theta == null || (theta.numCols() != cachedLables.size() && alwaysKeepData)) {
theta = SimpleMatrix.random_DDRM(inputMatrix.numCols(), outputMatrix.numCols(), 0, 1, rand);
} else if (theta.numCols() != cachedLables.size() && !alwaysKeepData) {
throw new UnsupportedOperationException(
"Partial fit of LogisticRegressionKernel is not supported. " +
"Number of labels differs from trained model." +
" Consider setting alwaysKeepData to true or changing Kernel that supports partial fit."
);
}

for (int i = 0; i < 10000; i++) {
theta = LogisticUtils.gradientDescent(inputMatrix, theta, outputMatrix, 0.1);
}

}
}

@Override
public void fit(List<Pair<List<Double>, String>> data) {
this.data = DataMappingUtils.createPredictionSet(data);
this.cachedLables = toDistinctListOfPairValues(data);
if (readyToPredict()) {
//features
double[][] inputs = toInput2dArray(data);

//labels
double[][] labelArray = new double[data.size()][cachedLables.size()];
double[][] labelArray = new double[data.size()][labels.size()];
for (int i = 0; i < data.size(); i++) {
int dummyPos = cachedLables.indexOf(data.get(i).getValue());
int dummyPos = labels.indexOf(data.get(i).getValue());
labelArray[i][dummyPos] = 1.0;
}

Expand All @@ -106,9 +62,9 @@ public void fit(List<Pair<List<Double>, String>> data) {
//3x4?

//TODO dilemma on if we need to re-do theta or keep it as-is, if new features arrising we'll have a problem
if (theta == null || (theta.numCols() != cachedLables.size() && alwaysKeepData)) {
if (theta == null || (theta.numCols() != labels.size() && alwaysKeepData)) {
theta = SimpleMatrix.random_DDRM(inputMatrix.numCols(), outputMatrix.numCols(), 0, 1, rand);
} else if (theta.numCols() != cachedLables.size() && !alwaysKeepData) {
} else if (theta.numCols() != labels.size() && !alwaysKeepData) {
throw new UnsupportedOperationException(
"Partial fit of LogisticRegressionKernel is not supported. " +
"Number of labels differs from trained model." +
Expand Down Expand Up @@ -136,6 +92,8 @@ public void updateSingle(List<Double> features, String label) {

@Override
public String predictOne(List<Double> features) {


final double[] doubles = ArrayUtils.toPrimitive(features.toArray(new Double[0]));

final SimpleMatrix inputs = new SimpleMatrix(new double[][]{doubles});
Expand All @@ -148,21 +106,18 @@ public String predictOne(List<Double> features) {
maxPosition = i;
}
}
return cachedLables.get(maxPosition);
return labels.get(maxPosition);
}

@Override
public void save(BehaivStorage storage) throws IOException {
if (theta == null && (data == null || data.getPredictionList().isEmpty())) {
if (theta == null && (data == null || data.isEmpty())) {
throw new IOException("Not enough data to save, network data is empty");
}
if (cachedLables == null || cachedLables.isEmpty()) {
cachedLables = new ArrayList<>();
for (int i = 0; i < data.getPredictionList().size(); i++) {
cachedLables.add(data.getPredictionList().get(i).getLabel());
}
if(labels == null || labels.isEmpty()) {
labels = toDistinctListOfPairValues(data);
}
if (cachedLables.isEmpty()) {
if (labels.isEmpty()) {
String message;
message = "Kernel collected data but failed to get labels, couldn't save network.";
throw new IOException(message);
Expand All @@ -176,7 +131,7 @@ public void save(BehaivStorage storage) throws IOException {
theta.saveToFileBinary(storage.getNetworkFile(id).toString());
try (final BufferedWriter writer = new BufferedWriter(new FileWriter(storage.getNetworkMetadataFile(id)))) {

writer.write(objectMapper.writeValueAsString(cachedLables));
writer.write(objectMapper.writeValueAsString(labels));
} catch (Exception e) {
e.printStackTrace();
}
Expand All @@ -202,9 +157,9 @@ public void restore(BehaivStorage storage) throws IOException {
};
final String labelsData = reader.readLine();
if (labelsData == null) {
cachedLables = new ArrayList<>();
labels = new ArrayList<>();
} else {
cachedLables = objectMapper.readValue(labelsData, typeReference);
labels = objectMapper.readValue(labelsData, typeReference);
}
} catch (IOException e) {
e.printStackTrace();
Expand Down
Loading