diff --git a/build.gradle b/build.gradle index 5c2876b..f722964 100644 --- a/build.gradle +++ b/build.gradle @@ -82,6 +82,10 @@ dependencies { // Anthropic Instrumentation compileOnly "com.anthropic:anthropic-java:2.8.1" testImplementation "com.anthropic:anthropic-java:2.8.1" + + // Google GenAI Instrumentation + compileOnly "com.google.genai:google-genai:1.20.0" + testImplementation "com.google.genai:google-genai:1.20.0" } /** diff --git a/examples/build.gradle b/examples/build.gradle index 932fb30..2244681 100644 --- a/examples/build.gradle +++ b/examples/build.gradle @@ -27,6 +27,14 @@ dependencies { implementation 'com.openai:openai-java:2.8.1' // to run anthropic examples implementation "com.anthropic:anthropic-java:2.8.1" + // to run gemini examples + implementation 'com.google.genai:google-genai:1.20.0' + // spring ai examples + implementation 'org.springframework.ai:spring-ai-google-genai:1.1.0' + // spring boot for SpringAIExample (exclude logback, use slf4j-simple like other examples) + implementation('org.springframework.boot:spring-boot-starter:3.4.1') { + exclude group: 'org.springframework.boot', module: 'spring-boot-starter-logging' + } } application { @@ -105,3 +113,31 @@ task runPromptFetching(type: JavaExec) { suspend = false } } + +task runGeminiInstrumentation(type: JavaExec) { + group = 'Braintrust SDK Examples' + description = 'Run the Gemini instrumentation example. NOTE: this requires GOOGLE_API_KEY or GEMINI_API_KEY to be exported and will make a small call to google, using your tokens' + classpath = sourceSets.main.runtimeClasspath + mainClass = 'dev.braintrust.examples.GeminiInstrumentationExample' + systemProperty 'org.slf4j.simpleLogger.log.dev.braintrust', braintrustLogLevel + debugOptions { + enabled = true + port = 5566 + server = true + suspend = false + } +} + +task runSpringAI(type: JavaExec) { + group = 'Braintrust SDK Examples' + description = 'Run the Spring Boot + Spring AI + Gemini example.' + classpath = sourceSets.main.runtimeClasspath + mainClass = 'dev.braintrust.examples.SpringAIExample' + systemProperty 'org.slf4j.simpleLogger.log.dev.braintrust', braintrustLogLevel + debugOptions { + enabled = true + port = 5566 + server = true + suspend = false + } +} diff --git a/examples/src/main/java/dev/braintrust/examples/GeminiInstrumentationExample.java b/examples/src/main/java/dev/braintrust/examples/GeminiInstrumentationExample.java new file mode 100644 index 0000000..da138df --- /dev/null +++ b/examples/src/main/java/dev/braintrust/examples/GeminiInstrumentationExample.java @@ -0,0 +1,76 @@ +package dev.braintrust.examples; + +import com.google.genai.Client; +import com.google.genai.types.GenerateContentConfig; +import com.google.genai.types.GenerateContentResponse; +import dev.braintrust.Braintrust; +import dev.braintrust.instrumentation.genai.BraintrustGenAI; +import io.opentelemetry.api.OpenTelemetry; + +/** Basic OTel + Gemini instrumentation example */ +public class GeminiInstrumentationExample { + public static void main(String[] args) throws Exception { + if (null == System.getenv("GOOGLE_API_KEY") && null == System.getenv("GEMINI_API_KEY")) { + System.err.println( + "\n" + + "WARNING: Neither GOOGLE_API_KEY nor GEMINI_API_KEY found. This" + + " example will likely fail.\n" + + "Set either: export GOOGLE_API_KEY='your-key' (recommended) or export" + + " GEMINI_API_KEY='your-key'\n"); + } + + Braintrust braintrust = Braintrust.get(); + OpenTelemetry openTelemetry = braintrust.openTelemetryCreate(); + // CLAUDE: don't change the type of geminiClient -- sdk users must use the google genai + // client in their signature, not our instrumented client. + Client geminiClient = BraintrustGenAI.wrap(openTelemetry, new Client.Builder()); + + var tracer = openTelemetry.getTracer("my-instrumentation"); + var rootSpan = tracer.spanBuilder("gemini-java-instrumentation-example").startSpan(); + try (var ignored = rootSpan.makeCurrent()) { + generateContentExample(geminiClient); + // generateContentStreamingExample(client); + } finally { + rootSpan.end(); + } + + var url = + braintrust.projectUri() + + "/logs?r=%s&s=%s" + .formatted( + rootSpan.getSpanContext().getTraceId(), + rootSpan.getSpanContext().getSpanId()); + + System.out.println( + "\n\n Example complete! View your data in Braintrust: %s\n".formatted(url)); + } + + private static void generateContentExample(Client client) { + var config = GenerateContentConfig.builder().temperature(0.0f).maxOutputTokens(50).build(); + + var response = + client.models.generateContent( + "gemini-2.0-flash-lite", "What is the third planet from the sun?", config); + + System.out.println("\n~~~ GENERATE CONTENT RESPONSE: %s\n".formatted(response.text())); + } + + private static void generateContentStreamingExample(Client client) { + var config = GenerateContentConfig.builder().temperature(0.0f).maxOutputTokens(50).build(); + + System.out.println("\n~~~ STREAMING RESPONSE:"); + var stream = + client.models.generateContentStream( + "gemini-2.0-flash-exp", + "Who was the first president of the United States?", + config); + + for (GenerateContentResponse chunk : stream) { + String text = chunk.text(); + if (text != null && !text.isEmpty()) { + System.out.print(text); + } + } + System.out.println("\n"); + } +} diff --git a/examples/src/main/java/dev/braintrust/examples/SpringAIExample.java b/examples/src/main/java/dev/braintrust/examples/SpringAIExample.java new file mode 100644 index 0000000..dfdfc3a --- /dev/null +++ b/examples/src/main/java/dev/braintrust/examples/SpringAIExample.java @@ -0,0 +1,115 @@ +package dev.braintrust.examples; + +import com.google.genai.Client; +import dev.braintrust.Braintrust; +import dev.braintrust.config.BraintrustConfig; +import dev.braintrust.instrumentation.genai.BraintrustGenAI; +import io.opentelemetry.api.OpenTelemetry; +import io.opentelemetry.api.trace.Span; +import io.opentelemetry.api.trace.Tracer; +import io.opentelemetry.context.Scope; +import org.springframework.ai.chat.model.ChatModel; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.google.genai.GoogleGenAiChatModel; +import org.springframework.ai.google.genai.GoogleGenAiChatOptions; +import org.springframework.boot.CommandLineRunner; +import org.springframework.boot.SpringApplication; +import org.springframework.boot.autoconfigure.SpringBootApplication; +import org.springframework.boot.autoconfigure.http.client.HttpClientAutoConfiguration; +import org.springframework.boot.autoconfigure.web.client.RestClientAutoConfiguration; +import org.springframework.context.annotation.Bean; + +/** Spring Boot application demonstrating Braintrust + Spring AI integration */ +@SpringBootApplication( + // NOTE: these excludes are specific to the Braintrust examples project to play nice with + // other examples' classpaths. Excludes are not required for production spring apps + exclude = {HttpClientAutoConfiguration.class, RestClientAutoConfiguration.class}) +public class SpringAIExample { + + public static void main(String[] args) { + SpringApplication.run(SpringAIExample.class, args); + } + + @Bean + public CommandLineRunner run(ChatModel chatModel, Tracer tracer, Braintrust braintrust) { + return args -> { + Span rootSpan = tracer.spanBuilder("spring-ai-example").startSpan(); + try (Scope scope = rootSpan.makeCurrent()) { + System.out.println("\n=== Running Spring Boot Example ===\n"); + + // Make a simple chat call + var prompt = new Prompt("what's the name of the most popular java DI framework?"); + var response = chatModel.call(prompt); + + System.out.println( + "~~~ SPRING AI CHAT RESPONSE: %s\n" + .formatted(response.getResult().getOutput().getText())); + } finally { + rootSpan.end(); + } + + var url = + braintrust.projectUri() + + "/logs?r=%s&s=%s" + .formatted( + rootSpan.getSpanContext().getTraceId(), + rootSpan.getSpanContext().getSpanId()); + + System.out.println( + "\n Example complete! View your data in Braintrust: %s\n".formatted(url)); + }; + } + + @Bean + public Braintrust braintrust() { + return Braintrust.get(BraintrustConfig.fromEnvironment()); + } + + @Bean + public OpenTelemetry openTelemetry(Braintrust braintrust) { + return braintrust.openTelemetryCreate(); + } + + @Bean + public Tracer tracer(OpenTelemetry openTelemetry) { + return openTelemetry.getTracer("spring-ai-instrumentation"); + } + + @Bean + public String aiProvider() { + // return "openai"; + // return "anthropic"; + return "google"; + } + + @Bean + public ChatModel chatModel(String aiProvider, OpenTelemetry openTelemetry) { + return switch (aiProvider) { + case "openai", "anthropic" -> { + throw new RuntimeException("TODO: " + aiProvider); + } + case "google" -> { + if (null == System.getenv("GOOGLE_API_KEY") + && null == System.getenv("GEMINI_API_KEY")) { + System.err.println( + "\n" + + "WARNING: Neither GOOGLE_API_KEY nor GEMINI_API_KEY found. This" + + " example will likely fail.\n" + + "Set either: export GOOGLE_API_KEY='your-key' (recommended) or" + + " export GEMINI_API_KEY='your-key'\n"); + } + Client genAIClient = BraintrustGenAI.wrap(openTelemetry, new Client.Builder()); + yield GoogleGenAiChatModel.builder() + .genAiClient(genAIClient) + .defaultOptions( + GoogleGenAiChatOptions.builder() + .model("gemini-2.0-flash-lite") + .temperature(0.0) + .maxOutputTokens(50) + .build()) + .build(); + } + default -> throw new RuntimeException("unsupported provider: " + aiProvider); + }; + } +} diff --git a/src/main/java/com/google/genai/BraintrustApiClient.java b/src/main/java/com/google/genai/BraintrustApiClient.java new file mode 100644 index 0000000..5a0bc87 --- /dev/null +++ b/src/main/java/com/google/genai/BraintrustApiClient.java @@ -0,0 +1,416 @@ +package com.google.genai; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.genai.types.HttpOptions; +import io.opentelemetry.api.OpenTelemetry; +import io.opentelemetry.api.trace.Span; +import io.opentelemetry.api.trace.SpanKind; +import io.opentelemetry.api.trace.StatusCode; +import io.opentelemetry.api.trace.Tracer; +import io.opentelemetry.context.Context; +import io.opentelemetry.context.Scope; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.concurrent.CompletableFuture; +import javax.annotation.Nullable; +import lombok.SneakyThrows; +import lombok.extern.slf4j.Slf4j; +import okhttp3.Headers; +import okhttp3.MediaType; +import okhttp3.ResponseBody; + +/** + * Instrumented wrapper for ApiClient that adds OpenTelemetry spans. + * + *

This class lives in com.google.genai package to access package-private ApiClient class. + */ +@Slf4j +class BraintrustApiClient extends ApiClient { + private static final ObjectMapper JSON_MAPPER = new ObjectMapper(); + + private final ApiClient delegate; + private final Tracer tracer; + + public BraintrustApiClient(ApiClient delegate, OpenTelemetry openTelemetry) { + // We must call super(), but we'll override all methods to delegate + // Pass the delegate's config to minimize differences + super( + delegate.apiKey != null ? delegate.apiKey : Optional.empty(), + delegate.project != null ? delegate.project : Optional.empty(), + delegate.location != null ? delegate.location : Optional.empty(), + delegate.credentials != null ? delegate.credentials : Optional.empty(), + delegate.httpOptions != null ? Optional.of(delegate.httpOptions) : Optional.empty(), + delegate.clientOptions != null ? delegate.clientOptions : Optional.empty()); + this.delegate = delegate; + this.tracer = openTelemetry.getTracer("io.opentelemetry.gemini-java-1.20"); + } + + private void tagSpan( + Span span, + @Nullable String genAIEndpoint, + @Nullable String requestMethod, + @Nullable String requestBody, + @Nullable String responseBody) { + try { + Map metadata = new java.util.HashMap<>(); + metadata.put("provider", "gemini"); + + // Parse request + if (requestBody != null) { + var requestJson = JSON_MAPPER.readValue(requestBody, Map.class); + + // Extract metadata fields + for (String field : + List.of( + "model", + "systemInstruction", + "tools", + "toolConfig", + "safetySettings", + "cachedContent")) { + if (requestJson.containsKey(field)) { + metadata.put(field, requestJson.get(field)); + } + } + + // Extract generationConfig fields into metadata + if (requestJson.get("generationConfig") instanceof Map) { + var genConfig = (Map) requestJson.get("generationConfig"); + for (String field : + List.of( + "temperature", + "topP", + "topK", + "candidateCount", + "maxOutputTokens", + "stopSequences", + "responseMimeType", + "responseSchema")) { + if (genConfig.containsKey(field)) { + metadata.put(field, genConfig.get(field)); + } + } + } + + // Build input_json + Map inputJson = new java.util.HashMap<>(); + String model = getModel(genAIEndpoint); + if (requestJson.containsKey("model")) { + inputJson.put("model", requestJson.get("model")); + } else if (model != null) { + inputJson.put("model", model); + } + if (requestJson.containsKey("contents")) { + inputJson.put("contents", requestJson.get("contents")); + } + if (requestJson.containsKey("generationConfig")) { + inputJson.put("config", requestJson.get("generationConfig")); + } + + span.setAttribute( + "braintrust.input_json", JSON_MAPPER.writeValueAsString(inputJson)); + } + + // Parse response + if (responseBody != null) { + var responseJson = JSON_MAPPER.readValue(responseBody, Map.class); + + // Extract model version from response + if (responseJson.containsKey("modelVersion")) { + metadata.put("model", responseJson.get("modelVersion")); + } + + // Set full response as output_json + span.setAttribute( + "braintrust.output_json", JSON_MAPPER.writeValueAsString(responseJson)); + + // Parse usage metadata for metrics + if (responseJson.get("usageMetadata") instanceof Map) { + var usage = (Map) responseJson.get("usageMetadata"); + Map metrics = new java.util.HashMap<>(); + + if (usage.containsKey("promptTokenCount")) { + metrics.put("prompt_tokens", (Number) usage.get("promptTokenCount")); + } + if (usage.containsKey("candidatesTokenCount")) { + metrics.put( + "completion_tokens", (Number) usage.get("candidatesTokenCount")); + } + if (usage.containsKey("totalTokenCount")) { + metrics.put("tokens", (Number) usage.get("totalTokenCount")); + } + if (usage.containsKey("cachedContentTokenCount")) { + metrics.put( + "prompt_cached_tokens", + (Number) usage.get("cachedContentTokenCount")); + } + + span.setAttribute( + "braintrust.metrics", JSON_MAPPER.writeValueAsString(metrics)); + } + } + + // Set metadata + span.setAttribute("braintrust.metadata", JSON_MAPPER.writeValueAsString(metadata)); + + // Set span_attributes to mark as LLM span + span.setAttribute( + "braintrust.span_attributes", + JSON_MAPPER.writeValueAsString(Map.of("type", "llm"))); + + } catch (Throwable t) { + log.warn("failed to tag gemini span", t); + } + } + + // Override accessor methods to delegate to original client + @Override + public boolean vertexAI() { + return delegate.vertexAI(); + } + + @Override + public String project() { + return delegate.project(); + } + + @Override + public String location() { + return delegate.location(); + } + + @Override + public String apiKey() { + return delegate.apiKey(); + } + + @Override + @SneakyThrows + public ApiResponse request( + String requestMethod, + String genAIUrl, + String requestBody, + Optional options) { + Span span = + tracer.spanBuilder(getOperation(genAIUrl)).setSpanKind(SpanKind.CLIENT).startSpan(); + try (Scope scope = span.makeCurrent()) { + ApiResponse response = delegate.request(requestMethod, genAIUrl, requestBody, options); + BufferedApiResponse bufferedResponse = new BufferedApiResponse(response); + span.setStatus(StatusCode.OK); + tagSpan(span, genAIUrl, requestMethod, requestBody, bufferedResponse.getBodyAsString()); + return bufferedResponse; + } catch (Throwable t) { + span.setStatus(StatusCode.ERROR, t.getMessage()); + span.recordException(t); + throw t; + } finally { + span.end(); + } + } + + @Override + @SneakyThrows + public ApiResponse request( + String requestMethod, + String genAIUrl, + byte[] requestBodyBytes, + Optional options) { + Span span = + tracer.spanBuilder(getOperation(genAIUrl)).setSpanKind(SpanKind.CLIENT).startSpan(); + try (Scope scope = span.makeCurrent()) { + ApiResponse response = + delegate.request(requestMethod, genAIUrl, requestBodyBytes, options); + BufferedApiResponse bufferedResponse = new BufferedApiResponse(response); + span.setStatus(StatusCode.OK); + tagSpan( + span, + genAIUrl, + requestMethod, + new String(requestBodyBytes), + bufferedResponse.getBodyAsString()); + return bufferedResponse; + } catch (Throwable t) { + span.setStatus(StatusCode.ERROR, t.getMessage()); + span.recordException(t); + throw t; + } finally { + span.end(); + } + } + + @Override + public CompletableFuture asyncRequest( + String method, String url, String body, Optional options) { + Span span = tracer.spanBuilder(getOperation(url)).setSpanKind(SpanKind.CLIENT).startSpan(); + Context context = Context.current().with(span); + + return delegate.asyncRequest(method, url, body, options) + .handle( + (response, throwable) -> { + try (Scope scope = context.makeCurrent()) { + if (throwable != null) { + span.setStatus(StatusCode.ERROR, throwable.getMessage()); + span.recordException(throwable); + throw new RuntimeException(throwable); + } + + try { + // Buffer the response so we can read it for instrumentation + BufferedApiResponse bufferedResponse = + new BufferedApiResponse(response); + span.setStatus(StatusCode.OK); + tagSpan( + span, + url, + method, + body, + bufferedResponse.getBodyAsString()); + return (ApiResponse) bufferedResponse; + } catch (Exception e) { + span.setStatus(StatusCode.ERROR, e.getMessage()); + span.recordException(e); + throw new RuntimeException(e); + } + } finally { + span.end(); + } + }); + } + + @Override + public CompletableFuture asyncRequest( + String method, String url, byte[] body, Optional options) { + Span span = tracer.spanBuilder(getOperation(url)).setSpanKind(SpanKind.CLIENT).startSpan(); + Context context = Context.current().with(span); + + return delegate.asyncRequest(method, url, body, options) + .handle( + (response, throwable) -> { + try (Scope scope = context.makeCurrent()) { + if (throwable != null) { + span.setStatus(StatusCode.ERROR, throwable.getMessage()); + span.recordException(throwable); + throw new RuntimeException(throwable); + } + + try { + // Buffer the response so we can read it for instrumentation + BufferedApiResponse bufferedResponse = + new BufferedApiResponse(response); + span.setStatus(StatusCode.OK); + tagSpan( + span, + url, + method, + new String(body), + bufferedResponse.getBodyAsString()); + return (ApiResponse) bufferedResponse; + } catch (Exception e) { + span.setStatus(StatusCode.ERROR, e.getMessage()); + span.recordException(e); + throw new RuntimeException(e); + } + } finally { + span.end(); + } + }); + } + + private static String getModel(String genAIEndpoint) { + try { + // endpoint has model and request type. Example: + // models/gemini-2.0-flash-lite:generateContent + var segments = genAIEndpoint.split("/"); + var lastSegment = segments[segments.length - 1].split(":"); + return lastSegment[0]; + } catch (Exception e) { + log.debug("unable to determine model name", e); + return "gemini"; + } + } + + private static String getOperation(String genAIEndpoint) { + try { + // endpoint has model and request type. Example: + // models/gemini-2.0-flash-lite:generateContent + var segments = genAIEndpoint.split("/"); + var lastSegment = segments[segments.length - 1].split(":"); + return toSnakeCase(lastSegment[1]); + } catch (Exception e) { + log.debug("unable to determine operation name", e); + return "gemini.api.call"; + } + } + + /** convert a camelCaseString to a snake_case_string */ + private static String toSnakeCase(String camelCase) { + if (camelCase == null || camelCase.isEmpty()) return camelCase; + + StringBuilder sb = new StringBuilder(camelCase.length() + 5); + + for (int i = 0; i < camelCase.length(); i++) { + char c = camelCase.charAt(i); + if (Character.isUpperCase(c)) { + if (i > 0) sb.append('_'); + sb.append(Character.toLowerCase(c)); + } else { + sb.append(c); + } + } + + return sb.toString(); + } + + /** + * Wrapper for ApiResponse that buffers the response body so it can be read multiple times. + * + *

This allows us to capture the response body for instrumentation while still allowing the + * delegate to read it. + */ + static class BufferedApiResponse extends ApiResponse { + private final ApiResponse delegate; + private final byte[] bufferedBody; + + public BufferedApiResponse(ApiResponse delegate) throws Exception { + this.delegate = delegate; + // Read the body once and buffer it + ResponseBody body = delegate.getBody(); + this.bufferedBody = body != null ? body.bytes() : null; + } + + @Override + public ResponseBody getBody() { + if (bufferedBody == null) { + return null; + } + // Create a new ResponseBody from the buffered bytes + // Get the original content type if available + MediaType contentType = null; + try { + ResponseBody originalBody = delegate.getBody(); + if (originalBody != null) { + contentType = originalBody.contentType(); + } + } catch (Exception e) { + // Ignore, use null content type + } + return ResponseBody.create(bufferedBody, contentType); + } + + @Override + public Headers getHeaders() { + return delegate.getHeaders(); + } + + @Override + public void close() { + delegate.close(); + } + + /** Get the buffered body as a string for instrumentation. */ + public String getBodyAsString() { + return bufferedBody != null ? new String(bufferedBody) : null; + } + } +} diff --git a/src/main/java/com/google/genai/BraintrustInstrumentation.java b/src/main/java/com/google/genai/BraintrustInstrumentation.java new file mode 100644 index 0000000..8b2754f --- /dev/null +++ b/src/main/java/com/google/genai/BraintrustInstrumentation.java @@ -0,0 +1,97 @@ +package com.google.genai; + +import io.opentelemetry.api.OpenTelemetry; +import java.lang.reflect.Field; +import java.lang.reflect.Modifier; +import java.util.logging.Logger; +import lombok.extern.slf4j.Slf4j; + +/** + * Helper class for instrumenting Gemini Client by replacing its internal ApiClient. + * + *

This class lives in com.google.genai package to access package-private ApiClient class. + */ +@Slf4j +public class BraintrustInstrumentation { + private static final Logger logger = + Logger.getLogger(BraintrustInstrumentation.class.getName()); + + /** + * Wraps a Client's internal ApiClient with an instrumented version. + * + * @param client the client to instrument + * @param openTelemetry the OpenTelemetry instance + * @return the same client instance, but with instrumented ApiClient + */ + public static Client wrapClient(Client client, OpenTelemetry openTelemetry) throws Exception { + // Get the apiClient field from Client + Field clientApiClientField = Client.class.getDeclaredField("apiClient"); + clientApiClientField.setAccessible(true); + ApiClient originalApiClient = (ApiClient) clientApiClientField.get(client); + + // Create instrumented wrapper + BraintrustApiClient instrumentedApiClient = + new BraintrustApiClient(originalApiClient, openTelemetry); + + // Replace apiClient in Client + setFinalField(client, clientApiClientField, instrumentedApiClient); + + // Replace apiClient in all Client service fields + replaceApiClientInService(client.models, instrumentedApiClient); + replaceApiClientInService(client.batches, instrumentedApiClient); + replaceApiClientInService(client.caches, instrumentedApiClient); + replaceApiClientInService(client.operations, instrumentedApiClient); + replaceApiClientInService(client.chats, instrumentedApiClient); + replaceApiClientInService(client.files, instrumentedApiClient); + replaceApiClientInService(client.tunings, instrumentedApiClient); + + // Replace apiClient in all Client.async service fields + if (client.async != null) { + replaceApiClientInService(client.async.models, instrumentedApiClient); + replaceApiClientInService(client.async.batches, instrumentedApiClient); + replaceApiClientInService(client.async.caches, instrumentedApiClient); + replaceApiClientInService(client.async.operations, instrumentedApiClient); + replaceApiClientInService(client.async.chats, instrumentedApiClient); + replaceApiClientInService(client.async.files, instrumentedApiClient); + replaceApiClientInService(client.async.tunings, instrumentedApiClient); + } + + logger.info("Successfully instrumented Gemini client"); + return client; + } + + /** Replaces the apiClient field in a service object (Models, Batches, etc). */ + private static void replaceApiClientInService(Object service, ApiClient instrumentedApiClient) + throws Exception { + if (service == null) { + return; + } + try { + Field apiClientField = service.getClass().getDeclaredField("apiClient"); + apiClientField.setAccessible(true); + setFinalField(service, apiClientField, instrumentedApiClient); + } catch (NoSuchFieldException e) { + // Some services might not have an apiClient field + logger.fine("No apiClient field found in " + service.getClass().getSimpleName()); + } + } + + /** + * Sets a final field using reflection. + * + *

This works by making the field accessible and, on older Java versions, removing the final + * modifier. + */ + private static void setFinalField(Object target, Field field, Object value) throws Exception { + field.setAccessible(true); + // Try to remove final modifier + try { + Field modifiersField = Field.class.getDeclaredField("modifiers"); + modifiersField.setAccessible(true); + modifiersField.setInt(field, field.getModifiers() & ~Modifier.FINAL); + } catch (NoSuchFieldException e) { + // ignore + } + field.set(target, value); + } +} diff --git a/src/main/java/dev/braintrust/instrumentation/genai/BraintrustGenAI.java b/src/main/java/dev/braintrust/instrumentation/genai/BraintrustGenAI.java new file mode 100644 index 0000000..3de2800 --- /dev/null +++ b/src/main/java/dev/braintrust/instrumentation/genai/BraintrustGenAI.java @@ -0,0 +1,29 @@ +package dev.braintrust.instrumentation.genai; + +import com.google.genai.BraintrustInstrumentation; +import com.google.genai.Client; +import io.opentelemetry.api.OpenTelemetry; +import lombok.extern.slf4j.Slf4j; + +/** Braintrust Google GenAI client instrumentation. */ +@Slf4j +public class BraintrustGenAI { + /** + * Instrument Google GenAI Client with Braintrust traces. + * + *

This wraps the client's internal HTTP layer to capture all API calls with OpenTelemetry + * spans. + * + * @param openTelemetry the OpenTelemetry instance + * @param genAIClientBuilder the Gemini client builder + * @return an instrumented Gemini client + */ + public static Client wrap(OpenTelemetry openTelemetry, Client.Builder genAIClientBuilder) { + try { + return BraintrustInstrumentation.wrapClient(genAIClientBuilder.build(), openTelemetry); + } catch (Throwable t) { + log.error("failed to instrument gemini client", t); + return genAIClientBuilder.build(); + } + } +} diff --git a/src/test/java/dev/braintrust/instrumentation/genai/BraintrustGenAITest.java b/src/test/java/dev/braintrust/instrumentation/genai/BraintrustGenAITest.java new file mode 100644 index 0000000..dd406fc --- /dev/null +++ b/src/test/java/dev/braintrust/instrumentation/genai/BraintrustGenAITest.java @@ -0,0 +1,268 @@ +package dev.braintrust.instrumentation.genai; + +import static com.github.tomakehurst.wiremock.client.WireMock.*; +import static com.github.tomakehurst.wiremock.core.WireMockConfiguration.wireMockConfig; +import static org.junit.jupiter.api.Assertions.*; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.github.tomakehurst.wiremock.junit5.WireMockExtension; +import com.google.genai.Client; +import com.google.genai.types.GenerateContentConfig; +import com.google.genai.types.HttpOptions; +import dev.braintrust.TestHarness; +import io.opentelemetry.api.common.AttributeKey; +import lombok.SneakyThrows; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +public class BraintrustGenAITest { + private static final ObjectMapper JSON_MAPPER = new ObjectMapper(); + + @RegisterExtension + static WireMockExtension wireMock = + WireMockExtension.newInstance().options(wireMockConfig().dynamicPort()).build(); + + private TestHarness testHarness; + + @BeforeEach + void beforeEach() { + testHarness = TestHarness.setup(); + wireMock.resetAll(); + } + + @Test + @SneakyThrows + void testWrapGemini() { + // Mock the Gemini API response + wireMock.stubFor( + post(urlPathMatching("/v1beta/models/.*:generateContent")) + .willReturn( + aResponse() + .withStatus(200) + .withHeader("Content-Type", "application/json") + .withBody( + """ + { + "candidates": [ + { + "content": { + "parts": [ + { + "text": "The capital of France is Paris." + } + ], + "role": "model" + }, + "finishReason": "STOP" + } + ], + "usageMetadata": { + "promptTokenCount": 10, + "candidatesTokenCount": 8, + "totalTokenCount": 18 + }, + "modelVersion": "gemini-2.0-flash-lite" + } + """))); + + // Create Gemini client pointing to WireMock server + HttpOptions httpOptions = + HttpOptions.builder().baseUrl("http://localhost:" + wireMock.getPort()).build(); + + // Wrap with Braintrust instrumentation + var geminiClient = + BraintrustGenAI.wrap( + testHarness.openTelemetry(), + new Client.Builder().apiKey("test-api-key").httpOptions(httpOptions)); + + var config = GenerateContentConfig.builder().temperature(0.0f).maxOutputTokens(50).build(); + + var response = + geminiClient.models.generateContent( + "gemini-2.0-flash-lite", "What is the capital of France?", config); + + // Verify the response + assertNotNull(response); + wireMock.verify(1, postRequestedFor(urlPathMatching("/v1beta/models/.*:generateContent"))); + assertEquals("The capital of France is Paris.", response.text()); + + // Verify spans were exported + var spans = testHarness.awaitExportedSpans(); + assertEquals(1, spans.size(), "Expected exactly 1 span to be created"); + var span = spans.get(0); + + // Verify span name matches the operation + assertEquals("generate_content", span.getName()); + + // Verify braintrust.metadata contains provider and model + String metadataJson = + span.getAttributes().get(AttributeKey.stringKey("braintrust.metadata")); + assertNotNull(metadataJson, "braintrust.metadata should be set"); + var metadata = JSON_MAPPER.readTree(metadataJson); + assertEquals("gemini", metadata.get("provider").asText()); + assertEquals("gemini-2.0-flash-lite", metadata.get("model").asText()); + assertEquals(0.0, metadata.get("temperature").asDouble()); + assertEquals(50, metadata.get("maxOutputTokens").asInt()); + + // Verify braintrust.metrics contains token counts + String metricsJson = span.getAttributes().get(AttributeKey.stringKey("braintrust.metrics")); + assertNotNull(metricsJson, "braintrust.metrics should be set"); + var metrics = JSON_MAPPER.readTree(metricsJson); + assertEquals(10, metrics.get("prompt_tokens").asInt()); + assertEquals(8, metrics.get("completion_tokens").asInt()); + assertEquals(18, metrics.get("tokens").asInt()); + + // Verify braintrust.span_attributes marks this as an LLM span + String spanAttributesJson = + span.getAttributes().get(AttributeKey.stringKey("braintrust.span_attributes")); + assertNotNull(spanAttributesJson, "braintrust.span_attributes should be set"); + var spanAttributes = JSON_MAPPER.readTree(spanAttributesJson); + assertEquals("llm", spanAttributes.get("type").asText()); + + // Verify braintrust.input_json contains the request + String inputJson = + span.getAttributes().get(AttributeKey.stringKey("braintrust.input_json")); + assertNotNull(inputJson, "braintrust.input_json should be set"); + var input = JSON_MAPPER.readTree(inputJson); + assertEquals("gemini-2.0-flash-lite", input.get("model").asText()); + assertTrue(input.has("contents"), "input should have contents"); + assertTrue(input.has("config"), "input should have config"); + + // Verify braintrust.output_json contains the response + String outputJson = + span.getAttributes().get(AttributeKey.stringKey("braintrust.output_json")); + assertNotNull(outputJson, "braintrust.output_json should be set"); + var output = JSON_MAPPER.readTree(outputJson); + assertTrue(output.has("candidates"), "output should have candidates"); + assertEquals("STOP", output.get("candidates").get(0).get("finishReason").asText()); + assertEquals( + "The capital of France is Paris.", + output.get("candidates") + .get(0) + .get("content") + .get("parts") + .get(0) + .get("text") + .asText()); + } + + @Test + @SneakyThrows + void testWrapGeminiAsync() { + // Mock the Gemini API response + wireMock.stubFor( + post(urlPathMatching("/v1beta/models/.*:generateContent")) + .willReturn( + aResponse() + .withStatus(200) + .withHeader("Content-Type", "application/json") + .withBody( + """ + { + "candidates": [ + { + "content": { + "parts": [ + { + "text": "The capital of Germany is Berlin." + } + ], + "role": "model" + }, + "finishReason": "STOP" + } + ], + "usageMetadata": { + "promptTokenCount": 10, + "candidatesTokenCount": 8, + "totalTokenCount": 18 + }, + "modelVersion": "gemini-2.0-flash-lite" + } + """))); + + // Create Gemini client pointing to WireMock server + HttpOptions httpOptions = + HttpOptions.builder().baseUrl("http://localhost:" + wireMock.getPort()).build(); + + // Wrap with Braintrust instrumentation + var geminiClient = + BraintrustGenAI.wrap( + testHarness.openTelemetry(), + new Client.Builder().apiKey("test-api-key").httpOptions(httpOptions)); + + var config = GenerateContentConfig.builder().temperature(0.0f).maxOutputTokens(50).build(); + + // Call async method and wait for completion + var responseFuture = + geminiClient.async.models.generateContent( + "gemini-2.0-flash-lite", "What is the capital of Germany?", config); + + var response = responseFuture.get(); // Wait for completion + + // Verify the response + assertNotNull(response); + wireMock.verify(1, postRequestedFor(urlPathMatching("/v1beta/models/.*:generateContent"))); + assertEquals("The capital of Germany is Berlin.", response.text()); + + // Verify spans were exported + var spans = testHarness.awaitExportedSpans(); + assertEquals(1, spans.size(), "Expected exactly 1 span to be created"); + var span = spans.get(0); + + // Verify span name matches the operation + assertEquals("generate_content", span.getName()); + + // Verify braintrust.metadata contains provider and model + String metadataJson = + span.getAttributes().get(AttributeKey.stringKey("braintrust.metadata")); + assertNotNull(metadataJson, "braintrust.metadata should be set"); + var metadata = JSON_MAPPER.readTree(metadataJson); + assertEquals("gemini", metadata.get("provider").asText()); + assertEquals("gemini-2.0-flash-lite", metadata.get("model").asText()); + assertEquals(0.0, metadata.get("temperature").asDouble()); + assertEquals(50, metadata.get("maxOutputTokens").asInt()); + + // Verify braintrust.metrics contains token counts + String metricsJson = span.getAttributes().get(AttributeKey.stringKey("braintrust.metrics")); + assertNotNull(metricsJson, "braintrust.metrics should be set"); + var metrics = JSON_MAPPER.readTree(metricsJson); + assertEquals(10, metrics.get("prompt_tokens").asInt()); + assertEquals(8, metrics.get("completion_tokens").asInt()); + assertEquals(18, metrics.get("tokens").asInt()); + + // Verify braintrust.span_attributes marks this as an LLM span + String spanAttributesJson = + span.getAttributes().get(AttributeKey.stringKey("braintrust.span_attributes")); + assertNotNull(spanAttributesJson, "braintrust.span_attributes should be set"); + var spanAttributes = JSON_MAPPER.readTree(spanAttributesJson); + assertEquals("llm", spanAttributes.get("type").asText()); + + // Verify braintrust.input_json contains the request + String inputJson = + span.getAttributes().get(AttributeKey.stringKey("braintrust.input_json")); + assertNotNull(inputJson, "braintrust.input_json should be set"); + var input = JSON_MAPPER.readTree(inputJson); + assertEquals("gemini-2.0-flash-lite", input.get("model").asText()); + assertTrue(input.has("contents"), "input should have contents"); + assertTrue(input.has("config"), "input should have config"); + + // Verify braintrust.output_json contains the response + String outputJson = + span.getAttributes().get(AttributeKey.stringKey("braintrust.output_json")); + assertNotNull(outputJson, "braintrust.output_json should be set"); + var output = JSON_MAPPER.readTree(outputJson); + assertTrue(output.has("candidates"), "output should have candidates"); + assertEquals("STOP", output.get("candidates").get(0).get("finishReason").asText()); + assertEquals( + "The capital of Germany is Berlin.", + output.get("candidates") + .get(0) + .get("content") + .get("parts") + .get(0) + .get("text") + .asText()); + } +}