diff --git a/README.md b/README.md index b564a2154..c61bc04bd 100644 --- a/README.md +++ b/README.md @@ -159,6 +159,7 @@ For SAP internal development, you can also use `SNAPSHOT` builds from the [inter The AI SDK leverages the destination concept from the SAP Cloud SDK to manage the connection to AI Core. This opens up a wide range of possibilities to customize the connection, including adding custom headers. +The following shows how to add custom headers to all requests sent to AI Core. ```java var service = new AiCoreService(); @@ -170,15 +171,19 @@ var destination = // AI Core client service = service.withBaseDestination(destination); DeploymentApi client = new DeploymentApi(service); +``` + +For more information, please refer to the [AI Core connectivity guide](https://sap.github.io/ai-sdk/docs/java/guides/connecting-to-ai-core) and the [SAP Cloud SDK documentation](https://sap.github.io/cloud-sdk/docs/java/features/connectivity/http-destinations). -// Orchestration client -OrchestrationClient client = new OrchestrationClient(destination); +There is also a convenient method to add custom headers to single calls through the Orchestration or OpenAI client. -// OpenAI client -OpenAiClient client2 = OpenAiClient.withCustomDestination(destination); +```java +var client = new OrchestrationClient(); + +var result = client.withHeader("my-header-key", "my-header-value").chatCompletion(prompt, config); ``` -For more information, please refer to the [AI Core connectivity guide](https://sap.github.io/ai-sdk/docs/java/guides/connecting-to-ai-core) and the [SAP Cloud SDK documentation](https://sap.github.io/cloud-sdk/docs/java/features/connectivity/http-destinations). +For more information on this feature, see the respective documentation of the [OrchestrationClient](https://sap.github.io/ai-sdk/docs/java/orchestration/chat-completion#custom-headers) and [OpenAIClient](https://sap.github.io/ai-sdk/docs/java/foundation-models/openai/chat-completion#custom-headers). ### _"There's a vulnerability warning `CVE-2021-41251`?"_ diff --git a/docs/release_notes.md b/docs/release_notes.md index e6890713b..60a92966b 100644 --- a/docs/release_notes.md +++ b/docs/release_notes.md @@ -27,6 +27,8 @@ - Added `OpenAiChatModel` - [Prompt Registry] [Using Prompt Registry Templates in SpringAI.](https://sap.github.io/ai-sdk/docs/java/ai-core/prompt-registry#using-templates-in-springai) - Added `SpringAiConverter` +- [Orchestration] [Added convenience to add custom headers to individual orchestration calls.](https://sap.github.io/ai-sdk/docs/java/orchestration/chat-completion#custom-headers) +- [OpenAI] [Added convenience to add custom headers to individual LLM calls.](https://sap.github.io/ai-sdk/docs/java/foundation-models/openai/chat-completion#custom-headers) ### 📈 Improvements diff --git a/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiClient.java b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiClient.java index d87658d4c..e24eab3e5 100644 --- a/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiClient.java +++ b/foundation-models/openai/src/main/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiClient.java @@ -25,8 +25,11 @@ import com.sap.cloud.sdk.cloudplatform.connectivity.ApacheHttpClient5Accessor; import com.sap.cloud.sdk.cloudplatform.connectivity.DefaultHttpDestination; import com.sap.cloud.sdk.cloudplatform.connectivity.Destination; +import com.sap.cloud.sdk.cloudplatform.connectivity.Header; import com.sap.cloud.sdk.cloudplatform.connectivity.HttpDestination; import java.io.IOException; +import java.util.ArrayList; +import java.util.List; import java.util.stream.Stream; import javax.annotation.Nonnull; import javax.annotation.Nullable; @@ -49,6 +52,7 @@ public final class OpenAiClient { @Nullable private String systemPrompt = null; @Nonnull private final Destination destination; + @Nonnull private final List
customHeaders = new ArrayList<>(); /** * Create a new OpenAI client for the given foundation model, using the default resource group. @@ -127,6 +131,23 @@ public OpenAiClient withSystemPrompt(@Nonnull final String systemPrompt) { return this; } + /** + * Create a new OpenAI client with a custom header added to every call made with this client + * + * @param key the key of the custom header to add + * @param value the value of the custom header to add + * @return a new client. + * @since 1.11.0 + */ + @Beta + @Nonnull + public OpenAiClient withHeader(@Nonnull final String key, @Nonnull final String value) { + final var newClient = new OpenAiClient(this.destination); + newClient.customHeaders.addAll(this.customHeaders); + newClient.customHeaders.add(new Header(key, value)); + return newClient; + } + /** * Generate a completion for the given string prompt as user. * @@ -395,7 +416,7 @@ private T execute( @Nonnull final Object payload, @Nonnull final Class responseType) { final var request = new HttpPost(path); - serializeAndSetHttpEntity(request, payload); + serializeAndSetHttpEntity(request, payload, this.customHeaders); return executeRequest(request, responseType); } @@ -405,15 +426,18 @@ private Stream executeStream( @Nonnull final Object payload, @Nonnull final Class deltaType) { final var request = new HttpPost(path); - serializeAndSetHttpEntity(request, payload); + serializeAndSetHttpEntity(request, payload, this.customHeaders); return streamRequest(request, deltaType); } private static void serializeAndSetHttpEntity( - @Nonnull final BasicClassicHttpRequest request, @Nonnull final Object payload) { + @Nonnull final BasicClassicHttpRequest request, + @Nonnull final Object payload, + @Nonnull final List
customHeaders) { try { final var json = JACKSON.writeValueAsString(payload); request.setEntity(new StringEntity(json, ContentType.APPLICATION_JSON)); + customHeaders.forEach(h -> request.addHeader(h.getName(), h.getValue())); } catch (final JsonProcessingException e) { throw new OpenAiClientException("Failed to serialize request parameters", e) .setHttpRequest(request); diff --git a/foundation-models/openai/src/test/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiClientTest.java b/foundation-models/openai/src/test/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiClientTest.java index abd9c85f6..1b157fbac 100644 --- a/foundation-models/openai/src/test/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiClientTest.java +++ b/foundation-models/openai/src/test/java/com/sap/ai/sdk/foundationmodels/openai/OpenAiClientTest.java @@ -1,6 +1,7 @@ package com.sap.ai.sdk.foundationmodels.openai; import static com.github.tomakehurst.wiremock.client.WireMock.*; +import static com.github.tomakehurst.wiremock.client.WireMock.anyUrl; import static com.sap.ai.sdk.foundationmodels.openai.model.OpenAiChatCompletionTool.ToolType.FUNCTION; import static com.sap.ai.sdk.foundationmodels.openai.model.OpenAiChatMessage.*; import static com.sap.ai.sdk.foundationmodels.openai.model.OpenAiContentFilterSeverityResult.Severity.SAFE; @@ -480,4 +481,30 @@ void chatCompletionTool() { } """))); } + + @Test + void testCustomHeaders() { + stubForChatCompletion(); + final var request = + new OpenAiChatCompletionRequest("Hello World! Why is this phrase so famous?"); + final var clientWithHeader = client.withHeader("Header-For-Both", "value"); + + final var result = clientWithHeader.withHeader("foo", "bar").chatCompletion(request); + assertThat(result).isNotNull(); + + var streamResult = + clientWithHeader + .withHeader("foot", "baz") + .streamChatCompletion("Hello World! Why is this phrase so famous?"); + assertThat(streamResult).isNotNull(); + + verify( + postRequestedFor(anyUrl()) + .withHeader("Header-For-Both", equalTo("value")) + .withHeader("foo", equalTo("bar"))); + verify( + postRequestedFor(anyUrl()) + .withHeader("Header-For-Both", equalTo("value")) + .withHeader("foot", equalTo("baz"))); + } } diff --git a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationClient.java b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationClient.java index d7e165cdb..483503b85 100644 --- a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationClient.java +++ b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationClient.java @@ -15,18 +15,22 @@ import com.sap.ai.sdk.orchestration.model.GlobalStreamOptions; import com.sap.ai.sdk.orchestration.model.ModuleConfigs; import com.sap.ai.sdk.orchestration.model.OrchestrationConfig; +import com.sap.cloud.sdk.cloudplatform.connectivity.Header; import com.sap.cloud.sdk.cloudplatform.connectivity.HttpDestination; import io.vavr.control.Try; +import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.function.Supplier; import java.util.stream.Stream; import javax.annotation.Nonnull; +import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; import lombok.val; /** Client to execute requests to the orchestration service. */ @Slf4j +@RequiredArgsConstructor(access = lombok.AccessLevel.PRIVATE) public class OrchestrationClient { private static final String DEFAULT_SCENARIO = "orchestration"; private static final String COMPLETION_ENDPOINT = "/v2/completion"; @@ -34,6 +38,7 @@ public class OrchestrationClient { static final ObjectMapper JACKSON = getOrchestrationObjectMapper(); private final OrchestrationHttpExecutor executor; + private final List
customHeaders = new ArrayList<>(); /** Default constructor. */ public OrchestrationClient() { @@ -156,7 +161,8 @@ private static Map getOutputFilteringChoices( @Nonnull public CompletionPostResponse executeRequest(@Nonnull final CompletionPostRequest request) throws OrchestrationClientException { - return executor.execute(COMPLETION_ENDPOINT, request, CompletionPostResponse.class); + return executor.execute( + COMPLETION_ENDPOINT, request, CompletionPostResponse.class, customHeaders); } /** @@ -198,7 +204,8 @@ public OrchestrationChatResponse executeRequestFromJsonModuleConfig( requestJson.set("orchestration_config", moduleConfigJson); return new OrchestrationChatResponse( - executor.execute(COMPLETION_ENDPOINT, requestJson, CompletionPostResponse.class)); + executor.execute( + COMPLETION_ENDPOINT, requestJson, CompletionPostResponse.class, customHeaders)); } /** @@ -214,7 +221,7 @@ public Stream streamChatCompletionDeltas( @Nonnull final CompletionPostRequest request) throws OrchestrationClientException { request.getConfig().setStream(GlobalStreamOptions.create().enabled(true).delimiters(null)); - return executor.stream(COMPLETION_ENDPOINT, request); + return executor.stream(COMPLETION_ENDPOINT, request, customHeaders); } /** @@ -228,6 +235,24 @@ public Stream streamChatCompletionDeltas( @Nonnull EmbeddingsPostResponse embed(@Nonnull final EmbeddingsPostRequest request) throws OrchestrationClientException { - return executor.execute("/v2/embeddings", request, EmbeddingsPostResponse.class); + return executor.execute("/v2/embeddings", request, EmbeddingsPostResponse.class, customHeaders); + } + + /** + * Create a new orchestration client with a custom header added to every call made with this + * client + * + * @param key the key of the custom header to add + * @param value the value of the custom header to add + * @return a new client. + * @since 1.11.0 + */ + @Beta + @Nonnull + public OrchestrationClient withHeader(@Nonnull final String key, @Nonnull final String value) { + final var newClient = new OrchestrationClient(this.executor); + newClient.customHeaders.addAll(this.customHeaders); + newClient.customHeaders.add(new Header(key, value)); + return newClient; } } diff --git a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationHttpExecutor.java b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationHttpExecutor.java index 59954c929..65f4e3ece 100644 --- a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationHttpExecutor.java +++ b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationHttpExecutor.java @@ -9,11 +9,13 @@ import com.sap.ai.sdk.core.common.ClientResponseHandler; import com.sap.ai.sdk.core.common.ClientStreamingHandler; import com.sap.cloud.sdk.cloudplatform.connectivity.ApacheHttpClient5Accessor; +import com.sap.cloud.sdk.cloudplatform.connectivity.Header; import com.sap.cloud.sdk.cloudplatform.connectivity.HttpDestination; import com.sap.cloud.sdk.cloudplatform.connectivity.exception.DestinationAccessException; import com.sap.cloud.sdk.cloudplatform.connectivity.exception.DestinationNotFoundException; import com.sap.cloud.sdk.cloudplatform.connectivity.exception.HttpClientInstantiationException; import java.io.IOException; +import java.util.List; import java.util.function.Supplier; import java.util.stream.Stream; import javax.annotation.Nonnull; @@ -39,12 +41,14 @@ class OrchestrationHttpExecutor { T execute( @Nonnull final String path, @Nonnull final Object payload, - @Nonnull final Class responseType) { + @Nonnull final Class responseType, + @Nonnull final List
customHeaders) { try { val json = JACKSON.writeValueAsString(payload); log.debug("Successfully serialized request into JSON payload"); val request = new HttpPost(path); request.setEntity(new StringEntity(json, ContentType.APPLICATION_JSON)); + customHeaders.forEach(h -> request.addHeader(h.getName(), h.getValue())); val client = getHttpClient(); @@ -67,12 +71,15 @@ T execute( @Nonnull Stream stream( - @Nonnull final String path, @Nonnull final Object payload) { + @Nonnull final String path, + @Nonnull final Object payload, + @Nonnull final List
customHeaders) { try { - val json = JACKSON.writeValueAsString(payload); val request = new HttpPost(path); request.setEntity(new StringEntity(json, ContentType.APPLICATION_JSON)); + customHeaders.forEach(h -> request.addHeader(h.getName(), h.getValue())); + val client = getHttpClient(); return new ClientStreamingHandler<>( diff --git a/orchestration/src/test/java/com/sap/ai/sdk/orchestration/OrchestrationUnitTest.java b/orchestration/src/test/java/com/sap/ai/sdk/orchestration/OrchestrationUnitTest.java index 9f0d8af0b..176809764 100644 --- a/orchestration/src/test/java/com/sap/ai/sdk/orchestration/OrchestrationUnitTest.java +++ b/orchestration/src/test/java/com/sap/ai/sdk/orchestration/OrchestrationUnitTest.java @@ -3,6 +3,7 @@ import static com.github.tomakehurst.wiremock.client.WireMock.aResponse; import static com.github.tomakehurst.wiremock.client.WireMock.anyUrl; import static com.github.tomakehurst.wiremock.client.WireMock.badRequest; +import static com.github.tomakehurst.wiremock.client.WireMock.equalTo; import static com.github.tomakehurst.wiremock.client.WireMock.equalToJson; import static com.github.tomakehurst.wiremock.client.WireMock.jsonResponse; import static com.github.tomakehurst.wiremock.client.WireMock.noContent; @@ -168,6 +169,33 @@ void testCompletionError() { "Request failed with status 500 (Server Error): Internal Server Error located in Masking Module - Masking"); } + @Test + void testCustomHeaders() { + stubFor( + post(urlPathEqualTo("/v2/completion")) + .willReturn( + aResponse() + .withBodyFile("templatingResponse.json") + .withHeader("Content-Type", "application/json"))); + + final var clientWithHeader = client.withHeader("Header-For-Both", "value"); + final var result = clientWithHeader.withHeader("foo", "bar").chatCompletion(prompt, config); + assertThat(result).isNotNull(); + + var streamResult = + clientWithHeader.withHeader("foot", "baz").streamChatCompletion(prompt, config); + assertThat(streamResult).isNotNull(); + + verify( + postRequestedFor(urlPathEqualTo("/v2/completion")) + .withHeader("Header-For-Both", equalTo("value")) + .withHeader("foo", equalTo("bar"))); + verify( + postRequestedFor(urlPathEqualTo("/v2/completion")) + .withHeader("Header-For-Both", equalTo("value")) + .withHeader("foot", equalTo("baz"))); + } + @Test void testGrounding() throws IOException { stubFor(