diff --git a/dotnet/samples/GettingStartedWithAgents/Step1_Agent.cs b/dotnet/samples/GettingStartedWithAgents/Step1_Agent.cs
index 682c96001deb..c9ffcdac8a84 100644
--- a/dotnet/samples/GettingStartedWithAgents/Step1_Agent.cs
+++ b/dotnet/samples/GettingStartedWithAgents/Step1_Agent.cs
@@ -27,7 +27,7 @@ public async Task UseSingleChatCompletionAgentAsync()
};
/// Create a chat for agent interaction. For more, .
- AgentGroupChat chat = new();
+ ChatHistory chat = new();
// Respond to user input
await InvokeAgentAsync("Fortune favors the bold.");
@@ -37,11 +37,11 @@ public async Task UseSingleChatCompletionAgentAsync()
// Local function to invoke agent and display the conversation messages.
async Task InvokeAgentAsync(string input)
{
- chat.AddChatMessage(new ChatMessageContent(AuthorRole.User, input));
+ chat.Add(new ChatMessageContent(AuthorRole.User, input));
Console.WriteLine($"# {AuthorRole.User}: '{input}'");
- await foreach (var content in chat.InvokeAsync(agent))
+ await foreach (var content in agent.InvokeAsync(chat))
{
Console.WriteLine($"# {content.Role} - {content.AuthorName ?? "*"}: '{content.Content}'");
}
diff --git a/dotnet/samples/GettingStartedWithAgents/Step8_OpenAIAssistant.cs b/dotnet/samples/GettingStartedWithAgents/Step8_OpenAIAssistant.cs
index 5f03ffb39c8f..09afcfc44826 100644
--- a/dotnet/samples/GettingStartedWithAgents/Step8_OpenAIAssistant.cs
+++ b/dotnet/samples/GettingStartedWithAgents/Step8_OpenAIAssistant.cs
@@ -37,7 +37,7 @@ await OpenAIAssistantAgent.CreateAsync(
agent.Kernel.Plugins.Add(plugin);
// Create a chat for agent interaction.
- var chat = new AgentGroupChat();
+ string threadId = await agent.CreateThreadAsync();
// Respond to user input
try
@@ -49,19 +49,23 @@ await OpenAIAssistantAgent.CreateAsync(
}
finally
{
+ await agent.DeleteThreadAsync(threadId);
await agent.DeleteAsync();
}
// Local function to invoke agent and display the conversation messages.
async Task InvokeAgentAsync(string input)
{
- chat.AddChatMessage(new ChatMessageContent(AuthorRole.User, input));
+ await agent.AddChatMessageAsync(threadId, new ChatMessageContent(AuthorRole.User, input));
Console.WriteLine($"# {AuthorRole.User}: '{input}'");
- await foreach (var content in chat.InvokeAsync(agent))
+ await foreach (var content in agent.InvokeAsync(threadId))
{
- Console.WriteLine($"# {content.Role} - {content.AuthorName ?? "*"}: '{content.Content}'");
+ if (content.Role != AuthorRole.Tool)
+ {
+ Console.WriteLine($"# {content.Role} - {content.AuthorName ?? "*"}: '{content.Content}'");
+ }
}
}
}
diff --git a/dotnet/src/Agents/OpenAI/AssistantThreadActions.cs b/dotnet/src/Agents/OpenAI/AssistantThreadActions.cs
new file mode 100644
index 000000000000..37649844a230
--- /dev/null
+++ b/dotnet/src/Agents/OpenAI/AssistantThreadActions.cs
@@ -0,0 +1,525 @@
+// Copyright (c) Microsoft. All rights reserved.
+using System.Collections.Generic;
+using System.Linq;
+using System.Net;
+using System.Runtime.CompilerServices;
+using System.Text.Json;
+using System.Threading;
+using System.Threading.Tasks;
+using Azure;
+using Azure.AI.OpenAI.Assistants;
+using Microsoft.Extensions.Logging;
+using Microsoft.SemanticKernel.ChatCompletion;
+
+namespace Microsoft.SemanticKernel.Agents.OpenAI;
+
+///
+/// Actions associated with an Open Assistant thread.
+///
+internal static class AssistantThreadActions
+{
+ /*AssistantsClient client, string threadId, OpenAIAssistantConfiguration.PollingConfiguration pollingConfiguration*/
+ private const string FunctionDelimiter = "-";
+
+ private static readonly HashSet s_messageRoles =
+ [
+ AuthorRole.User,
+ AuthorRole.Assistant,
+ ];
+
+ private static readonly HashSet s_pollingStatuses =
+ [
+ RunStatus.Queued,
+ RunStatus.InProgress,
+ RunStatus.Cancelling,
+ ];
+
+ private static readonly HashSet s_terminalStatuses =
+ [
+ RunStatus.Expired,
+ RunStatus.Failed,
+ RunStatus.Cancelled,
+ ];
+
+ ///
+ /// Create a message in the specified thread.
+ ///
+ /// The assistant client
+ /// The thread identifier
+ /// The message to add
+ /// The to monitor for cancellation requests. The default is .
+ /// if a system message is present, without taking any other action
+ public static async Task CreateMessageAsync(AssistantsClient client, string threadId, ChatMessageContent message, CancellationToken cancellationToken)
+ {
+ if (!s_messageRoles.Contains(message.Role))
+ {
+ throw new KernelException($"Invalid message role: {message.Role}");
+ }
+
+ if (string.IsNullOrWhiteSpace(message.Content))
+ {
+ return;
+ }
+
+ await client.CreateMessageAsync(
+ threadId,
+ message.Role.ToMessageRole(),
+ message.Content,
+ cancellationToken: cancellationToken).ConfigureAwait(false);
+ }
+
+ ///
+ /// Retrieves the thread messages.
+ ///
+ /// The assistant client
+ /// The thread identifier
+ /// The to monitor for cancellation requests. The default is .
+ /// Asynchronous enumeration of messages.
+ public static async IAsyncEnumerable GetMessagesAsync(AssistantsClient client, string threadId, [EnumeratorCancellation] CancellationToken cancellationToken)
+ {
+ Dictionary agentNames = []; // Cache agent names by their identifier
+
+ PageableList messages;
+
+ string? lastId = null;
+ do
+ {
+ messages = await client.GetMessagesAsync(threadId, limit: 100, ListSortOrder.Descending, after: lastId, null, cancellationToken).ConfigureAwait(false);
+ foreach (ThreadMessage message in messages)
+ {
+ AuthorRole role = new(message.Role.ToString());
+
+ string? assistantName = null;
+ if (!string.IsNullOrWhiteSpace(message.AssistantId) &&
+ !agentNames.TryGetValue(message.AssistantId, out assistantName))
+ {
+ Assistant assistant = await client.GetAssistantAsync(message.AssistantId, cancellationToken).ConfigureAwait(false);
+ if (!string.IsNullOrWhiteSpace(assistant.Name))
+ {
+ agentNames.Add(assistant.Id, assistant.Name);
+ }
+ }
+
+ assistantName ??= message.AssistantId;
+
+ foreach (MessageContent item in message.ContentItems)
+ {
+ ChatMessageContent? content = null;
+
+ if (item is MessageTextContent contentMessage)
+ {
+ content = GenerateTextMessageContent(assistantName, role, contentMessage);
+ }
+ else if (item is MessageImageFileContent contentImage)
+ {
+ content = GenerateImageFileContent(assistantName, role, contentImage);
+ }
+
+ if (content is not null)
+ {
+ yield return content;
+ }
+ }
+
+ lastId = message.Id;
+ }
+ }
+ while (messages.HasMore);
+ }
+
+ ///
+ /// Invoke the assistant on the specified thread.
+ ///
+ /// The assistant agent to interact with the thread.
+ /// The assistant client
+ /// The thread identifier
+ /// Config to utilize when polling for run state.
+ /// The logger to utilize (might be agent or channel scoped)
+ /// The to monitor for cancellation requests. The default is .
+ /// Asynchronous enumeration of messages.
+ public static async IAsyncEnumerable InvokeAsync(
+ OpenAIAssistantAgent agent,
+ AssistantsClient client,
+ string threadId,
+ OpenAIAssistantConfiguration.PollingConfiguration pollingConfiguration,
+ ILogger logger,
+ [EnumeratorCancellation] CancellationToken cancellationToken)
+ {
+ if (agent.IsDeleted)
+ {
+ throw new KernelException($"Agent Failure - {nameof(OpenAIAssistantAgent)} agent is deleted: {agent.Id}.");
+ }
+
+ ToolDefinition[]? tools = [.. agent.Tools, .. agent.Kernel.Plugins.SelectMany(p => p.Select(f => f.ToToolDefinition(p.Name, FunctionDelimiter)))];
+
+ logger.LogDebug("[{MethodName}] Creating run for agent/thrad: {AgentId}/{ThreadId}", nameof(InvokeAsync), agent.Id, threadId);
+
+ CreateRunOptions options =
+ new(agent.Id)
+ {
+ OverrideInstructions = agent.Instructions,
+ OverrideTools = tools,
+ };
+
+ // Create run
+ ThreadRun run = await client.CreateRunAsync(threadId, options, cancellationToken).ConfigureAwait(false);
+
+ logger.LogInformation("[{MethodName}] Created run: {RunId}", nameof(InvokeAsync), run.Id);
+
+ // Evaluate status and process steps and messages, as encountered.
+ HashSet processedStepIds = [];
+ Dictionary functionSteps = [];
+
+ do
+ {
+ // Poll run and steps until actionable
+ PageableList steps = await PollRunStatusAsync().ConfigureAwait(false);
+
+ // Is in terminal state?
+ if (s_terminalStatuses.Contains(run.Status))
+ {
+ throw new KernelException($"Agent Failure - Run terminated: {run.Status} [{run.Id}]: {run.LastError?.Message ?? "Unknown"}");
+ }
+
+ // Is tool action required?
+ if (run.Status == RunStatus.RequiresAction)
+ {
+ logger.LogDebug("[{MethodName}] Processing run steps: {RunId}", nameof(InvokeAsync), run.Id);
+
+ // Execute functions in parallel and post results at once.
+ FunctionCallContent[] activeFunctionSteps = steps.Data.SelectMany(step => ParseFunctionStep(agent, step)).ToArray();
+ if (activeFunctionSteps.Length > 0)
+ {
+ // Emit function-call content
+ yield return GenerateFunctionCallContent(agent.GetName(), activeFunctionSteps);
+
+ // Invoke functions for each tool-step
+ IEnumerable> functionResultTasks = ExecuteFunctionSteps(agent, activeFunctionSteps, cancellationToken);
+
+ // Block for function results
+ FunctionResultContent[] functionResults = await Task.WhenAll(functionResultTasks).ConfigureAwait(false);
+
+ // Process tool output
+ ToolOutput[] toolOutputs = GenerateToolOutputs(functionResults);
+
+ await client.SubmitToolOutputsToRunAsync(run, toolOutputs, cancellationToken).ConfigureAwait(false);
+ }
+
+ if (logger.IsEnabled(LogLevel.Information)) // Avoid boxing if not enabled
+ {
+ logger.LogInformation("[{MethodName}] Processed #{MessageCount} run steps: {RunId}", nameof(InvokeAsync), activeFunctionSteps.Length, run.Id);
+ }
+ }
+
+ // Enumerate completed messages
+ logger.LogDebug("[{MethodName}] Processing run messages: {RunId}", nameof(InvokeAsync), run.Id);
+
+ IEnumerable completedStepsToProcess =
+ steps
+ .Where(s => s.CompletedAt.HasValue && !processedStepIds.Contains(s.Id))
+ .OrderBy(s => s.CreatedAt);
+
+ int messageCount = 0;
+ foreach (RunStep completedStep in completedStepsToProcess)
+ {
+ if (completedStep.Type.Equals(RunStepType.ToolCalls))
+ {
+ RunStepToolCallDetails toolCallDetails = (RunStepToolCallDetails)completedStep.StepDetails;
+
+ foreach (RunStepToolCall toolCall in toolCallDetails.ToolCalls)
+ {
+ ChatMessageContent? content = null;
+
+ // Process code-interpreter content
+ if (toolCall is RunStepCodeInterpreterToolCall toolCodeInterpreter)
+ {
+ content = GenerateCodeInterpreterContent(agent.GetName(), toolCodeInterpreter);
+ }
+ // Process function result content
+ else if (toolCall is RunStepFunctionToolCall toolFunction)
+ {
+ FunctionCallContent functionStep = functionSteps[toolFunction.Id]; // Function step always captured on invocation
+ content = GenerateFunctionResultContent(agent.GetName(), functionStep, toolFunction.Output);
+ }
+
+ if (content is not null)
+ {
+ ++messageCount;
+
+ yield return content;
+ }
+ }
+ }
+ else if (completedStep.Type.Equals(RunStepType.MessageCreation))
+ {
+ RunStepMessageCreationDetails messageCreationDetails = (RunStepMessageCreationDetails)completedStep.StepDetails;
+
+ // Retrieve the message
+ ThreadMessage? message = await RetrieveMessageAsync(messageCreationDetails, cancellationToken).ConfigureAwait(false);
+
+ if (message is not null)
+ {
+ AuthorRole role = new(message.Role.ToString());
+
+ foreach (MessageContent itemContent in message.ContentItems)
+ {
+ ChatMessageContent? content = null;
+
+ // Process text content
+ if (itemContent is MessageTextContent contentMessage)
+ {
+ content = GenerateTextMessageContent(agent.GetName(), role, contentMessage);
+ }
+ // Process image content
+ else if (itemContent is MessageImageFileContent contentImage)
+ {
+ content = GenerateImageFileContent(agent.GetName(), role, contentImage);
+ }
+
+ if (content is not null)
+ {
+ ++messageCount;
+
+ yield return content;
+ }
+ }
+ }
+ }
+
+ processedStepIds.Add(completedStep.Id);
+ }
+
+ if (logger.IsEnabled(LogLevel.Information)) // Avoid boxing if not enabled
+ {
+ logger.LogInformation("[{MethodName}] Processed #{MessageCount} run messages: {RunId}", nameof(InvokeAsync), messageCount, run.Id);
+ }
+ }
+ while (RunStatus.Completed != run.Status);
+
+ logger.LogInformation("[{MethodName}] Completed run: {RunId}", nameof(InvokeAsync), run.Id);
+
+ // Local function to assist in run polling (participates in method closure).
+ async Task> PollRunStatusAsync()
+ {
+ logger.LogInformation("[{MethodName}] Polling run status: {RunId}", nameof(PollRunStatusAsync), run.Id);
+
+ int count = 0;
+
+ do
+ {
+ // Reduce polling frequency after a couple attempts
+ await Task.Delay(count >= 2 ? pollingConfiguration.RunPollingInterval : pollingConfiguration.RunPollingBackoff, cancellationToken).ConfigureAwait(false);
+ ++count;
+
+#pragma warning disable CA1031 // Do not catch general exception types
+ try
+ {
+ run = await client.GetRunAsync(threadId, run.Id, cancellationToken).ConfigureAwait(false);
+ }
+ catch
+ {
+ // Retry anyway..
+ }
+#pragma warning restore CA1031 // Do not catch general exception types
+ }
+ while (s_pollingStatuses.Contains(run.Status));
+
+ logger.LogInformation("[{MethodName}] Run status is {RunStatus}: {RunId}", nameof(PollRunStatusAsync), run.Status, run.Id);
+
+ return await client.GetRunStepsAsync(run, cancellationToken: cancellationToken).ConfigureAwait(false);
+ }
+
+ // Local function to capture kernel function state for further processing (participates in method closure).
+ IEnumerable ParseFunctionStep(OpenAIAssistantAgent agent, RunStep step)
+ {
+ if (step.Status == RunStepStatus.InProgress && step.StepDetails is RunStepToolCallDetails callDetails)
+ {
+ foreach (RunStepFunctionToolCall toolCall in callDetails.ToolCalls.OfType())
+ {
+ var nameParts = FunctionName.Parse(toolCall.Name, FunctionDelimiter);
+
+ KernelArguments functionArguments = [];
+ if (!string.IsNullOrWhiteSpace(toolCall.Arguments))
+ {
+ Dictionary arguments = JsonSerializer.Deserialize>(toolCall.Arguments)!;
+ foreach (var argumentKvp in arguments)
+ {
+ functionArguments[argumentKvp.Key] = argumentKvp.Value.ToString();
+ }
+ }
+
+ var content = new FunctionCallContent(nameParts.Name, nameParts.PluginName, toolCall.Id, functionArguments);
+
+ functionSteps.Add(toolCall.Id, content);
+
+ yield return content;
+ }
+ }
+ }
+
+ async Task RetrieveMessageAsync(RunStepMessageCreationDetails detail, CancellationToken cancellationToken)
+ {
+ ThreadMessage? message = null;
+
+ bool retry = false;
+ int count = 0;
+ do
+ {
+ try
+ {
+ message = await client.GetMessageAsync(threadId, detail.MessageCreation.MessageId, cancellationToken).ConfigureAwait(false);
+ }
+ catch (RequestFailedException exception)
+ {
+ // Step has provided the message-id. Retry on of NotFound/404 exists.
+ // Extremely rarely there might be a synchronization issue between the
+ // assistant response and message-service.
+ retry = exception.Status == (int)HttpStatusCode.NotFound && count < 3;
+ }
+
+ if (retry)
+ {
+ await Task.Delay(pollingConfiguration.MessageSynchronizationDelay, cancellationToken).ConfigureAwait(false);
+ }
+
+ ++count;
+ }
+ while (retry);
+
+ return message;
+ }
+ }
+
+ private static AnnotationContent GenerateAnnotationContent(MessageTextAnnotation annotation)
+ {
+ string? fileId = null;
+ if (annotation is MessageTextFileCitationAnnotation citationAnnotation)
+ {
+ fileId = citationAnnotation.FileId;
+ }
+ else if (annotation is MessageTextFilePathAnnotation pathAnnotation)
+ {
+ fileId = pathAnnotation.FileId;
+ }
+
+ return
+ new()
+ {
+ Quote = annotation.Text,
+ StartIndex = annotation.StartIndex,
+ EndIndex = annotation.EndIndex,
+ FileId = fileId,
+ };
+ }
+
+ private static ChatMessageContent GenerateImageFileContent(string agentName, AuthorRole role, MessageImageFileContent contentImage)
+ {
+ return
+ new ChatMessageContent(
+ role,
+ [
+ new FileReferenceContent(contentImage.FileId)
+ ])
+ {
+ AuthorName = agentName,
+ };
+ }
+
+ private static ChatMessageContent? GenerateTextMessageContent(string agentName, AuthorRole role, MessageTextContent contentMessage)
+ {
+ ChatMessageContent? messageContent = null;
+
+ string textContent = contentMessage.Text.Trim();
+
+ if (!string.IsNullOrWhiteSpace(textContent))
+ {
+ messageContent =
+ new(role, textContent)
+ {
+ AuthorName = agentName
+ };
+
+ foreach (MessageTextAnnotation annotation in contentMessage.Annotations)
+ {
+ messageContent.Items.Add(GenerateAnnotationContent(annotation));
+ }
+ }
+
+ return messageContent;
+ }
+
+ private static ChatMessageContent GenerateCodeInterpreterContent(string agentName, RunStepCodeInterpreterToolCall contentCodeInterpreter)
+ {
+ return
+ new ChatMessageContent(
+ AuthorRole.Tool,
+ [
+ new TextContent(contentCodeInterpreter.Input)
+ ])
+ {
+ AuthorName = agentName,
+ };
+ }
+
+ private static ChatMessageContent GenerateFunctionCallContent(string agentName, FunctionCallContent[] functionSteps)
+ {
+ ChatMessageContent functionCallContent = new(AuthorRole.Tool, content: null)
+ {
+ AuthorName = agentName
+ };
+
+ functionCallContent.Items.AddRange(functionSteps);
+
+ return functionCallContent;
+ }
+
+ private static ChatMessageContent GenerateFunctionResultContent(string agentName, FunctionCallContent functionStep, string result)
+ {
+ ChatMessageContent functionCallContent = new(AuthorRole.Tool, content: null)
+ {
+ AuthorName = agentName
+ };
+
+ functionCallContent.Items.Add(
+ new FunctionResultContent(
+ functionStep.FunctionName,
+ functionStep.PluginName,
+ functionStep.Id,
+ result));
+
+ return functionCallContent;
+ }
+
+ private static Task[] ExecuteFunctionSteps(OpenAIAssistantAgent agent, FunctionCallContent[] functionSteps, CancellationToken cancellationToken)
+ {
+ Task[] functionTasks = new Task[functionSteps.Length];
+
+ for (int index = 0; index < functionSteps.Length; ++index)
+ {
+ functionTasks[index] = functionSteps[index].InvokeAsync(agent.Kernel, cancellationToken);
+ }
+
+ return functionTasks;
+ }
+
+ private static ToolOutput[] GenerateToolOutputs(FunctionResultContent[] functionResults)
+ {
+ ToolOutput[] toolOutputs = new ToolOutput[functionResults.Length];
+
+ for (int index = 0; index < functionResults.Length; ++index)
+ {
+ FunctionResultContent functionResult = functionResults[index];
+
+ object resultValue = functionResult.Result ?? string.Empty;
+
+ if (resultValue is not string textResult)
+ {
+ textResult = JsonSerializer.Serialize(resultValue);
+ }
+
+ toolOutputs[index] = new ToolOutput(functionResult.CallId, textResult!);
+ }
+
+ return toolOutputs;
+ }
+}
diff --git a/dotnet/src/Agents/OpenAI/OpenAIAssistantAgent.cs b/dotnet/src/Agents/OpenAI/OpenAIAssistantAgent.cs
index f101aa9ffb83..b46cdb013c18 100644
--- a/dotnet/src/Agents/OpenAI/OpenAIAssistantAgent.cs
+++ b/dotnet/src/Agents/OpenAI/OpenAIAssistantAgent.cs
@@ -162,15 +162,91 @@ public static async Task RetrieveAsync(
};
}
- ///
- public async Task DeleteAsync(CancellationToken cancellationToken = default)
+ ///
+ /// Create a new assistant thread.
+ ///
+ /// The to monitor for cancellation requests. The default is .
+ /// The thread identifier
+ public async Task CreateThreadAsync(CancellationToken cancellationToken = default)
{
- if (this.IsDeleted)
+ AssistantThread thread = await this._client.CreateThreadAsync(cancellationToken).ConfigureAwait(false);
+
+ return thread.Id;
+ }
+
+ ///
+ /// Create a new assistant thread.
+ ///
+ /// The thread identifier
+ /// The to monitor for cancellation requests. The default is .
+ /// The thread identifier
+ public async Task DeleteThreadAsync(
+ string threadId,
+ CancellationToken cancellationToken = default)
+ {
+ // Validate input
+ Verify.NotNullOrWhiteSpace(threadId, nameof(threadId));
+
+ return await this._client.DeleteThreadAsync(threadId, cancellationToken).ConfigureAwait(false);
+ }
+
+ ///
+ /// Adds a message to the specified thread.
+ ///
+ /// The thread identifier
+ /// A non-system message with which to append to the conversation.
+ /// The to monitor for cancellation requests. The default is .
+ public Task AddChatMessageAsync(string threadId, ChatMessageContent message, CancellationToken cancellationToken = default)
+ {
+ this.ThrowIfDeleted();
+
+ return AssistantThreadActions.CreateMessageAsync(this._client, threadId, message, cancellationToken);
+ }
+
+ ///
+ /// Gets messages for a specified thread.
+ ///
+ /// The thread identifier
+ /// The to monitor for cancellation requests. The default is .
+ /// Asynchronous enumeration of messages.
+ public IAsyncEnumerable GetThreadMessagesAsync(string threadId, CancellationToken cancellationToken = default)
+ {
+ this.ThrowIfDeleted();
+
+ return AssistantThreadActions.GetMessagesAsync(this._client, threadId, cancellationToken);
+ }
+
+ ///
+ /// Delete the assistant definition.
+ ///
+ ///
+ /// True if assistant definition has been deleted
+ ///
+ /// Assistant based agent will not be useable after deletion.
+ ///
+ public async Task DeleteAsync(CancellationToken cancellationToken = default)
+ {
+ if (!this.IsDeleted)
{
- return;
+ this.IsDeleted = (await this._client.DeleteAssistantAsync(this.Id, cancellationToken).ConfigureAwait(false)).Value;
}
- this.IsDeleted = (await this._client.DeleteAssistantAsync(this.Id, cancellationToken).ConfigureAwait(false)).Value;
+ return this.IsDeleted;
+ }
+
+ ///
+ /// Invoke the assistant on the specified thread.
+ ///
+ /// The thread identifier
+ /// The to monitor for cancellation requests. The default is .
+ /// Asynchronous enumeration of messages.
+ public IAsyncEnumerable InvokeAsync(
+ string threadId,
+ CancellationToken cancellationToken = default)
+ {
+ this.ThrowIfDeleted();
+
+ return AssistantThreadActions.InvokeAsync(this, this._client, threadId, this._config.Polling, this.Logger, cancellationToken);
}
///
@@ -212,7 +288,19 @@ protected override async Task CreateChannelAsync(CancellationToken
this.Logger.LogInformation("[{MethodName}] Created assistant thread: {ThreadId}", nameof(CreateChannelAsync), thread.Id);
- return new OpenAIAssistantChannel(this._client, thread.Id, this._config.Polling);
+ return
+ new OpenAIAssistantChannel(this._client, thread.Id, this._config.Polling)
+ {
+ Logger = this.LoggerFactory.CreateLogger()
+ };
+ }
+
+ internal void ThrowIfDeleted()
+ {
+ if (this.IsDeleted)
+ {
+ throw new KernelException($"Agent Failure - {nameof(OpenAIAssistantAgent)} agent is deleted: {this.Id}.");
+ }
}
///
diff --git a/dotnet/src/Agents/OpenAI/OpenAIAssistantChannel.cs b/dotnet/src/Agents/OpenAI/OpenAIAssistantChannel.cs
index 166884cf7a11..b84ef800ebd4 100644
--- a/dotnet/src/Agents/OpenAI/OpenAIAssistantChannel.cs
+++ b/dotnet/src/Agents/OpenAI/OpenAIAssistantChannel.cs
@@ -1,15 +1,8 @@
// Copyright (c) Microsoft. All rights reserved.
using System.Collections.Generic;
-using System.Linq;
-using System.Net;
-using System.Runtime.CompilerServices;
-using System.Text.Json;
using System.Threading;
using System.Threading.Tasks;
-using Azure;
using Azure.AI.OpenAI.Assistants;
-using Microsoft.Extensions.Logging;
-using Microsoft.SemanticKernel.ChatCompletion;
namespace Microsoft.SemanticKernel.Agents.OpenAI;
@@ -19,485 +12,31 @@ namespace Microsoft.SemanticKernel.Agents.OpenAI;
internal sealed class OpenAIAssistantChannel(AssistantsClient client, string threadId, OpenAIAssistantConfiguration.PollingConfiguration pollingConfiguration)
: AgentChannel
{
- private const string FunctionDelimiter = "-";
-
- private static readonly HashSet s_pollingStatuses =
- [
- RunStatus.Queued,
- RunStatus.InProgress,
- RunStatus.Cancelling,
- ];
-
- private static readonly HashSet s_terminalStatuses =
- [
- RunStatus.Expired,
- RunStatus.Failed,
- RunStatus.Cancelled,
- ];
-
private readonly AssistantsClient _client = client;
private readonly string _threadId = threadId;
- private readonly Dictionary _agentTools = [];
- private readonly Dictionary _agentNames = []; // Cache agent names by their identifier for GetHistoryAsync()
///
protected override async Task ReceiveAsync(IReadOnlyList history, CancellationToken cancellationToken)
{
foreach (ChatMessageContent message in history)
{
- if (string.IsNullOrWhiteSpace(message.Content))
- {
- continue;
- }
-
- await this._client.CreateMessageAsync(
- this._threadId,
- message.Role.ToMessageRole(),
- message.Content,
- cancellationToken: cancellationToken).ConfigureAwait(false);
+ await AssistantThreadActions.CreateMessageAsync(this._client, this._threadId, message, cancellationToken).ConfigureAwait(false);
}
}
///
- protected override async IAsyncEnumerable InvokeAsync(
+ protected override IAsyncEnumerable InvokeAsync(
OpenAIAssistantAgent agent,
- [EnumeratorCancellation] CancellationToken cancellationToken)
+ CancellationToken cancellationToken)
{
- if (agent.IsDeleted)
- {
- throw new KernelException($"Agent Failure - {nameof(OpenAIAssistantAgent)} agent is deleted: {agent.Id}.");
- }
-
- if (!this._agentTools.TryGetValue(agent.Id, out ToolDefinition[]? tools))
- {
- tools = [.. agent.Tools, .. agent.Kernel.Plugins.SelectMany(p => p.Select(f => f.ToToolDefinition(p.Name, FunctionDelimiter)))];
- this._agentTools.Add(agent.Id, tools);
- }
-
- if (!this._agentNames.ContainsKey(agent.Id) && !string.IsNullOrWhiteSpace(agent.Name))
- {
- this._agentNames.Add(agent.Id, agent.Name);
- }
-
- this.Logger.LogDebug("[{MethodName}] Creating run for agent/thrad: {AgentId}/{ThreadId}", nameof(InvokeAsync), agent.Id, this._threadId);
-
- CreateRunOptions options =
- new(agent.Id)
- {
- OverrideInstructions = agent.Instructions,
- OverrideTools = tools,
- };
-
- // Create run
- ThreadRun run = await this._client.CreateRunAsync(this._threadId, options, cancellationToken).ConfigureAwait(false);
-
- this.Logger.LogInformation("[{MethodName}] Created run: {RunId}", nameof(InvokeAsync), run.Id);
-
- // Evaluate status and process steps and messages, as encountered.
- HashSet processedStepIds = [];
- Dictionary functionSteps = [];
-
- do
- {
- // Poll run and steps until actionable
- PageableList steps = await PollRunStatusAsync().ConfigureAwait(false);
-
- // Is in terminal state?
- if (s_terminalStatuses.Contains(run.Status))
- {
- throw new KernelException($"Agent Failure - Run terminated: {run.Status} [{run.Id}]: {run.LastError?.Message ?? "Unknown"}");
- }
-
- // Is tool action required?
- if (run.Status == RunStatus.RequiresAction)
- {
- this.Logger.LogDebug("[{MethodName}] Processing run steps: {RunId}", nameof(InvokeAsync), run.Id);
-
- // Execute functions in parallel and post results at once.
- FunctionCallContent[] activeFunctionSteps = steps.Data.SelectMany(step => ParseFunctionStep(agent, step)).ToArray();
- if (activeFunctionSteps.Length > 0)
- {
- // Emit function-call content
- yield return GenerateFunctionCallContent(agent.GetName(), activeFunctionSteps);
-
- // Invoke functions for each tool-step
- IEnumerable> functionResultTasks = ExecuteFunctionSteps(agent, activeFunctionSteps, cancellationToken);
-
- // Block for function results
- FunctionResultContent[] functionResults = await Task.WhenAll(functionResultTasks).ConfigureAwait(false);
-
- // Process tool output
- ToolOutput[] toolOutputs = GenerateToolOutputs(functionResults);
-
- await this._client.SubmitToolOutputsToRunAsync(run, toolOutputs, cancellationToken).ConfigureAwait(false);
- }
-
- if (this.Logger.IsEnabled(LogLevel.Information)) // Avoid boxing if not enabled
- {
- this.Logger.LogInformation("[{MethodName}] Processed #{MessageCount} run steps: {RunId}", nameof(InvokeAsync), activeFunctionSteps.Length, run.Id);
- }
- }
-
- // Enumerate completed messages
- this.Logger.LogDebug("[{MethodName}] Processing run messages: {RunId}", nameof(InvokeAsync), run.Id);
-
- IEnumerable completedStepsToProcess =
- steps
- .Where(s => s.CompletedAt.HasValue && !processedStepIds.Contains(s.Id))
- .OrderBy(s => s.CreatedAt);
-
- int messageCount = 0;
- foreach (RunStep completedStep in completedStepsToProcess)
- {
- if (completedStep.Type.Equals(RunStepType.ToolCalls))
- {
- RunStepToolCallDetails toolCallDetails = (RunStepToolCallDetails)completedStep.StepDetails;
-
- foreach (RunStepToolCall toolCall in toolCallDetails.ToolCalls)
- {
- ChatMessageContent? content = null;
-
- // Process code-interpreter content
- if (toolCall is RunStepCodeInterpreterToolCall toolCodeInterpreter)
- {
- content = GenerateCodeInterpreterContent(agent.GetName(), toolCodeInterpreter);
- }
- // Process function result content
- else if (toolCall is RunStepFunctionToolCall toolFunction)
- {
- FunctionCallContent functionStep = functionSteps[toolFunction.Id]; // Function step always captured on invocation
- content = GenerateFunctionResultContent(agent.GetName(), functionStep, toolFunction.Output);
- }
+ agent.ThrowIfDeleted();
- if (content is not null)
- {
- ++messageCount;
-
- yield return content;
- }
- }
- }
- else if (completedStep.Type.Equals(RunStepType.MessageCreation))
- {
- RunStepMessageCreationDetails messageCreationDetails = (RunStepMessageCreationDetails)completedStep.StepDetails;
-
- // Retrieve the message
- ThreadMessage? message = await this.RetrieveMessageAsync(messageCreationDetails, cancellationToken).ConfigureAwait(false);
-
- if (message is not null)
- {
- AuthorRole role = new(message.Role.ToString());
-
- foreach (MessageContent itemContent in message.ContentItems)
- {
- ChatMessageContent? content = null;
-
- // Process text content
- if (itemContent is MessageTextContent contentMessage)
- {
- content = GenerateTextMessageContent(agent.GetName(), role, contentMessage);
- }
- // Process image content
- else if (itemContent is MessageImageFileContent contentImage)
- {
- content = GenerateImageFileContent(agent.GetName(), role, contentImage);
- }
-
- if (content is not null)
- {
- ++messageCount;
-
- yield return content;
- }
- }
- }
- }
-
- processedStepIds.Add(completedStep.Id);
- }
-
- if (this.Logger.IsEnabled(LogLevel.Information)) // Avoid boxing if not enabled
- {
- this.Logger.LogInformation("[{MethodName}] Processed #{MessageCount} run messages: {RunId}", nameof(InvokeAsync), messageCount, run.Id);
- }
- }
- while (RunStatus.Completed != run.Status);
-
- this.Logger.LogInformation("[{MethodName}] Completed run: {RunId}", nameof(InvokeAsync), run.Id);
-
- // Local function to assist in run polling (participates in method closure).
- async Task> PollRunStatusAsync()
- {
- this.Logger.LogInformation("[{MethodName}] Polling run status: {RunId}", nameof(PollRunStatusAsync), run.Id);
-
- int count = 0;
-
- do
- {
- // Reduce polling frequency after a couple attempts
- await Task.Delay(count >= 2 ? pollingConfiguration.RunPollingInterval : pollingConfiguration.RunPollingBackoff, cancellationToken).ConfigureAwait(false);
- ++count;
-
-#pragma warning disable CA1031 // Do not catch general exception types
- try
- {
- run = await this._client.GetRunAsync(this._threadId, run.Id, cancellationToken).ConfigureAwait(false);
- }
- catch
- {
- // Retry anyway..
- }
-#pragma warning restore CA1031 // Do not catch general exception types
- }
- while (s_pollingStatuses.Contains(run.Status));
-
- this.Logger.LogInformation("[{MethodName}] Run status is {RunStatus}: {RunId}", nameof(PollRunStatusAsync), run.Status, run.Id);
-
- return await this._client.GetRunStepsAsync(run, cancellationToken: cancellationToken).ConfigureAwait(false);
- }
-
- // Local function to capture kernel function state for further processing (participates in method closure).
- IEnumerable ParseFunctionStep(OpenAIAssistantAgent agent, RunStep step)
- {
- if (step.Status == RunStepStatus.InProgress && step.StepDetails is RunStepToolCallDetails callDetails)
- {
- foreach (RunStepFunctionToolCall toolCall in callDetails.ToolCalls.OfType())
- {
- var nameParts = FunctionName.Parse(toolCall.Name, FunctionDelimiter);
-
- KernelArguments functionArguments = [];
- if (!string.IsNullOrWhiteSpace(toolCall.Arguments))
- {
- Dictionary arguments = JsonSerializer.Deserialize>(toolCall.Arguments)!;
- foreach (var argumentKvp in arguments)
- {
- functionArguments[argumentKvp.Key] = argumentKvp.Value.ToString();
- }
- }
-
- var content = new FunctionCallContent(nameParts.Name, nameParts.PluginName, toolCall.Id, functionArguments);
-
- functionSteps.Add(toolCall.Id, content);
-
- yield return content;
- }
- }
- }
+ return AssistantThreadActions.InvokeAsync(agent, this._client, this._threadId, pollingConfiguration, this.Logger, cancellationToken);
}
///
- protected override async IAsyncEnumerable GetHistoryAsync([EnumeratorCancellation] CancellationToken cancellationToken)
- {
- PageableList messages;
-
- string? lastId = null;
- do
- {
- messages = await this._client.GetMessagesAsync(this._threadId, limit: 100, ListSortOrder.Descending, after: lastId, null, cancellationToken).ConfigureAwait(false);
- foreach (ThreadMessage message in messages)
- {
- AuthorRole role = new(message.Role.ToString());
-
- string? assistantName = null;
- if (!string.IsNullOrWhiteSpace(message.AssistantId) &&
- !this._agentNames.TryGetValue(message.AssistantId, out assistantName))
- {
- Assistant assistant = await this._client.GetAssistantAsync(message.AssistantId, cancellationToken).ConfigureAwait(false);
- if (!string.IsNullOrWhiteSpace(assistant.Name))
- {
- this._agentNames.Add(assistant.Id, assistant.Name);
- }
- }
-
- assistantName ??= message.AssistantId;
-
- foreach (MessageContent item in message.ContentItems)
- {
- ChatMessageContent? content = null;
-
- if (item is MessageTextContent contentMessage)
- {
- content = GenerateTextMessageContent(assistantName, role, contentMessage);
- }
- else if (item is MessageImageFileContent contentImage)
- {
- content = GenerateImageFileContent(assistantName, role, contentImage);
- }
-
- if (content is not null)
- {
- yield return content;
- }
- }
-
- lastId = message.Id;
- }
- }
- while (messages.HasMore);
- }
-
- private static AnnotationContent GenerateAnnotationContent(MessageTextAnnotation annotation)
- {
- string? fileId = null;
- if (annotation is MessageTextFileCitationAnnotation citationAnnotation)
- {
- fileId = citationAnnotation.FileId;
- }
- else if (annotation is MessageTextFilePathAnnotation pathAnnotation)
- {
- fileId = pathAnnotation.FileId;
- }
-
- return
- new()
- {
- Quote = annotation.Text,
- StartIndex = annotation.StartIndex,
- EndIndex = annotation.EndIndex,
- FileId = fileId,
- };
- }
-
- private static ChatMessageContent GenerateImageFileContent(string agentName, AuthorRole role, MessageImageFileContent contentImage)
- {
- return
- new ChatMessageContent(
- role,
- [
- new FileReferenceContent(contentImage.FileId)
- ])
- {
- AuthorName = agentName,
- };
- }
-
- private static ChatMessageContent? GenerateTextMessageContent(string agentName, AuthorRole role, MessageTextContent contentMessage)
- {
- ChatMessageContent? messageContent = null;
-
- string textContent = contentMessage.Text.Trim();
-
- if (!string.IsNullOrWhiteSpace(textContent))
- {
- messageContent =
- new(role, textContent)
- {
- AuthorName = agentName
- };
-
- foreach (MessageTextAnnotation annotation in contentMessage.Annotations)
- {
- messageContent.Items.Add(GenerateAnnotationContent(annotation));
- }
- }
-
- return messageContent;
- }
-
- private static ChatMessageContent GenerateCodeInterpreterContent(string agentName, RunStepCodeInterpreterToolCall contentCodeInterpreter)
- {
- return
- new ChatMessageContent(
- AuthorRole.Tool,
- [
- new TextContent(contentCodeInterpreter.Input)
- ])
- {
- AuthorName = agentName,
- };
- }
-
- private static ChatMessageContent GenerateFunctionCallContent(string agentName, FunctionCallContent[] functionSteps)
- {
- ChatMessageContent functionCallContent = new(AuthorRole.Tool, content: null)
- {
- AuthorName = agentName
- };
-
- functionCallContent.Items.AddRange(functionSteps);
-
- return functionCallContent;
- }
-
- private static ChatMessageContent GenerateFunctionResultContent(string agentName, FunctionCallContent functionStep, string result)
- {
- ChatMessageContent functionCallContent = new(AuthorRole.Tool, content: null)
- {
- AuthorName = agentName
- };
-
- functionCallContent.Items.Add(
- new FunctionResultContent(
- functionStep.FunctionName,
- functionStep.PluginName,
- functionStep.Id,
- result));
-
- return functionCallContent;
- }
-
- private static Task[] ExecuteFunctionSteps(OpenAIAssistantAgent agent, FunctionCallContent[] functionSteps, CancellationToken cancellationToken)
- {
- Task[] functionTasks = new Task[functionSteps.Length];
-
- for (int index = 0; index < functionSteps.Length; ++index)
- {
- functionTasks[index] = functionSteps[index].InvokeAsync(agent.Kernel, cancellationToken);
- }
-
- return functionTasks;
- }
-
- private static ToolOutput[] GenerateToolOutputs(FunctionResultContent[] functionResults)
+ protected override IAsyncEnumerable GetHistoryAsync(CancellationToken cancellationToken)
{
- ToolOutput[] toolOutputs = new ToolOutput[functionResults.Length];
-
- for (int index = 0; index < functionResults.Length; ++index)
- {
- FunctionResultContent functionResult = functionResults[index];
-
- object resultValue = functionResult.Result ?? string.Empty;
-
- if (resultValue is not string textResult)
- {
- textResult = JsonSerializer.Serialize(resultValue);
- }
-
- toolOutputs[index] = new ToolOutput(functionResult.CallId, textResult!);
- }
-
- return toolOutputs;
- }
-
- private async Task RetrieveMessageAsync(RunStepMessageCreationDetails detail, CancellationToken cancellationToken)
- {
- ThreadMessage? message = null;
-
- bool retry = false;
- int count = 0;
- do
- {
- try
- {
- message = await this._client.GetMessageAsync(this._threadId, detail.MessageCreation.MessageId, cancellationToken).ConfigureAwait(false);
- }
- catch (RequestFailedException exception)
- {
- // Step has provided the message-id. Retry on of NotFound/404 exists.
- // Extremely rarely there might be a synchronization issue between the
- // assistant response and message-service.
- retry = exception.Status == (int)HttpStatusCode.NotFound && count < 3;
- }
-
- if (retry)
- {
- await Task.Delay(pollingConfiguration.MessageSynchronizationDelay, cancellationToken).ConfigureAwait(false);
- }
-
- ++count;
- }
- while (retry);
-
- return message;
+ return AssistantThreadActions.GetMessagesAsync(this._client, this._threadId, cancellationToken);
}
}