diff --git a/dotnet/samples/Concepts/Agents/ChatCompletion_FunctionTermination.cs b/dotnet/samples/Concepts/Agents/ChatCompletion_FunctionTermination.cs new file mode 100644 index 000000000000..f344dae432b9 --- /dev/null +++ b/dotnet/samples/Concepts/Agents/ChatCompletion_FunctionTermination.cs @@ -0,0 +1,158 @@ +// Copyright (c) Microsoft. All rights reserved. +using System.ComponentModel; +using Microsoft.SemanticKernel; +using Microsoft.SemanticKernel.Agents; +using Microsoft.SemanticKernel.ChatCompletion; +using Microsoft.SemanticKernel.Connectors.OpenAI; + +namespace Agents; + +/// +/// Demonstrate usage of for both direction invocation +/// of and via . +/// +public class ChatCompletion_FunctionTermination(ITestOutputHelper output) : BaseTest(output) +{ + [Fact] + public async Task UseAutoFunctionInvocationFilterWithAgentInvocationAsync() + { + // Define the agent + ChatCompletionAgent agent = + new() + { + Instructions = "Answer questions about the menu.", + Kernel = CreateKernelWithChatCompletion(), + ExecutionSettings = new OpenAIPromptExecutionSettings() { ToolCallBehavior = ToolCallBehavior.AutoInvokeKernelFunctions }, + }; + + KernelPlugin plugin = KernelPluginFactory.CreateFromType(); + agent.Kernel.Plugins.Add(plugin); + + /// Create the chat history to capture the agent interaction. + ChatHistory chat = []; + + // Respond to user input, invoking functions where appropriate. + await InvokeAgentAsync("Hello"); + await InvokeAgentAsync("What is the special soup?"); + await InvokeAgentAsync("What is the special drink?"); + await InvokeAgentAsync("Thank you"); + + // Display the chat history. + Console.WriteLine("================================"); + Console.WriteLine("CHAT HISTORY"); + Console.WriteLine("================================"); + foreach (ChatMessageContent message in chat) + { + this.WriteContent(message); + } + + // Local function to invoke agent and display the conversation messages. + async Task InvokeAgentAsync(string input) + { + ChatMessageContent userContent = new(AuthorRole.User, input); + chat.Add(userContent); + this.WriteContent(userContent); + + await foreach (ChatMessageContent content in agent.InvokeAsync(chat)) + { + // Do not add a message implicitly added to the history. + if (!content.Items.Any(i => i is FunctionCallContent || i is FunctionResultContent)) + { + chat.Add(content); + } + + this.WriteContent(content); + } + } + } + + [Fact] + public async Task UseAutoFunctionInvocationFilterWithAgentChatAsync() + { + // Define the agent + ChatCompletionAgent agent = + new() + { + Instructions = "Answer questions about the menu.", + Kernel = CreateKernelWithChatCompletion(), + ExecutionSettings = new OpenAIPromptExecutionSettings() { ToolCallBehavior = ToolCallBehavior.AutoInvokeKernelFunctions }, + }; + + KernelPlugin plugin = KernelPluginFactory.CreateFromType(); + agent.Kernel.Plugins.Add(plugin); + + // Create a chat for agent interaction. + AgentGroupChat chat = new(); + + // Respond to user input, invoking functions where appropriate. + await InvokeAgentAsync("Hello"); + await InvokeAgentAsync("What is the special soup?"); + await InvokeAgentAsync("What is the special drink?"); + await InvokeAgentAsync("Thank you"); + + // Display the chat history. + Console.WriteLine("================================"); + Console.WriteLine("CHAT HISTORY"); + Console.WriteLine("================================"); + ChatMessageContent[] history = await chat.GetChatMessagesAsync().ToArrayAsync(); + for (int index = history.Length; index > 0; --index) + { + this.WriteContent(history[index - 1]); + } + + // Local function to invoke agent and display the conversation messages. + async Task InvokeAgentAsync(string input) + { + ChatMessageContent userContent = new(AuthorRole.User, input); + chat.AddChatMessage(userContent); + this.WriteContent(userContent); + + await foreach (ChatMessageContent content in chat.InvokeAsync(agent)) + { + this.WriteContent(content); + } + } + } + + private void WriteContent(ChatMessageContent content) + { + Console.WriteLine($"[{content.Items.LastOrDefault()?.GetType().Name ?? "(empty)"}] {content.Role} : '{content.Content}'"); + } + + private sealed class MenuPlugin + { + [KernelFunction, Description("Provides a list of specials from the menu.")] + [System.Diagnostics.CodeAnalysis.SuppressMessage("Design", "CA1024:Use properties where appropriate", Justification = "Too smart")] + public string GetSpecials() + { + return @" +Special Soup: Clam Chowder +Special Salad: Cobb Salad +Special Drink: Chai Tea +"; + } + + [KernelFunction, Description("Provides the price of the requested menu item.")] + public string GetItemPrice( + [Description("The name of the menu item.")] + string menuItem) + { + return "$9.99"; + } + } + + private sealed class AutoInvocationFilter(bool terminate = true) : IAutoFunctionInvocationFilter + { + public async Task OnAutoFunctionInvocationAsync(AutoFunctionInvocationContext context, Func next) + { + // Execution the function + await next(context); + + // Signal termination if the function is from the MenuPlugin + if (context.Function.PluginName == nameof(MenuPlugin)) + { + context.Terminate = terminate; + } + } + } +} diff --git a/dotnet/samples/Concepts/Agents/ChatCompletion_Streaming.cs b/dotnet/samples/Concepts/Agents/ChatCompletion_Streaming.cs index ee6fb9b38f2a..258e12166a6b 100644 --- a/dotnet/samples/Concepts/Agents/ChatCompletion_Streaming.cs +++ b/dotnet/samples/Concepts/Agents/ChatCompletion_Streaming.cs @@ -1,8 +1,10 @@ // Copyright (c) Microsoft. All rights reserved. +using System.ComponentModel; using System.Text; using Microsoft.SemanticKernel; using Microsoft.SemanticKernel.Agents; using Microsoft.SemanticKernel.ChatCompletion; +using Microsoft.SemanticKernel.Connectors.OpenAI; namespace Agents; @@ -30,40 +32,88 @@ public async Task UseStreamingChatCompletionAgentAsync() ChatHistory chat = []; // Respond to user input - await InvokeAgentAsync("Fortune favors the bold."); - await InvokeAgentAsync("I came, I saw, I conquered."); - await InvokeAgentAsync("Practice makes perfect."); + await InvokeAgentAsync(agent, chat, "Fortune favors the bold."); + await InvokeAgentAsync(agent, chat, "I came, I saw, I conquered."); + await InvokeAgentAsync(agent, chat, "Practice makes perfect."); + } - // Local function to invoke agent and display the conversation messages. - async Task InvokeAgentAsync(string input) - { - chat.Add(new ChatMessageContent(AuthorRole.User, input)); + [Fact] + public async Task UseStreamingChatCompletionAgentWithPluginAsync() + { + const string MenuInstructions = "Answer questions about the menu."; + + // Define the agent + ChatCompletionAgent agent = + new() + { + Name = "Host", + Instructions = MenuInstructions, + Kernel = this.CreateKernelWithChatCompletion(), + ExecutionSettings = new OpenAIPromptExecutionSettings() { ToolCallBehavior = ToolCallBehavior.AutoInvokeKernelFunctions }, + }; + + // Initialize plugin and add to the agent's Kernel (same as direct Kernel usage). + KernelPlugin plugin = KernelPluginFactory.CreateFromType(); + agent.Kernel.Plugins.Add(plugin); + + ChatHistory chat = []; + + // Respond to user input + await InvokeAgentAsync(agent, chat, "What is the special soup?"); + await InvokeAgentAsync(agent, chat, "What is the special drink?"); + } + + // Local function to invoke agent and display the conversation messages. + private async Task InvokeAgentAsync(ChatCompletionAgent agent, ChatHistory chat, string input) + { + chat.Add(new ChatMessageContent(AuthorRole.User, input)); - Console.WriteLine($"# {AuthorRole.User}: '{input}'"); + Console.WriteLine($"# {AuthorRole.User}: '{input}'"); - StringBuilder builder = new(); - await foreach (StreamingChatMessageContent message in agent.InvokeStreamingAsync(chat)) + StringBuilder builder = new(); + await foreach (StreamingChatMessageContent message in agent.InvokeStreamingAsync(chat)) + { + if (string.IsNullOrEmpty(message.Content)) { - if (string.IsNullOrEmpty(message.Content)) - { - continue; - } - - if (builder.Length == 0) - { - Console.WriteLine($"# {message.Role} - {message.AuthorName ?? "*"}:"); - } - - Console.WriteLine($"\t > streamed: '{message.Content}'"); - builder.Append(message.Content); + continue; } - if (builder.Length > 0) + if (builder.Length == 0) { - // Display full response and capture in chat history - Console.WriteLine($"\t > complete: '{builder}'"); - chat.Add(new ChatMessageContent(AuthorRole.Assistant, builder.ToString()) { AuthorName = agent.Name }); + Console.WriteLine($"# {message.Role} - {message.AuthorName ?? "*"}:"); } + + Console.WriteLine($"\t > streamed: '{message.Content}'"); + builder.Append(message.Content); + } + + if (builder.Length > 0) + { + // Display full response and capture in chat history + Console.WriteLine($"\t > complete: '{builder}'"); + chat.Add(new ChatMessageContent(AuthorRole.Assistant, builder.ToString()) { AuthorName = agent.Name }); + } + } + + public sealed class MenuPlugin + { + [KernelFunction, Description("Provides a list of specials from the menu.")] + [System.Diagnostics.CodeAnalysis.SuppressMessage("Design", "CA1024:Use properties where appropriate", Justification = "Too smart")] + public string GetSpecials() + { + return @" +Special Soup: Clam Chowder +Special Salad: Cobb Salad +Special Drink: Chai Tea +"; + } + + [KernelFunction, Description("Provides the price of the requested menu item.")] + public string GetItemPrice( + [Description("The name of the menu item.")] + string menuItem) + { + return "$9.99"; } } } diff --git a/dotnet/src/Agents/Abstractions/AgentChannel.cs b/dotnet/src/Agents/Abstractions/AgentChannel.cs index ad58deedb017..9788464a2adb 100644 --- a/dotnet/src/Agents/Abstractions/AgentChannel.cs +++ b/dotnet/src/Agents/Abstractions/AgentChannel.cs @@ -23,7 +23,7 @@ public abstract class AgentChannel /// /// The chat history at the point the channel is created. /// The to monitor for cancellation requests. The default is . - protected internal abstract Task ReceiveAsync(IReadOnlyList history, CancellationToken cancellationToken = default); + protected internal abstract Task ReceiveAsync(IEnumerable history, CancellationToken cancellationToken = default); /// /// Perform a discrete incremental interaction between a single and . @@ -31,7 +31,7 @@ public abstract class AgentChannel /// The agent actively interacting with the chat. /// The to monitor for cancellation requests. The default is . /// Asynchronous enumeration of messages. - protected internal abstract IAsyncEnumerable InvokeAsync( + protected internal abstract IAsyncEnumerable<(bool IsVisible, ChatMessageContent Message)> InvokeAsync( Agent agent, CancellationToken cancellationToken = default); @@ -59,12 +59,12 @@ public abstract class AgentChannel : AgentChannel where TAgent : Agent /// The agent actively interacting with the chat. /// The to monitor for cancellation requests. The default is . /// Asynchronous enumeration of messages. - protected internal abstract IAsyncEnumerable InvokeAsync( + protected internal abstract IAsyncEnumerable<(bool IsVisible, ChatMessageContent Message)> InvokeAsync( TAgent agent, CancellationToken cancellationToken = default); /// - protected internal override IAsyncEnumerable InvokeAsync( + protected internal override IAsyncEnumerable<(bool IsVisible, ChatMessageContent Message)> InvokeAsync( Agent agent, CancellationToken cancellationToken = default) { diff --git a/dotnet/src/Agents/Abstractions/AgentChat.cs b/dotnet/src/Agents/Abstractions/AgentChat.cs index 9c834380a8f4..f4654963444e 100644 --- a/dotnet/src/Agents/Abstractions/AgentChat.cs +++ b/dotnet/src/Agents/Abstractions/AgentChat.cs @@ -209,22 +209,21 @@ protected async IAsyncEnumerable InvokeAgentAsync( // Invoke agent & process response List messages = []; - await foreach (ChatMessageContent message in channel.InvokeAsync(agent, cancellationToken).ConfigureAwait(false)) + + await foreach ((bool isVisible, ChatMessageContent message) in channel.InvokeAsync(agent, cancellationToken).ConfigureAwait(false)) { this.Logger.LogAgentChatInvokedAgentMessage(nameof(InvokeAgentAsync), agent.GetType(), agent.Id, message); + messages.Add(message); + // Add to primary history this.History.Add(message); - messages.Add(message); - // Don't expose function-call and function-result messages to caller. - if (message.Items.All(i => i is FunctionCallContent || i is FunctionResultContent)) + if (isVisible) { - continue; + // Yield message to caller + yield return message; } - - // Yield message to caller - yield return message; } // Broadcast message to other channels (in parallel) @@ -233,7 +232,7 @@ protected async IAsyncEnumerable InvokeAgentAsync( this._agentChannels .Where(kvp => kvp.Value != channel) .Select(kvp => new ChannelReference(kvp.Value, kvp.Key)); - this._broadcastQueue.Enqueue(channelRefs, messages.Where(m => m.Role != AuthorRole.Tool).ToArray()); + this._broadcastQueue.Enqueue(channelRefs, messages); this.Logger.LogAgentChatInvokedAgent(nameof(InvokeAgentAsync), agent.GetType(), agent.Id); } @@ -256,6 +255,7 @@ async Task GetOrCreateChannelAsync() if (this.History.Count > 0) { + // Sync channel with existing history await channel.ReceiveAsync(this.History, cancellationToken).ConfigureAwait(false); } diff --git a/dotnet/src/Agents/Abstractions/AggregatorChannel.cs b/dotnet/src/Agents/Abstractions/AggregatorChannel.cs index 60b1cd4367f6..73561a4eba8b 100644 --- a/dotnet/src/Agents/Abstractions/AggregatorChannel.cs +++ b/dotnet/src/Agents/Abstractions/AggregatorChannel.cs @@ -18,7 +18,7 @@ protected internal override IAsyncEnumerable GetHistoryAsync return this._chat.GetChatMessagesAsync(cancellationToken); } - protected internal override async IAsyncEnumerable InvokeAsync(AggregatorAgent agent, [EnumeratorCancellation] CancellationToken cancellationToken = default) + protected internal override async IAsyncEnumerable<(bool IsVisible, ChatMessageContent Message)> InvokeAsync(AggregatorAgent agent, [EnumeratorCancellation] CancellationToken cancellationToken = default) { ChatMessageContent? lastMessage = null; @@ -27,7 +27,7 @@ protected internal override async IAsyncEnumerable InvokeAsy // For AggregatorMode.Flat, the entire aggregated chat is merged into the owning chat. if (agent.Mode == AggregatorMode.Flat) { - yield return message; + yield return (IsVisible: true, message); } lastMessage = message; @@ -43,11 +43,11 @@ protected internal override async IAsyncEnumerable InvokeAsy AuthorName = agent.Name }; - yield return message; + yield return (IsVisible: true, message); } } - protected internal override Task ReceiveAsync(IReadOnlyList history, CancellationToken cancellationToken = default) + protected internal override Task ReceiveAsync(IEnumerable history, CancellationToken cancellationToken = default) { // Always receive the initial history from the owning chat. this._chat.AddChatMessages([.. history]); diff --git a/dotnet/src/Agents/Abstractions/ChatHistoryChannel.cs b/dotnet/src/Agents/Abstractions/ChatHistoryChannel.cs index 2bb5616ff959..5dcb6b9b0204 100644 --- a/dotnet/src/Agents/Abstractions/ChatHistoryChannel.cs +++ b/dotnet/src/Agents/Abstractions/ChatHistoryChannel.cs @@ -1,5 +1,6 @@ // Copyright (c) Microsoft. All rights reserved. using System.Collections.Generic; +using System.Linq; using System.Runtime.CompilerServices; using System.Threading; using System.Threading.Tasks; @@ -16,7 +17,7 @@ public class ChatHistoryChannel : AgentChannel private readonly ChatHistory _history; /// - protected internal sealed override async IAsyncEnumerable InvokeAsync( + protected internal sealed override async IAsyncEnumerable<(bool IsVisible, ChatMessageContent Message)> InvokeAsync( Agent agent, [EnumeratorCancellation] CancellationToken cancellationToken = default) { @@ -25,16 +26,55 @@ protected internal sealed override async IAsyncEnumerable In throw new KernelException($"Invalid channel binding for agent: {agent.Id} ({agent.GetType().FullName})"); } - await foreach (ChatMessageContent message in historyHandler.InvokeAsync(this._history, cancellationToken).ConfigureAwait(false)) + // Capture the current message count to evaluate history mutation. + int messageCount = this._history.Count; + HashSet mutatedHistory = []; + + // Utilize a queue as a "read-ahead" cache to evaluate message sequencing (i.e., which message is final). + Queue messageQueue = []; + + ChatMessageContent? yieldMessage = null; + await foreach (ChatMessageContent responseMessage in historyHandler.InvokeAsync(this._history, cancellationToken).ConfigureAwait(false)) { - this._history.Add(message); + // Capture all messages that have been included in the mutated the history. + for (int messageIndex = messageCount; messageIndex < this._history.Count; messageIndex++) + { + ChatMessageContent mutatedMessage = this._history[messageIndex]; + mutatedHistory.Add(mutatedMessage); + messageQueue.Enqueue(mutatedMessage); + } + + // Update the message count pointer to reflect the current history. + messageCount = this._history.Count; - yield return message; + // Avoid duplicating any message included in the mutated history and also returned by the enumeration result. + if (!mutatedHistory.Contains(responseMessage)) + { + this._history.Add(responseMessage); + messageQueue.Enqueue(responseMessage); + } + + // Dequeue the next message to yield. + yieldMessage = messageQueue.Dequeue(); + yield return (IsMessageVisible(yieldMessage), yieldMessage); } + + // Dequeue any remaining messages to yield. + while (messageQueue.Count > 0) + { + yieldMessage = messageQueue.Dequeue(); + + yield return (IsMessageVisible(yieldMessage), yieldMessage); + } + + // Function content not visible, unless result is the final message. + bool IsMessageVisible(ChatMessageContent message) => + (!message.Items.Any(i => i is FunctionCallContent || i is FunctionResultContent) || + messageQueue.Count == 0); } /// - protected internal sealed override Task ReceiveAsync(IReadOnlyList history, CancellationToken cancellationToken) + protected internal sealed override Task ReceiveAsync(IEnumerable history, CancellationToken cancellationToken) { this._history.AddRange(history); diff --git a/dotnet/src/Agents/Core/ChatCompletionAgent.cs b/dotnet/src/Agents/Core/ChatCompletionAgent.cs index 990154b139e4..1e9ea3d3208e 100644 --- a/dotnet/src/Agents/Core/ChatCompletionAgent.cs +++ b/dotnet/src/Agents/Core/ChatCompletionAgent.cs @@ -84,22 +84,22 @@ public override async IAsyncEnumerable InvokeStream this.Logger.LogAgentChatServiceInvokedStreamingAgent(nameof(InvokeAsync), this.Id, chatCompletionService.GetType()); - // Capture mutated messages related function calling / tools - for (int messageIndex = messageCount; messageIndex < chat.Count; messageIndex++) + await foreach (StreamingChatMessageContent message in messages.ConfigureAwait(false)) { - ChatMessageContent message = chat[messageIndex]; - + // TODO: MESSAGE SOURCE - ISSUE #5731 message.AuthorName = this.Name; - history.Add(message); + yield return message; } - await foreach (StreamingChatMessageContent message in messages.ConfigureAwait(false)) + // Capture mutated messages related function calling / tools + for (int messageIndex = messageCount; messageIndex < chat.Count; messageIndex++) { - // TODO: MESSAGE SOURCE - ISSUE #5731 + ChatMessageContent message = chat[messageIndex]; + message.AuthorName = this.Name; - yield return message; + history.Add(message); } } diff --git a/dotnet/src/Agents/OpenAI/AssistantThreadActions.cs b/dotnet/src/Agents/OpenAI/AssistantThreadActions.cs index b1be5bb52765..f768d89a54bb 100644 --- a/dotnet/src/Agents/OpenAI/AssistantThreadActions.cs +++ b/dotnet/src/Agents/OpenAI/AssistantThreadActions.cs @@ -20,12 +20,6 @@ internal static class AssistantThreadActions { private const string FunctionDelimiter = "-"; - private static readonly HashSet s_messageRoles = - [ - AuthorRole.User, - AuthorRole.Assistant, - ]; - private static readonly HashSet s_pollingStatuses = [ RunStatus.Queued, @@ -50,12 +44,8 @@ internal static class AssistantThreadActions /// if a system message is present, without taking any other action public static async Task CreateMessageAsync(AssistantsClient client, string threadId, ChatMessageContent message, CancellationToken cancellationToken) { - if (!s_messageRoles.Contains(message.Role)) - { - throw new KernelException($"Invalid message role: {message.Role}"); - } - - if (string.IsNullOrWhiteSpace(message.Content)) + if (string.IsNullOrEmpty(message.Content) || + message.Items.Any(i => i is FunctionCallContent)) { return; } @@ -136,7 +126,7 @@ public static async IAsyncEnumerable GetMessagesAsync(Assist /// The logger to utilize (might be agent or channel scoped) /// The to monitor for cancellation requests. The default is . /// Asynchronous enumeration of messages. - public static async IAsyncEnumerable InvokeAsync( + public static async IAsyncEnumerable<(bool IsVisible, ChatMessageContent Message)> InvokeAsync( OpenAIAssistantAgent agent, AssistantsClient client, string threadId, @@ -190,7 +180,7 @@ public static async IAsyncEnumerable InvokeAsync( if (activeFunctionSteps.Length > 0) { // Emit function-call content - yield return GenerateFunctionCallContent(agent.GetName(), activeFunctionSteps); + yield return (IsVisible: false, Message: GenerateFunctionCallContent(agent.GetName(), activeFunctionSteps)); // Invoke functions for each tool-step IEnumerable> functionResultTasks = ExecuteFunctionSteps(agent, activeFunctionSteps, cancellationToken); @@ -224,12 +214,14 @@ public static async IAsyncEnumerable InvokeAsync( foreach (RunStepToolCall toolCall in toolCallDetails.ToolCalls) { + bool isVisible = false; ChatMessageContent? content = null; // Process code-interpreter content if (toolCall is RunStepCodeInterpreterToolCall toolCodeInterpreter) { content = GenerateCodeInterpreterContent(agent.GetName(), toolCodeInterpreter); + isVisible = true; } // Process function result content else if (toolCall is RunStepFunctionToolCall toolFunction) @@ -242,7 +234,7 @@ public static async IAsyncEnumerable InvokeAsync( { ++messageCount; - yield return content; + yield return (isVisible, Message: content); } } } @@ -276,7 +268,7 @@ public static async IAsyncEnumerable InvokeAsync( { ++messageCount; - yield return content; + yield return (IsVisible: true, Message: content); } } } diff --git a/dotnet/src/Agents/OpenAI/OpenAIAssistantAgent.cs b/dotnet/src/Agents/OpenAI/OpenAIAssistantAgent.cs index 31c0bb1c0de7..8e8797fa8885 100644 --- a/dotnet/src/Agents/OpenAI/OpenAIAssistantAgent.cs +++ b/dotnet/src/Agents/OpenAI/OpenAIAssistantAgent.cs @@ -240,13 +240,19 @@ public async Task DeleteAsync(CancellationToken cancellationToken = defaul /// The thread identifier /// The to monitor for cancellation requests. The default is . /// Asynchronous enumeration of messages. - public IAsyncEnumerable InvokeAsync( + public async IAsyncEnumerable InvokeAsync( string threadId, - CancellationToken cancellationToken = default) + [EnumeratorCancellation] CancellationToken cancellationToken = default) { this.ThrowIfDeleted(); - return AssistantThreadActions.InvokeAsync(this, this._client, threadId, this._config.Polling, this.Logger, cancellationToken); + await foreach ((bool isVisible, ChatMessageContent message) in AssistantThreadActions.InvokeAsync(this, this._client, threadId, this._config.Polling, this.Logger, cancellationToken).ConfigureAwait(false)) + { + if (isVisible) + { + yield return message; + } + } } /// diff --git a/dotnet/src/Agents/OpenAI/OpenAIAssistantChannel.cs b/dotnet/src/Agents/OpenAI/OpenAIAssistantChannel.cs index b84ef800ebd4..48fdefa65fe9 100644 --- a/dotnet/src/Agents/OpenAI/OpenAIAssistantChannel.cs +++ b/dotnet/src/Agents/OpenAI/OpenAIAssistantChannel.cs @@ -16,7 +16,7 @@ internal sealed class OpenAIAssistantChannel(AssistantsClient client, string thr private readonly string _threadId = threadId; /// - protected override async Task ReceiveAsync(IReadOnlyList history, CancellationToken cancellationToken) + protected override async Task ReceiveAsync(IEnumerable history, CancellationToken cancellationToken) { foreach (ChatMessageContent message in history) { @@ -25,7 +25,7 @@ protected override async Task ReceiveAsync(IReadOnlyList his } /// - protected override IAsyncEnumerable InvokeAsync( + protected override IAsyncEnumerable<(bool IsVisible, ChatMessageContent Message)> InvokeAsync( OpenAIAssistantAgent agent, CancellationToken cancellationToken) { diff --git a/dotnet/src/Agents/UnitTests/AgentChannelTests.cs b/dotnet/src/Agents/UnitTests/AgentChannelTests.cs index 7223b8d46805..2a680614a54f 100644 --- a/dotnet/src/Agents/UnitTests/AgentChannelTests.cs +++ b/dotnet/src/Agents/UnitTests/AgentChannelTests.cs @@ -40,11 +40,11 @@ private sealed class TestChannel : AgentChannel { public int InvokeCount { get; private set; } - public IAsyncEnumerable InvokeAgentAsync(Agent agent, CancellationToken cancellationToken = default) + public IAsyncEnumerable<(bool IsVisible, ChatMessageContent Message)> InvokeAgentAsync(Agent agent, CancellationToken cancellationToken = default) => base.InvokeAsync(agent, cancellationToken); #pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously - protected internal override async IAsyncEnumerable InvokeAsync(TestAgent agent, [EnumeratorCancellation] CancellationToken cancellationToken = default) + protected internal override async IAsyncEnumerable<(bool IsVisible, ChatMessageContent Message)> InvokeAsync(TestAgent agent, [EnumeratorCancellation] CancellationToken cancellationToken = default) #pragma warning restore CS1998 // Async method lacks 'await' operators and will run synchronously { this.InvokeCount++; @@ -57,7 +57,7 @@ protected internal override IAsyncEnumerable GetHistoryAsync throw new NotImplementedException(); } - protected internal override Task ReceiveAsync(IReadOnlyList history, CancellationToken cancellationToken = default) + protected internal override Task ReceiveAsync(IEnumerable history, CancellationToken cancellationToken = default) { throw new NotImplementedException(); } diff --git a/dotnet/src/Agents/UnitTests/Internal/BroadcastQueueTests.cs b/dotnet/src/Agents/UnitTests/Internal/BroadcastQueueTests.cs index 482c4cfa09a3..452a0566e11f 100644 --- a/dotnet/src/Agents/UnitTests/Internal/BroadcastQueueTests.cs +++ b/dotnet/src/Agents/UnitTests/Internal/BroadcastQueueTests.cs @@ -136,12 +136,12 @@ protected internal override IAsyncEnumerable GetHistoryAsync throw new NotImplementedException(); } - protected internal override IAsyncEnumerable InvokeAsync(Agent agent, CancellationToken cancellationToken = default) + protected internal override IAsyncEnumerable<(bool IsVisible, ChatMessageContent Message)> InvokeAsync(Agent agent, CancellationToken cancellationToken = default) { throw new NotImplementedException(); } - protected internal override async Task ReceiveAsync(IReadOnlyList history, CancellationToken cancellationToken = default) + protected internal override async Task ReceiveAsync(IEnumerable history, CancellationToken cancellationToken = default) { this.ReceivedMessages.AddRange(history); this.ReceiveCount++; @@ -159,12 +159,12 @@ protected internal override IAsyncEnumerable GetHistoryAsync throw new NotImplementedException(); } - protected internal override IAsyncEnumerable InvokeAsync(Agent agent, CancellationToken cancellationToken = default) + protected internal override IAsyncEnumerable<(bool IsVisible, ChatMessageContent Message)> InvokeAsync(Agent agent, CancellationToken cancellationToken = default) { throw new NotImplementedException(); } - protected internal override async Task ReceiveAsync(IReadOnlyList history, CancellationToken cancellationToken = default) + protected internal override async Task ReceiveAsync(IEnumerable history, CancellationToken cancellationToken = default) { await Task.Delay(this.ReceiveDuration, cancellationToken); diff --git a/dotnet/src/IntegrationTests/Agents/ChatCompletionAgentTests.cs b/dotnet/src/IntegrationTests/Agents/ChatCompletionAgentTests.cs new file mode 100644 index 000000000000..91796c1970b0 --- /dev/null +++ b/dotnet/src/IntegrationTests/Agents/ChatCompletionAgentTests.cs @@ -0,0 +1,140 @@ +// Copyright (c) Microsoft. All rights reserved. +using System; +using System.ComponentModel; +using System.Linq; +using System.Threading.Tasks; +using Microsoft.Extensions.Configuration; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; +using Microsoft.SemanticKernel; +using Microsoft.SemanticKernel.Agents; +using Microsoft.SemanticKernel.ChatCompletion; +using Microsoft.SemanticKernel.Connectors.OpenAI; +using SemanticKernel.IntegrationTests.TestSettings; +using Xunit; +using Xunit.Abstractions; + +namespace SemanticKernel.IntegrationTests.Agents.OpenAI; + +#pragma warning disable xUnit1004 // Contains test methods used in manual verification. Disable warning for this file only. + +public sealed class ChatCompletionAgentTests(ITestOutputHelper output) : IDisposable +{ + private readonly IKernelBuilder _kernelBuilder = Kernel.CreateBuilder(); + private readonly IConfigurationRoot _configuration = new ConfigurationBuilder() + .AddJsonFile(path: "testsettings.json", optional: false, reloadOnChange: true) + .AddJsonFile(path: "testsettings.development.json", optional: true, reloadOnChange: true) + .AddEnvironmentVariables() + .AddUserSecrets() + .Build(); + + /// + /// Integration test for using function calling + /// and targeting Azure OpenAI services. + /// + [Theory] + [InlineData("What is the special soup?", "Clam Chowder", false)] + [InlineData("What is the special soup?", "Clam Chowder", true)] + public async Task AzureChatCompletionAgentAsync(string input, string expectedAnswerContains, bool useAutoFunctionTermination) + { + // Arrange + AzureOpenAIConfiguration configuration = this._configuration.GetSection("AzureOpenAI").Get()!; + + KernelPlugin plugin = KernelPluginFactory.CreateFromType(); + + this._kernelBuilder.Services.AddSingleton(this._logger); + + this._kernelBuilder.AddAzureOpenAIChatCompletion( + configuration.ChatDeploymentName!, + configuration.Endpoint, + configuration.ApiKey); + + if (useAutoFunctionTermination) + { + this._kernelBuilder.Services.AddSingleton(new AutoInvocationFilter()); + } + + this._kernelBuilder.Plugins.Add(plugin); + + Kernel kernel = this._kernelBuilder.Build(); + + ChatCompletionAgent agent = + new() + { + Kernel = kernel, + Instructions = "Answer questions about the menu.", + ExecutionSettings = new OpenAIPromptExecutionSettings() { ToolCallBehavior = ToolCallBehavior.AutoInvokeKernelFunctions }, + }; + + AgentGroupChat chat = new(); + chat.AddChatMessage(new ChatMessageContent(AuthorRole.User, input)); + + // Act + ChatMessageContent[] messages = await chat.InvokeAsync(agent).ToArrayAsync(); + ChatMessageContent[] history = await chat.GetChatMessagesAsync().ToArrayAsync(); + + // Assert + Assert.Single(messages); + + ChatMessageContent response = messages.Single(); + + if (useAutoFunctionTermination) + { + Assert.Equal(3, history.Length); + Assert.Single(response.Items.OfType()); + Assert.Single(response.Items.OfType()); + } + else + { + Assert.Equal(4, history.Length); + Assert.Single(response.Items); + Assert.Single(response.Items.OfType()); + } + + Assert.Contains(expectedAnswerContains, messages.Single().Content, StringComparison.OrdinalIgnoreCase); + } + + private readonly XunitLogger _logger = new(output); + private readonly RedirectOutput _testOutputHelper = new(output); + + public void Dispose() + { + this._logger.Dispose(); + this._testOutputHelper.Dispose(); + } + + public sealed class MenuPlugin + { + [KernelFunction, Description("Provides a list of specials from the menu.")] + [System.Diagnostics.CodeAnalysis.SuppressMessage("Design", "CA1024:Use properties where appropriate", Justification = "Too smart")] + public string GetSpecials() + { + return @" +Special Soup: Clam Chowder +Special Salad: Cobb Salad +Special Drink: Chai Tea +"; + } + + [KernelFunction, Description("Provides the price of the requested menu item.")] + public string GetItemPrice( + [Description("The name of the menu item.")] + string menuItem) + { + return "$9.99"; + } + } + + private sealed class AutoInvocationFilter(bool terminate = true) : IAutoFunctionInvocationFilter + { + public async Task OnAutoFunctionInvocationAsync(AutoFunctionInvocationContext context, Func next) + { + await next(context); + + if (context.Function.PluginName == nameof(MenuPlugin)) + { + context.Terminate = terminate; + } + } + } +}