From eccd6de90c1c4e291d512b589b086347d00325dc Mon Sep 17 00:00:00 2001 From: westey <164392973+westey-m@users.noreply.github.com> Date: Thu, 5 Feb 2026 12:01:13 +0000 Subject: [PATCH 01/28] Add a StateBag to AgentSession and pass Agent and AgentSession to AIContextProvider and ChatHistoryProviders --- .../Program.cs | 8 +- .../AIContextProvider.cs | 39 +- .../AgentAbstractionsJsonUtilities.cs | 2 + .../AgentSession.cs | 5 + .../AgentSessionStateBag.cs | 180 ++++++++ .../AgentSessionStateBagValue.cs | 67 +++ .../ChatHistoryProvider.cs | 39 +- .../ChatClient/ChatClientAgent.cs | 44 +- .../ChatClient/ChatClientAgentSession.cs | 5 + .../IAgentFixture.cs | 2 +- .../RunStreamingTests.cs | 2 +- .../RunTests.cs | 2 +- .../AnthropicChatCompletionFixture.cs | 4 +- .../AIProjectClientFixture.cs | 4 +- .../AzureAIAgentsPersistentFixture.cs | 2 +- .../CopilotStudioFixture.cs | 2 +- .../AIContextProviderTests.cs | 122 +++++- .../AgentSessionStateBagTests.cs | 407 ++++++++++++++++++ .../AgentSessionTests.cs | 15 + .../ChatHistoryProviderExtensionsTests.cs | 9 +- .../ChatHistoryProviderMessageFilterTests.cs | 13 +- .../ChatHistoryProviderTests.cs | 122 +++++- .../InMemoryChatHistoryProviderTests.cs | 21 +- .../CosmosChatHistoryProviderTests.cs | 59 +-- .../DurableAgentSessionTests.cs | 4 +- .../Mem0ProviderTests.cs | 23 +- .../Mem0ProviderTests.cs | 17 +- .../ChatClient/ChatClientAgentSessionTests.cs | 74 ++++ .../Data/TextSearchProviderTests.cs | 89 ++-- .../Memory/ChatHistoryMemoryProviderTests.cs | 17 +- .../TestJsonSerializerContext.cs | 1 + .../OpenAIAssistantFixture.cs | 2 +- .../OpenAIChatCompletionFixture.cs | 4 +- .../OpenAIResponseFixture.cs | 4 +- 34 files changed, 1238 insertions(+), 172 deletions(-) create mode 100644 dotnet/src/Microsoft.Agents.AI.Abstractions/AgentSessionStateBag.cs create mode 100644 dotnet/src/Microsoft.Agents.AI.Abstractions/AgentSessionStateBagValue.cs create mode 100644 dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AgentSessionStateBagTests.cs diff --git a/dotnet/samples/GettingStarted/AgentProviders/Agent_With_CustomImplementation/Program.cs b/dotnet/samples/GettingStarted/AgentProviders/Agent_With_CustomImplementation/Program.cs index 5d4e77474a..82f76e7599 100644 --- a/dotnet/samples/GettingStarted/AgentProviders/Agent_With_CustomImplementation/Program.cs +++ b/dotnet/samples/GettingStarted/AgentProviders/Agent_With_CustomImplementation/Program.cs @@ -55,14 +55,14 @@ protected override async Task RunCoreAsync(IEnumerable responseMessages = CloneAndToUpperCase(messages, this.Name).ToList(); // Notify the session of the input and output messages. - var invokedContext = new ChatHistoryProvider.InvokedContext(messages, storeMessages) + var invokedContext = new ChatHistoryProvider.InvokedContext(this, session, messages, storeMessages) { ResponseMessages = responseMessages }; @@ -87,14 +87,14 @@ protected override async IAsyncEnumerable RunCoreStreamingA } // Get existing messages from the store - var invokingContext = new ChatHistoryProvider.InvokingContext(messages); + var invokingContext = new ChatHistoryProvider.InvokingContext(this, session, messages); var storeMessages = await typedSession.ChatHistoryProvider.InvokingAsync(invokingContext, cancellationToken); // Clone the input messages and turn them into response messages with upper case text. List responseMessages = CloneAndToUpperCase(messages, this.Name).ToList(); // Notify the session of the input and output messages. - var invokedContext = new ChatHistoryProvider.InvokedContext(messages, storeMessages) + var invokedContext = new ChatHistoryProvider.InvokedContext(this, session, messages, storeMessages) { ResponseMessages = responseMessages }; diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/AIContextProvider.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/AIContextProvider.cs index f104f12890..8428d46f9b 100644 --- a/dotnet/src/Microsoft.Agents.AI.Abstractions/AIContextProvider.cs +++ b/dotnet/src/Microsoft.Agents.AI.Abstractions/AIContextProvider.cs @@ -129,13 +129,30 @@ public sealed class InvokingContext /// /// Initializes a new instance of the class with the specified request messages. /// + /// The agent being invoked. + /// The session associated with the agent invocation. /// The messages to be used by the agent for this invocation. /// is . - public InvokingContext(IEnumerable requestMessages) + public InvokingContext( + AIAgent agent, + AgentSession? session, + IEnumerable requestMessages) { + this.Agent = Throw.IfNull(agent); + this.Session = session; this.RequestMessages = requestMessages ?? throw new ArgumentNullException(nameof(requestMessages)); } + /// + /// Gets the agent that is being invoked. + /// + public AIAgent Agent { get; } + + /// + /// Gets the agent session associated with the agent invocation. + /// + public AgentSession? Session { get; } + /// /// Gets the caller provided messages that will be used by the agent for this invocation. /// @@ -158,15 +175,33 @@ public sealed class InvokedContext /// /// Initializes a new instance of the class with the specified request messages. /// + /// The agent being invoked. + /// The session associated with the agent invocation. /// The caller provided messages that were used by the agent for this invocation. /// The messages provided by the for this invocation, if any. /// is . - public InvokedContext(IEnumerable requestMessages, IEnumerable? aiContextProviderMessages) + public InvokedContext( + AIAgent agent, + AgentSession? session, + IEnumerable requestMessages, + IEnumerable? aiContextProviderMessages) { + this.Agent = Throw.IfNull(agent); + this.Session = session; this.RequestMessages = requestMessages ?? throw new ArgumentNullException(nameof(requestMessages)); this.AIContextProviderMessages = aiContextProviderMessages; } + /// + /// Gets the agent that is being invoked. + /// + public AIAgent Agent { get; } + + /// + /// Gets the agent session associated with the agent invocation. + /// + public AgentSession? Session { get; } + /// /// Gets the caller provided messages that were used by the agent for this invocation. /// diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentAbstractionsJsonUtilities.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentAbstractionsJsonUtilities.cs index 17fbb9e4c6..bf0e835b4b 100644 --- a/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentAbstractionsJsonUtilities.cs +++ b/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentAbstractionsJsonUtilities.cs @@ -1,5 +1,6 @@ // Copyright (c) Microsoft. All rights reserved. +using System.Collections.Concurrent; using System.Diagnostics.CodeAnalysis; using System.Text.Encodings.Web; using System.Text.Json; @@ -83,6 +84,7 @@ private static JsonSerializerOptions CreateDefaultOptions() [JsonSerializable(typeof(ServiceIdAgentSession.ServiceIdAgentSessionState))] [JsonSerializable(typeof(InMemoryAgentSession.InMemoryAgentSessionState))] [JsonSerializable(typeof(InMemoryChatHistoryProvider.State))] + [JsonSerializable(typeof(ConcurrentDictionary))] [ExcludeFromCodeCoverage] private sealed partial class JsonContext : JsonSerializerContext; diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentSession.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentSession.cs index 3efce9be17..722660d49e 100644 --- a/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentSession.cs +++ b/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentSession.cs @@ -53,6 +53,11 @@ protected AgentSession() { } + /// + /// Gets any arbitrary state associated with this session. + /// + public AgentSessionStateBag StateBag { get; protected set; } = new(); + /// Asks the for an object of the specified type . /// The type of object being requested. /// An optional key that can be used to help identify the target service. diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentSessionStateBag.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentSessionStateBag.cs new file mode 100644 index 0000000000..47eef61508 --- /dev/null +++ b/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentSessionStateBag.cs @@ -0,0 +1,180 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Concurrent; +using System.Text.Json; +using Microsoft.Shared.Diagnostics; + +namespace Microsoft.Agents.AI; + +/// +/// Provides a thread-safe key-value store for managing session-scoped state with support for type-safe access and JSON +/// serialization options. +/// +/// +/// SessionState enables storing and retrieving objects associated with a session using string keys. +/// Values can be accessed in a type-safe manner and are serialized or deserialized using configurable JSON serializer +/// options. This class is designed for concurrent access and is safe to use across multiple threads. +/// +public class AgentSessionStateBag +{ + private readonly ConcurrentDictionary _state; + + /// + /// Initializes a new instance of the class. + /// + public AgentSessionStateBag() + { + this._state = new ConcurrentDictionary(); + } + + /// + /// Initializes a new instance of the class. + /// + /// The initial state dictionary. + internal AgentSessionStateBag(ConcurrentDictionary? state) + { + this._state = state ?? new ConcurrentDictionary(); + } + + /// + /// Tries to get a value from the session state. + /// + /// The type of the value to retrieve. + /// The key from which to retrieve the value. + /// The value if found and convertable to the required type; otherwise, null. + /// The JSON serializer options to use for serializing/deserialing the value. + /// if the value was successfully retrieved, otherwise. + public bool TryGetValue(string key, out T? value, JsonSerializerOptions? jsonSerializerOptions = null) + where T : class + { + _ = Throw.IfNullOrWhitespace(key); + var jso = jsonSerializerOptions ?? AgentAbstractionsJsonUtilities.DefaultOptions; + + if (this._state.TryGetValue(key, out var stateValue)) + { + if (stateValue.DeserializedValue is T cachedValue) + { + value = cachedValue; + return true; + } + + switch (stateValue.JsonValue) + { + case T tValue: + value = tValue; + return true; + case JsonElement jsonElement when jsonElement.ValueKind == JsonValueKind.Null || jsonElement.ValueKind == JsonValueKind.Undefined: + value = null; + return false; + default: + T? result = stateValue.JsonValue.Deserialize(jso.GetTypeInfo(typeof(T))) as T; + if (result is null) + { + value = null; + return false; + } + + stateValue.DeserializedValue = result; + stateValue.ValueType = typeof(T); + stateValue.JsonSerializerOptions = jso; + + value = result; + return true; + } + } + value = null; + return false; + } + + /// + /// Gets a value from the session state. + /// + /// The type of value to get. + /// The key from which to retrieve the value. + /// The JSON serializer options to use for serializing/deserialing the value. + /// The retrieved value or null if not found. + /// The value could not be deserialized into the required type. + public T? GetValue(string key, JsonSerializerOptions? jsonSerializerOptions = null) + where T : class + { + _ = Throw.IfNullOrWhitespace(key); + var jso = jsonSerializerOptions ?? AgentAbstractionsJsonUtilities.DefaultOptions; + + if (this._state.TryGetValue(key, out var stateValue)) + { + if (stateValue.DeserializedValue is T cachedValue) + { + return cachedValue; + } + + switch (stateValue.JsonValue) + { + case T tValue: + return tValue; + case JsonElement jsonElement when jsonElement.ValueKind == JsonValueKind.Null || jsonElement.ValueKind == JsonValueKind.Undefined: + return null; + default: + T? result = stateValue.JsonValue.Deserialize(jso.GetTypeInfo(typeof(T))) as T; + if (result is null) + { + throw new InvalidOperationException($"Failed to deserialize session state value to type {typeof(T).FullName}."); + } + stateValue.DeserializedValue = result; + stateValue.ValueType = typeof(T); + stateValue.JsonSerializerOptions = jso; + return result; + } + } + + return null; + } + + /// + /// Sets a value in the session state. + /// + /// The type of the value to set. + /// The key to store the value under. + /// The value to set. + /// The JSON serializer options to use for serializing the value. + public void SetValue(string key, T value, JsonSerializerOptions? jsonSerializerOptions = null) + where T : class + { + _ = Throw.IfNullOrWhitespace(key); + var jso = jsonSerializerOptions ?? AgentAbstractionsJsonUtilities.DefaultOptions; + + var stateValue = this._state.GetOrAdd(key, _ => + new AgentSessionStateBagValue(value, typeof(T), jso)); + + stateValue.DeserializedValue = value; + stateValue.ValueType = typeof(T); + stateValue.JsonSerializerOptions = jso; + } + + /// + /// Serializes all session state values to a JSON object. + /// + /// A representing the serialized session state. + /// Thrown when a session state value is not properly initialized. + public JsonElement Serialize() + { + return JsonSerializer.SerializeToElement(this._state, AgentAbstractionsJsonUtilities.DefaultOptions.GetTypeInfo(typeof(ConcurrentDictionary))); + } + + /// + /// Deserializes a JSON object into an instance. + /// + /// The element to deserialize. + /// The deserialized . + public static AgentSessionStateBag Deserialize(JsonElement jsonElement) + { + if (jsonElement.ValueKind is JsonValueKind.Undefined or JsonValueKind.Null) + { + return new AgentSessionStateBag(); + } + + return new AgentSessionStateBag( + jsonElement.Deserialize(AgentAbstractionsJsonUtilities.DefaultOptions.GetTypeInfo(typeof(ConcurrentDictionary))) as ConcurrentDictionary + ?? new ConcurrentDictionary()); + } +} diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentSessionStateBagValue.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentSessionStateBagValue.cs new file mode 100644 index 0000000000..a4c1743e77 --- /dev/null +++ b/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentSessionStateBagValue.cs @@ -0,0 +1,67 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Text.Json; +using System.Text.Json.Serialization; + +namespace Microsoft.Agents.AI; + +/// +/// Used to store a value in session state. +/// +internal class AgentSessionStateBagValue +{ + /// + /// Initializes a new instance of the SessionStateValue class with the specified value. + /// + /// The serialized value to associate with the session state. + [JsonConstructor] + public AgentSessionStateBagValue(JsonElement jsonValue) + { + this.JsonValue = jsonValue; + } + + /// + /// Initializes a new instance of the SessionStateValue class with the specified value. + /// + /// The value to associate with the session state. Can be any object, including null. + /// The type of the value. + /// The JSON serializer options to use for serializing the value. + public AgentSessionStateBagValue(object deserializedValue, Type valueType, JsonSerializerOptions jsonSerializerOptions) + { + this.DeserializedValue = deserializedValue; + this.ValueType = valueType; + this.JsonSerializerOptions = jsonSerializerOptions; + } + + /// + /// Gets or sets the value associated with this instance. + /// + public JsonElement JsonValue + { + get + { + if (this.DeserializedValue != null) + { + if (this.ValueType is null || this.JsonSerializerOptions is null) + { + throw new InvalidOperationException($"{nameof(AgentSessionStateBagValue)} has not been properly initialized, please set {nameof(this.ValueType)} and {nameof(this.JsonSerializerOptions)} before accessing {nameof(this.JsonValue)}."); + } + + return JsonSerializer.SerializeToElement(this.DeserializedValue, this.JsonSerializerOptions.GetTypeInfo(this.ValueType)); + } + + return field; + } + set; + } + + [JsonIgnore] + public object? DeserializedValue { get; set; } + + [JsonIgnore] + public Type? ValueType { get; set; } + + [JsonIgnore] + public JsonSerializerOptions? JsonSerializerOptions { get; set; } +} diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/ChatHistoryProvider.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/ChatHistoryProvider.cs index d809582ea4..cecfa92e8f 100644 --- a/dotnet/src/Microsoft.Agents.AI.Abstractions/ChatHistoryProvider.cs +++ b/dotnet/src/Microsoft.Agents.AI.Abstractions/ChatHistoryProvider.cs @@ -143,13 +143,30 @@ public sealed class InvokingContext /// /// Initializes a new instance of the class with the specified request messages. /// + /// The agent being invoked. + /// The session associated with the agent invocation. /// The new messages to be used by the agent for this invocation. /// is . - public InvokingContext(IEnumerable requestMessages) + public InvokingContext( + AIAgent agent, + AgentSession? session, + IEnumerable requestMessages) { + this.Agent = Throw.IfNull(agent); + this.Session = session; this.RequestMessages = requestMessages ?? throw new ArgumentNullException(nameof(requestMessages)); } + /// + /// Gets the agent that is being invoked. + /// + public AIAgent Agent { get; } + + /// + /// Gets the agent session associated with the agent invocation. + /// + public AgentSession? Session { get; } + /// /// Gets the caller provided messages that will be used by the agent for this invocation. /// @@ -172,15 +189,33 @@ public sealed class InvokedContext /// /// Initializes a new instance of the class with the specified request messages. /// + /// The agent being invoked. + /// The session associated with the agent invocation. /// The caller provided messages that were used by the agent for this invocation. /// The messages retrieved from the for this invocation. /// is . - public InvokedContext(IEnumerable requestMessages, IEnumerable? chatHistoryProviderMessages) + public InvokedContext( + AIAgent agent, + AgentSession? session, + IEnumerable requestMessages, + IEnumerable? chatHistoryProviderMessages) { + this.Agent = Throw.IfNull(agent); + this.Session = session; this.RequestMessages = Throw.IfNull(requestMessages); this.ChatHistoryProviderMessages = chatHistoryProviderMessages; } + /// + /// Gets the agent that is being invoked. + /// + public AIAgent Agent { get; } + + /// + /// Gets the agent session associated with the agent invocation. + /// + public AgentSession? Session { get; } + /// /// Gets the caller provided messages that were used by the agent for this invocation. /// diff --git a/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgent.cs b/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgent.cs index ee6db4830d..23e120f14f 100644 --- a/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgent.cs +++ b/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgent.cs @@ -231,8 +231,8 @@ protected override async IAsyncEnumerable RunCoreStreamingA } catch (Exception ex) { - await NotifyChatHistoryProviderOfFailureAsync(safeSession, ex, GetInputMessages(inputMessages, continuationToken), chatHistoryProviderMessages, aiContextProviderMessages, chatOptions, cancellationToken).ConfigureAwait(false); - await NotifyAIContextProviderOfFailureAsync(safeSession, ex, GetInputMessages(inputMessages, continuationToken), aiContextProviderMessages, cancellationToken).ConfigureAwait(false); + await this.NotifyChatHistoryProviderOfFailureAsync(safeSession, ex, GetInputMessages(inputMessages, continuationToken), chatHistoryProviderMessages, aiContextProviderMessages, chatOptions, cancellationToken).ConfigureAwait(false); + await this.NotifyAIContextProviderOfFailureAsync(safeSession, ex, GetInputMessages(inputMessages, continuationToken), aiContextProviderMessages, cancellationToken).ConfigureAwait(false); throw; } @@ -246,8 +246,8 @@ protected override async IAsyncEnumerable RunCoreStreamingA } catch (Exception ex) { - await NotifyChatHistoryProviderOfFailureAsync(safeSession, ex, GetInputMessages(inputMessages, continuationToken), chatHistoryProviderMessages, aiContextProviderMessages, chatOptions, cancellationToken).ConfigureAwait(false); - await NotifyAIContextProviderOfFailureAsync(safeSession, ex, GetInputMessages(inputMessages, continuationToken), aiContextProviderMessages, cancellationToken).ConfigureAwait(false); + await this.NotifyChatHistoryProviderOfFailureAsync(safeSession, ex, GetInputMessages(inputMessages, continuationToken), chatHistoryProviderMessages, aiContextProviderMessages, chatOptions, cancellationToken).ConfigureAwait(false); + await this.NotifyAIContextProviderOfFailureAsync(safeSession, ex, GetInputMessages(inputMessages, continuationToken), aiContextProviderMessages, cancellationToken).ConfigureAwait(false); throw; } @@ -273,8 +273,8 @@ protected override async IAsyncEnumerable RunCoreStreamingA } catch (Exception ex) { - await NotifyChatHistoryProviderOfFailureAsync(safeSession, ex, GetInputMessages(inputMessages, continuationToken), chatHistoryProviderMessages, aiContextProviderMessages, chatOptions, cancellationToken).ConfigureAwait(false); - await NotifyAIContextProviderOfFailureAsync(safeSession, ex, GetInputMessages(inputMessages, continuationToken), aiContextProviderMessages, cancellationToken).ConfigureAwait(false); + await this.NotifyChatHistoryProviderOfFailureAsync(safeSession, ex, GetInputMessages(inputMessages, continuationToken), chatHistoryProviderMessages, aiContextProviderMessages, chatOptions, cancellationToken).ConfigureAwait(false); + await this.NotifyAIContextProviderOfFailureAsync(safeSession, ex, GetInputMessages(inputMessages, continuationToken), aiContextProviderMessages, cancellationToken).ConfigureAwait(false); throw; } } @@ -286,10 +286,10 @@ protected override async IAsyncEnumerable RunCoreStreamingA await this.UpdateSessionWithTypeAndConversationIdAsync(safeSession, chatResponse.ConversationId, cancellationToken).ConfigureAwait(false); // To avoid inconsistent state we only notify the session of the input messages if no error occurs after the initial request. - await NotifyChatHistoryProviderOfNewMessagesAsync(safeSession, GetInputMessages(inputMessages, continuationToken), chatHistoryProviderMessages, aiContextProviderMessages, chatResponse.Messages, chatOptions, cancellationToken).ConfigureAwait(false); + await this.NotifyChatHistoryProviderOfNewMessagesAsync(safeSession, GetInputMessages(inputMessages, continuationToken), chatHistoryProviderMessages, aiContextProviderMessages, chatResponse.Messages, chatOptions, cancellationToken).ConfigureAwait(false); // Notify the AIContextProvider of all new messages. - await NotifyAIContextProviderOfSuccessAsync(safeSession, GetInputMessages(inputMessages, continuationToken), aiContextProviderMessages, chatResponse.Messages, cancellationToken).ConfigureAwait(false); + await this.NotifyAIContextProviderOfSuccessAsync(safeSession, GetInputMessages(inputMessages, continuationToken), aiContextProviderMessages, chatResponse.Messages, cancellationToken).ConfigureAwait(false); } /// @@ -455,8 +455,8 @@ private async Task RunCoreAsync RunCoreAsync RunCoreAsync /// Notify the when an agent run succeeded, if there is an . /// - private static async Task NotifyAIContextProviderOfSuccessAsync( + private async Task NotifyAIContextProviderOfSuccessAsync( ChatClientAgentSession session, IEnumerable inputMessages, IList? aiContextProviderMessages, @@ -497,7 +497,7 @@ private static async Task NotifyAIContextProviderOfSuccessAsync( { if (session.AIContextProvider is not null) { - await session.AIContextProvider.InvokedAsync(new(inputMessages, aiContextProviderMessages) { ResponseMessages = responseMessages }, + await session.AIContextProvider.InvokedAsync(new(this, session, inputMessages, aiContextProviderMessages) { ResponseMessages = responseMessages }, cancellationToken).ConfigureAwait(false); } } @@ -505,7 +505,7 @@ await session.AIContextProvider.InvokedAsync(new(inputMessages, aiContextProvide /// /// Notify the of any failure during an agent run, if there is an . /// - private static async Task NotifyAIContextProviderOfFailureAsync( + private async Task NotifyAIContextProviderOfFailureAsync( ChatClientAgentSession session, Exception ex, IEnumerable inputMessages, @@ -514,7 +514,7 @@ private static async Task NotifyAIContextProviderOfFailureAsync( { if (session.AIContextProvider is not null) { - await session.AIContextProvider.InvokedAsync(new(inputMessages, aiContextProviderMessages) { InvokeException = ex }, + await session.AIContextProvider.InvokedAsync(new(this, session, inputMessages, aiContextProviderMessages) { InvokeException = ex }, cancellationToken).ConfigureAwait(false); } } @@ -726,7 +726,7 @@ private async Task // Add any existing messages from the session to the messages to be sent to the chat client. if (chatHistoryProvider is not null) { - var invokingContext = new ChatHistoryProvider.InvokingContext(inputMessages); + var invokingContext = new ChatHistoryProvider.InvokingContext(this, typedSession, inputMessages); var providerMessages = await chatHistoryProvider.InvokingAsync(invokingContext, cancellationToken).ConfigureAwait(false); inputMessagesForChatClient.AddRange(providerMessages); chatHistoryProviderMessages = providerMessages as IList ?? providerMessages.ToList(); @@ -739,7 +739,7 @@ private async Task // messages and options with the additional context. if (typedSession.AIContextProvider is not null) { - var invokingContext = new AIContextProvider.InvokingContext(inputMessages); + var invokingContext = new AIContextProvider.InvokingContext(this, typedSession, inputMessages); var aiContext = await typedSession.AIContextProvider.InvokingAsync(invokingContext, cancellationToken).ConfigureAwait(false); if (aiContext.Messages is { Count: > 0 }) { @@ -812,7 +812,7 @@ private async Task UpdateSessionWithTypeAndConversationIdAsync(ChatClientAgentSe } } - private static Task NotifyChatHistoryProviderOfFailureAsync( + private Task NotifyChatHistoryProviderOfFailureAsync( ChatClientAgentSession session, Exception ex, IEnumerable requestMessages, @@ -827,7 +827,7 @@ private static Task NotifyChatHistoryProviderOfFailureAsync( // If we don't have one, it means that the chat history is service managed and the underlying service is responsible for storing messages. if (provider is not null) { - var invokedContext = new ChatHistoryProvider.InvokedContext(requestMessages, chatHistoryProviderMessages!) + var invokedContext = new ChatHistoryProvider.InvokedContext(this, session, requestMessages, chatHistoryProviderMessages!) { AIContextProviderMessages = aiContextProviderMessages, InvokeException = ex @@ -839,7 +839,7 @@ private static Task NotifyChatHistoryProviderOfFailureAsync( return Task.CompletedTask; } - private static Task NotifyChatHistoryProviderOfNewMessagesAsync( + private Task NotifyChatHistoryProviderOfNewMessagesAsync( ChatClientAgentSession session, IEnumerable requestMessages, IEnumerable? chatHistoryProviderMessages, @@ -854,7 +854,7 @@ private static Task NotifyChatHistoryProviderOfNewMessagesAsync( // If we don't have one, it means that the chat history is service managed and the underlying service is responsible for storing messages. if (provider is not null) { - var invokedContext = new ChatHistoryProvider.InvokedContext(requestMessages, chatHistoryProviderMessages!) + var invokedContext = new ChatHistoryProvider.InvokedContext(this, session, requestMessages, chatHistoryProviderMessages!) { AIContextProviderMessages = aiContextProviderMessages, ResponseMessages = responseMessages diff --git a/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgentSession.cs b/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgentSession.cs index 1a79ae64d1..cb791bd564 100644 --- a/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgentSession.cs +++ b/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgentSession.cs @@ -148,6 +148,8 @@ internal static async Task DeserializeAsync( ? await aiContextProviderFactory.Invoke(state?.AIContextProviderState ?? default, jsonSerializerOptions, cancellationToken).ConfigureAwait(false) : null; + session.StateBag = AgentSessionStateBag.Deserialize(state?.StateBag ?? default); + if (state?.ConversationId is string sessionId) { session.ConversationId = sessionId; @@ -176,6 +178,7 @@ internal JsonElement Serialize(JsonSerializerOptions? jsonSerializerOptions = nu ConversationId = this.ConversationId, ChatHistoryProviderState = chatHistoryProviderState is { ValueKind: not JsonValueKind.Undefined } ? chatHistoryProviderState : null, AIContextProviderState = aiContextProviderState is { ValueKind: not JsonValueKind.Undefined } ? aiContextProviderState : null, + StateBag = this.StateBag.Serialize(), }; return JsonSerializer.SerializeToElement(state, AgentJsonUtilities.DefaultOptions.GetTypeInfo(typeof(SessionState))); @@ -201,5 +204,7 @@ internal sealed class SessionState public JsonElement? ChatHistoryProviderState { get; set; } public JsonElement? AIContextProviderState { get; set; } + + public JsonElement? StateBag { get; set; } } } diff --git a/dotnet/tests/AgentConformance.IntegrationTests/IAgentFixture.cs b/dotnet/tests/AgentConformance.IntegrationTests/IAgentFixture.cs index 96b40d561b..5548c5aaf9 100644 --- a/dotnet/tests/AgentConformance.IntegrationTests/IAgentFixture.cs +++ b/dotnet/tests/AgentConformance.IntegrationTests/IAgentFixture.cs @@ -15,7 +15,7 @@ public interface IAgentFixture : IAsyncLifetime { AIAgent Agent { get; } - Task> GetChatHistoryAsync(AgentSession session); + Task> GetChatHistoryAsync(AIAgent agent, AgentSession session); Task DeleteSessionAsync(AgentSession session); } diff --git a/dotnet/tests/AgentConformance.IntegrationTests/RunStreamingTests.cs b/dotnet/tests/AgentConformance.IntegrationTests/RunStreamingTests.cs index f9cc732175..18982baaad 100644 --- a/dotnet/tests/AgentConformance.IntegrationTests/RunStreamingTests.cs +++ b/dotnet/tests/AgentConformance.IntegrationTests/RunStreamingTests.cs @@ -106,7 +106,7 @@ public virtual async Task SessionMaintainsHistoryAsync() Assert.Contains("Paris", response1Text); Assert.Contains("Vienna", response2Text); - var chatHistory = await this.Fixture.GetChatHistoryAsync(session); + var chatHistory = await this.Fixture.GetChatHistoryAsync(agent, session); Assert.Equal(4, chatHistory.Count); Assert.Equal(2, chatHistory.Count(x => x.Role == ChatRole.User)); Assert.Equal(2, chatHistory.Count(x => x.Role == ChatRole.Assistant)); diff --git a/dotnet/tests/AgentConformance.IntegrationTests/RunTests.cs b/dotnet/tests/AgentConformance.IntegrationTests/RunTests.cs index 302784a1a8..da1cebaf52 100644 --- a/dotnet/tests/AgentConformance.IntegrationTests/RunTests.cs +++ b/dotnet/tests/AgentConformance.IntegrationTests/RunTests.cs @@ -111,7 +111,7 @@ public virtual async Task SessionMaintainsHistoryAsync() Assert.Contains("Paris", result1.Text); Assert.Contains("Vienna", result2.Text); - var chatHistory = await this.Fixture.GetChatHistoryAsync(session); + var chatHistory = await this.Fixture.GetChatHistoryAsync(agent, session); Assert.Equal(4, chatHistory.Count); Assert.Equal(2, chatHistory.Count(x => x.Role == ChatRole.User)); Assert.Equal(2, chatHistory.Count(x => x.Role == ChatRole.Assistant)); diff --git a/dotnet/tests/AnthropicChatCompletion.IntegrationTests/AnthropicChatCompletionFixture.cs b/dotnet/tests/AnthropicChatCompletion.IntegrationTests/AnthropicChatCompletionFixture.cs index 5f0fcbca2c..a0e4b64763 100644 --- a/dotnet/tests/AnthropicChatCompletion.IntegrationTests/AnthropicChatCompletionFixture.cs +++ b/dotnet/tests/AnthropicChatCompletion.IntegrationTests/AnthropicChatCompletionFixture.cs @@ -35,7 +35,7 @@ public AnthropicChatCompletionFixture(bool useReasoningChatModel, bool useBeta) public IChatClient ChatClient => this._agent.ChatClient; - public async Task> GetChatHistoryAsync(AgentSession session) + public async Task> GetChatHistoryAsync(AIAgent agent, AgentSession session) { var typedSession = (ChatClientAgentSession)session; @@ -44,7 +44,7 @@ public async Task> GetChatHistoryAsync(AgentSession session) return []; } - return (await typedSession.ChatHistoryProvider.InvokingAsync(new([]))).ToList(); + return (await typedSession.ChatHistoryProvider.InvokingAsync(new(agent, session, []))).ToList(); } public Task CreateChatClientAgentAsync( diff --git a/dotnet/tests/AzureAI.IntegrationTests/AIProjectClientFixture.cs b/dotnet/tests/AzureAI.IntegrationTests/AIProjectClientFixture.cs index 4b78d30f1c..c655fd7a58 100644 --- a/dotnet/tests/AzureAI.IntegrationTests/AIProjectClientFixture.cs +++ b/dotnet/tests/AzureAI.IntegrationTests/AIProjectClientFixture.cs @@ -33,7 +33,7 @@ public async Task CreateConversationAsync() return response.Value.Id; } - public async Task> GetChatHistoryAsync(AgentSession session) + public async Task> GetChatHistoryAsync(AIAgent agent, AgentSession session) { var chatClientSession = (ChatClientAgentSession)session; @@ -53,7 +53,7 @@ public async Task> GetChatHistoryAsync(AgentSession session) return []; } - return (await chatClientSession.ChatHistoryProvider.InvokingAsync(new([]))).ToList(); + return (await chatClientSession.ChatHistoryProvider.InvokingAsync(new(agent, session, []))).ToList(); } private async Task> GetChatHistoryFromResponsesChainAsync(string conversationId) diff --git a/dotnet/tests/AzureAIAgentsPersistent.IntegrationTests/AzureAIAgentsPersistentFixture.cs b/dotnet/tests/AzureAIAgentsPersistent.IntegrationTests/AzureAIAgentsPersistentFixture.cs index 2f59630c38..3e3272d951 100644 --- a/dotnet/tests/AzureAIAgentsPersistent.IntegrationTests/AzureAIAgentsPersistentFixture.cs +++ b/dotnet/tests/AzureAIAgentsPersistent.IntegrationTests/AzureAIAgentsPersistentFixture.cs @@ -24,7 +24,7 @@ public class AzureAIAgentsPersistentFixture : IChatClientAgentFixture public AIAgent Agent => this._agent; - public async Task> GetChatHistoryAsync(AgentSession session) + public async Task> GetChatHistoryAsync(AIAgent agent, AgentSession session) { List messages = []; var typedSession = (ChatClientAgentSession)session; diff --git a/dotnet/tests/CopilotStudio.IntegrationTests/CopilotStudioFixture.cs b/dotnet/tests/CopilotStudio.IntegrationTests/CopilotStudioFixture.cs index 3b3ac7ff7e..dd5fe46ecc 100644 --- a/dotnet/tests/CopilotStudio.IntegrationTests/CopilotStudioFixture.cs +++ b/dotnet/tests/CopilotStudio.IntegrationTests/CopilotStudioFixture.cs @@ -20,7 +20,7 @@ public class CopilotStudioFixture : IAgentFixture { public AIAgent Agent { get; private set; } = null!; - public Task> GetChatHistoryAsync(AgentSession session) => + public Task> GetChatHistoryAsync(AIAgent agent, AgentSession session) => throw new NotSupportedException("CopilotStudio doesn't allow retrieval of chat history."); public Task DeleteSessionAsync(AgentSession session) => diff --git a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AIContextProviderTests.cs b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AIContextProviderTests.cs index b6aabd081e..94aa73858a 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AIContextProviderTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AIContextProviderTests.cs @@ -6,17 +6,21 @@ using System.Threading; using System.Threading.Tasks; using Microsoft.Extensions.AI; +using Moq; namespace Microsoft.Agents.AI.Abstractions.UnitTests; public class AIContextProviderTests { + private static readonly AIAgent s_mockAgent = new Mock().Object; + private static readonly AgentSession s_mockSession = new Mock().Object; + [Fact] public async Task InvokedAsync_ReturnsCompletedTaskAsync() { var provider = new TestAIContextProvider(); var messages = new ReadOnlyCollection([]); - var task = provider.InvokedAsync(new(messages, aiContextProviderMessages: null)); + var task = provider.InvokedAsync(new(s_mockAgent, s_mockSession, messages, aiContextProviderMessages: null)); Assert.Equal(default, task); } @@ -31,13 +35,13 @@ public void Serialize_ReturnsEmptyElement() [Fact] public void InvokingContext_Constructor_ThrowsForNullMessages() { - Assert.Throws(() => new AIContextProvider.InvokingContext(null!)); + Assert.Throws(() => new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, null!)); } [Fact] public void InvokedContext_Constructor_ThrowsForNullMessages() { - Assert.Throws(() => new AIContextProvider.InvokedContext(null!, aiContextProviderMessages: null)); + Assert.Throws(() => new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, null!, aiContextProviderMessages: null)); } #region GetService Method Tests @@ -163,7 +167,7 @@ public void InvokingContext_RequestMessages_SetterThrowsForNull() { // Arrange var messages = new ReadOnlyCollection([new(ChatRole.User, "Hello")]); - var context = new AIContextProvider.InvokingContext(messages); + var context = new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, messages); // Act & Assert Assert.Throws(() => context.RequestMessages = null!); @@ -175,7 +179,7 @@ public void InvokingContext_RequestMessages_SetterRoundtrips() // Arrange var initialMessages = new ReadOnlyCollection([new(ChatRole.User, "Hello")]); var newMessages = new List { new(ChatRole.User, "New message") }; - var context = new AIContextProvider.InvokingContext(initialMessages); + var context = new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, initialMessages); // Act context.RequestMessages = newMessages; @@ -184,6 +188,55 @@ public void InvokingContext_RequestMessages_SetterRoundtrips() Assert.Same(newMessages, context.RequestMessages); } + [Fact] + public void InvokingContext_Agent_ReturnsConstructorValue() + { + // Arrange + var messages = new ReadOnlyCollection([new(ChatRole.User, "Hello")]); + + // Act + var context = new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, messages); + + // Assert + Assert.Same(s_mockAgent, context.Agent); + } + + [Fact] + public void InvokingContext_Session_ReturnsConstructorValue() + { + // Arrange + var messages = new ReadOnlyCollection([new(ChatRole.User, "Hello")]); + + // Act + var context = new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, messages); + + // Assert + Assert.Same(s_mockSession, context.Session); + } + + [Fact] + public void InvokingContext_Session_CanBeNull() + { + // Arrange + var messages = new ReadOnlyCollection([new(ChatRole.User, "Hello")]); + + // Act + var context = new AIContextProvider.InvokingContext(s_mockAgent, null, messages); + + // Assert + Assert.Null(context.Session); + } + + [Fact] + public void InvokingContext_Constructor_ThrowsForNullAgent() + { + // Arrange + var messages = new ReadOnlyCollection([new(ChatRole.User, "Hello")]); + + // Act & Assert + Assert.Throws(() => new AIContextProvider.InvokingContext(null!, s_mockSession, messages)); + } + #endregion #region InvokedContext Tests @@ -193,7 +246,7 @@ public void InvokedContext_RequestMessages_SetterThrowsForNull() { // Arrange var messages = new ReadOnlyCollection([new(ChatRole.User, "Hello")]); - var context = new AIContextProvider.InvokedContext(messages, aiContextProviderMessages: null); + var context = new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, messages, aiContextProviderMessages: null); // Act & Assert Assert.Throws(() => context.RequestMessages = null!); @@ -205,7 +258,7 @@ public void InvokedContext_RequestMessages_SetterRoundtrips() // Arrange var initialMessages = new ReadOnlyCollection([new(ChatRole.User, "Hello")]); var newMessages = new List { new(ChatRole.User, "New message") }; - var context = new AIContextProvider.InvokedContext(initialMessages, aiContextProviderMessages: null); + var context = new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, initialMessages, aiContextProviderMessages: null); // Act context.RequestMessages = newMessages; @@ -220,7 +273,7 @@ public void InvokedContext_AIContextProviderMessages_Roundtrips() // Arrange var requestMessages = new ReadOnlyCollection([new(ChatRole.User, "Hello")]); var aiContextMessages = new List { new(ChatRole.System, "AI context message") }; - var context = new AIContextProvider.InvokedContext(requestMessages, aiContextProviderMessages: null); + var context = new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages, aiContextProviderMessages: null); // Act context.AIContextProviderMessages = aiContextMessages; @@ -235,7 +288,7 @@ public void InvokedContext_ResponseMessages_Roundtrips() // Arrange var requestMessages = new ReadOnlyCollection([new(ChatRole.User, "Hello")]); var responseMessages = new List { new(ChatRole.Assistant, "Response message") }; - var context = new AIContextProvider.InvokedContext(requestMessages, aiContextProviderMessages: null); + var context = new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages, aiContextProviderMessages: null); // Act context.ResponseMessages = responseMessages; @@ -250,7 +303,7 @@ public void InvokedContext_InvokeException_Roundtrips() // Arrange var requestMessages = new ReadOnlyCollection([new(ChatRole.User, "Hello")]); var exception = new InvalidOperationException("Test exception"); - var context = new AIContextProvider.InvokedContext(requestMessages, aiContextProviderMessages: null); + var context = new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages, aiContextProviderMessages: null); // Act context.InvokeException = exception; @@ -259,6 +312,55 @@ public void InvokedContext_InvokeException_Roundtrips() Assert.Same(exception, context.InvokeException); } + [Fact] + public void InvokedContext_Agent_ReturnsConstructorValue() + { + // Arrange + var requestMessages = new ReadOnlyCollection([new(ChatRole.User, "Hello")]); + + // Act + var context = new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages, aiContextProviderMessages: null); + + // Assert + Assert.Same(s_mockAgent, context.Agent); + } + + [Fact] + public void InvokedContext_Session_ReturnsConstructorValue() + { + // Arrange + var requestMessages = new ReadOnlyCollection([new(ChatRole.User, "Hello")]); + + // Act + var context = new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages, aiContextProviderMessages: null); + + // Assert + Assert.Same(s_mockSession, context.Session); + } + + [Fact] + public void InvokedContext_Session_CanBeNull() + { + // Arrange + var requestMessages = new ReadOnlyCollection([new(ChatRole.User, "Hello")]); + + // Act + var context = new AIContextProvider.InvokedContext(s_mockAgent, null, requestMessages, aiContextProviderMessages: null); + + // Assert + Assert.Null(context.Session); + } + + [Fact] + public void InvokedContext_Constructor_ThrowsForNullAgent() + { + // Arrange + var requestMessages = new ReadOnlyCollection([new(ChatRole.User, "Hello")]); + + // Act & Assert + Assert.Throws(() => new AIContextProvider.InvokedContext(null!, s_mockSession, requestMessages, aiContextProviderMessages: null)); + } + #endregion private sealed class TestAIContextProvider : AIContextProvider diff --git a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AgentSessionStateBagTests.cs b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AgentSessionStateBagTests.cs new file mode 100644 index 0000000000..5d07a4749f --- /dev/null +++ b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AgentSessionStateBagTests.cs @@ -0,0 +1,407 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Text.Json; +using Microsoft.Agents.AI.Abstractions.UnitTests.Models; + +namespace Microsoft.Agents.AI.Abstractions.UnitTests; + +/// +/// Contains tests for the class. +/// +public sealed class AgentSessionStateBagTests +{ + #region Constructor Tests + + [Fact] + public void Constructor_Default_CreatesEmptyStateBag() + { + // Act + var stateBag = new AgentSessionStateBag(); + + // Assert + Assert.False(stateBag.TryGetValue("nonexistent", out _)); + } + + #endregion + + #region SetValue Tests + + [Fact] + public void SetValue_WithValidKeyAndValue_StoresValue() + { + // Arrange + var stateBag = new AgentSessionStateBag(); + + // Act + stateBag.SetValue("key1", "value1"); + + // Assert + Assert.True(stateBag.TryGetValue("key1", out var result)); + Assert.Equal("value1", result); + } + + [Fact] + public void SetValue_WithNullKey_ThrowsArgumentException() + { + // Arrange + var stateBag = new AgentSessionStateBag(); + + // Act & Assert + Assert.Throws(() => stateBag.SetValue(null!, "value")); + } + + [Fact] + public void SetValue_WithEmptyKey_ThrowsArgumentException() + { + // Arrange + var stateBag = new AgentSessionStateBag(); + + // Act & Assert + Assert.Throws(() => stateBag.SetValue("", "value")); + } + + [Fact] + public void SetValue_WithWhitespaceKey_ThrowsArgumentException() + { + // Arrange + var stateBag = new AgentSessionStateBag(); + + // Act & Assert + Assert.Throws(() => stateBag.SetValue(" ", "value")); + } + + [Fact] + public void SetValue_OverwritesExistingValue() + { + // Arrange + var stateBag = new AgentSessionStateBag(); + stateBag.SetValue("key1", "originalValue"); + + // Act + stateBag.SetValue("key1", "newValue"); + + // Assert + Assert.Equal("newValue", stateBag.GetValue("key1")); + } + + #endregion + + #region GetValue Tests + + [Fact] + public void GetValue_WithExistingKey_ReturnsValue() + { + // Arrange + var stateBag = new AgentSessionStateBag(); + stateBag.SetValue("key1", "value1"); + + // Act + var result = stateBag.GetValue("key1"); + + // Assert + Assert.Equal("value1", result); + } + + [Fact] + public void GetValue_WithNonexistentKey_ReturnsNull() + { + // Arrange + var stateBag = new AgentSessionStateBag(); + + // Act + var result = stateBag.GetValue("nonexistent"); + + // Assert + Assert.Null(result); + } + + [Fact] + public void GetValue_WithNullKey_ThrowsArgumentException() + { + // Arrange + var stateBag = new AgentSessionStateBag(); + + // Act & Assert + Assert.Throws(() => stateBag.GetValue(null!)); + } + + [Fact] + public void GetValue_WithEmptyKey_ThrowsArgumentException() + { + // Arrange + var stateBag = new AgentSessionStateBag(); + + // Act & Assert + Assert.Throws(() => stateBag.GetValue("")); + } + + [Fact] + public void GetValue_CachesDeserializedValue() + { + // Arrange + var stateBag = new AgentSessionStateBag(); + stateBag.SetValue("key1", "value1"); + + // Act + var result1 = stateBag.GetValue("key1"); + var result2 = stateBag.GetValue("key1"); + + // Assert + Assert.Same(result1, result2); + } + + #endregion + + #region TryGetValue Tests + + [Fact] + public void TryGetValue_WithExistingKey_ReturnsTrueAndValue() + { + // Arrange + var stateBag = new AgentSessionStateBag(); + stateBag.SetValue("key1", "value1"); + + // Act + var found = stateBag.TryGetValue("key1", out var result); + + // Assert + Assert.True(found); + Assert.Equal("value1", result); + } + + [Fact] + public void TryGetValue_WithNonexistentKey_ReturnsFalseAndNull() + { + // Arrange + var stateBag = new AgentSessionStateBag(); + + // Act + var found = stateBag.TryGetValue("nonexistent", out var result); + + // Assert + Assert.False(found); + Assert.Null(result); + } + + [Fact] + public void TryGetValue_WithNullKey_ThrowsArgumentException() + { + // Arrange + var stateBag = new AgentSessionStateBag(); + + // Act & Assert + Assert.Throws(() => stateBag.TryGetValue(null!, out _)); + } + + [Fact] + public void TryGetValue_WithEmptyKey_ThrowsArgumentException() + { + // Arrange + var stateBag = new AgentSessionStateBag(); + + // Act & Assert + Assert.Throws(() => stateBag.TryGetValue("", out _)); + } + + #endregion + + #region Serialize/Deserialize Tests + + [Fact] + public void Serialize_EmptyStateBag_ReturnsEmptyObject() + { + // Arrange + var stateBag = new AgentSessionStateBag(); + + // Act + var json = stateBag.Serialize(); + + // Assert + Assert.Equal(JsonValueKind.Object, json.ValueKind); + } + + [Fact] + public void Serialize_WithStringValue_ReturnsJsonWithValue() + { + // Arrange + var stateBag = new AgentSessionStateBag(); + stateBag.SetValue("stringKey", "stringValue"); + + // Act + var json = stateBag.Serialize(); + + // Assert + Assert.Equal(JsonValueKind.Object, json.ValueKind); + Assert.True(json.TryGetProperty("stringKey", out _)); + } + + [Fact] + public void Deserialize_FromJsonDocument_ReturnsEmptyStateBag() + { + // Arrange + var emptyJson = JsonDocument.Parse("{}").RootElement; + + // Act + var stateBag = AgentSessionStateBag.Deserialize(emptyJson); + + // Assert + Assert.False(stateBag.TryGetValue("nonexistent", out _)); + } + + [Fact] + public void Deserialize_NullElement_ReturnsEmptyStateBag() + { + // Arrange + var nullJson = default(JsonElement); + + // Act + var stateBag = AgentSessionStateBag.Deserialize(nullJson); + + // Assert + Assert.False(stateBag.TryGetValue("nonexistent", out _)); + } + + [Fact] + public void SerializeDeserialize_WithStringValue_Roundtrips() + { + // Arrange + var originalStateBag = new AgentSessionStateBag(); + originalStateBag.SetValue("stringKey", "stringValue"); + + // Act + var json = originalStateBag.Serialize(); + var restoredStateBag = AgentSessionStateBag.Deserialize(json); + + // Assert + Assert.Equal("stringValue", restoredStateBag.GetValue("stringKey")); + } + + #endregion + + #region Thread Safety Tests + + [Fact] + public async System.Threading.Tasks.Task SetValue_MultipleConcurrentWrites_DoesNotThrowAsync() + { + // Arrange + var stateBag = new AgentSessionStateBag(); + var tasks = new System.Threading.Tasks.Task[100]; + + // Act + for (int i = 0; i < 100; i++) + { + int index = i; + tasks[i] = System.Threading.Tasks.Task.Run(() => stateBag.SetValue($"key{index}", $"value{index}")); + } + + await System.Threading.Tasks.Task.WhenAll(tasks); + + // Assert + for (int i = 0; i < 100; i++) + { + Assert.True(stateBag.TryGetValue($"key{i}", out var value)); + Assert.Equal($"value{i}", value); + } + } + + #endregion + + #region Complex Object Tests + + [Fact] + public void SetValue_WithComplexObject_StoresValue() + { + // Arrange + var stateBag = new AgentSessionStateBag(); + var animal = new Animal { Id = 1, FullName = "Buddy", Species = Species.Bear }; + + // Act + stateBag.SetValue("animal", animal, TestJsonSerializerContext.Default.Options); + + // Assert + Animal? result = stateBag.GetValue("animal", TestJsonSerializerContext.Default.Options); + Assert.NotNull(result); + Assert.Equal(1, result.Id); + Assert.Equal("Buddy", result.FullName); + Assert.Equal(Species.Bear, result.Species); + } + + [Fact] + public void GetValue_WithComplexObject_CachesDeserializedValue() + { + // Arrange + var stateBag = new AgentSessionStateBag(); + var animal = new Animal { Id = 2, FullName = "Whiskers", Species = Species.Tiger }; + stateBag.SetValue("animal", animal, TestJsonSerializerContext.Default.Options); + + // Act + Animal? result1 = stateBag.GetValue("animal", TestJsonSerializerContext.Default.Options); + Animal? result2 = stateBag.GetValue("animal", TestJsonSerializerContext.Default.Options); + + // Assert + Assert.Same(result1, result2); + } + + [Fact] + public void TryGetValue_WithComplexObject_ReturnsTrueAndValue() + { + // Arrange + var stateBag = new AgentSessionStateBag(); + var animal = new Animal { Id = 3, FullName = "Goldie", Species = Species.Walrus }; + stateBag.SetValue("animal", animal, TestJsonSerializerContext.Default.Options); + + // Act + bool found = stateBag.TryGetValue("animal", out Animal? result, TestJsonSerializerContext.Default.Options); + + // Assert + Assert.True(found); + Assert.NotNull(result); + Assert.Equal(3, result.Id); + Assert.Equal("Goldie", result.FullName); + Assert.Equal(Species.Walrus, result.Species); + } + + [Fact] + public void SerializeDeserialize_WithComplexObject_Roundtrips() + { + // Arrange + var originalStateBag = new AgentSessionStateBag(); + var animal = new Animal { Id = 4, FullName = "Polly", Species = Species.Bear }; + originalStateBag.SetValue("animal", animal, TestJsonSerializerContext.Default.Options); + + // Act + JsonElement json = originalStateBag.Serialize(); + AgentSessionStateBag restoredStateBag = AgentSessionStateBag.Deserialize(json); + + // Assert + Animal? restoredAnimal = restoredStateBag.GetValue("animal", TestJsonSerializerContext.Default.Options); + Assert.NotNull(restoredAnimal); + Assert.Equal(4, restoredAnimal.Id); + Assert.Equal("Polly", restoredAnimal.FullName); + Assert.Equal(Species.Bear, restoredAnimal.Species); + } + + [Fact] + public void Serialize_WithComplexObject_ReturnsJsonWithProperties() + { + // Arrange + var stateBag = new AgentSessionStateBag(); + var animal = new Animal { Id = 7, FullName = "Spot", Species = Species.Walrus }; + stateBag.SetValue("animal", animal, TestJsonSerializerContext.Default.Options); + + // Act + JsonElement json = stateBag.Serialize(); + + // Assert + Assert.Equal(JsonValueKind.Object, json.ValueKind); + Assert.True(json.TryGetProperty("animal", out JsonElement animalElement)); + Assert.True(animalElement.TryGetProperty("jsonValue", out JsonElement jsonValueElement)); + Assert.Equal(JsonValueKind.Object, jsonValueElement.ValueKind); + Assert.Equal(7, jsonValueElement.GetProperty("id").GetInt32()); + Assert.Equal("Spot", jsonValueElement.GetProperty("fullName").GetString()); + Assert.Equal("Walrus", jsonValueElement.GetProperty("species").GetString()); + } + + #endregion +} diff --git a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AgentSessionTests.cs b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AgentSessionTests.cs index 5a776c9fb0..b80f0a4fd2 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AgentSessionTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AgentSessionTests.cs @@ -11,6 +11,21 @@ namespace Microsoft.Agents.AI.Abstractions.UnitTests; /// public class AgentSessionTests { + #region StateBag Tests + + [Fact] + public void StateBag_Values_Roundtrips() + { + // Arrange + var session = new TestAgentSession(); + + // Act & Assert + session.StateBag.SetValue("key1", "value1"); + Assert.Equal("value1", session.StateBag.GetValue("key1")); + } + + #endregion + #region GetService Method Tests /// diff --git a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/ChatHistoryProviderExtensionsTests.cs b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/ChatHistoryProviderExtensionsTests.cs index 84a0242320..a74906c801 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/ChatHistoryProviderExtensionsTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/ChatHistoryProviderExtensionsTests.cs @@ -14,6 +14,9 @@ namespace Microsoft.Agents.AI.Abstractions.UnitTests; /// public sealed class ChatHistoryProviderExtensionsTests { + private static readonly AIAgent s_mockAgent = new Mock().Object; + private static readonly AgentSession s_mockSession = new Mock().Object; + [Fact] public void WithMessageFilters_ReturnsChatHistoryProviderMessageFilter() { @@ -35,7 +38,7 @@ public async Task WithMessageFilters_InvokingFilter_IsAppliedAsync() // Arrange Mock providerMock = new(); List innerMessages = [new(ChatRole.User, "Hello"), new(ChatRole.Assistant, "Hi")]; - ChatHistoryProvider.InvokingContext context = new([new ChatMessage(ChatRole.User, "Test")]); + ChatHistoryProvider.InvokingContext context = new(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Test")]); providerMock .Setup(p => p.InvokingAsync(context, It.IsAny())) @@ -59,7 +62,7 @@ public async Task WithMessageFilters_InvokedFilter_IsAppliedAsync() Mock providerMock = new(); List requestMessages = [new(ChatRole.User, "Hello")]; List chatHistoryProviderMessages = [new(ChatRole.System, "System")]; - ChatHistoryProvider.InvokedContext context = new(requestMessages, chatHistoryProviderMessages) + ChatHistoryProvider.InvokedContext context = new(s_mockAgent, s_mockSession, requestMessages, chatHistoryProviderMessages) { ResponseMessages = [new ChatMessage(ChatRole.Assistant, "Response")] }; @@ -106,7 +109,7 @@ public async Task WithAIContextProviderMessageRemoval_RemovesAIContextProviderMe List requestMessages = [new(ChatRole.User, "Hello")]; List chatHistoryProviderMessages = [new(ChatRole.System, "System")]; List aiContextProviderMessages = [new(ChatRole.System, "Context")]; - ChatHistoryProvider.InvokedContext context = new(requestMessages, chatHistoryProviderMessages) + ChatHistoryProvider.InvokedContext context = new(s_mockAgent, s_mockSession, requestMessages, chatHistoryProviderMessages) { AIContextProviderMessages = aiContextProviderMessages }; diff --git a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/ChatHistoryProviderMessageFilterTests.cs b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/ChatHistoryProviderMessageFilterTests.cs index 43a3e78f10..4b955a43c0 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/ChatHistoryProviderMessageFilterTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/ChatHistoryProviderMessageFilterTests.cs @@ -16,6 +16,9 @@ namespace Microsoft.Agents.AI.Abstractions.UnitTests; /// public sealed class ChatHistoryProviderMessageFilterTests { + private static readonly AIAgent s_mockAgent = new Mock().Object; + private static readonly AgentSession s_mockSession = new Mock().Object; + [Fact] public void Constructor_WithNullInnerProvider_ThrowsArgumentNullException() { @@ -59,7 +62,7 @@ public async Task InvokingAsync_WithNoOpFilters_ReturnsInnerProviderMessagesAsyn new(ChatRole.User, "Hello"), new(ChatRole.Assistant, "Hi there!") }; - var context = new ChatHistoryProvider.InvokingContext([new ChatMessage(ChatRole.User, "Test")]); + var context = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Test")]); innerProviderMock .Setup(s => s.InvokingAsync(context, It.IsAny())) @@ -88,7 +91,7 @@ public async Task InvokingAsync_WithInvokingFilter_AppliesFilterAsync() new(ChatRole.Assistant, "Hi there!"), new(ChatRole.User, "How are you?") }; - var context = new ChatHistoryProvider.InvokingContext([new ChatMessage(ChatRole.User, "Test")]); + var context = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Test")]); innerProviderMock .Setup(s => s.InvokingAsync(context, It.IsAny())) @@ -118,7 +121,7 @@ public async Task InvokingAsync_WithInvokingFilter_CanModifyMessagesAsync() new(ChatRole.User, "Hello"), new(ChatRole.Assistant, "Hi there!") }; - var context = new ChatHistoryProvider.InvokingContext([new ChatMessage(ChatRole.User, "Test")]); + var context = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Test")]); innerProviderMock .Setup(s => s.InvokingAsync(context, It.IsAny())) @@ -147,7 +150,7 @@ public async Task InvokedAsync_WithInvokedFilter_AppliesFilterAsync() var requestMessages = new List { new(ChatRole.User, "Hello") }; var chatHistoryProviderMessages = new List { new(ChatRole.System, "System") }; var responseMessages = new List { new(ChatRole.Assistant, "Response") }; - var context = new ChatHistoryProvider.InvokedContext(requestMessages, chatHistoryProviderMessages) + var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages, chatHistoryProviderMessages) { ResponseMessages = responseMessages }; @@ -162,7 +165,7 @@ public async Task InvokedAsync_WithInvokedFilter_AppliesFilterAsync() ChatHistoryProvider.InvokedContext InvokedFilter(ChatHistoryProvider.InvokedContext ctx) { var modifiedRequestMessages = ctx.RequestMessages.Select(m => new ChatMessage(m.Role, $"[FILTERED] {m.Text}")).ToList(); - return new ChatHistoryProvider.InvokedContext(modifiedRequestMessages, ctx.ChatHistoryProviderMessages) + return new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, modifiedRequestMessages, ctx.ChatHistoryProviderMessages) { ResponseMessages = ctx.ResponseMessages, AIContextProviderMessages = ctx.AIContextProviderMessages, diff --git a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/ChatHistoryProviderTests.cs b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/ChatHistoryProviderTests.cs index a26ef199d9..5e0fbe9817 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/ChatHistoryProviderTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/ChatHistoryProviderTests.cs @@ -6,6 +6,7 @@ using System.Threading; using System.Threading.Tasks; using Microsoft.Extensions.AI; +using Moq; namespace Microsoft.Agents.AI.Abstractions.UnitTests; @@ -14,6 +15,9 @@ namespace Microsoft.Agents.AI.Abstractions.UnitTests; /// public class ChatHistoryProviderTests { + private static readonly AIAgent s_mockAgent = new Mock().Object; + private static readonly AgentSession s_mockSession = new Mock().Object; + #region GetService Method Tests [Fact] @@ -82,7 +86,7 @@ public void GetService_Generic_ReturnsNullForUnrelatedType() public void InvokingContext_Constructor_ThrowsForNullMessages() { // Arrange & Act & Assert - Assert.Throws(() => new ChatHistoryProvider.InvokingContext(null!)); + Assert.Throws(() => new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, null!)); } [Fact] @@ -90,7 +94,7 @@ public void InvokingContext_RequestMessages_SetterThrowsForNull() { // Arrange var messages = new List { new(ChatRole.User, "Hello") }; - var context = new ChatHistoryProvider.InvokingContext(messages); + var context = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, messages); // Act & Assert Assert.Throws(() => context.RequestMessages = null!); @@ -102,7 +106,7 @@ public void InvokingContext_RequestMessages_SetterRoundtrips() // Arrange var initialMessages = new List { new(ChatRole.User, "Hello") }; var newMessages = new List { new(ChatRole.User, "New message") }; - var context = new ChatHistoryProvider.InvokingContext(initialMessages); + var context = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, initialMessages); // Act context.RequestMessages = newMessages; @@ -111,6 +115,55 @@ public void InvokingContext_RequestMessages_SetterRoundtrips() Assert.Same(newMessages, context.RequestMessages); } + [Fact] + public void InvokingContext_Agent_ReturnsConstructorValue() + { + // Arrange + var messages = new List { new(ChatRole.User, "Hello") }; + + // Act + var context = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, messages); + + // Assert + Assert.Same(s_mockAgent, context.Agent); + } + + [Fact] + public void InvokingContext_Session_ReturnsConstructorValue() + { + // Arrange + var messages = new List { new(ChatRole.User, "Hello") }; + + // Act + var context = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, messages); + + // Assert + Assert.Same(s_mockSession, context.Session); + } + + [Fact] + public void InvokingContext_Session_CanBeNull() + { + // Arrange + var messages = new List { new(ChatRole.User, "Hello") }; + + // Act + var context = new ChatHistoryProvider.InvokingContext(s_mockAgent, null, messages); + + // Assert + Assert.Null(context.Session); + } + + [Fact] + public void InvokingContext_Constructor_ThrowsForNullAgent() + { + // Arrange + var messages = new List { new(ChatRole.User, "Hello") }; + + // Act & Assert + Assert.Throws(() => new ChatHistoryProvider.InvokingContext(null!, s_mockSession, messages)); + } + #endregion #region InvokedContext Tests @@ -119,7 +172,7 @@ public void InvokingContext_RequestMessages_SetterRoundtrips() public void InvokedContext_Constructor_ThrowsForNullRequestMessages() { // Arrange & Act & Assert - Assert.Throws(() => new ChatHistoryProvider.InvokedContext(null!, [])); + Assert.Throws(() => new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, null!, [])); } [Fact] @@ -127,7 +180,7 @@ public void InvokedContext_RequestMessages_SetterThrowsForNull() { // Arrange var requestMessages = new List { new(ChatRole.User, "Hello") }; - var context = new ChatHistoryProvider.InvokedContext(requestMessages, []); + var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages, []); // Act & Assert Assert.Throws(() => context.RequestMessages = null!); @@ -139,7 +192,7 @@ public void InvokedContext_RequestMessages_SetterRoundtrips() // Arrange var initialMessages = new List { new(ChatRole.User, "Hello") }; var newMessages = new List { new(ChatRole.User, "New message") }; - var context = new ChatHistoryProvider.InvokedContext(initialMessages, []); + var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, initialMessages, []); // Act context.RequestMessages = newMessages; @@ -154,7 +207,7 @@ public void InvokedContext_ChatHistoryProviderMessages_SetterRoundtrips() // Arrange var requestMessages = new List { new(ChatRole.User, "Hello") }; var newProviderMessages = new List { new(ChatRole.System, "System message") }; - var context = new ChatHistoryProvider.InvokedContext(requestMessages, []); + var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages, []); // Act context.ChatHistoryProviderMessages = newProviderMessages; @@ -169,7 +222,7 @@ public void InvokedContext_AIContextProviderMessages_Roundtrips() // Arrange var requestMessages = new List { new(ChatRole.User, "Hello") }; var aiContextMessages = new List { new(ChatRole.System, "AI context message") }; - var context = new ChatHistoryProvider.InvokedContext(requestMessages, []); + var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages, []); // Act context.AIContextProviderMessages = aiContextMessages; @@ -184,7 +237,7 @@ public void InvokedContext_ResponseMessages_Roundtrips() // Arrange var requestMessages = new List { new(ChatRole.User, "Hello") }; var responseMessages = new List { new(ChatRole.Assistant, "Response message") }; - var context = new ChatHistoryProvider.InvokedContext(requestMessages, []); + var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages, []); // Act context.ResponseMessages = responseMessages; @@ -199,7 +252,7 @@ public void InvokedContext_InvokeException_Roundtrips() // Arrange var requestMessages = new List { new(ChatRole.User, "Hello") }; var exception = new InvalidOperationException("Test exception"); - var context = new ChatHistoryProvider.InvokedContext(requestMessages, []); + var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages, []); // Act context.InvokeException = exception; @@ -208,6 +261,55 @@ public void InvokedContext_InvokeException_Roundtrips() Assert.Same(exception, context.InvokeException); } + [Fact] + public void InvokedContext_Agent_ReturnsConstructorValue() + { + // Arrange + var requestMessages = new List { new(ChatRole.User, "Hello") }; + + // Act + var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages, []); + + // Assert + Assert.Same(s_mockAgent, context.Agent); + } + + [Fact] + public void InvokedContext_Session_ReturnsConstructorValue() + { + // Arrange + var requestMessages = new List { new(ChatRole.User, "Hello") }; + + // Act + var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages, []); + + // Assert + Assert.Same(s_mockSession, context.Session); + } + + [Fact] + public void InvokedContext_Session_CanBeNull() + { + // Arrange + var requestMessages = new List { new(ChatRole.User, "Hello") }; + + // Act + var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, null, requestMessages, []); + + // Assert + Assert.Null(context.Session); + } + + [Fact] + public void InvokedContext_Constructor_ThrowsForNullAgent() + { + // Arrange + var requestMessages = new List { new(ChatRole.User, "Hello") }; + + // Act & Assert + Assert.Throws(() => new ChatHistoryProvider.InvokedContext(null!, s_mockSession, requestMessages, [])); + } + #endregion private sealed class TestChatHistoryProvider : ChatHistoryProvider diff --git a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/InMemoryChatHistoryProviderTests.cs b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/InMemoryChatHistoryProviderTests.cs index ff31d0afc9..bf8ff998b9 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/InMemoryChatHistoryProviderTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/InMemoryChatHistoryProviderTests.cs @@ -18,6 +18,9 @@ namespace Microsoft.Agents.AI.Abstractions.UnitTests; /// public class InMemoryChatHistoryProviderTests { + private static readonly AIAgent s_mockAgent = new Mock().Object; + private static readonly AgentSession s_mockSession = new Mock().Object; + [Fact] public void Constructor_Throws_ForNullReducer() => // Arrange & Act & Assert @@ -68,7 +71,7 @@ public async Task InvokedAsyncAddsMessagesAsync() var provider = new InMemoryChatHistoryProvider(); provider.Add(providerMessages[0]); - var context = new ChatHistoryProvider.InvokedContext(requestMessages, providerMessages) + var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages, providerMessages) { AIContextProviderMessages = aiContextProviderMessages, ResponseMessages = responseMessages @@ -87,7 +90,7 @@ public async Task InvokedAsyncWithEmptyDoesNotFailAsync() { var provider = new InMemoryChatHistoryProvider(); - var context = new ChatHistoryProvider.InvokedContext([], []); + var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, [], []); await provider.InvokedAsync(context, CancellationToken.None); Assert.Empty(provider); @@ -102,7 +105,7 @@ public async Task InvokingAsyncReturnsAllMessagesAsync() new ChatMessage(ChatRole.Assistant, "Test2") }; - var context = new ChatHistoryProvider.InvokingContext([]); + var context = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, []); var result = (await provider.InvokingAsync(context, CancellationToken.None)).ToList(); Assert.Equal(2, result.Count); @@ -183,7 +186,7 @@ public async Task InvokedAsyncWithEmptyMessagesDoesNotChangeProviderAsync() var provider = new InMemoryChatHistoryProvider(); var messages = new List(); - var context = new ChatHistoryProvider.InvokedContext(messages, []); + var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, messages, []); await provider.InvokedAsync(context, CancellationToken.None); Assert.Empty(provider); @@ -520,7 +523,7 @@ public async Task AddMessagesAsync_WithReducer_AfterMessageAdded_InvokesReducerA var provider = new InMemoryChatHistoryProvider(reducerMock.Object, InMemoryChatHistoryProvider.ChatReducerTriggerEvent.AfterMessageAdded); // Act - var context = new ChatHistoryProvider.InvokedContext(originalMessages, []); + var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, originalMessages, []); await provider.InvokedAsync(context, CancellationToken.None); // Assert @@ -556,7 +559,7 @@ public async Task GetMessagesAsync_WithReducer_BeforeMessagesRetrieval_InvokesRe } // Act - var invokingContext = new ChatHistoryProvider.InvokingContext(Array.Empty()); + var invokingContext = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, Array.Empty()); var result = (await provider.InvokingAsync(invokingContext, CancellationToken.None)).ToList(); // Assert @@ -579,7 +582,7 @@ public async Task AddMessagesAsync_WithReducer_ButWrongTrigger_DoesNotInvokeRedu var provider = new InMemoryChatHistoryProvider(reducerMock.Object, InMemoryChatHistoryProvider.ChatReducerTriggerEvent.BeforeMessagesRetrieval); // Act - var context = new ChatHistoryProvider.InvokedContext(originalMessages, []); + var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, originalMessages, []); await provider.InvokedAsync(context, CancellationToken.None); // Assert @@ -605,7 +608,7 @@ public async Task GetMessagesAsync_WithReducer_ButWrongTrigger_DoesNotInvokeRedu }; // Act - var invokingContext = new ChatHistoryProvider.InvokingContext(Array.Empty()); + var invokingContext = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, Array.Empty()); var result = (await provider.InvokingAsync(invokingContext, CancellationToken.None)).ToList(); // Assert @@ -627,7 +630,7 @@ public async Task InvokedAsync_WithException_DoesNotAddMessagesAsync() { new(ChatRole.Assistant, "Hi there!") }; - var context = new ChatHistoryProvider.InvokedContext(requestMessages, []) + var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages, []) { ResponseMessages = responseMessages, InvokeException = new InvalidOperationException("Test exception") diff --git a/dotnet/tests/Microsoft.Agents.AI.CosmosNoSql.UnitTests/CosmosChatHistoryProviderTests.cs b/dotnet/tests/Microsoft.Agents.AI.CosmosNoSql.UnitTests/CosmosChatHistoryProviderTests.cs index ab2f58dfd5..f6589ff9e3 100644 --- a/dotnet/tests/Microsoft.Agents.AI.CosmosNoSql.UnitTests/CosmosChatHistoryProviderTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.CosmosNoSql.UnitTests/CosmosChatHistoryProviderTests.cs @@ -41,6 +41,9 @@ namespace Microsoft.Agents.AI.CosmosNoSql.UnitTests; [Collection("CosmosDB")] public sealed class CosmosChatHistoryProviderTests : IAsyncLifetime, IDisposable { + private static readonly AIAgent s_mockAgent = new Moq.Mock().Object; + private static readonly AgentSession s_mockSession = new Moq.Mock().Object; + // Cosmos DB Emulator connection settings private const string EmulatorEndpoint = "https://localhost:8081"; private const string EmulatorKey = "C2y6yDjf5/R+ob0N8A7Cgv30VRDJIWEHLM+4QDU5DE2nQ9nDuVTqobD4b8mGGyPMbIZnqyMsEcaGQy67XIw/Jw=="; @@ -214,7 +217,7 @@ public async Task InvokedAsync_WithSingleMessage_ShouldAddMessageAsync() using var provider = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, TestContainerId, conversationId); var message = new ChatMessage(ChatRole.User, "Hello, world!"); - var context = new ChatHistoryProvider.InvokedContext([message], []) + var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, [message], []) { ResponseMessages = [] }; @@ -226,7 +229,7 @@ public async Task InvokedAsync_WithSingleMessage_ShouldAddMessageAsync() await Task.Delay(100); // Assert - var invokingContext = new ChatHistoryProvider.InvokingContext([]); + var invokingContext = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, []); var messages = await provider.InvokingAsync(invokingContext); var messageList = messages.ToList(); @@ -293,7 +296,7 @@ public async Task InvokedAsync_WithMultipleMessages_ShouldAddAllMessagesAsync() new ChatMessage(ChatRole.Assistant, "Response message") }; - var context = new ChatHistoryProvider.InvokedContext(requestMessages, []) + var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages, []) { AIContextProviderMessages = aiContextProviderMessages, ResponseMessages = responseMessages @@ -303,7 +306,7 @@ public async Task InvokedAsync_WithMultipleMessages_ShouldAddAllMessagesAsync() await provider.InvokedAsync(context); // Assert - var invokingContext = new ChatHistoryProvider.InvokingContext([]); + var invokingContext = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, []); var retrievedMessages = await provider.InvokingAsync(invokingContext); var messageList = retrievedMessages.ToList(); Assert.Equal(5, messageList.Count); @@ -327,7 +330,7 @@ public async Task InvokingAsync_WithNoMessages_ShouldReturnEmptyAsync() using var provider = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, TestContainerId, Guid.NewGuid().ToString()); // Act - var invokingContext = new ChatHistoryProvider.InvokingContext([]); + var invokingContext = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, []); var messages = await provider.InvokingAsync(invokingContext); // Assert @@ -346,15 +349,15 @@ public async Task InvokingAsync_WithConversationIsolation_ShouldOnlyReturnMessag using var store1 = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, TestContainerId, conversation1); using var store2 = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, TestContainerId, conversation2); - var context1 = new ChatHistoryProvider.InvokedContext([new ChatMessage(ChatRole.User, "Message for conversation 1")], []); - var context2 = new ChatHistoryProvider.InvokedContext([new ChatMessage(ChatRole.User, "Message for conversation 2")], []); + var context1 = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Message for conversation 1")], []); + var context2 = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Message for conversation 2")], []); await store1.InvokedAsync(context1); await store2.InvokedAsync(context2); // Act - var invokingContext1 = new ChatHistoryProvider.InvokingContext([]); - var invokingContext2 = new ChatHistoryProvider.InvokingContext([]); + var invokingContext1 = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, []); + var invokingContext2 = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, []); var messages1 = await store1.InvokingAsync(invokingContext1); var messages2 = await store2.InvokingAsync(invokingContext2); @@ -391,11 +394,11 @@ public async Task FullWorkflow_AddAndGet_ShouldWorkCorrectlyAsync() }; // Act 1: Add messages - var invokedContext = new ChatHistoryProvider.InvokedContext(messages, []); + var invokedContext = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, messages, []); await originalStore.InvokedAsync(invokedContext); // Act 2: Verify messages were added - var invokingContext = new ChatHistoryProvider.InvokingContext([]); + var invokingContext = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, []); var retrievedMessages = await originalStore.InvokingAsync(invokingContext); var retrievedList = retrievedMessages.ToList(); Assert.Equal(5, retrievedList.Count); @@ -545,7 +548,7 @@ public async Task InvokedAsync_WithHierarchicalPartitioning_ShouldAddMessageWith using var provider = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, HierarchicalTestContainerId, TenantId, UserId, SessionId); var message = new ChatMessage(ChatRole.User, "Hello from hierarchical partitioning!"); - var context = new ChatHistoryProvider.InvokedContext([message], []); + var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, [message], []); // Act await provider.InvokedAsync(context); @@ -554,7 +557,7 @@ public async Task InvokedAsync_WithHierarchicalPartitioning_ShouldAddMessageWith await Task.Delay(100); // Assert - var invokingContext = new ChatHistoryProvider.InvokingContext([]); + var invokingContext = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, []); var messages = await provider.InvokingAsync(invokingContext); var messageList = messages.ToList(); @@ -602,7 +605,7 @@ public async Task InvokedAsync_WithHierarchicalMultipleMessages_ShouldAddAllMess new ChatMessage(ChatRole.User, "Third hierarchical message") }; - var context = new ChatHistoryProvider.InvokedContext(messages, []); + var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, messages, []); // Act await provider.InvokedAsync(context); @@ -611,7 +614,7 @@ public async Task InvokedAsync_WithHierarchicalMultipleMessages_ShouldAddAllMess await Task.Delay(100); // Assert - var invokingContext = new ChatHistoryProvider.InvokingContext([]); + var invokingContext = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, []); var retrievedMessages = await provider.InvokingAsync(invokingContext); var messageList = retrievedMessages.ToList(); @@ -637,8 +640,8 @@ public async Task InvokingAsync_WithHierarchicalPartitionIsolation_ShouldIsolate using var store2 = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, HierarchicalTestContainerId, TenantId, UserId2, SessionId); // Add messages to both stores - var context1 = new ChatHistoryProvider.InvokedContext([new ChatMessage(ChatRole.User, "Message from user 1")], []); - var context2 = new ChatHistoryProvider.InvokedContext([new ChatMessage(ChatRole.User, "Message from user 2")], []); + var context1 = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Message from user 1")], []); + var context2 = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Message from user 2")], []); await store1.InvokedAsync(context1); await store2.InvokedAsync(context2); @@ -647,8 +650,8 @@ public async Task InvokingAsync_WithHierarchicalPartitionIsolation_ShouldIsolate await Task.Delay(100); // Act & Assert - var invokingContext1 = new ChatHistoryProvider.InvokingContext([]); - var invokingContext2 = new ChatHistoryProvider.InvokingContext([]); + var invokingContext1 = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, []); + var invokingContext2 = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, []); var messages1 = await store1.InvokingAsync(invokingContext1); var messageList1 = messages1.ToList(); @@ -675,7 +678,7 @@ public async Task SerializeDeserialize_WithHierarchicalPartitioning_ShouldPreser using var originalStore = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, HierarchicalTestContainerId, TenantId, UserId, SessionId); - var context = new ChatHistoryProvider.InvokedContext([new ChatMessage(ChatRole.User, "Test serialization message")], []); + var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Test serialization message")], []); await originalStore.InvokedAsync(context); // Act - Serialize the provider state @@ -693,7 +696,7 @@ public async Task SerializeDeserialize_WithHierarchicalPartitioning_ShouldPreser await Task.Delay(100); // Assert - The deserialized provider should have the same functionality - var invokingContext = new ChatHistoryProvider.InvokingContext([]); + var invokingContext = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, []); var messages = await deserializedStore.InvokingAsync(invokingContext); var messageList = messages.ToList(); @@ -717,8 +720,8 @@ public async Task HierarchicalAndSimplePartitioning_ShouldCoexistAsync() using var hierarchicalProvider = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, HierarchicalTestContainerId, "tenant-coexist", "user-coexist", SessionId); // Add messages to both - var simpleContext = new ChatHistoryProvider.InvokedContext([new ChatMessage(ChatRole.User, "Simple partitioning message")], []); - var hierarchicalContext = new ChatHistoryProvider.InvokedContext([new ChatMessage(ChatRole.User, "Hierarchical partitioning message")], []); + var simpleContext = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Simple partitioning message")], []); + var hierarchicalContext = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Hierarchical partitioning message")], []); await simpleProvider.InvokedAsync(simpleContext); await hierarchicalProvider.InvokedAsync(hierarchicalContext); @@ -727,7 +730,7 @@ public async Task HierarchicalAndSimplePartitioning_ShouldCoexistAsync() await Task.Delay(100); // Act & Assert - var invokingContext = new ChatHistoryProvider.InvokingContext([]); + var invokingContext = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, []); var simpleMessages = await simpleProvider.InvokingAsync(invokingContext); var simpleMessageList = simpleMessages.ToList(); @@ -760,7 +763,7 @@ public async Task MaxMessagesToRetrieve_ShouldLimitAndReturnMostRecentAsync() await Task.Delay(10); // Small delay to ensure different timestamps } - var context = new ChatHistoryProvider.InvokedContext(messages, []); + var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, messages, []); await provider.InvokedAsync(context); // Wait for eventual consistency @@ -768,7 +771,7 @@ public async Task MaxMessagesToRetrieve_ShouldLimitAndReturnMostRecentAsync() // Act - Set max to 5 and retrieve provider.MaxMessagesToRetrieve = 5; - var invokingContext = new ChatHistoryProvider.InvokingContext([]); + var invokingContext = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, []); var retrievedMessages = await provider.InvokingAsync(invokingContext); var messageList = retrievedMessages.ToList(); @@ -798,14 +801,14 @@ public async Task MaxMessagesToRetrieve_Null_ShouldReturnAllMessagesAsync() messages.Add(new ChatMessage(ChatRole.User, $"Message {i}")); } - var context = new ChatHistoryProvider.InvokedContext(messages, []); + var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, messages, []); await provider.InvokedAsync(context); // Wait for eventual consistency await Task.Delay(100); // Act - No limit set (default null) - var invokingContext = new ChatHistoryProvider.InvokingContext([]); + var invokingContext = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, []); var retrievedMessages = await provider.InvokingAsync(invokingContext); var messageList = retrievedMessages.ToList(); diff --git a/dotnet/tests/Microsoft.Agents.AI.DurableTask.UnitTests/DurableAgentSessionTests.cs b/dotnet/tests/Microsoft.Agents.AI.DurableTask.UnitTests/DurableAgentSessionTests.cs index 4bf8ebc718..db6ec99058 100644 --- a/dotnet/tests/Microsoft.Agents.AI.DurableTask.UnitTests/DurableAgentSessionTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.DurableTask.UnitTests/DurableAgentSessionTests.cs @@ -15,7 +15,7 @@ public void BuiltInSerialization() JsonElement serializedSession = session.Serialize(); // Expected format: "{\"sessionId\":\"@dafx-test-agent@\"}" - string expectedSerializedSession = $"{{\"sessionId\":\"@dafx-{sessionId.Name}@{sessionId.Key}\"}}"; + string expectedSerializedSession = $"{{\"sessionId\":\"@dafx-{sessionId.Name}@{sessionId.Key}\",\"stateBag\":{{}}}}"; Assert.Equal(expectedSerializedSession, serializedSession.ToString()); DurableAgentSession deserializedSession = DurableAgentSession.Deserialize(serializedSession); @@ -33,7 +33,7 @@ public void STJSerialization() string serializedSession = JsonSerializer.Serialize(session, typeof(DurableAgentSession)); // Expected format: "{\"sessionId\":\"@dafx-test-agent@\"}" - string expectedSerializedSession = $"{{\"sessionId\":\"@dafx-{sessionId.Name}@{sessionId.Key}\"}}"; + string expectedSerializedSession = $"{{\"sessionId\":\"@dafx-{sessionId.Name}@{sessionId.Key}\",\"StateBag\":{{}}}}"; Assert.Equal(expectedSerializedSession, serializedSession); DurableAgentSession? deserializedSession = JsonSerializer.Deserialize(serializedSession); diff --git a/dotnet/tests/Microsoft.Agents.AI.Mem0.IntegrationTests/Mem0ProviderTests.cs b/dotnet/tests/Microsoft.Agents.AI.Mem0.IntegrationTests/Mem0ProviderTests.cs index bacc59833a..81ca4eb588 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Mem0.IntegrationTests/Mem0ProviderTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Mem0.IntegrationTests/Mem0ProviderTests.cs @@ -18,6 +18,9 @@ public sealed class Mem0ProviderTests : IDisposable { private const string SkipReason = "Requires a Mem0 service configured"; // Set to null to enable. + private static readonly AIAgent s_mockAgent = new Moq.Mock().Object; + private static readonly AgentSession s_mockSession = new Moq.Mock().Object; + private readonly HttpClient _httpClient; public Mem0ProviderTests() @@ -49,14 +52,14 @@ public async Task CanAddAndRetrieveUserMemoriesAsync() var sut = new Mem0Provider(this._httpClient, storageScope); await sut.ClearStoredMemoriesAsync(); - var ctxBefore = await sut.InvokingAsync(new AIContextProvider.InvokingContext([question])); + var ctxBefore = await sut.InvokingAsync(new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [question])); Assert.DoesNotContain("Caoimhe", ctxBefore.Messages?[0].Text ?? string.Empty); // Act - await sut.InvokedAsync(new AIContextProvider.InvokedContext([input], aiContextProviderMessages: null)); + await sut.InvokedAsync(new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, [input], aiContextProviderMessages: null)); var ctxAfterAdding = await GetContextWithRetryAsync(sut, question); await sut.ClearStoredMemoriesAsync(); - var ctxAfterClearing = await sut.InvokingAsync(new AIContextProvider.InvokingContext([question])); + var ctxAfterClearing = await sut.InvokingAsync(new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [question])); // Assert Assert.Contains("Caoimhe", ctxAfterAdding.Messages?[0].Text ?? string.Empty); @@ -73,14 +76,14 @@ public async Task CanAddAndRetrieveAgentMemoriesAsync() var sut = new Mem0Provider(this._httpClient, storageScope); await sut.ClearStoredMemoriesAsync(); - var ctxBefore = await sut.InvokingAsync(new AIContextProvider.InvokingContext([question])); + var ctxBefore = await sut.InvokingAsync(new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [question])); Assert.DoesNotContain("Caoimhe", ctxBefore.Messages?[0].Text ?? string.Empty); // Act - await sut.InvokedAsync(new AIContextProvider.InvokedContext([assistantIntro], aiContextProviderMessages: null)); + await sut.InvokedAsync(new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, [assistantIntro], aiContextProviderMessages: null)); var ctxAfterAdding = await GetContextWithRetryAsync(sut, question); await sut.ClearStoredMemoriesAsync(); - var ctxAfterClearing = await sut.InvokingAsync(new AIContextProvider.InvokingContext([question])); + var ctxAfterClearing = await sut.InvokingAsync(new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [question])); // Assert Assert.Contains("Caoimhe", ctxAfterAdding.Messages?[0].Text ?? string.Empty); @@ -99,13 +102,13 @@ public async Task DoesNotLeakMemoriesAcrossAgentScopesAsync() await sut1.ClearStoredMemoriesAsync(); await sut2.ClearStoredMemoriesAsync(); - var ctxBefore1 = await sut1.InvokingAsync(new AIContextProvider.InvokingContext([question])); - var ctxBefore2 = await sut2.InvokingAsync(new AIContextProvider.InvokingContext([question])); + var ctxBefore1 = await sut1.InvokingAsync(new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [question])); + var ctxBefore2 = await sut2.InvokingAsync(new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [question])); Assert.DoesNotContain("Caoimhe", ctxBefore1.Messages?[0].Text ?? string.Empty); Assert.DoesNotContain("Caoimhe", ctxBefore2.Messages?[0].Text ?? string.Empty); // Act - await sut1.InvokedAsync(new AIContextProvider.InvokedContext([assistantIntro], aiContextProviderMessages: null)); + await sut1.InvokedAsync(new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, [assistantIntro], aiContextProviderMessages: null)); var ctxAfterAdding1 = await GetContextWithRetryAsync(sut1, question); var ctxAfterAdding2 = await GetContextWithRetryAsync(sut2, question); @@ -123,7 +126,7 @@ private static async Task GetContextWithRetryAsync(Mem0Provider provi AIContext? ctx = null; for (int i = 0; i < attempts; i++) { - ctx = await provider.InvokingAsync(new AIContextProvider.InvokingContext([question]), CancellationToken.None); + ctx = await provider.InvokingAsync(new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [question]), CancellationToken.None); var text = ctx.Messages?[0].Text; if (!string.IsNullOrEmpty(text) && text.IndexOf("Caoimhe", StringComparison.OrdinalIgnoreCase) >= 0) { diff --git a/dotnet/tests/Microsoft.Agents.AI.Mem0.UnitTests/Mem0ProviderTests.cs b/dotnet/tests/Microsoft.Agents.AI.Mem0.UnitTests/Mem0ProviderTests.cs index 832881857d..b886784af9 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Mem0.UnitTests/Mem0ProviderTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Mem0.UnitTests/Mem0ProviderTests.cs @@ -18,6 +18,9 @@ namespace Microsoft.Agents.AI.Mem0.UnitTests; /// public sealed class Mem0ProviderTests : IDisposable { + private static readonly AIAgent s_mockAgent = new Mock().Object; + private static readonly AgentSession s_mockSession = new Mock().Object; + private readonly Mock> _loggerMock; private readonly Mock _loggerFactoryMock; private readonly RecordingHandler _handler = new(); @@ -96,7 +99,7 @@ public async Task InvokingAsync_PerformsSearch_AndReturnsContextMessageAsync() UserId = "user" }; var sut = new Mem0Provider(this._httpClient, storageScope, options: new() { EnableSensitiveTelemetryData = true }, loggerFactory: this._loggerFactoryMock.Object); - var invokingContext = new AIContextProvider.InvokingContext([new ChatMessage(ChatRole.User, "What is my name?")]); + var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "What is my name?")]); // Act var aiContext = await sut.InvokingAsync(invokingContext); @@ -161,7 +164,7 @@ public async Task InvokingAsync_LogsUserIdBasedOnEnableSensitiveTelemetryDataAsy var options = new Mem0ProviderOptions { EnableSensitiveTelemetryData = enableSensitiveTelemetryData }; var sut = new Mem0Provider(this._httpClient, storageScope, options: options, loggerFactory: this._loggerFactoryMock.Object); - var invokingContext = new AIContextProvider.InvokingContext([new ChatMessage(ChatRole.User, "Who am I?")]); + var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Who am I?")]); // Act await sut.InvokingAsync(invokingContext, CancellationToken.None); @@ -215,7 +218,7 @@ public async Task InvokedAsync_PersistsAllowedMessagesAsync() }; // Act - await sut.InvokedAsync(new AIContextProvider.InvokedContext(requestMessages, aiContextProviderMessages: null) { ResponseMessages = responseMessages }); + await sut.InvokedAsync(new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages, aiContextProviderMessages: null) { ResponseMessages = responseMessages }); // Assert var memoryPosts = this._handler.Requests.Where(r => r.RequestMessage.RequestUri!.AbsolutePath == "/v1/memories/" && r.RequestMessage.Method == HttpMethod.Post).ToList(); @@ -242,7 +245,7 @@ public async Task InvokedAsync_PersistsNothingForFailedRequestAsync() }; // Act - await sut.InvokedAsync(new AIContextProvider.InvokedContext(requestMessages, aiContextProviderMessages: null) { ResponseMessages = null, InvokeException = new InvalidOperationException("Request Failed") }); + await sut.InvokedAsync(new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages, aiContextProviderMessages: null) { ResponseMessages = null, InvokeException = new InvalidOperationException("Request Failed") }); // Assert Assert.Empty(this._handler.Requests); @@ -268,7 +271,7 @@ public async Task InvokedAsync_ShouldNotThrow_WhenStorageFailsAsync() }; // Act - await sut.InvokedAsync(new AIContextProvider.InvokedContext(requestMessages, aiContextProviderMessages: null) { ResponseMessages = responseMessages }); + await sut.InvokedAsync(new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages, aiContextProviderMessages: null) { ResponseMessages = responseMessages }); // Assert this._loggerMock.Verify( @@ -318,7 +321,7 @@ public async Task InvokedAsync_LogsUserIdBasedOnEnableSensitiveTelemetryDataAsyn }; // Act - await sut.InvokedAsync(new AIContextProvider.InvokedContext(requestMessages, aiContextProviderMessages: null) { ResponseMessages = responseMessages }); + await sut.InvokedAsync(new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages, aiContextProviderMessages: null) { ResponseMessages = responseMessages }); // Assert Assert.Equal(expectedLogCount, this._loggerMock.Invocations.Count); @@ -400,7 +403,7 @@ public async Task InvokingAsync_ShouldNotThrow_WhenSearchFailsAsync() // Arrange var storageScope = new Mem0ProviderScope { ApplicationId = "app" }; var provider = new Mem0Provider(this._httpClient, storageScope, loggerFactory: this._loggerFactoryMock.Object); - var invokingContext = new AIContextProvider.InvokingContext([new ChatMessage(ChatRole.User, "Q?")]); + var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Q?")]); // Act var aiContext = await provider.InvokingAsync(invokingContext, CancellationToken.None); diff --git a/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgentSessionTests.cs b/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgentSessionTests.cs index 4001b59090..2acfaf1a10 100644 --- a/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgentSessionTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgentSessionTests.cs @@ -152,6 +152,33 @@ public async Task VerifyDeserializeWithAIContextProviderAsync() Assert.Same(session.AIContextProvider, mockProvider.Object); } + [Fact] + public async Task VerifyDeserializeWithStateBagAsync() + { + // Arrange + var json = JsonSerializer.Deserialize(""" + { + "conversationId": "TestConvId", + "stateBag": { + "dog": { + "jsonValue": { + "name": "Fido" + } + } + } + } + """, TestJsonSerializerContext.Default.JsonElement); + Mock mockProvider = new(); + + // Act + var session = await ChatClientAgentSession.DeserializeAsync(json, aiContextProviderFactory: (_, _, _) => new(mockProvider.Object)); + + // Assert + var dog = session.StateBag.GetValue("dog", TestJsonSerializerContext.Default.Options); + Assert.NotNull(dog); + Assert.Equal("Fido", dog.Name); + } + [Fact] public async Task DeserializeWithInvalidJsonThrowsAsync() { @@ -249,6 +276,27 @@ public void VerifySessionSerializationWithWithAIContextProvider() mockProvider.Verify(m => m.Serialize(It.IsAny()), Times.Once); } + [Fact] + public void VerifySessionSerializationWithWithStateBag() + { + // Arrange + var session = new ChatClientAgentSession(); + session.StateBag.SetValue("dog", new Animal { Name = "Fido" }, TestJsonSerializerContext.Default.Options); + + // Act + var json = session.Serialize(); + + // Assert + Assert.Equal(JsonValueKind.Object, json.ValueKind); + Assert.True(json.TryGetProperty("stateBag", out var stateBagProperty)); + Assert.Equal(JsonValueKind.Object, stateBagProperty.ValueKind); + Assert.True(stateBagProperty.TryGetProperty("dog", out var dogProperty)); + Assert.Equal(JsonValueKind.Object, dogProperty.ValueKind); + Assert.True(dogProperty.TryGetProperty("jsonValue", out var dogJsonValueProperty)); + Assert.True(dogJsonValueProperty.TryGetProperty("name", out var nameProperty)); + Assert.Equal("Fido", nameProperty.GetString()); + } + /// /// Verify session serialization to JSON with custom options. /// @@ -289,6 +337,27 @@ public void VerifySessionSerializationWithCustomOptions() #endregion Serialize Tests + #region StateBag Roundtrip Tests + + [Fact] + public async Task VerifyStateBagRoundtripsAsync() + { + // Arrange + var session = new ChatClientAgentSession(); + session.StateBag.SetValue("dog", new Animal { Name = "Fido" }, TestJsonSerializerContext.Default.Options); + + // Act + var serializedSession = session.Serialize(); + var deserializedSession = await ChatClientAgentSession.DeserializeAsync(serializedSession); + + // Assert + var dog = deserializedSession.StateBag.GetValue("dog", TestJsonSerializerContext.Default.Options); + Assert.NotNull(dog); + Assert.Equal("Fido", dog.Name); + } + + #endregion + #region GetService Tests [Fact] @@ -327,4 +396,9 @@ public void GetService_RequestingChatHistoryProvider_ReturnsChatHistoryProvider( } #endregion + + internal sealed class Animal + { + public string Name { get; set; } = string.Empty; + } } diff --git a/dotnet/tests/Microsoft.Agents.AI.UnitTests/Data/TextSearchProviderTests.cs b/dotnet/tests/Microsoft.Agents.AI.UnitTests/Data/TextSearchProviderTests.cs index 3698ee7065..360c3071ae 100644 --- a/dotnet/tests/Microsoft.Agents.AI.UnitTests/Data/TextSearchProviderTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.UnitTests/Data/TextSearchProviderTests.cs @@ -17,6 +17,9 @@ namespace Microsoft.Agents.AI.UnitTests.Data; /// public sealed class TextSearchProviderTests { + private static readonly AIAgent s_mockAgent = new Mock().Object; + private static readonly AgentSession s_mockSession = new Mock().Object; + private readonly Mock> _loggerMock; private readonly Mock _loggerFactoryMock; @@ -64,10 +67,12 @@ public async Task InvokingAsync_ShouldInjectFormattedResultsAsync(string? overri var provider = new TextSearchProvider(SearchDelegateAsync, default, null, options, withLogging ? this._loggerFactoryMock.Object : null); var invokingContext = new AIContextProvider.InvokingContext( - [ - new ChatMessage(ChatRole.User, "Sample user question?"), - new ChatMessage(ChatRole.User, "Additional part") - ]); + s_mockAgent, + s_mockSession, + [ + new ChatMessage(ChatRole.User, "Sample user question?"), + new ChatMessage(ChatRole.User, "Additional part") + ]); // Act var aiContext = await provider.InvokingAsync(invokingContext, CancellationToken.None); @@ -139,7 +144,7 @@ public async Task InvokingAsync_OnDemand_ShouldExposeSearchToolAsync(string? ove FunctionToolDescription = overrideDescription }; var provider = new TextSearchProvider(this.NoResultSearchAsync, default, null, options); - var invokingContext = new AIContextProvider.InvokingContext([new ChatMessage(ChatRole.User, "Q?")]); + var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Q?")]); // Act var aiContext = await provider.InvokingAsync(invokingContext, CancellationToken.None); @@ -158,7 +163,7 @@ public async Task InvokingAsync_ShouldNotThrow_WhenSearchFailsAsync() { // Arrange var provider = new TextSearchProvider(this.FailingSearchAsync, default, null, loggerFactory: this._loggerFactoryMock.Object); - var invokingContext = new AIContextProvider.InvokingContext([new ChatMessage(ChatRole.User, "Q?")]); + var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Q?")]); // Act var aiContext = await provider.InvokingAsync(invokingContext, CancellationToken.None); @@ -251,7 +256,7 @@ public async Task InvokingAsync_ShouldUseContextFormatterWhenProvidedAsync() ContextFormatter = r => $"Custom formatted context with {r.Count} results." }; var provider = new TextSearchProvider(SearchDelegateAsync, default, null, options); - var invokingContext = new AIContextProvider.InvokingContext([new ChatMessage(ChatRole.User, "Q?")]); + var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Q?")]); // Act var aiContext = await provider.InvokingAsync(invokingContext, CancellationToken.None); @@ -285,7 +290,7 @@ public async Task InvokingAsync_WithRawRepresentations_ContextFormatterCanAccess ContextFormatter = r => string.Join(",", r.Select(x => ((RawPayload)x.RawRepresentation!).Id)) }; var provider = new TextSearchProvider(SearchDelegateAsync, default, null, options); - var invokingContext = new AIContextProvider.InvokingContext([new ChatMessage(ChatRole.User, "Q?")]); + var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Q?")]); // Act var aiContext = await provider.InvokingAsync(invokingContext, CancellationToken.None); @@ -302,7 +307,7 @@ public async Task InvokingAsync_WithNoResults_ShouldReturnEmptyContextAsync() // Arrange var options = new TextSearchProviderOptions { SearchTime = TextSearchProviderOptions.TextSearchBehavior.BeforeAIInvoke }; var provider = new TextSearchProvider(this.NoResultSearchAsync, default, null, options); - var invokingContext = new AIContextProvider.InvokingContext([new ChatMessage(ChatRole.User, "Q?")]); + var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Q?")]); // Act var aiContext = await provider.InvokingAsync(invokingContext, CancellationToken.None); @@ -340,12 +345,14 @@ public async Task InvokingAsync_WithPreviousFailedRequest_ShouldNotIncludeFailed new ChatMessage(ChatRole.User, "C"), new ChatMessage(ChatRole.Assistant, "D"), }; - await provider.InvokedAsync(new(initialMessages, aiContextProviderMessages: null) { InvokeException = new InvalidOperationException("Request Failed") }); + await provider.InvokedAsync(new(s_mockAgent, s_mockSession, initialMessages, aiContextProviderMessages: null) { InvokeException = new InvalidOperationException("Request Failed") }); var invokingContext = new AIContextProvider.InvokingContext( - [ - new ChatMessage(ChatRole.User, "E") - ]); + s_mockAgent, + s_mockSession, + [ + new ChatMessage(ChatRole.User, "E") + ]); // Act await provider.InvokingAsync(invokingContext, CancellationToken.None); @@ -380,12 +387,14 @@ public async Task InvokingAsync_WithRecentMessageMemory_ShouldIncludeStoredMessa new ChatMessage(ChatRole.User, "C"), new ChatMessage(ChatRole.Assistant, "D"), }; - await provider.InvokedAsync(new(initialMessages, aiContextProviderMessages: null)); + await provider.InvokedAsync(new(s_mockAgent, s_mockSession, initialMessages, aiContextProviderMessages: null)); var invokingContext = new AIContextProvider.InvokingContext( - [ - new ChatMessage(ChatRole.User, "E") - ]); + s_mockAgent, + s_mockSession, + [ + new ChatMessage(ChatRole.User, "E") + ]); // Act await provider.InvokingAsync(invokingContext, CancellationToken.None); @@ -414,20 +423,24 @@ public async Task InvokingAsync_WithAccumulatedMemoryAcrossInvocations_ShouldInc // First memory update (A,B) await provider.InvokedAsync(new( - [ - new ChatMessage(ChatRole.User, "A"), - new ChatMessage(ChatRole.Assistant, "B"), - ], aiContextProviderMessages: null)); + s_mockAgent, + s_mockSession, + [ + new ChatMessage(ChatRole.User, "A"), + new ChatMessage(ChatRole.Assistant, "B"), + ], aiContextProviderMessages: null)); // Second memory update (C,D,E) await provider.InvokedAsync(new( - [ - new ChatMessage(ChatRole.User, "C"), - new ChatMessage(ChatRole.Assistant, "D"), - new ChatMessage(ChatRole.User, "E"), - ], aiContextProviderMessages: null)); + s_mockAgent, + s_mockSession, + [ + new ChatMessage(ChatRole.User, "C"), + new ChatMessage(ChatRole.Assistant, "D"), + new ChatMessage(ChatRole.User, "E"), + ], aiContextProviderMessages: null)); - var invokingContext = new AIContextProvider.InvokingContext([new ChatMessage(ChatRole.User, "F")]); + var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "F")]); // Act await provider.InvokingAsync(invokingContext, CancellationToken.None); @@ -462,12 +475,14 @@ public async Task InvokingAsync_WithRecentMessageRolesIncluded_ShouldFilterRoles new ChatMessage(ChatRole.User, "U2"), new ChatMessage(ChatRole.Assistant, "A2"), }; - await provider.InvokedAsync(new(initialMessages, null)); + await provider.InvokedAsync(new(s_mockAgent, s_mockSession, initialMessages, null)); var invokingContext = new AIContextProvider.InvokingContext( - [ - new ChatMessage(ChatRole.User, "Question?") // Current request message always appended. - ]); + s_mockAgent, + s_mockSession, + [ + new ChatMessage(ChatRole.User, "Question?") // Current request message always appended. + ]); // Act await provider.InvokingAsync(invokingContext, CancellationToken.None); @@ -518,7 +533,7 @@ public async Task Serialize_WithRecentMessages_ShouldPersistMessagesUpToLimitAsy }; // Act - await provider.InvokedAsync(new(messages, aiContextProviderMessages: null)); // Populate recent memory. + await provider.InvokedAsync(new(s_mockAgent, s_mockSession, messages, aiContextProviderMessages: null)); // Populate recent memory. var state = provider.Serialize(); // Assert @@ -547,7 +562,7 @@ public async Task SerializeAndDeserialize_RoundtripRestoresMessagesAsync() new ChatMessage(ChatRole.User, "C"), new ChatMessage(ChatRole.Assistant, "D"), }; - await provider.InvokedAsync(new(messages, aiContextProviderMessages: null)); + await provider.InvokedAsync(new(s_mockAgent, s_mockSession, messages, aiContextProviderMessages: null)); // Act var state = provider.Serialize(); @@ -563,7 +578,7 @@ public async Task SerializeAndDeserialize_RoundtripRestoresMessagesAsync() RecentMessageMemoryLimit = 4 }); var emptyMessages = Array.Empty(); - await roundTrippedProvider.InvokingAsync(new(emptyMessages), CancellationToken.None); // Trigger search to read memory. + await roundTrippedProvider.InvokingAsync(new(s_mockAgent, s_mockSession, emptyMessages), CancellationToken.None); // Trigger search to read memory. // Assert Assert.NotNull(capturedInput); @@ -588,7 +603,7 @@ public async Task Deserialize_WithChangedLowerLimit_ShouldTruncateToNewLimitAsyn new ChatMessage(ChatRole.Assistant, "L4"), new ChatMessage(ChatRole.User, "L5"), }; - await initialProvider.InvokedAsync(new(messages, aiContextProviderMessages: null)); + await initialProvider.InvokedAsync(new(s_mockAgent, s_mockSession, messages, aiContextProviderMessages: null)); var state = initialProvider.Serialize(); string? capturedInput = null; @@ -604,7 +619,7 @@ public async Task Deserialize_WithChangedLowerLimit_ShouldTruncateToNewLimitAsyn SearchTime = TextSearchProviderOptions.TextSearchBehavior.BeforeAIInvoke, RecentMessageMemoryLimit = 3 // Lower limit }); - await restoredProvider.InvokingAsync(new(Array.Empty()), CancellationToken.None); + await restoredProvider.InvokingAsync(new(s_mockAgent, s_mockSession, Array.Empty()), CancellationToken.None); // Assert Assert.NotNull(capturedInput); @@ -631,7 +646,7 @@ public async Task Deserialize_WithEmptyState_ShouldHaveNoMessagesAsync() RecentMessageMemoryLimit = 3 }); var emptyMessages = Array.Empty(); - await provider.InvokingAsync(new(emptyMessages), CancellationToken.None); + await provider.InvokingAsync(new(s_mockAgent, s_mockSession, emptyMessages), CancellationToken.None); // Assert Assert.NotNull(capturedInput); diff --git a/dotnet/tests/Microsoft.Agents.AI.UnitTests/Memory/ChatHistoryMemoryProviderTests.cs b/dotnet/tests/Microsoft.Agents.AI.UnitTests/Memory/ChatHistoryMemoryProviderTests.cs index f46538c8e4..8d3cad85ae 100644 --- a/dotnet/tests/Microsoft.Agents.AI.UnitTests/Memory/ChatHistoryMemoryProviderTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.UnitTests/Memory/ChatHistoryMemoryProviderTests.cs @@ -18,6 +18,9 @@ namespace Microsoft.Agents.AI.Memory.UnitTests; /// public class ChatHistoryMemoryProviderTests { + private static readonly AIAgent s_mockAgent = new Mock().Object; + private static readonly AgentSession s_mockSession = new Mock().Object; + private readonly Mock> _loggerMock; private readonly Mock _loggerFactoryMock; @@ -116,7 +119,7 @@ public async Task InvokedAsync_UpsertsMessages_ToCollectionAsync() var requestMsgWithNulls = new ChatMessage(ChatRole.User, "request text nulls"); var responseMsg = new ChatMessage(ChatRole.Assistant, "response text") { MessageId = "resp-1", AuthorName = "assistant" }; - var invokedContext = new AIContextProvider.InvokedContext([requestMsgWithValues, requestMsgWithNulls], aiContextProviderMessages: null) + var invokedContext = new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, [requestMsgWithValues, requestMsgWithNulls], aiContextProviderMessages: null) { ResponseMessages = [responseMsg] }; @@ -174,7 +177,7 @@ public async Task InvokedAsync_DoesNotUpsertMessages_WhenInvokeFailedAsync() 1, new ChatHistoryMemoryProviderScope() { UserId = "UID" }); var requestMsg = new ChatMessage(ChatRole.User, "request text") { MessageId = "req-1" }; - var invokedContext = new AIContextProvider.InvokedContext([requestMsg], aiContextProviderMessages: null) + var invokedContext = new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, [requestMsg], aiContextProviderMessages: null) { InvokeException = new InvalidOperationException("Invoke failed") }; @@ -203,7 +206,7 @@ public async Task InvokedAsync_DoesNotThrow_WhenUpsertThrowsAsync() new ChatHistoryMemoryProviderScope() { UserId = "UID" }, loggerFactory: this._loggerFactoryMock.Object); var requestMsg = new ChatMessage(ChatRole.User, "request text") { MessageId = "req-1" }; - var invokedContext = new AIContextProvider.InvokedContext([requestMsg], aiContextProviderMessages: null); + var invokedContext = new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, [requestMsg], aiContextProviderMessages: null); // Act await provider.InvokedAsync(invokedContext, CancellationToken.None); @@ -254,7 +257,7 @@ public async Task InvokedAsync_LogsUserIdBasedOnEnableSensitiveTelemetryDataAsyn loggerFactory: this._loggerFactoryMock.Object); var requestMsg = new ChatMessage(ChatRole.User, "request text"); - var invokedContext = new AIContextProvider.InvokedContext([requestMsg], aiContextProviderMessages: null); + var invokedContext = new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, [requestMsg], aiContextProviderMessages: null); // Act await provider.InvokedAsync(invokedContext, CancellationToken.None); @@ -327,7 +330,7 @@ public async Task InvokedAsync_SearchesVectorStoreAsync() options: providerOptions); var requestMsg = new ChatMessage(ChatRole.User, "requesting relevant history"); - var invokingContext = new AIContextProvider.InvokingContext([requestMsg]); + var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [requestMsg]); // Act await provider.InvokingAsync(invokingContext, CancellationToken.None); @@ -378,7 +381,7 @@ public async Task InvokedAsync_CreatesFilter_WhenSearchScopeProvidedAsync() var provider = new ChatHistoryMemoryProvider(this._vectorStoreMock.Object, TestCollectionName, 1, options: providerOptions, storageScope: searchScope, searchScope: searchScope); var requestMsg = new ChatMessage(ChatRole.User, "requesting relevant history"); - var invokingContext = new AIContextProvider.InvokingContext([requestMsg]); + var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [requestMsg]); // Act await provider.InvokingAsync(invokingContext, CancellationToken.None); @@ -442,7 +445,7 @@ public async Task InvokingAsync_LogsUserIdBasedOnEnableSensitiveTelemetryDataAsy options: options, loggerFactory: this._loggerFactoryMock.Object); - var invokingContext = new AIContextProvider.InvokingContext([new ChatMessage(ChatRole.User, "requesting relevant history")]); + var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "requesting relevant history")]); // Act await provider.InvokingAsync(invokingContext, CancellationToken.None); diff --git a/dotnet/tests/Microsoft.Agents.AI.UnitTests/TestJsonSerializerContext.cs b/dotnet/tests/Microsoft.Agents.AI.UnitTests/TestJsonSerializerContext.cs index b145991439..0ac3ab9fbf 100644 --- a/dotnet/tests/Microsoft.Agents.AI.UnitTests/TestJsonSerializerContext.cs +++ b/dotnet/tests/Microsoft.Agents.AI.UnitTests/TestJsonSerializerContext.cs @@ -14,4 +14,5 @@ namespace Microsoft.Agents.AI.UnitTests; [JsonSerializable(typeof(string))] [JsonSerializable(typeof(string[]))] [JsonSerializable(typeof(Dictionary))] +[JsonSerializable(typeof(ChatClientAgentSessionTests.Animal))] internal sealed partial class TestJsonSerializerContext : JsonSerializerContext; diff --git a/dotnet/tests/OpenAIAssistant.IntegrationTests/OpenAIAssistantFixture.cs b/dotnet/tests/OpenAIAssistant.IntegrationTests/OpenAIAssistantFixture.cs index a90f49f428..bb58f09fb4 100644 --- a/dotnet/tests/OpenAIAssistant.IntegrationTests/OpenAIAssistantFixture.cs +++ b/dotnet/tests/OpenAIAssistant.IntegrationTests/OpenAIAssistantFixture.cs @@ -23,7 +23,7 @@ public class OpenAIAssistantFixture : IChatClientAgentFixture public IChatClient ChatClient => this._agent.ChatClient; - public async Task> GetChatHistoryAsync(AgentSession session) + public async Task> GetChatHistoryAsync(AIAgent agent, AgentSession session) { var typedSession = (ChatClientAgentSession)session; List messages = []; diff --git a/dotnet/tests/OpenAIChatCompletion.IntegrationTests/OpenAIChatCompletionFixture.cs b/dotnet/tests/OpenAIChatCompletion.IntegrationTests/OpenAIChatCompletionFixture.cs index eeb60620c0..304df28fba 100644 --- a/dotnet/tests/OpenAIChatCompletion.IntegrationTests/OpenAIChatCompletionFixture.cs +++ b/dotnet/tests/OpenAIChatCompletion.IntegrationTests/OpenAIChatCompletionFixture.cs @@ -28,7 +28,7 @@ public OpenAIChatCompletionFixture(bool useReasoningChatModel) public IChatClient ChatClient => this._agent.ChatClient; - public async Task> GetChatHistoryAsync(AgentSession session) + public async Task> GetChatHistoryAsync(AIAgent agent, AgentSession session) { var typedSession = (ChatClientAgentSession)session; @@ -37,7 +37,7 @@ public async Task> GetChatHistoryAsync(AgentSession session) return []; } - return (await typedSession.ChatHistoryProvider.InvokingAsync(new([]))).ToList(); + return (await typedSession.ChatHistoryProvider.InvokingAsync(new(agent, session, []))).ToList(); } public Task CreateChatClientAgentAsync( diff --git a/dotnet/tests/OpenAIResponse.IntegrationTests/OpenAIResponseFixture.cs b/dotnet/tests/OpenAIResponse.IntegrationTests/OpenAIResponseFixture.cs index 2006404239..719db6a0b0 100644 --- a/dotnet/tests/OpenAIResponse.IntegrationTests/OpenAIResponseFixture.cs +++ b/dotnet/tests/OpenAIResponse.IntegrationTests/OpenAIResponseFixture.cs @@ -25,7 +25,7 @@ public class OpenAIResponseFixture(bool store) : IChatClientAgentFixture public IChatClient ChatClient => this._agent.ChatClient; - public async Task> GetChatHistoryAsync(AgentSession session) + public async Task> GetChatHistoryAsync(AIAgent agent, AgentSession session) { var typedSession = (ChatClientAgentSession)session; @@ -55,7 +55,7 @@ public async Task> GetChatHistoryAsync(AgentSession session) return []; } - return (await typedSession.ChatHistoryProvider.InvokingAsync(new([]))).ToList(); + return (await typedSession.ChatHistoryProvider.InvokingAsync(new(agent, session, []))).ToList(); } private static ChatMessage ConvertToChatMessage(ResponseItem item) From d2686c19ce4787c730ddbe180111a4230a4afa10 Mon Sep 17 00:00:00 2001 From: westey <164392973+westey-m@users.noreply.github.com> Date: Thu, 5 Feb 2026 20:10:25 +0000 Subject: [PATCH 02/28] Convert all AIContextProviders to use the statebag --- .../Program.cs | 17 +- .../Program.cs | 18 +- .../Program.cs | 50 ++-- .../Program.cs | 2 +- .../Program.cs | 2 +- .../Program.cs | 2 +- .../Program.cs | 94 +++---- .../Mem0JsonUtilities.cs | 2 +- .../Microsoft.Agents.AI.Mem0/Mem0Provider.cs | 207 +++++++-------- .../Mem0ProviderOptions.cs | 6 + .../Microsoft.Agents.AI/AgentJsonUtilities.cs | 2 +- .../Memory/ChatHistoryMemoryProvider.cs | 244 +++++++++--------- .../ChatHistoryMemoryProviderOptions.cs | 9 + .../Microsoft.Agents.AI/TextSearchProvider.cs | 83 +++--- .../TextSearchProviderOptions.cs | 9 + .../Mem0ProviderTests.cs | 73 +++--- .../Mem0ProviderTests.cs | 154 ++++++----- .../Data/TextSearchProviderTests.cs | 159 +++++------- .../Memory/ChatHistoryMemoryProviderTests.cs | 128 ++++----- 19 files changed, 641 insertions(+), 620 deletions(-) diff --git a/dotnet/samples/GettingStarted/AgentWithMemory/AgentWithMemory_Step01_ChatHistoryMemory/Program.cs b/dotnet/samples/GettingStarted/AgentWithMemory/AgentWithMemory_Step01_ChatHistoryMemory/Program.cs index ec55abf3a4..4b2f739725 100644 --- a/dotnet/samples/GettingStarted/AgentWithMemory/AgentWithMemory_Step01_ChatHistoryMemory/Program.cs +++ b/dotnet/samples/GettingStarted/AgentWithMemory/AgentWithMemory_Step01_ChatHistoryMemory/Program.cs @@ -38,12 +38,17 @@ vectorStore, collectionName: "chathistory", vectorDimensions: 3072, - // Configure the scope values under which chat messages will be stored. - // In this case, we are using a fixed user ID and a unique session ID for each new session. - storageScope: new() { UserId = "UID1", SessionId = Guid.NewGuid().ToString() }, - // Configure the scope which would be used to search for relevant prior messages. - // In this case, we are searching for any messages for the user across all sessions. - searchScope: new() { UserId = "UID1" })) + // Callback to configure the initial state of the ChatHistoryMemoryProvider. + // The ChatHistoryMemoryProvider stores its state in the AgentSession and this callback + // will be called whenever the ChatHistoryMemoryProvider cannot find existing state in the session, + // typically the first time it is used with a new session. + session => new ChatHistoryMemoryProvider.State( + // Configure the scope values under which chat messages will be stored. + // In this case, we are using a fixed user ID and a unique session ID for each new session. + storageScope: new() { UserId = "UID1", SessionId = Guid.NewGuid().ToString() }, + // Configure the scope which would be used to search for relevant prior messages. + // In this case, we are searching for any messages for the user across all sessions. + searchScope: new() { UserId = "UID1" }))) }); // Start a new session for the agent conversation. diff --git a/dotnet/samples/GettingStarted/AgentWithMemory/AgentWithMemory_Step02_MemoryUsingMem0/Program.cs b/dotnet/samples/GettingStarted/AgentWithMemory/AgentWithMemory_Step02_MemoryUsingMem0/Program.cs index 4a0dbe0839..23154de8f1 100644 --- a/dotnet/samples/GettingStarted/AgentWithMemory/AgentWithMemory_Step02_MemoryUsingMem0/Program.cs +++ b/dotnet/samples/GettingStarted/AgentWithMemory/AgentWithMemory_Step02_MemoryUsingMem0/Program.cs @@ -31,20 +31,22 @@ .AsAIAgent(new ChatClientAgentOptions() { ChatOptions = new() { Instructions = "You are a friendly travel assistant. Use known memories about the user when responding, and do not invent details." }, - AIContextProviderFactory = (ctx, ct) => new ValueTask(ctx.SerializedState.ValueKind is not JsonValueKind.Null and not JsonValueKind.Undefined - // If each session should have its own Mem0 scope, you can create a new id per session here: - // ? new Mem0Provider(mem0HttpClient, new Mem0ProviderScope() { ThreadId = Guid.NewGuid().ToString() }) - // In this case we are storing memories scoped by application and user instead so that memories are retained across threads. - ? new Mem0Provider(mem0HttpClient, new Mem0ProviderScope() { ApplicationId = "getting-started-agents", UserId = "sample-user" }) - // For cases where we are restoring from serialized state: - : new Mem0Provider(mem0HttpClient, ctx.SerializedState, ctx.JsonSerializerOptions)) + AIContextProviderFactory = (ctx, ct) => new ValueTask( + // The stateInitializer can be used to customize the Mem0 scope per session and it will be called each time a session + // is encountered by the Mem0Provider that does not already have Mem0Provider state stored on the session. + // If each session should have its own Mem0 scope, you can create a new id per session via the stateInitializer, e.g.: + // new Mem0Provider(mem0HttpClient, stateInitializer: _ => new(new Mem0ProviderScope() { ThreadId = Guid.NewGuid().ToString() })) + // In our case we are storing memories scoped by application and user instead so that memories are retained across threads. + new Mem0Provider(mem0HttpClient, stateInitializer: _ => new(new Mem0ProviderScope() { ApplicationId = "getting-started-agents", UserId = "sample-user" }))) }); AgentSession session = await agent.CreateSessionAsync(); // Clear any existing memories for this scope to demonstrate fresh behavior. +// Note that the ClearStoredMemoriesAsync method will clear memories +// using the scope stored in the session, or provided via the stateInitializer. Mem0Provider mem0Provider = session.GetService()!; -await mem0Provider.ClearStoredMemoriesAsync(); +await mem0Provider.ClearStoredMemoriesAsync(session); Console.WriteLine(await agent.RunAsync("Hi there! My name is Taylor and I'm planning a hiking trip to Patagonia in November.", session)); Console.WriteLine(await agent.RunAsync("I'm travelling with my sister and we love finding scenic viewpoints.", session)); diff --git a/dotnet/samples/GettingStarted/AgentWithMemory/AgentWithMemory_Step03_CustomMemory/Program.cs b/dotnet/samples/GettingStarted/AgentWithMemory/AgentWithMemory_Step03_CustomMemory/Program.cs index 509b79e53f..13e0726c38 100644 --- a/dotnet/samples/GettingStarted/AgentWithMemory/AgentWithMemory_Step03_CustomMemory/Program.cs +++ b/dotnet/samples/GettingStarted/AgentWithMemory/AgentWithMemory_Step03_CustomMemory/Program.cs @@ -33,7 +33,7 @@ AIAgent agent = chatClient.AsAIAgent(new ChatClientAgentOptions() { ChatOptions = new() { Instructions = "You are a friendly assistant. Always address the user by their name." }, - AIContextProviderFactory = (ctx, ct) => new ValueTask(new UserInfoMemory(chatClient.AsIChatClient(), ctx.SerializedState, ctx.JsonSerializerOptions)) + AIContextProviderFactory = (ctx, ct) => new ValueTask(new UserInfoMemory(chatClient.AsIChatClient())) }); // Create a new session for the conversation. @@ -58,7 +58,7 @@ Console.WriteLine("\n>> Read memories from memory component\n"); // It's possible to access the memory component via the session's GetService method. -var userInfo = deserializedSession.GetService()?.UserInfo; +var userInfo = deserializedSession.GetService()?.GetUserInfo(deserializedSession); // Output the user info that was captured by the memory component. Console.WriteLine($"MEMORY - User Name: {userInfo?.UserName}"); @@ -71,7 +71,7 @@ var newSession = await agent.CreateSessionAsync(); if (userInfo is not null && newSession.GetService() is UserInfoMemory newSessionMemory) { - newSessionMemory.UserInfo = userInfo; + newSessionMemory.SetUserInfo(newSession, userInfo); } // Invoke the agent and output the text result. @@ -86,28 +86,28 @@ namespace SampleApp internal sealed class UserInfoMemory : AIContextProvider { private readonly IChatClient _chatClient; + private readonly Func? _stateInitializer; - public UserInfoMemory(IChatClient chatClient, UserInfo? userInfo = null) + public UserInfoMemory(IChatClient chatClient, Func? stateInitializer = null) { this._chatClient = chatClient; - this.UserInfo = userInfo ?? new UserInfo(); + this._stateInitializer = stateInitializer; } - public UserInfoMemory(IChatClient chatClient, JsonElement serializedState, JsonSerializerOptions? jsonSerializerOptions = null) - { - this._chatClient = chatClient; - - this.UserInfo = serializedState.ValueKind == JsonValueKind.Object ? - serializedState.Deserialize(jsonSerializerOptions)! : - new UserInfo(); - } + public UserInfo GetUserInfo(AgentSession session) + => session.StateBag.GetValue(nameof(UserInfoMemory)) ?? new UserInfo(); - public UserInfo UserInfo { get; set; } + public void SetUserInfo(AgentSession session, UserInfo userInfo) + => session.StateBag.SetValue(nameof(UserInfoMemory), userInfo); public override async ValueTask InvokedAsync(InvokedContext context, CancellationToken cancellationToken = default) { + var userInfo = context.Session?.StateBag.GetValue(nameof(UserInfoMemory)) + ?? this._stateInitializer?.Invoke(context.Session) + ?? new UserInfo(); + // Try and extract the user name and age from the message if we don't have it already and it's a user message. - if ((this.UserInfo.UserName is null || this.UserInfo.UserAge is null) && context.RequestMessages.Any(x => x.Role == ChatRole.User)) + if ((userInfo.UserName is null || userInfo.UserAge is null) && context.RequestMessages.Any(x => x.Role == ChatRole.User)) { var result = await this._chatClient.GetResponseAsync( context.RequestMessages, @@ -117,25 +117,31 @@ public override async ValueTask InvokedAsync(InvokedContext context, Cancellatio }, cancellationToken: cancellationToken); - this.UserInfo.UserName ??= result.Result.UserName; - this.UserInfo.UserAge ??= result.Result.UserAge; + userInfo.UserName ??= result.Result.UserName; + userInfo.UserAge ??= result.Result.UserAge; } + + context.Session?.StateBag.SetValue(nameof(UserInfoMemory), userInfo); } public override ValueTask InvokingAsync(InvokingContext context, CancellationToken cancellationToken = default) { + var userInfo = context.Session?.StateBag.GetValue(nameof(UserInfoMemory)) + ?? this._stateInitializer?.Invoke(context.Session) + ?? new UserInfo(); + StringBuilder instructions = new(); // If we don't already know the user's name and age, add instructions to ask for them, otherwise just provide what we have to the context. instructions .AppendLine( - this.UserInfo.UserName is null ? + userInfo.UserName is null ? "Ask the user for their name and politely decline to answer any questions until they provide it." : - $"The user's name is {this.UserInfo.UserName}.") + $"The user's name is {userInfo.UserName}.") .AppendLine( - this.UserInfo.UserAge is null ? + userInfo.UserAge is null ? "Ask the user for their age and politely decline to answer any questions until they provide it." : - $"The user's age is {this.UserInfo.UserAge}."); + $"The user's age is {userInfo.UserAge}."); return new ValueTask(new AIContext { @@ -145,7 +151,7 @@ this.UserInfo.UserAge is null ? public override JsonElement Serialize(JsonSerializerOptions? jsonSerializerOptions = null) { - return JsonSerializer.SerializeToElement(this.UserInfo, jsonSerializerOptions); + return JsonSerializer.SerializeToElement(new UserInfo(), jsonSerializerOptions); } } diff --git a/dotnet/samples/GettingStarted/AgentWithRAG/AgentWithRAG_Step01_BasicTextRAG/Program.cs b/dotnet/samples/GettingStarted/AgentWithRAG/AgentWithRAG_Step01_BasicTextRAG/Program.cs index ca798aa333..38f87ad2cf 100644 --- a/dotnet/samples/GettingStarted/AgentWithRAG/AgentWithRAG_Step01_BasicTextRAG/Program.cs +++ b/dotnet/samples/GettingStarted/AgentWithRAG/AgentWithRAG_Step01_BasicTextRAG/Program.cs @@ -62,7 +62,7 @@ .AsAIAgent(new ChatClientAgentOptions { ChatOptions = new() { Instructions = "You are a helpful support specialist for Contoso Outdoors. Answer questions using the provided context and cite the source document when available." }, - AIContextProviderFactory = (ctx, ct) => new ValueTask(new TextSearchProvider(SearchAdapter, ctx.SerializedState, ctx.JsonSerializerOptions, textSearchOptions)), + AIContextProviderFactory = (ctx, ct) => new ValueTask(new TextSearchProvider(SearchAdapter, textSearchOptions)), // Since we are using ChatCompletion which stores chat history locally, we can also add a message removal policy // that removes messages produced by the TextSearchProvider before they are added to the chat history, so that // we don't bloat chat history with all the search result messages. diff --git a/dotnet/samples/GettingStarted/AgentWithRAG/AgentWithRAG_Step02_CustomVectorStoreRAG/Program.cs b/dotnet/samples/GettingStarted/AgentWithRAG/AgentWithRAG_Step02_CustomVectorStoreRAG/Program.cs index 3648ccc898..4ab1a2df6d 100644 --- a/dotnet/samples/GettingStarted/AgentWithRAG/AgentWithRAG_Step02_CustomVectorStoreRAG/Program.cs +++ b/dotnet/samples/GettingStarted/AgentWithRAG/AgentWithRAG_Step02_CustomVectorStoreRAG/Program.cs @@ -71,7 +71,7 @@ .AsAIAgent(new ChatClientAgentOptions { ChatOptions = new() { Instructions = "You are a helpful support specialist for the Microsoft Agent Framework. Answer questions using the provided context and cite the source document when available. Keep responses brief." }, - AIContextProviderFactory = (ctx, ct) => new ValueTask(new TextSearchProvider(SearchAdapter, ctx.SerializedState, ctx.JsonSerializerOptions, textSearchOptions)) + AIContextProviderFactory = (ctx, ct) => new ValueTask(new TextSearchProvider(SearchAdapter, textSearchOptions)) }); AgentSession session = await agent.CreateSessionAsync(); diff --git a/dotnet/samples/GettingStarted/AgentWithRAG/AgentWithRAG_Step03_CustomRAGDataSource/Program.cs b/dotnet/samples/GettingStarted/AgentWithRAG/AgentWithRAG_Step03_CustomRAGDataSource/Program.cs index bcc823de46..a4b5a10ba1 100644 --- a/dotnet/samples/GettingStarted/AgentWithRAG/AgentWithRAG_Step03_CustomRAGDataSource/Program.cs +++ b/dotnet/samples/GettingStarted/AgentWithRAG/AgentWithRAG_Step03_CustomRAGDataSource/Program.cs @@ -29,7 +29,7 @@ .AsAIAgent(new ChatClientAgentOptions { ChatOptions = new() { Instructions = "You are a helpful support specialist for Contoso Outdoors. Answer questions using the provided context and cite the source document when available." }, - AIContextProviderFactory = (ctx, ct) => new ValueTask(new TextSearchProvider(MockSearchAsync, ctx.SerializedState, ctx.JsonSerializerOptions, textSearchOptions)) + AIContextProviderFactory = (ctx, ct) => new ValueTask(new TextSearchProvider(MockSearchAsync, textSearchOptions)) }); AgentSession session = await agent.CreateSessionAsync(); diff --git a/dotnet/samples/GettingStarted/Agents/Agent_Step20_AdditionalAIContext/Program.cs b/dotnet/samples/GettingStarted/Agents/Agent_Step20_AdditionalAIContext/Program.cs index 28b9780d17..bc7a1ababb 100644 --- a/dotnet/samples/GettingStarted/Agents/Agent_Step20_AdditionalAIContext/Program.cs +++ b/dotnet/samples/GettingStarted/Agents/Agent_Step20_AdditionalAIContext/Program.cs @@ -7,7 +7,6 @@ #pragma warning disable CA1869 // Cache and reuse 'JsonSerializerOptions' instances -using System.ComponentModel; using System.Text; using System.Text.Json; using Azure.AI.OpenAI; @@ -52,9 +51,9 @@ You remind users of upcoming calendar events when the user interacts with you. // Add an AI context provider that maintains a todo list for the agent and one that provides upcoming calendar entries. // Wrap these in an AI context provider that aggregates the other two. AIContextProviderFactory = (ctx, ct) => new ValueTask(new AggregatingAIContextProvider([ - AggregatingAIContextProvider.CreateFactory((jsonElement, jsonSerializerOptions) => new TodoListAIContextProvider(jsonElement, jsonSerializerOptions)), - AggregatingAIContextProvider.CreateFactory((_, _) => new CalendarSearchAIContextProvider(loadNextThreeCalendarEvents)) - ], ctx.SerializedState, ctx.JsonSerializerOptions)), + new TodoListAIContextProvider(), + new CalendarSearchAIContextProvider(loadNextThreeCalendarEvents) + ])), }); // Invoke the agent and output the text result. @@ -80,51 +79,65 @@ namespace SampleApp /// internal sealed class TodoListAIContextProvider : AIContextProvider { - private readonly List _todoItems = new(); + private const string StateKey = nameof(TodoListAIContextProvider); - public TodoListAIContextProvider(JsonElement jsonElement, JsonSerializerOptions? jsonSerializerOptions = null) - { - // Only try and restore the state if we got an array, since any other json would be invalid or undefined/null meaning - // it's the first time we are running. - if (jsonElement.ValueKind == JsonValueKind.Array) - { - this._todoItems = JsonSerializer.Deserialize>(jsonElement.GetRawText(), jsonSerializerOptions) ?? new List(); - } - } + private static List GetTodoItems(AgentSession? session) + => session?.StateBag.GetValue>(StateKey) ?? new List(); + + private static void SetTodoItems(AgentSession? session, List items) + => session?.StateBag.SetValue(StateKey, items); public override ValueTask InvokingAsync(InvokingContext context, CancellationToken cancellationToken = default) { + var todoItems = GetTodoItems(context.Session); + StringBuilder outputMessageBuilder = new(); outputMessageBuilder.AppendLine("Your todo list contains the following items:"); - if (this._todoItems.Count == 0) + if (todoItems.Count == 0) { outputMessageBuilder.AppendLine(" (no items)"); } else { - for (int i = 0; i < this._todoItems.Count; i++) + for (int i = 0; i < todoItems.Count; i++) { - outputMessageBuilder.AppendLine($"{i}. {this._todoItems[i]}"); + outputMessageBuilder.AppendLine($"{i}. {todoItems[i]}"); } } return new ValueTask(new AIContext { - Tools = [AIFunctionFactory.Create(this.AddTodoItem), AIFunctionFactory.Create(this.RemoveTodoItem)], + Tools = + [ + AIFunctionFactory.Create(() => AddTodoItem(context.Session, string.Empty), "AddTodoItem", "Adds an item to the todo list."), + AIFunctionFactory.Create(() => RemoveTodoItem(context.Session, 0), "RemoveTodoItem", "Adds an item to the todo list. Index is zero based.") + ], Messages = [new MEAI.ChatMessage(ChatRole.User, outputMessageBuilder.ToString())] }); } - [Description("Adds an item to the todo list. Index is zero based.")] - private void RemoveTodoItem(int index) => - this._todoItems.RemoveAt(index); + private static void RemoveTodoItem(AgentSession? session, int index) + { + var items = GetTodoItems(session); + items.RemoveAt(index); + SetTodoItems(session, items); + } + + private static void AddTodoItem(AgentSession? session, string item) + { + if (string.IsNullOrWhiteSpace(item)) + { + throw new ArgumentException("Item must have a value"); + } - private void AddTodoItem(string item) => - this._todoItems.Add(string.IsNullOrWhiteSpace(item) ? throw new ArgumentException("Item must have a value") : item); + var items = GetTodoItems(session); + items.Add(item); + SetTodoItems(session, items); + } public override JsonElement Serialize(JsonSerializerOptions? jsonSerializerOptions = null) => - JsonSerializer.SerializeToElement(this._todoItems, jsonSerializerOptions); + JsonSerializer.SerializeToElement(new List(), jsonSerializerOptions); } /// @@ -155,28 +168,15 @@ public override async ValueTask InvokingAsync(InvokingContext context /// /// An which aggregates multiple AI context providers into one. - /// Serialized state for the different providers are stored under their type name. /// Tools and messages from all providers are combined, and instructions are concatenated. /// internal sealed class AggregatingAIContextProvider : AIContextProvider { - private readonly List _providers = new(); + private readonly List _providers; - public AggregatingAIContextProvider(ProviderFactory[] providerFactories, JsonElement jsonElement, JsonSerializerOptions? jsonSerializerOptions) + public AggregatingAIContextProvider(List providers) { - // We received a json object, so let's check if it has some previously serialized state that we can use. - if (jsonElement.ValueKind == JsonValueKind.Object) - { - this._providers = providerFactories - .Select(factory => factory.FactoryMethod(jsonElement.TryGetProperty(factory.ProviderType.Name, out var prop) ? prop : default, jsonSerializerOptions)) - .ToList(); - return; - } - - // We didn't receive any valid json, so we can just construct fresh providers. - this._providers = providerFactories - .Select(factory => factory.FactoryMethod(default, jsonSerializerOptions)) - .ToList(); + this._providers = providers; } public override async ValueTask InvokingAsync(InvokingContext context, CancellationToken cancellationToken = default) @@ -210,19 +210,5 @@ public override JsonElement Serialize(JsonSerializerOptions? jsonSerializerOptio return JsonSerializer.SerializeToElement(elements, jsonSerializerOptions); } - - public static ProviderFactory CreateFactory(Func factoryMethod) - where TProviderType : AIContextProvider => new() - { - FactoryMethod = (jsonElement, jsonSerializerOptions) => factoryMethod(jsonElement, jsonSerializerOptions), - ProviderType = typeof(TProviderType) - }; - - public readonly struct ProviderFactory - { - public Func FactoryMethod { get; init; } - - public Type ProviderType { get; init; } - } } } diff --git a/dotnet/src/Microsoft.Agents.AI.Mem0/Mem0JsonUtilities.cs b/dotnet/src/Microsoft.Agents.AI.Mem0/Mem0JsonUtilities.cs index d139cb0f76..33f92f3ac2 100644 --- a/dotnet/src/Microsoft.Agents.AI.Mem0/Mem0JsonUtilities.cs +++ b/dotnet/src/Microsoft.Agents.AI.Mem0/Mem0JsonUtilities.cs @@ -65,7 +65,7 @@ private static JsonSerializerOptions CreateDefaultOptions() NumberHandling = JsonNumberHandling.AllowReadingFromString)] // Agent abstraction types - [JsonSerializable(typeof(Mem0Provider.Mem0State))] + [JsonSerializable(typeof(Mem0Provider.State))] [ExcludeFromCodeCoverage] internal sealed partial class JsonContext : JsonSerializerContext; diff --git a/dotnet/src/Microsoft.Agents.AI.Mem0/Mem0Provider.cs b/dotnet/src/Microsoft.Agents.AI.Mem0/Mem0Provider.cs index 0e9b4288b1..a49c282dac 100644 --- a/dotnet/src/Microsoft.Agents.AI.Mem0/Mem0Provider.cs +++ b/dotnet/src/Microsoft.Agents.AI.Mem0/Mem0Provider.cs @@ -26,24 +26,24 @@ namespace Microsoft.Agents.AI.Mem0; public sealed class Mem0Provider : AIContextProvider { private const string DefaultContextPrompt = "## Memories\nConsider the following memories when answering user questions:"; + private const string DefaultStateBagKey = "Mem0Provider.State"; private readonly string _contextPrompt; private readonly bool _enableSensitiveTelemetryData; + private readonly string _stateKey; + private readonly Func _stateInitializer; private readonly Mem0Client _client; private readonly ILogger? _logger; - private readonly Mem0ProviderScope _storageScope; - private readonly Mem0ProviderScope _searchScope; - /// /// Initializes a new instance of the class. /// /// Configured (base address + auth). - /// Optional values to scope the memory storage with. - /// Optional values to scope the memory search with. Defaults to if not provided. + /// A delegate that initializes the provider state on the first invocation, providing the storage and search scopes. /// Provider options. /// Optional logger factory. + /// Thrown when or is . /// /// The base address of the required mem0 service, and any authentication headers, should be set on the /// already, when passed as a parameter here. E.g.: @@ -51,83 +51,46 @@ public sealed class Mem0Provider : AIContextProvider /// using var httpClient = new HttpClient(); /// httpClient.BaseAddress = new Uri("https://api.mem0.ai"); /// httpClient.DefaultRequestHeaders.Authorization = new AuthenticationHeaderValue("Token", "<Your APIKey>"); - /// new Mem0AIContextProvider(httpClient); + /// new Mem0Provider(httpClient); /// /// - public Mem0Provider(HttpClient httpClient, Mem0ProviderScope storageScope, Mem0ProviderScope? searchScope = null, Mem0ProviderOptions? options = null, ILoggerFactory? loggerFactory = null) + public Mem0Provider(HttpClient httpClient, Func stateInitializer, Mem0ProviderOptions? options = null, ILoggerFactory? loggerFactory = null) { + Throw.IfNull(httpClient); if (string.IsNullOrWhiteSpace(httpClient.BaseAddress?.AbsoluteUri)) { throw new ArgumentException("The HttpClient BaseAddress must be set for Mem0 operations.", nameof(httpClient)); } + this._stateInitializer = Throw.IfNull(stateInitializer); this._logger = loggerFactory?.CreateLogger(); this._client = new Mem0Client(httpClient); this._contextPrompt = options?.ContextPrompt ?? DefaultContextPrompt; this._enableSensitiveTelemetryData = options?.EnableSensitiveTelemetryData ?? false; - this._storageScope = new Mem0ProviderScope(Throw.IfNull(storageScope)); - this._searchScope = searchScope ?? storageScope; - - if (string.IsNullOrWhiteSpace(this._storageScope.ApplicationId) - && string.IsNullOrWhiteSpace(this._storageScope.AgentId) - && string.IsNullOrWhiteSpace(this._storageScope.ThreadId) - && string.IsNullOrWhiteSpace(this._storageScope.UserId)) - { - throw new ArgumentException("At least one of ApplicationId, AgentId, ThreadId, or UserId must be provided for the storage scope."); - } - - if (string.IsNullOrWhiteSpace(this._searchScope.ApplicationId) - && string.IsNullOrWhiteSpace(this._searchScope.AgentId) - && string.IsNullOrWhiteSpace(this._searchScope.ThreadId) - && string.IsNullOrWhiteSpace(this._searchScope.UserId)) - { - throw new ArgumentException("At least one of ApplicationId, AgentId, ThreadId, or UserId must be provided for the search scope."); - } + this._stateKey = options?.StateKey ?? DefaultStateBagKey; } /// - /// Initializes a new instance of the class, with existing state from a serialized JSON element. + /// Gets the state from the session's StateBag, or initializes it using the StateInitializer if not present. /// - /// Configured (base address + auth). - /// A representing the serialized state of the store. - /// Optional settings for customizing the JSON deserialization process. - /// Provider options. - /// Optional logger factory. - /// - /// - /// The base address of the required mem0 service, and any authentication headers, should be set on the - /// already, when passed as a parameter here. E.g.: - /// - /// using var httpClient = new HttpClient(); - /// httpClient.BaseAddress = new Uri("https://api.mem0.ai"); - /// httpClient.DefaultRequestHeaders.Authorization = new AuthenticationHeaderValue("Token", "<Your APIKey>"); - /// new Mem0AIContextProvider(httpClient, state); - /// - /// - public Mem0Provider(HttpClient httpClient, JsonElement serializedState, JsonSerializerOptions? jsonSerializerOptions = null, Mem0ProviderOptions? options = null, ILoggerFactory? loggerFactory = null) + /// The agent session containing the StateBag. + /// The provider state, or null if no session is available. + private State? GetOrInitializeState(AgentSession? session) { - if (string.IsNullOrWhiteSpace(httpClient.BaseAddress?.AbsoluteUri)) + var state = session?.StateBag.GetValue(this._stateKey, Mem0JsonUtilities.DefaultOptions); + if (state is not null) { - throw new ArgumentException("The HttpClient BaseAddress must be set for Mem0 operations.", nameof(httpClient)); + return state; } - this._logger = loggerFactory?.CreateLogger(); - this._client = new Mem0Client(httpClient); - - this._contextPrompt = options?.ContextPrompt ?? DefaultContextPrompt; - this._enableSensitiveTelemetryData = options?.EnableSensitiveTelemetryData ?? false; - - var jso = jsonSerializerOptions ?? Mem0JsonUtilities.DefaultOptions; - var state = serializedState.Deserialize(jso.GetTypeInfo(typeof(Mem0State))) as Mem0State; - - if (state == null || state.StorageScope == null || state.SearchScope == null) + state = this._stateInitializer(session); + if (state is not null && session is not null) { - throw new InvalidOperationException("The Mem0Provider state did not contain the required scope properties."); + session.StateBag.SetValue(this._stateKey, state, Mem0JsonUtilities.DefaultOptions); } - this._storageScope = state.StorageScope; - this._searchScope = state.SearchScope; + return state; } /// @@ -135,6 +98,9 @@ public override async ValueTask InvokingAsync(InvokingContext context { Throw.IfNull(context); + var state = this.GetOrInitializeState(context.Session); + var searchScope = state?.SearchScope ?? new Mem0ProviderScope(); + string queryText = string.Join( Environment.NewLine, context.RequestMessages.Where(m => !string.IsNullOrWhiteSpace(m.Text)).Select(m => m.Text)); @@ -142,10 +108,10 @@ public override async ValueTask InvokingAsync(InvokingContext context try { var memories = (await this._client.SearchAsync( - this._searchScope.ApplicationId, - this._searchScope.AgentId, - this._searchScope.ThreadId, - this._searchScope.UserId, + searchScope.ApplicationId, + searchScope.AgentId, + searchScope.ThreadId, + searchScope.UserId, queryText, cancellationToken).ConfigureAwait(false)).ToList(); @@ -158,10 +124,10 @@ public override async ValueTask InvokingAsync(InvokingContext context this._logger.LogInformation( "Mem0AIContextProvider: Retrieved {Count} memories. ApplicationId: '{ApplicationId}', AgentId: '{AgentId}', ThreadId: '{ThreadId}', UserId: '{UserId}'.", memories.Count, - this._searchScope.ApplicationId, - this._searchScope.AgentId, - this._searchScope.ThreadId, - this.SanitizeLogData(this._searchScope.UserId)); + searchScope.ApplicationId, + searchScope.AgentId, + searchScope.ThreadId, + this.SanitizeLogData(searchScope.UserId)); if (outputMessageText is not null && this._logger.IsEnabled(LogLevel.Trace)) { @@ -169,10 +135,10 @@ public override async ValueTask InvokingAsync(InvokingContext context "Mem0AIContextProvider: Search Results\nInput:{Input}\nOutput:{MessageText}\nApplicationId: '{ApplicationId}', AgentId: '{AgentId}', ThreadId: '{ThreadId}', UserId: '{UserId}'.", this.SanitizeLogData(queryText), this.SanitizeLogData(outputMessageText), - this._searchScope.ApplicationId, - this._searchScope.AgentId, - this._searchScope.ThreadId, - this.SanitizeLogData(this._searchScope.UserId)); + searchScope.ApplicationId, + searchScope.AgentId, + searchScope.ThreadId, + this.SanitizeLogData(searchScope.UserId)); } } @@ -192,10 +158,10 @@ public override async ValueTask InvokingAsync(InvokingContext context this._logger.LogError( ex, "Mem0AIContextProvider: Failed to search Mem0 for memories due to error. ApplicationId: '{ApplicationId}', AgentId: '{AgentId}', ThreadId: '{ThreadId}', UserId: '{UserId}'.", - this._searchScope.ApplicationId, - this._searchScope.AgentId, - this._searchScope.ThreadId, - this.SanitizeLogData(this._searchScope.UserId)); + searchScope.ApplicationId, + searchScope.AgentId, + searchScope.ThreadId, + this.SanitizeLogData(searchScope.UserId)); } return new AIContext(); } @@ -209,10 +175,13 @@ public override async ValueTask InvokedAsync(InvokedContext context, Cancellatio return; // Do not update memory on failed invocations. } + var state = this.GetOrInitializeState(context.Session); + var storageScope = state?.StorageScope ?? new Mem0ProviderScope(); + try { // Persist request and response messages after invocation. - await this.PersistMessagesAsync(context.RequestMessages.Concat(context.ResponseMessages ?? []), cancellationToken).ConfigureAwait(false); + await this.PersistMessagesAsync(storageScope, context.RequestMessages.Concat(context.ResponseMessages ?? []), cancellationToken).ConfigureAwait(false); } catch (Exception ex) { @@ -221,36 +190,55 @@ public override async ValueTask InvokedAsync(InvokedContext context, Cancellatio this._logger.LogError( ex, "Mem0AIContextProvider: Failed to send messages to Mem0 due to error. ApplicationId: '{ApplicationId}', AgentId: '{AgentId}', ThreadId: '{ThreadId}', UserId: '{UserId}'.", - this._storageScope.ApplicationId, - this._storageScope.AgentId, - this._storageScope.ThreadId, - this.SanitizeLogData(this._storageScope.UserId)); + storageScope.ApplicationId, + storageScope.AgentId, + storageScope.ThreadId, + this.SanitizeLogData(storageScope.UserId)); } } } /// - /// Clears stored memories for the configured scopes. + /// Clears stored memories for the specified scope. /// + /// The session containing the scope state to clear memories for. /// Cancellation token. - public Task ClearStoredMemoriesAsync(CancellationToken cancellationToken = default) => - this._client.ClearMemoryAsync( - this._storageScope.ApplicationId, - this._storageScope.AgentId, - this._storageScope.ThreadId, - this._storageScope.UserId, + public Task ClearStoredMemoriesAsync(AgentSession session, CancellationToken cancellationToken = default) + { + Throw.IfNull(session); + var state = this.GetOrInitializeState(session); + var storageScope = state?.StorageScope; + + if (storageScope is null) + { + return Task.CompletedTask; // Nothing to clear if there is no state. + } + + return this._client.ClearMemoryAsync( + storageScope.ApplicationId, + storageScope.AgentId, + storageScope.ThreadId, + storageScope.UserId, cancellationToken); + } - /// + /// + /// Serializes the current provider state to a containing any overridden prompts or descriptions. + /// + /// Optional serializer options (ignored, source generated context is used). + /// An empty object. + /// + /// This method is deprecated. State is now stored in the and serialized as part of the session. + /// public override JsonElement Serialize(JsonSerializerOptions? jsonSerializerOptions = null) { - var state = new Mem0State(this._storageScope, this._searchScope); - - var jso = jsonSerializerOptions ?? Mem0JsonUtilities.DefaultOptions; - return JsonSerializer.SerializeToElement(state, jso.GetTypeInfo(typeof(Mem0State))); + // State is now stored in the session StateBag, so there is nothing to serialize here. + // Return an empty JSON object. + using var doc = JsonDocument.Parse("{}"); + return doc.RootElement.Clone(); } - private async Task PersistMessagesAsync(IEnumerable messages, CancellationToken cancellationToken) + private async Task PersistMessagesAsync(Mem0ProviderScope storageScope, IEnumerable messages, CancellationToken cancellationToken) { foreach (var message in messages) { @@ -270,27 +258,42 @@ private async Task PersistMessagesAsync(IEnumerable messages, Cance } await this._client.CreateMemoryAsync( - this._storageScope.ApplicationId, - this._storageScope.AgentId, - this._storageScope.ThreadId, - this._storageScope.UserId, + storageScope.ApplicationId, + storageScope.AgentId, + storageScope.ThreadId, + storageScope.UserId, message.Text, message.Role.Value, cancellationToken).ConfigureAwait(false); } } - internal sealed class Mem0State + /// + /// Represents the state of a stored in the . + /// + public sealed class State { + /// + /// Initializes a new instance of the class with the specified storage and search scopes. + /// + /// The scope to use when storing memories. + /// The scope to use when searching for memories. If null, the storage scope will be used for searching as well. [JsonConstructor] - public Mem0State(Mem0ProviderScope storageScope, Mem0ProviderScope searchScope) + public State(Mem0ProviderScope storageScope, Mem0ProviderScope? searchScope = null) { - this.StorageScope = storageScope; - this.SearchScope = searchScope; + this.StorageScope = Throw.IfNull(storageScope); + this.SearchScope = searchScope ?? storageScope; } - public Mem0ProviderScope StorageScope { get; set; } - public Mem0ProviderScope SearchScope { get; set; } + /// + /// Gets the scope used when storing memories. + /// + public Mem0ProviderScope StorageScope { get; } + + /// + /// Gets the scope used when searching memories. + /// + public Mem0ProviderScope SearchScope { get; } } private string? SanitizeLogData(string? data) => this._enableSensitiveTelemetryData ? data : ""; diff --git a/dotnet/src/Microsoft.Agents.AI.Mem0/Mem0ProviderOptions.cs b/dotnet/src/Microsoft.Agents.AI.Mem0/Mem0ProviderOptions.cs index f2d3d89e16..f382e80bf2 100644 --- a/dotnet/src/Microsoft.Agents.AI.Mem0/Mem0ProviderOptions.cs +++ b/dotnet/src/Microsoft.Agents.AI.Mem0/Mem0ProviderOptions.cs @@ -18,4 +18,10 @@ public sealed class Mem0ProviderOptions /// /// Defaults to . public bool EnableSensitiveTelemetryData { get; set; } + + /// + /// Gets or sets the key used to store the provider state in the session's . + /// + /// Defaults to "Mem0Provider.State". + public string? StateKey { get; set; } } diff --git a/dotnet/src/Microsoft.Agents.AI/AgentJsonUtilities.cs b/dotnet/src/Microsoft.Agents.AI/AgentJsonUtilities.cs index 5c915b6b01..79cbf2193c 100644 --- a/dotnet/src/Microsoft.Agents.AI/AgentJsonUtilities.cs +++ b/dotnet/src/Microsoft.Agents.AI/AgentJsonUtilities.cs @@ -67,7 +67,7 @@ private static JsonSerializerOptions CreateDefaultOptions() // Agent abstraction types [JsonSerializable(typeof(ChatClientAgentSession.SessionState))] [JsonSerializable(typeof(TextSearchProvider.TextSearchProviderState))] - [JsonSerializable(typeof(ChatHistoryMemoryProvider.ChatHistoryMemoryProviderState))] + [JsonSerializable(typeof(ChatHistoryMemoryProvider.State))] [ExcludeFromCodeCoverage] internal sealed partial class JsonContext : JsonSerializerContext; diff --git a/dotnet/src/Microsoft.Agents.AI/Memory/ChatHistoryMemoryProvider.cs b/dotnet/src/Microsoft.Agents.AI/Memory/ChatHistoryMemoryProvider.cs index 87adc9fd7a..8e9eade875 100644 --- a/dotnet/src/Microsoft.Agents.AI/Memory/ChatHistoryMemoryProvider.cs +++ b/dotnet/src/Microsoft.Agents.AI/Memory/ChatHistoryMemoryProvider.cs @@ -41,18 +41,21 @@ public sealed class ChatHistoryMemoryProvider : AIContextProvider, IDisposable private const int DefaultMaxResults = 3; private const string DefaultFunctionToolName = "Search"; private const string DefaultFunctionToolDescription = "Allows searching for related previous chat history to help answer the user question."; + private const string DefaultStateBagKey = "ChatHistoryMemoryProvider.State"; +#pragma warning disable CA2213 // VectorStore is not owned by this class - caller is responsible for disposal private readonly VectorStore _vectorStore; +#pragma warning restore CA2213 private readonly VectorStoreCollection> _collection; private readonly int _maxResults; private readonly string _contextPrompt; private readonly bool _enableSensitiveTelemetryData; private readonly ChatHistoryMemoryProviderOptions.SearchBehavior _searchTime; - private readonly AITool[] _tools; + private readonly string _toolName; + private readonly string _toolDescription; private readonly ILogger? _logger; - - private readonly ChatHistoryMemoryProviderScope _storageScope; - private readonly ChatHistoryMemoryProviderScope _searchScope; + private readonly string _stateKey; + private readonly Func _stateInitializer; private bool _collectionInitialized; private readonly SemaphoreSlim _initializationLock = new(1, 1); @@ -64,93 +67,30 @@ public sealed class ChatHistoryMemoryProvider : AIContextProvider, IDisposable /// The vector store to use for storing and retrieving chat history. /// The name of the collection for storing chat history in the vector store. /// The number of dimensions to use for the chat history vector store embeddings. - /// Optional values to scope the chat history storage with. - /// Optional values to scope the chat history search with. Where values are null, no filtering is done using those values. Defaults to if not provided. + /// A delegate that initializes the provider state on the first invocation, providing the storage and search scopes. /// Optional configuration options. /// Optional logger factory. - /// Thrown when is . + /// Thrown when or is . public ChatHistoryMemoryProvider( VectorStore vectorStore, string collectionName, int vectorDimensions, - ChatHistoryMemoryProviderScope storageScope, - ChatHistoryMemoryProviderScope? searchScope = null, + Func stateInitializer, ChatHistoryMemoryProviderOptions? options = null, ILoggerFactory? loggerFactory = null) - : this( - vectorStore, - collectionName, - vectorDimensions, - new ChatHistoryMemoryProviderState - { - StorageScope = new(Throw.IfNull(storageScope)), - SearchScope = searchScope ?? new(storageScope), - }, - options, - loggerFactory) { - } + this._vectorStore = Throw.IfNull(vectorStore); + this._stateInitializer = Throw.IfNull(stateInitializer); - /// - /// Initializes a new instance of the class from previously serialized state. - /// - /// The vector store to use for storing and retrieving chat history. - /// The name of the collection for storing chat history in the vector store. - /// The number of dimensions to use for the chat history vector store embeddings. - /// A representing the serialized state of the provider. - /// Optional settings for customizing the JSON deserialization process. - /// Optional configuration options. - /// Optional logger factory. - public ChatHistoryMemoryProvider( - VectorStore vectorStore, - string collectionName, - int vectorDimensions, - JsonElement serializedState, - JsonSerializerOptions? jsonSerializerOptions = null, - ChatHistoryMemoryProviderOptions? options = null, - ILoggerFactory? loggerFactory = null) - : this( - vectorStore, - collectionName, - vectorDimensions, - DeserializeState(serializedState, jsonSerializerOptions), - options, - loggerFactory) - { - } - - private ChatHistoryMemoryProvider( - VectorStore vectorStore, - string collectionName, - int vectorDimensions, - ChatHistoryMemoryProviderState? state = null, - ChatHistoryMemoryProviderOptions? options = null, - ILoggerFactory? loggerFactory = null) - { - this._vectorStore = vectorStore ?? throw new ArgumentNullException(nameof(vectorStore)); options ??= new ChatHistoryMemoryProviderOptions(); this._maxResults = options.MaxResults.HasValue ? Throw.IfLessThanOrEqual(options.MaxResults.Value, 0) : DefaultMaxResults; this._contextPrompt = options.ContextPrompt ?? DefaultContextPrompt; this._enableSensitiveTelemetryData = options.EnableSensitiveTelemetryData; this._searchTime = options.SearchTime; + this._stateKey = options.StateKey ?? DefaultStateBagKey; this._logger = loggerFactory?.CreateLogger(); - - if (state == null || state.StorageScope == null || state.SearchScope == null) - { - throw new InvalidOperationException($"The {nameof(ChatHistoryMemoryProvider)} state did not contain the required scope properties."); - } - - this._storageScope = state.StorageScope; - this._searchScope = state.SearchScope; - - // Create on-demand search tool (only used when behavior is OnDemandFunctionCalling) - this._tools = - [ - AIFunctionFactory.Create( - (Func>)this.SearchTextAsync, - name: options.FunctionToolName ?? DefaultFunctionToolName, - description: options.FunctionToolDescription ?? DefaultFunctionToolDescription) - ]; + this._toolName = options.FunctionToolName ?? DefaultFunctionToolName; + this._toolDescription = options.FunctionToolDescription ?? DefaultFunctionToolDescription; // Create a definition so that we can use the dimensions provided at runtime. var definition = new VectorStoreCollectionDefinition @@ -174,15 +114,52 @@ private ChatHistoryMemoryProvider( this._collection = this._vectorStore.GetDynamicCollection(Throw.IfNullOrWhitespace(collectionName), definition); } + /// + /// Gets the state from the session's StateBag, or initializes it using the StateInitializer if not present. + /// + /// The agent session containing the StateBag. + /// The provider state, or null if no session is available. + private State? GetOrInitializeState(AgentSession? session) + { + var state = session?.StateBag.GetValue(this._stateKey, AgentJsonUtilities.DefaultOptions); + if (state is not null) + { + return state; + } + + state = this._stateInitializer(session); + if (state is not null && session is not null) + { + session.StateBag.SetValue(this._stateKey, state, AgentJsonUtilities.DefaultOptions); + } + + return state; + } + /// public override async ValueTask InvokingAsync(InvokingContext context, CancellationToken cancellationToken = default) { _ = Throw.IfNull(context); + var state = this.GetOrInitializeState(context.Session); + var searchScope = state?.SearchScope ?? new ChatHistoryMemoryProviderScope(); + if (this._searchTime == ChatHistoryMemoryProviderOptions.SearchBehavior.OnDemandFunctionCalling) { + Task InlineSearchAsync(string userQuestion, CancellationToken ct) + => this.SearchTextAsync(userQuestion, searchScope, ct); + + // Create on-demand search tool (only used when behavior is OnDemandFunctionCalling) + AITool[] tools = + [ + AIFunctionFactory.Create( + InlineSearchAsync, + name: this._toolName, + description: this._toolDescription) + ]; + // Expose search tool for on-demand invocation by the model - return new AIContext { Tools = this._tools }; + return new AIContext { Tools = tools }; } try @@ -198,7 +175,7 @@ public override async ValueTask InvokingAsync(InvokingContext context } // Search for relevant chat history - var contextText = await this.SearchTextAsync(requestText, cancellationToken).ConfigureAwait(false); + var contextText = await this.SearchTextAsync(requestText, searchScope, cancellationToken).ConfigureAwait(false); if (string.IsNullOrWhiteSpace(contextText)) { @@ -217,10 +194,10 @@ public override async ValueTask InvokingAsync(InvokingContext context this._logger.LogError( ex, "ChatHistoryMemoryProvider: Failed to search for chat history due to error. ApplicationId: '{ApplicationId}', AgentId: '{AgentId}', SessionId: '{SessionId}', UserId: '{UserId}'.", - this._searchScope.ApplicationId, - this._searchScope.AgentId, - this._searchScope.SessionId, - this.SanitizeLogData(this._searchScope.UserId)); + searchScope.ApplicationId, + searchScope.AgentId, + searchScope.SessionId, + this.SanitizeLogData(searchScope.UserId)); } return new AIContext(); @@ -238,6 +215,9 @@ public override async ValueTask InvokedAsync(InvokedContext context, Cancellatio return; } + var state = this.GetOrInitializeState(context.Session); + var storageScope = state?.StorageScope ?? new ChatHistoryMemoryProviderScope(); + try { // Ensure the collection is initialized @@ -251,10 +231,10 @@ public override async ValueTask InvokedAsync(InvokedContext context, Cancellatio ["Role"] = message.Role.ToString(), ["MessageId"] = message.MessageId, ["AuthorName"] = message.AuthorName, - ["ApplicationId"] = this._storageScope?.ApplicationId, - ["AgentId"] = this._storageScope?.AgentId, - ["UserId"] = this._storageScope?.UserId, - ["SessionId"] = this._storageScope?.SessionId, + ["ApplicationId"] = storageScope.ApplicationId, + ["AgentId"] = storageScope.AgentId, + ["UserId"] = storageScope.UserId, + ["SessionId"] = storageScope.SessionId, ["Content"] = message.Text, ["CreatedAt"] = message.CreatedAt?.ToString("O") ?? DateTimeOffset.UtcNow.ToString("O"), ["ContentEmbedding"] = message.Text, @@ -273,10 +253,10 @@ public override async ValueTask InvokedAsync(InvokedContext context, Cancellatio this._logger.LogError( ex, "ChatHistoryMemoryProvider: Failed to add messages to chat history vector store due to error. ApplicationId: '{ApplicationId}', AgentId: '{AgentId}', SessionId: '{SessionId}', UserId: '{UserId}'.", - this._searchScope.ApplicationId, - this._searchScope.AgentId, - this._searchScope.SessionId, - this.SanitizeLogData(this._searchScope.UserId)); + storageScope.ApplicationId, + storageScope.AgentId, + storageScope.SessionId, + this.SanitizeLogData(storageScope.UserId)); } } } @@ -285,16 +265,17 @@ public override async ValueTask InvokedAsync(InvokedContext context, Cancellatio /// Function callable by the AI model (when enabled) to perform an ad-hoc chat history search. /// /// The query text. + /// The scope to filter search results with. /// Cancellation token. /// Formatted search results (may be empty). - internal async Task SearchTextAsync(string userQuestion, CancellationToken cancellationToken = default) + private async Task SearchTextAsync(string userQuestion, ChatHistoryMemoryProviderScope searchScope, CancellationToken cancellationToken = default) { if (string.IsNullOrWhiteSpace(userQuestion)) { return string.Empty; } - var results = await this.SearchChatHistoryAsync(userQuestion, this._maxResults, cancellationToken).ConfigureAwait(false); + var results = await this.SearchChatHistoryAsync(userQuestion, searchScope, this._maxResults, cancellationToken).ConfigureAwait(false); if (!results.Any()) { return string.Empty; @@ -315,10 +296,10 @@ internal async Task SearchTextAsync(string userQuestion, CancellationTok "ChatHistoryMemoryProvider: Search Results\nInput:{Input}\nOutput:{MessageText}\n ApplicationId: '{ApplicationId}', AgentId: '{AgentId}', SessionId: '{SessionId}', UserId: '{UserId}'.", this.SanitizeLogData(userQuestion), this.SanitizeLogData(formatted), - this._searchScope.ApplicationId, - this._searchScope.AgentId, - this._searchScope.SessionId, - this.SanitizeLogData(this._searchScope.UserId)); + searchScope.ApplicationId, + searchScope.AgentId, + searchScope.SessionId, + this.SanitizeLogData(searchScope.UserId)); } return formatted; @@ -328,11 +309,13 @@ internal async Task SearchTextAsync(string userQuestion, CancellationTok /// Searches for relevant chat history items based on the provided query text. /// /// The text to search for. + /// The scope to filter search results with. /// The maximum number of results to return. /// The cancellation token. /// A list of relevant chat history items. private async Task>> SearchChatHistoryAsync( string queryText, + ChatHistoryMemoryProviderScope searchScope, int top, CancellationToken cancellationToken = default) { @@ -343,10 +326,10 @@ internal async Task SearchTextAsync(string userQuestion, CancellationTok var collection = await this.EnsureCollectionExistsAsync(cancellationToken).ConfigureAwait(false); - string? applicationId = this._searchScope.ApplicationId; - string? agentId = this._searchScope.AgentId; - string? userId = this._searchScope.UserId; - string? sessionId = this._searchScope.SessionId; + string? applicationId = searchScope.ApplicationId; + string? agentId = searchScope.AgentId; + string? userId = searchScope.UserId; + string? sessionId = searchScope.SessionId; Expression, bool>>? filter = null; if (applicationId != null) @@ -399,10 +382,10 @@ internal async Task SearchTextAsync(string userQuestion, CancellationTok this._logger.LogInformation( "ChatHistoryMemoryProvider: Retrieved {Count} search results. ApplicationId: '{ApplicationId}', AgentId: '{AgentId}', SessionId: '{SessionId}', UserId: '{UserId}'.", results.Count, - this._searchScope.ApplicationId, - this._searchScope.AgentId, - this._searchScope.SessionId, - this.SanitizeLogData(this._searchScope.UserId)); + searchScope.ApplicationId, + searchScope.AgentId, + searchScope.SessionId, + this.SanitizeLogData(searchScope.UserId)); } return results; @@ -464,38 +447,47 @@ public void Dispose() } /// - /// Serializes the current provider state to a including storage and search scopes. + /// Serializes the current provider state to a containing any overridden prompts or descriptions. /// - /// Optional serializer options. - /// Serialized provider state. + /// Optional serializer options (ignored, source generated context is used). + /// An empty object. + /// + /// This method is deprecated. State is now stored in the and serialized as part of the session. + /// public override JsonElement Serialize(JsonSerializerOptions? jsonSerializerOptions = null) { - var state = new ChatHistoryMemoryProviderState - { - StorageScope = this._storageScope, - SearchScope = this._searchScope, - }; - - var jso = jsonSerializerOptions ?? AgentJsonUtilities.DefaultOptions; - return JsonSerializer.SerializeToElement(state, jso.GetTypeInfo(typeof(ChatHistoryMemoryProviderState))); + // State is now stored in the session StateBag, so there is nothing to serialize here. + // Return an empty JSON object. + using var doc = JsonDocument.Parse("{}"); + return doc.RootElement.Clone(); } - private static ChatHistoryMemoryProviderState? DeserializeState(JsonElement serializedState, JsonSerializerOptions? jsonSerializerOptions) + private string? SanitizeLogData(string? data) => this._enableSensitiveTelemetryData ? data : ""; + + /// + /// Represents the state of a stored in the . + /// + public sealed class State { - if (serializedState.ValueKind != JsonValueKind.Object) + /// + /// Initializes a new instance of the class with the specified storage and search scopes. + /// + /// The scope to use when storing chat history messages. + /// The scope to use when searching for relevant chat history messages. If null, the storage scope will be used for searching as well. + public State(ChatHistoryMemoryProviderScope storageScope, ChatHistoryMemoryProviderScope? searchScope = null) { - return null; + this.StorageScope = Throw.IfNull(storageScope); + this.SearchScope = searchScope ?? storageScope; } - var jso = jsonSerializerOptions ?? AgentJsonUtilities.DefaultOptions; - return serializedState.Deserialize(jso.GetTypeInfo(typeof(ChatHistoryMemoryProviderState))) as ChatHistoryMemoryProviderState; - } + /// + /// Gets or sets the scope used when storing chat history messages. + /// + public ChatHistoryMemoryProviderScope StorageScope { get; } - private string? SanitizeLogData(string? data) => this._enableSensitiveTelemetryData ? data : ""; - - internal sealed class ChatHistoryMemoryProviderState - { - public ChatHistoryMemoryProviderScope? StorageScope { get; set; } - public ChatHistoryMemoryProviderScope? SearchScope { get; set; } + /// + /// Gets or sets the scope used when searching chat history messages. + /// + public ChatHistoryMemoryProviderScope SearchScope { get; } } } diff --git a/dotnet/src/Microsoft.Agents.AI/Memory/ChatHistoryMemoryProviderOptions.cs b/dotnet/src/Microsoft.Agents.AI/Memory/ChatHistoryMemoryProviderOptions.cs index e09de68a59..05b03fef63 100644 --- a/dotnet/src/Microsoft.Agents.AI/Memory/ChatHistoryMemoryProviderOptions.cs +++ b/dotnet/src/Microsoft.Agents.AI/Memory/ChatHistoryMemoryProviderOptions.cs @@ -44,6 +44,15 @@ public sealed class ChatHistoryMemoryProviderOptions /// Defaults to . public bool EnableSensitiveTelemetryData { get; set; } + /// + /// Gets or sets the key used to store provider state in the . + /// + /// + /// Defaults to "ChatHistoryMemoryProvider.State". Override this if you need multiple + /// instances with separate state in the same session. + /// + public string? StateKey { get; set; } + /// /// Behavior choices for the provider. /// diff --git a/dotnet/src/Microsoft.Agents.AI/TextSearchProvider.cs b/dotnet/src/Microsoft.Agents.AI/TextSearchProvider.cs index be9eba1365..3ae7350dcb 100644 --- a/dotnet/src/Microsoft.Agents.AI/TextSearchProvider.cs +++ b/dotnet/src/Microsoft.Agents.AI/TextSearchProvider.cs @@ -39,31 +39,28 @@ public sealed class TextSearchProvider : AIContextProvider private const string DefaultPluginSearchFunctionDescription = "Allows searching for additional information to help answer the user question."; private const string DefaultContextPrompt = "## Additional Context\nConsider the following information from source documents when responding to the user:"; private const string DefaultCitationsPrompt = "Include citations to the source document with document name and link if document name and link is available."; + private const string DefaultStateBagKey = "TextSearchProvider.RecentMessagesText"; private readonly Func>> _searchAsync; private readonly ILogger? _logger; private readonly AITool[] _tools; - private readonly Queue _recentMessagesText; private readonly List _recentMessageRolesIncluded; private readonly int _recentMessageMemoryLimit; private readonly TextSearchProviderOptions.TextSearchBehavior _searchTime; private readonly string _contextPrompt; private readonly string _citationsPrompt; + private readonly string _stateKey; private readonly Func, string>? _contextFormatter; /// /// Initializes a new instance of the class. /// /// Delegate that executes the search logic. Must not be . - /// A representing the serialized provider state. - /// Optional serializer options (unused - source generated context is used). /// Optional configuration options. /// Optional logger factory. /// Thrown when is . public TextSearchProvider( Func>> searchAsync, - JsonElement serializedState, - JsonSerializerOptions? jsonSerializerOptions = null, TextSearchProviderOptions? options = null, ILoggerFactory? loggerFactory = null) { @@ -75,27 +72,9 @@ public TextSearchProvider( this._searchTime = options?.SearchTime ?? TextSearchProviderOptions.TextSearchBehavior.BeforeAIInvoke; this._contextPrompt = options?.ContextPrompt ?? DefaultContextPrompt; this._citationsPrompt = options?.CitationsPrompt ?? DefaultCitationsPrompt; + this._stateKey = options?.StateKey ?? DefaultStateBagKey; this._contextFormatter = options?.ContextFormatter; - // Restore recent messages from serialized state if provided - List? restoredMessages = null; - if (serializedState.ValueKind is JsonValueKind.Null or JsonValueKind.Undefined) - { - this._recentMessagesText = new(); - } - else - { - var jso = jsonSerializerOptions ?? AgentJsonUtilities.DefaultOptions; - var state = serializedState.Deserialize(jso.GetTypeInfo(typeof(TextSearchProviderState))) as TextSearchProviderState; - if (state?.RecentMessagesText is { Count: > 0 }) - { - restoredMessages = state.RecentMessagesText; - } - - // Restore recent messages respecting the limit (may truncate if limit changed afterwards). - this._recentMessagesText = restoredMessages is null ? new() : new(restoredMessages.Take(this._recentMessageMemoryLimit)); - } - // Create the on-demand search tool (only used if behavior is OnDemandFunctionCalling) this._tools = [ @@ -115,10 +94,14 @@ public override async ValueTask InvokingAsync(InvokingContext context return new AIContext { Tools = this._tools }; // No automatic message injection. } + // Retrieve recent messages from the session state bag. + var recentMessagesText = context.Session?.StateBag.GetValue(this._stateKey, AgentJsonUtilities.DefaultOptions)?.RecentMessagesText + ?? []; + // Aggregate text from memory + current request messages. var sbInput = new StringBuilder(); var requestMessagesText = context.RequestMessages.Where(x => !string.IsNullOrWhiteSpace(x?.Text)).Select(x => x.Text); - foreach (var messageText in this._recentMessagesText.Concat(requestMessagesText)) + foreach (var messageText in recentMessagesText.Concat(requestMessagesText)) { if (sbInput.Length > 0) { @@ -174,12 +157,21 @@ public override ValueTask InvokedAsync(InvokedContext context, CancellationToken return default; // Memory disabled. } + if (context.Session is null) + { + return default; // No session to store state in. + } + if (context.InvokeException is not null) { return default; // Do not update memory on failed invocations. } - var messagesText = context.RequestMessages + // Retrieve existing recent messages from the session state bag. + var recentMessagesText = context.Session.StateBag.GetValue(this._stateKey, AgentJsonUtilities.DefaultOptions)?.RecentMessagesText + ?? []; + + var newMessagesText = context.RequestMessages .Concat(context.ResponseMessages ?? []) .Where(m => this._recentMessageRolesIncluded.Contains(m.Role) && @@ -187,23 +179,19 @@ public override ValueTask InvokedAsync(InvokedContext context, CancellationToken // Filter out any messages that were added by this class in InvokingAsync, since we don't want // a feedback loop where previous search results are used to find new search results. (m.AdditionalProperties == null || m.AdditionalProperties.TryGetValue("IsTextSearchProviderOutput", out bool isTextSearchProviderOutput) == false || !isTextSearchProviderOutput)) - .Select(m => m.Text) - .ToList(); - if (messagesText.Count > limit) - { - // If the current request/response exceeds the limit, only keep the most recent messages from it. - messagesText = messagesText.Skip(messagesText.Count - limit).ToList(); - } + .Select(m => m.Text); - foreach (var message in messagesText) - { - this._recentMessagesText.Enqueue(message); - } + // Combine existing messages with new messages, then take the most recent up to the limit. + var allMessages = recentMessagesText.Concat(newMessagesText).ToList(); + var updatedMessages = allMessages.Count > limit + ? allMessages.Skip(allMessages.Count - limit).ToList() + : allMessages; - while (this._recentMessagesText.Count > limit) - { - this._recentMessagesText.Dequeue(); - } + // Store updated state back to the session state bag. + context.Session.StateBag.SetValue( + this._stateKey, + new TextSearchProviderState { RecentMessagesText = updatedMessages }, + AgentJsonUtilities.DefaultOptions); return default; } @@ -213,16 +201,13 @@ public override ValueTask InvokedAsync(InvokedContext context, CancellationToken /// /// Optional serializer options (ignored, source generated context is used). /// A with overridden values, or default if nothing was overridden. + /// + /// This method is deprecated. State is now stored in the and serialized as part of the session. + /// public override JsonElement Serialize(JsonSerializerOptions? jsonSerializerOptions = null) { - // Only persist values that differ from defaults plus recent memory configuration & messages. - TextSearchProviderState state = new(); - if (this._recentMessageMemoryLimit > 0 && this._recentMessagesText.Count > 0) - { - state.RecentMessagesText = this._recentMessagesText.Take(this._recentMessageMemoryLimit).ToList(); - } - - return JsonSerializer.SerializeToElement(state, AgentJsonUtilities.DefaultOptions.GetTypeInfo(typeof(TextSearchProviderState))); + // State is now stored in the session StateBag, so there is nothing to serialize here. + return JsonSerializer.SerializeToElement(new TextSearchProviderState(), AgentJsonUtilities.DefaultOptions.GetTypeInfo(typeof(TextSearchProviderState))); } /// diff --git a/dotnet/src/Microsoft.Agents.AI/TextSearchProviderOptions.cs b/dotnet/src/Microsoft.Agents.AI/TextSearchProviderOptions.cs index e90a6efa63..89d959f679 100644 --- a/dotnet/src/Microsoft.Agents.AI/TextSearchProviderOptions.cs +++ b/dotnet/src/Microsoft.Agents.AI/TextSearchProviderOptions.cs @@ -59,6 +59,15 @@ public sealed class TextSearchProviderOptions /// public int RecentMessageMemoryLimit { get; set; } + /// + /// Gets or sets the key used to store provider state in the . + /// + /// + /// Defaults to "TextSearchProvider.RecentMessagesText". Override this if you need multiple + /// instances with separate state in the same session. + /// + public string? StateKey { get; set; } + /// /// Gets or sets the list of types to filter recent messages to /// when deciding which recent messages to include when constructing the search input. diff --git a/dotnet/tests/Microsoft.Agents.AI.Mem0.IntegrationTests/Mem0ProviderTests.cs b/dotnet/tests/Microsoft.Agents.AI.Mem0.IntegrationTests/Mem0ProviderTests.cs index 81ca4eb588..a78446cf29 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Mem0.IntegrationTests/Mem0ProviderTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Mem0.IntegrationTests/Mem0ProviderTests.cs @@ -19,7 +19,6 @@ public sealed class Mem0ProviderTests : IDisposable private const string SkipReason = "Requires a Mem0 service configured"; // Set to null to enable. private static readonly AIAgent s_mockAgent = new Moq.Mock().Object; - private static readonly AgentSession s_mockSession = new Moq.Mock().Object; private readonly HttpClient _httpClient; @@ -49,17 +48,18 @@ public async Task CanAddAndRetrieveUserMemoriesAsync() var question = new ChatMessage(ChatRole.User, "What is my name?"); var input = new ChatMessage(ChatRole.User, "Hello, my name is Caoimhe."); var storageScope = new Mem0ProviderScope { ThreadId = "it-thread-1", UserId = "it-user-1" }; - var sut = new Mem0Provider(this._httpClient, storageScope); + var mockSession = new TestAgentSession(); + var sut = new Mem0Provider(this._httpClient, _ => new Mem0Provider.State(storageScope)); - await sut.ClearStoredMemoriesAsync(); - var ctxBefore = await sut.InvokingAsync(new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [question])); + await sut.ClearStoredMemoriesAsync(mockSession); + var ctxBefore = await sut.InvokingAsync(new AIContextProvider.InvokingContext(s_mockAgent, mockSession, [question])); Assert.DoesNotContain("Caoimhe", ctxBefore.Messages?[0].Text ?? string.Empty); // Act - await sut.InvokedAsync(new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, [input], aiContextProviderMessages: null)); - var ctxAfterAdding = await GetContextWithRetryAsync(sut, question); - await sut.ClearStoredMemoriesAsync(); - var ctxAfterClearing = await sut.InvokingAsync(new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [question])); + await sut.InvokedAsync(new AIContextProvider.InvokedContext(s_mockAgent, mockSession, [input], aiContextProviderMessages: null)); + var ctxAfterAdding = await GetContextWithRetryAsync(sut, mockSession, question); + await sut.ClearStoredMemoriesAsync(mockSession); + var ctxAfterClearing = await sut.InvokingAsync(new AIContextProvider.InvokingContext(s_mockAgent, mockSession, [question])); // Assert Assert.Contains("Caoimhe", ctxAfterAdding.Messages?[0].Text ?? string.Empty); @@ -73,17 +73,18 @@ public async Task CanAddAndRetrieveAgentMemoriesAsync() var question = new ChatMessage(ChatRole.User, "What is your name?"); var assistantIntro = new ChatMessage(ChatRole.Assistant, "Hello, I'm a friendly assistant and my name is Caoimhe."); var storageScope = new Mem0ProviderScope { AgentId = "it-agent-1" }; - var sut = new Mem0Provider(this._httpClient, storageScope); + var mockSession = new TestAgentSession(); + var sut = new Mem0Provider(this._httpClient, _ => new Mem0Provider.State(storageScope)); - await sut.ClearStoredMemoriesAsync(); - var ctxBefore = await sut.InvokingAsync(new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [question])); + await sut.ClearStoredMemoriesAsync(mockSession); + var ctxBefore = await sut.InvokingAsync(new AIContextProvider.InvokingContext(s_mockAgent, mockSession, [question])); Assert.DoesNotContain("Caoimhe", ctxBefore.Messages?[0].Text ?? string.Empty); // Act - await sut.InvokedAsync(new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, [assistantIntro], aiContextProviderMessages: null)); - var ctxAfterAdding = await GetContextWithRetryAsync(sut, question); - await sut.ClearStoredMemoriesAsync(); - var ctxAfterClearing = await sut.InvokingAsync(new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [question])); + await sut.InvokedAsync(new AIContextProvider.InvokedContext(s_mockAgent, mockSession, [assistantIntro], aiContextProviderMessages: null)); + var ctxAfterAdding = await GetContextWithRetryAsync(sut, mockSession, question); + await sut.ClearStoredMemoriesAsync(mockSession); + var ctxAfterClearing = await sut.InvokingAsync(new AIContextProvider.InvokingContext(s_mockAgent, mockSession, [question])); // Assert Assert.Contains("Caoimhe", ctxAfterAdding.Messages?[0].Text ?? string.Empty); @@ -96,37 +97,41 @@ public async Task DoesNotLeakMemoriesAcrossAgentScopesAsync() // Arrange var question = new ChatMessage(ChatRole.User, "What is your name?"); var assistantIntro = new ChatMessage(ChatRole.Assistant, "I'm an AI tutor and my name is Caoimhe."); - var sut1 = new Mem0Provider(this._httpClient, new Mem0ProviderScope { AgentId = "it-agent-a" }); - var sut2 = new Mem0Provider(this._httpClient, new Mem0ProviderScope { AgentId = "it-agent-b" }); - - await sut1.ClearStoredMemoriesAsync(); - await sut2.ClearStoredMemoriesAsync(); - - var ctxBefore1 = await sut1.InvokingAsync(new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [question])); - var ctxBefore2 = await sut2.InvokingAsync(new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [question])); + var storageScope1 = new Mem0ProviderScope { AgentId = "it-agent-a" }; + var storageScope2 = new Mem0ProviderScope { AgentId = "it-agent-b" }; + var mockSession1 = new TestAgentSession(); + var mockSession2 = new TestAgentSession(); + var sut1 = new Mem0Provider(this._httpClient, _ => new Mem0Provider.State(storageScope1)); + var sut2 = new Mem0Provider(this._httpClient, _ => new Mem0Provider.State(storageScope2)); + + await sut1.ClearStoredMemoriesAsync(mockSession1); + await sut2.ClearStoredMemoriesAsync(mockSession2); + + var ctxBefore1 = await sut1.InvokingAsync(new AIContextProvider.InvokingContext(s_mockAgent, mockSession1, [question])); + var ctxBefore2 = await sut2.InvokingAsync(new AIContextProvider.InvokingContext(s_mockAgent, mockSession2, [question])); Assert.DoesNotContain("Caoimhe", ctxBefore1.Messages?[0].Text ?? string.Empty); Assert.DoesNotContain("Caoimhe", ctxBefore2.Messages?[0].Text ?? string.Empty); // Act - await sut1.InvokedAsync(new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, [assistantIntro], aiContextProviderMessages: null)); - var ctxAfterAdding1 = await GetContextWithRetryAsync(sut1, question); - var ctxAfterAdding2 = await GetContextWithRetryAsync(sut2, question); + await sut1.InvokedAsync(new AIContextProvider.InvokedContext(s_mockAgent, mockSession1, [assistantIntro], aiContextProviderMessages: null)); + var ctxAfterAdding1 = await GetContextWithRetryAsync(sut1, mockSession1, question); + var ctxAfterAdding2 = await GetContextWithRetryAsync(sut2, mockSession2, question); // Assert Assert.Contains("Caoimhe", ctxAfterAdding1.Messages?[0].Text ?? string.Empty); Assert.DoesNotContain("Caoimhe", ctxAfterAdding2.Messages?[0].Text ?? string.Empty); // Cleanup - await sut1.ClearStoredMemoriesAsync(); - await sut2.ClearStoredMemoriesAsync(); + await sut1.ClearStoredMemoriesAsync(mockSession1); + await sut2.ClearStoredMemoriesAsync(mockSession2); } - private static async Task GetContextWithRetryAsync(Mem0Provider provider, ChatMessage question, int attempts = 5, int delayMs = 1000) + private static async Task GetContextWithRetryAsync(Mem0Provider provider, AgentSession session, ChatMessage question, int attempts = 5, int delayMs = 1000) { AIContext? ctx = null; for (int i = 0; i < attempts; i++) { - ctx = await provider.InvokingAsync(new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [question]), CancellationToken.None); + ctx = await provider.InvokingAsync(new AIContextProvider.InvokingContext(s_mockAgent, session, [question]), CancellationToken.None); var text = ctx.Messages?[0].Text; if (!string.IsNullOrEmpty(text) && text.IndexOf("Caoimhe", StringComparison.OrdinalIgnoreCase) >= 0) { @@ -141,4 +146,12 @@ public void Dispose() { this._httpClient.Dispose(); } + + private sealed class TestAgentSession : AgentSession + { + public TestAgentSession() + { + this.StateBag = new AgentSessionStateBag(); + } + } } diff --git a/dotnet/tests/Microsoft.Agents.AI.Mem0.UnitTests/Mem0ProviderTests.cs b/dotnet/tests/Microsoft.Agents.AI.Mem0.UnitTests/Mem0ProviderTests.cs index b886784af9..71a4c4fbd3 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Mem0.UnitTests/Mem0ProviderTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Mem0.UnitTests/Mem0ProviderTests.cs @@ -19,7 +19,6 @@ namespace Microsoft.Agents.AI.Mem0.UnitTests; public sealed class Mem0ProviderTests : IDisposable { private static readonly AIAgent s_mockAgent = new Mock().Object; - private static readonly AgentSession s_mockSession = new Mock().Object; private readonly Mock> _loggerMock; private readonly Mock _loggerFactoryMock; @@ -55,35 +54,16 @@ public void Constructor_Throws_WhenBaseAddressMissing() using HttpClient client = new(); // Act & Assert - var ex = Assert.Throws(() => new Mem0Provider(client, new Mem0ProviderScope() { ThreadId = "tid" })); + var ex = Assert.Throws(() => new Mem0Provider(client, _ => new Mem0Provider.State(new Mem0ProviderScope { ThreadId = "tid" }))); Assert.StartsWith("The HttpClient BaseAddress must be set for Mem0 operations.", ex.Message); } [Fact] - public void Constructor_Throws_WhenNoStorageScopeValueIsSet() + public void Constructor_Throws_WhenStateInitializerIsNull() { // Act & Assert - var ex = Assert.Throws(() => new Mem0Provider(this._httpClient, new Mem0ProviderScope())); - Assert.StartsWith("At least one of ApplicationId, AgentId, ThreadId, or UserId must be provided for the storage scope.", ex.Message); - } - - [Fact] - public void Constructor_Throws_WhenNoSearchScopeValueIsSet() - { - // Act & Assert - var ex = Assert.Throws(() => new Mem0Provider(this._httpClient, new Mem0ProviderScope() { ThreadId = "tid" }, new Mem0ProviderScope())); - Assert.StartsWith("At least one of ApplicationId, AgentId, ThreadId, or UserId must be provided for the search scope.", ex.Message); - } - - [Fact] - public void DeserializingConstructor_Throws_WithEmptyJsonElement() - { - // Arrange - var jsonElement = JsonSerializer.SerializeToElement(new object(), Mem0JsonUtilities.DefaultOptions); - - // Act & Assert - var ex = Assert.Throws(() => new Mem0Provider(this._httpClient, jsonElement)); - Assert.StartsWith("The Mem0Provider state did not contain the required scope properties.", ex.Message); + var ex = Assert.Throws(() => new Mem0Provider(this._httpClient, null!)); + Assert.Contains("stateInitializer", ex.Message); } [Fact] @@ -98,8 +78,9 @@ public async Task InvokingAsync_PerformsSearch_AndReturnsContextMessageAsync() ThreadId = "session", UserId = "user" }; - var sut = new Mem0Provider(this._httpClient, storageScope, options: new() { EnableSensitiveTelemetryData = true }, loggerFactory: this._loggerFactoryMock.Object); - var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "What is my name?")]); + var mockSession = new TestAgentSession(); + var sut = new Mem0Provider(this._httpClient, _ => new Mem0Provider.State(storageScope), options: new() { EnableSensitiveTelemetryData = true }, loggerFactory: this._loggerFactoryMock.Object); + var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, mockSession, [new ChatMessage(ChatRole.User, "What is my name?")]); // Act var aiContext = await sut.InvokingAsync(invokingContext); @@ -162,9 +143,10 @@ public async Task InvokingAsync_LogsUserIdBasedOnEnableSensitiveTelemetryDataAsy UserId = "user" }; var options = new Mem0ProviderOptions { EnableSensitiveTelemetryData = enableSensitiveTelemetryData }; + var mockSession = new TestAgentSession(); - var sut = new Mem0Provider(this._httpClient, storageScope, options: options, loggerFactory: this._loggerFactoryMock.Object); - var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Who am I?")]); + var sut = new Mem0Provider(this._httpClient, _ => new Mem0Provider.State(storageScope), options: options, loggerFactory: this._loggerFactoryMock.Object); + var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, mockSession, [new ChatMessage(ChatRole.User, "Who am I?")]); // Act await sut.InvokingAsync(invokingContext, CancellationToken.None); @@ -204,7 +186,8 @@ public async Task InvokedAsync_PersistsAllowedMessagesAsync() this._handler.EnqueueEmptyOk(); // For second CreateMemory this._handler.EnqueueEmptyOk(); // For third CreateMemory var storageScope = new Mem0ProviderScope { ApplicationId = "a", AgentId = "b", ThreadId = "c", UserId = "d" }; - var sut = new Mem0Provider(this._httpClient, storageScope); + var mockSession = new TestAgentSession(); + var sut = new Mem0Provider(this._httpClient, _ => new Mem0Provider.State(storageScope)); var requestMessages = new List { @@ -218,7 +201,7 @@ public async Task InvokedAsync_PersistsAllowedMessagesAsync() }; // Act - await sut.InvokedAsync(new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages, aiContextProviderMessages: null) { ResponseMessages = responseMessages }); + await sut.InvokedAsync(new AIContextProvider.InvokedContext(s_mockAgent, mockSession, requestMessages, aiContextProviderMessages: null) { ResponseMessages = responseMessages }); // Assert var memoryPosts = this._handler.Requests.Where(r => r.RequestMessage.RequestUri!.AbsolutePath == "/v1/memories/" && r.RequestMessage.Method == HttpMethod.Post).ToList(); @@ -235,7 +218,8 @@ public async Task InvokedAsync_PersistsNothingForFailedRequestAsync() { // Arrange var storageScope = new Mem0ProviderScope { ApplicationId = "a", AgentId = "b", ThreadId = "c", UserId = "d" }; - var sut = new Mem0Provider(this._httpClient, storageScope); + var mockSession = new TestAgentSession(); + var sut = new Mem0Provider(this._httpClient, _ => new Mem0Provider.State(storageScope)); var requestMessages = new List { @@ -245,7 +229,7 @@ public async Task InvokedAsync_PersistsNothingForFailedRequestAsync() }; // Act - await sut.InvokedAsync(new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages, aiContextProviderMessages: null) { ResponseMessages = null, InvokeException = new InvalidOperationException("Request Failed") }); + await sut.InvokedAsync(new AIContextProvider.InvokedContext(s_mockAgent, mockSession, requestMessages, aiContextProviderMessages: null) { ResponseMessages = null, InvokeException = new InvalidOperationException("Request Failed") }); // Assert Assert.Empty(this._handler.Requests); @@ -256,7 +240,8 @@ public async Task InvokedAsync_ShouldNotThrow_WhenStorageFailsAsync() { // Arrange var storageScope = new Mem0ProviderScope { ApplicationId = "a", AgentId = "b", ThreadId = "c", UserId = "d" }; - var sut = new Mem0Provider(this._httpClient, storageScope, loggerFactory: this._loggerFactoryMock.Object); + var mockSession = new TestAgentSession(); + var sut = new Mem0Provider(this._httpClient, _ => new Mem0Provider.State(storageScope), loggerFactory: this._loggerFactoryMock.Object); this._handler.EnqueueEmptyInternalServerError(); var requestMessages = new List @@ -271,7 +256,7 @@ public async Task InvokedAsync_ShouldNotThrow_WhenStorageFailsAsync() }; // Act - await sut.InvokedAsync(new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages, aiContextProviderMessages: null) { ResponseMessages = responseMessages }); + await sut.InvokedAsync(new AIContextProvider.InvokedContext(s_mockAgent, mockSession, requestMessages, aiContextProviderMessages: null) { ResponseMessages = responseMessages }); // Assert this._loggerMock.Verify( @@ -310,7 +295,8 @@ public async Task InvokedAsync_LogsUserIdBasedOnEnableSensitiveTelemetryDataAsyn }; var options = new Mem0ProviderOptions { EnableSensitiveTelemetryData = enableSensitiveTelemetryData }; - var sut = new Mem0Provider(this._httpClient, storageScope, options: options, loggerFactory: this._loggerFactoryMock.Object); + var mockSession = new TestAgentSession(); + var sut = new Mem0Provider(this._httpClient, _ => new Mem0Provider.State(storageScope), options: options, loggerFactory: this._loggerFactoryMock.Object); var requestMessages = new List { new(ChatRole.User, "User text") @@ -321,7 +307,7 @@ public async Task InvokedAsync_LogsUserIdBasedOnEnableSensitiveTelemetryDataAsyn }; // Act - await sut.InvokedAsync(new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages, aiContextProviderMessages: null) { ResponseMessages = responseMessages }); + await sut.InvokedAsync(new AIContextProvider.InvokedContext(s_mockAgent, mockSession, requestMessages, aiContextProviderMessages: null) { ResponseMessages = responseMessages }); // Assert Assert.Equal(expectedLogCount, this._loggerMock.Invocations.Count); @@ -343,11 +329,12 @@ public async Task ClearStoredMemoriesAsync_SendsDeleteWithQueryAsync() { // Arrange var storageScope = new Mem0ProviderScope { ApplicationId = "app", AgentId = "agent", ThreadId = "session", UserId = "user" }; - var sut = new Mem0Provider(this._httpClient, storageScope); + var sut = new Mem0Provider(this._httpClient, _ => new Mem0Provider.State(storageScope)); this._handler.EnqueueEmptyOk(); // for DELETE + var mockSession = new TestAgentSession(); // Act - await sut.ClearStoredMemoriesAsync(); + await sut.ClearStoredMemoriesAsync(mockSession); // Assert var delete = Assert.Single(this._handler.Requests, r => r.RequestMessage.Method == HttpMethod.Delete); @@ -355,46 +342,19 @@ public async Task ClearStoredMemoriesAsync_SendsDeleteWithQueryAsync() } [Fact] - public void Serialize_RoundTripsScopes() + public void Serialize_ReturnsEmptyJsonObject() { // Arrange var storageScope = new Mem0ProviderScope { ApplicationId = "app", AgentId = "agent", ThreadId = "session", UserId = "user" }; - var sut = new Mem0Provider(this._httpClient, storageScope, options: new() { ContextPrompt = "Custom:" }, loggerFactory: this._loggerFactoryMock.Object); - - // Act - var stateElement = sut.Serialize(); - using JsonDocument doc = JsonDocument.Parse(stateElement.GetRawText()); - var storageScopeElement = doc.RootElement.GetProperty("storageScope"); - Assert.Equal("app", storageScopeElement.GetProperty("applicationId").GetString()); - Assert.Equal("agent", storageScopeElement.GetProperty("agentId").GetString()); - Assert.Equal("session", storageScopeElement.GetProperty("threadId").GetString()); - Assert.Equal("user", storageScopeElement.GetProperty("userId").GetString()); - - var sut2 = new Mem0Provider(this._httpClient, stateElement); - var stateElement2 = sut2.Serialize(); - - // Assert - using JsonDocument doc2 = JsonDocument.Parse(stateElement2.GetRawText()); - var storageScopeElement2 = doc2.RootElement.GetProperty("storageScope"); - Assert.Equal("app", storageScopeElement2.GetProperty("applicationId").GetString()); - Assert.Equal("agent", storageScopeElement2.GetProperty("agentId").GetString()); - Assert.Equal("session", storageScopeElement2.GetProperty("threadId").GetString()); - Assert.Equal("user", storageScopeElement2.GetProperty("userId").GetString()); - } - - [Fact] - public void Serialize_DoesNotIncludeDefaultContextPrompt() - { - // Arrange - var storageScope = new Mem0ProviderScope { ApplicationId = "app" }; - var sut = new Mem0Provider(this._httpClient, storageScope); + var sut = new Mem0Provider(this._httpClient, _ => new Mem0Provider.State(storageScope), options: new() { ContextPrompt = "Custom:" }, loggerFactory: this._loggerFactoryMock.Object); // Act var stateElement = sut.Serialize(); // Assert using JsonDocument doc = JsonDocument.Parse(stateElement.GetRawText()); - Assert.False(doc.RootElement.TryGetProperty("contextPrompt", out _)); + Assert.Equal(JsonValueKind.Object, doc.RootElement.ValueKind); + Assert.Empty(doc.RootElement.EnumerateObject()); } [Fact] @@ -402,8 +362,9 @@ public async Task InvokingAsync_ShouldNotThrow_WhenSearchFailsAsync() { // Arrange var storageScope = new Mem0ProviderScope { ApplicationId = "app" }; - var provider = new Mem0Provider(this._httpClient, storageScope, loggerFactory: this._loggerFactoryMock.Object); - var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Q?")]); + var mockSession = new TestAgentSession(); + var provider = new Mem0Provider(this._httpClient, _ => new Mem0Provider.State(storageScope), loggerFactory: this._loggerFactoryMock.Object); + var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, mockSession, [new ChatMessage(ChatRole.User, "Q?")]); // Act var aiContext = await provider.InvokingAsync(invokingContext, CancellationToken.None); @@ -421,6 +382,49 @@ public async Task InvokingAsync_ShouldNotThrow_WhenSearchFailsAsync() Times.Once); } + [Fact] + public async Task StateInitializer_IsCalledOnceAndStoredInStateBagAsync() + { + // Arrange + this._handler.EnqueueJsonResponse("[]"); + this._handler.EnqueueJsonResponse("[]"); + var storageScope = new Mem0ProviderScope { ApplicationId = "app" }; + var mockSession = new TestAgentSession(); + int initializerCallCount = 0; + var sut = new Mem0Provider(this._httpClient, _ => + { + initializerCallCount++; + return new Mem0Provider.State(storageScope); + }); + var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, mockSession, [new ChatMessage(ChatRole.User, "Q?")]); + + // Act + await sut.InvokingAsync(invokingContext, CancellationToken.None); + await sut.InvokingAsync(invokingContext, CancellationToken.None); + + // Assert + Assert.Equal(1, initializerCallCount); + } + + [Fact] + public async Task StateKey_CanBeConfiguredViaOptionsAsync() + { + // Arrange + this._handler.EnqueueJsonResponse("[]"); + var storageScope = new Mem0ProviderScope { ApplicationId = "app" }; + var mockSession = new TestAgentSession(); + const string CustomKey = "MyCustomKey"; + var sut = new Mem0Provider(this._httpClient, _ => new Mem0Provider.State(storageScope), options: new() { StateKey = CustomKey }); + var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, mockSession, [new ChatMessage(ChatRole.User, "Q?")]); + + // Act + await sut.InvokingAsync(invokingContext, CancellationToken.None); + + // Assert + Assert.True(mockSession.StateBag.TryGetValue(CustomKey, out var state, Mem0JsonUtilities.DefaultOptions)); + Assert.NotNull(state); + } + private static bool ContainsOrdinal(string source, string value) => source.IndexOf(value, StringComparison.Ordinal) >= 0; public void Dispose() @@ -465,4 +469,12 @@ public void EnqueueJsonResponse(string json) public void EnqueueEmptyInternalServerError() => this._responses.Enqueue(new HttpResponseMessage(System.Net.HttpStatusCode.InternalServerError)); } + + private sealed class TestAgentSession : AgentSession + { + public TestAgentSession() + { + this.StateBag = new AgentSessionStateBag(); + } + } } diff --git a/dotnet/tests/Microsoft.Agents.AI.UnitTests/Data/TextSearchProviderTests.cs b/dotnet/tests/Microsoft.Agents.AI.UnitTests/Data/TextSearchProviderTests.cs index 360c3071ae..c9e39d7d46 100644 --- a/dotnet/tests/Microsoft.Agents.AI.UnitTests/Data/TextSearchProviderTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.UnitTests/Data/TextSearchProviderTests.cs @@ -18,7 +18,6 @@ namespace Microsoft.Agents.AI.UnitTests.Data; public sealed class TextSearchProviderTests { private static readonly AIAgent s_mockAgent = new Mock().Object; - private static readonly AgentSession s_mockSession = new Mock().Object; private readonly Mock> _loggerMock; private readonly Mock _loggerFactoryMock; @@ -64,11 +63,11 @@ public async Task InvokingAsync_ShouldInjectFormattedResultsAsync(string? overri ContextPrompt = overrideContextPrompt, CitationsPrompt = overrideCitationsPrompt }; - var provider = new TextSearchProvider(SearchDelegateAsync, default, null, options, withLogging ? this._loggerFactoryMock.Object : null); + var provider = new TextSearchProvider(SearchDelegateAsync, options, withLogging ? this._loggerFactoryMock.Object : null); var invokingContext = new AIContextProvider.InvokingContext( s_mockAgent, - s_mockSession, + new TestAgentSession(), [ new ChatMessage(ChatRole.User, "Sample user question?"), new ChatMessage(ChatRole.User, "Additional part") @@ -143,8 +142,8 @@ public async Task InvokingAsync_OnDemand_ShouldExposeSearchToolAsync(string? ove FunctionToolName = overrideName, FunctionToolDescription = overrideDescription }; - var provider = new TextSearchProvider(this.NoResultSearchAsync, default, null, options); - var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Q?")]); + var provider = new TextSearchProvider(this.NoResultSearchAsync, options); + var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, new TestAgentSession(), [new ChatMessage(ChatRole.User, "Q?")]); // Act var aiContext = await provider.InvokingAsync(invokingContext, CancellationToken.None); @@ -162,8 +161,8 @@ public async Task InvokingAsync_OnDemand_ShouldExposeSearchToolAsync(string? ove public async Task InvokingAsync_ShouldNotThrow_WhenSearchFailsAsync() { // Arrange - var provider = new TextSearchProvider(this.FailingSearchAsync, default, null, loggerFactory: this._loggerFactoryMock.Object); - var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Q?")]); + var provider = new TextSearchProvider(this.FailingSearchAsync, loggerFactory: this._loggerFactoryMock.Object); + var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, new TestAgentSession(), [new ChatMessage(ChatRole.User, "Q?")]); // Act var aiContext = await provider.InvokingAsync(invokingContext, CancellationToken.None); @@ -203,7 +202,7 @@ public async Task SearchAsync_ShouldReturnFormattedResultsAsync(string? override ContextPrompt = overrideContextPrompt, CitationsPrompt = overrideCitationsPrompt }; - var provider = new TextSearchProvider(SearchDelegateAsync, default, null, options); + var provider = new TextSearchProvider(SearchDelegateAsync, options); // Act var formatted = await provider.SearchAsync("Sample user question?", CancellationToken.None); @@ -255,8 +254,8 @@ public async Task InvokingAsync_ShouldUseContextFormatterWhenProvidedAsync() SearchTime = TextSearchProviderOptions.TextSearchBehavior.BeforeAIInvoke, ContextFormatter = r => $"Custom formatted context with {r.Count} results." }; - var provider = new TextSearchProvider(SearchDelegateAsync, default, null, options); - var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Q?")]); + var provider = new TextSearchProvider(SearchDelegateAsync, options); + var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, new TestAgentSession(), [new ChatMessage(ChatRole.User, "Q?")]); // Act var aiContext = await provider.InvokingAsync(invokingContext, CancellationToken.None); @@ -289,8 +288,8 @@ public async Task InvokingAsync_WithRawRepresentations_ContextFormatterCanAccess SearchTime = TextSearchProviderOptions.TextSearchBehavior.BeforeAIInvoke, ContextFormatter = r => string.Join(",", r.Select(x => ((RawPayload)x.RawRepresentation!).Id)) }; - var provider = new TextSearchProvider(SearchDelegateAsync, default, null, options); - var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Q?")]); + var provider = new TextSearchProvider(SearchDelegateAsync, options); + var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, new TestAgentSession(), [new ChatMessage(ChatRole.User, "Q?")]); // Act var aiContext = await provider.InvokingAsync(invokingContext, CancellationToken.None); @@ -306,8 +305,8 @@ public async Task InvokingAsync_WithNoResults_ShouldReturnEmptyContextAsync() { // Arrange var options = new TextSearchProviderOptions { SearchTime = TextSearchProviderOptions.TextSearchBehavior.BeforeAIInvoke }; - var provider = new TextSearchProvider(this.NoResultSearchAsync, default, null, options); - var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Q?")]); + var provider = new TextSearchProvider(this.NoResultSearchAsync, options); + var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, new TestAgentSession(), [new ChatMessage(ChatRole.User, "Q?")]); // Act var aiContext = await provider.InvokingAsync(invokingContext, CancellationToken.None); @@ -335,7 +334,7 @@ public async Task InvokingAsync_WithPreviousFailedRequest_ShouldNotIncludeFailed capturedInput = input; return Task.FromResult>([]); // No results needed. } - var provider = new TextSearchProvider(SearchDelegateAsync, default, null, options); + var provider = new TextSearchProvider(SearchDelegateAsync, options); // Populate memory with more messages than the limit (A,B,C,D) -> should retain B,C,D var initialMessages = new[] @@ -345,11 +344,13 @@ public async Task InvokingAsync_WithPreviousFailedRequest_ShouldNotIncludeFailed new ChatMessage(ChatRole.User, "C"), new ChatMessage(ChatRole.Assistant, "D"), }; - await provider.InvokedAsync(new(s_mockAgent, s_mockSession, initialMessages, aiContextProviderMessages: null) { InvokeException = new InvalidOperationException("Request Failed") }); + + var session = new TestAgentSession(); + await provider.InvokedAsync(new(s_mockAgent, session, initialMessages, aiContextProviderMessages: null) { InvokeException = new InvalidOperationException("Request Failed") }); var invokingContext = new AIContextProvider.InvokingContext( s_mockAgent, - s_mockSession, + session, [ new ChatMessage(ChatRole.User, "E") ]); @@ -377,7 +378,8 @@ public async Task InvokingAsync_WithRecentMessageMemory_ShouldIncludeStoredMessa capturedInput = input; return Task.FromResult>([]); // No results needed. } - var provider = new TextSearchProvider(SearchDelegateAsync, default, null, options); + var provider = new TextSearchProvider(SearchDelegateAsync, options); + var session = new TestAgentSession(); // Populate memory with more messages than the limit (A,B,C,D) -> should retain B,C,D var initialMessages = new[] @@ -387,11 +389,11 @@ public async Task InvokingAsync_WithRecentMessageMemory_ShouldIncludeStoredMessa new ChatMessage(ChatRole.User, "C"), new ChatMessage(ChatRole.Assistant, "D"), }; - await provider.InvokedAsync(new(s_mockAgent, s_mockSession, initialMessages, aiContextProviderMessages: null)); + await provider.InvokedAsync(new(s_mockAgent, session, initialMessages, aiContextProviderMessages: null)); var invokingContext = new AIContextProvider.InvokingContext( s_mockAgent, - s_mockSession, + session, [ new ChatMessage(ChatRole.User, "E") ]); @@ -419,12 +421,13 @@ public async Task InvokingAsync_WithAccumulatedMemoryAcrossInvocations_ShouldInc capturedInput = input; return Task.FromResult>([]); } - var provider = new TextSearchProvider(SearchDelegateAsync, default, null, options); + var provider = new TextSearchProvider(SearchDelegateAsync, options); + var session = new TestAgentSession(); // First memory update (A,B) await provider.InvokedAsync(new( s_mockAgent, - s_mockSession, + session, [ new ChatMessage(ChatRole.User, "A"), new ChatMessage(ChatRole.Assistant, "B"), @@ -433,14 +436,14 @@ await provider.InvokedAsync(new( // Second memory update (C,D,E) await provider.InvokedAsync(new( s_mockAgent, - s_mockSession, + session, [ new ChatMessage(ChatRole.User, "C"), new ChatMessage(ChatRole.Assistant, "D"), new ChatMessage(ChatRole.User, "E"), ], aiContextProviderMessages: null)); - var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "F")]); + var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, session, [new ChatMessage(ChatRole.User, "F")]); // Act await provider.InvokingAsync(invokingContext, CancellationToken.None); @@ -465,7 +468,8 @@ public async Task InvokingAsync_WithRecentMessageRolesIncluded_ShouldFilterRoles capturedInput = input; return Task.FromResult>([]); // No results needed for this test. } - var provider = new TextSearchProvider(SearchDelegateAsync, default, null, options); + var provider = new TextSearchProvider(SearchDelegateAsync, options); + var session = new TestAgentSession(); // Populate memory with mixed roles; only Assistant messages (A1,A2) should be retained. var initialMessages = new[] @@ -475,11 +479,11 @@ public async Task InvokingAsync_WithRecentMessageRolesIncluded_ShouldFilterRoles new ChatMessage(ChatRole.User, "U2"), new ChatMessage(ChatRole.Assistant, "A2"), }; - await provider.InvokedAsync(new(s_mockAgent, s_mockSession, initialMessages, null)); + await provider.InvokedAsync(new(s_mockAgent, session, initialMessages, null)); var invokingContext = new AIContextProvider.InvokingContext( s_mockAgent, - s_mockSession, + session, [ new ChatMessage(ChatRole.User, "Question?") // Current request message always appended. ]); @@ -496,7 +500,7 @@ public async Task InvokingAsync_WithRecentMessageRolesIncluded_ShouldFilterRoles #region Serialization Tests [Fact] - public void Serialize_WithNoRecentMessages_ShouldReturnEmptyState() + public void Serialize_ShouldReturnEmptyState() { // Arrange var options = new TextSearchProviderOptions @@ -504,18 +508,18 @@ public void Serialize_WithNoRecentMessages_ShouldReturnEmptyState() SearchTime = TextSearchProviderOptions.TextSearchBehavior.BeforeAIInvoke, RecentMessageMemoryLimit = 3 }; - var provider = new TextSearchProvider(this.NoResultSearchAsync, default, null, options); + var provider = new TextSearchProvider(this.NoResultSearchAsync, options); // Act var state = provider.Serialize(); - // Assert + // Assert - State is now stored in session StateBag, so provider.Serialize() returns empty state Assert.Equal(JsonValueKind.Object, state.ValueKind); Assert.False(state.TryGetProperty("recentMessagesText", out _)); } [Fact] - public async Task Serialize_WithRecentMessages_ShouldPersistMessagesUpToLimitAsync() + public async Task InvokedAsync_ShouldPersistMessagesToSessionStateBagAsync() { // Arrange var options = new TextSearchProviderOptions @@ -524,7 +528,8 @@ public async Task Serialize_WithRecentMessages_ShouldPersistMessagesUpToLimitAsy RecentMessageMemoryLimit = 3, RecentMessageRolesIncluded = [ChatRole.User, ChatRole.Assistant] }; - var provider = new TextSearchProvider(this.NoResultSearchAsync, default, null, options); + var provider = new TextSearchProvider(this.NoResultSearchAsync, options); + var session = new TestAgentSession(); var messages = new[] { new ChatMessage(ChatRole.User, "M1"), @@ -533,11 +538,13 @@ public async Task Serialize_WithRecentMessages_ShouldPersistMessagesUpToLimitAsy }; // Act - await provider.InvokedAsync(new(s_mockAgent, s_mockSession, messages, aiContextProviderMessages: null)); // Populate recent memory. - var state = provider.Serialize(); + await provider.InvokedAsync(new(s_mockAgent, session, messages, aiContextProviderMessages: null)); // Populate recent memory. - // Assert - Assert.True(state.TryGetProperty("recentMessagesText", out var recentProperty)); + // Assert - State should be in the session's StateBag + var stateBagSerialized = session.StateBag.Serialize(); + Assert.True(stateBagSerialized.TryGetProperty("TextSearchProvider.RecentMessagesText", out var stateProperty)); + Assert.True(stateProperty.TryGetProperty("jsonValue", out var jsonValueProperty)); + Assert.True(jsonValueProperty.TryGetProperty("recentMessagesText", out var recentProperty)); Assert.Equal(JsonValueKind.Array, recentProperty.ValueKind); var list = recentProperty.EnumerateArray().Select(e => e.GetString()).ToList(); Assert.Equal(3, list.Count); @@ -545,7 +552,7 @@ public async Task Serialize_WithRecentMessages_ShouldPersistMessagesUpToLimitAsy } [Fact] - public async Task SerializeAndDeserialize_RoundtripRestoresMessagesAsync() + public async Task StateBag_RoundtripRestoresMessagesAsync() { // Arrange var options = new TextSearchProviderOptions @@ -554,7 +561,8 @@ public async Task SerializeAndDeserialize_RoundtripRestoresMessagesAsync() RecentMessageMemoryLimit = 4, RecentMessageRolesIncluded = [ChatRole.User, ChatRole.Assistant] }; - var provider = new TextSearchProvider(this.NoResultSearchAsync, default, null, options); + var provider = new TextSearchProvider(this.NoResultSearchAsync, options); + var session = new TestAgentSession(); var messages = new[] { new ChatMessage(ChatRole.User, "A"), @@ -562,23 +570,25 @@ public async Task SerializeAndDeserialize_RoundtripRestoresMessagesAsync() new ChatMessage(ChatRole.User, "C"), new ChatMessage(ChatRole.Assistant, "D"), }; - await provider.InvokedAsync(new(s_mockAgent, s_mockSession, messages, aiContextProviderMessages: null)); + await provider.InvokedAsync(new(s_mockAgent, session, messages, aiContextProviderMessages: null)); + + // Act - Serialize and deserialize the StateBag + var serializedStateBag = session.StateBag.Serialize(); + var restoredSession = new TestAgentSession(AgentSessionStateBag.Deserialize(serializedStateBag)); - // Act - var state = provider.Serialize(); string? capturedInput = null; Task> SearchDelegate2Async(string input, CancellationToken ct) { capturedInput = input; return Task.FromResult>([]); } - var roundTrippedProvider = new TextSearchProvider(SearchDelegate2Async, state, options: new TextSearchProviderOptions + var newProvider = new TextSearchProvider(SearchDelegate2Async, new TextSearchProviderOptions { SearchTime = TextSearchProviderOptions.TextSearchBehavior.BeforeAIInvoke, RecentMessageMemoryLimit = 4 }); var emptyMessages = Array.Empty(); - await roundTrippedProvider.InvokingAsync(new(s_mockAgent, s_mockSession, emptyMessages), CancellationToken.None); // Trigger search to read memory. + await newProvider.InvokingAsync(new(s_mockAgent, restoredSession, emptyMessages), CancellationToken.None); // Trigger search to read memory. // Assert Assert.NotNull(capturedInput); @@ -586,25 +596,10 @@ public async Task SerializeAndDeserialize_RoundtripRestoresMessagesAsync() } [Fact] - public async Task Deserialize_WithChangedLowerLimit_ShouldTruncateToNewLimitAsync() + public async Task InvokingAsync_WithEmptyStateBag_ShouldHaveNoMessagesAsync() { // Arrange - var initialProvider = new TextSearchProvider(this.NoResultSearchAsync, default, null, new TextSearchProviderOptions - { - SearchTime = TextSearchProviderOptions.TextSearchBehavior.BeforeAIInvoke, - RecentMessageMemoryLimit = 5, - RecentMessageRolesIncluded = [ChatRole.User, ChatRole.Assistant] - }); - var messages = new[] - { - new ChatMessage(ChatRole.User, "L1"), - new ChatMessage(ChatRole.Assistant, "L2"), - new ChatMessage(ChatRole.User, "L3"), - new ChatMessage(ChatRole.Assistant, "L4"), - new ChatMessage(ChatRole.User, "L5"), - }; - await initialProvider.InvokedAsync(new(s_mockAgent, s_mockSession, messages, aiContextProviderMessages: null)); - var state = initialProvider.Serialize(); + var session = new TestAgentSession(); // Fresh session with empty StateBag string? capturedInput = null; Task> SearchDelegate2Async(string input, CancellationToken ct) @@ -614,43 +609,17 @@ public async Task Deserialize_WithChangedLowerLimit_ShouldTruncateToNewLimitAsyn } // Act - var restoredProvider = new TextSearchProvider(SearchDelegate2Async, state, options: new TextSearchProviderOptions - { - SearchTime = TextSearchProviderOptions.TextSearchBehavior.BeforeAIInvoke, - RecentMessageMemoryLimit = 3 // Lower limit - }); - await restoredProvider.InvokingAsync(new(s_mockAgent, s_mockSession, Array.Empty()), CancellationToken.None); - - // Assert - Assert.NotNull(capturedInput); - Assert.Equal("L1\nL2\nL3", capturedInput); - } - - [Fact] - public async Task Deserialize_WithEmptyState_ShouldHaveNoMessagesAsync() - { - // Arrange - var emptyState = JsonSerializer.Deserialize("{}", TestJsonSerializerContext.Default.JsonElement); - - string? capturedInput = null; - Task> SearchDelegate2Async(string input, CancellationToken ct) - { - capturedInput = input; - return Task.FromResult>([]); - } - - // Act - var provider = new TextSearchProvider(SearchDelegate2Async, emptyState, options: new TextSearchProviderOptions + var provider = new TextSearchProvider(SearchDelegate2Async, new TextSearchProviderOptions { SearchTime = TextSearchProviderOptions.TextSearchBehavior.BeforeAIInvoke, RecentMessageMemoryLimit = 3 }); var emptyMessages = Array.Empty(); - await provider.InvokingAsync(new(s_mockAgent, s_mockSession, emptyMessages), CancellationToken.None); + await provider.InvokingAsync(new(s_mockAgent, session, emptyMessages), CancellationToken.None); // Assert Assert.NotNull(capturedInput); - Assert.Equal(string.Empty, capturedInput); // No recent messages serialized => empty input. + Assert.Equal(string.Empty, capturedInput); // No recent messages in StateBag => empty input. } #endregion @@ -669,4 +638,16 @@ private sealed class RawPayload { public string Id { get; set; } = string.Empty; } + + private sealed class TestAgentSession : AgentSession + { + public TestAgentSession() + { + } + + public TestAgentSession(AgentSessionStateBag stateBag) + { + this.StateBag = stateBag; + } + } } diff --git a/dotnet/tests/Microsoft.Agents.AI.UnitTests/Memory/ChatHistoryMemoryProviderTests.cs b/dotnet/tests/Microsoft.Agents.AI.UnitTests/Memory/ChatHistoryMemoryProviderTests.cs index 8d3cad85ae..f290f9c0db 100644 --- a/dotnet/tests/Microsoft.Agents.AI.UnitTests/Memory/ChatHistoryMemoryProviderTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.UnitTests/Memory/ChatHistoryMemoryProviderTests.cs @@ -19,7 +19,6 @@ namespace Microsoft.Agents.AI.Memory.UnitTests; public class ChatHistoryMemoryProviderTests { private static readonly AIAgent s_mockAgent = new Mock().Object; - private static readonly AgentSession s_mockSession = new Mock().Object; private readonly Mock> _loggerMock; private readonly Mock _loggerFactoryMock; @@ -61,29 +60,49 @@ public ChatHistoryMemoryProviderTests() public void Constructor_Throws_ForNullVectorStore() { // Act & Assert - Assert.Throws(() => new ChatHistoryMemoryProvider(null!, "testcollection", 1, new ChatHistoryMemoryProviderScope() { UserId = "UID" })); + Assert.Throws(() => new ChatHistoryMemoryProvider( + null!, + "testcollection", + 1, + _ => new ChatHistoryMemoryProvider.State(new ChatHistoryMemoryProviderScope { UserId = "UID" }))); } [Fact] public void Constructor_Throws_ForNullCollectionName() { // Act & Assert - Assert.Throws(() => new ChatHistoryMemoryProvider(this._vectorStoreMock.Object, null!, 1, new ChatHistoryMemoryProviderScope() { UserId = "UID" })); + Assert.Throws(() => new ChatHistoryMemoryProvider( + this._vectorStoreMock.Object, + null!, + 1, + _ => new ChatHistoryMemoryProvider.State(new ChatHistoryMemoryProviderScope { UserId = "UID" }))); } [Fact] - public void Constructor_Throws_ForNullStorageScope() + public void Constructor_Throws_ForNullStateInitializer() { // Act & Assert - Assert.Throws(() => new ChatHistoryMemoryProvider(this._vectorStoreMock.Object, "testcollection", 1, null!)); + Assert.Throws(() => new ChatHistoryMemoryProvider( + this._vectorStoreMock.Object, + "testcollection", + 1, + null!)); } [Fact] public void Constructor_Throws_ForInvalidVectorDimensions() { // Act & Assert - Assert.Throws(() => new ChatHistoryMemoryProvider(this._vectorStoreMock.Object, "testcollection", 0, new ChatHistoryMemoryProviderScope() { UserId = "UID" })); - Assert.Throws(() => new ChatHistoryMemoryProvider(this._vectorStoreMock.Object, "testcollection", -5, new ChatHistoryMemoryProviderScope() { UserId = "UID" })); + Assert.Throws(() => new ChatHistoryMemoryProvider( + this._vectorStoreMock.Object, + "testcollection", + 0, + _ => new ChatHistoryMemoryProvider.State(new ChatHistoryMemoryProviderScope { UserId = "UID" }))); + Assert.Throws(() => new ChatHistoryMemoryProvider( + this._vectorStoreMock.Object, + "testcollection", + -5, + _ => new ChatHistoryMemoryProvider.State(new ChatHistoryMemoryProviderScope { UserId = "UID" }))); } #region InvokedAsync Tests @@ -113,13 +132,17 @@ public async Task InvokedAsync_UpsertsMessages_ToCollectionAsync() UserId = "user1" }; - var provider = new ChatHistoryMemoryProvider(this._vectorStoreMock.Object, TestCollectionName, 1, storeScope); + var provider = new ChatHistoryMemoryProvider( + this._vectorStoreMock.Object, + TestCollectionName, + 1, + _ => new ChatHistoryMemoryProvider.State(storeScope)); var requestMsgWithValues = new ChatMessage(ChatRole.User, "request text") { MessageId = "req-1", AuthorName = "user1", CreatedAt = new DateTimeOffset(new DateTime(2000, 1, 1), TimeSpan.Zero) }; var requestMsgWithNulls = new ChatMessage(ChatRole.User, "request text nulls"); var responseMsg = new ChatMessage(ChatRole.Assistant, "response text") { MessageId = "resp-1", AuthorName = "assistant" }; - var invokedContext = new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, [requestMsgWithValues, requestMsgWithNulls], aiContextProviderMessages: null) + var invokedContext = new AIContextProvider.InvokedContext(s_mockAgent, new TestAgentSession(), [requestMsgWithValues, requestMsgWithNulls], aiContextProviderMessages: null) { ResponseMessages = [responseMsg] }; @@ -175,9 +198,9 @@ public async Task InvokedAsync_DoesNotUpsertMessages_WhenInvokeFailedAsync() this._vectorStoreMock.Object, TestCollectionName, 1, - new ChatHistoryMemoryProviderScope() { UserId = "UID" }); + _ => new ChatHistoryMemoryProvider.State(new ChatHistoryMemoryProviderScope { UserId = "UID" })); var requestMsg = new ChatMessage(ChatRole.User, "request text") { MessageId = "req-1" }; - var invokedContext = new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, [requestMsg], aiContextProviderMessages: null) + var invokedContext = new AIContextProvider.InvokedContext(s_mockAgent, new TestAgentSession(), [requestMsg], aiContextProviderMessages: null) { InvokeException = new InvalidOperationException("Invoke failed") }; @@ -203,10 +226,10 @@ public async Task InvokedAsync_DoesNotThrow_WhenUpsertThrowsAsync() this._vectorStoreMock.Object, TestCollectionName, 1, - new ChatHistoryMemoryProviderScope() { UserId = "UID" }, + _ => new ChatHistoryMemoryProvider.State(new ChatHistoryMemoryProviderScope { UserId = "UID" }), loggerFactory: this._loggerFactoryMock.Object); var requestMsg = new ChatMessage(ChatRole.User, "request text") { MessageId = "req-1" }; - var invokedContext = new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, [requestMsg], aiContextProviderMessages: null); + var invokedContext = new AIContextProvider.InvokedContext(s_mockAgent, new TestAgentSession(), [requestMsg], aiContextProviderMessages: null); // Act await provider.InvokedAsync(invokedContext, CancellationToken.None); @@ -252,12 +275,12 @@ public async Task InvokedAsync_LogsUserIdBasedOnEnableSensitiveTelemetryDataAsyn this._vectorStoreMock.Object, TestCollectionName, 1, - new ChatHistoryMemoryProviderScope { UserId = "user1" }, + _ => new ChatHistoryMemoryProvider.State(new ChatHistoryMemoryProviderScope { UserId = "user1" }), options: options, loggerFactory: this._loggerFactoryMock.Object); var requestMsg = new ChatMessage(ChatRole.User, "request text"); - var invokedContext = new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, [requestMsg], aiContextProviderMessages: null); + var invokedContext = new AIContextProvider.InvokedContext(s_mockAgent, new TestAgentSession(), [requestMsg], aiContextProviderMessages: null); // Act await provider.InvokedAsync(invokedContext, CancellationToken.None); @@ -326,11 +349,11 @@ public async Task InvokedAsync_SearchesVectorStoreAsync() this._vectorStoreMock.Object, TestCollectionName, 1, - new ChatHistoryMemoryProviderScope() { UserId = "UID" }, + _ => new ChatHistoryMemoryProvider.State(new ChatHistoryMemoryProviderScope { UserId = "UID" }), options: providerOptions); var requestMsg = new ChatMessage(ChatRole.User, "requesting relevant history"); - var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [requestMsg]); + var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, new TestAgentSession(), [requestMsg]); // Act await provider.InvokingAsync(invokingContext, CancellationToken.None); @@ -378,10 +401,15 @@ public async Task InvokedAsync_CreatesFilter_WhenSearchScopeProvidedAsync() }) .Returns(ToAsyncEnumerableAsync(new List>>())); - var provider = new ChatHistoryMemoryProvider(this._vectorStoreMock.Object, TestCollectionName, 1, options: providerOptions, storageScope: searchScope, searchScope: searchScope); + var provider = new ChatHistoryMemoryProvider( + this._vectorStoreMock.Object, + TestCollectionName, + 1, + _ => new ChatHistoryMemoryProvider.State(searchScope, searchScope), + options: providerOptions); var requestMsg = new ChatMessage(ChatRole.User, "requesting relevant history"); - var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [requestMsg]); + var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, new TestAgentSession(), [requestMsg]); // Act await provider.InvokingAsync(invokingContext, CancellationToken.None); @@ -440,12 +468,11 @@ public async Task InvokingAsync_LogsUserIdBasedOnEnableSensitiveTelemetryDataAsy this._vectorStoreMock.Object, TestCollectionName, 1, - storageScope: scope, - searchScope: scope, + _ => new ChatHistoryMemoryProvider.State(scope, scope), options: options, loggerFactory: this._loggerFactoryMock.Object); - var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "requesting relevant history")]); + var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, new TestAgentSession(), [new ChatMessage(ChatRole.User, "requesting relevant history")]); // Act await provider.InvokingAsync(invokingContext, CancellationToken.None); @@ -482,49 +509,22 @@ public async Task InvokingAsync_LogsUserIdBasedOnEnableSensitiveTelemetryDataAsy #region Serialization Tests [Fact] - public void Serialize_Deserialize_RoundtripsScopes() + public void Serialize_ReturnsEmptyState() { // Arrange - var storageScope = new ChatHistoryMemoryProviderScope - { - ApplicationId = "app", - AgentId = "agent", - SessionId = "session", - UserId = "user" - }; - - var searchScope = new ChatHistoryMemoryProviderScope - { - ApplicationId = "app2", - AgentId = "agent2", - SessionId = "session2", - UserId = "user2" - }; - - var provider = new ChatHistoryMemoryProvider(this._vectorStoreMock.Object, TestCollectionName, 1, storageScope: storageScope, searchScope: searchScope); + var provider = new ChatHistoryMemoryProvider( + this._vectorStoreMock.Object, + TestCollectionName, + 1, + _ => new ChatHistoryMemoryProvider.State(new ChatHistoryMemoryProviderScope { UserId = "UID" }, null)); // Act var stateElement = provider.Serialize(); + // Assert - Serialize returns empty object since state is now stored in StateBag using JsonDocument doc = JsonDocument.Parse(stateElement.GetRawText()); - var storage = doc.RootElement.GetProperty("storageScope"); - Assert.Equal("app", storage.GetProperty("applicationId").GetString()); - Assert.Equal("agent", storage.GetProperty("agentId").GetString()); - Assert.Equal("session", storage.GetProperty("sessionId").GetString()); - Assert.Equal("user", storage.GetProperty("userId").GetString()); - - var search = doc.RootElement.GetProperty("searchScope"); - Assert.Equal("app2", search.GetProperty("applicationId").GetString()); - Assert.Equal("agent2", search.GetProperty("agentId").GetString()); - Assert.Equal("session2", search.GetProperty("sessionId").GetString()); - Assert.Equal("user2", search.GetProperty("userId").GetString()); - - // Act - deserialize and serialize again - var provider2 = new ChatHistoryMemoryProvider(this._vectorStoreMock.Object, TestCollectionName, 1, serializedState: stateElement); - var stateElement2 = provider2.Serialize(); - - // Assert - roundtrip the state - Assert.Equal(stateElement.GetRawText(), stateElement2.GetRawText()); + Assert.Equal(JsonValueKind.Object, doc.RootElement.ValueKind); + Assert.Empty(doc.RootElement.EnumerateObject()); } #endregion @@ -537,4 +537,16 @@ private static async IAsyncEnumerable ToAsyncEnumerableAsync(IEnumerable Date: Fri, 6 Feb 2026 11:44:55 +0000 Subject: [PATCH 03/28] Update InMemoryChatHistoryProvider to use StateBag --- .../Program.cs | 10 +- .../Program.cs | 2 +- .../Agent_Step16_ChatReduction/Program.cs | 2 +- .../InMemoryAgentSession.cs | 18 +- .../InMemoryChatHistoryProvider.cs | 222 ++++---- .../ChatClient/ChatClientAgentSession.cs | 4 +- .../InMemoryAgentSessionTests.cs | 60 +-- .../InMemoryChatHistoryProviderTests.cs | 477 ++++-------------- .../ChatClient/ChatClientAgentSessionTests.cs | 24 +- .../ChatClient/ChatClientAgentTests.cs | 25 +- ...tClientAgent_ChatHistoryManagementTests.cs | 7 +- .../TestEchoAgent.cs | 2 +- 12 files changed, 283 insertions(+), 570 deletions(-) diff --git a/dotnet/samples/GettingStarted/AgentWithMemory/AgentWithMemory_Step03_CustomMemory/Program.cs b/dotnet/samples/GettingStarted/AgentWithMemory/AgentWithMemory_Step03_CustomMemory/Program.cs index 13e0726c38..dd451af127 100644 --- a/dotnet/samples/GettingStarted/AgentWithMemory/AgentWithMemory_Step03_CustomMemory/Program.cs +++ b/dotnet/samples/GettingStarted/AgentWithMemory/AgentWithMemory_Step03_CustomMemory/Program.cs @@ -86,12 +86,12 @@ namespace SampleApp internal sealed class UserInfoMemory : AIContextProvider { private readonly IChatClient _chatClient; - private readonly Func? _stateInitializer; + private readonly Func _stateInitializer; public UserInfoMemory(IChatClient chatClient, Func? stateInitializer = null) { this._chatClient = chatClient; - this._stateInitializer = stateInitializer; + this._stateInitializer = stateInitializer ?? (_ => new UserInfo()); } public UserInfo GetUserInfo(AgentSession session) @@ -103,8 +103,7 @@ public void SetUserInfo(AgentSession session, UserInfo userInfo) public override async ValueTask InvokedAsync(InvokedContext context, CancellationToken cancellationToken = default) { var userInfo = context.Session?.StateBag.GetValue(nameof(UserInfoMemory)) - ?? this._stateInitializer?.Invoke(context.Session) - ?? new UserInfo(); + ?? this._stateInitializer.Invoke(context.Session); // Try and extract the user name and age from the message if we don't have it already and it's a user message. if ((userInfo.UserName is null || userInfo.UserAge is null) && context.RequestMessages.Any(x => x.Role == ChatRole.User)) @@ -127,8 +126,7 @@ public override async ValueTask InvokedAsync(InvokedContext context, Cancellatio public override ValueTask InvokingAsync(InvokingContext context, CancellationToken cancellationToken = default) { var userInfo = context.Session?.StateBag.GetValue(nameof(UserInfoMemory)) - ?? this._stateInitializer?.Invoke(context.Session) - ?? new UserInfo(); + ?? this._stateInitializer.Invoke(context.Session); StringBuilder instructions = new(); diff --git a/dotnet/samples/GettingStarted/AgentWithRAG/AgentWithRAG_Step01_BasicTextRAG/Program.cs b/dotnet/samples/GettingStarted/AgentWithRAG/AgentWithRAG_Step01_BasicTextRAG/Program.cs index 38f87ad2cf..c3956b9c13 100644 --- a/dotnet/samples/GettingStarted/AgentWithRAG/AgentWithRAG_Step01_BasicTextRAG/Program.cs +++ b/dotnet/samples/GettingStarted/AgentWithRAG/AgentWithRAG_Step01_BasicTextRAG/Program.cs @@ -66,7 +66,7 @@ // Since we are using ChatCompletion which stores chat history locally, we can also add a message removal policy // that removes messages produced by the TextSearchProvider before they are added to the chat history, so that // we don't bloat chat history with all the search result messages. - ChatHistoryProviderFactory = (ctx, ct) => new ValueTask(new InMemoryChatHistoryProvider(ctx.SerializedState, ctx.JsonSerializerOptions) + ChatHistoryProviderFactory = (ctx, ct) => new ValueTask(new InMemoryChatHistoryProvider() .WithAIContextProviderMessageRemoval()), }); diff --git a/dotnet/samples/GettingStarted/Agents/Agent_Step16_ChatReduction/Program.cs b/dotnet/samples/GettingStarted/Agents/Agent_Step16_ChatReduction/Program.cs index f9a5a1fc01..0c97334d10 100644 --- a/dotnet/samples/GettingStarted/Agents/Agent_Step16_ChatReduction/Program.cs +++ b/dotnet/samples/GettingStarted/Agents/Agent_Step16_ChatReduction/Program.cs @@ -24,7 +24,7 @@ { ChatOptions = new() { Instructions = "You are good at telling jokes." }, Name = "Joker", - ChatHistoryProviderFactory = (ctx, ct) => new ValueTask(new InMemoryChatHistoryProvider(new MessageCountingChatReducer(2), ctx.SerializedState, ctx.JsonSerializerOptions)) + ChatHistoryProviderFactory = (ctx, ct) => new ValueTask(new InMemoryChatHistoryProvider(chatReducer: new MessageCountingChatReducer(2))) }); AgentSession session = await agent.CreateSessionAsync(); diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/InMemoryAgentSession.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/InMemoryAgentSession.cs index 05ffafaeb9..6e964914b9 100644 --- a/dotnet/src/Microsoft.Agents.AI.Abstractions/InMemoryAgentSession.cs +++ b/dotnet/src/Microsoft.Agents.AI.Abstractions/InMemoryAgentSession.cs @@ -5,6 +5,7 @@ using System.Diagnostics; using System.Text.Json; using Microsoft.Extensions.AI; +using Microsoft.Shared.Diagnostics; namespace Microsoft.Agents.AI; @@ -38,7 +39,7 @@ public abstract class InMemoryAgentSession : AgentSession /// protected InMemoryAgentSession(InMemoryChatHistoryProvider? chatHistoryProvider = null) { - this.ChatHistoryProvider = chatHistoryProvider ?? []; + this.ChatHistoryProvider = chatHistoryProvider ?? new InMemoryChatHistoryProvider(); } /// @@ -52,7 +53,8 @@ protected InMemoryAgentSession(InMemoryChatHistoryProvider? chatHistoryProvider /// protected InMemoryAgentSession(IEnumerable messages) { - this.ChatHistoryProvider = [.. messages]; + this.ChatHistoryProvider = new InMemoryChatHistoryProvider(); + this.ChatHistoryProvider.GetMessages(this).AddRange(Throw.IfNull(messages)); } /// @@ -85,7 +87,12 @@ protected InMemoryAgentSession( this.ChatHistoryProvider = chatHistoryProviderFactory?.Invoke(state?.ChatHistoryProviderState ?? default, jsonSerializerOptions) ?? - new InMemoryChatHistoryProvider(state?.ChatHistoryProviderState ?? default, jsonSerializerOptions); + new InMemoryChatHistoryProvider(); + + if (state?.StateBag is { ValueKind: JsonValueKind.Object } stateBagElement) + { + this.StateBag = AgentSessionStateBag.Deserialize(stateBagElement); + } } /// @@ -105,6 +112,7 @@ protected internal virtual JsonElement Serialize(JsonSerializerOptions? jsonSeri var state = new InMemoryAgentSessionState { ChatHistoryProviderState = chatHistoryProviderState, + StateBag = this.StateBag.Serialize(), }; return JsonSerializer.SerializeToElement(state, AgentAbstractionsJsonUtilities.DefaultOptions.GetTypeInfo(typeof(InMemoryAgentSessionState))); @@ -115,10 +123,12 @@ protected internal virtual JsonElement Serialize(JsonSerializerOptions? jsonSeri base.GetService(serviceType, serviceKey) ?? this.ChatHistoryProvider?.GetService(serviceType, serviceKey); [DebuggerBrowsable(DebuggerBrowsableState.Never)] - private string DebuggerDisplay => $"Count = {this.ChatHistoryProvider.Count}"; + private string DebuggerDisplay => $"Count = {this.ChatHistoryProvider.GetMessages(this).Count}"; internal sealed class InMemoryAgentSessionState { public JsonElement? ChatHistoryProviderState { get; set; } + + public JsonElement? StateBag { get; set; } } } diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/InMemoryChatHistoryProvider.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/InMemoryChatHistoryProvider.cs index ab408c6a5e..638d6c2ec4 100644 --- a/dotnet/src/Microsoft.Agents.AI.Abstractions/InMemoryChatHistoryProvider.cs +++ b/dotnet/src/Microsoft.Agents.AI.Abstractions/InMemoryChatHistoryProvider.cs @@ -1,9 +1,7 @@ // Copyright (c) Microsoft. All rights reserved. using System; -using System.Collections; using System.Collections.Generic; -using System.Diagnostics; using System.Linq; using System.Text.Json; using System.Threading; @@ -14,99 +12,58 @@ namespace Microsoft.Agents.AI; /// -/// Provides an in-memory implementation of with support for message reduction and collection semantics. +/// Provides an in-memory implementation of with support for message reduction. /// /// /// -/// stores chat messages entirely in local memory, providing fast access and manipulation -/// capabilities. It implements both for agent integration and -/// for direct collection manipulation. +/// stores chat messages in the , +/// providing fast access and manipulation capabilities integrated with session state management. /// /// /// This maintains all messages in memory. For long-running conversations or high-volume scenarios, consider using /// message reduction strategies or alternative storage implementations. /// /// -[DebuggerDisplay("Count = {Count}")] -[DebuggerTypeProxy(typeof(DebugView))] -public sealed class InMemoryChatHistoryProvider : ChatHistoryProvider, IList, IReadOnlyList +public sealed class InMemoryChatHistoryProvider : ChatHistoryProvider { - private List _messages; + private const string DefaultStateBagKey = "InMemoryChatHistoryProvider.State"; - /// - /// Initializes a new instance of the class. - /// - /// - /// This constructor creates a basic in-memory without message reduction capabilities. - /// Messages will be stored exactly as added without any automatic processing or reduction. - /// - public InMemoryChatHistoryProvider() - { - this._messages = []; - } - - /// - /// Initializes a new instance of the class from previously serialized state. - /// - /// A representing the serialized state of the provider. - /// Optional settings for customizing the JSON deserialization process. - /// The is not a valid JSON object or cannot be deserialized. - /// - /// This constructor enables restoration of messages from previously saved state, allowing - /// conversation history to be preserved across application restarts or migrated between instances. - /// - public InMemoryChatHistoryProvider(JsonElement serializedState, JsonSerializerOptions? jsonSerializerOptions = null) - : this(null, serializedState, jsonSerializerOptions, ChatReducerTriggerEvent.BeforeMessagesRetrieval) - { - } + private readonly string _stateKey; + private readonly Func _stateInitializer; /// /// Initializes a new instance of the class. /// + /// + /// An optional delegate that initializes the provider state on the first invocation. + /// If , a default initializer that creates an empty state will be used. + /// /// - /// A instance used to process, reduce, or optimize chat messages. + /// An optional instance used to process, reduce, or optimize chat messages. /// This can be used to implement strategies like message summarization, truncation, or cleanup. /// /// /// Specifies when the message reducer should be invoked. The default is , /// which applies reduction logic when messages are retrieved for agent consumption. /// - /// is . + /// + /// An optional key to use for storing the state in the . + /// If , a default key will be used. + /// /// /// Message reducers enable automatic management of message storage by implementing strategies to /// keep memory usage under control while preserving important conversation context. /// - public InMemoryChatHistoryProvider(IChatReducer chatReducer, ChatReducerTriggerEvent reducerTriggerEvent = ChatReducerTriggerEvent.BeforeMessagesRetrieval) - : this(chatReducer, default, null, reducerTriggerEvent) - { - Throw.IfNull(chatReducer); - } - - /// - /// Initializes a new instance of the class, with an existing state from a serialized JSON element. - /// - /// An optional instance used to process or reduce chat messages. If null, no reduction logic will be applied. - /// A representing the serialized state of the provider. - /// Optional settings for customizing the JSON deserialization process. - /// The event that should trigger the reducer invocation. - public InMemoryChatHistoryProvider(IChatReducer? chatReducer, JsonElement serializedState, JsonSerializerOptions? jsonSerializerOptions = null, ChatReducerTriggerEvent reducerTriggerEvent = ChatReducerTriggerEvent.BeforeMessagesRetrieval) + public InMemoryChatHistoryProvider( + Func? stateInitializer = null, + IChatReducer? chatReducer = null, + ChatReducerTriggerEvent reducerTriggerEvent = ChatReducerTriggerEvent.BeforeMessagesRetrieval, + string? stateKey = null) { + this._stateInitializer = stateInitializer ?? (_ => new State()); this.ChatReducer = chatReducer; this.ReducerTriggerEvent = reducerTriggerEvent; - - if (serializedState.ValueKind is JsonValueKind.Object) - { - var jso = jsonSerializerOptions ?? AgentAbstractionsJsonUtilities.DefaultOptions; - var state = serializedState.Deserialize( - jso.GetTypeInfo(typeof(State))) as State; - if (state?.Messages is { } messages) - { - this._messages = messages; - return; - } - } - - this._messages = []; + this._stateKey = stateKey ?? DefaultStateBagKey; } /// @@ -119,17 +76,48 @@ public InMemoryChatHistoryProvider(IChatReducer? chatReducer, JsonElement serial /// public ChatReducerTriggerEvent ReducerTriggerEvent { get; } - /// - public int Count => this._messages.Count; + /// + /// Gets the chat messages stored for the specified session. + /// + /// The agent session containing the state. + /// A list of chat messages, or an empty list if no state is found. + public List GetMessages(AgentSession? session) + => this.GetOrInitializeState(session).Messages; - /// - public bool IsReadOnly => ((IList)this._messages).IsReadOnly; + /// + /// Sets the chat messages for the specified session. + /// + /// The agent session containing the state. + /// The messages to store. + /// is . + public void SetMessages(AgentSession? session, List messages) + { + _ = Throw.IfNull(messages); - /// - public ChatMessage this[int index] + var state = this.GetOrInitializeState(session); + state.Messages = messages; + } + + /// + /// Gets the state from the session's StateBag, or initializes it using the state initializer if not present. + /// + /// The agent session containing the StateBag. + /// The provider state, or null if no session is available. + private State GetOrInitializeState(AgentSession? session) { - get => this._messages[index]; - set => this._messages[index] = value; + var state = session?.StateBag.GetValue(this._stateKey, AgentAbstractionsJsonUtilities.DefaultOptions); + if (state is not null) + { + return state; + } + + state = this._stateInitializer(session); + if (session is not null) + { + session.StateBag.SetValue(this._stateKey, state, AgentAbstractionsJsonUtilities.DefaultOptions); + } + + return state; } /// @@ -137,12 +125,14 @@ public override async ValueTask> InvokingAsync(Invoking { _ = Throw.IfNull(context); + var state = this.GetOrInitializeState(context.Session); + if (this.ReducerTriggerEvent is ChatReducerTriggerEvent.BeforeMessagesRetrieval && this.ChatReducer is not null) { - this._messages = (await this.ChatReducer.ReduceAsync(this._messages, cancellationToken).ConfigureAwait(false)).ToList(); + state.Messages = (await this.ChatReducer.ReduceAsync(state.Messages, cancellationToken).ConfigureAwait(false)).ToList(); } - return this._messages; + return state.Messages; } /// @@ -155,70 +145,42 @@ public override async ValueTask InvokedAsync(InvokedContext context, Cancellatio return; } + var state = this.GetOrInitializeState(context.Session); + // Add request, AI context provider, and response messages to the provider var allNewMessages = context.RequestMessages.Concat(context.AIContextProviderMessages ?? []).Concat(context.ResponseMessages ?? []); - this._messages.AddRange(allNewMessages); + state.Messages.AddRange(allNewMessages); if (this.ReducerTriggerEvent is ChatReducerTriggerEvent.AfterMessageAdded && this.ChatReducer is not null) { - this._messages = (await this.ChatReducer.ReduceAsync(this._messages, cancellationToken).ConfigureAwait(false)).ToList(); + state.Messages = (await this.ChatReducer.ReduceAsync(state.Messages, cancellationToken).ConfigureAwait(false)).ToList(); } } - /// + /// + /// Serializes the current provider state to a . + /// + /// Optional serializer options (ignored, source generated context is used). + /// An empty object. + /// + /// State is now stored in the and serialized as part of the session. + /// public override JsonElement Serialize(JsonSerializerOptions? jsonSerializerOptions = null) { - State state = new() - { - Messages = this._messages, - }; - - var jso = jsonSerializerOptions ?? AgentAbstractionsJsonUtilities.DefaultOptions; - return JsonSerializer.SerializeToElement(state, jso.GetTypeInfo(typeof(State))); + // State is now stored in the session StateBag, so there is nothing to serialize here. + // Return an empty JSON object. + using var doc = JsonDocument.Parse("{}"); + return doc.RootElement.Clone(); } - /// - public int IndexOf(ChatMessage item) - => this._messages.IndexOf(item); - - /// - public void Insert(int index, ChatMessage item) - => this._messages.Insert(index, item); - - /// - public void RemoveAt(int index) - => this._messages.RemoveAt(index); - - /// - public void Add(ChatMessage item) - => this._messages.Add(item); - - /// - public void Clear() - => this._messages.Clear(); - - /// - public bool Contains(ChatMessage item) - => this._messages.Contains(item); - - /// - public void CopyTo(ChatMessage[] array, int arrayIndex) - => this._messages.CopyTo(array, arrayIndex); - - /// - public bool Remove(ChatMessage item) - => this._messages.Remove(item); - - /// - public IEnumerator GetEnumerator() - => this._messages.GetEnumerator(); - - /// - IEnumerator IEnumerable.GetEnumerator() - => this.GetEnumerator(); - - internal sealed class State + /// + /// Represents the state of a stored in the . + /// + public sealed class State { + /// + /// Gets or sets the list of chat messages. + /// public List Messages { get; set; } = []; } @@ -239,10 +201,4 @@ public enum ChatReducerTriggerEvent /// BeforeMessagesRetrieval } - - private sealed class DebugView(InMemoryChatHistoryProvider provider) - { - [DebuggerBrowsable(DebuggerBrowsableState.RootHidden)] - public ChatMessage[] Items => provider._messages.ToArray(); - } } diff --git a/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgentSession.cs b/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgentSession.cs index cb791bd564..a5ace55884 100644 --- a/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgentSession.cs +++ b/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgentSession.cs @@ -161,7 +161,7 @@ internal static async Task DeserializeAsync( session._chatHistoryProvider = chatHistoryProviderFactory is not null ? await chatHistoryProviderFactory.Invoke(state?.ChatHistoryProviderState ?? default, jsonSerializerOptions, cancellationToken).ConfigureAwait(false) - : new InMemoryChatHistoryProvider(state?.ChatHistoryProviderState ?? default, jsonSerializerOptions); // default to an in-memory ChatHistoryProvider + : new InMemoryChatHistoryProvider(); // default to an in-memory ChatHistoryProvider return session; } @@ -193,7 +193,7 @@ internal JsonElement Serialize(JsonSerializerOptions? jsonSerializerOptions = nu [DebuggerBrowsable(DebuggerBrowsableState.Never)] private string DebuggerDisplay => this.ConversationId is { } conversationId ? $"ConversationId = {conversationId}" : - this._chatHistoryProvider is InMemoryChatHistoryProvider inMemoryChatHistoryProvider ? $"Count = {inMemoryChatHistoryProvider.Count}" : + this._chatHistoryProvider is InMemoryChatHistoryProvider ? "InMemoryChatHistoryProvider" : this._chatHistoryProvider is { } chatHistoryProvider ? $"ChatHistoryProvider = {chatHistoryProvider.GetType().Name}" : "Count = 0"; diff --git a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/InMemoryAgentSessionTests.cs b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/InMemoryAgentSessionTests.cs index a3a4bf7e0e..725729897c 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/InMemoryAgentSessionTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/InMemoryAgentSessionTests.cs @@ -22,23 +22,23 @@ public void Constructor_SetsDefaultChatHistoryProvider() var session = new TestInMemoryAgentSession(); // Assert - Assert.NotNull(session.GetChatHistoryProvider()); - Assert.Empty(session.GetChatHistoryProvider()); + Assert.NotNull(session.ChatHistoryProvider); + Assert.Empty(session.ChatHistoryProvider.GetMessages(session)); } [Fact] public void Constructor_WithChatHistoryProvider_SetsProperty() { // Arrange - InMemoryChatHistoryProvider provider = [new(ChatRole.User, "Hello")]; - - // Act + var provider = new InMemoryChatHistoryProvider(); var session = new TestInMemoryAgentSession(provider); + provider.SetMessages(session, [new(ChatRole.User, "Hello")]); - // Assert - Assert.Same(provider, session.GetChatHistoryProvider()); - Assert.Single(session.GetChatHistoryProvider()); - Assert.Equal("Hello", session.GetChatHistoryProvider()[0].Text); + // Act & Assert + Assert.Same(provider, session.ChatHistoryProvider); + var messages = session.ChatHistoryProvider.GetMessages(session); + Assert.Single(messages); + Assert.Equal("Hello", messages[0].Text); } [Fact] @@ -51,27 +51,26 @@ public void Constructor_WithMessages_SetsProperty() var session = new TestInMemoryAgentSession(messages); // Assert - Assert.NotNull(session.GetChatHistoryProvider()); - Assert.Single(session.GetChatHistoryProvider()); - Assert.Equal("Hi", session.GetChatHistoryProvider()[0].Text); + Assert.NotNull(session.ChatHistoryProvider); + Assert.Single(session.ChatHistoryProvider.GetMessages(session)); + Assert.Equal("Hi", session.ChatHistoryProvider.GetMessages(session)[0].Text); } [Fact] public void Constructor_WithSerializedState_SetsProperty() { - // Arrange - InMemoryChatHistoryProvider provider = [new(ChatRole.User, "TestMsg")]; - var providerState = provider.Serialize(); - var sessionStateWrapper = new InMemoryAgentSession.InMemoryAgentSessionState { ChatHistoryProviderState = providerState }; - var json = JsonSerializer.SerializeToElement(sessionStateWrapper, TestJsonSerializerContext.Default.InMemoryAgentSessionState); + // Arrange - create a session with a StateBag containing chat history + var originalSession = new TestInMemoryAgentSession([new(ChatRole.User, "TestMsg")]); + var json = originalSession.Serialize(); // Act var session = new TestInMemoryAgentSession(json); // Assert - Assert.NotNull(session.GetChatHistoryProvider()); - Assert.Single(session.GetChatHistoryProvider()); - Assert.Equal("TestMsg", session.GetChatHistoryProvider()[0].Text); + Assert.NotNull(session.ChatHistoryProvider); + var messages = session.ChatHistoryProvider.GetMessages(session); + Assert.Single(messages); + Assert.Equal("TestMsg", messages[0].Text); } [Fact] @@ -99,16 +98,20 @@ public void Serialize_ReturnsCorrectJson_WhenMessagesExist() // Assert Assert.Equal(JsonValueKind.Object, json.ValueKind); - Assert.True(json.TryGetProperty("chatHistoryProviderState", out var providerStateProperty)); + Assert.True(json.TryGetProperty("stateBag", out var stateBagProperty)); + Assert.Equal(JsonValueKind.Object, stateBagProperty.ValueKind); + Assert.True(stateBagProperty.TryGetProperty("InMemoryChatHistoryProvider.State", out var providerStateProperty)); Assert.Equal(JsonValueKind.Object, providerStateProperty.ValueKind); - Assert.True(providerStateProperty.TryGetProperty("messages", out var messagesProperty)); + Assert.True(providerStateProperty.TryGetProperty("jsonValue", out var jsonValueProperty)); + Assert.Equal(JsonValueKind.Object, jsonValueProperty.ValueKind); + Assert.True(jsonValueProperty.TryGetProperty("messages", out var messagesProperty)); Assert.Equal(JsonValueKind.Array, messagesProperty.ValueKind); var messagesList = messagesProperty.EnumerateArray().ToList(); Assert.Single(messagesList); } [Fact] - public void Serialize_ReturnsEmptyMessages_WhenNoMessages() + public void Serialize_ReturnsEmptyStateBag_WhenNoMessages() { // Arrange var session = new TestInMemoryAgentSession(); @@ -118,11 +121,9 @@ public void Serialize_ReturnsEmptyMessages_WhenNoMessages() // Assert Assert.Equal(JsonValueKind.Object, json.ValueKind); - Assert.True(json.TryGetProperty("chatHistoryProviderState", out var providerStateProperty)); + Assert.True(json.TryGetProperty("stateBag", out var providerStateProperty)); Assert.Equal(JsonValueKind.Object, providerStateProperty.ValueKind); - Assert.True(providerStateProperty.TryGetProperty("messages", out var messagesProperty)); - Assert.Equal(JsonValueKind.Array, messagesProperty.ValueKind); - Assert.Empty(messagesProperty.EnumerateArray()); + Assert.False(providerStateProperty.EnumerateObject().Any()); } #endregion @@ -137,8 +138,8 @@ public void GetService_RequestingChatHistoryProvider_ReturnsChatHistoryProvider( // Act & Assert Assert.NotNull(session.GetService(typeof(ChatHistoryProvider))); - Assert.Same(session.GetChatHistoryProvider(), session.GetService(typeof(ChatHistoryProvider))); - Assert.Same(session.GetChatHistoryProvider(), session.GetService(typeof(InMemoryChatHistoryProvider))); + Assert.Same(session.ChatHistoryProvider, session.GetService(typeof(ChatHistoryProvider))); + Assert.Same(session.ChatHistoryProvider, session.GetService(typeof(InMemoryChatHistoryProvider))); } #endregion @@ -150,6 +151,5 @@ public TestInMemoryAgentSession() { } public TestInMemoryAgentSession(InMemoryChatHistoryProvider? provider) : base(provider) { } public TestInMemoryAgentSession(IEnumerable messages) : base(messages) { } public TestInMemoryAgentSession(JsonElement serializedSessionState) : base(serializedSessionState) { } - public InMemoryChatHistoryProvider GetChatHistoryProvider() => this.ChatHistoryProvider; } } diff --git a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/InMemoryChatHistoryProviderTests.cs b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/InMemoryChatHistoryProviderTests.cs index bf8ff998b9..d3346b3ed7 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/InMemoryChatHistoryProviderTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/InMemoryChatHistoryProviderTests.cs @@ -3,9 +3,6 @@ using System; using System.Collections.Generic; using System.Linq; -using System.Text.Encodings.Web; -using System.Text.Json; -using System.Text.Json.Serialization.Metadata; using System.Threading; using System.Threading.Tasks; using Microsoft.Extensions.AI; @@ -19,19 +16,15 @@ namespace Microsoft.Agents.AI.Abstractions.UnitTests; public class InMemoryChatHistoryProviderTests { private static readonly AIAgent s_mockAgent = new Mock().Object; - private static readonly AgentSession s_mockSession = new Mock().Object; - [Fact] - public void Constructor_Throws_ForNullReducer() => - // Arrange & Act & Assert - Assert.Throws(() => new InMemoryChatHistoryProvider(null!)); + private static AgentSession CreateMockSession() => new Mock().Object; [Fact] public void Constructor_DefaultsToBeforeMessageRetrieval_ForNotProvidedTriggerEvent() { // Arrange & Act var reducerMock = new Mock(); - var provider = new InMemoryChatHistoryProvider(reducerMock.Object); + var provider = new InMemoryChatHistoryProvider(chatReducer: reducerMock.Object); // Assert Assert.Equal(InMemoryChatHistoryProvider.ChatReducerTriggerEvent.BeforeMessagesRetrieval, provider.ReducerTriggerEvent); @@ -42,7 +35,7 @@ public void Constructor_Arguments_SetOnPropertiesCorrectly() { // Arrange & Act var reducerMock = new Mock(); - var provider = new InMemoryChatHistoryProvider(reducerMock.Object, InMemoryChatHistoryProvider.ChatReducerTriggerEvent.AfterMessageAdded); + var provider = new InMemoryChatHistoryProvider(chatReducer: reducerMock.Object, reducerTriggerEvent: InMemoryChatHistoryProvider.ChatReducerTriggerEvent.AfterMessageAdded); // Assert Assert.Same(reducerMock.Object, provider.ChatReducer); @@ -52,6 +45,9 @@ public void Constructor_Arguments_SetOnPropertiesCorrectly() [Fact] public async Task InvokedAsyncAddsMessagesAsync() { + var session = CreateMockSession(); + + // Arrange var requestMessages = new List { new(ChatRole.User, "Hello") @@ -70,440 +66,156 @@ public async Task InvokedAsyncAddsMessagesAsync() }; var provider = new InMemoryChatHistoryProvider(); - provider.Add(providerMessages[0]); - var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages, providerMessages) + provider.SetMessages(session, [providerMessages[0]]); + var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, session, requestMessages, providerMessages) { AIContextProviderMessages = aiContextProviderMessages, ResponseMessages = responseMessages }; await provider.InvokedAsync(context, CancellationToken.None); - Assert.Equal(4, provider.Count); - Assert.Equal("original instructions", provider[0].Text); - Assert.Equal("Hello", provider[1].Text); - Assert.Equal("additional context", provider[2].Text); - Assert.Equal("Hi there!", provider[3].Text); + // Assert + var messages = provider.GetMessages(session); + Assert.Equal(4, messages.Count); + Assert.Equal("original instructions", messages[0].Text); + Assert.Equal("Hello", messages[1].Text); + Assert.Equal("additional context", messages[2].Text); + Assert.Equal("Hi there!", messages[3].Text); } [Fact] public async Task InvokedAsyncWithEmptyDoesNotFailAsync() { + var session = CreateMockSession(); + + // Arrange var provider = new InMemoryChatHistoryProvider(); - var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, [], []); + var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, session, [], []); await provider.InvokedAsync(context, CancellationToken.None); - Assert.Empty(provider); + // Assert + Assert.Empty(provider.GetMessages(session)); } [Fact] public async Task InvokingAsyncReturnsAllMessagesAsync() { - var provider = new InMemoryChatHistoryProvider - { + var session = CreateMockSession(); + + // Arrange + var provider = new InMemoryChatHistoryProvider(); + provider.SetMessages(session, + [ new ChatMessage(ChatRole.User, "Test1"), new ChatMessage(ChatRole.Assistant, "Test2") - }; + ]); - var context = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, []); + var context = new ChatHistoryProvider.InvokingContext(s_mockAgent, session, []); var result = (await provider.InvokingAsync(context, CancellationToken.None)).ToList(); + // Assert Assert.Equal(2, result.Count); Assert.Contains(result, m => m.Text == "Test1"); Assert.Contains(result, m => m.Text == "Test2"); } [Fact] - public async Task DeserializeConstructorWithEmptyElementAsync() - { - var emptyObject = JsonSerializer.Deserialize("{}", TestJsonSerializerContext.Default.JsonElement); - - var newProvider = new InMemoryChatHistoryProvider(emptyObject); - - Assert.Empty(newProvider); - } - - [Fact] - public async Task SerializeAndDeserializeConstructorRoundtripsAsync() - { - var provider = new InMemoryChatHistoryProvider - { - new ChatMessage(ChatRole.User, "A"), - new ChatMessage(ChatRole.Assistant, "B") - }; - - var jsonElement = provider.Serialize(); - var newProvider = new InMemoryChatHistoryProvider(jsonElement); - - Assert.Equal(2, newProvider.Count); - Assert.Equal("A", newProvider[0].Text); - Assert.Equal("B", newProvider[1].Text); - } - - [Fact] - public async Task SerializeAndDeserializeConstructorRoundtripsWithCustomAIContentAsync() - { - JsonSerializerOptions options = new(TestJsonSerializerContext.Default.Options) - { - TypeInfoResolver = JsonTypeInfoResolver.Combine(AgentAbstractionsJsonUtilities.DefaultOptions.TypeInfoResolver, TestJsonSerializerContext.Default), - Encoder = JavaScriptEncoder.UnsafeRelaxedJsonEscaping, - }; - options.AddAIContentType(typeDiscriminatorId: "testContent"); - - var provider = new InMemoryChatHistoryProvider - { - new ChatMessage(ChatRole.User, [new TestAIContent("foo data")]), - }; - - var jsonElement = provider.Serialize(options); - var newProvider = new InMemoryChatHistoryProvider(jsonElement, options); - - Assert.Single(newProvider); - var actualTestAIContent = Assert.IsType(newProvider[0].Contents[0]); - Assert.Equal("foo data", actualTestAIContent.TestData); - } - - [Fact] - public async Task SerializeAndDeserializeWorksWithExperimentalContentTypesAsync() - { - var provider = new InMemoryChatHistoryProvider - { - new ChatMessage(ChatRole.User, [new FunctionApprovalRequestContent("call123", new FunctionCallContent("call123", "some_func"))]), - new ChatMessage(ChatRole.Assistant, [new FunctionApprovalResponseContent("call123", true, new FunctionCallContent("call123", "some_func"))]) - }; - - var jsonElement = provider.Serialize(); - var newProvider = new InMemoryChatHistoryProvider(jsonElement); - - Assert.Equal(2, newProvider.Count); - Assert.IsType(newProvider[0].Contents[0]); - Assert.IsType(newProvider[1].Contents[0]); - } - - [Fact] - public async Task InvokedAsyncWithEmptyMessagesDoesNotChangeProviderAsync() - { - var provider = new InMemoryChatHistoryProvider(); - var messages = new List(); - - var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, messages, []); - await provider.InvokedAsync(context, CancellationToken.None); - - Assert.Empty(provider); - } - - [Fact] - public async Task InvokedAsync_WithNullContext_ThrowsArgumentNullExceptionAsync() + public void StateInitializer_IsInvoked_WhenSessionHasNoState() { // Arrange - var provider = new InMemoryChatHistoryProvider(); - - // Act & Assert - await Assert.ThrowsAsync(() => provider.InvokedAsync(null!, CancellationToken.None).AsTask()); - } - - [Fact] - public void DeserializeContructor_WithNullSerializedState_CreatesEmptyProvider() - { - // Act - var provider = new InMemoryChatHistoryProvider(new JsonElement()); - - // Assert - Assert.Empty(provider); - } - - [Fact] - public async Task DeserializeContructor_WithEmptyMessages_DoesNotAddMessagesAsync() - { - // Arrange - var stateWithEmptyMessages = JsonSerializer.SerializeToElement( - new Dictionary { ["messages"] = new List() }, - TestJsonSerializerContext.Default.IDictionaryStringObject); - - // Act - var provider = new InMemoryChatHistoryProvider(stateWithEmptyMessages); - - // Assert - Assert.Empty(provider); - } - - [Fact] - public async Task DeserializeConstructor_WithNullMessages_DoesNotAddMessagesAsync() - { - // Arrange - var stateWithNullMessages = JsonSerializer.SerializeToElement( - new Dictionary { ["messages"] = null! }, - TestJsonSerializerContext.Default.DictionaryStringObject); - - // Act - var provider = new InMemoryChatHistoryProvider(stateWithNullMessages); - - // Assert - Assert.Empty(provider); - } - - [Fact] - public async Task DeserializeConstructor_WithValidMessages_AddsMessagesAsync() - { - // Arrange - var messages = new List + var initialMessages = new List { - new(ChatRole.User, "User message"), - new(ChatRole.Assistant, "Assistant message") + new(ChatRole.User, "Initial message") }; - var state = new Dictionary { ["messages"] = messages }; - var serializedState = JsonSerializer.SerializeToElement( - state, - TestJsonSerializerContext.Default.DictionaryStringObject); + var provider = new InMemoryChatHistoryProvider( + stateInitializer: _ => new InMemoryChatHistoryProvider.State { Messages = initialMessages }); // Act - var provider = new InMemoryChatHistoryProvider(serializedState); + var messages = provider.GetMessages(CreateMockSession()); // Assert - Assert.Equal(2, provider.Count); - Assert.Equal("User message", provider[0].Text); - Assert.Equal("Assistant message", provider[1].Text); - } - - [Fact] - public void IndexerGet_ReturnsCorrectMessage() - { - // Arrange - var provider = new InMemoryChatHistoryProvider(); - var message1 = new ChatMessage(ChatRole.User, "First"); - var message2 = new ChatMessage(ChatRole.Assistant, "Second"); - provider.Add(message1); - provider.Add(message2); - - // Act & Assert - Assert.Same(message1, provider[0]); - Assert.Same(message2, provider[1]); + Assert.Single(messages); + Assert.Equal("Initial message", messages[0].Text); } [Fact] - public void IndexerSet_UpdatesMessage() + public void GetMessages_ReturnsEmptyList_WhenNullSession() { // Arrange var provider = new InMemoryChatHistoryProvider(); - var originalMessage = new ChatMessage(ChatRole.User, "Original"); - var newMessage = new ChatMessage(ChatRole.User, "Updated"); - provider.Add(originalMessage); // Act - provider[0] = newMessage; + var messages = provider.GetMessages(null); // Assert - Assert.Same(newMessage, provider[0]); - Assert.Equal("Updated", provider[0].Text); + Assert.Empty(messages); } [Fact] - public void IsReadOnly_ReturnsFalse() + public void SetMessages_ThrowsForNullMessages() { // Arrange var provider = new InMemoryChatHistoryProvider(); // Act & Assert - Assert.False(provider.IsReadOnly); + Assert.Throws(() => provider.SetMessages(CreateMockSession(), null!)); } [Fact] - public void IndexOf_ReturnsCorrectIndex() + public void SetMessages_UpdatesState() { - // Arrange - var provider = new InMemoryChatHistoryProvider(); - var message1 = new ChatMessage(ChatRole.User, "First"); - var message2 = new ChatMessage(ChatRole.Assistant, "Second"); - var message3 = new ChatMessage(ChatRole.User, "Third"); - provider.Add(message1); - provider.Add(message2); - - // Act & Assert - Assert.Equal(0, provider.IndexOf(message1)); - Assert.Equal(1, provider.IndexOf(message2)); - Assert.Equal(-1, provider.IndexOf(message3)); // Not in provider - } + var session = CreateMockSession(); - [Fact] - public void Insert_InsertsMessageAtCorrectIndex() - { // Arrange var provider = new InMemoryChatHistoryProvider(); - var message1 = new ChatMessage(ChatRole.User, "First"); - var message2 = new ChatMessage(ChatRole.Assistant, "Second"); - var insertMessage = new ChatMessage(ChatRole.User, "Inserted"); - provider.Add(message1); - provider.Add(message2); - - // Act - provider.Insert(1, insertMessage); - - // Assert - Assert.Equal(3, provider.Count); - Assert.Same(message1, provider[0]); - Assert.Same(insertMessage, provider[1]); - Assert.Same(message2, provider[2]); - } - - [Fact] - public void RemoveAt_RemovesMessageAtIndex() - { - // Arrange - var provider = new InMemoryChatHistoryProvider(); - var message1 = new ChatMessage(ChatRole.User, "First"); - var message2 = new ChatMessage(ChatRole.Assistant, "Second"); - var message3 = new ChatMessage(ChatRole.User, "Third"); - provider.Add(message1); - provider.Add(message2); - provider.Add(message3); - - // Act - provider.RemoveAt(1); - - // Assert - Assert.Equal(2, provider.Count); - Assert.Same(message1, provider[0]); - Assert.Same(message3, provider[1]); - } - - [Fact] - public void Clear_RemovesAllMessages() - { - // Arrange - var provider = new InMemoryChatHistoryProvider + var messages = new List { - new ChatMessage(ChatRole.User, "First"), - new ChatMessage(ChatRole.Assistant, "Second") + new(ChatRole.User, "Hello"), + new(ChatRole.Assistant, "World") }; // Act - provider.Clear(); - - // Assert - Assert.Empty(provider); - } - - [Fact] - public void Contains_ReturnsTrueForExistingMessage() - { - // Arrange - var provider = new InMemoryChatHistoryProvider(); - var message1 = new ChatMessage(ChatRole.User, "First"); - var message2 = new ChatMessage(ChatRole.Assistant, "Second"); - provider.Add(message1); - - // Act & Assert - Assert.Contains(message1, provider); - Assert.DoesNotContain(message2, provider); - } - - [Fact] - public void CopyTo_CopiesMessagesToArray() - { - // Arrange - var provider = new InMemoryChatHistoryProvider(); - var message1 = new ChatMessage(ChatRole.User, "First"); - var message2 = new ChatMessage(ChatRole.Assistant, "Second"); - provider.Add(message1); - provider.Add(message2); - var array = new ChatMessage[4]; - - // Act - provider.CopyTo(array, 1); - - // Assert - Assert.Null(array[0]); - Assert.Same(message1, array[1]); - Assert.Same(message2, array[2]); - Assert.Null(array[3]); - } - - [Fact] - public void Remove_RemovesSpecificMessage() - { - // Arrange - var provider = new InMemoryChatHistoryProvider(); - var message1 = new ChatMessage(ChatRole.User, "First"); - var message2 = new ChatMessage(ChatRole.Assistant, "Second"); - var message3 = new ChatMessage(ChatRole.User, "Third"); - provider.Add(message1); - provider.Add(message2); - provider.Add(message3); - - // Act - var removed = provider.Remove(message2); + provider.SetMessages(session, messages); + var retrieved = provider.GetMessages(session); // Assert - Assert.True(removed); - Assert.Equal(2, provider.Count); - Assert.Same(message1, provider[0]); - Assert.Same(message3, provider[1]); + Assert.Equal(2, retrieved.Count); + Assert.Equal("Hello", retrieved[0].Text); + Assert.Equal("World", retrieved[1].Text); } [Fact] - public void Remove_ReturnsFalseForNonExistentMessage() + public async Task InvokedAsyncWithEmptyMessagesDoesNotChangeProviderAsync() { - // Arrange - var provider = new InMemoryChatHistoryProvider(); - var message1 = new ChatMessage(ChatRole.User, "First"); - var message2 = new ChatMessage(ChatRole.Assistant, "Second"); - provider.Add(message1); + var session = CreateMockSession(); - // Act - var removed = provider.Remove(message2); - - // Assert - Assert.False(removed); - Assert.Single(provider); - } - - [Fact] - public void GetEnumerator_Generic_ReturnsAllMessages() - { // Arrange var provider = new InMemoryChatHistoryProvider(); - var message1 = new ChatMessage(ChatRole.User, "First"); - var message2 = new ChatMessage(ChatRole.Assistant, "Second"); - provider.Add(message1); - provider.Add(message2); - - // Act var messages = new List(); - messages.AddRange(provider); + + var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, session, messages, []); + await provider.InvokedAsync(context, CancellationToken.None); // Assert - Assert.Equal(2, messages.Count); - Assert.Same(message1, messages[0]); - Assert.Same(message2, messages[1]); + Assert.Empty(provider.GetMessages(session)); } [Fact] - public void GetEnumerator_NonGeneric_ReturnsAllMessages() + public async Task InvokedAsync_WithNullContext_ThrowsArgumentNullExceptionAsync() { // Arrange var provider = new InMemoryChatHistoryProvider(); - var message1 = new ChatMessage(ChatRole.User, "First"); - var message2 = new ChatMessage(ChatRole.Assistant, "Second"); - provider.Add(message1); - provider.Add(message2); - // Act - var messages = new List(); - var enumerator = ((System.Collections.IEnumerable)provider).GetEnumerator(); - while (enumerator.MoveNext()) - { - messages.Add((ChatMessage)enumerator.Current); - } - - // Assert - Assert.Equal(2, messages.Count); - Assert.Same(message1, messages[0]); - Assert.Same(message2, messages[1]); + // Act & Assert + await Assert.ThrowsAsync(() => provider.InvokedAsync(null!, CancellationToken.None).AsTask()); } [Fact] public async Task AddMessagesAsync_WithReducer_AfterMessageAdded_InvokesReducerAsync() { + var session = CreateMockSession(); + // Arrange var originalMessages = new List { @@ -520,21 +232,24 @@ public async Task AddMessagesAsync_WithReducer_AfterMessageAdded_InvokesReducerA .Setup(r => r.ReduceAsync(It.Is>(x => x.SequenceEqual(originalMessages)), It.IsAny())) .ReturnsAsync(reducedMessages); - var provider = new InMemoryChatHistoryProvider(reducerMock.Object, InMemoryChatHistoryProvider.ChatReducerTriggerEvent.AfterMessageAdded); + var provider = new InMemoryChatHistoryProvider(chatReducer: reducerMock.Object, reducerTriggerEvent: InMemoryChatHistoryProvider.ChatReducerTriggerEvent.AfterMessageAdded); // Act - var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, originalMessages, []); + var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, session, originalMessages, []); await provider.InvokedAsync(context, CancellationToken.None); // Assert - Assert.Single(provider); - Assert.Equal("Reduced", provider[0].Text); + var messages = provider.GetMessages(session); + Assert.Single(messages); + Assert.Equal("Reduced", messages[0].Text); reducerMock.Verify(r => r.ReduceAsync(It.Is>(x => x.SequenceEqual(originalMessages)), It.IsAny()), Times.Once); } [Fact] public async Task GetMessagesAsync_WithReducer_BeforeMessagesRetrieval_InvokesReducerAsync() { + var session = CreateMockSession(); + // Arrange var originalMessages = new List { @@ -551,15 +266,12 @@ public async Task GetMessagesAsync_WithReducer_BeforeMessagesRetrieval_InvokesRe .Setup(r => r.ReduceAsync(It.Is>(x => x.SequenceEqual(originalMessages)), It.IsAny())) .ReturnsAsync(reducedMessages); - var provider = new InMemoryChatHistoryProvider(reducerMock.Object, InMemoryChatHistoryProvider.ChatReducerTriggerEvent.BeforeMessagesRetrieval); + var provider = new InMemoryChatHistoryProvider(chatReducer: reducerMock.Object, reducerTriggerEvent: InMemoryChatHistoryProvider.ChatReducerTriggerEvent.BeforeMessagesRetrieval); // Add messages directly to the provider for this test - foreach (var msg in originalMessages) - { - provider.Add(msg); - } + provider.SetMessages(session, new List(originalMessages)); // Act - var invokingContext = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, Array.Empty()); + var invokingContext = new ChatHistoryProvider.InvokingContext(s_mockAgent, session, Array.Empty()); var result = (await provider.InvokingAsync(invokingContext, CancellationToken.None)).ToList(); // Assert @@ -571,6 +283,8 @@ public async Task GetMessagesAsync_WithReducer_BeforeMessagesRetrieval_InvokesRe [Fact] public async Task AddMessagesAsync_WithReducer_ButWrongTrigger_DoesNotInvokeReducerAsync() { + var session = CreateMockSession(); + // Arrange var originalMessages = new List { @@ -579,21 +293,24 @@ public async Task AddMessagesAsync_WithReducer_ButWrongTrigger_DoesNotInvokeRedu var reducerMock = new Mock(); - var provider = new InMemoryChatHistoryProvider(reducerMock.Object, InMemoryChatHistoryProvider.ChatReducerTriggerEvent.BeforeMessagesRetrieval); + var provider = new InMemoryChatHistoryProvider(chatReducer: reducerMock.Object, reducerTriggerEvent: InMemoryChatHistoryProvider.ChatReducerTriggerEvent.BeforeMessagesRetrieval); // Act - var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, originalMessages, []); + var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, session, originalMessages, []); await provider.InvokedAsync(context, CancellationToken.None); // Assert - Assert.Single(provider); - Assert.Equal("Hello", provider[0].Text); + var messages = provider.GetMessages(session); + Assert.Single(messages); + Assert.Equal("Hello", messages[0].Text); reducerMock.Verify(r => r.ReduceAsync(It.IsAny>(), It.IsAny()), Times.Never); } [Fact] public async Task GetMessagesAsync_WithReducer_ButWrongTrigger_DoesNotInvokeReducerAsync() { + var session = CreateMockSession(); + // Arrange var originalMessages = new List { @@ -602,13 +319,11 @@ public async Task GetMessagesAsync_WithReducer_ButWrongTrigger_DoesNotInvokeRedu var reducerMock = new Mock(); - var provider = new InMemoryChatHistoryProvider(reducerMock.Object, InMemoryChatHistoryProvider.ChatReducerTriggerEvent.AfterMessageAdded) - { - originalMessages[0] - }; + var provider = new InMemoryChatHistoryProvider(chatReducer: reducerMock.Object, reducerTriggerEvent: InMemoryChatHistoryProvider.ChatReducerTriggerEvent.AfterMessageAdded); + provider.SetMessages(session, new List(originalMessages)); // Act - var invokingContext = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, Array.Empty()); + var invokingContext = new ChatHistoryProvider.InvokingContext(s_mockAgent, session, Array.Empty()); var result = (await provider.InvokingAsync(invokingContext, CancellationToken.None)).ToList(); // Assert @@ -620,6 +335,8 @@ public async Task GetMessagesAsync_WithReducer_ButWrongTrigger_DoesNotInvokeRedu [Fact] public async Task InvokedAsync_WithException_DoesNotAddMessagesAsync() { + var session = CreateMockSession(); + // Arrange var provider = new InMemoryChatHistoryProvider(); var requestMessages = new List @@ -630,7 +347,7 @@ public async Task InvokedAsync_WithException_DoesNotAddMessagesAsync() { new(ChatRole.Assistant, "Hi there!") }; - var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages, []) + var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, session, requestMessages, []) { ResponseMessages = responseMessages, InvokeException = new InvalidOperationException("Test exception") @@ -640,7 +357,7 @@ public async Task InvokedAsync_WithException_DoesNotAddMessagesAsync() await provider.InvokedAsync(context, CancellationToken.None); // Assert - Assert.Empty(provider); + Assert.Empty(provider.GetMessages(session)); } [Fact] @@ -653,6 +370,20 @@ public async Task InvokingAsync_WithNullContext_ThrowsArgumentNullExceptionAsync await Assert.ThrowsAsync(() => provider.InvokingAsync(null!, CancellationToken.None).AsTask()); } + [Fact] + public void Serialize_ReturnsEmptyJsonObject() + { + // Arrange + var provider = new InMemoryChatHistoryProvider(); + + // Act + var result = provider.Serialize(); + + // Assert + Assert.Equal(System.Text.Json.JsonValueKind.Object, result.ValueKind); + Assert.Empty(result.EnumerateObject()); + } + public class TestAIContent(string testData) : AIContent { public string TestData => testData; diff --git a/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgentSessionTests.cs b/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgentSessionTests.cs index 2acfaf1a10..977b9077d7 100644 --- a/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgentSessionTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgentSessionTests.cs @@ -98,7 +98,11 @@ public async Task VerifyDeserializeWithMessagesAsync() // Arrange var json = JsonSerializer.Deserialize(""" { - "chatHistoryProviderState": { "messages": [{"authorName": "testAuthor"}] } + "stateBag": { + "InMemoryChatHistoryProvider.State": { + "jsonValue": { "messages": [{"authorName": "testAuthor"}] } + } + } } """, TestJsonSerializerContext.Default.JsonElement); @@ -110,8 +114,9 @@ public async Task VerifyDeserializeWithMessagesAsync() var chatHistoryProvider = session.ChatHistoryProvider as InMemoryChatHistoryProvider; Assert.NotNull(chatHistoryProvider); - Assert.Single(chatHistoryProvider); - Assert.Equal("testAuthor", chatHistoryProvider[0].AuthorName); + var messages = chatHistoryProvider.GetMessages(session); + Assert.Single(messages); + Assert.Equal("testAuthor", messages[0].AuthorName); } [Fact] @@ -222,8 +227,9 @@ public void VerifySessionSerializationWithId() public void VerifySessionSerializationWithMessages() { // Arrange - InMemoryChatHistoryProvider provider = [new(ChatRole.User, "TestContent") { AuthorName = "TestAuthor" }]; + var provider = new InMemoryChatHistoryProvider(); var session = new ChatClientAgentSession { ChatHistoryProvider = provider }; + provider.SetMessages(session, [new(ChatRole.User, "TestContent") { AuthorName = "TestAuthor" }]); // Act var json = session.Serialize(); @@ -233,10 +239,18 @@ public void VerifySessionSerializationWithMessages() Assert.False(json.TryGetProperty("conversationId", out _)); + // chatHistoryProviderState should be an empty JSON object (state is in StateBag now) Assert.True(json.TryGetProperty("chatHistoryProviderState", out var chatHistoryProviderStateProperty)); Assert.Equal(JsonValueKind.Object, chatHistoryProviderStateProperty.ValueKind); - Assert.True(chatHistoryProviderStateProperty.TryGetProperty("messages", out var messagesProperty)); + // Messages should be stored in the stateBag + Assert.True(json.TryGetProperty("stateBag", out var stateBagProperty)); + Assert.Equal(JsonValueKind.Object, stateBagProperty.ValueKind); + Assert.True(stateBagProperty.TryGetProperty("InMemoryChatHistoryProvider.State", out var providerStateProperty)); + Assert.Equal(JsonValueKind.Object, providerStateProperty.ValueKind); + Assert.True(providerStateProperty.TryGetProperty("jsonValue", out var jsonValueProperty)); + Assert.Equal(JsonValueKind.Object, jsonValueProperty.ValueKind); + Assert.True(jsonValueProperty.TryGetProperty("messages", out var messagesProperty)); Assert.Equal(JsonValueKind.Array, messagesProperty.ValueKind); Assert.Single(messagesProperty.EnumerateArray()); diff --git a/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgentTests.cs b/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgentTests.cs index e1ff5f8cbd..7668bbdb70 100644 --- a/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgentTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgentTests.cs @@ -373,10 +373,11 @@ public async Task RunAsyncInvokesAIContextProviderAndUsesResultAsync() // Verify that the session was updated with the ai context provider, input and response messages var chatHistoryProvider = Assert.IsType(session!.ChatHistoryProvider); - Assert.Equal(3, chatHistoryProvider.Count); - Assert.Equal("user message", chatHistoryProvider[0].Text); - Assert.Equal("context provider message", chatHistoryProvider[1].Text); - Assert.Equal("response", chatHistoryProvider[2].Text); + var messages = chatHistoryProvider.GetMessages(session); + Assert.Equal(3, messages.Count); + Assert.Equal("user message", messages[0].Text); + Assert.Equal("context provider message", messages[1].Text); + Assert.Equal("response", messages[2].Text); mockProvider.Verify(p => p.InvokingAsync(It.IsAny(), It.IsAny()), Times.Once); mockProvider.Verify(p => p.InvokedAsync(It.Is(x => @@ -1301,9 +1302,10 @@ public async Task RunStreamingAsyncUsesChatHistoryProviderWhenNoConversationIdRe // Assert var chatHistoryProvider = Assert.IsType(session!.ChatHistoryProvider); - Assert.Equal(2, chatHistoryProvider.Count); - Assert.Equal("test", chatHistoryProvider[0].Text); - Assert.Equal("what?", chatHistoryProvider[1].Text); + var historyMessages = chatHistoryProvider.GetMessages(session); + Assert.Equal(2, historyMessages.Count); + Assert.Equal("test", historyMessages[0].Text); + Assert.Equal("what?", historyMessages[1].Text); mockFactory.Verify(f => f(It.IsAny(), It.IsAny()), Times.Once); } @@ -1409,10 +1411,11 @@ public async Task RunStreamingAsyncInvokesAIContextProviderAndUsesResultAsync() // Verify that the session was updated with the input, ai context provider, and response messages var chatHistoryProvider = Assert.IsType(session!.ChatHistoryProvider); - Assert.Equal(3, chatHistoryProvider.Count); - Assert.Equal("user message", chatHistoryProvider[0].Text); - Assert.Equal("context provider message", chatHistoryProvider[1].Text); - Assert.Equal("response", chatHistoryProvider[2].Text); + var historyMessages2 = chatHistoryProvider.GetMessages(session); + Assert.Equal(3, historyMessages2.Count); + Assert.Equal("user message", historyMessages2[0].Text); + Assert.Equal("context provider message", historyMessages2[1].Text); + Assert.Equal("response", historyMessages2[2].Text); mockProvider.Verify(p => p.InvokingAsync(It.IsAny(), It.IsAny()), Times.Once); mockProvider.Verify(p => p.InvokedAsync(It.Is(x => diff --git a/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgent_ChatHistoryManagementTests.cs b/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgent_ChatHistoryManagementTests.cs index a854a76622..3e471e38a5 100644 --- a/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgent_ChatHistoryManagementTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgent_ChatHistoryManagementTests.cs @@ -163,9 +163,10 @@ public async Task RunAsync_UsesDefaultInMemoryChatHistoryProvider_WhenNoConversa // Assert InMemoryChatHistoryProvider chatHistoryProvider = Assert.IsType(session!.ChatHistoryProvider); - Assert.Equal(2, chatHistoryProvider.Count); - Assert.Equal("test", chatHistoryProvider[0].Text); - Assert.Equal("response", chatHistoryProvider[1].Text); + var messages = chatHistoryProvider.GetMessages(session); + Assert.Equal(2, messages.Count); + Assert.Equal("test", messages[0].Text); + Assert.Equal("response", messages[1].Text); } /// diff --git a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/TestEchoAgent.cs b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/TestEchoAgent.cs index d66443f069..0a359f9cdd 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/TestEchoAgent.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/TestEchoAgent.cs @@ -36,7 +36,7 @@ public override ValueTask CreateSessionAsync(CancellationToken can private static ChatMessage UpdateSession(ChatMessage message, InMemoryAgentSession? session = null) { - session?.ChatHistoryProvider.Add(message); + session?.ChatHistoryProvider.GetMessages(session).Add(message); return message; } From 122872417b44edc47acc230c00c0bb920920168b Mon Sep 17 00:00:00 2001 From: westey <164392973+westey-m@users.noreply.github.com> Date: Fri, 6 Feb 2026 13:29:33 +0000 Subject: [PATCH 04/28] Update Comsos and Workflow ChatHistoryProviders --- .../CosmosChatHistoryProvider.cs | 357 ++++++++---------- .../CosmosDBChatExtensions.cs | 24 +- .../WorkflowChatHistoryProvider.cs | 55 ++- .../WorkflowHostAgent.cs | 2 +- .../WorkflowSession.cs | 21 +- .../CosmosChatHistoryProviderTests.cs | 222 ++++++----- 6 files changed, 319 insertions(+), 362 deletions(-) diff --git a/dotnet/src/Microsoft.Agents.AI.CosmosNoSql/CosmosChatHistoryProvider.cs b/dotnet/src/Microsoft.Agents.AI.CosmosNoSql/CosmosChatHistoryProvider.cs index 41c9a211dc..e8d129b25b 100644 --- a/dotnet/src/Microsoft.Agents.AI.CosmosNoSql/CosmosChatHistoryProvider.cs +++ b/dotnet/src/Microsoft.Agents.AI.CosmosNoSql/CosmosChatHistoryProvider.cs @@ -21,17 +21,15 @@ namespace Microsoft.Agents.AI; [RequiresDynamicCode("The CosmosChatHistoryProvider uses JSON serialization which is incompatible with NativeAOT.")] public sealed class CosmosChatHistoryProvider : ChatHistoryProvider, IDisposable { + private const string DefaultStateBagKey = "CosmosChatHistoryProvider.State"; + private readonly CosmosClient _cosmosClient; private readonly Container _container; private readonly bool _ownsClient; + private readonly string _stateKey; + private readonly Func _stateInitializer; private bool _disposed; - // Hierarchical partition key support - private readonly string? _tenantId; - private readonly string? _userId; - private readonly PartitionKey _partitionKey; - private readonly bool _useHierarchicalPartitioning; - /// /// Cached JSON serializer options for .NET 9.0 compatibility. /// @@ -72,11 +70,6 @@ private static JsonSerializerOptions CreateDefaultJsonOptions() /// public int? MessageTtlSeconds { get; set; } = 86400; - /// - /// Gets the conversation ID associated with this provider. - /// - public string ConversationId { get; init; } - /// /// Gets the database ID associated with this provider. /// @@ -88,36 +81,31 @@ private static JsonSerializerOptions CreateDefaultJsonOptions() public string ContainerId { get; init; } /// - /// Internal primary constructor used by all public constructors. + /// Initializes a new instance of the class. /// /// The instance to use for Cosmos DB operations. /// The identifier of the Cosmos DB database. /// The identifier of the Cosmos DB container. - /// The unique identifier for this conversation thread. + /// A delegate that initializes the provider state on the first invocation, providing the conversation routing info (conversationId, tenantId, userId). /// Whether this instance owns the CosmosClient and should dispose it. - /// Optional tenant identifier for hierarchical partitioning. - /// Optional user identifier for hierarchical partitioning. - internal CosmosChatHistoryProvider(CosmosClient cosmosClient, string databaseId, string containerId, string conversationId, bool ownsClient, string? tenantId = null, string? userId = null) + /// An optional key to use for storing the state in the . + /// Thrown when or is . + /// Thrown when any string parameter is null or whitespace. + public CosmosChatHistoryProvider( + CosmosClient cosmosClient, + string databaseId, + string containerId, + Func stateInitializer, + bool ownsClient = false, + string? stateKey = null) { this._cosmosClient = Throw.IfNull(cosmosClient); - this._container = this._cosmosClient.GetContainer(Throw.IfNullOrWhitespace(databaseId), Throw.IfNullOrWhitespace(containerId)); - this.ConversationId = Throw.IfNullOrWhitespace(conversationId); - this.DatabaseId = databaseId; - this.ContainerId = containerId; + this.DatabaseId = Throw.IfNullOrWhitespace(databaseId); + this.ContainerId = Throw.IfNullOrWhitespace(containerId); + this._container = this._cosmosClient.GetContainer(databaseId, containerId); + this._stateInitializer = Throw.IfNull(stateInitializer); this._ownsClient = ownsClient; - - // Initialize partitioning mode - this._tenantId = tenantId; - this._userId = userId; - this._useHierarchicalPartitioning = tenantId != null && userId != null; - - this._partitionKey = this._useHierarchicalPartitioning - ? new PartitionKeyBuilder() - .Add(tenantId!) - .Add(userId!) - .Add(conversationId) - .Build() - : new PartitionKey(conversationId); + this._stateKey = stateKey ?? DefaultStateBagKey; } /// @@ -126,24 +114,17 @@ internal CosmosChatHistoryProvider(CosmosClient cosmosClient, string databaseId, /// The Cosmos DB connection string. /// The identifier of the Cosmos DB database. /// The identifier of the Cosmos DB container. + /// A delegate that initializes the provider state on the first invocation. + /// An optional key to use for storing the state in the . /// Thrown when any required parameter is null. /// Thrown when any string parameter is null or whitespace. - public CosmosChatHistoryProvider(string connectionString, string databaseId, string containerId) - : this(connectionString, databaseId, containerId, Guid.NewGuid().ToString("N")) - { - } - - /// - /// Initializes a new instance of the class using a connection string. - /// - /// The Cosmos DB connection string. - /// The identifier of the Cosmos DB database. - /// The identifier of the Cosmos DB container. - /// The unique identifier for this conversation thread. - /// Thrown when any required parameter is null. - /// Thrown when any string parameter is null or whitespace. - public CosmosChatHistoryProvider(string connectionString, string databaseId, string containerId, string conversationId) - : this(new CosmosClient(Throw.IfNullOrWhitespace(connectionString)), databaseId, containerId, conversationId, ownsClient: true) + public CosmosChatHistoryProvider( + string connectionString, + string databaseId, + string containerId, + Func stateInitializer, + string? stateKey = null) + : this(new CosmosClient(Throw.IfNullOrWhitespace(connectionString)), databaseId, containerId, stateInitializer, ownsClient: true, stateKey) { } @@ -154,136 +135,64 @@ public CosmosChatHistoryProvider(string connectionString, string databaseId, str /// The TokenCredential to use for authentication (e.g., DefaultAzureCredential, ManagedIdentityCredential). /// The identifier of the Cosmos DB database. /// The identifier of the Cosmos DB container. + /// A delegate that initializes the provider state on the first invocation. + /// An optional key to use for storing the state in the . /// Thrown when any required parameter is null. /// Thrown when any string parameter is null or whitespace. - public CosmosChatHistoryProvider(string accountEndpoint, TokenCredential tokenCredential, string databaseId, string containerId) - : this(accountEndpoint, tokenCredential, databaseId, containerId, Guid.NewGuid().ToString("N")) - { - } - - /// - /// Initializes a new instance of the class using a TokenCredential for authentication. - /// - /// The Cosmos DB account endpoint URI. - /// The TokenCredential to use for authentication (e.g., DefaultAzureCredential, ManagedIdentityCredential). - /// The identifier of the Cosmos DB database. - /// The identifier of the Cosmos DB container. - /// The unique identifier for this conversation thread. - /// Thrown when any required parameter is null. - /// Thrown when any string parameter is null or whitespace. - public CosmosChatHistoryProvider(string accountEndpoint, TokenCredential tokenCredential, string databaseId, string containerId, string conversationId) - : this(new CosmosClient(Throw.IfNullOrWhitespace(accountEndpoint), Throw.IfNull(tokenCredential)), databaseId, containerId, conversationId, ownsClient: true) + public CosmosChatHistoryProvider( + string accountEndpoint, + TokenCredential tokenCredential, + string databaseId, + string containerId, + Func stateInitializer, + string? stateKey = null) + : this(new CosmosClient(Throw.IfNullOrWhitespace(accountEndpoint), Throw.IfNull(tokenCredential)), databaseId, containerId, stateInitializer, ownsClient: true, stateKey) { } /// - /// Initializes a new instance of the class using an existing . + /// Gets the state from the session's StateBag, or initializes it using the state initializer if not present. /// - /// The instance to use for Cosmos DB operations. - /// The identifier of the Cosmos DB database. - /// The identifier of the Cosmos DB container. - /// Thrown when is null. - /// Thrown when any string parameter is null or whitespace. - public CosmosChatHistoryProvider(CosmosClient cosmosClient, string databaseId, string containerId) - : this(cosmosClient, databaseId, containerId, Guid.NewGuid().ToString("N")) + /// The agent session containing the StateBag. + /// The provider state, or null if no session is available. + private State GetOrInitializeState(AgentSession? session) { - } - - /// - /// Initializes a new instance of the class using an existing . - /// - /// The instance to use for Cosmos DB operations. - /// The identifier of the Cosmos DB database. - /// The identifier of the Cosmos DB container. - /// The unique identifier for this conversation thread. - /// Thrown when is null. - /// Thrown when any string parameter is null or whitespace. - public CosmosChatHistoryProvider(CosmosClient cosmosClient, string databaseId, string containerId, string conversationId) - : this(cosmosClient, databaseId, containerId, conversationId, ownsClient: false) - { - } + var state = session?.StateBag.GetValue(this._stateKey, AgentAbstractionsJsonUtilities.DefaultOptions); + if (state is not null) + { + return state; + } - /// - /// Initializes a new instance of the class using a connection string with hierarchical partition keys. - /// - /// The Cosmos DB connection string. - /// The identifier of the Cosmos DB database. - /// The identifier of the Cosmos DB container. - /// The tenant identifier for hierarchical partitioning. - /// The user identifier for hierarchical partitioning. - /// The session identifier for hierarchical partitioning. - /// Thrown when any required parameter is null. - /// Thrown when any string parameter is null or whitespace. - public CosmosChatHistoryProvider(string connectionString, string databaseId, string containerId, string tenantId, string userId, string sessionId) - : this(new CosmosClient(Throw.IfNullOrWhitespace(connectionString)), databaseId, containerId, Throw.IfNullOrWhitespace(sessionId), ownsClient: true, Throw.IfNullOrWhitespace(tenantId), Throw.IfNullOrWhitespace(userId)) - { - } + state = this._stateInitializer(session); + if (session is not null) + { + session.StateBag.SetValue(this._stateKey, state, AgentAbstractionsJsonUtilities.DefaultOptions); + } - /// - /// Initializes a new instance of the class using a TokenCredential for authentication with hierarchical partition keys. - /// - /// The Cosmos DB account endpoint URI. - /// The TokenCredential to use for authentication (e.g., DefaultAzureCredential, ManagedIdentityCredential). - /// The identifier of the Cosmos DB database. - /// The identifier of the Cosmos DB container. - /// The tenant identifier for hierarchical partitioning. - /// The user identifier for hierarchical partitioning. - /// The session identifier for hierarchical partitioning. - /// Thrown when any required parameter is null. - /// Thrown when any string parameter is null or whitespace. - public CosmosChatHistoryProvider(string accountEndpoint, TokenCredential tokenCredential, string databaseId, string containerId, string tenantId, string userId, string sessionId) - : this(new CosmosClient(Throw.IfNullOrWhitespace(accountEndpoint), Throw.IfNull(tokenCredential)), databaseId, containerId, Throw.IfNullOrWhitespace(sessionId), ownsClient: true, Throw.IfNullOrWhitespace(tenantId), Throw.IfNullOrWhitespace(userId)) - { + return state; } /// - /// Initializes a new instance of the class using an existing with hierarchical partition keys. + /// Determines whether hierarchical partitioning should be used based on the state. /// - /// The instance to use for Cosmos DB operations. - /// The identifier of the Cosmos DB database. - /// The identifier of the Cosmos DB container. - /// The tenant identifier for hierarchical partitioning. - /// The user identifier for hierarchical partitioning. - /// The session identifier for hierarchical partitioning. - /// Thrown when is null. - /// Thrown when any string parameter is null or whitespace. - public CosmosChatHistoryProvider(CosmosClient cosmosClient, string databaseId, string containerId, string tenantId, string userId, string sessionId) - : this(cosmosClient, databaseId, containerId, Throw.IfNullOrWhitespace(sessionId), ownsClient: false, Throw.IfNullOrWhitespace(tenantId), Throw.IfNullOrWhitespace(userId)) - { - } + private static bool UseHierarchicalPartitioning(State state) => + state.TenantId is not null && state.UserId is not null; /// - /// Creates a new instance of the class from previously serialized state. + /// Builds the partition key from the state. /// - /// The instance to use for Cosmos DB operations. - /// A representing the serialized state of the provider. - /// The identifier of the Cosmos DB database. - /// The identifier of the Cosmos DB container. - /// Optional settings for customizing the JSON deserialization process. - /// A new instance of initialized from the serialized state. - /// Thrown when is null. - /// Thrown when the serialized state cannot be deserialized. - public static CosmosChatHistoryProvider CreateFromSerializedState(CosmosClient cosmosClient, JsonElement serializedState, string databaseId, string containerId, JsonSerializerOptions? jsonSerializerOptions = null) + private static PartitionKey BuildPartitionKey(State state) { - Throw.IfNull(cosmosClient); - Throw.IfNullOrWhitespace(databaseId); - Throw.IfNullOrWhitespace(containerId); - - if (serializedState.ValueKind is not JsonValueKind.Object) + if (UseHierarchicalPartitioning(state)) { - throw new ArgumentException("Invalid serialized state", nameof(serializedState)); + return new PartitionKeyBuilder() + .Add(state.TenantId) + .Add(state.UserId) + .Add(state.ConversationId) + .Build(); } - var state = serializedState.Deserialize(jsonSerializerOptions); - if (state?.ConversationIdentifier is not { } conversationId) - { - throw new ArgumentException("Invalid serialized state", nameof(serializedState)); - } - - // Use the internal constructor with all parameters to ensure partition key logic is centralized - return state.UseHierarchicalPartitioning && state.TenantId != null && state.UserId != null - ? new CosmosChatHistoryProvider(cosmosClient, databaseId, containerId, conversationId, ownsClient: false, state.TenantId, state.UserId) - : new CosmosChatHistoryProvider(cosmosClient, databaseId, containerId, conversationId, ownsClient: false); + return new PartitionKey(state.ConversationId); } /// @@ -296,15 +205,20 @@ public override async ValueTask> InvokingAsync(Invoking } #pragma warning restore CA1513 + _ = Throw.IfNull(context); + + var state = this.GetOrInitializeState(context.Session); + var partitionKey = BuildPartitionKey(state); + // Fetch most recent messages in descending order when limit is set, then reverse to ascending var orderDirection = this.MaxMessagesToRetrieve.HasValue ? "DESC" : "ASC"; var query = new QueryDefinition($"SELECT * FROM c WHERE c.conversationId = @conversationId AND c.type = @type ORDER BY c.timestamp {orderDirection}") - .WithParameter("@conversationId", this.ConversationId) + .WithParameter("@conversationId", state.ConversationId) .WithParameter("@type", "ChatMessage"); var iterator = this._container.GetItemQueryIterator(query, requestOptions: new QueryRequestOptions { - PartitionKey = this._partitionKey, + PartitionKey = partitionKey, MaxItemCount = this.MaxItemCount // Configurable query performance }); @@ -364,27 +278,30 @@ public override async ValueTask InvokedAsync(InvokedContext context, Cancellatio } #pragma warning restore CA1513 + var state = this.GetOrInitializeState(context.Session); var messageList = context.RequestMessages.Concat(context.AIContextProviderMessages ?? []).Concat(context.ResponseMessages ?? []).ToList(); if (messageList.Count == 0) { return; } + var partitionKey = BuildPartitionKey(state); + // Use transactional batch for atomic operations if (messageList.Count > 1) { - await this.AddMessagesInBatchAsync(messageList, cancellationToken).ConfigureAwait(false); + await this.AddMessagesInBatchAsync(partitionKey, state, messageList, cancellationToken).ConfigureAwait(false); } else { - await this.AddSingleMessageAsync(messageList.First(), cancellationToken).ConfigureAwait(false); + await this.AddSingleMessageAsync(partitionKey, state, messageList.First(), cancellationToken).ConfigureAwait(false); } } /// /// Adds multiple messages using transactional batch operations for atomicity. /// - private async Task AddMessagesInBatchAsync(List messages, CancellationToken cancellationToken) + private async Task AddMessagesInBatchAsync(PartitionKey partitionKey, State state, List messages, CancellationToken cancellationToken) { var currentTimestamp = DateTimeOffset.UtcNow.ToUnixTimeSeconds(); @@ -392,7 +309,7 @@ private async Task AddMessagesInBatchAsync(List messages, Cancellat for (int i = 0; i < messages.Count; i += this.MaxBatchSize) { var batchMessages = messages.Skip(i).Take(this.MaxBatchSize).ToList(); - await this.ExecuteBatchOperationAsync(batchMessages, currentTimestamp, cancellationToken).ConfigureAwait(false); + await this.ExecuteBatchOperationAsync(partitionKey, state, batchMessages, currentTimestamp, cancellationToken).ConfigureAwait(false); } } @@ -400,13 +317,13 @@ private async Task AddMessagesInBatchAsync(List messages, Cancellat /// Executes a single batch operation with enhanced error handling. /// Cosmos SDK handles throttling (429) retries automatically. /// - private async Task ExecuteBatchOperationAsync(List messages, long timestamp, CancellationToken cancellationToken) + private async Task ExecuteBatchOperationAsync(PartitionKey partitionKey, State state, List messages, long timestamp, CancellationToken cancellationToken) { // Create all documents upfront for validation and batch operation var documents = new List(messages.Count); foreach (var message in messages) { - documents.Add(this.CreateMessageDocument(message, timestamp)); + documents.Add(this.CreateMessageDocument(state, message, timestamp)); } // Defensive check: Verify all messages share the same partition key values @@ -414,7 +331,7 @@ private async Task ExecuteBatchOperationAsync(List messages, long t // In simple partitioning, this means same conversationId if (documents.Count > 0) { - if (this._useHierarchicalPartitioning) + if (UseHierarchicalPartitioning(state)) { // Verify all documents have matching hierarchical partition key components var firstDoc = documents[0]; @@ -436,7 +353,7 @@ private async Task ExecuteBatchOperationAsync(List messages, long t // All messages in this store share the same partition key by design // Transactional batches require all items to share the same partition key - var batch = this._container.CreateTransactionalBatch(this._partitionKey); + var batch = this._container.CreateTransactionalBatch(partitionKey); foreach (var document in documents) { @@ -457,7 +374,7 @@ private async Task ExecuteBatchOperationAsync(List messages, long t if (messages.Count == 1) { // Can't split further, use single operation - await this.AddSingleMessageAsync(messages[0], cancellationToken).ConfigureAwait(false); + await this.AddSingleMessageAsync(partitionKey, state, messages[0], cancellationToken).ConfigureAwait(false); return; } @@ -466,21 +383,21 @@ private async Task ExecuteBatchOperationAsync(List messages, long t var firstHalf = messages.Take(midpoint).ToList(); var secondHalf = messages.Skip(midpoint).ToList(); - await this.ExecuteBatchOperationAsync(firstHalf, timestamp, cancellationToken).ConfigureAwait(false); - await this.ExecuteBatchOperationAsync(secondHalf, timestamp, cancellationToken).ConfigureAwait(false); + await this.ExecuteBatchOperationAsync(partitionKey, state, firstHalf, timestamp, cancellationToken).ConfigureAwait(false); + await this.ExecuteBatchOperationAsync(partitionKey, state, secondHalf, timestamp, cancellationToken).ConfigureAwait(false); } } /// /// Adds a single message to the store. /// - private async Task AddSingleMessageAsync(ChatMessage message, CancellationToken cancellationToken) + private async Task AddSingleMessageAsync(PartitionKey partitionKey, State state, ChatMessage message, CancellationToken cancellationToken) { - var document = this.CreateMessageDocument(message, DateTimeOffset.UtcNow.ToUnixTimeSeconds()); + var document = this.CreateMessageDocument(state, message, DateTimeOffset.UtcNow.ToUnixTimeSeconds()); try { - await this._container.CreateItemAsync(document, this._partitionKey, cancellationToken: cancellationToken).ConfigureAwait(false); + await this._container.CreateItemAsync(document, partitionKey, cancellationToken: cancellationToken).ConfigureAwait(false); } catch (CosmosException ex) when (ex.StatusCode == System.Net.HttpStatusCode.RequestEntityTooLarge) { @@ -495,12 +412,14 @@ private async Task AddSingleMessageAsync(ChatMessage message, CancellationToken /// /// Creates a message document with enhanced metadata. /// - private CosmosMessageDocument CreateMessageDocument(ChatMessage message, long timestamp) + private CosmosMessageDocument CreateMessageDocument(State state, ChatMessage message, long timestamp) { + var useHierarchical = UseHierarchicalPartitioning(state); + return new CosmosMessageDocument { Id = Guid.NewGuid().ToString(), - ConversationId = this.ConversationId, + ConversationId = state.ConversationId, Timestamp = timestamp, MessageId = message.MessageId, Role = message.Role.Value, @@ -508,13 +427,20 @@ private CosmosMessageDocument CreateMessageDocument(ChatMessage message, long ti Type = "ChatMessage", // Type discriminator Ttl = this.MessageTtlSeconds, // Configurable TTL // Include hierarchical metadata when using hierarchical partitioning - TenantId = this._useHierarchicalPartitioning ? this._tenantId : null, - UserId = this._useHierarchicalPartitioning ? this._userId : null, - SessionId = this._useHierarchicalPartitioning ? this.ConversationId : null + TenantId = useHierarchical ? state.TenantId : null, + UserId = useHierarchical ? state.UserId : null, + SessionId = useHierarchical ? state.ConversationId : null }; } - /// + /// + /// Serializes the current provider state to a . + /// + /// Optional serializer options (ignored). + /// An empty object. + /// + /// State is now stored in the and serialized as part of the session. + /// public override JsonElement Serialize(JsonSerializerOptions? jsonSerializerOptions = null) { #pragma warning disable CA1513 // Use ObjectDisposedException.ThrowIf - not available on all target frameworks @@ -524,25 +450,19 @@ public override JsonElement Serialize(JsonSerializerOptions? jsonSerializerOptio } #pragma warning restore CA1513 - var state = new State - { - ConversationIdentifier = this.ConversationId, - TenantId = this._tenantId, - UserId = this._userId, - UseHierarchicalPartitioning = this._useHierarchicalPartitioning - }; - - var options = jsonSerializerOptions ?? s_defaultJsonOptions; - return JsonSerializer.SerializeToElement(state, options); + // State is now stored in the session StateBag, so there is nothing to serialize here. + using var doc = JsonDocument.Parse("{}"); + return doc.RootElement.Clone(); } /// /// Gets the count of messages in this conversation. /// This is an additional utility method beyond the base contract. /// + /// The agent session to get state from. /// The cancellation token. /// The number of messages in the conversation. - public async Task GetMessageCountAsync(CancellationToken cancellationToken = default) + public async Task GetMessageCountAsync(AgentSession? session, CancellationToken cancellationToken = default) { #pragma warning disable CA1513 // Use ObjectDisposedException.ThrowIf - not available on all target frameworks if (this._disposed) @@ -551,14 +471,17 @@ public async Task GetMessageCountAsync(CancellationToken cancellationToken } #pragma warning restore CA1513 + var state = this.GetOrInitializeState(session); + var partitionKey = BuildPartitionKey(state); + // Efficient count query var query = new QueryDefinition("SELECT VALUE COUNT(1) FROM c WHERE c.conversationId = @conversationId AND c.Type = @type") - .WithParameter("@conversationId", this.ConversationId) + .WithParameter("@conversationId", state.ConversationId) .WithParameter("@type", "ChatMessage"); var iterator = this._container.GetItemQueryIterator(query, requestOptions: new QueryRequestOptions { - PartitionKey = this._partitionKey + PartitionKey = partitionKey }); // COUNT queries always return a result @@ -570,9 +493,10 @@ public async Task GetMessageCountAsync(CancellationToken cancellationToken /// Deletes all messages in this conversation. /// This is an additional utility method beyond the base contract. /// + /// The agent session to get state from. /// The cancellation token. /// The number of messages deleted. - public async Task ClearMessagesAsync(CancellationToken cancellationToken = default) + public async Task ClearMessagesAsync(AgentSession? session, CancellationToken cancellationToken = default) { #pragma warning disable CA1513 // Use ObjectDisposedException.ThrowIf - not available on all target frameworks if (this._disposed) @@ -581,14 +505,17 @@ public async Task ClearMessagesAsync(CancellationToken cancellationToken = } #pragma warning restore CA1513 + var state = this.GetOrInitializeState(session); + var partitionKey = BuildPartitionKey(state); + // Batch delete for efficiency var query = new QueryDefinition("SELECT VALUE c.id FROM c WHERE c.conversationId = @conversationId AND c.Type = @type") - .WithParameter("@conversationId", this.ConversationId) + .WithParameter("@conversationId", state.ConversationId) .WithParameter("@type", "ChatMessage"); var iterator = this._container.GetItemQueryIterator(query, requestOptions: new QueryRequestOptions { - PartitionKey = this._partitionKey, + PartitionKey = partitionKey, MaxItemCount = this.MaxItemCount }); @@ -597,7 +524,7 @@ public async Task ClearMessagesAsync(CancellationToken cancellationToken = while (iterator.HasMoreResults) { var response = await iterator.ReadNextAsync(cancellationToken).ConfigureAwait(false); - var batch = this._container.CreateTransactionalBatch(this._partitionKey); + var batch = this._container.CreateTransactionalBatch(partitionKey); var batchItemCount = 0; foreach (var itemId in response) @@ -632,12 +559,38 @@ public void Dispose() } } - private sealed class State + /// + /// Represents the per-session state of a stored in the . + /// + public sealed class State { - public string ConversationIdentifier { get; set; } = string.Empty; - public string? TenantId { get; set; } - public string? UserId { get; set; } - public bool UseHierarchicalPartitioning { get; set; } + /// + /// Initializes a new instance of the class. + /// + /// The unique identifier for this conversation thread. + /// Optional tenant identifier for hierarchical partitioning. + /// Optional user identifier for hierarchical partitioning. + public State(string conversationId, string? tenantId = null, string? userId = null) + { + this.ConversationId = Throw.IfNullOrWhitespace(conversationId); + this.TenantId = tenantId; + this.UserId = userId; + } + + /// + /// Gets the conversation ID associated with this state. + /// + public string ConversationId { get; } + + /// + /// Gets the tenant identifier for hierarchical partitioning, if any. + /// + public string? TenantId { get; } + + /// + /// Gets the user identifier for hierarchical partitioning, if any. + /// + public string? UserId { get; } } /// diff --git a/dotnet/src/Microsoft.Agents.AI.CosmosNoSql/CosmosDBChatExtensions.cs b/dotnet/src/Microsoft.Agents.AI.CosmosNoSql/CosmosDBChatExtensions.cs index 3d93e9dd6a..b64ec9a30c 100644 --- a/dotnet/src/Microsoft.Agents.AI.CosmosNoSql/CosmosDBChatExtensions.cs +++ b/dotnet/src/Microsoft.Agents.AI.CosmosNoSql/CosmosDBChatExtensions.cs @@ -13,6 +13,9 @@ namespace Microsoft.Agents.AI; /// public static class CosmosDBChatExtensions { + private static readonly Func s_defaultStateInitializer = + _ => new CosmosChatHistoryProvider.State(Guid.NewGuid().ToString("N")); + /// /// Configures the agent to use Cosmos DB for message storage with connection string authentication. /// @@ -20,6 +23,7 @@ public static class CosmosDBChatExtensions /// The Cosmos DB connection string. /// The identifier of the Cosmos DB database. /// The identifier of the Cosmos DB container. + /// An optional delegate that initializes the provider state on the first invocation, providing the conversation routing info (conversationId, tenantId, userId). When not provided, a new conversation ID is generated automatically. /// The configured . /// Thrown when is null. /// Thrown when any string parameter is null or whitespace. @@ -29,14 +33,16 @@ public static ChatClientAgentOptions WithCosmosDBChatHistoryProvider( this ChatClientAgentOptions options, string connectionString, string databaseId, - string containerId) + string containerId, + Func? stateInitializer = null) { if (options is null) { throw new ArgumentNullException(nameof(options)); } - options.ChatHistoryProviderFactory = (context, ct) => new ValueTask(new CosmosChatHistoryProvider(connectionString, databaseId, containerId)); + options.ChatHistoryProviderFactory = (context, ct) => new ValueTask( + new CosmosChatHistoryProvider(connectionString, databaseId, containerId, stateInitializer ?? s_defaultStateInitializer)); return options; } @@ -48,6 +54,7 @@ public static ChatClientAgentOptions WithCosmosDBChatHistoryProvider( /// The identifier of the Cosmos DB database. /// The identifier of the Cosmos DB container. /// The TokenCredential to use for authentication (e.g., DefaultAzureCredential, ManagedIdentityCredential). + /// An optional delegate that initializes the provider state on the first invocation, providing the conversation routing info (conversationId, tenantId, userId). When not provided, a new conversation ID is generated automatically. /// The configured . /// Thrown when or is null. /// Thrown when any string parameter is null or whitespace. @@ -58,7 +65,8 @@ public static ChatClientAgentOptions WithCosmosDBChatHistoryProviderUsingManaged string accountEndpoint, string databaseId, string containerId, - TokenCredential tokenCredential) + TokenCredential tokenCredential, + Func? stateInitializer = null) { if (options is null) { @@ -70,7 +78,8 @@ public static ChatClientAgentOptions WithCosmosDBChatHistoryProviderUsingManaged throw new ArgumentNullException(nameof(tokenCredential)); } - options.ChatHistoryProviderFactory = (context, ct) => new ValueTask(new CosmosChatHistoryProvider(accountEndpoint, tokenCredential, databaseId, containerId)); + options.ChatHistoryProviderFactory = (context, ct) => new ValueTask( + new CosmosChatHistoryProvider(accountEndpoint, tokenCredential, databaseId, containerId, stateInitializer ?? s_defaultStateInitializer)); return options; } @@ -81,6 +90,7 @@ public static ChatClientAgentOptions WithCosmosDBChatHistoryProviderUsingManaged /// The instance to use for Cosmos DB operations. /// The identifier of the Cosmos DB database. /// The identifier of the Cosmos DB container. + /// An optional delegate that initializes the provider state on the first invocation, providing the conversation routing info (conversationId, tenantId, userId). When not provided, a new conversation ID is generated automatically. /// The configured . /// Thrown when any required parameter is null. /// Thrown when any string parameter is null or whitespace. @@ -90,14 +100,16 @@ public static ChatClientAgentOptions WithCosmosDBChatHistoryProvider( this ChatClientAgentOptions options, CosmosClient cosmosClient, string databaseId, - string containerId) + string containerId, + Func? stateInitializer = null) { if (options is null) { throw new ArgumentNullException(nameof(options)); } - options.ChatHistoryProviderFactory = (context, ct) => new ValueTask(new CosmosChatHistoryProvider(cosmosClient, databaseId, containerId)); + options.ChatHistoryProviderFactory = (context, ct) => new ValueTask( + new CosmosChatHistoryProvider(cosmosClient, databaseId, containerId, stateInitializer ?? s_defaultStateInitializer)); return options; } } diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows/WorkflowChatHistoryProvider.cs b/dotnet/src/Microsoft.Agents.AI.Workflows/WorkflowChatHistoryProvider.cs index afe6706553..8528bda3b5 100644 --- a/dotnet/src/Microsoft.Agents.AI.Workflows/WorkflowChatHistoryProvider.cs +++ b/dotnet/src/Microsoft.Agents.AI.Workflows/WorkflowChatHistoryProvider.cs @@ -6,48 +6,45 @@ using System.Threading; using System.Threading.Tasks; using Microsoft.Extensions.AI; -using Microsoft.Shared.Diagnostics; namespace Microsoft.Agents.AI.Workflows; internal sealed class WorkflowChatHistoryProvider : ChatHistoryProvider { - private int _bookmark; - private readonly List _chatMessages = []; + private const string DefaultStateBagKey = "WorkflowChatHistoryProvider.State"; public WorkflowChatHistoryProvider() { } - public WorkflowChatHistoryProvider(StoreState state) + internal sealed class StoreState { - this.ImportStoreState(Throw.IfNull(state)); + public int Bookmark { get; set; } + public List Messages { get; set; } = []; } - private void ImportStoreState(StoreState state, bool clearMessages = false) + private StoreState GetOrInitializeState(AgentSession? session) { - if (clearMessages) + var state = session?.StateBag.GetValue(DefaultStateBagKey, AgentAbstractionsJsonUtilities.DefaultOptions); + if (state is not null) { - this._chatMessages.Clear(); + return state; } - if (state?.Messages is not null) + state = new(); + if (session is not null) { - this._chatMessages.AddRange(state.Messages); + session.StateBag.SetValue(DefaultStateBagKey, state, AgentAbstractionsJsonUtilities.DefaultOptions); } - this._bookmark = state?.Bookmark ?? 0; - } - internal sealed class StoreState - { - public int Bookmark { get; set; } - public IList Messages { get; set; } = []; + return state; } - internal void AddMessages(params IEnumerable messages) => this._chatMessages.AddRange(messages); + internal void AddMessages(AgentSession session, params IEnumerable messages) + => this.GetOrInitializeState(session).Messages.AddRange(messages); public override ValueTask> InvokingAsync(InvokingContext context, CancellationToken cancellationToken = default) - => new(this._chatMessages.AsReadOnly()); + => new(this.GetOrInitializeState(context.Session).Messages.AsReadOnly()); public override ValueTask InvokedAsync(InvokedContext context, CancellationToken cancellationToken = default) { @@ -57,28 +54,30 @@ public override ValueTask InvokedAsync(InvokedContext context, CancellationToken } var allNewMessages = context.RequestMessages.Concat(context.AIContextProviderMessages ?? []).Concat(context.ResponseMessages ?? []); - this._chatMessages.AddRange(allNewMessages); + this.GetOrInitializeState(context.Session).Messages.AddRange(allNewMessages); return default; } - public IEnumerable GetFromBookmark() + public IEnumerable GetFromBookmark(AgentSession session) { - for (int i = this._bookmark; i < this._chatMessages.Count; i++) + var state = this.GetOrInitializeState(session); + + for (int i = state.Bookmark; i < state.Messages.Count; i++) { - yield return this._chatMessages[i]; + yield return state.Messages[i]; } } - public void UpdateBookmark() => this._bookmark = this._chatMessages.Count; + public void UpdateBookmark(AgentSession session) + { + var state = this.GetOrInitializeState(session); + state.Bookmark = state.Messages.Count; + } public override JsonElement Serialize(JsonSerializerOptions? jsonSerializerOptions = null) { - StoreState state = this.ExportStoreState(); - - return JsonSerializer.SerializeToElement(state, + return JsonSerializer.SerializeToElement(null, WorkflowsJsonUtilities.DefaultOptions.GetTypeInfo(typeof(StoreState))); } - - internal StoreState ExportStoreState() => new() { Bookmark = this._bookmark, Messages = this._chatMessages }; } diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows/WorkflowHostAgent.cs b/dotnet/src/Microsoft.Agents.AI.Workflows/WorkflowHostAgent.cs index 189ca43101..eb9ae3f21f 100644 --- a/dotnet/src/Microsoft.Agents.AI.Workflows/WorkflowHostAgent.cs +++ b/dotnet/src/Microsoft.Agents.AI.Workflows/WorkflowHostAgent.cs @@ -94,7 +94,7 @@ private async ValueTask UpdateSessionAsync(IEnumerable InvokeStageAsync( try { this.LastResponseId = Guid.NewGuid().ToString("N"); - List messages = this.ChatHistoryProvider.GetFromBookmark().ToList(); + List messages = this.ChatHistoryProvider.GetFromBookmark(this).ToList(); #pragma warning disable CA2007 // Analyzer misfiring and not seeing .ConfigureAwait(false) below. await using Checkpointed checkpointed = @@ -240,7 +241,7 @@ IAsyncEnumerable InvokeStageAsync( finally { // Do we want to try to undo the step, and not update the bookmark? - this.ChatHistoryProvider.UpdateBookmark(); + this.ChatHistoryProvider.UpdateBookmark(this); } } @@ -254,12 +255,12 @@ IAsyncEnumerable InvokeStageAsync( internal sealed class SessionState( string runId, CheckpointInfo? lastCheckpoint, - WorkflowChatHistoryProvider.StoreState chatHistoryProviderState, - InMemoryCheckpointManager? checkpointManager = null) + InMemoryCheckpointManager? checkpointManager = null, + AgentSessionStateBag? stateBag = null) { public string RunId { get; } = runId; public CheckpointInfo? LastCheckpoint { get; } = lastCheckpoint; - public WorkflowChatHistoryProvider.StoreState ChatHistoryProviderState { get; } = chatHistoryProviderState; public InMemoryCheckpointManager? CheckpointManager { get; } = checkpointManager; + public AgentSessionStateBag StateBag { get; } = stateBag ?? new(); } } diff --git a/dotnet/tests/Microsoft.Agents.AI.CosmosNoSql.UnitTests/CosmosChatHistoryProviderTests.cs b/dotnet/tests/Microsoft.Agents.AI.CosmosNoSql.UnitTests/CosmosChatHistoryProviderTests.cs index f6589ff9e3..a7e3f0411a 100644 --- a/dotnet/tests/Microsoft.Agents.AI.CosmosNoSql.UnitTests/CosmosChatHistoryProviderTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.CosmosNoSql.UnitTests/CosmosChatHistoryProviderTests.cs @@ -3,8 +3,6 @@ using System; using System.Collections.Generic; using System.Linq; -using System.Text.Json; -using System.Text.Json.Serialization.Metadata; using System.Threading.Tasks; using Azure.Core; using Azure.Identity; @@ -42,7 +40,8 @@ namespace Microsoft.Agents.AI.CosmosNoSql.UnitTests; public sealed class CosmosChatHistoryProviderTests : IAsyncLifetime, IDisposable { private static readonly AIAgent s_mockAgent = new Moq.Mock().Object; - private static readonly AgentSession s_mockSession = new Moq.Mock().Object; + + private static AgentSession CreateMockSession() => new Moq.Mock().Object; // Cosmos DB Emulator connection settings private const string EmulatorEndpoint = "https://localhost:8081"; @@ -157,28 +156,11 @@ public void Constructor_WithConnectionString_ShouldCreateInstance() this.SkipIfEmulatorNotAvailable(); // Act - using var provider = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, TestContainerId, "test-conversation"); - - // Assert - Assert.NotNull(provider); - Assert.Equal("test-conversation", provider.ConversationId); - Assert.Equal(s_testDatabaseId, provider.DatabaseId); - Assert.Equal(TestContainerId, provider.ContainerId); - } - - [SkippableFact] - [Trait("Category", "CosmosDB")] - public void Constructor_WithConnectionStringNoConversationId_ShouldCreateInstance() - { - // Arrange - this.SkipIfEmulatorNotAvailable(); - - // Act - using var provider = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, TestContainerId); + using var provider = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, TestContainerId, + _ => new CosmosChatHistoryProvider.State("test-conversation")); // Assert Assert.NotNull(provider); - Assert.NotNull(provider.ConversationId); Assert.Equal(s_testDatabaseId, provider.DatabaseId); Assert.Equal(TestContainerId, provider.ContainerId); } @@ -189,18 +171,19 @@ public void Constructor_WithNullConnectionString_ShouldThrowArgumentException() { // Arrange & Act & Assert Assert.Throws(() => - new CosmosChatHistoryProvider((string)null!, s_testDatabaseId, TestContainerId, "test-conversation")); + new CosmosChatHistoryProvider((string)null!, s_testDatabaseId, TestContainerId, + _ => new CosmosChatHistoryProvider.State("test-conversation"))); } [SkippableFact] [Trait("Category", "CosmosDB")] - public void Constructor_WithEmptyConversationId_ShouldThrowArgumentException() + public void Constructor_WithNullStateInitializer_ShouldThrowArgumentNullException() { // Arrange & Act & Assert this.SkipIfEmulatorNotAvailable(); - Assert.Throws(() => - new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, TestContainerId, "")); + Assert.Throws(() => + new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, TestContainerId, null!)); } #endregion @@ -213,11 +196,13 @@ public async Task InvokedAsync_WithSingleMessage_ShouldAddMessageAsync() { // Arrange this.SkipIfEmulatorNotAvailable(); + var session = CreateMockSession(); var conversationId = Guid.NewGuid().ToString(); - using var provider = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, TestContainerId, conversationId); + using var provider = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, TestContainerId, + _ => new CosmosChatHistoryProvider.State(conversationId)); var message = new ChatMessage(ChatRole.User, "Hello, world!"); - var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, [message], []) + var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, session, [message], []) { ResponseMessages = [] }; @@ -229,7 +214,7 @@ public async Task InvokedAsync_WithSingleMessage_ShouldAddMessageAsync() await Task.Delay(100); // Assert - var invokingContext = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, []); + var invokingContext = new ChatHistoryProvider.InvokingContext(s_mockAgent, session, []); var messages = await provider.InvokingAsync(invokingContext); var messageList = messages.ToList(); @@ -279,8 +264,10 @@ public async Task InvokedAsync_WithMultipleMessages_ShouldAddAllMessagesAsync() { // Arrange this.SkipIfEmulatorNotAvailable(); + var session = CreateMockSession(); var conversationId = Guid.NewGuid().ToString(); - using var provider = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, TestContainerId, conversationId); + using var provider = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, TestContainerId, + _ => new CosmosChatHistoryProvider.State(conversationId)); var requestMessages = new[] { new ChatMessage(ChatRole.User, "First message"), @@ -296,7 +283,7 @@ public async Task InvokedAsync_WithMultipleMessages_ShouldAddAllMessagesAsync() new ChatMessage(ChatRole.Assistant, "Response message") }; - var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages, []) + var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, session, requestMessages, []) { AIContextProviderMessages = aiContextProviderMessages, ResponseMessages = responseMessages @@ -306,7 +293,7 @@ public async Task InvokedAsync_WithMultipleMessages_ShouldAddAllMessagesAsync() await provider.InvokedAsync(context); // Assert - var invokingContext = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, []); + var invokingContext = new ChatHistoryProvider.InvokingContext(s_mockAgent, session, []); var retrievedMessages = await provider.InvokingAsync(invokingContext); var messageList = retrievedMessages.ToList(); Assert.Equal(5, messageList.Count); @@ -327,10 +314,12 @@ public async Task InvokingAsync_WithNoMessages_ShouldReturnEmptyAsync() { // Arrange this.SkipIfEmulatorNotAvailable(); - using var provider = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, TestContainerId, Guid.NewGuid().ToString()); + var session = CreateMockSession(); + using var provider = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, TestContainerId, + _ => new CosmosChatHistoryProvider.State(Guid.NewGuid().ToString())); // Act - var invokingContext = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, []); + var invokingContext = new ChatHistoryProvider.InvokingContext(s_mockAgent, session, []); var messages = await provider.InvokingAsync(invokingContext); // Assert @@ -343,21 +332,24 @@ public async Task InvokingAsync_WithConversationIsolation_ShouldOnlyReturnMessag { // Arrange this.SkipIfEmulatorNotAvailable(); + var session = CreateMockSession(); var conversation1 = Guid.NewGuid().ToString(); var conversation2 = Guid.NewGuid().ToString(); - using var store1 = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, TestContainerId, conversation1); - using var store2 = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, TestContainerId, conversation2); + using var store1 = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, TestContainerId, + _ => new CosmosChatHistoryProvider.State(conversation1)); + using var store2 = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, TestContainerId, + _ => new CosmosChatHistoryProvider.State(conversation2)); - var context1 = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Message for conversation 1")], []); - var context2 = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Message for conversation 2")], []); + var context1 = new ChatHistoryProvider.InvokedContext(s_mockAgent, session, [new ChatMessage(ChatRole.User, "Message for conversation 1")], []); + var context2 = new ChatHistoryProvider.InvokedContext(s_mockAgent, session, [new ChatMessage(ChatRole.User, "Message for conversation 2")], []); await store1.InvokedAsync(context1); await store2.InvokedAsync(context2); // Act - var invokingContext1 = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, []); - var invokingContext2 = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, []); + var invokingContext1 = new ChatHistoryProvider.InvokingContext(s_mockAgent, session, []); + var invokingContext2 = new ChatHistoryProvider.InvokingContext(s_mockAgent, session, []); var messages1 = await store1.InvokingAsync(invokingContext1); var messages2 = await store2.InvokingAsync(invokingContext2); @@ -381,8 +373,10 @@ public async Task FullWorkflow_AddAndGet_ShouldWorkCorrectlyAsync() { // Arrange this.SkipIfEmulatorNotAvailable(); + var session = CreateMockSession(); var conversationId = $"test-conversation-{Guid.NewGuid():N}"; // Use unique conversation ID - using var originalStore = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, TestContainerId, conversationId); + using var originalStore = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, TestContainerId, + _ => new CosmosChatHistoryProvider.State(conversationId)); var messages = new[] { @@ -394,18 +388,21 @@ public async Task FullWorkflow_AddAndGet_ShouldWorkCorrectlyAsync() }; // Act 1: Add messages - var invokedContext = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, messages, []); + var invokedContext = new ChatHistoryProvider.InvokedContext(s_mockAgent, session, messages, []); await originalStore.InvokedAsync(invokedContext); // Act 2: Verify messages were added - var invokingContext = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, []); + var invokingContext = new ChatHistoryProvider.InvokingContext(s_mockAgent, session, []); var retrievedMessages = await originalStore.InvokingAsync(invokingContext); var retrievedList = retrievedMessages.ToList(); Assert.Equal(5, retrievedList.Count); // Act 3: Create new provider instance for same conversation (test persistence) - using var newProvider = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, TestContainerId, conversationId); - var persistedMessages = await newProvider.InvokingAsync(invokingContext); + using var newProvider = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, TestContainerId, + _ => new CosmosChatHistoryProvider.State(conversationId)); + var newSession = CreateMockSession(); + var newInvokingContext = new ChatHistoryProvider.InvokingContext(s_mockAgent, newSession, []); + var persistedMessages = await newProvider.InvokingAsync(newInvokingContext); var persistedList = persistedMessages.ToList(); // Assert final state @@ -427,7 +424,8 @@ public void Dispose_AfterUse_ShouldNotThrow() { // Arrange this.SkipIfEmulatorNotAvailable(); - var provider = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, TestContainerId, Guid.NewGuid().ToString()); + var provider = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, TestContainerId, + _ => new CosmosChatHistoryProvider.State(Guid.NewGuid().ToString())); // Act & Assert provider.Dispose(); // Should not throw @@ -439,7 +437,8 @@ public void Dispose_MultipleCalls_ShouldNotThrow() { // Arrange this.SkipIfEmulatorNotAvailable(); - var provider = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, TestContainerId, Guid.NewGuid().ToString()); + var provider = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, TestContainerId, + _ => new CosmosChatHistoryProvider.State(Guid.NewGuid().ToString())); // Act & Assert provider.Dispose(); // First call @@ -458,11 +457,11 @@ public void Constructor_WithHierarchicalConnectionString_ShouldCreateInstance() this.SkipIfEmulatorNotAvailable(); // Act - using var provider = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, HierarchicalTestContainerId, "tenant-123", "user-456", "session-789"); + using var provider = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, HierarchicalTestContainerId, + _ => new CosmosChatHistoryProvider.State("session-789", "tenant-123", "user-456")); // Assert Assert.NotNull(provider); - Assert.Equal("session-789", provider.ConversationId); Assert.Equal(s_testDatabaseId, provider.DatabaseId); Assert.Equal(HierarchicalTestContainerId, provider.ContainerId); } @@ -476,11 +475,11 @@ public void Constructor_WithHierarchicalEndpoint_ShouldCreateInstance() // Act TokenCredential credential = new DefaultAzureCredential(); - using var provider = new CosmosChatHistoryProvider(EmulatorEndpoint, credential, s_testDatabaseId, HierarchicalTestContainerId, "tenant-123", "user-456", "session-789"); + using var provider = new CosmosChatHistoryProvider(EmulatorEndpoint, credential, s_testDatabaseId, HierarchicalTestContainerId, + _ => new CosmosChatHistoryProvider.State("session-789", "tenant-123", "user-456")); // Assert Assert.NotNull(provider); - Assert.Equal("session-789", provider.ConversationId); Assert.Equal(s_testDatabaseId, provider.DatabaseId); Assert.Equal(HierarchicalTestContainerId, provider.ContainerId); } @@ -493,46 +492,31 @@ public void Constructor_WithHierarchicalCosmosClient_ShouldCreateInstance() this.SkipIfEmulatorNotAvailable(); using var cosmosClient = new CosmosClient(EmulatorEndpoint, EmulatorKey); - using var provider = new CosmosChatHistoryProvider(cosmosClient, s_testDatabaseId, HierarchicalTestContainerId, "tenant-123", "user-456", "session-789"); + using var provider = new CosmosChatHistoryProvider(cosmosClient, s_testDatabaseId, HierarchicalTestContainerId, + _ => new CosmosChatHistoryProvider.State("session-789", "tenant-123", "user-456")); // Assert Assert.NotNull(provider); - Assert.Equal("session-789", provider.ConversationId); Assert.Equal(s_testDatabaseId, provider.DatabaseId); Assert.Equal(HierarchicalTestContainerId, provider.ContainerId); } [SkippableFact] [Trait("Category", "CosmosDB")] - public void Constructor_WithHierarchicalNullTenantId_ShouldThrowArgumentException() - { - // Arrange & Act & Assert - this.SkipIfEmulatorNotAvailable(); - - Assert.Throws(() => - new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, TestContainerId, null!, "user-456", "session-789")); - } - - [SkippableFact] - [Trait("Category", "CosmosDB")] - public void Constructor_WithHierarchicalEmptyUserId_ShouldThrowArgumentException() + public void State_WithEmptyConversationId_ShouldThrowArgumentException() { // Arrange & Act & Assert - this.SkipIfEmulatorNotAvailable(); - Assert.Throws(() => - new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, HierarchicalTestContainerId, "tenant-123", "", "session-789")); + new CosmosChatHistoryProvider.State("")); } [SkippableFact] [Trait("Category", "CosmosDB")] - public void Constructor_WithHierarchicalWhitespaceSessionId_ShouldThrowArgumentException() + public void State_WithWhitespaceConversationId_ShouldThrowArgumentException() { // Arrange & Act & Assert - this.SkipIfEmulatorNotAvailable(); - Assert.Throws(() => - new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, HierarchicalTestContainerId, "tenant-123", "user-456", " ")); + new CosmosChatHistoryProvider.State(" ")); } [SkippableFact] @@ -541,14 +525,16 @@ public async Task InvokedAsync_WithHierarchicalPartitioning_ShouldAddMessageWith { // Arrange this.SkipIfEmulatorNotAvailable(); + var session = CreateMockSession(); const string TenantId = "tenant-123"; const string UserId = "user-456"; const string SessionId = "session-789"; // Test hierarchical partitioning constructor with connection string - using var provider = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, HierarchicalTestContainerId, TenantId, UserId, SessionId); + using var provider = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, HierarchicalTestContainerId, + _ => new CosmosChatHistoryProvider.State(SessionId, TenantId, UserId)); var message = new ChatMessage(ChatRole.User, "Hello from hierarchical partitioning!"); - var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, [message], []); + var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, session, [message], []); // Act await provider.InvokedAsync(context); @@ -557,7 +543,7 @@ public async Task InvokedAsync_WithHierarchicalPartitioning_ShouldAddMessageWith await Task.Delay(100); // Assert - var invokingContext = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, []); + var invokingContext = new ChatHistoryProvider.InvokingContext(s_mockAgent, session, []); var messages = await provider.InvokingAsync(invokingContext); var messageList = messages.ToList(); @@ -593,11 +579,13 @@ public async Task InvokedAsync_WithHierarchicalMultipleMessages_ShouldAddAllMess { // Arrange this.SkipIfEmulatorNotAvailable(); + var session = CreateMockSession(); const string TenantId = "tenant-batch"; const string UserId = "user-batch"; const string SessionId = "session-batch"; // Test hierarchical partitioning constructor with connection string - using var provider = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, HierarchicalTestContainerId, TenantId, UserId, SessionId); + using var provider = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, HierarchicalTestContainerId, + _ => new CosmosChatHistoryProvider.State(SessionId, TenantId, UserId)); var messages = new[] { new ChatMessage(ChatRole.User, "First hierarchical message"), @@ -605,7 +593,7 @@ public async Task InvokedAsync_WithHierarchicalMultipleMessages_ShouldAddAllMess new ChatMessage(ChatRole.User, "Third hierarchical message") }; - var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, messages, []); + var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, session, messages, []); // Act await provider.InvokedAsync(context); @@ -614,7 +602,7 @@ public async Task InvokedAsync_WithHierarchicalMultipleMessages_ShouldAddAllMess await Task.Delay(100); // Assert - var invokingContext = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, []); + var invokingContext = new ChatHistoryProvider.InvokingContext(s_mockAgent, session, []); var retrievedMessages = await provider.InvokingAsync(invokingContext); var messageList = retrievedMessages.ToList(); @@ -630,18 +618,21 @@ public async Task InvokingAsync_WithHierarchicalPartitionIsolation_ShouldIsolate { // Arrange this.SkipIfEmulatorNotAvailable(); + var session = CreateMockSession(); const string TenantId = "tenant-isolation"; const string UserId1 = "user-1"; const string UserId2 = "user-2"; const string SessionId = "session-isolation"; // Different userIds create different hierarchical partitions, providing proper isolation - using var store1 = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, HierarchicalTestContainerId, TenantId, UserId1, SessionId); - using var store2 = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, HierarchicalTestContainerId, TenantId, UserId2, SessionId); + using var store1 = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, HierarchicalTestContainerId, + _ => new CosmosChatHistoryProvider.State(SessionId, TenantId, UserId1)); + using var store2 = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, HierarchicalTestContainerId, + _ => new CosmosChatHistoryProvider.State(SessionId, TenantId, UserId2)); // Add messages to both stores - var context1 = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Message from user 1")], []); - var context2 = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Message from user 2")], []); + var context1 = new ChatHistoryProvider.InvokedContext(s_mockAgent, session, [new ChatMessage(ChatRole.User, "Message from user 1")], []); + var context2 = new ChatHistoryProvider.InvokedContext(s_mockAgent, session, [new ChatMessage(ChatRole.User, "Message from user 2")], []); await store1.InvokedAsync(context1); await store2.InvokedAsync(context2); @@ -650,8 +641,8 @@ public async Task InvokingAsync_WithHierarchicalPartitionIsolation_ShouldIsolate await Task.Delay(100); // Act & Assert - var invokingContext1 = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, []); - var invokingContext2 = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, []); + var invokingContext1 = new ChatHistoryProvider.InvokingContext(s_mockAgent, session, []); + var invokingContext2 = new ChatHistoryProvider.InvokingContext(s_mockAgent, session, []); var messages1 = await store1.InvokingAsync(invokingContext1); var messageList1 = messages1.ToList(); @@ -668,43 +659,37 @@ public async Task InvokingAsync_WithHierarchicalPartitionIsolation_ShouldIsolate [SkippableFact] [Trait("Category", "CosmosDB")] - public async Task SerializeDeserialize_WithHierarchicalPartitioning_ShouldPreserveStateAsync() + public async Task StateBag_WithHierarchicalPartitioning_ShouldPreserveStateAcrossProviderInstancesAsync() { // Arrange this.SkipIfEmulatorNotAvailable(); + var session = CreateMockSession(); const string TenantId = "tenant-serialize"; const string UserId = "user-serialize"; const string SessionId = "session-serialize"; - using var originalStore = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, HierarchicalTestContainerId, TenantId, UserId, SessionId); + using var originalStore = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, HierarchicalTestContainerId, + _ => new CosmosChatHistoryProvider.State(SessionId, TenantId, UserId)); - var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Test serialization message")], []); + var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, session, [new ChatMessage(ChatRole.User, "Test serialization message")], []); await originalStore.InvokedAsync(context); - // Act - Serialize the provider state - var serializedState = originalStore.Serialize(); - - // Create a new provider from the serialized state - using var cosmosClient = new CosmosClient(EmulatorEndpoint, EmulatorKey); - var serializerOptions = new JsonSerializerOptions - { - TypeInfoResolver = new DefaultJsonTypeInfoResolver() - }; - using var deserializedStore = CosmosChatHistoryProvider.CreateFromSerializedState(cosmosClient, serializedState, s_testDatabaseId, HierarchicalTestContainerId, serializerOptions); - // Wait a moment for eventual consistency await Task.Delay(100); - // Assert - The deserialized provider should have the same functionality - var invokingContext = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, []); - var messages = await deserializedStore.InvokingAsync(invokingContext); + // Act - Create a new provider that uses a different intializer, but we will use the same session. + using var newStore = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, HierarchicalTestContainerId, + _ => new CosmosChatHistoryProvider.State(Guid.NewGuid().ToString())); + + // Assert - The new provider should read the same messages from Cosmos DB + var invokingContext = new ChatHistoryProvider.InvokingContext(s_mockAgent, session, []); + var messages = await newStore.InvokingAsync(invokingContext); var messageList = messages.ToList(); Assert.Single(messageList); Assert.Equal("Test serialization message", messageList[0].Text); - Assert.Equal(SessionId, deserializedStore.ConversationId); - Assert.Equal(s_testDatabaseId, deserializedStore.DatabaseId); - Assert.Equal(HierarchicalTestContainerId, deserializedStore.ContainerId); + Assert.Equal(s_testDatabaseId, newStore.DatabaseId); + Assert.Equal(HierarchicalTestContainerId, newStore.ContainerId); } [SkippableFact] @@ -715,13 +700,16 @@ public async Task HierarchicalAndSimplePartitioning_ShouldCoexistAsync() this.SkipIfEmulatorNotAvailable(); const string SessionId = "coexist-session"; + var session = CreateMockSession(); // Create simple provider using simple partitioning container and hierarchical provider using hierarchical container - using var simpleProvider = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, TestContainerId, SessionId); - using var hierarchicalProvider = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, HierarchicalTestContainerId, "tenant-coexist", "user-coexist", SessionId); + using var simpleProvider = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, TestContainerId, + _ => new CosmosChatHistoryProvider.State(SessionId)); + using var hierarchicalProvider = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, HierarchicalTestContainerId, + _ => new CosmosChatHistoryProvider.State(SessionId, "tenant-coexist", "user-coexist")); // Add messages to both - var simpleContext = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Simple partitioning message")], []); - var hierarchicalContext = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Hierarchical partitioning message")], []); + var simpleContext = new ChatHistoryProvider.InvokedContext(s_mockAgent, session, [new ChatMessage(ChatRole.User, "Simple partitioning message")], []); + var hierarchicalContext = new ChatHistoryProvider.InvokedContext(s_mockAgent, session, [new ChatMessage(ChatRole.User, "Hierarchical partitioning message")], []); await simpleProvider.InvokedAsync(simpleContext); await hierarchicalProvider.InvokedAsync(hierarchicalContext); @@ -730,7 +718,7 @@ public async Task HierarchicalAndSimplePartitioning_ShouldCoexistAsync() await Task.Delay(100); // Act & Assert - var invokingContext = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, []); + var invokingContext = new ChatHistoryProvider.InvokingContext(s_mockAgent, session, []); var simpleMessages = await simpleProvider.InvokingAsync(invokingContext); var simpleMessageList = simpleMessages.ToList(); @@ -751,9 +739,11 @@ public async Task MaxMessagesToRetrieve_ShouldLimitAndReturnMostRecentAsync() { // Arrange this.SkipIfEmulatorNotAvailable(); + var session = CreateMockSession(); const string ConversationId = "max-messages-test"; - using var provider = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, TestContainerId, ConversationId); + using var provider = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, TestContainerId, + _ => new CosmosChatHistoryProvider.State(ConversationId)); // Add 10 messages var messages = new List(); @@ -763,7 +753,7 @@ public async Task MaxMessagesToRetrieve_ShouldLimitAndReturnMostRecentAsync() await Task.Delay(10); // Small delay to ensure different timestamps } - var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, messages, []); + var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, session, messages, []); await provider.InvokedAsync(context); // Wait for eventual consistency @@ -771,7 +761,7 @@ public async Task MaxMessagesToRetrieve_ShouldLimitAndReturnMostRecentAsync() // Act - Set max to 5 and retrieve provider.MaxMessagesToRetrieve = 5; - var invokingContext = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, []); + var invokingContext = new ChatHistoryProvider.InvokingContext(s_mockAgent, session, []); var retrievedMessages = await provider.InvokingAsync(invokingContext); var messageList = retrievedMessages.ToList(); @@ -790,9 +780,11 @@ public async Task MaxMessagesToRetrieve_Null_ShouldReturnAllMessagesAsync() { // Arrange this.SkipIfEmulatorNotAvailable(); + var session = CreateMockSession(); const string ConversationId = "max-messages-null-test"; - using var provider = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, TestContainerId, ConversationId); + using var provider = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, TestContainerId, + _ => new CosmosChatHistoryProvider.State(ConversationId)); // Add 10 messages var messages = new List(); @@ -801,14 +793,14 @@ public async Task MaxMessagesToRetrieve_Null_ShouldReturnAllMessagesAsync() messages.Add(new ChatMessage(ChatRole.User, $"Message {i}")); } - var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, messages, []); + var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, session, messages, []); await provider.InvokedAsync(context); // Wait for eventual consistency await Task.Delay(100); // Act - No limit set (default null) - var invokingContext = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, []); + var invokingContext = new ChatHistoryProvider.InvokingContext(s_mockAgent, session, []); var retrievedMessages = await provider.InvokingAsync(invokingContext); var messageList = retrievedMessages.ToList(); From 3b93fff7c427b50713eb3513b42db41c8c687b29 Mon Sep 17 00:00:00 2001 From: westey <164392973+westey-m@users.noreply.github.com> Date: Fri, 6 Feb 2026 14:18:21 +0000 Subject: [PATCH 05/28] Update 3rd party chat history storage sample. --- .../Program.cs | 72 +++++++++++++------ 1 file changed, 52 insertions(+), 20 deletions(-) diff --git a/dotnet/samples/GettingStarted/Agents/Agent_Step07_3rdPartyChatHistoryStorage/Program.cs b/dotnet/samples/GettingStarted/Agents/Agent_Step07_3rdPartyChatHistoryStorage/Program.cs index 81a2beb3da..f00ad58a5e 100644 --- a/dotnet/samples/GettingStarted/Agents/Agent_Step07_3rdPartyChatHistoryStorage/Program.cs +++ b/dotnet/samples/GettingStarted/Agents/Agent_Step07_3rdPartyChatHistoryStorage/Program.cs @@ -3,7 +3,7 @@ #pragma warning disable CA1869 // Cache and reuse 'JsonSerializerOptions' instances // This sample shows how to create and use a simple AI agent with custom ChatHistoryProvider that stores chat history in a custom storage location. -// The state of the custom ChatHistoryProvider (SessionDbKey) is stored with the agent session, so that when the session is resumed later, +// The state of the custom ChatHistoryProvider (SessionDbKey) is stored in the AgentSession's StateBag, so that when the session is resumed later, // the chat history can be retrieved from the custom storage location. using System.Text.Json; @@ -35,9 +35,7 @@ Name = "Joker", ChatHistoryProviderFactory = (ctx, ct) => new ValueTask( // Create a new ChatHistoryProvider for this agent that stores chat history in a vector store. - // Each session must get its own copy of the VectorChatHistoryProvider, since the provider - // also contains the id that the chat history is stored under. - new VectorChatHistoryProvider(vectorStore, ctx.SerializedState, ctx.JsonSerializerOptions)) + new VectorChatHistoryProvider(vectorStore)) }); // Start a new session for the agent conversation. @@ -65,44 +63,62 @@ // We can access the VectorChatHistoryProvider via the session's GetService method if we need to read the key under which chat history is stored. var chatHistoryProvider = resumedSession.GetService()!; -Console.WriteLine($"\nSession is stored in vector store under key: {chatHistoryProvider.SessionDbKey}"); +Console.WriteLine($"\nSession is stored in vector store under key: {chatHistoryProvider.GetSessionDbKey(resumedSession)}"); namespace SampleApp { /// /// A sample implementation of that stores chat history in a vector store. + /// State (the session DB key) is stored in the so it roundtrips + /// automatically with session serialization. /// internal sealed class VectorChatHistoryProvider : ChatHistoryProvider { + private const string DefaultStateBagKey = "VectorChatHistoryProvider.State"; + private readonly VectorStore _vectorStore; + private readonly Func _stateInitializer; + private readonly string _stateKey; - public VectorChatHistoryProvider(VectorStore vectorStore, JsonElement serializedState, JsonSerializerOptions? jsonSerializerOptions = null) + public VectorChatHistoryProvider( + VectorStore vectorStore, + Func? stateInitializer = null, + string? stateKey = null) { this._vectorStore = vectorStore ?? throw new ArgumentNullException(nameof(vectorStore)); + this._stateInitializer = stateInitializer ?? (_ => new State(Guid.NewGuid().ToString("N"))); + this._stateKey = stateKey ?? DefaultStateBagKey; + } - if (serializedState.ValueKind is JsonValueKind.String) + public string GetSessionDbKey(AgentSession session) + => this.GetOrInitializeState(session).SessionDbKey; + + private State GetOrInitializeState(AgentSession? session) + { + var state = session?.StateBag.GetValue(this._stateKey); + state = this._stateInitializer(session); + if (session is not null) { - // Here we can deserialize the session id so that we can access the same messages as before the suspension. - this.SessionDbKey = serializedState.Deserialize(); + session.StateBag.SetValue(this._stateKey, state); } - } - public string? SessionDbKey { get; private set; } + return state; + } public override async ValueTask> InvokingAsync(InvokingContext context, CancellationToken cancellationToken = default) { + var state = this.GetOrInitializeState(context.Session); var collection = this._vectorStore.GetCollection("ChatHistory"); await collection.EnsureCollectionExistsAsync(cancellationToken); var records = await collection .GetAsync( - x => x.SessionId == this.SessionDbKey, 10, + x => x.SessionId == state.SessionDbKey, 10, new() { OrderBy = x => x.Descending(y => y.Timestamp) }, cancellationToken) .ToListAsync(cancellationToken); - var messages = records.ConvertAll(x => JsonSerializer.Deserialize(x.SerializedMessage!)!) -; + var messages = records.ConvertAll(x => JsonSerializer.Deserialize(x.SerializedMessage!)!); messages.Reverse(); return messages; } @@ -115,7 +131,7 @@ public override async ValueTask InvokedAsync(InvokedContext context, Cancellatio return; } - this.SessionDbKey ??= Guid.NewGuid().ToString("N"); + var state = this.GetOrInitializeState(context.Session); var collection = this._vectorStore.GetCollection("ChatHistory"); await collection.EnsureCollectionExistsAsync(cancellationToken); @@ -126,17 +142,33 @@ public override async ValueTask InvokedAsync(InvokedContext context, Cancellatio await collection.UpsertAsync(allNewMessages.Select(x => new ChatHistoryItem() { - Key = this.SessionDbKey + x.MessageId, + Key = state.SessionDbKey + x.MessageId, Timestamp = DateTimeOffset.UtcNow, - SessionId = this.SessionDbKey, + SessionId = state.SessionDbKey, SerializedMessage = JsonSerializer.Serialize(x), MessageText = x.Text }), cancellationToken); } - public override JsonElement Serialize(JsonSerializerOptions? jsonSerializerOptions = null) => - // We have to serialize the session id, so that on deserialization we can retrieve the messages using the same session id. - JsonSerializer.SerializeToElement(this.SessionDbKey); + public override JsonElement Serialize(JsonSerializerOptions? jsonSerializerOptions = null) + { + // State is stored in the session StateBag, so there is nothing to serialize here. + using var doc = JsonDocument.Parse("{}"); + return doc.RootElement.Clone(); + } + + /// + /// Represents the per-session state stored in the . + /// + public sealed class State + { + public State(string sessionDbKey) + { + this.SessionDbKey = sessionDbKey ?? throw new ArgumentNullException(nameof(sessionDbKey)); + } + + public string SessionDbKey { get; } + } /// /// The data structure used to store chat history items in the vector store. From c4d80f1d035ffdf60ac6a0250184d97f14ec08f2 Mon Sep 17 00:00:00 2001 From: westey <164392973+westey-m@users.noreply.github.com> Date: Fri, 6 Feb 2026 15:50:40 +0000 Subject: [PATCH 06/28] Remove serialize method from providers --- .../Program.cs | 2 +- .../Program.cs | 2 +- .../Program.cs | 7 +-- .../Program.cs | 4 +- .../Program.cs | 2 +- .../Program.cs | 2 +- .../Program.cs | 9 +--- .../Agent_Step16_ChatReduction/Program.cs | 2 +- .../Program.cs | 24 +-------- .../AIContextProvider.cs | 13 ----- .../ChatHistoryProvider.cs | 9 ---- .../ChatHistoryProviderMessageFilter.cs | 7 --- .../InMemoryAgentSession.cs | 13 ++--- .../InMemoryChatHistoryProvider.cs | 17 ------ .../CosmosChatHistoryProvider.cs | 22 -------- .../CosmosDBChatExtensions.cs | 6 +-- .../Microsoft.Agents.AI.Mem0/Mem0Provider.cs | 17 ------ .../WorkflowChatHistoryProvider.cs | 7 --- .../ChatClient/ChatClientAgent.cs | 19 ++++--- .../ChatClient/ChatClientAgentOptions.cs | 39 +------------- .../ChatClient/ChatClientAgentSession.cs | 24 +++------ .../Memory/ChatHistoryMemoryProvider.cs | 16 ------ .../Microsoft.Agents.AI/TextSearchProvider.cs | 15 ------ .../AIContextProviderTests.cs | 8 --- .../ChatHistoryProviderMessageFilterTests.cs | 21 -------- .../ChatHistoryProviderTests.cs | 3 -- .../InMemoryChatHistoryProviderTests.cs | 14 ----- ...AzureAIProjectChatClientExtensionsTests.cs | 9 +--- .../Mem0ProviderTests.cs | 16 ------ .../ChatClient/ChatClientAgentOptionsTests.cs | 4 +- .../ChatClient/ChatClientAgentSessionTests.cs | 52 ++----------------- .../ChatClient/ChatClientAgentTests.cs | 20 +++---- ...tClientAgent_ChatHistoryManagementTests.cs | 20 +++---- .../ChatClientAgent_CreateSessionTests.cs | 4 +- ...ChatClientAgent_DeserializeSessionTests.cs | 4 +- .../Data/TextSearchProviderTests.cs | 19 ------- .../Memory/ChatHistoryMemoryProviderTests.cs | 24 --------- 37 files changed, 65 insertions(+), 431 deletions(-) diff --git a/dotnet/samples/GettingStarted/AgentWithMemory/AgentWithMemory_Step01_ChatHistoryMemory/Program.cs b/dotnet/samples/GettingStarted/AgentWithMemory/AgentWithMemory_Step01_ChatHistoryMemory/Program.cs index 4b2f739725..fc6a98df5a 100644 --- a/dotnet/samples/GettingStarted/AgentWithMemory/AgentWithMemory_Step01_ChatHistoryMemory/Program.cs +++ b/dotnet/samples/GettingStarted/AgentWithMemory/AgentWithMemory_Step01_ChatHistoryMemory/Program.cs @@ -34,7 +34,7 @@ { ChatOptions = new() { Instructions = "You are good at telling jokes." }, Name = "Joker", - AIContextProviderFactory = (ctx, ct) => new ValueTask(new ChatHistoryMemoryProvider( + AIContextProviderFactory = (ct) => new ValueTask(new ChatHistoryMemoryProvider( vectorStore, collectionName: "chathistory", vectorDimensions: 3072, diff --git a/dotnet/samples/GettingStarted/AgentWithMemory/AgentWithMemory_Step02_MemoryUsingMem0/Program.cs b/dotnet/samples/GettingStarted/AgentWithMemory/AgentWithMemory_Step02_MemoryUsingMem0/Program.cs index 23154de8f1..6072ab0e4c 100644 --- a/dotnet/samples/GettingStarted/AgentWithMemory/AgentWithMemory_Step02_MemoryUsingMem0/Program.cs +++ b/dotnet/samples/GettingStarted/AgentWithMemory/AgentWithMemory_Step02_MemoryUsingMem0/Program.cs @@ -31,7 +31,7 @@ .AsAIAgent(new ChatClientAgentOptions() { ChatOptions = new() { Instructions = "You are a friendly travel assistant. Use known memories about the user when responding, and do not invent details." }, - AIContextProviderFactory = (ctx, ct) => new ValueTask( + AIContextProviderFactory = (ct) => new ValueTask( // The stateInitializer can be used to customize the Mem0 scope per session and it will be called each time a session // is encountered by the Mem0Provider that does not already have Mem0Provider state stored on the session. // If each session should have its own Mem0 scope, you can create a new id per session via the stateInitializer, e.g.: diff --git a/dotnet/samples/GettingStarted/AgentWithMemory/AgentWithMemory_Step03_CustomMemory/Program.cs b/dotnet/samples/GettingStarted/AgentWithMemory/AgentWithMemory_Step03_CustomMemory/Program.cs index dd451af127..0d1252329d 100644 --- a/dotnet/samples/GettingStarted/AgentWithMemory/AgentWithMemory_Step03_CustomMemory/Program.cs +++ b/dotnet/samples/GettingStarted/AgentWithMemory/AgentWithMemory_Step03_CustomMemory/Program.cs @@ -33,7 +33,7 @@ AIAgent agent = chatClient.AsAIAgent(new ChatClientAgentOptions() { ChatOptions = new() { Instructions = "You are a friendly assistant. Always address the user by their name." }, - AIContextProviderFactory = (ctx, ct) => new ValueTask(new UserInfoMemory(chatClient.AsIChatClient())) + AIContextProviderFactory = (ct) => new ValueTask(new UserInfoMemory(chatClient.AsIChatClient())) }); // Create a new session for the conversation. @@ -146,11 +146,6 @@ userInfo.UserAge is null ? Instructions = instructions.ToString() }); } - - public override JsonElement Serialize(JsonSerializerOptions? jsonSerializerOptions = null) - { - return JsonSerializer.SerializeToElement(new UserInfo(), jsonSerializerOptions); - } } internal sealed class UserInfo diff --git a/dotnet/samples/GettingStarted/AgentWithRAG/AgentWithRAG_Step01_BasicTextRAG/Program.cs b/dotnet/samples/GettingStarted/AgentWithRAG/AgentWithRAG_Step01_BasicTextRAG/Program.cs index c3956b9c13..0eb2fb6436 100644 --- a/dotnet/samples/GettingStarted/AgentWithRAG/AgentWithRAG_Step01_BasicTextRAG/Program.cs +++ b/dotnet/samples/GettingStarted/AgentWithRAG/AgentWithRAG_Step01_BasicTextRAG/Program.cs @@ -62,11 +62,11 @@ .AsAIAgent(new ChatClientAgentOptions { ChatOptions = new() { Instructions = "You are a helpful support specialist for Contoso Outdoors. Answer questions using the provided context and cite the source document when available." }, - AIContextProviderFactory = (ctx, ct) => new ValueTask(new TextSearchProvider(SearchAdapter, textSearchOptions)), + AIContextProviderFactory = (ct) => new ValueTask(new TextSearchProvider(SearchAdapter, textSearchOptions)), // Since we are using ChatCompletion which stores chat history locally, we can also add a message removal policy // that removes messages produced by the TextSearchProvider before they are added to the chat history, so that // we don't bloat chat history with all the search result messages. - ChatHistoryProviderFactory = (ctx, ct) => new ValueTask(new InMemoryChatHistoryProvider() + ChatHistoryProviderFactory = (ct) => new ValueTask(new InMemoryChatHistoryProvider() .WithAIContextProviderMessageRemoval()), }); diff --git a/dotnet/samples/GettingStarted/AgentWithRAG/AgentWithRAG_Step02_CustomVectorStoreRAG/Program.cs b/dotnet/samples/GettingStarted/AgentWithRAG/AgentWithRAG_Step02_CustomVectorStoreRAG/Program.cs index 4ab1a2df6d..17e50999b9 100644 --- a/dotnet/samples/GettingStarted/AgentWithRAG/AgentWithRAG_Step02_CustomVectorStoreRAG/Program.cs +++ b/dotnet/samples/GettingStarted/AgentWithRAG/AgentWithRAG_Step02_CustomVectorStoreRAG/Program.cs @@ -71,7 +71,7 @@ .AsAIAgent(new ChatClientAgentOptions { ChatOptions = new() { Instructions = "You are a helpful support specialist for the Microsoft Agent Framework. Answer questions using the provided context and cite the source document when available. Keep responses brief." }, - AIContextProviderFactory = (ctx, ct) => new ValueTask(new TextSearchProvider(SearchAdapter, textSearchOptions)) + AIContextProviderFactory = (ct) => new ValueTask(new TextSearchProvider(SearchAdapter, textSearchOptions)) }); AgentSession session = await agent.CreateSessionAsync(); diff --git a/dotnet/samples/GettingStarted/AgentWithRAG/AgentWithRAG_Step03_CustomRAGDataSource/Program.cs b/dotnet/samples/GettingStarted/AgentWithRAG/AgentWithRAG_Step03_CustomRAGDataSource/Program.cs index a4b5a10ba1..f1c6256cd2 100644 --- a/dotnet/samples/GettingStarted/AgentWithRAG/AgentWithRAG_Step03_CustomRAGDataSource/Program.cs +++ b/dotnet/samples/GettingStarted/AgentWithRAG/AgentWithRAG_Step03_CustomRAGDataSource/Program.cs @@ -29,7 +29,7 @@ .AsAIAgent(new ChatClientAgentOptions { ChatOptions = new() { Instructions = "You are a helpful support specialist for Contoso Outdoors. Answer questions using the provided context and cite the source document when available." }, - AIContextProviderFactory = (ctx, ct) => new ValueTask(new TextSearchProvider(MockSearchAsync, textSearchOptions)) + AIContextProviderFactory = (ct) => new ValueTask(new TextSearchProvider(MockSearchAsync, textSearchOptions)) }); AgentSession session = await agent.CreateSessionAsync(); diff --git a/dotnet/samples/GettingStarted/Agents/Agent_Step07_3rdPartyChatHistoryStorage/Program.cs b/dotnet/samples/GettingStarted/Agents/Agent_Step07_3rdPartyChatHistoryStorage/Program.cs index f00ad58a5e..418691144a 100644 --- a/dotnet/samples/GettingStarted/Agents/Agent_Step07_3rdPartyChatHistoryStorage/Program.cs +++ b/dotnet/samples/GettingStarted/Agents/Agent_Step07_3rdPartyChatHistoryStorage/Program.cs @@ -33,7 +33,7 @@ { ChatOptions = new() { Instructions = "You are good at telling jokes." }, Name = "Joker", - ChatHistoryProviderFactory = (ctx, ct) => new ValueTask( + ChatHistoryProviderFactory = (ct) => new ValueTask( // Create a new ChatHistoryProvider for this agent that stores chat history in a vector store. new VectorChatHistoryProvider(vectorStore)) }); @@ -150,13 +150,6 @@ public override async ValueTask InvokedAsync(InvokedContext context, Cancellatio }), cancellationToken); } - public override JsonElement Serialize(JsonSerializerOptions? jsonSerializerOptions = null) - { - // State is stored in the session StateBag, so there is nothing to serialize here. - using var doc = JsonDocument.Parse("{}"); - return doc.RootElement.Clone(); - } - /// /// Represents the per-session state stored in the . /// diff --git a/dotnet/samples/GettingStarted/Agents/Agent_Step16_ChatReduction/Program.cs b/dotnet/samples/GettingStarted/Agents/Agent_Step16_ChatReduction/Program.cs index 0c97334d10..83bc33ed25 100644 --- a/dotnet/samples/GettingStarted/Agents/Agent_Step16_ChatReduction/Program.cs +++ b/dotnet/samples/GettingStarted/Agents/Agent_Step16_ChatReduction/Program.cs @@ -24,7 +24,7 @@ { ChatOptions = new() { Instructions = "You are good at telling jokes." }, Name = "Joker", - ChatHistoryProviderFactory = (ctx, ct) => new ValueTask(new InMemoryChatHistoryProvider(chatReducer: new MessageCountingChatReducer(2))) + ChatHistoryProviderFactory = (ct) => new ValueTask(new InMemoryChatHistoryProvider(chatReducer: new MessageCountingChatReducer(2))) }); AgentSession session = await agent.CreateSessionAsync(); diff --git a/dotnet/samples/GettingStarted/Agents/Agent_Step20_AdditionalAIContext/Program.cs b/dotnet/samples/GettingStarted/Agents/Agent_Step20_AdditionalAIContext/Program.cs index bc7a1ababb..7b86959a95 100644 --- a/dotnet/samples/GettingStarted/Agents/Agent_Step20_AdditionalAIContext/Program.cs +++ b/dotnet/samples/GettingStarted/Agents/Agent_Step20_AdditionalAIContext/Program.cs @@ -44,13 +44,13 @@ You are a helpful personal assistant. You manage a TODO list for the user. When the user has completed one of the tasks it can be removed from the TODO list. Only provide the list of TODO items if asked. You remind users of upcoming calendar events when the user interacts with you. """ }, - ChatHistoryProviderFactory = (ctx, ct) => new ValueTask(new InMemoryChatHistoryProvider() + ChatHistoryProviderFactory = (ct) => new ValueTask(new InMemoryChatHistoryProvider() // Use WithAIContextProviderMessageRemoval, so that we don't store the messages from the AI context provider in the chat history. // You may want to store these messages, depending on their content and your requirements. .WithAIContextProviderMessageRemoval()), // Add an AI context provider that maintains a todo list for the agent and one that provides upcoming calendar entries. // Wrap these in an AI context provider that aggregates the other two. - AIContextProviderFactory = (ctx, ct) => new ValueTask(new AggregatingAIContextProvider([ + AIContextProviderFactory = (ct) => new ValueTask(new AggregatingAIContextProvider([ new TodoListAIContextProvider(), new CalendarSearchAIContextProvider(loadNextThreeCalendarEvents) ])), @@ -135,9 +135,6 @@ private static void AddTodoItem(AgentSession? session, string item) items.Add(item); SetTodoItems(session, items); } - - public override JsonElement Serialize(JsonSerializerOptions? jsonSerializerOptions = null) => - JsonSerializer.SerializeToElement(new List(), jsonSerializerOptions); } /// @@ -193,22 +190,5 @@ public override async ValueTask InvokingAsync(InvokingContext context Instructions = string.Join("\n", results.Select(r => r.Instructions).Where(s => !string.IsNullOrEmpty(s))) }; } - - public override JsonElement Serialize(JsonSerializerOptions? jsonSerializerOptions = null) - { - Dictionary elements = new(); - foreach (var provider in this._providers) - { - JsonElement element = provider.Serialize(jsonSerializerOptions); - - // Don't try to store state for any providers that aren't producing any. - if (element.ValueKind != JsonValueKind.Undefined && element.ValueKind != JsonValueKind.Null) - { - elements[provider.GetType().Name] = element; - } - } - - return JsonSerializer.SerializeToElement(elements, jsonSerializerOptions); - } } } diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/AIContextProvider.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/AIContextProvider.cs index f79b0a851d..bc864393eb 100644 --- a/dotnet/src/Microsoft.Agents.AI.Abstractions/AIContextProvider.cs +++ b/dotnet/src/Microsoft.Agents.AI.Abstractions/AIContextProvider.cs @@ -2,7 +2,6 @@ using System; using System.Collections.Generic; -using System.Text.Json; using System.Threading; using System.Threading.Tasks; using Microsoft.Extensions.AI; @@ -74,18 +73,6 @@ public abstract class AIContextProvider public virtual ValueTask InvokedAsync(InvokedContext context, CancellationToken cancellationToken = default) => default; - /// - /// Serializes the current object's state to a using the specified serialization options. - /// - /// The JSON serialization options to use for the serialization process. - /// A representation of the object's state, or a default if the provider has no serializable state. - /// - /// The default implementation returns a default . Override this method if the provider - /// maintains state that should be preserved across sessions or distributed scenarios. - /// - public virtual JsonElement Serialize(JsonSerializerOptions? jsonSerializerOptions = null) - => default; - /// Asks the for an object of the specified type . /// The type of object being requested. /// An optional key that can be used to help identify the target service. diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/ChatHistoryProvider.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/ChatHistoryProvider.cs index 352bae3355..7e36c5fc33 100644 --- a/dotnet/src/Microsoft.Agents.AI.Abstractions/ChatHistoryProvider.cs +++ b/dotnet/src/Microsoft.Agents.AI.Abstractions/ChatHistoryProvider.cs @@ -2,7 +2,6 @@ using System; using System.Collections.Generic; -using System.Text.Json; using System.Threading; using System.Threading.Tasks; using Microsoft.Extensions.AI; @@ -26,7 +25,6 @@ namespace Microsoft.Agents.AI; /// Storing chat messages with proper ordering and metadata preservation /// Retrieving messages in chronological order for agent context /// Managing storage limits through truncation, summarization, or other strategies -/// Supporting serialization for thread persistence and migration /// /// /// @@ -94,13 +92,6 @@ public abstract class ChatHistoryProvider /// public abstract ValueTask InvokedAsync(InvokedContext context, CancellationToken cancellationToken = default); - /// - /// Serializes the current object's state to a using the specified serialization options. - /// - /// The JSON serialization options to use. - /// A representation of the object's state. - public abstract JsonElement Serialize(JsonSerializerOptions? jsonSerializerOptions = null); - /// Asks the for an object of the specified type . /// The type of object being requested. /// An optional key that can be used to help identify the target service. diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/ChatHistoryProviderMessageFilter.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/ChatHistoryProviderMessageFilter.cs index df7b536ea2..4ab8628ac7 100644 --- a/dotnet/src/Microsoft.Agents.AI.Abstractions/ChatHistoryProviderMessageFilter.cs +++ b/dotnet/src/Microsoft.Agents.AI.Abstractions/ChatHistoryProviderMessageFilter.cs @@ -2,7 +2,6 @@ using System; using System.Collections.Generic; -using System.Text.Json; using System.Threading; using System.Threading.Tasks; using Microsoft.Extensions.AI; @@ -65,10 +64,4 @@ public override ValueTask InvokedAsync(InvokedContext context, CancellationToken return this._innerProvider.InvokedAsync(context, cancellationToken); } - - /// - public override JsonElement Serialize(JsonSerializerOptions? jsonSerializerOptions = null) - { - return this._innerProvider.Serialize(jsonSerializerOptions); - } } diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/InMemoryAgentSession.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/InMemoryAgentSession.cs index 6e964914b9..5159431d9c 100644 --- a/dotnet/src/Microsoft.Agents.AI.Abstractions/InMemoryAgentSession.cs +++ b/dotnet/src/Microsoft.Agents.AI.Abstractions/InMemoryAgentSession.cs @@ -63,7 +63,7 @@ protected InMemoryAgentSession(IEnumerable messages) /// A representing the serialized state of the session. /// Optional settings for customizing the JSON deserialization process. /// - /// Optional factory function to create the from its serialized state. + /// Optional factory function to create the . /// If not provided, a default factory will be used that creates a basic . /// /// The is not a JSON object. @@ -75,7 +75,7 @@ protected InMemoryAgentSession(IEnumerable messages) protected InMemoryAgentSession( JsonElement serializedState, JsonSerializerOptions? jsonSerializerOptions = null, - Func? chatHistoryProviderFactory = null) + Func? chatHistoryProviderFactory = null) { if (serializedState.ValueKind != JsonValueKind.Object) { @@ -85,9 +85,7 @@ protected InMemoryAgentSession( var state = serializedState.Deserialize( AgentAbstractionsJsonUtilities.DefaultOptions.GetTypeInfo(typeof(InMemoryAgentSessionState))) as InMemoryAgentSessionState; - this.ChatHistoryProvider = - chatHistoryProviderFactory?.Invoke(state?.ChatHistoryProviderState ?? default, jsonSerializerOptions) ?? - new InMemoryChatHistoryProvider(); + this.ChatHistoryProvider = chatHistoryProviderFactory?.Invoke() ?? new InMemoryChatHistoryProvider(); if (state?.StateBag is { ValueKind: JsonValueKind.Object } stateBagElement) { @@ -107,11 +105,8 @@ protected InMemoryAgentSession( /// A representation of the object's state. protected internal virtual JsonElement Serialize(JsonSerializerOptions? jsonSerializerOptions = null) { - var chatHistoryProviderState = this.ChatHistoryProvider.Serialize(jsonSerializerOptions); - var state = new InMemoryAgentSessionState { - ChatHistoryProviderState = chatHistoryProviderState, StateBag = this.StateBag.Serialize(), }; @@ -127,8 +122,6 @@ protected internal virtual JsonElement Serialize(JsonSerializerOptions? jsonSeri internal sealed class InMemoryAgentSessionState { - public JsonElement? ChatHistoryProviderState { get; set; } - public JsonElement? StateBag { get; set; } } } diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/InMemoryChatHistoryProvider.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/InMemoryChatHistoryProvider.cs index 638d6c2ec4..e968c8154c 100644 --- a/dotnet/src/Microsoft.Agents.AI.Abstractions/InMemoryChatHistoryProvider.cs +++ b/dotnet/src/Microsoft.Agents.AI.Abstractions/InMemoryChatHistoryProvider.cs @@ -3,7 +3,6 @@ using System; using System.Collections.Generic; using System.Linq; -using System.Text.Json; using System.Threading; using System.Threading.Tasks; using Microsoft.Extensions.AI; @@ -157,22 +156,6 @@ public override async ValueTask InvokedAsync(InvokedContext context, Cancellatio } } - /// - /// Serializes the current provider state to a . - /// - /// Optional serializer options (ignored, source generated context is used). - /// An empty object. - /// - /// State is now stored in the and serialized as part of the session. - /// - public override JsonElement Serialize(JsonSerializerOptions? jsonSerializerOptions = null) - { - // State is now stored in the session StateBag, so there is nothing to serialize here. - // Return an empty JSON object. - using var doc = JsonDocument.Parse("{}"); - return doc.RootElement.Clone(); - } - /// /// Represents the state of a stored in the . /// diff --git a/dotnet/src/Microsoft.Agents.AI.CosmosNoSql/CosmosChatHistoryProvider.cs b/dotnet/src/Microsoft.Agents.AI.CosmosNoSql/CosmosChatHistoryProvider.cs index e8d129b25b..1071ac6565 100644 --- a/dotnet/src/Microsoft.Agents.AI.CosmosNoSql/CosmosChatHistoryProvider.cs +++ b/dotnet/src/Microsoft.Agents.AI.CosmosNoSql/CosmosChatHistoryProvider.cs @@ -433,28 +433,6 @@ private CosmosMessageDocument CreateMessageDocument(State state, ChatMessage mes }; } - /// - /// Serializes the current provider state to a . - /// - /// Optional serializer options (ignored). - /// An empty object. - /// - /// State is now stored in the and serialized as part of the session. - /// - public override JsonElement Serialize(JsonSerializerOptions? jsonSerializerOptions = null) - { -#pragma warning disable CA1513 // Use ObjectDisposedException.ThrowIf - not available on all target frameworks - if (this._disposed) - { - throw new ObjectDisposedException(this.GetType().FullName); - } -#pragma warning restore CA1513 - - // State is now stored in the session StateBag, so there is nothing to serialize here. - using var doc = JsonDocument.Parse("{}"); - return doc.RootElement.Clone(); - } - /// /// Gets the count of messages in this conversation. /// This is an additional utility method beyond the base contract. diff --git a/dotnet/src/Microsoft.Agents.AI.CosmosNoSql/CosmosDBChatExtensions.cs b/dotnet/src/Microsoft.Agents.AI.CosmosNoSql/CosmosDBChatExtensions.cs index b64ec9a30c..4ee2c5b5bb 100644 --- a/dotnet/src/Microsoft.Agents.AI.CosmosNoSql/CosmosDBChatExtensions.cs +++ b/dotnet/src/Microsoft.Agents.AI.CosmosNoSql/CosmosDBChatExtensions.cs @@ -41,7 +41,7 @@ public static ChatClientAgentOptions WithCosmosDBChatHistoryProvider( throw new ArgumentNullException(nameof(options)); } - options.ChatHistoryProviderFactory = (context, ct) => new ValueTask( + options.ChatHistoryProviderFactory = (ct) => new ValueTask( new CosmosChatHistoryProvider(connectionString, databaseId, containerId, stateInitializer ?? s_defaultStateInitializer)); return options; } @@ -78,7 +78,7 @@ public static ChatClientAgentOptions WithCosmosDBChatHistoryProviderUsingManaged throw new ArgumentNullException(nameof(tokenCredential)); } - options.ChatHistoryProviderFactory = (context, ct) => new ValueTask( + options.ChatHistoryProviderFactory = (ct) => new ValueTask( new CosmosChatHistoryProvider(accountEndpoint, tokenCredential, databaseId, containerId, stateInitializer ?? s_defaultStateInitializer)); return options; } @@ -108,7 +108,7 @@ public static ChatClientAgentOptions WithCosmosDBChatHistoryProvider( throw new ArgumentNullException(nameof(options)); } - options.ChatHistoryProviderFactory = (context, ct) => new ValueTask( + options.ChatHistoryProviderFactory = (ct) => new ValueTask( new CosmosChatHistoryProvider(cosmosClient, databaseId, containerId, stateInitializer ?? s_defaultStateInitializer)); return options; } diff --git a/dotnet/src/Microsoft.Agents.AI.Mem0/Mem0Provider.cs b/dotnet/src/Microsoft.Agents.AI.Mem0/Mem0Provider.cs index a49c282dac..d26b505a7e 100644 --- a/dotnet/src/Microsoft.Agents.AI.Mem0/Mem0Provider.cs +++ b/dotnet/src/Microsoft.Agents.AI.Mem0/Mem0Provider.cs @@ -4,7 +4,6 @@ using System.Collections.Generic; using System.Linq; using System.Net.Http; -using System.Text.Json; using System.Text.Json.Serialization; using System.Threading; using System.Threading.Tasks; @@ -222,22 +221,6 @@ public Task ClearStoredMemoriesAsync(AgentSession session, CancellationToken can cancellationToken); } - /// - /// Serializes the current provider state to a containing any overridden prompts or descriptions. - /// - /// Optional serializer options (ignored, source generated context is used). - /// An empty object. - /// - /// This method is deprecated. State is now stored in the and serialized as part of the session. - /// - public override JsonElement Serialize(JsonSerializerOptions? jsonSerializerOptions = null) - { - // State is now stored in the session StateBag, so there is nothing to serialize here. - // Return an empty JSON object. - using var doc = JsonDocument.Parse("{}"); - return doc.RootElement.Clone(); - } - private async Task PersistMessagesAsync(Mem0ProviderScope storageScope, IEnumerable messages, CancellationToken cancellationToken) { foreach (var message in messages) diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows/WorkflowChatHistoryProvider.cs b/dotnet/src/Microsoft.Agents.AI.Workflows/WorkflowChatHistoryProvider.cs index 8528bda3b5..708100b5ba 100644 --- a/dotnet/src/Microsoft.Agents.AI.Workflows/WorkflowChatHistoryProvider.cs +++ b/dotnet/src/Microsoft.Agents.AI.Workflows/WorkflowChatHistoryProvider.cs @@ -2,7 +2,6 @@ using System.Collections.Generic; using System.Linq; -using System.Text.Json; using System.Threading; using System.Threading.Tasks; using Microsoft.Extensions.AI; @@ -74,10 +73,4 @@ public void UpdateBookmark(AgentSession session) var state = this.GetOrInitializeState(session); state.Bookmark = state.Messages.Count; } - - public override JsonElement Serialize(JsonSerializerOptions? jsonSerializerOptions = null) - { - return JsonSerializer.SerializeToElement(null, - WorkflowsJsonUtilities.DefaultOptions.GetTypeInfo(typeof(StoreState))); - } } diff --git a/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgent.cs b/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgent.cs index 23e120f14f..1048ba3998 100644 --- a/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgent.cs +++ b/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgent.cs @@ -305,11 +305,11 @@ protected override async IAsyncEnumerable RunCoreStreamingA public override async ValueTask CreateSessionAsync(CancellationToken cancellationToken = default) { ChatHistoryProvider? chatHistoryProvider = this._agentOptions?.ChatHistoryProviderFactory is not null - ? await this._agentOptions.ChatHistoryProviderFactory.Invoke(new() { SerializedState = default, JsonSerializerOptions = null }, cancellationToken).ConfigureAwait(false) + ? await this._agentOptions.ChatHistoryProviderFactory.Invoke(cancellationToken).ConfigureAwait(false) : null; AIContextProvider? contextProvider = this._agentOptions?.AIContextProviderFactory is not null - ? await this._agentOptions.AIContextProviderFactory.Invoke(new() { SerializedState = default, JsonSerializerOptions = null }, cancellationToken).ConfigureAwait(false) + ? await this._agentOptions.AIContextProviderFactory.Invoke(cancellationToken).ConfigureAwait(false) : null; return new ChatClientAgentSession @@ -340,7 +340,7 @@ public override async ValueTask CreateSessionAsync(CancellationTok public async ValueTask CreateSessionAsync(string conversationId, CancellationToken cancellationToken = default) { AIContextProvider? contextProvider = this._agentOptions?.AIContextProviderFactory is not null - ? await this._agentOptions.AIContextProviderFactory.Invoke(new() { SerializedState = default, JsonSerializerOptions = null }, cancellationToken).ConfigureAwait(false) + ? await this._agentOptions.AIContextProviderFactory.Invoke(cancellationToken).ConfigureAwait(false) : null; return new ChatClientAgentSession() @@ -375,7 +375,7 @@ public async ValueTask CreateSessionAsync(string conversationId, C public async ValueTask CreateSessionAsync(ChatHistoryProvider chatHistoryProvider, CancellationToken cancellationToken = default) { AIContextProvider? contextProvider = this._agentOptions?.AIContextProviderFactory is not null - ? await this._agentOptions.AIContextProviderFactory.Invoke(new() { SerializedState = default, JsonSerializerOptions = null }, cancellationToken).ConfigureAwait(false) + ? await this._agentOptions.AIContextProviderFactory.Invoke(cancellationToken).ConfigureAwait(false) : null; return new ChatClientAgentSession() @@ -401,17 +401,16 @@ public override JsonElement SerializeSession(AgentSession session, JsonSerialize /// public override async ValueTask DeserializeSessionAsync(JsonElement serializedState, JsonSerializerOptions? jsonSerializerOptions = null, CancellationToken cancellationToken = default) { - Func>? chatHistoryProviderFactory = this._agentOptions?.ChatHistoryProviderFactory is null ? + Func>? chatHistoryProviderFactory = this._agentOptions?.ChatHistoryProviderFactory is null ? null : - (jse, jso, ct) => this._agentOptions.ChatHistoryProviderFactory.Invoke(new() { SerializedState = jse, JsonSerializerOptions = jso }, ct); + (ct) => this._agentOptions.ChatHistoryProviderFactory.Invoke(ct); - Func>? aiContextProviderFactory = this._agentOptions?.AIContextProviderFactory is null ? + Func>? aiContextProviderFactory = this._agentOptions?.AIContextProviderFactory is null ? null : - (jse, jso, ct) => this._agentOptions.AIContextProviderFactory.Invoke(new() { SerializedState = jse, JsonSerializerOptions = jso }, ct); + (ct) => this._agentOptions.AIContextProviderFactory.Invoke(ct); return await ChatClientAgentSession.DeserializeAsync( serializedState, - jsonSerializerOptions, chatHistoryProviderFactory, aiContextProviderFactory, cancellationToken).ConfigureAwait(false); @@ -807,7 +806,7 @@ private async Task UpdateSessionWithTypeAndConversationIdAsync(ChatClientAgentSe // the session has no ChatHistoryProvider yet, we should update the session with the custom ChatHistoryProvider or // default InMemoryChatHistoryProvider so that it has somewhere to store the chat history. session.ChatHistoryProvider ??= this._agentOptions?.ChatHistoryProviderFactory is not null - ? await this._agentOptions.ChatHistoryProviderFactory.Invoke(new() { SerializedState = default, JsonSerializerOptions = null }, cancellationToken).ConfigureAwait(false) + ? await this._agentOptions.ChatHistoryProviderFactory.Invoke(cancellationToken).ConfigureAwait(false) : new InMemoryChatHistoryProvider(); } } diff --git a/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgentOptions.cs b/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgentOptions.cs index 6f8451e2b8..84d64d67e1 100644 --- a/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgentOptions.cs +++ b/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgentOptions.cs @@ -1,7 +1,6 @@ // Copyright (c) Microsoft. All rights reserved. using System; -using System.Text.Json; using System.Threading; using System.Threading.Tasks; using Microsoft.Extensions.AI; @@ -42,14 +41,14 @@ public sealed class ChatClientAgentOptions /// Gets or sets a factory function to create an instance of /// which will be used to provide chat history for this agent. /// - public Func>? ChatHistoryProviderFactory { get; set; } + public Func>? ChatHistoryProviderFactory { get; set; } /// /// Gets or sets a factory function to create an instance of /// which will be used to create a context provider for each new thread, and can then /// provide additional context for each agent run. /// - public Func>? AIContextProviderFactory { get; set; } + public Func>? AIContextProviderFactory { get; set; } /// /// Gets or sets a value indicating whether to use the provided instance as is, @@ -78,38 +77,4 @@ public ChatClientAgentOptions Clone() ChatHistoryProviderFactory = this.ChatHistoryProviderFactory, AIContextProviderFactory = this.AIContextProviderFactory, }; - - /// - /// Context object passed to the to create a new instance of . - /// - public sealed class AIContextProviderFactoryContext - { - /// - /// Gets or sets the serialized state of the , if any. - /// - /// if there is no state, e.g. when the is first created. - public JsonElement SerializedState { get; set; } - - /// - /// Gets or sets the JSON serialization options to use when deserializing the . - /// - public JsonSerializerOptions? JsonSerializerOptions { get; set; } - } - - /// - /// Context object passed to the to create a new instance of . - /// - public sealed class ChatHistoryProviderFactoryContext - { - /// - /// Gets or sets the serialized state of the , if any. - /// - /// if there is no state, e.g. when the is first created. - public JsonElement SerializedState { get; set; } - - /// - /// Gets or sets the JSON serialization options to use when deserializing the . - /// - public JsonSerializerOptions? JsonSerializerOptions { get; set; } - } } diff --git a/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgentSession.cs b/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgentSession.cs index a5ace55884..935cf89323 100644 --- a/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgentSession.cs +++ b/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgentSession.cs @@ -116,22 +116,20 @@ internal set /// Creates a new instance of the class from previously serialized state. /// /// A representing the serialized state of the session. - /// Optional settings for customizing the JSON deserialization process. /// - /// An optional factory function to create a custom from its serialized state. + /// An optional factory function to create a custom . /// If not provided, the default will be used. /// /// - /// An optional factory function to create a custom from its serialized state. + /// An optional factory function to create a custom . /// If not provided, no context provider will be configured. /// /// The to monitor for cancellation requests. /// A task representing the asynchronous operation. The task result contains the deserialized . internal static async Task DeserializeAsync( JsonElement serializedState, - JsonSerializerOptions? jsonSerializerOptions = null, - Func>? chatHistoryProviderFactory = null, - Func>? aiContextProviderFactory = null, + Func>? chatHistoryProviderFactory = null, + Func>? aiContextProviderFactory = null, CancellationToken cancellationToken = default) { if (serializedState.ValueKind != JsonValueKind.Object) @@ -145,7 +143,7 @@ internal static async Task DeserializeAsync( var session = new ChatClientAgentSession(); session.AIContextProvider = aiContextProviderFactory is not null - ? await aiContextProviderFactory.Invoke(state?.AIContextProviderState ?? default, jsonSerializerOptions, cancellationToken).ConfigureAwait(false) + ? await aiContextProviderFactory.Invoke(cancellationToken).ConfigureAwait(false) : null; session.StateBag = AgentSessionStateBag.Deserialize(state?.StateBag ?? default); @@ -160,7 +158,7 @@ internal static async Task DeserializeAsync( session._chatHistoryProvider = chatHistoryProviderFactory is not null - ? await chatHistoryProviderFactory.Invoke(state?.ChatHistoryProviderState ?? default, jsonSerializerOptions, cancellationToken).ConfigureAwait(false) + ? await chatHistoryProviderFactory.Invoke(cancellationToken).ConfigureAwait(false) : new InMemoryChatHistoryProvider(); // default to an in-memory ChatHistoryProvider return session; @@ -169,15 +167,9 @@ chatHistoryProviderFactory is not null /// internal JsonElement Serialize(JsonSerializerOptions? jsonSerializerOptions = null) { - JsonElement? chatHistoryProviderState = this._chatHistoryProvider?.Serialize(jsonSerializerOptions); - - JsonElement? aiContextProviderState = this.AIContextProvider?.Serialize(jsonSerializerOptions); - var state = new SessionState { ConversationId = this.ConversationId, - ChatHistoryProviderState = chatHistoryProviderState is { ValueKind: not JsonValueKind.Undefined } ? chatHistoryProviderState : null, - AIContextProviderState = aiContextProviderState is { ValueKind: not JsonValueKind.Undefined } ? aiContextProviderState : null, StateBag = this.StateBag.Serialize(), }; @@ -201,10 +193,6 @@ internal sealed class SessionState { public string? ConversationId { get; set; } - public JsonElement? ChatHistoryProviderState { get; set; } - - public JsonElement? AIContextProviderState { get; set; } - public JsonElement? StateBag { get; set; } } } diff --git a/dotnet/src/Microsoft.Agents.AI/Memory/ChatHistoryMemoryProvider.cs b/dotnet/src/Microsoft.Agents.AI/Memory/ChatHistoryMemoryProvider.cs index 8e9eade875..8f7bd02648 100644 --- a/dotnet/src/Microsoft.Agents.AI/Memory/ChatHistoryMemoryProvider.cs +++ b/dotnet/src/Microsoft.Agents.AI/Memory/ChatHistoryMemoryProvider.cs @@ -446,22 +446,6 @@ public void Dispose() GC.SuppressFinalize(this); } - /// - /// Serializes the current provider state to a containing any overridden prompts or descriptions. - /// - /// Optional serializer options (ignored, source generated context is used). - /// An empty object. - /// - /// This method is deprecated. State is now stored in the and serialized as part of the session. - /// - public override JsonElement Serialize(JsonSerializerOptions? jsonSerializerOptions = null) - { - // State is now stored in the session StateBag, so there is nothing to serialize here. - // Return an empty JSON object. - using var doc = JsonDocument.Parse("{}"); - return doc.RootElement.Clone(); - } - private string? SanitizeLogData(string? data) => this._enableSensitiveTelemetryData ? data : ""; /// diff --git a/dotnet/src/Microsoft.Agents.AI/TextSearchProvider.cs b/dotnet/src/Microsoft.Agents.AI/TextSearchProvider.cs index 3ae7350dcb..91c597223f 100644 --- a/dotnet/src/Microsoft.Agents.AI/TextSearchProvider.cs +++ b/dotnet/src/Microsoft.Agents.AI/TextSearchProvider.cs @@ -4,7 +4,6 @@ using System.Collections.Generic; using System.Linq; using System.Text; -using System.Text.Json; using System.Threading; using System.Threading.Tasks; using Microsoft.Extensions.AI; @@ -196,20 +195,6 @@ public override ValueTask InvokedAsync(InvokedContext context, CancellationToken return default; } - /// - /// Serializes the current provider state to a containing any overridden prompts or descriptions. - /// - /// Optional serializer options (ignored, source generated context is used). - /// A with overridden values, or default if nothing was overridden. - /// - /// This method is deprecated. State is now stored in the and serialized as part of the session. - /// - public override JsonElement Serialize(JsonSerializerOptions? jsonSerializerOptions = null) - { - // State is now stored in the session StateBag, so there is nothing to serialize here. - return JsonSerializer.SerializeToElement(new TextSearchProviderState(), AgentJsonUtilities.DefaultOptions.GetTypeInfo(typeof(TextSearchProviderState))); - } - /// /// Function callable by the AI model (when enabled) to perform an ad-hoc search. /// diff --git a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AIContextProviderTests.cs b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AIContextProviderTests.cs index 94aa73858a..d8820f284b 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AIContextProviderTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AIContextProviderTests.cs @@ -24,14 +24,6 @@ public async Task InvokedAsync_ReturnsCompletedTaskAsync() Assert.Equal(default, task); } - [Fact] - public void Serialize_ReturnsEmptyElement() - { - var provider = new TestAIContextProvider(); - var actual = provider.Serialize(); - Assert.Equal(default, actual); - } - [Fact] public void InvokingContext_Constructor_ThrowsForNullMessages() { diff --git a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/ChatHistoryProviderMessageFilterTests.cs b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/ChatHistoryProviderMessageFilterTests.cs index 4b955a43c0..f91c329cd7 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/ChatHistoryProviderMessageFilterTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/ChatHistoryProviderMessageFilterTests.cs @@ -184,25 +184,4 @@ ChatHistoryProvider.InvokedContext InvokedFilter(ChatHistoryProvider.InvokedCont Assert.Equal("[FILTERED] Hello", capturedContext.RequestMessages.First().Text); innerProviderMock.Verify(s => s.InvokedAsync(It.IsAny(), It.IsAny()), Times.Once); } - - [Fact] - public void Serialize_DelegatesToInnerProvider() - { - // Arrange - var innerProviderMock = new Mock(); - var expectedJson = JsonSerializer.SerializeToElement("data", TestJsonSerializerContext.Default.String); - - innerProviderMock - .Setup(s => s.Serialize(It.IsAny())) - .Returns(expectedJson); - - var filter = new ChatHistoryProviderMessageFilter(innerProviderMock.Object, x => x, x => x); - - // Act - var result = filter.Serialize(); - - // Assert - Assert.Equal(expectedJson.GetRawText(), result.GetRawText()); - innerProviderMock.Verify(s => s.Serialize(null), Times.Once); - } } diff --git a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/ChatHistoryProviderTests.cs b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/ChatHistoryProviderTests.cs index 5e0fbe9817..38312d3e80 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/ChatHistoryProviderTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/ChatHistoryProviderTests.cs @@ -319,8 +319,5 @@ public override ValueTask> InvokingAsync(InvokingContex public override ValueTask InvokedAsync(InvokedContext context, CancellationToken cancellationToken = default) => default; - - public override JsonElement Serialize(JsonSerializerOptions? jsonSerializerOptions = null) - => default; } } diff --git a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/InMemoryChatHistoryProviderTests.cs b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/InMemoryChatHistoryProviderTests.cs index d3346b3ed7..fb81b0fc73 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/InMemoryChatHistoryProviderTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/InMemoryChatHistoryProviderTests.cs @@ -370,20 +370,6 @@ public async Task InvokingAsync_WithNullContext_ThrowsArgumentNullExceptionAsync await Assert.ThrowsAsync(() => provider.InvokingAsync(null!, CancellationToken.None).AsTask()); } - [Fact] - public void Serialize_ReturnsEmptyJsonObject() - { - // Arrange - var provider = new InMemoryChatHistoryProvider(); - - // Act - var result = provider.Serialize(); - - // Assert - Assert.Equal(System.Text.Json.JsonValueKind.Object, result.ValueKind); - Assert.Empty(result.EnumerateObject()); - } - public class TestAIContent(string testData) : AIContent { public string TestData => testData; diff --git a/dotnet/tests/Microsoft.Agents.AI.AzureAI.UnitTests/AzureAIProjectChatClientExtensionsTests.cs b/dotnet/tests/Microsoft.Agents.AI.AzureAI.UnitTests/AzureAIProjectChatClientExtensionsTests.cs index 447c195c83..36222afe32 100644 --- a/dotnet/tests/Microsoft.Agents.AI.AzureAI.UnitTests/AzureAIProjectChatClientExtensionsTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.AzureAI.UnitTests/AzureAIProjectChatClientExtensionsTests.cs @@ -2322,7 +2322,7 @@ public async Task GetAIAgentAsync_WithAIContextProviderFactory_PreservesFactoryA { Name = "test-agent", ChatOptions = new ChatOptions { Instructions = "Test" }, - AIContextProviderFactory = (_, _) => + AIContextProviderFactory = (_) => { factoryInvoked = true; return new ValueTask(new TestAIContextProvider()); @@ -2350,7 +2350,7 @@ public async Task GetAIAgentAsync_WithChatHistoryProviderFactory_PreservesFactor { Name = "test-agent", ChatOptions = new ChatOptions { Instructions = "Test" }, - ChatHistoryProviderFactory = (_, _) => new ValueTask(new TestChatHistoryProvider()) + ChatHistoryProviderFactory = (_) => new ValueTask(new TestChatHistoryProvider()) }; // Act @@ -3160,11 +3160,6 @@ public override ValueTask InvokedAsync(InvokedContext context, CancellationToken { return default; } - - public override JsonElement Serialize(JsonSerializerOptions? jsonSerializerOptions = null) - { - return default; - } } } diff --git a/dotnet/tests/Microsoft.Agents.AI.Mem0.UnitTests/Mem0ProviderTests.cs b/dotnet/tests/Microsoft.Agents.AI.Mem0.UnitTests/Mem0ProviderTests.cs index 71a4c4fbd3..6babc8d7b8 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Mem0.UnitTests/Mem0ProviderTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Mem0.UnitTests/Mem0ProviderTests.cs @@ -341,22 +341,6 @@ public async Task ClearStoredMemoriesAsync_SendsDeleteWithQueryAsync() Assert.Equal("https://localhost/v1/memories/?app_id=app&agent_id=agent&run_id=session&user_id=user", delete.RequestMessage.RequestUri!.AbsoluteUri); } - [Fact] - public void Serialize_ReturnsEmptyJsonObject() - { - // Arrange - var storageScope = new Mem0ProviderScope { ApplicationId = "app", AgentId = "agent", ThreadId = "session", UserId = "user" }; - var sut = new Mem0Provider(this._httpClient, _ => new Mem0Provider.State(storageScope), options: new() { ContextPrompt = "Custom:" }, loggerFactory: this._loggerFactoryMock.Object); - - // Act - var stateElement = sut.Serialize(); - - // Assert - using JsonDocument doc = JsonDocument.Parse(stateElement.GetRawText()); - Assert.Equal(JsonValueKind.Object, doc.RootElement.ValueKind); - Assert.Empty(doc.RootElement.EnumerateObject()); - } - [Fact] public async Task InvokingAsync_ShouldNotThrow_WhenSearchFailsAsync() { diff --git a/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgentOptionsTests.cs b/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgentOptionsTests.cs index 8502550d2c..2f89c55d62 100644 --- a/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgentOptionsTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgentOptionsTests.cs @@ -118,10 +118,10 @@ public void Clone_CreatesDeepCopyWithSameValues() var tools = new List { AIFunctionFactory.Create(() => "test") }; static ValueTask ChatHistoryProviderFactoryAsync( - ChatClientAgentOptions.ChatHistoryProviderFactoryContext ctx, CancellationToken ct) => new(new Mock().Object); + CancellationToken ct) => new(new Mock().Object); static ValueTask AIContextProviderFactoryAsync( - ChatClientAgentOptions.AIContextProviderFactoryContext ctx, CancellationToken ct) => new(new Mock().Object); + CancellationToken ct) => new(new Mock().Object); var original = new ChatClientAgentOptions() { diff --git a/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgentSessionTests.cs b/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgentSessionTests.cs index 977b9077d7..81b213e488 100644 --- a/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgentSessionTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgentSessionTests.cs @@ -1,7 +1,6 @@ // Copyright (c) Microsoft. All rights reserved. using System; -using System.Collections.Generic; using System.Linq; using System.Text.Json; using System.Threading.Tasks; @@ -150,7 +149,7 @@ public async Task VerifyDeserializeWithAIContextProviderAsync() Mock mockProvider = new(); // Act - var session = await ChatClientAgentSession.DeserializeAsync(json, aiContextProviderFactory: (_, _, _) => new(mockProvider.Object)); + var session = await ChatClientAgentSession.DeserializeAsync(json, aiContextProviderFactory: (_) => new(mockProvider.Object)); // Assert Assert.Null(session.ChatHistoryProvider); @@ -176,7 +175,7 @@ public async Task VerifyDeserializeWithStateBagAsync() Mock mockProvider = new(); // Act - var session = await ChatClientAgentSession.DeserializeAsync(json, aiContextProviderFactory: (_, _, _) => new(mockProvider.Object)); + var session = await ChatClientAgentSession.DeserializeAsync(json, aiContextProviderFactory: (_) => new(mockProvider.Object)); // Assert var dog = session.StateBag.GetValue("dog", TestJsonSerializerContext.Default.Options); @@ -239,10 +238,6 @@ public void VerifySessionSerializationWithMessages() Assert.False(json.TryGetProperty("conversationId", out _)); - // chatHistoryProviderState should be an empty JSON object (state is in StateBag now) - Assert.True(json.TryGetProperty("chatHistoryProviderState", out var chatHistoryProviderStateProperty)); - Assert.Equal(JsonValueKind.Object, chatHistoryProviderStateProperty.ValueKind); - // Messages should be stored in the stateBag Assert.True(json.TryGetProperty("stateBag", out var stateBagProperty)); Assert.Equal(JsonValueKind.Object, stateBagProperty.ValueKind); @@ -264,32 +259,6 @@ public void VerifySessionSerializationWithMessages() Assert.Equal("TestContent", textContent.GetProperty("text").GetString()); } - [Fact] - public void VerifySessionSerializationWithWithAIContextProvider() - { - // Arrange - Mock mockProvider = new(); - mockProvider - .Setup(m => m.Serialize(It.IsAny())) - .Returns(JsonSerializer.SerializeToElement(["CP1"], TestJsonSerializerContext.Default.StringArray)); - - var session = new ChatClientAgentSession - { - AIContextProvider = mockProvider.Object - }; - - // Act - var json = session.Serialize(); - - // Assert - Assert.Equal(JsonValueKind.Object, json.ValueKind); - Assert.True(json.TryGetProperty("aiContextProviderState", out var providerStateProperty)); - Assert.Equal(JsonValueKind.Array, providerStateProperty.ValueKind); - Assert.Single(providerStateProperty.EnumerateArray()); - Assert.Equal("CP1", providerStateProperty.EnumerateArray().First().GetString()); - mockProvider.Verify(m => m.Serialize(It.IsAny()), Times.Once); - } - [Fact] public void VerifySessionSerializationWithWithStateBag() { @@ -322,14 +291,7 @@ public void VerifySessionSerializationWithCustomOptions() JsonSerializerOptions options = new() { PropertyNamingPolicy = JsonNamingPolicy.SnakeCaseLower }; options.TypeInfoResolverChain.Add(AgentAbstractionsJsonUtilities.DefaultOptions.TypeInfoResolver!); - var chatHistoryProviderStateElement = JsonSerializer.SerializeToElement( - new Dictionary { ["Key"] = "TestValue" }, - TestJsonSerializerContext.Default.DictionaryStringObject); - var chatHistoryProviderMock = new Mock(); - chatHistoryProviderMock - .Setup(m => m.Serialize(options)) - .Returns(chatHistoryProviderStateElement); session.ChatHistoryProvider = chatHistoryProviderMock.Object; // Act @@ -338,15 +300,7 @@ public void VerifySessionSerializationWithCustomOptions() // Assert Assert.Equal(JsonValueKind.Object, json.ValueKind); - Assert.False(json.TryGetProperty("conversationId", out var idProperty)); - - Assert.True(json.TryGetProperty("chatHistoryProviderState", out var chatHistoryProviderStateProperty)); - Assert.Equal(JsonValueKind.Object, chatHistoryProviderStateProperty.ValueKind); - - Assert.True(chatHistoryProviderStateProperty.TryGetProperty("Key", out var keyProperty)); - Assert.Equal("TestValue", keyProperty.GetString()); - - chatHistoryProviderMock.Verify(m => m.Serialize(options), Times.Once); + Assert.False(json.TryGetProperty("conversationId", out var _)); } #endregion Serialize Tests diff --git a/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgentTests.cs b/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgentTests.cs index 7668bbdb70..e8c34e13a3 100644 --- a/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgentTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgentTests.cs @@ -353,7 +353,7 @@ public async Task RunAsyncInvokesAIContextProviderAndUsesResultAsync() .Setup(p => p.InvokedAsync(It.IsAny(), It.IsAny())) .Returns(new ValueTask()); - ChatClientAgent agent = new(mockService.Object, options: new() { AIContextProviderFactory = (_, _) => new(mockProvider.Object), ChatOptions = new() { Instructions = "base instructions", Tools = [AIFunctionFactory.Create(() => { }, "base function")] } }); + ChatClientAgent agent = new(mockService.Object, options: new() { AIContextProviderFactory = (_) => new(mockProvider.Object), ChatOptions = new() { Instructions = "base instructions", Tools = [AIFunctionFactory.Create(() => { }, "base function")] } }); // Act var session = await agent.CreateSessionAsync() as ChatClientAgentSession; @@ -416,7 +416,7 @@ public async Task RunAsyncInvokesAIContextProviderWhenGetResponseFailsAsync() .Setup(p => p.InvokedAsync(It.IsAny(), It.IsAny())) .Returns(new ValueTask()); - ChatClientAgent agent = new(mockService.Object, options: new() { AIContextProviderFactory = (_, _) => new(mockProvider.Object), ChatOptions = new() { Instructions = "base instructions", Tools = [AIFunctionFactory.Create(() => { }, "base function")] } }); + ChatClientAgent agent = new(mockService.Object, options: new() { AIContextProviderFactory = (_) => new(mockProvider.Object), ChatOptions = new() { Instructions = "base instructions", Tools = [AIFunctionFactory.Create(() => { }, "base function")] } }); // Act await Assert.ThrowsAsync(() => agent.RunAsync(requestMessages)); @@ -462,7 +462,7 @@ public async Task RunAsyncInvokesAIContextProviderAndSucceedsWithEmptyAIContextA .Setup(p => p.InvokingAsync(It.IsAny(), It.IsAny())) .ReturnsAsync(new AIContext()); - ChatClientAgent agent = new(mockService.Object, options: new() { AIContextProviderFactory = (_, _) => new(mockProvider.Object), ChatOptions = new() { Instructions = "base instructions", Tools = [AIFunctionFactory.Create(() => { }, "base function")] } }); + ChatClientAgent agent = new(mockService.Object, options: new() { AIContextProviderFactory = (_) => new(mockProvider.Object), ChatOptions = new() { Instructions = "base instructions", Tools = [AIFunctionFactory.Create(() => { }, "base function")] } }); // Act await agent.RunAsync([new(ChatRole.User, "user message")]); @@ -1288,8 +1288,8 @@ public async Task RunStreamingAsyncUsesChatHistoryProviderWhenNoConversationIdRe It.IsAny>(), It.IsAny(), It.IsAny())).Returns(ToAsyncEnumerableAsync(returnUpdates)); - Mock>> mockFactory = new(); - mockFactory.Setup(f => f(It.IsAny(), It.IsAny())).ReturnsAsync(new InMemoryChatHistoryProvider()); + Mock>> mockFactory = new(); + mockFactory.Setup(f => f(It.IsAny())).ReturnsAsync(new InMemoryChatHistoryProvider()); ChatClientAgent agent = new(mockService.Object, options: new() { ChatOptions = new() { Instructions = "test instructions" }, @@ -1306,7 +1306,7 @@ public async Task RunStreamingAsyncUsesChatHistoryProviderWhenNoConversationIdRe Assert.Equal(2, historyMessages.Count); Assert.Equal("test", historyMessages[0].Text); Assert.Equal("what?", historyMessages[1].Text); - mockFactory.Verify(f => f(It.IsAny(), It.IsAny()), Times.Once); + mockFactory.Verify(f => f(It.IsAny()), Times.Once); } /// @@ -1327,8 +1327,8 @@ public async Task RunStreamingAsyncThrowsWhenChatHistoryProviderFactoryProvidedA It.IsAny>(), It.IsAny(), It.IsAny())).Returns(ToAsyncEnumerableAsync(returnUpdates)); - Mock>> mockFactory = new(); - mockFactory.Setup(f => f(It.IsAny(), It.IsAny())).ReturnsAsync(new InMemoryChatHistoryProvider()); + Mock>> mockFactory = new(); + mockFactory.Setup(f => f(It.IsAny())).ReturnsAsync(new InMemoryChatHistoryProvider()); ChatClientAgent agent = new(mockService.Object, options: new() { ChatOptions = new() { Instructions = "test instructions" }, @@ -1389,7 +1389,7 @@ public async Task RunStreamingAsyncInvokesAIContextProviderAndUsesResultAsync() options: new() { ChatOptions = new() { Instructions = "base instructions", Tools = [AIFunctionFactory.Create(() => { }, "base function")] }, - AIContextProviderFactory = (_, _) => new(mockProvider.Object) + AIContextProviderFactory = (_) => new(mockProvider.Object) }); // Act @@ -1459,7 +1459,7 @@ public async Task RunStreamingAsyncInvokesAIContextProviderWhenGetResponseFailsA options: new() { ChatOptions = new() { Instructions = "base instructions", Tools = [AIFunctionFactory.Create(() => { }, "base function")] }, - AIContextProviderFactory = (_, _) => new(mockProvider.Object) + AIContextProviderFactory = (_) => new(mockProvider.Object) }); // Act diff --git a/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgent_ChatHistoryManagementTests.cs b/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgent_ChatHistoryManagementTests.cs index 3e471e38a5..7e0538ec17 100644 --- a/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgent_ChatHistoryManagementTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgent_ChatHistoryManagementTests.cs @@ -191,8 +191,8 @@ public async Task RunAsync_UsesChatHistoryProviderFactory_WhenProvidedAndNoConve It.IsAny(), It.IsAny())).Returns(new ValueTask()); - Mock>> mockFactory = new(); - mockFactory.Setup(f => f(It.IsAny(), It.IsAny())).ReturnsAsync(mockChatHistoryProvider.Object); + Mock>> mockFactory = new(); + mockFactory.Setup(f => f(It.IsAny())).ReturnsAsync(mockChatHistoryProvider.Object); ChatClientAgent agent = new(mockService.Object, options: new() { @@ -220,7 +220,7 @@ public async Task RunAsync_UsesChatHistoryProviderFactory_WhenProvidedAndNoConve It.Is(x => x.RequestMessages.Count() == 1 && x.ChatHistoryProviderMessages != null && x.ChatHistoryProviderMessages.Count() == 1 && x.ResponseMessages!.Count() == 1), It.IsAny()), Times.Once); - mockFactory.Verify(f => f(It.IsAny(), It.IsAny()), Times.Once); + mockFactory.Verify(f => f(It.IsAny()), Times.Once); } /// @@ -239,8 +239,8 @@ public async Task RunAsync_NotifiesChatHistoryProvider_OnFailureAsync() Mock mockChatHistoryProvider = new(); - Mock>> mockFactory = new(); - mockFactory.Setup(f => f(It.IsAny(), It.IsAny())).ReturnsAsync(mockChatHistoryProvider.Object); + Mock>> mockFactory = new(); + mockFactory.Setup(f => f(It.IsAny())).ReturnsAsync(mockChatHistoryProvider.Object); ChatClientAgent agent = new(mockService.Object, options: new() { @@ -258,7 +258,7 @@ public async Task RunAsync_NotifiesChatHistoryProvider_OnFailureAsync() It.Is(x => x.RequestMessages.Count() == 1 && x.ResponseMessages == null && x.InvokeException!.Message == "Test Error"), It.IsAny()), Times.Once); - mockFactory.Verify(f => f(It.IsAny(), It.IsAny()), Times.Once); + mockFactory.Verify(f => f(It.IsAny()), Times.Once); } /// @@ -274,8 +274,8 @@ public async Task RunAsync_Throws_WhenChatHistoryProviderFactoryProvidedAndConve It.IsAny>(), It.IsAny(), It.IsAny())).ReturnsAsync(new ChatResponse([new(ChatRole.Assistant, "response")]) { ConversationId = "ConvId" }); - Mock>> mockFactory = new(); - mockFactory.Setup(f => f(It.IsAny(), It.IsAny())).ReturnsAsync(new InMemoryChatHistoryProvider()); + Mock>> mockFactory = new(); + mockFactory.Setup(f => f(It.IsAny())).ReturnsAsync(new InMemoryChatHistoryProvider()); ChatClientAgent agent = new(mockService.Object, options: new() { ChatOptions = new() { Instructions = "test instructions" }, @@ -326,8 +326,8 @@ public async Task RunAsync_UsesOverrideChatHistoryProvider_WhenProvidedViaAdditi It.IsAny(), It.IsAny())).Throws(FailException.ForFailure("Base ChatHistoryProvider shouldn't be used.")); - Mock>> mockFactory = new(); - mockFactory.Setup(f => f(It.IsAny(), It.IsAny())).ReturnsAsync(mockFactoryChatHistoryProvider.Object); + Mock>> mockFactory = new(); + mockFactory.Setup(f => f(It.IsAny())).ReturnsAsync(mockFactoryChatHistoryProvider.Object); ChatClientAgent agent = new(mockService.Object, options: new() { diff --git a/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgent_CreateSessionTests.cs b/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgent_CreateSessionTests.cs index 86220a6462..86d1da2a2c 100644 --- a/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgent_CreateSessionTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgent_CreateSessionTests.cs @@ -21,7 +21,7 @@ public async Task CreateSession_UsesAIContextProviderFactory_IfProvidedAsync() var agent = new ChatClientAgent(mockChatClient.Object, new ChatClientAgentOptions { ChatOptions = new() { Instructions = "Test instructions" }, - AIContextProviderFactory = (_, _) => + AIContextProviderFactory = (_) => { factoryCalled = true; return new ValueTask(mockContextProvider.Object); @@ -48,7 +48,7 @@ public async Task CreateSession_UsesChatHistoryProviderFactory_IfProvidedAsync() var agent = new ChatClientAgent(mockChatClient.Object, new ChatClientAgentOptions { ChatOptions = new() { Instructions = "Test instructions" }, - ChatHistoryProviderFactory = (_, _) => + ChatHistoryProviderFactory = (_) => { factoryCalled = true; return new ValueTask(mockChatHistoryProvider.Object); diff --git a/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgent_DeserializeSessionTests.cs b/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgent_DeserializeSessionTests.cs index 014cb1483b..12ae51ad4f 100644 --- a/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgent_DeserializeSessionTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgent_DeserializeSessionTests.cs @@ -22,7 +22,7 @@ public async Task DeserializeSession_UsesAIContextProviderFactory_IfProvidedAsyn var agent = new ChatClientAgent(mockChatClient.Object, new ChatClientAgentOptions { ChatOptions = new() { Instructions = "Test instructions" }, - AIContextProviderFactory = (_, _) => + AIContextProviderFactory = (_) => { factoryCalled = true; return new ValueTask(mockContextProvider.Object); @@ -55,7 +55,7 @@ public async Task DeserializeSession_UsesChatHistoryProviderFactory_IfProvidedAs var agent = new ChatClientAgent(mockChatClient.Object, new ChatClientAgentOptions { ChatOptions = new() { Instructions = "Test instructions" }, - ChatHistoryProviderFactory = (_, _) => + ChatHistoryProviderFactory = (_) => { factoryCalled = true; return new ValueTask(mockChatHistoryProvider.Object); diff --git a/dotnet/tests/Microsoft.Agents.AI.UnitTests/Data/TextSearchProviderTests.cs b/dotnet/tests/Microsoft.Agents.AI.UnitTests/Data/TextSearchProviderTests.cs index c9e39d7d46..2cc83583f4 100644 --- a/dotnet/tests/Microsoft.Agents.AI.UnitTests/Data/TextSearchProviderTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.UnitTests/Data/TextSearchProviderTests.cs @@ -499,25 +499,6 @@ public async Task InvokingAsync_WithRecentMessageRolesIncluded_ShouldFilterRoles #region Serialization Tests - [Fact] - public void Serialize_ShouldReturnEmptyState() - { - // Arrange - var options = new TextSearchProviderOptions - { - SearchTime = TextSearchProviderOptions.TextSearchBehavior.BeforeAIInvoke, - RecentMessageMemoryLimit = 3 - }; - var provider = new TextSearchProvider(this.NoResultSearchAsync, options); - - // Act - var state = provider.Serialize(); - - // Assert - State is now stored in session StateBag, so provider.Serialize() returns empty state - Assert.Equal(JsonValueKind.Object, state.ValueKind); - Assert.False(state.TryGetProperty("recentMessagesText", out _)); - } - [Fact] public async Task InvokedAsync_ShouldPersistMessagesToSessionStateBagAsync() { diff --git a/dotnet/tests/Microsoft.Agents.AI.UnitTests/Memory/ChatHistoryMemoryProviderTests.cs b/dotnet/tests/Microsoft.Agents.AI.UnitTests/Memory/ChatHistoryMemoryProviderTests.cs index f290f9c0db..3c2e0e13c4 100644 --- a/dotnet/tests/Microsoft.Agents.AI.UnitTests/Memory/ChatHistoryMemoryProviderTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.UnitTests/Memory/ChatHistoryMemoryProviderTests.cs @@ -3,7 +3,6 @@ using System; using System.Collections.Generic; using System.Linq; -using System.Text.Json; using System.Threading; using System.Threading.Tasks; using Microsoft.Extensions.AI; @@ -506,29 +505,6 @@ public async Task InvokingAsync_LogsUserIdBasedOnEnableSensitiveTelemetryDataAsy #endregion - #region Serialization Tests - - [Fact] - public void Serialize_ReturnsEmptyState() - { - // Arrange - var provider = new ChatHistoryMemoryProvider( - this._vectorStoreMock.Object, - TestCollectionName, - 1, - _ => new ChatHistoryMemoryProvider.State(new ChatHistoryMemoryProviderScope { UserId = "UID" }, null)); - - // Act - var stateElement = provider.Serialize(); - - // Assert - Serialize returns empty object since state is now stored in StateBag - using JsonDocument doc = JsonDocument.Parse(stateElement.GetRawText()); - Assert.Equal(JsonValueKind.Object, doc.RootElement.ValueKind); - Assert.Empty(doc.RootElement.EnumerateObject()); - } - - #endregion - private static async IAsyncEnumerable ToAsyncEnumerableAsync(IEnumerable values) { await Task.Yield(); From 076bcdafe200611474b7882e4377fa8084da6d20 Mon Sep 17 00:00:00 2001 From: westey <164392973+westey-m@users.noreply.github.com> Date: Fri, 6 Feb 2026 17:16:35 +0000 Subject: [PATCH 07/28] Replacing provider factories with properties --- .../Program.cs | 4 +- .../Program.cs | 13 ++-- .../Program.cs | 2 +- .../Program.cs | 6 +- .../Program.cs | 2 +- .../Program.cs | 2 +- .../Program.cs | 5 +- .../Agent_Step16_ChatReduction/Program.cs | 2 +- .../Program.cs | 8 +-- .../PersistentAgentsClientExtensions.cs | 4 +- .../AzureAIProjectChatClientExtensions.cs | 4 +- .../CosmosDBChatExtensions.cs | 13 ++-- .../OpenAIAssistantClientExtensions.cs | 4 +- .../ChatClient/ChatClientAgent.cs | 64 ++++++------------- .../ChatClient/ChatClientAgentOptions.cs | 18 ++---- .../ChatClient/ChatClientAgentSession.cs | 29 ++++----- ...AzureAIProjectChatClientExtensionsTests.cs | 19 ++---- .../ChatClient/ChatClientAgentOptionsTests.cs | 36 +++++------ .../ChatClient/ChatClientAgentSessionTests.cs | 26 ++++---- .../ChatClient/ChatClientAgentTests.cs | 23 +++---- ...tClientAgent_ChatHistoryManagementTests.cs | 31 +++------ .../ChatClientAgent_CreateSessionTests.cs | 20 ++---- ...ChatClientAgent_DeserializeSessionTests.cs | 20 ++---- 23 files changed, 131 insertions(+), 224 deletions(-) diff --git a/dotnet/samples/GettingStarted/AgentWithMemory/AgentWithMemory_Step01_ChatHistoryMemory/Program.cs b/dotnet/samples/GettingStarted/AgentWithMemory/AgentWithMemory_Step01_ChatHistoryMemory/Program.cs index fc6a98df5a..54ee8b5008 100644 --- a/dotnet/samples/GettingStarted/AgentWithMemory/AgentWithMemory_Step01_ChatHistoryMemory/Program.cs +++ b/dotnet/samples/GettingStarted/AgentWithMemory/AgentWithMemory_Step01_ChatHistoryMemory/Program.cs @@ -34,7 +34,7 @@ { ChatOptions = new() { Instructions = "You are good at telling jokes." }, Name = "Joker", - AIContextProviderFactory = (ct) => new ValueTask(new ChatHistoryMemoryProvider( + AIContextProvider = new ChatHistoryMemoryProvider( vectorStore, collectionName: "chathistory", vectorDimensions: 3072, @@ -48,7 +48,7 @@ storageScope: new() { UserId = "UID1", SessionId = Guid.NewGuid().ToString() }, // Configure the scope which would be used to search for relevant prior messages. // In this case, we are searching for any messages for the user across all sessions. - searchScope: new() { UserId = "UID1" }))) + searchScope: new() { UserId = "UID1" })) }); // Start a new session for the agent conversation. diff --git a/dotnet/samples/GettingStarted/AgentWithMemory/AgentWithMemory_Step02_MemoryUsingMem0/Program.cs b/dotnet/samples/GettingStarted/AgentWithMemory/AgentWithMemory_Step02_MemoryUsingMem0/Program.cs index 6072ab0e4c..0c6406133e 100644 --- a/dotnet/samples/GettingStarted/AgentWithMemory/AgentWithMemory_Step02_MemoryUsingMem0/Program.cs +++ b/dotnet/samples/GettingStarted/AgentWithMemory/AgentWithMemory_Step02_MemoryUsingMem0/Program.cs @@ -31,13 +31,12 @@ .AsAIAgent(new ChatClientAgentOptions() { ChatOptions = new() { Instructions = "You are a friendly travel assistant. Use known memories about the user when responding, and do not invent details." }, - AIContextProviderFactory = (ct) => new ValueTask( - // The stateInitializer can be used to customize the Mem0 scope per session and it will be called each time a session - // is encountered by the Mem0Provider that does not already have Mem0Provider state stored on the session. - // If each session should have its own Mem0 scope, you can create a new id per session via the stateInitializer, e.g.: - // new Mem0Provider(mem0HttpClient, stateInitializer: _ => new(new Mem0ProviderScope() { ThreadId = Guid.NewGuid().ToString() })) - // In our case we are storing memories scoped by application and user instead so that memories are retained across threads. - new Mem0Provider(mem0HttpClient, stateInitializer: _ => new(new Mem0ProviderScope() { ApplicationId = "getting-started-agents", UserId = "sample-user" }))) + // The stateInitializer can be used to customize the Mem0 scope per session and it will be called each time a session + // is encountered by the Mem0Provider that does not already have Mem0Provider state stored on the session. + // If each session should have its own Mem0 scope, you can create a new id per session via the stateInitializer, e.g.: + // new Mem0Provider(mem0HttpClient, stateInitializer: _ => new(new Mem0ProviderScope() { ThreadId = Guid.NewGuid().ToString() })) + // In our case we are storing memories scoped by application and user instead so that memories are retained across threads. + AIContextProvider = new Mem0Provider(mem0HttpClient, stateInitializer: _ => new(new Mem0ProviderScope() { ApplicationId = "getting-started-agents", UserId = "sample-user" })) }); AgentSession session = await agent.CreateSessionAsync(); diff --git a/dotnet/samples/GettingStarted/AgentWithMemory/AgentWithMemory_Step03_CustomMemory/Program.cs b/dotnet/samples/GettingStarted/AgentWithMemory/AgentWithMemory_Step03_CustomMemory/Program.cs index 0d1252329d..151a32f67e 100644 --- a/dotnet/samples/GettingStarted/AgentWithMemory/AgentWithMemory_Step03_CustomMemory/Program.cs +++ b/dotnet/samples/GettingStarted/AgentWithMemory/AgentWithMemory_Step03_CustomMemory/Program.cs @@ -33,7 +33,7 @@ AIAgent agent = chatClient.AsAIAgent(new ChatClientAgentOptions() { ChatOptions = new() { Instructions = "You are a friendly assistant. Always address the user by their name." }, - AIContextProviderFactory = (ct) => new ValueTask(new UserInfoMemory(chatClient.AsIChatClient())) + AIContextProvider = new UserInfoMemory(chatClient.AsIChatClient()) }); // Create a new session for the conversation. diff --git a/dotnet/samples/GettingStarted/AgentWithRAG/AgentWithRAG_Step01_BasicTextRAG/Program.cs b/dotnet/samples/GettingStarted/AgentWithRAG/AgentWithRAG_Step01_BasicTextRAG/Program.cs index 0eb2fb6436..a3dfae995f 100644 --- a/dotnet/samples/GettingStarted/AgentWithRAG/AgentWithRAG_Step01_BasicTextRAG/Program.cs +++ b/dotnet/samples/GettingStarted/AgentWithRAG/AgentWithRAG_Step01_BasicTextRAG/Program.cs @@ -62,12 +62,12 @@ .AsAIAgent(new ChatClientAgentOptions { ChatOptions = new() { Instructions = "You are a helpful support specialist for Contoso Outdoors. Answer questions using the provided context and cite the source document when available." }, - AIContextProviderFactory = (ct) => new ValueTask(new TextSearchProvider(SearchAdapter, textSearchOptions)), + AIContextProvider = new TextSearchProvider(SearchAdapter, textSearchOptions), // Since we are using ChatCompletion which stores chat history locally, we can also add a message removal policy // that removes messages produced by the TextSearchProvider before they are added to the chat history, so that // we don't bloat chat history with all the search result messages. - ChatHistoryProviderFactory = (ct) => new ValueTask(new InMemoryChatHistoryProvider() - .WithAIContextProviderMessageRemoval()), + ChatHistoryProvider = new InMemoryChatHistoryProvider() + .WithAIContextProviderMessageRemoval(), }); AgentSession session = await agent.CreateSessionAsync(); diff --git a/dotnet/samples/GettingStarted/AgentWithRAG/AgentWithRAG_Step02_CustomVectorStoreRAG/Program.cs b/dotnet/samples/GettingStarted/AgentWithRAG/AgentWithRAG_Step02_CustomVectorStoreRAG/Program.cs index 17e50999b9..4e135db78c 100644 --- a/dotnet/samples/GettingStarted/AgentWithRAG/AgentWithRAG_Step02_CustomVectorStoreRAG/Program.cs +++ b/dotnet/samples/GettingStarted/AgentWithRAG/AgentWithRAG_Step02_CustomVectorStoreRAG/Program.cs @@ -71,7 +71,7 @@ .AsAIAgent(new ChatClientAgentOptions { ChatOptions = new() { Instructions = "You are a helpful support specialist for the Microsoft Agent Framework. Answer questions using the provided context and cite the source document when available. Keep responses brief." }, - AIContextProviderFactory = (ct) => new ValueTask(new TextSearchProvider(SearchAdapter, textSearchOptions)) + AIContextProvider = new TextSearchProvider(SearchAdapter, textSearchOptions) }); AgentSession session = await agent.CreateSessionAsync(); diff --git a/dotnet/samples/GettingStarted/AgentWithRAG/AgentWithRAG_Step03_CustomRAGDataSource/Program.cs b/dotnet/samples/GettingStarted/AgentWithRAG/AgentWithRAG_Step03_CustomRAGDataSource/Program.cs index f1c6256cd2..6db583ca41 100644 --- a/dotnet/samples/GettingStarted/AgentWithRAG/AgentWithRAG_Step03_CustomRAGDataSource/Program.cs +++ b/dotnet/samples/GettingStarted/AgentWithRAG/AgentWithRAG_Step03_CustomRAGDataSource/Program.cs @@ -29,7 +29,7 @@ .AsAIAgent(new ChatClientAgentOptions { ChatOptions = new() { Instructions = "You are a helpful support specialist for Contoso Outdoors. Answer questions using the provided context and cite the source document when available." }, - AIContextProviderFactory = (ct) => new ValueTask(new TextSearchProvider(MockSearchAsync, textSearchOptions)) + AIContextProvider = new TextSearchProvider(MockSearchAsync, textSearchOptions) }); AgentSession session = await agent.CreateSessionAsync(); diff --git a/dotnet/samples/GettingStarted/Agents/Agent_Step07_3rdPartyChatHistoryStorage/Program.cs b/dotnet/samples/GettingStarted/Agents/Agent_Step07_3rdPartyChatHistoryStorage/Program.cs index 418691144a..d58c2253ae 100644 --- a/dotnet/samples/GettingStarted/Agents/Agent_Step07_3rdPartyChatHistoryStorage/Program.cs +++ b/dotnet/samples/GettingStarted/Agents/Agent_Step07_3rdPartyChatHistoryStorage/Program.cs @@ -33,9 +33,8 @@ { ChatOptions = new() { Instructions = "You are good at telling jokes." }, Name = "Joker", - ChatHistoryProviderFactory = (ct) => new ValueTask( - // Create a new ChatHistoryProvider for this agent that stores chat history in a vector store. - new VectorChatHistoryProvider(vectorStore)) + // Create a new ChatHistoryProvider for this agent that stores chat history in a vector store. + ChatHistoryProvider = new VectorChatHistoryProvider(vectorStore) }); // Start a new session for the agent conversation. diff --git a/dotnet/samples/GettingStarted/Agents/Agent_Step16_ChatReduction/Program.cs b/dotnet/samples/GettingStarted/Agents/Agent_Step16_ChatReduction/Program.cs index 83bc33ed25..f1949d4935 100644 --- a/dotnet/samples/GettingStarted/Agents/Agent_Step16_ChatReduction/Program.cs +++ b/dotnet/samples/GettingStarted/Agents/Agent_Step16_ChatReduction/Program.cs @@ -24,7 +24,7 @@ { ChatOptions = new() { Instructions = "You are good at telling jokes." }, Name = "Joker", - ChatHistoryProviderFactory = (ct) => new ValueTask(new InMemoryChatHistoryProvider(chatReducer: new MessageCountingChatReducer(2))) + ChatHistoryProvider = new InMemoryChatHistoryProvider(chatReducer: new MessageCountingChatReducer(2)) }); AgentSession session = await agent.CreateSessionAsync(); diff --git a/dotnet/samples/GettingStarted/Agents/Agent_Step20_AdditionalAIContext/Program.cs b/dotnet/samples/GettingStarted/Agents/Agent_Step20_AdditionalAIContext/Program.cs index 7b86959a95..44167ca894 100644 --- a/dotnet/samples/GettingStarted/Agents/Agent_Step20_AdditionalAIContext/Program.cs +++ b/dotnet/samples/GettingStarted/Agents/Agent_Step20_AdditionalAIContext/Program.cs @@ -44,16 +44,16 @@ You are a helpful personal assistant. You manage a TODO list for the user. When the user has completed one of the tasks it can be removed from the TODO list. Only provide the list of TODO items if asked. You remind users of upcoming calendar events when the user interacts with you. """ }, - ChatHistoryProviderFactory = (ct) => new ValueTask(new InMemoryChatHistoryProvider() + ChatHistoryProvider = new InMemoryChatHistoryProvider() // Use WithAIContextProviderMessageRemoval, so that we don't store the messages from the AI context provider in the chat history. // You may want to store these messages, depending on their content and your requirements. - .WithAIContextProviderMessageRemoval()), + .WithAIContextProviderMessageRemoval(), // Add an AI context provider that maintains a todo list for the agent and one that provides upcoming calendar entries. // Wrap these in an AI context provider that aggregates the other two. - AIContextProviderFactory = (ct) => new ValueTask(new AggregatingAIContextProvider([ + AIContextProvider = new AggregatingAIContextProvider([ new TodoListAIContextProvider(), new CalendarSearchAIContextProvider(loadNextThreeCalendarEvents) - ])), + ]), }); // Invoke the agent and output the text result. diff --git a/dotnet/src/Microsoft.Agents.AI.AzureAI.Persistent/PersistentAgentsClientExtensions.cs b/dotnet/src/Microsoft.Agents.AI.AzureAI.Persistent/PersistentAgentsClientExtensions.cs index 2058e6760b..c9c898a18a 100644 --- a/dotnet/src/Microsoft.Agents.AI.AzureAI.Persistent/PersistentAgentsClientExtensions.cs +++ b/dotnet/src/Microsoft.Agents.AI.AzureAI.Persistent/PersistentAgentsClientExtensions.cs @@ -191,8 +191,8 @@ public static ChatClientAgent AsAIAgent( Name = options.Name ?? persistentAgentMetadata.Name, Description = options.Description ?? persistentAgentMetadata.Description, ChatOptions = options.ChatOptions, - AIContextProviderFactory = options.AIContextProviderFactory, - ChatHistoryProviderFactory = options.ChatHistoryProviderFactory, + AIContextProvider = options.AIContextProvider, + ChatHistoryProvider = options.ChatHistoryProvider, UseProvidedChatClientAsIs = options.UseProvidedChatClientAsIs }; diff --git a/dotnet/src/Microsoft.Agents.AI.AzureAI/AzureAIProjectChatClientExtensions.cs b/dotnet/src/Microsoft.Agents.AI.AzureAI/AzureAIProjectChatClientExtensions.cs index 18a5ba3b72..82ca49705e 100644 --- a/dotnet/src/Microsoft.Agents.AI.AzureAI/AzureAIProjectChatClientExtensions.cs +++ b/dotnet/src/Microsoft.Agents.AI.AzureAI/AzureAIProjectChatClientExtensions.cs @@ -589,8 +589,8 @@ private static ChatClientAgentOptions CreateChatClientAgentOptions(AgentVersion var agentOptions = CreateChatClientAgentOptions(agentVersion, options?.ChatOptions, requireInvocableTools); if (options is not null) { - agentOptions.AIContextProviderFactory = options.AIContextProviderFactory; - agentOptions.ChatHistoryProviderFactory = options.ChatHistoryProviderFactory; + agentOptions.AIContextProvider = options.AIContextProvider; + agentOptions.ChatHistoryProvider = options.ChatHistoryProvider; agentOptions.UseProvidedChatClientAsIs = options.UseProvidedChatClientAsIs; } diff --git a/dotnet/src/Microsoft.Agents.AI.CosmosNoSql/CosmosDBChatExtensions.cs b/dotnet/src/Microsoft.Agents.AI.CosmosNoSql/CosmosDBChatExtensions.cs index 4ee2c5b5bb..76b865e4c8 100644 --- a/dotnet/src/Microsoft.Agents.AI.CosmosNoSql/CosmosDBChatExtensions.cs +++ b/dotnet/src/Microsoft.Agents.AI.CosmosNoSql/CosmosDBChatExtensions.cs @@ -2,7 +2,6 @@ using System; using System.Diagnostics.CodeAnalysis; -using System.Threading.Tasks; using Azure.Core; using Microsoft.Azure.Cosmos; @@ -41,8 +40,8 @@ public static ChatClientAgentOptions WithCosmosDBChatHistoryProvider( throw new ArgumentNullException(nameof(options)); } - options.ChatHistoryProviderFactory = (ct) => new ValueTask( - new CosmosChatHistoryProvider(connectionString, databaseId, containerId, stateInitializer ?? s_defaultStateInitializer)); + options.ChatHistoryProvider = + new CosmosChatHistoryProvider(connectionString, databaseId, containerId, stateInitializer ?? s_defaultStateInitializer); return options; } @@ -78,8 +77,8 @@ public static ChatClientAgentOptions WithCosmosDBChatHistoryProviderUsingManaged throw new ArgumentNullException(nameof(tokenCredential)); } - options.ChatHistoryProviderFactory = (ct) => new ValueTask( - new CosmosChatHistoryProvider(accountEndpoint, tokenCredential, databaseId, containerId, stateInitializer ?? s_defaultStateInitializer)); + options.ChatHistoryProvider = + new CosmosChatHistoryProvider(accountEndpoint, tokenCredential, databaseId, containerId, stateInitializer ?? s_defaultStateInitializer); return options; } @@ -108,8 +107,8 @@ public static ChatClientAgentOptions WithCosmosDBChatHistoryProvider( throw new ArgumentNullException(nameof(options)); } - options.ChatHistoryProviderFactory = (ct) => new ValueTask( - new CosmosChatHistoryProvider(cosmosClient, databaseId, containerId, stateInitializer ?? s_defaultStateInitializer)); + options.ChatHistoryProvider = + new CosmosChatHistoryProvider(cosmosClient, databaseId, containerId, stateInitializer ?? s_defaultStateInitializer); return options; } } diff --git a/dotnet/src/Microsoft.Agents.AI.OpenAI/Extensions/OpenAIAssistantClientExtensions.cs b/dotnet/src/Microsoft.Agents.AI.OpenAI/Extensions/OpenAIAssistantClientExtensions.cs index d167d1f0b4..3f25f71d83 100644 --- a/dotnet/src/Microsoft.Agents.AI.OpenAI/Extensions/OpenAIAssistantClientExtensions.cs +++ b/dotnet/src/Microsoft.Agents.AI.OpenAI/Extensions/OpenAIAssistantClientExtensions.cs @@ -204,8 +204,8 @@ public static ChatClientAgent AsAIAgent( Name = options.Name ?? assistantMetadata.Name, Description = options.Description ?? assistantMetadata.Description, ChatOptions = options.ChatOptions, - AIContextProviderFactory = options.AIContextProviderFactory, - ChatHistoryProviderFactory = options.ChatHistoryProviderFactory, + AIContextProvider = options.AIContextProvider, + ChatHistoryProvider = options.ChatHistoryProvider, UseProvidedChatClientAsIs = options.UseProvidedChatClientAsIs }; diff --git a/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgent.cs b/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgent.cs index 1048ba3998..db2ba201c5 100644 --- a/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgent.cs +++ b/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgent.cs @@ -302,21 +302,13 @@ protected override async IAsyncEnumerable RunCoreStreamingA : this.ChatClient.GetService(serviceType, serviceKey)); /// - public override async ValueTask CreateSessionAsync(CancellationToken cancellationToken = default) + public override ValueTask CreateSessionAsync(CancellationToken cancellationToken = default) { - ChatHistoryProvider? chatHistoryProvider = this._agentOptions?.ChatHistoryProviderFactory is not null - ? await this._agentOptions.ChatHistoryProviderFactory.Invoke(cancellationToken).ConfigureAwait(false) - : null; - - AIContextProvider? contextProvider = this._agentOptions?.AIContextProviderFactory is not null - ? await this._agentOptions.AIContextProviderFactory.Invoke(cancellationToken).ConfigureAwait(false) - : null; - - return new ChatClientAgentSession + return new(new ChatClientAgentSession { - ChatHistoryProvider = chatHistoryProvider, - AIContextProvider = contextProvider - }; + ChatHistoryProvider = this._agentOptions?.ChatHistoryProvider, + AIContextProvider = this._agentOptions?.AIContextProvider + }); } /// @@ -337,17 +329,13 @@ public override async ValueTask CreateSessionAsync(CancellationTok /// instances that support server-side conversation storage through their underlying . /// /// - public async ValueTask CreateSessionAsync(string conversationId, CancellationToken cancellationToken = default) + public ValueTask CreateSessionAsync(string conversationId, CancellationToken cancellationToken = default) { - AIContextProvider? contextProvider = this._agentOptions?.AIContextProviderFactory is not null - ? await this._agentOptions.AIContextProviderFactory.Invoke(cancellationToken).ConfigureAwait(false) - : null; - - return new ChatClientAgentSession() + return new(new ChatClientAgentSession() { ConversationId = conversationId, - AIContextProvider = contextProvider - }; + AIContextProvider = this._agentOptions?.AIContextProvider + }); } /// @@ -372,17 +360,13 @@ public async ValueTask CreateSessionAsync(string conversationId, C /// the session will throw an exception to indicate that it cannot continue using the provided . /// /// - public async ValueTask CreateSessionAsync(ChatHistoryProvider chatHistoryProvider, CancellationToken cancellationToken = default) + public ValueTask CreateSessionAsync(ChatHistoryProvider chatHistoryProvider, CancellationToken cancellationToken = default) { - AIContextProvider? contextProvider = this._agentOptions?.AIContextProviderFactory is not null - ? await this._agentOptions.AIContextProviderFactory.Invoke(cancellationToken).ConfigureAwait(false) - : null; - - return new ChatClientAgentSession() + return new(new ChatClientAgentSession() { ChatHistoryProvider = Throw.IfNull(chatHistoryProvider), - AIContextProvider = contextProvider - }; + AIContextProvider = this._agentOptions?.AIContextProvider + }); } /// @@ -399,21 +383,12 @@ public override JsonElement SerializeSession(AgentSession session, JsonSerialize } /// - public override async ValueTask DeserializeSessionAsync(JsonElement serializedState, JsonSerializerOptions? jsonSerializerOptions = null, CancellationToken cancellationToken = default) + public override ValueTask DeserializeSessionAsync(JsonElement serializedState, JsonSerializerOptions? jsonSerializerOptions = null, CancellationToken cancellationToken = default) { - Func>? chatHistoryProviderFactory = this._agentOptions?.ChatHistoryProviderFactory is null ? - null : - (ct) => this._agentOptions.ChatHistoryProviderFactory.Invoke(ct); - - Func>? aiContextProviderFactory = this._agentOptions?.AIContextProviderFactory is null ? - null : - (ct) => this._agentOptions.AIContextProviderFactory.Invoke(ct); - - return await ChatClientAgentSession.DeserializeAsync( + return new(ChatClientAgentSession.Deserialize( serializedState, - chatHistoryProviderFactory, - aiContextProviderFactory, - cancellationToken).ConfigureAwait(false); + this._agentOptions?.ChatHistoryProvider, + this._agentOptions?.AIContextProvider)); } #region Private @@ -805,9 +780,8 @@ private async Task UpdateSessionWithTypeAndConversationIdAsync(ChatClientAgentSe // If the service doesn't use service side chat history storage (i.e. we got no id back from invocation), and // the session has no ChatHistoryProvider yet, we should update the session with the custom ChatHistoryProvider or // default InMemoryChatHistoryProvider so that it has somewhere to store the chat history. - session.ChatHistoryProvider ??= this._agentOptions?.ChatHistoryProviderFactory is not null - ? await this._agentOptions.ChatHistoryProviderFactory.Invoke(cancellationToken).ConfigureAwait(false) - : new InMemoryChatHistoryProvider(); + session.ChatHistoryProvider ??= this._agentOptions?.ChatHistoryProvider + ?? new InMemoryChatHistoryProvider(); } } diff --git a/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgentOptions.cs b/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgentOptions.cs index 84d64d67e1..66f4f797c5 100644 --- a/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgentOptions.cs +++ b/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgentOptions.cs @@ -1,8 +1,5 @@ // Copyright (c) Microsoft. All rights reserved. -using System; -using System.Threading; -using System.Threading.Tasks; using Microsoft.Extensions.AI; namespace Microsoft.Agents.AI; @@ -38,17 +35,14 @@ public sealed class ChatClientAgentOptions public ChatOptions? ChatOptions { get; set; } /// - /// Gets or sets a factory function to create an instance of - /// which will be used to provide chat history for this agent. + /// Gets or sets the instance to use for providing chat history for this agent. /// - public Func>? ChatHistoryProviderFactory { get; set; } + public ChatHistoryProvider? ChatHistoryProvider { get; set; } /// - /// Gets or sets a factory function to create an instance of - /// which will be used to create a context provider for each new thread, and can then - /// provide additional context for each agent run. + /// Gets or sets the instance to use for providing additional context for each agent run. /// - public Func>? AIContextProviderFactory { get; set; } + public AIContextProvider? AIContextProvider { get; set; } /// /// Gets or sets a value indicating whether to use the provided instance as is, @@ -74,7 +68,7 @@ public ChatClientAgentOptions Clone() Name = this.Name, Description = this.Description, ChatOptions = this.ChatOptions?.Clone(), - ChatHistoryProviderFactory = this.ChatHistoryProviderFactory, - AIContextProviderFactory = this.AIContextProviderFactory, + ChatHistoryProvider = this.ChatHistoryProvider, + AIContextProvider = this.AIContextProvider, }; } diff --git a/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgentSession.cs b/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgentSession.cs index 935cf89323..4e9b6bced2 100644 --- a/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgentSession.cs +++ b/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgentSession.cs @@ -3,8 +3,6 @@ using System; using System.Diagnostics; using System.Text.Json; -using System.Threading; -using System.Threading.Tasks; using Microsoft.Shared.Diagnostics; namespace Microsoft.Agents.AI; @@ -116,21 +114,19 @@ internal set /// Creates a new instance of the class from previously serialized state. /// /// A representing the serialized state of the session. - /// - /// An optional factory function to create a custom . + /// + /// An optional instance. /// If not provided, the default will be used. /// - /// - /// An optional factory function to create a custom . + /// + /// An optional instance. /// If not provided, no context provider will be configured. /// - /// The to monitor for cancellation requests. - /// A task representing the asynchronous operation. The task result contains the deserialized . - internal static async Task DeserializeAsync( + /// The deserialized . + internal static ChatClientAgentSession Deserialize( JsonElement serializedState, - Func>? chatHistoryProviderFactory = null, - Func>? aiContextProviderFactory = null, - CancellationToken cancellationToken = default) + ChatHistoryProvider? chatHistoryProvider = null, + AIContextProvider? aiContextProvider = null) { if (serializedState.ValueKind != JsonValueKind.Object) { @@ -142,9 +138,7 @@ internal static async Task DeserializeAsync( var session = new ChatClientAgentSession(); - session.AIContextProvider = aiContextProviderFactory is not null - ? await aiContextProviderFactory.Invoke(cancellationToken).ConfigureAwait(false) - : null; + session.AIContextProvider = aiContextProvider; session.StateBag = AgentSessionStateBag.Deserialize(state?.StateBag ?? default); @@ -157,9 +151,8 @@ internal static async Task DeserializeAsync( } session._chatHistoryProvider = - chatHistoryProviderFactory is not null - ? await chatHistoryProviderFactory.Invoke(cancellationToken).ConfigureAwait(false) - : new InMemoryChatHistoryProvider(); // default to an in-memory ChatHistoryProvider + chatHistoryProvider + ?? new InMemoryChatHistoryProvider(); // default to an in-memory ChatHistoryProvider return session; } diff --git a/dotnet/tests/Microsoft.Agents.AI.AzureAI.UnitTests/AzureAIProjectChatClientExtensionsTests.cs b/dotnet/tests/Microsoft.Agents.AI.AzureAI.UnitTests/AzureAIProjectChatClientExtensionsTests.cs index 36222afe32..8b8a6f6dc6 100644 --- a/dotnet/tests/Microsoft.Agents.AI.AzureAI.UnitTests/AzureAIProjectChatClientExtensionsTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.AzureAI.UnitTests/AzureAIProjectChatClientExtensionsTests.cs @@ -2310,23 +2310,18 @@ public async Task GetAIAgentAsync_WithMatchingToolsProvided_CreatesAgentSuccessf #region CreateChatClientAgentOptions - Options Preservation Tests /// - /// Verify that CreateChatClientAgentOptions preserves AIContextProviderFactory. + /// Verify that CreateChatClientAgentOptions preserves AIContextProvider. /// [Fact] - public async Task GetAIAgentAsync_WithAIContextProviderFactory_PreservesFactoryAsync() + public async Task GetAIAgentAsync_WithAIContextProvider_PreservesProviderAsync() { // Arrange AIProjectClient client = this.CreateTestAgentClient(); - bool factoryInvoked = false; var options = new ChatClientAgentOptions { Name = "test-agent", ChatOptions = new ChatOptions { Instructions = "Test" }, - AIContextProviderFactory = (_) => - { - factoryInvoked = true; - return new ValueTask(new TestAIContextProvider()); - } + AIContextProvider = new TestAIContextProvider() }; // Act @@ -2334,15 +2329,13 @@ public async Task GetAIAgentAsync_WithAIContextProviderFactory_PreservesFactoryA // Assert Assert.NotNull(agent); - // Verify the factory was captured (though not necessarily invoked yet) - Assert.False(factoryInvoked); // Factory is not invoked during creation } /// - /// Verify that CreateChatClientAgentOptions preserves ChatHistoryProviderFactory. + /// Verify that CreateChatClientAgentOptions preserves ChatHistoryProvider. /// [Fact] - public async Task GetAIAgentAsync_WithChatHistoryProviderFactory_PreservesFactoryAsync() + public async Task GetAIAgentAsync_WithChatHistoryProvider_PreservesProviderAsync() { // Arrange AIProjectClient client = this.CreateTestAgentClient(); @@ -2350,7 +2343,7 @@ public async Task GetAIAgentAsync_WithChatHistoryProviderFactory_PreservesFactor { Name = "test-agent", ChatOptions = new ChatOptions { Instructions = "Test" }, - ChatHistoryProviderFactory = (_) => new ValueTask(new TestChatHistoryProvider()) + ChatHistoryProvider = new TestChatHistoryProvider() }; // Act diff --git a/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgentOptionsTests.cs b/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgentOptionsTests.cs index 2f89c55d62..e9ad2b9a6b 100644 --- a/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgentOptionsTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgentOptionsTests.cs @@ -1,8 +1,6 @@ // Copyright (c) Microsoft. All rights reserved. using System.Collections.Generic; -using System.Threading; -using System.Threading.Tasks; using Microsoft.Extensions.AI; using Moq; @@ -23,8 +21,8 @@ public void DefaultConstructor_InitializesWithNullValues() Assert.Null(options.Name); Assert.Null(options.Description); Assert.Null(options.ChatOptions); - Assert.Null(options.ChatHistoryProviderFactory); - Assert.Null(options.AIContextProviderFactory); + Assert.Null(options.ChatHistoryProvider); + Assert.Null(options.AIContextProvider); } [Fact] @@ -36,8 +34,8 @@ public void Constructor_WithNullValues_SetsPropertiesCorrectly() // Assert Assert.Null(options.Name); Assert.Null(options.Description); - Assert.Null(options.AIContextProviderFactory); - Assert.Null(options.ChatHistoryProviderFactory); + Assert.Null(options.AIContextProvider); + Assert.Null(options.ChatHistoryProvider); Assert.NotNull(options.ChatOptions); Assert.Null(options.ChatOptions.Instructions); Assert.Null(options.ChatOptions.Tools); @@ -117,11 +115,8 @@ public void Clone_CreatesDeepCopyWithSameValues() const string Description = "Test description"; var tools = new List { AIFunctionFactory.Create(() => "test") }; - static ValueTask ChatHistoryProviderFactoryAsync( - CancellationToken ct) => new(new Mock().Object); - - static ValueTask AIContextProviderFactoryAsync( - CancellationToken ct) => new(new Mock().Object); + var mockChatHistoryProvider = new Mock().Object; + var mockAIContextProvider = new Mock().Object; var original = new ChatClientAgentOptions() { @@ -129,8 +124,8 @@ static ValueTask AIContextProviderFactoryAsync( Description = Description, ChatOptions = new() { Tools = tools }, Id = "test-id", - ChatHistoryProviderFactory = ChatHistoryProviderFactoryAsync, - AIContextProviderFactory = AIContextProviderFactoryAsync + ChatHistoryProvider = mockChatHistoryProvider, + AIContextProvider = mockAIContextProvider }; // Act @@ -141,8 +136,8 @@ static ValueTask AIContextProviderFactoryAsync( Assert.Equal(original.Id, clone.Id); Assert.Equal(original.Name, clone.Name); Assert.Equal(original.Description, clone.Description); - Assert.Same(original.ChatHistoryProviderFactory, clone.ChatHistoryProviderFactory); - Assert.Same(original.AIContextProviderFactory, clone.AIContextProviderFactory); + Assert.Same(original.ChatHistoryProvider, clone.ChatHistoryProvider); + Assert.Same(original.AIContextProvider, clone.AIContextProvider); // ChatOptions should be cloned, not the same reference Assert.NotSame(original.ChatOptions, clone.ChatOptions); @@ -154,11 +149,16 @@ static ValueTask AIContextProviderFactoryAsync( public void Clone_WithoutProvidingChatOptions_ClonesCorrectly() { // Arrange + var mockChatHistoryProvider = new Mock().Object; + var mockAIContextProvider = new Mock().Object; + var original = new ChatClientAgentOptions { Id = "test-id", Name = "Test name", - Description = "Test description" + Description = "Test description", + ChatHistoryProvider = mockChatHistoryProvider, + AIContextProvider = mockAIContextProvider }; // Act @@ -170,8 +170,8 @@ public void Clone_WithoutProvidingChatOptions_ClonesCorrectly() Assert.Equal(original.Name, clone.Name); Assert.Equal(original.Description, clone.Description); Assert.Null(original.ChatOptions); - Assert.Null(clone.ChatHistoryProviderFactory); - Assert.Null(clone.AIContextProviderFactory); + Assert.Same(original.ChatHistoryProvider, clone.ChatHistoryProvider); + Assert.Same(original.AIContextProvider, clone.AIContextProvider); } private static void AssertSameTools(IList? expected, IList? actual) diff --git a/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgentSessionTests.cs b/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgentSessionTests.cs index 81b213e488..93440fcad9 100644 --- a/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgentSessionTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgentSessionTests.cs @@ -3,7 +3,6 @@ using System; using System.Linq; using System.Text.Json; -using System.Threading.Tasks; using Microsoft.Extensions.AI; using Moq; @@ -92,7 +91,7 @@ public void SetChatHistoryProviderThrowsWhenConversationIdIsSet() #region Deserialize Tests [Fact] - public async Task VerifyDeserializeWithMessagesAsync() + public void VerifyDeserializeWithMessages() { // Arrange var json = JsonSerializer.Deserialize(""" @@ -106,7 +105,7 @@ public async Task VerifyDeserializeWithMessagesAsync() """, TestJsonSerializerContext.Default.JsonElement); // Act. - var session = await ChatClientAgentSession.DeserializeAsync(json); + var session = ChatClientAgentSession.Deserialize(json); // Assert Assert.Null(session.ConversationId); @@ -119,7 +118,7 @@ public async Task VerifyDeserializeWithMessagesAsync() } [Fact] - public async Task VerifyDeserializeWithIdAsync() + public void VerifyDeserializeWithId() { // Arrange var json = JsonSerializer.Deserialize(""" @@ -129,7 +128,7 @@ public async Task VerifyDeserializeWithIdAsync() """, TestJsonSerializerContext.Default.JsonElement); // Act - var session = await ChatClientAgentSession.DeserializeAsync(json); + var session = ChatClientAgentSession.Deserialize(json); // Assert Assert.Equal("TestConvId", session.ConversationId); @@ -137,7 +136,7 @@ public async Task VerifyDeserializeWithIdAsync() } [Fact] - public async Task VerifyDeserializeWithAIContextProviderAsync() + public void VerifyDeserializeWithAIContextProvider() { // Arrange var json = JsonSerializer.Deserialize(""" @@ -149,7 +148,7 @@ public async Task VerifyDeserializeWithAIContextProviderAsync() Mock mockProvider = new(); // Act - var session = await ChatClientAgentSession.DeserializeAsync(json, aiContextProviderFactory: (_) => new(mockProvider.Object)); + var session = ChatClientAgentSession.Deserialize(json, aiContextProvider: mockProvider.Object); // Assert Assert.Null(session.ChatHistoryProvider); @@ -157,7 +156,7 @@ public async Task VerifyDeserializeWithAIContextProviderAsync() } [Fact] - public async Task VerifyDeserializeWithStateBagAsync() + public void VerifyDeserializeWithStateBag() { // Arrange var json = JsonSerializer.Deserialize(""" @@ -175,7 +174,7 @@ public async Task VerifyDeserializeWithStateBagAsync() Mock mockProvider = new(); // Act - var session = await ChatClientAgentSession.DeserializeAsync(json, aiContextProviderFactory: (_) => new(mockProvider.Object)); + var session = ChatClientAgentSession.Deserialize(json, aiContextProvider: mockProvider.Object); // Assert var dog = session.StateBag.GetValue("dog", TestJsonSerializerContext.Default.Options); @@ -184,14 +183,13 @@ public async Task VerifyDeserializeWithStateBagAsync() } [Fact] - public async Task DeserializeWithInvalidJsonThrowsAsync() + public void DeserializeWithInvalidJsonThrows() { // Arrange var invalidJson = JsonSerializer.Deserialize("[42]", TestJsonSerializerContext.Default.JsonElement); - var session = new ChatClientAgentSession(); // Act & Assert - await Assert.ThrowsAsync(() => ChatClientAgentSession.DeserializeAsync(invalidJson)); + Assert.Throws(() => ChatClientAgentSession.Deserialize(invalidJson)); } #endregion Deserialize Tests @@ -308,7 +306,7 @@ public void VerifySessionSerializationWithCustomOptions() #region StateBag Roundtrip Tests [Fact] - public async Task VerifyStateBagRoundtripsAsync() + public void VerifyStateBagRoundtrips() { // Arrange var session = new ChatClientAgentSession(); @@ -316,7 +314,7 @@ public async Task VerifyStateBagRoundtripsAsync() // Act var serializedSession = session.Serialize(); - var deserializedSession = await ChatClientAgentSession.DeserializeAsync(serializedSession); + var deserializedSession = ChatClientAgentSession.Deserialize(serializedSession); // Assert var dog = deserializedSession.StateBag.GetValue("dog", TestJsonSerializerContext.Default.Options); diff --git a/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgentTests.cs b/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgentTests.cs index e8c34e13a3..bda334020e 100644 --- a/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgentTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgentTests.cs @@ -353,7 +353,7 @@ public async Task RunAsyncInvokesAIContextProviderAndUsesResultAsync() .Setup(p => p.InvokedAsync(It.IsAny(), It.IsAny())) .Returns(new ValueTask()); - ChatClientAgent agent = new(mockService.Object, options: new() { AIContextProviderFactory = (_) => new(mockProvider.Object), ChatOptions = new() { Instructions = "base instructions", Tools = [AIFunctionFactory.Create(() => { }, "base function")] } }); + ChatClientAgent agent = new(mockService.Object, options: new() { AIContextProvider = mockProvider.Object, ChatOptions = new() { Instructions = "base instructions", Tools = [AIFunctionFactory.Create(() => { }, "base function")] } }); // Act var session = await agent.CreateSessionAsync() as ChatClientAgentSession; @@ -416,7 +416,7 @@ public async Task RunAsyncInvokesAIContextProviderWhenGetResponseFailsAsync() .Setup(p => p.InvokedAsync(It.IsAny(), It.IsAny())) .Returns(new ValueTask()); - ChatClientAgent agent = new(mockService.Object, options: new() { AIContextProviderFactory = (_) => new(mockProvider.Object), ChatOptions = new() { Instructions = "base instructions", Tools = [AIFunctionFactory.Create(() => { }, "base function")] } }); + ChatClientAgent agent = new(mockService.Object, options: new() { AIContextProvider = mockProvider.Object, ChatOptions = new() { Instructions = "base instructions", Tools = [AIFunctionFactory.Create(() => { }, "base function")] } }); // Act await Assert.ThrowsAsync(() => agent.RunAsync(requestMessages)); @@ -462,7 +462,7 @@ public async Task RunAsyncInvokesAIContextProviderAndSucceedsWithEmptyAIContextA .Setup(p => p.InvokingAsync(It.IsAny(), It.IsAny())) .ReturnsAsync(new AIContext()); - ChatClientAgent agent = new(mockService.Object, options: new() { AIContextProviderFactory = (_) => new(mockProvider.Object), ChatOptions = new() { Instructions = "base instructions", Tools = [AIFunctionFactory.Create(() => { }, "base function")] } }); + ChatClientAgent agent = new(mockService.Object, options: new() { AIContextProvider = mockProvider.Object, ChatOptions = new() { Instructions = "base instructions", Tools = [AIFunctionFactory.Create(() => { }, "base function")] } }); // Act await agent.RunAsync([new(ChatRole.User, "user message")]); @@ -1288,12 +1288,10 @@ public async Task RunStreamingAsyncUsesChatHistoryProviderWhenNoConversationIdRe It.IsAny>(), It.IsAny(), It.IsAny())).Returns(ToAsyncEnumerableAsync(returnUpdates)); - Mock>> mockFactory = new(); - mockFactory.Setup(f => f(It.IsAny())).ReturnsAsync(new InMemoryChatHistoryProvider()); ChatClientAgent agent = new(mockService.Object, options: new() { ChatOptions = new() { Instructions = "test instructions" }, - ChatHistoryProviderFactory = mockFactory.Object + ChatHistoryProvider = new InMemoryChatHistoryProvider() }); // Act @@ -1306,14 +1304,13 @@ public async Task RunStreamingAsyncUsesChatHistoryProviderWhenNoConversationIdRe Assert.Equal(2, historyMessages.Count); Assert.Equal("test", historyMessages[0].Text); Assert.Equal("what?", historyMessages[1].Text); - mockFactory.Verify(f => f(It.IsAny()), Times.Once); } /// - /// Verify that RunStreamingAsync throws when a factory is provided and the chat client returns a conversation id. + /// Verify that RunStreamingAsync throws when a is provided and the chat client returns a conversation id. /// [Fact] - public async Task RunStreamingAsyncThrowsWhenChatHistoryProviderFactoryProvidedAndConversationIdReturnedByChatClientAsync() + public async Task RunStreamingAsyncThrowsWhenChatHistoryProviderProvidedAndConversationIdReturnedByChatClientAsync() { // Arrange Mock mockService = new(); @@ -1327,12 +1324,10 @@ public async Task RunStreamingAsyncThrowsWhenChatHistoryProviderFactoryProvidedA It.IsAny>(), It.IsAny(), It.IsAny())).Returns(ToAsyncEnumerableAsync(returnUpdates)); - Mock>> mockFactory = new(); - mockFactory.Setup(f => f(It.IsAny())).ReturnsAsync(new InMemoryChatHistoryProvider()); ChatClientAgent agent = new(mockService.Object, options: new() { ChatOptions = new() { Instructions = "test instructions" }, - ChatHistoryProviderFactory = mockFactory.Object + ChatHistoryProvider = new InMemoryChatHistoryProvider() }); // Act & Assert @@ -1389,7 +1384,7 @@ public async Task RunStreamingAsyncInvokesAIContextProviderAndUsesResultAsync() options: new() { ChatOptions = new() { Instructions = "base instructions", Tools = [AIFunctionFactory.Create(() => { }, "base function")] }, - AIContextProviderFactory = (_) => new(mockProvider.Object) + AIContextProvider = mockProvider.Object }); // Act @@ -1459,7 +1454,7 @@ public async Task RunStreamingAsyncInvokesAIContextProviderWhenGetResponseFailsA options: new() { ChatOptions = new() { Instructions = "base instructions", Tools = [AIFunctionFactory.Create(() => { }, "base function")] }, - AIContextProviderFactory = (_) => new(mockProvider.Object) + AIContextProvider = mockProvider.Object }); // Act diff --git a/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgent_ChatHistoryManagementTests.cs b/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgent_ChatHistoryManagementTests.cs index 7e0538ec17..db1c0101f2 100644 --- a/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgent_ChatHistoryManagementTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgent_ChatHistoryManagementTests.cs @@ -170,10 +170,10 @@ public async Task RunAsync_UsesDefaultInMemoryChatHistoryProvider_WhenNoConversa } /// - /// Verify that RunAsync uses the ChatHistoryProvider factory when the chat client returns no conversation id. + /// Verify that RunAsync uses the ChatHistoryProvider when the chat client returns no conversation id. /// [Fact] - public async Task RunAsync_UsesChatHistoryProviderFactory_WhenProvidedAndNoConversationIdReturnedByChatClientAsync() + public async Task RunAsync_UsesChatHistoryProvider_WhenProvidedAndNoConversationIdReturnedByChatClientAsync() { // Arrange Mock mockService = new(); @@ -191,13 +191,10 @@ public async Task RunAsync_UsesChatHistoryProviderFactory_WhenProvidedAndNoConve It.IsAny(), It.IsAny())).Returns(new ValueTask()); - Mock>> mockFactory = new(); - mockFactory.Setup(f => f(It.IsAny())).ReturnsAsync(mockChatHistoryProvider.Object); - ChatClientAgent agent = new(mockService.Object, options: new() { ChatOptions = new() { Instructions = "test instructions" }, - ChatHistoryProviderFactory = mockFactory.Object + ChatHistoryProvider = mockChatHistoryProvider.Object }); // Act @@ -220,7 +217,6 @@ public async Task RunAsync_UsesChatHistoryProviderFactory_WhenProvidedAndNoConve It.Is(x => x.RequestMessages.Count() == 1 && x.ChatHistoryProviderMessages != null && x.ChatHistoryProviderMessages.Count() == 1 && x.ResponseMessages!.Count() == 1), It.IsAny()), Times.Once); - mockFactory.Verify(f => f(It.IsAny()), Times.Once); } /// @@ -239,13 +235,10 @@ public async Task RunAsync_NotifiesChatHistoryProvider_OnFailureAsync() Mock mockChatHistoryProvider = new(); - Mock>> mockFactory = new(); - mockFactory.Setup(f => f(It.IsAny())).ReturnsAsync(mockChatHistoryProvider.Object); - ChatClientAgent agent = new(mockService.Object, options: new() { ChatOptions = new() { Instructions = "test instructions" }, - ChatHistoryProviderFactory = mockFactory.Object + ChatHistoryProvider = mockChatHistoryProvider.Object }); // Act @@ -258,14 +251,13 @@ public async Task RunAsync_NotifiesChatHistoryProvider_OnFailureAsync() It.Is(x => x.RequestMessages.Count() == 1 && x.ResponseMessages == null && x.InvokeException!.Message == "Test Error"), It.IsAny()), Times.Once); - mockFactory.Verify(f => f(It.IsAny()), Times.Once); } /// - /// Verify that RunAsync throws when a ChatHistoryProvider Factory is provided and the chat client returns a conversation id. + /// Verify that RunAsync throws when a ChatHistoryProvider is provided and the chat client returns a conversation id. /// [Fact] - public async Task RunAsync_Throws_WhenChatHistoryProviderFactoryProvidedAndConversationIdReturnedByChatClientAsync() + public async Task RunAsync_Throws_WhenChatHistoryProviderProvidedAndConversationIdReturnedByChatClientAsync() { // Arrange Mock mockService = new(); @@ -274,12 +266,10 @@ public async Task RunAsync_Throws_WhenChatHistoryProviderFactoryProvidedAndConve It.IsAny>(), It.IsAny(), It.IsAny())).ReturnsAsync(new ChatResponse([new(ChatRole.Assistant, "response")]) { ConversationId = "ConvId" }); - Mock>> mockFactory = new(); - mockFactory.Setup(f => f(It.IsAny())).ReturnsAsync(new InMemoryChatHistoryProvider()); ChatClientAgent agent = new(mockService.Object, options: new() { ChatOptions = new() { Instructions = "test instructions" }, - ChatHistoryProviderFactory = mockFactory.Object + ChatHistoryProvider = new InMemoryChatHistoryProvider() }); // Act & Assert @@ -316,7 +306,7 @@ public async Task RunAsync_UsesOverrideChatHistoryProvider_WhenProvidedViaAdditi It.IsAny(), It.IsAny())).Returns(new ValueTask()); - // Arrange a chat history provider to provide to the agent via a factory at construction time. + // Arrange a chat history provider to provide to the agent at construction time. // This one shouldn't be used since it is being overridden. Mock mockFactoryChatHistoryProvider = new(); mockFactoryChatHistoryProvider.Setup(s => s.InvokingAsync( @@ -326,13 +316,10 @@ public async Task RunAsync_UsesOverrideChatHistoryProvider_WhenProvidedViaAdditi It.IsAny(), It.IsAny())).Throws(FailException.ForFailure("Base ChatHistoryProvider shouldn't be used.")); - Mock>> mockFactory = new(); - mockFactory.Setup(f => f(It.IsAny())).ReturnsAsync(mockFactoryChatHistoryProvider.Object); - ChatClientAgent agent = new(mockService.Object, options: new() { ChatOptions = new() { Instructions = "test instructions" }, - ChatHistoryProviderFactory = mockFactory.Object + ChatHistoryProvider = mockFactoryChatHistoryProvider.Object }); // Act diff --git a/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgent_CreateSessionTests.cs b/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgent_CreateSessionTests.cs index 86d1da2a2c..f108b0966d 100644 --- a/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgent_CreateSessionTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgent_CreateSessionTests.cs @@ -12,54 +12,42 @@ namespace Microsoft.Agents.AI.UnitTests; public class ChatClientAgent_CreateSessionTests { [Fact] - public async Task CreateSession_UsesAIContextProviderFactory_IfProvidedAsync() + public async Task CreateSession_UsesAIContextProvider_IfProvidedAsync() { // Arrange var mockChatClient = new Mock(); var mockContextProvider = new Mock(); - var factoryCalled = false; var agent = new ChatClientAgent(mockChatClient.Object, new ChatClientAgentOptions { ChatOptions = new() { Instructions = "Test instructions" }, - AIContextProviderFactory = (_) => - { - factoryCalled = true; - return new ValueTask(mockContextProvider.Object); - } + AIContextProvider = mockContextProvider.Object }); // Act var session = await agent.CreateSessionAsync(); // Assert - Assert.True(factoryCalled, "AIContextProviderFactory was not called."); Assert.IsType(session); var typedSession = (ChatClientAgentSession)session; Assert.Same(mockContextProvider.Object, typedSession.AIContextProvider); } [Fact] - public async Task CreateSession_UsesChatHistoryProviderFactory_IfProvidedAsync() + public async Task CreateSession_UsesChatHistoryProvider_IfProvidedAsync() { // Arrange var mockChatClient = new Mock(); var mockChatHistoryProvider = new Mock(); - var factoryCalled = false; var agent = new ChatClientAgent(mockChatClient.Object, new ChatClientAgentOptions { ChatOptions = new() { Instructions = "Test instructions" }, - ChatHistoryProviderFactory = (_) => - { - factoryCalled = true; - return new ValueTask(mockChatHistoryProvider.Object); - } + ChatHistoryProvider = mockChatHistoryProvider.Object }); // Act var session = await agent.CreateSessionAsync(); // Assert - Assert.True(factoryCalled, "ChatHistoryProviderFactory was not called."); Assert.IsType(session); var typedSession = (ChatClientAgentSession)session; Assert.Same(mockChatHistoryProvider.Object, typedSession.ChatHistoryProvider); diff --git a/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgent_DeserializeSessionTests.cs b/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgent_DeserializeSessionTests.cs index 12ae51ad4f..d4f2ab9628 100644 --- a/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgent_DeserializeSessionTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgent_DeserializeSessionTests.cs @@ -13,20 +13,15 @@ namespace Microsoft.Agents.AI.UnitTests; public class ChatClientAgent_DeserializeSessionTests { [Fact] - public async Task DeserializeSession_UsesAIContextProviderFactory_IfProvidedAsync() + public async Task DeserializeSession_UsesAIContextProvider_IfProvidedAsync() { // Arrange var mockChatClient = new Mock(); var mockContextProvider = new Mock(); - var factoryCalled = false; var agent = new ChatClientAgent(mockChatClient.Object, new ChatClientAgentOptions { ChatOptions = new() { Instructions = "Test instructions" }, - AIContextProviderFactory = (_) => - { - factoryCalled = true; - return new ValueTask(mockContextProvider.Object); - } + AIContextProvider = mockContextProvider.Object }); var json = JsonSerializer.Deserialize(""" @@ -39,27 +34,21 @@ public async Task DeserializeSession_UsesAIContextProviderFactory_IfProvidedAsyn var session = await agent.DeserializeSessionAsync(json); // Assert - Assert.True(factoryCalled, "AIContextProviderFactory was not called."); Assert.IsType(session); var typedSession = (ChatClientAgentSession)session; Assert.Same(mockContextProvider.Object, typedSession.AIContextProvider); } [Fact] - public async Task DeserializeSession_UsesChatHistoryProviderFactory_IfProvidedAsync() + public async Task DeserializeSession_UsesChatHistoryProvider_IfProvidedAsync() { // Arrange var mockChatClient = new Mock(); var mockChatHistoryProvider = new Mock(); - var factoryCalled = false; var agent = new ChatClientAgent(mockChatClient.Object, new ChatClientAgentOptions { ChatOptions = new() { Instructions = "Test instructions" }, - ChatHistoryProviderFactory = (_) => - { - factoryCalled = true; - return new ValueTask(mockChatHistoryProvider.Object); - } + ChatHistoryProvider = mockChatHistoryProvider.Object }); var json = JsonSerializer.Deserialize(""" @@ -72,7 +61,6 @@ public async Task DeserializeSession_UsesChatHistoryProviderFactory_IfProvidedAs var session = await agent.DeserializeSessionAsync(json); // Assert - Assert.True(factoryCalled, "ChatHistoryProviderFactory was not called."); Assert.IsType(session); var typedSession = (ChatClientAgentSession)session; Assert.Same(mockChatHistoryProvider.Object, typedSession.ChatHistoryProvider); From 8fae8b07bd6d0b753c5aac15b816fec994e0e2b8 Mon Sep 17 00:00:00 2001 From: westey <164392973+westey-m@users.noreply.github.com> Date: Fri, 6 Feb 2026 19:26:16 +0000 Subject: [PATCH 08/28] Remove Providers from Session and flatten state bag serialization --- .../AgentAbstractionsJsonUtilities.cs | 1 + .../AgentSessionStateBag.cs | 7 + .../AgentSessionStateBagJsonConverter.cs | 28 ++++ .../AgentSessionStateBagValue.cs | 5 +- .../AgentSessionStateBagValueJsonConverter.cs | 27 ++++ .../ChatClient/ChatClientAgent.cs | 103 +++++-------- .../ChatClient/ChatClientAgentSession.cs | 105 ++------------ .../AnthropicChatCompletionFixture.cs | 6 +- .../AIProjectClientFixture.cs | 6 +- .../AgentSessionStateBagTests.cs | 91 +++++++++++- .../InMemoryAgentSessionTests.cs | 4 +- .../ChatClient/ChatClientAgentSessionTests.cs | 137 ++---------------- .../ChatClient/ChatClientAgentTests.cs | 10 +- ...hatClientAgent_BackgroundResponsesTests.cs | 36 ++--- ...tClientAgent_ChatHistoryManagementTests.cs | 25 ++-- .../ChatClientAgent_CreateSessionTests.cs | 59 -------- ...ChatClientAgent_DeserializeSessionTests.cs | 68 --------- .../Data/TextSearchProviderTests.cs | 3 +- .../OpenAIChatCompletionFixture.cs | 6 +- .../OpenAIResponseFixture.cs | 6 +- 20 files changed, 260 insertions(+), 473 deletions(-) create mode 100644 dotnet/src/Microsoft.Agents.AI.Abstractions/AgentSessionStateBagJsonConverter.cs create mode 100644 dotnet/src/Microsoft.Agents.AI.Abstractions/AgentSessionStateBagValueJsonConverter.cs delete mode 100644 dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgent_DeserializeSessionTests.cs diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentAbstractionsJsonUtilities.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentAbstractionsJsonUtilities.cs index bf0e835b4b..df5375d0b6 100644 --- a/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentAbstractionsJsonUtilities.cs +++ b/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentAbstractionsJsonUtilities.cs @@ -84,6 +84,7 @@ private static JsonSerializerOptions CreateDefaultOptions() [JsonSerializable(typeof(ServiceIdAgentSession.ServiceIdAgentSessionState))] [JsonSerializable(typeof(InMemoryAgentSession.InMemoryAgentSessionState))] [JsonSerializable(typeof(InMemoryChatHistoryProvider.State))] + [JsonSerializable(typeof(AgentSessionStateBag))] [JsonSerializable(typeof(ConcurrentDictionary))] [ExcludeFromCodeCoverage] diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentSessionStateBag.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentSessionStateBag.cs index 47eef61508..1bc55d59ae 100644 --- a/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentSessionStateBag.cs +++ b/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentSessionStateBag.cs @@ -3,6 +3,7 @@ using System; using System.Collections.Concurrent; using System.Text.Json; +using System.Text.Json.Serialization; using Microsoft.Shared.Diagnostics; namespace Microsoft.Agents.AI; @@ -16,6 +17,7 @@ namespace Microsoft.Agents.AI; /// Values can be accessed in a type-safe manner and are serialized or deserialized using configurable JSON serializer /// options. This class is designed for concurrent access and is safe to use across multiple threads. /// +[JsonConverter(typeof(AgentSessionStateBagJsonConverter))] public class AgentSessionStateBag { private readonly ConcurrentDictionary _state; @@ -37,6 +39,11 @@ internal AgentSessionStateBag(ConcurrentDictionary(); } + /// + /// Gets the number of key-value pairs contained in the session state. + /// + public int Count => this._state.Count; + /// /// Tries to get a value from the session state. /// diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentSessionStateBagJsonConverter.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentSessionStateBagJsonConverter.cs new file mode 100644 index 0000000000..999f3c627b --- /dev/null +++ b/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentSessionStateBagJsonConverter.cs @@ -0,0 +1,28 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Text.Json; +using System.Text.Json.Serialization; + +namespace Microsoft.Agents.AI; + +/// +/// Custom JSON converter for that serializes and deserializes +/// the internal dictionary contents rather than the container object's public properties. +/// +public sealed class AgentSessionStateBagJsonConverter : JsonConverter +{ + /// + public override AgentSessionStateBag Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) + { + var element = JsonElement.ParseValue(ref reader); + return AgentSessionStateBag.Deserialize(element); + } + + /// + public override void Write(Utf8JsonWriter writer, AgentSessionStateBag value, JsonSerializerOptions options) + { + var element = value.Serialize(); + element.WriteTo(writer); + } +} diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentSessionStateBagValue.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentSessionStateBagValue.cs index a4c1743e77..10536a9fa1 100644 --- a/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentSessionStateBagValue.cs +++ b/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentSessionStateBagValue.cs @@ -9,13 +9,13 @@ namespace Microsoft.Agents.AI; /// /// Used to store a value in session state. /// +[JsonConverter(typeof(AgentSessionStateBagValueJsonConverter))] internal class AgentSessionStateBagValue { /// /// Initializes a new instance of the SessionStateValue class with the specified value. /// /// The serialized value to associate with the session state. - [JsonConstructor] public AgentSessionStateBagValue(JsonElement jsonValue) { this.JsonValue = jsonValue; @@ -56,12 +56,9 @@ public JsonElement JsonValue set; } - [JsonIgnore] public object? DeserializedValue { get; set; } - [JsonIgnore] public Type? ValueType { get; set; } - [JsonIgnore] public JsonSerializerOptions? JsonSerializerOptions { get; set; } } diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentSessionStateBagValueJsonConverter.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentSessionStateBagValueJsonConverter.cs new file mode 100644 index 0000000000..0020526d6c --- /dev/null +++ b/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentSessionStateBagValueJsonConverter.cs @@ -0,0 +1,27 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Text.Json; +using System.Text.Json.Serialization; + +namespace Microsoft.Agents.AI; + +/// +/// Custom JSON converter for that serializes and deserializes +/// the directly rather than wrapping it in a container object. +/// +internal sealed class AgentSessionStateBagValueJsonConverter : JsonConverter +{ + /// + public override AgentSessionStateBagValue Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) + { + var element = JsonElement.ParseValue(ref reader); + return new AgentSessionStateBagValue(element); + } + + /// + public override void Write(Utf8JsonWriter writer, AgentSessionStateBagValue value, JsonSerializerOptions options) + { + value.JsonValue.WriteTo(writer); + } +} diff --git a/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgent.cs b/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgent.cs index cbcfeaa24d..93070e00be 100644 --- a/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgent.cs +++ b/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgent.cs @@ -105,6 +105,9 @@ public ChatClientAgent(IChatClient chatClient, ChatClientAgentOptions? options, // If the user has not opted out of using our default decorators, we wrap the chat client. this.ChatClient = options?.UseProvidedChatClientAsIs is true ? chatClient : chatClient.WithDefaultAgentMiddleware(options, services); + // Default to an InMemoryChatHistoryProvider if no provider is configured. + this.ChatHistoryProvider = options?.ChatHistoryProvider ?? new InMemoryChatHistoryProvider(); + this._logger = (loggerFactory ?? chatClient.GetService() ?? NullLoggerFactory.Instance).CreateLogger(); } @@ -120,6 +123,14 @@ public ChatClientAgent(IChatClient chatClient, ChatClientAgentOptions? options, /// public IChatClient ChatClient { get; } + /// + /// Gets the used by this agent, to support cases where the chat history is not stored by the agent service. + /// + /// + /// This property may be null in case the agent stores messages in the underlying agent service. + /// + public ChatHistoryProvider? ChatHistoryProvider { get; } + /// protected override string? IdCore => this._agentOptions?.Id; @@ -283,7 +294,7 @@ protected override async IAsyncEnumerable RunCoreStreamingA // We can derive the type of supported session from whether we have a conversation id, // so let's update it and set the conversation id for the service session case. - await this.UpdateSessionWithTypeAndConversationIdAsync(safeSession, chatResponse.ConversationId, cancellationToken).ConfigureAwait(false); + this.UpdateSessionConversationId(safeSession, chatResponse.ConversationId, cancellationToken); // To avoid inconsistent state we only notify the session of the input messages if no error occurs after the initial request. await this.NotifyChatHistoryProviderOfNewMessagesAsync(safeSession, GetInputMessages(inputMessages, continuationToken), chatHistoryProviderMessages, aiContextProviderMessages, chatResponse.Messages, chatOptions, cancellationToken).ConfigureAwait(false); @@ -299,16 +310,14 @@ protected override async IAsyncEnumerable RunCoreStreamingA : serviceType == typeof(IChatClient) ? this.ChatClient : serviceType == typeof(ChatOptions) ? this._agentOptions?.ChatOptions : serviceType == typeof(ChatClientAgentOptions) ? this._agentOptions - : this.ChatClient.GetService(serviceType, serviceKey)); + : this._agentOptions?.AIContextProvider?.GetService(serviceType, serviceKey) + ?? this._agentOptions?.ChatHistoryProvider?.GetService(serviceType, serviceKey) + ?? this.ChatClient.GetService(serviceType, serviceKey)); /// protected override ValueTask CreateSessionCoreAsync(CancellationToken cancellationToken = default) { - return new(new ChatClientAgentSession - { - ChatHistoryProvider = this._agentOptions?.ChatHistoryProvider, - AIContextProvider = this._agentOptions?.AIContextProvider - }); + return new(new ChatClientAgentSession()); } /// @@ -334,38 +343,6 @@ public ValueTask CreateSessionAsync(string conversationId, Cancell return new(new ChatClientAgentSession() { ConversationId = conversationId, - AIContextProvider = this._agentOptions?.AIContextProvider - }); - } - - /// - /// Creates a new agent session instance using an existing to continue a conversation. - /// - /// The instance to use for managing the conversation's message history. - /// The to monitor for cancellation requests. - /// - /// A value task representing the asynchronous operation. The task result contains a new instance configured to work with the provided . - /// - /// - /// - /// This method creates threads that do not support server-side conversation storage. - /// Some AI services require server-side conversation storage to function properly, and creating a session - /// with a may not be compatible with these services. - /// - /// - /// Where a service requires server-side conversation storage, use . - /// - /// - /// If the agent detects, during the first run, that the underlying AI service requires server-side conversation storage, - /// the session will throw an exception to indicate that it cannot continue using the provided . - /// - /// - public ValueTask CreateSessionAsync(ChatHistoryProvider chatHistoryProvider, CancellationToken cancellationToken = default) - { - return new(new ChatClientAgentSession() - { - ChatHistoryProvider = Throw.IfNull(chatHistoryProvider), - AIContextProvider = this._agentOptions?.AIContextProvider }); } @@ -385,10 +362,7 @@ protected override JsonElement SerializeSessionCore(AgentSession session, JsonSe /// protected override ValueTask DeserializeSessionCoreAsync(JsonElement serializedState, JsonSerializerOptions? jsonSerializerOptions = null, CancellationToken cancellationToken = default) { - return new(ChatClientAgentSession.Deserialize( - serializedState, - this._agentOptions?.ChatHistoryProvider, - this._agentOptions?.AIContextProvider)); + return new(ChatClientAgentSession.Deserialize(serializedState)); } #region Private @@ -438,7 +412,7 @@ private async Task RunCoreAsync responseMessages, CancellationToken cancellationToken) { - if (session.AIContextProvider is not null) + if (this._agentOptions?.AIContextProvider is { } contextProvider) { - await session.AIContextProvider.InvokedAsync(new(this, session, inputMessages, aiContextProviderMessages) { ResponseMessages = responseMessages }, + await contextProvider.InvokedAsync(new(this, session, inputMessages, aiContextProviderMessages) { ResponseMessages = responseMessages }, cancellationToken).ConfigureAwait(false); } } @@ -486,9 +460,9 @@ private async Task NotifyAIContextProviderOfFailureAsync( IList? aiContextProviderMessages, CancellationToken cancellationToken) { - if (session.AIContextProvider is not null) + if (this._agentOptions?.AIContextProvider is { } contextProvider) { - await session.AIContextProvider.InvokedAsync(new(this, session, inputMessages, aiContextProviderMessages) { InvokeException = ex }, + await contextProvider.InvokedAsync(new(this, session, inputMessages, aiContextProviderMessages) { InvokeException = ex }, cancellationToken).ConfigureAwait(false); } } @@ -695,7 +669,7 @@ private async Task // Populate the session messages only if we are not continuing an existing response as it's not allowed if (chatOptions?.ContinuationToken is null) { - ChatHistoryProvider? chatHistoryProvider = ResolveChatHistoryProvider(typedSession, chatOptions); + ChatHistoryProvider? chatHistoryProvider = this.ResolveChatHistoryProvider(chatOptions); // Add any existing messages from the session to the messages to be sent to the chat client. if (chatHistoryProvider is not null) @@ -711,10 +685,10 @@ private async Task // If we have an AIContextProvider, we should get context from it, and update our // messages and options with the additional context. - if (typedSession.AIContextProvider is not null) + if (this._agentOptions?.AIContextProvider is { } aiContextProvider) { var invokingContext = new AIContextProvider.InvokingContext(this, typedSession, inputMessages); - var aiContext = await typedSession.AIContextProvider.InvokingAsync(invokingContext, cancellationToken).ConfigureAwait(false); + var aiContext = await aiContextProvider.InvokingAsync(invokingContext, cancellationToken).ConfigureAwait(false); if (aiContext.Messages is { Count: > 0 }) { inputMessagesForChatClient.AddRange(aiContext.Messages); @@ -760,7 +734,7 @@ private async Task return (typedSession, chatOptions, inputMessagesForChatClient, aiContextProviderMessages, chatHistoryProviderMessages, continuationToken); } - private async Task UpdateSessionWithTypeAndConversationIdAsync(ChatClientAgentSession session, string? responseConversationId, CancellationToken cancellationToken) + private void UpdateSessionConversationId(ChatClientAgentSession session, string? responseConversationId, CancellationToken cancellationToken) { if (string.IsNullOrWhiteSpace(responseConversationId) && !string.IsNullOrWhiteSpace(session.ConversationId)) { @@ -771,18 +745,17 @@ private async Task UpdateSessionWithTypeAndConversationIdAsync(ChatClientAgentSe if (!string.IsNullOrWhiteSpace(responseConversationId)) { + if (this._agentOptions?.ChatHistoryProvider is not null) + { + // The agent has a ChatHistoryProvider configured, but the service returned a conversation id, + // meaning the service manages chat history server-side. Both cannot be used simultaneously. + throw new InvalidOperationException("Only the ConversationId or ChatHistoryProvider may be used, but not both. The service returned a conversation id indicating server-side chat history management, but the agent has a ChatHistoryProvider configured."); + } + // If we got a conversation id back from the chat client, it means that the service supports server side session storage // so we should update the session with the new id. session.ConversationId = responseConversationId; } - else - { - // If the service doesn't use service side chat history storage (i.e. we got no id back from invocation), and - // the session has no ChatHistoryProvider yet, we should update the session with the custom ChatHistoryProvider or - // default InMemoryChatHistoryProvider so that it has somewhere to store the chat history. - session.ChatHistoryProvider ??= this._agentOptions?.ChatHistoryProvider - ?? new InMemoryChatHistoryProvider(); - } } private Task NotifyChatHistoryProviderOfFailureAsync( @@ -794,7 +767,7 @@ private Task NotifyChatHistoryProviderOfFailureAsync( ChatOptions? chatOptions, CancellationToken cancellationToken) { - ChatHistoryProvider? provider = ResolveChatHistoryProvider(session, chatOptions); + ChatHistoryProvider? provider = this.ResolveChatHistoryProvider(chatOptions); // Only notify the provider if we have one. // If we don't have one, it means that the chat history is service managed and the underlying service is responsible for storing messages. @@ -821,7 +794,7 @@ private Task NotifyChatHistoryProviderOfNewMessagesAsync( ChatOptions? chatOptions, CancellationToken cancellationToken) { - ChatHistoryProvider? provider = ResolveChatHistoryProvider(session, chatOptions); + ChatHistoryProvider? provider = this.ResolveChatHistoryProvider(chatOptions); // Only notify the provider if we have one. // If we don't have one, it means that the chat history is service managed and the underlying service is responsible for storing messages. @@ -838,11 +811,11 @@ private Task NotifyChatHistoryProviderOfNewMessagesAsync( return Task.CompletedTask; } - private static ChatHistoryProvider? ResolveChatHistoryProvider(ChatClientAgentSession session, ChatOptions? chatOptions) + private ChatHistoryProvider? ResolveChatHistoryProvider(ChatOptions? chatOptions) { - ChatHistoryProvider? provider = session.ChatHistoryProvider; + ChatHistoryProvider? provider = this.ChatHistoryProvider; - // If someone provided an override ChatHistoryProvider via AdditionalProperties, we should use that instead of the one on the session. + // If someone provided an override ChatHistoryProvider via AdditionalProperties, we should use that instead. if (chatOptions?.AdditionalProperties?.TryGetValue(out ChatHistoryProvider? overrideProvider) is true) { provider = overrideProvider; diff --git a/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgentSession.cs b/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgentSession.cs index 4e9b6bced2..b220170771 100644 --- a/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgentSession.cs +++ b/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgentSession.cs @@ -13,8 +13,6 @@ namespace Microsoft.Agents.AI; [DebuggerDisplay("{DebuggerDisplay,nq}")] public sealed class ChatClientAgentSession : AgentSession { - private ChatHistoryProvider? _chatHistoryProvider; - /// /// Initializes a new instance of the class. /// @@ -23,28 +21,22 @@ internal ChatClientAgentSession() } /// - /// Gets or sets the ID of the underlying service thread to support cases where the chat history is stored by the agent service. + /// Gets or sets the ID of the underlying service chat history to support cases where the chat history is stored by the agent service. /// /// /// - /// Note that either or may be set, but not both. - /// If is not null, setting will throw an - /// exception. - /// - /// /// This property may be null in the following cases: /// - /// The thread stores messages via the and not in the agent service. - /// This thread object is new and a server managed thread has not yet been created in the agent service. + /// The agent stores messages via a and not in the agent service. + /// This session object is new and server managed chat history has not yet been created in the agent service. /// /// /// - /// The id may also change over time where the id is pointing at a - /// agent service managed thread, and the default behavior of a service is - /// to fork the thread with each iteration. + /// The id may also change over time where the id is pointing at + /// agent service managed chat history, and the default behavior of a service is + /// to fork the chat history with each iteration. /// /// - /// Attempted to set a conversation ID but a is already set. public string? ConversationId { get; @@ -55,78 +47,16 @@ internal set return; } - if (this._chatHistoryProvider is not null) - { - // If we have a ChatHistoryProvider already, we shouldn't switch the session to use a conversation id - // since it means that the session contents will essentially be deleted, and the session will not work - // with the original agent anymore. - throw new InvalidOperationException("Only the ConversationId or ChatHistoryProvider may be set, but not both and switching from one to another is not supported."); - } - field = Throw.IfNullOrWhitespace(value); } } - /// - /// Gets or sets the used by this thread, for cases where messages should be stored in a custom location. - /// - /// - /// - /// Note that either or may be set, but not both. - /// If is not null, and is set, - /// will be reverted to null, and vice versa. - /// - /// - /// This property may be null in the following cases: - /// - /// The thread stores messages in the agent service and just has an id to the remove thread, instead of in an . - /// This thread object is new it is not yet clear whether it will be backed by a server managed thread or an . - /// - /// - /// - public ChatHistoryProvider? ChatHistoryProvider - { - get => this._chatHistoryProvider; - internal set - { - if (this._chatHistoryProvider is null && value is null) - { - return; - } - - if (!string.IsNullOrWhiteSpace(this.ConversationId)) - { - // If we have a conversation id already, we shouldn't switch the session to use a ChatHistoryProvider - // since it means that the session will not work with the original agent anymore. - throw new InvalidOperationException("Only the ConversationId or ChatHistoryProvider may be set, but not both and switching from one to another is not supported."); - } - - this._chatHistoryProvider = Throw.IfNull(value); - } - } - - /// - /// Gets or sets the used by this thread to provide additional context to the AI model before each invocation. - /// - public AIContextProvider? AIContextProvider { get; internal set; } - /// /// Creates a new instance of the class from previously serialized state. /// /// A representing the serialized state of the session. - /// - /// An optional instance. - /// If not provided, the default will be used. - /// - /// - /// An optional instance. - /// If not provided, no context provider will be configured. - /// /// The deserialized . - internal static ChatClientAgentSession Deserialize( - JsonElement serializedState, - ChatHistoryProvider? chatHistoryProvider = null, - AIContextProvider? aiContextProvider = null) + internal static ChatClientAgentSession Deserialize(JsonElement serializedState) { if (serializedState.ValueKind != JsonValueKind.Object) { @@ -138,22 +68,13 @@ internal static ChatClientAgentSession Deserialize( var session = new ChatClientAgentSession(); - session.AIContextProvider = aiContextProvider; - session.StateBag = AgentSessionStateBag.Deserialize(state?.StateBag ?? default); if (state?.ConversationId is string sessionId) { session.ConversationId = sessionId; - - // Since we have an ID, we should not have a ChatHistoryProvider and we can return here. - return session; } - session._chatHistoryProvider = - chatHistoryProvider - ?? new InMemoryChatHistoryProvider(); // default to an in-memory ChatHistoryProvider - return session; } @@ -169,18 +90,10 @@ internal JsonElement Serialize(JsonSerializerOptions? jsonSerializerOptions = nu return JsonSerializer.SerializeToElement(state, AgentJsonUtilities.DefaultOptions.GetTypeInfo(typeof(SessionState))); } - /// - public override object? GetService(Type serviceType, object? serviceKey = null) => - base.GetService(serviceType, serviceKey) - ?? this.AIContextProvider?.GetService(serviceType, serviceKey) - ?? this.ChatHistoryProvider?.GetService(serviceType, serviceKey); - [DebuggerBrowsable(DebuggerBrowsableState.Never)] private string DebuggerDisplay => - this.ConversationId is { } conversationId ? $"ConversationId = {conversationId}" : - this._chatHistoryProvider is InMemoryChatHistoryProvider ? "InMemoryChatHistoryProvider" : - this._chatHistoryProvider is { } chatHistoryProvider ? $"ChatHistoryProvider = {chatHistoryProvider.GetType().Name}" : - "Count = 0"; + this.ConversationId is { } conversationId ? $"ConversationId = {conversationId}, StateBag Count = {this.StateBag.Count}" : + $"StateBag Count = {this.StateBag.Count}"; internal sealed class SessionState { diff --git a/dotnet/tests/AnthropicChatCompletion.IntegrationTests/AnthropicChatCompletionFixture.cs b/dotnet/tests/AnthropicChatCompletion.IntegrationTests/AnthropicChatCompletionFixture.cs index a0e4b64763..f36cb119d9 100644 --- a/dotnet/tests/AnthropicChatCompletion.IntegrationTests/AnthropicChatCompletionFixture.cs +++ b/dotnet/tests/AnthropicChatCompletion.IntegrationTests/AnthropicChatCompletionFixture.cs @@ -37,14 +37,14 @@ public AnthropicChatCompletionFixture(bool useReasoningChatModel, bool useBeta) public async Task> GetChatHistoryAsync(AIAgent agent, AgentSession session) { - var typedSession = (ChatClientAgentSession)session; + var chatHistoryProvider = agent.GetService(); - if (typedSession.ChatHistoryProvider is null) + if (chatHistoryProvider is null) { return []; } - return (await typedSession.ChatHistoryProvider.InvokingAsync(new(agent, session, []))).ToList(); + return (await chatHistoryProvider.InvokingAsync(new(agent, session, []))).ToList(); } public Task CreateChatClientAgentAsync( diff --git a/dotnet/tests/AzureAI.IntegrationTests/AIProjectClientFixture.cs b/dotnet/tests/AzureAI.IntegrationTests/AIProjectClientFixture.cs index c655fd7a58..dd926174c0 100644 --- a/dotnet/tests/AzureAI.IntegrationTests/AIProjectClientFixture.cs +++ b/dotnet/tests/AzureAI.IntegrationTests/AIProjectClientFixture.cs @@ -48,12 +48,14 @@ public async Task> GetChatHistoryAsync(AIAgent agent, AgentSes return await this.GetChatHistoryFromResponsesChainAsync(chatClientSession.ConversationId); } - if (chatClientSession.ChatHistoryProvider is null) + var chatHistoryProvider = agent.GetService(); + + if (chatHistoryProvider is null) { return []; } - return (await chatClientSession.ChatHistoryProvider.InvokingAsync(new(agent, session, []))).ToList(); + return (await chatHistoryProvider.InvokingAsync(new(agent, session, []))).ToList(); } private async Task> GetChatHistoryFromResponsesChainAsync(string conversationId) diff --git a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AgentSessionStateBagTests.cs b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AgentSessionStateBagTests.cs index 5d07a4749f..0ecc8b4163 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AgentSessionStateBagTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AgentSessionStateBagTests.cs @@ -396,11 +396,92 @@ public void Serialize_WithComplexObject_ReturnsJsonWithProperties() // Assert Assert.Equal(JsonValueKind.Object, json.ValueKind); Assert.True(json.TryGetProperty("animal", out JsonElement animalElement)); - Assert.True(animalElement.TryGetProperty("jsonValue", out JsonElement jsonValueElement)); - Assert.Equal(JsonValueKind.Object, jsonValueElement.ValueKind); - Assert.Equal(7, jsonValueElement.GetProperty("id").GetInt32()); - Assert.Equal("Spot", jsonValueElement.GetProperty("fullName").GetString()); - Assert.Equal("Walrus", jsonValueElement.GetProperty("species").GetString()); + Assert.Equal(JsonValueKind.Object, animalElement.ValueKind); + Assert.Equal(7, animalElement.GetProperty("id").GetInt32()); + Assert.Equal("Spot", animalElement.GetProperty("fullName").GetString()); + Assert.Equal("Walrus", animalElement.GetProperty("species").GetString()); + } + + #endregion + + #region JsonSerializer Integration Tests + + [Fact] + public void JsonSerializerSerialize_EmptyStateBag_ReturnsEmptyObject() + { + // Arrange + var stateBag = new AgentSessionStateBag(); + + // Act + var json = JsonSerializer.Serialize(stateBag, AgentAbstractionsJsonUtilities.DefaultOptions); + + // Assert + Assert.Equal("{}", json); + } + + [Fact] + public void JsonSerializerSerialize_WithStringValue_ProducesSameOutputAsSerializeMethod() + { + // Arrange + var stateBag = new AgentSessionStateBag(); + stateBag.SetValue("stringKey", "stringValue"); + + // Act + var jsonFromSerializer = JsonSerializer.Serialize(stateBag, AgentAbstractionsJsonUtilities.DefaultOptions); + var jsonFromMethod = stateBag.Serialize().GetRawText(); + + // Assert + Assert.Equal(jsonFromMethod, jsonFromSerializer); + } + + [Fact] + public void JsonSerializerRoundtrip_WithStringValue_PreservesData() + { + // Arrange + var stateBag = new AgentSessionStateBag(); + stateBag.SetValue("greeting", "hello world"); + + // Act + var json = JsonSerializer.Serialize(stateBag, AgentAbstractionsJsonUtilities.DefaultOptions); + var restored = JsonSerializer.Deserialize(json, AgentAbstractionsJsonUtilities.DefaultOptions); + + // Assert + Assert.NotNull(restored); + Assert.Equal("hello world", restored!.GetValue("greeting")); + } + + [Fact] + public void JsonSerializerRoundtrip_WithComplexObject_PreservesData() + { + // Arrange + var stateBag = new AgentSessionStateBag(); + var animal = new Animal { Id = 10, FullName = "Rex", Species = Species.Tiger }; + stateBag.SetValue("animal", animal, TestJsonSerializerContext.Default.Options); + + // Act + var json = JsonSerializer.Serialize(stateBag, AgentAbstractionsJsonUtilities.DefaultOptions); + var restored = JsonSerializer.Deserialize(json, AgentAbstractionsJsonUtilities.DefaultOptions); + + // Assert + Assert.NotNull(restored); + var restoredAnimal = restored!.GetValue("animal", TestJsonSerializerContext.Default.Options); + Assert.NotNull(restoredAnimal); + Assert.Equal(10, restoredAnimal!.Id); + Assert.Equal("Rex", restoredAnimal.FullName); + Assert.Equal(Species.Tiger, restoredAnimal.Species); + } + + [Fact] + public void JsonSerializerDeserialize_NullJson_ReturnsNull() + { + // Arrange + const string Json = "null"; + + // Act + var stateBag = JsonSerializer.Deserialize(Json, AgentAbstractionsJsonUtilities.DefaultOptions); + + // Assert + Assert.Null(stateBag); } #endregion diff --git a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/InMemoryAgentSessionTests.cs b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/InMemoryAgentSessionTests.cs index 725729897c..5511181276 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/InMemoryAgentSessionTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/InMemoryAgentSessionTests.cs @@ -102,9 +102,7 @@ public void Serialize_ReturnsCorrectJson_WhenMessagesExist() Assert.Equal(JsonValueKind.Object, stateBagProperty.ValueKind); Assert.True(stateBagProperty.TryGetProperty("InMemoryChatHistoryProvider.State", out var providerStateProperty)); Assert.Equal(JsonValueKind.Object, providerStateProperty.ValueKind); - Assert.True(providerStateProperty.TryGetProperty("jsonValue", out var jsonValueProperty)); - Assert.Equal(JsonValueKind.Object, jsonValueProperty.ValueKind); - Assert.True(jsonValueProperty.TryGetProperty("messages", out var messagesProperty)); + Assert.True(providerStateProperty.TryGetProperty("messages", out var messagesProperty)); Assert.Equal(JsonValueKind.Array, messagesProperty.ValueKind); var messagesList = messagesProperty.EnumerateArray().ToList(); Assert.Single(messagesList); diff --git a/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgentSessionTests.cs b/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgentSessionTests.cs index 93440fcad9..52b3652e02 100644 --- a/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgentSessionTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgentSessionTests.cs @@ -4,7 +4,6 @@ using System.Linq; using System.Text.Json; using Microsoft.Extensions.AI; -using Moq; #pragma warning disable CA1861 // Avoid constant arrays as arguments @@ -22,7 +21,6 @@ public void ConstructorSetsDefaults() // Assert Assert.Null(session.ConversationId); - Assert.Null(session.ChatHistoryProvider); } [Fact] @@ -37,53 +35,6 @@ public void SetConversationIdRoundtrips() // Assert Assert.Equal(ConversationId, session.ConversationId); - Assert.Null(session.ChatHistoryProvider); - } - - [Fact] - public void SetChatHistoryProviderRoundtrips() - { - // Arrange - var session = new ChatClientAgentSession(); - var chatHistoryProvider = new InMemoryChatHistoryProvider(); - - // Act - session.ChatHistoryProvider = chatHistoryProvider; - - // Assert - Assert.Same(chatHistoryProvider, session.ChatHistoryProvider); - Assert.Null(session.ConversationId); - } - - [Fact] - public void SetConversationIdThrowsWhenChatHistoryProviderIsSet() - { - // Arrange - var session = new ChatClientAgentSession - { - ChatHistoryProvider = new InMemoryChatHistoryProvider() - }; - - // Act & Assert - var exception = Assert.Throws(() => session.ConversationId = "new-session-id"); - Assert.Equal("Only the ConversationId or ChatHistoryProvider may be set, but not both and switching from one to another is not supported.", exception.Message); - Assert.NotNull(session.ChatHistoryProvider); - } - - [Fact] - public void SetChatHistoryProviderThrowsWhenConversationIdIsSet() - { - // Arrange - var session = new ChatClientAgentSession - { - ConversationId = "existing-session-id" - }; - var provider = new InMemoryChatHistoryProvider(); - - // Act & Assert - var exception = Assert.Throws(() => session.ChatHistoryProvider = provider); - Assert.Equal("Only the ConversationId or ChatHistoryProvider may be set, but not both and switching from one to another is not supported.", exception.Message); - Assert.NotNull(session.ConversationId); } #endregion Constructor and Property Tests @@ -98,7 +49,7 @@ public void VerifyDeserializeWithMessages() { "stateBag": { "InMemoryChatHistoryProvider.State": { - "jsonValue": { "messages": [{"authorName": "testAuthor"}] } + "messages": [{"authorName": "testAuthor"}] } } } @@ -110,8 +61,10 @@ public void VerifyDeserializeWithMessages() // Assert Assert.Null(session.ConversationId); - var chatHistoryProvider = session.ChatHistoryProvider as InMemoryChatHistoryProvider; - Assert.NotNull(chatHistoryProvider); + // Verify the StateBag contains the serialized chat history provider state + Assert.True(session.StateBag.TryGetValue("InMemoryChatHistoryProvider.State", out _)); + + var chatHistoryProvider = new InMemoryChatHistoryProvider(); var messages = chatHistoryProvider.GetMessages(session); Assert.Single(messages); Assert.Equal("testAuthor", messages[0].AuthorName); @@ -132,27 +85,6 @@ public void VerifyDeserializeWithId() // Assert Assert.Equal("TestConvId", session.ConversationId); - Assert.Null(session.ChatHistoryProvider); - } - - [Fact] - public void VerifyDeserializeWithAIContextProvider() - { - // Arrange - var json = JsonSerializer.Deserialize(""" - { - "conversationId": "TestConvId", - "aiContextProviderState": ["CP1"] - } - """, TestJsonSerializerContext.Default.JsonElement); - Mock mockProvider = new(); - - // Act - var session = ChatClientAgentSession.Deserialize(json, aiContextProvider: mockProvider.Object); - - // Assert - Assert.Null(session.ChatHistoryProvider); - Assert.Same(session.AIContextProvider, mockProvider.Object); } [Fact] @@ -164,17 +96,13 @@ public void VerifyDeserializeWithStateBag() "conversationId": "TestConvId", "stateBag": { "dog": { - "jsonValue": { - "name": "Fido" - } + "name": "Fido" } } } """, TestJsonSerializerContext.Default.JsonElement); - Mock mockProvider = new(); - // Act - var session = ChatClientAgentSession.Deserialize(json, aiContextProvider: mockProvider.Object); + var session = ChatClientAgentSession.Deserialize(json); // Assert var dog = session.StateBag.GetValue("dog", TestJsonSerializerContext.Default.Options); @@ -225,7 +153,7 @@ public void VerifySessionSerializationWithMessages() { // Arrange var provider = new InMemoryChatHistoryProvider(); - var session = new ChatClientAgentSession { ChatHistoryProvider = provider }; + var session = new ChatClientAgentSession(); provider.SetMessages(session, [new(ChatRole.User, "TestContent") { AuthorName = "TestAuthor" }]); // Act @@ -241,9 +169,7 @@ public void VerifySessionSerializationWithMessages() Assert.Equal(JsonValueKind.Object, stateBagProperty.ValueKind); Assert.True(stateBagProperty.TryGetProperty("InMemoryChatHistoryProvider.State", out var providerStateProperty)); Assert.Equal(JsonValueKind.Object, providerStateProperty.ValueKind); - Assert.True(providerStateProperty.TryGetProperty("jsonValue", out var jsonValueProperty)); - Assert.Equal(JsonValueKind.Object, jsonValueProperty.ValueKind); - Assert.True(jsonValueProperty.TryGetProperty("messages", out var messagesProperty)); + Assert.True(providerStateProperty.TryGetProperty("messages", out var messagesProperty)); Assert.Equal(JsonValueKind.Array, messagesProperty.ValueKind); Assert.Single(messagesProperty.EnumerateArray()); @@ -273,8 +199,7 @@ public void VerifySessionSerializationWithWithStateBag() Assert.Equal(JsonValueKind.Object, stateBagProperty.ValueKind); Assert.True(stateBagProperty.TryGetProperty("dog", out var dogProperty)); Assert.Equal(JsonValueKind.Object, dogProperty.ValueKind); - Assert.True(dogProperty.TryGetProperty("jsonValue", out var dogJsonValueProperty)); - Assert.True(dogJsonValueProperty.TryGetProperty("name", out var nameProperty)); + Assert.True(dogProperty.TryGetProperty("name", out var nameProperty)); Assert.Equal("Fido", nameProperty.GetString()); } @@ -289,9 +214,6 @@ public void VerifySessionSerializationWithCustomOptions() JsonSerializerOptions options = new() { PropertyNamingPolicy = JsonNamingPolicy.SnakeCaseLower }; options.TypeInfoResolverChain.Add(AgentAbstractionsJsonUtilities.DefaultOptions.TypeInfoResolver!); - var chatHistoryProviderMock = new Mock(); - session.ChatHistoryProvider = chatHistoryProviderMock.Object; - // Act var json = session.Serialize(options); @@ -324,45 +246,6 @@ public void VerifyStateBagRoundtrips() #endregion - #region GetService Tests - - [Fact] - public void GetService_RequestingAIContextProvider_ReturnsAIContextProvider() - { - // Arrange - var session = new ChatClientAgentSession(); - var mockProvider = new Mock(); - mockProvider - .Setup(m => m.GetService(It.Is(x => x == typeof(AIContextProvider)), null)) - .Returns(mockProvider.Object); - session.AIContextProvider = mockProvider.Object; - - // Act - var result = session.GetService(typeof(AIContextProvider)); - - // Assert - Assert.NotNull(result); - Assert.Same(mockProvider.Object, result); - } - - [Fact] - public void GetService_RequestingChatHistoryProvider_ReturnsChatHistoryProvider() - { - // Arrange - var session = new ChatClientAgentSession(); - var chatHistoryProvider = new InMemoryChatHistoryProvider(); - session.ChatHistoryProvider = chatHistoryProvider; - - // Act - var result = session.GetService(typeof(ChatHistoryProvider)); - - // Assert - Assert.NotNull(result); - Assert.Same(chatHistoryProvider, result); - } - - #endregion - internal sealed class Animal { public string Name { get; set; } = string.Empty; diff --git a/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgentTests.cs b/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgentTests.cs index bda334020e..156679a466 100644 --- a/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgentTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgentTests.cs @@ -372,7 +372,8 @@ public async Task RunAsyncInvokesAIContextProviderAndUsesResultAsync() Assert.Contains(capturedTools, t => t.Name == "context provider function"); // Verify that the session was updated with the ai context provider, input and response messages - var chatHistoryProvider = Assert.IsType(session!.ChatHistoryProvider); + var chatHistoryProvider = agent.ChatHistoryProvider as InMemoryChatHistoryProvider; + Assert.NotNull(chatHistoryProvider); var messages = chatHistoryProvider.GetMessages(session); Assert.Equal(3, messages.Count); Assert.Equal("user message", messages[0].Text); @@ -1299,7 +1300,7 @@ public async Task RunStreamingAsyncUsesChatHistoryProviderWhenNoConversationIdRe await agent.RunStreamingAsync([new(ChatRole.User, "test")], session).ToListAsync(); // Assert - var chatHistoryProvider = Assert.IsType(session!.ChatHistoryProvider); + var chatHistoryProvider = Assert.IsType(agent.GetService(typeof(ChatHistoryProvider))); var historyMessages = chatHistoryProvider.GetMessages(session); Assert.Equal(2, historyMessages.Count); Assert.Equal("test", historyMessages[0].Text); @@ -1333,7 +1334,7 @@ public async Task RunStreamingAsyncThrowsWhenChatHistoryProviderProvidedAndConve // Act & Assert ChatClientAgentSession? session = await agent.CreateSessionAsync() as ChatClientAgentSession; var exception = await Assert.ThrowsAsync(async () => await agent.RunStreamingAsync([new(ChatRole.User, "test")], session).ToListAsync()); - Assert.Equal("Only the ConversationId or ChatHistoryProvider may be set, but not both and switching from one to another is not supported.", exception.Message); + Assert.Equal("Only the ConversationId or ChatHistoryProvider may be used, but not both. The service returned a conversation id indicating server-side chat history management, but the agent has a ChatHistoryProvider configured.", exception.Message); } /// @@ -1405,7 +1406,8 @@ public async Task RunStreamingAsyncInvokesAIContextProviderAndUsesResultAsync() Assert.Contains(capturedTools, t => t.Name == "context provider function"); // Verify that the session was updated with the input, ai context provider, and response messages - var chatHistoryProvider = Assert.IsType(session!.ChatHistoryProvider); + var chatHistoryProvider = agent.ChatHistoryProvider as InMemoryChatHistoryProvider; + Assert.NotNull(chatHistoryProvider); var historyMessages2 = chatHistoryProvider.GetMessages(session); Assert.Equal(3, historyMessages2.Count); Assert.Equal("user message", historyMessages2[0].Text); diff --git a/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgent_BackgroundResponsesTests.cs b/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgent_BackgroundResponsesTests.cs index 87be3fb96e..e5863efd15 100644 --- a/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgent_BackgroundResponsesTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgent_BackgroundResponsesTests.cs @@ -362,14 +362,14 @@ public async Task RunAsync_WhenContinuationTokenProvided_SkipsSessionMessagePopu capturedMessages.AddRange(msgs)) .ReturnsAsync(new ChatResponse([new(ChatRole.Assistant, "continued response")])); - ChatClientAgent agent = new(mockChatClient.Object); - - // Create a session with both chat history provider and AI context provider - ChatClientAgentSession? session = new() + ChatClientAgent agent = new(mockChatClient.Object, options: new() { ChatHistoryProvider = mockChatHistoryProvider.Object, AIContextProvider = mockContextProvider.Object - }; + }); + + // Create a session + ChatClientAgentSession? session = new(); AgentRunOptions runOptions = new() { @@ -427,14 +427,14 @@ public async Task RunStreamingAsync_WhenContinuationTokenProvided_SkipsSessionMe capturedMessages.AddRange(msgs)) .Returns(ToAsyncEnumerableAsync([new ChatResponseUpdate(role: ChatRole.Assistant, content: "continued response")])); - ChatClientAgent agent = new(mockChatClient.Object); - - // Create a session with both chat history provider and AI context provider - ChatClientAgentSession? session = new() + ChatClientAgent agent = new(mockChatClient.Object, options: new() { ChatHistoryProvider = mockChatHistoryProvider.Object, AIContextProvider = mockContextProvider.Object - }; + }); + + // Create a session + ChatClientAgentSession? session = new(); AgentRunOptions runOptions = new() { @@ -628,8 +628,6 @@ public async Task RunStreamingAsync_WhenResumingStreaming_UsesUpdatesFromInitial It.IsAny())) .Returns(ToAsyncEnumerableAsync(returnUpdates)); - ChatClientAgent agent = new(mockChatClient.Object); - List capturedMessagesAddedToProvider = []; var mockChatHistoryProvider = new Mock(); mockChatHistoryProvider @@ -644,11 +642,13 @@ public async Task RunStreamingAsync_WhenResumingStreaming_UsesUpdatesFromInitial .Callback((context, ct) => capturedInvokedContext = context) .Returns(new ValueTask()); - ChatClientAgentSession? session = new() + ChatClientAgent agent = new(mockChatClient.Object, options: new() { ChatHistoryProvider = mockChatHistoryProvider.Object, AIContextProvider = mockContextProvider.Object - }; + }); + + ChatClientAgentSession? session = new(); AgentRunOptions runOptions = new() { @@ -684,8 +684,6 @@ public async Task RunStreamingAsync_WhenResumingStreaming_UsesInputMessagesFromI It.IsAny())) .Returns(ToAsyncEnumerableAsync(Array.Empty())); - ChatClientAgent agent = new(mockChatClient.Object); - List capturedMessagesAddedToProvider = []; var mockChatHistoryProvider = new Mock(); mockChatHistoryProvider @@ -700,11 +698,13 @@ public async Task RunStreamingAsync_WhenResumingStreaming_UsesInputMessagesFromI .Callback((context, ct) => capturedInvokedContext = context) .Returns(new ValueTask()); - ChatClientAgentSession? session = new() + ChatClientAgent agent = new(mockChatClient.Object, options: new() { ChatHistoryProvider = mockChatHistoryProvider.Object, AIContextProvider = mockContextProvider.Object - }; + }); + + ChatClientAgentSession? session = new(); AgentRunOptions runOptions = new() { diff --git a/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgent_ChatHistoryManagementTests.cs b/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgent_ChatHistoryManagementTests.cs index db1c0101f2..9e7121467c 100644 --- a/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgent_ChatHistoryManagementTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgent_ChatHistoryManagementTests.cs @@ -162,8 +162,9 @@ public async Task RunAsync_UsesDefaultInMemoryChatHistoryProvider_WhenNoConversa await agent.RunAsync([new(ChatRole.User, "test")], session); // Assert - InMemoryChatHistoryProvider chatHistoryProvider = Assert.IsType(session!.ChatHistoryProvider); - var messages = chatHistoryProvider.GetMessages(session); + var inMemoryProvider = agent.ChatHistoryProvider as InMemoryChatHistoryProvider; + Assert.NotNull(inMemoryProvider); + var messages = inMemoryProvider.GetMessages(session!); Assert.Equal(2, messages.Count); Assert.Equal("test", messages[0].Text); Assert.Equal("response", messages[1].Text); @@ -202,7 +203,7 @@ public async Task RunAsync_UsesChatHistoryProvider_WhenProvidedAndNoConversation await agent.RunAsync([new(ChatRole.User, "test")], session); // Assert - Assert.IsType(session!.ChatHistoryProvider, exactMatch: false); + Assert.Same(mockChatHistoryProvider.Object, agent.ChatHistoryProvider); mockService.Verify( x => x.GetResponseAsync( It.Is>(msgs => msgs.Count() == 2 && msgs.Any(m => m.Text == "Existing Chat History") && msgs.Any(m => m.Text == "test")), @@ -246,7 +247,7 @@ public async Task RunAsync_NotifiesChatHistoryProvider_OnFailureAsync() await Assert.ThrowsAsync(() => agent.RunAsync([new(ChatRole.User, "test")], session)); // Assert - Assert.IsType(session!.ChatHistoryProvider, exactMatch: false); + Assert.Same(mockChatHistoryProvider.Object, agent.ChatHistoryProvider); mockChatHistoryProvider.Verify(s => s.InvokedAsync( It.Is(x => x.RequestMessages.Count() == 1 && x.ResponseMessages == null && x.InvokeException!.Message == "Test Error"), It.IsAny()), @@ -275,7 +276,7 @@ public async Task RunAsync_Throws_WhenChatHistoryProviderProvidedAndConversation // Act & Assert ChatClientAgentSession? session = await agent.CreateSessionAsync() as ChatClientAgentSession; InvalidOperationException exception = await Assert.ThrowsAsync(() => agent.RunAsync([new(ChatRole.User, "test")], session)); - Assert.Equal("Only the ConversationId or ChatHistoryProvider may be set, but not both and switching from one to another is not supported.", exception.Message); + Assert.Equal("Only the ConversationId or ChatHistoryProvider may be used, but not both. The service returned a conversation id indicating server-side chat history management, but the agent has a ChatHistoryProvider configured.", exception.Message); } #endregion @@ -308,18 +309,18 @@ public async Task RunAsync_UsesOverrideChatHistoryProvider_WhenProvidedViaAdditi // Arrange a chat history provider to provide to the agent at construction time. // This one shouldn't be used since it is being overridden. - Mock mockFactoryChatHistoryProvider = new(); - mockFactoryChatHistoryProvider.Setup(s => s.InvokingAsync( + Mock mockAgentOptionsChatHistoryProvider = new(); + mockAgentOptionsChatHistoryProvider.Setup(s => s.InvokingAsync( It.IsAny(), It.IsAny())).ThrowsAsync(FailException.ForFailure("Base ChatHistoryProvider shouldn't be used.")); - mockFactoryChatHistoryProvider.Setup(s => s.InvokedAsync( + mockAgentOptionsChatHistoryProvider.Setup(s => s.InvokedAsync( It.IsAny(), It.IsAny())).Throws(FailException.ForFailure("Base ChatHistoryProvider shouldn't be used.")); ChatClientAgent agent = new(mockService.Object, options: new() { ChatOptions = new() { Instructions = "test instructions" }, - ChatHistoryProvider = mockFactoryChatHistoryProvider.Object + ChatHistoryProvider = mockAgentOptionsChatHistoryProvider.Object }); // Act @@ -329,7 +330,7 @@ public async Task RunAsync_UsesOverrideChatHistoryProvider_WhenProvidedViaAdditi await agent.RunAsync([new(ChatRole.User, "test")], session, options: new AgentRunOptions { AdditionalProperties = additionalProperties }); // Assert - Assert.Same(mockFactoryChatHistoryProvider.Object, session!.ChatHistoryProvider); + Assert.Same(mockAgentOptionsChatHistoryProvider.Object, agent.ChatHistoryProvider); mockService.Verify( x => x.GetResponseAsync( It.Is>(msgs => msgs.Count() == 2 && msgs.Any(m => m.Text == "Existing Chat History") && msgs.Any(m => m.Text == "test")), @@ -345,11 +346,11 @@ public async Task RunAsync_UsesOverrideChatHistoryProvider_WhenProvidedViaAdditi It.IsAny()), Times.Once); - mockFactoryChatHistoryProvider.Verify(s => s.InvokingAsync( + mockAgentOptionsChatHistoryProvider.Verify(s => s.InvokingAsync( It.IsAny(), It.IsAny()), Times.Never); - mockFactoryChatHistoryProvider.Verify(s => s.InvokedAsync( + mockAgentOptionsChatHistoryProvider.Verify(s => s.InvokedAsync( It.IsAny(), It.IsAny()), Times.Never); diff --git a/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgent_CreateSessionTests.cs b/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgent_CreateSessionTests.cs index f108b0966d..68fd008e9e 100644 --- a/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgent_CreateSessionTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgent_CreateSessionTests.cs @@ -11,65 +11,6 @@ namespace Microsoft.Agents.AI.UnitTests; /// public class ChatClientAgent_CreateSessionTests { - [Fact] - public async Task CreateSession_UsesAIContextProvider_IfProvidedAsync() - { - // Arrange - var mockChatClient = new Mock(); - var mockContextProvider = new Mock(); - var agent = new ChatClientAgent(mockChatClient.Object, new ChatClientAgentOptions - { - ChatOptions = new() { Instructions = "Test instructions" }, - AIContextProvider = mockContextProvider.Object - }); - - // Act - var session = await agent.CreateSessionAsync(); - - // Assert - Assert.IsType(session); - var typedSession = (ChatClientAgentSession)session; - Assert.Same(mockContextProvider.Object, typedSession.AIContextProvider); - } - - [Fact] - public async Task CreateSession_UsesChatHistoryProvider_IfProvidedAsync() - { - // Arrange - var mockChatClient = new Mock(); - var mockChatHistoryProvider = new Mock(); - var agent = new ChatClientAgent(mockChatClient.Object, new ChatClientAgentOptions - { - ChatOptions = new() { Instructions = "Test instructions" }, - ChatHistoryProvider = mockChatHistoryProvider.Object - }); - - // Act - var session = await agent.CreateSessionAsync(); - - // Assert - Assert.IsType(session); - var typedSession = (ChatClientAgentSession)session; - Assert.Same(mockChatHistoryProvider.Object, typedSession.ChatHistoryProvider); - } - - [Fact] - public async Task CreateSession_UsesChatHistoryProvider_FromTypedOverloadAsync() - { - // Arrange - var mockChatClient = new Mock(); - var mockChatHistoryProvider = new Mock(); - var agent = new ChatClientAgent(mockChatClient.Object); - - // Act - var session = await agent.CreateSessionAsync(mockChatHistoryProvider.Object); - - // Assert - Assert.IsType(session); - var typedSession = (ChatClientAgentSession)session; - Assert.Same(mockChatHistoryProvider.Object, typedSession.ChatHistoryProvider); - } - [Fact] public async Task CreateSession_UsesConversationId_FromTypedOverloadAsync() { diff --git a/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgent_DeserializeSessionTests.cs b/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgent_DeserializeSessionTests.cs deleted file mode 100644 index d4f2ab9628..0000000000 --- a/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgent_DeserializeSessionTests.cs +++ /dev/null @@ -1,68 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. - -using System.Text.Json; -using System.Threading.Tasks; -using Microsoft.Extensions.AI; -using Moq; - -namespace Microsoft.Agents.AI.UnitTests; - -/// -/// Contains unit tests for the ChatClientAgent.DeserializeSession methods. -/// -public class ChatClientAgent_DeserializeSessionTests -{ - [Fact] - public async Task DeserializeSession_UsesAIContextProvider_IfProvidedAsync() - { - // Arrange - var mockChatClient = new Mock(); - var mockContextProvider = new Mock(); - var agent = new ChatClientAgent(mockChatClient.Object, new ChatClientAgentOptions - { - ChatOptions = new() { Instructions = "Test instructions" }, - AIContextProvider = mockContextProvider.Object - }); - - var json = JsonSerializer.Deserialize(""" - { - "aiContextProviderState": ["CP1"] - } - """, TestJsonSerializerContext.Default.JsonElement); - - // Act - var session = await agent.DeserializeSessionAsync(json); - - // Assert - Assert.IsType(session); - var typedSession = (ChatClientAgentSession)session; - Assert.Same(mockContextProvider.Object, typedSession.AIContextProvider); - } - - [Fact] - public async Task DeserializeSession_UsesChatHistoryProvider_IfProvidedAsync() - { - // Arrange - var mockChatClient = new Mock(); - var mockChatHistoryProvider = new Mock(); - var agent = new ChatClientAgent(mockChatClient.Object, new ChatClientAgentOptions - { - ChatOptions = new() { Instructions = "Test instructions" }, - ChatHistoryProvider = mockChatHistoryProvider.Object - }); - - var json = JsonSerializer.Deserialize(""" - { - "chatHistoryProviderState": { } - } - """, TestJsonSerializerContext.Default.JsonElement); - - // Act - var session = await agent.DeserializeSessionAsync(json); - - // Assert - Assert.IsType(session); - var typedSession = (ChatClientAgentSession)session; - Assert.Same(mockChatHistoryProvider.Object, typedSession.ChatHistoryProvider); - } -} diff --git a/dotnet/tests/Microsoft.Agents.AI.UnitTests/Data/TextSearchProviderTests.cs b/dotnet/tests/Microsoft.Agents.AI.UnitTests/Data/TextSearchProviderTests.cs index 2cc83583f4..7a4caf820c 100644 --- a/dotnet/tests/Microsoft.Agents.AI.UnitTests/Data/TextSearchProviderTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.UnitTests/Data/TextSearchProviderTests.cs @@ -524,8 +524,7 @@ public async Task InvokedAsync_ShouldPersistMessagesToSessionStateBagAsync() // Assert - State should be in the session's StateBag var stateBagSerialized = session.StateBag.Serialize(); Assert.True(stateBagSerialized.TryGetProperty("TextSearchProvider.RecentMessagesText", out var stateProperty)); - Assert.True(stateProperty.TryGetProperty("jsonValue", out var jsonValueProperty)); - Assert.True(jsonValueProperty.TryGetProperty("recentMessagesText", out var recentProperty)); + Assert.True(stateProperty.TryGetProperty("recentMessagesText", out var recentProperty)); Assert.Equal(JsonValueKind.Array, recentProperty.ValueKind); var list = recentProperty.EnumerateArray().Select(e => e.GetString()).ToList(); Assert.Equal(3, list.Count); diff --git a/dotnet/tests/OpenAIChatCompletion.IntegrationTests/OpenAIChatCompletionFixture.cs b/dotnet/tests/OpenAIChatCompletion.IntegrationTests/OpenAIChatCompletionFixture.cs index 304df28fba..96a9a17ae8 100644 --- a/dotnet/tests/OpenAIChatCompletion.IntegrationTests/OpenAIChatCompletionFixture.cs +++ b/dotnet/tests/OpenAIChatCompletion.IntegrationTests/OpenAIChatCompletionFixture.cs @@ -30,14 +30,14 @@ public OpenAIChatCompletionFixture(bool useReasoningChatModel) public async Task> GetChatHistoryAsync(AIAgent agent, AgentSession session) { - var typedSession = (ChatClientAgentSession)session; + var chatHistoryProvider = agent.GetService(); - if (typedSession.ChatHistoryProvider is null) + if (chatHistoryProvider is null) { return []; } - return (await typedSession.ChatHistoryProvider.InvokingAsync(new(agent, session, []))).ToList(); + return (await chatHistoryProvider.InvokingAsync(new(agent, session, []))).ToList(); } public Task CreateChatClientAgentAsync( diff --git a/dotnet/tests/OpenAIResponse.IntegrationTests/OpenAIResponseFixture.cs b/dotnet/tests/OpenAIResponse.IntegrationTests/OpenAIResponseFixture.cs index 719db6a0b0..e36f8990f6 100644 --- a/dotnet/tests/OpenAIResponse.IntegrationTests/OpenAIResponseFixture.cs +++ b/dotnet/tests/OpenAIResponse.IntegrationTests/OpenAIResponseFixture.cs @@ -50,12 +50,14 @@ public async Task> GetChatHistoryAsync(AIAgent agent, AgentSes return [.. previousMessages, responseMessage]; } - if (typedSession.ChatHistoryProvider is null) + var chatHistoryProvider = agent.GetService(); + + if (chatHistoryProvider is null) { return []; } - return (await typedSession.ChatHistoryProvider.InvokingAsync(new(agent, session, []))).ToList(); + return (await chatHistoryProvider.InvokingAsync(new(agent, session, []))).ToList(); } private static ChatMessage ConvertToChatMessage(ResponseItem item) From e2ee1355e4b5bd23a6e7db49098d3fd07ddc4ffb Mon Sep 17 00:00:00 2001 From: westey <164392973+westey-m@users.noreply.github.com> Date: Fri, 6 Feb 2026 19:46:10 +0000 Subject: [PATCH 09/28] Update samples to use getservice on agent --- .../AgentWithMemory_Step02_MemoryUsingMem0/Program.cs | 2 +- .../AgentWithMemory_Step03_CustomMemory/Program.cs | 10 +++++----- .../Agent_Step07_3rdPartyChatHistoryStorage/Program.cs | 6 ++++-- .../Agents/Agent_Step16_ChatReduction/Program.cs | 5 ++++- .../Microsoft.Agents.AI/ChatClient/ChatClientAgent.cs | 4 ++-- 5 files changed, 16 insertions(+), 11 deletions(-) diff --git a/dotnet/samples/GettingStarted/AgentWithMemory/AgentWithMemory_Step02_MemoryUsingMem0/Program.cs b/dotnet/samples/GettingStarted/AgentWithMemory/AgentWithMemory_Step02_MemoryUsingMem0/Program.cs index 0c6406133e..588c79ba9a 100644 --- a/dotnet/samples/GettingStarted/AgentWithMemory/AgentWithMemory_Step02_MemoryUsingMem0/Program.cs +++ b/dotnet/samples/GettingStarted/AgentWithMemory/AgentWithMemory_Step02_MemoryUsingMem0/Program.cs @@ -44,7 +44,7 @@ // Clear any existing memories for this scope to demonstrate fresh behavior. // Note that the ClearStoredMemoriesAsync method will clear memories // using the scope stored in the session, or provided via the stateInitializer. -Mem0Provider mem0Provider = session.GetService()!; +Mem0Provider mem0Provider = agent.GetService()!; await mem0Provider.ClearStoredMemoriesAsync(session); Console.WriteLine(await agent.RunAsync("Hi there! My name is Taylor and I'm planning a hiking trip to Patagonia in November.", session)); diff --git a/dotnet/samples/GettingStarted/AgentWithMemory/AgentWithMemory_Step03_CustomMemory/Program.cs b/dotnet/samples/GettingStarted/AgentWithMemory/AgentWithMemory_Step03_CustomMemory/Program.cs index 151a32f67e..06b70004bf 100644 --- a/dotnet/samples/GettingStarted/AgentWithMemory/AgentWithMemory_Step03_CustomMemory/Program.cs +++ b/dotnet/samples/GettingStarted/AgentWithMemory/AgentWithMemory_Step03_CustomMemory/Program.cs @@ -55,10 +55,10 @@ var deserializedSession = await agent.DeserializeSessionAsync(sesionElement); Console.WriteLine(await agent.RunAsync("What is my name and age?", deserializedSession)); -Console.WriteLine("\n>> Read memories from memory component\n"); +Console.WriteLine("\n>> Read memories using memory component\n"); -// It's possible to access the memory component via the session's GetService method. -var userInfo = deserializedSession.GetService()?.GetUserInfo(deserializedSession); +// It's possible to access the memory component via the agent's GetService method. +var userInfo = agent.GetService()?.GetUserInfo(deserializedSession); // Output the user info that was captured by the memory component. Console.WriteLine($"MEMORY - User Name: {userInfo?.UserName}"); @@ -66,10 +66,10 @@ Console.WriteLine("\n>> Use new session with previously created memories\n"); -// It is also possible to set the memories in a memory component on an individual session. +// It is also possible to set the memories using a memory component on an individual session. // This is useful if we want to start a new session, but have it share the same memories as a previous session. var newSession = await agent.CreateSessionAsync(); -if (userInfo is not null && newSession.GetService() is UserInfoMemory newSessionMemory) +if (userInfo is not null && agent.GetService() is UserInfoMemory newSessionMemory) { newSessionMemory.SetUserInfo(newSession, userInfo); } diff --git a/dotnet/samples/GettingStarted/Agents/Agent_Step07_3rdPartyChatHistoryStorage/Program.cs b/dotnet/samples/GettingStarted/Agents/Agent_Step07_3rdPartyChatHistoryStorage/Program.cs index d58c2253ae..d56c4adc0b 100644 --- a/dotnet/samples/GettingStarted/Agents/Agent_Step07_3rdPartyChatHistoryStorage/Program.cs +++ b/dotnet/samples/GettingStarted/Agents/Agent_Step07_3rdPartyChatHistoryStorage/Program.cs @@ -60,8 +60,10 @@ // Run the agent with the session that stores chat history in the vector store a second time. Console.WriteLine(await agent.RunAsync("Now tell the same joke in the voice of a pirate, and add some emojis to the joke.", resumedSession)); -// We can access the VectorChatHistoryProvider via the session's GetService method if we need to read the key under which chat history is stored. -var chatHistoryProvider = resumedSession.GetService()!; +// We can access the VectorChatHistoryProvider via the agent's GetService method +// if we need to read the key under which chat history is stored. The key is stored +// in the session state, and therefore we need to provide the session when reading it. +var chatHistoryProvider = agent.GetService()!; Console.WriteLine($"\nSession is stored in vector store under key: {chatHistoryProvider.GetSessionDbKey(resumedSession)}"); namespace SampleApp diff --git a/dotnet/samples/GettingStarted/Agents/Agent_Step16_ChatReduction/Program.cs b/dotnet/samples/GettingStarted/Agents/Agent_Step16_ChatReduction/Program.cs index f1949d4935..dfeddb386d 100644 --- a/dotnet/samples/GettingStarted/Agents/Agent_Step16_ChatReduction/Program.cs +++ b/dotnet/samples/GettingStarted/Agents/Agent_Step16_ChatReduction/Program.cs @@ -33,7 +33,10 @@ Console.WriteLine(await agent.RunAsync("Tell me a joke about a pirate.", session)); // Get the chat history to see how many messages are stored. -IList? chatHistory = session.GetService>(); +// We can use the ChatHistoryProvider, that is also used by the agent, to read the +// chat history from the session state, and see how the reducer is affecting the stored messages. +var provider = agent.GetService(); +List? chatHistory = provider?.GetMessages(session); Console.WriteLine($"\nChat history has {chatHistory?.Count} messages.\n"); // Invoke the agent a few more times. diff --git a/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgent.cs b/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgent.cs index 93070e00be..c20d464d26 100644 --- a/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgent.cs +++ b/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgent.cs @@ -311,7 +311,7 @@ protected override async IAsyncEnumerable RunCoreStreamingA : serviceType == typeof(ChatOptions) ? this._agentOptions?.ChatOptions : serviceType == typeof(ChatClientAgentOptions) ? this._agentOptions : this._agentOptions?.AIContextProvider?.GetService(serviceType, serviceKey) - ?? this._agentOptions?.ChatHistoryProvider?.GetService(serviceType, serviceKey) + ?? this.ChatHistoryProvider?.GetService(serviceType, serviceKey) ?? this.ChatClient.GetService(serviceType, serviceKey)); /// @@ -745,7 +745,7 @@ private void UpdateSessionConversationId(ChatClientAgentSession session, string? if (!string.IsNullOrWhiteSpace(responseConversationId)) { - if (this._agentOptions?.ChatHistoryProvider is not null) + if (this.ChatHistoryProvider is not null) { // The agent has a ChatHistoryProvider configured, but the service returned a conversation id, // meaning the service manages chat history server-side. Both cannot be used simultaneously. From c541ae32a24de4119e97e1f516a6169d8d729238 Mon Sep 17 00:00:00 2001 From: westey <164392973+westey-m@users.noreply.github.com> Date: Fri, 6 Feb 2026 20:27:31 +0000 Subject: [PATCH 10/28] Updated additional session types to serialize statebag --- .../A2AAgentSession.cs | 7 +++- .../AgentSession.cs | 10 +++++ .../ServiceIdAgentSession.cs | 5 +++ .../DurableAgentSession.cs | 13 ++++++- .../A2AAgentSessionTests.cs | 18 +++++++++ .../ServiceIdAgentSessionTests.cs | 17 +++++++++ .../DurableAgentSessionTests.cs | 38 ++++++++++++++++++- 7 files changed, 104 insertions(+), 4 deletions(-) diff --git a/dotnet/src/Microsoft.Agents.AI.A2A/A2AAgentSession.cs b/dotnet/src/Microsoft.Agents.AI.A2A/A2AAgentSession.cs index cac9b43a30..1d63a44a85 100644 --- a/dotnet/src/Microsoft.Agents.AI.A2A/A2AAgentSession.cs +++ b/dotnet/src/Microsoft.Agents.AI.A2A/A2AAgentSession.cs @@ -33,6 +33,8 @@ internal A2AAgentSession(JsonElement serializedSessionState, JsonSerializerOptio { this.TaskId = taskId; } + + this.StateBag = AgentSessionStateBag.Deserialize(state?.StateBag ?? default); } /// @@ -51,7 +53,8 @@ internal JsonElement Serialize(JsonSerializerOptions? jsonSerializerOptions = nu var state = new A2AAgentSessionState { ContextId = this.ContextId, - TaskId = this.TaskId + TaskId = this.TaskId, + StateBag = this.StateBag.Serialize(), }; return JsonSerializer.SerializeToElement(state, A2AJsonUtilities.DefaultOptions.GetTypeInfo(typeof(A2AAgentSessionState))); @@ -62,5 +65,7 @@ internal sealed class A2AAgentSessionState public string? ContextId { get; set; } public string? TaskId { get; set; } + + public JsonElement? StateBag { get; set; } } } diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentSession.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentSession.cs index 722660d49e..ad10a19d81 100644 --- a/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentSession.cs +++ b/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentSession.cs @@ -2,6 +2,7 @@ using System; using System.Text.Json; +using System.Text.Json.Serialization; using Microsoft.Shared.Diagnostics; namespace Microsoft.Agents.AI; @@ -53,9 +54,18 @@ protected AgentSession() { } + /// + /// Initializes a new instance of the class. + /// + protected AgentSession(AgentSessionStateBag stateBag) + { + this.StateBag = Throw.IfNull(stateBag); + } + /// /// Gets any arbitrary state associated with this session. /// + [JsonPropertyName("stateBag")] public AgentSessionStateBag StateBag { get; protected set; } = new(); /// Asks the for an object of the specified type . diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/ServiceIdAgentSession.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/ServiceIdAgentSession.cs index cf00635984..940c7fd211 100644 --- a/dotnet/src/Microsoft.Agents.AI.Abstractions/ServiceIdAgentSession.cs +++ b/dotnet/src/Microsoft.Agents.AI.Abstractions/ServiceIdAgentSession.cs @@ -66,6 +66,8 @@ protected ServiceIdAgentSession( { this.ServiceSessionId = serviceSessionId; } + + this.StateBag = AgentSessionStateBag.Deserialize(state?.StateBag ?? default); } /// @@ -92,6 +94,7 @@ protected internal virtual JsonElement Serialize(JsonSerializerOptions? jsonSeri var state = new ServiceIdAgentSessionState { ServiceSessionId = this.ServiceSessionId, + StateBag = this.StateBag.Serialize(), }; return JsonSerializer.SerializeToElement(state, AgentAbstractionsJsonUtilities.DefaultOptions.GetTypeInfo(typeof(ServiceIdAgentSessionState))); @@ -100,5 +103,7 @@ protected internal virtual JsonElement Serialize(JsonSerializerOptions? jsonSeri internal sealed class ServiceIdAgentSessionState { public string? ServiceSessionId { get; set; } + + public JsonElement? StateBag { get; set; } } } diff --git a/dotnet/src/Microsoft.Agents.AI.DurableTask/DurableAgentSession.cs b/dotnet/src/Microsoft.Agents.AI.DurableTask/DurableAgentSession.cs index b9d9807728..31c8eec2db 100644 --- a/dotnet/src/Microsoft.Agents.AI.DurableTask/DurableAgentSession.cs +++ b/dotnet/src/Microsoft.Agents.AI.DurableTask/DurableAgentSession.cs @@ -12,12 +12,17 @@ namespace Microsoft.Agents.AI.DurableTask; [DebuggerDisplay("{SessionId}")] public sealed class DurableAgentSession : AgentSession { - [JsonConstructor] internal DurableAgentSession(AgentSessionId sessionId) { this.SessionId = sessionId; } + [JsonConstructor] + internal DurableAgentSession(AgentSessionId sessionId, AgentSessionStateBag stateBag) : base(stateBag) + { + this.SessionId = sessionId; + } + /// /// Gets the agent session ID. /// @@ -49,7 +54,11 @@ internal static DurableAgentSession Deserialize(JsonElement serializedSession, J string sessionIdString = sessionIdElement.GetString() ?? throw new JsonException("sessionId property is null."); AgentSessionId sessionId = AgentSessionId.Parse(sessionIdString); - return new DurableAgentSession(sessionId); + AgentSessionStateBag stateBag = serializedSession.TryGetProperty("stateBag", out JsonElement stateBagElement) + ? AgentSessionStateBag.Deserialize(stateBagElement) + : new AgentSessionStateBag(); + + return new DurableAgentSession(sessionId, stateBag); } /// diff --git a/dotnet/tests/Microsoft.Agents.AI.A2A.UnitTests/A2AAgentSessionTests.cs b/dotnet/tests/Microsoft.Agents.AI.A2A.UnitTests/A2AAgentSessionTests.cs index 66018e0131..6b7348c05b 100644 --- a/dotnet/tests/Microsoft.Agents.AI.A2A.UnitTests/A2AAgentSessionTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.A2A.UnitTests/A2AAgentSessionTests.cs @@ -27,4 +27,22 @@ public void Constructor_RoundTrip_SerializationPreservesState() Assert.Equal(originalSession.ContextId, deserializedSession.ContextId); Assert.Equal(originalSession.TaskId, deserializedSession.TaskId); } + + [Fact] + public void Constructor_RoundTrip_SerializationPreservesStateBag() + { + // Arrange + A2AAgentSession originalSession = new() { ContextId = "ctx-1", TaskId = "task-1" }; + originalSession.StateBag.SetValue("testKey", "testValue"); + + // Act + JsonElement serialized = originalSession.Serialize(); + A2AAgentSession deserializedSession = new(serialized); + + // Assert + Assert.Equal("ctx-1", deserializedSession.ContextId); + Assert.Equal("task-1", deserializedSession.TaskId); + Assert.True(deserializedSession.StateBag.TryGetValue("testKey", out var value)); + Assert.Equal("testValue", value); + } } diff --git a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/ServiceIdAgentSessionTests.cs b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/ServiceIdAgentSessionTests.cs index e4a6626f72..778f4939f5 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/ServiceIdAgentSessionTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/ServiceIdAgentSessionTests.cs @@ -103,6 +103,23 @@ public void Serialize_ReturnsUndefinedServiceSessionId_WhenNotSet() Assert.False(json.TryGetProperty("serviceSessionId", out _)); } + [Fact] + public void Serialize_RoundTrip_PreservesStateBag() + { + // Arrange + var session = new TestServiceIdAgentSession("service-id-roundtrip"); + session.StateBag.SetValue("myKey", "myValue"); + + // Act + var json = session.Serialize(); + var restored = new TestServiceIdAgentSession(json); + + // Assert + Assert.Equal("service-id-roundtrip", restored.GetServiceSessionId()); + Assert.True(restored.StateBag.TryGetValue("myKey", out var value)); + Assert.Equal("myValue", value); + } + #endregion // Sealed test subclass to expose protected members for testing diff --git a/dotnet/tests/Microsoft.Agents.AI.DurableTask.UnitTests/DurableAgentSessionTests.cs b/dotnet/tests/Microsoft.Agents.AI.DurableTask.UnitTests/DurableAgentSessionTests.cs index db6ec99058..bc06c35ab8 100644 --- a/dotnet/tests/Microsoft.Agents.AI.DurableTask.UnitTests/DurableAgentSessionTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.DurableTask.UnitTests/DurableAgentSessionTests.cs @@ -33,11 +33,47 @@ public void STJSerialization() string serializedSession = JsonSerializer.Serialize(session, typeof(DurableAgentSession)); // Expected format: "{\"sessionId\":\"@dafx-test-agent@\"}" - string expectedSerializedSession = $"{{\"sessionId\":\"@dafx-{sessionId.Name}@{sessionId.Key}\",\"StateBag\":{{}}}}"; + string expectedSerializedSession = $"{{\"sessionId\":\"@dafx-{sessionId.Name}@{sessionId.Key}\",\"stateBag\":{{}}}}"; Assert.Equal(expectedSerializedSession, serializedSession); DurableAgentSession? deserializedSession = JsonSerializer.Deserialize(serializedSession); Assert.NotNull(deserializedSession); Assert.Equal(sessionId, deserializedSession.SessionId); } + + [Fact] + public void BuiltInSerialization_RoundTrip_PreservesStateBag() + { + // Arrange + AgentSessionId sessionId = AgentSessionId.WithRandomKey("test-agent"); + DurableAgentSession session = new(sessionId); + session.StateBag.SetValue("durableKey", "durableValue"); + + // Act + JsonElement serializedSession = session.Serialize(); + DurableAgentSession deserializedSession = DurableAgentSession.Deserialize(serializedSession); + + // Assert + Assert.Equal(sessionId, deserializedSession.SessionId); + Assert.True(deserializedSession.StateBag.TryGetValue("durableKey", out var value)); + Assert.Equal("durableValue", value); + } + + [Fact] + public void STJSerialization_RoundTrip_PreservesStateBag() + { + // Arrange + AgentSessionId sessionId = AgentSessionId.WithRandomKey("test-agent"); + DurableAgentSession session = new(sessionId); + session.StateBag.SetValue("stjKey", "stjValue"); + + // Act + string serializedSession = JsonSerializer.Serialize(session, typeof(DurableAgentSession)); + DurableAgentSession? deserializedSession = JsonSerializer.Deserialize(serializedSession); + + // Assert + Assert.NotNull(deserializedSession); + Assert.True(deserializedSession.StateBag.TryGetValue("stjKey", out var value)); + Assert.Equal("stjValue", value); + } } From 6fcb9dcbf3ce65782fe53e49db7409eaf9a98c19 Mon Sep 17 00:00:00 2001 From: westey <164392973+westey-m@users.noreply.github.com> Date: Fri, 6 Feb 2026 20:47:32 +0000 Subject: [PATCH 11/28] Fix regression --- .../ChatClient/ChatClientAgent.cs | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgent.cs b/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgent.cs index c20d464d26..39538a398c 100644 --- a/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgent.cs +++ b/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgent.cs @@ -105,8 +105,10 @@ public ChatClientAgent(IChatClient chatClient, ChatClientAgentOptions? options, // If the user has not opted out of using our default decorators, we wrap the chat client. this.ChatClient = options?.UseProvidedChatClientAsIs is true ? chatClient : chatClient.WithDefaultAgentMiddleware(options, services); - // Default to an InMemoryChatHistoryProvider if no provider is configured. - this.ChatHistoryProvider = options?.ChatHistoryProvider ?? new InMemoryChatHistoryProvider(); + // Use the ChatHistoryProvider from options if provided. + // If one was not provided, and we later find out that the underlying service does not manage chat history server-side, + // we will use the defalut InMemoryChatHistoryProvider at that time. + this.ChatHistoryProvider = options?.ChatHistoryProvider; this._logger = (loggerFactory ?? chatClient.GetService() ?? NullLoggerFactory.Instance).CreateLogger(); } @@ -129,7 +131,7 @@ public ChatClientAgent(IChatClient chatClient, ChatClientAgentOptions? options, /// /// This property may be null in case the agent stores messages in the underlying agent service. /// - public ChatHistoryProvider? ChatHistoryProvider { get; } + public ChatHistoryProvider? ChatHistoryProvider { get; private set; } /// protected override string? IdCore => this._agentOptions?.Id; @@ -756,6 +758,13 @@ private void UpdateSessionConversationId(ChatClientAgentSession session, string? // so we should update the session with the new id. session.ConversationId = responseConversationId; } + else + { + // If the service doesn't use service side chat history storage (i.e. we got no id back from invocation), and + // the agent has no ChatHistoryProvider yet, we should use the default InMemoryChatHistoryProvider so that + // we have somewhere to store the chat history. + this.ChatHistoryProvider ??= new InMemoryChatHistoryProvider(); + } } private Task NotifyChatHistoryProviderOfFailureAsync( From 6073ba58a15b459ca21c8caf38f4b5c5e56b321c Mon Sep 17 00:00:00 2001 From: westey <164392973+westey-m@users.noreply.github.com> Date: Fri, 6 Feb 2026 21:13:39 +0000 Subject: [PATCH 12/28] Address PR comments --- .../Program.cs | 6 + .../Program.cs | 4 +- .../AgentSessionStateBag.cs | 22 +- .../AgentSessionStateBagValue.cs | 9 +- .../AgentSessionStateBagTests.cs | 189 ++++++++++++++++++ 5 files changed, 218 insertions(+), 12 deletions(-) diff --git a/dotnet/samples/GettingStarted/Agents/Agent_Step07_3rdPartyChatHistoryStorage/Program.cs b/dotnet/samples/GettingStarted/Agents/Agent_Step07_3rdPartyChatHistoryStorage/Program.cs index d56c4adc0b..69f16bfc0b 100644 --- a/dotnet/samples/GettingStarted/Agents/Agent_Step07_3rdPartyChatHistoryStorage/Program.cs +++ b/dotnet/samples/GettingStarted/Agents/Agent_Step07_3rdPartyChatHistoryStorage/Program.cs @@ -97,6 +97,12 @@ public string GetSessionDbKey(AgentSession session) private State GetOrInitializeState(AgentSession? session) { var state = session?.StateBag.GetValue(this._stateKey); + + if (state is not null) + { + return state; + } + state = this._stateInitializer(session); if (session is not null) { diff --git a/dotnet/samples/GettingStarted/Agents/Agent_Step20_AdditionalAIContext/Program.cs b/dotnet/samples/GettingStarted/Agents/Agent_Step20_AdditionalAIContext/Program.cs index 44167ca894..024667796d 100644 --- a/dotnet/samples/GettingStarted/Agents/Agent_Step20_AdditionalAIContext/Program.cs +++ b/dotnet/samples/GettingStarted/Agents/Agent_Step20_AdditionalAIContext/Program.cs @@ -110,8 +110,8 @@ public override ValueTask InvokingAsync(InvokingContext context, Canc { Tools = [ - AIFunctionFactory.Create(() => AddTodoItem(context.Session, string.Empty), "AddTodoItem", "Adds an item to the todo list."), - AIFunctionFactory.Create(() => RemoveTodoItem(context.Session, 0), "RemoveTodoItem", "Adds an item to the todo list. Index is zero based.") + AIFunctionFactory.Create((string item) => AddTodoItem(context.Session, item), "AddTodoItem", "Adds an item to the todo list."), + AIFunctionFactory.Create((int index) => RemoveTodoItem(context.Session, index), "RemoveTodoItem", "Adds an item to the todo list. Index is zero based.") ], Messages = [new MEAI.ChatMessage(ChatRole.User, outputMessageBuilder.ToString())] }); diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentSessionStateBag.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentSessionStateBag.cs index 1bc55d59ae..dc73a93a89 100644 --- a/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentSessionStateBag.cs +++ b/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentSessionStateBag.cs @@ -68,12 +68,12 @@ public bool TryGetValue(string key, out T? value, JsonSerializerOptions? json switch (stateValue.JsonValue) { - case T tValue: - value = tValue; - return true; - case JsonElement jsonElement when jsonElement.ValueKind == JsonValueKind.Null || jsonElement.ValueKind == JsonValueKind.Undefined: + case JsonElement jsonElement when jsonElement.ValueKind == JsonValueKind.Undefined: value = null; return false; + case JsonElement jsonElement when jsonElement.ValueKind == JsonValueKind.Null: + value = null; + return true; default: T? result = stateValue.JsonValue.Deserialize(jso.GetTypeInfo(typeof(T))) as T; if (result is null) @@ -117,8 +117,6 @@ public bool TryGetValue(string key, out T? value, JsonSerializerOptions? json switch (stateValue.JsonValue) { - case T tValue: - return tValue; case JsonElement jsonElement when jsonElement.ValueKind == JsonValueKind.Null || jsonElement.ValueKind == JsonValueKind.Undefined: return null; default: @@ -127,6 +125,7 @@ public bool TryGetValue(string key, out T? value, JsonSerializerOptions? json { throw new InvalidOperationException($"Failed to deserialize session state value to type {typeof(T).FullName}."); } + stateValue.IsDeserialized = true; stateValue.DeserializedValue = result; stateValue.ValueType = typeof(T); stateValue.JsonSerializerOptions = jso; @@ -144,7 +143,7 @@ public bool TryGetValue(string key, out T? value, JsonSerializerOptions? json /// The key to store the value under. /// The value to set. /// The JSON serializer options to use for serializing the value. - public void SetValue(string key, T value, JsonSerializerOptions? jsonSerializerOptions = null) + public void SetValue(string key, T? value, JsonSerializerOptions? jsonSerializerOptions = null) where T : class { _ = Throw.IfNullOrWhitespace(key); @@ -153,11 +152,20 @@ public void SetValue(string key, T value, JsonSerializerOptions? jsonSerializ var stateValue = this._state.GetOrAdd(key, _ => new AgentSessionStateBagValue(value, typeof(T), jso)); + stateValue.IsDeserialized = true; stateValue.DeserializedValue = value; stateValue.ValueType = typeof(T); stateValue.JsonSerializerOptions = jso; } + /// + /// Tries to remove a value from the session state. + /// + /// The key of the value to remove. + /// if the value was successfully removed; otherwise, . + public bool TryRemoveValue(string key) + => this._state.TryRemove(Throw.IfNullOrWhitespace(key), out _); + /// /// Serializes all session state values to a JSON object. /// diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentSessionStateBagValue.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentSessionStateBagValue.cs index 10536a9fa1..e3295f7089 100644 --- a/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentSessionStateBagValue.cs +++ b/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentSessionStateBagValue.cs @@ -27,8 +27,9 @@ public AgentSessionStateBagValue(JsonElement jsonValue) /// The value to associate with the session state. Can be any object, including null. /// The type of the value. /// The JSON serializer options to use for serializing the value. - public AgentSessionStateBagValue(object deserializedValue, Type valueType, JsonSerializerOptions jsonSerializerOptions) + public AgentSessionStateBagValue(object? deserializedValue, Type valueType, JsonSerializerOptions jsonSerializerOptions) { + this.IsDeserialized = true; this.DeserializedValue = deserializedValue; this.ValueType = valueType; this.JsonSerializerOptions = jsonSerializerOptions; @@ -41,14 +42,14 @@ public JsonElement JsonValue { get { - if (this.DeserializedValue != null) + if (this.IsDeserialized) { if (this.ValueType is null || this.JsonSerializerOptions is null) { throw new InvalidOperationException($"{nameof(AgentSessionStateBagValue)} has not been properly initialized, please set {nameof(this.ValueType)} and {nameof(this.JsonSerializerOptions)} before accessing {nameof(this.JsonValue)}."); } - return JsonSerializer.SerializeToElement(this.DeserializedValue, this.JsonSerializerOptions.GetTypeInfo(this.ValueType)); + field = JsonSerializer.SerializeToElement(this.DeserializedValue, this.JsonSerializerOptions.GetTypeInfo(this.ValueType)); } return field; @@ -56,6 +57,8 @@ public JsonElement JsonValue set; } + public bool IsDeserialized { get; set; } + public object? DeserializedValue { get; set; } public Type? ValueType { get; set; } diff --git a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AgentSessionStateBagTests.cs b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AgentSessionStateBagTests.cs index 0ecc8b4163..33993bd42c 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AgentSessionStateBagTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AgentSessionStateBagTests.cs @@ -206,6 +206,195 @@ public void TryGetValue_WithEmptyKey_ThrowsArgumentException() #endregion + #region Null Value Tests + + [Fact] + public void SetValue_WithNullValue_StoresNull() + { + // Arrange + var stateBag = new AgentSessionStateBag(); + + // Act + stateBag.SetValue("key1", null); + + // Assert + Assert.Equal(1, stateBag.Count); + } + + [Fact] + public void TryGetValue_WithNullValue_ReturnsTrueAndNull() + { + // Arrange + var stateBag = new AgentSessionStateBag(); + stateBag.SetValue("key1", null); + + // Act + var found = stateBag.TryGetValue("key1", out var result); + + // Assert + Assert.True(found); + Assert.Null(result); + } + + [Fact] + public void GetValue_WithNullValue_ReturnsNull() + { + // Arrange + var stateBag = new AgentSessionStateBag(); + stateBag.SetValue("key1", null); + + // Act + var result = stateBag.GetValue("key1"); + + // Assert + Assert.Null(result); + } + + [Fact] + public void SetValue_OverwriteWithNull_ReturnsNull() + { + // Arrange + var stateBag = new AgentSessionStateBag(); + stateBag.SetValue("key1", "value1"); + + // Act + stateBag.SetValue("key1", null); + + // Assert + Assert.True(stateBag.TryGetValue("key1", out var result)); + Assert.Null(result); + } + + [Fact] + public void SetValue_OverwriteNullWithValue_ReturnsValue() + { + // Arrange + var stateBag = new AgentSessionStateBag(); + stateBag.SetValue("key1", null); + + // Act + stateBag.SetValue("key1", "newValue"); + + // Assert + Assert.True(stateBag.TryGetValue("key1", out var result)); + Assert.Equal("newValue", result); + } + + [Fact] + public void SerializeDeserialize_WithNullValue_SerializesAsNull() + { + // Arrange + var stateBag = new AgentSessionStateBag(); + stateBag.SetValue("nullKey", null); + + // Act + var json = stateBag.Serialize(); + + // Assert - null values are serialized as JSON null + Assert.Equal(JsonValueKind.Object, json.ValueKind); + Assert.True(json.TryGetProperty("nullKey", out var nullElement)); + Assert.Equal(JsonValueKind.Null, nullElement.ValueKind); + } + + #endregion + + #region TryRemoveValue Tests + + [Fact] + public void TryRemoveValue_ExistingKey_ReturnsTrueAndRemoves() + { + // Arrange + var stateBag = new AgentSessionStateBag(); + stateBag.SetValue("key1", "value1"); + + // Act + var removed = stateBag.TryRemoveValue("key1"); + + // Assert + Assert.True(removed); + Assert.Equal(0, stateBag.Count); + Assert.False(stateBag.TryGetValue("key1", out _)); + } + + [Fact] + public void TryRemoveValue_NonexistentKey_ReturnsFalse() + { + // Arrange + var stateBag = new AgentSessionStateBag(); + + // Act + var removed = stateBag.TryRemoveValue("nonexistent"); + + // Assert + Assert.False(removed); + } + + [Fact] + public void TryRemoveValue_WithNullKey_ThrowsArgumentException() + { + // Arrange + var stateBag = new AgentSessionStateBag(); + + // Act & Assert + Assert.Throws(() => stateBag.TryRemoveValue(null!)); + } + + [Fact] + public void TryRemoveValue_WithEmptyKey_ThrowsArgumentException() + { + // Arrange + var stateBag = new AgentSessionStateBag(); + + // Act & Assert + Assert.Throws(() => stateBag.TryRemoveValue("")); + } + + [Fact] + public void TryRemoveValue_WithWhitespaceKey_ThrowsArgumentException() + { + // Arrange + var stateBag = new AgentSessionStateBag(); + + // Act & Assert + Assert.Throws(() => stateBag.TryRemoveValue(" ")); + } + + [Fact] + public void TryRemoveValue_DoesNotAffectOtherKeys() + { + // Arrange + var stateBag = new AgentSessionStateBag(); + stateBag.SetValue("key1", "value1"); + stateBag.SetValue("key2", "value2"); + + // Act + stateBag.TryRemoveValue("key1"); + + // Assert + Assert.Equal(1, stateBag.Count); + Assert.False(stateBag.TryGetValue("key1", out _)); + Assert.True(stateBag.TryGetValue("key2", out var value)); + Assert.Equal("value2", value); + } + + [Fact] + public void TryRemoveValue_ThenSetValue_Works() + { + // Arrange + var stateBag = new AgentSessionStateBag(); + stateBag.SetValue("key1", "original"); + + // Act + stateBag.TryRemoveValue("key1"); + stateBag.SetValue("key1", "replacement"); + + // Assert + Assert.True(stateBag.TryGetValue("key1", out var result)); + Assert.Equal("replacement", result); + } + + #endregion + #region Serialize/Deserialize Tests [Fact] From d8b14346804f9fad4cd0004fbf8d4cc21185eefd Mon Sep 17 00:00:00 2001 From: westey <164392973+westey-m@users.noreply.github.com> Date: Fri, 6 Feb 2026 21:15:56 +0000 Subject: [PATCH 13/28] Address PR comments. --- .../Agents/Agent_Step20_AdditionalAIContext/Program.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dotnet/samples/GettingStarted/Agents/Agent_Step20_AdditionalAIContext/Program.cs b/dotnet/samples/GettingStarted/Agents/Agent_Step20_AdditionalAIContext/Program.cs index 024667796d..b0e96c11b9 100644 --- a/dotnet/samples/GettingStarted/Agents/Agent_Step20_AdditionalAIContext/Program.cs +++ b/dotnet/samples/GettingStarted/Agents/Agent_Step20_AdditionalAIContext/Program.cs @@ -111,7 +111,7 @@ public override ValueTask InvokingAsync(InvokingContext context, Canc Tools = [ AIFunctionFactory.Create((string item) => AddTodoItem(context.Session, item), "AddTodoItem", "Adds an item to the todo list."), - AIFunctionFactory.Create((int index) => RemoveTodoItem(context.Session, index), "RemoveTodoItem", "Adds an item to the todo list. Index is zero based.") + AIFunctionFactory.Create((int index) => RemoveTodoItem(context.Session, index), "RemoveTodoItem", "Removes an item from the todo list. Index is zero based.") ], Messages = [new MEAI.ChatMessage(ChatRole.User, outputMessageBuilder.ToString())] }); From e0cdb89c696a6a430af22cbabc2d61496d5ebf22 Mon Sep 17 00:00:00 2001 From: westey <164392973+westey-m@users.noreply.github.com> Date: Fri, 6 Feb 2026 21:38:40 +0000 Subject: [PATCH 14/28] Fix formatting --- .../Microsoft.Agents.AI.Abstractions/AgentSessionStateBag.cs | 2 +- .../AgentSessionStateBagJsonConverter.cs | 2 +- .../AgentSessionStateBagValue.cs | 2 +- .../AgentSessionStateBagValueJsonConverter.cs | 2 +- .../src/Microsoft.Agents.AI/Memory/ChatHistoryMemoryProvider.cs | 1 - .../AgentSessionStateBagTests.cs | 2 +- .../ChatHistoryProviderMessageFilterTests.cs | 1 - .../ChatHistoryProviderTests.cs | 1 - 8 files changed, 5 insertions(+), 8 deletions(-) diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentSessionStateBag.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentSessionStateBag.cs index dc73a93a89..89a825c1a5 100644 --- a/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentSessionStateBag.cs +++ b/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentSessionStateBag.cs @@ -1,4 +1,4 @@ -// Copyright (c) Microsoft. All rights reserved. +// Copyright (c) Microsoft. All rights reserved. using System; using System.Collections.Concurrent; diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentSessionStateBagJsonConverter.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentSessionStateBagJsonConverter.cs index 999f3c627b..bfb6904320 100644 --- a/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentSessionStateBagJsonConverter.cs +++ b/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentSessionStateBagJsonConverter.cs @@ -1,4 +1,4 @@ -// Copyright (c) Microsoft. All rights reserved. +// Copyright (c) Microsoft. All rights reserved. using System; using System.Text.Json; diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentSessionStateBagValue.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentSessionStateBagValue.cs index e3295f7089..fd1a455ce8 100644 --- a/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentSessionStateBagValue.cs +++ b/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentSessionStateBagValue.cs @@ -1,4 +1,4 @@ -// Copyright (c) Microsoft. All rights reserved. +// Copyright (c) Microsoft. All rights reserved. using System; using System.Text.Json; diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentSessionStateBagValueJsonConverter.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentSessionStateBagValueJsonConverter.cs index 0020526d6c..27c9dc08a8 100644 --- a/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentSessionStateBagValueJsonConverter.cs +++ b/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentSessionStateBagValueJsonConverter.cs @@ -1,4 +1,4 @@ -// Copyright (c) Microsoft. All rights reserved. +// Copyright (c) Microsoft. All rights reserved. using System; using System.Text.Json; diff --git a/dotnet/src/Microsoft.Agents.AI/Memory/ChatHistoryMemoryProvider.cs b/dotnet/src/Microsoft.Agents.AI/Memory/ChatHistoryMemoryProvider.cs index 8f7bd02648..afafb68f71 100644 --- a/dotnet/src/Microsoft.Agents.AI/Memory/ChatHistoryMemoryProvider.cs +++ b/dotnet/src/Microsoft.Agents.AI/Memory/ChatHistoryMemoryProvider.cs @@ -4,7 +4,6 @@ using System.Collections.Generic; using System.Linq; using System.Linq.Expressions; -using System.Text.Json; using System.Threading; using System.Threading.Tasks; using Microsoft.Extensions.AI; diff --git a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AgentSessionStateBagTests.cs b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AgentSessionStateBagTests.cs index 33993bd42c..68c6d24dd4 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AgentSessionStateBagTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AgentSessionStateBagTests.cs @@ -1,4 +1,4 @@ -// Copyright (c) Microsoft. All rights reserved. +// Copyright (c) Microsoft. All rights reserved. using System; using System.Text.Json; diff --git a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/ChatHistoryProviderMessageFilterTests.cs b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/ChatHistoryProviderMessageFilterTests.cs index f91c329cd7..a42c515f1d 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/ChatHistoryProviderMessageFilterTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/ChatHistoryProviderMessageFilterTests.cs @@ -3,7 +3,6 @@ using System; using System.Collections.Generic; using System.Linq; -using System.Text.Json; using System.Threading; using System.Threading.Tasks; using Microsoft.Extensions.AI; diff --git a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/ChatHistoryProviderTests.cs b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/ChatHistoryProviderTests.cs index 38312d3e80..c2ca047070 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/ChatHistoryProviderTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/ChatHistoryProviderTests.cs @@ -2,7 +2,6 @@ using System; using System.Collections.Generic; -using System.Text.Json; using System.Threading; using System.Threading.Tasks; using Microsoft.Extensions.AI; From 88849e888f6dd3faf036355ebdbcc4125bcc8ca7 Mon Sep 17 00:00:00 2001 From: westey <164392973+westey-m@users.noreply.github.com> Date: Fri, 6 Feb 2026 22:35:54 +0000 Subject: [PATCH 15/28] Fix unit tests --- .../CosmosChatHistoryProviderTests.cs | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/dotnet/tests/Microsoft.Agents.AI.CosmosNoSql.UnitTests/CosmosChatHistoryProviderTests.cs b/dotnet/tests/Microsoft.Agents.AI.CosmosNoSql.UnitTests/CosmosChatHistoryProviderTests.cs index a7e3f0411a..0dae2454fd 100644 --- a/dotnet/tests/Microsoft.Agents.AI.CosmosNoSql.UnitTests/CosmosChatHistoryProviderTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.CosmosNoSql.UnitTests/CosmosChatHistoryProviderTests.cs @@ -336,10 +336,11 @@ public async Task InvokingAsync_WithConversationIsolation_ShouldOnlyReturnMessag var conversation1 = Guid.NewGuid().ToString(); var conversation2 = Guid.NewGuid().ToString(); + // Use different stateKey values so the providers don't overwrite each other's state in the shared session using var store1 = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, TestContainerId, - _ => new CosmosChatHistoryProvider.State(conversation1)); + _ => new CosmosChatHistoryProvider.State(conversation1), stateKey: "conv1"); using var store2 = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, TestContainerId, - _ => new CosmosChatHistoryProvider.State(conversation2)); + _ => new CosmosChatHistoryProvider.State(conversation2), stateKey: "conv2"); var context1 = new ChatHistoryProvider.InvokedContext(s_mockAgent, session, [new ChatMessage(ChatRole.User, "Message for conversation 1")], []); var context2 = new ChatHistoryProvider.InvokedContext(s_mockAgent, session, [new ChatMessage(ChatRole.User, "Message for conversation 2")], []); @@ -625,10 +626,11 @@ public async Task InvokingAsync_WithHierarchicalPartitionIsolation_ShouldIsolate const string SessionId = "session-isolation"; // Different userIds create different hierarchical partitions, providing proper isolation + // Use different stateKey values so the providers don't overwrite each other's state in the shared session using var store1 = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, HierarchicalTestContainerId, - _ => new CosmosChatHistoryProvider.State(SessionId, TenantId, UserId1)); + _ => new CosmosChatHistoryProvider.State(SessionId, TenantId, UserId1), stateKey: "user1"); using var store2 = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, HierarchicalTestContainerId, - _ => new CosmosChatHistoryProvider.State(SessionId, TenantId, UserId2)); + _ => new CosmosChatHistoryProvider.State(SessionId, TenantId, UserId2), stateKey: "user2"); // Add messages to both stores var context1 = new ChatHistoryProvider.InvokedContext(s_mockAgent, session, [new ChatMessage(ChatRole.User, "Message from user 1")], []); @@ -702,10 +704,11 @@ public async Task HierarchicalAndSimplePartitioning_ShouldCoexistAsync() var session = CreateMockSession(); // Create simple provider using simple partitioning container and hierarchical provider using hierarchical container + // Use different stateKey values so the providers don't overwrite each other's state in the shared session using var simpleProvider = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, TestContainerId, - _ => new CosmosChatHistoryProvider.State(SessionId)); + _ => new CosmosChatHistoryProvider.State(SessionId), stateKey: "simple"); using var hierarchicalProvider = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, HierarchicalTestContainerId, - _ => new CosmosChatHistoryProvider.State(SessionId, "tenant-coexist", "user-coexist")); + _ => new CosmosChatHistoryProvider.State(SessionId, "tenant-coexist", "user-coexist"), stateKey: "hierarchical"); // Add messages to both var simpleContext = new ChatHistoryProvider.InvokedContext(s_mockAgent, session, [new ChatMessage(ChatRole.User, "Simple partitioning message")], []); From 03fa29bde4f72b8dabd1a9b97f8cc086725f5fa5 Mon Sep 17 00:00:00 2001 From: westey <164392973+westey-m@users.noreply.github.com> Date: Mon, 9 Feb 2026 11:32:34 +0000 Subject: [PATCH 16/28] Remove InMemoryAgentSession since it is not required anymore. --- .../Program.cs | 30 ++-- .../AgentAbstractionsJsonUtilities.cs | 1 - .../AgentSession.cs | 5 + .../InMemoryAgentSession.cs | 127 --------------- .../InMemoryAgentSessionTests.cs | 153 ------------------ .../TestJsonSerializerContext.cs | 1 - .../BasicStreamingTests.cs | 33 ++-- .../ForwardedPropertiesTests.cs | 21 ++- .../SharedStateTests.cs | 21 ++- ...AGUIEndpointRouteBuilderExtensionsTests.cs | 29 ++-- .../AgentWorkflowBuilderTests.cs | 2 +- .../InProcessExecutionTests.cs | 2 +- .../RoleCheckAgent.cs | 2 +- .../Sample/06_GroupChat_Workflow.cs | 2 +- .../TestEchoAgent.cs | 27 ++-- .../TestReplayAgent.cs | 2 +- .../TestRequestAgent.cs | 18 +-- .../WorkflowHostSmokeTests.cs | 8 +- 18 files changed, 97 insertions(+), 387 deletions(-) delete mode 100644 dotnet/src/Microsoft.Agents.AI.Abstractions/InMemoryAgentSession.cs delete mode 100644 dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/InMemoryAgentSessionTests.cs diff --git a/dotnet/samples/GettingStarted/AgentProviders/Agent_With_CustomImplementation/Program.cs b/dotnet/samples/GettingStarted/AgentProviders/Agent_With_CustomImplementation/Program.cs index cc0c15eda5..fbed86bacb 100644 --- a/dotnet/samples/GettingStarted/AgentProviders/Agent_With_CustomImplementation/Program.cs +++ b/dotnet/samples/GettingStarted/AgentProviders/Agent_With_CustomImplementation/Program.cs @@ -6,6 +6,7 @@ using System.Runtime.CompilerServices; using System.Text.Json; +using System.Text.Json.Serialization; using Microsoft.Agents.AI; using Microsoft.Extensions.AI; using SampleApp; @@ -28,6 +29,8 @@ internal sealed class UpperCaseParrotAgent : AIAgent { public override string? Name => "UpperCaseParrotAgent"; + public ChatHistoryProvider ChatHistoryProvider = new InMemoryChatHistoryProvider(); + protected override ValueTask CreateSessionCoreAsync(CancellationToken cancellationToken = default) => new(new CustomAgentSession()); @@ -38,11 +41,11 @@ protected override JsonElement SerializeSessionCore(AgentSession session, JsonSe throw new ArgumentException($"The provided session is not of type {nameof(CustomAgentSession)}.", nameof(session)); } - return typedSession.Serialize(jsonSerializerOptions); + return JsonSerializer.SerializeToElement(typedSession, jsonSerializerOptions); } protected override ValueTask DeserializeSessionCoreAsync(JsonElement serializedState, JsonSerializerOptions? jsonSerializerOptions = null, CancellationToken cancellationToken = default) - => new(new CustomAgentSession(serializedState, jsonSerializerOptions)); + => new(serializedState.Deserialize(jsonSerializerOptions)!); protected override async Task RunCoreAsync(IEnumerable messages, AgentSession? session = null, AgentRunOptions? options = null, CancellationToken cancellationToken = default) { @@ -56,7 +59,7 @@ protected override async Task RunCoreAsync(IEnumerable responseMessages = CloneAndToUpperCase(messages, this.Name).ToList(); @@ -66,7 +69,7 @@ protected override async Task RunCoreAsync(IEnumerable RunCoreStreamingA // Get existing messages from the store var invokingContext = new ChatHistoryProvider.InvokingContext(this, session, messages); - var storeMessages = await typedSession.ChatHistoryProvider.InvokingAsync(invokingContext, cancellationToken); + var storeMessages = await this.ChatHistoryProvider.InvokingAsync(invokingContext, cancellationToken); // Clone the input messages and turn them into response messages with upper case text. List responseMessages = CloneAndToUpperCase(messages, this.Name).ToList(); @@ -98,7 +101,7 @@ protected override async IAsyncEnumerable RunCoreStreamingA { ResponseMessages = responseMessages }; - await typedSession.ChatHistoryProvider.InvokedAsync(invokedContext, cancellationToken); + await this.ChatHistoryProvider.InvokedAsync(invokedContext, cancellationToken); foreach (var message in responseMessages) { @@ -140,15 +143,16 @@ private static IEnumerable CloneAndToUpperCase(IEnumerable /// A session type for our custom agent that only supports in memory storage of messages. /// - internal sealed class CustomAgentSession : InMemoryAgentSession + internal sealed class CustomAgentSession : AgentSession { - internal CustomAgentSession() { } - - internal CustomAgentSession(JsonElement serializedSessionState, JsonSerializerOptions? jsonSerializerOptions = null) - : base(serializedSessionState, jsonSerializerOptions) { } + internal CustomAgentSession() + { + } - internal new JsonElement Serialize(JsonSerializerOptions? jsonSerializerOptions = null) - => base.Serialize(jsonSerializerOptions); + [JsonConstructor] + internal CustomAgentSession(AgentSessionStateBag stateBag) : base(stateBag) + { + } } } } diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentAbstractionsJsonUtilities.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentAbstractionsJsonUtilities.cs index df5375d0b6..af20162e89 100644 --- a/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentAbstractionsJsonUtilities.cs +++ b/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentAbstractionsJsonUtilities.cs @@ -82,7 +82,6 @@ private static JsonSerializerOptions CreateDefaultOptions() [JsonSerializable(typeof(AgentResponseUpdate))] [JsonSerializable(typeof(AgentResponseUpdate[]))] [JsonSerializable(typeof(ServiceIdAgentSession.ServiceIdAgentSessionState))] - [JsonSerializable(typeof(InMemoryAgentSession.InMemoryAgentSessionState))] [JsonSerializable(typeof(InMemoryChatHistoryProvider.State))] [JsonSerializable(typeof(AgentSessionStateBag))] [JsonSerializable(typeof(ConcurrentDictionary))] diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentSession.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentSession.cs index ad10a19d81..144c09a541 100644 --- a/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentSession.cs +++ b/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentSession.cs @@ -1,6 +1,7 @@ // Copyright (c) Microsoft. All rights reserved. using System; +using System.Diagnostics; using System.Text.Json; using System.Text.Json.Serialization; using Microsoft.Shared.Diagnostics; @@ -45,6 +46,7 @@ namespace Microsoft.Agents.AI; /// /// /// +[DebuggerDisplay("{DebuggerDisplay,nq}")] public abstract class AgentSession { /// @@ -97,4 +99,7 @@ protected AgentSession(AgentSessionStateBag stateBag) /// public TService? GetService(object? serviceKey = null) => this.GetService(typeof(TService), serviceKey) is TService service ? service : default; + + [DebuggerBrowsable(DebuggerBrowsableState.Never)] + private string DebuggerDisplay => $"StateBag Count = {this.StateBag.Count}"; } diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/InMemoryAgentSession.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/InMemoryAgentSession.cs deleted file mode 100644 index 5159431d9c..0000000000 --- a/dotnet/src/Microsoft.Agents.AI.Abstractions/InMemoryAgentSession.cs +++ /dev/null @@ -1,127 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. - -using System; -using System.Collections.Generic; -using System.Diagnostics; -using System.Text.Json; -using Microsoft.Extensions.AI; -using Microsoft.Shared.Diagnostics; - -namespace Microsoft.Agents.AI; - -/// -/// Provides an abstract base class for an that maintain all chat history in local memory. -/// -/// -/// -/// is designed for scenarios where chat history should be stored locally -/// rather than in external services or databases. This approach provides high performance and simplicity while -/// maintaining full control over the conversation data. -/// -/// -/// In-memory threads do not persist conversation data across application restarts -/// unless explicitly serialized and restored. -/// -/// -[DebuggerDisplay("{DebuggerDisplay,nq}")] -public abstract class InMemoryAgentSession : AgentSession -{ - /// - /// Initializes a new instance of the class. - /// - /// - /// An optional instance to use for storing chat messages. - /// If , a new empty will be created. - /// - /// - /// This constructor allows sharing of between sessions or providing pre-configured - /// with specific reduction or processing logic. - /// - protected InMemoryAgentSession(InMemoryChatHistoryProvider? chatHistoryProvider = null) - { - this.ChatHistoryProvider = chatHistoryProvider ?? new InMemoryChatHistoryProvider(); - } - - /// - /// Initializes a new instance of the class. - /// - /// The initial messages to populate the conversation history. - /// is . - /// - /// This constructor is useful for initializing sessions with existing conversation history or - /// for migrating conversations from other storage systems. - /// - protected InMemoryAgentSession(IEnumerable messages) - { - this.ChatHistoryProvider = new InMemoryChatHistoryProvider(); - this.ChatHistoryProvider.GetMessages(this).AddRange(Throw.IfNull(messages)); - } - - /// - /// Initializes a new instance of the class from previously serialized state. - /// - /// A representing the serialized state of the session. - /// Optional settings for customizing the JSON deserialization process. - /// - /// Optional factory function to create the . - /// If not provided, a default factory will be used that creates a basic . - /// - /// The is not a JSON object. - /// The is invalid or cannot be deserialized to the expected type. - /// - /// This constructor enables restoration of in-memory threads from previously saved state, allowing - /// conversations to be resumed across application restarts or migrated between different instances. - /// - protected InMemoryAgentSession( - JsonElement serializedState, - JsonSerializerOptions? jsonSerializerOptions = null, - Func? chatHistoryProviderFactory = null) - { - if (serializedState.ValueKind != JsonValueKind.Object) - { - throw new ArgumentException("The serialized session state must be a JSON object.", nameof(serializedState)); - } - - var state = serializedState.Deserialize( - AgentAbstractionsJsonUtilities.DefaultOptions.GetTypeInfo(typeof(InMemoryAgentSessionState))) as InMemoryAgentSessionState; - - this.ChatHistoryProvider = chatHistoryProviderFactory?.Invoke() ?? new InMemoryChatHistoryProvider(); - - if (state?.StateBag is { ValueKind: JsonValueKind.Object } stateBagElement) - { - this.StateBag = AgentSessionStateBag.Deserialize(stateBagElement); - } - } - - /// - /// Gets or sets the used by this thread. - /// - public InMemoryChatHistoryProvider ChatHistoryProvider { get; } - - /// - /// Serializes the current object's state to a using the specified serialization options. - /// - /// The JSON serialization options to use. - /// A representation of the object's state. - protected internal virtual JsonElement Serialize(JsonSerializerOptions? jsonSerializerOptions = null) - { - var state = new InMemoryAgentSessionState - { - StateBag = this.StateBag.Serialize(), - }; - - return JsonSerializer.SerializeToElement(state, AgentAbstractionsJsonUtilities.DefaultOptions.GetTypeInfo(typeof(InMemoryAgentSessionState))); - } - - /// - public override object? GetService(Type serviceType, object? serviceKey = null) => - base.GetService(serviceType, serviceKey) ?? this.ChatHistoryProvider?.GetService(serviceType, serviceKey); - - [DebuggerBrowsable(DebuggerBrowsableState.Never)] - private string DebuggerDisplay => $"Count = {this.ChatHistoryProvider.GetMessages(this).Count}"; - - internal sealed class InMemoryAgentSessionState - { - public JsonElement? StateBag { get; set; } - } -} diff --git a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/InMemoryAgentSessionTests.cs b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/InMemoryAgentSessionTests.cs deleted file mode 100644 index 5511181276..0000000000 --- a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/InMemoryAgentSessionTests.cs +++ /dev/null @@ -1,153 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. - -using System; -using System.Collections.Generic; -using System.Linq; -using System.Text.Json; -using Microsoft.Extensions.AI; - -namespace Microsoft.Agents.AI.Abstractions.UnitTests; - -/// -/// Contains tests for . -/// -public class InMemoryAgentSessionTests -{ - #region Constructor and Property Tests - - [Fact] - public void Constructor_SetsDefaultChatHistoryProvider() - { - // Arrange & Act - var session = new TestInMemoryAgentSession(); - - // Assert - Assert.NotNull(session.ChatHistoryProvider); - Assert.Empty(session.ChatHistoryProvider.GetMessages(session)); - } - - [Fact] - public void Constructor_WithChatHistoryProvider_SetsProperty() - { - // Arrange - var provider = new InMemoryChatHistoryProvider(); - var session = new TestInMemoryAgentSession(provider); - provider.SetMessages(session, [new(ChatRole.User, "Hello")]); - - // Act & Assert - Assert.Same(provider, session.ChatHistoryProvider); - var messages = session.ChatHistoryProvider.GetMessages(session); - Assert.Single(messages); - Assert.Equal("Hello", messages[0].Text); - } - - [Fact] - public void Constructor_WithMessages_SetsProperty() - { - // Arrange - var messages = new List { new(ChatRole.User, "Hi") }; - - // Act - var session = new TestInMemoryAgentSession(messages); - - // Assert - Assert.NotNull(session.ChatHistoryProvider); - Assert.Single(session.ChatHistoryProvider.GetMessages(session)); - Assert.Equal("Hi", session.ChatHistoryProvider.GetMessages(session)[0].Text); - } - - [Fact] - public void Constructor_WithSerializedState_SetsProperty() - { - // Arrange - create a session with a StateBag containing chat history - var originalSession = new TestInMemoryAgentSession([new(ChatRole.User, "TestMsg")]); - var json = originalSession.Serialize(); - - // Act - var session = new TestInMemoryAgentSession(json); - - // Assert - Assert.NotNull(session.ChatHistoryProvider); - var messages = session.ChatHistoryProvider.GetMessages(session); - Assert.Single(messages); - Assert.Equal("TestMsg", messages[0].Text); - } - - [Fact] - public void Constructor_WithInvalidJson_ThrowsArgumentException() - { - // Arrange - var invalidJson = JsonSerializer.SerializeToElement(42, TestJsonSerializerContext.Default.Int32); - - // Act & Assert - Assert.Throws(() => new TestInMemoryAgentSession(invalidJson)); - } - - #endregion - - #region SerializeAsync Tests - - [Fact] - public void Serialize_ReturnsCorrectJson_WhenMessagesExist() - { - // Arrange - var session = new TestInMemoryAgentSession([new(ChatRole.User, "TestContent")]); - - // Act - var json = session.Serialize(); - - // Assert - Assert.Equal(JsonValueKind.Object, json.ValueKind); - Assert.True(json.TryGetProperty("stateBag", out var stateBagProperty)); - Assert.Equal(JsonValueKind.Object, stateBagProperty.ValueKind); - Assert.True(stateBagProperty.TryGetProperty("InMemoryChatHistoryProvider.State", out var providerStateProperty)); - Assert.Equal(JsonValueKind.Object, providerStateProperty.ValueKind); - Assert.True(providerStateProperty.TryGetProperty("messages", out var messagesProperty)); - Assert.Equal(JsonValueKind.Array, messagesProperty.ValueKind); - var messagesList = messagesProperty.EnumerateArray().ToList(); - Assert.Single(messagesList); - } - - [Fact] - public void Serialize_ReturnsEmptyStateBag_WhenNoMessages() - { - // Arrange - var session = new TestInMemoryAgentSession(); - - // Act - var json = session.Serialize(); - - // Assert - Assert.Equal(JsonValueKind.Object, json.ValueKind); - Assert.True(json.TryGetProperty("stateBag", out var providerStateProperty)); - Assert.Equal(JsonValueKind.Object, providerStateProperty.ValueKind); - Assert.False(providerStateProperty.EnumerateObject().Any()); - } - - #endregion - - #region GetService Tests - - [Fact] - public void GetService_RequestingChatHistoryProvider_ReturnsChatHistoryProvider() - { - // Arrange - var session = new TestInMemoryAgentSession(); - - // Act & Assert - Assert.NotNull(session.GetService(typeof(ChatHistoryProvider))); - Assert.Same(session.ChatHistoryProvider, session.GetService(typeof(ChatHistoryProvider))); - Assert.Same(session.ChatHistoryProvider, session.GetService(typeof(InMemoryChatHistoryProvider))); - } - - #endregion - - // Sealed test subclass to expose protected members for testing - private sealed class TestInMemoryAgentSession : InMemoryAgentSession - { - public TestInMemoryAgentSession() { } - public TestInMemoryAgentSession(InMemoryChatHistoryProvider? provider) : base(provider) { } - public TestInMemoryAgentSession(IEnumerable messages) : base(messages) { } - public TestInMemoryAgentSession(JsonElement serializedSessionState) : base(serializedSessionState) { } - } -} diff --git a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/TestJsonSerializerContext.cs b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/TestJsonSerializerContext.cs index 05d13c2e95..37ec8e0a03 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/TestJsonSerializerContext.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/TestJsonSerializerContext.cs @@ -19,7 +19,6 @@ namespace Microsoft.Agents.AI.Abstractions.UnitTests; [JsonSerializable(typeof(Dictionary))] [JsonSerializable(typeof(string[]))] [JsonSerializable(typeof(int))] -[JsonSerializable(typeof(InMemoryAgentSession.InMemoryAgentSessionState))] [JsonSerializable(typeof(ServiceIdAgentSession.ServiceIdAgentSessionState))] [JsonSerializable(typeof(ServiceIdAgentSessionTests.EmptyObject))] [JsonSerializable(typeof(InMemoryChatHistoryProviderTests.TestAIContent))] diff --git a/dotnet/tests/Microsoft.Agents.AI.Hosting.AGUI.AspNetCore.IntegrationTests/BasicStreamingTests.cs b/dotnet/tests/Microsoft.Agents.AI.Hosting.AGUI.AspNetCore.IntegrationTests/BasicStreamingTests.cs index a6e7aab212..f990650e7c 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Hosting.AGUI.AspNetCore.IntegrationTests/BasicStreamingTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Hosting.AGUI.AspNetCore.IntegrationTests/BasicStreamingTests.cs @@ -7,6 +7,7 @@ using System.Net.Http; using System.Runtime.CompilerServices; using System.Text.Json; +using System.Text.Json.Serialization; using System.Threading; using System.Threading.Tasks; using FluentAssertions; @@ -281,10 +282,10 @@ internal sealed class FakeChatClientAgent : AIAgent public override string? Description => "A fake agent for testing"; protected override ValueTask CreateSessionCoreAsync(CancellationToken cancellationToken = default) => - new(new FakeInMemoryAgentSession()); + new(new FakeAgentSession()); protected override ValueTask DeserializeSessionCoreAsync(JsonElement serializedState, JsonSerializerOptions? jsonSerializerOptions = null, CancellationToken cancellationToken = default) => - new(new FakeInMemoryAgentSession(serializedState, jsonSerializerOptions)); + new(serializedState.Deserialize(jsonSerializerOptions)!); protected override JsonElement SerializeSessionCore(AgentSession session, JsonSerializerOptions? jsonSerializerOptions = null) => throw new NotImplementedException(); @@ -326,15 +327,14 @@ protected override async IAsyncEnumerable RunCoreStreamingA } } - private sealed class FakeInMemoryAgentSession : InMemoryAgentSession + private sealed class FakeAgentSession : AgentSession { - public FakeInMemoryAgentSession() - : base() + public FakeAgentSession() { } - public FakeInMemoryAgentSession(JsonElement serializedSession, JsonSerializerOptions? jsonSerializerOptions = null) - : base(serializedSession, jsonSerializerOptions) + [JsonConstructor] + public FakeAgentSession(AgentSessionStateBag stateBag) : base(stateBag) { } } @@ -348,19 +348,19 @@ internal sealed class FakeMultiMessageAgent : AIAgent public override string? Description => "A fake agent that sends multiple messages for testing"; protected override ValueTask CreateSessionCoreAsync(CancellationToken cancellationToken = default) => - new(new FakeInMemoryAgentSession()); + new(new FakeAgentSession()); protected override ValueTask DeserializeSessionCoreAsync(JsonElement serializedState, JsonSerializerOptions? jsonSerializerOptions = null, CancellationToken cancellationToken = default) => - new(new FakeInMemoryAgentSession(serializedState, jsonSerializerOptions)); + new(serializedState.Deserialize(jsonSerializerOptions)!); protected override JsonElement SerializeSessionCore(AgentSession session, JsonSerializerOptions? jsonSerializerOptions = null) { - if (session is not FakeInMemoryAgentSession fakeSession) + if (session is not FakeAgentSession fakeSession) { throw new InvalidOperationException("The provided session is not compatible with the agent. Only sessions created by the agent can be serialized."); } - return fakeSession.Serialize(jsonSerializerOptions); + return JsonSerializer.SerializeToElement(fakeSession, jsonSerializerOptions); } protected override async Task RunCoreAsync( @@ -427,19 +427,16 @@ protected override async IAsyncEnumerable RunCoreStreamingA } } - private sealed class FakeInMemoryAgentSession : InMemoryAgentSession + private sealed class FakeAgentSession : AgentSession { - public FakeInMemoryAgentSession() - : base() + public FakeAgentSession() { } - public FakeInMemoryAgentSession(JsonElement serializedSession, JsonSerializerOptions? jsonSerializerOptions = null) - : base(serializedSession, jsonSerializerOptions) + [JsonConstructor] + public FakeAgentSession(AgentSessionStateBag stateBag) : base(stateBag) { } - internal new JsonElement Serialize(JsonSerializerOptions? jsonSerializerOptions = null) - => base.Serialize(jsonSerializerOptions); } public override object? GetService(Type serviceType, object? serviceKey = null) => null; diff --git a/dotnet/tests/Microsoft.Agents.AI.Hosting.AGUI.AspNetCore.IntegrationTests/ForwardedPropertiesTests.cs b/dotnet/tests/Microsoft.Agents.AI.Hosting.AGUI.AspNetCore.IntegrationTests/ForwardedPropertiesTests.cs index afd3db44b3..353641502c 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Hosting.AGUI.AspNetCore.IntegrationTests/ForwardedPropertiesTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Hosting.AGUI.AspNetCore.IntegrationTests/ForwardedPropertiesTests.cs @@ -9,6 +9,7 @@ using System.Runtime.CompilerServices; using System.Text; using System.Text.Json; +using System.Text.Json.Serialization; using System.Threading; using System.Threading.Tasks; using FluentAssertions; @@ -335,35 +336,31 @@ protected override async IAsyncEnumerable RunCoreStreamingA } protected override ValueTask CreateSessionCoreAsync(CancellationToken cancellationToken = default) => - new(new FakeInMemoryAgentSession()); + new(new FakeAgentSession()); protected override ValueTask DeserializeSessionCoreAsync(JsonElement serializedState, JsonSerializerOptions? jsonSerializerOptions = null, CancellationToken cancellationToken = default) => - new(new FakeInMemoryAgentSession(serializedState, jsonSerializerOptions)); + new(serializedState.Deserialize(jsonSerializerOptions)!); protected override JsonElement SerializeSessionCore(AgentSession session, JsonSerializerOptions? jsonSerializerOptions = null) { - if (session is not FakeInMemoryAgentSession fakeSession) + if (session is not FakeAgentSession fakeSession) { throw new InvalidOperationException("The provided session is not compatible with the agent. Only sessions created by the agent can be serialized."); } - return fakeSession.Serialize(jsonSerializerOptions); + return JsonSerializer.SerializeToElement(fakeSession, jsonSerializerOptions); } - private sealed class FakeInMemoryAgentSession : InMemoryAgentSession + private sealed class FakeAgentSession : AgentSession { - public FakeInMemoryAgentSession() - : base() + public FakeAgentSession() { } - public FakeInMemoryAgentSession(JsonElement serializedSession, JsonSerializerOptions? jsonSerializerOptions = null) - : base(serializedSession, jsonSerializerOptions) + [JsonConstructor] + public FakeAgentSession(AgentSessionStateBag stateBag) : base(stateBag) { } - - internal new JsonElement Serialize(JsonSerializerOptions? jsonSerializerOptions = null) - => base.Serialize(jsonSerializerOptions); } public override object? GetService(Type serviceType, object? serviceKey = null) => null; diff --git a/dotnet/tests/Microsoft.Agents.AI.Hosting.AGUI.AspNetCore.IntegrationTests/SharedStateTests.cs b/dotnet/tests/Microsoft.Agents.AI.Hosting.AGUI.AspNetCore.IntegrationTests/SharedStateTests.cs index b78ddd0e11..67805edd86 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Hosting.AGUI.AspNetCore.IntegrationTests/SharedStateTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Hosting.AGUI.AspNetCore.IntegrationTests/SharedStateTests.cs @@ -7,6 +7,7 @@ using System.Net.Http; using System.Runtime.CompilerServices; using System.Text.Json; +using System.Text.Json.Serialization; using System.Threading; using System.Threading.Tasks; using FluentAssertions; @@ -418,35 +419,31 @@ stateObj is JsonElement state && } protected override ValueTask CreateSessionCoreAsync(CancellationToken cancellationToken = default) => - new(new FakeInMemoryAgentSession()); + new(new FakeAgentSession()); protected override ValueTask DeserializeSessionCoreAsync(JsonElement serializedState, JsonSerializerOptions? jsonSerializerOptions = null, CancellationToken cancellationToken = default) => - new(new FakeInMemoryAgentSession(serializedState, jsonSerializerOptions)); + new(serializedState.Deserialize(jsonSerializerOptions)!); protected override JsonElement SerializeSessionCore(AgentSession session, JsonSerializerOptions? jsonSerializerOptions = null) { - if (session is not FakeInMemoryAgentSession fakeSession) + if (session is not FakeAgentSession fakeSession) { throw new InvalidOperationException("The provided session is not compatible with the agent. Only sessions created by the agent can be serialized."); } - return fakeSession.Serialize(jsonSerializerOptions); + return JsonSerializer.SerializeToElement(fakeSession, jsonSerializerOptions); } - private sealed class FakeInMemoryAgentSession : InMemoryAgentSession + private sealed class FakeAgentSession : AgentSession { - public FakeInMemoryAgentSession() - : base() + public FakeAgentSession() { } - public FakeInMemoryAgentSession(JsonElement serializedSession, JsonSerializerOptions? jsonSerializerOptions = null) - : base(serializedSession, jsonSerializerOptions) + [JsonConstructor] + public FakeAgentSession(AgentSessionStateBag stateBag) : base(stateBag) { } - - internal new JsonElement Serialize(JsonSerializerOptions? jsonSerializerOptions = null) - => base.Serialize(jsonSerializerOptions); } public override object? GetService(Type serviceType, object? serviceKey = null) => null; diff --git a/dotnet/tests/Microsoft.Agents.AI.Hosting.AGUI.AspNetCore.UnitTests/AGUIEndpointRouteBuilderExtensionsTests.cs b/dotnet/tests/Microsoft.Agents.AI.Hosting.AGUI.AspNetCore.UnitTests/AGUIEndpointRouteBuilderExtensionsTests.cs index fecb9d421b..70070eb28e 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Hosting.AGUI.AspNetCore.UnitTests/AGUIEndpointRouteBuilderExtensionsTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Hosting.AGUI.AspNetCore.UnitTests/AGUIEndpointRouteBuilderExtensionsTests.cs @@ -6,6 +6,7 @@ using System.Linq; using System.Text; using System.Text.Json; +using System.Text.Json.Serialization; using System.Threading; using System.Threading.Tasks; using Microsoft.Agents.AI.Hosting.AGUI.AspNetCore.Shared; @@ -426,19 +427,19 @@ private sealed class MultiResponseAgent : AIAgent public override string? Description => "Agent that produces multiple text chunks"; protected override ValueTask CreateSessionCoreAsync(CancellationToken cancellationToken = default) => - new(new TestInMemoryAgentSession()); + new(new TestAgentSession()); protected override ValueTask DeserializeSessionCoreAsync(JsonElement serializedState, JsonSerializerOptions? jsonSerializerOptions = null, CancellationToken cancellationToken = default) => - new(new TestInMemoryAgentSession(serializedState, jsonSerializerOptions)); + new(serializedState.Deserialize(jsonSerializerOptions)!); protected override JsonElement SerializeSessionCore(AgentSession session, JsonSerializerOptions? jsonSerializerOptions = null) { - if (session is not TestInMemoryAgentSession testSession) + if (session is not TestAgentSession testSession) { throw new InvalidOperationException("The provided session is not compatible with the agent. Only sessions created by the agent can be serialized."); } - return testSession.Serialize(jsonSerializerOptions); + return JsonSerializer.SerializeToElement(testSession, jsonSerializerOptions); } protected override Task RunCoreAsync(IEnumerable messages, AgentSession? session = null, AgentRunOptions? options = null, CancellationToken cancellationToken = default) @@ -506,20 +507,16 @@ private RequestDelegate CreateRequestDelegate( }; } - private sealed class TestInMemoryAgentSession : InMemoryAgentSession + private sealed class TestAgentSession : AgentSession { - public TestInMemoryAgentSession() - : base() + public TestAgentSession() { } - public TestInMemoryAgentSession(JsonElement serializedSessionState, JsonSerializerOptions? jsonSerializerOptions = null) - : base(serializedSessionState, jsonSerializerOptions, null) + [JsonConstructor] + public TestAgentSession(AgentSessionStateBag stateBag) : base(stateBag) { } - - internal new JsonElement Serialize(JsonSerializerOptions? jsonSerializerOptions = null) - => base.Serialize(jsonSerializerOptions); } private sealed class TestAgent : AIAgent @@ -529,19 +526,19 @@ private sealed class TestAgent : AIAgent public override string? Description => "Test agent"; protected override ValueTask CreateSessionCoreAsync(CancellationToken cancellationToken = default) => - new(new TestInMemoryAgentSession()); + new(new TestAgentSession()); protected override ValueTask DeserializeSessionCoreAsync(JsonElement serializedState, JsonSerializerOptions? jsonSerializerOptions = null, CancellationToken cancellationToken = default) => - new(new TestInMemoryAgentSession(serializedState, jsonSerializerOptions)); + new(serializedState.Deserialize(jsonSerializerOptions)!); protected override JsonElement SerializeSessionCore(AgentSession session, JsonSerializerOptions? jsonSerializerOptions = null) { - if (session is not TestInMemoryAgentSession testSession) + if (session is not TestAgentSession testSession) { throw new InvalidOperationException("The provided session is not compatible with the agent. Only sessions created by the agent can be serialized."); } - return testSession.Serialize(jsonSerializerOptions); + return JsonSerializer.SerializeToElement(testSession, jsonSerializerOptions); } protected override Task RunCoreAsync(IEnumerable messages, AgentSession? session = null, AgentRunOptions? options = null, CancellationToken cancellationToken = default) diff --git a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/AgentWorkflowBuilderTests.cs b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/AgentWorkflowBuilderTests.cs index a21eda21c4..c0c1c92a3c 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/AgentWorkflowBuilderTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/AgentWorkflowBuilderTests.cs @@ -161,7 +161,7 @@ protected override async IAsyncEnumerable RunCoreStreamingA } } - private sealed class DoubleEchoAgentSession() : InMemoryAgentSession(); + private sealed class DoubleEchoAgentSession() : AgentSession(); [Fact] public async Task BuildConcurrent_AgentsRunInParallelAsync() diff --git a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/InProcessExecutionTests.cs b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/InProcessExecutionTests.cs index f12ffc6988..922f7a300f 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/InProcessExecutionTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/InProcessExecutionTests.cs @@ -195,5 +195,5 @@ protected override async IAsyncEnumerable RunCoreStreamingA /// /// Simple session implementation for SimpleTestAgent. /// - private sealed class SimpleTestAgentSession : InMemoryAgentSession; + private sealed class SimpleTestAgentSession : AgentSession; } diff --git a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/RoleCheckAgent.cs b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/RoleCheckAgent.cs index 48fc432eeb..b1b149de23 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/RoleCheckAgent.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/RoleCheckAgent.cs @@ -46,5 +46,5 @@ protected override async IAsyncEnumerable RunCoreStreamingA }; } - private sealed class RoleCheckAgentSession : InMemoryAgentSession; + private sealed class RoleCheckAgentSession : AgentSession; } diff --git a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/Sample/06_GroupChat_Workflow.cs b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/Sample/06_GroupChat_Workflow.cs index e4c905f814..1d9bc40516 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/Sample/06_GroupChat_Workflow.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/Sample/06_GroupChat_Workflow.cs @@ -90,4 +90,4 @@ protected override async IAsyncEnumerable RunCoreStreamingA } } -internal sealed class HelloAgentSession() : InMemoryAgentSession(); +internal sealed class HelloAgentSession() : AgentSession(); diff --git a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/TestEchoAgent.cs b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/TestEchoAgent.cs index df68dbbe28..4eda20b956 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/TestEchoAgent.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/TestEchoAgent.cs @@ -5,6 +5,7 @@ using System.Linq; using System.Runtime.CompilerServices; using System.Text.Json; +using System.Text.Json.Serialization; using System.Threading; using System.Threading.Tasks; using Microsoft.Extensions.AI; @@ -16,6 +17,8 @@ internal class TestEchoAgent(string? id = null, string? name = null, string? pre protected override string? IdCore => id; public override string? Name => name ?? base.Name; + public InMemoryChatHistoryProvider ChatHistoryProvider { get; } = new(); + protected override async ValueTask DeserializeSessionCoreAsync(JsonElement serializedState, JsonSerializerOptions? jsonSerializerOptions = null, CancellationToken cancellationToken = default) { return serializedState.Deserialize(jsonSerializerOptions) ?? await this.CreateSessionAsync(cancellationToken); @@ -28,15 +31,15 @@ protected override JsonElement SerializeSessionCore(AgentSession session, JsonSe throw new InvalidOperationException("The provided session is not compatible with the agent. Only sessions created by the agent can be serialized."); } - return typedSession.Serialize(jsonSerializerOptions); + return JsonSerializer.SerializeToElement(typedSession, jsonSerializerOptions); } protected override ValueTask CreateSessionCoreAsync(CancellationToken cancellationToken = default) => new(new EchoAgentSession()); - private static ChatMessage UpdateSession(ChatMessage message, InMemoryAgentSession? session = null) + private ChatMessage UpdateSession(ChatMessage message, AgentSession? session = null) { - session?.ChatHistoryProvider.GetMessages(session).Add(message); + this.ChatHistoryProvider.GetMessages(session).Add(message); return message; } @@ -45,7 +48,7 @@ private IEnumerable EchoMessages(IEnumerable messages, { foreach (ChatMessage message in messages) { - UpdateSession(message, session as InMemoryAgentSession); + this.UpdateSession(message, session); } IEnumerable echoMessages @@ -53,14 +56,14 @@ IEnumerable echoMessages where message.Role == ChatRole.User && !string.IsNullOrEmpty(message.Text) select - UpdateSession(new ChatMessage(ChatRole.Assistant, $"{prefix}{message.Text}") + this.UpdateSession(new ChatMessage(ChatRole.Assistant, $"{prefix}{message.Text}") { AuthorName = this.Name ?? this.Id, CreatedAt = DateTimeOffset.Now, MessageId = Guid.NewGuid().ToString("N") - }, session as InMemoryAgentSession); + }, session); - return echoMessages.Concat(this.GetEpilogueMessages(options).Select(m => UpdateSession(m, session as InMemoryAgentSession))); + return echoMessages.Concat(this.GetEpilogueMessages(options).Select(m => this.UpdateSession(m, session))); } protected virtual IEnumerable GetEpilogueMessages(AgentRunOptions? options = null) @@ -99,11 +102,11 @@ protected override async IAsyncEnumerable RunCoreStreamingA } } - private sealed class EchoAgentSession : InMemoryAgentSession + private sealed class EchoAgentSession : AgentSession { - internal new JsonElement Serialize(JsonSerializerOptions? jsonSerializerOptions = null) - { - return base.Serialize(jsonSerializerOptions); - } + internal EchoAgentSession() { } + + [JsonConstructor] + internal EchoAgentSession(AgentSessionStateBag stateBag) : base(stateBag) { } } } diff --git a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/TestReplayAgent.cs b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/TestReplayAgent.cs index 032ba9001c..1721cdd4c4 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/TestReplayAgent.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/TestReplayAgent.cs @@ -104,5 +104,5 @@ protected override async IAsyncEnumerable RunCoreStreamingA return candidateMessages; } - private sealed class ReplayAgentSession() : InMemoryAgentSession(); + private sealed class ReplayAgentSession() : AgentSession(); } diff --git a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/TestRequestAgent.cs b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/TestRequestAgent.cs index 63cd8dd6f0..5533d97091 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/TestRequestAgent.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/TestRequestAgent.cs @@ -330,7 +330,7 @@ static TRequest AssertAndExtractRequestContent(ExternalRequest request } } - private sealed class TestRequestAgentSession : InMemoryAgentSession + private sealed class TestRequestAgentSession : AgentSession where TRequest : AIContent where TResponse : AIContent { @@ -343,19 +343,13 @@ public TestRequestAgentSession() public HashSet ServicedRequests { get; } = new(); public HashSet PairedRequests { get; } = new(); - private static JsonElement DeserializeAndExtractState(JsonElement serializedState, - out TestRequestAgentSessionState state, - JsonSerializerOptions? jsonSerializerOptions = null) + public TestRequestAgentSession(JsonElement element, JsonSerializerOptions? jsonSerializerOptions = null) { - state = JsonSerializer.Deserialize(serializedState, jsonSerializerOptions) + var state = JsonSerializer.Deserialize(element, jsonSerializerOptions) ?? throw new ArgumentException("Unable to deserialize session state."); - return state.SessionState; - } + this.StateBag = AgentSessionStateBag.Deserialize(state.SessionState); - public TestRequestAgentSession(JsonElement element, JsonSerializerOptions? jsonSerializerOptions = null) - : base(DeserializeAndExtractState(element, out TestRequestAgentSessionState state, jsonSerializerOptions)) - { this.UnservicedRequests = state.UnservicedRequests.ToDictionary( keySelector: item => item.Key, elementSelector: item => item.Value.As()!); @@ -364,9 +358,9 @@ public TestRequestAgentSession(JsonElement element, JsonSerializerOptions? jsonS this.PairedRequests = state.PairedRequests; } - protected override JsonElement Serialize(JsonSerializerOptions? jsonSerializerOptions = null) + internal JsonElement Serialize(JsonSerializerOptions? jsonSerializerOptions = null) { - JsonElement sessionState = base.Serialize(jsonSerializerOptions); + JsonElement sessionState = this.StateBag.Serialize(); Dictionary portableUnservicedRequests = this.UnservicedRequests.ToDictionary( diff --git a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/WorkflowHostSmokeTests.cs b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/WorkflowHostSmokeTests.cs index ac6485131d..e2f01d675f 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/WorkflowHostSmokeTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/WorkflowHostSmokeTests.cs @@ -32,18 +32,16 @@ public class WorkflowHostSmokeTests { private sealed class AlwaysFailsAIAgent(bool failByThrowing) : AIAgent { - private sealed class Session : InMemoryAgentSession + private sealed class Session : AgentSession { public Session() { } - public Session(JsonElement serializedSession, JsonSerializerOptions? jsonSerializerOptions = null) - : base(serializedSession, jsonSerializerOptions) - { } + public Session(AgentSessionStateBag stateBag) : base(stateBag) { } } protected override ValueTask DeserializeSessionCoreAsync(JsonElement serializedState, JsonSerializerOptions? jsonSerializerOptions = null, CancellationToken cancellationToken = default) { - return new(new Session(serializedState, jsonSerializerOptions)); + return new(serializedState.Deserialize(jsonSerializerOptions)!); } protected override ValueTask CreateSessionCoreAsync(CancellationToken cancellationToken = default) From d419f735869313398a471f26c1c4fc32c1b79bd7 Mon Sep 17 00:00:00 2001 From: westey <164392973+westey-m@users.noreply.github.com> Date: Mon, 9 Feb 2026 11:42:32 +0000 Subject: [PATCH 17/28] Address PR comments --- .../Microsoft.Agents.AI.Abstractions/AIContextProvider.cs | 7 ++++++- .../AgentSessionStateBag.cs | 2 ++ .../ChatHistoryProvider.cs | 5 +++++ .../src/Microsoft.Agents.AI/ChatClient/ChatClientAgent.cs | 2 +- 4 files changed, 14 insertions(+), 2 deletions(-) diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/AIContextProvider.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/AIContextProvider.cs index bc864393eb..39ddc4729c 100644 --- a/dotnet/src/Microsoft.Agents.AI.Abstractions/AIContextProvider.cs +++ b/dotnet/src/Microsoft.Agents.AI.Abstractions/AIContextProvider.cs @@ -59,13 +59,18 @@ public abstract class AIContextProvider /// /// Implementers can use the request and response messages in the provided to: /// - /// Update internal state based on conversation outcomes + /// Update state based on conversation outcomes /// Extract and store memories or preferences from user messages /// Log or audit conversation details /// Perform cleanup or finalization tasks /// /// /// + /// The is passed a reference to the via and + /// allowing it to store state in the . Since an is used with many different sessions, it should + /// not store any session-specific information within its own instance fields. Instead, any session-specific state should be stored in the associated . + /// + /// /// This method is called regardless of whether the invocation succeeded or failed. /// To check if the invocation was successful, inspect the property. /// diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentSessionStateBag.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentSessionStateBag.cs index 89a825c1a5..0d4e189a1c 100644 --- a/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentSessionStateBag.cs +++ b/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentSessionStateBag.cs @@ -82,6 +82,7 @@ public bool TryGetValue(string key, out T? value, JsonSerializerOptions? json return false; } + stateValue.IsDeserialized = true; stateValue.DeserializedValue = result; stateValue.ValueType = typeof(T); stateValue.JsonSerializerOptions = jso; @@ -125,6 +126,7 @@ public bool TryGetValue(string key, out T? value, JsonSerializerOptions? json { throw new InvalidOperationException($"Failed to deserialize session state value to type {typeof(T).FullName}."); } + stateValue.IsDeserialized = true; stateValue.DeserializedValue = result; stateValue.ValueType = typeof(T); diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/ChatHistoryProvider.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/ChatHistoryProvider.cs index 7e36c5fc33..07bf7e5ed2 100644 --- a/dotnet/src/Microsoft.Agents.AI.Abstractions/ChatHistoryProvider.cs +++ b/dotnet/src/Microsoft.Agents.AI.Abstractions/ChatHistoryProvider.cs @@ -28,6 +28,11 @@ namespace Microsoft.Agents.AI; /// /// /// +/// The is passed a reference to the via and +/// allowing it to store state in the . Since a is used with many different sessions, it should +/// not store any session-specific information within its own instance fields. Instead, any session-specific state should be stored in the associated . +/// +/// /// A is only relevant for scenarios where the underlying AI service that the agent is using /// does not use in-service chat history storage. /// diff --git a/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgent.cs b/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgent.cs index 39538a398c..7fdc7bc84e 100644 --- a/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgent.cs +++ b/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgent.cs @@ -107,7 +107,7 @@ public ChatClientAgent(IChatClient chatClient, ChatClientAgentOptions? options, // Use the ChatHistoryProvider from options if provided. // If one was not provided, and we later find out that the underlying service does not manage chat history server-side, - // we will use the defalut InMemoryChatHistoryProvider at that time. + // we will use the default InMemoryChatHistoryProvider at that time. this.ChatHistoryProvider = options?.ChatHistoryProvider; this._logger = (loggerFactory ?? chatClient.GetService() ?? NullLoggerFactory.Instance).CreateLogger(); From 23ee6f75acda88159a673a85ea24fc1e39390653 Mon Sep 17 00:00:00 2001 From: westey <164392973+westey-m@users.noreply.github.com> Date: Mon, 9 Feb 2026 15:21:44 +0000 Subject: [PATCH 18/28] Convert sessions for A2AAgent, ChatClientAgent, CopilotStudioAgent and GithubCopilotAgent to use regular json serialization. --- .../src/Microsoft.Agents.AI.A2A/A2AAgent.cs | 2 +- .../A2AAgentSession.cs | 56 +++----- .../A2AJsonUtilities.cs | 2 +- .../AgentAbstractionsJsonUtilities.cs | 1 - .../ServiceIdAgentSession.cs | 109 -------------- .../CopilotStudioAgent.cs | 2 +- .../CopilotStudioAgentSession.cs | 40 ++++-- .../CopilotStudioJsonUtilities.cs | 48 +++++++ .../GitHubCopilotAgent.cs | 2 +- .../GitHubCopilotAgentSession.cs | 44 +++--- .../GitHubCopilotJsonUtilities.cs | 2 +- .../Microsoft.Agents.AI/AgentJsonUtilities.cs | 2 +- .../ChatClient/ChatClientAgent.cs | 2 +- .../ChatClient/ChatClientAgentSession.cs | 43 ++---- .../A2AAgentSessionTests.cs | 4 +- .../ServiceIdAgentSessionTests.cs | 136 ------------------ .../TestJsonSerializerContext.cs | 2 - .../ChatClient/ChatClientAgentSessionTests.cs | 5 +- 18 files changed, 152 insertions(+), 350 deletions(-) delete mode 100644 dotnet/src/Microsoft.Agents.AI.Abstractions/ServiceIdAgentSession.cs create mode 100644 dotnet/src/Microsoft.Agents.AI.CopilotStudio/CopilotStudioJsonUtilities.cs delete mode 100644 dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/ServiceIdAgentSessionTests.cs diff --git a/dotnet/src/Microsoft.Agents.AI.A2A/A2AAgent.cs b/dotnet/src/Microsoft.Agents.AI.A2A/A2AAgent.cs index 533b50c8fe..9d819c90d8 100644 --- a/dotnet/src/Microsoft.Agents.AI.A2A/A2AAgent.cs +++ b/dotnet/src/Microsoft.Agents.AI.A2A/A2AAgent.cs @@ -80,7 +80,7 @@ protected override JsonElement SerializeSessionCore(AgentSession session, JsonSe /// protected override ValueTask DeserializeSessionCoreAsync(JsonElement serializedState, JsonSerializerOptions? jsonSerializerOptions = null, CancellationToken cancellationToken = default) - => new(new A2AAgentSession(serializedState, jsonSerializerOptions)); + => new(A2AAgentSession.Deserialize(serializedState, jsonSerializerOptions)); /// protected override async Task RunCoreAsync(IEnumerable messages, AgentSession? session = null, AgentRunOptions? options = null, CancellationToken cancellationToken = default) diff --git a/dotnet/src/Microsoft.Agents.AI.A2A/A2AAgentSession.cs b/dotnet/src/Microsoft.Agents.AI.A2A/A2AAgentSession.cs index 1d63a44a85..045abc736a 100644 --- a/dotnet/src/Microsoft.Agents.AI.A2A/A2AAgentSession.cs +++ b/dotnet/src/Microsoft.Agents.AI.A2A/A2AAgentSession.cs @@ -1,71 +1,61 @@ // Copyright (c) Microsoft. All rights reserved. using System; +using System.Diagnostics; using System.Text.Json; +using System.Text.Json.Serialization; namespace Microsoft.Agents.AI.A2A; /// /// Session for A2A based agents. /// +[DebuggerDisplay("{DebuggerDisplay,nq}")] public sealed class A2AAgentSession : AgentSession { internal A2AAgentSession() { } - internal A2AAgentSession(JsonElement serializedSessionState, JsonSerializerOptions? jsonSerializerOptions = null) + [JsonConstructor] + internal A2AAgentSession(string? contextId, string? taskId, AgentSessionStateBag? stateBag) : base(stateBag ?? new()) { - if (serializedSessionState.ValueKind != JsonValueKind.Object) - { - throw new ArgumentException("The serialized session state must be a JSON object.", nameof(serializedSessionState)); - } - - var state = serializedSessionState.Deserialize( - A2AJsonUtilities.DefaultOptions.GetTypeInfo(typeof(A2AAgentSessionState))) as A2AAgentSessionState; - - if (state?.ContextId is string contextId) - { - this.ContextId = contextId; - } - - if (state?.TaskId is string taskId) - { - this.TaskId = taskId; - } - - this.StateBag = AgentSessionStateBag.Deserialize(state?.StateBag ?? default); + this.ContextId = contextId; + this.TaskId = taskId; } /// /// Gets the ID for the current conversation with the A2A agent. /// + [JsonPropertyName("contextId")] public string? ContextId { get; internal set; } /// /// Gets the ID for the task the agent is currently working on. /// + [JsonPropertyName("taskId")] public string? TaskId { get; internal set; } /// internal JsonElement Serialize(JsonSerializerOptions? jsonSerializerOptions = null) { - var state = new A2AAgentSessionState - { - ContextId = this.ContextId, - TaskId = this.TaskId, - StateBag = this.StateBag.Serialize(), - }; - - return JsonSerializer.SerializeToElement(state, A2AJsonUtilities.DefaultOptions.GetTypeInfo(typeof(A2AAgentSessionState))); + var jso = jsonSerializerOptions ?? A2AJsonUtilities.DefaultOptions; + return JsonSerializer.SerializeToElement(this, jso.GetTypeInfo(typeof(A2AAgentSession))); } - internal sealed class A2AAgentSessionState + internal static A2AAgentSession Deserialize(JsonElement serializedState, JsonSerializerOptions? jsonSerializerOptions = null) { - public string? ContextId { get; set; } - - public string? TaskId { get; set; } + if (serializedState.ValueKind != JsonValueKind.Object) + { + throw new ArgumentException("The serialized session state must be a JSON object.", nameof(serializedState)); + } - public JsonElement? StateBag { get; set; } + var jso = jsonSerializerOptions ?? A2AJsonUtilities.DefaultOptions; + return serializedState.Deserialize(jso.GetTypeInfo(typeof(A2AAgentSession))) as A2AAgentSession + ?? new A2AAgentSession(); } + + [DebuggerBrowsable(DebuggerBrowsableState.Never)] + private string DebuggerDisplay => + $"ContextId = {this.ContextId}, TaskId = {this.TaskId}, StateBag Count = {this.StateBag.Count}"; } diff --git a/dotnet/src/Microsoft.Agents.AI.A2A/A2AJsonUtilities.cs b/dotnet/src/Microsoft.Agents.AI.A2A/A2AJsonUtilities.cs index 5f079d9573..3c25e350ae 100644 --- a/dotnet/src/Microsoft.Agents.AI.A2A/A2AJsonUtilities.cs +++ b/dotnet/src/Microsoft.Agents.AI.A2A/A2AJsonUtilities.cs @@ -74,7 +74,7 @@ private static JsonSerializerOptions CreateDefaultOptions() NumberHandling = JsonNumberHandling.AllowReadingFromString)] // A2A agent types - [JsonSerializable(typeof(A2AAgentSession.A2AAgentSessionState))] + [JsonSerializable(typeof(A2AAgentSession))] [ExcludeFromCodeCoverage] private sealed partial class JsonContext : JsonSerializerContext; } diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentAbstractionsJsonUtilities.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentAbstractionsJsonUtilities.cs index af20162e89..f8c8aa9b98 100644 --- a/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentAbstractionsJsonUtilities.cs +++ b/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentAbstractionsJsonUtilities.cs @@ -81,7 +81,6 @@ private static JsonSerializerOptions CreateDefaultOptions() [JsonSerializable(typeof(AgentResponse[]))] [JsonSerializable(typeof(AgentResponseUpdate))] [JsonSerializable(typeof(AgentResponseUpdate[]))] - [JsonSerializable(typeof(ServiceIdAgentSession.ServiceIdAgentSessionState))] [JsonSerializable(typeof(InMemoryChatHistoryProvider.State))] [JsonSerializable(typeof(AgentSessionStateBag))] [JsonSerializable(typeof(ConcurrentDictionary))] diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/ServiceIdAgentSession.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/ServiceIdAgentSession.cs deleted file mode 100644 index 940c7fd211..0000000000 --- a/dotnet/src/Microsoft.Agents.AI.Abstractions/ServiceIdAgentSession.cs +++ /dev/null @@ -1,109 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. - -using System; -using System.Diagnostics; -using System.Text.Json; -using Microsoft.Shared.Diagnostics; - -namespace Microsoft.Agents.AI; - -/// -/// Provides a base class for agent sessions that store conversation state remotely in a service and maintain only an identifier reference locally. -/// -/// -/// This class is designed for scenarios where conversation state is managed by an external service (such as a cloud-based AI service) -/// rather than being stored locally. The session maintains only the service identifier needed to reference the remote conversation state. -/// -[DebuggerDisplay("ServiceSessionId = {ServiceSessionId}")] -public abstract class ServiceIdAgentSession : AgentSession -{ - /// - /// Initializes a new instance of the class without a service session identifier. - /// - /// - /// When using this constructor, the will be initially - /// and should be set by derived classes when the remote conversation is created. - /// - protected ServiceIdAgentSession() - { - } - - /// - /// Initializes a new instance of the class with the specified service session identifier. - /// - /// The unique identifier that references the conversation state stored in the remote service. - /// is . - /// is empty or contains only whitespace. - protected ServiceIdAgentSession(string serviceSessionId) - { - this.ServiceSessionId = Throw.IfNullOrEmpty(serviceSessionId); - } - - /// - /// Initializes a new instance of the class from previously serialized state. - /// - /// A representing the serialized state of the session. - /// Optional settings for customizing the JSON deserialization process. - /// The is not a JSON object. - /// The is invalid or cannot be deserialized to the expected type. - /// - /// This constructor enables restoration of a service-backed session from serialized state, typically used - /// when deserializing session information that was previously saved or transmitted across application boundaries. - /// - protected ServiceIdAgentSession( - JsonElement serializedState, - JsonSerializerOptions? jsonSerializerOptions = null) - { - if (serializedState.ValueKind != JsonValueKind.Object) - { - throw new ArgumentException("The serialized session state must be a JSON object.", nameof(serializedState)); - } - - var state = serializedState.Deserialize( - AgentAbstractionsJsonUtilities.DefaultOptions.GetTypeInfo(typeof(ServiceIdAgentSessionState))) as ServiceIdAgentSessionState; - - if (state?.ServiceSessionId is string serviceSessionId) - { - this.ServiceSessionId = serviceSessionId; - } - - this.StateBag = AgentSessionStateBag.Deserialize(state?.StateBag ?? default); - } - - /// - /// Gets or sets the unique identifier that references the conversation state stored in the remote service. - /// - /// - /// A string identifier that uniquely identifies the conversation within the remote service, - /// or if no remote conversation has been established yet. - /// - /// - /// This identifier is used by derived classes to reference the remote conversation state when making - /// API calls to the backing service. The exact format and meaning of this identifier depends on the - /// specific service implementation. - /// - protected string? ServiceSessionId { get; set; } - - /// - /// Serializes the current object's state to a using the specified serialization options. - /// - /// The JSON serialization options to use. - /// A representation of the object's state. - protected internal virtual JsonElement Serialize(JsonSerializerOptions? jsonSerializerOptions = null) - { - var state = new ServiceIdAgentSessionState - { - ServiceSessionId = this.ServiceSessionId, - StateBag = this.StateBag.Serialize(), - }; - - return JsonSerializer.SerializeToElement(state, AgentAbstractionsJsonUtilities.DefaultOptions.GetTypeInfo(typeof(ServiceIdAgentSessionState))); - } - - internal sealed class ServiceIdAgentSessionState - { - public string? ServiceSessionId { get; set; } - - public JsonElement? StateBag { get; set; } - } -} diff --git a/dotnet/src/Microsoft.Agents.AI.CopilotStudio/CopilotStudioAgent.cs b/dotnet/src/Microsoft.Agents.AI.CopilotStudio/CopilotStudioAgent.cs index 48642139a9..1bcb84ffbd 100644 --- a/dotnet/src/Microsoft.Agents.AI.CopilotStudio/CopilotStudioAgent.cs +++ b/dotnet/src/Microsoft.Agents.AI.CopilotStudio/CopilotStudioAgent.cs @@ -68,7 +68,7 @@ protected override JsonElement SerializeSessionCore(AgentSession session, JsonSe /// protected override ValueTask DeserializeSessionCoreAsync(JsonElement serializedState, JsonSerializerOptions? jsonSerializerOptions = null, CancellationToken cancellationToken = default) - => new(new CopilotStudioAgentSession(serializedState, jsonSerializerOptions)); + => new(CopilotStudioAgentSession.Deserialize(serializedState, jsonSerializerOptions)); /// protected override async Task RunCoreAsync( diff --git a/dotnet/src/Microsoft.Agents.AI.CopilotStudio/CopilotStudioAgentSession.cs b/dotnet/src/Microsoft.Agents.AI.CopilotStudio/CopilotStudioAgentSession.cs index ec3e21ca91..ba101df082 100644 --- a/dotnet/src/Microsoft.Agents.AI.CopilotStudio/CopilotStudioAgentSession.cs +++ b/dotnet/src/Microsoft.Agents.AI.CopilotStudio/CopilotStudioAgentSession.cs @@ -1,36 +1,58 @@ // Copyright (c) Microsoft. All rights reserved. +using System; +using System.Diagnostics; using System.Text.Json; +using System.Text.Json.Serialization; namespace Microsoft.Agents.AI.CopilotStudio; /// /// Session for CopilotStudio based agents. /// -public sealed class CopilotStudioAgentSession : ServiceIdAgentSession +[DebuggerDisplay("{DebuggerDisplay,nq}")] +public sealed class CopilotStudioAgentSession : AgentSession { internal CopilotStudioAgentSession() { } - internal CopilotStudioAgentSession(JsonElement serializedSessionState, JsonSerializerOptions? jsonSerializerOptions = null) : base(serializedSessionState, jsonSerializerOptions) + [JsonConstructor] + internal CopilotStudioAgentSession(string? conversationId, AgentSessionStateBag? stateBag) : base(stateBag ?? new()) { + this.ConversationId = conversationId; } /// /// Gets the ID for the current conversation with the Copilot Studio agent. /// - public string? ConversationId - { - get { return this.ServiceSessionId; } - internal set { this.ServiceSessionId = value; } - } + [JsonPropertyName("serviceSessionId")] + public string? ConversationId { get; internal set; } /// /// Serializes the current object's state to a using the specified serialization options. /// /// The JSON serialization options to use. /// A representation of the object's state. - internal new JsonElement Serialize(JsonSerializerOptions? jsonSerializerOptions = null) - => base.Serialize(jsonSerializerOptions); + internal JsonElement Serialize(JsonSerializerOptions? jsonSerializerOptions = null) + { + var jso = jsonSerializerOptions ?? CopilotStudioJsonUtilities.DefaultOptions; + return JsonSerializer.SerializeToElement(this, jso.GetTypeInfo(typeof(CopilotStudioAgentSession))); + } + + internal static CopilotStudioAgentSession Deserialize(JsonElement serializedState, JsonSerializerOptions? jsonSerializerOptions = null) + { + if (serializedState.ValueKind != JsonValueKind.Object) + { + throw new ArgumentException("The serialized session state must be a JSON object.", nameof(serializedState)); + } + + var jso = jsonSerializerOptions ?? CopilotStudioJsonUtilities.DefaultOptions; + return serializedState.Deserialize(jso.GetTypeInfo(typeof(CopilotStudioAgentSession))) as CopilotStudioAgentSession + ?? new CopilotStudioAgentSession(); + } + + [DebuggerBrowsable(DebuggerBrowsableState.Never)] + private string DebuggerDisplay => + $"ConversationId = {this.ConversationId}, StateBag Count = {this.StateBag.Count}"; } diff --git a/dotnet/src/Microsoft.Agents.AI.CopilotStudio/CopilotStudioJsonUtilities.cs b/dotnet/src/Microsoft.Agents.AI.CopilotStudio/CopilotStudioJsonUtilities.cs new file mode 100644 index 0000000000..44177b0708 --- /dev/null +++ b/dotnet/src/Microsoft.Agents.AI.CopilotStudio/CopilotStudioJsonUtilities.cs @@ -0,0 +1,48 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Diagnostics.CodeAnalysis; +using System.Text.Encodings.Web; +using System.Text.Json; +using System.Text.Json.Serialization; + +namespace Microsoft.Agents.AI.CopilotStudio; + +/// +/// Provides utility methods and configurations for JSON serialization operations within the Copilot Studio agent implementation. +/// +internal static partial class CopilotStudioJsonUtilities +{ + /// + /// Gets the default instance used for JSON serialization operations. + /// + public static JsonSerializerOptions DefaultOptions { get; } = CreateDefaultOptions(); + + /// + /// Creates and configures the default JSON serialization options. + /// + /// The configured options. + private static JsonSerializerOptions CreateDefaultOptions() + { + // Copy the configuration from the source generated context. + JsonSerializerOptions options = new(JsonContext.Default.Options) + { + Encoder = JavaScriptEncoder.UnsafeRelaxedJsonEscaping, + }; + + // Chain in the resolvers from both AgentAbstractionsJsonUtilities and our source generated context. + options.TypeInfoResolverChain.Clear(); + options.TypeInfoResolverChain.Add(AgentAbstractionsJsonUtilities.DefaultOptions.TypeInfoResolver!); + options.TypeInfoResolverChain.Add(JsonContext.Default.Options.TypeInfoResolver!); + + options.MakeReadOnly(); + return options; + } + + [JsonSourceGenerationOptions(JsonSerializerDefaults.Web, + UseStringEnumConverter = true, + DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull, + NumberHandling = JsonNumberHandling.AllowReadingFromString)] + [JsonSerializable(typeof(CopilotStudioAgentSession))] + [ExcludeFromCodeCoverage] + private sealed partial class JsonContext : JsonSerializerContext; +} diff --git a/dotnet/src/Microsoft.Agents.AI.GitHub.Copilot/GitHubCopilotAgent.cs b/dotnet/src/Microsoft.Agents.AI.GitHub.Copilot/GitHubCopilotAgent.cs index 06c7f24ee2..0556430636 100644 --- a/dotnet/src/Microsoft.Agents.AI.GitHub.Copilot/GitHubCopilotAgent.cs +++ b/dotnet/src/Microsoft.Agents.AI.GitHub.Copilot/GitHubCopilotAgent.cs @@ -115,7 +115,7 @@ protected override ValueTask DeserializeSessionCoreAsync( JsonElement serializedState, JsonSerializerOptions? jsonSerializerOptions = null, CancellationToken cancellationToken = default) - => new(new GitHubCopilotAgentSession(serializedState, jsonSerializerOptions)); + => new(GitHubCopilotAgentSession.Deserialize(serializedState, jsonSerializerOptions)); /// protected override Task RunCoreAsync( diff --git a/dotnet/src/Microsoft.Agents.AI.GitHub.Copilot/GitHubCopilotAgentSession.cs b/dotnet/src/Microsoft.Agents.AI.GitHub.Copilot/GitHubCopilotAgentSession.cs index f514eeb71b..70fe43425e 100644 --- a/dotnet/src/Microsoft.Agents.AI.GitHub.Copilot/GitHubCopilotAgentSession.cs +++ b/dotnet/src/Microsoft.Agents.AI.GitHub.Copilot/GitHubCopilotAgentSession.cs @@ -1,17 +1,22 @@ // Copyright (c) Microsoft. All rights reserved. +using System; +using System.Diagnostics; using System.Text.Json; +using System.Text.Json.Serialization; namespace Microsoft.Agents.AI.GitHub.Copilot; /// /// Represents a session for a GitHub Copilot agent conversation. /// +[DebuggerDisplay("{DebuggerDisplay,nq}")] public sealed class GitHubCopilotAgentSession : AgentSession { /// /// Gets or sets the session ID for the GitHub Copilot conversation. /// + [JsonPropertyName("sessionId")] public string? SessionId { get; internal set; } /// @@ -21,35 +26,32 @@ internal GitHubCopilotAgentSession() { } - /// - /// Initializes a new instance of the class from serialized data. - /// - /// The serialized thread data. - /// Optional JSON serialization options. - internal GitHubCopilotAgentSession(JsonElement serializedThread, JsonSerializerOptions? jsonSerializerOptions = null) + [JsonConstructor] + internal GitHubCopilotAgentSession(string? sessionId, AgentSessionStateBag? stateBag) : base(stateBag ?? new()) { - // The JSON serialization uses camelCase - if (serializedThread.TryGetProperty("sessionId", out JsonElement sessionIdElement)) - { - this.SessionId = sessionIdElement.GetString(); - } + this.SessionId = sessionId; } /// internal JsonElement Serialize(JsonSerializerOptions? jsonSerializerOptions = null) { - State state = new() - { - SessionId = this.SessionId - }; - - return JsonSerializer.SerializeToElement( - state, - GitHubCopilotJsonUtilities.DefaultOptions.GetTypeInfo(typeof(State))); + var jso = jsonSerializerOptions ?? GitHubCopilotJsonUtilities.DefaultOptions; + return JsonSerializer.SerializeToElement(this, jso.GetTypeInfo(typeof(GitHubCopilotAgentSession))); } - internal sealed class State + internal static GitHubCopilotAgentSession Deserialize(JsonElement serializedState, JsonSerializerOptions? jsonSerializerOptions = null) { - public string? SessionId { get; set; } + if (serializedState.ValueKind != JsonValueKind.Object) + { + throw new ArgumentException("The serialized session state must be a JSON object.", nameof(serializedState)); + } + + var jso = jsonSerializerOptions ?? GitHubCopilotJsonUtilities.DefaultOptions; + return serializedState.Deserialize(jso.GetTypeInfo(typeof(GitHubCopilotAgentSession))) as GitHubCopilotAgentSession + ?? new GitHubCopilotAgentSession(); } + + [DebuggerBrowsable(DebuggerBrowsableState.Never)] + private string DebuggerDisplay => + $"SessionId = {this.SessionId}, StateBag Count = {this.StateBag.Count}"; } diff --git a/dotnet/src/Microsoft.Agents.AI.GitHub.Copilot/GitHubCopilotJsonUtilities.cs b/dotnet/src/Microsoft.Agents.AI.GitHub.Copilot/GitHubCopilotJsonUtilities.cs index c5254efd6b..9e97c0585b 100644 --- a/dotnet/src/Microsoft.Agents.AI.GitHub.Copilot/GitHubCopilotJsonUtilities.cs +++ b/dotnet/src/Microsoft.Agents.AI.GitHub.Copilot/GitHubCopilotJsonUtilities.cs @@ -42,7 +42,7 @@ private static JsonSerializerOptions CreateDefaultOptions() UseStringEnumConverter = true, DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull, NumberHandling = JsonNumberHandling.AllowReadingFromString)] - [JsonSerializable(typeof(GitHubCopilotAgentSession.State))] + [JsonSerializable(typeof(GitHubCopilotAgentSession))] [ExcludeFromCodeCoverage] private sealed partial class JsonContext : JsonSerializerContext; } diff --git a/dotnet/src/Microsoft.Agents.AI/AgentJsonUtilities.cs b/dotnet/src/Microsoft.Agents.AI/AgentJsonUtilities.cs index 79cbf2193c..96ec6dbecb 100644 --- a/dotnet/src/Microsoft.Agents.AI/AgentJsonUtilities.cs +++ b/dotnet/src/Microsoft.Agents.AI/AgentJsonUtilities.cs @@ -65,7 +65,7 @@ private static JsonSerializerOptions CreateDefaultOptions() NumberHandling = JsonNumberHandling.AllowReadingFromString)] // Agent abstraction types - [JsonSerializable(typeof(ChatClientAgentSession.SessionState))] + [JsonSerializable(typeof(ChatClientAgentSession))] [JsonSerializable(typeof(TextSearchProvider.TextSearchProviderState))] [JsonSerializable(typeof(ChatHistoryMemoryProvider.State))] diff --git a/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgent.cs b/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgent.cs index 7fdc7bc84e..5d05daf6e3 100644 --- a/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgent.cs +++ b/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgent.cs @@ -364,7 +364,7 @@ protected override JsonElement SerializeSessionCore(AgentSession session, JsonSe /// protected override ValueTask DeserializeSessionCoreAsync(JsonElement serializedState, JsonSerializerOptions? jsonSerializerOptions = null, CancellationToken cancellationToken = default) { - return new(ChatClientAgentSession.Deserialize(serializedState)); + return new(ChatClientAgentSession.Deserialize(serializedState, jsonSerializerOptions)); } #region Private diff --git a/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgentSession.cs b/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgentSession.cs index b220170771..400bfbcaf6 100644 --- a/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgentSession.cs +++ b/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgentSession.cs @@ -3,6 +3,7 @@ using System; using System.Diagnostics; using System.Text.Json; +using System.Text.Json.Serialization; using Microsoft.Shared.Diagnostics; namespace Microsoft.Agents.AI; @@ -20,6 +21,12 @@ internal ChatClientAgentSession() { } + [JsonConstructor] + internal ChatClientAgentSession(string? conversationId, AgentSessionStateBag? stateBag) : base(stateBag ?? new()) + { + this.ConversationId = conversationId; + } + /// /// Gets or sets the ID of the underlying service chat history to support cases where the chat history is stored by the agent service. /// @@ -37,6 +44,7 @@ internal ChatClientAgentSession() /// to fork the chat history with each iteration. /// /// + [JsonPropertyName("conversationId")] public string? ConversationId { get; @@ -55,50 +63,29 @@ internal set /// Creates a new instance of the class from previously serialized state. /// /// A representing the serialized state of the session. + /// Optional JSON serialization options to use instead of the default options. /// The deserialized . - internal static ChatClientAgentSession Deserialize(JsonElement serializedState) + internal static ChatClientAgentSession Deserialize(JsonElement serializedState, JsonSerializerOptions? jsonSerializerOptions = null) { if (serializedState.ValueKind != JsonValueKind.Object) { throw new ArgumentException("The serialized session state must be a JSON object.", nameof(serializedState)); } - var state = serializedState.Deserialize( - AgentJsonUtilities.DefaultOptions.GetTypeInfo(typeof(SessionState))) as SessionState; - - var session = new ChatClientAgentSession(); - - session.StateBag = AgentSessionStateBag.Deserialize(state?.StateBag ?? default); - - if (state?.ConversationId is string sessionId) - { - session.ConversationId = sessionId; - } - - return session; + var jso = jsonSerializerOptions ?? AgentJsonUtilities.DefaultOptions; + return serializedState.Deserialize(jso.GetTypeInfo(typeof(ChatClientAgentSession))) as ChatClientAgentSession + ?? new ChatClientAgentSession(); } /// internal JsonElement Serialize(JsonSerializerOptions? jsonSerializerOptions = null) { - var state = new SessionState - { - ConversationId = this.ConversationId, - StateBag = this.StateBag.Serialize(), - }; - - return JsonSerializer.SerializeToElement(state, AgentJsonUtilities.DefaultOptions.GetTypeInfo(typeof(SessionState))); + var jso = jsonSerializerOptions ?? AgentJsonUtilities.DefaultOptions; + return JsonSerializer.SerializeToElement(this, jso.GetTypeInfo(typeof(ChatClientAgentSession))); } [DebuggerBrowsable(DebuggerBrowsableState.Never)] private string DebuggerDisplay => this.ConversationId is { } conversationId ? $"ConversationId = {conversationId}, StateBag Count = {this.StateBag.Count}" : $"StateBag Count = {this.StateBag.Count}"; - - internal sealed class SessionState - { - public string? ConversationId { get; set; } - - public JsonElement? StateBag { get; set; } - } } diff --git a/dotnet/tests/Microsoft.Agents.AI.A2A.UnitTests/A2AAgentSessionTests.cs b/dotnet/tests/Microsoft.Agents.AI.A2A.UnitTests/A2AAgentSessionTests.cs index 6b7348c05b..8c3e89adf6 100644 --- a/dotnet/tests/Microsoft.Agents.AI.A2A.UnitTests/A2AAgentSessionTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.A2A.UnitTests/A2AAgentSessionTests.cs @@ -21,7 +21,7 @@ public void Constructor_RoundTrip_SerializationPreservesState() // Act JsonElement serialized = originalSession.Serialize(); - A2AAgentSession deserializedSession = new(serialized); + A2AAgentSession deserializedSession = A2AAgentSession.Deserialize(serialized); // Assert Assert.Equal(originalSession.ContextId, deserializedSession.ContextId); @@ -37,7 +37,7 @@ public void Constructor_RoundTrip_SerializationPreservesStateBag() // Act JsonElement serialized = originalSession.Serialize(); - A2AAgentSession deserializedSession = new(serialized); + A2AAgentSession deserializedSession = A2AAgentSession.Deserialize(serialized); // Assert Assert.Equal("ctx-1", deserializedSession.ContextId); diff --git a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/ServiceIdAgentSessionTests.cs b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/ServiceIdAgentSessionTests.cs deleted file mode 100644 index 778f4939f5..0000000000 --- a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/ServiceIdAgentSessionTests.cs +++ /dev/null @@ -1,136 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. - -using System; -using System.Text.Json; - -namespace Microsoft.Agents.AI.Abstractions.UnitTests; - -/// -/// Tests for . -/// -public class ServiceIdAgentSessionTests -{ - #region Constructor and Property Tests - - [Fact] - public void Constructor_SetsDefaults() - { - // Arrange & Act - var session = new TestServiceIdAgentSession(); - - // Assert - Assert.Null(session.GetServiceSessionId()); - } - - [Fact] - public void Constructor_WithServiceSessionId_SetsProperty() - { - // Arrange & Act - var session = new TestServiceIdAgentSession("service-id-123"); - - // Assert - Assert.Equal("service-id-123", session.GetServiceSessionId()); - } - - [Fact] - public void Constructor_WithSerializedId_SetsProperty() - { - // Arrange - var serviceSessionWrapper = new ServiceIdAgentSession.ServiceIdAgentSessionState { ServiceSessionId = "service-id-456" }; - var json = JsonSerializer.SerializeToElement(serviceSessionWrapper, TestJsonSerializerContext.Default.ServiceIdAgentSessionState); - - // Act - var session = new TestServiceIdAgentSession(json); - - // Assert - Assert.Equal("service-id-456", session.GetServiceSessionId()); - } - - [Fact] - public void Constructor_WithSerializedUndefinedId_SetsProperty() - { - // Arrange - var emptyObject = new EmptyObject(); - var json = JsonSerializer.SerializeToElement(emptyObject, TestJsonSerializerContext.Default.EmptyObject); - - // Act - var session = new TestServiceIdAgentSession(json); - - // Assert - Assert.Null(session.GetServiceSessionId()); - } - - [Fact] - public void Constructor_WithInvalidJson_ThrowsArgumentException() - { - // Arrange - var invalidJson = JsonSerializer.SerializeToElement(42, TestJsonSerializerContext.Default.Int32); - - // Act & Assert - Assert.Throws(() => new TestServiceIdAgentSession(invalidJson)); - } - - #endregion - - #region SerializeAsync Tests - - [Fact] - public void Serialize_ReturnsCorrectJson_WhenServiceSessionIdIsSet() - { - // Arrange - var session = new TestServiceIdAgentSession("service-id-789"); - - // Act - var json = session.Serialize(); - - // Assert - Assert.Equal(JsonValueKind.Object, json.ValueKind); - Assert.True(json.TryGetProperty("serviceSessionId", out var idProperty)); - Assert.Equal("service-id-789", idProperty.GetString()); - } - - [Fact] - public void Serialize_ReturnsUndefinedServiceSessionId_WhenNotSet() - { - // Arrange - var session = new TestServiceIdAgentSession(); - - // Act - var json = session.Serialize(); - - // Assert - Assert.Equal(JsonValueKind.Object, json.ValueKind); - Assert.False(json.TryGetProperty("serviceSessionId", out _)); - } - - [Fact] - public void Serialize_RoundTrip_PreservesStateBag() - { - // Arrange - var session = new TestServiceIdAgentSession("service-id-roundtrip"); - session.StateBag.SetValue("myKey", "myValue"); - - // Act - var json = session.Serialize(); - var restored = new TestServiceIdAgentSession(json); - - // Assert - Assert.Equal("service-id-roundtrip", restored.GetServiceSessionId()); - Assert.True(restored.StateBag.TryGetValue("myKey", out var value)); - Assert.Equal("myValue", value); - } - - #endregion - - // Sealed test subclass to expose protected members for testing - private sealed class TestServiceIdAgentSession : ServiceIdAgentSession - { - public TestServiceIdAgentSession() { } - public TestServiceIdAgentSession(string serviceSessionId) : base(serviceSessionId) { } - public TestServiceIdAgentSession(JsonElement serializedSessionState) : base(serializedSessionState) { } - public string? GetServiceSessionId() => this.ServiceSessionId; - } - - // Helper class to represent empty objects - internal sealed class EmptyObject; -} diff --git a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/TestJsonSerializerContext.cs b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/TestJsonSerializerContext.cs index 37ec8e0a03..c4f3b7511a 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/TestJsonSerializerContext.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/TestJsonSerializerContext.cs @@ -19,7 +19,5 @@ namespace Microsoft.Agents.AI.Abstractions.UnitTests; [JsonSerializable(typeof(Dictionary))] [JsonSerializable(typeof(string[]))] [JsonSerializable(typeof(int))] -[JsonSerializable(typeof(ServiceIdAgentSession.ServiceIdAgentSessionState))] -[JsonSerializable(typeof(ServiceIdAgentSessionTests.EmptyObject))] [JsonSerializable(typeof(InMemoryChatHistoryProviderTests.TestAIContent))] internal sealed partial class TestJsonSerializerContext : JsonSerializerContext; diff --git a/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgentSessionTests.cs b/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgentSessionTests.cs index 52b3652e02..f60a7d126f 100644 --- a/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgentSessionTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgentSessionTests.cs @@ -212,7 +212,7 @@ public void VerifySessionSerializationWithCustomOptions() // Arrange var session = new ChatClientAgentSession(); JsonSerializerOptions options = new() { PropertyNamingPolicy = JsonNamingPolicy.SnakeCaseLower }; - options.TypeInfoResolverChain.Add(AgentAbstractionsJsonUtilities.DefaultOptions.TypeInfoResolver!); + options.TypeInfoResolverChain.Add(AgentJsonUtilities.DefaultOptions.TypeInfoResolver!); // Act var json = session.Serialize(options); @@ -220,7 +220,8 @@ public void VerifySessionSerializationWithCustomOptions() // Assert Assert.Equal(JsonValueKind.Object, json.ValueKind); - Assert.False(json.TryGetProperty("conversationId", out var _)); + // [JsonPropertyName] takes precedence over naming policy + Assert.True(json.TryGetProperty("conversationId", out var _)); } #endregion Serialize Tests From b399e07b3afea9932f0a53588f102bfb1438068c Mon Sep 17 00:00:00 2001 From: westey <164392973+westey-m@users.noreply.github.com> Date: Mon, 9 Feb 2026 15:37:18 +0000 Subject: [PATCH 19/28] Fix durable agent session jso usgae --- .../DurableAgentSession.cs | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/dotnet/src/Microsoft.Agents.AI.DurableTask/DurableAgentSession.cs b/dotnet/src/Microsoft.Agents.AI.DurableTask/DurableAgentSession.cs index 31c8eec2db..ba33c15d32 100644 --- a/dotnet/src/Microsoft.Agents.AI.DurableTask/DurableAgentSession.cs +++ b/dotnet/src/Microsoft.Agents.AI.DurableTask/DurableAgentSession.cs @@ -7,9 +7,9 @@ namespace Microsoft.Agents.AI.DurableTask; /// -/// An agent thread implementation for durable agents. +/// An implementation for durable agents. /// -[DebuggerDisplay("{SessionId}")] +[DebuggerDisplay("{DebuggerDisplay,nq}")] public sealed class DurableAgentSession : AgentSession { internal DurableAgentSession(AgentSessionId sessionId) @@ -33,9 +33,8 @@ internal DurableAgentSession(AgentSessionId sessionId, AgentSessionStateBag stat /// internal JsonElement Serialize(JsonSerializerOptions? jsonSerializerOptions = null) { - return JsonSerializer.SerializeToElement( - this, - DurableAgentJsonUtilities.DefaultOptions.GetTypeInfo(typeof(DurableAgentSession))); + var jso = jsonSerializerOptions ?? DurableAgentJsonUtilities.DefaultOptions; + return JsonSerializer.SerializeToElement(this, jso.GetTypeInfo(typeof(DurableAgentSession))); } /// @@ -77,4 +76,8 @@ public override string ToString() { return this.SessionId.ToString(); } + + [DebuggerBrowsable(DebuggerBrowsableState.Never)] + private string DebuggerDisplay => + $"SessionId = {this.SessionId}, StateBag Count = {this.StateBag.Count}"; } From c907e12ebc8ccd37cc8613c8c18372684ed983e8 Mon Sep 17 00:00:00 2001 From: westey <164392973+westey-m@users.noreply.github.com> Date: Mon, 9 Feb 2026 16:47:47 +0000 Subject: [PATCH 20/28] Add jso to InMemory and Workflow ChatHistoryProviders --- .../InMemoryChatHistoryProvider.cs | 15 ++++++++++++--- .../WorkflowChatHistoryProvider.cs | 19 +++++++++++++++---- 2 files changed, 27 insertions(+), 7 deletions(-) diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/InMemoryChatHistoryProvider.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/InMemoryChatHistoryProvider.cs index e968c8154c..52d356e010 100644 --- a/dotnet/src/Microsoft.Agents.AI.Abstractions/InMemoryChatHistoryProvider.cs +++ b/dotnet/src/Microsoft.Agents.AI.Abstractions/InMemoryChatHistoryProvider.cs @@ -3,6 +3,7 @@ using System; using System.Collections.Generic; using System.Linq; +using System.Text.Json; using System.Threading; using System.Threading.Tasks; using Microsoft.Extensions.AI; @@ -29,6 +30,7 @@ public sealed class InMemoryChatHistoryProvider : ChatHistoryProvider private readonly string _stateKey; private readonly Func _stateInitializer; + private readonly JsonSerializerOptions _jsonSerializerOptions; /// /// Initializes a new instance of the class. @@ -49,6 +51,11 @@ public sealed class InMemoryChatHistoryProvider : ChatHistoryProvider /// An optional key to use for storing the state in the . /// If , a default key will be used. /// + /// + /// Optional JSON serializer options for serializing the state of this provider. + /// This is valuable for cases like when the chat history contains custom types + /// and source generated serializers are required, or Native AOT / Trimming is required. + /// /// /// Message reducers enable automatic management of message storage by implementing strategies to /// keep memory usage under control while preserving important conversation context. @@ -57,12 +64,14 @@ public InMemoryChatHistoryProvider( Func? stateInitializer = null, IChatReducer? chatReducer = null, ChatReducerTriggerEvent reducerTriggerEvent = ChatReducerTriggerEvent.BeforeMessagesRetrieval, - string? stateKey = null) + string? stateKey = null, + JsonSerializerOptions? jsonSerializerOptions = null) { this._stateInitializer = stateInitializer ?? (_ => new State()); this.ChatReducer = chatReducer; this.ReducerTriggerEvent = reducerTriggerEvent; this._stateKey = stateKey ?? DefaultStateBagKey; + this._jsonSerializerOptions = jsonSerializerOptions ?? AgentAbstractionsJsonUtilities.DefaultOptions; } /// @@ -104,7 +113,7 @@ public void SetMessages(AgentSession? session, List messages) /// The provider state, or null if no session is available. private State GetOrInitializeState(AgentSession? session) { - var state = session?.StateBag.GetValue(this._stateKey, AgentAbstractionsJsonUtilities.DefaultOptions); + var state = session?.StateBag.GetValue(this._stateKey, this._jsonSerializerOptions); if (state is not null) { return state; @@ -113,7 +122,7 @@ private State GetOrInitializeState(AgentSession? session) state = this._stateInitializer(session); if (session is not null) { - session.StateBag.SetValue(this._stateKey, state, AgentAbstractionsJsonUtilities.DefaultOptions); + session.StateBag.SetValue(this._stateKey, state, this._jsonSerializerOptions); } return state; diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows/WorkflowChatHistoryProvider.cs b/dotnet/src/Microsoft.Agents.AI.Workflows/WorkflowChatHistoryProvider.cs index 708100b5ba..51ede16fca 100644 --- a/dotnet/src/Microsoft.Agents.AI.Workflows/WorkflowChatHistoryProvider.cs +++ b/dotnet/src/Microsoft.Agents.AI.Workflows/WorkflowChatHistoryProvider.cs @@ -2,6 +2,7 @@ using System.Collections.Generic; using System.Linq; +using System.Text.Json; using System.Threading; using System.Threading.Tasks; using Microsoft.Extensions.AI; @@ -11,9 +12,19 @@ namespace Microsoft.Agents.AI.Workflows; internal sealed class WorkflowChatHistoryProvider : ChatHistoryProvider { private const string DefaultStateBagKey = "WorkflowChatHistoryProvider.State"; - - public WorkflowChatHistoryProvider() + private readonly JsonSerializerOptions _jsonSerializerOptions; + + /// + /// Initializes a new instance of the class. + /// + /// + /// Optional JSON serializer options for serializing the state of this provider. + /// This is valuable for cases like when the chat history contains custom types + /// and source generated serializers are required, or Native AOT / Trimming is required. + /// + public WorkflowChatHistoryProvider(JsonSerializerOptions? jsonSerializerOptions = null) { + this._jsonSerializerOptions = jsonSerializerOptions ?? AgentAbstractionsJsonUtilities.DefaultOptions; } internal sealed class StoreState @@ -24,7 +35,7 @@ internal sealed class StoreState private StoreState GetOrInitializeState(AgentSession? session) { - var state = session?.StateBag.GetValue(DefaultStateBagKey, AgentAbstractionsJsonUtilities.DefaultOptions); + var state = session?.StateBag.GetValue(DefaultStateBagKey, this._jsonSerializerOptions); if (state is not null) { return state; @@ -33,7 +44,7 @@ private StoreState GetOrInitializeState(AgentSession? session) state = new(); if (session is not null) { - session.StateBag.SetValue(DefaultStateBagKey, state, AgentAbstractionsJsonUtilities.DefaultOptions); + session.StateBag.SetValue(DefaultStateBagKey, state, this._jsonSerializerOptions); } return state; From d7356cf5786f14aba52df8487137a421e37c0778 Mon Sep 17 00:00:00 2001 From: westey <164392973+westey-m@users.noreply.github.com> Date: Mon, 9 Feb 2026 17:02:38 +0000 Subject: [PATCH 21/28] Update InMemoryChatHistoryProvider to use an options class for it's many optional settings. --- .../Agent_Step16_ChatReduction/Program.cs | 2 +- .../InMemoryChatHistoryProvider.cs | 68 ++++--------------- .../InMemoryChatHistoryProviderOptions.cs | 67 ++++++++++++++++++ .../InMemoryChatHistoryProviderTests.cs | 23 ++++--- 4 files changed, 92 insertions(+), 68 deletions(-) create mode 100644 dotnet/src/Microsoft.Agents.AI.Abstractions/InMemoryChatHistoryProviderOptions.cs diff --git a/dotnet/samples/GettingStarted/Agents/Agent_Step16_ChatReduction/Program.cs b/dotnet/samples/GettingStarted/Agents/Agent_Step16_ChatReduction/Program.cs index dfeddb386d..1989dbc2d6 100644 --- a/dotnet/samples/GettingStarted/Agents/Agent_Step16_ChatReduction/Program.cs +++ b/dotnet/samples/GettingStarted/Agents/Agent_Step16_ChatReduction/Program.cs @@ -24,7 +24,7 @@ { ChatOptions = new() { Instructions = "You are good at telling jokes." }, Name = "Joker", - ChatHistoryProvider = new InMemoryChatHistoryProvider(chatReducer: new MessageCountingChatReducer(2)) + ChatHistoryProvider = new InMemoryChatHistoryProvider(new() { ChatReducer = new MessageCountingChatReducer(2) }) }); AgentSession session = await agent.CreateSessionAsync(); diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/InMemoryChatHistoryProvider.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/InMemoryChatHistoryProvider.cs index 52d356e010..f322d0bfd4 100644 --- a/dotnet/src/Microsoft.Agents.AI.Abstractions/InMemoryChatHistoryProvider.cs +++ b/dotnet/src/Microsoft.Agents.AI.Abstractions/InMemoryChatHistoryProvider.cs @@ -35,43 +35,17 @@ public sealed class InMemoryChatHistoryProvider : ChatHistoryProvider /// /// Initializes a new instance of the class. /// - /// - /// An optional delegate that initializes the provider state on the first invocation. - /// If , a default initializer that creates an empty state will be used. + /// + /// Optional configuration options that control the provider's behavior, including state initialization, + /// message reduction, and serialization settings. If , default settings will be used. /// - /// - /// An optional instance used to process, reduce, or optimize chat messages. - /// This can be used to implement strategies like message summarization, truncation, or cleanup. - /// - /// - /// Specifies when the message reducer should be invoked. The default is , - /// which applies reduction logic when messages are retrieved for agent consumption. - /// - /// - /// An optional key to use for storing the state in the . - /// If , a default key will be used. - /// - /// - /// Optional JSON serializer options for serializing the state of this provider. - /// This is valuable for cases like when the chat history contains custom types - /// and source generated serializers are required, or Native AOT / Trimming is required. - /// - /// - /// Message reducers enable automatic management of message storage by implementing strategies to - /// keep memory usage under control while preserving important conversation context. - /// - public InMemoryChatHistoryProvider( - Func? stateInitializer = null, - IChatReducer? chatReducer = null, - ChatReducerTriggerEvent reducerTriggerEvent = ChatReducerTriggerEvent.BeforeMessagesRetrieval, - string? stateKey = null, - JsonSerializerOptions? jsonSerializerOptions = null) + public InMemoryChatHistoryProvider(InMemoryChatHistoryProviderOptions? options = null) { - this._stateInitializer = stateInitializer ?? (_ => new State()); - this.ChatReducer = chatReducer; - this.ReducerTriggerEvent = reducerTriggerEvent; - this._stateKey = stateKey ?? DefaultStateBagKey; - this._jsonSerializerOptions = jsonSerializerOptions ?? AgentAbstractionsJsonUtilities.DefaultOptions; + this._stateInitializer = options?.StateInitializer ?? (_ => new State()); + this.ChatReducer = options?.ChatReducer; + this.ReducerTriggerEvent = options?.ReducerTriggerEvent ?? InMemoryChatHistoryProviderOptions.ChatReducerTriggerEvent.BeforeMessagesRetrieval; + this._stateKey = options?.StateKey ?? DefaultStateBagKey; + this._jsonSerializerOptions = options?.JsonSerializerOptions ?? AgentAbstractionsJsonUtilities.DefaultOptions; } /// @@ -82,7 +56,7 @@ public InMemoryChatHistoryProvider( /// /// Gets the event that triggers the reducer invocation in this provider. /// - public ChatReducerTriggerEvent ReducerTriggerEvent { get; } + public InMemoryChatHistoryProviderOptions.ChatReducerTriggerEvent ReducerTriggerEvent { get; } /// /// Gets the chat messages stored for the specified session. @@ -135,7 +109,7 @@ public override async ValueTask> InvokingAsync(Invoking var state = this.GetOrInitializeState(context.Session); - if (this.ReducerTriggerEvent is ChatReducerTriggerEvent.BeforeMessagesRetrieval && this.ChatReducer is not null) + if (this.ReducerTriggerEvent is InMemoryChatHistoryProviderOptions.ChatReducerTriggerEvent.BeforeMessagesRetrieval && this.ChatReducer is not null) { state.Messages = (await this.ChatReducer.ReduceAsync(state.Messages, cancellationToken).ConfigureAwait(false)).ToList(); } @@ -159,7 +133,7 @@ public override async ValueTask InvokedAsync(InvokedContext context, Cancellatio var allNewMessages = context.RequestMessages.Concat(context.AIContextProviderMessages ?? []).Concat(context.ResponseMessages ?? []); state.Messages.AddRange(allNewMessages); - if (this.ReducerTriggerEvent is ChatReducerTriggerEvent.AfterMessageAdded && this.ChatReducer is not null) + if (this.ReducerTriggerEvent is InMemoryChatHistoryProviderOptions.ChatReducerTriggerEvent.AfterMessageAdded && this.ChatReducer is not null) { state.Messages = (await this.ChatReducer.ReduceAsync(state.Messages, cancellationToken).ConfigureAwait(false)).ToList(); } @@ -175,22 +149,4 @@ public sealed class State /// public List Messages { get; set; } = []; } - - /// - /// Defines the events that can trigger a reducer in the . - /// - public enum ChatReducerTriggerEvent - { - /// - /// Trigger the reducer when a new message is added. - /// will only complete when reducer processing is done. - /// - AfterMessageAdded, - - /// - /// Trigger the reducer before messages are retrieved from the provider. - /// The reducer will process the messages before they are returned to the caller. - /// - BeforeMessagesRetrieval - } } diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/InMemoryChatHistoryProviderOptions.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/InMemoryChatHistoryProviderOptions.cs new file mode 100644 index 0000000000..df0cae090d --- /dev/null +++ b/dotnet/src/Microsoft.Agents.AI.Abstractions/InMemoryChatHistoryProviderOptions.cs @@ -0,0 +1,67 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Text.Json; +using Microsoft.Extensions.AI; + +namespace Microsoft.Agents.AI; + +/// +/// Represents configuration options for . +/// +public sealed class InMemoryChatHistoryProviderOptions +{ + /// + /// Gets or sets an optional delegate that initializes the provider state on the first invocation. + /// If , a default initializer that creates an empty state will be used. + /// + public Func? StateInitializer { get; set; } + + /// + /// Gets or sets an optional instance used to process, reduce, or optimize chat messages. + /// This can be used to implement strategies like message summarization, truncation, or cleanup. + /// + public IChatReducer? ChatReducer { get; set; } + + /// + /// Gets or sets when the message reducer should be invoked. + /// The default is , + /// which applies reduction logic when messages are retrieved for agent consumption. + /// + /// + /// Message reducers enable automatic management of message storage by implementing strategies to + /// keep memory usage under control while preserving important conversation context. + /// + public ChatReducerTriggerEvent ReducerTriggerEvent { get; set; } = ChatReducerTriggerEvent.BeforeMessagesRetrieval; + + /// + /// Gets or sets an optional key to use for storing the state in the . + /// If , a default key will be used. + /// + public string? StateKey { get; set; } + + /// + /// Gets or sets optional JSON serializer options for serializing the state of this provider. + /// This is valuable for cases like when the chat history contains custom types + /// and source generated serializers are required, or Native AOT / Trimming is required. + /// + public JsonSerializerOptions? JsonSerializerOptions { get; set; } + + /// + /// Defines the events that can trigger a reducer in the . + /// + public enum ChatReducerTriggerEvent + { + /// + /// Trigger the reducer when a new message is added. + /// will only complete when reducer processing is done. + /// + AfterMessageAdded, + + /// + /// Trigger the reducer before messages are retrieved from the provider. + /// The reducer will process the messages before they are returned to the caller. + /// + BeforeMessagesRetrieval + } +} diff --git a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/InMemoryChatHistoryProviderTests.cs b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/InMemoryChatHistoryProviderTests.cs index fb81b0fc73..7556022489 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/InMemoryChatHistoryProviderTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/InMemoryChatHistoryProviderTests.cs @@ -24,10 +24,10 @@ public void Constructor_DefaultsToBeforeMessageRetrieval_ForNotProvidedTriggerEv { // Arrange & Act var reducerMock = new Mock(); - var provider = new InMemoryChatHistoryProvider(chatReducer: reducerMock.Object); + var provider = new InMemoryChatHistoryProvider(new() { ChatReducer = reducerMock.Object }); // Assert - Assert.Equal(InMemoryChatHistoryProvider.ChatReducerTriggerEvent.BeforeMessagesRetrieval, provider.ReducerTriggerEvent); + Assert.Equal(InMemoryChatHistoryProviderOptions.ChatReducerTriggerEvent.BeforeMessagesRetrieval, provider.ReducerTriggerEvent); } [Fact] @@ -35,11 +35,11 @@ public void Constructor_Arguments_SetOnPropertiesCorrectly() { // Arrange & Act var reducerMock = new Mock(); - var provider = new InMemoryChatHistoryProvider(chatReducer: reducerMock.Object, reducerTriggerEvent: InMemoryChatHistoryProvider.ChatReducerTriggerEvent.AfterMessageAdded); + var provider = new InMemoryChatHistoryProvider(new() { ChatReducer = reducerMock.Object, ReducerTriggerEvent = InMemoryChatHistoryProviderOptions.ChatReducerTriggerEvent.AfterMessageAdded }); // Assert Assert.Same(reducerMock.Object, provider.ChatReducer); - Assert.Equal(InMemoryChatHistoryProvider.ChatReducerTriggerEvent.AfterMessageAdded, provider.ReducerTriggerEvent); + Assert.Equal(InMemoryChatHistoryProviderOptions.ChatReducerTriggerEvent.AfterMessageAdded, provider.ReducerTriggerEvent); } [Fact] @@ -128,8 +128,10 @@ public void StateInitializer_IsInvoked_WhenSessionHasNoState() { new(ChatRole.User, "Initial message") }; - var provider = new InMemoryChatHistoryProvider( - stateInitializer: _ => new InMemoryChatHistoryProvider.State { Messages = initialMessages }); + var provider = new InMemoryChatHistoryProvider(new() + { + StateInitializer = _ => new InMemoryChatHistoryProvider.State { Messages = initialMessages } + }); // Act var messages = provider.GetMessages(CreateMockSession()); @@ -232,7 +234,7 @@ public async Task AddMessagesAsync_WithReducer_AfterMessageAdded_InvokesReducerA .Setup(r => r.ReduceAsync(It.Is>(x => x.SequenceEqual(originalMessages)), It.IsAny())) .ReturnsAsync(reducedMessages); - var provider = new InMemoryChatHistoryProvider(chatReducer: reducerMock.Object, reducerTriggerEvent: InMemoryChatHistoryProvider.ChatReducerTriggerEvent.AfterMessageAdded); + var provider = new InMemoryChatHistoryProvider(new() { ChatReducer = reducerMock.Object, ReducerTriggerEvent = InMemoryChatHistoryProviderOptions.ChatReducerTriggerEvent.AfterMessageAdded }); // Act var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, session, originalMessages, []); @@ -266,8 +268,7 @@ public async Task GetMessagesAsync_WithReducer_BeforeMessagesRetrieval_InvokesRe .Setup(r => r.ReduceAsync(It.Is>(x => x.SequenceEqual(originalMessages)), It.IsAny())) .ReturnsAsync(reducedMessages); - var provider = new InMemoryChatHistoryProvider(chatReducer: reducerMock.Object, reducerTriggerEvent: InMemoryChatHistoryProvider.ChatReducerTriggerEvent.BeforeMessagesRetrieval); - // Add messages directly to the provider for this test + var provider = new InMemoryChatHistoryProvider(new() { ChatReducer = reducerMock.Object, ReducerTriggerEvent = InMemoryChatHistoryProviderOptions.ChatReducerTriggerEvent.BeforeMessagesRetrieval }); provider.SetMessages(session, new List(originalMessages)); // Act @@ -293,7 +294,7 @@ public async Task AddMessagesAsync_WithReducer_ButWrongTrigger_DoesNotInvokeRedu var reducerMock = new Mock(); - var provider = new InMemoryChatHistoryProvider(chatReducer: reducerMock.Object, reducerTriggerEvent: InMemoryChatHistoryProvider.ChatReducerTriggerEvent.BeforeMessagesRetrieval); + var provider = new InMemoryChatHistoryProvider(new() { ChatReducer = reducerMock.Object, ReducerTriggerEvent = InMemoryChatHistoryProviderOptions.ChatReducerTriggerEvent.BeforeMessagesRetrieval }); // Act var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, session, originalMessages, []); @@ -319,7 +320,7 @@ public async Task GetMessagesAsync_WithReducer_ButWrongTrigger_DoesNotInvokeRedu var reducerMock = new Mock(); - var provider = new InMemoryChatHistoryProvider(chatReducer: reducerMock.Object, reducerTriggerEvent: InMemoryChatHistoryProvider.ChatReducerTriggerEvent.AfterMessageAdded); + var provider = new InMemoryChatHistoryProvider(new() { ChatReducer = reducerMock.Object, ReducerTriggerEvent = InMemoryChatHistoryProviderOptions.ChatReducerTriggerEvent.AfterMessageAdded }); provider.SetMessages(session, new List(originalMessages)); // Act From 99abe7ff0fe726eff3bed17067494758a511c103 Mon Sep 17 00:00:00 2001 From: westey <164392973+westey-m@users.noreply.github.com> Date: Mon, 9 Feb 2026 17:04:21 +0000 Subject: [PATCH 22/28] Apply suggestions from code review Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- .../AgentProviders/Agent_With_CustomImplementation/Program.cs | 2 +- .../Microsoft.Agents.AI.Abstractions/AgentSessionStateBag.cs | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/dotnet/samples/GettingStarted/AgentProviders/Agent_With_CustomImplementation/Program.cs b/dotnet/samples/GettingStarted/AgentProviders/Agent_With_CustomImplementation/Program.cs index fbed86bacb..4223830f65 100644 --- a/dotnet/samples/GettingStarted/AgentProviders/Agent_With_CustomImplementation/Program.cs +++ b/dotnet/samples/GettingStarted/AgentProviders/Agent_With_CustomImplementation/Program.cs @@ -29,7 +29,7 @@ internal sealed class UpperCaseParrotAgent : AIAgent { public override string? Name => "UpperCaseParrotAgent"; - public ChatHistoryProvider ChatHistoryProvider = new InMemoryChatHistoryProvider(); + public readonly ChatHistoryProvider ChatHistoryProvider = new InMemoryChatHistoryProvider(); protected override ValueTask CreateSessionCoreAsync(CancellationToken cancellationToken = default) => new(new CustomAgentSession()); diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentSessionStateBag.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentSessionStateBag.cs index 0d4e189a1c..45450100e0 100644 --- a/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentSessionStateBag.cs +++ b/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentSessionStateBag.cs @@ -49,8 +49,8 @@ internal AgentSessionStateBag(ConcurrentDictionary /// The type of the value to retrieve. /// The key from which to retrieve the value. - /// The value if found and convertable to the required type; otherwise, null. - /// The JSON serializer options to use for serializing/deserialing the value. + /// The value if found and convertible to the required type; otherwise, null. + /// The JSON serializer options to use for serializing/deserializing the value. /// if the value was successfully retrieved, otherwise. public bool TryGetValue(string key, out T? value, JsonSerializerOptions? jsonSerializerOptions = null) where T : class From 23704fdfc87fd4a886a738dba6ae3cc5a629f1c7 Mon Sep 17 00:00:00 2001 From: westey <164392973+westey-m@users.noreply.github.com> Date: Mon, 9 Feb 2026 17:32:27 +0000 Subject: [PATCH 23/28] Address PR feedback --- .../Microsoft.Agents.AI.Mem0/Mem0Provider.cs | 12 +++++++++- .../ChatClient/ChatClientAgent.cs | 23 +++++++++++++++---- 2 files changed, 29 insertions(+), 6 deletions(-) diff --git a/dotnet/src/Microsoft.Agents.AI.Mem0/Mem0Provider.cs b/dotnet/src/Microsoft.Agents.AI.Mem0/Mem0Provider.cs index d26b505a7e..562d312caf 100644 --- a/dotnet/src/Microsoft.Agents.AI.Mem0/Mem0Provider.cs +++ b/dotnet/src/Microsoft.Agents.AI.Mem0/Mem0Provider.cs @@ -84,7 +84,17 @@ public Mem0Provider(HttpClient httpClient, Func stateIniti } state = this._stateInitializer(session); - if (state is not null && session is not null) + + if (state is null + || state.StorageScope is null + || (state.StorageScope.AgentId is null && state.StorageScope.ThreadId is null && state.StorageScope.UserId is null && state.StorageScope.ApplicationId is null) + || state.SearchScope is null + || (state.SearchScope.AgentId is null && state.SearchScope.ThreadId is null && state.SearchScope.UserId is null && state.SearchScope.ApplicationId is null)) + { + throw new InvalidOperationException("State initializer must return a non-null state with valid storage and search scopes, where at lest one scoping parameter is set for each."); + } + + if (session is not null) { session.StateBag.SetValue(this._stateKey, state, Mem0JsonUtilities.DefaultOptions); } diff --git a/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgent.cs b/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgent.cs index 5d05daf6e3..8d925cbc09 100644 --- a/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgent.cs +++ b/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgent.cs @@ -671,7 +671,7 @@ private async Task // Populate the session messages only if we are not continuing an existing response as it's not allowed if (chatOptions?.ContinuationToken is null) { - ChatHistoryProvider? chatHistoryProvider = this.ResolveChatHistoryProvider(chatOptions); + ChatHistoryProvider? chatHistoryProvider = this.ResolveChatHistoryProvider(chatOptions, typedSession); // Add any existing messages from the session to the messages to be sent to the chat client. if (chatHistoryProvider is not null) @@ -751,7 +751,8 @@ private void UpdateSessionConversationId(ChatClientAgentSession session, string? { // The agent has a ChatHistoryProvider configured, but the service returned a conversation id, // meaning the service manages chat history server-side. Both cannot be used simultaneously. - throw new InvalidOperationException("Only the ConversationId or ChatHistoryProvider may be used, but not both. The service returned a conversation id indicating server-side chat history management, but the agent has a ChatHistoryProvider configured."); + throw new InvalidOperationException( + $"Only {nameof(ChatClientAgentSession.ConversationId)} or {nameof(this.ChatHistoryProvider)} may be used, but not both. The service returned a conversation id indicating server-side chat history management, but the agent has a {nameof(this.ChatHistoryProvider)} configured."); } // If we got a conversation id back from the chat client, it means that the service supports server side session storage @@ -776,7 +777,7 @@ private Task NotifyChatHistoryProviderOfFailureAsync( ChatOptions? chatOptions, CancellationToken cancellationToken) { - ChatHistoryProvider? provider = this.ResolveChatHistoryProvider(chatOptions); + ChatHistoryProvider? provider = this.ResolveChatHistoryProvider(chatOptions, session); // Only notify the provider if we have one. // If we don't have one, it means that the chat history is service managed and the underlying service is responsible for storing messages. @@ -803,7 +804,7 @@ private Task NotifyChatHistoryProviderOfNewMessagesAsync( ChatOptions? chatOptions, CancellationToken cancellationToken) { - ChatHistoryProvider? provider = this.ResolveChatHistoryProvider(chatOptions); + ChatHistoryProvider? provider = this.ResolveChatHistoryProvider(chatOptions, session); // Only notify the provider if we have one. // If we don't have one, it means that the chat history is service managed and the underlying service is responsible for storing messages. @@ -820,13 +821,25 @@ private Task NotifyChatHistoryProviderOfNewMessagesAsync( return Task.CompletedTask; } - private ChatHistoryProvider? ResolveChatHistoryProvider(ChatOptions? chatOptions) + private ChatHistoryProvider? ResolveChatHistoryProvider(ChatOptions? chatOptions, ChatClientAgentSession session) { ChatHistoryProvider? provider = this.ChatHistoryProvider; + if (provider is not null) + { + throw new InvalidOperationException( + $"Only {nameof(ChatClientAgentSession.ConversationId)} or {nameof(this.ChatHistoryProvider)} may be used, but not both. The current {nameof(ChatClientAgentSession)} has a {nameof(ChatClientAgentSession.ConversationId)} indicating server-side chat history management, but the agent has a {nameof(this.ChatHistoryProvider)} configured."); + } + // If someone provided an override ChatHistoryProvider via AdditionalProperties, we should use that instead. if (chatOptions?.AdditionalProperties?.TryGetValue(out ChatHistoryProvider? overrideProvider) is true) { + if (overrideProvider is not null) + { + throw new InvalidOperationException( + $"Only {nameof(ChatClientAgentSession.ConversationId)} or {nameof(this.ChatHistoryProvider)} may be used, but not both. The current {nameof(ChatClientAgentSession)} has a {nameof(ChatClientAgentSession.ConversationId)} indicating server-side chat history management, but an override {nameof(this.ChatHistoryProvider)} is was provided via {nameof(AgentRunOptions.AdditionalProperties)}."); + } + provider = overrideProvider; } From d32139c218105bd370f250c384197a457a68766c Mon Sep 17 00:00:00 2001 From: westey <164392973+westey-m@users.noreply.github.com> Date: Mon, 9 Feb 2026 18:21:40 +0000 Subject: [PATCH 24/28] Fix verification bug. --- dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgent.cs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgent.cs b/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgent.cs index 8d925cbc09..8cbc4026de 100644 --- a/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgent.cs +++ b/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgent.cs @@ -825,7 +825,7 @@ private Task NotifyChatHistoryProviderOfNewMessagesAsync( { ChatHistoryProvider? provider = this.ChatHistoryProvider; - if (provider is not null) + if (session.ConversationId is not null && provider is not null) { throw new InvalidOperationException( $"Only {nameof(ChatClientAgentSession.ConversationId)} or {nameof(this.ChatHistoryProvider)} may be used, but not both. The current {nameof(ChatClientAgentSession)} has a {nameof(ChatClientAgentSession.ConversationId)} indicating server-side chat history management, but the agent has a {nameof(this.ChatHistoryProvider)} configured."); @@ -834,7 +834,7 @@ private Task NotifyChatHistoryProviderOfNewMessagesAsync( // If someone provided an override ChatHistoryProvider via AdditionalProperties, we should use that instead. if (chatOptions?.AdditionalProperties?.TryGetValue(out ChatHistoryProvider? overrideProvider) is true) { - if (overrideProvider is not null) + if (session.ConversationId is not null && overrideProvider is not null) { throw new InvalidOperationException( $"Only {nameof(ChatClientAgentSession.ConversationId)} or {nameof(this.ChatHistoryProvider)} may be used, but not both. The current {nameof(ChatClientAgentSession)} has a {nameof(ChatClientAgentSession.ConversationId)} indicating server-side chat history management, but an override {nameof(this.ChatHistoryProvider)} is was provided via {nameof(AgentRunOptions.AdditionalProperties)}."); From dc433a6d474cbdbef17611199bc67f500010875f Mon Sep 17 00:00:00 2001 From: westey <164392973+westey-m@users.noreply.github.com> Date: Mon, 9 Feb 2026 18:45:31 +0000 Subject: [PATCH 25/28] Improve state bag thread safety --- .../AgentSessionStateBag.cs | 60 +------- .../AgentSessionStateBagValue.cs | 131 ++++++++++++++++-- .../AgentSessionStateBagTests.cs | 57 ++++++++ 3 files changed, 178 insertions(+), 70 deletions(-) diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentSessionStateBag.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentSessionStateBag.cs index 45450100e0..d78a866b2c 100644 --- a/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentSessionStateBag.cs +++ b/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentSessionStateBag.cs @@ -60,37 +60,9 @@ public bool TryGetValue(string key, out T? value, JsonSerializerOptions? json if (this._state.TryGetValue(key, out var stateValue)) { - if (stateValue.DeserializedValue is T cachedValue) - { - value = cachedValue; - return true; - } - - switch (stateValue.JsonValue) - { - case JsonElement jsonElement when jsonElement.ValueKind == JsonValueKind.Undefined: - value = null; - return false; - case JsonElement jsonElement when jsonElement.ValueKind == JsonValueKind.Null: - value = null; - return true; - default: - T? result = stateValue.JsonValue.Deserialize(jso.GetTypeInfo(typeof(T))) as T; - if (result is null) - { - value = null; - return false; - } - - stateValue.IsDeserialized = true; - stateValue.DeserializedValue = result; - stateValue.ValueType = typeof(T); - stateValue.JsonSerializerOptions = jso; - - value = result; - return true; - } + return stateValue.TryReadDeserializedValue(out value, jso); } + value = null; return false; } @@ -111,28 +83,7 @@ public bool TryGetValue(string key, out T? value, JsonSerializerOptions? json if (this._state.TryGetValue(key, out var stateValue)) { - if (stateValue.DeserializedValue is T cachedValue) - { - return cachedValue; - } - - switch (stateValue.JsonValue) - { - case JsonElement jsonElement when jsonElement.ValueKind == JsonValueKind.Null || jsonElement.ValueKind == JsonValueKind.Undefined: - return null; - default: - T? result = stateValue.JsonValue.Deserialize(jso.GetTypeInfo(typeof(T))) as T; - if (result is null) - { - throw new InvalidOperationException($"Failed to deserialize session state value to type {typeof(T).FullName}."); - } - - stateValue.IsDeserialized = true; - stateValue.DeserializedValue = result; - stateValue.ValueType = typeof(T); - stateValue.JsonSerializerOptions = jso; - return result; - } + return stateValue.ReadDeserializedValue(jso); } return null; @@ -154,10 +105,7 @@ public void SetValue(string key, T? value, JsonSerializerOptions? jsonSeriali var stateValue = this._state.GetOrAdd(key, _ => new AgentSessionStateBagValue(value, typeof(T), jso)); - stateValue.IsDeserialized = true; - stateValue.DeserializedValue = value; - stateValue.ValueType = typeof(T); - stateValue.JsonSerializerOptions = jso; + stateValue.SetDeserialized(value, typeof(T), jso); } /// diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentSessionStateBagValue.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentSessionStateBagValue.cs index fd1a455ce8..6f97237516 100644 --- a/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentSessionStateBagValue.cs +++ b/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentSessionStateBagValue.cs @@ -12,6 +12,10 @@ namespace Microsoft.Agents.AI; [JsonConverter(typeof(AgentSessionStateBagValueJsonConverter))] internal class AgentSessionStateBagValue { + private readonly object _lock = new(); + private DeserializedCache? _cache; + private JsonElement _jsonValue; + /// /// Initializes a new instance of the SessionStateValue class with the specified value. /// @@ -29,10 +33,7 @@ public AgentSessionStateBagValue(JsonElement jsonValue) /// The JSON serializer options to use for serializing the value. public AgentSessionStateBagValue(object? deserializedValue, Type valueType, JsonSerializerOptions jsonSerializerOptions) { - this.IsDeserialized = true; - this.DeserializedValue = deserializedValue; - this.ValueType = valueType; - this.JsonSerializerOptions = jsonSerializerOptions; + this._cache = new DeserializedCache(deserializedValue, valueType, jsonSerializerOptions); } /// @@ -42,26 +43,128 @@ public JsonElement JsonValue { get { - if (this.IsDeserialized) + lock (this._lock) { - if (this.ValueType is null || this.JsonSerializerOptions is null) + // We are assuming here that JsonValue will only be read when the object is being serialized, + // which means that we will only call SerializeToElement when serializing and therefore it's + // OK to serialize on each read if the cache is set. + if (this._cache is { } cache) { - throw new InvalidOperationException($"{nameof(AgentSessionStateBagValue)} has not been properly initialized, please set {nameof(this.ValueType)} and {nameof(this.JsonSerializerOptions)} before accessing {nameof(this.JsonValue)}."); + this._jsonValue = JsonSerializer.SerializeToElement(cache.Value, cache.Options.GetTypeInfo(cache.ValueType)); } - field = JsonSerializer.SerializeToElement(this.DeserializedValue, this.JsonSerializerOptions.GetTypeInfo(this.ValueType)); + return this._jsonValue; + } + } + set + { + lock (this._lock) + { + this._jsonValue = value; + this._cache = null; + } + } + } + + /// + /// Tries to read the deserialized value of this session state value. + /// Returns false if the value could not be deserialized into the required type, or if the value is undefined. + /// Returns true and sets the out parameter to null if the value is null. + /// + public bool TryReadDeserializedValue(out T? value, JsonSerializerOptions? jsonSerializerOptions = null) + where T : class + { + var jso = jsonSerializerOptions ?? AgentAbstractionsJsonUtilities.DefaultOptions; + + lock (this._lock) + { + if (this._cache is { } cache) + { + value = cache.Value as T; + return true; + } + + switch (this._jsonValue) + { + case JsonElement jsonElement when jsonElement.ValueKind == JsonValueKind.Undefined: + value = null; + return false; + case JsonElement jsonElement when jsonElement.ValueKind == JsonValueKind.Null: + value = null; + return true; + default: + T? result = this._jsonValue.Deserialize(jso.GetTypeInfo(typeof(T))) as T; + if (result is null) + { + value = null; + return false; + } + + this._cache = new DeserializedCache(result, typeof(T), jso); + + value = result; + return true; + } + } + } + + /// + /// Reads the deserialized value of this session state value, throwing an exception if the value could not be deserialized into the required type or is undefined. + /// + public T? ReadDeserializedValue(JsonSerializerOptions? jsonSerializerOptions = null) + where T : class + { + var jso = jsonSerializerOptions ?? AgentAbstractionsJsonUtilities.DefaultOptions; + + lock (this._lock) + { + if (this._cache is { } cache) + { + return cache.Value as T; + } + + switch (this._jsonValue) + { + case JsonElement jsonElement when jsonElement.ValueKind == JsonValueKind.Null || jsonElement.ValueKind == JsonValueKind.Undefined: + return null; + default: + T? result = this._jsonValue.Deserialize(jso.GetTypeInfo(typeof(T))) as T; + if (result is null) + { + throw new InvalidOperationException($"Failed to deserialize session state value to type {typeof(T).FullName}."); + } + + this._cache = new DeserializedCache(result, typeof(T), jso); + return result; } + } + } - return field; + /// + /// Sets the deserialized value of this session state value, updating the cache accordingly. + /// This does not update the JsonValue directly; the JsonValue will be updated on the next read or when the object is serialized. + /// + public void SetDeserialized(object? deserializedValue, Type valueType, JsonSerializerOptions jsonSerializerOptions) + { + lock (this._lock) + { + this._cache = new DeserializedCache(deserializedValue, valueType, jsonSerializerOptions); } - set; } - public bool IsDeserialized { get; set; } + private readonly struct DeserializedCache + { + public DeserializedCache(object? value, Type valueType, JsonSerializerOptions options) + { + this.Value = value; + this.ValueType = valueType; + this.Options = options; + } - public object? DeserializedValue { get; set; } + public object? Value { get; } - public Type? ValueType { get; set; } + public Type ValueType { get; } - public JsonSerializerOptions? JsonSerializerOptions { get; set; } + public JsonSerializerOptions Options { get; } + } } diff --git a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AgentSessionStateBagTests.cs b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AgentSessionStateBagTests.cs index 68c6d24dd4..b10f9acf32 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AgentSessionStateBagTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AgentSessionStateBagTests.cs @@ -494,6 +494,63 @@ public async System.Threading.Tasks.Task SetValue_MultipleConcurrentWrites_DoesN } } + [Fact] + public async System.Threading.Tasks.Task ConcurrentWritesAndSerialize_DoesNotThrowAsync() + { + // Arrange + var stateBag = new AgentSessionStateBag(); + stateBag.SetValue("shared", "initial"); + var tasks = new System.Threading.Tasks.Task[100]; + + // Act - concurrently write and serialize the same key + for (int i = 0; i < 100; i++) + { + int index = i; + tasks[i] = System.Threading.Tasks.Task.Run(() => + { + stateBag.SetValue("shared", $"value{index}"); + _ = stateBag.Serialize(); + }); + } + + await System.Threading.Tasks.Task.WhenAll(tasks); + + // Assert - should have some value and serialize without error + Assert.True(stateBag.TryGetValue("shared", out var result)); + Assert.NotNull(result); + var json = stateBag.Serialize(); + Assert.Equal(JsonValueKind.Object, json.ValueKind); + } + + [Fact] + public async System.Threading.Tasks.Task ConcurrentReadsAndWrites_DoesNotThrowAsync() + { + // Arrange + var stateBag = new AgentSessionStateBag(); + stateBag.SetValue("key", "initial"); + var tasks = new System.Threading.Tasks.Task[200]; + + // Act - half readers, half writers on the same key + for (int i = 0; i < 200; i++) + { + int index = i; + if (index % 2 == 0) + { + tasks[i] = System.Threading.Tasks.Task.Run(() => stateBag.GetValue("key")); + } + else + { + tasks[i] = System.Threading.Tasks.Task.Run(() => stateBag.SetValue("key", $"value{index}")); + } + } + + await System.Threading.Tasks.Task.WhenAll(tasks); + + // Assert - should have a consistent value + Assert.True(stateBag.TryGetValue("key", out var result)); + Assert.NotNull(result); + } + #endregion #region Complex Object Tests From 1cb99f242f6d096c1bdce788f0ff1bd28a122f7b Mon Sep 17 00:00:00 2001 From: westey <164392973+westey-m@users.noreply.github.com> Date: Mon, 9 Feb 2026 19:08:45 +0000 Subject: [PATCH 26/28] Address PR comments and fix unit tests --- .../Microsoft.Agents.AI.Abstractions/ChatHistoryProvider.cs | 4 ---- .../InMemoryChatHistoryProvider.cs | 2 ++ dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgent.cs | 2 +- .../ChatClient/ChatClientAgentSessionTests.cs | 5 +---- .../ChatClient/ChatClientAgentTests.cs | 2 +- .../ChatClient/ChatClientAgent_ChatHistoryManagementTests.cs | 2 +- .../TestJsonSerializerContext.cs | 1 + 7 files changed, 7 insertions(+), 11 deletions(-) diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/ChatHistoryProvider.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/ChatHistoryProvider.cs index 07bf7e5ed2..81209c63dc 100644 --- a/dotnet/src/Microsoft.Agents.AI.Abstractions/ChatHistoryProvider.cs +++ b/dotnet/src/Microsoft.Agents.AI.Abstractions/ChatHistoryProvider.cs @@ -63,10 +63,6 @@ public abstract class ChatHistoryProvider /// Archiving old messages while keeping active conversation context /// /// - /// - /// Each instance should be associated with a single to ensure proper message isolation - /// and context management. - /// /// public abstract ValueTask> InvokingAsync(InvokingContext context, CancellationToken cancellationToken = default); diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/InMemoryChatHistoryProvider.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/InMemoryChatHistoryProvider.cs index f322d0bfd4..d48f4dd2de 100644 --- a/dotnet/src/Microsoft.Agents.AI.Abstractions/InMemoryChatHistoryProvider.cs +++ b/dotnet/src/Microsoft.Agents.AI.Abstractions/InMemoryChatHistoryProvider.cs @@ -4,6 +4,7 @@ using System.Collections.Generic; using System.Linq; using System.Text.Json; +using System.Text.Json.Serialization; using System.Threading; using System.Threading.Tasks; using Microsoft.Extensions.AI; @@ -147,6 +148,7 @@ public sealed class State /// /// Gets or sets the list of chat messages. /// + [JsonPropertyName("messages")] public List Messages { get; set; } = []; } } diff --git a/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgent.cs b/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgent.cs index 8cbc4026de..f30eadf3bc 100644 --- a/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgent.cs +++ b/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgent.cs @@ -837,7 +837,7 @@ private Task NotifyChatHistoryProviderOfNewMessagesAsync( if (session.ConversationId is not null && overrideProvider is not null) { throw new InvalidOperationException( - $"Only {nameof(ChatClientAgentSession.ConversationId)} or {nameof(this.ChatHistoryProvider)} may be used, but not both. The current {nameof(ChatClientAgentSession)} has a {nameof(ChatClientAgentSession.ConversationId)} indicating server-side chat history management, but an override {nameof(this.ChatHistoryProvider)} is was provided via {nameof(AgentRunOptions.AdditionalProperties)}."); + $"Only {nameof(ChatClientAgentSession.ConversationId)} or {nameof(this.ChatHistoryProvider)} may be used, but not both. The current {nameof(ChatClientAgentSession)} has a {nameof(ChatClientAgentSession.ConversationId)} indicating server-side chat history management, but an override {nameof(this.ChatHistoryProvider)} was provided via {nameof(AgentRunOptions.AdditionalProperties)}."); } provider = overrideProvider; diff --git a/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgentSessionTests.cs b/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgentSessionTests.cs index f60a7d126f..27a13c3b49 100644 --- a/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgentSessionTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgentSessionTests.cs @@ -56,14 +56,11 @@ public void VerifyDeserializeWithMessages() """, TestJsonSerializerContext.Default.JsonElement); // Act. - var session = ChatClientAgentSession.Deserialize(json); + var session = ChatClientAgentSession.Deserialize(json, TestJsonSerializerContext.Default.Options); // Assert Assert.Null(session.ConversationId); - // Verify the StateBag contains the serialized chat history provider state - Assert.True(session.StateBag.TryGetValue("InMemoryChatHistoryProvider.State", out _)); - var chatHistoryProvider = new InMemoryChatHistoryProvider(); var messages = chatHistoryProvider.GetMessages(session); Assert.Single(messages); diff --git a/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgentTests.cs b/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgentTests.cs index 156679a466..fd9dc25bc9 100644 --- a/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgentTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgentTests.cs @@ -1334,7 +1334,7 @@ public async Task RunStreamingAsyncThrowsWhenChatHistoryProviderProvidedAndConve // Act & Assert ChatClientAgentSession? session = await agent.CreateSessionAsync() as ChatClientAgentSession; var exception = await Assert.ThrowsAsync(async () => await agent.RunStreamingAsync([new(ChatRole.User, "test")], session).ToListAsync()); - Assert.Equal("Only the ConversationId or ChatHistoryProvider may be used, but not both. The service returned a conversation id indicating server-side chat history management, but the agent has a ChatHistoryProvider configured.", exception.Message); + Assert.Equal("Only ConversationId or ChatHistoryProvider may be used, but not both. The service returned a conversation id indicating server-side chat history management, but the agent has a ChatHistoryProvider configured.", exception.Message); } /// diff --git a/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgent_ChatHistoryManagementTests.cs b/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgent_ChatHistoryManagementTests.cs index 9e7121467c..c5ff91c5e8 100644 --- a/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgent_ChatHistoryManagementTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgent_ChatHistoryManagementTests.cs @@ -276,7 +276,7 @@ public async Task RunAsync_Throws_WhenChatHistoryProviderProvidedAndConversation // Act & Assert ChatClientAgentSession? session = await agent.CreateSessionAsync() as ChatClientAgentSession; InvalidOperationException exception = await Assert.ThrowsAsync(() => agent.RunAsync([new(ChatRole.User, "test")], session)); - Assert.Equal("Only the ConversationId or ChatHistoryProvider may be used, but not both. The service returned a conversation id indicating server-side chat history management, but the agent has a ChatHistoryProvider configured.", exception.Message); + Assert.Equal("Only ConversationId or ChatHistoryProvider may be used, but not both. The service returned a conversation id indicating server-side chat history management, but the agent has a ChatHistoryProvider configured.", exception.Message); } #endregion diff --git a/dotnet/tests/Microsoft.Agents.AI.UnitTests/TestJsonSerializerContext.cs b/dotnet/tests/Microsoft.Agents.AI.UnitTests/TestJsonSerializerContext.cs index 0ac3ab9fbf..c07dd6eb8d 100644 --- a/dotnet/tests/Microsoft.Agents.AI.UnitTests/TestJsonSerializerContext.cs +++ b/dotnet/tests/Microsoft.Agents.AI.UnitTests/TestJsonSerializerContext.cs @@ -15,4 +15,5 @@ namespace Microsoft.Agents.AI.UnitTests; [JsonSerializable(typeof(string[]))] [JsonSerializable(typeof(Dictionary))] [JsonSerializable(typeof(ChatClientAgentSessionTests.Animal))] +[JsonSerializable(typeof(ChatClientAgentSession))] internal sealed partial class TestJsonSerializerContext : JsonSerializerContext; From 10001a5a3cf864e34032cac731b13ff92b17c33d Mon Sep 17 00:00:00 2001 From: westey <164392973+westey-m@users.noreply.github.com> Date: Tue, 10 Feb 2026 11:27:54 +0000 Subject: [PATCH 27/28] Address PR comments --- .../Program.cs | 4 +--- .../InMemoryChatHistoryProvider.cs | 3 +-- .../CosmosChatHistoryProvider.cs | 3 +-- dotnet/src/Microsoft.Agents.AI.Mem0/Mem0Provider.cs | 3 +-- .../WorkflowChatHistoryProvider.cs | 3 +-- .../Memory/ChatHistoryMemoryProvider.cs | 3 +-- .../AgentSessionStateBagTests.cs | 13 +++++++++++++ 7 files changed, 19 insertions(+), 13 deletions(-) diff --git a/dotnet/samples/GettingStarted/Agents/Agent_Step07_3rdPartyChatHistoryStorage/Program.cs b/dotnet/samples/GettingStarted/Agents/Agent_Step07_3rdPartyChatHistoryStorage/Program.cs index 23b07a1920..a044ac496c 100644 --- a/dotnet/samples/GettingStarted/Agents/Agent_Step07_3rdPartyChatHistoryStorage/Program.cs +++ b/dotnet/samples/GettingStarted/Agents/Agent_Step07_3rdPartyChatHistoryStorage/Program.cs @@ -96,9 +96,7 @@ public string GetSessionDbKey(AgentSession session) private State GetOrInitializeState(AgentSession? session) { - var state = session?.StateBag.GetValue(this._stateKey); - - if (state is not null) + if (session?.StateBag.TryGetValue(this._stateKey, out var state) is true && state is not null) { return state; } diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/InMemoryChatHistoryProvider.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/InMemoryChatHistoryProvider.cs index 61d23380c0..f85b7a4662 100644 --- a/dotnet/src/Microsoft.Agents.AI.Abstractions/InMemoryChatHistoryProvider.cs +++ b/dotnet/src/Microsoft.Agents.AI.Abstractions/InMemoryChatHistoryProvider.cs @@ -88,8 +88,7 @@ public void SetMessages(AgentSession? session, List messages) /// The provider state, or null if no session is available. private State GetOrInitializeState(AgentSession? session) { - var state = session?.StateBag.GetValue(this._stateKey, this._jsonSerializerOptions); - if (state is not null) + if (session?.StateBag.TryGetValue(this._stateKey, out var state, this._jsonSerializerOptions) is true && state is not null) { return state; } diff --git a/dotnet/src/Microsoft.Agents.AI.CosmosNoSql/CosmosChatHistoryProvider.cs b/dotnet/src/Microsoft.Agents.AI.CosmosNoSql/CosmosChatHistoryProvider.cs index adc5fd89d0..1646f9216b 100644 --- a/dotnet/src/Microsoft.Agents.AI.CosmosNoSql/CosmosChatHistoryProvider.cs +++ b/dotnet/src/Microsoft.Agents.AI.CosmosNoSql/CosmosChatHistoryProvider.cs @@ -157,8 +157,7 @@ public CosmosChatHistoryProvider( /// The provider state, or null if no session is available. private State GetOrInitializeState(AgentSession? session) { - var state = session?.StateBag.GetValue(this._stateKey, AgentAbstractionsJsonUtilities.DefaultOptions); - if (state is not null) + if (session?.StateBag.TryGetValue(this._stateKey, out var state, AgentAbstractionsJsonUtilities.DefaultOptions) is true && state is not null) { return state; } diff --git a/dotnet/src/Microsoft.Agents.AI.Mem0/Mem0Provider.cs b/dotnet/src/Microsoft.Agents.AI.Mem0/Mem0Provider.cs index 2dd4aa3903..7e645971d8 100644 --- a/dotnet/src/Microsoft.Agents.AI.Mem0/Mem0Provider.cs +++ b/dotnet/src/Microsoft.Agents.AI.Mem0/Mem0Provider.cs @@ -77,8 +77,7 @@ public Mem0Provider(HttpClient httpClient, Func stateIniti /// The provider state, or null if no session is available. private State? GetOrInitializeState(AgentSession? session) { - var state = session?.StateBag.GetValue(this._stateKey, Mem0JsonUtilities.DefaultOptions); - if (state is not null) + if (session?.StateBag.TryGetValue(this._stateKey, out var state, Mem0JsonUtilities.DefaultOptions) is true && state is not null) { return state; } diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows/WorkflowChatHistoryProvider.cs b/dotnet/src/Microsoft.Agents.AI.Workflows/WorkflowChatHistoryProvider.cs index 8d08410ce6..349c680085 100644 --- a/dotnet/src/Microsoft.Agents.AI.Workflows/WorkflowChatHistoryProvider.cs +++ b/dotnet/src/Microsoft.Agents.AI.Workflows/WorkflowChatHistoryProvider.cs @@ -35,8 +35,7 @@ internal sealed class StoreState private StoreState GetOrInitializeState(AgentSession? session) { - var state = session?.StateBag.GetValue(DefaultStateBagKey, this._jsonSerializerOptions); - if (state is not null) + if (session?.StateBag.TryGetValue(DefaultStateBagKey, out var state, this._jsonSerializerOptions) is true && state is not null) { return state; } diff --git a/dotnet/src/Microsoft.Agents.AI/Memory/ChatHistoryMemoryProvider.cs b/dotnet/src/Microsoft.Agents.AI/Memory/ChatHistoryMemoryProvider.cs index 403742c753..93901deeee 100644 --- a/dotnet/src/Microsoft.Agents.AI/Memory/ChatHistoryMemoryProvider.cs +++ b/dotnet/src/Microsoft.Agents.AI/Memory/ChatHistoryMemoryProvider.cs @@ -120,8 +120,7 @@ public ChatHistoryMemoryProvider( /// The provider state, or null if no session is available. private State? GetOrInitializeState(AgentSession? session) { - var state = session?.StateBag.GetValue(this._stateKey, AgentJsonUtilities.DefaultOptions); - if (state is not null) + if (session?.StateBag.TryGetValue(this._stateKey, out var state, AgentJsonUtilities.DefaultOptions) is true && state is not null) { return state; } diff --git a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AgentSessionStateBagTests.cs b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AgentSessionStateBagTests.cs index b10f9acf32..d14621c434 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AgentSessionStateBagTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AgentSessionStateBagTests.cs @@ -730,5 +730,18 @@ public void JsonSerializerDeserialize_NullJson_ReturnsNull() Assert.Null(stateBag); } +#if NET + [Fact] + public void JsonSerializerSerialize_WithUnknownType_Throws() + { + // Arrange + var stateBag = new AgentSessionStateBag(); + stateBag.SetValue("key", new { Name = "Test" }); // Anonymous type which cannot be deserialized + + // Act & Assert + Assert.Throws(() => JsonSerializer.Serialize(stateBag, AgentAbstractionsJsonUtilities.DefaultOptions)); + } +#endif + #endregion } From a6d2f0df69cfaba6924d5c35eacc2f70dad91b28 Mon Sep 17 00:00:00 2001 From: westey <164392973+westey-m@users.noreply.github.com> Date: Tue, 10 Feb 2026 11:43:48 +0000 Subject: [PATCH 28/28] Fix unit test --- .../AgentSessionStateBagTests.cs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AgentSessionStateBagTests.cs b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AgentSessionStateBagTests.cs index d14621c434..b30af7acc6 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AgentSessionStateBagTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AgentSessionStateBagTests.cs @@ -730,7 +730,7 @@ public void JsonSerializerDeserialize_NullJson_ReturnsNull() Assert.Null(stateBag); } -#if NET +#if NET10_0_OR_GREATER [Fact] public void JsonSerializerSerialize_WithUnknownType_Throws() { @@ -739,7 +739,7 @@ public void JsonSerializerSerialize_WithUnknownType_Throws() stateBag.SetValue("key", new { Name = "Test" }); // Anonymous type which cannot be deserialized // Act & Assert - Assert.Throws(() => JsonSerializer.Serialize(stateBag, AgentAbstractionsJsonUtilities.DefaultOptions)); + Assert.Throws(() => JsonSerializer.Serialize(stateBag, AgentAbstractionsJsonUtilities.DefaultOptions)); } #endif