From 54a17aeb3d6a4bf12fd25e79c39a2037e3829a00 Mon Sep 17 00:00:00 2001 From: Andrew Kent Date: Mon, 1 Dec 2025 14:15:00 -0700 Subject: [PATCH 1/2] rm useless line in release.yml --- .github/release.yml | 1 - 1 file changed, 1 deletion(-) diff --git a/.github/release.yml b/.github/release.yml index bfe7aac..6d198a7 100644 --- a/.github/release.yml +++ b/.github/release.yml @@ -1,5 +1,4 @@ changelog: - include_author: false exclude: labels: - ignore-for-release From 096f8bb841080104e55e3dca3ac2a6557853034d Mon Sep 17 00:00:00 2001 From: Andrew Kent Date: Wed, 3 Dec 2025 23:41:09 -0700 Subject: [PATCH 2/2] Support fetching datasets from Braintrust --- src/main/java/dev/braintrust/Braintrust.java | 12 + .../braintrust/api/BraintrustApiClient.java | 93 ++++++- .../java/dev/braintrust/eval/Dataset.java | 47 +++- .../eval/DatasetBrainstoreImpl.java | 120 +++++++++ .../braintrust/eval/DatasetInMemoryImpl.java | 4 +- src/main/java/dev/braintrust/eval/Eval.java | 10 +- .../eval/DatasetBrainstoreImplTest.java | 243 ++++++++++++++++++ 7 files changed, 514 insertions(+), 15 deletions(-) create mode 100644 src/main/java/dev/braintrust/eval/DatasetBrainstoreImpl.java create mode 100644 src/test/java/dev/braintrust/eval/DatasetBrainstoreImplTest.java diff --git a/src/main/java/dev/braintrust/Braintrust.java b/src/main/java/dev/braintrust/Braintrust.java index d0bc5b3..4b723f3 100644 --- a/src/main/java/dev/braintrust/Braintrust.java +++ b/src/main/java/dev/braintrust/Braintrust.java @@ -2,6 +2,7 @@ import dev.braintrust.api.BraintrustApiClient; import dev.braintrust.config.BraintrustConfig; +import dev.braintrust.eval.Dataset; import dev.braintrust.eval.Eval; import dev.braintrust.prompt.BraintrustPromptLoader; import dev.braintrust.trace.BraintrustTracing; @@ -12,6 +13,7 @@ import java.net.URI; import java.util.concurrent.atomic.AtomicReference; import javax.annotation.Nonnull; +import javax.annotation.Nullable; import lombok.Getter; import lombok.experimental.Accessors; import lombok.extern.slf4j.Slf4j; @@ -161,4 +163,14 @@ public Eval.Builder evalBuilder() { return (Eval.Builder) Eval.builder().config(this.config).apiClient(this.apiClient); } + + public Dataset fetchDataset(String datasetName) { + return fetchDataset(datasetName, null); + } + + public Dataset fetchDataset( + String datasetName, @Nullable String datasetVersion) { + var projectName = apiClient.getOrCreateProjectAndOrgInfo(config).project().name(); + return Dataset.fetchFromBraintrust(apiClient(), projectName, datasetName, datasetVersion); + } } diff --git a/src/main/java/dev/braintrust/api/BraintrustApiClient.java b/src/main/java/dev/braintrust/api/BraintrustApiClient.java index 51561d4..a8f8965 100644 --- a/src/main/java/dev/braintrust/api/BraintrustApiClient.java +++ b/src/main/java/dev/braintrust/api/BraintrustApiClient.java @@ -49,6 +49,15 @@ public interface BraintrustApiClient { Optional getPrompt( @Nonnull String projectName, @Nonnull String slug, @Nullable String version); + /** Fetch dataset events with pagination */ + DatasetFetchResponse fetchDatasetEvents(String datasetId, DatasetFetchRequest request); + + /** Get dataset metadata by ID */ + Optional getDataset(String datasetId); + + /** Query datasets by project name and dataset name */ + List queryDatasets(String projectName, String datasetName); + static BraintrustApiClient of(BraintrustConfig config) { return new HttpImpl(config); } @@ -235,6 +244,49 @@ public Optional getPrompt( } } + @Override + public DatasetFetchResponse fetchDatasetEvents( + String datasetId, DatasetFetchRequest request) { + try { + String path = "/v1/dataset/" + datasetId + "/fetch"; + return postAsync(path, request, DatasetFetchResponse.class).get(); + } catch (InterruptedException | ExecutionException e) { + throw new ApiException(e); + } + } + + @Override + public Optional getDataset(String datasetId) { + try { + return getAsync("/v1/dataset/" + datasetId, Dataset.class) + .handle( + (dataset, error) -> { + if (error != null && isNotFound(error)) { + return Optional.empty(); + } + if (error != null) { + throw new CompletionException(error); + } + return Optional.of(dataset); + }) + .get(); + } catch (InterruptedException | ExecutionException e) { + throw new RuntimeException(e); + } + } + + @Override + public List queryDatasets(String projectName, String datasetName) { + try { + String path = + "/v1/dataset?project_name=" + projectName + "&dataset_name=" + datasetName; + DatasetList response = getAsync(path, DatasetList.class).get(); + return response.objects(); + } catch (InterruptedException | ExecutionException e) { + throw new RuntimeException(e); + } + } + private CompletableFuture getAsync(String path, Class responseType) { var request = HttpRequest.newBuilder() @@ -301,8 +353,14 @@ private T handleResponse(HttpResponse response, Class responseTyp } private boolean isNotFound(Throwable error) { - if (error instanceof ApiException) { - return ((ApiException) error).getMessage().contains("404"); + // Unwrap CompletionException if present + Throwable cause = error; + if (error instanceof CompletionException && error.getCause() != null) { + cause = error.getCause(); + } + + if (cause instanceof ApiException) { + return ((ApiException) cause).getMessage().contains("404"); } return false; } @@ -493,6 +551,23 @@ public Optional getPrompt( return Optional.of(matchingPrompts.get(0)); } + + // Will add dataset support if needed in unit tests (this is unlikely to be needed though) + @Override + public DatasetFetchResponse fetchDatasetEvents( + String datasetId, DatasetFetchRequest request) { + return new DatasetFetchResponse(List.of(), null); + } + + @Override + public Optional getDataset(String datasetId) { + return Optional.empty(); + } + + @Override + public List queryDatasets(String projectName, String datasetName) { + return List.of(); + } } // Request/Response DTOs @@ -538,7 +613,7 @@ record Dataset( String createdAt, String updatedAt) {} - record DatasetList(List datasets) {} + record DatasetList(List objects) {} record DatasetEvent(Object input, Optional output, Optional metadata) { public DatasetEvent(Object input) { @@ -554,6 +629,18 @@ record InsertEventsRequest(List events) {} record InsertEventsResponse(int insertedCount) {} + record DatasetFetchRequest(int limit, @Nullable String cursor, @Nullable String version) { + public DatasetFetchRequest(int limit) { + this(limit, null, null); + } + + public DatasetFetchRequest(int limit, @Nullable String cursor) { + this(limit, cursor, null); + } + } + + record DatasetFetchResponse(List> events, @Nullable String cursor) {} + // User and Organization models for login functionality record OrganizationInfo(String id, String name) {} diff --git a/src/main/java/dev/braintrust/eval/Dataset.java b/src/main/java/dev/braintrust/eval/Dataset.java index f07b0d7..381f5ad 100644 --- a/src/main/java/dev/braintrust/eval/Dataset.java +++ b/src/main/java/dev/braintrust/eval/Dataset.java @@ -1,7 +1,10 @@ package dev.braintrust.eval; +import dev.braintrust.api.BraintrustApiClient; import java.util.List; import java.util.Optional; +import java.util.function.Consumer; +import javax.annotation.Nullable; import javax.annotation.concurrent.NotThreadSafe; /** @@ -16,7 +19,19 @@ public interface Dataset { String id(); - String version(); + /** Dataset version. Empty means the dataset will fetch latest upon every cursor open */ + Optional version(); + + /** Convenience method to safely iterate all items in a dataset. */ + default void forEach(Consumer> consumer) { + try (var cursor = openCursor()) { + Optional> cursorCase = cursor.next(); + while (cursorCase.isPresent()) { + consumer.accept(cursorCase.get()); + cursorCase = cursor.next(); + } + } + } @NotThreadSafe interface Cursor extends AutoCloseable { @@ -39,4 +54,34 @@ interface Cursor extends AutoCloseable { static Dataset of(DatasetCase... cases) { return new DatasetInMemoryImpl<>(List.of(cases)); } + + static Dataset fetchFromBraintrust( + BraintrustApiClient apiClient, + String projectName, + String datasetName, + @Nullable String datasetVersion) { + var datasets = apiClient.queryDatasets(projectName, datasetName); + + if (datasets.isEmpty()) { + throw new RuntimeException( + "Dataset not found: project=" + projectName + ", dataset=" + datasetName); + } + + if (datasets.size() > 1) { + throw new RuntimeException( + "Multiple datasets found for project=" + + projectName + + ", dataset=" + + datasetName + + ". Found " + + datasets.size() + + " datasets"); + } + + var dataset = datasets.get(0); + return new DatasetBrainstoreImpl<>( + apiClient, + dataset.id(), + datasetVersion != null ? datasetVersion : dataset.updatedAt()); + } } diff --git a/src/main/java/dev/braintrust/eval/DatasetBrainstoreImpl.java b/src/main/java/dev/braintrust/eval/DatasetBrainstoreImpl.java new file mode 100644 index 0000000..606c8c0 --- /dev/null +++ b/src/main/java/dev/braintrust/eval/DatasetBrainstoreImpl.java @@ -0,0 +1,120 @@ +package dev.braintrust.eval; + +import dev.braintrust.api.BraintrustApiClient; +import java.util.*; +import javax.annotation.Nullable; + +/** A dataset loaded externally from Braintrust using paginated API fetches */ +public class DatasetBrainstoreImpl implements Dataset { + private final BraintrustApiClient apiClient; + private final String datasetId; + private final @Nullable String pinnedVersion; + private final int batchSize; + + public DatasetBrainstoreImpl( + BraintrustApiClient apiClient, String datasetId, @Nullable String datasetVersion) { + this(apiClient, datasetId, datasetVersion, 512); + } + + DatasetBrainstoreImpl( + BraintrustApiClient apiClient, + String datasetId, + @Nullable String datasetVersion, + int batchSize) { + this.apiClient = apiClient; + this.datasetId = datasetId; + this.batchSize = batchSize; + this.pinnedVersion = datasetVersion; + } + + @Override + public String id() { + return datasetId; + } + + @Override + public Optional version() { + return Optional.ofNullable(pinnedVersion); + } + + @Override + public Cursor> openCursor() { + return new BrainstoreCursor(); + } + + private class BrainstoreCursor implements Cursor> { + private List> currentBatch; + private int currentIndex; + private @Nullable String cursor; + private boolean exhausted; + private boolean closed; + + BrainstoreCursor() { + this.currentBatch = new ArrayList<>(); + this.currentIndex = 0; + this.cursor = null; + this.exhausted = false; + this.closed = false; + } + + @Override + @SuppressWarnings("unchecked") + public Optional> next() { + if (closed) { + throw new IllegalStateException("Cursor is closed"); + } + + // Fetch next batch if we've consumed the current one + if (currentIndex >= currentBatch.size() && !exhausted) { + fetchNextBatch(); + } + + // Return empty if no more data + if (currentIndex >= currentBatch.size()) { + return Optional.empty(); + } + + // Parse the event into a DatasetCase + Map event = currentBatch.get(currentIndex++); + + INPUT input = (INPUT) event.get("input"); + OUTPUT expected = (OUTPUT) event.get("expected"); + + Map metadata = (Map) event.get("metadata"); + if (metadata == null) { + metadata = Map.of(); + } + + List tags = (List) event.get("tags"); + if (tags == null) { + tags = List.of(); + } + + DatasetCase datasetCase = + new DatasetCase<>(input, expected, tags, metadata); + + return Optional.of(datasetCase); + } + + private void fetchNextBatch() { + var request = + new BraintrustApiClient.DatasetFetchRequest(batchSize, cursor, pinnedVersion); + var response = apiClient.fetchDatasetEvents(datasetId, request); + + currentBatch = new ArrayList<>(response.events()); + currentIndex = 0; + cursor = response.cursor(); + + // Mark as exhausted if no cursor or no events returned + if (cursor == null || cursor.isEmpty() || response.events().isEmpty()) { + exhausted = true; + } + } + + @Override + public void close() { + closed = true; + currentBatch.clear(); + } + } +} diff --git a/src/main/java/dev/braintrust/eval/DatasetInMemoryImpl.java b/src/main/java/dev/braintrust/eval/DatasetInMemoryImpl.java index 554fdd8..9d2d868 100644 --- a/src/main/java/dev/braintrust/eval/DatasetInMemoryImpl.java +++ b/src/main/java/dev/braintrust/eval/DatasetInMemoryImpl.java @@ -19,8 +19,8 @@ public String id() { } @Override - public String version() { - return "0"; + public Optional version() { + return Optional.empty(); } @Override diff --git a/src/main/java/dev/braintrust/eval/Eval.java b/src/main/java/dev/braintrust/eval/Eval.java index b5fb9d3..d08b4cd 100644 --- a/src/main/java/dev/braintrust/eval/Eval.java +++ b/src/main/java/dev/braintrust/eval/Eval.java @@ -64,15 +64,7 @@ public EvalResult run() { experimentName, Optional.empty(), Optional.empty())); - var experimentID = experiment.id(); - - try (var cursor = dataset.openCursor()) { - for (var datsetCase = cursor.next(); - datsetCase.isPresent(); - datsetCase = cursor.next()) { - evalOne(experimentID, datsetCase.get()); - } - } + dataset.forEach(datasetCase -> evalOne(experiment.id(), datasetCase)); var experimentUrl = "%s/experiments/%s" .formatted( diff --git a/src/test/java/dev/braintrust/eval/DatasetBrainstoreImplTest.java b/src/test/java/dev/braintrust/eval/DatasetBrainstoreImplTest.java new file mode 100644 index 0000000..6c536c5 --- /dev/null +++ b/src/test/java/dev/braintrust/eval/DatasetBrainstoreImplTest.java @@ -0,0 +1,243 @@ +package dev.braintrust.eval; + +import static com.github.tomakehurst.wiremock.client.WireMock.*; +import static com.github.tomakehurst.wiremock.core.WireMockConfiguration.wireMockConfig; +import static org.junit.jupiter.api.Assertions.*; + +import com.github.tomakehurst.wiremock.junit5.WireMockExtension; +import dev.braintrust.api.BraintrustApiClient; +import dev.braintrust.config.BraintrustConfig; +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +public class DatasetBrainstoreImplTest { + + @RegisterExtension + static WireMockExtension wireMock = + WireMockExtension.newInstance().options(wireMockConfig().dynamicPort()).build(); + + private BraintrustApiClient apiClient; + private String datasetId; + + @BeforeEach + void beforeEach() { + wireMock.resetAll(); + datasetId = "test-dataset-123"; + + // Create API client pointing to WireMock server + var config = + BraintrustConfig.builder() + .apiKey("test-api-key") + .apiUrl("http://localhost:" + wireMock.getPort()) + .build(); + apiClient = BraintrustApiClient.of(config); + } + + @Test + void testFetchAll() { + // Mock the first batch with a cursor + wireMock.stubFor( + post(urlEqualTo("/v1/dataset/" + datasetId + "/fetch")) + .withRequestBody(matchingJsonPath("$.cursor", absent())) + .willReturn( + aResponse() + .withStatus(200) + .withHeader("Content-Type", "application/json") + .withBody( + """ + { + "events": [ + { + "id": "event-1", + "input": "Question 1", + "expected": "Answer 1" + }, + { + "id": "event-2", + "input": "Question 2", + "expected": "Answer 2" + } + ], + "cursor": "next-page-token" + } + """))); + + // Mock the second batch without a cursor (last page) + wireMock.stubFor( + post(urlEqualTo("/v1/dataset/" + datasetId + "/fetch")) + .withRequestBody(matchingJsonPath("$.cursor", equalTo("next-page-token"))) + .willReturn( + aResponse() + .withStatus(200) + .withHeader("Content-Type", "application/json") + .withBody( + """ + { + "events": [ + { + "id": "event-3", + "input": "Question 3", + "expected": "Answer 3" + } + ], + "cursor": null + } + """))); + + // Create dataset with smaller batch size + DatasetBrainstoreImpl dataset = + new DatasetBrainstoreImpl<>(apiClient, datasetId, "test-version", 2); + + List> cases = new ArrayList<>(); + dataset.forEach(cases::add); + + // Verify we got all 3 cases + assertEquals(3, cases.size()); + assertEquals("Question 1", cases.get(0).input()); + assertEquals("Question 2", cases.get(1).input()); + assertEquals("Question 3", cases.get(2).input()); + + // Verify the API was called twice (once for each batch) + wireMock.verify(2, postRequestedFor(urlEqualTo("/v1/dataset/" + datasetId + "/fetch"))); + } + + @Test + void testEmptyDataset() { + // Mock empty dataset + wireMock.stubFor( + post(urlEqualTo("/v1/dataset/" + datasetId + "/fetch")) + .willReturn( + aResponse() + .withStatus(200) + .withHeader("Content-Type", "application/json") + .withBody( + """ + { + "events": [], + "cursor": null + } + """))); + + DatasetBrainstoreImpl dataset = + new DatasetBrainstoreImpl<>(apiClient, datasetId, "test-version"); + + List> cases = new ArrayList<>(); + dataset.forEach(cases::add); + + // Verify we got no cases + assertEquals(0, cases.size()); + + // Verify the API was called once + wireMock.verify(1, postRequestedFor(urlEqualTo("/v1/dataset/" + datasetId + "/fetch"))); + } + + @Test + void testFetchWithPinnedVersion() { + String projectName = "test-project"; + String datasetName = "test-dataset"; + String pinnedVersion = "12345"; + + // Mock the query endpoint + wireMock.stubFor( + get(urlPathEqualTo("/v1/dataset")) + .withQueryParam("project_name", equalTo(projectName)) + .withQueryParam("dataset_name", equalTo(datasetName)) + .willReturn( + aResponse() + .withStatus(200) + .withHeader("Content-Type", "application/json") + .withBody( + """ + { + "objects": [ + { + "id": "dataset-789", + "project_id": "proj-456", + "name": "test-dataset", + "description": "Test dataset", + "created_at": "2024-01-01T00:00:00Z", + "updated_at": "2024-01-15T12:30:00Z" + } + ] + } + """))); + + // Mock the fetch endpoint - only succeeds if version is passed correctly + wireMock.stubFor( + post(urlEqualTo("/v1/dataset/dataset-789/fetch")) + .withRequestBody(matchingJsonPath("$.version", equalTo(pinnedVersion))) + .willReturn( + aResponse() + .withStatus(200) + .withHeader("Content-Type", "application/json") + .withBody( + """ + { + "events": [ + { + "id": "event-1", + "input": "test input", + "expected": "test output", + "metadata": {}, + "tags": [] + } + ], + "cursor": null + } + """))); + + Dataset dataset = + Dataset.fetchFromBraintrust(apiClient, projectName, datasetName, pinnedVersion); + + assertEquals("dataset-789", dataset.id()); + assertEquals(Optional.of(pinnedVersion), dataset.version()); + + // Open cursor and fetch data to trigger the API call with version + List> cases = new ArrayList<>(); + dataset.forEach(cases::add); + + // Verify we got the expected case + assertEquals(1, cases.size()); + assertEquals("test input", cases.get(0).input()); + assertEquals("test output", cases.get(0).expected()); + + // Verify the fetch endpoint was called with the version + wireMock.verify( + 1, + postRequestedFor(urlEqualTo("/v1/dataset/dataset-789/fetch")) + .withRequestBody(matchingJsonPath("$.version", equalTo(pinnedVersion)))); + } + + @Test + void testFetchFromBraintrustNotFound() { + String projectName = "test-project"; + String datasetName = "nonexistent"; + + // Mock empty response + wireMock.stubFor( + get(urlPathEqualTo("/v1/dataset")) + .withQueryParam("project_name", equalTo(projectName)) + .withQueryParam("dataset_name", equalTo(datasetName)) + .willReturn( + aResponse() + .withStatus(200) + .withHeader("Content-Type", "application/json") + .withBody( + """ + {"objects": []} + """))); + + RuntimeException exception = + assertThrows( + RuntimeException.class, + () -> + Dataset.fetchFromBraintrust( + apiClient, projectName, datasetName, null)); + + assertTrue(exception.getMessage().contains("Dataset not found")); + } +}