diff --git a/core-services/prompt-registry/pom.xml b/core-services/prompt-registry/pom.xml index 8fd41e9c9..53949ebb8 100644 --- a/core-services/prompt-registry/pom.xml +++ b/core-services/prompt-registry/pom.xml @@ -38,10 +38,10 @@ ${project.basedir}/../../ - 75% + 73% 87% 89% - 100% + 75% 75% 100% @@ -64,6 +64,11 @@ org.springframework spring-web + + org.springframework.ai + spring-ai-model + true + com.sap.cloud.sdk.cloudplatform cloudplatform-connectivity diff --git a/core-services/prompt-registry/src/main/java/com/sap/ai/sdk/prompt/registry/spring/SpringAiConverter.java b/core-services/prompt-registry/src/main/java/com/sap/ai/sdk/prompt/registry/spring/SpringAiConverter.java new file mode 100644 index 000000000..1d3a6926d --- /dev/null +++ b/core-services/prompt-registry/src/main/java/com/sap/ai/sdk/prompt/registry/spring/SpringAiConverter.java @@ -0,0 +1,49 @@ +package com.sap.ai.sdk.prompt.registry.spring; + +import com.sap.ai.sdk.prompt.registry.model.PromptTemplateSubstitutionResponse; +import com.sap.ai.sdk.prompt.registry.model.SingleChatTemplate; +import com.sap.ai.sdk.prompt.registry.model.Template; +import java.util.List; +import javax.annotation.Nonnull; +import lombok.val; +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.SystemMessage; +import org.springframework.ai.chat.messages.UserMessage; + +/** Utility class for prompt registry related operations in a Spring context. */ +public class SpringAiConverter { + + private SpringAiConverter() { + // Utility class, no instantiation allowed + } + + /** + * Get a SpringAI list of messages from a Prompt Registry Response. + * + * @param promptResponse the response from Prompt Registry. + * @return list of SpringAI messages. + */ + @Nonnull + public static List promptTemplateToMessages( + @Nonnull final PromptTemplateSubstitutionResponse promptResponse) { + + val res = promptResponse.getParsedPrompt(); + + // TRANSFORM TEMPLATE TO SPRING AI MESSAGES + return res.stream() + .map( + (Template t) -> { + final SingleChatTemplate message = (SingleChatTemplate) t; + return (Message) + switch (message.getRole()) { + case "system" -> new SystemMessage(message.getContent()); + case "user" -> new UserMessage(message.getContent()); + case "assistant" -> new AssistantMessage(message.getContent()); + default -> + throw new IllegalArgumentException("Unknown role: " + message.getRole()); + }; + }) + .toList(); + } +} diff --git a/core-services/prompt-registry/src/test/java/com/sap/ai/sdk/prompt/registry/spring/SpringAiConverterTest.java b/core-services/prompt-registry/src/test/java/com/sap/ai/sdk/prompt/registry/spring/SpringAiConverterTest.java new file mode 100644 index 000000000..07d117f34 --- /dev/null +++ b/core-services/prompt-registry/src/test/java/com/sap/ai/sdk/prompt/registry/spring/SpringAiConverterTest.java @@ -0,0 +1,69 @@ +package com.sap.ai.sdk.prompt.registry.spring; + +import static com.github.tomakehurst.wiremock.core.WireMockConfiguration.wireMockConfig; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +import com.github.tomakehurst.wiremock.junit5.WireMockExtension; +import com.sap.ai.sdk.core.AiCoreService; +import com.sap.ai.sdk.prompt.registry.PromptClient; +import com.sap.ai.sdk.prompt.registry.model.PromptTemplateSubstitutionRequest; +import com.sap.cloud.sdk.cloudplatform.connectivity.DefaultHttpDestination; +import com.sap.cloud.sdk.cloudplatform.connectivity.HttpDestination; +import java.util.List; +import java.util.Map; +import lombok.val; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.SystemMessage; +import org.springframework.ai.chat.messages.UserMessage; + +public class SpringAiConverterTest { + @RegisterExtension + private static final WireMockExtension WM = + WireMockExtension.newInstance().options(wireMockConfig().dynamicPort()).build(); + + private final HttpDestination DESTINATION = DefaultHttpDestination.builder(WM.baseUrl()).build(); + private final AiCoreService SERVICE = new AiCoreService().withBaseDestination(DESTINATION); + + @Test + void testPromptRegistryToSpringAi() { + var client = new PromptClient(SERVICE); + val promptResponse = + client.parsePromptTemplateByNameVersion( + "categorization", + "0.0.1", + "java-e2e-test", + "default", + false, + PromptTemplateSubstitutionRequest.create() + .inputParams(Map.of("inputExample", "I love football"))); + + List messages = SpringAiConverter.promptTemplateToMessages(promptResponse); + assertThat(messages) + .isEqualTo( + List.of( + new SystemMessage( + "You classify input text into the two following categories: Finance, Tech, Sports, Politics"), + new UserMessage("I love football"))); + } + + @Test + void testInvalidRoleThrowsException() { + var client = new PromptClient(SERVICE); + val errorPrompt = + client.parsePromptTemplateByNameVersion( + "categorization", + "0.0.1", + "error", + "default", + false, + PromptTemplateSubstitutionRequest.create() + .inputParams(Map.of("inputExample", "I love football"))); + + assertThatThrownBy(() -> SpringAiConverter.promptTemplateToMessages(errorPrompt)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Unknown role: error"); + } +} diff --git a/core-services/prompt-registry/src/test/resources/mappings/error.json b/core-services/prompt-registry/src/test/resources/mappings/error.json new file mode 100644 index 000000000..50f131be0 --- /dev/null +++ b/core-services/prompt-registry/src/test/resources/mappings/error.json @@ -0,0 +1,26 @@ +{ + "request": { + "method": "POST", + "url": "/v2/lm/scenarios/categorization/promptTemplates/error/versions/0.0.1/substitution?metadata=false" + }, + "response": { + "status": 200, + "headers": { + "Content-Type": "application/json" + }, + "jsonBody": { + "parsedPrompt": [ + { + "role": "assistant", + "content": "What can I help you with?" + }, + { + "role": "error", + "content": "What is this?" + } + ] + } + } +} + + diff --git a/core-services/prompt-registry/src/test/resources/mappings/templatesInputParams.json b/core-services/prompt-registry/src/test/resources/mappings/templatesInputParams.json new file mode 100644 index 000000000..ea3a55fc0 --- /dev/null +++ b/core-services/prompt-registry/src/test/resources/mappings/templatesInputParams.json @@ -0,0 +1,26 @@ +{ + "request": { + "method": "POST", + "url": "/v2/lm/scenarios/categorization/promptTemplates/java-e2e-test/versions/0.0.1/substitution?metadata=false" + }, + "response": { + "status": 200, + "headers": { + "Content-Type": "application/json" + }, + "jsonBody": { + "parsedPrompt": [ + { + "role": "system", + "content": "You classify input text into the two following categories: Finance, Tech, Sports, Politics" + }, + { + "role": "user", + "content": "I love football" + } + ] + } + } +} + + diff --git a/docs/release_notes.md b/docs/release_notes.md index f872d0927..6d043220b 100644 --- a/docs/release_notes.md +++ b/docs/release_notes.md @@ -25,6 +25,8 @@ - [Orchestration] Deprecated `OrchestrationAiModel.IBM_GRANITE_13B_CHAT` with no replacement. - [OpenAI] [Introduced SpringAI integration with our OpenAI client.](https://sap.github.io/ai-sdk/docs/java/spring-ai/openai) - Added `OpenAiChatModel` +- [Prompt Registry] [Using Prompt Registry Templates in SpringAI.](https://sap.github.io/ai-sdk/docs/java/ai-core/prompt-registry#using-templates-in-springai) + - Added `SpringAiConverter` ### 📈 Improvements diff --git a/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/controllers/PromptRegistryController.java b/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/controllers/PromptRegistryController.java index cf8fdb513..52b4b391f 100644 --- a/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/controllers/PromptRegistryController.java +++ b/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/controllers/PromptRegistryController.java @@ -1,5 +1,8 @@ package com.sap.ai.sdk.app.controllers; +import com.sap.ai.sdk.foundationmodels.openai.OpenAiClient; +import com.sap.ai.sdk.foundationmodels.openai.OpenAiModel; +import com.sap.ai.sdk.foundationmodels.openai.spring.OpenAiChatModel; import com.sap.ai.sdk.prompt.registry.PromptClient; import com.sap.ai.sdk.prompt.registry.model.PromptTemplateDeleteResponse; import com.sap.ai.sdk.prompt.registry.model.PromptTemplateListResponse; @@ -9,10 +12,19 @@ import com.sap.ai.sdk.prompt.registry.model.PromptTemplateSubstitutionRequest; import com.sap.ai.sdk.prompt.registry.model.PromptTemplateSubstitutionResponse; import com.sap.ai.sdk.prompt.registry.model.SingleChatTemplate; +import com.sap.ai.sdk.prompt.registry.spring.SpringAiConverter; import java.io.File; import java.io.IOException; import java.util.List; import java.util.Map; +import lombok.val; +import org.springframework.ai.chat.client.ChatClient; +import org.springframework.ai.chat.client.advisor.MessageChatMemoryAdvisor; +import org.springframework.ai.chat.memory.InMemoryChatMemoryRepository; +import org.springframework.ai.chat.memory.MessageWindowChatMemory; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.model.Generation; +import org.springframework.ai.chat.prompt.Prompt; import org.springframework.core.io.ClassPathResource; import org.springframework.core.io.Resource; import org.springframework.web.bind.annotation.GetMapping; @@ -99,4 +111,29 @@ List deleteTemplate() { .map(template -> client.deletePromptTemplate(template.getId())) .toList(); } + + @GetMapping("/promptRegistryToSpringAi") + Generation promptRegistryToSpringAi() { + val openAiClient = new OpenAiChatModel(OpenAiClient.forModel(OpenAiModel.GPT_4O_MINI)); + val repository = new InMemoryChatMemoryRepository(); + val memory = MessageWindowChatMemory.builder().chatMemoryRepository(repository).build(); + val advisor = MessageChatMemoryAdvisor.builder(memory).build(); + val cl = ChatClient.builder(openAiClient).defaultAdvisors(advisor).build(); + + val promptResponse = + new PromptClient() + .parsePromptTemplateByNameVersion( + "categorization", + "0.0.1", + "java-e2e-test", + "default", + false, + PromptTemplateSubstitutionRequest.create() + .inputParams(Map.of("inputExample", "I love football"))); + + final List messages = SpringAiConverter.promptTemplateToMessages(promptResponse); + val prompt = new Prompt(messages); + val response = cl.prompt(prompt).call().chatResponse(); + return response != null ? response.getResult() : null; + } } diff --git a/sample-code/spring-app/src/main/resources/static/index.html b/sample-code/spring-app/src/main/resources/static/index.html index 61e2a9f68..62a5f9052 100644 --- a/sample-code/spring-app/src/main/resources/static/index.html +++ b/sample-code/spring-app/src/main/resources/static/index.html @@ -1051,6 +1051,18 @@

📚 Prompt Registry

+
  • +
    + +
    + Get a SpringAI list of messages from a Prompt Registry Response. +
    +
    +
  • diff --git a/sample-code/spring-app/src/test/java/com/sap/ai/sdk/app/controllers/PromptRegistryTest.java b/sample-code/spring-app/src/test/java/com/sap/ai/sdk/app/controllers/PromptRegistryTest.java index 1758fc8dc..2854a5a97 100644 --- a/sample-code/spring-app/src/test/java/com/sap/ai/sdk/app/controllers/PromptRegistryTest.java +++ b/sample-code/spring-app/src/test/java/com/sap/ai/sdk/app/controllers/PromptRegistryTest.java @@ -85,4 +85,12 @@ void history() { assertThat(deletedTemplate).hasSize(1); assertThat(deletedTemplate.get(0).getMessage()).contains("successful"); } + + @Test + void promptRegistryToSpringAi() { + var controller = new PromptRegistryController(); + var ChatResponse = controller.promptRegistryToSpringAi(); + assertThat(ChatResponse).isNotNull(); + assertThat(ChatResponse.getOutput().getText()).contains("Sports"); + } }