Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ repositories {

dependencies {
// Protobuf
compile group: 'com.google.protobuf', name: 'protobuf-java', version: '3.11.4'
compile group: 'com.google.protobuf', name: 'protobuf-javalite', version: '3.11.4'
// Swagger codegen dependencies
swaggerCodegen 'io.swagger:swagger-codegen-cli:2.4.2' // Swagger Codegen V2
// This dependency is exported to consumers, that is to say found on their compile classpath.
Expand Down Expand Up @@ -80,6 +80,11 @@ jacocoTestReport {
xml.enabled = true
csv.enabled = false
}
afterEvaluate {
classDirectories.setFrom(files(classDirectories.files.collect {
fileTree(dir: it, exclude: 'tech/donau/behaiv/proto**')
}))
}
}

task proto(type: Exec) {
Expand Down
2 changes: 1 addition & 1 deletion proto.sh
Original file line number Diff line number Diff line change
@@ -1 +1 @@
protoc -I=. --java_out=./src/main/java ./behaiv.proto
protoc -I=. --java_out=lite:./src/main/java ./behaiv.proto
103 changes: 103 additions & 0 deletions src/main/java/de/dmi3y/behaiv/kernel/BaseKernel.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
package de.dmi3y.behaiv.kernel;

import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.protobuf.CodedOutputStream;
import de.dmi3y.behaiv.storage.BehaivStorage;
import de.dmi3y.behaiv.tools.DataMappingUtils;
import de.dmi3y.behaiv.tools.Pair;
import tech.donau.behaiv.proto.PredictionSet;

import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;

public abstract class BaseKernel implements Kernel {

protected String id;
protected Long treshold = 10L;
protected ObjectMapper objectMapper;
protected boolean partialFitAllowed = false;
protected boolean alwaysKeepData = true;

public BaseKernel(String id) {
this.id = id;
objectMapper = new ObjectMapper();
}


@Override
public void setId(String id) {
this.id = id;
}


//list<features>, label
protected List<Pair<List<Double>, String>> data = new ArrayList<>();


@Override
public void fit() {
fit(this.data);
}

@Override
public void setTreshold(Long treshold) {
this.treshold = treshold;
}

@Override
public boolean readyToPredict() {
return data.size() > treshold;
}

@Override
public void update(List<Pair<List<Double>, String>> data) {
}

@Override
public boolean isPartialFitAllowed() {
return partialFitAllowed;
}

@Override
public void updateSingle(List<Double> features, String label) {
data.add(new Pair<>(features, label));
}

@Override
public boolean isAlwaysKeepData() {
return alwaysKeepData;
}

@Override
public void setAlwaysKeepData(boolean alwaysKeepData) {
this.alwaysKeepData = alwaysKeepData;
}

@Override
public void save(BehaivStorage storage) throws IOException {
final File dataFile = storage.getDataFile(id);
final PredictionSet predictionSet = DataMappingUtils.createPredictionSet(data);
try(FileOutputStream fileOutputStream = new FileOutputStream(dataFile)) {
predictionSet.writeTo(fileOutputStream);
}
}

@Override
public void restore(BehaivStorage storage) throws IOException {
final File dataFile = storage.getDataFile(id);
try (FileInputStream fileInputStream = new FileInputStream(dataFile)) {
final PredictionSet predictionSet = PredictionSet.parseFrom(fileInputStream);
data = DataMappingUtils.dataFromPredictionSet(predictionSet);
}
}
}
88 changes: 15 additions & 73 deletions src/main/java/de/dmi3y/behaiv/kernel/Kernel.java
Original file line number Diff line number Diff line change
@@ -1,95 +1,37 @@
package de.dmi3y.behaiv.kernel;

import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import de.dmi3y.behaiv.storage.BehaivStorage;
import de.dmi3y.behaiv.tools.Pair;

import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;

public abstract class Kernel {
public interface Kernel {
void setId(String id);

protected String id;
protected Long treshold = 10L;
protected ObjectMapper objectMapper;
protected boolean partialFitAllowed = false;
protected boolean alwaysKeepData = true;
boolean isEmpty();

public Kernel(String id) {
this.id = id;
objectMapper = new ObjectMapper();
}
void fit(List<Pair<List<Double>, String>> data);

void fit();

public void setId(String id) {
this.id = id;
}
void setTreshold(Long treshold);

boolean readyToPredict();

//list<features>, label
protected List<Pair<List<Double>, String>> data = new ArrayList<>();
void update(List<Pair<List<Double>, String>> data);

boolean isPartialFitAllowed();

public abstract boolean isEmpty();
void updateSingle(List<Double> features, String label);

public abstract void fit(List<Pair<List<Double>, String>> data);
String predictOne(List<Double> features);

public void fit() {
fit(this.data);
}
boolean isAlwaysKeepData();

public void setTreshold(Long treshold) {
this.treshold = treshold;
}
void setAlwaysKeepData(boolean alwaysKeepData);

public boolean readyToPredict() {
return data.size() > treshold;
}
void save(BehaivStorage storage) throws IOException;

public void update(List<Pair<List<Double>, String>> data) {
}

public boolean isPartialFitAllowed() {
return partialFitAllowed;
}

public void updateSingle(List<Double> features, String label) {
data.add(new Pair<>(features, label));
}

public abstract String predictOne(List<Double> features);

public boolean isAlwaysKeepData() {
return alwaysKeepData;
}

public void setAlwaysKeepData(boolean alwaysKeepData) {
this.alwaysKeepData = alwaysKeepData;
}

public void save(BehaivStorage storage) throws IOException {

try (final BufferedWriter writer = new BufferedWriter(new FileWriter(storage.getDataFile(id)))) {
writer.write(objectMapper.writeValueAsString(data));
}
}

public void restore(BehaivStorage storage) throws IOException {
final TypeReference<List<Pair<List<Double>, String>>> typeReference = new TypeReference<List<Pair<List<Double>, String>>>() {
};
try (final BufferedReader reader = new BufferedReader(new FileReader(storage.getDataFile(id)))) {
final String content = reader.readLine();
if (content == null || content.isEmpty()) {
data = new ArrayList<>();
} else {
data = objectMapper.readValue(content, typeReference);
}
}
}
void restore(BehaivStorage storage) throws IOException;
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import static de.dmi3y.behaiv.tools.DataMappingUtils.toDistinctListOfPairValues;
import static de.dmi3y.behaiv.tools.DataMappingUtils.toInput2dArray;

public class LogisticRegressionKernel extends Kernel {
public class LogisticRegressionKernel extends BaseKernel {

protected List<String> labels = new ArrayList<>();
private Random rand;
Expand Down
35 changes: 35 additions & 0 deletions src/main/java/de/dmi3y/behaiv/tools/DataMappingUtils.java
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
package de.dmi3y.behaiv.tools;

import org.apache.commons.lang3.ArrayUtils;
import tech.donau.behaiv.proto.Data;
import tech.donau.behaiv.proto.Prediction;
import tech.donau.behaiv.proto.PredictionSet;

import java.util.ArrayList;
import java.util.HashSet;
Expand All @@ -14,6 +17,38 @@ private DataMappingUtils() {
// Unused utility class
}

public static PredictionSet createPredictionSet(List<Pair<List<Double>, String>> data) {
final PredictionSet.Builder predictionSetBuilder = PredictionSet.newBuilder();
for (int i = 0; i < data.size(); i++) {
final List<Double> weights = data.get(i).getKey();
final ArrayList<Data> row = new ArrayList<>();
for (int j = 0; j < weights.size(); j++) {
row.add(Data.newBuilder().setKey("key"+j).setValue(weights.get(j)).build());
}
predictionSetBuilder.addPrediction(
Prediction.newBuilder()
.addAllData(row)
.setLabel(data.get(i).getValue())
.build()
);
}
return predictionSetBuilder.build();
}

public static List<Pair<List<Double>, String>> dataFromPredictionSet(PredictionSet predictionSet) {
final ArrayList<Pair<List<Double>, String>> data = new ArrayList<>();
for (int i = 0; i < predictionSet.getPredictionCount(); i++) {
final Prediction prediction = predictionSet.getPrediction(i);
final List<Data> dataList = prediction.getDataList();
final String label = prediction.getLabel();
final List<Double> entries = new ArrayList<>();
for (Data entry : dataList) {
entries.add(entry.getValue());
}
data.add(Pair.create(entries, label));
}
return data;
}

public static List<String> toDistinctListOfPairValues(List<Pair<List<Double>, String>> data) {
Set<String> setOfValues = new HashSet<>();
Expand Down
Loading