Skip to content
5 changes: 5 additions & 0 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,11 @@ jacocoTestReport {
xml.enabled = true
csv.enabled = false
}
afterEvaluate {
classDirectories.setFrom(files(classDirectories.files.collect {
fileTree(dir: it, exclude: 'tech/donau/behaiv/proto**')
}))
}
}

task proto(type: Exec) {
Expand Down
97 changes: 97 additions & 0 deletions src/main/java/de/dmi3y/behaiv/kernel/BaseKernel.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
package de.dmi3y.behaiv.kernel;

import com.fasterxml.jackson.databind.ObjectMapper;
import de.dmi3y.behaiv.storage.BehaivStorage;
import de.dmi3y.behaiv.tools.Pair;
import tech.donau.behaiv.proto.Data;
import tech.donau.behaiv.proto.Prediction;
import tech.donau.behaiv.proto.PredictionSet;

import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;

public abstract class BaseKernel implements Kernel {

protected String id;
protected Long treshold = 10L;
protected ObjectMapper objectMapper;
protected boolean partialFitAllowed = false;
protected boolean alwaysKeepData = true;

public BaseKernel(String id) {
this.id = id;
objectMapper = new ObjectMapper();
}


@Override
public void setId(String id) {
this.id = id;
}


//list<features>, label
protected PredictionSet data = PredictionSet.newBuilder().addAllPrediction(new ArrayList<Prediction>()).build();


@Override
public void fit() {
fit(this.data);
}

@Override
public void setTreshold(Long treshold) {
this.treshold = treshold;
}

@Override
public boolean readyToPredict() {
return data.getPredictionList().size() > treshold;
}

@Override
public void update(List<Pair<List<Double>, String>> data) {
}

@Override
public boolean isPartialFitAllowed() {
return partialFitAllowed;
}

@Override
public void updateSingle(List<Double> features, String label) {
final Prediction.Builder predictionBuilder = Prediction.newBuilder();
for (int i = 0; i < features.size(); i++) {
predictionBuilder.addData(Data.newBuilder().setKey("key" + i).setValue(features.get(i)).build());
}
predictionBuilder.setLabel(label);
data = data.toBuilder().addPrediction(predictionBuilder.build()).build();
}

@Override
public boolean isAlwaysKeepData() {
return alwaysKeepData;
}

@Override
public void setAlwaysKeepData(boolean alwaysKeepData) {
this.alwaysKeepData = alwaysKeepData;
}

@Override
public void save(BehaivStorage storage) throws IOException {
try (final FileOutputStream writer = new FileOutputStream(storage.getDataFile(id))) {
data.writeTo(writer);
}
}

@Override
public void restore(BehaivStorage storage) throws IOException {
try (final FileInputStream reader = new FileInputStream(storage.getDataFile(id))) {
data = PredictionSet.parseFrom(reader);
}
}
}
90 changes: 18 additions & 72 deletions src/main/java/de/dmi3y/behaiv/kernel/Kernel.java
Original file line number Diff line number Diff line change
@@ -1,95 +1,41 @@
package de.dmi3y.behaiv.kernel;

import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import de.dmi3y.behaiv.storage.BehaivStorage;
import de.dmi3y.behaiv.tools.Pair;
import tech.donau.behaiv.proto.PredictionSet;

import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;

public abstract class Kernel {
public interface Kernel {
void setId(String id);

protected String id;
protected Long treshold = 10L;
protected ObjectMapper objectMapper;
protected boolean partialFitAllowed = false;
protected boolean alwaysKeepData = true;
boolean isEmpty();

public Kernel(String id) {
this.id = id;
objectMapper = new ObjectMapper();
}
@Deprecated
void fit(List<Pair<List<Double>, String>> data);

void fit(PredictionSet data);

public void setId(String id) {
this.id = id;
}
void fit();

void setTreshold(Long treshold);

//list<features>, label
protected List<Pair<List<Double>, String>> data = new ArrayList<>();
boolean readyToPredict();

void update(List<Pair<List<Double>, String>> data);

public abstract boolean isEmpty();
boolean isPartialFitAllowed();

public abstract void fit(List<Pair<List<Double>, String>> data);
void updateSingle(List<Double> features, String label);

public void fit() {
fit(this.data);
}
String predictOne(List<Double> features);

public void setTreshold(Long treshold) {
this.treshold = treshold;
}
boolean isAlwaysKeepData();

public boolean readyToPredict() {
return data.size() > treshold;
}
void setAlwaysKeepData(boolean alwaysKeepData);

public void update(List<Pair<List<Double>, String>> data) {
}
void save(BehaivStorage storage) throws IOException;

public boolean isPartialFitAllowed() {
return partialFitAllowed;
}

public void updateSingle(List<Double> features, String label) {
data.add(new Pair<>(features, label));
}

public abstract String predictOne(List<Double> features);

public boolean isAlwaysKeepData() {
return alwaysKeepData;
}

public void setAlwaysKeepData(boolean alwaysKeepData) {
this.alwaysKeepData = alwaysKeepData;
}

public void save(BehaivStorage storage) throws IOException {

try (final BufferedWriter writer = new BufferedWriter(new FileWriter(storage.getDataFile(id)))) {
writer.write(objectMapper.writeValueAsString(data));
}
}

public void restore(BehaivStorage storage) throws IOException {
final TypeReference<List<Pair<List<Double>, String>>> typeReference = new TypeReference<List<Pair<List<Double>, String>>>() {
};
try (final BufferedReader reader = new BufferedReader(new FileReader(storage.getDataFile(id)))) {
final String content = reader.readLine();
if (content == null || content.isEmpty()) {
data = new ArrayList<>();
} else {
data = objectMapper.readValue(content, typeReference);
}
}
}
void restore(BehaivStorage storage) throws IOException;
}
87 changes: 66 additions & 21 deletions src/main/java/de/dmi3y/behaiv/kernel/LogisticRegressionKernel.java
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,27 @@
import com.fasterxml.jackson.core.type.TypeReference;
import de.dmi3y.behaiv.kernel.logistic.LogisticUtils;
import de.dmi3y.behaiv.storage.BehaivStorage;
import de.dmi3y.behaiv.tools.DataMappingUtils;
import de.dmi3y.behaiv.tools.Pair;
import org.apache.commons.lang3.ArrayUtils;
import org.ejml.simple.SimpleMatrix;
import tech.donau.behaiv.proto.PredictionSet;

import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.util.*;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;

import static de.dmi3y.behaiv.tools.DataMappingUtils.toDistinctListOfPairValues;
import static de.dmi3y.behaiv.tools.DataMappingUtils.toInput2dArray;

public class LogisticRegressionKernel extends Kernel {
public class LogisticRegressionKernel extends BaseKernel {

protected List<String> labels = new ArrayList<>();
protected List<String> cachedLables = new ArrayList<>();
private Random rand;
protected SimpleMatrix theta;

Expand All @@ -35,24 +39,64 @@ public LogisticRegressionKernel(String id) {

@Override
public boolean isEmpty() {
return theta == null && data.size() == 0;
return theta == null && data.getPredictionCount() == 0;
}


@Override
public void fit(List<Pair<List<Double>, String>> data) {
public void fit(PredictionSet data) {
this.data = data;
labels = toDistinctListOfPairValues(data);
if (data.getDynamicColumns()) {
throw new UnsupportedOperationException("LogisticRegressionKernel doesn't support dynamic fields");
}
for (int i = 0; i < data.getPredictionList().size(); i++) {
cachedLables.add(data.getPredictionList().get(i).getLabel());
}
if (readyToPredict()) {
//features
double[][] inputs = toInput2dArray(data);

//labels
double[][] labelArray = new double[data.getPredictionCount()][cachedLables.size()];
for (int i = 0; i < data.getPredictionCount(); i++) {
int dummyPos = cachedLables.indexOf(data.getPrediction(i).getLabel());
labelArray[i][dummyPos] = 1.0;
}

//output layer
final SimpleMatrix inputMatrix = new SimpleMatrix(inputs);
final SimpleMatrix outputMatrix = new SimpleMatrix(labelArray);
//3x4?

//TODO dilemma on if we need to re-do theta or keep it as-is, if new features arrising we'll have a problem
if (theta == null || (theta.numCols() != cachedLables.size() && alwaysKeepData)) {
theta = SimpleMatrix.random_DDRM(inputMatrix.numCols(), outputMatrix.numCols(), 0, 1, rand);
} else if (theta.numCols() != cachedLables.size() && !alwaysKeepData) {
throw new UnsupportedOperationException(
"Partial fit of LogisticRegressionKernel is not supported. " +
"Number of labels differs from trained model." +
" Consider setting alwaysKeepData to true or changing Kernel that supports partial fit."
);
}

for (int i = 0; i < 10000; i++) {
theta = LogisticUtils.gradientDescent(inputMatrix, theta, outputMatrix, 0.1);
}

}
}

@Override
public void fit(List<Pair<List<Double>, String>> data) {
this.data = DataMappingUtils.createPredictionSet(data);
this.cachedLables = toDistinctListOfPairValues(data);
if (readyToPredict()) {
//features
double[][] inputs = toInput2dArray(data);

//labels
double[][] labelArray = new double[data.size()][labels.size()];
double[][] labelArray = new double[data.size()][cachedLables.size()];
for (int i = 0; i < data.size(); i++) {
int dummyPos = labels.indexOf(data.get(i).getValue());
int dummyPos = cachedLables.indexOf(data.get(i).getValue());
labelArray[i][dummyPos] = 1.0;
}

Expand All @@ -62,9 +106,9 @@ public void fit(List<Pair<List<Double>, String>> data) {
//3x4?

//TODO dilemma on if we need to re-do theta or keep it as-is, if new features arrising we'll have a problem
if (theta == null || (theta.numCols() != labels.size() && alwaysKeepData)) {
if (theta == null || (theta.numCols() != cachedLables.size() && alwaysKeepData)) {
theta = SimpleMatrix.random_DDRM(inputMatrix.numCols(), outputMatrix.numCols(), 0, 1, rand);
} else if (theta.numCols() != labels.size() && !alwaysKeepData) {
} else if (theta.numCols() != cachedLables.size() && !alwaysKeepData) {
throw new UnsupportedOperationException(
"Partial fit of LogisticRegressionKernel is not supported. " +
"Number of labels differs from trained model." +
Expand Down Expand Up @@ -92,8 +136,6 @@ public void updateSingle(List<Double> features, String label) {

@Override
public String predictOne(List<Double> features) {


final double[] doubles = ArrayUtils.toPrimitive(features.toArray(new Double[0]));

final SimpleMatrix inputs = new SimpleMatrix(new double[][]{doubles});
Expand All @@ -106,18 +148,21 @@ public String predictOne(List<Double> features) {
maxPosition = i;
}
}
return labels.get(maxPosition);
return cachedLables.get(maxPosition);
}

@Override
public void save(BehaivStorage storage) throws IOException {
if (theta == null && (data == null || data.isEmpty())) {
if (theta == null && (data == null || data.getPredictionList().isEmpty())) {
throw new IOException("Not enough data to save, network data is empty");
}
if(labels == null || labels.isEmpty()) {
labels = toDistinctListOfPairValues(data);
if (cachedLables == null || cachedLables.isEmpty()) {
cachedLables = new ArrayList<>();
for (int i = 0; i < data.getPredictionList().size(); i++) {
cachedLables.add(data.getPredictionList().get(i).getLabel());
}
}
if (labels.isEmpty()) {
if (cachedLables.isEmpty()) {
String message;
message = "Kernel collected data but failed to get labels, couldn't save network.";
throw new IOException(message);
Expand All @@ -131,7 +176,7 @@ public void save(BehaivStorage storage) throws IOException {
theta.saveToFileBinary(storage.getNetworkFile(id).toString());
try (final BufferedWriter writer = new BufferedWriter(new FileWriter(storage.getNetworkMetadataFile(id)))) {

writer.write(objectMapper.writeValueAsString(labels));
writer.write(objectMapper.writeValueAsString(cachedLables));
} catch (Exception e) {
e.printStackTrace();
}
Expand All @@ -157,9 +202,9 @@ public void restore(BehaivStorage storage) throws IOException {
};
final String labelsData = reader.readLine();
if (labelsData == null) {
labels = new ArrayList<>();
cachedLables = new ArrayList<>();
} else {
labels = objectMapper.readValue(labelsData, typeReference);
cachedLables = objectMapper.readValue(labelsData, typeReference);
}
} catch (IOException e) {
e.printStackTrace();
Expand Down
Loading