Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 2 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,8 @@ var eval = braintrust.<String, String>evalBuilder()
.taskFunction(getFoodType)
.scorers(
Scorer.of(
"fruit_scorer",
result -> "fruit".equals(result) ? 1.0 : 0.0),
Scorer.of(
"vegetable_scorer",
result -> "vegetable".equals(result) ? 1.0 : 0.0))
"exact_match",
(expected, result) -> expected.equals(result) ? 1.0 : 0.0))
.build();
var result = eval.run();
System.out.println("\n\n" + result.createReportString());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,10 @@ public static void main(String[] args) throws Exception {
var eval =
braintrust
.<String, String>evalBuilder()
.name("java-eval-x-" + System.currentTimeMillis()) // NOTE: if you use a
// constant, additional runs
// will append new cases to
// the same experiment
// NOTE: pre-existing experiment names will append results
.name("java-eval-x-" + System.currentTimeMillis())
.cases(
new DatasetCase<>(
DatasetCase.of(
"strawberry",
"fruit",
// custom tags which appear in Braintrust UI
Expand All @@ -49,14 +47,13 @@ public static void main(String[] args) throws Exception {
DatasetCase.of("asparagus", "vegetable"),
DatasetCase.of("apple", "fruit"),
DatasetCase.of("banana", "fruit"))
// Or, to fetch a remote dataset:
// .dataset(braintrust.fetchDataset("my-dataset-name"))
.taskFunction(getFoodType)
.scorers(
Scorer.of(
"fruit_scorer",
result -> "fruit".equals(result) ? 1.0 : 0.0),
Scorer.of(
"vegetable_scorer",
result -> "vegetable".equals(result) ? 1.0 : 0.0))
"exact_match",
(expected, result) -> expected.equals(result) ? 1.0 : 0.0))
.build();
var result = eval.run();
System.out.println("\n\n" + result.createReportString());
Expand Down
16 changes: 16 additions & 0 deletions src/main/java/dev/braintrust/Origin.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package dev.braintrust;

import com.fasterxml.jackson.annotation.JsonProperty;

/** Generic pointer to an object in braintrust */
public record Origin(
/** origin type. e.g. dataset, playground_logs */
@JsonProperty("object_type") String objectType,
/** id of the object. e.g. dataset id */
@JsonProperty("object_id") String objectId,
/** id of the specific item within the origin. e.g. dataset row id */
@JsonProperty("id") String id,
/** origin xact id */
@JsonProperty("_xact_id") String xactId,
/** creation timestamp of the origin */
@JsonProperty("created") String createdTimestamp) {}
15 changes: 14 additions & 1 deletion src/main/java/dev/braintrust/eval/DatasetBrainstoreImpl.java
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,20 @@ public Optional<DatasetCase<INPUT, OUTPUT>> next() {
}

DatasetCase<INPUT, OUTPUT> datasetCase =
new DatasetCase<>(input, expected, tags, metadata);
new DatasetCase<>(
input,
expected,
tags,
metadata,
Optional.of(
new dev.braintrust.Origin(
"dataset",
Objects.requireNonNull(
(String) event.get("dataset_id")),
Objects.requireNonNull((String) event.get("id")),
Objects.requireNonNull((String) event.get("_xact_id")),
Objects.requireNonNull(
(String) event.get("created")))));

return Optional.of(datasetCase);
}
Expand Down
16 changes: 14 additions & 2 deletions src/main/java/dev/braintrust/eval/DatasetCase.java
Original file line number Diff line number Diff line change
@@ -1,17 +1,29 @@
package dev.braintrust.eval;

import dev.braintrust.Origin;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import javax.annotation.Nonnull;

/** A single row in a dataset. */
public record DatasetCase<INPUT, OUTPUT>(
INPUT input,
OUTPUT expected,
@Nonnull List<String> tags,
@Nonnull Map<String, Object> metadata) {
@Nonnull Map<String, Object> metadata,
/** origin information. empty for in-memory cases */
Optional<Origin> origin) {

public static <INPUT, OUTPUT> DatasetCase<INPUT, OUTPUT> of(INPUT input, OUTPUT expected) {
return new DatasetCase<>(input, expected, List.of(), Map.of());
return of(input, expected, List.of(), Map.of());
}

public static <INPUT, OUTPUT> DatasetCase<INPUT, OUTPUT> of(
INPUT input,
OUTPUT expected,
@Nonnull List<String> tags,
@Nonnull Map<String, Object> metadata) {
return new DatasetCase<>(input, expected, tags, metadata, Optional.empty());
}
}
4 changes: 3 additions & 1 deletion src/main/java/dev/braintrust/eval/Eval.java
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,6 @@ public EvalResult run() {

@SneakyThrows
private void evalOne(String experimentId, DatasetCase<INPUT, OUTPUT> datasetCase) {
JSON_MAPPER.writeValueAsString(Map.of("type", "eval"));
var rootSpan =
tracer.spanBuilder("eval") // TODO: allow names for eval cases
.setNoParent() // each eval case is its own trace
Expand All @@ -87,6 +86,9 @@ private void evalOne(String experimentId, DatasetCase<INPUT, OUTPUT> datasetCase
"braintrust.input_json", json(Map.of("input", datasetCase.input())))
.setAttribute("braintrust.expected", json(datasetCase.expected()))
.startSpan();
if (datasetCase.origin().isPresent()) {
rootSpan.setAttribute("braintrust.origin", json(datasetCase.origin().get()));
}
if (!datasetCase.tags().isEmpty()) {
rootSpan.setAttribute(
AttributeKey.stringArrayKey("braintrust.tags"), datasetCase.tags());
Expand Down
11 changes: 4 additions & 7 deletions src/main/java/dev/braintrust/eval/Scorer.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ public interface Scorer<INPUT, OUTPUT> {
List<Score> score(TaskResult<INPUT, OUTPUT> taskResult);

static <INPUT, OUTPUT> Scorer<INPUT, OUTPUT> of(
String scorerName, Function<OUTPUT, Double> scorerFn) {
String scorerName, Function<TaskResult<INPUT, OUTPUT>, Double> scorerFn) {
return new Scorer<>() {
@Override
public String getName() {
Expand All @@ -26,15 +26,13 @@ public String getName() {

@Override
public List<Score> score(TaskResult<INPUT, OUTPUT> taskResult) {
return List.of(new Score(scorerName, scorerFn.apply(taskResult.result())));
return List.of(new Score(scorerName, scorerFn.apply(taskResult)));
}
};
}

/** Deprecated. Use {@link #of(String, Function)} or implement the Scorer interface instead. */
@Deprecated
static <INPUT, OUTPUT> Scorer<INPUT, OUTPUT> of(
String scorerName, BiFunction<EvalCase<INPUT, OUTPUT>, OUTPUT, Double> scorerFn) {
String scorerName, BiFunction<OUTPUT, OUTPUT, Double> scorerFn) {
return new Scorer<>() {
@Override
public String getName() {
Expand All @@ -47,8 +45,7 @@ public List<Score> score(TaskResult<INPUT, OUTPUT> taskResult) {
new Score(
scorerName,
scorerFn.apply(
EvalCase.from(taskResult.datasetCase()),
taskResult.result())));
taskResult.datasetCase().expected(), taskResult.result())));
}
};
}
Expand Down
71 changes: 60 additions & 11 deletions src/test/java/dev/braintrust/eval/DatasetBrainstoreImplTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,7 @@
import com.github.tomakehurst.wiremock.junit5.WireMockExtension;
import dev.braintrust.api.BraintrustApiClient;
import dev.braintrust.config.BraintrustConfig;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import java.util.*;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;
Expand Down Expand Up @@ -52,12 +50,20 @@ void testFetchAll() {
{
"events": [
{
"id": "event-1",
"object_type": "dataset",
"dataset_id": "test-dataset-123",
"id": "123-1",
"created": "sometimestamp",
"_xact_id": "1",
"input": "Question 1",
"expected": "Answer 1"
},
{
"id": "event-2",
"object_type": "dataset",
"dataset_id": "test-dataset-123",
"id": "123-2",
"_xact_id": "1",
"created": "sometimestamp",
"input": "Question 2",
"expected": "Answer 2"
}
Expand All @@ -79,7 +85,11 @@ void testFetchAll() {
{
"events": [
{
"id": "event-3",
"object_type": "dataset",
"dataset_id": "test-dataset-123",
"id": "123-3",
"_xact_id": "1",
"created": "sometimestamp",
"input": "Question 3",
"expected": "Answer 3"
}
Expand All @@ -97,9 +107,40 @@ void testFetchAll() {

// Verify we got all 3 cases
assertEquals(3, cases.size());
assertEquals("Question 1", cases.get(0).input());
List<String> tags = List.of();
Map<String, Object> metadata = Map.of();
assertEquals(
new DatasetCase<>(
"Question 1",
"Answer 1",
tags,
metadata,
Optional.of(
new dev.braintrust.Origin(
"dataset", datasetId, "123-1", "1", "sometimestamp"))),
cases.get(0));
assertEquals("Question 2", cases.get(1).input());
assertEquals(
new DatasetCase<>(
"Question 2",
"Answer 2",
tags,
metadata,
Optional.of(
new dev.braintrust.Origin(
"dataset", datasetId, "123-2", "1", "sometimestamp"))),
cases.get(1));
assertEquals("Question 3", cases.get(2).input());
assertEquals(
new DatasetCase<>(
"Question 3",
"Answer 3",
tags,
metadata,
Optional.of(
new dev.braintrust.Origin(
"dataset", datasetId, "123-3", "1", "sometimestamp"))),
cases.get(2));

// Verify the API was called twice (once for each batch)
wireMock.verify(2, postRequestedFor(urlEqualTo("/v1/dataset/" + datasetId + "/fetch")));
Expand Down Expand Up @@ -155,12 +196,16 @@ void testFetchWithPinnedVersion() {
{
"objects": [
{
"object_type": "dataset",
"dataset_id": "test-dataset-123",
"id": "dataset-789",
"project_id": "proj-456",
"name": "test-dataset",
"description": "Test dataset",
"created_at": "2024-01-01T00:00:00Z",
"updated_at": "2024-01-15T12:30:00Z"
"_xact_id": "12345",
"input": "test input",
"expected": "test output",
"created": "sometimestamp"
}
]
}
Expand All @@ -179,11 +224,15 @@ void testFetchWithPinnedVersion() {
{
"events": [
{
"id": "event-1",
"object_type": "dataset",
"dataset_id": "test-dataset-123",
"id": "some-row-id",
"input": "test input",
"expected": "test output",
"metadata": {},
"tags": []
"tags": [],
"_xact_id": "12346",
"created": "sometimestamp"
}
],
"cursor": null
Expand Down
Loading