Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ private static class ChatMessages {

private final List<ChatRequestMessage> newMessages;
private final List<ChatRequestMessage> allMessages;
private final List<OpenAIChatMessageContent> newChatMessageContent;
private final List<OpenAIChatMessageContent<?>> newChatMessageContent;

public ChatMessages(List<ChatRequestMessage> allMessages) {
this.allMessages = Collections.unmodifiableList(allMessages);
Expand All @@ -195,7 +195,7 @@ public ChatMessages(List<ChatRequestMessage> allMessages) {
private ChatMessages(
List<ChatRequestMessage> allMessages,
List<ChatRequestMessage> newMessages,
List<OpenAIChatMessageContent> newChatMessageContent) {
List<OpenAIChatMessageContent<?>> newChatMessageContent) {
this.allMessages = Collections.unmodifiableList(allMessages);
this.newMessages = Collections.unmodifiableList(newMessages);
this.newChatMessageContent = Collections.unmodifiableList(newChatMessageContent);
Expand All @@ -219,8 +219,8 @@ public ChatMessages add(ChatRequestMessage requestMessage) {
}

@CheckReturnValue
public ChatMessages addChatMessage(List<OpenAIChatMessageContent> chatMessageContent) {
ArrayList<OpenAIChatMessageContent> tmpChatMessageContent = new ArrayList<>(
public ChatMessages addChatMessage(List<OpenAIChatMessageContent<?>> chatMessageContent) {
ArrayList<OpenAIChatMessageContent<?>> tmpChatMessageContent = new ArrayList<>(
newChatMessageContent);
tmpChatMessageContent.addAll(chatMessageContent);

Expand Down Expand Up @@ -357,19 +357,16 @@ private Mono<ChatMessages> internalChatMessageContentsAsync(
// If we don't want to attempt to invoke any functions
// Or if we are auto-invoking, but we somehow end up with other than 1 choice even though only 1 was requested
if (autoInvokeAttempts == 0 || responseMessages.size() != 1) {
return getChatMessageContentsAsync(completions)
.flatMap(m -> {
return Mono.just(messages.addChatMessage(m));
});
List<OpenAIChatMessageContent<?>> chatMessageContents = getChatMessageContentsAsync(completions);
return Mono.just(messages.addChatMessage(chatMessageContents));
}
// Or if there are no tool calls to be done
ChatResponseMessage response = responseMessages.get(0);
List<ChatCompletionsToolCall> toolCalls = response.getToolCalls();
if (toolCalls == null || toolCalls.isEmpty()) {
return getChatMessageContentsAsync(completions)
.flatMap(m -> {
return Mono.just(messages.addChatMessage(m));
});
List<OpenAIChatMessageContent<?>> chatMessageContents = getChatMessageContentsAsync(
completions);
return Mono.just(messages.addChatMessage(chatMessageContents));
}

ChatRequestAssistantMessage requestMessage = new ChatRequestAssistantMessage(
Expand Down Expand Up @@ -592,7 +589,7 @@ private OpenAIFunctionToolCall extractOpenAIFunctionToolCall(
arguments);
}

private Mono<List<OpenAIChatMessageContent>> getChatMessageContentsAsync(
private List<OpenAIChatMessageContent<?>> getChatMessageContentsAsync(
ChatCompletions completions) {
FunctionResultMetadata<CompletionsUsage> completionMetadata = FunctionResultMetadata.build(
completions.getId(),
Expand All @@ -606,22 +603,28 @@ private Mono<List<OpenAIChatMessageContent>> getChatMessageContentsAsync(
.filter(Objects::nonNull)
.collect(Collectors.toList());

return Flux.fromIterable(responseMessages)
.flatMap(response -> {
List<OpenAIChatMessageContent<?>> chatMessageContent =
responseMessages
.stream()
.map(response -> {
try {
return Mono.just(new OpenAIChatMessageContent(
return new OpenAIChatMessageContent<>(
AuthorRole.ASSISTANT,
response.getContent(),
this.getModelId(),
null,
null,
completionMetadata,
formOpenAiToolCalls(response)));
} catch (Exception e) {
return Mono.error(e);
formOpenAiToolCalls(response));
} catch (SKCheckedException e) {
LOGGER.warn("Failed to form chat message content", e);
return null;
}
})
.collectList();
.filter(Objects::nonNull)
.collect(Collectors.toList());

return chatMessageContent;
}

private List<ChatMessageContent<?>> toOpenAIChatMessageContent(
Expand Down Expand Up @@ -931,7 +934,7 @@ private static boolean hasToolCallBeenExecuted(List<ChatRequestMessage> chatRequ
}

private static List<ChatRequestMessage> getChatRequestMessages(
List<? extends ChatMessageContent> messages) {
List<? extends ChatMessageContent<?>> messages) {
if (messages == null || messages.isEmpty()) {
return new ArrayList<>();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ public OpenAIChatMessageContent(
@Nullable String modelId,
@Nullable T innerContent,
@Nullable Charset encoding,
@Nullable FunctionResultMetadata metadata,
@Nullable FunctionResultMetadata<?> metadata,
@Nullable List<OpenAIFunctionToolCall> toolCall) {
super(authorRole, content, modelId, innerContent, encoding, metadata);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@
import com.microsoft.semantickernel.services.chatcompletion.message.ChatMessageTextContent;
import java.nio.charset.Charset;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Optional;
import java.util.Spliterator;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.function.Consumer;
import javax.annotation.Nullable;

Expand All @@ -18,7 +20,7 @@
*/
public class ChatHistory implements Iterable<ChatMessageContent<?>> {

private final List<ChatMessageContent<?>> chatMessageContents;
private final Collection<ChatMessageContent<?>> chatMessageContents;

/**
* The default constructor
Expand All @@ -33,7 +35,7 @@ public ChatHistory() {
* @param instructions The instructions to add to the chat history
*/
public ChatHistory(@Nullable String instructions) {
this.chatMessageContents = new ArrayList<>();
this.chatMessageContents = new ConcurrentLinkedQueue<>();
if (instructions != null) {
this.chatMessageContents.add(
ChatMessageTextContent.systemMessage(instructions));
Expand All @@ -45,8 +47,8 @@ public ChatHistory(@Nullable String instructions) {
*
* @param chatMessageContents The chat message contents to add to the chat history
*/
public ChatHistory(List<? extends ChatMessageContent> chatMessageContents) {
this.chatMessageContents = new ArrayList(chatMessageContents);
public ChatHistory(List<? extends ChatMessageContent<?>> chatMessageContents) {
this.chatMessageContents = new ConcurrentLinkedQueue<>(chatMessageContents);
}

/**
Expand All @@ -55,7 +57,7 @@ public ChatHistory(List<? extends ChatMessageContent> chatMessageContents) {
* @return List of messages in the chat
*/
public List<ChatMessageContent<?>> getMessages() {
return Collections.unmodifiableList(chatMessageContents);
return Collections.unmodifiableList(new ArrayList<>(chatMessageContents));
}

/**
Expand All @@ -67,7 +69,7 @@ public Optional<ChatMessageContent<?>> getLastMessage() {
if (chatMessageContents.isEmpty()) {
return Optional.empty();
}
return Optional.of(chatMessageContents.get(chatMessageContents.size() - 1));
return Optional.of(((ConcurrentLinkedQueue<ChatMessageContent<?>>)chatMessageContents).peek());
}

/**
Expand Down Expand Up @@ -114,7 +116,7 @@ public Spliterator<ChatMessageContent<?>> spliterator() {
* @param metadata The metadata of the message
*/
public ChatHistory addMessage(AuthorRole authorRole, String content, Charset encoding,
FunctionResultMetadata metadata) {
FunctionResultMetadata<?> metadata) {
chatMessageContents.add(
ChatMessageTextContent.builder()
.withAuthorRole(authorRole)
Expand Down