Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import com.microsoft.semantickernel.services.chatcompletion.ChatCompletionService;
import com.microsoft.semantickernel.services.textcompletion.TextGenerationService;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -439,6 +440,7 @@ public KernelFunction<T> build() {
name,
template,
templateFormat,
Collections.emptySet(),
description,
inputVariables,
outputVariable,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,11 @@
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import javax.annotation.Nullable;

/**
Expand Down Expand Up @@ -43,6 +45,7 @@ public class PromptTemplateConfig {
@Nullable
private final String template;
private final String templateFormat;
private final Set<PromptTemplateOption> promptTemplateOptions;
@Nullable
private final String description;
private final List<InputVariable> inputVariables;
Expand All @@ -61,6 +64,7 @@ protected PromptTemplateConfig(String template) {
DEFAULT_CONFIG_NAME,
template,
SEMANTIC_KERNEL_TEMPLATE_FORMAT,
Collections.emptySet(),
"",
Collections.emptyList(),
new OutputVariable(String.class.getName(), "out"),
Expand All @@ -70,21 +74,23 @@ protected PromptTemplateConfig(String template) {
/**
* Constructor for a prompt template config
*
* @param schema Schema version
* @param name Name of the template
* @param template Template string
* @param templateFormat Template format
* @param description Description of the template
* @param inputVariables Input variables
* @param outputVariable Output variable
* @param executionSettings Execution settings
* @param schema Schema version
* @param name Name of the template
* @param template Template string
* @param templateFormat Template format
* @param promptTemplateOptions Prompt template options
* @param description Description of the template
* @param inputVariables Input variables
* @param outputVariable Output variable
* @param executionSettings Execution settings
*/
@JsonCreator
public PromptTemplateConfig(
@JsonProperty("schema") int schema,
@Nullable @JsonProperty("name") String name,
@Nullable @JsonProperty("template") String template,
@Nullable @JsonProperty(value = "template_format", defaultValue = SEMANTIC_KERNEL_TEMPLATE_FORMAT) String templateFormat,
@Nullable @JsonProperty(value = "prompt_template_options") Set<PromptTemplateOption> promptTemplateOptions,
@Nullable @JsonProperty("description") String description,
@Nullable @JsonProperty("input_variables") List<InputVariable> inputVariables,
@Nullable @JsonProperty("output_variable") OutputVariable outputVariable,
Expand All @@ -96,6 +102,10 @@ public PromptTemplateConfig(
templateFormat = SEMANTIC_KERNEL_TEMPLATE_FORMAT;
}
this.templateFormat = templateFormat;
if (promptTemplateOptions == null) {
promptTemplateOptions = new HashSet<>();
}
this.promptTemplateOptions = promptTemplateOptions;
this.description = description;
if (inputVariables == null) {
this.inputVariables = new ArrayList<>();
Expand Down Expand Up @@ -127,6 +137,7 @@ protected PromptTemplateConfig(
@Nullable String name,
@Nullable String template,
@Nullable String templateFormat,
@Nullable Set<PromptTemplateOption> promptTemplateOptions,
@Nullable String description,
@Nullable List<InputVariable> inputVariables,
@Nullable OutputVariable outputVariable,
Expand All @@ -136,6 +147,7 @@ protected PromptTemplateConfig(
name,
template,
templateFormat,
promptTemplateOptions,
description,
inputVariables,
outputVariable,
Expand All @@ -152,6 +164,7 @@ public PromptTemplateConfig(PromptTemplateConfig promptTemplate) {
promptTemplate.name,
promptTemplate.template,
promptTemplate.templateFormat,
promptTemplate.promptTemplateOptions,
promptTemplate.description,
promptTemplate.inputVariables,
promptTemplate.outputVariable,
Expand Down Expand Up @@ -300,6 +313,15 @@ public int getSchema() {
return schema;
}

/**
* Get the prompt template options of the prompt template config.
*
* @return The prompt template options of the prompt template config.
*/
public Set<PromptTemplateOption> getPromptTemplateOptions() {
return Collections.unmodifiableSet(promptTemplateOptions);
}

/**
* Create a builder for a prompt template config which is a clone of the current object.
*
Expand Down Expand Up @@ -358,6 +380,7 @@ public static class Builder {
@Nullable
private String template;
private String templateFormat = SEMANTIC_KERNEL_TEMPLATE_FORMAT;
private final Set<PromptTemplateOption> promptTemplateOptions = new HashSet<>();
@Nullable
private String description = null;
private List<InputVariable> inputVariables = new ArrayList<>();
Expand Down Expand Up @@ -433,6 +456,11 @@ public Builder withTemplateFormat(String templateFormat) {
return this;
}

public Builder addPromptTemplateOption(PromptTemplateOption option) {
promptTemplateOptions.add(option);
return this;
}

/**
* Set the inputVariables of the prompt template config.
*
Expand Down Expand Up @@ -477,6 +505,7 @@ public PromptTemplateConfig build() {
name,
template,
templateFormat,
promptTemplateOptions,
description,
inputVariables,
outputVariable,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
// Copyright (c) Microsoft. All rights reserved.
package com.microsoft.semantickernel.semanticfunctions;

public enum PromptTemplateOption {
/**
* Allow methods on objects provided as arguments to an invocation, to be invoked when rendering
* a template and its return value used. Typically, this would be used to call a getter on an
* object i.e. {@code {{#each users}} {{userName}} {{/each}} } on a handlebars template will
* call the method {@code getUserName()} on each object in {@code users}.
* <p>
* WARNING: If this option is used, ensure that your template is trusted, and that objects added
* as arguments to an invocation, do not contain methods that are unsafe to be invoked when
* rendering a template.
*/
ALLOW_CONTEXT_VARIABLE_METHOD_CALLS_UNSAFE
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import com.github.jknack.handlebars.Helper;
import com.github.jknack.handlebars.Options;
import com.github.jknack.handlebars.ValueResolver;
import com.github.jknack.handlebars.context.JavaBeanValueResolver;
import com.microsoft.semantickernel.Kernel;
import com.microsoft.semantickernel.contextvariables.ContextVariable;
import com.microsoft.semantickernel.contextvariables.ContextVariableType;
Expand All @@ -21,6 +22,7 @@
import com.microsoft.semantickernel.semanticfunctions.KernelFunctionArguments;
import com.microsoft.semantickernel.semanticfunctions.PromptTemplate;
import com.microsoft.semantickernel.semanticfunctions.PromptTemplateConfig;
import com.microsoft.semantickernel.semanticfunctions.PromptTemplateOption;
import com.microsoft.semantickernel.services.chatcompletion.ChatMessageContent;
import edu.umd.cs.findbugs.annotations.SuppressFBWarnings;
import java.io.IOException;
Expand All @@ -35,7 +37,6 @@
import java.util.stream.Collectors;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import org.apache.commons.text.StringEscapeUtils;
import reactor.core.publisher.Mono;

/**
Expand Down Expand Up @@ -168,7 +169,7 @@ public Set<Entry<String, Object>> propertySet(Object context) {
}
}

private static class HandleBarsPromptTemplateHandler {
private class HandleBarsPromptTemplateHandler {

private final String template;
private final Handlebars handlebars;
Expand All @@ -181,7 +182,7 @@ public HandleBarsPromptTemplateHandler(
this.template = template;
this.handlebars = new Handlebars();
this.handlebars
.registerHelper("message", HandleBarsPromptTemplateHandler::handleMessage)
.registerHelper("message", this::handleMessage)
.registerHelper("each", handleEach(context))
.with(EscapingStrategy.XML);

Expand All @@ -190,7 +191,7 @@ public HandleBarsPromptTemplateHandler(
// TODO: 1.0 Add more helpers
}

private static Helper<Object> handleEach(InvocationContext invocationContext) {
private Helper<Object> handleEach(InvocationContext invocationContext) {
return (context, options) -> {
if (context instanceof ContextVariable) {
return ((ContextVariable<?>) context)
Expand Down Expand Up @@ -227,7 +228,7 @@ private static Helper<Object> handleEach(InvocationContext invocationContext) {
}

@Nullable
private static CharSequence handleMessage(Object context, Options options)
private CharSequence handleMessage(Object context, Options options)
throws IOException {
String role = options.hash("role");
String content = (String) options.fn(context);
Expand Down Expand Up @@ -258,7 +259,10 @@ public Mono<String> render(KernelFunctionArguments variables) {
resolvers.add(new MessageResolver());
resolvers.add(new ContextVariableResolver());

// resolvers.addAll(ValueResolver.defaultValueResolvers());
if (promptTemplate.getPromptTemplateOptions()
.contains(PromptTemplateOption.ALLOW_CONTEXT_VARIABLE_METHOD_CALLS_UNSAFE)) {
resolvers.add(JavaBeanValueResolver.INSTANCE);
}

Context context = Context
.newBuilder(variables)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import com.microsoft.semantickernel.orchestration.PromptExecutionSettings;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import org.junit.jupiter.api.Test;
Expand All @@ -28,6 +29,7 @@ void testInstanceMadeWithBuilderEqualsInstanceMadeWithConstructor() {
name,
template,
"semantic-kernel",
Collections.emptySet(),
description,
inputVariables,
outputVariable,
Expand Down