From caff8b665d098ee5e74973c97bf22559c391f3e0 Mon Sep 17 00:00:00 2001 From: John Oliver <1615532+johnoliver@users.noreply.github.com> Date: Fri, 26 Jul 2024 12:07:48 +0100 Subject: [PATCH] Fix some global kernel hooks not being executed --- .../chatcompletion/OpenAIChatCompletion.java | 33 ++++++++++++------- .../samples/demos/lights/App.java | 22 +++++++++++++ .../memory/InMemory_DataStorage.java | 3 +- .../semantickernel/hooks/KernelHooks.java | 27 ++++++++++++++- .../KernelFunctionFromMethod.java | 7 ++-- .../KernelFunctionFromPrompt.java | 7 ++-- 6 files changed, 78 insertions(+), 21 deletions(-) diff --git a/aiservices/openai/src/main/java/com/microsoft/semantickernel/aiservices/openai/chatcompletion/OpenAIChatCompletion.java b/aiservices/openai/src/main/java/com/microsoft/semantickernel/aiservices/openai/chatcompletion/OpenAIChatCompletion.java index 6bdb4f1c..ce988126 100644 --- a/aiservices/openai/src/main/java/com/microsoft/semantickernel/aiservices/openai/chatcompletion/OpenAIChatCompletion.java +++ b/aiservices/openai/src/main/java/com/microsoft/semantickernel/aiservices/openai/chatcompletion/OpenAIChatCompletion.java @@ -311,6 +311,7 @@ private Mono internalChatMessageContentsAsync( ChatCompletionsOptions options = executeHook( invocationContext, + kernel, new PreChatCompletionEvent( getCompletionsOptions( this, @@ -349,7 +350,7 @@ private Mono internalChatMessageContentsAsync( .collect(Collectors.toList()); // execute post chat completion hook - executeHook(invocationContext, new PostChatCompletionEvent(completions)); + executeHook(invocationContext, kernel, new PostChatCompletionEvent(completions)); // Just return the result: // If we don't want to attempt to invoke any functions @@ -517,11 +518,12 @@ private Mono> invokeFunctionTool( pluginName, openAIFunctionToolCall.getFunctionName()); - PreToolCallEvent hookResult = executeHook(invocationContext, new PreToolCallEvent( - openAIFunctionToolCall.getFunctionName(), - openAIFunctionToolCall.getArguments(), - function, - contextVariableTypes)); + PreToolCallEvent hookResult = executeHook(invocationContext, kernel, + new PreToolCallEvent( + openAIFunctionToolCall.getFunctionName(), + openAIFunctionToolCall.getArguments(), + function, + contextVariableTypes)); function = hookResult.getFunction(); KernelFunctionArguments arguments = hookResult.getArguments(); @@ -537,12 +539,21 @@ private Mono> invokeFunctionTool( private static T executeHook( @Nullable InvocationContext invocationContext, + @Nullable Kernel kernel, T event) { - KernelHooks kernelHooks = invocationContext != null - && invocationContext.getKernelHooks() != null - ? invocationContext.getKernelHooks() - : new KernelHooks(); - + KernelHooks kernelHooks = null; + if (kernel == null) { + if (invocationContext != null) { + kernelHooks = invocationContext.getKernelHooks(); + } + } else { + kernelHooks = KernelHooks.merge( + kernel.getGlobalKernelHooks(), + invocationContext != null ? invocationContext.getKernelHooks() : null); + } + if (kernelHooks == null) { + return event; + } return kernelHooks.executeHooks(event); } diff --git a/samples/semantickernel-concepts/semantickernel-syntax-examples/src/main/java/com/microsoft/semantickernel/samples/demos/lights/App.java b/samples/semantickernel-concepts/semantickernel-syntax-examples/src/main/java/com/microsoft/semantickernel/samples/demos/lights/App.java index 08fb2697..08a8f3b5 100644 --- a/samples/semantickernel-concepts/semantickernel-syntax-examples/src/main/java/com/microsoft/semantickernel/samples/demos/lights/App.java +++ b/samples/semantickernel-concepts/semantickernel-syntax-examples/src/main/java/com/microsoft/semantickernel/samples/demos/lights/App.java @@ -10,6 +10,7 @@ import com.microsoft.semantickernel.aiservices.openai.chatcompletion.OpenAIChatCompletion; import com.microsoft.semantickernel.contextvariables.ContextVariableTypeConverter; import com.microsoft.semantickernel.contextvariables.ContextVariableTypes; +import com.microsoft.semantickernel.hooks.KernelHooks; import com.microsoft.semantickernel.orchestration.InvocationContext; import com.microsoft.semantickernel.orchestration.InvocationContext.Builder; import com.microsoft.semantickernel.orchestration.InvocationReturnMode; @@ -73,6 +74,27 @@ public static void main(String[] args) throws Exception { .toPromptString(new Gson()::toJson) .build()); + KernelHooks hook = new KernelHooks(); + + hook.addPreToolCallHook((context) -> { + System.out.println("Pre-tool call hook"); + return context; + }); + + hook.addPreChatCompletionHook( + (context) -> { + System.out.println("Pre-chat completion hook"); + return context; + }); + + hook.addPostChatCompletionHook( + (context) -> { + System.out.println("Post-chat completion hook"); + return context; + }); + + kernel.getGlobalKernelHooks().addHooks(hook); + // Enable planning InvocationContext invocationContext = new Builder() .withReturnMode(InvocationReturnMode.LAST_MESSAGE_ONLY) diff --git a/samples/semantickernel-concepts/semantickernel-syntax-examples/src/main/java/com/microsoft/semantickernel/samples/syntaxexamples/memory/InMemory_DataStorage.java b/samples/semantickernel-concepts/semantickernel-syntax-examples/src/main/java/com/microsoft/semantickernel/samples/syntaxexamples/memory/InMemory_DataStorage.java index b189081b..d92c1e06 100644 --- a/samples/semantickernel-concepts/semantickernel-syntax-examples/src/main/java/com/microsoft/semantickernel/samples/syntaxexamples/memory/InMemory_DataStorage.java +++ b/samples/semantickernel-concepts/semantickernel-syntax-examples/src/main/java/com/microsoft/semantickernel/samples/syntaxexamples/memory/InMemory_DataStorage.java @@ -105,7 +105,8 @@ public static void main(String[] args) { inMemoryDataStorage(embeddingGeneration); } - public static void inMemoryDataStorage(OpenAITextEmbeddingGenerationService embeddingGeneration) { + public static void inMemoryDataStorage( + OpenAITextEmbeddingGenerationService embeddingGeneration) { // Create a new Volatile vector store var volatileVectorStore = new VolatileVectorStore(); diff --git a/semantickernel-api/src/main/java/com/microsoft/semantickernel/hooks/KernelHooks.java b/semantickernel-api/src/main/java/com/microsoft/semantickernel/hooks/KernelHooks.java index f690fe26..959dda5a 100644 --- a/semantickernel-api/src/main/java/com/microsoft/semantickernel/hooks/KernelHooks.java +++ b/semantickernel-api/src/main/java/com/microsoft/semantickernel/hooks/KernelHooks.java @@ -66,7 +66,7 @@ public UnmodifiableKernelHooks unmodifiableClone() { * * @return an unmodifiable map of the hooks */ - private Map> getHooks() { + protected Map> getHooks() { return Collections.unmodifiableMap(hooks); } @@ -224,6 +224,31 @@ public boolean isEmpty() { return hooks.isEmpty(); } + /** + * Builds the list of hooks to be invoked for the given context, by merging the hooks in this + * collection with the hooks in the context. Duplicate hooks in b will override hooks in a. + * + * @param a hooks to merge + * @param b hooks to merge + * @return the list of hooks to be invoked + */ + public static KernelHooks merge(@Nullable KernelHooks a, @Nullable KernelHooks b) { + KernelHooks hooks = a; + if (hooks == null) { + hooks = new KernelHooks(); + } + + if (b == null) { + return hooks; + } else if (hooks.isEmpty()) { + return b; + } else { + HashMap> merged = new HashMap<>(hooks.getHooks()); + merged.putAll(b.getHooks()); + return new KernelHooks(merged); + } + } + /** * A wrapper for KernelHooks that disables mutating methods. */ diff --git a/semantickernel-api/src/main/java/com/microsoft/semantickernel/semanticfunctions/KernelFunctionFromMethod.java b/semantickernel-api/src/main/java/com/microsoft/semantickernel/semanticfunctions/KernelFunctionFromMethod.java index 6d9d1166..9a7b09dc 100644 --- a/semantickernel-api/src/main/java/com/microsoft/semantickernel/semanticfunctions/KernelFunctionFromMethod.java +++ b/semantickernel-api/src/main/java/com/microsoft/semantickernel/semanticfunctions/KernelFunctionFromMethod.java @@ -157,10 +157,9 @@ public static ImplementationFunc getFunction(Method method, Object instan } // kernelHooks must be effectively final for lambda - KernelHooks kernelHooks = context.getKernelHooks() != null - ? context.getKernelHooks() - : kernel.getGlobalKernelHooks(); - assert kernelHooks != null : "getGlobalKernelHooks() should never return null!"; + KernelHooks kernelHooks = KernelHooks.merge( + kernel.getGlobalKernelHooks(), + context.getKernelHooks()); FunctionInvokingEvent updatedState = kernelHooks .executeHooks( diff --git a/semantickernel-api/src/main/java/com/microsoft/semantickernel/semanticfunctions/KernelFunctionFromPrompt.java b/semantickernel-api/src/main/java/com/microsoft/semantickernel/semanticfunctions/KernelFunctionFromPrompt.java index 1d754d65..e3be61e8 100644 --- a/semantickernel-api/src/main/java/com/microsoft/semantickernel/semanticfunctions/KernelFunctionFromPrompt.java +++ b/semantickernel-api/src/main/java/com/microsoft/semantickernel/semanticfunctions/KernelFunctionFromPrompt.java @@ -102,10 +102,9 @@ private Flux> invokeInternalAsync( : InvocationContext.builder().build(); // must be effectively final for lambda - KernelHooks kernelHooks = context.getKernelHooks() != null - ? context.getKernelHooks() - : kernel.getGlobalKernelHooks(); - assert kernelHooks != null : "getGlobalKernelHooks() should never return null"; + KernelHooks kernelHooks = KernelHooks.merge( + kernel.getGlobalKernelHooks(), + context.getKernelHooks()); PromptRenderingEvent preRenderingHookResult = kernelHooks .executeHooks(new PromptRenderingEvent(this, argumentsIn));