diff --git a/dotnet/samples/GettingStartedWithAgents/Step1_Agent.cs b/dotnet/samples/GettingStartedWithAgents/Step1_Agent.cs index 682c96001deb..c9ffcdac8a84 100644 --- a/dotnet/samples/GettingStartedWithAgents/Step1_Agent.cs +++ b/dotnet/samples/GettingStartedWithAgents/Step1_Agent.cs @@ -27,7 +27,7 @@ public async Task UseSingleChatCompletionAgentAsync() }; /// Create a chat for agent interaction. For more, . - AgentGroupChat chat = new(); + ChatHistory chat = new(); // Respond to user input await InvokeAgentAsync("Fortune favors the bold."); @@ -37,11 +37,11 @@ public async Task UseSingleChatCompletionAgentAsync() // Local function to invoke agent and display the conversation messages. async Task InvokeAgentAsync(string input) { - chat.AddChatMessage(new ChatMessageContent(AuthorRole.User, input)); + chat.Add(new ChatMessageContent(AuthorRole.User, input)); Console.WriteLine($"# {AuthorRole.User}: '{input}'"); - await foreach (var content in chat.InvokeAsync(agent)) + await foreach (var content in agent.InvokeAsync(chat)) { Console.WriteLine($"# {content.Role} - {content.AuthorName ?? "*"}: '{content.Content}'"); } diff --git a/dotnet/samples/GettingStartedWithAgents/Step8_OpenAIAssistant.cs b/dotnet/samples/GettingStartedWithAgents/Step8_OpenAIAssistant.cs index 5f03ffb39c8f..09afcfc44826 100644 --- a/dotnet/samples/GettingStartedWithAgents/Step8_OpenAIAssistant.cs +++ b/dotnet/samples/GettingStartedWithAgents/Step8_OpenAIAssistant.cs @@ -37,7 +37,7 @@ await OpenAIAssistantAgent.CreateAsync( agent.Kernel.Plugins.Add(plugin); // Create a chat for agent interaction. - var chat = new AgentGroupChat(); + string threadId = await agent.CreateThreadAsync(); // Respond to user input try @@ -49,19 +49,23 @@ await OpenAIAssistantAgent.CreateAsync( } finally { + await agent.DeleteThreadAsync(threadId); await agent.DeleteAsync(); } // Local function to invoke agent and display the conversation messages. async Task InvokeAgentAsync(string input) { - chat.AddChatMessage(new ChatMessageContent(AuthorRole.User, input)); + await agent.AddChatMessageAsync(threadId, new ChatMessageContent(AuthorRole.User, input)); Console.WriteLine($"# {AuthorRole.User}: '{input}'"); - await foreach (var content in chat.InvokeAsync(agent)) + await foreach (var content in agent.InvokeAsync(threadId)) { - Console.WriteLine($"# {content.Role} - {content.AuthorName ?? "*"}: '{content.Content}'"); + if (content.Role != AuthorRole.Tool) + { + Console.WriteLine($"# {content.Role} - {content.AuthorName ?? "*"}: '{content.Content}'"); + } } } } diff --git a/dotnet/src/Agents/OpenAI/AssistantThreadActions.cs b/dotnet/src/Agents/OpenAI/AssistantThreadActions.cs new file mode 100644 index 000000000000..37649844a230 --- /dev/null +++ b/dotnet/src/Agents/OpenAI/AssistantThreadActions.cs @@ -0,0 +1,525 @@ +// Copyright (c) Microsoft. All rights reserved. +using System.Collections.Generic; +using System.Linq; +using System.Net; +using System.Runtime.CompilerServices; +using System.Text.Json; +using System.Threading; +using System.Threading.Tasks; +using Azure; +using Azure.AI.OpenAI.Assistants; +using Microsoft.Extensions.Logging; +using Microsoft.SemanticKernel.ChatCompletion; + +namespace Microsoft.SemanticKernel.Agents.OpenAI; + +/// +/// Actions associated with an Open Assistant thread. +/// +internal static class AssistantThreadActions +{ + /*AssistantsClient client, string threadId, OpenAIAssistantConfiguration.PollingConfiguration pollingConfiguration*/ + private const string FunctionDelimiter = "-"; + + private static readonly HashSet s_messageRoles = + [ + AuthorRole.User, + AuthorRole.Assistant, + ]; + + private static readonly HashSet s_pollingStatuses = + [ + RunStatus.Queued, + RunStatus.InProgress, + RunStatus.Cancelling, + ]; + + private static readonly HashSet s_terminalStatuses = + [ + RunStatus.Expired, + RunStatus.Failed, + RunStatus.Cancelled, + ]; + + /// + /// Create a message in the specified thread. + /// + /// The assistant client + /// The thread identifier + /// The message to add + /// The to monitor for cancellation requests. The default is . + /// if a system message is present, without taking any other action + public static async Task CreateMessageAsync(AssistantsClient client, string threadId, ChatMessageContent message, CancellationToken cancellationToken) + { + if (!s_messageRoles.Contains(message.Role)) + { + throw new KernelException($"Invalid message role: {message.Role}"); + } + + if (string.IsNullOrWhiteSpace(message.Content)) + { + return; + } + + await client.CreateMessageAsync( + threadId, + message.Role.ToMessageRole(), + message.Content, + cancellationToken: cancellationToken).ConfigureAwait(false); + } + + /// + /// Retrieves the thread messages. + /// + /// The assistant client + /// The thread identifier + /// The to monitor for cancellation requests. The default is . + /// Asynchronous enumeration of messages. + public static async IAsyncEnumerable GetMessagesAsync(AssistantsClient client, string threadId, [EnumeratorCancellation] CancellationToken cancellationToken) + { + Dictionary agentNames = []; // Cache agent names by their identifier + + PageableList messages; + + string? lastId = null; + do + { + messages = await client.GetMessagesAsync(threadId, limit: 100, ListSortOrder.Descending, after: lastId, null, cancellationToken).ConfigureAwait(false); + foreach (ThreadMessage message in messages) + { + AuthorRole role = new(message.Role.ToString()); + + string? assistantName = null; + if (!string.IsNullOrWhiteSpace(message.AssistantId) && + !agentNames.TryGetValue(message.AssistantId, out assistantName)) + { + Assistant assistant = await client.GetAssistantAsync(message.AssistantId, cancellationToken).ConfigureAwait(false); + if (!string.IsNullOrWhiteSpace(assistant.Name)) + { + agentNames.Add(assistant.Id, assistant.Name); + } + } + + assistantName ??= message.AssistantId; + + foreach (MessageContent item in message.ContentItems) + { + ChatMessageContent? content = null; + + if (item is MessageTextContent contentMessage) + { + content = GenerateTextMessageContent(assistantName, role, contentMessage); + } + else if (item is MessageImageFileContent contentImage) + { + content = GenerateImageFileContent(assistantName, role, contentImage); + } + + if (content is not null) + { + yield return content; + } + } + + lastId = message.Id; + } + } + while (messages.HasMore); + } + + /// + /// Invoke the assistant on the specified thread. + /// + /// The assistant agent to interact with the thread. + /// The assistant client + /// The thread identifier + /// Config to utilize when polling for run state. + /// The logger to utilize (might be agent or channel scoped) + /// The to monitor for cancellation requests. The default is . + /// Asynchronous enumeration of messages. + public static async IAsyncEnumerable InvokeAsync( + OpenAIAssistantAgent agent, + AssistantsClient client, + string threadId, + OpenAIAssistantConfiguration.PollingConfiguration pollingConfiguration, + ILogger logger, + [EnumeratorCancellation] CancellationToken cancellationToken) + { + if (agent.IsDeleted) + { + throw new KernelException($"Agent Failure - {nameof(OpenAIAssistantAgent)} agent is deleted: {agent.Id}."); + } + + ToolDefinition[]? tools = [.. agent.Tools, .. agent.Kernel.Plugins.SelectMany(p => p.Select(f => f.ToToolDefinition(p.Name, FunctionDelimiter)))]; + + logger.LogDebug("[{MethodName}] Creating run for agent/thrad: {AgentId}/{ThreadId}", nameof(InvokeAsync), agent.Id, threadId); + + CreateRunOptions options = + new(agent.Id) + { + OverrideInstructions = agent.Instructions, + OverrideTools = tools, + }; + + // Create run + ThreadRun run = await client.CreateRunAsync(threadId, options, cancellationToken).ConfigureAwait(false); + + logger.LogInformation("[{MethodName}] Created run: {RunId}", nameof(InvokeAsync), run.Id); + + // Evaluate status and process steps and messages, as encountered. + HashSet processedStepIds = []; + Dictionary functionSteps = []; + + do + { + // Poll run and steps until actionable + PageableList steps = await PollRunStatusAsync().ConfigureAwait(false); + + // Is in terminal state? + if (s_terminalStatuses.Contains(run.Status)) + { + throw new KernelException($"Agent Failure - Run terminated: {run.Status} [{run.Id}]: {run.LastError?.Message ?? "Unknown"}"); + } + + // Is tool action required? + if (run.Status == RunStatus.RequiresAction) + { + logger.LogDebug("[{MethodName}] Processing run steps: {RunId}", nameof(InvokeAsync), run.Id); + + // Execute functions in parallel and post results at once. + FunctionCallContent[] activeFunctionSteps = steps.Data.SelectMany(step => ParseFunctionStep(agent, step)).ToArray(); + if (activeFunctionSteps.Length > 0) + { + // Emit function-call content + yield return GenerateFunctionCallContent(agent.GetName(), activeFunctionSteps); + + // Invoke functions for each tool-step + IEnumerable> functionResultTasks = ExecuteFunctionSteps(agent, activeFunctionSteps, cancellationToken); + + // Block for function results + FunctionResultContent[] functionResults = await Task.WhenAll(functionResultTasks).ConfigureAwait(false); + + // Process tool output + ToolOutput[] toolOutputs = GenerateToolOutputs(functionResults); + + await client.SubmitToolOutputsToRunAsync(run, toolOutputs, cancellationToken).ConfigureAwait(false); + } + + if (logger.IsEnabled(LogLevel.Information)) // Avoid boxing if not enabled + { + logger.LogInformation("[{MethodName}] Processed #{MessageCount} run steps: {RunId}", nameof(InvokeAsync), activeFunctionSteps.Length, run.Id); + } + } + + // Enumerate completed messages + logger.LogDebug("[{MethodName}] Processing run messages: {RunId}", nameof(InvokeAsync), run.Id); + + IEnumerable completedStepsToProcess = + steps + .Where(s => s.CompletedAt.HasValue && !processedStepIds.Contains(s.Id)) + .OrderBy(s => s.CreatedAt); + + int messageCount = 0; + foreach (RunStep completedStep in completedStepsToProcess) + { + if (completedStep.Type.Equals(RunStepType.ToolCalls)) + { + RunStepToolCallDetails toolCallDetails = (RunStepToolCallDetails)completedStep.StepDetails; + + foreach (RunStepToolCall toolCall in toolCallDetails.ToolCalls) + { + ChatMessageContent? content = null; + + // Process code-interpreter content + if (toolCall is RunStepCodeInterpreterToolCall toolCodeInterpreter) + { + content = GenerateCodeInterpreterContent(agent.GetName(), toolCodeInterpreter); + } + // Process function result content + else if (toolCall is RunStepFunctionToolCall toolFunction) + { + FunctionCallContent functionStep = functionSteps[toolFunction.Id]; // Function step always captured on invocation + content = GenerateFunctionResultContent(agent.GetName(), functionStep, toolFunction.Output); + } + + if (content is not null) + { + ++messageCount; + + yield return content; + } + } + } + else if (completedStep.Type.Equals(RunStepType.MessageCreation)) + { + RunStepMessageCreationDetails messageCreationDetails = (RunStepMessageCreationDetails)completedStep.StepDetails; + + // Retrieve the message + ThreadMessage? message = await RetrieveMessageAsync(messageCreationDetails, cancellationToken).ConfigureAwait(false); + + if (message is not null) + { + AuthorRole role = new(message.Role.ToString()); + + foreach (MessageContent itemContent in message.ContentItems) + { + ChatMessageContent? content = null; + + // Process text content + if (itemContent is MessageTextContent contentMessage) + { + content = GenerateTextMessageContent(agent.GetName(), role, contentMessage); + } + // Process image content + else if (itemContent is MessageImageFileContent contentImage) + { + content = GenerateImageFileContent(agent.GetName(), role, contentImage); + } + + if (content is not null) + { + ++messageCount; + + yield return content; + } + } + } + } + + processedStepIds.Add(completedStep.Id); + } + + if (logger.IsEnabled(LogLevel.Information)) // Avoid boxing if not enabled + { + logger.LogInformation("[{MethodName}] Processed #{MessageCount} run messages: {RunId}", nameof(InvokeAsync), messageCount, run.Id); + } + } + while (RunStatus.Completed != run.Status); + + logger.LogInformation("[{MethodName}] Completed run: {RunId}", nameof(InvokeAsync), run.Id); + + // Local function to assist in run polling (participates in method closure). + async Task> PollRunStatusAsync() + { + logger.LogInformation("[{MethodName}] Polling run status: {RunId}", nameof(PollRunStatusAsync), run.Id); + + int count = 0; + + do + { + // Reduce polling frequency after a couple attempts + await Task.Delay(count >= 2 ? pollingConfiguration.RunPollingInterval : pollingConfiguration.RunPollingBackoff, cancellationToken).ConfigureAwait(false); + ++count; + +#pragma warning disable CA1031 // Do not catch general exception types + try + { + run = await client.GetRunAsync(threadId, run.Id, cancellationToken).ConfigureAwait(false); + } + catch + { + // Retry anyway.. + } +#pragma warning restore CA1031 // Do not catch general exception types + } + while (s_pollingStatuses.Contains(run.Status)); + + logger.LogInformation("[{MethodName}] Run status is {RunStatus}: {RunId}", nameof(PollRunStatusAsync), run.Status, run.Id); + + return await client.GetRunStepsAsync(run, cancellationToken: cancellationToken).ConfigureAwait(false); + } + + // Local function to capture kernel function state for further processing (participates in method closure). + IEnumerable ParseFunctionStep(OpenAIAssistantAgent agent, RunStep step) + { + if (step.Status == RunStepStatus.InProgress && step.StepDetails is RunStepToolCallDetails callDetails) + { + foreach (RunStepFunctionToolCall toolCall in callDetails.ToolCalls.OfType()) + { + var nameParts = FunctionName.Parse(toolCall.Name, FunctionDelimiter); + + KernelArguments functionArguments = []; + if (!string.IsNullOrWhiteSpace(toolCall.Arguments)) + { + Dictionary arguments = JsonSerializer.Deserialize>(toolCall.Arguments)!; + foreach (var argumentKvp in arguments) + { + functionArguments[argumentKvp.Key] = argumentKvp.Value.ToString(); + } + } + + var content = new FunctionCallContent(nameParts.Name, nameParts.PluginName, toolCall.Id, functionArguments); + + functionSteps.Add(toolCall.Id, content); + + yield return content; + } + } + } + + async Task RetrieveMessageAsync(RunStepMessageCreationDetails detail, CancellationToken cancellationToken) + { + ThreadMessage? message = null; + + bool retry = false; + int count = 0; + do + { + try + { + message = await client.GetMessageAsync(threadId, detail.MessageCreation.MessageId, cancellationToken).ConfigureAwait(false); + } + catch (RequestFailedException exception) + { + // Step has provided the message-id. Retry on of NotFound/404 exists. + // Extremely rarely there might be a synchronization issue between the + // assistant response and message-service. + retry = exception.Status == (int)HttpStatusCode.NotFound && count < 3; + } + + if (retry) + { + await Task.Delay(pollingConfiguration.MessageSynchronizationDelay, cancellationToken).ConfigureAwait(false); + } + + ++count; + } + while (retry); + + return message; + } + } + + private static AnnotationContent GenerateAnnotationContent(MessageTextAnnotation annotation) + { + string? fileId = null; + if (annotation is MessageTextFileCitationAnnotation citationAnnotation) + { + fileId = citationAnnotation.FileId; + } + else if (annotation is MessageTextFilePathAnnotation pathAnnotation) + { + fileId = pathAnnotation.FileId; + } + + return + new() + { + Quote = annotation.Text, + StartIndex = annotation.StartIndex, + EndIndex = annotation.EndIndex, + FileId = fileId, + }; + } + + private static ChatMessageContent GenerateImageFileContent(string agentName, AuthorRole role, MessageImageFileContent contentImage) + { + return + new ChatMessageContent( + role, + [ + new FileReferenceContent(contentImage.FileId) + ]) + { + AuthorName = agentName, + }; + } + + private static ChatMessageContent? GenerateTextMessageContent(string agentName, AuthorRole role, MessageTextContent contentMessage) + { + ChatMessageContent? messageContent = null; + + string textContent = contentMessage.Text.Trim(); + + if (!string.IsNullOrWhiteSpace(textContent)) + { + messageContent = + new(role, textContent) + { + AuthorName = agentName + }; + + foreach (MessageTextAnnotation annotation in contentMessage.Annotations) + { + messageContent.Items.Add(GenerateAnnotationContent(annotation)); + } + } + + return messageContent; + } + + private static ChatMessageContent GenerateCodeInterpreterContent(string agentName, RunStepCodeInterpreterToolCall contentCodeInterpreter) + { + return + new ChatMessageContent( + AuthorRole.Tool, + [ + new TextContent(contentCodeInterpreter.Input) + ]) + { + AuthorName = agentName, + }; + } + + private static ChatMessageContent GenerateFunctionCallContent(string agentName, FunctionCallContent[] functionSteps) + { + ChatMessageContent functionCallContent = new(AuthorRole.Tool, content: null) + { + AuthorName = agentName + }; + + functionCallContent.Items.AddRange(functionSteps); + + return functionCallContent; + } + + private static ChatMessageContent GenerateFunctionResultContent(string agentName, FunctionCallContent functionStep, string result) + { + ChatMessageContent functionCallContent = new(AuthorRole.Tool, content: null) + { + AuthorName = agentName + }; + + functionCallContent.Items.Add( + new FunctionResultContent( + functionStep.FunctionName, + functionStep.PluginName, + functionStep.Id, + result)); + + return functionCallContent; + } + + private static Task[] ExecuteFunctionSteps(OpenAIAssistantAgent agent, FunctionCallContent[] functionSteps, CancellationToken cancellationToken) + { + Task[] functionTasks = new Task[functionSteps.Length]; + + for (int index = 0; index < functionSteps.Length; ++index) + { + functionTasks[index] = functionSteps[index].InvokeAsync(agent.Kernel, cancellationToken); + } + + return functionTasks; + } + + private static ToolOutput[] GenerateToolOutputs(FunctionResultContent[] functionResults) + { + ToolOutput[] toolOutputs = new ToolOutput[functionResults.Length]; + + for (int index = 0; index < functionResults.Length; ++index) + { + FunctionResultContent functionResult = functionResults[index]; + + object resultValue = functionResult.Result ?? string.Empty; + + if (resultValue is not string textResult) + { + textResult = JsonSerializer.Serialize(resultValue); + } + + toolOutputs[index] = new ToolOutput(functionResult.CallId, textResult!); + } + + return toolOutputs; + } +} diff --git a/dotnet/src/Agents/OpenAI/OpenAIAssistantAgent.cs b/dotnet/src/Agents/OpenAI/OpenAIAssistantAgent.cs index f101aa9ffb83..b46cdb013c18 100644 --- a/dotnet/src/Agents/OpenAI/OpenAIAssistantAgent.cs +++ b/dotnet/src/Agents/OpenAI/OpenAIAssistantAgent.cs @@ -162,15 +162,91 @@ public static async Task RetrieveAsync( }; } - /// - public async Task DeleteAsync(CancellationToken cancellationToken = default) + /// + /// Create a new assistant thread. + /// + /// The to monitor for cancellation requests. The default is . + /// The thread identifier + public async Task CreateThreadAsync(CancellationToken cancellationToken = default) { - if (this.IsDeleted) + AssistantThread thread = await this._client.CreateThreadAsync(cancellationToken).ConfigureAwait(false); + + return thread.Id; + } + + /// + /// Create a new assistant thread. + /// + /// The thread identifier + /// The to monitor for cancellation requests. The default is . + /// The thread identifier + public async Task DeleteThreadAsync( + string threadId, + CancellationToken cancellationToken = default) + { + // Validate input + Verify.NotNullOrWhiteSpace(threadId, nameof(threadId)); + + return await this._client.DeleteThreadAsync(threadId, cancellationToken).ConfigureAwait(false); + } + + /// + /// Adds a message to the specified thread. + /// + /// The thread identifier + /// A non-system message with which to append to the conversation. + /// The to monitor for cancellation requests. The default is . + public Task AddChatMessageAsync(string threadId, ChatMessageContent message, CancellationToken cancellationToken = default) + { + this.ThrowIfDeleted(); + + return AssistantThreadActions.CreateMessageAsync(this._client, threadId, message, cancellationToken); + } + + /// + /// Gets messages for a specified thread. + /// + /// The thread identifier + /// The to monitor for cancellation requests. The default is . + /// Asynchronous enumeration of messages. + public IAsyncEnumerable GetThreadMessagesAsync(string threadId, CancellationToken cancellationToken = default) + { + this.ThrowIfDeleted(); + + return AssistantThreadActions.GetMessagesAsync(this._client, threadId, cancellationToken); + } + + /// + /// Delete the assistant definition. + /// + /// + /// True if assistant definition has been deleted + /// + /// Assistant based agent will not be useable after deletion. + /// + public async Task DeleteAsync(CancellationToken cancellationToken = default) + { + if (!this.IsDeleted) { - return; + this.IsDeleted = (await this._client.DeleteAssistantAsync(this.Id, cancellationToken).ConfigureAwait(false)).Value; } - this.IsDeleted = (await this._client.DeleteAssistantAsync(this.Id, cancellationToken).ConfigureAwait(false)).Value; + return this.IsDeleted; + } + + /// + /// Invoke the assistant on the specified thread. + /// + /// The thread identifier + /// The to monitor for cancellation requests. The default is . + /// Asynchronous enumeration of messages. + public IAsyncEnumerable InvokeAsync( + string threadId, + CancellationToken cancellationToken = default) + { + this.ThrowIfDeleted(); + + return AssistantThreadActions.InvokeAsync(this, this._client, threadId, this._config.Polling, this.Logger, cancellationToken); } /// @@ -212,7 +288,19 @@ protected override async Task CreateChannelAsync(CancellationToken this.Logger.LogInformation("[{MethodName}] Created assistant thread: {ThreadId}", nameof(CreateChannelAsync), thread.Id); - return new OpenAIAssistantChannel(this._client, thread.Id, this._config.Polling); + return + new OpenAIAssistantChannel(this._client, thread.Id, this._config.Polling) + { + Logger = this.LoggerFactory.CreateLogger() + }; + } + + internal void ThrowIfDeleted() + { + if (this.IsDeleted) + { + throw new KernelException($"Agent Failure - {nameof(OpenAIAssistantAgent)} agent is deleted: {this.Id}."); + } } /// diff --git a/dotnet/src/Agents/OpenAI/OpenAIAssistantChannel.cs b/dotnet/src/Agents/OpenAI/OpenAIAssistantChannel.cs index 166884cf7a11..b84ef800ebd4 100644 --- a/dotnet/src/Agents/OpenAI/OpenAIAssistantChannel.cs +++ b/dotnet/src/Agents/OpenAI/OpenAIAssistantChannel.cs @@ -1,15 +1,8 @@ // Copyright (c) Microsoft. All rights reserved. using System.Collections.Generic; -using System.Linq; -using System.Net; -using System.Runtime.CompilerServices; -using System.Text.Json; using System.Threading; using System.Threading.Tasks; -using Azure; using Azure.AI.OpenAI.Assistants; -using Microsoft.Extensions.Logging; -using Microsoft.SemanticKernel.ChatCompletion; namespace Microsoft.SemanticKernel.Agents.OpenAI; @@ -19,485 +12,31 @@ namespace Microsoft.SemanticKernel.Agents.OpenAI; internal sealed class OpenAIAssistantChannel(AssistantsClient client, string threadId, OpenAIAssistantConfiguration.PollingConfiguration pollingConfiguration) : AgentChannel { - private const string FunctionDelimiter = "-"; - - private static readonly HashSet s_pollingStatuses = - [ - RunStatus.Queued, - RunStatus.InProgress, - RunStatus.Cancelling, - ]; - - private static readonly HashSet s_terminalStatuses = - [ - RunStatus.Expired, - RunStatus.Failed, - RunStatus.Cancelled, - ]; - private readonly AssistantsClient _client = client; private readonly string _threadId = threadId; - private readonly Dictionary _agentTools = []; - private readonly Dictionary _agentNames = []; // Cache agent names by their identifier for GetHistoryAsync() /// protected override async Task ReceiveAsync(IReadOnlyList history, CancellationToken cancellationToken) { foreach (ChatMessageContent message in history) { - if (string.IsNullOrWhiteSpace(message.Content)) - { - continue; - } - - await this._client.CreateMessageAsync( - this._threadId, - message.Role.ToMessageRole(), - message.Content, - cancellationToken: cancellationToken).ConfigureAwait(false); + await AssistantThreadActions.CreateMessageAsync(this._client, this._threadId, message, cancellationToken).ConfigureAwait(false); } } /// - protected override async IAsyncEnumerable InvokeAsync( + protected override IAsyncEnumerable InvokeAsync( OpenAIAssistantAgent agent, - [EnumeratorCancellation] CancellationToken cancellationToken) + CancellationToken cancellationToken) { - if (agent.IsDeleted) - { - throw new KernelException($"Agent Failure - {nameof(OpenAIAssistantAgent)} agent is deleted: {agent.Id}."); - } - - if (!this._agentTools.TryGetValue(agent.Id, out ToolDefinition[]? tools)) - { - tools = [.. agent.Tools, .. agent.Kernel.Plugins.SelectMany(p => p.Select(f => f.ToToolDefinition(p.Name, FunctionDelimiter)))]; - this._agentTools.Add(agent.Id, tools); - } - - if (!this._agentNames.ContainsKey(agent.Id) && !string.IsNullOrWhiteSpace(agent.Name)) - { - this._agentNames.Add(agent.Id, agent.Name); - } - - this.Logger.LogDebug("[{MethodName}] Creating run for agent/thrad: {AgentId}/{ThreadId}", nameof(InvokeAsync), agent.Id, this._threadId); - - CreateRunOptions options = - new(agent.Id) - { - OverrideInstructions = agent.Instructions, - OverrideTools = tools, - }; - - // Create run - ThreadRun run = await this._client.CreateRunAsync(this._threadId, options, cancellationToken).ConfigureAwait(false); - - this.Logger.LogInformation("[{MethodName}] Created run: {RunId}", nameof(InvokeAsync), run.Id); - - // Evaluate status and process steps and messages, as encountered. - HashSet processedStepIds = []; - Dictionary functionSteps = []; - - do - { - // Poll run and steps until actionable - PageableList steps = await PollRunStatusAsync().ConfigureAwait(false); - - // Is in terminal state? - if (s_terminalStatuses.Contains(run.Status)) - { - throw new KernelException($"Agent Failure - Run terminated: {run.Status} [{run.Id}]: {run.LastError?.Message ?? "Unknown"}"); - } - - // Is tool action required? - if (run.Status == RunStatus.RequiresAction) - { - this.Logger.LogDebug("[{MethodName}] Processing run steps: {RunId}", nameof(InvokeAsync), run.Id); - - // Execute functions in parallel and post results at once. - FunctionCallContent[] activeFunctionSteps = steps.Data.SelectMany(step => ParseFunctionStep(agent, step)).ToArray(); - if (activeFunctionSteps.Length > 0) - { - // Emit function-call content - yield return GenerateFunctionCallContent(agent.GetName(), activeFunctionSteps); - - // Invoke functions for each tool-step - IEnumerable> functionResultTasks = ExecuteFunctionSteps(agent, activeFunctionSteps, cancellationToken); - - // Block for function results - FunctionResultContent[] functionResults = await Task.WhenAll(functionResultTasks).ConfigureAwait(false); - - // Process tool output - ToolOutput[] toolOutputs = GenerateToolOutputs(functionResults); - - await this._client.SubmitToolOutputsToRunAsync(run, toolOutputs, cancellationToken).ConfigureAwait(false); - } - - if (this.Logger.IsEnabled(LogLevel.Information)) // Avoid boxing if not enabled - { - this.Logger.LogInformation("[{MethodName}] Processed #{MessageCount} run steps: {RunId}", nameof(InvokeAsync), activeFunctionSteps.Length, run.Id); - } - } - - // Enumerate completed messages - this.Logger.LogDebug("[{MethodName}] Processing run messages: {RunId}", nameof(InvokeAsync), run.Id); - - IEnumerable completedStepsToProcess = - steps - .Where(s => s.CompletedAt.HasValue && !processedStepIds.Contains(s.Id)) - .OrderBy(s => s.CreatedAt); - - int messageCount = 0; - foreach (RunStep completedStep in completedStepsToProcess) - { - if (completedStep.Type.Equals(RunStepType.ToolCalls)) - { - RunStepToolCallDetails toolCallDetails = (RunStepToolCallDetails)completedStep.StepDetails; - - foreach (RunStepToolCall toolCall in toolCallDetails.ToolCalls) - { - ChatMessageContent? content = null; - - // Process code-interpreter content - if (toolCall is RunStepCodeInterpreterToolCall toolCodeInterpreter) - { - content = GenerateCodeInterpreterContent(agent.GetName(), toolCodeInterpreter); - } - // Process function result content - else if (toolCall is RunStepFunctionToolCall toolFunction) - { - FunctionCallContent functionStep = functionSteps[toolFunction.Id]; // Function step always captured on invocation - content = GenerateFunctionResultContent(agent.GetName(), functionStep, toolFunction.Output); - } + agent.ThrowIfDeleted(); - if (content is not null) - { - ++messageCount; - - yield return content; - } - } - } - else if (completedStep.Type.Equals(RunStepType.MessageCreation)) - { - RunStepMessageCreationDetails messageCreationDetails = (RunStepMessageCreationDetails)completedStep.StepDetails; - - // Retrieve the message - ThreadMessage? message = await this.RetrieveMessageAsync(messageCreationDetails, cancellationToken).ConfigureAwait(false); - - if (message is not null) - { - AuthorRole role = new(message.Role.ToString()); - - foreach (MessageContent itemContent in message.ContentItems) - { - ChatMessageContent? content = null; - - // Process text content - if (itemContent is MessageTextContent contentMessage) - { - content = GenerateTextMessageContent(agent.GetName(), role, contentMessage); - } - // Process image content - else if (itemContent is MessageImageFileContent contentImage) - { - content = GenerateImageFileContent(agent.GetName(), role, contentImage); - } - - if (content is not null) - { - ++messageCount; - - yield return content; - } - } - } - } - - processedStepIds.Add(completedStep.Id); - } - - if (this.Logger.IsEnabled(LogLevel.Information)) // Avoid boxing if not enabled - { - this.Logger.LogInformation("[{MethodName}] Processed #{MessageCount} run messages: {RunId}", nameof(InvokeAsync), messageCount, run.Id); - } - } - while (RunStatus.Completed != run.Status); - - this.Logger.LogInformation("[{MethodName}] Completed run: {RunId}", nameof(InvokeAsync), run.Id); - - // Local function to assist in run polling (participates in method closure). - async Task> PollRunStatusAsync() - { - this.Logger.LogInformation("[{MethodName}] Polling run status: {RunId}", nameof(PollRunStatusAsync), run.Id); - - int count = 0; - - do - { - // Reduce polling frequency after a couple attempts - await Task.Delay(count >= 2 ? pollingConfiguration.RunPollingInterval : pollingConfiguration.RunPollingBackoff, cancellationToken).ConfigureAwait(false); - ++count; - -#pragma warning disable CA1031 // Do not catch general exception types - try - { - run = await this._client.GetRunAsync(this._threadId, run.Id, cancellationToken).ConfigureAwait(false); - } - catch - { - // Retry anyway.. - } -#pragma warning restore CA1031 // Do not catch general exception types - } - while (s_pollingStatuses.Contains(run.Status)); - - this.Logger.LogInformation("[{MethodName}] Run status is {RunStatus}: {RunId}", nameof(PollRunStatusAsync), run.Status, run.Id); - - return await this._client.GetRunStepsAsync(run, cancellationToken: cancellationToken).ConfigureAwait(false); - } - - // Local function to capture kernel function state for further processing (participates in method closure). - IEnumerable ParseFunctionStep(OpenAIAssistantAgent agent, RunStep step) - { - if (step.Status == RunStepStatus.InProgress && step.StepDetails is RunStepToolCallDetails callDetails) - { - foreach (RunStepFunctionToolCall toolCall in callDetails.ToolCalls.OfType()) - { - var nameParts = FunctionName.Parse(toolCall.Name, FunctionDelimiter); - - KernelArguments functionArguments = []; - if (!string.IsNullOrWhiteSpace(toolCall.Arguments)) - { - Dictionary arguments = JsonSerializer.Deserialize>(toolCall.Arguments)!; - foreach (var argumentKvp in arguments) - { - functionArguments[argumentKvp.Key] = argumentKvp.Value.ToString(); - } - } - - var content = new FunctionCallContent(nameParts.Name, nameParts.PluginName, toolCall.Id, functionArguments); - - functionSteps.Add(toolCall.Id, content); - - yield return content; - } - } - } + return AssistantThreadActions.InvokeAsync(agent, this._client, this._threadId, pollingConfiguration, this.Logger, cancellationToken); } /// - protected override async IAsyncEnumerable GetHistoryAsync([EnumeratorCancellation] CancellationToken cancellationToken) - { - PageableList messages; - - string? lastId = null; - do - { - messages = await this._client.GetMessagesAsync(this._threadId, limit: 100, ListSortOrder.Descending, after: lastId, null, cancellationToken).ConfigureAwait(false); - foreach (ThreadMessage message in messages) - { - AuthorRole role = new(message.Role.ToString()); - - string? assistantName = null; - if (!string.IsNullOrWhiteSpace(message.AssistantId) && - !this._agentNames.TryGetValue(message.AssistantId, out assistantName)) - { - Assistant assistant = await this._client.GetAssistantAsync(message.AssistantId, cancellationToken).ConfigureAwait(false); - if (!string.IsNullOrWhiteSpace(assistant.Name)) - { - this._agentNames.Add(assistant.Id, assistant.Name); - } - } - - assistantName ??= message.AssistantId; - - foreach (MessageContent item in message.ContentItems) - { - ChatMessageContent? content = null; - - if (item is MessageTextContent contentMessage) - { - content = GenerateTextMessageContent(assistantName, role, contentMessage); - } - else if (item is MessageImageFileContent contentImage) - { - content = GenerateImageFileContent(assistantName, role, contentImage); - } - - if (content is not null) - { - yield return content; - } - } - - lastId = message.Id; - } - } - while (messages.HasMore); - } - - private static AnnotationContent GenerateAnnotationContent(MessageTextAnnotation annotation) - { - string? fileId = null; - if (annotation is MessageTextFileCitationAnnotation citationAnnotation) - { - fileId = citationAnnotation.FileId; - } - else if (annotation is MessageTextFilePathAnnotation pathAnnotation) - { - fileId = pathAnnotation.FileId; - } - - return - new() - { - Quote = annotation.Text, - StartIndex = annotation.StartIndex, - EndIndex = annotation.EndIndex, - FileId = fileId, - }; - } - - private static ChatMessageContent GenerateImageFileContent(string agentName, AuthorRole role, MessageImageFileContent contentImage) - { - return - new ChatMessageContent( - role, - [ - new FileReferenceContent(contentImage.FileId) - ]) - { - AuthorName = agentName, - }; - } - - private static ChatMessageContent? GenerateTextMessageContent(string agentName, AuthorRole role, MessageTextContent contentMessage) - { - ChatMessageContent? messageContent = null; - - string textContent = contentMessage.Text.Trim(); - - if (!string.IsNullOrWhiteSpace(textContent)) - { - messageContent = - new(role, textContent) - { - AuthorName = agentName - }; - - foreach (MessageTextAnnotation annotation in contentMessage.Annotations) - { - messageContent.Items.Add(GenerateAnnotationContent(annotation)); - } - } - - return messageContent; - } - - private static ChatMessageContent GenerateCodeInterpreterContent(string agentName, RunStepCodeInterpreterToolCall contentCodeInterpreter) - { - return - new ChatMessageContent( - AuthorRole.Tool, - [ - new TextContent(contentCodeInterpreter.Input) - ]) - { - AuthorName = agentName, - }; - } - - private static ChatMessageContent GenerateFunctionCallContent(string agentName, FunctionCallContent[] functionSteps) - { - ChatMessageContent functionCallContent = new(AuthorRole.Tool, content: null) - { - AuthorName = agentName - }; - - functionCallContent.Items.AddRange(functionSteps); - - return functionCallContent; - } - - private static ChatMessageContent GenerateFunctionResultContent(string agentName, FunctionCallContent functionStep, string result) - { - ChatMessageContent functionCallContent = new(AuthorRole.Tool, content: null) - { - AuthorName = agentName - }; - - functionCallContent.Items.Add( - new FunctionResultContent( - functionStep.FunctionName, - functionStep.PluginName, - functionStep.Id, - result)); - - return functionCallContent; - } - - private static Task[] ExecuteFunctionSteps(OpenAIAssistantAgent agent, FunctionCallContent[] functionSteps, CancellationToken cancellationToken) - { - Task[] functionTasks = new Task[functionSteps.Length]; - - for (int index = 0; index < functionSteps.Length; ++index) - { - functionTasks[index] = functionSteps[index].InvokeAsync(agent.Kernel, cancellationToken); - } - - return functionTasks; - } - - private static ToolOutput[] GenerateToolOutputs(FunctionResultContent[] functionResults) + protected override IAsyncEnumerable GetHistoryAsync(CancellationToken cancellationToken) { - ToolOutput[] toolOutputs = new ToolOutput[functionResults.Length]; - - for (int index = 0; index < functionResults.Length; ++index) - { - FunctionResultContent functionResult = functionResults[index]; - - object resultValue = functionResult.Result ?? string.Empty; - - if (resultValue is not string textResult) - { - textResult = JsonSerializer.Serialize(resultValue); - } - - toolOutputs[index] = new ToolOutput(functionResult.CallId, textResult!); - } - - return toolOutputs; - } - - private async Task RetrieveMessageAsync(RunStepMessageCreationDetails detail, CancellationToken cancellationToken) - { - ThreadMessage? message = null; - - bool retry = false; - int count = 0; - do - { - try - { - message = await this._client.GetMessageAsync(this._threadId, detail.MessageCreation.MessageId, cancellationToken).ConfigureAwait(false); - } - catch (RequestFailedException exception) - { - // Step has provided the message-id. Retry on of NotFound/404 exists. - // Extremely rarely there might be a synchronization issue between the - // assistant response and message-service. - retry = exception.Status == (int)HttpStatusCode.NotFound && count < 3; - } - - if (retry) - { - await Task.Delay(pollingConfiguration.MessageSynchronizationDelay, cancellationToken).ConfigureAwait(false); - } - - ++count; - } - while (retry); - - return message; + return AssistantThreadActions.GetMessagesAsync(this._client, this._threadId, cancellationToken); } }