diff --git a/semantickernel-api/src/main/java/com/microsoft/semantickernel/semanticfunctions/KernelFunctionFromPrompt.java b/semantickernel-api/src/main/java/com/microsoft/semantickernel/semanticfunctions/KernelFunctionFromPrompt.java index e3be61e8..5d124d84 100644 --- a/semantickernel-api/src/main/java/com/microsoft/semantickernel/semanticfunctions/KernelFunctionFromPrompt.java +++ b/semantickernel-api/src/main/java/com/microsoft/semantickernel/semanticfunctions/KernelFunctionFromPrompt.java @@ -20,6 +20,7 @@ import com.microsoft.semantickernel.services.chatcompletion.ChatCompletionService; import com.microsoft.semantickernel.services.textcompletion.TextGenerationService; import java.util.ArrayList; +import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -439,6 +440,7 @@ public KernelFunction build() { name, template, templateFormat, + Collections.emptySet(), description, inputVariables, outputVariable, diff --git a/semantickernel-api/src/main/java/com/microsoft/semantickernel/semanticfunctions/PromptTemplateConfig.java b/semantickernel-api/src/main/java/com/microsoft/semantickernel/semanticfunctions/PromptTemplateConfig.java index 02adb289..5f69ad46 100644 --- a/semantickernel-api/src/main/java/com/microsoft/semantickernel/semanticfunctions/PromptTemplateConfig.java +++ b/semantickernel-api/src/main/java/com/microsoft/semantickernel/semanticfunctions/PromptTemplateConfig.java @@ -11,9 +11,11 @@ import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; +import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Objects; +import java.util.Set; import javax.annotation.Nullable; /** @@ -43,6 +45,7 @@ public class PromptTemplateConfig { @Nullable private final String template; private final String templateFormat; + private final Set promptTemplateOptions; @Nullable private final String description; private final List inputVariables; @@ -61,6 +64,7 @@ protected PromptTemplateConfig(String template) { DEFAULT_CONFIG_NAME, template, SEMANTIC_KERNEL_TEMPLATE_FORMAT, + Collections.emptySet(), "", Collections.emptyList(), new OutputVariable(String.class.getName(), "out"), @@ -70,14 +74,15 @@ protected PromptTemplateConfig(String template) { /** * Constructor for a prompt template config * - * @param schema Schema version - * @param name Name of the template - * @param template Template string - * @param templateFormat Template format - * @param description Description of the template - * @param inputVariables Input variables - * @param outputVariable Output variable - * @param executionSettings Execution settings + * @param schema Schema version + * @param name Name of the template + * @param template Template string + * @param templateFormat Template format + * @param promptTemplateOptions Prompt template options + * @param description Description of the template + * @param inputVariables Input variables + * @param outputVariable Output variable + * @param executionSettings Execution settings */ @JsonCreator public PromptTemplateConfig( @@ -85,6 +90,7 @@ public PromptTemplateConfig( @Nullable @JsonProperty("name") String name, @Nullable @JsonProperty("template") String template, @Nullable @JsonProperty(value = "template_format", defaultValue = SEMANTIC_KERNEL_TEMPLATE_FORMAT) String templateFormat, + @Nullable @JsonProperty(value = "prompt_template_options") Set promptTemplateOptions, @Nullable @JsonProperty("description") String description, @Nullable @JsonProperty("input_variables") List inputVariables, @Nullable @JsonProperty("output_variable") OutputVariable outputVariable, @@ -96,6 +102,10 @@ public PromptTemplateConfig( templateFormat = SEMANTIC_KERNEL_TEMPLATE_FORMAT; } this.templateFormat = templateFormat; + if (promptTemplateOptions == null) { + promptTemplateOptions = new HashSet<>(); + } + this.promptTemplateOptions = promptTemplateOptions; this.description = description; if (inputVariables == null) { this.inputVariables = new ArrayList<>(); @@ -127,6 +137,7 @@ protected PromptTemplateConfig( @Nullable String name, @Nullable String template, @Nullable String templateFormat, + @Nullable Set promptTemplateOptions, @Nullable String description, @Nullable List inputVariables, @Nullable OutputVariable outputVariable, @@ -136,6 +147,7 @@ protected PromptTemplateConfig( name, template, templateFormat, + promptTemplateOptions, description, inputVariables, outputVariable, @@ -152,6 +164,7 @@ public PromptTemplateConfig(PromptTemplateConfig promptTemplate) { promptTemplate.name, promptTemplate.template, promptTemplate.templateFormat, + promptTemplate.promptTemplateOptions, promptTemplate.description, promptTemplate.inputVariables, promptTemplate.outputVariable, @@ -300,6 +313,15 @@ public int getSchema() { return schema; } + /** + * Get the prompt template options of the prompt template config. + * + * @return The prompt template options of the prompt template config. + */ + public Set getPromptTemplateOptions() { + return Collections.unmodifiableSet(promptTemplateOptions); + } + /** * Create a builder for a prompt template config which is a clone of the current object. * @@ -358,6 +380,7 @@ public static class Builder { @Nullable private String template; private String templateFormat = SEMANTIC_KERNEL_TEMPLATE_FORMAT; + private final Set promptTemplateOptions = new HashSet<>(); @Nullable private String description = null; private List inputVariables = new ArrayList<>(); @@ -433,6 +456,11 @@ public Builder withTemplateFormat(String templateFormat) { return this; } + public Builder addPromptTemplateOption(PromptTemplateOption option) { + promptTemplateOptions.add(option); + return this; + } + /** * Set the inputVariables of the prompt template config. * @@ -477,6 +505,7 @@ public PromptTemplateConfig build() { name, template, templateFormat, + promptTemplateOptions, description, inputVariables, outputVariable, diff --git a/semantickernel-api/src/main/java/com/microsoft/semantickernel/semanticfunctions/PromptTemplateOption.java b/semantickernel-api/src/main/java/com/microsoft/semantickernel/semanticfunctions/PromptTemplateOption.java new file mode 100644 index 00000000..5d244613 --- /dev/null +++ b/semantickernel-api/src/main/java/com/microsoft/semantickernel/semanticfunctions/PromptTemplateOption.java @@ -0,0 +1,16 @@ +// Copyright (c) Microsoft. All rights reserved. +package com.microsoft.semantickernel.semanticfunctions; + +public enum PromptTemplateOption { + /** + * Allow methods on objects provided as arguments to an invocation, to be invoked when rendering + * a template and its return value used. Typically, this would be used to call a getter on an + * object i.e. {@code {{#each users}} {{userName}} {{/each}} } on a handlebars template will + * call the method {@code getUserName()} on each object in {@code users}. + *

+ * WARNING: If this option is used, ensure that your template is trusted, and that objects added + * as arguments to an invocation, do not contain methods that are unsafe to be invoked when + * rendering a template. + */ + ALLOW_CONTEXT_VARIABLE_METHOD_CALLS_UNSAFE +} \ No newline at end of file diff --git a/semantickernel-api/src/main/java/com/microsoft/semantickernel/templateengine/handlebars/HandlebarsPromptTemplate.java b/semantickernel-api/src/main/java/com/microsoft/semantickernel/templateengine/handlebars/HandlebarsPromptTemplate.java index 2e7c260e..ec701ce3 100644 --- a/semantickernel-api/src/main/java/com/microsoft/semantickernel/templateengine/handlebars/HandlebarsPromptTemplate.java +++ b/semantickernel-api/src/main/java/com/microsoft/semantickernel/templateengine/handlebars/HandlebarsPromptTemplate.java @@ -9,6 +9,7 @@ import com.github.jknack.handlebars.Helper; import com.github.jknack.handlebars.Options; import com.github.jknack.handlebars.ValueResolver; +import com.github.jknack.handlebars.context.JavaBeanValueResolver; import com.microsoft.semantickernel.Kernel; import com.microsoft.semantickernel.contextvariables.ContextVariable; import com.microsoft.semantickernel.contextvariables.ContextVariableType; @@ -21,6 +22,7 @@ import com.microsoft.semantickernel.semanticfunctions.KernelFunctionArguments; import com.microsoft.semantickernel.semanticfunctions.PromptTemplate; import com.microsoft.semantickernel.semanticfunctions.PromptTemplateConfig; +import com.microsoft.semantickernel.semanticfunctions.PromptTemplateOption; import com.microsoft.semantickernel.services.chatcompletion.ChatMessageContent; import edu.umd.cs.findbugs.annotations.SuppressFBWarnings; import java.io.IOException; @@ -35,7 +37,6 @@ import java.util.stream.Collectors; import javax.annotation.Nonnull; import javax.annotation.Nullable; -import org.apache.commons.text.StringEscapeUtils; import reactor.core.publisher.Mono; /** @@ -168,7 +169,7 @@ public Set> propertySet(Object context) { } } - private static class HandleBarsPromptTemplateHandler { + private class HandleBarsPromptTemplateHandler { private final String template; private final Handlebars handlebars; @@ -181,7 +182,7 @@ public HandleBarsPromptTemplateHandler( this.template = template; this.handlebars = new Handlebars(); this.handlebars - .registerHelper("message", HandleBarsPromptTemplateHandler::handleMessage) + .registerHelper("message", this::handleMessage) .registerHelper("each", handleEach(context)) .with(EscapingStrategy.XML); @@ -190,7 +191,7 @@ public HandleBarsPromptTemplateHandler( // TODO: 1.0 Add more helpers } - private static Helper handleEach(InvocationContext invocationContext) { + private Helper handleEach(InvocationContext invocationContext) { return (context, options) -> { if (context instanceof ContextVariable) { return ((ContextVariable) context) @@ -227,7 +228,7 @@ private static Helper handleEach(InvocationContext invocationContext) { } @Nullable - private static CharSequence handleMessage(Object context, Options options) + private CharSequence handleMessage(Object context, Options options) throws IOException { String role = options.hash("role"); String content = (String) options.fn(context); @@ -258,7 +259,10 @@ public Mono render(KernelFunctionArguments variables) { resolvers.add(new MessageResolver()); resolvers.add(new ContextVariableResolver()); - // resolvers.addAll(ValueResolver.defaultValueResolvers()); + if (promptTemplate.getPromptTemplateOptions() + .contains(PromptTemplateOption.ALLOW_CONTEXT_VARIABLE_METHOD_CALLS_UNSAFE)) { + resolvers.add(JavaBeanValueResolver.INSTANCE); + } Context context = Context .newBuilder(variables) diff --git a/semantickernel-api/src/test/java/com/microsoft/semantickernel/semanticfunctions/PromptTemplateConfigTest.java b/semantickernel-api/src/test/java/com/microsoft/semantickernel/semanticfunctions/PromptTemplateConfigTest.java index d4141dcf..ea34400a 100644 --- a/semantickernel-api/src/test/java/com/microsoft/semantickernel/semanticfunctions/PromptTemplateConfigTest.java +++ b/semantickernel-api/src/test/java/com/microsoft/semantickernel/semanticfunctions/PromptTemplateConfigTest.java @@ -5,6 +5,7 @@ import com.microsoft.semantickernel.orchestration.PromptExecutionSettings; import java.util.Arrays; +import java.util.Collections; import java.util.HashMap; import java.util.List; import org.junit.jupiter.api.Test; @@ -28,6 +29,7 @@ void testInstanceMadeWithBuilderEqualsInstanceMadeWithConstructor() { name, template, "semantic-kernel", + Collections.emptySet(), description, inputVariables, outputVariable,