From eccd6de90c1c4e291d512b589b086347d00325dc Mon Sep 17 00:00:00 2001 From: westey <164392973+westey-m@users.noreply.github.com> Date: Thu, 5 Feb 2026 12:01:13 +0000 Subject: [PATCH 1/5] Add a StateBag to AgentSession and pass Agent and AgentSession to AIContextProvider and ChatHistoryProviders --- .../Program.cs | 8 +- .../AIContextProvider.cs | 39 +- .../AgentAbstractionsJsonUtilities.cs | 2 + .../AgentSession.cs | 5 + .../AgentSessionStateBag.cs | 180 ++++++++ .../AgentSessionStateBagValue.cs | 67 +++ .../ChatHistoryProvider.cs | 39 +- .../ChatClient/ChatClientAgent.cs | 44 +- .../ChatClient/ChatClientAgentSession.cs | 5 + .../IAgentFixture.cs | 2 +- .../RunStreamingTests.cs | 2 +- .../RunTests.cs | 2 +- .../AnthropicChatCompletionFixture.cs | 4 +- .../AIProjectClientFixture.cs | 4 +- .../AzureAIAgentsPersistentFixture.cs | 2 +- .../CopilotStudioFixture.cs | 2 +- .../AIContextProviderTests.cs | 122 +++++- .../AgentSessionStateBagTests.cs | 407 ++++++++++++++++++ .../AgentSessionTests.cs | 15 + .../ChatHistoryProviderExtensionsTests.cs | 9 +- .../ChatHistoryProviderMessageFilterTests.cs | 13 +- .../ChatHistoryProviderTests.cs | 122 +++++- .../InMemoryChatHistoryProviderTests.cs | 21 +- .../CosmosChatHistoryProviderTests.cs | 59 +-- .../DurableAgentSessionTests.cs | 4 +- .../Mem0ProviderTests.cs | 23 +- .../Mem0ProviderTests.cs | 17 +- .../ChatClient/ChatClientAgentSessionTests.cs | 74 ++++ .../Data/TextSearchProviderTests.cs | 89 ++-- .../Memory/ChatHistoryMemoryProviderTests.cs | 17 +- .../TestJsonSerializerContext.cs | 1 + .../OpenAIAssistantFixture.cs | 2 +- .../OpenAIChatCompletionFixture.cs | 4 +- .../OpenAIResponseFixture.cs | 4 +- 34 files changed, 1238 insertions(+), 172 deletions(-) create mode 100644 dotnet/src/Microsoft.Agents.AI.Abstractions/AgentSessionStateBag.cs create mode 100644 dotnet/src/Microsoft.Agents.AI.Abstractions/AgentSessionStateBagValue.cs create mode 100644 dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AgentSessionStateBagTests.cs diff --git a/dotnet/samples/GettingStarted/AgentProviders/Agent_With_CustomImplementation/Program.cs b/dotnet/samples/GettingStarted/AgentProviders/Agent_With_CustomImplementation/Program.cs index 5d4e77474a..82f76e7599 100644 --- a/dotnet/samples/GettingStarted/AgentProviders/Agent_With_CustomImplementation/Program.cs +++ b/dotnet/samples/GettingStarted/AgentProviders/Agent_With_CustomImplementation/Program.cs @@ -55,14 +55,14 @@ protected override async Task RunCoreAsync(IEnumerable responseMessages = CloneAndToUpperCase(messages, this.Name).ToList(); // Notify the session of the input and output messages. - var invokedContext = new ChatHistoryProvider.InvokedContext(messages, storeMessages) + var invokedContext = new ChatHistoryProvider.InvokedContext(this, session, messages, storeMessages) { ResponseMessages = responseMessages }; @@ -87,14 +87,14 @@ protected override async IAsyncEnumerable RunCoreStreamingA } // Get existing messages from the store - var invokingContext = new ChatHistoryProvider.InvokingContext(messages); + var invokingContext = new ChatHistoryProvider.InvokingContext(this, session, messages); var storeMessages = await typedSession.ChatHistoryProvider.InvokingAsync(invokingContext, cancellationToken); // Clone the input messages and turn them into response messages with upper case text. List responseMessages = CloneAndToUpperCase(messages, this.Name).ToList(); // Notify the session of the input and output messages. - var invokedContext = new ChatHistoryProvider.InvokedContext(messages, storeMessages) + var invokedContext = new ChatHistoryProvider.InvokedContext(this, session, messages, storeMessages) { ResponseMessages = responseMessages }; diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/AIContextProvider.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/AIContextProvider.cs index f104f12890..8428d46f9b 100644 --- a/dotnet/src/Microsoft.Agents.AI.Abstractions/AIContextProvider.cs +++ b/dotnet/src/Microsoft.Agents.AI.Abstractions/AIContextProvider.cs @@ -129,13 +129,30 @@ public sealed class InvokingContext /// /// Initializes a new instance of the class with the specified request messages. /// + /// The agent being invoked. + /// The session associated with the agent invocation. /// The messages to be used by the agent for this invocation. /// is . - public InvokingContext(IEnumerable requestMessages) + public InvokingContext( + AIAgent agent, + AgentSession? session, + IEnumerable requestMessages) { + this.Agent = Throw.IfNull(agent); + this.Session = session; this.RequestMessages = requestMessages ?? throw new ArgumentNullException(nameof(requestMessages)); } + /// + /// Gets the agent that is being invoked. + /// + public AIAgent Agent { get; } + + /// + /// Gets the agent session associated with the agent invocation. + /// + public AgentSession? Session { get; } + /// /// Gets the caller provided messages that will be used by the agent for this invocation. /// @@ -158,15 +175,33 @@ public sealed class InvokedContext /// /// Initializes a new instance of the class with the specified request messages. /// + /// The agent being invoked. + /// The session associated with the agent invocation. /// The caller provided messages that were used by the agent for this invocation. /// The messages provided by the for this invocation, if any. /// is . - public InvokedContext(IEnumerable requestMessages, IEnumerable? aiContextProviderMessages) + public InvokedContext( + AIAgent agent, + AgentSession? session, + IEnumerable requestMessages, + IEnumerable? aiContextProviderMessages) { + this.Agent = Throw.IfNull(agent); + this.Session = session; this.RequestMessages = requestMessages ?? throw new ArgumentNullException(nameof(requestMessages)); this.AIContextProviderMessages = aiContextProviderMessages; } + /// + /// Gets the agent that is being invoked. + /// + public AIAgent Agent { get; } + + /// + /// Gets the agent session associated with the agent invocation. + /// + public AgentSession? Session { get; } + /// /// Gets the caller provided messages that were used by the agent for this invocation. /// diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentAbstractionsJsonUtilities.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentAbstractionsJsonUtilities.cs index 17fbb9e4c6..bf0e835b4b 100644 --- a/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentAbstractionsJsonUtilities.cs +++ b/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentAbstractionsJsonUtilities.cs @@ -1,5 +1,6 @@ // Copyright (c) Microsoft. All rights reserved. +using System.Collections.Concurrent; using System.Diagnostics.CodeAnalysis; using System.Text.Encodings.Web; using System.Text.Json; @@ -83,6 +84,7 @@ private static JsonSerializerOptions CreateDefaultOptions() [JsonSerializable(typeof(ServiceIdAgentSession.ServiceIdAgentSessionState))] [JsonSerializable(typeof(InMemoryAgentSession.InMemoryAgentSessionState))] [JsonSerializable(typeof(InMemoryChatHistoryProvider.State))] + [JsonSerializable(typeof(ConcurrentDictionary))] [ExcludeFromCodeCoverage] private sealed partial class JsonContext : JsonSerializerContext; diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentSession.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentSession.cs index 3efce9be17..722660d49e 100644 --- a/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentSession.cs +++ b/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentSession.cs @@ -53,6 +53,11 @@ protected AgentSession() { } + /// + /// Gets any arbitrary state associated with this session. + /// + public AgentSessionStateBag StateBag { get; protected set; } = new(); + /// Asks the for an object of the specified type . /// The type of object being requested. /// An optional key that can be used to help identify the target service. diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentSessionStateBag.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentSessionStateBag.cs new file mode 100644 index 0000000000..47eef61508 --- /dev/null +++ b/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentSessionStateBag.cs @@ -0,0 +1,180 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Concurrent; +using System.Text.Json; +using Microsoft.Shared.Diagnostics; + +namespace Microsoft.Agents.AI; + +/// +/// Provides a thread-safe key-value store for managing session-scoped state with support for type-safe access and JSON +/// serialization options. +/// +/// +/// SessionState enables storing and retrieving objects associated with a session using string keys. +/// Values can be accessed in a type-safe manner and are serialized or deserialized using configurable JSON serializer +/// options. This class is designed for concurrent access and is safe to use across multiple threads. +/// +public class AgentSessionStateBag +{ + private readonly ConcurrentDictionary _state; + + /// + /// Initializes a new instance of the class. + /// + public AgentSessionStateBag() + { + this._state = new ConcurrentDictionary(); + } + + /// + /// Initializes a new instance of the class. + /// + /// The initial state dictionary. + internal AgentSessionStateBag(ConcurrentDictionary? state) + { + this._state = state ?? new ConcurrentDictionary(); + } + + /// + /// Tries to get a value from the session state. + /// + /// The type of the value to retrieve. + /// The key from which to retrieve the value. + /// The value if found and convertable to the required type; otherwise, null. + /// The JSON serializer options to use for serializing/deserialing the value. + /// if the value was successfully retrieved, otherwise. + public bool TryGetValue(string key, out T? value, JsonSerializerOptions? jsonSerializerOptions = null) + where T : class + { + _ = Throw.IfNullOrWhitespace(key); + var jso = jsonSerializerOptions ?? AgentAbstractionsJsonUtilities.DefaultOptions; + + if (this._state.TryGetValue(key, out var stateValue)) + { + if (stateValue.DeserializedValue is T cachedValue) + { + value = cachedValue; + return true; + } + + switch (stateValue.JsonValue) + { + case T tValue: + value = tValue; + return true; + case JsonElement jsonElement when jsonElement.ValueKind == JsonValueKind.Null || jsonElement.ValueKind == JsonValueKind.Undefined: + value = null; + return false; + default: + T? result = stateValue.JsonValue.Deserialize(jso.GetTypeInfo(typeof(T))) as T; + if (result is null) + { + value = null; + return false; + } + + stateValue.DeserializedValue = result; + stateValue.ValueType = typeof(T); + stateValue.JsonSerializerOptions = jso; + + value = result; + return true; + } + } + value = null; + return false; + } + + /// + /// Gets a value from the session state. + /// + /// The type of value to get. + /// The key from which to retrieve the value. + /// The JSON serializer options to use for serializing/deserialing the value. + /// The retrieved value or null if not found. + /// The value could not be deserialized into the required type. + public T? GetValue(string key, JsonSerializerOptions? jsonSerializerOptions = null) + where T : class + { + _ = Throw.IfNullOrWhitespace(key); + var jso = jsonSerializerOptions ?? AgentAbstractionsJsonUtilities.DefaultOptions; + + if (this._state.TryGetValue(key, out var stateValue)) + { + if (stateValue.DeserializedValue is T cachedValue) + { + return cachedValue; + } + + switch (stateValue.JsonValue) + { + case T tValue: + return tValue; + case JsonElement jsonElement when jsonElement.ValueKind == JsonValueKind.Null || jsonElement.ValueKind == JsonValueKind.Undefined: + return null; + default: + T? result = stateValue.JsonValue.Deserialize(jso.GetTypeInfo(typeof(T))) as T; + if (result is null) + { + throw new InvalidOperationException($"Failed to deserialize session state value to type {typeof(T).FullName}."); + } + stateValue.DeserializedValue = result; + stateValue.ValueType = typeof(T); + stateValue.JsonSerializerOptions = jso; + return result; + } + } + + return null; + } + + /// + /// Sets a value in the session state. + /// + /// The type of the value to set. + /// The key to store the value under. + /// The value to set. + /// The JSON serializer options to use for serializing the value. + public void SetValue(string key, T value, JsonSerializerOptions? jsonSerializerOptions = null) + where T : class + { + _ = Throw.IfNullOrWhitespace(key); + var jso = jsonSerializerOptions ?? AgentAbstractionsJsonUtilities.DefaultOptions; + + var stateValue = this._state.GetOrAdd(key, _ => + new AgentSessionStateBagValue(value, typeof(T), jso)); + + stateValue.DeserializedValue = value; + stateValue.ValueType = typeof(T); + stateValue.JsonSerializerOptions = jso; + } + + /// + /// Serializes all session state values to a JSON object. + /// + /// A representing the serialized session state. + /// Thrown when a session state value is not properly initialized. + public JsonElement Serialize() + { + return JsonSerializer.SerializeToElement(this._state, AgentAbstractionsJsonUtilities.DefaultOptions.GetTypeInfo(typeof(ConcurrentDictionary))); + } + + /// + /// Deserializes a JSON object into an instance. + /// + /// The element to deserialize. + /// The deserialized . + public static AgentSessionStateBag Deserialize(JsonElement jsonElement) + { + if (jsonElement.ValueKind is JsonValueKind.Undefined or JsonValueKind.Null) + { + return new AgentSessionStateBag(); + } + + return new AgentSessionStateBag( + jsonElement.Deserialize(AgentAbstractionsJsonUtilities.DefaultOptions.GetTypeInfo(typeof(ConcurrentDictionary))) as ConcurrentDictionary + ?? new ConcurrentDictionary()); + } +} diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentSessionStateBagValue.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentSessionStateBagValue.cs new file mode 100644 index 0000000000..a4c1743e77 --- /dev/null +++ b/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentSessionStateBagValue.cs @@ -0,0 +1,67 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Text.Json; +using System.Text.Json.Serialization; + +namespace Microsoft.Agents.AI; + +/// +/// Used to store a value in session state. +/// +internal class AgentSessionStateBagValue +{ + /// + /// Initializes a new instance of the SessionStateValue class with the specified value. + /// + /// The serialized value to associate with the session state. + [JsonConstructor] + public AgentSessionStateBagValue(JsonElement jsonValue) + { + this.JsonValue = jsonValue; + } + + /// + /// Initializes a new instance of the SessionStateValue class with the specified value. + /// + /// The value to associate with the session state. Can be any object, including null. + /// The type of the value. + /// The JSON serializer options to use for serializing the value. + public AgentSessionStateBagValue(object deserializedValue, Type valueType, JsonSerializerOptions jsonSerializerOptions) + { + this.DeserializedValue = deserializedValue; + this.ValueType = valueType; + this.JsonSerializerOptions = jsonSerializerOptions; + } + + /// + /// Gets or sets the value associated with this instance. + /// + public JsonElement JsonValue + { + get + { + if (this.DeserializedValue != null) + { + if (this.ValueType is null || this.JsonSerializerOptions is null) + { + throw new InvalidOperationException($"{nameof(AgentSessionStateBagValue)} has not been properly initialized, please set {nameof(this.ValueType)} and {nameof(this.JsonSerializerOptions)} before accessing {nameof(this.JsonValue)}."); + } + + return JsonSerializer.SerializeToElement(this.DeserializedValue, this.JsonSerializerOptions.GetTypeInfo(this.ValueType)); + } + + return field; + } + set; + } + + [JsonIgnore] + public object? DeserializedValue { get; set; } + + [JsonIgnore] + public Type? ValueType { get; set; } + + [JsonIgnore] + public JsonSerializerOptions? JsonSerializerOptions { get; set; } +} diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/ChatHistoryProvider.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/ChatHistoryProvider.cs index d809582ea4..cecfa92e8f 100644 --- a/dotnet/src/Microsoft.Agents.AI.Abstractions/ChatHistoryProvider.cs +++ b/dotnet/src/Microsoft.Agents.AI.Abstractions/ChatHistoryProvider.cs @@ -143,13 +143,30 @@ public sealed class InvokingContext /// /// Initializes a new instance of the class with the specified request messages. /// + /// The agent being invoked. + /// The session associated with the agent invocation. /// The new messages to be used by the agent for this invocation. /// is . - public InvokingContext(IEnumerable requestMessages) + public InvokingContext( + AIAgent agent, + AgentSession? session, + IEnumerable requestMessages) { + this.Agent = Throw.IfNull(agent); + this.Session = session; this.RequestMessages = requestMessages ?? throw new ArgumentNullException(nameof(requestMessages)); } + /// + /// Gets the agent that is being invoked. + /// + public AIAgent Agent { get; } + + /// + /// Gets the agent session associated with the agent invocation. + /// + public AgentSession? Session { get; } + /// /// Gets the caller provided messages that will be used by the agent for this invocation. /// @@ -172,15 +189,33 @@ public sealed class InvokedContext /// /// Initializes a new instance of the class with the specified request messages. /// + /// The agent being invoked. + /// The session associated with the agent invocation. /// The caller provided messages that were used by the agent for this invocation. /// The messages retrieved from the for this invocation. /// is . - public InvokedContext(IEnumerable requestMessages, IEnumerable? chatHistoryProviderMessages) + public InvokedContext( + AIAgent agent, + AgentSession? session, + IEnumerable requestMessages, + IEnumerable? chatHistoryProviderMessages) { + this.Agent = Throw.IfNull(agent); + this.Session = session; this.RequestMessages = Throw.IfNull(requestMessages); this.ChatHistoryProviderMessages = chatHistoryProviderMessages; } + /// + /// Gets the agent that is being invoked. + /// + public AIAgent Agent { get; } + + /// + /// Gets the agent session associated with the agent invocation. + /// + public AgentSession? Session { get; } + /// /// Gets the caller provided messages that were used by the agent for this invocation. /// diff --git a/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgent.cs b/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgent.cs index ee6db4830d..23e120f14f 100644 --- a/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgent.cs +++ b/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgent.cs @@ -231,8 +231,8 @@ protected override async IAsyncEnumerable RunCoreStreamingA } catch (Exception ex) { - await NotifyChatHistoryProviderOfFailureAsync(safeSession, ex, GetInputMessages(inputMessages, continuationToken), chatHistoryProviderMessages, aiContextProviderMessages, chatOptions, cancellationToken).ConfigureAwait(false); - await NotifyAIContextProviderOfFailureAsync(safeSession, ex, GetInputMessages(inputMessages, continuationToken), aiContextProviderMessages, cancellationToken).ConfigureAwait(false); + await this.NotifyChatHistoryProviderOfFailureAsync(safeSession, ex, GetInputMessages(inputMessages, continuationToken), chatHistoryProviderMessages, aiContextProviderMessages, chatOptions, cancellationToken).ConfigureAwait(false); + await this.NotifyAIContextProviderOfFailureAsync(safeSession, ex, GetInputMessages(inputMessages, continuationToken), aiContextProviderMessages, cancellationToken).ConfigureAwait(false); throw; } @@ -246,8 +246,8 @@ protected override async IAsyncEnumerable RunCoreStreamingA } catch (Exception ex) { - await NotifyChatHistoryProviderOfFailureAsync(safeSession, ex, GetInputMessages(inputMessages, continuationToken), chatHistoryProviderMessages, aiContextProviderMessages, chatOptions, cancellationToken).ConfigureAwait(false); - await NotifyAIContextProviderOfFailureAsync(safeSession, ex, GetInputMessages(inputMessages, continuationToken), aiContextProviderMessages, cancellationToken).ConfigureAwait(false); + await this.NotifyChatHistoryProviderOfFailureAsync(safeSession, ex, GetInputMessages(inputMessages, continuationToken), chatHistoryProviderMessages, aiContextProviderMessages, chatOptions, cancellationToken).ConfigureAwait(false); + await this.NotifyAIContextProviderOfFailureAsync(safeSession, ex, GetInputMessages(inputMessages, continuationToken), aiContextProviderMessages, cancellationToken).ConfigureAwait(false); throw; } @@ -273,8 +273,8 @@ protected override async IAsyncEnumerable RunCoreStreamingA } catch (Exception ex) { - await NotifyChatHistoryProviderOfFailureAsync(safeSession, ex, GetInputMessages(inputMessages, continuationToken), chatHistoryProviderMessages, aiContextProviderMessages, chatOptions, cancellationToken).ConfigureAwait(false); - await NotifyAIContextProviderOfFailureAsync(safeSession, ex, GetInputMessages(inputMessages, continuationToken), aiContextProviderMessages, cancellationToken).ConfigureAwait(false); + await this.NotifyChatHistoryProviderOfFailureAsync(safeSession, ex, GetInputMessages(inputMessages, continuationToken), chatHistoryProviderMessages, aiContextProviderMessages, chatOptions, cancellationToken).ConfigureAwait(false); + await this.NotifyAIContextProviderOfFailureAsync(safeSession, ex, GetInputMessages(inputMessages, continuationToken), aiContextProviderMessages, cancellationToken).ConfigureAwait(false); throw; } } @@ -286,10 +286,10 @@ protected override async IAsyncEnumerable RunCoreStreamingA await this.UpdateSessionWithTypeAndConversationIdAsync(safeSession, chatResponse.ConversationId, cancellationToken).ConfigureAwait(false); // To avoid inconsistent state we only notify the session of the input messages if no error occurs after the initial request. - await NotifyChatHistoryProviderOfNewMessagesAsync(safeSession, GetInputMessages(inputMessages, continuationToken), chatHistoryProviderMessages, aiContextProviderMessages, chatResponse.Messages, chatOptions, cancellationToken).ConfigureAwait(false); + await this.NotifyChatHistoryProviderOfNewMessagesAsync(safeSession, GetInputMessages(inputMessages, continuationToken), chatHistoryProviderMessages, aiContextProviderMessages, chatResponse.Messages, chatOptions, cancellationToken).ConfigureAwait(false); // Notify the AIContextProvider of all new messages. - await NotifyAIContextProviderOfSuccessAsync(safeSession, GetInputMessages(inputMessages, continuationToken), aiContextProviderMessages, chatResponse.Messages, cancellationToken).ConfigureAwait(false); + await this.NotifyAIContextProviderOfSuccessAsync(safeSession, GetInputMessages(inputMessages, continuationToken), aiContextProviderMessages, chatResponse.Messages, cancellationToken).ConfigureAwait(false); } /// @@ -455,8 +455,8 @@ private async Task RunCoreAsync RunCoreAsync RunCoreAsync /// Notify the when an agent run succeeded, if there is an . /// - private static async Task NotifyAIContextProviderOfSuccessAsync( + private async Task NotifyAIContextProviderOfSuccessAsync( ChatClientAgentSession session, IEnumerable inputMessages, IList? aiContextProviderMessages, @@ -497,7 +497,7 @@ private static async Task NotifyAIContextProviderOfSuccessAsync( { if (session.AIContextProvider is not null) { - await session.AIContextProvider.InvokedAsync(new(inputMessages, aiContextProviderMessages) { ResponseMessages = responseMessages }, + await session.AIContextProvider.InvokedAsync(new(this, session, inputMessages, aiContextProviderMessages) { ResponseMessages = responseMessages }, cancellationToken).ConfigureAwait(false); } } @@ -505,7 +505,7 @@ await session.AIContextProvider.InvokedAsync(new(inputMessages, aiContextProvide /// /// Notify the of any failure during an agent run, if there is an . /// - private static async Task NotifyAIContextProviderOfFailureAsync( + private async Task NotifyAIContextProviderOfFailureAsync( ChatClientAgentSession session, Exception ex, IEnumerable inputMessages, @@ -514,7 +514,7 @@ private static async Task NotifyAIContextProviderOfFailureAsync( { if (session.AIContextProvider is not null) { - await session.AIContextProvider.InvokedAsync(new(inputMessages, aiContextProviderMessages) { InvokeException = ex }, + await session.AIContextProvider.InvokedAsync(new(this, session, inputMessages, aiContextProviderMessages) { InvokeException = ex }, cancellationToken).ConfigureAwait(false); } } @@ -726,7 +726,7 @@ private async Task // Add any existing messages from the session to the messages to be sent to the chat client. if (chatHistoryProvider is not null) { - var invokingContext = new ChatHistoryProvider.InvokingContext(inputMessages); + var invokingContext = new ChatHistoryProvider.InvokingContext(this, typedSession, inputMessages); var providerMessages = await chatHistoryProvider.InvokingAsync(invokingContext, cancellationToken).ConfigureAwait(false); inputMessagesForChatClient.AddRange(providerMessages); chatHistoryProviderMessages = providerMessages as IList ?? providerMessages.ToList(); @@ -739,7 +739,7 @@ private async Task // messages and options with the additional context. if (typedSession.AIContextProvider is not null) { - var invokingContext = new AIContextProvider.InvokingContext(inputMessages); + var invokingContext = new AIContextProvider.InvokingContext(this, typedSession, inputMessages); var aiContext = await typedSession.AIContextProvider.InvokingAsync(invokingContext, cancellationToken).ConfigureAwait(false); if (aiContext.Messages is { Count: > 0 }) { @@ -812,7 +812,7 @@ private async Task UpdateSessionWithTypeAndConversationIdAsync(ChatClientAgentSe } } - private static Task NotifyChatHistoryProviderOfFailureAsync( + private Task NotifyChatHistoryProviderOfFailureAsync( ChatClientAgentSession session, Exception ex, IEnumerable requestMessages, @@ -827,7 +827,7 @@ private static Task NotifyChatHistoryProviderOfFailureAsync( // If we don't have one, it means that the chat history is service managed and the underlying service is responsible for storing messages. if (provider is not null) { - var invokedContext = new ChatHistoryProvider.InvokedContext(requestMessages, chatHistoryProviderMessages!) + var invokedContext = new ChatHistoryProvider.InvokedContext(this, session, requestMessages, chatHistoryProviderMessages!) { AIContextProviderMessages = aiContextProviderMessages, InvokeException = ex @@ -839,7 +839,7 @@ private static Task NotifyChatHistoryProviderOfFailureAsync( return Task.CompletedTask; } - private static Task NotifyChatHistoryProviderOfNewMessagesAsync( + private Task NotifyChatHistoryProviderOfNewMessagesAsync( ChatClientAgentSession session, IEnumerable requestMessages, IEnumerable? chatHistoryProviderMessages, @@ -854,7 +854,7 @@ private static Task NotifyChatHistoryProviderOfNewMessagesAsync( // If we don't have one, it means that the chat history is service managed and the underlying service is responsible for storing messages. if (provider is not null) { - var invokedContext = new ChatHistoryProvider.InvokedContext(requestMessages, chatHistoryProviderMessages!) + var invokedContext = new ChatHistoryProvider.InvokedContext(this, session, requestMessages, chatHistoryProviderMessages!) { AIContextProviderMessages = aiContextProviderMessages, ResponseMessages = responseMessages diff --git a/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgentSession.cs b/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgentSession.cs index 1a79ae64d1..cb791bd564 100644 --- a/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgentSession.cs +++ b/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgentSession.cs @@ -148,6 +148,8 @@ internal static async Task DeserializeAsync( ? await aiContextProviderFactory.Invoke(state?.AIContextProviderState ?? default, jsonSerializerOptions, cancellationToken).ConfigureAwait(false) : null; + session.StateBag = AgentSessionStateBag.Deserialize(state?.StateBag ?? default); + if (state?.ConversationId is string sessionId) { session.ConversationId = sessionId; @@ -176,6 +178,7 @@ internal JsonElement Serialize(JsonSerializerOptions? jsonSerializerOptions = nu ConversationId = this.ConversationId, ChatHistoryProviderState = chatHistoryProviderState is { ValueKind: not JsonValueKind.Undefined } ? chatHistoryProviderState : null, AIContextProviderState = aiContextProviderState is { ValueKind: not JsonValueKind.Undefined } ? aiContextProviderState : null, + StateBag = this.StateBag.Serialize(), }; return JsonSerializer.SerializeToElement(state, AgentJsonUtilities.DefaultOptions.GetTypeInfo(typeof(SessionState))); @@ -201,5 +204,7 @@ internal sealed class SessionState public JsonElement? ChatHistoryProviderState { get; set; } public JsonElement? AIContextProviderState { get; set; } + + public JsonElement? StateBag { get; set; } } } diff --git a/dotnet/tests/AgentConformance.IntegrationTests/IAgentFixture.cs b/dotnet/tests/AgentConformance.IntegrationTests/IAgentFixture.cs index 96b40d561b..5548c5aaf9 100644 --- a/dotnet/tests/AgentConformance.IntegrationTests/IAgentFixture.cs +++ b/dotnet/tests/AgentConformance.IntegrationTests/IAgentFixture.cs @@ -15,7 +15,7 @@ public interface IAgentFixture : IAsyncLifetime { AIAgent Agent { get; } - Task> GetChatHistoryAsync(AgentSession session); + Task> GetChatHistoryAsync(AIAgent agent, AgentSession session); Task DeleteSessionAsync(AgentSession session); } diff --git a/dotnet/tests/AgentConformance.IntegrationTests/RunStreamingTests.cs b/dotnet/tests/AgentConformance.IntegrationTests/RunStreamingTests.cs index f9cc732175..18982baaad 100644 --- a/dotnet/tests/AgentConformance.IntegrationTests/RunStreamingTests.cs +++ b/dotnet/tests/AgentConformance.IntegrationTests/RunStreamingTests.cs @@ -106,7 +106,7 @@ public virtual async Task SessionMaintainsHistoryAsync() Assert.Contains("Paris", response1Text); Assert.Contains("Vienna", response2Text); - var chatHistory = await this.Fixture.GetChatHistoryAsync(session); + var chatHistory = await this.Fixture.GetChatHistoryAsync(agent, session); Assert.Equal(4, chatHistory.Count); Assert.Equal(2, chatHistory.Count(x => x.Role == ChatRole.User)); Assert.Equal(2, chatHistory.Count(x => x.Role == ChatRole.Assistant)); diff --git a/dotnet/tests/AgentConformance.IntegrationTests/RunTests.cs b/dotnet/tests/AgentConformance.IntegrationTests/RunTests.cs index 302784a1a8..da1cebaf52 100644 --- a/dotnet/tests/AgentConformance.IntegrationTests/RunTests.cs +++ b/dotnet/tests/AgentConformance.IntegrationTests/RunTests.cs @@ -111,7 +111,7 @@ public virtual async Task SessionMaintainsHistoryAsync() Assert.Contains("Paris", result1.Text); Assert.Contains("Vienna", result2.Text); - var chatHistory = await this.Fixture.GetChatHistoryAsync(session); + var chatHistory = await this.Fixture.GetChatHistoryAsync(agent, session); Assert.Equal(4, chatHistory.Count); Assert.Equal(2, chatHistory.Count(x => x.Role == ChatRole.User)); Assert.Equal(2, chatHistory.Count(x => x.Role == ChatRole.Assistant)); diff --git a/dotnet/tests/AnthropicChatCompletion.IntegrationTests/AnthropicChatCompletionFixture.cs b/dotnet/tests/AnthropicChatCompletion.IntegrationTests/AnthropicChatCompletionFixture.cs index 5f0fcbca2c..a0e4b64763 100644 --- a/dotnet/tests/AnthropicChatCompletion.IntegrationTests/AnthropicChatCompletionFixture.cs +++ b/dotnet/tests/AnthropicChatCompletion.IntegrationTests/AnthropicChatCompletionFixture.cs @@ -35,7 +35,7 @@ public AnthropicChatCompletionFixture(bool useReasoningChatModel, bool useBeta) public IChatClient ChatClient => this._agent.ChatClient; - public async Task> GetChatHistoryAsync(AgentSession session) + public async Task> GetChatHistoryAsync(AIAgent agent, AgentSession session) { var typedSession = (ChatClientAgentSession)session; @@ -44,7 +44,7 @@ public async Task> GetChatHistoryAsync(AgentSession session) return []; } - return (await typedSession.ChatHistoryProvider.InvokingAsync(new([]))).ToList(); + return (await typedSession.ChatHistoryProvider.InvokingAsync(new(agent, session, []))).ToList(); } public Task CreateChatClientAgentAsync( diff --git a/dotnet/tests/AzureAI.IntegrationTests/AIProjectClientFixture.cs b/dotnet/tests/AzureAI.IntegrationTests/AIProjectClientFixture.cs index 4b78d30f1c..c655fd7a58 100644 --- a/dotnet/tests/AzureAI.IntegrationTests/AIProjectClientFixture.cs +++ b/dotnet/tests/AzureAI.IntegrationTests/AIProjectClientFixture.cs @@ -33,7 +33,7 @@ public async Task CreateConversationAsync() return response.Value.Id; } - public async Task> GetChatHistoryAsync(AgentSession session) + public async Task> GetChatHistoryAsync(AIAgent agent, AgentSession session) { var chatClientSession = (ChatClientAgentSession)session; @@ -53,7 +53,7 @@ public async Task> GetChatHistoryAsync(AgentSession session) return []; } - return (await chatClientSession.ChatHistoryProvider.InvokingAsync(new([]))).ToList(); + return (await chatClientSession.ChatHistoryProvider.InvokingAsync(new(agent, session, []))).ToList(); } private async Task> GetChatHistoryFromResponsesChainAsync(string conversationId) diff --git a/dotnet/tests/AzureAIAgentsPersistent.IntegrationTests/AzureAIAgentsPersistentFixture.cs b/dotnet/tests/AzureAIAgentsPersistent.IntegrationTests/AzureAIAgentsPersistentFixture.cs index 2f59630c38..3e3272d951 100644 --- a/dotnet/tests/AzureAIAgentsPersistent.IntegrationTests/AzureAIAgentsPersistentFixture.cs +++ b/dotnet/tests/AzureAIAgentsPersistent.IntegrationTests/AzureAIAgentsPersistentFixture.cs @@ -24,7 +24,7 @@ public class AzureAIAgentsPersistentFixture : IChatClientAgentFixture public AIAgent Agent => this._agent; - public async Task> GetChatHistoryAsync(AgentSession session) + public async Task> GetChatHistoryAsync(AIAgent agent, AgentSession session) { List messages = []; var typedSession = (ChatClientAgentSession)session; diff --git a/dotnet/tests/CopilotStudio.IntegrationTests/CopilotStudioFixture.cs b/dotnet/tests/CopilotStudio.IntegrationTests/CopilotStudioFixture.cs index 3b3ac7ff7e..dd5fe46ecc 100644 --- a/dotnet/tests/CopilotStudio.IntegrationTests/CopilotStudioFixture.cs +++ b/dotnet/tests/CopilotStudio.IntegrationTests/CopilotStudioFixture.cs @@ -20,7 +20,7 @@ public class CopilotStudioFixture : IAgentFixture { public AIAgent Agent { get; private set; } = null!; - public Task> GetChatHistoryAsync(AgentSession session) => + public Task> GetChatHistoryAsync(AIAgent agent, AgentSession session) => throw new NotSupportedException("CopilotStudio doesn't allow retrieval of chat history."); public Task DeleteSessionAsync(AgentSession session) => diff --git a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AIContextProviderTests.cs b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AIContextProviderTests.cs index b6aabd081e..94aa73858a 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AIContextProviderTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AIContextProviderTests.cs @@ -6,17 +6,21 @@ using System.Threading; using System.Threading.Tasks; using Microsoft.Extensions.AI; +using Moq; namespace Microsoft.Agents.AI.Abstractions.UnitTests; public class AIContextProviderTests { + private static readonly AIAgent s_mockAgent = new Mock().Object; + private static readonly AgentSession s_mockSession = new Mock().Object; + [Fact] public async Task InvokedAsync_ReturnsCompletedTaskAsync() { var provider = new TestAIContextProvider(); var messages = new ReadOnlyCollection([]); - var task = provider.InvokedAsync(new(messages, aiContextProviderMessages: null)); + var task = provider.InvokedAsync(new(s_mockAgent, s_mockSession, messages, aiContextProviderMessages: null)); Assert.Equal(default, task); } @@ -31,13 +35,13 @@ public void Serialize_ReturnsEmptyElement() [Fact] public void InvokingContext_Constructor_ThrowsForNullMessages() { - Assert.Throws(() => new AIContextProvider.InvokingContext(null!)); + Assert.Throws(() => new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, null!)); } [Fact] public void InvokedContext_Constructor_ThrowsForNullMessages() { - Assert.Throws(() => new AIContextProvider.InvokedContext(null!, aiContextProviderMessages: null)); + Assert.Throws(() => new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, null!, aiContextProviderMessages: null)); } #region GetService Method Tests @@ -163,7 +167,7 @@ public void InvokingContext_RequestMessages_SetterThrowsForNull() { // Arrange var messages = new ReadOnlyCollection([new(ChatRole.User, "Hello")]); - var context = new AIContextProvider.InvokingContext(messages); + var context = new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, messages); // Act & Assert Assert.Throws(() => context.RequestMessages = null!); @@ -175,7 +179,7 @@ public void InvokingContext_RequestMessages_SetterRoundtrips() // Arrange var initialMessages = new ReadOnlyCollection([new(ChatRole.User, "Hello")]); var newMessages = new List { new(ChatRole.User, "New message") }; - var context = new AIContextProvider.InvokingContext(initialMessages); + var context = new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, initialMessages); // Act context.RequestMessages = newMessages; @@ -184,6 +188,55 @@ public void InvokingContext_RequestMessages_SetterRoundtrips() Assert.Same(newMessages, context.RequestMessages); } + [Fact] + public void InvokingContext_Agent_ReturnsConstructorValue() + { + // Arrange + var messages = new ReadOnlyCollection([new(ChatRole.User, "Hello")]); + + // Act + var context = new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, messages); + + // Assert + Assert.Same(s_mockAgent, context.Agent); + } + + [Fact] + public void InvokingContext_Session_ReturnsConstructorValue() + { + // Arrange + var messages = new ReadOnlyCollection([new(ChatRole.User, "Hello")]); + + // Act + var context = new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, messages); + + // Assert + Assert.Same(s_mockSession, context.Session); + } + + [Fact] + public void InvokingContext_Session_CanBeNull() + { + // Arrange + var messages = new ReadOnlyCollection([new(ChatRole.User, "Hello")]); + + // Act + var context = new AIContextProvider.InvokingContext(s_mockAgent, null, messages); + + // Assert + Assert.Null(context.Session); + } + + [Fact] + public void InvokingContext_Constructor_ThrowsForNullAgent() + { + // Arrange + var messages = new ReadOnlyCollection([new(ChatRole.User, "Hello")]); + + // Act & Assert + Assert.Throws(() => new AIContextProvider.InvokingContext(null!, s_mockSession, messages)); + } + #endregion #region InvokedContext Tests @@ -193,7 +246,7 @@ public void InvokedContext_RequestMessages_SetterThrowsForNull() { // Arrange var messages = new ReadOnlyCollection([new(ChatRole.User, "Hello")]); - var context = new AIContextProvider.InvokedContext(messages, aiContextProviderMessages: null); + var context = new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, messages, aiContextProviderMessages: null); // Act & Assert Assert.Throws(() => context.RequestMessages = null!); @@ -205,7 +258,7 @@ public void InvokedContext_RequestMessages_SetterRoundtrips() // Arrange var initialMessages = new ReadOnlyCollection([new(ChatRole.User, "Hello")]); var newMessages = new List { new(ChatRole.User, "New message") }; - var context = new AIContextProvider.InvokedContext(initialMessages, aiContextProviderMessages: null); + var context = new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, initialMessages, aiContextProviderMessages: null); // Act context.RequestMessages = newMessages; @@ -220,7 +273,7 @@ public void InvokedContext_AIContextProviderMessages_Roundtrips() // Arrange var requestMessages = new ReadOnlyCollection([new(ChatRole.User, "Hello")]); var aiContextMessages = new List { new(ChatRole.System, "AI context message") }; - var context = new AIContextProvider.InvokedContext(requestMessages, aiContextProviderMessages: null); + var context = new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages, aiContextProviderMessages: null); // Act context.AIContextProviderMessages = aiContextMessages; @@ -235,7 +288,7 @@ public void InvokedContext_ResponseMessages_Roundtrips() // Arrange var requestMessages = new ReadOnlyCollection([new(ChatRole.User, "Hello")]); var responseMessages = new List { new(ChatRole.Assistant, "Response message") }; - var context = new AIContextProvider.InvokedContext(requestMessages, aiContextProviderMessages: null); + var context = new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages, aiContextProviderMessages: null); // Act context.ResponseMessages = responseMessages; @@ -250,7 +303,7 @@ public void InvokedContext_InvokeException_Roundtrips() // Arrange var requestMessages = new ReadOnlyCollection([new(ChatRole.User, "Hello")]); var exception = new InvalidOperationException("Test exception"); - var context = new AIContextProvider.InvokedContext(requestMessages, aiContextProviderMessages: null); + var context = new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages, aiContextProviderMessages: null); // Act context.InvokeException = exception; @@ -259,6 +312,55 @@ public void InvokedContext_InvokeException_Roundtrips() Assert.Same(exception, context.InvokeException); } + [Fact] + public void InvokedContext_Agent_ReturnsConstructorValue() + { + // Arrange + var requestMessages = new ReadOnlyCollection([new(ChatRole.User, "Hello")]); + + // Act + var context = new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages, aiContextProviderMessages: null); + + // Assert + Assert.Same(s_mockAgent, context.Agent); + } + + [Fact] + public void InvokedContext_Session_ReturnsConstructorValue() + { + // Arrange + var requestMessages = new ReadOnlyCollection([new(ChatRole.User, "Hello")]); + + // Act + var context = new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages, aiContextProviderMessages: null); + + // Assert + Assert.Same(s_mockSession, context.Session); + } + + [Fact] + public void InvokedContext_Session_CanBeNull() + { + // Arrange + var requestMessages = new ReadOnlyCollection([new(ChatRole.User, "Hello")]); + + // Act + var context = new AIContextProvider.InvokedContext(s_mockAgent, null, requestMessages, aiContextProviderMessages: null); + + // Assert + Assert.Null(context.Session); + } + + [Fact] + public void InvokedContext_Constructor_ThrowsForNullAgent() + { + // Arrange + var requestMessages = new ReadOnlyCollection([new(ChatRole.User, "Hello")]); + + // Act & Assert + Assert.Throws(() => new AIContextProvider.InvokedContext(null!, s_mockSession, requestMessages, aiContextProviderMessages: null)); + } + #endregion private sealed class TestAIContextProvider : AIContextProvider diff --git a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AgentSessionStateBagTests.cs b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AgentSessionStateBagTests.cs new file mode 100644 index 0000000000..5d07a4749f --- /dev/null +++ b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AgentSessionStateBagTests.cs @@ -0,0 +1,407 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Text.Json; +using Microsoft.Agents.AI.Abstractions.UnitTests.Models; + +namespace Microsoft.Agents.AI.Abstractions.UnitTests; + +/// +/// Contains tests for the class. +/// +public sealed class AgentSessionStateBagTests +{ + #region Constructor Tests + + [Fact] + public void Constructor_Default_CreatesEmptyStateBag() + { + // Act + var stateBag = new AgentSessionStateBag(); + + // Assert + Assert.False(stateBag.TryGetValue("nonexistent", out _)); + } + + #endregion + + #region SetValue Tests + + [Fact] + public void SetValue_WithValidKeyAndValue_StoresValue() + { + // Arrange + var stateBag = new AgentSessionStateBag(); + + // Act + stateBag.SetValue("key1", "value1"); + + // Assert + Assert.True(stateBag.TryGetValue("key1", out var result)); + Assert.Equal("value1", result); + } + + [Fact] + public void SetValue_WithNullKey_ThrowsArgumentException() + { + // Arrange + var stateBag = new AgentSessionStateBag(); + + // Act & Assert + Assert.Throws(() => stateBag.SetValue(null!, "value")); + } + + [Fact] + public void SetValue_WithEmptyKey_ThrowsArgumentException() + { + // Arrange + var stateBag = new AgentSessionStateBag(); + + // Act & Assert + Assert.Throws(() => stateBag.SetValue("", "value")); + } + + [Fact] + public void SetValue_WithWhitespaceKey_ThrowsArgumentException() + { + // Arrange + var stateBag = new AgentSessionStateBag(); + + // Act & Assert + Assert.Throws(() => stateBag.SetValue(" ", "value")); + } + + [Fact] + public void SetValue_OverwritesExistingValue() + { + // Arrange + var stateBag = new AgentSessionStateBag(); + stateBag.SetValue("key1", "originalValue"); + + // Act + stateBag.SetValue("key1", "newValue"); + + // Assert + Assert.Equal("newValue", stateBag.GetValue("key1")); + } + + #endregion + + #region GetValue Tests + + [Fact] + public void GetValue_WithExistingKey_ReturnsValue() + { + // Arrange + var stateBag = new AgentSessionStateBag(); + stateBag.SetValue("key1", "value1"); + + // Act + var result = stateBag.GetValue("key1"); + + // Assert + Assert.Equal("value1", result); + } + + [Fact] + public void GetValue_WithNonexistentKey_ReturnsNull() + { + // Arrange + var stateBag = new AgentSessionStateBag(); + + // Act + var result = stateBag.GetValue("nonexistent"); + + // Assert + Assert.Null(result); + } + + [Fact] + public void GetValue_WithNullKey_ThrowsArgumentException() + { + // Arrange + var stateBag = new AgentSessionStateBag(); + + // Act & Assert + Assert.Throws(() => stateBag.GetValue(null!)); + } + + [Fact] + public void GetValue_WithEmptyKey_ThrowsArgumentException() + { + // Arrange + var stateBag = new AgentSessionStateBag(); + + // Act & Assert + Assert.Throws(() => stateBag.GetValue("")); + } + + [Fact] + public void GetValue_CachesDeserializedValue() + { + // Arrange + var stateBag = new AgentSessionStateBag(); + stateBag.SetValue("key1", "value1"); + + // Act + var result1 = stateBag.GetValue("key1"); + var result2 = stateBag.GetValue("key1"); + + // Assert + Assert.Same(result1, result2); + } + + #endregion + + #region TryGetValue Tests + + [Fact] + public void TryGetValue_WithExistingKey_ReturnsTrueAndValue() + { + // Arrange + var stateBag = new AgentSessionStateBag(); + stateBag.SetValue("key1", "value1"); + + // Act + var found = stateBag.TryGetValue("key1", out var result); + + // Assert + Assert.True(found); + Assert.Equal("value1", result); + } + + [Fact] + public void TryGetValue_WithNonexistentKey_ReturnsFalseAndNull() + { + // Arrange + var stateBag = new AgentSessionStateBag(); + + // Act + var found = stateBag.TryGetValue("nonexistent", out var result); + + // Assert + Assert.False(found); + Assert.Null(result); + } + + [Fact] + public void TryGetValue_WithNullKey_ThrowsArgumentException() + { + // Arrange + var stateBag = new AgentSessionStateBag(); + + // Act & Assert + Assert.Throws(() => stateBag.TryGetValue(null!, out _)); + } + + [Fact] + public void TryGetValue_WithEmptyKey_ThrowsArgumentException() + { + // Arrange + var stateBag = new AgentSessionStateBag(); + + // Act & Assert + Assert.Throws(() => stateBag.TryGetValue("", out _)); + } + + #endregion + + #region Serialize/Deserialize Tests + + [Fact] + public void Serialize_EmptyStateBag_ReturnsEmptyObject() + { + // Arrange + var stateBag = new AgentSessionStateBag(); + + // Act + var json = stateBag.Serialize(); + + // Assert + Assert.Equal(JsonValueKind.Object, json.ValueKind); + } + + [Fact] + public void Serialize_WithStringValue_ReturnsJsonWithValue() + { + // Arrange + var stateBag = new AgentSessionStateBag(); + stateBag.SetValue("stringKey", "stringValue"); + + // Act + var json = stateBag.Serialize(); + + // Assert + Assert.Equal(JsonValueKind.Object, json.ValueKind); + Assert.True(json.TryGetProperty("stringKey", out _)); + } + + [Fact] + public void Deserialize_FromJsonDocument_ReturnsEmptyStateBag() + { + // Arrange + var emptyJson = JsonDocument.Parse("{}").RootElement; + + // Act + var stateBag = AgentSessionStateBag.Deserialize(emptyJson); + + // Assert + Assert.False(stateBag.TryGetValue("nonexistent", out _)); + } + + [Fact] + public void Deserialize_NullElement_ReturnsEmptyStateBag() + { + // Arrange + var nullJson = default(JsonElement); + + // Act + var stateBag = AgentSessionStateBag.Deserialize(nullJson); + + // Assert + Assert.False(stateBag.TryGetValue("nonexistent", out _)); + } + + [Fact] + public void SerializeDeserialize_WithStringValue_Roundtrips() + { + // Arrange + var originalStateBag = new AgentSessionStateBag(); + originalStateBag.SetValue("stringKey", "stringValue"); + + // Act + var json = originalStateBag.Serialize(); + var restoredStateBag = AgentSessionStateBag.Deserialize(json); + + // Assert + Assert.Equal("stringValue", restoredStateBag.GetValue("stringKey")); + } + + #endregion + + #region Thread Safety Tests + + [Fact] + public async System.Threading.Tasks.Task SetValue_MultipleConcurrentWrites_DoesNotThrowAsync() + { + // Arrange + var stateBag = new AgentSessionStateBag(); + var tasks = new System.Threading.Tasks.Task[100]; + + // Act + for (int i = 0; i < 100; i++) + { + int index = i; + tasks[i] = System.Threading.Tasks.Task.Run(() => stateBag.SetValue($"key{index}", $"value{index}")); + } + + await System.Threading.Tasks.Task.WhenAll(tasks); + + // Assert + for (int i = 0; i < 100; i++) + { + Assert.True(stateBag.TryGetValue($"key{i}", out var value)); + Assert.Equal($"value{i}", value); + } + } + + #endregion + + #region Complex Object Tests + + [Fact] + public void SetValue_WithComplexObject_StoresValue() + { + // Arrange + var stateBag = new AgentSessionStateBag(); + var animal = new Animal { Id = 1, FullName = "Buddy", Species = Species.Bear }; + + // Act + stateBag.SetValue("animal", animal, TestJsonSerializerContext.Default.Options); + + // Assert + Animal? result = stateBag.GetValue("animal", TestJsonSerializerContext.Default.Options); + Assert.NotNull(result); + Assert.Equal(1, result.Id); + Assert.Equal("Buddy", result.FullName); + Assert.Equal(Species.Bear, result.Species); + } + + [Fact] + public void GetValue_WithComplexObject_CachesDeserializedValue() + { + // Arrange + var stateBag = new AgentSessionStateBag(); + var animal = new Animal { Id = 2, FullName = "Whiskers", Species = Species.Tiger }; + stateBag.SetValue("animal", animal, TestJsonSerializerContext.Default.Options); + + // Act + Animal? result1 = stateBag.GetValue("animal", TestJsonSerializerContext.Default.Options); + Animal? result2 = stateBag.GetValue("animal", TestJsonSerializerContext.Default.Options); + + // Assert + Assert.Same(result1, result2); + } + + [Fact] + public void TryGetValue_WithComplexObject_ReturnsTrueAndValue() + { + // Arrange + var stateBag = new AgentSessionStateBag(); + var animal = new Animal { Id = 3, FullName = "Goldie", Species = Species.Walrus }; + stateBag.SetValue("animal", animal, TestJsonSerializerContext.Default.Options); + + // Act + bool found = stateBag.TryGetValue("animal", out Animal? result, TestJsonSerializerContext.Default.Options); + + // Assert + Assert.True(found); + Assert.NotNull(result); + Assert.Equal(3, result.Id); + Assert.Equal("Goldie", result.FullName); + Assert.Equal(Species.Walrus, result.Species); + } + + [Fact] + public void SerializeDeserialize_WithComplexObject_Roundtrips() + { + // Arrange + var originalStateBag = new AgentSessionStateBag(); + var animal = new Animal { Id = 4, FullName = "Polly", Species = Species.Bear }; + originalStateBag.SetValue("animal", animal, TestJsonSerializerContext.Default.Options); + + // Act + JsonElement json = originalStateBag.Serialize(); + AgentSessionStateBag restoredStateBag = AgentSessionStateBag.Deserialize(json); + + // Assert + Animal? restoredAnimal = restoredStateBag.GetValue("animal", TestJsonSerializerContext.Default.Options); + Assert.NotNull(restoredAnimal); + Assert.Equal(4, restoredAnimal.Id); + Assert.Equal("Polly", restoredAnimal.FullName); + Assert.Equal(Species.Bear, restoredAnimal.Species); + } + + [Fact] + public void Serialize_WithComplexObject_ReturnsJsonWithProperties() + { + // Arrange + var stateBag = new AgentSessionStateBag(); + var animal = new Animal { Id = 7, FullName = "Spot", Species = Species.Walrus }; + stateBag.SetValue("animal", animal, TestJsonSerializerContext.Default.Options); + + // Act + JsonElement json = stateBag.Serialize(); + + // Assert + Assert.Equal(JsonValueKind.Object, json.ValueKind); + Assert.True(json.TryGetProperty("animal", out JsonElement animalElement)); + Assert.True(animalElement.TryGetProperty("jsonValue", out JsonElement jsonValueElement)); + Assert.Equal(JsonValueKind.Object, jsonValueElement.ValueKind); + Assert.Equal(7, jsonValueElement.GetProperty("id").GetInt32()); + Assert.Equal("Spot", jsonValueElement.GetProperty("fullName").GetString()); + Assert.Equal("Walrus", jsonValueElement.GetProperty("species").GetString()); + } + + #endregion +} diff --git a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AgentSessionTests.cs b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AgentSessionTests.cs index 5a776c9fb0..b80f0a4fd2 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AgentSessionTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AgentSessionTests.cs @@ -11,6 +11,21 @@ namespace Microsoft.Agents.AI.Abstractions.UnitTests; /// public class AgentSessionTests { + #region StateBag Tests + + [Fact] + public void StateBag_Values_Roundtrips() + { + // Arrange + var session = new TestAgentSession(); + + // Act & Assert + session.StateBag.SetValue("key1", "value1"); + Assert.Equal("value1", session.StateBag.GetValue("key1")); + } + + #endregion + #region GetService Method Tests /// diff --git a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/ChatHistoryProviderExtensionsTests.cs b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/ChatHistoryProviderExtensionsTests.cs index 84a0242320..a74906c801 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/ChatHistoryProviderExtensionsTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/ChatHistoryProviderExtensionsTests.cs @@ -14,6 +14,9 @@ namespace Microsoft.Agents.AI.Abstractions.UnitTests; /// public sealed class ChatHistoryProviderExtensionsTests { + private static readonly AIAgent s_mockAgent = new Mock().Object; + private static readonly AgentSession s_mockSession = new Mock().Object; + [Fact] public void WithMessageFilters_ReturnsChatHistoryProviderMessageFilter() { @@ -35,7 +38,7 @@ public async Task WithMessageFilters_InvokingFilter_IsAppliedAsync() // Arrange Mock providerMock = new(); List innerMessages = [new(ChatRole.User, "Hello"), new(ChatRole.Assistant, "Hi")]; - ChatHistoryProvider.InvokingContext context = new([new ChatMessage(ChatRole.User, "Test")]); + ChatHistoryProvider.InvokingContext context = new(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Test")]); providerMock .Setup(p => p.InvokingAsync(context, It.IsAny())) @@ -59,7 +62,7 @@ public async Task WithMessageFilters_InvokedFilter_IsAppliedAsync() Mock providerMock = new(); List requestMessages = [new(ChatRole.User, "Hello")]; List chatHistoryProviderMessages = [new(ChatRole.System, "System")]; - ChatHistoryProvider.InvokedContext context = new(requestMessages, chatHistoryProviderMessages) + ChatHistoryProvider.InvokedContext context = new(s_mockAgent, s_mockSession, requestMessages, chatHistoryProviderMessages) { ResponseMessages = [new ChatMessage(ChatRole.Assistant, "Response")] }; @@ -106,7 +109,7 @@ public async Task WithAIContextProviderMessageRemoval_RemovesAIContextProviderMe List requestMessages = [new(ChatRole.User, "Hello")]; List chatHistoryProviderMessages = [new(ChatRole.System, "System")]; List aiContextProviderMessages = [new(ChatRole.System, "Context")]; - ChatHistoryProvider.InvokedContext context = new(requestMessages, chatHistoryProviderMessages) + ChatHistoryProvider.InvokedContext context = new(s_mockAgent, s_mockSession, requestMessages, chatHistoryProviderMessages) { AIContextProviderMessages = aiContextProviderMessages }; diff --git a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/ChatHistoryProviderMessageFilterTests.cs b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/ChatHistoryProviderMessageFilterTests.cs index 43a3e78f10..4b955a43c0 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/ChatHistoryProviderMessageFilterTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/ChatHistoryProviderMessageFilterTests.cs @@ -16,6 +16,9 @@ namespace Microsoft.Agents.AI.Abstractions.UnitTests; /// public sealed class ChatHistoryProviderMessageFilterTests { + private static readonly AIAgent s_mockAgent = new Mock().Object; + private static readonly AgentSession s_mockSession = new Mock().Object; + [Fact] public void Constructor_WithNullInnerProvider_ThrowsArgumentNullException() { @@ -59,7 +62,7 @@ public async Task InvokingAsync_WithNoOpFilters_ReturnsInnerProviderMessagesAsyn new(ChatRole.User, "Hello"), new(ChatRole.Assistant, "Hi there!") }; - var context = new ChatHistoryProvider.InvokingContext([new ChatMessage(ChatRole.User, "Test")]); + var context = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Test")]); innerProviderMock .Setup(s => s.InvokingAsync(context, It.IsAny())) @@ -88,7 +91,7 @@ public async Task InvokingAsync_WithInvokingFilter_AppliesFilterAsync() new(ChatRole.Assistant, "Hi there!"), new(ChatRole.User, "How are you?") }; - var context = new ChatHistoryProvider.InvokingContext([new ChatMessage(ChatRole.User, "Test")]); + var context = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Test")]); innerProviderMock .Setup(s => s.InvokingAsync(context, It.IsAny())) @@ -118,7 +121,7 @@ public async Task InvokingAsync_WithInvokingFilter_CanModifyMessagesAsync() new(ChatRole.User, "Hello"), new(ChatRole.Assistant, "Hi there!") }; - var context = new ChatHistoryProvider.InvokingContext([new ChatMessage(ChatRole.User, "Test")]); + var context = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Test")]); innerProviderMock .Setup(s => s.InvokingAsync(context, It.IsAny())) @@ -147,7 +150,7 @@ public async Task InvokedAsync_WithInvokedFilter_AppliesFilterAsync() var requestMessages = new List { new(ChatRole.User, "Hello") }; var chatHistoryProviderMessages = new List { new(ChatRole.System, "System") }; var responseMessages = new List { new(ChatRole.Assistant, "Response") }; - var context = new ChatHistoryProvider.InvokedContext(requestMessages, chatHistoryProviderMessages) + var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages, chatHistoryProviderMessages) { ResponseMessages = responseMessages }; @@ -162,7 +165,7 @@ public async Task InvokedAsync_WithInvokedFilter_AppliesFilterAsync() ChatHistoryProvider.InvokedContext InvokedFilter(ChatHistoryProvider.InvokedContext ctx) { var modifiedRequestMessages = ctx.RequestMessages.Select(m => new ChatMessage(m.Role, $"[FILTERED] {m.Text}")).ToList(); - return new ChatHistoryProvider.InvokedContext(modifiedRequestMessages, ctx.ChatHistoryProviderMessages) + return new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, modifiedRequestMessages, ctx.ChatHistoryProviderMessages) { ResponseMessages = ctx.ResponseMessages, AIContextProviderMessages = ctx.AIContextProviderMessages, diff --git a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/ChatHistoryProviderTests.cs b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/ChatHistoryProviderTests.cs index a26ef199d9..5e0fbe9817 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/ChatHistoryProviderTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/ChatHistoryProviderTests.cs @@ -6,6 +6,7 @@ using System.Threading; using System.Threading.Tasks; using Microsoft.Extensions.AI; +using Moq; namespace Microsoft.Agents.AI.Abstractions.UnitTests; @@ -14,6 +15,9 @@ namespace Microsoft.Agents.AI.Abstractions.UnitTests; /// public class ChatHistoryProviderTests { + private static readonly AIAgent s_mockAgent = new Mock().Object; + private static readonly AgentSession s_mockSession = new Mock().Object; + #region GetService Method Tests [Fact] @@ -82,7 +86,7 @@ public void GetService_Generic_ReturnsNullForUnrelatedType() public void InvokingContext_Constructor_ThrowsForNullMessages() { // Arrange & Act & Assert - Assert.Throws(() => new ChatHistoryProvider.InvokingContext(null!)); + Assert.Throws(() => new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, null!)); } [Fact] @@ -90,7 +94,7 @@ public void InvokingContext_RequestMessages_SetterThrowsForNull() { // Arrange var messages = new List { new(ChatRole.User, "Hello") }; - var context = new ChatHistoryProvider.InvokingContext(messages); + var context = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, messages); // Act & Assert Assert.Throws(() => context.RequestMessages = null!); @@ -102,7 +106,7 @@ public void InvokingContext_RequestMessages_SetterRoundtrips() // Arrange var initialMessages = new List { new(ChatRole.User, "Hello") }; var newMessages = new List { new(ChatRole.User, "New message") }; - var context = new ChatHistoryProvider.InvokingContext(initialMessages); + var context = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, initialMessages); // Act context.RequestMessages = newMessages; @@ -111,6 +115,55 @@ public void InvokingContext_RequestMessages_SetterRoundtrips() Assert.Same(newMessages, context.RequestMessages); } + [Fact] + public void InvokingContext_Agent_ReturnsConstructorValue() + { + // Arrange + var messages = new List { new(ChatRole.User, "Hello") }; + + // Act + var context = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, messages); + + // Assert + Assert.Same(s_mockAgent, context.Agent); + } + + [Fact] + public void InvokingContext_Session_ReturnsConstructorValue() + { + // Arrange + var messages = new List { new(ChatRole.User, "Hello") }; + + // Act + var context = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, messages); + + // Assert + Assert.Same(s_mockSession, context.Session); + } + + [Fact] + public void InvokingContext_Session_CanBeNull() + { + // Arrange + var messages = new List { new(ChatRole.User, "Hello") }; + + // Act + var context = new ChatHistoryProvider.InvokingContext(s_mockAgent, null, messages); + + // Assert + Assert.Null(context.Session); + } + + [Fact] + public void InvokingContext_Constructor_ThrowsForNullAgent() + { + // Arrange + var messages = new List { new(ChatRole.User, "Hello") }; + + // Act & Assert + Assert.Throws(() => new ChatHistoryProvider.InvokingContext(null!, s_mockSession, messages)); + } + #endregion #region InvokedContext Tests @@ -119,7 +172,7 @@ public void InvokingContext_RequestMessages_SetterRoundtrips() public void InvokedContext_Constructor_ThrowsForNullRequestMessages() { // Arrange & Act & Assert - Assert.Throws(() => new ChatHistoryProvider.InvokedContext(null!, [])); + Assert.Throws(() => new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, null!, [])); } [Fact] @@ -127,7 +180,7 @@ public void InvokedContext_RequestMessages_SetterThrowsForNull() { // Arrange var requestMessages = new List { new(ChatRole.User, "Hello") }; - var context = new ChatHistoryProvider.InvokedContext(requestMessages, []); + var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages, []); // Act & Assert Assert.Throws(() => context.RequestMessages = null!); @@ -139,7 +192,7 @@ public void InvokedContext_RequestMessages_SetterRoundtrips() // Arrange var initialMessages = new List { new(ChatRole.User, "Hello") }; var newMessages = new List { new(ChatRole.User, "New message") }; - var context = new ChatHistoryProvider.InvokedContext(initialMessages, []); + var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, initialMessages, []); // Act context.RequestMessages = newMessages; @@ -154,7 +207,7 @@ public void InvokedContext_ChatHistoryProviderMessages_SetterRoundtrips() // Arrange var requestMessages = new List { new(ChatRole.User, "Hello") }; var newProviderMessages = new List { new(ChatRole.System, "System message") }; - var context = new ChatHistoryProvider.InvokedContext(requestMessages, []); + var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages, []); // Act context.ChatHistoryProviderMessages = newProviderMessages; @@ -169,7 +222,7 @@ public void InvokedContext_AIContextProviderMessages_Roundtrips() // Arrange var requestMessages = new List { new(ChatRole.User, "Hello") }; var aiContextMessages = new List { new(ChatRole.System, "AI context message") }; - var context = new ChatHistoryProvider.InvokedContext(requestMessages, []); + var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages, []); // Act context.AIContextProviderMessages = aiContextMessages; @@ -184,7 +237,7 @@ public void InvokedContext_ResponseMessages_Roundtrips() // Arrange var requestMessages = new List { new(ChatRole.User, "Hello") }; var responseMessages = new List { new(ChatRole.Assistant, "Response message") }; - var context = new ChatHistoryProvider.InvokedContext(requestMessages, []); + var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages, []); // Act context.ResponseMessages = responseMessages; @@ -199,7 +252,7 @@ public void InvokedContext_InvokeException_Roundtrips() // Arrange var requestMessages = new List { new(ChatRole.User, "Hello") }; var exception = new InvalidOperationException("Test exception"); - var context = new ChatHistoryProvider.InvokedContext(requestMessages, []); + var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages, []); // Act context.InvokeException = exception; @@ -208,6 +261,55 @@ public void InvokedContext_InvokeException_Roundtrips() Assert.Same(exception, context.InvokeException); } + [Fact] + public void InvokedContext_Agent_ReturnsConstructorValue() + { + // Arrange + var requestMessages = new List { new(ChatRole.User, "Hello") }; + + // Act + var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages, []); + + // Assert + Assert.Same(s_mockAgent, context.Agent); + } + + [Fact] + public void InvokedContext_Session_ReturnsConstructorValue() + { + // Arrange + var requestMessages = new List { new(ChatRole.User, "Hello") }; + + // Act + var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages, []); + + // Assert + Assert.Same(s_mockSession, context.Session); + } + + [Fact] + public void InvokedContext_Session_CanBeNull() + { + // Arrange + var requestMessages = new List { new(ChatRole.User, "Hello") }; + + // Act + var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, null, requestMessages, []); + + // Assert + Assert.Null(context.Session); + } + + [Fact] + public void InvokedContext_Constructor_ThrowsForNullAgent() + { + // Arrange + var requestMessages = new List { new(ChatRole.User, "Hello") }; + + // Act & Assert + Assert.Throws(() => new ChatHistoryProvider.InvokedContext(null!, s_mockSession, requestMessages, [])); + } + #endregion private sealed class TestChatHistoryProvider : ChatHistoryProvider diff --git a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/InMemoryChatHistoryProviderTests.cs b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/InMemoryChatHistoryProviderTests.cs index ff31d0afc9..bf8ff998b9 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/InMemoryChatHistoryProviderTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/InMemoryChatHistoryProviderTests.cs @@ -18,6 +18,9 @@ namespace Microsoft.Agents.AI.Abstractions.UnitTests; /// public class InMemoryChatHistoryProviderTests { + private static readonly AIAgent s_mockAgent = new Mock().Object; + private static readonly AgentSession s_mockSession = new Mock().Object; + [Fact] public void Constructor_Throws_ForNullReducer() => // Arrange & Act & Assert @@ -68,7 +71,7 @@ public async Task InvokedAsyncAddsMessagesAsync() var provider = new InMemoryChatHistoryProvider(); provider.Add(providerMessages[0]); - var context = new ChatHistoryProvider.InvokedContext(requestMessages, providerMessages) + var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages, providerMessages) { AIContextProviderMessages = aiContextProviderMessages, ResponseMessages = responseMessages @@ -87,7 +90,7 @@ public async Task InvokedAsyncWithEmptyDoesNotFailAsync() { var provider = new InMemoryChatHistoryProvider(); - var context = new ChatHistoryProvider.InvokedContext([], []); + var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, [], []); await provider.InvokedAsync(context, CancellationToken.None); Assert.Empty(provider); @@ -102,7 +105,7 @@ public async Task InvokingAsyncReturnsAllMessagesAsync() new ChatMessage(ChatRole.Assistant, "Test2") }; - var context = new ChatHistoryProvider.InvokingContext([]); + var context = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, []); var result = (await provider.InvokingAsync(context, CancellationToken.None)).ToList(); Assert.Equal(2, result.Count); @@ -183,7 +186,7 @@ public async Task InvokedAsyncWithEmptyMessagesDoesNotChangeProviderAsync() var provider = new InMemoryChatHistoryProvider(); var messages = new List(); - var context = new ChatHistoryProvider.InvokedContext(messages, []); + var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, messages, []); await provider.InvokedAsync(context, CancellationToken.None); Assert.Empty(provider); @@ -520,7 +523,7 @@ public async Task AddMessagesAsync_WithReducer_AfterMessageAdded_InvokesReducerA var provider = new InMemoryChatHistoryProvider(reducerMock.Object, InMemoryChatHistoryProvider.ChatReducerTriggerEvent.AfterMessageAdded); // Act - var context = new ChatHistoryProvider.InvokedContext(originalMessages, []); + var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, originalMessages, []); await provider.InvokedAsync(context, CancellationToken.None); // Assert @@ -556,7 +559,7 @@ public async Task GetMessagesAsync_WithReducer_BeforeMessagesRetrieval_InvokesRe } // Act - var invokingContext = new ChatHistoryProvider.InvokingContext(Array.Empty()); + var invokingContext = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, Array.Empty()); var result = (await provider.InvokingAsync(invokingContext, CancellationToken.None)).ToList(); // Assert @@ -579,7 +582,7 @@ public async Task AddMessagesAsync_WithReducer_ButWrongTrigger_DoesNotInvokeRedu var provider = new InMemoryChatHistoryProvider(reducerMock.Object, InMemoryChatHistoryProvider.ChatReducerTriggerEvent.BeforeMessagesRetrieval); // Act - var context = new ChatHistoryProvider.InvokedContext(originalMessages, []); + var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, originalMessages, []); await provider.InvokedAsync(context, CancellationToken.None); // Assert @@ -605,7 +608,7 @@ public async Task GetMessagesAsync_WithReducer_ButWrongTrigger_DoesNotInvokeRedu }; // Act - var invokingContext = new ChatHistoryProvider.InvokingContext(Array.Empty()); + var invokingContext = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, Array.Empty()); var result = (await provider.InvokingAsync(invokingContext, CancellationToken.None)).ToList(); // Assert @@ -627,7 +630,7 @@ public async Task InvokedAsync_WithException_DoesNotAddMessagesAsync() { new(ChatRole.Assistant, "Hi there!") }; - var context = new ChatHistoryProvider.InvokedContext(requestMessages, []) + var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages, []) { ResponseMessages = responseMessages, InvokeException = new InvalidOperationException("Test exception") diff --git a/dotnet/tests/Microsoft.Agents.AI.CosmosNoSql.UnitTests/CosmosChatHistoryProviderTests.cs b/dotnet/tests/Microsoft.Agents.AI.CosmosNoSql.UnitTests/CosmosChatHistoryProviderTests.cs index ab2f58dfd5..f6589ff9e3 100644 --- a/dotnet/tests/Microsoft.Agents.AI.CosmosNoSql.UnitTests/CosmosChatHistoryProviderTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.CosmosNoSql.UnitTests/CosmosChatHistoryProviderTests.cs @@ -41,6 +41,9 @@ namespace Microsoft.Agents.AI.CosmosNoSql.UnitTests; [Collection("CosmosDB")] public sealed class CosmosChatHistoryProviderTests : IAsyncLifetime, IDisposable { + private static readonly AIAgent s_mockAgent = new Moq.Mock().Object; + private static readonly AgentSession s_mockSession = new Moq.Mock().Object; + // Cosmos DB Emulator connection settings private const string EmulatorEndpoint = "https://localhost:8081"; private const string EmulatorKey = "C2y6yDjf5/R+ob0N8A7Cgv30VRDJIWEHLM+4QDU5DE2nQ9nDuVTqobD4b8mGGyPMbIZnqyMsEcaGQy67XIw/Jw=="; @@ -214,7 +217,7 @@ public async Task InvokedAsync_WithSingleMessage_ShouldAddMessageAsync() using var provider = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, TestContainerId, conversationId); var message = new ChatMessage(ChatRole.User, "Hello, world!"); - var context = new ChatHistoryProvider.InvokedContext([message], []) + var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, [message], []) { ResponseMessages = [] }; @@ -226,7 +229,7 @@ public async Task InvokedAsync_WithSingleMessage_ShouldAddMessageAsync() await Task.Delay(100); // Assert - var invokingContext = new ChatHistoryProvider.InvokingContext([]); + var invokingContext = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, []); var messages = await provider.InvokingAsync(invokingContext); var messageList = messages.ToList(); @@ -293,7 +296,7 @@ public async Task InvokedAsync_WithMultipleMessages_ShouldAddAllMessagesAsync() new ChatMessage(ChatRole.Assistant, "Response message") }; - var context = new ChatHistoryProvider.InvokedContext(requestMessages, []) + var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages, []) { AIContextProviderMessages = aiContextProviderMessages, ResponseMessages = responseMessages @@ -303,7 +306,7 @@ public async Task InvokedAsync_WithMultipleMessages_ShouldAddAllMessagesAsync() await provider.InvokedAsync(context); // Assert - var invokingContext = new ChatHistoryProvider.InvokingContext([]); + var invokingContext = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, []); var retrievedMessages = await provider.InvokingAsync(invokingContext); var messageList = retrievedMessages.ToList(); Assert.Equal(5, messageList.Count); @@ -327,7 +330,7 @@ public async Task InvokingAsync_WithNoMessages_ShouldReturnEmptyAsync() using var provider = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, TestContainerId, Guid.NewGuid().ToString()); // Act - var invokingContext = new ChatHistoryProvider.InvokingContext([]); + var invokingContext = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, []); var messages = await provider.InvokingAsync(invokingContext); // Assert @@ -346,15 +349,15 @@ public async Task InvokingAsync_WithConversationIsolation_ShouldOnlyReturnMessag using var store1 = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, TestContainerId, conversation1); using var store2 = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, TestContainerId, conversation2); - var context1 = new ChatHistoryProvider.InvokedContext([new ChatMessage(ChatRole.User, "Message for conversation 1")], []); - var context2 = new ChatHistoryProvider.InvokedContext([new ChatMessage(ChatRole.User, "Message for conversation 2")], []); + var context1 = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Message for conversation 1")], []); + var context2 = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Message for conversation 2")], []); await store1.InvokedAsync(context1); await store2.InvokedAsync(context2); // Act - var invokingContext1 = new ChatHistoryProvider.InvokingContext([]); - var invokingContext2 = new ChatHistoryProvider.InvokingContext([]); + var invokingContext1 = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, []); + var invokingContext2 = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, []); var messages1 = await store1.InvokingAsync(invokingContext1); var messages2 = await store2.InvokingAsync(invokingContext2); @@ -391,11 +394,11 @@ public async Task FullWorkflow_AddAndGet_ShouldWorkCorrectlyAsync() }; // Act 1: Add messages - var invokedContext = new ChatHistoryProvider.InvokedContext(messages, []); + var invokedContext = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, messages, []); await originalStore.InvokedAsync(invokedContext); // Act 2: Verify messages were added - var invokingContext = new ChatHistoryProvider.InvokingContext([]); + var invokingContext = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, []); var retrievedMessages = await originalStore.InvokingAsync(invokingContext); var retrievedList = retrievedMessages.ToList(); Assert.Equal(5, retrievedList.Count); @@ -545,7 +548,7 @@ public async Task InvokedAsync_WithHierarchicalPartitioning_ShouldAddMessageWith using var provider = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, HierarchicalTestContainerId, TenantId, UserId, SessionId); var message = new ChatMessage(ChatRole.User, "Hello from hierarchical partitioning!"); - var context = new ChatHistoryProvider.InvokedContext([message], []); + var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, [message], []); // Act await provider.InvokedAsync(context); @@ -554,7 +557,7 @@ public async Task InvokedAsync_WithHierarchicalPartitioning_ShouldAddMessageWith await Task.Delay(100); // Assert - var invokingContext = new ChatHistoryProvider.InvokingContext([]); + var invokingContext = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, []); var messages = await provider.InvokingAsync(invokingContext); var messageList = messages.ToList(); @@ -602,7 +605,7 @@ public async Task InvokedAsync_WithHierarchicalMultipleMessages_ShouldAddAllMess new ChatMessage(ChatRole.User, "Third hierarchical message") }; - var context = new ChatHistoryProvider.InvokedContext(messages, []); + var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, messages, []); // Act await provider.InvokedAsync(context); @@ -611,7 +614,7 @@ public async Task InvokedAsync_WithHierarchicalMultipleMessages_ShouldAddAllMess await Task.Delay(100); // Assert - var invokingContext = new ChatHistoryProvider.InvokingContext([]); + var invokingContext = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, []); var retrievedMessages = await provider.InvokingAsync(invokingContext); var messageList = retrievedMessages.ToList(); @@ -637,8 +640,8 @@ public async Task InvokingAsync_WithHierarchicalPartitionIsolation_ShouldIsolate using var store2 = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, HierarchicalTestContainerId, TenantId, UserId2, SessionId); // Add messages to both stores - var context1 = new ChatHistoryProvider.InvokedContext([new ChatMessage(ChatRole.User, "Message from user 1")], []); - var context2 = new ChatHistoryProvider.InvokedContext([new ChatMessage(ChatRole.User, "Message from user 2")], []); + var context1 = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Message from user 1")], []); + var context2 = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Message from user 2")], []); await store1.InvokedAsync(context1); await store2.InvokedAsync(context2); @@ -647,8 +650,8 @@ public async Task InvokingAsync_WithHierarchicalPartitionIsolation_ShouldIsolate await Task.Delay(100); // Act & Assert - var invokingContext1 = new ChatHistoryProvider.InvokingContext([]); - var invokingContext2 = new ChatHistoryProvider.InvokingContext([]); + var invokingContext1 = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, []); + var invokingContext2 = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, []); var messages1 = await store1.InvokingAsync(invokingContext1); var messageList1 = messages1.ToList(); @@ -675,7 +678,7 @@ public async Task SerializeDeserialize_WithHierarchicalPartitioning_ShouldPreser using var originalStore = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, HierarchicalTestContainerId, TenantId, UserId, SessionId); - var context = new ChatHistoryProvider.InvokedContext([new ChatMessage(ChatRole.User, "Test serialization message")], []); + var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Test serialization message")], []); await originalStore.InvokedAsync(context); // Act - Serialize the provider state @@ -693,7 +696,7 @@ public async Task SerializeDeserialize_WithHierarchicalPartitioning_ShouldPreser await Task.Delay(100); // Assert - The deserialized provider should have the same functionality - var invokingContext = new ChatHistoryProvider.InvokingContext([]); + var invokingContext = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, []); var messages = await deserializedStore.InvokingAsync(invokingContext); var messageList = messages.ToList(); @@ -717,8 +720,8 @@ public async Task HierarchicalAndSimplePartitioning_ShouldCoexistAsync() using var hierarchicalProvider = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, HierarchicalTestContainerId, "tenant-coexist", "user-coexist", SessionId); // Add messages to both - var simpleContext = new ChatHistoryProvider.InvokedContext([new ChatMessage(ChatRole.User, "Simple partitioning message")], []); - var hierarchicalContext = new ChatHistoryProvider.InvokedContext([new ChatMessage(ChatRole.User, "Hierarchical partitioning message")], []); + var simpleContext = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Simple partitioning message")], []); + var hierarchicalContext = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Hierarchical partitioning message")], []); await simpleProvider.InvokedAsync(simpleContext); await hierarchicalProvider.InvokedAsync(hierarchicalContext); @@ -727,7 +730,7 @@ public async Task HierarchicalAndSimplePartitioning_ShouldCoexistAsync() await Task.Delay(100); // Act & Assert - var invokingContext = new ChatHistoryProvider.InvokingContext([]); + var invokingContext = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, []); var simpleMessages = await simpleProvider.InvokingAsync(invokingContext); var simpleMessageList = simpleMessages.ToList(); @@ -760,7 +763,7 @@ public async Task MaxMessagesToRetrieve_ShouldLimitAndReturnMostRecentAsync() await Task.Delay(10); // Small delay to ensure different timestamps } - var context = new ChatHistoryProvider.InvokedContext(messages, []); + var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, messages, []); await provider.InvokedAsync(context); // Wait for eventual consistency @@ -768,7 +771,7 @@ public async Task MaxMessagesToRetrieve_ShouldLimitAndReturnMostRecentAsync() // Act - Set max to 5 and retrieve provider.MaxMessagesToRetrieve = 5; - var invokingContext = new ChatHistoryProvider.InvokingContext([]); + var invokingContext = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, []); var retrievedMessages = await provider.InvokingAsync(invokingContext); var messageList = retrievedMessages.ToList(); @@ -798,14 +801,14 @@ public async Task MaxMessagesToRetrieve_Null_ShouldReturnAllMessagesAsync() messages.Add(new ChatMessage(ChatRole.User, $"Message {i}")); } - var context = new ChatHistoryProvider.InvokedContext(messages, []); + var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, messages, []); await provider.InvokedAsync(context); // Wait for eventual consistency await Task.Delay(100); // Act - No limit set (default null) - var invokingContext = new ChatHistoryProvider.InvokingContext([]); + var invokingContext = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, []); var retrievedMessages = await provider.InvokingAsync(invokingContext); var messageList = retrievedMessages.ToList(); diff --git a/dotnet/tests/Microsoft.Agents.AI.DurableTask.UnitTests/DurableAgentSessionTests.cs b/dotnet/tests/Microsoft.Agents.AI.DurableTask.UnitTests/DurableAgentSessionTests.cs index 4bf8ebc718..db6ec99058 100644 --- a/dotnet/tests/Microsoft.Agents.AI.DurableTask.UnitTests/DurableAgentSessionTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.DurableTask.UnitTests/DurableAgentSessionTests.cs @@ -15,7 +15,7 @@ public void BuiltInSerialization() JsonElement serializedSession = session.Serialize(); // Expected format: "{\"sessionId\":\"@dafx-test-agent@\"}" - string expectedSerializedSession = $"{{\"sessionId\":\"@dafx-{sessionId.Name}@{sessionId.Key}\"}}"; + string expectedSerializedSession = $"{{\"sessionId\":\"@dafx-{sessionId.Name}@{sessionId.Key}\",\"stateBag\":{{}}}}"; Assert.Equal(expectedSerializedSession, serializedSession.ToString()); DurableAgentSession deserializedSession = DurableAgentSession.Deserialize(serializedSession); @@ -33,7 +33,7 @@ public void STJSerialization() string serializedSession = JsonSerializer.Serialize(session, typeof(DurableAgentSession)); // Expected format: "{\"sessionId\":\"@dafx-test-agent@\"}" - string expectedSerializedSession = $"{{\"sessionId\":\"@dafx-{sessionId.Name}@{sessionId.Key}\"}}"; + string expectedSerializedSession = $"{{\"sessionId\":\"@dafx-{sessionId.Name}@{sessionId.Key}\",\"StateBag\":{{}}}}"; Assert.Equal(expectedSerializedSession, serializedSession); DurableAgentSession? deserializedSession = JsonSerializer.Deserialize(serializedSession); diff --git a/dotnet/tests/Microsoft.Agents.AI.Mem0.IntegrationTests/Mem0ProviderTests.cs b/dotnet/tests/Microsoft.Agents.AI.Mem0.IntegrationTests/Mem0ProviderTests.cs index bacc59833a..81ca4eb588 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Mem0.IntegrationTests/Mem0ProviderTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Mem0.IntegrationTests/Mem0ProviderTests.cs @@ -18,6 +18,9 @@ public sealed class Mem0ProviderTests : IDisposable { private const string SkipReason = "Requires a Mem0 service configured"; // Set to null to enable. + private static readonly AIAgent s_mockAgent = new Moq.Mock().Object; + private static readonly AgentSession s_mockSession = new Moq.Mock().Object; + private readonly HttpClient _httpClient; public Mem0ProviderTests() @@ -49,14 +52,14 @@ public async Task CanAddAndRetrieveUserMemoriesAsync() var sut = new Mem0Provider(this._httpClient, storageScope); await sut.ClearStoredMemoriesAsync(); - var ctxBefore = await sut.InvokingAsync(new AIContextProvider.InvokingContext([question])); + var ctxBefore = await sut.InvokingAsync(new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [question])); Assert.DoesNotContain("Caoimhe", ctxBefore.Messages?[0].Text ?? string.Empty); // Act - await sut.InvokedAsync(new AIContextProvider.InvokedContext([input], aiContextProviderMessages: null)); + await sut.InvokedAsync(new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, [input], aiContextProviderMessages: null)); var ctxAfterAdding = await GetContextWithRetryAsync(sut, question); await sut.ClearStoredMemoriesAsync(); - var ctxAfterClearing = await sut.InvokingAsync(new AIContextProvider.InvokingContext([question])); + var ctxAfterClearing = await sut.InvokingAsync(new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [question])); // Assert Assert.Contains("Caoimhe", ctxAfterAdding.Messages?[0].Text ?? string.Empty); @@ -73,14 +76,14 @@ public async Task CanAddAndRetrieveAgentMemoriesAsync() var sut = new Mem0Provider(this._httpClient, storageScope); await sut.ClearStoredMemoriesAsync(); - var ctxBefore = await sut.InvokingAsync(new AIContextProvider.InvokingContext([question])); + var ctxBefore = await sut.InvokingAsync(new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [question])); Assert.DoesNotContain("Caoimhe", ctxBefore.Messages?[0].Text ?? string.Empty); // Act - await sut.InvokedAsync(new AIContextProvider.InvokedContext([assistantIntro], aiContextProviderMessages: null)); + await sut.InvokedAsync(new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, [assistantIntro], aiContextProviderMessages: null)); var ctxAfterAdding = await GetContextWithRetryAsync(sut, question); await sut.ClearStoredMemoriesAsync(); - var ctxAfterClearing = await sut.InvokingAsync(new AIContextProvider.InvokingContext([question])); + var ctxAfterClearing = await sut.InvokingAsync(new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [question])); // Assert Assert.Contains("Caoimhe", ctxAfterAdding.Messages?[0].Text ?? string.Empty); @@ -99,13 +102,13 @@ public async Task DoesNotLeakMemoriesAcrossAgentScopesAsync() await sut1.ClearStoredMemoriesAsync(); await sut2.ClearStoredMemoriesAsync(); - var ctxBefore1 = await sut1.InvokingAsync(new AIContextProvider.InvokingContext([question])); - var ctxBefore2 = await sut2.InvokingAsync(new AIContextProvider.InvokingContext([question])); + var ctxBefore1 = await sut1.InvokingAsync(new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [question])); + var ctxBefore2 = await sut2.InvokingAsync(new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [question])); Assert.DoesNotContain("Caoimhe", ctxBefore1.Messages?[0].Text ?? string.Empty); Assert.DoesNotContain("Caoimhe", ctxBefore2.Messages?[0].Text ?? string.Empty); // Act - await sut1.InvokedAsync(new AIContextProvider.InvokedContext([assistantIntro], aiContextProviderMessages: null)); + await sut1.InvokedAsync(new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, [assistantIntro], aiContextProviderMessages: null)); var ctxAfterAdding1 = await GetContextWithRetryAsync(sut1, question); var ctxAfterAdding2 = await GetContextWithRetryAsync(sut2, question); @@ -123,7 +126,7 @@ private static async Task GetContextWithRetryAsync(Mem0Provider provi AIContext? ctx = null; for (int i = 0; i < attempts; i++) { - ctx = await provider.InvokingAsync(new AIContextProvider.InvokingContext([question]), CancellationToken.None); + ctx = await provider.InvokingAsync(new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [question]), CancellationToken.None); var text = ctx.Messages?[0].Text; if (!string.IsNullOrEmpty(text) && text.IndexOf("Caoimhe", StringComparison.OrdinalIgnoreCase) >= 0) { diff --git a/dotnet/tests/Microsoft.Agents.AI.Mem0.UnitTests/Mem0ProviderTests.cs b/dotnet/tests/Microsoft.Agents.AI.Mem0.UnitTests/Mem0ProviderTests.cs index 832881857d..b886784af9 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Mem0.UnitTests/Mem0ProviderTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Mem0.UnitTests/Mem0ProviderTests.cs @@ -18,6 +18,9 @@ namespace Microsoft.Agents.AI.Mem0.UnitTests; /// public sealed class Mem0ProviderTests : IDisposable { + private static readonly AIAgent s_mockAgent = new Mock().Object; + private static readonly AgentSession s_mockSession = new Mock().Object; + private readonly Mock> _loggerMock; private readonly Mock _loggerFactoryMock; private readonly RecordingHandler _handler = new(); @@ -96,7 +99,7 @@ public async Task InvokingAsync_PerformsSearch_AndReturnsContextMessageAsync() UserId = "user" }; var sut = new Mem0Provider(this._httpClient, storageScope, options: new() { EnableSensitiveTelemetryData = true }, loggerFactory: this._loggerFactoryMock.Object); - var invokingContext = new AIContextProvider.InvokingContext([new ChatMessage(ChatRole.User, "What is my name?")]); + var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "What is my name?")]); // Act var aiContext = await sut.InvokingAsync(invokingContext); @@ -161,7 +164,7 @@ public async Task InvokingAsync_LogsUserIdBasedOnEnableSensitiveTelemetryDataAsy var options = new Mem0ProviderOptions { EnableSensitiveTelemetryData = enableSensitiveTelemetryData }; var sut = new Mem0Provider(this._httpClient, storageScope, options: options, loggerFactory: this._loggerFactoryMock.Object); - var invokingContext = new AIContextProvider.InvokingContext([new ChatMessage(ChatRole.User, "Who am I?")]); + var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Who am I?")]); // Act await sut.InvokingAsync(invokingContext, CancellationToken.None); @@ -215,7 +218,7 @@ public async Task InvokedAsync_PersistsAllowedMessagesAsync() }; // Act - await sut.InvokedAsync(new AIContextProvider.InvokedContext(requestMessages, aiContextProviderMessages: null) { ResponseMessages = responseMessages }); + await sut.InvokedAsync(new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages, aiContextProviderMessages: null) { ResponseMessages = responseMessages }); // Assert var memoryPosts = this._handler.Requests.Where(r => r.RequestMessage.RequestUri!.AbsolutePath == "/v1/memories/" && r.RequestMessage.Method == HttpMethod.Post).ToList(); @@ -242,7 +245,7 @@ public async Task InvokedAsync_PersistsNothingForFailedRequestAsync() }; // Act - await sut.InvokedAsync(new AIContextProvider.InvokedContext(requestMessages, aiContextProviderMessages: null) { ResponseMessages = null, InvokeException = new InvalidOperationException("Request Failed") }); + await sut.InvokedAsync(new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages, aiContextProviderMessages: null) { ResponseMessages = null, InvokeException = new InvalidOperationException("Request Failed") }); // Assert Assert.Empty(this._handler.Requests); @@ -268,7 +271,7 @@ public async Task InvokedAsync_ShouldNotThrow_WhenStorageFailsAsync() }; // Act - await sut.InvokedAsync(new AIContextProvider.InvokedContext(requestMessages, aiContextProviderMessages: null) { ResponseMessages = responseMessages }); + await sut.InvokedAsync(new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages, aiContextProviderMessages: null) { ResponseMessages = responseMessages }); // Assert this._loggerMock.Verify( @@ -318,7 +321,7 @@ public async Task InvokedAsync_LogsUserIdBasedOnEnableSensitiveTelemetryDataAsyn }; // Act - await sut.InvokedAsync(new AIContextProvider.InvokedContext(requestMessages, aiContextProviderMessages: null) { ResponseMessages = responseMessages }); + await sut.InvokedAsync(new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages, aiContextProviderMessages: null) { ResponseMessages = responseMessages }); // Assert Assert.Equal(expectedLogCount, this._loggerMock.Invocations.Count); @@ -400,7 +403,7 @@ public async Task InvokingAsync_ShouldNotThrow_WhenSearchFailsAsync() // Arrange var storageScope = new Mem0ProviderScope { ApplicationId = "app" }; var provider = new Mem0Provider(this._httpClient, storageScope, loggerFactory: this._loggerFactoryMock.Object); - var invokingContext = new AIContextProvider.InvokingContext([new ChatMessage(ChatRole.User, "Q?")]); + var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Q?")]); // Act var aiContext = await provider.InvokingAsync(invokingContext, CancellationToken.None); diff --git a/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgentSessionTests.cs b/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgentSessionTests.cs index 4001b59090..2acfaf1a10 100644 --- a/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgentSessionTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgentSessionTests.cs @@ -152,6 +152,33 @@ public async Task VerifyDeserializeWithAIContextProviderAsync() Assert.Same(session.AIContextProvider, mockProvider.Object); } + [Fact] + public async Task VerifyDeserializeWithStateBagAsync() + { + // Arrange + var json = JsonSerializer.Deserialize(""" + { + "conversationId": "TestConvId", + "stateBag": { + "dog": { + "jsonValue": { + "name": "Fido" + } + } + } + } + """, TestJsonSerializerContext.Default.JsonElement); + Mock mockProvider = new(); + + // Act + var session = await ChatClientAgentSession.DeserializeAsync(json, aiContextProviderFactory: (_, _, _) => new(mockProvider.Object)); + + // Assert + var dog = session.StateBag.GetValue("dog", TestJsonSerializerContext.Default.Options); + Assert.NotNull(dog); + Assert.Equal("Fido", dog.Name); + } + [Fact] public async Task DeserializeWithInvalidJsonThrowsAsync() { @@ -249,6 +276,27 @@ public void VerifySessionSerializationWithWithAIContextProvider() mockProvider.Verify(m => m.Serialize(It.IsAny()), Times.Once); } + [Fact] + public void VerifySessionSerializationWithWithStateBag() + { + // Arrange + var session = new ChatClientAgentSession(); + session.StateBag.SetValue("dog", new Animal { Name = "Fido" }, TestJsonSerializerContext.Default.Options); + + // Act + var json = session.Serialize(); + + // Assert + Assert.Equal(JsonValueKind.Object, json.ValueKind); + Assert.True(json.TryGetProperty("stateBag", out var stateBagProperty)); + Assert.Equal(JsonValueKind.Object, stateBagProperty.ValueKind); + Assert.True(stateBagProperty.TryGetProperty("dog", out var dogProperty)); + Assert.Equal(JsonValueKind.Object, dogProperty.ValueKind); + Assert.True(dogProperty.TryGetProperty("jsonValue", out var dogJsonValueProperty)); + Assert.True(dogJsonValueProperty.TryGetProperty("name", out var nameProperty)); + Assert.Equal("Fido", nameProperty.GetString()); + } + /// /// Verify session serialization to JSON with custom options. /// @@ -289,6 +337,27 @@ public void VerifySessionSerializationWithCustomOptions() #endregion Serialize Tests + #region StateBag Roundtrip Tests + + [Fact] + public async Task VerifyStateBagRoundtripsAsync() + { + // Arrange + var session = new ChatClientAgentSession(); + session.StateBag.SetValue("dog", new Animal { Name = "Fido" }, TestJsonSerializerContext.Default.Options); + + // Act + var serializedSession = session.Serialize(); + var deserializedSession = await ChatClientAgentSession.DeserializeAsync(serializedSession); + + // Assert + var dog = deserializedSession.StateBag.GetValue("dog", TestJsonSerializerContext.Default.Options); + Assert.NotNull(dog); + Assert.Equal("Fido", dog.Name); + } + + #endregion + #region GetService Tests [Fact] @@ -327,4 +396,9 @@ public void GetService_RequestingChatHistoryProvider_ReturnsChatHistoryProvider( } #endregion + + internal sealed class Animal + { + public string Name { get; set; } = string.Empty; + } } diff --git a/dotnet/tests/Microsoft.Agents.AI.UnitTests/Data/TextSearchProviderTests.cs b/dotnet/tests/Microsoft.Agents.AI.UnitTests/Data/TextSearchProviderTests.cs index 3698ee7065..360c3071ae 100644 --- a/dotnet/tests/Microsoft.Agents.AI.UnitTests/Data/TextSearchProviderTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.UnitTests/Data/TextSearchProviderTests.cs @@ -17,6 +17,9 @@ namespace Microsoft.Agents.AI.UnitTests.Data; /// public sealed class TextSearchProviderTests { + private static readonly AIAgent s_mockAgent = new Mock().Object; + private static readonly AgentSession s_mockSession = new Mock().Object; + private readonly Mock> _loggerMock; private readonly Mock _loggerFactoryMock; @@ -64,10 +67,12 @@ public async Task InvokingAsync_ShouldInjectFormattedResultsAsync(string? overri var provider = new TextSearchProvider(SearchDelegateAsync, default, null, options, withLogging ? this._loggerFactoryMock.Object : null); var invokingContext = new AIContextProvider.InvokingContext( - [ - new ChatMessage(ChatRole.User, "Sample user question?"), - new ChatMessage(ChatRole.User, "Additional part") - ]); + s_mockAgent, + s_mockSession, + [ + new ChatMessage(ChatRole.User, "Sample user question?"), + new ChatMessage(ChatRole.User, "Additional part") + ]); // Act var aiContext = await provider.InvokingAsync(invokingContext, CancellationToken.None); @@ -139,7 +144,7 @@ public async Task InvokingAsync_OnDemand_ShouldExposeSearchToolAsync(string? ove FunctionToolDescription = overrideDescription }; var provider = new TextSearchProvider(this.NoResultSearchAsync, default, null, options); - var invokingContext = new AIContextProvider.InvokingContext([new ChatMessage(ChatRole.User, "Q?")]); + var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Q?")]); // Act var aiContext = await provider.InvokingAsync(invokingContext, CancellationToken.None); @@ -158,7 +163,7 @@ public async Task InvokingAsync_ShouldNotThrow_WhenSearchFailsAsync() { // Arrange var provider = new TextSearchProvider(this.FailingSearchAsync, default, null, loggerFactory: this._loggerFactoryMock.Object); - var invokingContext = new AIContextProvider.InvokingContext([new ChatMessage(ChatRole.User, "Q?")]); + var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Q?")]); // Act var aiContext = await provider.InvokingAsync(invokingContext, CancellationToken.None); @@ -251,7 +256,7 @@ public async Task InvokingAsync_ShouldUseContextFormatterWhenProvidedAsync() ContextFormatter = r => $"Custom formatted context with {r.Count} results." }; var provider = new TextSearchProvider(SearchDelegateAsync, default, null, options); - var invokingContext = new AIContextProvider.InvokingContext([new ChatMessage(ChatRole.User, "Q?")]); + var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Q?")]); // Act var aiContext = await provider.InvokingAsync(invokingContext, CancellationToken.None); @@ -285,7 +290,7 @@ public async Task InvokingAsync_WithRawRepresentations_ContextFormatterCanAccess ContextFormatter = r => string.Join(",", r.Select(x => ((RawPayload)x.RawRepresentation!).Id)) }; var provider = new TextSearchProvider(SearchDelegateAsync, default, null, options); - var invokingContext = new AIContextProvider.InvokingContext([new ChatMessage(ChatRole.User, "Q?")]); + var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Q?")]); // Act var aiContext = await provider.InvokingAsync(invokingContext, CancellationToken.None); @@ -302,7 +307,7 @@ public async Task InvokingAsync_WithNoResults_ShouldReturnEmptyContextAsync() // Arrange var options = new TextSearchProviderOptions { SearchTime = TextSearchProviderOptions.TextSearchBehavior.BeforeAIInvoke }; var provider = new TextSearchProvider(this.NoResultSearchAsync, default, null, options); - var invokingContext = new AIContextProvider.InvokingContext([new ChatMessage(ChatRole.User, "Q?")]); + var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Q?")]); // Act var aiContext = await provider.InvokingAsync(invokingContext, CancellationToken.None); @@ -340,12 +345,14 @@ public async Task InvokingAsync_WithPreviousFailedRequest_ShouldNotIncludeFailed new ChatMessage(ChatRole.User, "C"), new ChatMessage(ChatRole.Assistant, "D"), }; - await provider.InvokedAsync(new(initialMessages, aiContextProviderMessages: null) { InvokeException = new InvalidOperationException("Request Failed") }); + await provider.InvokedAsync(new(s_mockAgent, s_mockSession, initialMessages, aiContextProviderMessages: null) { InvokeException = new InvalidOperationException("Request Failed") }); var invokingContext = new AIContextProvider.InvokingContext( - [ - new ChatMessage(ChatRole.User, "E") - ]); + s_mockAgent, + s_mockSession, + [ + new ChatMessage(ChatRole.User, "E") + ]); // Act await provider.InvokingAsync(invokingContext, CancellationToken.None); @@ -380,12 +387,14 @@ public async Task InvokingAsync_WithRecentMessageMemory_ShouldIncludeStoredMessa new ChatMessage(ChatRole.User, "C"), new ChatMessage(ChatRole.Assistant, "D"), }; - await provider.InvokedAsync(new(initialMessages, aiContextProviderMessages: null)); + await provider.InvokedAsync(new(s_mockAgent, s_mockSession, initialMessages, aiContextProviderMessages: null)); var invokingContext = new AIContextProvider.InvokingContext( - [ - new ChatMessage(ChatRole.User, "E") - ]); + s_mockAgent, + s_mockSession, + [ + new ChatMessage(ChatRole.User, "E") + ]); // Act await provider.InvokingAsync(invokingContext, CancellationToken.None); @@ -414,20 +423,24 @@ public async Task InvokingAsync_WithAccumulatedMemoryAcrossInvocations_ShouldInc // First memory update (A,B) await provider.InvokedAsync(new( - [ - new ChatMessage(ChatRole.User, "A"), - new ChatMessage(ChatRole.Assistant, "B"), - ], aiContextProviderMessages: null)); + s_mockAgent, + s_mockSession, + [ + new ChatMessage(ChatRole.User, "A"), + new ChatMessage(ChatRole.Assistant, "B"), + ], aiContextProviderMessages: null)); // Second memory update (C,D,E) await provider.InvokedAsync(new( - [ - new ChatMessage(ChatRole.User, "C"), - new ChatMessage(ChatRole.Assistant, "D"), - new ChatMessage(ChatRole.User, "E"), - ], aiContextProviderMessages: null)); + s_mockAgent, + s_mockSession, + [ + new ChatMessage(ChatRole.User, "C"), + new ChatMessage(ChatRole.Assistant, "D"), + new ChatMessage(ChatRole.User, "E"), + ], aiContextProviderMessages: null)); - var invokingContext = new AIContextProvider.InvokingContext([new ChatMessage(ChatRole.User, "F")]); + var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "F")]); // Act await provider.InvokingAsync(invokingContext, CancellationToken.None); @@ -462,12 +475,14 @@ public async Task InvokingAsync_WithRecentMessageRolesIncluded_ShouldFilterRoles new ChatMessage(ChatRole.User, "U2"), new ChatMessage(ChatRole.Assistant, "A2"), }; - await provider.InvokedAsync(new(initialMessages, null)); + await provider.InvokedAsync(new(s_mockAgent, s_mockSession, initialMessages, null)); var invokingContext = new AIContextProvider.InvokingContext( - [ - new ChatMessage(ChatRole.User, "Question?") // Current request message always appended. - ]); + s_mockAgent, + s_mockSession, + [ + new ChatMessage(ChatRole.User, "Question?") // Current request message always appended. + ]); // Act await provider.InvokingAsync(invokingContext, CancellationToken.None); @@ -518,7 +533,7 @@ public async Task Serialize_WithRecentMessages_ShouldPersistMessagesUpToLimitAsy }; // Act - await provider.InvokedAsync(new(messages, aiContextProviderMessages: null)); // Populate recent memory. + await provider.InvokedAsync(new(s_mockAgent, s_mockSession, messages, aiContextProviderMessages: null)); // Populate recent memory. var state = provider.Serialize(); // Assert @@ -547,7 +562,7 @@ public async Task SerializeAndDeserialize_RoundtripRestoresMessagesAsync() new ChatMessage(ChatRole.User, "C"), new ChatMessage(ChatRole.Assistant, "D"), }; - await provider.InvokedAsync(new(messages, aiContextProviderMessages: null)); + await provider.InvokedAsync(new(s_mockAgent, s_mockSession, messages, aiContextProviderMessages: null)); // Act var state = provider.Serialize(); @@ -563,7 +578,7 @@ public async Task SerializeAndDeserialize_RoundtripRestoresMessagesAsync() RecentMessageMemoryLimit = 4 }); var emptyMessages = Array.Empty(); - await roundTrippedProvider.InvokingAsync(new(emptyMessages), CancellationToken.None); // Trigger search to read memory. + await roundTrippedProvider.InvokingAsync(new(s_mockAgent, s_mockSession, emptyMessages), CancellationToken.None); // Trigger search to read memory. // Assert Assert.NotNull(capturedInput); @@ -588,7 +603,7 @@ public async Task Deserialize_WithChangedLowerLimit_ShouldTruncateToNewLimitAsyn new ChatMessage(ChatRole.Assistant, "L4"), new ChatMessage(ChatRole.User, "L5"), }; - await initialProvider.InvokedAsync(new(messages, aiContextProviderMessages: null)); + await initialProvider.InvokedAsync(new(s_mockAgent, s_mockSession, messages, aiContextProviderMessages: null)); var state = initialProvider.Serialize(); string? capturedInput = null; @@ -604,7 +619,7 @@ public async Task Deserialize_WithChangedLowerLimit_ShouldTruncateToNewLimitAsyn SearchTime = TextSearchProviderOptions.TextSearchBehavior.BeforeAIInvoke, RecentMessageMemoryLimit = 3 // Lower limit }); - await restoredProvider.InvokingAsync(new(Array.Empty()), CancellationToken.None); + await restoredProvider.InvokingAsync(new(s_mockAgent, s_mockSession, Array.Empty()), CancellationToken.None); // Assert Assert.NotNull(capturedInput); @@ -631,7 +646,7 @@ public async Task Deserialize_WithEmptyState_ShouldHaveNoMessagesAsync() RecentMessageMemoryLimit = 3 }); var emptyMessages = Array.Empty(); - await provider.InvokingAsync(new(emptyMessages), CancellationToken.None); + await provider.InvokingAsync(new(s_mockAgent, s_mockSession, emptyMessages), CancellationToken.None); // Assert Assert.NotNull(capturedInput); diff --git a/dotnet/tests/Microsoft.Agents.AI.UnitTests/Memory/ChatHistoryMemoryProviderTests.cs b/dotnet/tests/Microsoft.Agents.AI.UnitTests/Memory/ChatHistoryMemoryProviderTests.cs index f46538c8e4..8d3cad85ae 100644 --- a/dotnet/tests/Microsoft.Agents.AI.UnitTests/Memory/ChatHistoryMemoryProviderTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.UnitTests/Memory/ChatHistoryMemoryProviderTests.cs @@ -18,6 +18,9 @@ namespace Microsoft.Agents.AI.Memory.UnitTests; /// public class ChatHistoryMemoryProviderTests { + private static readonly AIAgent s_mockAgent = new Mock().Object; + private static readonly AgentSession s_mockSession = new Mock().Object; + private readonly Mock> _loggerMock; private readonly Mock _loggerFactoryMock; @@ -116,7 +119,7 @@ public async Task InvokedAsync_UpsertsMessages_ToCollectionAsync() var requestMsgWithNulls = new ChatMessage(ChatRole.User, "request text nulls"); var responseMsg = new ChatMessage(ChatRole.Assistant, "response text") { MessageId = "resp-1", AuthorName = "assistant" }; - var invokedContext = new AIContextProvider.InvokedContext([requestMsgWithValues, requestMsgWithNulls], aiContextProviderMessages: null) + var invokedContext = new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, [requestMsgWithValues, requestMsgWithNulls], aiContextProviderMessages: null) { ResponseMessages = [responseMsg] }; @@ -174,7 +177,7 @@ public async Task InvokedAsync_DoesNotUpsertMessages_WhenInvokeFailedAsync() 1, new ChatHistoryMemoryProviderScope() { UserId = "UID" }); var requestMsg = new ChatMessage(ChatRole.User, "request text") { MessageId = "req-1" }; - var invokedContext = new AIContextProvider.InvokedContext([requestMsg], aiContextProviderMessages: null) + var invokedContext = new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, [requestMsg], aiContextProviderMessages: null) { InvokeException = new InvalidOperationException("Invoke failed") }; @@ -203,7 +206,7 @@ public async Task InvokedAsync_DoesNotThrow_WhenUpsertThrowsAsync() new ChatHistoryMemoryProviderScope() { UserId = "UID" }, loggerFactory: this._loggerFactoryMock.Object); var requestMsg = new ChatMessage(ChatRole.User, "request text") { MessageId = "req-1" }; - var invokedContext = new AIContextProvider.InvokedContext([requestMsg], aiContextProviderMessages: null); + var invokedContext = new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, [requestMsg], aiContextProviderMessages: null); // Act await provider.InvokedAsync(invokedContext, CancellationToken.None); @@ -254,7 +257,7 @@ public async Task InvokedAsync_LogsUserIdBasedOnEnableSensitiveTelemetryDataAsyn loggerFactory: this._loggerFactoryMock.Object); var requestMsg = new ChatMessage(ChatRole.User, "request text"); - var invokedContext = new AIContextProvider.InvokedContext([requestMsg], aiContextProviderMessages: null); + var invokedContext = new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, [requestMsg], aiContextProviderMessages: null); // Act await provider.InvokedAsync(invokedContext, CancellationToken.None); @@ -327,7 +330,7 @@ public async Task InvokedAsync_SearchesVectorStoreAsync() options: providerOptions); var requestMsg = new ChatMessage(ChatRole.User, "requesting relevant history"); - var invokingContext = new AIContextProvider.InvokingContext([requestMsg]); + var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [requestMsg]); // Act await provider.InvokingAsync(invokingContext, CancellationToken.None); @@ -378,7 +381,7 @@ public async Task InvokedAsync_CreatesFilter_WhenSearchScopeProvidedAsync() var provider = new ChatHistoryMemoryProvider(this._vectorStoreMock.Object, TestCollectionName, 1, options: providerOptions, storageScope: searchScope, searchScope: searchScope); var requestMsg = new ChatMessage(ChatRole.User, "requesting relevant history"); - var invokingContext = new AIContextProvider.InvokingContext([requestMsg]); + var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [requestMsg]); // Act await provider.InvokingAsync(invokingContext, CancellationToken.None); @@ -442,7 +445,7 @@ public async Task InvokingAsync_LogsUserIdBasedOnEnableSensitiveTelemetryDataAsy options: options, loggerFactory: this._loggerFactoryMock.Object); - var invokingContext = new AIContextProvider.InvokingContext([new ChatMessage(ChatRole.User, "requesting relevant history")]); + var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "requesting relevant history")]); // Act await provider.InvokingAsync(invokingContext, CancellationToken.None); diff --git a/dotnet/tests/Microsoft.Agents.AI.UnitTests/TestJsonSerializerContext.cs b/dotnet/tests/Microsoft.Agents.AI.UnitTests/TestJsonSerializerContext.cs index b145991439..0ac3ab9fbf 100644 --- a/dotnet/tests/Microsoft.Agents.AI.UnitTests/TestJsonSerializerContext.cs +++ b/dotnet/tests/Microsoft.Agents.AI.UnitTests/TestJsonSerializerContext.cs @@ -14,4 +14,5 @@ namespace Microsoft.Agents.AI.UnitTests; [JsonSerializable(typeof(string))] [JsonSerializable(typeof(string[]))] [JsonSerializable(typeof(Dictionary))] +[JsonSerializable(typeof(ChatClientAgentSessionTests.Animal))] internal sealed partial class TestJsonSerializerContext : JsonSerializerContext; diff --git a/dotnet/tests/OpenAIAssistant.IntegrationTests/OpenAIAssistantFixture.cs b/dotnet/tests/OpenAIAssistant.IntegrationTests/OpenAIAssistantFixture.cs index a90f49f428..bb58f09fb4 100644 --- a/dotnet/tests/OpenAIAssistant.IntegrationTests/OpenAIAssistantFixture.cs +++ b/dotnet/tests/OpenAIAssistant.IntegrationTests/OpenAIAssistantFixture.cs @@ -23,7 +23,7 @@ public class OpenAIAssistantFixture : IChatClientAgentFixture public IChatClient ChatClient => this._agent.ChatClient; - public async Task> GetChatHistoryAsync(AgentSession session) + public async Task> GetChatHistoryAsync(AIAgent agent, AgentSession session) { var typedSession = (ChatClientAgentSession)session; List messages = []; diff --git a/dotnet/tests/OpenAIChatCompletion.IntegrationTests/OpenAIChatCompletionFixture.cs b/dotnet/tests/OpenAIChatCompletion.IntegrationTests/OpenAIChatCompletionFixture.cs index eeb60620c0..304df28fba 100644 --- a/dotnet/tests/OpenAIChatCompletion.IntegrationTests/OpenAIChatCompletionFixture.cs +++ b/dotnet/tests/OpenAIChatCompletion.IntegrationTests/OpenAIChatCompletionFixture.cs @@ -28,7 +28,7 @@ public OpenAIChatCompletionFixture(bool useReasoningChatModel) public IChatClient ChatClient => this._agent.ChatClient; - public async Task> GetChatHistoryAsync(AgentSession session) + public async Task> GetChatHistoryAsync(AIAgent agent, AgentSession session) { var typedSession = (ChatClientAgentSession)session; @@ -37,7 +37,7 @@ public async Task> GetChatHistoryAsync(AgentSession session) return []; } - return (await typedSession.ChatHistoryProvider.InvokingAsync(new([]))).ToList(); + return (await typedSession.ChatHistoryProvider.InvokingAsync(new(agent, session, []))).ToList(); } public Task CreateChatClientAgentAsync( diff --git a/dotnet/tests/OpenAIResponse.IntegrationTests/OpenAIResponseFixture.cs b/dotnet/tests/OpenAIResponse.IntegrationTests/OpenAIResponseFixture.cs index 2006404239..719db6a0b0 100644 --- a/dotnet/tests/OpenAIResponse.IntegrationTests/OpenAIResponseFixture.cs +++ b/dotnet/tests/OpenAIResponse.IntegrationTests/OpenAIResponseFixture.cs @@ -25,7 +25,7 @@ public class OpenAIResponseFixture(bool store) : IChatClientAgentFixture public IChatClient ChatClient => this._agent.ChatClient; - public async Task> GetChatHistoryAsync(AgentSession session) + public async Task> GetChatHistoryAsync(AIAgent agent, AgentSession session) { var typedSession = (ChatClientAgentSession)session; @@ -55,7 +55,7 @@ public async Task> GetChatHistoryAsync(AgentSession session) return []; } - return (await typedSession.ChatHistoryProvider.InvokingAsync(new([]))).ToList(); + return (await typedSession.ChatHistoryProvider.InvokingAsync(new(agent, session, []))).ToList(); } private static ChatMessage ConvertToChatMessage(ResponseItem item) From ddc8ac9fc3f57963f4479bbd45cf7d2882a422d1 Mon Sep 17 00:00:00 2001 From: westey <164392973+westey-m@users.noreply.github.com> Date: Thu, 5 Feb 2026 12:14:52 +0000 Subject: [PATCH 2/5] Remove statebag code from this branch, to get the refactoring out of the way first --- .../AgentAbstractionsJsonUtilities.cs | 2 - .../AgentSession.cs | 5 - .../AgentSessionStateBag.cs | 180 -------- .../AgentSessionStateBagValue.cs | 67 --- .../ChatClient/ChatClientAgentSession.cs | 5 - .../AgentSessionStateBagTests.cs | 407 ------------------ .../AgentSessionTests.cs | 15 - .../DurableAgentSessionTests.cs | 4 +- .../ChatClient/ChatClientAgentSessionTests.cs | 69 --- 9 files changed, 2 insertions(+), 752 deletions(-) delete mode 100644 dotnet/src/Microsoft.Agents.AI.Abstractions/AgentSessionStateBag.cs delete mode 100644 dotnet/src/Microsoft.Agents.AI.Abstractions/AgentSessionStateBagValue.cs delete mode 100644 dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AgentSessionStateBagTests.cs diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentAbstractionsJsonUtilities.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentAbstractionsJsonUtilities.cs index bf0e835b4b..17fbb9e4c6 100644 --- a/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentAbstractionsJsonUtilities.cs +++ b/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentAbstractionsJsonUtilities.cs @@ -1,6 +1,5 @@ // Copyright (c) Microsoft. All rights reserved. -using System.Collections.Concurrent; using System.Diagnostics.CodeAnalysis; using System.Text.Encodings.Web; using System.Text.Json; @@ -84,7 +83,6 @@ private static JsonSerializerOptions CreateDefaultOptions() [JsonSerializable(typeof(ServiceIdAgentSession.ServiceIdAgentSessionState))] [JsonSerializable(typeof(InMemoryAgentSession.InMemoryAgentSessionState))] [JsonSerializable(typeof(InMemoryChatHistoryProvider.State))] - [JsonSerializable(typeof(ConcurrentDictionary))] [ExcludeFromCodeCoverage] private sealed partial class JsonContext : JsonSerializerContext; diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentSession.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentSession.cs index 722660d49e..3efce9be17 100644 --- a/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentSession.cs +++ b/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentSession.cs @@ -53,11 +53,6 @@ protected AgentSession() { } - /// - /// Gets any arbitrary state associated with this session. - /// - public AgentSessionStateBag StateBag { get; protected set; } = new(); - /// Asks the for an object of the specified type . /// The type of object being requested. /// An optional key that can be used to help identify the target service. diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentSessionStateBag.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentSessionStateBag.cs deleted file mode 100644 index 47eef61508..0000000000 --- a/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentSessionStateBag.cs +++ /dev/null @@ -1,180 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. - -using System; -using System.Collections.Concurrent; -using System.Text.Json; -using Microsoft.Shared.Diagnostics; - -namespace Microsoft.Agents.AI; - -/// -/// Provides a thread-safe key-value store for managing session-scoped state with support for type-safe access and JSON -/// serialization options. -/// -/// -/// SessionState enables storing and retrieving objects associated with a session using string keys. -/// Values can be accessed in a type-safe manner and are serialized or deserialized using configurable JSON serializer -/// options. This class is designed for concurrent access and is safe to use across multiple threads. -/// -public class AgentSessionStateBag -{ - private readonly ConcurrentDictionary _state; - - /// - /// Initializes a new instance of the class. - /// - public AgentSessionStateBag() - { - this._state = new ConcurrentDictionary(); - } - - /// - /// Initializes a new instance of the class. - /// - /// The initial state dictionary. - internal AgentSessionStateBag(ConcurrentDictionary? state) - { - this._state = state ?? new ConcurrentDictionary(); - } - - /// - /// Tries to get a value from the session state. - /// - /// The type of the value to retrieve. - /// The key from which to retrieve the value. - /// The value if found and convertable to the required type; otherwise, null. - /// The JSON serializer options to use for serializing/deserialing the value. - /// if the value was successfully retrieved, otherwise. - public bool TryGetValue(string key, out T? value, JsonSerializerOptions? jsonSerializerOptions = null) - where T : class - { - _ = Throw.IfNullOrWhitespace(key); - var jso = jsonSerializerOptions ?? AgentAbstractionsJsonUtilities.DefaultOptions; - - if (this._state.TryGetValue(key, out var stateValue)) - { - if (stateValue.DeserializedValue is T cachedValue) - { - value = cachedValue; - return true; - } - - switch (stateValue.JsonValue) - { - case T tValue: - value = tValue; - return true; - case JsonElement jsonElement when jsonElement.ValueKind == JsonValueKind.Null || jsonElement.ValueKind == JsonValueKind.Undefined: - value = null; - return false; - default: - T? result = stateValue.JsonValue.Deserialize(jso.GetTypeInfo(typeof(T))) as T; - if (result is null) - { - value = null; - return false; - } - - stateValue.DeserializedValue = result; - stateValue.ValueType = typeof(T); - stateValue.JsonSerializerOptions = jso; - - value = result; - return true; - } - } - value = null; - return false; - } - - /// - /// Gets a value from the session state. - /// - /// The type of value to get. - /// The key from which to retrieve the value. - /// The JSON serializer options to use for serializing/deserialing the value. - /// The retrieved value or null if not found. - /// The value could not be deserialized into the required type. - public T? GetValue(string key, JsonSerializerOptions? jsonSerializerOptions = null) - where T : class - { - _ = Throw.IfNullOrWhitespace(key); - var jso = jsonSerializerOptions ?? AgentAbstractionsJsonUtilities.DefaultOptions; - - if (this._state.TryGetValue(key, out var stateValue)) - { - if (stateValue.DeserializedValue is T cachedValue) - { - return cachedValue; - } - - switch (stateValue.JsonValue) - { - case T tValue: - return tValue; - case JsonElement jsonElement when jsonElement.ValueKind == JsonValueKind.Null || jsonElement.ValueKind == JsonValueKind.Undefined: - return null; - default: - T? result = stateValue.JsonValue.Deserialize(jso.GetTypeInfo(typeof(T))) as T; - if (result is null) - { - throw new InvalidOperationException($"Failed to deserialize session state value to type {typeof(T).FullName}."); - } - stateValue.DeserializedValue = result; - stateValue.ValueType = typeof(T); - stateValue.JsonSerializerOptions = jso; - return result; - } - } - - return null; - } - - /// - /// Sets a value in the session state. - /// - /// The type of the value to set. - /// The key to store the value under. - /// The value to set. - /// The JSON serializer options to use for serializing the value. - public void SetValue(string key, T value, JsonSerializerOptions? jsonSerializerOptions = null) - where T : class - { - _ = Throw.IfNullOrWhitespace(key); - var jso = jsonSerializerOptions ?? AgentAbstractionsJsonUtilities.DefaultOptions; - - var stateValue = this._state.GetOrAdd(key, _ => - new AgentSessionStateBagValue(value, typeof(T), jso)); - - stateValue.DeserializedValue = value; - stateValue.ValueType = typeof(T); - stateValue.JsonSerializerOptions = jso; - } - - /// - /// Serializes all session state values to a JSON object. - /// - /// A representing the serialized session state. - /// Thrown when a session state value is not properly initialized. - public JsonElement Serialize() - { - return JsonSerializer.SerializeToElement(this._state, AgentAbstractionsJsonUtilities.DefaultOptions.GetTypeInfo(typeof(ConcurrentDictionary))); - } - - /// - /// Deserializes a JSON object into an instance. - /// - /// The element to deserialize. - /// The deserialized . - public static AgentSessionStateBag Deserialize(JsonElement jsonElement) - { - if (jsonElement.ValueKind is JsonValueKind.Undefined or JsonValueKind.Null) - { - return new AgentSessionStateBag(); - } - - return new AgentSessionStateBag( - jsonElement.Deserialize(AgentAbstractionsJsonUtilities.DefaultOptions.GetTypeInfo(typeof(ConcurrentDictionary))) as ConcurrentDictionary - ?? new ConcurrentDictionary()); - } -} diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentSessionStateBagValue.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentSessionStateBagValue.cs deleted file mode 100644 index a4c1743e77..0000000000 --- a/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentSessionStateBagValue.cs +++ /dev/null @@ -1,67 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. - -using System; -using System.Text.Json; -using System.Text.Json.Serialization; - -namespace Microsoft.Agents.AI; - -/// -/// Used to store a value in session state. -/// -internal class AgentSessionStateBagValue -{ - /// - /// Initializes a new instance of the SessionStateValue class with the specified value. - /// - /// The serialized value to associate with the session state. - [JsonConstructor] - public AgentSessionStateBagValue(JsonElement jsonValue) - { - this.JsonValue = jsonValue; - } - - /// - /// Initializes a new instance of the SessionStateValue class with the specified value. - /// - /// The value to associate with the session state. Can be any object, including null. - /// The type of the value. - /// The JSON serializer options to use for serializing the value. - public AgentSessionStateBagValue(object deserializedValue, Type valueType, JsonSerializerOptions jsonSerializerOptions) - { - this.DeserializedValue = deserializedValue; - this.ValueType = valueType; - this.JsonSerializerOptions = jsonSerializerOptions; - } - - /// - /// Gets or sets the value associated with this instance. - /// - public JsonElement JsonValue - { - get - { - if (this.DeserializedValue != null) - { - if (this.ValueType is null || this.JsonSerializerOptions is null) - { - throw new InvalidOperationException($"{nameof(AgentSessionStateBagValue)} has not been properly initialized, please set {nameof(this.ValueType)} and {nameof(this.JsonSerializerOptions)} before accessing {nameof(this.JsonValue)}."); - } - - return JsonSerializer.SerializeToElement(this.DeserializedValue, this.JsonSerializerOptions.GetTypeInfo(this.ValueType)); - } - - return field; - } - set; - } - - [JsonIgnore] - public object? DeserializedValue { get; set; } - - [JsonIgnore] - public Type? ValueType { get; set; } - - [JsonIgnore] - public JsonSerializerOptions? JsonSerializerOptions { get; set; } -} diff --git a/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgentSession.cs b/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgentSession.cs index cb791bd564..1a79ae64d1 100644 --- a/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgentSession.cs +++ b/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgentSession.cs @@ -148,8 +148,6 @@ internal static async Task DeserializeAsync( ? await aiContextProviderFactory.Invoke(state?.AIContextProviderState ?? default, jsonSerializerOptions, cancellationToken).ConfigureAwait(false) : null; - session.StateBag = AgentSessionStateBag.Deserialize(state?.StateBag ?? default); - if (state?.ConversationId is string sessionId) { session.ConversationId = sessionId; @@ -178,7 +176,6 @@ internal JsonElement Serialize(JsonSerializerOptions? jsonSerializerOptions = nu ConversationId = this.ConversationId, ChatHistoryProviderState = chatHistoryProviderState is { ValueKind: not JsonValueKind.Undefined } ? chatHistoryProviderState : null, AIContextProviderState = aiContextProviderState is { ValueKind: not JsonValueKind.Undefined } ? aiContextProviderState : null, - StateBag = this.StateBag.Serialize(), }; return JsonSerializer.SerializeToElement(state, AgentJsonUtilities.DefaultOptions.GetTypeInfo(typeof(SessionState))); @@ -204,7 +201,5 @@ internal sealed class SessionState public JsonElement? ChatHistoryProviderState { get; set; } public JsonElement? AIContextProviderState { get; set; } - - public JsonElement? StateBag { get; set; } } } diff --git a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AgentSessionStateBagTests.cs b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AgentSessionStateBagTests.cs deleted file mode 100644 index 5d07a4749f..0000000000 --- a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AgentSessionStateBagTests.cs +++ /dev/null @@ -1,407 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. - -using System; -using System.Text.Json; -using Microsoft.Agents.AI.Abstractions.UnitTests.Models; - -namespace Microsoft.Agents.AI.Abstractions.UnitTests; - -/// -/// Contains tests for the class. -/// -public sealed class AgentSessionStateBagTests -{ - #region Constructor Tests - - [Fact] - public void Constructor_Default_CreatesEmptyStateBag() - { - // Act - var stateBag = new AgentSessionStateBag(); - - // Assert - Assert.False(stateBag.TryGetValue("nonexistent", out _)); - } - - #endregion - - #region SetValue Tests - - [Fact] - public void SetValue_WithValidKeyAndValue_StoresValue() - { - // Arrange - var stateBag = new AgentSessionStateBag(); - - // Act - stateBag.SetValue("key1", "value1"); - - // Assert - Assert.True(stateBag.TryGetValue("key1", out var result)); - Assert.Equal("value1", result); - } - - [Fact] - public void SetValue_WithNullKey_ThrowsArgumentException() - { - // Arrange - var stateBag = new AgentSessionStateBag(); - - // Act & Assert - Assert.Throws(() => stateBag.SetValue(null!, "value")); - } - - [Fact] - public void SetValue_WithEmptyKey_ThrowsArgumentException() - { - // Arrange - var stateBag = new AgentSessionStateBag(); - - // Act & Assert - Assert.Throws(() => stateBag.SetValue("", "value")); - } - - [Fact] - public void SetValue_WithWhitespaceKey_ThrowsArgumentException() - { - // Arrange - var stateBag = new AgentSessionStateBag(); - - // Act & Assert - Assert.Throws(() => stateBag.SetValue(" ", "value")); - } - - [Fact] - public void SetValue_OverwritesExistingValue() - { - // Arrange - var stateBag = new AgentSessionStateBag(); - stateBag.SetValue("key1", "originalValue"); - - // Act - stateBag.SetValue("key1", "newValue"); - - // Assert - Assert.Equal("newValue", stateBag.GetValue("key1")); - } - - #endregion - - #region GetValue Tests - - [Fact] - public void GetValue_WithExistingKey_ReturnsValue() - { - // Arrange - var stateBag = new AgentSessionStateBag(); - stateBag.SetValue("key1", "value1"); - - // Act - var result = stateBag.GetValue("key1"); - - // Assert - Assert.Equal("value1", result); - } - - [Fact] - public void GetValue_WithNonexistentKey_ReturnsNull() - { - // Arrange - var stateBag = new AgentSessionStateBag(); - - // Act - var result = stateBag.GetValue("nonexistent"); - - // Assert - Assert.Null(result); - } - - [Fact] - public void GetValue_WithNullKey_ThrowsArgumentException() - { - // Arrange - var stateBag = new AgentSessionStateBag(); - - // Act & Assert - Assert.Throws(() => stateBag.GetValue(null!)); - } - - [Fact] - public void GetValue_WithEmptyKey_ThrowsArgumentException() - { - // Arrange - var stateBag = new AgentSessionStateBag(); - - // Act & Assert - Assert.Throws(() => stateBag.GetValue("")); - } - - [Fact] - public void GetValue_CachesDeserializedValue() - { - // Arrange - var stateBag = new AgentSessionStateBag(); - stateBag.SetValue("key1", "value1"); - - // Act - var result1 = stateBag.GetValue("key1"); - var result2 = stateBag.GetValue("key1"); - - // Assert - Assert.Same(result1, result2); - } - - #endregion - - #region TryGetValue Tests - - [Fact] - public void TryGetValue_WithExistingKey_ReturnsTrueAndValue() - { - // Arrange - var stateBag = new AgentSessionStateBag(); - stateBag.SetValue("key1", "value1"); - - // Act - var found = stateBag.TryGetValue("key1", out var result); - - // Assert - Assert.True(found); - Assert.Equal("value1", result); - } - - [Fact] - public void TryGetValue_WithNonexistentKey_ReturnsFalseAndNull() - { - // Arrange - var stateBag = new AgentSessionStateBag(); - - // Act - var found = stateBag.TryGetValue("nonexistent", out var result); - - // Assert - Assert.False(found); - Assert.Null(result); - } - - [Fact] - public void TryGetValue_WithNullKey_ThrowsArgumentException() - { - // Arrange - var stateBag = new AgentSessionStateBag(); - - // Act & Assert - Assert.Throws(() => stateBag.TryGetValue(null!, out _)); - } - - [Fact] - public void TryGetValue_WithEmptyKey_ThrowsArgumentException() - { - // Arrange - var stateBag = new AgentSessionStateBag(); - - // Act & Assert - Assert.Throws(() => stateBag.TryGetValue("", out _)); - } - - #endregion - - #region Serialize/Deserialize Tests - - [Fact] - public void Serialize_EmptyStateBag_ReturnsEmptyObject() - { - // Arrange - var stateBag = new AgentSessionStateBag(); - - // Act - var json = stateBag.Serialize(); - - // Assert - Assert.Equal(JsonValueKind.Object, json.ValueKind); - } - - [Fact] - public void Serialize_WithStringValue_ReturnsJsonWithValue() - { - // Arrange - var stateBag = new AgentSessionStateBag(); - stateBag.SetValue("stringKey", "stringValue"); - - // Act - var json = stateBag.Serialize(); - - // Assert - Assert.Equal(JsonValueKind.Object, json.ValueKind); - Assert.True(json.TryGetProperty("stringKey", out _)); - } - - [Fact] - public void Deserialize_FromJsonDocument_ReturnsEmptyStateBag() - { - // Arrange - var emptyJson = JsonDocument.Parse("{}").RootElement; - - // Act - var stateBag = AgentSessionStateBag.Deserialize(emptyJson); - - // Assert - Assert.False(stateBag.TryGetValue("nonexistent", out _)); - } - - [Fact] - public void Deserialize_NullElement_ReturnsEmptyStateBag() - { - // Arrange - var nullJson = default(JsonElement); - - // Act - var stateBag = AgentSessionStateBag.Deserialize(nullJson); - - // Assert - Assert.False(stateBag.TryGetValue("nonexistent", out _)); - } - - [Fact] - public void SerializeDeserialize_WithStringValue_Roundtrips() - { - // Arrange - var originalStateBag = new AgentSessionStateBag(); - originalStateBag.SetValue("stringKey", "stringValue"); - - // Act - var json = originalStateBag.Serialize(); - var restoredStateBag = AgentSessionStateBag.Deserialize(json); - - // Assert - Assert.Equal("stringValue", restoredStateBag.GetValue("stringKey")); - } - - #endregion - - #region Thread Safety Tests - - [Fact] - public async System.Threading.Tasks.Task SetValue_MultipleConcurrentWrites_DoesNotThrowAsync() - { - // Arrange - var stateBag = new AgentSessionStateBag(); - var tasks = new System.Threading.Tasks.Task[100]; - - // Act - for (int i = 0; i < 100; i++) - { - int index = i; - tasks[i] = System.Threading.Tasks.Task.Run(() => stateBag.SetValue($"key{index}", $"value{index}")); - } - - await System.Threading.Tasks.Task.WhenAll(tasks); - - // Assert - for (int i = 0; i < 100; i++) - { - Assert.True(stateBag.TryGetValue($"key{i}", out var value)); - Assert.Equal($"value{i}", value); - } - } - - #endregion - - #region Complex Object Tests - - [Fact] - public void SetValue_WithComplexObject_StoresValue() - { - // Arrange - var stateBag = new AgentSessionStateBag(); - var animal = new Animal { Id = 1, FullName = "Buddy", Species = Species.Bear }; - - // Act - stateBag.SetValue("animal", animal, TestJsonSerializerContext.Default.Options); - - // Assert - Animal? result = stateBag.GetValue("animal", TestJsonSerializerContext.Default.Options); - Assert.NotNull(result); - Assert.Equal(1, result.Id); - Assert.Equal("Buddy", result.FullName); - Assert.Equal(Species.Bear, result.Species); - } - - [Fact] - public void GetValue_WithComplexObject_CachesDeserializedValue() - { - // Arrange - var stateBag = new AgentSessionStateBag(); - var animal = new Animal { Id = 2, FullName = "Whiskers", Species = Species.Tiger }; - stateBag.SetValue("animal", animal, TestJsonSerializerContext.Default.Options); - - // Act - Animal? result1 = stateBag.GetValue("animal", TestJsonSerializerContext.Default.Options); - Animal? result2 = stateBag.GetValue("animal", TestJsonSerializerContext.Default.Options); - - // Assert - Assert.Same(result1, result2); - } - - [Fact] - public void TryGetValue_WithComplexObject_ReturnsTrueAndValue() - { - // Arrange - var stateBag = new AgentSessionStateBag(); - var animal = new Animal { Id = 3, FullName = "Goldie", Species = Species.Walrus }; - stateBag.SetValue("animal", animal, TestJsonSerializerContext.Default.Options); - - // Act - bool found = stateBag.TryGetValue("animal", out Animal? result, TestJsonSerializerContext.Default.Options); - - // Assert - Assert.True(found); - Assert.NotNull(result); - Assert.Equal(3, result.Id); - Assert.Equal("Goldie", result.FullName); - Assert.Equal(Species.Walrus, result.Species); - } - - [Fact] - public void SerializeDeserialize_WithComplexObject_Roundtrips() - { - // Arrange - var originalStateBag = new AgentSessionStateBag(); - var animal = new Animal { Id = 4, FullName = "Polly", Species = Species.Bear }; - originalStateBag.SetValue("animal", animal, TestJsonSerializerContext.Default.Options); - - // Act - JsonElement json = originalStateBag.Serialize(); - AgentSessionStateBag restoredStateBag = AgentSessionStateBag.Deserialize(json); - - // Assert - Animal? restoredAnimal = restoredStateBag.GetValue("animal", TestJsonSerializerContext.Default.Options); - Assert.NotNull(restoredAnimal); - Assert.Equal(4, restoredAnimal.Id); - Assert.Equal("Polly", restoredAnimal.FullName); - Assert.Equal(Species.Bear, restoredAnimal.Species); - } - - [Fact] - public void Serialize_WithComplexObject_ReturnsJsonWithProperties() - { - // Arrange - var stateBag = new AgentSessionStateBag(); - var animal = new Animal { Id = 7, FullName = "Spot", Species = Species.Walrus }; - stateBag.SetValue("animal", animal, TestJsonSerializerContext.Default.Options); - - // Act - JsonElement json = stateBag.Serialize(); - - // Assert - Assert.Equal(JsonValueKind.Object, json.ValueKind); - Assert.True(json.TryGetProperty("animal", out JsonElement animalElement)); - Assert.True(animalElement.TryGetProperty("jsonValue", out JsonElement jsonValueElement)); - Assert.Equal(JsonValueKind.Object, jsonValueElement.ValueKind); - Assert.Equal(7, jsonValueElement.GetProperty("id").GetInt32()); - Assert.Equal("Spot", jsonValueElement.GetProperty("fullName").GetString()); - Assert.Equal("Walrus", jsonValueElement.GetProperty("species").GetString()); - } - - #endregion -} diff --git a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AgentSessionTests.cs b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AgentSessionTests.cs index b80f0a4fd2..5a776c9fb0 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AgentSessionTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AgentSessionTests.cs @@ -11,21 +11,6 @@ namespace Microsoft.Agents.AI.Abstractions.UnitTests; /// public class AgentSessionTests { - #region StateBag Tests - - [Fact] - public void StateBag_Values_Roundtrips() - { - // Arrange - var session = new TestAgentSession(); - - // Act & Assert - session.StateBag.SetValue("key1", "value1"); - Assert.Equal("value1", session.StateBag.GetValue("key1")); - } - - #endregion - #region GetService Method Tests /// diff --git a/dotnet/tests/Microsoft.Agents.AI.DurableTask.UnitTests/DurableAgentSessionTests.cs b/dotnet/tests/Microsoft.Agents.AI.DurableTask.UnitTests/DurableAgentSessionTests.cs index db6ec99058..4bf8ebc718 100644 --- a/dotnet/tests/Microsoft.Agents.AI.DurableTask.UnitTests/DurableAgentSessionTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.DurableTask.UnitTests/DurableAgentSessionTests.cs @@ -15,7 +15,7 @@ public void BuiltInSerialization() JsonElement serializedSession = session.Serialize(); // Expected format: "{\"sessionId\":\"@dafx-test-agent@\"}" - string expectedSerializedSession = $"{{\"sessionId\":\"@dafx-{sessionId.Name}@{sessionId.Key}\",\"stateBag\":{{}}}}"; + string expectedSerializedSession = $"{{\"sessionId\":\"@dafx-{sessionId.Name}@{sessionId.Key}\"}}"; Assert.Equal(expectedSerializedSession, serializedSession.ToString()); DurableAgentSession deserializedSession = DurableAgentSession.Deserialize(serializedSession); @@ -33,7 +33,7 @@ public void STJSerialization() string serializedSession = JsonSerializer.Serialize(session, typeof(DurableAgentSession)); // Expected format: "{\"sessionId\":\"@dafx-test-agent@\"}" - string expectedSerializedSession = $"{{\"sessionId\":\"@dafx-{sessionId.Name}@{sessionId.Key}\",\"StateBag\":{{}}}}"; + string expectedSerializedSession = $"{{\"sessionId\":\"@dafx-{sessionId.Name}@{sessionId.Key}\"}}"; Assert.Equal(expectedSerializedSession, serializedSession); DurableAgentSession? deserializedSession = JsonSerializer.Deserialize(serializedSession); diff --git a/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgentSessionTests.cs b/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgentSessionTests.cs index 2acfaf1a10..fd311f9225 100644 --- a/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgentSessionTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgentSessionTests.cs @@ -152,33 +152,6 @@ public async Task VerifyDeserializeWithAIContextProviderAsync() Assert.Same(session.AIContextProvider, mockProvider.Object); } - [Fact] - public async Task VerifyDeserializeWithStateBagAsync() - { - // Arrange - var json = JsonSerializer.Deserialize(""" - { - "conversationId": "TestConvId", - "stateBag": { - "dog": { - "jsonValue": { - "name": "Fido" - } - } - } - } - """, TestJsonSerializerContext.Default.JsonElement); - Mock mockProvider = new(); - - // Act - var session = await ChatClientAgentSession.DeserializeAsync(json, aiContextProviderFactory: (_, _, _) => new(mockProvider.Object)); - - // Assert - var dog = session.StateBag.GetValue("dog", TestJsonSerializerContext.Default.Options); - Assert.NotNull(dog); - Assert.Equal("Fido", dog.Name); - } - [Fact] public async Task DeserializeWithInvalidJsonThrowsAsync() { @@ -276,27 +249,6 @@ public void VerifySessionSerializationWithWithAIContextProvider() mockProvider.Verify(m => m.Serialize(It.IsAny()), Times.Once); } - [Fact] - public void VerifySessionSerializationWithWithStateBag() - { - // Arrange - var session = new ChatClientAgentSession(); - session.StateBag.SetValue("dog", new Animal { Name = "Fido" }, TestJsonSerializerContext.Default.Options); - - // Act - var json = session.Serialize(); - - // Assert - Assert.Equal(JsonValueKind.Object, json.ValueKind); - Assert.True(json.TryGetProperty("stateBag", out var stateBagProperty)); - Assert.Equal(JsonValueKind.Object, stateBagProperty.ValueKind); - Assert.True(stateBagProperty.TryGetProperty("dog", out var dogProperty)); - Assert.Equal(JsonValueKind.Object, dogProperty.ValueKind); - Assert.True(dogProperty.TryGetProperty("jsonValue", out var dogJsonValueProperty)); - Assert.True(dogJsonValueProperty.TryGetProperty("name", out var nameProperty)); - Assert.Equal("Fido", nameProperty.GetString()); - } - /// /// Verify session serialization to JSON with custom options. /// @@ -337,27 +289,6 @@ public void VerifySessionSerializationWithCustomOptions() #endregion Serialize Tests - #region StateBag Roundtrip Tests - - [Fact] - public async Task VerifyStateBagRoundtripsAsync() - { - // Arrange - var session = new ChatClientAgentSession(); - session.StateBag.SetValue("dog", new Animal { Name = "Fido" }, TestJsonSerializerContext.Default.Options); - - // Act - var serializedSession = session.Serialize(); - var deserializedSession = await ChatClientAgentSession.DeserializeAsync(serializedSession); - - // Assert - var dog = deserializedSession.StateBag.GetValue("dog", TestJsonSerializerContext.Default.Options); - Assert.NotNull(dog); - Assert.Equal("Fido", dog.Name); - } - - #endregion - #region GetService Tests [Fact] From 4a658ac940fbd508e9e7e3c83e556c2c658464f5 Mon Sep 17 00:00:00 2001 From: westey <164392973+westey-m@users.noreply.github.com> Date: Thu, 5 Feb 2026 14:42:32 +0000 Subject: [PATCH 3/5] Apply suggestion from @rogerbarreto Co-authored-by: Roger Barreto <19890735+rogerbarreto@users.noreply.github.com> --- .../src/Microsoft.Agents.AI.Abstractions/AIContextProvider.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/AIContextProvider.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/AIContextProvider.cs index 8428d46f9b..bce1bf2cb8 100644 --- a/dotnet/src/Microsoft.Agents.AI.Abstractions/AIContextProvider.cs +++ b/dotnet/src/Microsoft.Agents.AI.Abstractions/AIContextProvider.cs @@ -140,7 +140,7 @@ public InvokingContext( { this.Agent = Throw.IfNull(agent); this.Session = session; - this.RequestMessages = requestMessages ?? throw new ArgumentNullException(nameof(requestMessages)); + this.RequestMessages = Throw.IfNull(requestMessages); } /// From debb5ab57c4b9533ef7f45bd76c18dad777c7feb Mon Sep 17 00:00:00 2001 From: westey <164392973+westey-m@users.noreply.github.com> Date: Thu, 5 Feb 2026 14:43:52 +0000 Subject: [PATCH 4/5] Apply suggestion from @westey-m --- .../src/Microsoft.Agents.AI.Abstractions/ChatHistoryProvider.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/ChatHistoryProvider.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/ChatHistoryProvider.cs index cecfa92e8f..352bae3355 100644 --- a/dotnet/src/Microsoft.Agents.AI.Abstractions/ChatHistoryProvider.cs +++ b/dotnet/src/Microsoft.Agents.AI.Abstractions/ChatHistoryProvider.cs @@ -154,7 +154,7 @@ public InvokingContext( { this.Agent = Throw.IfNull(agent); this.Session = session; - this.RequestMessages = requestMessages ?? throw new ArgumentNullException(nameof(requestMessages)); + this.RequestMessages = Throw.IfNull(requestMessages); } /// From c5eddc2a334b840d00f614568343321fb4384768 Mon Sep 17 00:00:00 2001 From: westey <164392973+westey-m@users.noreply.github.com> Date: Thu, 5 Feb 2026 14:44:08 +0000 Subject: [PATCH 5/5] Apply suggestion from @westey-m --- .../src/Microsoft.Agents.AI.Abstractions/AIContextProvider.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/AIContextProvider.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/AIContextProvider.cs index bce1bf2cb8..f79b0a851d 100644 --- a/dotnet/src/Microsoft.Agents.AI.Abstractions/AIContextProvider.cs +++ b/dotnet/src/Microsoft.Agents.AI.Abstractions/AIContextProvider.cs @@ -188,7 +188,7 @@ public InvokedContext( { this.Agent = Throw.IfNull(agent); this.Session = session; - this.RequestMessages = requestMessages ?? throw new ArgumentNullException(nameof(requestMessages)); + this.RequestMessages = Throw.IfNull(requestMessages); this.AIContextProviderMessages = aiContextProviderMessages; }