Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 100 additions & 41 deletions dotnet/src/Microsoft.Agents/ChatCompletion/ChatClientAgent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Runtime.CompilerServices;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.AI;
Expand All @@ -19,6 +20,7 @@ public sealed class ChatClientAgent : Agent
{
private readonly ChatClientAgentOptions? _agentOptions;
private readonly ILogger _logger;
private readonly Type _chatClientType;

/// <summary>
/// Initializes a new instance of the <see cref="ChatClientAgent"/> class.
Expand All @@ -30,27 +32,17 @@ public ChatClientAgent(IChatClient chatClient, ChatClientAgentOptions? options =
{
Throw.IfNull(chatClient);

this._chatClientType = chatClient.GetType();
this.ChatClient = chatClient.AsAgentInvokingChatClient();
this._agentOptions = options;
this._logger = (loggerFactory ?? chatClient.GetService<ILoggerFactory>() ?? NullLoggerFactory.Instance).CreateLogger<ChatClientAgent>();
}

/// <summary>
/// The chat client.
/// The underlying chat client used by the agent to invoke chat completions.
/// </summary>
public IChatClient ChatClient { get; }

/// <summary>
/// Gets the role used for agent instructions. Defaults to "system".
/// </summary>
/// <remarks>
/// Certain versions of "O*" series (deep reasoning) models require the instructions
/// to be provided as "developer" role. Other versions support neither role and
/// an agent targeting such a model cannot provide instructions. Agent functionality
/// will be dictated entirely by the provided plugins.
/// </remarks>
public ChatRole InstructionsRole { get; set; } = ChatRole.System;

/// <inheritdoc/>
public override string Id => this._agentOptions?.Id ?? base.Id;

Expand All @@ -72,35 +64,16 @@ public override async Task<ChatResponse> RunAsync(
{
Throw.IfNull(messages);

// Retrieve chat options from the provided AgentRunOptions if available.
ChatOptions? chatOptions = (options as ChatClientAgentRunOptions)?.ChatOptions;

var chatClientThread = this.ValidateOrCreateThreadType<ChatClientAgentThread>(thread, () => new());

// Add any existing messages from the thread to the messages to be sent to the chat client.
List<ChatMessage> threadMessages = [];
if (chatClientThread is IMessagesRetrievableThread messagesRetrievableThread)
{
await foreach (ChatMessage message in messagesRetrievableThread.GetMessagesAsync(cancellationToken).ConfigureAwait(false))
{
threadMessages.Add(message);
}
}

// Append to the existing thread messages the messages that were passed in to this call.
threadMessages.AddRange(messages);
(ChatClientAgentThread chatClientThread, ChatOptions? chatOptions, List<ChatMessage> threadMessages) =
await this.PrepareThreadAndMessagesAsync(thread, messages, options, cancellationToken).ConfigureAwait(false);

// Update the messages with agent instructions.
this.UpdateThreadMessagesWithAgentInstructions(threadMessages, options);

var agentName = this.Name ?? "UnnamedAgent";
Type serviceType = this.ChatClient.GetType();
var agentName = this.GetAgentName();

this._logger.LogAgentChatClientInvokingAgent(nameof(RunAsync), this.Id, agentName, serviceType);
this._logger.LogAgentChatClientInvokingAgent(nameof(RunAsync), this.Id, agentName, this._chatClientType);

ChatResponse chatResponse = await this.ChatClient.GetResponseAsync(threadMessages, chatOptions, cancellationToken).ConfigureAwait(false);

this._logger.LogAgentChatClientInvokedAgent(nameof(RunAsync), this.Id, agentName, serviceType, messages.Count);
this._logger.LogAgentChatClientInvokedAgent(nameof(RunAsync), this.Id, agentName, this._chatClientType, messages.Count);

// Only notify the thread of new messages if the chatResponse was successful to avoid inconsistent messages state in the thread.
await this.NotifyThreadOfNewMessagesAsync(chatClientThread, messages, cancellationToken).ConfigureAwait(false);
Expand All @@ -112,7 +85,7 @@ public override async Task<ChatResponse> RunAsync(
}

// Convert the chat response messages to a valid IReadOnlyCollection for notification signatures below.
var chatResponseMessages = chatResponse.Messages.ToArray();
var chatResponseMessages = chatResponse.Messages as IReadOnlyCollection<ChatMessage> ?? chatResponse.Messages.ToArray();

await this.NotifyThreadOfNewMessagesAsync(chatClientThread, chatResponseMessages, cancellationToken).ConfigureAwait(false);
if (options?.OnIntermediateMessages is not null)
Expand All @@ -124,28 +97,114 @@ public override async Task<ChatResponse> RunAsync(
}

/// <inheritdoc/>
public override IAsyncEnumerable<ChatResponseUpdate> RunStreamingAsync(IReadOnlyCollection<ChatMessage> messages, AgentThread? thread = null, AgentRunOptions? options = null, CancellationToken cancellationToken = default)
public override async IAsyncEnumerable<ChatResponseUpdate> RunStreamingAsync(
IReadOnlyCollection<ChatMessage> messages,
AgentThread? thread = null,
AgentRunOptions? options = null,
[EnumeratorCancellation] CancellationToken cancellationToken = default)
{
throw new System.NotImplementedException();
Throw.IfNull(messages);

(ChatClientAgentThread chatClientThread, ChatOptions? chatOptions, List<ChatMessage> threadMessages) =
await this.PrepareThreadAndMessagesAsync(thread, messages, options, cancellationToken).ConfigureAwait(false);

int messageCount = threadMessages.Count;
var agentName = this.GetAgentName();

this._logger.LogAgentChatClientInvokingAgent(nameof(RunStreamingAsync), this.Id, agentName, this._chatClientType);

// Using the enumerator to ensure we consider the case where no updates are returned for notification.
var responseUpdatesEnumerator = this.ChatClient.GetStreamingResponseAsync(threadMessages, chatOptions, cancellationToken).GetAsyncEnumerator(cancellationToken);

this._logger.LogAgentChatClientInvokedStreamingAgent(nameof(RunStreamingAsync), this.Id, agentName, this._chatClientType);

List<ChatResponseUpdate> responseUpdates = [];

// Ensure we start the streaming request
var hasUpdates = await responseUpdatesEnumerator.MoveNextAsync().ConfigureAwait(false);

// To avoid inconsistent state we only notify the thread of the input messages if no error occurs after the initial request.
await this.NotifyThreadOfNewMessagesAsync(chatClientThread, messages, cancellationToken).ConfigureAwait(false);

while (hasUpdates)
{
var update = responseUpdatesEnumerator.Current;
if (update is not null)
{
responseUpdates.Add(update);
update.AuthorName ??= agentName;
yield return update;
}

hasUpdates = await responseUpdatesEnumerator.MoveNextAsync().ConfigureAwait(false);
}

var chatResponse = responseUpdates.ToChatResponse();
var chatResponseMessages = chatResponse.Messages as IReadOnlyCollection<ChatMessage> ?? chatResponse.Messages.ToArray();

await this.NotifyThreadOfNewMessagesAsync(chatClientThread, chatResponseMessages, cancellationToken).ConfigureAwait(false);
if (options?.OnIntermediateMessages is not null)
{
await options.OnIntermediateMessages(chatResponseMessages).ConfigureAwait(false);
}
}

/// <inheritdoc/>
public override AgentThread GetNewThread() => new ChatClientAgentThread();

#region Private

/// <summary>
/// Prepares the thread, chat options, and messages for agent execution.
/// </summary>
/// <param name="thread">The conversation thread to use or create.</param>
/// <param name="inputMessages">The input messages to use.</param>
/// <param name="options">Optional parameters for agent invocation.</param>
/// <param name="cancellationToken">The cancellation token.</param>
/// <returns>A tuple containing the thread, chat options, and thread messages.</returns>
private async Task<(ChatClientAgentThread thread, ChatOptions? chatOptions, List<ChatMessage> threadMessages)> PrepareThreadAndMessagesAsync(
AgentThread? thread,
IReadOnlyCollection<ChatMessage> inputMessages,
AgentRunOptions? options,
CancellationToken cancellationToken)
{
// Retrieve chat options from the provided AgentRunOptions if available.
ChatOptions? chatOptions = (options as ChatClientAgentRunOptions)?.ChatOptions;

var chatClientThread = this.ValidateOrCreateThreadType<ChatClientAgentThread>(thread, () => new());

// Add any existing messages from the thread to the messages to be sent to the chat client.
List<ChatMessage> threadMessages = [];
if (chatClientThread is IMessagesRetrievableThread messagesRetrievableThread)
{
await foreach (ChatMessage message in messagesRetrievableThread.GetMessagesAsync(cancellationToken).ConfigureAwait(false))
{
threadMessages.Add(message);
}
}

// Update the messages with agent instructions.
this.UpdateThreadMessagesWithAgentInstructions(threadMessages, options);

// Add the input messages to the end of thread messages.
threadMessages.AddRange(inputMessages);

return (chatClientThread, chatOptions, threadMessages);
}

private void UpdateThreadMessagesWithAgentInstructions(List<ChatMessage> threadMessages, AgentRunOptions? options)
{
if (!string.IsNullOrWhiteSpace(options?.AdditionalInstructions))
{
threadMessages.Insert(0, new(this.InstructionsRole, options?.AdditionalInstructions) { AuthorName = this.Name });
threadMessages.Insert(0, new(ChatRole.System, options?.AdditionalInstructions) { AuthorName = this.Name });
}

if (!string.IsNullOrWhiteSpace(this.Instructions))
{
threadMessages.Insert(0, new(this.InstructionsRole, this.Instructions) { AuthorName = this.Name });
threadMessages.Insert(0, new(ChatRole.System, this.Instructions) { AuthorName = this.Name });
}
}

private string GetAgentName() => this.Name ?? "UnnamedAgent";
#endregion
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,26 +14,49 @@ namespace Microsoft.Agents;
public static class ChatClientAgentExtensions
{
/// <summary>
/// Allow running a chat client agent with a <see cref="ChatOptions"/> configuration.
/// Run the agent with the provided message and arguments.
/// </summary>
/// <param name="agent">Target agent to run.</param>
/// <param name="messages">Messages to send to the agent.</param>
/// <param name="thread">Optional thread to use for the agent.</param>
/// <param name="agentOptions">Optional agent run options.</param>
/// <param name="messages">The messages to pass to the agent.</param>
/// <param name="thread">The conversation thread to continue with this invocation. If not provided, creates a new thread. The thread will be mutated with the provided messages and agent reponse.</param>
/// <param name="agentRunOptions">Optional parameters for agent invocation.</param>
/// <param name="chatOptions">Optional chat options.</param>
/// <param name="cancellationToken">Optional cancellation token.</param>
/// <returns>A task representing the asynchronous operation, with the chat response.</returns>
/// <param name="cancellationToken">The <see cref="CancellationToken"/> to monitor for cancellation requests. The default is <see cref="CancellationToken.None"/>.</param>
/// <returns>A <see cref="ChatResponse"/> containing the list of <see cref="ChatMessage"/> items.</returns>
public static Task<ChatResponse> RunAsync(
this ChatClientAgent agent,
IReadOnlyCollection<ChatMessage> messages,
AgentThread? thread = null,
AgentRunOptions? agentOptions = null,
AgentRunOptions? agentRunOptions = null,
ChatOptions? chatOptions = null,
CancellationToken cancellationToken = default)
{
Throw.IfNull(agent);
Throw.IfNull(messages);

return agent.RunAsync(messages, thread, new ChatClientAgentRunOptions(agentOptions, chatOptions), cancellationToken);
return agent.RunAsync(messages, thread, new ChatClientAgentRunOptions(agentRunOptions, chatOptions), cancellationToken);
}

/// <summary>
/// Run the agent with the provided message and arguments.
/// </summary>
/// <param name="agent">Target agent to run.</param>
/// <param name="messages">The messages to pass to the agent.</param>
/// <param name="thread">The conversation thread to continue with this invocation. If not provided, creates a new thread. The thread will be mutated with the provided messages and agent reponse.</param>
/// <param name="agentRunOptions">Optional parameters for agent invocation.</param>
/// <param name="chatOptions">Optional chat options.</param>
/// <param name="cancellationToken">The <see cref="CancellationToken"/> to monitor for cancellation requests. The default is <see cref="CancellationToken.None"/>.</param>
public static IAsyncEnumerable<ChatResponseUpdate> RunStreamingAsync(
this ChatClientAgent agent,
IReadOnlyCollection<ChatMessage> messages,
AgentThread? thread = null,
AgentRunOptions? agentRunOptions = null,
ChatOptions? chatOptions = null,
CancellationToken cancellationToken = default)
{
Throw.IfNull(agent);
Throw.IfNull(messages);

return agent.RunStreamingAsync(messages, thread, new ChatClientAgentRunOptions(agentRunOptions, chatOptions), cancellationToken);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,27 +23,27 @@ internal static partial class ChatClientAgentLogMessages
[LoggerMessage(
EventId = 0,
Level = LogLevel.Debug,
Message = "[{MethodName}] Agent {AgentId}/{AgentName} Invoking service {ServiceType}.")]
Message = "[{MethodName}] Agent {AgentId}/{AgentName} Invoking client {ClientType}.")]
public static partial void LogAgentChatClientInvokingAgent(
this ILogger logger,
string methodName,
string agentId,
string agentName,
Type serviceType);
Type clientType);

/// <summary>
/// Logs <see cref="ChatClientAgent"/> invoked agent (complete).
/// </summary>
[LoggerMessage(
EventId = 0,
Level = LogLevel.Information,
Message = "[{MethodName}] Agent {AgentId}/{AgentName} Invoked service {ServiceType} with message count: {MessageCount}.")]
Message = "[{MethodName}] Agent {AgentId}/{AgentName} Invoked client {ClientType} with message count: {MessageCount}.")]
public static partial void LogAgentChatClientInvokedAgent(
this ILogger logger,
string methodName,
string agentId,
string agentName,
Type serviceType,
Type clientType,
int messageCount);

/// <summary>
Expand All @@ -52,11 +52,11 @@ public static partial void LogAgentChatClientInvokedAgent(
[LoggerMessage(
EventId = 0,
Level = LogLevel.Information,
Message = "[{MethodName}] Agent {AgentId}/{AgentName} Invoked service {ServiceType}.")]
Message = "[{MethodName}] Agent {AgentId}/{AgentName} Invoked client {ClientType}.")]
public static partial void LogAgentChatClientInvokedStreamingAgent(
this ILogger logger,
string methodName,
string agentId,
string agentName,
Type serviceType);
Type clientType);
}
Loading
Loading