diff --git a/examples/src/main/java/dev/braintrust/examples/CustomOpenTelemetryExample.java b/examples/src/main/java/dev/braintrust/examples/CustomOpenTelemetryExample.java index cea869b..0d259db 100644 --- a/examples/src/main/java/dev/braintrust/examples/CustomOpenTelemetryExample.java +++ b/examples/src/main/java/dev/braintrust/examples/CustomOpenTelemetryExample.java @@ -2,6 +2,10 @@ import dev.braintrust.Braintrust; import io.opentelemetry.api.GlobalOpenTelemetry; +import io.opentelemetry.api.baggage.propagation.W3CBaggagePropagator; +import io.opentelemetry.api.trace.propagation.W3CTraceContextPropagator; +import io.opentelemetry.context.propagation.ContextPropagators; +import io.opentelemetry.context.propagation.TextMapPropagator; import io.opentelemetry.exporter.otlp.http.logs.OtlpHttpLogRecordExporter; import io.opentelemetry.exporter.otlp.http.metrics.OtlpHttpMetricExporter; import io.opentelemetry.exporter.otlp.http.trace.OtlpHttpSpanExporter; @@ -53,11 +57,18 @@ public static void main(String[] args) throws Exception { var braintrust = Braintrust.get(); braintrust.openTelemetryEnable(tracerBuilder, loggerBuilder, meterBuilder); + // context propagation is required only if you wish to see distributed traces in Braintrust + var contextPropagator = + ContextPropagators.create( + TextMapPropagator.composite( + W3CTraceContextPropagator.getInstance(), + W3CBaggagePropagator.getInstance())); var openTelemetry = OpenTelemetrySdk.builder() .setTracerProvider(tracerBuilder.build()) .setLoggerProvider(loggerBuilder.build()) .setMeterProvider(meterBuilder.build()) + .setPropagators(contextPropagator) .build(); GlobalOpenTelemetry.set(openTelemetry); registerShutdownHook(openTelemetry); diff --git a/src/main/java/dev/braintrust/BraintrustUtils.java b/src/main/java/dev/braintrust/BraintrustUtils.java index 1924952..b85ff42 100644 --- a/src/main/java/dev/braintrust/BraintrustUtils.java +++ b/src/main/java/dev/braintrust/BraintrustUtils.java @@ -3,6 +3,7 @@ import dev.braintrust.api.BraintrustApiClient; import java.net.URI; import java.net.URISyntaxException; +import javax.annotation.Nonnull; public class BraintrustUtils { /** construct a URI to link to a specific braintrust project within an org */ @@ -26,4 +27,19 @@ public static URI createProjectURI( throw new RuntimeException(e); } } + + static Parent parseParent(@Nonnull String parentStr) { + String[] parts = parentStr.split(":"); + if (parts.length != 2) { + throw new IllegalArgumentException("Invalid parent format: " + parentStr); + } + return new Parent(parts[0], parts[1]); + } + + /** Represents a parsed parent with type and ID. */ + public record Parent(String type, String id) { + public String toParentValue() { + return type + ":" + id; + } + } } diff --git a/src/main/java/dev/braintrust/trace/BraintrustContext.java b/src/main/java/dev/braintrust/trace/BraintrustContext.java index a3967e2..5911796 100644 --- a/src/main/java/dev/braintrust/trace/BraintrustContext.java +++ b/src/main/java/dev/braintrust/trace/BraintrustContext.java @@ -1,5 +1,8 @@ package dev.braintrust.trace; +import dev.braintrust.BraintrustUtils; +import io.opentelemetry.api.baggage.Baggage; +import io.opentelemetry.api.baggage.BaggageBuilder; import io.opentelemetry.api.trace.Span; import io.opentelemetry.context.Context; import io.opentelemetry.context.ContextKey; @@ -7,11 +10,13 @@ import java.util.Optional; import javax.annotation.Nonnull; import javax.annotation.Nullable; +import lombok.extern.slf4j.Slf4j; /** * Used to identify the braintrust parent for spans and experiments. SDK users probably don't want * to use this and instead should use {@link BraintrustTracing} or {@link dev.braintrust.eval.Eval} */ +@Slf4j public final class BraintrustContext { private static final ContextKey KEY = ContextKey.named("braintrust-context"); @@ -28,7 +33,55 @@ private BraintrustContext(@Nullable String projectId, @Nullable String experimen public static Context ofExperiment(@Nonnull String experimentId, @Nonnull Span span) { Objects.requireNonNull(experimentId); Objects.requireNonNull(span); - return Context.current().with(span).with(KEY, new BraintrustContext(null, experimentId)); + Context ctx = + Context.current().with(span).with(KEY, new BraintrustContext(null, experimentId)); + return setParentInBaggage(ctx, "experiment_id", experimentId); + } + + /** + * Sets the parent in baggage for distributed tracing. + * + *

Baggage propagates automatically via W3C headers when propagators are configured, allowing + * parent context to flow across process boundaries. + * + * @param ctx the context to update + * @param parentType the type of parent (e.g., "experiment_id", "project_name") + * @param parentId the ID of the parent + * @return updated context with baggage set + */ + static Context setParentInBaggage( + @Nonnull Context ctx, @Nonnull String parentType, @Nonnull String parentId) { + try { + String parentValue = (new BraintrustUtils.Parent(parentType, parentId)).toParentValue(); + Baggage existingBaggage = Baggage.fromContext(ctx); + BaggageBuilder builder = existingBaggage.toBuilder(); + builder.put(BraintrustTracing.PARENT_KEY, parentValue); + return ctx.with(builder.build()); + } catch (Exception e) { + log.warn("Failed to set parent in baggage: {}", e.getMessage(), e); + return ctx; + } + } + + /** + * Retrieves the parent value from baggage for distributed tracing. + * + *

This method checks the OpenTelemetry Baggage for the braintrust.parent attribute. This is + * used as a fallback when parent information is not available in the Context (e.g., when + * crossing process boundaries). + * + * @param ctx the context to check + * @return the parent value if present in baggage (format: "type:id") + */ + static Optional getParentFromBaggage(@Nonnull Context ctx) { + try { + Baggage baggage = Baggage.fromContext(ctx); + String parentValue = baggage.getEntryValue(BraintrustTracing.PARENT_KEY); + return Optional.ofNullable(parentValue).filter(s -> !s.isEmpty()); + } catch (Exception e) { + log.warn("Failed to get parent from baggage: {}", e.getMessage(), e); + return Optional.empty(); + } } /** Retrieves a BraintrustContext from the given Context. */ diff --git a/src/main/java/dev/braintrust/trace/BraintrustSpanProcessor.java b/src/main/java/dev/braintrust/trace/BraintrustSpanProcessor.java index dbc94da..9eda204 100644 --- a/src/main/java/dev/braintrust/trace/BraintrustSpanProcessor.java +++ b/src/main/java/dev/braintrust/trace/BraintrustSpanProcessor.java @@ -42,16 +42,26 @@ public void onStart(@NotNull Context parentContext, ReadWriteSpan span) { // Check if parent context has Braintrust attributes first var btContext = BraintrustContext.fromContext(parentContext); if (btContext == null) { - // Get parent from the config if otel doesn't have it - config.getBraintrustParentValue() - .ifPresent( - parentValue -> { - span.setAttribute(PARENT, parentValue); - log.debug( - "OnStart: set parent {} for span {}", - parentValue, - span.getName()); - }); + // Check baggage for distributed tracing (cross-process parent propagation) + var parentFromBaggage = BraintrustContext.getParentFromBaggage(parentContext); + if (parentFromBaggage.isPresent()) { + span.setAttribute(PARENT, parentFromBaggage.get()); + log.debug( + "OnStart: set parent {} from baggage for span {}", + parentFromBaggage.get(), + span.getName()); + } else { + // Get parent from the config if otel doesn't have it + config.getBraintrustParentValue() + .ifPresent( + parentValue -> { + span.setAttribute(PARENT, parentValue); + log.debug( + "OnStart: set parent {} for span {}", + parentValue, + span.getName()); + }); + } } else { btContext .projectId() diff --git a/src/main/java/dev/braintrust/trace/BraintrustTracing.java b/src/main/java/dev/braintrust/trace/BraintrustTracing.java index 0a8259d..f6d7eab 100644 --- a/src/main/java/dev/braintrust/trace/BraintrustTracing.java +++ b/src/main/java/dev/braintrust/trace/BraintrustTracing.java @@ -3,7 +3,11 @@ import dev.braintrust.config.BraintrustConfig; import io.opentelemetry.api.GlobalOpenTelemetry; import io.opentelemetry.api.OpenTelemetry; +import io.opentelemetry.api.baggage.propagation.W3CBaggagePropagator; import io.opentelemetry.api.trace.Tracer; +import io.opentelemetry.api.trace.propagation.W3CTraceContextPropagator; +import io.opentelemetry.context.propagation.ContextPropagators; +import io.opentelemetry.context.propagation.TextMapPropagator; import io.opentelemetry.sdk.OpenTelemetrySdk; import io.opentelemetry.sdk.common.CompletableResultCode; import io.opentelemetry.sdk.logs.SdkLoggerProvider; @@ -54,12 +58,18 @@ public static OpenTelemetry of(@Nonnull BraintrustConfig config, boolean registe var tracerBuilder = SdkTracerProvider.builder(); var loggerBuilder = SdkLoggerProvider.builder(); var meterBuilder = SdkMeterProvider.builder(); + var contextPropagator = + ContextPropagators.create( + TextMapPropagator.composite( + W3CTraceContextPropagator.getInstance(), + W3CBaggagePropagator.getInstance())); enable(config, tracerBuilder, loggerBuilder, meterBuilder); var openTelemetry = OpenTelemetrySdk.builder() .setTracerProvider(tracerBuilder.build()) .setLoggerProvider(loggerBuilder.build()) .setMeterProvider(meterBuilder.build()) + .setPropagators(contextPropagator) .build(); if (registerGlobal) { GlobalOpenTelemetry.set(openTelemetry); diff --git a/src/test/java/dev/braintrust/BraintrustUtilsTest.java b/src/test/java/dev/braintrust/BraintrustUtilsTest.java index fcf5468..d95c77f 100644 --- a/src/test/java/dev/braintrust/BraintrustUtilsTest.java +++ b/src/test/java/dev/braintrust/BraintrustUtilsTest.java @@ -1,6 +1,7 @@ package dev.braintrust; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; import dev.braintrust.api.BraintrustApiClient; import java.net.URI; @@ -18,4 +19,28 @@ public void testBuildProjectUri() { URI.create("http://someserver:3009/app/some%20org/p/some%20project"), BraintrustUtils.createProjectURI("http://someserver:3009/", orgAndProject)); } + + @Test + void testParseParent() { + var parsed1 = BraintrustUtils.parseParent("experiment_id:abc123"); + assertEquals("experiment_id", parsed1.type()); + assertEquals("abc123", parsed1.id()); + + var parsed2 = BraintrustUtils.parseParent("project_name:my-project"); + assertEquals("project_name", parsed2.type()); + assertEquals("my-project", parsed2.id()); + + assertThrows( + Exception.class, + () -> BraintrustUtils.parseParent("invalid-no-colon"), + "Should throw on invalid format"); + assertThrows( + Exception.class, + () -> BraintrustUtils.parseParent("invalid:too:many:colons"), + "Should throw on invalid format"); + assertThrows( + Exception.class, + () -> BraintrustUtils.parseParent(""), + "Should throw on empty string"); + } } diff --git a/src/test/java/dev/braintrust/TestHarness.java b/src/test/java/dev/braintrust/TestHarness.java index 033563d..3597e12 100644 --- a/src/test/java/dev/braintrust/TestHarness.java +++ b/src/test/java/dev/braintrust/TestHarness.java @@ -6,6 +6,10 @@ import dev.braintrust.config.BraintrustConfig; import dev.braintrust.prompt.BraintrustPromptLoader; import io.opentelemetry.api.GlobalOpenTelemetry; +import io.opentelemetry.api.baggage.propagation.W3CBaggagePropagator; +import io.opentelemetry.api.trace.propagation.W3CTraceContextPropagator; +import io.opentelemetry.context.propagation.ContextPropagators; +import io.opentelemetry.context.propagation.TextMapPropagator; import io.opentelemetry.sdk.OpenTelemetrySdk; import io.opentelemetry.sdk.logs.SdkLoggerProvider; import io.opentelemetry.sdk.metrics.SdkMeterProvider; @@ -81,11 +85,17 @@ private TestHarness(@Nonnull Braintrust braintrust) { braintrust.openTelemetryEnable(tracerBuilder, loggerBuilder, meterBuilder); // Add the in-memory span exporter for testing tracerBuilder.addSpanProcessor(SimpleSpanProcessor.create(this.spanExporter)); + var contextPropagator = + ContextPropagators.create( + TextMapPropagator.composite( + W3CTraceContextPropagator.getInstance(), + W3CBaggagePropagator.getInstance())); var openTelemetry = OpenTelemetrySdk.builder() .setTracerProvider(tracerBuilder.build()) .setLoggerProvider(loggerBuilder.build()) .setMeterProvider(meterBuilder.build()) + .setPropagators(contextPropagator) .build(); this.openTelemetry = openTelemetry; } diff --git a/src/test/java/dev/braintrust/trace/DistributedTracingTest.java b/src/test/java/dev/braintrust/trace/DistributedTracingTest.java new file mode 100644 index 0000000..63fa460 --- /dev/null +++ b/src/test/java/dev/braintrust/trace/DistributedTracingTest.java @@ -0,0 +1,197 @@ +package dev.braintrust.trace; + +import static org.junit.jupiter.api.Assertions.*; + +import dev.braintrust.TestHarness; +import io.opentelemetry.api.common.AttributeKey; +import io.opentelemetry.api.trace.Span; +import io.opentelemetry.api.trace.Tracer; +import io.opentelemetry.context.Context; +import io.opentelemetry.context.propagation.TextMapGetter; +import io.opentelemetry.context.propagation.TextMapPropagator; +import io.opentelemetry.sdk.trace.data.SpanData; +import java.util.HashMap; +import java.util.Map; +import javax.annotation.Nullable; +import org.junit.jupiter.api.Test; + +public class DistributedTracingTest { + + private static final AttributeKey PARENT_ATTR_KEY = + AttributeKey.stringKey(BraintrustTracing.PARENT_KEY); + + @Test + void testDistributedTracingPropagation() throws Exception { + TestHarness harness = TestHarness.setup(); + TextMapPropagator propagator = + harness.openTelemetry().getPropagators().getTextMapPropagator(); + + Tracer clientTracer = harness.openTelemetry().getTracer("test-client"); + Tracer serverTracer = harness.openTelemetry().getTracer("test-server"); + + com.sun.net.httpserver.HttpServer httpServer = + com.sun.net.httpserver.HttpServer.create( + new java.net.InetSocketAddress("localhost", 0), 0); + int port = httpServer.getAddress().getPort(); + + httpServer.createContext( + "/test", + exchange -> { + Map headers = new HashMap<>(); + exchange.getRequestHeaders() + .forEach( + (key, values) -> { + if (!values.isEmpty()) { + headers.put(key, values.get(0)); + } + }); + + Context serverContext = + propagator.extract(Context.root(), headers, MapGetter.INSTANCE); + Span serverSpan = + serverTracer + .spanBuilder("server-operation") + .setParent(serverContext) + .startSpan(); + + try (var scope = serverContext.with(serverSpan).makeCurrent()) { + String response = "OK"; + exchange.sendResponseHeaders(200, response.length()); + exchange.getResponseBody().write(response.getBytes()); + exchange.getResponseBody().close(); + } finally { + serverSpan.end(); + } + }); + + httpServer.start(); + + try { + String experimentId = "abc123-http-test"; + Context experimentContext = + BraintrustContext.setParentInBaggage( + Context.root(), "experiment_id", experimentId); + + Span clientSpan = + clientTracer + .spanBuilder("client-operation") + .setParent(experimentContext) + .startSpan(); + Context clientContext = experimentContext.with(clientSpan); + + try (var scope = clientContext.makeCurrent()) { + java.net.http.HttpClient httpClient = java.net.http.HttpClient.newHttpClient(); + java.net.http.HttpRequest.Builder requestBuilder = + java.net.http.HttpRequest.newBuilder() + .uri(java.net.URI.create("http://localhost:" + port + "/test")) + .GET(); + propagator.inject( + clientContext, + requestBuilder, + (builder, key, value) -> builder.header(key, value)); + + java.net.http.HttpRequest request = requestBuilder.build(); + java.net.http.HttpResponse response = + httpClient.send( + request, java.net.http.HttpResponse.BodyHandlers.ofString()); + + assertEquals(200, response.statusCode(), "HTTP request should succeed"); + } finally { + clientSpan.end(); + } + + var allSpans = harness.awaitExportedSpans(); + assertEquals(2, allSpans.size(), "Expected two spans (client + server)"); + + SpanData clientSpanData = + allSpans.stream() + .filter(s -> s.getName().equals("client-operation")) + .findFirst() + .orElseThrow(); + SpanData serverSpanData = + allSpans.stream() + .filter(s -> s.getName().equals("server-operation")) + .findFirst() + .orElseThrow(); + + String clientParentAttr = clientSpanData.getAttributes().get(PARENT_ATTR_KEY); + assertNotNull(clientParentAttr, "Client span should have braintrust.parent attribute"); + assertEquals( + "experiment_id:" + experimentId, + clientParentAttr, + "Client parent attribute should match experiment"); + + String serverParentAttr = serverSpanData.getAttributes().get(PARENT_ATTR_KEY); + assertNotNull( + serverParentAttr, + "Server span should have braintrust.parent attribute propagated via HTTP"); + assertEquals( + "experiment_id:" + experimentId, + serverParentAttr, + "Server parent attribute should match client experiment"); + + assertEquals( + clientSpanData.getTraceId(), + serverSpanData.getTraceId(), + "Trace IDs should match across HTTP boundary"); + + assertEquals( + clientSpanData.getSpanId(), + serverSpanData.getParentSpanId(), + "Server span should be a child of client span"); + + } finally { + httpServer.stop(0); + } + } + + /** + * Tests that parent can be retrieved from baggage when context doesn't have it. + * + *

This verifies the fallback mechanism in BraintrustSpanProcessor. + */ + @Test + void testGetParentFromBaggage() { + String experimentId = "test-experiment-123"; + String parentValue = "experiment_id:" + experimentId; + + // Create a context with parent in baggage + Context ctx = + BraintrustContext.setParentInBaggage(Context.root(), "experiment_id", experimentId); + + // Verify we can retrieve it + var retrieved = BraintrustContext.getParentFromBaggage(ctx); + assertTrue(retrieved.isPresent(), "Should retrieve parent from baggage"); + assertEquals(parentValue, retrieved.get(), "Parent value should match"); + } + + /** TextMapGetter for extracting headers from a Map (case-insensitive for HTTP headers). */ + private enum MapGetter implements TextMapGetter> { + INSTANCE; + + @Override + public Iterable keys(Map carrier) { + return carrier.keySet(); + } + + @Override + @Nullable + public String get(@Nullable Map carrier, String key) { + if (carrier == null) { + return null; + } + // Try exact match first + String value = carrier.get(key); + if (value != null) { + return value; + } + // Fall back to case-insensitive search for HTTP headers + for (Map.Entry entry : carrier.entrySet()) { + if (entry.getKey().equalsIgnoreCase(key)) { + return entry.getValue(); + } + } + return null; + } + } +}