diff --git a/dotnet/samples/Concepts/Agents/ChatCompletion_FunctionTermination.cs b/dotnet/samples/Concepts/Agents/ChatCompletion_FunctionTermination.cs
new file mode 100644
index 000000000000..f344dae432b9
--- /dev/null
+++ b/dotnet/samples/Concepts/Agents/ChatCompletion_FunctionTermination.cs
@@ -0,0 +1,158 @@
+// Copyright (c) Microsoft. All rights reserved.
+using System.ComponentModel;
+using Microsoft.SemanticKernel;
+using Microsoft.SemanticKernel.Agents;
+using Microsoft.SemanticKernel.ChatCompletion;
+using Microsoft.SemanticKernel.Connectors.OpenAI;
+
+namespace Agents;
+
+///
+/// Demonstrate usage of for both direction invocation
+/// of and via .
+///
+public class ChatCompletion_FunctionTermination(ITestOutputHelper output) : BaseTest(output)
+{
+ [Fact]
+ public async Task UseAutoFunctionInvocationFilterWithAgentInvocationAsync()
+ {
+ // Define the agent
+ ChatCompletionAgent agent =
+ new()
+ {
+ Instructions = "Answer questions about the menu.",
+ Kernel = CreateKernelWithChatCompletion(),
+ ExecutionSettings = new OpenAIPromptExecutionSettings() { ToolCallBehavior = ToolCallBehavior.AutoInvokeKernelFunctions },
+ };
+
+ KernelPlugin plugin = KernelPluginFactory.CreateFromType();
+ agent.Kernel.Plugins.Add(plugin);
+
+ /// Create the chat history to capture the agent interaction.
+ ChatHistory chat = [];
+
+ // Respond to user input, invoking functions where appropriate.
+ await InvokeAgentAsync("Hello");
+ await InvokeAgentAsync("What is the special soup?");
+ await InvokeAgentAsync("What is the special drink?");
+ await InvokeAgentAsync("Thank you");
+
+ // Display the chat history.
+ Console.WriteLine("================================");
+ Console.WriteLine("CHAT HISTORY");
+ Console.WriteLine("================================");
+ foreach (ChatMessageContent message in chat)
+ {
+ this.WriteContent(message);
+ }
+
+ // Local function to invoke agent and display the conversation messages.
+ async Task InvokeAgentAsync(string input)
+ {
+ ChatMessageContent userContent = new(AuthorRole.User, input);
+ chat.Add(userContent);
+ this.WriteContent(userContent);
+
+ await foreach (ChatMessageContent content in agent.InvokeAsync(chat))
+ {
+ // Do not add a message implicitly added to the history.
+ if (!content.Items.Any(i => i is FunctionCallContent || i is FunctionResultContent))
+ {
+ chat.Add(content);
+ }
+
+ this.WriteContent(content);
+ }
+ }
+ }
+
+ [Fact]
+ public async Task UseAutoFunctionInvocationFilterWithAgentChatAsync()
+ {
+ // Define the agent
+ ChatCompletionAgent agent =
+ new()
+ {
+ Instructions = "Answer questions about the menu.",
+ Kernel = CreateKernelWithChatCompletion(),
+ ExecutionSettings = new OpenAIPromptExecutionSettings() { ToolCallBehavior = ToolCallBehavior.AutoInvokeKernelFunctions },
+ };
+
+ KernelPlugin plugin = KernelPluginFactory.CreateFromType();
+ agent.Kernel.Plugins.Add(plugin);
+
+ // Create a chat for agent interaction.
+ AgentGroupChat chat = new();
+
+ // Respond to user input, invoking functions where appropriate.
+ await InvokeAgentAsync("Hello");
+ await InvokeAgentAsync("What is the special soup?");
+ await InvokeAgentAsync("What is the special drink?");
+ await InvokeAgentAsync("Thank you");
+
+ // Display the chat history.
+ Console.WriteLine("================================");
+ Console.WriteLine("CHAT HISTORY");
+ Console.WriteLine("================================");
+ ChatMessageContent[] history = await chat.GetChatMessagesAsync().ToArrayAsync();
+ for (int index = history.Length; index > 0; --index)
+ {
+ this.WriteContent(history[index - 1]);
+ }
+
+ // Local function to invoke agent and display the conversation messages.
+ async Task InvokeAgentAsync(string input)
+ {
+ ChatMessageContent userContent = new(AuthorRole.User, input);
+ chat.AddChatMessage(userContent);
+ this.WriteContent(userContent);
+
+ await foreach (ChatMessageContent content in chat.InvokeAsync(agent))
+ {
+ this.WriteContent(content);
+ }
+ }
+ }
+
+ private void WriteContent(ChatMessageContent content)
+ {
+ Console.WriteLine($"[{content.Items.LastOrDefault()?.GetType().Name ?? "(empty)"}] {content.Role} : '{content.Content}'");
+ }
+
+ private sealed class MenuPlugin
+ {
+ [KernelFunction, Description("Provides a list of specials from the menu.")]
+ [System.Diagnostics.CodeAnalysis.SuppressMessage("Design", "CA1024:Use properties where appropriate", Justification = "Too smart")]
+ public string GetSpecials()
+ {
+ return @"
+Special Soup: Clam Chowder
+Special Salad: Cobb Salad
+Special Drink: Chai Tea
+";
+ }
+
+ [KernelFunction, Description("Provides the price of the requested menu item.")]
+ public string GetItemPrice(
+ [Description("The name of the menu item.")]
+ string menuItem)
+ {
+ return "$9.99";
+ }
+ }
+
+ private sealed class AutoInvocationFilter(bool terminate = true) : IAutoFunctionInvocationFilter
+ {
+ public async Task OnAutoFunctionInvocationAsync(AutoFunctionInvocationContext context, Func next)
+ {
+ // Execution the function
+ await next(context);
+
+ // Signal termination if the function is from the MenuPlugin
+ if (context.Function.PluginName == nameof(MenuPlugin))
+ {
+ context.Terminate = terminate;
+ }
+ }
+ }
+}
diff --git a/dotnet/samples/Concepts/Agents/ChatCompletion_Streaming.cs b/dotnet/samples/Concepts/Agents/ChatCompletion_Streaming.cs
index ee6fb9b38f2a..258e12166a6b 100644
--- a/dotnet/samples/Concepts/Agents/ChatCompletion_Streaming.cs
+++ b/dotnet/samples/Concepts/Agents/ChatCompletion_Streaming.cs
@@ -1,8 +1,10 @@
// Copyright (c) Microsoft. All rights reserved.
+using System.ComponentModel;
using System.Text;
using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.Agents;
using Microsoft.SemanticKernel.ChatCompletion;
+using Microsoft.SemanticKernel.Connectors.OpenAI;
namespace Agents;
@@ -30,40 +32,88 @@ public async Task UseStreamingChatCompletionAgentAsync()
ChatHistory chat = [];
// Respond to user input
- await InvokeAgentAsync("Fortune favors the bold.");
- await InvokeAgentAsync("I came, I saw, I conquered.");
- await InvokeAgentAsync("Practice makes perfect.");
+ await InvokeAgentAsync(agent, chat, "Fortune favors the bold.");
+ await InvokeAgentAsync(agent, chat, "I came, I saw, I conquered.");
+ await InvokeAgentAsync(agent, chat, "Practice makes perfect.");
+ }
- // Local function to invoke agent and display the conversation messages.
- async Task InvokeAgentAsync(string input)
- {
- chat.Add(new ChatMessageContent(AuthorRole.User, input));
+ [Fact]
+ public async Task UseStreamingChatCompletionAgentWithPluginAsync()
+ {
+ const string MenuInstructions = "Answer questions about the menu.";
+
+ // Define the agent
+ ChatCompletionAgent agent =
+ new()
+ {
+ Name = "Host",
+ Instructions = MenuInstructions,
+ Kernel = this.CreateKernelWithChatCompletion(),
+ ExecutionSettings = new OpenAIPromptExecutionSettings() { ToolCallBehavior = ToolCallBehavior.AutoInvokeKernelFunctions },
+ };
+
+ // Initialize plugin and add to the agent's Kernel (same as direct Kernel usage).
+ KernelPlugin plugin = KernelPluginFactory.CreateFromType();
+ agent.Kernel.Plugins.Add(plugin);
+
+ ChatHistory chat = [];
+
+ // Respond to user input
+ await InvokeAgentAsync(agent, chat, "What is the special soup?");
+ await InvokeAgentAsync(agent, chat, "What is the special drink?");
+ }
+
+ // Local function to invoke agent and display the conversation messages.
+ private async Task InvokeAgentAsync(ChatCompletionAgent agent, ChatHistory chat, string input)
+ {
+ chat.Add(new ChatMessageContent(AuthorRole.User, input));
- Console.WriteLine($"# {AuthorRole.User}: '{input}'");
+ Console.WriteLine($"# {AuthorRole.User}: '{input}'");
- StringBuilder builder = new();
- await foreach (StreamingChatMessageContent message in agent.InvokeStreamingAsync(chat))
+ StringBuilder builder = new();
+ await foreach (StreamingChatMessageContent message in agent.InvokeStreamingAsync(chat))
+ {
+ if (string.IsNullOrEmpty(message.Content))
{
- if (string.IsNullOrEmpty(message.Content))
- {
- continue;
- }
-
- if (builder.Length == 0)
- {
- Console.WriteLine($"# {message.Role} - {message.AuthorName ?? "*"}:");
- }
-
- Console.WriteLine($"\t > streamed: '{message.Content}'");
- builder.Append(message.Content);
+ continue;
}
- if (builder.Length > 0)
+ if (builder.Length == 0)
{
- // Display full response and capture in chat history
- Console.WriteLine($"\t > complete: '{builder}'");
- chat.Add(new ChatMessageContent(AuthorRole.Assistant, builder.ToString()) { AuthorName = agent.Name });
+ Console.WriteLine($"# {message.Role} - {message.AuthorName ?? "*"}:");
}
+
+ Console.WriteLine($"\t > streamed: '{message.Content}'");
+ builder.Append(message.Content);
+ }
+
+ if (builder.Length > 0)
+ {
+ // Display full response and capture in chat history
+ Console.WriteLine($"\t > complete: '{builder}'");
+ chat.Add(new ChatMessageContent(AuthorRole.Assistant, builder.ToString()) { AuthorName = agent.Name });
+ }
+ }
+
+ public sealed class MenuPlugin
+ {
+ [KernelFunction, Description("Provides a list of specials from the menu.")]
+ [System.Diagnostics.CodeAnalysis.SuppressMessage("Design", "CA1024:Use properties where appropriate", Justification = "Too smart")]
+ public string GetSpecials()
+ {
+ return @"
+Special Soup: Clam Chowder
+Special Salad: Cobb Salad
+Special Drink: Chai Tea
+";
+ }
+
+ [KernelFunction, Description("Provides the price of the requested menu item.")]
+ public string GetItemPrice(
+ [Description("The name of the menu item.")]
+ string menuItem)
+ {
+ return "$9.99";
}
}
}
diff --git a/dotnet/src/Agents/Abstractions/AgentChannel.cs b/dotnet/src/Agents/Abstractions/AgentChannel.cs
index ad58deedb017..9788464a2adb 100644
--- a/dotnet/src/Agents/Abstractions/AgentChannel.cs
+++ b/dotnet/src/Agents/Abstractions/AgentChannel.cs
@@ -23,7 +23,7 @@ public abstract class AgentChannel
///
/// The chat history at the point the channel is created.
/// The to monitor for cancellation requests. The default is .
- protected internal abstract Task ReceiveAsync(IReadOnlyList history, CancellationToken cancellationToken = default);
+ protected internal abstract Task ReceiveAsync(IEnumerable history, CancellationToken cancellationToken = default);
///
/// Perform a discrete incremental interaction between a single and .
@@ -31,7 +31,7 @@ public abstract class AgentChannel
/// The agent actively interacting with the chat.
/// The to monitor for cancellation requests. The default is .
/// Asynchronous enumeration of messages.
- protected internal abstract IAsyncEnumerable InvokeAsync(
+ protected internal abstract IAsyncEnumerable<(bool IsVisible, ChatMessageContent Message)> InvokeAsync(
Agent agent,
CancellationToken cancellationToken = default);
@@ -59,12 +59,12 @@ public abstract class AgentChannel : AgentChannel where TAgent : Agent
/// The agent actively interacting with the chat.
/// The to monitor for cancellation requests. The default is .
/// Asynchronous enumeration of messages.
- protected internal abstract IAsyncEnumerable InvokeAsync(
+ protected internal abstract IAsyncEnumerable<(bool IsVisible, ChatMessageContent Message)> InvokeAsync(
TAgent agent,
CancellationToken cancellationToken = default);
///
- protected internal override IAsyncEnumerable InvokeAsync(
+ protected internal override IAsyncEnumerable<(bool IsVisible, ChatMessageContent Message)> InvokeAsync(
Agent agent,
CancellationToken cancellationToken = default)
{
diff --git a/dotnet/src/Agents/Abstractions/AgentChat.cs b/dotnet/src/Agents/Abstractions/AgentChat.cs
index 9c834380a8f4..f4654963444e 100644
--- a/dotnet/src/Agents/Abstractions/AgentChat.cs
+++ b/dotnet/src/Agents/Abstractions/AgentChat.cs
@@ -209,22 +209,21 @@ protected async IAsyncEnumerable InvokeAgentAsync(
// Invoke agent & process response
List messages = [];
- await foreach (ChatMessageContent message in channel.InvokeAsync(agent, cancellationToken).ConfigureAwait(false))
+
+ await foreach ((bool isVisible, ChatMessageContent message) in channel.InvokeAsync(agent, cancellationToken).ConfigureAwait(false))
{
this.Logger.LogAgentChatInvokedAgentMessage(nameof(InvokeAgentAsync), agent.GetType(), agent.Id, message);
+ messages.Add(message);
+
// Add to primary history
this.History.Add(message);
- messages.Add(message);
- // Don't expose function-call and function-result messages to caller.
- if (message.Items.All(i => i is FunctionCallContent || i is FunctionResultContent))
+ if (isVisible)
{
- continue;
+ // Yield message to caller
+ yield return message;
}
-
- // Yield message to caller
- yield return message;
}
// Broadcast message to other channels (in parallel)
@@ -233,7 +232,7 @@ protected async IAsyncEnumerable InvokeAgentAsync(
this._agentChannels
.Where(kvp => kvp.Value != channel)
.Select(kvp => new ChannelReference(kvp.Value, kvp.Key));
- this._broadcastQueue.Enqueue(channelRefs, messages.Where(m => m.Role != AuthorRole.Tool).ToArray());
+ this._broadcastQueue.Enqueue(channelRefs, messages);
this.Logger.LogAgentChatInvokedAgent(nameof(InvokeAgentAsync), agent.GetType(), agent.Id);
}
@@ -256,6 +255,7 @@ async Task GetOrCreateChannelAsync()
if (this.History.Count > 0)
{
+ // Sync channel with existing history
await channel.ReceiveAsync(this.History, cancellationToken).ConfigureAwait(false);
}
diff --git a/dotnet/src/Agents/Abstractions/AggregatorChannel.cs b/dotnet/src/Agents/Abstractions/AggregatorChannel.cs
index 60b1cd4367f6..73561a4eba8b 100644
--- a/dotnet/src/Agents/Abstractions/AggregatorChannel.cs
+++ b/dotnet/src/Agents/Abstractions/AggregatorChannel.cs
@@ -18,7 +18,7 @@ protected internal override IAsyncEnumerable GetHistoryAsync
return this._chat.GetChatMessagesAsync(cancellationToken);
}
- protected internal override async IAsyncEnumerable InvokeAsync(AggregatorAgent agent, [EnumeratorCancellation] CancellationToken cancellationToken = default)
+ protected internal override async IAsyncEnumerable<(bool IsVisible, ChatMessageContent Message)> InvokeAsync(AggregatorAgent agent, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
ChatMessageContent? lastMessage = null;
@@ -27,7 +27,7 @@ protected internal override async IAsyncEnumerable InvokeAsy
// For AggregatorMode.Flat, the entire aggregated chat is merged into the owning chat.
if (agent.Mode == AggregatorMode.Flat)
{
- yield return message;
+ yield return (IsVisible: true, message);
}
lastMessage = message;
@@ -43,11 +43,11 @@ protected internal override async IAsyncEnumerable InvokeAsy
AuthorName = agent.Name
};
- yield return message;
+ yield return (IsVisible: true, message);
}
}
- protected internal override Task ReceiveAsync(IReadOnlyList history, CancellationToken cancellationToken = default)
+ protected internal override Task ReceiveAsync(IEnumerable history, CancellationToken cancellationToken = default)
{
// Always receive the initial history from the owning chat.
this._chat.AddChatMessages([.. history]);
diff --git a/dotnet/src/Agents/Abstractions/ChatHistoryChannel.cs b/dotnet/src/Agents/Abstractions/ChatHistoryChannel.cs
index 2bb5616ff959..5dcb6b9b0204 100644
--- a/dotnet/src/Agents/Abstractions/ChatHistoryChannel.cs
+++ b/dotnet/src/Agents/Abstractions/ChatHistoryChannel.cs
@@ -1,5 +1,6 @@
// Copyright (c) Microsoft. All rights reserved.
using System.Collections.Generic;
+using System.Linq;
using System.Runtime.CompilerServices;
using System.Threading;
using System.Threading.Tasks;
@@ -16,7 +17,7 @@ public class ChatHistoryChannel : AgentChannel
private readonly ChatHistory _history;
///
- protected internal sealed override async IAsyncEnumerable InvokeAsync(
+ protected internal sealed override async IAsyncEnumerable<(bool IsVisible, ChatMessageContent Message)> InvokeAsync(
Agent agent,
[EnumeratorCancellation] CancellationToken cancellationToken = default)
{
@@ -25,16 +26,55 @@ protected internal sealed override async IAsyncEnumerable In
throw new KernelException($"Invalid channel binding for agent: {agent.Id} ({agent.GetType().FullName})");
}
- await foreach (ChatMessageContent message in historyHandler.InvokeAsync(this._history, cancellationToken).ConfigureAwait(false))
+ // Capture the current message count to evaluate history mutation.
+ int messageCount = this._history.Count;
+ HashSet mutatedHistory = [];
+
+ // Utilize a queue as a "read-ahead" cache to evaluate message sequencing (i.e., which message is final).
+ Queue messageQueue = [];
+
+ ChatMessageContent? yieldMessage = null;
+ await foreach (ChatMessageContent responseMessage in historyHandler.InvokeAsync(this._history, cancellationToken).ConfigureAwait(false))
{
- this._history.Add(message);
+ // Capture all messages that have been included in the mutated the history.
+ for (int messageIndex = messageCount; messageIndex < this._history.Count; messageIndex++)
+ {
+ ChatMessageContent mutatedMessage = this._history[messageIndex];
+ mutatedHistory.Add(mutatedMessage);
+ messageQueue.Enqueue(mutatedMessage);
+ }
+
+ // Update the message count pointer to reflect the current history.
+ messageCount = this._history.Count;
- yield return message;
+ // Avoid duplicating any message included in the mutated history and also returned by the enumeration result.
+ if (!mutatedHistory.Contains(responseMessage))
+ {
+ this._history.Add(responseMessage);
+ messageQueue.Enqueue(responseMessage);
+ }
+
+ // Dequeue the next message to yield.
+ yieldMessage = messageQueue.Dequeue();
+ yield return (IsMessageVisible(yieldMessage), yieldMessage);
}
+
+ // Dequeue any remaining messages to yield.
+ while (messageQueue.Count > 0)
+ {
+ yieldMessage = messageQueue.Dequeue();
+
+ yield return (IsMessageVisible(yieldMessage), yieldMessage);
+ }
+
+ // Function content not visible, unless result is the final message.
+ bool IsMessageVisible(ChatMessageContent message) =>
+ (!message.Items.Any(i => i is FunctionCallContent || i is FunctionResultContent) ||
+ messageQueue.Count == 0);
}
///
- protected internal sealed override Task ReceiveAsync(IReadOnlyList history, CancellationToken cancellationToken)
+ protected internal sealed override Task ReceiveAsync(IEnumerable history, CancellationToken cancellationToken)
{
this._history.AddRange(history);
diff --git a/dotnet/src/Agents/Core/ChatCompletionAgent.cs b/dotnet/src/Agents/Core/ChatCompletionAgent.cs
index 990154b139e4..1e9ea3d3208e 100644
--- a/dotnet/src/Agents/Core/ChatCompletionAgent.cs
+++ b/dotnet/src/Agents/Core/ChatCompletionAgent.cs
@@ -84,22 +84,22 @@ public override async IAsyncEnumerable InvokeStream
this.Logger.LogAgentChatServiceInvokedStreamingAgent(nameof(InvokeAsync), this.Id, chatCompletionService.GetType());
- // Capture mutated messages related function calling / tools
- for (int messageIndex = messageCount; messageIndex < chat.Count; messageIndex++)
+ await foreach (StreamingChatMessageContent message in messages.ConfigureAwait(false))
{
- ChatMessageContent message = chat[messageIndex];
-
+ // TODO: MESSAGE SOURCE - ISSUE #5731
message.AuthorName = this.Name;
- history.Add(message);
+ yield return message;
}
- await foreach (StreamingChatMessageContent message in messages.ConfigureAwait(false))
+ // Capture mutated messages related function calling / tools
+ for (int messageIndex = messageCount; messageIndex < chat.Count; messageIndex++)
{
- // TODO: MESSAGE SOURCE - ISSUE #5731
+ ChatMessageContent message = chat[messageIndex];
+
message.AuthorName = this.Name;
- yield return message;
+ history.Add(message);
}
}
diff --git a/dotnet/src/Agents/OpenAI/AssistantThreadActions.cs b/dotnet/src/Agents/OpenAI/AssistantThreadActions.cs
index b1be5bb52765..f768d89a54bb 100644
--- a/dotnet/src/Agents/OpenAI/AssistantThreadActions.cs
+++ b/dotnet/src/Agents/OpenAI/AssistantThreadActions.cs
@@ -20,12 +20,6 @@ internal static class AssistantThreadActions
{
private const string FunctionDelimiter = "-";
- private static readonly HashSet s_messageRoles =
- [
- AuthorRole.User,
- AuthorRole.Assistant,
- ];
-
private static readonly HashSet s_pollingStatuses =
[
RunStatus.Queued,
@@ -50,12 +44,8 @@ internal static class AssistantThreadActions
/// if a system message is present, without taking any other action
public static async Task CreateMessageAsync(AssistantsClient client, string threadId, ChatMessageContent message, CancellationToken cancellationToken)
{
- if (!s_messageRoles.Contains(message.Role))
- {
- throw new KernelException($"Invalid message role: {message.Role}");
- }
-
- if (string.IsNullOrWhiteSpace(message.Content))
+ if (string.IsNullOrEmpty(message.Content) ||
+ message.Items.Any(i => i is FunctionCallContent))
{
return;
}
@@ -136,7 +126,7 @@ public static async IAsyncEnumerable GetMessagesAsync(Assist
/// The logger to utilize (might be agent or channel scoped)
/// The to monitor for cancellation requests. The default is .
/// Asynchronous enumeration of messages.
- public static async IAsyncEnumerable InvokeAsync(
+ public static async IAsyncEnumerable<(bool IsVisible, ChatMessageContent Message)> InvokeAsync(
OpenAIAssistantAgent agent,
AssistantsClient client,
string threadId,
@@ -190,7 +180,7 @@ public static async IAsyncEnumerable InvokeAsync(
if (activeFunctionSteps.Length > 0)
{
// Emit function-call content
- yield return GenerateFunctionCallContent(agent.GetName(), activeFunctionSteps);
+ yield return (IsVisible: false, Message: GenerateFunctionCallContent(agent.GetName(), activeFunctionSteps));
// Invoke functions for each tool-step
IEnumerable> functionResultTasks = ExecuteFunctionSteps(agent, activeFunctionSteps, cancellationToken);
@@ -224,12 +214,14 @@ public static async IAsyncEnumerable InvokeAsync(
foreach (RunStepToolCall toolCall in toolCallDetails.ToolCalls)
{
+ bool isVisible = false;
ChatMessageContent? content = null;
// Process code-interpreter content
if (toolCall is RunStepCodeInterpreterToolCall toolCodeInterpreter)
{
content = GenerateCodeInterpreterContent(agent.GetName(), toolCodeInterpreter);
+ isVisible = true;
}
// Process function result content
else if (toolCall is RunStepFunctionToolCall toolFunction)
@@ -242,7 +234,7 @@ public static async IAsyncEnumerable InvokeAsync(
{
++messageCount;
- yield return content;
+ yield return (isVisible, Message: content);
}
}
}
@@ -276,7 +268,7 @@ public static async IAsyncEnumerable InvokeAsync(
{
++messageCount;
- yield return content;
+ yield return (IsVisible: true, Message: content);
}
}
}
diff --git a/dotnet/src/Agents/OpenAI/OpenAIAssistantAgent.cs b/dotnet/src/Agents/OpenAI/OpenAIAssistantAgent.cs
index 31c0bb1c0de7..8e8797fa8885 100644
--- a/dotnet/src/Agents/OpenAI/OpenAIAssistantAgent.cs
+++ b/dotnet/src/Agents/OpenAI/OpenAIAssistantAgent.cs
@@ -240,13 +240,19 @@ public async Task DeleteAsync(CancellationToken cancellationToken = defaul
/// The thread identifier
/// The to monitor for cancellation requests. The default is .
/// Asynchronous enumeration of messages.
- public IAsyncEnumerable InvokeAsync(
+ public async IAsyncEnumerable InvokeAsync(
string threadId,
- CancellationToken cancellationToken = default)
+ [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
this.ThrowIfDeleted();
- return AssistantThreadActions.InvokeAsync(this, this._client, threadId, this._config.Polling, this.Logger, cancellationToken);
+ await foreach ((bool isVisible, ChatMessageContent message) in AssistantThreadActions.InvokeAsync(this, this._client, threadId, this._config.Polling, this.Logger, cancellationToken).ConfigureAwait(false))
+ {
+ if (isVisible)
+ {
+ yield return message;
+ }
+ }
}
///
diff --git a/dotnet/src/Agents/OpenAI/OpenAIAssistantChannel.cs b/dotnet/src/Agents/OpenAI/OpenAIAssistantChannel.cs
index b84ef800ebd4..48fdefa65fe9 100644
--- a/dotnet/src/Agents/OpenAI/OpenAIAssistantChannel.cs
+++ b/dotnet/src/Agents/OpenAI/OpenAIAssistantChannel.cs
@@ -16,7 +16,7 @@ internal sealed class OpenAIAssistantChannel(AssistantsClient client, string thr
private readonly string _threadId = threadId;
///
- protected override async Task ReceiveAsync(IReadOnlyList history, CancellationToken cancellationToken)
+ protected override async Task ReceiveAsync(IEnumerable history, CancellationToken cancellationToken)
{
foreach (ChatMessageContent message in history)
{
@@ -25,7 +25,7 @@ protected override async Task ReceiveAsync(IReadOnlyList his
}
///
- protected override IAsyncEnumerable InvokeAsync(
+ protected override IAsyncEnumerable<(bool IsVisible, ChatMessageContent Message)> InvokeAsync(
OpenAIAssistantAgent agent,
CancellationToken cancellationToken)
{
diff --git a/dotnet/src/Agents/UnitTests/AgentChannelTests.cs b/dotnet/src/Agents/UnitTests/AgentChannelTests.cs
index 7223b8d46805..2a680614a54f 100644
--- a/dotnet/src/Agents/UnitTests/AgentChannelTests.cs
+++ b/dotnet/src/Agents/UnitTests/AgentChannelTests.cs
@@ -40,11 +40,11 @@ private sealed class TestChannel : AgentChannel
{
public int InvokeCount { get; private set; }
- public IAsyncEnumerable InvokeAgentAsync(Agent agent, CancellationToken cancellationToken = default)
+ public IAsyncEnumerable<(bool IsVisible, ChatMessageContent Message)> InvokeAgentAsync(Agent agent, CancellationToken cancellationToken = default)
=> base.InvokeAsync(agent, cancellationToken);
#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously
- protected internal override async IAsyncEnumerable InvokeAsync(TestAgent agent, [EnumeratorCancellation] CancellationToken cancellationToken = default)
+ protected internal override async IAsyncEnumerable<(bool IsVisible, ChatMessageContent Message)> InvokeAsync(TestAgent agent, [EnumeratorCancellation] CancellationToken cancellationToken = default)
#pragma warning restore CS1998 // Async method lacks 'await' operators and will run synchronously
{
this.InvokeCount++;
@@ -57,7 +57,7 @@ protected internal override IAsyncEnumerable GetHistoryAsync
throw new NotImplementedException();
}
- protected internal override Task ReceiveAsync(IReadOnlyList history, CancellationToken cancellationToken = default)
+ protected internal override Task ReceiveAsync(IEnumerable history, CancellationToken cancellationToken = default)
{
throw new NotImplementedException();
}
diff --git a/dotnet/src/Agents/UnitTests/Internal/BroadcastQueueTests.cs b/dotnet/src/Agents/UnitTests/Internal/BroadcastQueueTests.cs
index 482c4cfa09a3..452a0566e11f 100644
--- a/dotnet/src/Agents/UnitTests/Internal/BroadcastQueueTests.cs
+++ b/dotnet/src/Agents/UnitTests/Internal/BroadcastQueueTests.cs
@@ -136,12 +136,12 @@ protected internal override IAsyncEnumerable GetHistoryAsync
throw new NotImplementedException();
}
- protected internal override IAsyncEnumerable InvokeAsync(Agent agent, CancellationToken cancellationToken = default)
+ protected internal override IAsyncEnumerable<(bool IsVisible, ChatMessageContent Message)> InvokeAsync(Agent agent, CancellationToken cancellationToken = default)
{
throw new NotImplementedException();
}
- protected internal override async Task ReceiveAsync(IReadOnlyList history, CancellationToken cancellationToken = default)
+ protected internal override async Task ReceiveAsync(IEnumerable history, CancellationToken cancellationToken = default)
{
this.ReceivedMessages.AddRange(history);
this.ReceiveCount++;
@@ -159,12 +159,12 @@ protected internal override IAsyncEnumerable GetHistoryAsync
throw new NotImplementedException();
}
- protected internal override IAsyncEnumerable InvokeAsync(Agent agent, CancellationToken cancellationToken = default)
+ protected internal override IAsyncEnumerable<(bool IsVisible, ChatMessageContent Message)> InvokeAsync(Agent agent, CancellationToken cancellationToken = default)
{
throw new NotImplementedException();
}
- protected internal override async Task ReceiveAsync(IReadOnlyList history, CancellationToken cancellationToken = default)
+ protected internal override async Task ReceiveAsync(IEnumerable history, CancellationToken cancellationToken = default)
{
await Task.Delay(this.ReceiveDuration, cancellationToken);
diff --git a/dotnet/src/IntegrationTests/Agents/ChatCompletionAgentTests.cs b/dotnet/src/IntegrationTests/Agents/ChatCompletionAgentTests.cs
new file mode 100644
index 000000000000..91796c1970b0
--- /dev/null
+++ b/dotnet/src/IntegrationTests/Agents/ChatCompletionAgentTests.cs
@@ -0,0 +1,140 @@
+// Copyright (c) Microsoft. All rights reserved.
+using System;
+using System.ComponentModel;
+using System.Linq;
+using System.Threading.Tasks;
+using Microsoft.Extensions.Configuration;
+using Microsoft.Extensions.DependencyInjection;
+using Microsoft.Extensions.Logging;
+using Microsoft.SemanticKernel;
+using Microsoft.SemanticKernel.Agents;
+using Microsoft.SemanticKernel.ChatCompletion;
+using Microsoft.SemanticKernel.Connectors.OpenAI;
+using SemanticKernel.IntegrationTests.TestSettings;
+using Xunit;
+using Xunit.Abstractions;
+
+namespace SemanticKernel.IntegrationTests.Agents.OpenAI;
+
+#pragma warning disable xUnit1004 // Contains test methods used in manual verification. Disable warning for this file only.
+
+public sealed class ChatCompletionAgentTests(ITestOutputHelper output) : IDisposable
+{
+ private readonly IKernelBuilder _kernelBuilder = Kernel.CreateBuilder();
+ private readonly IConfigurationRoot _configuration = new ConfigurationBuilder()
+ .AddJsonFile(path: "testsettings.json", optional: false, reloadOnChange: true)
+ .AddJsonFile(path: "testsettings.development.json", optional: true, reloadOnChange: true)
+ .AddEnvironmentVariables()
+ .AddUserSecrets()
+ .Build();
+
+ ///
+ /// Integration test for using function calling
+ /// and targeting Azure OpenAI services.
+ ///
+ [Theory]
+ [InlineData("What is the special soup?", "Clam Chowder", false)]
+ [InlineData("What is the special soup?", "Clam Chowder", true)]
+ public async Task AzureChatCompletionAgentAsync(string input, string expectedAnswerContains, bool useAutoFunctionTermination)
+ {
+ // Arrange
+ AzureOpenAIConfiguration configuration = this._configuration.GetSection("AzureOpenAI").Get()!;
+
+ KernelPlugin plugin = KernelPluginFactory.CreateFromType();
+
+ this._kernelBuilder.Services.AddSingleton(this._logger);
+
+ this._kernelBuilder.AddAzureOpenAIChatCompletion(
+ configuration.ChatDeploymentName!,
+ configuration.Endpoint,
+ configuration.ApiKey);
+
+ if (useAutoFunctionTermination)
+ {
+ this._kernelBuilder.Services.AddSingleton(new AutoInvocationFilter());
+ }
+
+ this._kernelBuilder.Plugins.Add(plugin);
+
+ Kernel kernel = this._kernelBuilder.Build();
+
+ ChatCompletionAgent agent =
+ new()
+ {
+ Kernel = kernel,
+ Instructions = "Answer questions about the menu.",
+ ExecutionSettings = new OpenAIPromptExecutionSettings() { ToolCallBehavior = ToolCallBehavior.AutoInvokeKernelFunctions },
+ };
+
+ AgentGroupChat chat = new();
+ chat.AddChatMessage(new ChatMessageContent(AuthorRole.User, input));
+
+ // Act
+ ChatMessageContent[] messages = await chat.InvokeAsync(agent).ToArrayAsync();
+ ChatMessageContent[] history = await chat.GetChatMessagesAsync().ToArrayAsync();
+
+ // Assert
+ Assert.Single(messages);
+
+ ChatMessageContent response = messages.Single();
+
+ if (useAutoFunctionTermination)
+ {
+ Assert.Equal(3, history.Length);
+ Assert.Single(response.Items.OfType());
+ Assert.Single(response.Items.OfType());
+ }
+ else
+ {
+ Assert.Equal(4, history.Length);
+ Assert.Single(response.Items);
+ Assert.Single(response.Items.OfType());
+ }
+
+ Assert.Contains(expectedAnswerContains, messages.Single().Content, StringComparison.OrdinalIgnoreCase);
+ }
+
+ private readonly XunitLogger _logger = new(output);
+ private readonly RedirectOutput _testOutputHelper = new(output);
+
+ public void Dispose()
+ {
+ this._logger.Dispose();
+ this._testOutputHelper.Dispose();
+ }
+
+ public sealed class MenuPlugin
+ {
+ [KernelFunction, Description("Provides a list of specials from the menu.")]
+ [System.Diagnostics.CodeAnalysis.SuppressMessage("Design", "CA1024:Use properties where appropriate", Justification = "Too smart")]
+ public string GetSpecials()
+ {
+ return @"
+Special Soup: Clam Chowder
+Special Salad: Cobb Salad
+Special Drink: Chai Tea
+";
+ }
+
+ [KernelFunction, Description("Provides the price of the requested menu item.")]
+ public string GetItemPrice(
+ [Description("The name of the menu item.")]
+ string menuItem)
+ {
+ return "$9.99";
+ }
+ }
+
+ private sealed class AutoInvocationFilter(bool terminate = true) : IAutoFunctionInvocationFilter
+ {
+ public async Task OnAutoFunctionInvocationAsync(AutoFunctionInvocationContext context, Func next)
+ {
+ await next(context);
+
+ if (context.Function.PluginName == nameof(MenuPlugin))
+ {
+ context.Terminate = terminate;
+ }
+ }
+ }
+}