Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .github/release.yml
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
changelog:
include_author: false
exclude:
labels:
- ignore-for-release
Expand Down
12 changes: 12 additions & 0 deletions src/main/java/dev/braintrust/Braintrust.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import dev.braintrust.api.BraintrustApiClient;
import dev.braintrust.config.BraintrustConfig;
import dev.braintrust.eval.Dataset;
import dev.braintrust.eval.Eval;
import dev.braintrust.prompt.BraintrustPromptLoader;
import dev.braintrust.trace.BraintrustTracing;
Expand All @@ -12,6 +13,7 @@
import java.net.URI;
import java.util.concurrent.atomic.AtomicReference;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import lombok.Getter;
import lombok.experimental.Accessors;
import lombok.extern.slf4j.Slf4j;
Expand Down Expand Up @@ -161,4 +163,14 @@ public <INPUT, OUTPUT> Eval.Builder<INPUT, OUTPUT> evalBuilder() {
return (Eval.Builder<INPUT, OUTPUT>)
Eval.builder().config(this.config).apiClient(this.apiClient);
}

public <INPUT, OUTPUT> Dataset<INPUT, OUTPUT> fetchDataset(String datasetName) {
return fetchDataset(datasetName, null);
}

public <INPUT, OUTPUT> Dataset<INPUT, OUTPUT> fetchDataset(
String datasetName, @Nullable String datasetVersion) {
var projectName = apiClient.getOrCreateProjectAndOrgInfo(config).project().name();
return Dataset.fetchFromBraintrust(apiClient(), projectName, datasetName, datasetVersion);
}
}
93 changes: 90 additions & 3 deletions src/main/java/dev/braintrust/api/BraintrustApiClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,15 @@ public interface BraintrustApiClient {
Optional<Prompt> getPrompt(
@Nonnull String projectName, @Nonnull String slug, @Nullable String version);

/** Fetch dataset events with pagination */
DatasetFetchResponse fetchDatasetEvents(String datasetId, DatasetFetchRequest request);

/** Get dataset metadata by ID */
Optional<Dataset> getDataset(String datasetId);

/** Query datasets by project name and dataset name */
List<Dataset> queryDatasets(String projectName, String datasetName);

static BraintrustApiClient of(BraintrustConfig config) {
return new HttpImpl(config);
}
Expand Down Expand Up @@ -235,6 +244,49 @@ public Optional<Prompt> getPrompt(
}
}

@Override
public DatasetFetchResponse fetchDatasetEvents(
String datasetId, DatasetFetchRequest request) {
try {
String path = "/v1/dataset/" + datasetId + "/fetch";
return postAsync(path, request, DatasetFetchResponse.class).get();
} catch (InterruptedException | ExecutionException e) {
throw new ApiException(e);
}
}

@Override
public Optional<Dataset> getDataset(String datasetId) {
try {
return getAsync("/v1/dataset/" + datasetId, Dataset.class)
.handle(
(dataset, error) -> {
if (error != null && isNotFound(error)) {
return Optional.<Dataset>empty();
}
if (error != null) {
throw new CompletionException(error);
}
return Optional.of(dataset);
})
.get();
} catch (InterruptedException | ExecutionException e) {
throw new RuntimeException(e);
}
}

@Override
public List<Dataset> queryDatasets(String projectName, String datasetName) {
try {
String path =
"/v1/dataset?project_name=" + projectName + "&dataset_name=" + datasetName;
DatasetList response = getAsync(path, DatasetList.class).get();
return response.objects();
} catch (InterruptedException | ExecutionException e) {
throw new RuntimeException(e);
}
}

private <T> CompletableFuture<T> getAsync(String path, Class<T> responseType) {
var request =
HttpRequest.newBuilder()
Expand Down Expand Up @@ -301,8 +353,14 @@ private <T> T handleResponse(HttpResponse<String> response, Class<T> responseTyp
}

private boolean isNotFound(Throwable error) {
if (error instanceof ApiException) {
return ((ApiException) error).getMessage().contains("404");
// Unwrap CompletionException if present
Throwable cause = error;
if (error instanceof CompletionException && error.getCause() != null) {
cause = error.getCause();
}

if (cause instanceof ApiException) {
return ((ApiException) cause).getMessage().contains("404");
}
return false;
}
Expand Down Expand Up @@ -493,6 +551,23 @@ public Optional<Prompt> getPrompt(

return Optional.of(matchingPrompts.get(0));
}

// Will add dataset support if needed in unit tests (this is unlikely to be needed though)
@Override
public DatasetFetchResponse fetchDatasetEvents(
String datasetId, DatasetFetchRequest request) {
return new DatasetFetchResponse(List.of(), null);
}

@Override
public Optional<Dataset> getDataset(String datasetId) {
return Optional.empty();
}

@Override
public List<Dataset> queryDatasets(String projectName, String datasetName) {
return List.of();
}
}

// Request/Response DTOs
Expand Down Expand Up @@ -538,7 +613,7 @@ record Dataset(
String createdAt,
String updatedAt) {}

record DatasetList(List<Dataset> datasets) {}
record DatasetList(List<Dataset> objects) {}

record DatasetEvent(Object input, Optional<Object> output, Optional<Object> metadata) {
public DatasetEvent(Object input) {
Expand All @@ -554,6 +629,18 @@ record InsertEventsRequest(List<DatasetEvent> events) {}

record InsertEventsResponse(int insertedCount) {}

record DatasetFetchRequest(int limit, @Nullable String cursor, @Nullable String version) {
public DatasetFetchRequest(int limit) {
this(limit, null, null);
}

public DatasetFetchRequest(int limit, @Nullable String cursor) {
this(limit, cursor, null);
}
}

record DatasetFetchResponse(List<Map<String, Object>> events, @Nullable String cursor) {}

// User and Organization models for login functionality
record OrganizationInfo(String id, String name) {}

Expand Down
47 changes: 46 additions & 1 deletion src/main/java/dev/braintrust/eval/Dataset.java
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
package dev.braintrust.eval;

import dev.braintrust.api.BraintrustApiClient;
import java.util.List;
import java.util.Optional;
import java.util.function.Consumer;
import javax.annotation.Nullable;
import javax.annotation.concurrent.NotThreadSafe;

/**
Expand All @@ -16,7 +19,19 @@ public interface Dataset<INPUT, OUTPUT> {

String id();

String version();
/** Dataset version. Empty means the dataset will fetch latest upon every cursor open */
Optional<String> version();

/** Convenience method to safely iterate all items in a dataset. */
default void forEach(Consumer<DatasetCase<INPUT, OUTPUT>> consumer) {
try (var cursor = openCursor()) {
Optional<DatasetCase<INPUT, OUTPUT>> cursorCase = cursor.next();
while (cursorCase.isPresent()) {
consumer.accept(cursorCase.get());
cursorCase = cursor.next();
}
}
}

@NotThreadSafe
interface Cursor<CASE> extends AutoCloseable {
Expand All @@ -39,4 +54,34 @@ interface Cursor<CASE> extends AutoCloseable {
static <INPUT, OUTPUT> Dataset<INPUT, OUTPUT> of(DatasetCase<INPUT, OUTPUT>... cases) {
return new DatasetInMemoryImpl<>(List.of(cases));
}

static <INPUT, OUTPUT> Dataset<INPUT, OUTPUT> fetchFromBraintrust(
BraintrustApiClient apiClient,
String projectName,
String datasetName,
@Nullable String datasetVersion) {
var datasets = apiClient.queryDatasets(projectName, datasetName);

if (datasets.isEmpty()) {
throw new RuntimeException(
"Dataset not found: project=" + projectName + ", dataset=" + datasetName);
}

if (datasets.size() > 1) {
throw new RuntimeException(
"Multiple datasets found for project="
+ projectName
+ ", dataset="
+ datasetName
+ ". Found "
+ datasets.size()
+ " datasets");
}

var dataset = datasets.get(0);
return new DatasetBrainstoreImpl<>(
apiClient,
dataset.id(),
datasetVersion != null ? datasetVersion : dataset.updatedAt());
}
}
120 changes: 120 additions & 0 deletions src/main/java/dev/braintrust/eval/DatasetBrainstoreImpl.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
package dev.braintrust.eval;

import dev.braintrust.api.BraintrustApiClient;
import java.util.*;
import javax.annotation.Nullable;

/** A dataset loaded externally from Braintrust using paginated API fetches */
public class DatasetBrainstoreImpl<INPUT, OUTPUT> implements Dataset<INPUT, OUTPUT> {
private final BraintrustApiClient apiClient;
private final String datasetId;
private final @Nullable String pinnedVersion;
private final int batchSize;

public DatasetBrainstoreImpl(
BraintrustApiClient apiClient, String datasetId, @Nullable String datasetVersion) {
this(apiClient, datasetId, datasetVersion, 512);
}

DatasetBrainstoreImpl(
BraintrustApiClient apiClient,
String datasetId,
@Nullable String datasetVersion,
int batchSize) {
this.apiClient = apiClient;
this.datasetId = datasetId;
this.batchSize = batchSize;
this.pinnedVersion = datasetVersion;
}

@Override
public String id() {
return datasetId;
}

@Override
public Optional<String> version() {
return Optional.ofNullable(pinnedVersion);
}

@Override
public Cursor<DatasetCase<INPUT, OUTPUT>> openCursor() {
return new BrainstoreCursor();
}

private class BrainstoreCursor implements Cursor<DatasetCase<INPUT, OUTPUT>> {
private List<Map<String, Object>> currentBatch;
private int currentIndex;
private @Nullable String cursor;
private boolean exhausted;
private boolean closed;

BrainstoreCursor() {
this.currentBatch = new ArrayList<>();
this.currentIndex = 0;
this.cursor = null;
this.exhausted = false;
this.closed = false;
}

@Override
@SuppressWarnings("unchecked")
public Optional<DatasetCase<INPUT, OUTPUT>> next() {
if (closed) {
throw new IllegalStateException("Cursor is closed");
}

// Fetch next batch if we've consumed the current one
if (currentIndex >= currentBatch.size() && !exhausted) {
fetchNextBatch();
}

// Return empty if no more data
if (currentIndex >= currentBatch.size()) {
return Optional.empty();
}

// Parse the event into a DatasetCase
Map<String, Object> event = currentBatch.get(currentIndex++);

INPUT input = (INPUT) event.get("input");
OUTPUT expected = (OUTPUT) event.get("expected");

Map<String, Object> metadata = (Map<String, Object>) event.get("metadata");
if (metadata == null) {
metadata = Map.of();
}

List<String> tags = (List<String>) event.get("tags");
if (tags == null) {
tags = List.of();
}

DatasetCase<INPUT, OUTPUT> datasetCase =
new DatasetCase<>(input, expected, tags, metadata);

return Optional.of(datasetCase);
}

private void fetchNextBatch() {
var request =
new BraintrustApiClient.DatasetFetchRequest(batchSize, cursor, pinnedVersion);
var response = apiClient.fetchDatasetEvents(datasetId, request);

currentBatch = new ArrayList<>(response.events());
currentIndex = 0;
cursor = response.cursor();

// Mark as exhausted if no cursor or no events returned
if (cursor == null || cursor.isEmpty() || response.events().isEmpty()) {
exhausted = true;
}
}

@Override
public void close() {
closed = true;
currentBatch.clear();
}
}
}
4 changes: 2 additions & 2 deletions src/main/java/dev/braintrust/eval/DatasetInMemoryImpl.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ public String id() {
}

@Override
public String version() {
return "0";
public Optional<String> version() {
return Optional.empty();
}

@Override
Expand Down
10 changes: 1 addition & 9 deletions src/main/java/dev/braintrust/eval/Eval.java
Original file line number Diff line number Diff line change
Expand Up @@ -64,15 +64,7 @@ public EvalResult run() {
experimentName,
Optional.empty(),
Optional.empty()));
var experimentID = experiment.id();

try (var cursor = dataset.openCursor()) {
for (var datsetCase = cursor.next();
datsetCase.isPresent();
datsetCase = cursor.next()) {
evalOne(experimentID, datsetCase.get());
}
}
dataset.forEach(datasetCase -> evalOne(experiment.id(), datasetCase));
var experimentUrl =
"%s/experiments/%s"
.formatted(
Expand Down
Loading