diff --git a/dotnet/samples/GettingStarted/AgentProviders/Agent_With_CustomImplementation/Program.cs b/dotnet/samples/GettingStarted/AgentProviders/Agent_With_CustomImplementation/Program.cs index dd294040c4..bebb3a01d1 100644 --- a/dotnet/samples/GettingStarted/AgentProviders/Agent_With_CustomImplementation/Program.cs +++ b/dotnet/samples/GettingStarted/AgentProviders/Agent_With_CustomImplementation/Program.cs @@ -6,6 +6,7 @@ using System.Runtime.CompilerServices; using System.Text.Json; +using System.Text.Json.Serialization; using Microsoft.Agents.AI; using Microsoft.Extensions.AI; using SampleApp; @@ -28,6 +29,8 @@ internal sealed class UpperCaseParrotAgent : AIAgent { public override string? Name => "UpperCaseParrotAgent"; + public readonly ChatHistoryProvider ChatHistoryProvider = new InMemoryChatHistoryProvider(); + protected override ValueTask CreateSessionCoreAsync(CancellationToken cancellationToken = default) => new(new CustomAgentSession()); @@ -38,11 +41,11 @@ protected override JsonElement SerializeSessionCore(AgentSession session, JsonSe throw new ArgumentException($"The provided session is not of type {nameof(CustomAgentSession)}.", nameof(session)); } - return typedSession.Serialize(jsonSerializerOptions); + return JsonSerializer.SerializeToElement(typedSession, jsonSerializerOptions); } protected override ValueTask DeserializeSessionCoreAsync(JsonElement serializedState, JsonSerializerOptions? jsonSerializerOptions = null, CancellationToken cancellationToken = default) - => new(new CustomAgentSession(serializedState, jsonSerializerOptions)); + => new(serializedState.Deserialize(jsonSerializerOptions)!); protected override async Task RunCoreAsync(IEnumerable messages, AgentSession? session = null, AgentRunOptions? options = null, CancellationToken cancellationToken = default) { @@ -56,7 +59,7 @@ protected override async Task RunCoreAsync(IEnumerable responseMessages = CloneAndToUpperCase(messages, this.Name).ToList(); @@ -66,7 +69,7 @@ protected override async Task RunCoreAsync(IEnumerable RunCoreStreamingA // Get existing messages from the store var invokingContext = new ChatHistoryProvider.InvokingContext(this, session, messages); - var storeMessages = await typedSession.ChatHistoryProvider.InvokingAsync(invokingContext, cancellationToken); + var storeMessages = await this.ChatHistoryProvider.InvokingAsync(invokingContext, cancellationToken); // Clone the input messages and turn them into response messages with upper case text. List responseMessages = CloneAndToUpperCase(messages, this.Name).ToList(); @@ -98,7 +101,7 @@ protected override async IAsyncEnumerable RunCoreStreamingA { ResponseMessages = responseMessages }; - await typedSession.ChatHistoryProvider.InvokedAsync(invokedContext, cancellationToken); + await this.ChatHistoryProvider.InvokedAsync(invokedContext, cancellationToken); foreach (var message in responseMessages) { @@ -140,15 +143,16 @@ private static IEnumerable CloneAndToUpperCase(IEnumerable /// A session type for our custom agent that only supports in memory storage of messages. /// - internal sealed class CustomAgentSession : InMemoryAgentSession + internal sealed class CustomAgentSession : AgentSession { - internal CustomAgentSession() { } - - internal CustomAgentSession(JsonElement serializedSessionState, JsonSerializerOptions? jsonSerializerOptions = null) - : base(serializedSessionState, jsonSerializerOptions) { } + internal CustomAgentSession() + { + } - internal new JsonElement Serialize(JsonSerializerOptions? jsonSerializerOptions = null) - => base.Serialize(jsonSerializerOptions); + [JsonConstructor] + internal CustomAgentSession(AgentSessionStateBag stateBag) : base(stateBag) + { + } } } } diff --git a/dotnet/samples/GettingStarted/AgentWithMemory/AgentWithMemory_Step01_ChatHistoryMemory/Program.cs b/dotnet/samples/GettingStarted/AgentWithMemory/AgentWithMemory_Step01_ChatHistoryMemory/Program.cs index ec55abf3a4..54ee8b5008 100644 --- a/dotnet/samples/GettingStarted/AgentWithMemory/AgentWithMemory_Step01_ChatHistoryMemory/Program.cs +++ b/dotnet/samples/GettingStarted/AgentWithMemory/AgentWithMemory_Step01_ChatHistoryMemory/Program.cs @@ -34,16 +34,21 @@ { ChatOptions = new() { Instructions = "You are good at telling jokes." }, Name = "Joker", - AIContextProviderFactory = (ctx, ct) => new ValueTask(new ChatHistoryMemoryProvider( + AIContextProvider = new ChatHistoryMemoryProvider( vectorStore, collectionName: "chathistory", vectorDimensions: 3072, - // Configure the scope values under which chat messages will be stored. - // In this case, we are using a fixed user ID and a unique session ID for each new session. - storageScope: new() { UserId = "UID1", SessionId = Guid.NewGuid().ToString() }, - // Configure the scope which would be used to search for relevant prior messages. - // In this case, we are searching for any messages for the user across all sessions. - searchScope: new() { UserId = "UID1" })) + // Callback to configure the initial state of the ChatHistoryMemoryProvider. + // The ChatHistoryMemoryProvider stores its state in the AgentSession and this callback + // will be called whenever the ChatHistoryMemoryProvider cannot find existing state in the session, + // typically the first time it is used with a new session. + session => new ChatHistoryMemoryProvider.State( + // Configure the scope values under which chat messages will be stored. + // In this case, we are using a fixed user ID and a unique session ID for each new session. + storageScope: new() { UserId = "UID1", SessionId = Guid.NewGuid().ToString() }, + // Configure the scope which would be used to search for relevant prior messages. + // In this case, we are searching for any messages for the user across all sessions. + searchScope: new() { UserId = "UID1" })) }); // Start a new session for the agent conversation. diff --git a/dotnet/samples/GettingStarted/AgentWithMemory/AgentWithMemory_Step02_MemoryUsingMem0/Program.cs b/dotnet/samples/GettingStarted/AgentWithMemory/AgentWithMemory_Step02_MemoryUsingMem0/Program.cs index 4a0dbe0839..588c79ba9a 100644 --- a/dotnet/samples/GettingStarted/AgentWithMemory/AgentWithMemory_Step02_MemoryUsingMem0/Program.cs +++ b/dotnet/samples/GettingStarted/AgentWithMemory/AgentWithMemory_Step02_MemoryUsingMem0/Program.cs @@ -31,20 +31,21 @@ .AsAIAgent(new ChatClientAgentOptions() { ChatOptions = new() { Instructions = "You are a friendly travel assistant. Use known memories about the user when responding, and do not invent details." }, - AIContextProviderFactory = (ctx, ct) => new ValueTask(ctx.SerializedState.ValueKind is not JsonValueKind.Null and not JsonValueKind.Undefined - // If each session should have its own Mem0 scope, you can create a new id per session here: - // ? new Mem0Provider(mem0HttpClient, new Mem0ProviderScope() { ThreadId = Guid.NewGuid().ToString() }) - // In this case we are storing memories scoped by application and user instead so that memories are retained across threads. - ? new Mem0Provider(mem0HttpClient, new Mem0ProviderScope() { ApplicationId = "getting-started-agents", UserId = "sample-user" }) - // For cases where we are restoring from serialized state: - : new Mem0Provider(mem0HttpClient, ctx.SerializedState, ctx.JsonSerializerOptions)) + // The stateInitializer can be used to customize the Mem0 scope per session and it will be called each time a session + // is encountered by the Mem0Provider that does not already have Mem0Provider state stored on the session. + // If each session should have its own Mem0 scope, you can create a new id per session via the stateInitializer, e.g.: + // new Mem0Provider(mem0HttpClient, stateInitializer: _ => new(new Mem0ProviderScope() { ThreadId = Guid.NewGuid().ToString() })) + // In our case we are storing memories scoped by application and user instead so that memories are retained across threads. + AIContextProvider = new Mem0Provider(mem0HttpClient, stateInitializer: _ => new(new Mem0ProviderScope() { ApplicationId = "getting-started-agents", UserId = "sample-user" })) }); AgentSession session = await agent.CreateSessionAsync(); // Clear any existing memories for this scope to demonstrate fresh behavior. -Mem0Provider mem0Provider = session.GetService()!; -await mem0Provider.ClearStoredMemoriesAsync(); +// Note that the ClearStoredMemoriesAsync method will clear memories +// using the scope stored in the session, or provided via the stateInitializer. +Mem0Provider mem0Provider = agent.GetService()!; +await mem0Provider.ClearStoredMemoriesAsync(session); Console.WriteLine(await agent.RunAsync("Hi there! My name is Taylor and I'm planning a hiking trip to Patagonia in November.", session)); Console.WriteLine(await agent.RunAsync("I'm travelling with my sister and we love finding scenic viewpoints.", session)); diff --git a/dotnet/samples/GettingStarted/AgentWithMemory/AgentWithMemory_Step03_CustomMemory/Program.cs b/dotnet/samples/GettingStarted/AgentWithMemory/AgentWithMemory_Step03_CustomMemory/Program.cs index edd9248ff9..2945f46df8 100644 --- a/dotnet/samples/GettingStarted/AgentWithMemory/AgentWithMemory_Step03_CustomMemory/Program.cs +++ b/dotnet/samples/GettingStarted/AgentWithMemory/AgentWithMemory_Step03_CustomMemory/Program.cs @@ -33,7 +33,7 @@ AIAgent agent = chatClient.AsAIAgent(new ChatClientAgentOptions() { ChatOptions = new() { Instructions = "You are a friendly assistant. Always address the user by their name." }, - AIContextProviderFactory = (ctx, ct) => new ValueTask(new UserInfoMemory(chatClient.AsIChatClient(), ctx.SerializedState, ctx.JsonSerializerOptions)) + AIContextProvider = new UserInfoMemory(chatClient.AsIChatClient()) }); // Create a new session for the conversation. @@ -55,10 +55,10 @@ var deserializedSession = await agent.DeserializeSessionAsync(sesionElement); Console.WriteLine(await agent.RunAsync("What is my name and age?", deserializedSession)); -Console.WriteLine("\n>> Read memories from memory component\n"); +Console.WriteLine("\n>> Read memories using memory component\n"); -// It's possible to access the memory component via the session's GetService method. -var userInfo = deserializedSession.GetService()?.UserInfo; +// It's possible to access the memory component via the agent's GetService method. +var userInfo = agent.GetService()?.GetUserInfo(deserializedSession); // Output the user info that was captured by the memory component. Console.WriteLine($"MEMORY - User Name: {userInfo?.UserName}"); @@ -66,12 +66,12 @@ Console.WriteLine("\n>> Use new session with previously created memories\n"); -// It is also possible to set the memories in a memory component on an individual session. +// It is also possible to set the memories using a memory component on an individual session. // This is useful if we want to start a new session, but have it share the same memories as a previous session. var newSession = await agent.CreateSessionAsync(); -if (userInfo is not null && newSession.GetService() is UserInfoMemory newSessionMemory) +if (userInfo is not null && agent.GetService() is UserInfoMemory newSessionMemory) { - newSessionMemory.UserInfo = userInfo; + newSessionMemory.SetUserInfo(newSession, userInfo); } // Invoke the agent and output the text result. @@ -86,28 +86,27 @@ namespace SampleApp internal sealed class UserInfoMemory : AIContextProvider { private readonly IChatClient _chatClient; + private readonly Func _stateInitializer; - public UserInfoMemory(IChatClient chatClient, UserInfo? userInfo = null) + public UserInfoMemory(IChatClient chatClient, Func? stateInitializer = null) { this._chatClient = chatClient; - this.UserInfo = userInfo ?? new UserInfo(); + this._stateInitializer = stateInitializer ?? (_ => new UserInfo()); } - public UserInfoMemory(IChatClient chatClient, JsonElement serializedState, JsonSerializerOptions? jsonSerializerOptions = null) - { - this._chatClient = chatClient; + public UserInfo GetUserInfo(AgentSession session) + => session.StateBag.GetValue(nameof(UserInfoMemory)) ?? new UserInfo(); - this.UserInfo = serializedState.ValueKind == JsonValueKind.Object ? - serializedState.Deserialize(jsonSerializerOptions)! : - new UserInfo(); - } - - public UserInfo UserInfo { get; set; } + public void SetUserInfo(AgentSession session, UserInfo userInfo) + => session.StateBag.SetValue(nameof(UserInfoMemory), userInfo); protected override async ValueTask InvokedCoreAsync(InvokedContext context, CancellationToken cancellationToken = default) { + var userInfo = context.Session?.StateBag.GetValue(nameof(UserInfoMemory)) + ?? this._stateInitializer.Invoke(context.Session); + // Try and extract the user name and age from the message if we don't have it already and it's a user message. - if ((this.UserInfo.UserName is null || this.UserInfo.UserAge is null) && context.RequestMessages.Any(x => x.Role == ChatRole.User)) + if ((userInfo.UserName is null || userInfo.UserAge is null) && context.RequestMessages.Any(x => x.Role == ChatRole.User)) { var result = await this._chatClient.GetResponseAsync( context.RequestMessages, @@ -117,36 +116,36 @@ protected override async ValueTask InvokedCoreAsync(InvokedContext context, Canc }, cancellationToken: cancellationToken); - this.UserInfo.UserName ??= result.Result.UserName; - this.UserInfo.UserAge ??= result.Result.UserAge; + userInfo.UserName ??= result.Result.UserName; + userInfo.UserAge ??= result.Result.UserAge; } + + context.Session?.StateBag.SetValue(nameof(UserInfoMemory), userInfo); } protected override ValueTask InvokingCoreAsync(InvokingContext context, CancellationToken cancellationToken = default) { + var userInfo = context.Session?.StateBag.GetValue(nameof(UserInfoMemory)) + ?? this._stateInitializer.Invoke(context.Session); + StringBuilder instructions = new(); // If we don't already know the user's name and age, add instructions to ask for them, otherwise just provide what we have to the context. instructions .AppendLine( - this.UserInfo.UserName is null ? + userInfo.UserName is null ? "Ask the user for their name and politely decline to answer any questions until they provide it." : - $"The user's name is {this.UserInfo.UserName}.") + $"The user's name is {userInfo.UserName}.") .AppendLine( - this.UserInfo.UserAge is null ? + userInfo.UserAge is null ? "Ask the user for their age and politely decline to answer any questions until they provide it." : - $"The user's age is {this.UserInfo.UserAge}."); + $"The user's age is {userInfo.UserAge}."); return new ValueTask(new AIContext { Instructions = instructions.ToString() }); } - - public override JsonElement Serialize(JsonSerializerOptions? jsonSerializerOptions = null) - { - return JsonSerializer.SerializeToElement(this.UserInfo, jsonSerializerOptions); - } } internal sealed class UserInfo diff --git a/dotnet/samples/GettingStarted/AgentWithRAG/AgentWithRAG_Step01_BasicTextRAG/Program.cs b/dotnet/samples/GettingStarted/AgentWithRAG/AgentWithRAG_Step01_BasicTextRAG/Program.cs index ca798aa333..a3dfae995f 100644 --- a/dotnet/samples/GettingStarted/AgentWithRAG/AgentWithRAG_Step01_BasicTextRAG/Program.cs +++ b/dotnet/samples/GettingStarted/AgentWithRAG/AgentWithRAG_Step01_BasicTextRAG/Program.cs @@ -62,12 +62,12 @@ .AsAIAgent(new ChatClientAgentOptions { ChatOptions = new() { Instructions = "You are a helpful support specialist for Contoso Outdoors. Answer questions using the provided context and cite the source document when available." }, - AIContextProviderFactory = (ctx, ct) => new ValueTask(new TextSearchProvider(SearchAdapter, ctx.SerializedState, ctx.JsonSerializerOptions, textSearchOptions)), + AIContextProvider = new TextSearchProvider(SearchAdapter, textSearchOptions), // Since we are using ChatCompletion which stores chat history locally, we can also add a message removal policy // that removes messages produced by the TextSearchProvider before they are added to the chat history, so that // we don't bloat chat history with all the search result messages. - ChatHistoryProviderFactory = (ctx, ct) => new ValueTask(new InMemoryChatHistoryProvider(ctx.SerializedState, ctx.JsonSerializerOptions) - .WithAIContextProviderMessageRemoval()), + ChatHistoryProvider = new InMemoryChatHistoryProvider() + .WithAIContextProviderMessageRemoval(), }); AgentSession session = await agent.CreateSessionAsync(); diff --git a/dotnet/samples/GettingStarted/AgentWithRAG/AgentWithRAG_Step02_CustomVectorStoreRAG/Program.cs b/dotnet/samples/GettingStarted/AgentWithRAG/AgentWithRAG_Step02_CustomVectorStoreRAG/Program.cs index 3648ccc898..4e135db78c 100644 --- a/dotnet/samples/GettingStarted/AgentWithRAG/AgentWithRAG_Step02_CustomVectorStoreRAG/Program.cs +++ b/dotnet/samples/GettingStarted/AgentWithRAG/AgentWithRAG_Step02_CustomVectorStoreRAG/Program.cs @@ -71,7 +71,7 @@ .AsAIAgent(new ChatClientAgentOptions { ChatOptions = new() { Instructions = "You are a helpful support specialist for the Microsoft Agent Framework. Answer questions using the provided context and cite the source document when available. Keep responses brief." }, - AIContextProviderFactory = (ctx, ct) => new ValueTask(new TextSearchProvider(SearchAdapter, ctx.SerializedState, ctx.JsonSerializerOptions, textSearchOptions)) + AIContextProvider = new TextSearchProvider(SearchAdapter, textSearchOptions) }); AgentSession session = await agent.CreateSessionAsync(); diff --git a/dotnet/samples/GettingStarted/AgentWithRAG/AgentWithRAG_Step03_CustomRAGDataSource/Program.cs b/dotnet/samples/GettingStarted/AgentWithRAG/AgentWithRAG_Step03_CustomRAGDataSource/Program.cs index bcc823de46..6db583ca41 100644 --- a/dotnet/samples/GettingStarted/AgentWithRAG/AgentWithRAG_Step03_CustomRAGDataSource/Program.cs +++ b/dotnet/samples/GettingStarted/AgentWithRAG/AgentWithRAG_Step03_CustomRAGDataSource/Program.cs @@ -29,7 +29,7 @@ .AsAIAgent(new ChatClientAgentOptions { ChatOptions = new() { Instructions = "You are a helpful support specialist for Contoso Outdoors. Answer questions using the provided context and cite the source document when available." }, - AIContextProviderFactory = (ctx, ct) => new ValueTask(new TextSearchProvider(MockSearchAsync, ctx.SerializedState, ctx.JsonSerializerOptions, textSearchOptions)) + AIContextProvider = new TextSearchProvider(MockSearchAsync, textSearchOptions) }); AgentSession session = await agent.CreateSessionAsync(); diff --git a/dotnet/samples/GettingStarted/Agents/Agent_Step07_3rdPartyChatHistoryStorage/Program.cs b/dotnet/samples/GettingStarted/Agents/Agent_Step07_3rdPartyChatHistoryStorage/Program.cs index 33176d8fdf..a044ac496c 100644 --- a/dotnet/samples/GettingStarted/Agents/Agent_Step07_3rdPartyChatHistoryStorage/Program.cs +++ b/dotnet/samples/GettingStarted/Agents/Agent_Step07_3rdPartyChatHistoryStorage/Program.cs @@ -3,7 +3,7 @@ #pragma warning disable CA1869 // Cache and reuse 'JsonSerializerOptions' instances // This sample shows how to create and use a simple AI agent with custom ChatHistoryProvider that stores chat history in a custom storage location. -// The state of the custom ChatHistoryProvider (SessionDbKey) is stored with the agent session, so that when the session is resumed later, +// The state of the custom ChatHistoryProvider (SessionDbKey) is stored in the AgentSession's StateBag, so that when the session is resumed later, // the chat history can be retrieved from the custom storage location. using System.Text.Json; @@ -33,11 +33,8 @@ { ChatOptions = new() { Instructions = "You are good at telling jokes." }, Name = "Joker", - ChatHistoryProviderFactory = (ctx, ct) => new ValueTask( - // Create a new ChatHistoryProvider for this agent that stores chat history in a vector store. - // Each session must get its own copy of the VectorChatHistoryProvider, since the provider - // also contains the id that the chat history is stored under. - new VectorChatHistoryProvider(vectorStore, ctx.SerializedState, ctx.JsonSerializerOptions)) + // Create a new ChatHistoryProvider for this agent that stores chat history in a vector store. + ChatHistoryProvider = new VectorChatHistoryProvider(vectorStore) }); // Start a new session for the agent conversation. @@ -63,46 +60,70 @@ // Run the agent with the session that stores chat history in the vector store a second time. Console.WriteLine(await agent.RunAsync("Now tell the same joke in the voice of a pirate, and add some emojis to the joke.", resumedSession)); -// We can access the VectorChatHistoryProvider via the session's GetService method if we need to read the key under which chat history is stored. -var chatHistoryProvider = resumedSession.GetService()!; -Console.WriteLine($"\nSession is stored in vector store under key: {chatHistoryProvider.SessionDbKey}"); +// We can access the VectorChatHistoryProvider via the agent's GetService method +// if we need to read the key under which chat history is stored. The key is stored +// in the session state, and therefore we need to provide the session when reading it. +var chatHistoryProvider = agent.GetService()!; +Console.WriteLine($"\nSession is stored in vector store under key: {chatHistoryProvider.GetSessionDbKey(resumedSession)}"); namespace SampleApp { /// /// A sample implementation of that stores chat history in a vector store. + /// State (the session DB key) is stored in the so it roundtrips + /// automatically with session serialization. /// internal sealed class VectorChatHistoryProvider : ChatHistoryProvider { + private const string DefaultStateBagKey = "VectorChatHistoryProvider.State"; + private readonly VectorStore _vectorStore; + private readonly Func _stateInitializer; + private readonly string _stateKey; - public VectorChatHistoryProvider(VectorStore vectorStore, JsonElement serializedState, JsonSerializerOptions? jsonSerializerOptions = null) + public VectorChatHistoryProvider( + VectorStore vectorStore, + Func? stateInitializer = null, + string? stateKey = null) { this._vectorStore = vectorStore ?? throw new ArgumentNullException(nameof(vectorStore)); + this._stateInitializer = stateInitializer ?? (_ => new State(Guid.NewGuid().ToString("N"))); + this._stateKey = stateKey ?? DefaultStateBagKey; + } + + public string GetSessionDbKey(AgentSession session) + => this.GetOrInitializeState(session).SessionDbKey; + + private State GetOrInitializeState(AgentSession? session) + { + if (session?.StateBag.TryGetValue(this._stateKey, out var state) is true && state is not null) + { + return state; + } - if (serializedState.ValueKind is JsonValueKind.String) + state = this._stateInitializer(session); + if (session is not null) { - // Here we can deserialize the session id so that we can access the same messages as before the suspension. - this.SessionDbKey = serializedState.Deserialize(); + session.StateBag.SetValue(this._stateKey, state); } - } - public string? SessionDbKey { get; private set; } + return state; + } protected override async ValueTask> InvokingCoreAsync(InvokingContext context, CancellationToken cancellationToken = default) { + var state = this.GetOrInitializeState(context.Session); var collection = this._vectorStore.GetCollection("ChatHistory"); await collection.EnsureCollectionExistsAsync(cancellationToken); var records = await collection .GetAsync( - x => x.SessionId == this.SessionDbKey, 10, + x => x.SessionId == state.SessionDbKey, 10, new() { OrderBy = x => x.Descending(y => y.Timestamp) }, cancellationToken) .ToListAsync(cancellationToken); - var messages = records.ConvertAll(x => JsonSerializer.Deserialize(x.SerializedMessage!)!) -; + var messages = records.ConvertAll(x => JsonSerializer.Deserialize(x.SerializedMessage!)!); messages.Reverse(); return messages; } @@ -115,7 +136,7 @@ protected override async ValueTask InvokedCoreAsync(InvokedContext context, Canc return; } - this.SessionDbKey ??= Guid.NewGuid().ToString("N"); + var state = this.GetOrInitializeState(context.Session); var collection = this._vectorStore.GetCollection("ChatHistory"); await collection.EnsureCollectionExistsAsync(cancellationToken); @@ -126,17 +147,26 @@ protected override async ValueTask InvokedCoreAsync(InvokedContext context, Canc await collection.UpsertAsync(allNewMessages.Select(x => new ChatHistoryItem() { - Key = this.SessionDbKey + x.MessageId, + Key = state.SessionDbKey + x.MessageId, Timestamp = DateTimeOffset.UtcNow, - SessionId = this.SessionDbKey, + SessionId = state.SessionDbKey, SerializedMessage = JsonSerializer.Serialize(x), MessageText = x.Text }), cancellationToken); } - public override JsonElement Serialize(JsonSerializerOptions? jsonSerializerOptions = null) => - // We have to serialize the session id, so that on deserialization we can retrieve the messages using the same session id. - JsonSerializer.SerializeToElement(this.SessionDbKey); + /// + /// Represents the per-session state stored in the . + /// + public sealed class State + { + public State(string sessionDbKey) + { + this.SessionDbKey = sessionDbKey ?? throw new ArgumentNullException(nameof(sessionDbKey)); + } + + public string SessionDbKey { get; } + } /// /// The data structure used to store chat history items in the vector store. diff --git a/dotnet/samples/GettingStarted/Agents/Agent_Step16_ChatReduction/Program.cs b/dotnet/samples/GettingStarted/Agents/Agent_Step16_ChatReduction/Program.cs index f9a5a1fc01..1989dbc2d6 100644 --- a/dotnet/samples/GettingStarted/Agents/Agent_Step16_ChatReduction/Program.cs +++ b/dotnet/samples/GettingStarted/Agents/Agent_Step16_ChatReduction/Program.cs @@ -24,7 +24,7 @@ { ChatOptions = new() { Instructions = "You are good at telling jokes." }, Name = "Joker", - ChatHistoryProviderFactory = (ctx, ct) => new ValueTask(new InMemoryChatHistoryProvider(new MessageCountingChatReducer(2), ctx.SerializedState, ctx.JsonSerializerOptions)) + ChatHistoryProvider = new InMemoryChatHistoryProvider(new() { ChatReducer = new MessageCountingChatReducer(2) }) }); AgentSession session = await agent.CreateSessionAsync(); @@ -33,7 +33,10 @@ Console.WriteLine(await agent.RunAsync("Tell me a joke about a pirate.", session)); // Get the chat history to see how many messages are stored. -IList? chatHistory = session.GetService>(); +// We can use the ChatHistoryProvider, that is also used by the agent, to read the +// chat history from the session state, and see how the reducer is affecting the stored messages. +var provider = agent.GetService(); +List? chatHistory = provider?.GetMessages(session); Console.WriteLine($"\nChat history has {chatHistory?.Count} messages.\n"); // Invoke the agent a few more times. diff --git a/dotnet/samples/GettingStarted/Agents/Agent_Step20_AdditionalAIContext/Program.cs b/dotnet/samples/GettingStarted/Agents/Agent_Step20_AdditionalAIContext/Program.cs index 055172fa82..a8fb2f3df5 100644 --- a/dotnet/samples/GettingStarted/Agents/Agent_Step20_AdditionalAIContext/Program.cs +++ b/dotnet/samples/GettingStarted/Agents/Agent_Step20_AdditionalAIContext/Program.cs @@ -7,7 +7,6 @@ #pragma warning disable CA1869 // Cache and reuse 'JsonSerializerOptions' instances -using System.ComponentModel; using System.Text; using System.Text.Json; using Azure.AI.OpenAI; @@ -45,16 +44,16 @@ You are a helpful personal assistant. You manage a TODO list for the user. When the user has completed one of the tasks it can be removed from the TODO list. Only provide the list of TODO items if asked. You remind users of upcoming calendar events when the user interacts with you. """ }, - ChatHistoryProviderFactory = (ctx, ct) => new ValueTask(new InMemoryChatHistoryProvider() + ChatHistoryProvider = new InMemoryChatHistoryProvider() // Use WithAIContextProviderMessageRemoval, so that we don't store the messages from the AI context provider in the chat history. // You may want to store these messages, depending on their content and your requirements. - .WithAIContextProviderMessageRemoval()), + .WithAIContextProviderMessageRemoval(), // Add an AI context provider that maintains a todo list for the agent and one that provides upcoming calendar entries. // Wrap these in an AI context provider that aggregates the other two. - AIContextProviderFactory = (ctx, ct) => new ValueTask(new AggregatingAIContextProvider([ - AggregatingAIContextProvider.CreateFactory((jsonElement, jsonSerializerOptions) => new TodoListAIContextProvider(jsonElement, jsonSerializerOptions)), - AggregatingAIContextProvider.CreateFactory((_, _) => new CalendarSearchAIContextProvider(loadNextThreeCalendarEvents)) - ], ctx.SerializedState, ctx.JsonSerializerOptions)), + AIContextProvider = new AggregatingAIContextProvider([ + new TodoListAIContextProvider(), + new CalendarSearchAIContextProvider(loadNextThreeCalendarEvents) + ]), }); // Invoke the agent and output the text result. @@ -80,51 +79,62 @@ namespace SampleApp /// internal sealed class TodoListAIContextProvider : AIContextProvider { - private readonly List _todoItems = new(); + private const string StateKey = nameof(TodoListAIContextProvider); - public TodoListAIContextProvider(JsonElement jsonElement, JsonSerializerOptions? jsonSerializerOptions = null) - { - // Only try and restore the state if we got an array, since any other json would be invalid or undefined/null meaning - // it's the first time we are running. - if (jsonElement.ValueKind == JsonValueKind.Array) - { - this._todoItems = JsonSerializer.Deserialize>(jsonElement.GetRawText(), jsonSerializerOptions) ?? new List(); - } - } + private static List GetTodoItems(AgentSession? session) + => session?.StateBag.GetValue>(StateKey) ?? new List(); + + private static void SetTodoItems(AgentSession? session, List items) + => session?.StateBag.SetValue(StateKey, items); protected override ValueTask InvokingCoreAsync(InvokingContext context, CancellationToken cancellationToken = default) { + var todoItems = GetTodoItems(context.Session); + StringBuilder outputMessageBuilder = new(); outputMessageBuilder.AppendLine("Your todo list contains the following items:"); - if (this._todoItems.Count == 0) + if (todoItems.Count == 0) { outputMessageBuilder.AppendLine(" (no items)"); } else { - for (int i = 0; i < this._todoItems.Count; i++) + for (int i = 0; i < todoItems.Count; i++) { - outputMessageBuilder.AppendLine($"{i}. {this._todoItems[i]}"); + outputMessageBuilder.AppendLine($"{i}. {todoItems[i]}"); } } return new ValueTask(new AIContext { - Tools = [AIFunctionFactory.Create(this.AddTodoItem), AIFunctionFactory.Create(this.RemoveTodoItem)], + Tools = + [ + AIFunctionFactory.Create((string item) => AddTodoItem(context.Session, item), "AddTodoItem", "Adds an item to the todo list."), + AIFunctionFactory.Create((int index) => RemoveTodoItem(context.Session, index), "RemoveTodoItem", "Removes an item from the todo list. Index is zero based.") + ], Messages = [new MEAI.ChatMessage(ChatRole.User, outputMessageBuilder.ToString())] }); } - [Description("Adds an item to the todo list. Index is zero based.")] - private void RemoveTodoItem(int index) => - this._todoItems.RemoveAt(index); + private static void RemoveTodoItem(AgentSession? session, int index) + { + var items = GetTodoItems(session); + items.RemoveAt(index); + SetTodoItems(session, items); + } - private void AddTodoItem(string item) => - this._todoItems.Add(string.IsNullOrWhiteSpace(item) ? throw new ArgumentException("Item must have a value") : item); + private static void AddTodoItem(AgentSession? session, string item) + { + if (string.IsNullOrWhiteSpace(item)) + { + throw new ArgumentException("Item must have a value"); + } - public override JsonElement Serialize(JsonSerializerOptions? jsonSerializerOptions = null) => - JsonSerializer.SerializeToElement(this._todoItems, jsonSerializerOptions); + var items = GetTodoItems(session); + items.Add(item); + SetTodoItems(session, items); + } } /// @@ -155,28 +165,15 @@ protected override async ValueTask InvokingCoreAsync(InvokingContext /// /// An which aggregates multiple AI context providers into one. - /// Serialized state for the different providers are stored under their type name. /// Tools and messages from all providers are combined, and instructions are concatenated. /// internal sealed class AggregatingAIContextProvider : AIContextProvider { - private readonly List _providers = new(); + private readonly List _providers; - public AggregatingAIContextProvider(ProviderFactory[] providerFactories, JsonElement jsonElement, JsonSerializerOptions? jsonSerializerOptions) + public AggregatingAIContextProvider(List providers) { - // We received a json object, so let's check if it has some previously serialized state that we can use. - if (jsonElement.ValueKind == JsonValueKind.Object) - { - this._providers = providerFactories - .Select(factory => factory.FactoryMethod(jsonElement.TryGetProperty(factory.ProviderType.Name, out var prop) ? prop : default, jsonSerializerOptions)) - .ToList(); - return; - } - - // We didn't receive any valid json, so we can just construct fresh providers. - this._providers = providerFactories - .Select(factory => factory.FactoryMethod(default, jsonSerializerOptions)) - .ToList(); + this._providers = providers; } protected override async ValueTask InvokingCoreAsync(InvokingContext context, CancellationToken cancellationToken = default) @@ -193,36 +190,5 @@ protected override async ValueTask InvokingCoreAsync(InvokingContext Instructions = string.Join("\n", results.Select(r => r.Instructions).Where(s => !string.IsNullOrEmpty(s))) }; } - - public override JsonElement Serialize(JsonSerializerOptions? jsonSerializerOptions = null) - { - Dictionary elements = new(); - foreach (var provider in this._providers) - { - JsonElement element = provider.Serialize(jsonSerializerOptions); - - // Don't try to store state for any providers that aren't producing any. - if (element.ValueKind != JsonValueKind.Undefined && element.ValueKind != JsonValueKind.Null) - { - elements[provider.GetType().Name] = element; - } - } - - return JsonSerializer.SerializeToElement(elements, jsonSerializerOptions); - } - - public static ProviderFactory CreateFactory(Func factoryMethod) - where TProviderType : AIContextProvider => new() - { - FactoryMethod = (jsonElement, jsonSerializerOptions) => factoryMethod(jsonElement, jsonSerializerOptions), - ProviderType = typeof(TProviderType) - }; - - public readonly struct ProviderFactory - { - public Func FactoryMethod { get; init; } - - public Type ProviderType { get; init; } - } } } diff --git a/dotnet/src/Microsoft.Agents.AI.A2A/A2AAgent.cs b/dotnet/src/Microsoft.Agents.AI.A2A/A2AAgent.cs index 533b50c8fe..9d819c90d8 100644 --- a/dotnet/src/Microsoft.Agents.AI.A2A/A2AAgent.cs +++ b/dotnet/src/Microsoft.Agents.AI.A2A/A2AAgent.cs @@ -80,7 +80,7 @@ protected override JsonElement SerializeSessionCore(AgentSession session, JsonSe /// protected override ValueTask DeserializeSessionCoreAsync(JsonElement serializedState, JsonSerializerOptions? jsonSerializerOptions = null, CancellationToken cancellationToken = default) - => new(new A2AAgentSession(serializedState, jsonSerializerOptions)); + => new(A2AAgentSession.Deserialize(serializedState, jsonSerializerOptions)); /// protected override async Task RunCoreAsync(IEnumerable messages, AgentSession? session = null, AgentRunOptions? options = null, CancellationToken cancellationToken = default) diff --git a/dotnet/src/Microsoft.Agents.AI.A2A/A2AAgentSession.cs b/dotnet/src/Microsoft.Agents.AI.A2A/A2AAgentSession.cs index cac9b43a30..045abc736a 100644 --- a/dotnet/src/Microsoft.Agents.AI.A2A/A2AAgentSession.cs +++ b/dotnet/src/Microsoft.Agents.AI.A2A/A2AAgentSession.cs @@ -1,66 +1,61 @@ // Copyright (c) Microsoft. All rights reserved. using System; +using System.Diagnostics; using System.Text.Json; +using System.Text.Json.Serialization; namespace Microsoft.Agents.AI.A2A; /// /// Session for A2A based agents. /// +[DebuggerDisplay("{DebuggerDisplay,nq}")] public sealed class A2AAgentSession : AgentSession { internal A2AAgentSession() { } - internal A2AAgentSession(JsonElement serializedSessionState, JsonSerializerOptions? jsonSerializerOptions = null) + [JsonConstructor] + internal A2AAgentSession(string? contextId, string? taskId, AgentSessionStateBag? stateBag) : base(stateBag ?? new()) { - if (serializedSessionState.ValueKind != JsonValueKind.Object) - { - throw new ArgumentException("The serialized session state must be a JSON object.", nameof(serializedSessionState)); - } - - var state = serializedSessionState.Deserialize( - A2AJsonUtilities.DefaultOptions.GetTypeInfo(typeof(A2AAgentSessionState))) as A2AAgentSessionState; - - if (state?.ContextId is string contextId) - { - this.ContextId = contextId; - } - - if (state?.TaskId is string taskId) - { - this.TaskId = taskId; - } + this.ContextId = contextId; + this.TaskId = taskId; } /// /// Gets the ID for the current conversation with the A2A agent. /// + [JsonPropertyName("contextId")] public string? ContextId { get; internal set; } /// /// Gets the ID for the task the agent is currently working on. /// + [JsonPropertyName("taskId")] public string? TaskId { get; internal set; } /// internal JsonElement Serialize(JsonSerializerOptions? jsonSerializerOptions = null) { - var state = new A2AAgentSessionState - { - ContextId = this.ContextId, - TaskId = this.TaskId - }; - - return JsonSerializer.SerializeToElement(state, A2AJsonUtilities.DefaultOptions.GetTypeInfo(typeof(A2AAgentSessionState))); + var jso = jsonSerializerOptions ?? A2AJsonUtilities.DefaultOptions; + return JsonSerializer.SerializeToElement(this, jso.GetTypeInfo(typeof(A2AAgentSession))); } - internal sealed class A2AAgentSessionState + internal static A2AAgentSession Deserialize(JsonElement serializedState, JsonSerializerOptions? jsonSerializerOptions = null) { - public string? ContextId { get; set; } + if (serializedState.ValueKind != JsonValueKind.Object) + { + throw new ArgumentException("The serialized session state must be a JSON object.", nameof(serializedState)); + } - public string? TaskId { get; set; } + var jso = jsonSerializerOptions ?? A2AJsonUtilities.DefaultOptions; + return serializedState.Deserialize(jso.GetTypeInfo(typeof(A2AAgentSession))) as A2AAgentSession + ?? new A2AAgentSession(); } + + [DebuggerBrowsable(DebuggerBrowsableState.Never)] + private string DebuggerDisplay => + $"ContextId = {this.ContextId}, TaskId = {this.TaskId}, StateBag Count = {this.StateBag.Count}"; } diff --git a/dotnet/src/Microsoft.Agents.AI.A2A/A2AJsonUtilities.cs b/dotnet/src/Microsoft.Agents.AI.A2A/A2AJsonUtilities.cs index 5f079d9573..3c25e350ae 100644 --- a/dotnet/src/Microsoft.Agents.AI.A2A/A2AJsonUtilities.cs +++ b/dotnet/src/Microsoft.Agents.AI.A2A/A2AJsonUtilities.cs @@ -74,7 +74,7 @@ private static JsonSerializerOptions CreateDefaultOptions() NumberHandling = JsonNumberHandling.AllowReadingFromString)] // A2A agent types - [JsonSerializable(typeof(A2AAgentSession.A2AAgentSessionState))] + [JsonSerializable(typeof(A2AAgentSession))] [ExcludeFromCodeCoverage] private sealed partial class JsonContext : JsonSerializerContext; } diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/AIContextProvider.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/AIContextProvider.cs index a4b606e6a1..e45d56a18a 100644 --- a/dotnet/src/Microsoft.Agents.AI.Abstractions/AIContextProvider.cs +++ b/dotnet/src/Microsoft.Agents.AI.Abstractions/AIContextProvider.cs @@ -3,7 +3,6 @@ using System; using System.Collections.Generic; using System.Linq; -using System.Text.Json; using System.Threading; using System.Threading.Tasks; using Microsoft.Extensions.AI; @@ -130,13 +129,18 @@ public async ValueTask InvokingAsync(InvokingContext context, Cancell /// /// Implementers can use the request and response messages in the provided to: /// - /// Update internal state based on conversation outcomes + /// Update state based on conversation outcomes /// Extract and store memories or preferences from user messages /// Log or audit conversation details /// Perform cleanup or finalization tasks /// /// /// + /// The is passed a reference to the via and + /// allowing it to store state in the . Since an is used with many different sessions, it should + /// not store any session-specific information within its own instance fields. Instead, any session-specific state should be stored in the associated . + /// + /// /// This method is called regardless of whether the invocation succeeded or failed. /// To check if the invocation was successful, inspect the property. /// @@ -168,18 +172,6 @@ public ValueTask InvokedAsync(InvokedContext context, CancellationToken cancella protected virtual ValueTask InvokedCoreAsync(InvokedContext context, CancellationToken cancellationToken = default) => default; - /// - /// Serializes the current object's state to a using the specified serialization options. - /// - /// The JSON serialization options to use for the serialization process. - /// A representation of the object's state, or a default if the provider has no serializable state. - /// - /// The default implementation returns a default . Override this method if the provider - /// maintains state that should be preserved across sessions or distributed scenarios. - /// - public virtual JsonElement Serialize(JsonSerializerOptions? jsonSerializerOptions = null) - => default; - /// Asks the for an object of the specified type . /// The type of object being requested. /// An optional key that can be used to help identify the target service. diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentAbstractionsJsonUtilities.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentAbstractionsJsonUtilities.cs index 17fbb9e4c6..f8c8aa9b98 100644 --- a/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentAbstractionsJsonUtilities.cs +++ b/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentAbstractionsJsonUtilities.cs @@ -1,5 +1,6 @@ // Copyright (c) Microsoft. All rights reserved. +using System.Collections.Concurrent; using System.Diagnostics.CodeAnalysis; using System.Text.Encodings.Web; using System.Text.Json; @@ -80,9 +81,9 @@ private static JsonSerializerOptions CreateDefaultOptions() [JsonSerializable(typeof(AgentResponse[]))] [JsonSerializable(typeof(AgentResponseUpdate))] [JsonSerializable(typeof(AgentResponseUpdate[]))] - [JsonSerializable(typeof(ServiceIdAgentSession.ServiceIdAgentSessionState))] - [JsonSerializable(typeof(InMemoryAgentSession.InMemoryAgentSessionState))] [JsonSerializable(typeof(InMemoryChatHistoryProvider.State))] + [JsonSerializable(typeof(AgentSessionStateBag))] + [JsonSerializable(typeof(ConcurrentDictionary))] [ExcludeFromCodeCoverage] private sealed partial class JsonContext : JsonSerializerContext; diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentSession.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentSession.cs index 3efce9be17..144c09a541 100644 --- a/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentSession.cs +++ b/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentSession.cs @@ -1,7 +1,9 @@ // Copyright (c) Microsoft. All rights reserved. using System; +using System.Diagnostics; using System.Text.Json; +using System.Text.Json.Serialization; using Microsoft.Shared.Diagnostics; namespace Microsoft.Agents.AI; @@ -44,6 +46,7 @@ namespace Microsoft.Agents.AI; /// /// /// +[DebuggerDisplay("{DebuggerDisplay,nq}")] public abstract class AgentSession { /// @@ -53,6 +56,20 @@ protected AgentSession() { } + /// + /// Initializes a new instance of the class. + /// + protected AgentSession(AgentSessionStateBag stateBag) + { + this.StateBag = Throw.IfNull(stateBag); + } + + /// + /// Gets any arbitrary state associated with this session. + /// + [JsonPropertyName("stateBag")] + public AgentSessionStateBag StateBag { get; protected set; } = new(); + /// Asks the for an object of the specified type . /// The type of object being requested. /// An optional key that can be used to help identify the target service. @@ -82,4 +99,7 @@ protected AgentSession() /// public TService? GetService(object? serviceKey = null) => this.GetService(typeof(TService), serviceKey) is TService service ? service : default; + + [DebuggerBrowsable(DebuggerBrowsableState.Never)] + private string DebuggerDisplay => $"StateBag Count = {this.StateBag.Count}"; } diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentSessionStateBag.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentSessionStateBag.cs new file mode 100644 index 0000000000..d78a866b2c --- /dev/null +++ b/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentSessionStateBag.cs @@ -0,0 +1,145 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Concurrent; +using System.Text.Json; +using System.Text.Json.Serialization; +using Microsoft.Shared.Diagnostics; + +namespace Microsoft.Agents.AI; + +/// +/// Provides a thread-safe key-value store for managing session-scoped state with support for type-safe access and JSON +/// serialization options. +/// +/// +/// SessionState enables storing and retrieving objects associated with a session using string keys. +/// Values can be accessed in a type-safe manner and are serialized or deserialized using configurable JSON serializer +/// options. This class is designed for concurrent access and is safe to use across multiple threads. +/// +[JsonConverter(typeof(AgentSessionStateBagJsonConverter))] +public class AgentSessionStateBag +{ + private readonly ConcurrentDictionary _state; + + /// + /// Initializes a new instance of the class. + /// + public AgentSessionStateBag() + { + this._state = new ConcurrentDictionary(); + } + + /// + /// Initializes a new instance of the class. + /// + /// The initial state dictionary. + internal AgentSessionStateBag(ConcurrentDictionary? state) + { + this._state = state ?? new ConcurrentDictionary(); + } + + /// + /// Gets the number of key-value pairs contained in the session state. + /// + public int Count => this._state.Count; + + /// + /// Tries to get a value from the session state. + /// + /// The type of the value to retrieve. + /// The key from which to retrieve the value. + /// The value if found and convertible to the required type; otherwise, null. + /// The JSON serializer options to use for serializing/deserializing the value. + /// if the value was successfully retrieved, otherwise. + public bool TryGetValue(string key, out T? value, JsonSerializerOptions? jsonSerializerOptions = null) + where T : class + { + _ = Throw.IfNullOrWhitespace(key); + var jso = jsonSerializerOptions ?? AgentAbstractionsJsonUtilities.DefaultOptions; + + if (this._state.TryGetValue(key, out var stateValue)) + { + return stateValue.TryReadDeserializedValue(out value, jso); + } + + value = null; + return false; + } + + /// + /// Gets a value from the session state. + /// + /// The type of value to get. + /// The key from which to retrieve the value. + /// The JSON serializer options to use for serializing/deserialing the value. + /// The retrieved value or null if not found. + /// The value could not be deserialized into the required type. + public T? GetValue(string key, JsonSerializerOptions? jsonSerializerOptions = null) + where T : class + { + _ = Throw.IfNullOrWhitespace(key); + var jso = jsonSerializerOptions ?? AgentAbstractionsJsonUtilities.DefaultOptions; + + if (this._state.TryGetValue(key, out var stateValue)) + { + return stateValue.ReadDeserializedValue(jso); + } + + return null; + } + + /// + /// Sets a value in the session state. + /// + /// The type of the value to set. + /// The key to store the value under. + /// The value to set. + /// The JSON serializer options to use for serializing the value. + public void SetValue(string key, T? value, JsonSerializerOptions? jsonSerializerOptions = null) + where T : class + { + _ = Throw.IfNullOrWhitespace(key); + var jso = jsonSerializerOptions ?? AgentAbstractionsJsonUtilities.DefaultOptions; + + var stateValue = this._state.GetOrAdd(key, _ => + new AgentSessionStateBagValue(value, typeof(T), jso)); + + stateValue.SetDeserialized(value, typeof(T), jso); + } + + /// + /// Tries to remove a value from the session state. + /// + /// The key of the value to remove. + /// if the value was successfully removed; otherwise, . + public bool TryRemoveValue(string key) + => this._state.TryRemove(Throw.IfNullOrWhitespace(key), out _); + + /// + /// Serializes all session state values to a JSON object. + /// + /// A representing the serialized session state. + /// Thrown when a session state value is not properly initialized. + public JsonElement Serialize() + { + return JsonSerializer.SerializeToElement(this._state, AgentAbstractionsJsonUtilities.DefaultOptions.GetTypeInfo(typeof(ConcurrentDictionary))); + } + + /// + /// Deserializes a JSON object into an instance. + /// + /// The element to deserialize. + /// The deserialized . + public static AgentSessionStateBag Deserialize(JsonElement jsonElement) + { + if (jsonElement.ValueKind is JsonValueKind.Undefined or JsonValueKind.Null) + { + return new AgentSessionStateBag(); + } + + return new AgentSessionStateBag( + jsonElement.Deserialize(AgentAbstractionsJsonUtilities.DefaultOptions.GetTypeInfo(typeof(ConcurrentDictionary))) as ConcurrentDictionary + ?? new ConcurrentDictionary()); + } +} diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentSessionStateBagJsonConverter.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentSessionStateBagJsonConverter.cs new file mode 100644 index 0000000000..bfb6904320 --- /dev/null +++ b/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentSessionStateBagJsonConverter.cs @@ -0,0 +1,28 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Text.Json; +using System.Text.Json.Serialization; + +namespace Microsoft.Agents.AI; + +/// +/// Custom JSON converter for that serializes and deserializes +/// the internal dictionary contents rather than the container object's public properties. +/// +public sealed class AgentSessionStateBagJsonConverter : JsonConverter +{ + /// + public override AgentSessionStateBag Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) + { + var element = JsonElement.ParseValue(ref reader); + return AgentSessionStateBag.Deserialize(element); + } + + /// + public override void Write(Utf8JsonWriter writer, AgentSessionStateBag value, JsonSerializerOptions options) + { + var element = value.Serialize(); + element.WriteTo(writer); + } +} diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentSessionStateBagValue.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentSessionStateBagValue.cs new file mode 100644 index 0000000000..6f97237516 --- /dev/null +++ b/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentSessionStateBagValue.cs @@ -0,0 +1,170 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Text.Json; +using System.Text.Json.Serialization; + +namespace Microsoft.Agents.AI; + +/// +/// Used to store a value in session state. +/// +[JsonConverter(typeof(AgentSessionStateBagValueJsonConverter))] +internal class AgentSessionStateBagValue +{ + private readonly object _lock = new(); + private DeserializedCache? _cache; + private JsonElement _jsonValue; + + /// + /// Initializes a new instance of the SessionStateValue class with the specified value. + /// + /// The serialized value to associate with the session state. + public AgentSessionStateBagValue(JsonElement jsonValue) + { + this.JsonValue = jsonValue; + } + + /// + /// Initializes a new instance of the SessionStateValue class with the specified value. + /// + /// The value to associate with the session state. Can be any object, including null. + /// The type of the value. + /// The JSON serializer options to use for serializing the value. + public AgentSessionStateBagValue(object? deserializedValue, Type valueType, JsonSerializerOptions jsonSerializerOptions) + { + this._cache = new DeserializedCache(deserializedValue, valueType, jsonSerializerOptions); + } + + /// + /// Gets or sets the value associated with this instance. + /// + public JsonElement JsonValue + { + get + { + lock (this._lock) + { + // We are assuming here that JsonValue will only be read when the object is being serialized, + // which means that we will only call SerializeToElement when serializing and therefore it's + // OK to serialize on each read if the cache is set. + if (this._cache is { } cache) + { + this._jsonValue = JsonSerializer.SerializeToElement(cache.Value, cache.Options.GetTypeInfo(cache.ValueType)); + } + + return this._jsonValue; + } + } + set + { + lock (this._lock) + { + this._jsonValue = value; + this._cache = null; + } + } + } + + /// + /// Tries to read the deserialized value of this session state value. + /// Returns false if the value could not be deserialized into the required type, or if the value is undefined. + /// Returns true and sets the out parameter to null if the value is null. + /// + public bool TryReadDeserializedValue(out T? value, JsonSerializerOptions? jsonSerializerOptions = null) + where T : class + { + var jso = jsonSerializerOptions ?? AgentAbstractionsJsonUtilities.DefaultOptions; + + lock (this._lock) + { + if (this._cache is { } cache) + { + value = cache.Value as T; + return true; + } + + switch (this._jsonValue) + { + case JsonElement jsonElement when jsonElement.ValueKind == JsonValueKind.Undefined: + value = null; + return false; + case JsonElement jsonElement when jsonElement.ValueKind == JsonValueKind.Null: + value = null; + return true; + default: + T? result = this._jsonValue.Deserialize(jso.GetTypeInfo(typeof(T))) as T; + if (result is null) + { + value = null; + return false; + } + + this._cache = new DeserializedCache(result, typeof(T), jso); + + value = result; + return true; + } + } + } + + /// + /// Reads the deserialized value of this session state value, throwing an exception if the value could not be deserialized into the required type or is undefined. + /// + public T? ReadDeserializedValue(JsonSerializerOptions? jsonSerializerOptions = null) + where T : class + { + var jso = jsonSerializerOptions ?? AgentAbstractionsJsonUtilities.DefaultOptions; + + lock (this._lock) + { + if (this._cache is { } cache) + { + return cache.Value as T; + } + + switch (this._jsonValue) + { + case JsonElement jsonElement when jsonElement.ValueKind == JsonValueKind.Null || jsonElement.ValueKind == JsonValueKind.Undefined: + return null; + default: + T? result = this._jsonValue.Deserialize(jso.GetTypeInfo(typeof(T))) as T; + if (result is null) + { + throw new InvalidOperationException($"Failed to deserialize session state value to type {typeof(T).FullName}."); + } + + this._cache = new DeserializedCache(result, typeof(T), jso); + return result; + } + } + } + + /// + /// Sets the deserialized value of this session state value, updating the cache accordingly. + /// This does not update the JsonValue directly; the JsonValue will be updated on the next read or when the object is serialized. + /// + public void SetDeserialized(object? deserializedValue, Type valueType, JsonSerializerOptions jsonSerializerOptions) + { + lock (this._lock) + { + this._cache = new DeserializedCache(deserializedValue, valueType, jsonSerializerOptions); + } + } + + private readonly struct DeserializedCache + { + public DeserializedCache(object? value, Type valueType, JsonSerializerOptions options) + { + this.Value = value; + this.ValueType = valueType; + this.Options = options; + } + + public object? Value { get; } + + public Type ValueType { get; } + + public JsonSerializerOptions Options { get; } + } +} diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentSessionStateBagValueJsonConverter.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentSessionStateBagValueJsonConverter.cs new file mode 100644 index 0000000000..27c9dc08a8 --- /dev/null +++ b/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentSessionStateBagValueJsonConverter.cs @@ -0,0 +1,27 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Text.Json; +using System.Text.Json.Serialization; + +namespace Microsoft.Agents.AI; + +/// +/// Custom JSON converter for that serializes and deserializes +/// the directly rather than wrapping it in a container object. +/// +internal sealed class AgentSessionStateBagValueJsonConverter : JsonConverter +{ + /// + public override AgentSessionStateBagValue Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) + { + var element = JsonElement.ParseValue(ref reader); + return new AgentSessionStateBagValue(element); + } + + /// + public override void Write(Utf8JsonWriter writer, AgentSessionStateBagValue value, JsonSerializerOptions options) + { + value.JsonValue.WriteTo(writer); + } +} diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/ChatHistoryProvider.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/ChatHistoryProvider.cs index f49c5d46a7..1967752d6e 100644 --- a/dotnet/src/Microsoft.Agents.AI.Abstractions/ChatHistoryProvider.cs +++ b/dotnet/src/Microsoft.Agents.AI.Abstractions/ChatHistoryProvider.cs @@ -3,7 +3,6 @@ using System; using System.Collections.Generic; using System.Linq; -using System.Text.Json; using System.Threading; using System.Threading.Tasks; using Microsoft.Extensions.AI; @@ -27,10 +26,14 @@ namespace Microsoft.Agents.AI; /// Storing chat messages with proper ordering and metadata preservation /// Retrieving messages in chronological order for agent context /// Managing storage limits through truncation, summarization, or other strategies -/// Supporting serialization for thread persistence and migration /// /// /// +/// The is passed a reference to the via and +/// allowing it to store state in the . Since a is used with many different sessions, it should +/// not store any session-specific information within its own instance fields. Instead, any session-specific state should be stored in the associated . +/// +/// /// A is only relevant for scenarios where the underlying AI service that the agent is using /// does not use in-service chat history storage. /// @@ -80,10 +83,6 @@ protected ChatHistoryProvider(string sourceName) /// Archiving old messages while keeping active conversation context /// /// - /// - /// Each instance should be associated with a single to ensure proper message isolation - /// and context management. - /// /// public async ValueTask> InvokingAsync(InvokingContext context, CancellationToken cancellationToken = default) { @@ -198,13 +197,6 @@ public ValueTask InvokedAsync(InvokedContext context, CancellationToken cancella /// protected abstract ValueTask InvokedCoreAsync(InvokedContext context, CancellationToken cancellationToken = default); - /// - /// Serializes the current object's state to a using the specified serialization options. - /// - /// The JSON serialization options to use. - /// A representation of the object's state. - public abstract JsonElement Serialize(JsonSerializerOptions? jsonSerializerOptions = null); - /// Asks the for an object of the specified type . /// The type of object being requested. /// An optional key that can be used to help identify the target service. diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/ChatHistoryProviderMessageFilter.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/ChatHistoryProviderMessageFilter.cs index 6cee80986b..384d11b83a 100644 --- a/dotnet/src/Microsoft.Agents.AI.Abstractions/ChatHistoryProviderMessageFilter.cs +++ b/dotnet/src/Microsoft.Agents.AI.Abstractions/ChatHistoryProviderMessageFilter.cs @@ -2,7 +2,6 @@ using System; using System.Collections.Generic; -using System.Text.Json; using System.Threading; using System.Threading.Tasks; using Microsoft.Extensions.AI; @@ -65,10 +64,4 @@ protected override ValueTask InvokedCoreAsync(InvokedContext context, Cancellati return this._innerProvider.InvokedAsync(context, cancellationToken); } - - /// - public override JsonElement Serialize(JsonSerializerOptions? jsonSerializerOptions = null) - { - return this._innerProvider.Serialize(jsonSerializerOptions); - } } diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/InMemoryAgentSession.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/InMemoryAgentSession.cs deleted file mode 100644 index 05ffafaeb9..0000000000 --- a/dotnet/src/Microsoft.Agents.AI.Abstractions/InMemoryAgentSession.cs +++ /dev/null @@ -1,124 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. - -using System; -using System.Collections.Generic; -using System.Diagnostics; -using System.Text.Json; -using Microsoft.Extensions.AI; - -namespace Microsoft.Agents.AI; - -/// -/// Provides an abstract base class for an that maintain all chat history in local memory. -/// -/// -/// -/// is designed for scenarios where chat history should be stored locally -/// rather than in external services or databases. This approach provides high performance and simplicity while -/// maintaining full control over the conversation data. -/// -/// -/// In-memory threads do not persist conversation data across application restarts -/// unless explicitly serialized and restored. -/// -/// -[DebuggerDisplay("{DebuggerDisplay,nq}")] -public abstract class InMemoryAgentSession : AgentSession -{ - /// - /// Initializes a new instance of the class. - /// - /// - /// An optional instance to use for storing chat messages. - /// If , a new empty will be created. - /// - /// - /// This constructor allows sharing of between sessions or providing pre-configured - /// with specific reduction or processing logic. - /// - protected InMemoryAgentSession(InMemoryChatHistoryProvider? chatHistoryProvider = null) - { - this.ChatHistoryProvider = chatHistoryProvider ?? []; - } - - /// - /// Initializes a new instance of the class. - /// - /// The initial messages to populate the conversation history. - /// is . - /// - /// This constructor is useful for initializing sessions with existing conversation history or - /// for migrating conversations from other storage systems. - /// - protected InMemoryAgentSession(IEnumerable messages) - { - this.ChatHistoryProvider = [.. messages]; - } - - /// - /// Initializes a new instance of the class from previously serialized state. - /// - /// A representing the serialized state of the session. - /// Optional settings for customizing the JSON deserialization process. - /// - /// Optional factory function to create the from its serialized state. - /// If not provided, a default factory will be used that creates a basic . - /// - /// The is not a JSON object. - /// The is invalid or cannot be deserialized to the expected type. - /// - /// This constructor enables restoration of in-memory threads from previously saved state, allowing - /// conversations to be resumed across application restarts or migrated between different instances. - /// - protected InMemoryAgentSession( - JsonElement serializedState, - JsonSerializerOptions? jsonSerializerOptions = null, - Func? chatHistoryProviderFactory = null) - { - if (serializedState.ValueKind != JsonValueKind.Object) - { - throw new ArgumentException("The serialized session state must be a JSON object.", nameof(serializedState)); - } - - var state = serializedState.Deserialize( - AgentAbstractionsJsonUtilities.DefaultOptions.GetTypeInfo(typeof(InMemoryAgentSessionState))) as InMemoryAgentSessionState; - - this.ChatHistoryProvider = - chatHistoryProviderFactory?.Invoke(state?.ChatHistoryProviderState ?? default, jsonSerializerOptions) ?? - new InMemoryChatHistoryProvider(state?.ChatHistoryProviderState ?? default, jsonSerializerOptions); - } - - /// - /// Gets or sets the used by this thread. - /// - public InMemoryChatHistoryProvider ChatHistoryProvider { get; } - - /// - /// Serializes the current object's state to a using the specified serialization options. - /// - /// The JSON serialization options to use. - /// A representation of the object's state. - protected internal virtual JsonElement Serialize(JsonSerializerOptions? jsonSerializerOptions = null) - { - var chatHistoryProviderState = this.ChatHistoryProvider.Serialize(jsonSerializerOptions); - - var state = new InMemoryAgentSessionState - { - ChatHistoryProviderState = chatHistoryProviderState, - }; - - return JsonSerializer.SerializeToElement(state, AgentAbstractionsJsonUtilities.DefaultOptions.GetTypeInfo(typeof(InMemoryAgentSessionState))); - } - - /// - public override object? GetService(Type serviceType, object? serviceKey = null) => - base.GetService(serviceType, serviceKey) ?? this.ChatHistoryProvider?.GetService(serviceType, serviceKey); - - [DebuggerBrowsable(DebuggerBrowsableState.Never)] - private string DebuggerDisplay => $"Count = {this.ChatHistoryProvider.Count}"; - - internal sealed class InMemoryAgentSessionState - { - public JsonElement? ChatHistoryProviderState { get; set; } - } -} diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/InMemoryChatHistoryProvider.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/InMemoryChatHistoryProvider.cs index 001b3a3bcc..f85b7a4662 100644 --- a/dotnet/src/Microsoft.Agents.AI.Abstractions/InMemoryChatHistoryProvider.cs +++ b/dotnet/src/Microsoft.Agents.AI.Abstractions/InMemoryChatHistoryProvider.cs @@ -1,11 +1,10 @@ // Copyright (c) Microsoft. All rights reserved. using System; -using System.Collections; using System.Collections.Generic; -using System.Diagnostics; using System.Linq; using System.Text.Json; +using System.Text.Json.Serialization; using System.Threading; using System.Threading.Tasks; using Microsoft.Extensions.AI; @@ -14,122 +13,93 @@ namespace Microsoft.Agents.AI; /// -/// Provides an in-memory implementation of with support for message reduction and collection semantics. +/// Provides an in-memory implementation of with support for message reduction. /// /// /// -/// stores chat messages entirely in local memory, providing fast access and manipulation -/// capabilities. It implements both for agent integration and -/// for direct collection manipulation. +/// stores chat messages in the , +/// providing fast access and manipulation capabilities integrated with session state management. /// /// /// This maintains all messages in memory. For long-running conversations or high-volume scenarios, consider using /// message reduction strategies or alternative storage implementations. /// /// -[DebuggerDisplay("Count = {Count}")] -[DebuggerTypeProxy(typeof(DebugView))] -public sealed class InMemoryChatHistoryProvider : ChatHistoryProvider, IList, IReadOnlyList +public sealed class InMemoryChatHistoryProvider : ChatHistoryProvider { - private List _messages; + private const string DefaultStateBagKey = "InMemoryChatHistoryProvider.State"; + + private readonly string _stateKey; + private readonly Func _stateInitializer; + private readonly JsonSerializerOptions _jsonSerializerOptions; /// /// Initializes a new instance of the class. /// - /// - /// This constructor creates a basic in-memory without message reduction capabilities. - /// Messages will be stored exactly as added without any automatic processing or reduction. - /// - public InMemoryChatHistoryProvider() + /// + /// Optional configuration options that control the provider's behavior, including state initialization, + /// message reduction, and serialization settings. If , default settings will be used. + /// + public InMemoryChatHistoryProvider(InMemoryChatHistoryProviderOptions? options = null) { - this._messages = []; + this._stateInitializer = options?.StateInitializer ?? (_ => new State()); + this.ChatReducer = options?.ChatReducer; + this.ReducerTriggerEvent = options?.ReducerTriggerEvent ?? InMemoryChatHistoryProviderOptions.ChatReducerTriggerEvent.BeforeMessagesRetrieval; + this._stateKey = options?.StateKey ?? DefaultStateBagKey; + this._jsonSerializerOptions = options?.JsonSerializerOptions ?? AgentAbstractionsJsonUtilities.DefaultOptions; } /// - /// Initializes a new instance of the class from previously serialized state. + /// Gets the chat reducer used to process or reduce chat messages. If null, no reduction logic will be applied. /// - /// A representing the serialized state of the provider. - /// Optional settings for customizing the JSON deserialization process. - /// The is not a valid JSON object or cannot be deserialized. - /// - /// This constructor enables restoration of messages from previously saved state, allowing - /// conversation history to be preserved across application restarts or migrated between instances. - /// - public InMemoryChatHistoryProvider(JsonElement serializedState, JsonSerializerOptions? jsonSerializerOptions = null) - : this(null, serializedState, jsonSerializerOptions, ChatReducerTriggerEvent.BeforeMessagesRetrieval) - { - } + public IChatReducer? ChatReducer { get; } /// - /// Initializes a new instance of the class. + /// Gets the event that triggers the reducer invocation in this provider. /// - /// - /// A instance used to process, reduce, or optimize chat messages. - /// This can be used to implement strategies like message summarization, truncation, or cleanup. - /// - /// - /// Specifies when the message reducer should be invoked. The default is , - /// which applies reduction logic when messages are retrieved for agent consumption. - /// - /// is . - /// - /// Message reducers enable automatic management of message storage by implementing strategies to - /// keep memory usage under control while preserving important conversation context. - /// - public InMemoryChatHistoryProvider(IChatReducer chatReducer, ChatReducerTriggerEvent reducerTriggerEvent = ChatReducerTriggerEvent.BeforeMessagesRetrieval) - : this(chatReducer, default, null, reducerTriggerEvent) - { - Throw.IfNull(chatReducer); - } + public InMemoryChatHistoryProviderOptions.ChatReducerTriggerEvent ReducerTriggerEvent { get; } /// - /// Initializes a new instance of the class, with an existing state from a serialized JSON element. + /// Gets the chat messages stored for the specified session. /// - /// An optional instance used to process or reduce chat messages. If null, no reduction logic will be applied. - /// A representing the serialized state of the provider. - /// Optional settings for customizing the JSON deserialization process. - /// The event that should trigger the reducer invocation. - public InMemoryChatHistoryProvider(IChatReducer? chatReducer, JsonElement serializedState, JsonSerializerOptions? jsonSerializerOptions = null, ChatReducerTriggerEvent reducerTriggerEvent = ChatReducerTriggerEvent.BeforeMessagesRetrieval) - { - this.ChatReducer = chatReducer; - this.ReducerTriggerEvent = reducerTriggerEvent; - - if (serializedState.ValueKind is JsonValueKind.Object) - { - var jso = jsonSerializerOptions ?? AgentAbstractionsJsonUtilities.DefaultOptions; - var state = serializedState.Deserialize( - jso.GetTypeInfo(typeof(State))) as State; - if (state?.Messages is { } messages) - { - this._messages = messages; - return; - } - } - - this._messages = []; - } + /// The agent session containing the state. + /// A list of chat messages, or an empty list if no state is found. + public List GetMessages(AgentSession? session) + => this.GetOrInitializeState(session).Messages; /// - /// Gets the chat reducer used to process or reduce chat messages. If null, no reduction logic will be applied. + /// Sets the chat messages for the specified session. /// - public IChatReducer? ChatReducer { get; } + /// The agent session containing the state. + /// The messages to store. + /// is . + public void SetMessages(AgentSession? session, List messages) + { + _ = Throw.IfNull(messages); + + var state = this.GetOrInitializeState(session); + state.Messages = messages; + } /// - /// Gets the event that triggers the reducer invocation in this provider. + /// Gets the state from the session's StateBag, or initializes it using the state initializer if not present. /// - public ChatReducerTriggerEvent ReducerTriggerEvent { get; } - - /// - public int Count => this._messages.Count; + /// The agent session containing the StateBag. + /// The provider state, or null if no session is available. + private State GetOrInitializeState(AgentSession? session) + { + if (session?.StateBag.TryGetValue(this._stateKey, out var state, this._jsonSerializerOptions) is true && state is not null) + { + return state; + } - /// - public bool IsReadOnly => ((IList)this._messages).IsReadOnly; + state = this._stateInitializer(session); + if (session is not null) + { + session.StateBag.SetValue(this._stateKey, state, this._jsonSerializerOptions); + } - /// - public ChatMessage this[int index] - { - get => this._messages[index]; - set => this._messages[index] = value; + return state; } /// @@ -137,12 +107,14 @@ protected override async ValueTask> InvokingCoreAsync(I { _ = Throw.IfNull(context); - if (this.ReducerTriggerEvent is ChatReducerTriggerEvent.BeforeMessagesRetrieval && this.ChatReducer is not null) + var state = this.GetOrInitializeState(context.Session); + + if (this.ReducerTriggerEvent is InMemoryChatHistoryProviderOptions.ChatReducerTriggerEvent.BeforeMessagesRetrieval && this.ChatReducer is not null) { - this._messages = (await this.ChatReducer.ReduceAsync(this._messages, cancellationToken).ConfigureAwait(false)).ToList(); + state.Messages = (await this.ChatReducer.ReduceAsync(state.Messages, cancellationToken).ConfigureAwait(false)).ToList(); } - return this._messages; + return state.Messages; } /// @@ -155,94 +127,27 @@ protected override async ValueTask InvokedCoreAsync(InvokedContext context, Canc return; } + var state = this.GetOrInitializeState(context.Session); + // Add request and response messages to the provider var allNewMessages = context.RequestMessages.Concat(context.ResponseMessages ?? []); - this._messages.AddRange(allNewMessages); + state.Messages.AddRange(allNewMessages); - if (this.ReducerTriggerEvent is ChatReducerTriggerEvent.AfterMessageAdded && this.ChatReducer is not null) + if (this.ReducerTriggerEvent is InMemoryChatHistoryProviderOptions.ChatReducerTriggerEvent.AfterMessageAdded && this.ChatReducer is not null) { - this._messages = (await this.ChatReducer.ReduceAsync(this._messages, cancellationToken).ConfigureAwait(false)).ToList(); + state.Messages = (await this.ChatReducer.ReduceAsync(state.Messages, cancellationToken).ConfigureAwait(false)).ToList(); } } - /// - public override JsonElement Serialize(JsonSerializerOptions? jsonSerializerOptions = null) - { - State state = new() - { - Messages = this._messages, - }; - - var jso = jsonSerializerOptions ?? AgentAbstractionsJsonUtilities.DefaultOptions; - return JsonSerializer.SerializeToElement(state, jso.GetTypeInfo(typeof(State))); - } - - /// - public int IndexOf(ChatMessage item) - => this._messages.IndexOf(item); - - /// - public void Insert(int index, ChatMessage item) - => this._messages.Insert(index, item); - - /// - public void RemoveAt(int index) - => this._messages.RemoveAt(index); - - /// - public void Add(ChatMessage item) - => this._messages.Add(item); - - /// - public void Clear() - => this._messages.Clear(); - - /// - public bool Contains(ChatMessage item) - => this._messages.Contains(item); - - /// - public void CopyTo(ChatMessage[] array, int arrayIndex) - => this._messages.CopyTo(array, arrayIndex); - - /// - public bool Remove(ChatMessage item) - => this._messages.Remove(item); - - /// - public IEnumerator GetEnumerator() - => this._messages.GetEnumerator(); - - /// - IEnumerator IEnumerable.GetEnumerator() - => this.GetEnumerator(); - - internal sealed class State - { - public List Messages { get; set; } = []; - } - /// - /// Defines the events that can trigger a reducer in the . + /// Represents the state of a stored in the . /// - public enum ChatReducerTriggerEvent + public sealed class State { /// - /// Trigger the reducer when a new message is added. - /// will only complete when reducer processing is done. - /// - AfterMessageAdded, - - /// - /// Trigger the reducer before messages are retrieved from the provider. - /// The reducer will process the messages before they are returned to the caller. + /// Gets or sets the list of chat messages. /// - BeforeMessagesRetrieval - } - - private sealed class DebugView(InMemoryChatHistoryProvider provider) - { - [DebuggerBrowsable(DebuggerBrowsableState.RootHidden)] - public ChatMessage[] Items => provider._messages.ToArray(); + [JsonPropertyName("messages")] + public List Messages { get; set; } = []; } } diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/InMemoryChatHistoryProviderOptions.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/InMemoryChatHistoryProviderOptions.cs new file mode 100644 index 0000000000..848951329c --- /dev/null +++ b/dotnet/src/Microsoft.Agents.AI.Abstractions/InMemoryChatHistoryProviderOptions.cs @@ -0,0 +1,67 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Text.Json; +using Microsoft.Extensions.AI; + +namespace Microsoft.Agents.AI; + +/// +/// Represents configuration options for . +/// +public sealed class InMemoryChatHistoryProviderOptions +{ + /// + /// Gets or sets an optional delegate that initializes the provider state on the first invocation. + /// If , a default initializer that creates an empty state will be used. + /// + public Func? StateInitializer { get; set; } + + /// + /// Gets or sets an optional instance used to process, reduce, or optimize chat messages. + /// This can be used to implement strategies like message summarization, truncation, or cleanup. + /// + public IChatReducer? ChatReducer { get; set; } + + /// + /// Gets or sets when the message reducer should be invoked. + /// The default is , + /// which applies reduction logic when messages are retrieved for agent consumption. + /// + /// + /// Message reducers enable automatic management of message storage by implementing strategies to + /// keep memory usage under control while preserving important conversation context. + /// + public ChatReducerTriggerEvent ReducerTriggerEvent { get; set; } = ChatReducerTriggerEvent.BeforeMessagesRetrieval; + + /// + /// Gets or sets an optional key to use for storing the state in the . + /// If , a default key will be used. + /// + public string? StateKey { get; set; } + + /// + /// Gets or sets optional JSON serializer options for serializing the state of this provider. + /// This is valuable for cases like when the chat history contains custom types + /// and source generated serializers are required, or Native AOT / Trimming is required. + /// + public JsonSerializerOptions? JsonSerializerOptions { get; set; } + + /// + /// Defines the events that can trigger a reducer in the . + /// + public enum ChatReducerTriggerEvent + { + /// + /// Trigger the reducer when a new message is added. + /// will only complete when reducer processing is done. + /// + AfterMessageAdded, + + /// + /// Trigger the reducer before messages are retrieved from the provider. + /// The reducer will process the messages before they are returned to the caller. + /// + BeforeMessagesRetrieval + } +} diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/ServiceIdAgentSession.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/ServiceIdAgentSession.cs deleted file mode 100644 index cf00635984..0000000000 --- a/dotnet/src/Microsoft.Agents.AI.Abstractions/ServiceIdAgentSession.cs +++ /dev/null @@ -1,104 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. - -using System; -using System.Diagnostics; -using System.Text.Json; -using Microsoft.Shared.Diagnostics; - -namespace Microsoft.Agents.AI; - -/// -/// Provides a base class for agent sessions that store conversation state remotely in a service and maintain only an identifier reference locally. -/// -/// -/// This class is designed for scenarios where conversation state is managed by an external service (such as a cloud-based AI service) -/// rather than being stored locally. The session maintains only the service identifier needed to reference the remote conversation state. -/// -[DebuggerDisplay("ServiceSessionId = {ServiceSessionId}")] -public abstract class ServiceIdAgentSession : AgentSession -{ - /// - /// Initializes a new instance of the class without a service session identifier. - /// - /// - /// When using this constructor, the will be initially - /// and should be set by derived classes when the remote conversation is created. - /// - protected ServiceIdAgentSession() - { - } - - /// - /// Initializes a new instance of the class with the specified service session identifier. - /// - /// The unique identifier that references the conversation state stored in the remote service. - /// is . - /// is empty or contains only whitespace. - protected ServiceIdAgentSession(string serviceSessionId) - { - this.ServiceSessionId = Throw.IfNullOrEmpty(serviceSessionId); - } - - /// - /// Initializes a new instance of the class from previously serialized state. - /// - /// A representing the serialized state of the session. - /// Optional settings for customizing the JSON deserialization process. - /// The is not a JSON object. - /// The is invalid or cannot be deserialized to the expected type. - /// - /// This constructor enables restoration of a service-backed session from serialized state, typically used - /// when deserializing session information that was previously saved or transmitted across application boundaries. - /// - protected ServiceIdAgentSession( - JsonElement serializedState, - JsonSerializerOptions? jsonSerializerOptions = null) - { - if (serializedState.ValueKind != JsonValueKind.Object) - { - throw new ArgumentException("The serialized session state must be a JSON object.", nameof(serializedState)); - } - - var state = serializedState.Deserialize( - AgentAbstractionsJsonUtilities.DefaultOptions.GetTypeInfo(typeof(ServiceIdAgentSessionState))) as ServiceIdAgentSessionState; - - if (state?.ServiceSessionId is string serviceSessionId) - { - this.ServiceSessionId = serviceSessionId; - } - } - - /// - /// Gets or sets the unique identifier that references the conversation state stored in the remote service. - /// - /// - /// A string identifier that uniquely identifies the conversation within the remote service, - /// or if no remote conversation has been established yet. - /// - /// - /// This identifier is used by derived classes to reference the remote conversation state when making - /// API calls to the backing service. The exact format and meaning of this identifier depends on the - /// specific service implementation. - /// - protected string? ServiceSessionId { get; set; } - - /// - /// Serializes the current object's state to a using the specified serialization options. - /// - /// The JSON serialization options to use. - /// A representation of the object's state. - protected internal virtual JsonElement Serialize(JsonSerializerOptions? jsonSerializerOptions = null) - { - var state = new ServiceIdAgentSessionState - { - ServiceSessionId = this.ServiceSessionId, - }; - - return JsonSerializer.SerializeToElement(state, AgentAbstractionsJsonUtilities.DefaultOptions.GetTypeInfo(typeof(ServiceIdAgentSessionState))); - } - - internal sealed class ServiceIdAgentSessionState - { - public string? ServiceSessionId { get; set; } - } -} diff --git a/dotnet/src/Microsoft.Agents.AI.AzureAI.Persistent/PersistentAgentsClientExtensions.cs b/dotnet/src/Microsoft.Agents.AI.AzureAI.Persistent/PersistentAgentsClientExtensions.cs index 2058e6760b..c9c898a18a 100644 --- a/dotnet/src/Microsoft.Agents.AI.AzureAI.Persistent/PersistentAgentsClientExtensions.cs +++ b/dotnet/src/Microsoft.Agents.AI.AzureAI.Persistent/PersistentAgentsClientExtensions.cs @@ -191,8 +191,8 @@ public static ChatClientAgent AsAIAgent( Name = options.Name ?? persistentAgentMetadata.Name, Description = options.Description ?? persistentAgentMetadata.Description, ChatOptions = options.ChatOptions, - AIContextProviderFactory = options.AIContextProviderFactory, - ChatHistoryProviderFactory = options.ChatHistoryProviderFactory, + AIContextProvider = options.AIContextProvider, + ChatHistoryProvider = options.ChatHistoryProvider, UseProvidedChatClientAsIs = options.UseProvidedChatClientAsIs }; diff --git a/dotnet/src/Microsoft.Agents.AI.AzureAI/AzureAIProjectChatClientExtensions.cs b/dotnet/src/Microsoft.Agents.AI.AzureAI/AzureAIProjectChatClientExtensions.cs index 18a5ba3b72..82ca49705e 100644 --- a/dotnet/src/Microsoft.Agents.AI.AzureAI/AzureAIProjectChatClientExtensions.cs +++ b/dotnet/src/Microsoft.Agents.AI.AzureAI/AzureAIProjectChatClientExtensions.cs @@ -589,8 +589,8 @@ private static ChatClientAgentOptions CreateChatClientAgentOptions(AgentVersion var agentOptions = CreateChatClientAgentOptions(agentVersion, options?.ChatOptions, requireInvocableTools); if (options is not null) { - agentOptions.AIContextProviderFactory = options.AIContextProviderFactory; - agentOptions.ChatHistoryProviderFactory = options.ChatHistoryProviderFactory; + agentOptions.AIContextProvider = options.AIContextProvider; + agentOptions.ChatHistoryProvider = options.ChatHistoryProvider; agentOptions.UseProvidedChatClientAsIs = options.UseProvidedChatClientAsIs; } diff --git a/dotnet/src/Microsoft.Agents.AI.CopilotStudio/CopilotStudioAgent.cs b/dotnet/src/Microsoft.Agents.AI.CopilotStudio/CopilotStudioAgent.cs index 48642139a9..1bcb84ffbd 100644 --- a/dotnet/src/Microsoft.Agents.AI.CopilotStudio/CopilotStudioAgent.cs +++ b/dotnet/src/Microsoft.Agents.AI.CopilotStudio/CopilotStudioAgent.cs @@ -68,7 +68,7 @@ protected override JsonElement SerializeSessionCore(AgentSession session, JsonSe /// protected override ValueTask DeserializeSessionCoreAsync(JsonElement serializedState, JsonSerializerOptions? jsonSerializerOptions = null, CancellationToken cancellationToken = default) - => new(new CopilotStudioAgentSession(serializedState, jsonSerializerOptions)); + => new(CopilotStudioAgentSession.Deserialize(serializedState, jsonSerializerOptions)); /// protected override async Task RunCoreAsync( diff --git a/dotnet/src/Microsoft.Agents.AI.CopilotStudio/CopilotStudioAgentSession.cs b/dotnet/src/Microsoft.Agents.AI.CopilotStudio/CopilotStudioAgentSession.cs index ec3e21ca91..ba101df082 100644 --- a/dotnet/src/Microsoft.Agents.AI.CopilotStudio/CopilotStudioAgentSession.cs +++ b/dotnet/src/Microsoft.Agents.AI.CopilotStudio/CopilotStudioAgentSession.cs @@ -1,36 +1,58 @@ // Copyright (c) Microsoft. All rights reserved. +using System; +using System.Diagnostics; using System.Text.Json; +using System.Text.Json.Serialization; namespace Microsoft.Agents.AI.CopilotStudio; /// /// Session for CopilotStudio based agents. /// -public sealed class CopilotStudioAgentSession : ServiceIdAgentSession +[DebuggerDisplay("{DebuggerDisplay,nq}")] +public sealed class CopilotStudioAgentSession : AgentSession { internal CopilotStudioAgentSession() { } - internal CopilotStudioAgentSession(JsonElement serializedSessionState, JsonSerializerOptions? jsonSerializerOptions = null) : base(serializedSessionState, jsonSerializerOptions) + [JsonConstructor] + internal CopilotStudioAgentSession(string? conversationId, AgentSessionStateBag? stateBag) : base(stateBag ?? new()) { + this.ConversationId = conversationId; } /// /// Gets the ID for the current conversation with the Copilot Studio agent. /// - public string? ConversationId - { - get { return this.ServiceSessionId; } - internal set { this.ServiceSessionId = value; } - } + [JsonPropertyName("serviceSessionId")] + public string? ConversationId { get; internal set; } /// /// Serializes the current object's state to a using the specified serialization options. /// /// The JSON serialization options to use. /// A representation of the object's state. - internal new JsonElement Serialize(JsonSerializerOptions? jsonSerializerOptions = null) - => base.Serialize(jsonSerializerOptions); + internal JsonElement Serialize(JsonSerializerOptions? jsonSerializerOptions = null) + { + var jso = jsonSerializerOptions ?? CopilotStudioJsonUtilities.DefaultOptions; + return JsonSerializer.SerializeToElement(this, jso.GetTypeInfo(typeof(CopilotStudioAgentSession))); + } + + internal static CopilotStudioAgentSession Deserialize(JsonElement serializedState, JsonSerializerOptions? jsonSerializerOptions = null) + { + if (serializedState.ValueKind != JsonValueKind.Object) + { + throw new ArgumentException("The serialized session state must be a JSON object.", nameof(serializedState)); + } + + var jso = jsonSerializerOptions ?? CopilotStudioJsonUtilities.DefaultOptions; + return serializedState.Deserialize(jso.GetTypeInfo(typeof(CopilotStudioAgentSession))) as CopilotStudioAgentSession + ?? new CopilotStudioAgentSession(); + } + + [DebuggerBrowsable(DebuggerBrowsableState.Never)] + private string DebuggerDisplay => + $"ConversationId = {this.ConversationId}, StateBag Count = {this.StateBag.Count}"; } diff --git a/dotnet/src/Microsoft.Agents.AI.CopilotStudio/CopilotStudioJsonUtilities.cs b/dotnet/src/Microsoft.Agents.AI.CopilotStudio/CopilotStudioJsonUtilities.cs new file mode 100644 index 0000000000..44177b0708 --- /dev/null +++ b/dotnet/src/Microsoft.Agents.AI.CopilotStudio/CopilotStudioJsonUtilities.cs @@ -0,0 +1,48 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Diagnostics.CodeAnalysis; +using System.Text.Encodings.Web; +using System.Text.Json; +using System.Text.Json.Serialization; + +namespace Microsoft.Agents.AI.CopilotStudio; + +/// +/// Provides utility methods and configurations for JSON serialization operations within the Copilot Studio agent implementation. +/// +internal static partial class CopilotStudioJsonUtilities +{ + /// + /// Gets the default instance used for JSON serialization operations. + /// + public static JsonSerializerOptions DefaultOptions { get; } = CreateDefaultOptions(); + + /// + /// Creates and configures the default JSON serialization options. + /// + /// The configured options. + private static JsonSerializerOptions CreateDefaultOptions() + { + // Copy the configuration from the source generated context. + JsonSerializerOptions options = new(JsonContext.Default.Options) + { + Encoder = JavaScriptEncoder.UnsafeRelaxedJsonEscaping, + }; + + // Chain in the resolvers from both AgentAbstractionsJsonUtilities and our source generated context. + options.TypeInfoResolverChain.Clear(); + options.TypeInfoResolverChain.Add(AgentAbstractionsJsonUtilities.DefaultOptions.TypeInfoResolver!); + options.TypeInfoResolverChain.Add(JsonContext.Default.Options.TypeInfoResolver!); + + options.MakeReadOnly(); + return options; + } + + [JsonSourceGenerationOptions(JsonSerializerDefaults.Web, + UseStringEnumConverter = true, + DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull, + NumberHandling = JsonNumberHandling.AllowReadingFromString)] + [JsonSerializable(typeof(CopilotStudioAgentSession))] + [ExcludeFromCodeCoverage] + private sealed partial class JsonContext : JsonSerializerContext; +} diff --git a/dotnet/src/Microsoft.Agents.AI.CosmosNoSql/CosmosChatHistoryProvider.cs b/dotnet/src/Microsoft.Agents.AI.CosmosNoSql/CosmosChatHistoryProvider.cs index 85c5865f07..1646f9216b 100644 --- a/dotnet/src/Microsoft.Agents.AI.CosmosNoSql/CosmosChatHistoryProvider.cs +++ b/dotnet/src/Microsoft.Agents.AI.CosmosNoSql/CosmosChatHistoryProvider.cs @@ -21,17 +21,15 @@ namespace Microsoft.Agents.AI; [RequiresDynamicCode("The CosmosChatHistoryProvider uses JSON serialization which is incompatible with NativeAOT.")] public sealed class CosmosChatHistoryProvider : ChatHistoryProvider, IDisposable { + private const string DefaultStateBagKey = "CosmosChatHistoryProvider.State"; + private readonly CosmosClient _cosmosClient; private readonly Container _container; private readonly bool _ownsClient; + private readonly string _stateKey; + private readonly Func _stateInitializer; private bool _disposed; - // Hierarchical partition key support - private readonly string? _tenantId; - private readonly string? _userId; - private readonly PartitionKey _partitionKey; - private readonly bool _useHierarchicalPartitioning; - /// /// Cached JSON serializer options for .NET 9.0 compatibility. /// @@ -72,11 +70,6 @@ private static JsonSerializerOptions CreateDefaultJsonOptions() /// public int? MessageTtlSeconds { get; set; } = 86400; - /// - /// Gets the conversation ID associated with this provider. - /// - public string ConversationId { get; init; } - /// /// Gets the database ID associated with this provider. /// @@ -88,36 +81,31 @@ private static JsonSerializerOptions CreateDefaultJsonOptions() public string ContainerId { get; init; } /// - /// Internal primary constructor used by all public constructors. + /// Initializes a new instance of the class. /// /// The instance to use for Cosmos DB operations. /// The identifier of the Cosmos DB database. /// The identifier of the Cosmos DB container. - /// The unique identifier for this conversation thread. + /// A delegate that initializes the provider state on the first invocation, providing the conversation routing info (conversationId, tenantId, userId). /// Whether this instance owns the CosmosClient and should dispose it. - /// Optional tenant identifier for hierarchical partitioning. - /// Optional user identifier for hierarchical partitioning. - internal CosmosChatHistoryProvider(CosmosClient cosmosClient, string databaseId, string containerId, string conversationId, bool ownsClient, string? tenantId = null, string? userId = null) + /// An optional key to use for storing the state in the . + /// Thrown when or is . + /// Thrown when any string parameter is null or whitespace. + public CosmosChatHistoryProvider( + CosmosClient cosmosClient, + string databaseId, + string containerId, + Func stateInitializer, + bool ownsClient = false, + string? stateKey = null) { this._cosmosClient = Throw.IfNull(cosmosClient); - this._container = this._cosmosClient.GetContainer(Throw.IfNullOrWhitespace(databaseId), Throw.IfNullOrWhitespace(containerId)); - this.ConversationId = Throw.IfNullOrWhitespace(conversationId); - this.DatabaseId = databaseId; - this.ContainerId = containerId; + this.DatabaseId = Throw.IfNullOrWhitespace(databaseId); + this.ContainerId = Throw.IfNullOrWhitespace(containerId); + this._container = this._cosmosClient.GetContainer(databaseId, containerId); + this._stateInitializer = Throw.IfNull(stateInitializer); this._ownsClient = ownsClient; - - // Initialize partitioning mode - this._tenantId = tenantId; - this._userId = userId; - this._useHierarchicalPartitioning = tenantId != null && userId != null; - - this._partitionKey = this._useHierarchicalPartitioning - ? new PartitionKeyBuilder() - .Add(tenantId!) - .Add(userId!) - .Add(conversationId) - .Build() - : new PartitionKey(conversationId); + this._stateKey = stateKey ?? DefaultStateBagKey; } /// @@ -126,24 +114,17 @@ internal CosmosChatHistoryProvider(CosmosClient cosmosClient, string databaseId, /// The Cosmos DB connection string. /// The identifier of the Cosmos DB database. /// The identifier of the Cosmos DB container. + /// A delegate that initializes the provider state on the first invocation. + /// An optional key to use for storing the state in the . /// Thrown when any required parameter is null. /// Thrown when any string parameter is null or whitespace. - public CosmosChatHistoryProvider(string connectionString, string databaseId, string containerId) - : this(connectionString, databaseId, containerId, Guid.NewGuid().ToString("N")) - { - } - - /// - /// Initializes a new instance of the class using a connection string. - /// - /// The Cosmos DB connection string. - /// The identifier of the Cosmos DB database. - /// The identifier of the Cosmos DB container. - /// The unique identifier for this conversation thread. - /// Thrown when any required parameter is null. - /// Thrown when any string parameter is null or whitespace. - public CosmosChatHistoryProvider(string connectionString, string databaseId, string containerId, string conversationId) - : this(new CosmosClient(Throw.IfNullOrWhitespace(connectionString)), databaseId, containerId, conversationId, ownsClient: true) + public CosmosChatHistoryProvider( + string connectionString, + string databaseId, + string containerId, + Func stateInitializer, + string? stateKey = null) + : this(new CosmosClient(Throw.IfNullOrWhitespace(connectionString)), databaseId, containerId, stateInitializer, ownsClient: true, stateKey) { } @@ -154,136 +135,63 @@ public CosmosChatHistoryProvider(string connectionString, string databaseId, str /// The TokenCredential to use for authentication (e.g., DefaultAzureCredential, ManagedIdentityCredential). /// The identifier of the Cosmos DB database. /// The identifier of the Cosmos DB container. + /// A delegate that initializes the provider state on the first invocation. + /// An optional key to use for storing the state in the . /// Thrown when any required parameter is null. /// Thrown when any string parameter is null or whitespace. - public CosmosChatHistoryProvider(string accountEndpoint, TokenCredential tokenCredential, string databaseId, string containerId) - : this(accountEndpoint, tokenCredential, databaseId, containerId, Guid.NewGuid().ToString("N")) + public CosmosChatHistoryProvider( + string accountEndpoint, + TokenCredential tokenCredential, + string databaseId, + string containerId, + Func stateInitializer, + string? stateKey = null) + : this(new CosmosClient(Throw.IfNullOrWhitespace(accountEndpoint), Throw.IfNull(tokenCredential)), databaseId, containerId, stateInitializer, ownsClient: true, stateKey) { } /// - /// Initializes a new instance of the class using a TokenCredential for authentication. + /// Gets the state from the session's StateBag, or initializes it using the state initializer if not present. /// - /// The Cosmos DB account endpoint URI. - /// The TokenCredential to use for authentication (e.g., DefaultAzureCredential, ManagedIdentityCredential). - /// The identifier of the Cosmos DB database. - /// The identifier of the Cosmos DB container. - /// The unique identifier for this conversation thread. - /// Thrown when any required parameter is null. - /// Thrown when any string parameter is null or whitespace. - public CosmosChatHistoryProvider(string accountEndpoint, TokenCredential tokenCredential, string databaseId, string containerId, string conversationId) - : this(new CosmosClient(Throw.IfNullOrWhitespace(accountEndpoint), Throw.IfNull(tokenCredential)), databaseId, containerId, conversationId, ownsClient: true) + /// The agent session containing the StateBag. + /// The provider state, or null if no session is available. + private State GetOrInitializeState(AgentSession? session) { - } + if (session?.StateBag.TryGetValue(this._stateKey, out var state, AgentAbstractionsJsonUtilities.DefaultOptions) is true && state is not null) + { + return state; + } - /// - /// Initializes a new instance of the class using an existing . - /// - /// The instance to use for Cosmos DB operations. - /// The identifier of the Cosmos DB database. - /// The identifier of the Cosmos DB container. - /// Thrown when is null. - /// Thrown when any string parameter is null or whitespace. - public CosmosChatHistoryProvider(CosmosClient cosmosClient, string databaseId, string containerId) - : this(cosmosClient, databaseId, containerId, Guid.NewGuid().ToString("N")) - { - } + state = this._stateInitializer(session); + if (session is not null) + { + session.StateBag.SetValue(this._stateKey, state, AgentAbstractionsJsonUtilities.DefaultOptions); + } - /// - /// Initializes a new instance of the class using an existing . - /// - /// The instance to use for Cosmos DB operations. - /// The identifier of the Cosmos DB database. - /// The identifier of the Cosmos DB container. - /// The unique identifier for this conversation thread. - /// Thrown when is null. - /// Thrown when any string parameter is null or whitespace. - public CosmosChatHistoryProvider(CosmosClient cosmosClient, string databaseId, string containerId, string conversationId) - : this(cosmosClient, databaseId, containerId, conversationId, ownsClient: false) - { + return state; } /// - /// Initializes a new instance of the class using a connection string with hierarchical partition keys. + /// Determines whether hierarchical partitioning should be used based on the state. /// - /// The Cosmos DB connection string. - /// The identifier of the Cosmos DB database. - /// The identifier of the Cosmos DB container. - /// The tenant identifier for hierarchical partitioning. - /// The user identifier for hierarchical partitioning. - /// The session identifier for hierarchical partitioning. - /// Thrown when any required parameter is null. - /// Thrown when any string parameter is null or whitespace. - public CosmosChatHistoryProvider(string connectionString, string databaseId, string containerId, string tenantId, string userId, string sessionId) - : this(new CosmosClient(Throw.IfNullOrWhitespace(connectionString)), databaseId, containerId, Throw.IfNullOrWhitespace(sessionId), ownsClient: true, Throw.IfNullOrWhitespace(tenantId), Throw.IfNullOrWhitespace(userId)) - { - } + private static bool UseHierarchicalPartitioning(State state) => + state.TenantId is not null && state.UserId is not null; /// - /// Initializes a new instance of the class using a TokenCredential for authentication with hierarchical partition keys. + /// Builds the partition key from the state. /// - /// The Cosmos DB account endpoint URI. - /// The TokenCredential to use for authentication (e.g., DefaultAzureCredential, ManagedIdentityCredential). - /// The identifier of the Cosmos DB database. - /// The identifier of the Cosmos DB container. - /// The tenant identifier for hierarchical partitioning. - /// The user identifier for hierarchical partitioning. - /// The session identifier for hierarchical partitioning. - /// Thrown when any required parameter is null. - /// Thrown when any string parameter is null or whitespace. - public CosmosChatHistoryProvider(string accountEndpoint, TokenCredential tokenCredential, string databaseId, string containerId, string tenantId, string userId, string sessionId) - : this(new CosmosClient(Throw.IfNullOrWhitespace(accountEndpoint), Throw.IfNull(tokenCredential)), databaseId, containerId, Throw.IfNullOrWhitespace(sessionId), ownsClient: true, Throw.IfNullOrWhitespace(tenantId), Throw.IfNullOrWhitespace(userId)) + private static PartitionKey BuildPartitionKey(State state) { - } - - /// - /// Initializes a new instance of the class using an existing with hierarchical partition keys. - /// - /// The instance to use for Cosmos DB operations. - /// The identifier of the Cosmos DB database. - /// The identifier of the Cosmos DB container. - /// The tenant identifier for hierarchical partitioning. - /// The user identifier for hierarchical partitioning. - /// The session identifier for hierarchical partitioning. - /// Thrown when is null. - /// Thrown when any string parameter is null or whitespace. - public CosmosChatHistoryProvider(CosmosClient cosmosClient, string databaseId, string containerId, string tenantId, string userId, string sessionId) - : this(cosmosClient, databaseId, containerId, Throw.IfNullOrWhitespace(sessionId), ownsClient: false, Throw.IfNullOrWhitespace(tenantId), Throw.IfNullOrWhitespace(userId)) - { - } - - /// - /// Creates a new instance of the class from previously serialized state. - /// - /// The instance to use for Cosmos DB operations. - /// A representing the serialized state of the provider. - /// The identifier of the Cosmos DB database. - /// The identifier of the Cosmos DB container. - /// Optional settings for customizing the JSON deserialization process. - /// A new instance of initialized from the serialized state. - /// Thrown when is null. - /// Thrown when the serialized state cannot be deserialized. - public static CosmosChatHistoryProvider CreateFromSerializedState(CosmosClient cosmosClient, JsonElement serializedState, string databaseId, string containerId, JsonSerializerOptions? jsonSerializerOptions = null) - { - Throw.IfNull(cosmosClient); - Throw.IfNullOrWhitespace(databaseId); - Throw.IfNullOrWhitespace(containerId); - - if (serializedState.ValueKind is not JsonValueKind.Object) + if (UseHierarchicalPartitioning(state)) { - throw new ArgumentException("Invalid serialized state", nameof(serializedState)); + return new PartitionKeyBuilder() + .Add(state.TenantId) + .Add(state.UserId) + .Add(state.ConversationId) + .Build(); } - var state = serializedState.Deserialize(jsonSerializerOptions); - if (state?.ConversationIdentifier is not { } conversationId) - { - throw new ArgumentException("Invalid serialized state", nameof(serializedState)); - } - - // Use the internal constructor with all parameters to ensure partition key logic is centralized - return state.UseHierarchicalPartitioning && state.TenantId != null && state.UserId != null - ? new CosmosChatHistoryProvider(cosmosClient, databaseId, containerId, conversationId, ownsClient: false, state.TenantId, state.UserId) - : new CosmosChatHistoryProvider(cosmosClient, databaseId, containerId, conversationId, ownsClient: false); + return new PartitionKey(state.ConversationId); } /// @@ -296,15 +204,20 @@ protected override async ValueTask> InvokingCoreAsync(I } #pragma warning restore CA1513 + _ = Throw.IfNull(context); + + var state = this.GetOrInitializeState(context.Session); + var partitionKey = BuildPartitionKey(state); + // Fetch most recent messages in descending order when limit is set, then reverse to ascending var orderDirection = this.MaxMessagesToRetrieve.HasValue ? "DESC" : "ASC"; var query = new QueryDefinition($"SELECT * FROM c WHERE c.conversationId = @conversationId AND c.type = @type ORDER BY c.timestamp {orderDirection}") - .WithParameter("@conversationId", this.ConversationId) + .WithParameter("@conversationId", state.ConversationId) .WithParameter("@type", "ChatMessage"); var iterator = this._container.GetItemQueryIterator(query, requestOptions: new QueryRequestOptions { - PartitionKey = this._partitionKey, + PartitionKey = partitionKey, MaxItemCount = this.MaxItemCount // Configurable query performance }); @@ -364,27 +277,30 @@ protected override async ValueTask InvokedCoreAsync(InvokedContext context, Canc } #pragma warning restore CA1513 + var state = this.GetOrInitializeState(context.Session); var messageList = context.RequestMessages.Concat(context.ResponseMessages ?? []).ToList(); if (messageList.Count == 0) { return; } + var partitionKey = BuildPartitionKey(state); + // Use transactional batch for atomic operations if (messageList.Count > 1) { - await this.AddMessagesInBatchAsync(messageList, cancellationToken).ConfigureAwait(false); + await this.AddMessagesInBatchAsync(partitionKey, state, messageList, cancellationToken).ConfigureAwait(false); } else { - await this.AddSingleMessageAsync(messageList.First(), cancellationToken).ConfigureAwait(false); + await this.AddSingleMessageAsync(partitionKey, state, messageList.First(), cancellationToken).ConfigureAwait(false); } } /// /// Adds multiple messages using transactional batch operations for atomicity. /// - private async Task AddMessagesInBatchAsync(List messages, CancellationToken cancellationToken) + private async Task AddMessagesInBatchAsync(PartitionKey partitionKey, State state, List messages, CancellationToken cancellationToken) { var currentTimestamp = DateTimeOffset.UtcNow.ToUnixTimeSeconds(); @@ -392,7 +308,7 @@ private async Task AddMessagesInBatchAsync(List messages, Cancellat for (int i = 0; i < messages.Count; i += this.MaxBatchSize) { var batchMessages = messages.Skip(i).Take(this.MaxBatchSize).ToList(); - await this.ExecuteBatchOperationAsync(batchMessages, currentTimestamp, cancellationToken).ConfigureAwait(false); + await this.ExecuteBatchOperationAsync(partitionKey, state, batchMessages, currentTimestamp, cancellationToken).ConfigureAwait(false); } } @@ -400,13 +316,13 @@ private async Task AddMessagesInBatchAsync(List messages, Cancellat /// Executes a single batch operation with enhanced error handling. /// Cosmos SDK handles throttling (429) retries automatically. /// - private async Task ExecuteBatchOperationAsync(List messages, long timestamp, CancellationToken cancellationToken) + private async Task ExecuteBatchOperationAsync(PartitionKey partitionKey, State state, List messages, long timestamp, CancellationToken cancellationToken) { // Create all documents upfront for validation and batch operation var documents = new List(messages.Count); foreach (var message in messages) { - documents.Add(this.CreateMessageDocument(message, timestamp)); + documents.Add(this.CreateMessageDocument(state, message, timestamp)); } // Defensive check: Verify all messages share the same partition key values @@ -414,7 +330,7 @@ private async Task ExecuteBatchOperationAsync(List messages, long t // In simple partitioning, this means same conversationId if (documents.Count > 0) { - if (this._useHierarchicalPartitioning) + if (UseHierarchicalPartitioning(state)) { // Verify all documents have matching hierarchical partition key components var firstDoc = documents[0]; @@ -436,7 +352,7 @@ private async Task ExecuteBatchOperationAsync(List messages, long t // All messages in this store share the same partition key by design // Transactional batches require all items to share the same partition key - var batch = this._container.CreateTransactionalBatch(this._partitionKey); + var batch = this._container.CreateTransactionalBatch(partitionKey); foreach (var document in documents) { @@ -457,7 +373,7 @@ private async Task ExecuteBatchOperationAsync(List messages, long t if (messages.Count == 1) { // Can't split further, use single operation - await this.AddSingleMessageAsync(messages[0], cancellationToken).ConfigureAwait(false); + await this.AddSingleMessageAsync(partitionKey, state, messages[0], cancellationToken).ConfigureAwait(false); return; } @@ -466,21 +382,21 @@ private async Task ExecuteBatchOperationAsync(List messages, long t var firstHalf = messages.Take(midpoint).ToList(); var secondHalf = messages.Skip(midpoint).ToList(); - await this.ExecuteBatchOperationAsync(firstHalf, timestamp, cancellationToken).ConfigureAwait(false); - await this.ExecuteBatchOperationAsync(secondHalf, timestamp, cancellationToken).ConfigureAwait(false); + await this.ExecuteBatchOperationAsync(partitionKey, state, firstHalf, timestamp, cancellationToken).ConfigureAwait(false); + await this.ExecuteBatchOperationAsync(partitionKey, state, secondHalf, timestamp, cancellationToken).ConfigureAwait(false); } } /// /// Adds a single message to the store. /// - private async Task AddSingleMessageAsync(ChatMessage message, CancellationToken cancellationToken) + private async Task AddSingleMessageAsync(PartitionKey partitionKey, State state, ChatMessage message, CancellationToken cancellationToken) { - var document = this.CreateMessageDocument(message, DateTimeOffset.UtcNow.ToUnixTimeSeconds()); + var document = this.CreateMessageDocument(state, message, DateTimeOffset.UtcNow.ToUnixTimeSeconds()); try { - await this._container.CreateItemAsync(document, this._partitionKey, cancellationToken: cancellationToken).ConfigureAwait(false); + await this._container.CreateItemAsync(document, partitionKey, cancellationToken: cancellationToken).ConfigureAwait(false); } catch (CosmosException ex) when (ex.StatusCode == System.Net.HttpStatusCode.RequestEntityTooLarge) { @@ -495,12 +411,14 @@ private async Task AddSingleMessageAsync(ChatMessage message, CancellationToken /// /// Creates a message document with enhanced metadata. /// - private CosmosMessageDocument CreateMessageDocument(ChatMessage message, long timestamp) + private CosmosMessageDocument CreateMessageDocument(State state, ChatMessage message, long timestamp) { + var useHierarchical = UseHierarchicalPartitioning(state); + return new CosmosMessageDocument { Id = Guid.NewGuid().ToString(), - ConversationId = this.ConversationId, + ConversationId = state.ConversationId, Timestamp = timestamp, MessageId = message.MessageId, Role = message.Role.Value, @@ -508,41 +426,20 @@ private CosmosMessageDocument CreateMessageDocument(ChatMessage message, long ti Type = "ChatMessage", // Type discriminator Ttl = this.MessageTtlSeconds, // Configurable TTL // Include hierarchical metadata when using hierarchical partitioning - TenantId = this._useHierarchicalPartitioning ? this._tenantId : null, - UserId = this._useHierarchicalPartitioning ? this._userId : null, - SessionId = this._useHierarchicalPartitioning ? this.ConversationId : null + TenantId = useHierarchical ? state.TenantId : null, + UserId = useHierarchical ? state.UserId : null, + SessionId = useHierarchical ? state.ConversationId : null }; } - /// - public override JsonElement Serialize(JsonSerializerOptions? jsonSerializerOptions = null) - { -#pragma warning disable CA1513 // Use ObjectDisposedException.ThrowIf - not available on all target frameworks - if (this._disposed) - { - throw new ObjectDisposedException(this.GetType().FullName); - } -#pragma warning restore CA1513 - - var state = new State - { - ConversationIdentifier = this.ConversationId, - TenantId = this._tenantId, - UserId = this._userId, - UseHierarchicalPartitioning = this._useHierarchicalPartitioning - }; - - var options = jsonSerializerOptions ?? s_defaultJsonOptions; - return JsonSerializer.SerializeToElement(state, options); - } - /// /// Gets the count of messages in this conversation. /// This is an additional utility method beyond the base contract. /// + /// The agent session to get state from. /// The cancellation token. /// The number of messages in the conversation. - public async Task GetMessageCountAsync(CancellationToken cancellationToken = default) + public async Task GetMessageCountAsync(AgentSession? session, CancellationToken cancellationToken = default) { #pragma warning disable CA1513 // Use ObjectDisposedException.ThrowIf - not available on all target frameworks if (this._disposed) @@ -551,14 +448,17 @@ public async Task GetMessageCountAsync(CancellationToken cancellationToken } #pragma warning restore CA1513 + var state = this.GetOrInitializeState(session); + var partitionKey = BuildPartitionKey(state); + // Efficient count query var query = new QueryDefinition("SELECT VALUE COUNT(1) FROM c WHERE c.conversationId = @conversationId AND c.Type = @type") - .WithParameter("@conversationId", this.ConversationId) + .WithParameter("@conversationId", state.ConversationId) .WithParameter("@type", "ChatMessage"); var iterator = this._container.GetItemQueryIterator(query, requestOptions: new QueryRequestOptions { - PartitionKey = this._partitionKey + PartitionKey = partitionKey }); // COUNT queries always return a result @@ -570,9 +470,10 @@ public async Task GetMessageCountAsync(CancellationToken cancellationToken /// Deletes all messages in this conversation. /// This is an additional utility method beyond the base contract. /// + /// The agent session to get state from. /// The cancellation token. /// The number of messages deleted. - public async Task ClearMessagesAsync(CancellationToken cancellationToken = default) + public async Task ClearMessagesAsync(AgentSession? session, CancellationToken cancellationToken = default) { #pragma warning disable CA1513 // Use ObjectDisposedException.ThrowIf - not available on all target frameworks if (this._disposed) @@ -581,14 +482,17 @@ public async Task ClearMessagesAsync(CancellationToken cancellationToken = } #pragma warning restore CA1513 + var state = this.GetOrInitializeState(session); + var partitionKey = BuildPartitionKey(state); + // Batch delete for efficiency var query = new QueryDefinition("SELECT VALUE c.id FROM c WHERE c.conversationId = @conversationId AND c.Type = @type") - .WithParameter("@conversationId", this.ConversationId) + .WithParameter("@conversationId", state.ConversationId) .WithParameter("@type", "ChatMessage"); var iterator = this._container.GetItemQueryIterator(query, requestOptions: new QueryRequestOptions { - PartitionKey = this._partitionKey, + PartitionKey = partitionKey, MaxItemCount = this.MaxItemCount }); @@ -597,7 +501,7 @@ public async Task ClearMessagesAsync(CancellationToken cancellationToken = while (iterator.HasMoreResults) { var response = await iterator.ReadNextAsync(cancellationToken).ConfigureAwait(false); - var batch = this._container.CreateTransactionalBatch(this._partitionKey); + var batch = this._container.CreateTransactionalBatch(partitionKey); var batchItemCount = 0; foreach (var itemId in response) @@ -632,12 +536,38 @@ public void Dispose() } } - private sealed class State + /// + /// Represents the per-session state of a stored in the . + /// + public sealed class State { - public string ConversationIdentifier { get; set; } = string.Empty; - public string? TenantId { get; set; } - public string? UserId { get; set; } - public bool UseHierarchicalPartitioning { get; set; } + /// + /// Initializes a new instance of the class. + /// + /// The unique identifier for this conversation thread. + /// Optional tenant identifier for hierarchical partitioning. + /// Optional user identifier for hierarchical partitioning. + public State(string conversationId, string? tenantId = null, string? userId = null) + { + this.ConversationId = Throw.IfNullOrWhitespace(conversationId); + this.TenantId = tenantId; + this.UserId = userId; + } + + /// + /// Gets the conversation ID associated with this state. + /// + public string ConversationId { get; } + + /// + /// Gets the tenant identifier for hierarchical partitioning, if any. + /// + public string? TenantId { get; } + + /// + /// Gets the user identifier for hierarchical partitioning, if any. + /// + public string? UserId { get; } } /// diff --git a/dotnet/src/Microsoft.Agents.AI.CosmosNoSql/CosmosDBChatExtensions.cs b/dotnet/src/Microsoft.Agents.AI.CosmosNoSql/CosmosDBChatExtensions.cs index 3d93e9dd6a..76b865e4c8 100644 --- a/dotnet/src/Microsoft.Agents.AI.CosmosNoSql/CosmosDBChatExtensions.cs +++ b/dotnet/src/Microsoft.Agents.AI.CosmosNoSql/CosmosDBChatExtensions.cs @@ -2,7 +2,6 @@ using System; using System.Diagnostics.CodeAnalysis; -using System.Threading.Tasks; using Azure.Core; using Microsoft.Azure.Cosmos; @@ -13,6 +12,9 @@ namespace Microsoft.Agents.AI; /// public static class CosmosDBChatExtensions { + private static readonly Func s_defaultStateInitializer = + _ => new CosmosChatHistoryProvider.State(Guid.NewGuid().ToString("N")); + /// /// Configures the agent to use Cosmos DB for message storage with connection string authentication. /// @@ -20,6 +22,7 @@ public static class CosmosDBChatExtensions /// The Cosmos DB connection string. /// The identifier of the Cosmos DB database. /// The identifier of the Cosmos DB container. + /// An optional delegate that initializes the provider state on the first invocation, providing the conversation routing info (conversationId, tenantId, userId). When not provided, a new conversation ID is generated automatically. /// The configured . /// Thrown when is null. /// Thrown when any string parameter is null or whitespace. @@ -29,14 +32,16 @@ public static ChatClientAgentOptions WithCosmosDBChatHistoryProvider( this ChatClientAgentOptions options, string connectionString, string databaseId, - string containerId) + string containerId, + Func? stateInitializer = null) { if (options is null) { throw new ArgumentNullException(nameof(options)); } - options.ChatHistoryProviderFactory = (context, ct) => new ValueTask(new CosmosChatHistoryProvider(connectionString, databaseId, containerId)); + options.ChatHistoryProvider = + new CosmosChatHistoryProvider(connectionString, databaseId, containerId, stateInitializer ?? s_defaultStateInitializer); return options; } @@ -48,6 +53,7 @@ public static ChatClientAgentOptions WithCosmosDBChatHistoryProvider( /// The identifier of the Cosmos DB database. /// The identifier of the Cosmos DB container. /// The TokenCredential to use for authentication (e.g., DefaultAzureCredential, ManagedIdentityCredential). + /// An optional delegate that initializes the provider state on the first invocation, providing the conversation routing info (conversationId, tenantId, userId). When not provided, a new conversation ID is generated automatically. /// The configured . /// Thrown when or is null. /// Thrown when any string parameter is null or whitespace. @@ -58,7 +64,8 @@ public static ChatClientAgentOptions WithCosmosDBChatHistoryProviderUsingManaged string accountEndpoint, string databaseId, string containerId, - TokenCredential tokenCredential) + TokenCredential tokenCredential, + Func? stateInitializer = null) { if (options is null) { @@ -70,7 +77,8 @@ public static ChatClientAgentOptions WithCosmosDBChatHistoryProviderUsingManaged throw new ArgumentNullException(nameof(tokenCredential)); } - options.ChatHistoryProviderFactory = (context, ct) => new ValueTask(new CosmosChatHistoryProvider(accountEndpoint, tokenCredential, databaseId, containerId)); + options.ChatHistoryProvider = + new CosmosChatHistoryProvider(accountEndpoint, tokenCredential, databaseId, containerId, stateInitializer ?? s_defaultStateInitializer); return options; } @@ -81,6 +89,7 @@ public static ChatClientAgentOptions WithCosmosDBChatHistoryProviderUsingManaged /// The instance to use for Cosmos DB operations. /// The identifier of the Cosmos DB database. /// The identifier of the Cosmos DB container. + /// An optional delegate that initializes the provider state on the first invocation, providing the conversation routing info (conversationId, tenantId, userId). When not provided, a new conversation ID is generated automatically. /// The configured . /// Thrown when any required parameter is null. /// Thrown when any string parameter is null or whitespace. @@ -90,14 +99,16 @@ public static ChatClientAgentOptions WithCosmosDBChatHistoryProvider( this ChatClientAgentOptions options, CosmosClient cosmosClient, string databaseId, - string containerId) + string containerId, + Func? stateInitializer = null) { if (options is null) { throw new ArgumentNullException(nameof(options)); } - options.ChatHistoryProviderFactory = (context, ct) => new ValueTask(new CosmosChatHistoryProvider(cosmosClient, databaseId, containerId)); + options.ChatHistoryProvider = + new CosmosChatHistoryProvider(cosmosClient, databaseId, containerId, stateInitializer ?? s_defaultStateInitializer); return options; } } diff --git a/dotnet/src/Microsoft.Agents.AI.DurableTask/DurableAgentSession.cs b/dotnet/src/Microsoft.Agents.AI.DurableTask/DurableAgentSession.cs index b9d9807728..ba33c15d32 100644 --- a/dotnet/src/Microsoft.Agents.AI.DurableTask/DurableAgentSession.cs +++ b/dotnet/src/Microsoft.Agents.AI.DurableTask/DurableAgentSession.cs @@ -7,17 +7,22 @@ namespace Microsoft.Agents.AI.DurableTask; /// -/// An agent thread implementation for durable agents. +/// An implementation for durable agents. /// -[DebuggerDisplay("{SessionId}")] +[DebuggerDisplay("{DebuggerDisplay,nq}")] public sealed class DurableAgentSession : AgentSession { - [JsonConstructor] internal DurableAgentSession(AgentSessionId sessionId) { this.SessionId = sessionId; } + [JsonConstructor] + internal DurableAgentSession(AgentSessionId sessionId, AgentSessionStateBag stateBag) : base(stateBag) + { + this.SessionId = sessionId; + } + /// /// Gets the agent session ID. /// @@ -28,9 +33,8 @@ internal DurableAgentSession(AgentSessionId sessionId) /// internal JsonElement Serialize(JsonSerializerOptions? jsonSerializerOptions = null) { - return JsonSerializer.SerializeToElement( - this, - DurableAgentJsonUtilities.DefaultOptions.GetTypeInfo(typeof(DurableAgentSession))); + var jso = jsonSerializerOptions ?? DurableAgentJsonUtilities.DefaultOptions; + return JsonSerializer.SerializeToElement(this, jso.GetTypeInfo(typeof(DurableAgentSession))); } /// @@ -49,7 +53,11 @@ internal static DurableAgentSession Deserialize(JsonElement serializedSession, J string sessionIdString = sessionIdElement.GetString() ?? throw new JsonException("sessionId property is null."); AgentSessionId sessionId = AgentSessionId.Parse(sessionIdString); - return new DurableAgentSession(sessionId); + AgentSessionStateBag stateBag = serializedSession.TryGetProperty("stateBag", out JsonElement stateBagElement) + ? AgentSessionStateBag.Deserialize(stateBagElement) + : new AgentSessionStateBag(); + + return new DurableAgentSession(sessionId, stateBag); } /// @@ -68,4 +76,8 @@ public override string ToString() { return this.SessionId.ToString(); } + + [DebuggerBrowsable(DebuggerBrowsableState.Never)] + private string DebuggerDisplay => + $"SessionId = {this.SessionId}, StateBag Count = {this.StateBag.Count}"; } diff --git a/dotnet/src/Microsoft.Agents.AI.GitHub.Copilot/GitHubCopilotAgent.cs b/dotnet/src/Microsoft.Agents.AI.GitHub.Copilot/GitHubCopilotAgent.cs index 06c7f24ee2..0556430636 100644 --- a/dotnet/src/Microsoft.Agents.AI.GitHub.Copilot/GitHubCopilotAgent.cs +++ b/dotnet/src/Microsoft.Agents.AI.GitHub.Copilot/GitHubCopilotAgent.cs @@ -115,7 +115,7 @@ protected override ValueTask DeserializeSessionCoreAsync( JsonElement serializedState, JsonSerializerOptions? jsonSerializerOptions = null, CancellationToken cancellationToken = default) - => new(new GitHubCopilotAgentSession(serializedState, jsonSerializerOptions)); + => new(GitHubCopilotAgentSession.Deserialize(serializedState, jsonSerializerOptions)); /// protected override Task RunCoreAsync( diff --git a/dotnet/src/Microsoft.Agents.AI.GitHub.Copilot/GitHubCopilotAgentSession.cs b/dotnet/src/Microsoft.Agents.AI.GitHub.Copilot/GitHubCopilotAgentSession.cs index f514eeb71b..70fe43425e 100644 --- a/dotnet/src/Microsoft.Agents.AI.GitHub.Copilot/GitHubCopilotAgentSession.cs +++ b/dotnet/src/Microsoft.Agents.AI.GitHub.Copilot/GitHubCopilotAgentSession.cs @@ -1,17 +1,22 @@ // Copyright (c) Microsoft. All rights reserved. +using System; +using System.Diagnostics; using System.Text.Json; +using System.Text.Json.Serialization; namespace Microsoft.Agents.AI.GitHub.Copilot; /// /// Represents a session for a GitHub Copilot agent conversation. /// +[DebuggerDisplay("{DebuggerDisplay,nq}")] public sealed class GitHubCopilotAgentSession : AgentSession { /// /// Gets or sets the session ID for the GitHub Copilot conversation. /// + [JsonPropertyName("sessionId")] public string? SessionId { get; internal set; } /// @@ -21,35 +26,32 @@ internal GitHubCopilotAgentSession() { } - /// - /// Initializes a new instance of the class from serialized data. - /// - /// The serialized thread data. - /// Optional JSON serialization options. - internal GitHubCopilotAgentSession(JsonElement serializedThread, JsonSerializerOptions? jsonSerializerOptions = null) + [JsonConstructor] + internal GitHubCopilotAgentSession(string? sessionId, AgentSessionStateBag? stateBag) : base(stateBag ?? new()) { - // The JSON serialization uses camelCase - if (serializedThread.TryGetProperty("sessionId", out JsonElement sessionIdElement)) - { - this.SessionId = sessionIdElement.GetString(); - } + this.SessionId = sessionId; } /// internal JsonElement Serialize(JsonSerializerOptions? jsonSerializerOptions = null) { - State state = new() - { - SessionId = this.SessionId - }; - - return JsonSerializer.SerializeToElement( - state, - GitHubCopilotJsonUtilities.DefaultOptions.GetTypeInfo(typeof(State))); + var jso = jsonSerializerOptions ?? GitHubCopilotJsonUtilities.DefaultOptions; + return JsonSerializer.SerializeToElement(this, jso.GetTypeInfo(typeof(GitHubCopilotAgentSession))); } - internal sealed class State + internal static GitHubCopilotAgentSession Deserialize(JsonElement serializedState, JsonSerializerOptions? jsonSerializerOptions = null) { - public string? SessionId { get; set; } + if (serializedState.ValueKind != JsonValueKind.Object) + { + throw new ArgumentException("The serialized session state must be a JSON object.", nameof(serializedState)); + } + + var jso = jsonSerializerOptions ?? GitHubCopilotJsonUtilities.DefaultOptions; + return serializedState.Deserialize(jso.GetTypeInfo(typeof(GitHubCopilotAgentSession))) as GitHubCopilotAgentSession + ?? new GitHubCopilotAgentSession(); } + + [DebuggerBrowsable(DebuggerBrowsableState.Never)] + private string DebuggerDisplay => + $"SessionId = {this.SessionId}, StateBag Count = {this.StateBag.Count}"; } diff --git a/dotnet/src/Microsoft.Agents.AI.GitHub.Copilot/GitHubCopilotJsonUtilities.cs b/dotnet/src/Microsoft.Agents.AI.GitHub.Copilot/GitHubCopilotJsonUtilities.cs index c5254efd6b..9e97c0585b 100644 --- a/dotnet/src/Microsoft.Agents.AI.GitHub.Copilot/GitHubCopilotJsonUtilities.cs +++ b/dotnet/src/Microsoft.Agents.AI.GitHub.Copilot/GitHubCopilotJsonUtilities.cs @@ -42,7 +42,7 @@ private static JsonSerializerOptions CreateDefaultOptions() UseStringEnumConverter = true, DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull, NumberHandling = JsonNumberHandling.AllowReadingFromString)] - [JsonSerializable(typeof(GitHubCopilotAgentSession.State))] + [JsonSerializable(typeof(GitHubCopilotAgentSession))] [ExcludeFromCodeCoverage] private sealed partial class JsonContext : JsonSerializerContext; } diff --git a/dotnet/src/Microsoft.Agents.AI.Mem0/Mem0JsonUtilities.cs b/dotnet/src/Microsoft.Agents.AI.Mem0/Mem0JsonUtilities.cs index d139cb0f76..33f92f3ac2 100644 --- a/dotnet/src/Microsoft.Agents.AI.Mem0/Mem0JsonUtilities.cs +++ b/dotnet/src/Microsoft.Agents.AI.Mem0/Mem0JsonUtilities.cs @@ -65,7 +65,7 @@ private static JsonSerializerOptions CreateDefaultOptions() NumberHandling = JsonNumberHandling.AllowReadingFromString)] // Agent abstraction types - [JsonSerializable(typeof(Mem0Provider.Mem0State))] + [JsonSerializable(typeof(Mem0Provider.State))] [ExcludeFromCodeCoverage] internal sealed partial class JsonContext : JsonSerializerContext; diff --git a/dotnet/src/Microsoft.Agents.AI.Mem0/Mem0Provider.cs b/dotnet/src/Microsoft.Agents.AI.Mem0/Mem0Provider.cs index 8a0c016f07..7e645971d8 100644 --- a/dotnet/src/Microsoft.Agents.AI.Mem0/Mem0Provider.cs +++ b/dotnet/src/Microsoft.Agents.AI.Mem0/Mem0Provider.cs @@ -4,7 +4,6 @@ using System.Collections.Generic; using System.Linq; using System.Net.Http; -using System.Text.Json; using System.Text.Json.Serialization; using System.Threading; using System.Threading.Tasks; @@ -26,24 +25,24 @@ namespace Microsoft.Agents.AI.Mem0; public sealed class Mem0Provider : AIContextProvider { private const string DefaultContextPrompt = "## Memories\nConsider the following memories when answering user questions:"; + private const string DefaultStateBagKey = "Mem0Provider.State"; private readonly string _contextPrompt; private readonly bool _enableSensitiveTelemetryData; + private readonly string _stateKey; + private readonly Func _stateInitializer; private readonly Mem0Client _client; private readonly ILogger? _logger; - private readonly Mem0ProviderScope _storageScope; - private readonly Mem0ProviderScope _searchScope; - /// /// Initializes a new instance of the class. /// /// Configured (base address + auth). - /// Optional values to scope the memory storage with. - /// Optional values to scope the memory search with. Defaults to if not provided. + /// A delegate that initializes the provider state on the first invocation, providing the storage and search scopes. /// Provider options. /// Optional logger factory. + /// Thrown when or is . /// /// The base address of the required mem0 service, and any authentication headers, should be set on the /// already, when passed as a parameter here. E.g.: @@ -51,83 +50,55 @@ public sealed class Mem0Provider : AIContextProvider /// using var httpClient = new HttpClient(); /// httpClient.BaseAddress = new Uri("https://api.mem0.ai"); /// httpClient.DefaultRequestHeaders.Authorization = new AuthenticationHeaderValue("Token", "<Your APIKey>"); - /// new Mem0AIContextProvider(httpClient); + /// new Mem0Provider(httpClient); /// /// - public Mem0Provider(HttpClient httpClient, Mem0ProviderScope storageScope, Mem0ProviderScope? searchScope = null, Mem0ProviderOptions? options = null, ILoggerFactory? loggerFactory = null) + public Mem0Provider(HttpClient httpClient, Func stateInitializer, Mem0ProviderOptions? options = null, ILoggerFactory? loggerFactory = null) { + Throw.IfNull(httpClient); if (string.IsNullOrWhiteSpace(httpClient.BaseAddress?.AbsoluteUri)) { throw new ArgumentException("The HttpClient BaseAddress must be set for Mem0 operations.", nameof(httpClient)); } + this._stateInitializer = Throw.IfNull(stateInitializer); this._logger = loggerFactory?.CreateLogger(); this._client = new Mem0Client(httpClient); this._contextPrompt = options?.ContextPrompt ?? DefaultContextPrompt; this._enableSensitiveTelemetryData = options?.EnableSensitiveTelemetryData ?? false; - this._storageScope = new Mem0ProviderScope(Throw.IfNull(storageScope)); - this._searchScope = searchScope ?? storageScope; - - if (string.IsNullOrWhiteSpace(this._storageScope.ApplicationId) - && string.IsNullOrWhiteSpace(this._storageScope.AgentId) - && string.IsNullOrWhiteSpace(this._storageScope.ThreadId) - && string.IsNullOrWhiteSpace(this._storageScope.UserId)) - { - throw new ArgumentException("At least one of ApplicationId, AgentId, ThreadId, or UserId must be provided for the storage scope."); - } - - if (string.IsNullOrWhiteSpace(this._searchScope.ApplicationId) - && string.IsNullOrWhiteSpace(this._searchScope.AgentId) - && string.IsNullOrWhiteSpace(this._searchScope.ThreadId) - && string.IsNullOrWhiteSpace(this._searchScope.UserId)) - { - throw new ArgumentException("At least one of ApplicationId, AgentId, ThreadId, or UserId must be provided for the search scope."); - } + this._stateKey = options?.StateKey ?? DefaultStateBagKey; } /// - /// Initializes a new instance of the class, with existing state from a serialized JSON element. + /// Gets the state from the session's StateBag, or initializes it using the StateInitializer if not present. /// - /// Configured (base address + auth). - /// A representing the serialized state of the store. - /// Optional settings for customizing the JSON deserialization process. - /// Provider options. - /// Optional logger factory. - /// - /// - /// The base address of the required mem0 service, and any authentication headers, should be set on the - /// already, when passed as a parameter here. E.g.: - /// - /// using var httpClient = new HttpClient(); - /// httpClient.BaseAddress = new Uri("https://api.mem0.ai"); - /// httpClient.DefaultRequestHeaders.Authorization = new AuthenticationHeaderValue("Token", "<Your APIKey>"); - /// new Mem0AIContextProvider(httpClient, state); - /// - /// - public Mem0Provider(HttpClient httpClient, JsonElement serializedState, JsonSerializerOptions? jsonSerializerOptions = null, Mem0ProviderOptions? options = null, ILoggerFactory? loggerFactory = null) + /// The agent session containing the StateBag. + /// The provider state, or null if no session is available. + private State? GetOrInitializeState(AgentSession? session) { - if (string.IsNullOrWhiteSpace(httpClient.BaseAddress?.AbsoluteUri)) + if (session?.StateBag.TryGetValue(this._stateKey, out var state, Mem0JsonUtilities.DefaultOptions) is true && state is not null) { - throw new ArgumentException("The HttpClient BaseAddress must be set for Mem0 operations.", nameof(httpClient)); + return state; } - this._logger = loggerFactory?.CreateLogger(); - this._client = new Mem0Client(httpClient); - - this._contextPrompt = options?.ContextPrompt ?? DefaultContextPrompt; - this._enableSensitiveTelemetryData = options?.EnableSensitiveTelemetryData ?? false; + state = this._stateInitializer(session); - var jso = jsonSerializerOptions ?? Mem0JsonUtilities.DefaultOptions; - var state = serializedState.Deserialize(jso.GetTypeInfo(typeof(Mem0State))) as Mem0State; + if (state is null + || state.StorageScope is null + || (state.StorageScope.AgentId is null && state.StorageScope.ThreadId is null && state.StorageScope.UserId is null && state.StorageScope.ApplicationId is null) + || state.SearchScope is null + || (state.SearchScope.AgentId is null && state.SearchScope.ThreadId is null && state.SearchScope.UserId is null && state.SearchScope.ApplicationId is null)) + { + throw new InvalidOperationException("State initializer must return a non-null state with valid storage and search scopes, where at lest one scoping parameter is set for each."); + } - if (state == null || state.StorageScope == null || state.SearchScope == null) + if (session is not null) { - throw new InvalidOperationException("The Mem0Provider state did not contain the required scope properties."); + session.StateBag.SetValue(this._stateKey, state, Mem0JsonUtilities.DefaultOptions); } - this._storageScope = state.StorageScope; - this._searchScope = state.SearchScope; + return state; } /// @@ -135,6 +106,9 @@ protected override async ValueTask InvokingCoreAsync(InvokingContext { Throw.IfNull(context); + var state = this.GetOrInitializeState(context.Session); + var searchScope = state?.SearchScope ?? new Mem0ProviderScope(); + string queryText = string.Join( Environment.NewLine, context.RequestMessages @@ -145,10 +119,10 @@ protected override async ValueTask InvokingCoreAsync(InvokingContext try { var memories = (await this._client.SearchAsync( - this._searchScope.ApplicationId, - this._searchScope.AgentId, - this._searchScope.ThreadId, - this._searchScope.UserId, + searchScope.ApplicationId, + searchScope.AgentId, + searchScope.ThreadId, + searchScope.UserId, queryText, cancellationToken).ConfigureAwait(false)).ToList(); @@ -161,10 +135,10 @@ protected override async ValueTask InvokingCoreAsync(InvokingContext this._logger.LogInformation( "Mem0AIContextProvider: Retrieved {Count} memories. ApplicationId: '{ApplicationId}', AgentId: '{AgentId}', ThreadId: '{ThreadId}', UserId: '{UserId}'.", memories.Count, - this._searchScope.ApplicationId, - this._searchScope.AgentId, - this._searchScope.ThreadId, - this.SanitizeLogData(this._searchScope.UserId)); + searchScope.ApplicationId, + searchScope.AgentId, + searchScope.ThreadId, + this.SanitizeLogData(searchScope.UserId)); if (outputMessageText is not null && this._logger.IsEnabled(LogLevel.Trace)) { @@ -172,10 +146,10 @@ protected override async ValueTask InvokingCoreAsync(InvokingContext "Mem0AIContextProvider: Search Results\nInput:{Input}\nOutput:{MessageText}\nApplicationId: '{ApplicationId}', AgentId: '{AgentId}', ThreadId: '{ThreadId}', UserId: '{UserId}'.", this.SanitizeLogData(queryText), this.SanitizeLogData(outputMessageText), - this._searchScope.ApplicationId, - this._searchScope.AgentId, - this._searchScope.ThreadId, - this.SanitizeLogData(this._searchScope.UserId)); + searchScope.ApplicationId, + searchScope.AgentId, + searchScope.ThreadId, + this.SanitizeLogData(searchScope.UserId)); } } @@ -195,10 +169,10 @@ protected override async ValueTask InvokingCoreAsync(InvokingContext this._logger.LogError( ex, "Mem0AIContextProvider: Failed to search Mem0 for memories due to error. ApplicationId: '{ApplicationId}', AgentId: '{AgentId}', ThreadId: '{ThreadId}', UserId: '{UserId}'.", - this._searchScope.ApplicationId, - this._searchScope.AgentId, - this._searchScope.ThreadId, - this.SanitizeLogData(this._searchScope.UserId)); + searchScope.ApplicationId, + searchScope.AgentId, + searchScope.ThreadId, + this.SanitizeLogData(searchScope.UserId)); } return new AIContext(); } @@ -212,10 +186,14 @@ protected override async ValueTask InvokedCoreAsync(InvokedContext context, Canc return; // Do not update memory on failed invocations. } + var state = this.GetOrInitializeState(context.Session); + var storageScope = state?.StorageScope ?? new Mem0ProviderScope(); + try { // Persist request and response messages after invocation. await this.PersistMessagesAsync( + storageScope, context.RequestMessages .Where(m => m.GetAgentRequestMessageSource() == AgentRequestMessageSourceType.External) .Concat(context.ResponseMessages ?? []), @@ -228,36 +206,39 @@ await this.PersistMessagesAsync( this._logger.LogError( ex, "Mem0AIContextProvider: Failed to send messages to Mem0 due to error. ApplicationId: '{ApplicationId}', AgentId: '{AgentId}', ThreadId: '{ThreadId}', UserId: '{UserId}'.", - this._storageScope.ApplicationId, - this._storageScope.AgentId, - this._storageScope.ThreadId, - this.SanitizeLogData(this._storageScope.UserId)); + storageScope.ApplicationId, + storageScope.AgentId, + storageScope.ThreadId, + this.SanitizeLogData(storageScope.UserId)); } } } /// - /// Clears stored memories for the configured scopes. + /// Clears stored memories for the specified scope. /// + /// The session containing the scope state to clear memories for. /// Cancellation token. - public Task ClearStoredMemoriesAsync(CancellationToken cancellationToken = default) => - this._client.ClearMemoryAsync( - this._storageScope.ApplicationId, - this._storageScope.AgentId, - this._storageScope.ThreadId, - this._storageScope.UserId, - cancellationToken); - - /// - public override JsonElement Serialize(JsonSerializerOptions? jsonSerializerOptions = null) + public Task ClearStoredMemoriesAsync(AgentSession session, CancellationToken cancellationToken = default) { - var state = new Mem0State(this._storageScope, this._searchScope); + Throw.IfNull(session); + var state = this.GetOrInitializeState(session); + var storageScope = state?.StorageScope; + + if (storageScope is null) + { + return Task.CompletedTask; // Nothing to clear if there is no state. + } - var jso = jsonSerializerOptions ?? Mem0JsonUtilities.DefaultOptions; - return JsonSerializer.SerializeToElement(state, jso.GetTypeInfo(typeof(Mem0State))); + return this._client.ClearMemoryAsync( + storageScope.ApplicationId, + storageScope.AgentId, + storageScope.ThreadId, + storageScope.UserId, + cancellationToken); } - private async Task PersistMessagesAsync(IEnumerable messages, CancellationToken cancellationToken) + private async Task PersistMessagesAsync(Mem0ProviderScope storageScope, IEnumerable messages, CancellationToken cancellationToken) { foreach (var message in messages) { @@ -277,27 +258,42 @@ private async Task PersistMessagesAsync(IEnumerable messages, Cance } await this._client.CreateMemoryAsync( - this._storageScope.ApplicationId, - this._storageScope.AgentId, - this._storageScope.ThreadId, - this._storageScope.UserId, + storageScope.ApplicationId, + storageScope.AgentId, + storageScope.ThreadId, + storageScope.UserId, message.Text, message.Role.Value, cancellationToken).ConfigureAwait(false); } } - internal sealed class Mem0State + /// + /// Represents the state of a stored in the . + /// + public sealed class State { + /// + /// Initializes a new instance of the class with the specified storage and search scopes. + /// + /// The scope to use when storing memories. + /// The scope to use when searching for memories. If null, the storage scope will be used for searching as well. [JsonConstructor] - public Mem0State(Mem0ProviderScope storageScope, Mem0ProviderScope searchScope) + public State(Mem0ProviderScope storageScope, Mem0ProviderScope? searchScope = null) { - this.StorageScope = storageScope; - this.SearchScope = searchScope; + this.StorageScope = Throw.IfNull(storageScope); + this.SearchScope = searchScope ?? storageScope; } - public Mem0ProviderScope StorageScope { get; set; } - public Mem0ProviderScope SearchScope { get; set; } + /// + /// Gets the scope used when storing memories. + /// + public Mem0ProviderScope StorageScope { get; } + + /// + /// Gets the scope used when searching memories. + /// + public Mem0ProviderScope SearchScope { get; } } private string? SanitizeLogData(string? data) => this._enableSensitiveTelemetryData ? data : ""; diff --git a/dotnet/src/Microsoft.Agents.AI.Mem0/Mem0ProviderOptions.cs b/dotnet/src/Microsoft.Agents.AI.Mem0/Mem0ProviderOptions.cs index f2d3d89e16..f382e80bf2 100644 --- a/dotnet/src/Microsoft.Agents.AI.Mem0/Mem0ProviderOptions.cs +++ b/dotnet/src/Microsoft.Agents.AI.Mem0/Mem0ProviderOptions.cs @@ -18,4 +18,10 @@ public sealed class Mem0ProviderOptions /// /// Defaults to . public bool EnableSensitiveTelemetryData { get; set; } + + /// + /// Gets or sets the key used to store the provider state in the session's . + /// + /// Defaults to "Mem0Provider.State". + public string? StateKey { get; set; } } diff --git a/dotnet/src/Microsoft.Agents.AI.OpenAI/Extensions/OpenAIAssistantClientExtensions.cs b/dotnet/src/Microsoft.Agents.AI.OpenAI/Extensions/OpenAIAssistantClientExtensions.cs index d167d1f0b4..3f25f71d83 100644 --- a/dotnet/src/Microsoft.Agents.AI.OpenAI/Extensions/OpenAIAssistantClientExtensions.cs +++ b/dotnet/src/Microsoft.Agents.AI.OpenAI/Extensions/OpenAIAssistantClientExtensions.cs @@ -204,8 +204,8 @@ public static ChatClientAgent AsAIAgent( Name = options.Name ?? assistantMetadata.Name, Description = options.Description ?? assistantMetadata.Description, ChatOptions = options.ChatOptions, - AIContextProviderFactory = options.AIContextProviderFactory, - ChatHistoryProviderFactory = options.ChatHistoryProviderFactory, + AIContextProvider = options.AIContextProvider, + ChatHistoryProvider = options.ChatHistoryProvider, UseProvidedChatClientAsIs = options.UseProvidedChatClientAsIs }; diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows/WorkflowChatHistoryProvider.cs b/dotnet/src/Microsoft.Agents.AI.Workflows/WorkflowChatHistoryProvider.cs index f631de8e8a..349c680085 100644 --- a/dotnet/src/Microsoft.Agents.AI.Workflows/WorkflowChatHistoryProvider.cs +++ b/dotnet/src/Microsoft.Agents.AI.Workflows/WorkflowChatHistoryProvider.cs @@ -6,48 +6,54 @@ using System.Threading; using System.Threading.Tasks; using Microsoft.Extensions.AI; -using Microsoft.Shared.Diagnostics; namespace Microsoft.Agents.AI.Workflows; internal sealed class WorkflowChatHistoryProvider : ChatHistoryProvider { - private int _bookmark; - private readonly List _chatMessages = []; - - public WorkflowChatHistoryProvider() + private const string DefaultStateBagKey = "WorkflowChatHistoryProvider.State"; + private readonly JsonSerializerOptions _jsonSerializerOptions; + + /// + /// Initializes a new instance of the class. + /// + /// + /// Optional JSON serializer options for serializing the state of this provider. + /// This is valuable for cases like when the chat history contains custom types + /// and source generated serializers are required, or Native AOT / Trimming is required. + /// + public WorkflowChatHistoryProvider(JsonSerializerOptions? jsonSerializerOptions = null) { + this._jsonSerializerOptions = jsonSerializerOptions ?? AgentAbstractionsJsonUtilities.DefaultOptions; } - public WorkflowChatHistoryProvider(StoreState state) + internal sealed class StoreState { - this.ImportStoreState(Throw.IfNull(state)); + public int Bookmark { get; set; } + public List Messages { get; set; } = []; } - private void ImportStoreState(StoreState state, bool clearMessages = false) + private StoreState GetOrInitializeState(AgentSession? session) { - if (clearMessages) + if (session?.StateBag.TryGetValue(DefaultStateBagKey, out var state, this._jsonSerializerOptions) is true && state is not null) { - this._chatMessages.Clear(); + return state; } - if (state?.Messages is not null) + state = new(); + if (session is not null) { - this._chatMessages.AddRange(state.Messages); + session.StateBag.SetValue(DefaultStateBagKey, state, this._jsonSerializerOptions); } - this._bookmark = state?.Bookmark ?? 0; - } - internal sealed class StoreState - { - public int Bookmark { get; set; } - public IList Messages { get; set; } = []; + return state; } - internal void AddMessages(params IEnumerable messages) => this._chatMessages.AddRange(messages); + internal void AddMessages(AgentSession session, params IEnumerable messages) + => this.GetOrInitializeState(session).Messages.AddRange(messages); protected override ValueTask> InvokingCoreAsync(InvokingContext context, CancellationToken cancellationToken = default) - => new(this._chatMessages.AsReadOnly()); + => new(this.GetOrInitializeState(context.Session).Messages.AsReadOnly()); protected override ValueTask InvokedCoreAsync(InvokedContext context, CancellationToken cancellationToken = default) { @@ -57,28 +63,24 @@ protected override ValueTask InvokedCoreAsync(InvokedContext context, Cancellati } var allNewMessages = context.RequestMessages.Concat(context.ResponseMessages ?? []); - this._chatMessages.AddRange(allNewMessages); + this.GetOrInitializeState(context.Session).Messages.AddRange(allNewMessages); return default; } - public IEnumerable GetFromBookmark() + public IEnumerable GetFromBookmark(AgentSession session) { - for (int i = this._bookmark; i < this._chatMessages.Count; i++) + var state = this.GetOrInitializeState(session); + + for (int i = state.Bookmark; i < state.Messages.Count; i++) { - yield return this._chatMessages[i]; + yield return state.Messages[i]; } } - public void UpdateBookmark() => this._bookmark = this._chatMessages.Count; - - public override JsonElement Serialize(JsonSerializerOptions? jsonSerializerOptions = null) + public void UpdateBookmark(AgentSession session) { - StoreState state = this.ExportStoreState(); - - return JsonSerializer.SerializeToElement(state, - WorkflowsJsonUtilities.DefaultOptions.GetTypeInfo(typeof(StoreState))); + var state = this.GetOrInitializeState(session); + state.Bookmark = state.Messages.Count; } - - internal StoreState ExportStoreState() => new() { Bookmark = this._bookmark, Messages = this._chatMessages }; } diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows/WorkflowHostAgent.cs b/dotnet/src/Microsoft.Agents.AI.Workflows/WorkflowHostAgent.cs index 66f71a219a..95335b37c0 100644 --- a/dotnet/src/Microsoft.Agents.AI.Workflows/WorkflowHostAgent.cs +++ b/dotnet/src/Microsoft.Agents.AI.Workflows/WorkflowHostAgent.cs @@ -94,7 +94,7 @@ private async ValueTask UpdateSessionAsync(IEnumerable InvokeStageAsync( try { this.LastResponseId = Guid.NewGuid().ToString("N"); - List messages = this.ChatHistoryProvider.GetFromBookmark().ToList(); + List messages = this.ChatHistoryProvider.GetFromBookmark(this).ToList(); #pragma warning disable CA2007 // Analyzer misfiring and not seeing .ConfigureAwait(false) below. await using Checkpointed checkpointed = @@ -240,7 +241,7 @@ IAsyncEnumerable InvokeStageAsync( finally { // Do we want to try to undo the step, and not update the bookmark? - this.ChatHistoryProvider.UpdateBookmark(); + this.ChatHistoryProvider.UpdateBookmark(this); } } @@ -254,12 +255,12 @@ IAsyncEnumerable InvokeStageAsync( internal sealed class SessionState( string runId, CheckpointInfo? lastCheckpoint, - WorkflowChatHistoryProvider.StoreState chatHistoryProviderState, - InMemoryCheckpointManager? checkpointManager = null) + InMemoryCheckpointManager? checkpointManager = null, + AgentSessionStateBag? stateBag = null) { public string RunId { get; } = runId; public CheckpointInfo? LastCheckpoint { get; } = lastCheckpoint; - public WorkflowChatHistoryProvider.StoreState ChatHistoryProviderState { get; } = chatHistoryProviderState; public InMemoryCheckpointManager? CheckpointManager { get; } = checkpointManager; + public AgentSessionStateBag StateBag { get; } = stateBag ?? new(); } } diff --git a/dotnet/src/Microsoft.Agents.AI/AgentJsonUtilities.cs b/dotnet/src/Microsoft.Agents.AI/AgentJsonUtilities.cs index 5c915b6b01..96ec6dbecb 100644 --- a/dotnet/src/Microsoft.Agents.AI/AgentJsonUtilities.cs +++ b/dotnet/src/Microsoft.Agents.AI/AgentJsonUtilities.cs @@ -65,9 +65,9 @@ private static JsonSerializerOptions CreateDefaultOptions() NumberHandling = JsonNumberHandling.AllowReadingFromString)] // Agent abstraction types - [JsonSerializable(typeof(ChatClientAgentSession.SessionState))] + [JsonSerializable(typeof(ChatClientAgentSession))] [JsonSerializable(typeof(TextSearchProvider.TextSearchProviderState))] - [JsonSerializable(typeof(ChatHistoryMemoryProvider.ChatHistoryMemoryProviderState))] + [JsonSerializable(typeof(ChatHistoryMemoryProvider.State))] [ExcludeFromCodeCoverage] internal sealed partial class JsonContext : JsonSerializerContext; diff --git a/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgent.cs b/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgent.cs index 5878d877b2..daaf584e83 100644 --- a/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgent.cs +++ b/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgent.cs @@ -105,6 +105,11 @@ public ChatClientAgent(IChatClient chatClient, ChatClientAgentOptions? options, // If the user has not opted out of using our default decorators, we wrap the chat client. this.ChatClient = options?.UseProvidedChatClientAsIs is true ? chatClient : chatClient.WithDefaultAgentMiddleware(options, services); + // Use the ChatHistoryProvider from options if provided. + // If one was not provided, and we later find out that the underlying service does not manage chat history server-side, + // we will use the default InMemoryChatHistoryProvider at that time. + this.ChatHistoryProvider = options?.ChatHistoryProvider; + this._logger = (loggerFactory ?? chatClient.GetService() ?? NullLoggerFactory.Instance).CreateLogger(); } @@ -120,6 +125,14 @@ public ChatClientAgent(IChatClient chatClient, ChatClientAgentOptions? options, /// public IChatClient ChatClient { get; } + /// + /// Gets the used by this agent, to support cases where the chat history is not stored by the agent service. + /// + /// + /// This property may be null in case the agent stores messages in the underlying agent service. + /// + public ChatHistoryProvider? ChatHistoryProvider { get; private set; } + /// protected override string? IdCore => this._agentOptions?.Id; @@ -282,7 +295,7 @@ protected override async IAsyncEnumerable RunCoreStreamingA // We can derive the type of supported session from whether we have a conversation id, // so let's update it and set the conversation id for the service session case. - await this.UpdateSessionWithTypeAndConversationIdAsync(safeSession, chatResponse.ConversationId, cancellationToken).ConfigureAwait(false); + this.UpdateSessionConversationId(safeSession, chatResponse.ConversationId, cancellationToken); // To avoid inconsistent state we only notify the session of the input messages if no error occurs after the initial request. await this.NotifyChatHistoryProviderOfNewMessagesAsync(safeSession, GetInputMessages(inputMessagesForProviders, continuationToken), chatResponse.Messages, chatOptions, cancellationToken).ConfigureAwait(false); @@ -298,24 +311,14 @@ protected override async IAsyncEnumerable RunCoreStreamingA : serviceType == typeof(IChatClient) ? this.ChatClient : serviceType == typeof(ChatOptions) ? this._agentOptions?.ChatOptions : serviceType == typeof(ChatClientAgentOptions) ? this._agentOptions - : this.ChatClient.GetService(serviceType, serviceKey)); + : this._agentOptions?.AIContextProvider?.GetService(serviceType, serviceKey) + ?? this.ChatHistoryProvider?.GetService(serviceType, serviceKey) + ?? this.ChatClient.GetService(serviceType, serviceKey)); /// - protected override async ValueTask CreateSessionCoreAsync(CancellationToken cancellationToken = default) + protected override ValueTask CreateSessionCoreAsync(CancellationToken cancellationToken = default) { - ChatHistoryProvider? chatHistoryProvider = this._agentOptions?.ChatHistoryProviderFactory is not null - ? await this._agentOptions.ChatHistoryProviderFactory.Invoke(new() { SerializedState = default, JsonSerializerOptions = null }, cancellationToken).ConfigureAwait(false) - : null; - - AIContextProvider? contextProvider = this._agentOptions?.AIContextProviderFactory is not null - ? await this._agentOptions.AIContextProviderFactory.Invoke(new() { SerializedState = default, JsonSerializerOptions = null }, cancellationToken).ConfigureAwait(false) - : null; - - return new ChatClientAgentSession - { - ChatHistoryProvider = chatHistoryProvider, - AIContextProvider = contextProvider - }; + return new(new ChatClientAgentSession()); } /// @@ -336,52 +339,12 @@ protected override async ValueTask CreateSessionCoreAsync(Cancella /// instances that support server-side conversation storage through their underlying . /// /// - public async ValueTask CreateSessionAsync(string conversationId, CancellationToken cancellationToken = default) + public ValueTask CreateSessionAsync(string conversationId, CancellationToken cancellationToken = default) { - AIContextProvider? contextProvider = this._agentOptions?.AIContextProviderFactory is not null - ? await this._agentOptions.AIContextProviderFactory.Invoke(new() { SerializedState = default, JsonSerializerOptions = null }, cancellationToken).ConfigureAwait(false) - : null; - - return new ChatClientAgentSession() + return new(new ChatClientAgentSession() { ConversationId = conversationId, - AIContextProvider = contextProvider - }; - } - - /// - /// Creates a new agent session instance using an existing to continue a conversation. - /// - /// The instance to use for managing the conversation's message history. - /// The to monitor for cancellation requests. - /// - /// A value task representing the asynchronous operation. The task result contains a new instance configured to work with the provided . - /// - /// - /// - /// This method creates threads that do not support server-side conversation storage. - /// Some AI services require server-side conversation storage to function properly, and creating a session - /// with a may not be compatible with these services. - /// - /// - /// Where a service requires server-side conversation storage, use . - /// - /// - /// If the agent detects, during the first run, that the underlying AI service requires server-side conversation storage, - /// the session will throw an exception to indicate that it cannot continue using the provided . - /// - /// - public async ValueTask CreateSessionAsync(ChatHistoryProvider chatHistoryProvider, CancellationToken cancellationToken = default) - { - AIContextProvider? contextProvider = this._agentOptions?.AIContextProviderFactory is not null - ? await this._agentOptions.AIContextProviderFactory.Invoke(new() { SerializedState = default, JsonSerializerOptions = null }, cancellationToken).ConfigureAwait(false) - : null; - - return new ChatClientAgentSession() - { - ChatHistoryProvider = Throw.IfNull(chatHistoryProvider), - AIContextProvider = contextProvider - }; + }); } /// @@ -398,22 +361,9 @@ protected override JsonElement SerializeSessionCore(AgentSession session, JsonSe } /// - protected override async ValueTask DeserializeSessionCoreAsync(JsonElement serializedState, JsonSerializerOptions? jsonSerializerOptions = null, CancellationToken cancellationToken = default) + protected override ValueTask DeserializeSessionCoreAsync(JsonElement serializedState, JsonSerializerOptions? jsonSerializerOptions = null, CancellationToken cancellationToken = default) { - Func>? chatHistoryProviderFactory = this._agentOptions?.ChatHistoryProviderFactory is null ? - null : - (jse, jso, ct) => this._agentOptions.ChatHistoryProviderFactory.Invoke(new() { SerializedState = jse, JsonSerializerOptions = jso }, ct); - - Func>? aiContextProviderFactory = this._agentOptions?.AIContextProviderFactory is null ? - null : - (jse, jso, ct) => this._agentOptions.AIContextProviderFactory.Invoke(new() { SerializedState = jse, JsonSerializerOptions = jso }, ct); - - return await ChatClientAgentSession.DeserializeAsync( - serializedState, - jsonSerializerOptions, - chatHistoryProviderFactory, - aiContextProviderFactory, - cancellationToken).ConfigureAwait(false); + return new(ChatClientAgentSession.Deserialize(serializedState, jsonSerializerOptions)); } #region Private @@ -462,7 +412,7 @@ private async Task RunCoreAsync responseMessages, CancellationToken cancellationToken) { - if (session.AIContextProvider is not null) + if (this._agentOptions?.AIContextProvider is { } contextProvider) { - await session.AIContextProvider.InvokedAsync(new(this, session, inputMessages) { ResponseMessages = responseMessages }, + await contextProvider.InvokedAsync(new(this, session, inputMessages) { ResponseMessages = responseMessages }, cancellationToken).ConfigureAwait(false); } } @@ -508,9 +458,9 @@ private async Task NotifyAIContextProviderOfFailureAsync( IEnumerable inputMessages, CancellationToken cancellationToken) { - if (session.AIContextProvider is not null) + if (this._agentOptions?.AIContextProvider is { } contextProvider) { - await session.AIContextProvider.InvokedAsync(new(this, session, inputMessages) { InvokeException = ex }, + await contextProvider.InvokedAsync(new(this, session, inputMessages) { InvokeException = ex }, cancellationToken).ConfigureAwait(false); } } @@ -715,7 +665,7 @@ private async Task // Populate the session messages only if we are not continuing an existing response as it's not allowed if (chatOptions?.ContinuationToken is null) { - ChatHistoryProvider? chatHistoryProvider = ResolveChatHistoryProvider(typedSession, chatOptions); + ChatHistoryProvider? chatHistoryProvider = this.ResolveChatHistoryProvider(chatOptions, typedSession); // Add any existing messages from the session to the messages to be sent to the chat client. if (chatHistoryProvider is not null) @@ -731,10 +681,10 @@ private async Task // If we have an AIContextProvider, we should get context from it, and update our // messages and options with the additional context. - if (typedSession.AIContextProvider is not null) + if (this._agentOptions?.AIContextProvider is { } aiContextProvider) { var invokingContext = new AIContextProvider.InvokingContext(this, typedSession, inputMessages); - var aiContext = await typedSession.AIContextProvider.InvokingAsync(invokingContext, cancellationToken).ConfigureAwait(false); + var aiContext = await aiContextProvider.InvokingAsync(invokingContext, cancellationToken).ConfigureAwait(false); if (aiContext.Messages is { Count: > 0 }) { inputMessagesForProviders.AddRange(aiContext.Messages); @@ -780,7 +730,7 @@ private async Task return (typedSession, chatOptions, inputMessagesForProviders, inputMessagesForChatClient, continuationToken); } - private async Task UpdateSessionWithTypeAndConversationIdAsync(ChatClientAgentSession session, string? responseConversationId, CancellationToken cancellationToken) + private void UpdateSessionConversationId(ChatClientAgentSession session, string? responseConversationId, CancellationToken cancellationToken) { if (string.IsNullOrWhiteSpace(responseConversationId) && !string.IsNullOrWhiteSpace(session.ConversationId)) { @@ -791,6 +741,14 @@ private async Task UpdateSessionWithTypeAndConversationIdAsync(ChatClientAgentSe if (!string.IsNullOrWhiteSpace(responseConversationId)) { + if (this.ChatHistoryProvider is not null) + { + // The agent has a ChatHistoryProvider configured, but the service returned a conversation id, + // meaning the service manages chat history server-side. Both cannot be used simultaneously. + throw new InvalidOperationException( + $"Only {nameof(ChatClientAgentSession.ConversationId)} or {nameof(this.ChatHistoryProvider)} may be used, but not both. The service returned a conversation id indicating server-side chat history management, but the agent has a {nameof(this.ChatHistoryProvider)} configured."); + } + // If we got a conversation id back from the chat client, it means that the service supports server side session storage // so we should update the session with the new id. session.ConversationId = responseConversationId; @@ -798,11 +756,9 @@ private async Task UpdateSessionWithTypeAndConversationIdAsync(ChatClientAgentSe else { // If the service doesn't use service side chat history storage (i.e. we got no id back from invocation), and - // the session has no ChatHistoryProvider yet, we should update the session with the custom ChatHistoryProvider or - // default InMemoryChatHistoryProvider so that it has somewhere to store the chat history. - session.ChatHistoryProvider ??= this._agentOptions?.ChatHistoryProviderFactory is not null - ? await this._agentOptions.ChatHistoryProviderFactory.Invoke(new() { SerializedState = default, JsonSerializerOptions = null }, cancellationToken).ConfigureAwait(false) - : new InMemoryChatHistoryProvider(); + // the agent has no ChatHistoryProvider yet, we should use the default InMemoryChatHistoryProvider so that + // we have somewhere to store the chat history. + this.ChatHistoryProvider ??= new InMemoryChatHistoryProvider(); } } @@ -813,7 +769,7 @@ private Task NotifyChatHistoryProviderOfFailureAsync( ChatOptions? chatOptions, CancellationToken cancellationToken) { - ChatHistoryProvider? provider = ResolveChatHistoryProvider(session, chatOptions); + ChatHistoryProvider? provider = this.ResolveChatHistoryProvider(chatOptions, session); // Only notify the provider if we have one. // If we don't have one, it means that the chat history is service managed and the underlying service is responsible for storing messages. @@ -837,7 +793,7 @@ private Task NotifyChatHistoryProviderOfNewMessagesAsync( ChatOptions? chatOptions, CancellationToken cancellationToken) { - ChatHistoryProvider? provider = ResolveChatHistoryProvider(session, chatOptions); + ChatHistoryProvider? provider = this.ResolveChatHistoryProvider(chatOptions, session); // Only notify the provider if we have one. // If we don't have one, it means that the chat history is service managed and the underlying service is responsible for storing messages. @@ -853,13 +809,25 @@ private Task NotifyChatHistoryProviderOfNewMessagesAsync( return Task.CompletedTask; } - private static ChatHistoryProvider? ResolveChatHistoryProvider(ChatClientAgentSession session, ChatOptions? chatOptions) + private ChatHistoryProvider? ResolveChatHistoryProvider(ChatOptions? chatOptions, ChatClientAgentSession session) { - ChatHistoryProvider? provider = session.ChatHistoryProvider; + ChatHistoryProvider? provider = this.ChatHistoryProvider; + + if (session.ConversationId is not null && provider is not null) + { + throw new InvalidOperationException( + $"Only {nameof(ChatClientAgentSession.ConversationId)} or {nameof(this.ChatHistoryProvider)} may be used, but not both. The current {nameof(ChatClientAgentSession)} has a {nameof(ChatClientAgentSession.ConversationId)} indicating server-side chat history management, but the agent has a {nameof(this.ChatHistoryProvider)} configured."); + } - // If someone provided an override ChatHistoryProvider via AdditionalProperties, we should use that instead of the one on the session. + // If someone provided an override ChatHistoryProvider via AdditionalProperties, we should use that instead. if (chatOptions?.AdditionalProperties?.TryGetValue(out ChatHistoryProvider? overrideProvider) is true) { + if (session.ConversationId is not null && overrideProvider is not null) + { + throw new InvalidOperationException( + $"Only {nameof(ChatClientAgentSession.ConversationId)} or {nameof(this.ChatHistoryProvider)} may be used, but not both. The current {nameof(ChatClientAgentSession)} has a {nameof(ChatClientAgentSession.ConversationId)} indicating server-side chat history management, but an override {nameof(this.ChatHistoryProvider)} was provided via {nameof(AgentRunOptions.AdditionalProperties)}."); + } + provider = overrideProvider; } diff --git a/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgentOptions.cs b/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgentOptions.cs index 6f8451e2b8..66f4f797c5 100644 --- a/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgentOptions.cs +++ b/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgentOptions.cs @@ -1,9 +1,5 @@ // Copyright (c) Microsoft. All rights reserved. -using System; -using System.Text.Json; -using System.Threading; -using System.Threading.Tasks; using Microsoft.Extensions.AI; namespace Microsoft.Agents.AI; @@ -39,17 +35,14 @@ public sealed class ChatClientAgentOptions public ChatOptions? ChatOptions { get; set; } /// - /// Gets or sets a factory function to create an instance of - /// which will be used to provide chat history for this agent. + /// Gets or sets the instance to use for providing chat history for this agent. /// - public Func>? ChatHistoryProviderFactory { get; set; } + public ChatHistoryProvider? ChatHistoryProvider { get; set; } /// - /// Gets or sets a factory function to create an instance of - /// which will be used to create a context provider for each new thread, and can then - /// provide additional context for each agent run. + /// Gets or sets the instance to use for providing additional context for each agent run. /// - public Func>? AIContextProviderFactory { get; set; } + public AIContextProvider? AIContextProvider { get; set; } /// /// Gets or sets a value indicating whether to use the provided instance as is, @@ -75,41 +68,7 @@ public ChatClientAgentOptions Clone() Name = this.Name, Description = this.Description, ChatOptions = this.ChatOptions?.Clone(), - ChatHistoryProviderFactory = this.ChatHistoryProviderFactory, - AIContextProviderFactory = this.AIContextProviderFactory, + ChatHistoryProvider = this.ChatHistoryProvider, + AIContextProvider = this.AIContextProvider, }; - - /// - /// Context object passed to the to create a new instance of . - /// - public sealed class AIContextProviderFactoryContext - { - /// - /// Gets or sets the serialized state of the , if any. - /// - /// if there is no state, e.g. when the is first created. - public JsonElement SerializedState { get; set; } - - /// - /// Gets or sets the JSON serialization options to use when deserializing the . - /// - public JsonSerializerOptions? JsonSerializerOptions { get; set; } - } - - /// - /// Context object passed to the to create a new instance of . - /// - public sealed class ChatHistoryProviderFactoryContext - { - /// - /// Gets or sets the serialized state of the , if any. - /// - /// if there is no state, e.g. when the is first created. - public JsonElement SerializedState { get; set; } - - /// - /// Gets or sets the JSON serialization options to use when deserializing the . - /// - public JsonSerializerOptions? JsonSerializerOptions { get; set; } - } } diff --git a/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgentSession.cs b/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgentSession.cs index 1a79ae64d1..400bfbcaf6 100644 --- a/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgentSession.cs +++ b/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgentSession.cs @@ -3,8 +3,7 @@ using System; using System.Diagnostics; using System.Text.Json; -using System.Threading; -using System.Threading.Tasks; +using System.Text.Json.Serialization; using Microsoft.Shared.Diagnostics; namespace Microsoft.Agents.AI; @@ -15,8 +14,6 @@ namespace Microsoft.Agents.AI; [DebuggerDisplay("{DebuggerDisplay,nq}")] public sealed class ChatClientAgentSession : AgentSession { - private ChatHistoryProvider? _chatHistoryProvider; - /// /// Initializes a new instance of the class. /// @@ -24,29 +21,30 @@ internal ChatClientAgentSession() { } + [JsonConstructor] + internal ChatClientAgentSession(string? conversationId, AgentSessionStateBag? stateBag) : base(stateBag ?? new()) + { + this.ConversationId = conversationId; + } + /// - /// Gets or sets the ID of the underlying service thread to support cases where the chat history is stored by the agent service. + /// Gets or sets the ID of the underlying service chat history to support cases where the chat history is stored by the agent service. /// /// /// - /// Note that either or may be set, but not both. - /// If is not null, setting will throw an - /// exception. - /// - /// /// This property may be null in the following cases: /// - /// The thread stores messages via the and not in the agent service. - /// This thread object is new and a server managed thread has not yet been created in the agent service. + /// The agent stores messages via a and not in the agent service. + /// This session object is new and server managed chat history has not yet been created in the agent service. /// /// /// - /// The id may also change over time where the id is pointing at a - /// agent service managed thread, and the default behavior of a service is - /// to fork the thread with each iteration. + /// The id may also change over time where the id is pointing at + /// agent service managed chat history, and the default behavior of a service is + /// to fork the chat history with each iteration. /// /// - /// Attempted to set a conversation ID but a is already set. + [JsonPropertyName("conversationId")] public string? ConversationId { get; @@ -57,149 +55,37 @@ internal set return; } - if (this._chatHistoryProvider is not null) - { - // If we have a ChatHistoryProvider already, we shouldn't switch the session to use a conversation id - // since it means that the session contents will essentially be deleted, and the session will not work - // with the original agent anymore. - throw new InvalidOperationException("Only the ConversationId or ChatHistoryProvider may be set, but not both and switching from one to another is not supported."); - } - field = Throw.IfNullOrWhitespace(value); } } - /// - /// Gets or sets the used by this thread, for cases where messages should be stored in a custom location. - /// - /// - /// - /// Note that either or may be set, but not both. - /// If is not null, and is set, - /// will be reverted to null, and vice versa. - /// - /// - /// This property may be null in the following cases: - /// - /// The thread stores messages in the agent service and just has an id to the remove thread, instead of in an . - /// This thread object is new it is not yet clear whether it will be backed by a server managed thread or an . - /// - /// - /// - public ChatHistoryProvider? ChatHistoryProvider - { - get => this._chatHistoryProvider; - internal set - { - if (this._chatHistoryProvider is null && value is null) - { - return; - } - - if (!string.IsNullOrWhiteSpace(this.ConversationId)) - { - // If we have a conversation id already, we shouldn't switch the session to use a ChatHistoryProvider - // since it means that the session will not work with the original agent anymore. - throw new InvalidOperationException("Only the ConversationId or ChatHistoryProvider may be set, but not both and switching from one to another is not supported."); - } - - this._chatHistoryProvider = Throw.IfNull(value); - } - } - - /// - /// Gets or sets the used by this thread to provide additional context to the AI model before each invocation. - /// - public AIContextProvider? AIContextProvider { get; internal set; } - /// /// Creates a new instance of the class from previously serialized state. /// /// A representing the serialized state of the session. - /// Optional settings for customizing the JSON deserialization process. - /// - /// An optional factory function to create a custom from its serialized state. - /// If not provided, the default will be used. - /// - /// - /// An optional factory function to create a custom from its serialized state. - /// If not provided, no context provider will be configured. - /// - /// The to monitor for cancellation requests. - /// A task representing the asynchronous operation. The task result contains the deserialized . - internal static async Task DeserializeAsync( - JsonElement serializedState, - JsonSerializerOptions? jsonSerializerOptions = null, - Func>? chatHistoryProviderFactory = null, - Func>? aiContextProviderFactory = null, - CancellationToken cancellationToken = default) + /// Optional JSON serialization options to use instead of the default options. + /// The deserialized . + internal static ChatClientAgentSession Deserialize(JsonElement serializedState, JsonSerializerOptions? jsonSerializerOptions = null) { if (serializedState.ValueKind != JsonValueKind.Object) { throw new ArgumentException("The serialized session state must be a JSON object.", nameof(serializedState)); } - var state = serializedState.Deserialize( - AgentJsonUtilities.DefaultOptions.GetTypeInfo(typeof(SessionState))) as SessionState; - - var session = new ChatClientAgentSession(); - - session.AIContextProvider = aiContextProviderFactory is not null - ? await aiContextProviderFactory.Invoke(state?.AIContextProviderState ?? default, jsonSerializerOptions, cancellationToken).ConfigureAwait(false) - : null; - - if (state?.ConversationId is string sessionId) - { - session.ConversationId = sessionId; - - // Since we have an ID, we should not have a ChatHistoryProvider and we can return here. - return session; - } - - session._chatHistoryProvider = - chatHistoryProviderFactory is not null - ? await chatHistoryProviderFactory.Invoke(state?.ChatHistoryProviderState ?? default, jsonSerializerOptions, cancellationToken).ConfigureAwait(false) - : new InMemoryChatHistoryProvider(state?.ChatHistoryProviderState ?? default, jsonSerializerOptions); // default to an in-memory ChatHistoryProvider - - return session; + var jso = jsonSerializerOptions ?? AgentJsonUtilities.DefaultOptions; + return serializedState.Deserialize(jso.GetTypeInfo(typeof(ChatClientAgentSession))) as ChatClientAgentSession + ?? new ChatClientAgentSession(); } /// internal JsonElement Serialize(JsonSerializerOptions? jsonSerializerOptions = null) { - JsonElement? chatHistoryProviderState = this._chatHistoryProvider?.Serialize(jsonSerializerOptions); - - JsonElement? aiContextProviderState = this.AIContextProvider?.Serialize(jsonSerializerOptions); - - var state = new SessionState - { - ConversationId = this.ConversationId, - ChatHistoryProviderState = chatHistoryProviderState is { ValueKind: not JsonValueKind.Undefined } ? chatHistoryProviderState : null, - AIContextProviderState = aiContextProviderState is { ValueKind: not JsonValueKind.Undefined } ? aiContextProviderState : null, - }; - - return JsonSerializer.SerializeToElement(state, AgentJsonUtilities.DefaultOptions.GetTypeInfo(typeof(SessionState))); + var jso = jsonSerializerOptions ?? AgentJsonUtilities.DefaultOptions; + return JsonSerializer.SerializeToElement(this, jso.GetTypeInfo(typeof(ChatClientAgentSession))); } - /// - public override object? GetService(Type serviceType, object? serviceKey = null) => - base.GetService(serviceType, serviceKey) - ?? this.AIContextProvider?.GetService(serviceType, serviceKey) - ?? this.ChatHistoryProvider?.GetService(serviceType, serviceKey); - [DebuggerBrowsable(DebuggerBrowsableState.Never)] private string DebuggerDisplay => - this.ConversationId is { } conversationId ? $"ConversationId = {conversationId}" : - this._chatHistoryProvider is InMemoryChatHistoryProvider inMemoryChatHistoryProvider ? $"Count = {inMemoryChatHistoryProvider.Count}" : - this._chatHistoryProvider is { } chatHistoryProvider ? $"ChatHistoryProvider = {chatHistoryProvider.GetType().Name}" : - "Count = 0"; - - internal sealed class SessionState - { - public string? ConversationId { get; set; } - - public JsonElement? ChatHistoryProviderState { get; set; } - - public JsonElement? AIContextProviderState { get; set; } - } + this.ConversationId is { } conversationId ? $"ConversationId = {conversationId}, StateBag Count = {this.StateBag.Count}" : + $"StateBag Count = {this.StateBag.Count}"; } diff --git a/dotnet/src/Microsoft.Agents.AI/Memory/ChatHistoryMemoryProvider.cs b/dotnet/src/Microsoft.Agents.AI/Memory/ChatHistoryMemoryProvider.cs index c63e8ac682..93901deeee 100644 --- a/dotnet/src/Microsoft.Agents.AI/Memory/ChatHistoryMemoryProvider.cs +++ b/dotnet/src/Microsoft.Agents.AI/Memory/ChatHistoryMemoryProvider.cs @@ -4,7 +4,6 @@ using System.Collections.Generic; using System.Linq; using System.Linq.Expressions; -using System.Text.Json; using System.Threading; using System.Threading.Tasks; using Microsoft.Extensions.AI; @@ -41,18 +40,21 @@ public sealed class ChatHistoryMemoryProvider : AIContextProvider, IDisposable private const int DefaultMaxResults = 3; private const string DefaultFunctionToolName = "Search"; private const string DefaultFunctionToolDescription = "Allows searching for related previous chat history to help answer the user question."; + private const string DefaultStateBagKey = "ChatHistoryMemoryProvider.State"; +#pragma warning disable CA2213 // VectorStore is not owned by this class - caller is responsible for disposal private readonly VectorStore _vectorStore; +#pragma warning restore CA2213 private readonly VectorStoreCollection> _collection; private readonly int _maxResults; private readonly string _contextPrompt; private readonly bool _enableSensitiveTelemetryData; private readonly ChatHistoryMemoryProviderOptions.SearchBehavior _searchTime; - private readonly AITool[] _tools; + private readonly string _toolName; + private readonly string _toolDescription; private readonly ILogger? _logger; - - private readonly ChatHistoryMemoryProviderScope _storageScope; - private readonly ChatHistoryMemoryProviderScope _searchScope; + private readonly string _stateKey; + private readonly Func _stateInitializer; private bool _collectionInitialized; private readonly SemaphoreSlim _initializationLock = new(1, 1); @@ -64,93 +66,30 @@ public sealed class ChatHistoryMemoryProvider : AIContextProvider, IDisposable /// The vector store to use for storing and retrieving chat history. /// The name of the collection for storing chat history in the vector store. /// The number of dimensions to use for the chat history vector store embeddings. - /// Optional values to scope the chat history storage with. - /// Optional values to scope the chat history search with. Where values are null, no filtering is done using those values. Defaults to if not provided. - /// Optional configuration options. - /// Optional logger factory. - /// Thrown when is . - public ChatHistoryMemoryProvider( - VectorStore vectorStore, - string collectionName, - int vectorDimensions, - ChatHistoryMemoryProviderScope storageScope, - ChatHistoryMemoryProviderScope? searchScope = null, - ChatHistoryMemoryProviderOptions? options = null, - ILoggerFactory? loggerFactory = null) - : this( - vectorStore, - collectionName, - vectorDimensions, - new ChatHistoryMemoryProviderState - { - StorageScope = new(Throw.IfNull(storageScope)), - SearchScope = searchScope ?? new(storageScope), - }, - options, - loggerFactory) - { - } - - /// - /// Initializes a new instance of the class from previously serialized state. - /// - /// The vector store to use for storing and retrieving chat history. - /// The name of the collection for storing chat history in the vector store. - /// The number of dimensions to use for the chat history vector store embeddings. - /// A representing the serialized state of the provider. - /// Optional settings for customizing the JSON deserialization process. + /// A delegate that initializes the provider state on the first invocation, providing the storage and search scopes. /// Optional configuration options. /// Optional logger factory. + /// Thrown when or is . public ChatHistoryMemoryProvider( VectorStore vectorStore, string collectionName, int vectorDimensions, - JsonElement serializedState, - JsonSerializerOptions? jsonSerializerOptions = null, + Func stateInitializer, ChatHistoryMemoryProviderOptions? options = null, ILoggerFactory? loggerFactory = null) - : this( - vectorStore, - collectionName, - vectorDimensions, - DeserializeState(serializedState, jsonSerializerOptions), - options, - loggerFactory) { - } + this._vectorStore = Throw.IfNull(vectorStore); + this._stateInitializer = Throw.IfNull(stateInitializer); - private ChatHistoryMemoryProvider( - VectorStore vectorStore, - string collectionName, - int vectorDimensions, - ChatHistoryMemoryProviderState? state = null, - ChatHistoryMemoryProviderOptions? options = null, - ILoggerFactory? loggerFactory = null) - { - this._vectorStore = vectorStore ?? throw new ArgumentNullException(nameof(vectorStore)); options ??= new ChatHistoryMemoryProviderOptions(); this._maxResults = options.MaxResults.HasValue ? Throw.IfLessThanOrEqual(options.MaxResults.Value, 0) : DefaultMaxResults; this._contextPrompt = options.ContextPrompt ?? DefaultContextPrompt; this._enableSensitiveTelemetryData = options.EnableSensitiveTelemetryData; this._searchTime = options.SearchTime; + this._stateKey = options.StateKey ?? DefaultStateBagKey; this._logger = loggerFactory?.CreateLogger(); - - if (state == null || state.StorageScope == null || state.SearchScope == null) - { - throw new InvalidOperationException($"The {nameof(ChatHistoryMemoryProvider)} state did not contain the required scope properties."); - } - - this._storageScope = state.StorageScope; - this._searchScope = state.SearchScope; - - // Create on-demand search tool (only used when behavior is OnDemandFunctionCalling) - this._tools = - [ - AIFunctionFactory.Create( - (Func>)this.SearchTextAsync, - name: options.FunctionToolName ?? DefaultFunctionToolName, - description: options.FunctionToolDescription ?? DefaultFunctionToolDescription) - ]; + this._toolName = options.FunctionToolName ?? DefaultFunctionToolName; + this._toolDescription = options.FunctionToolDescription ?? DefaultFunctionToolDescription; // Create a definition so that we can use the dimensions provided at runtime. var definition = new VectorStoreCollectionDefinition @@ -174,15 +113,51 @@ private ChatHistoryMemoryProvider( this._collection = this._vectorStore.GetDynamicCollection(Throw.IfNullOrWhitespace(collectionName), definition); } + /// + /// Gets the state from the session's StateBag, or initializes it using the StateInitializer if not present. + /// + /// The agent session containing the StateBag. + /// The provider state, or null if no session is available. + private State? GetOrInitializeState(AgentSession? session) + { + if (session?.StateBag.TryGetValue(this._stateKey, out var state, AgentJsonUtilities.DefaultOptions) is true && state is not null) + { + return state; + } + + state = this._stateInitializer(session); + if (state is not null && session is not null) + { + session.StateBag.SetValue(this._stateKey, state, AgentJsonUtilities.DefaultOptions); + } + + return state; + } + /// protected override async ValueTask InvokingCoreAsync(InvokingContext context, CancellationToken cancellationToken = default) { _ = Throw.IfNull(context); + var state = this.GetOrInitializeState(context.Session); + var searchScope = state?.SearchScope ?? new ChatHistoryMemoryProviderScope(); + if (this._searchTime == ChatHistoryMemoryProviderOptions.SearchBehavior.OnDemandFunctionCalling) { + Task InlineSearchAsync(string userQuestion, CancellationToken ct) + => this.SearchTextAsync(userQuestion, searchScope, ct); + + // Create on-demand search tool (only used when behavior is OnDemandFunctionCalling) + AITool[] tools = + [ + AIFunctionFactory.Create( + InlineSearchAsync, + name: this._toolName, + description: this._toolDescription) + ]; + // Expose search tool for on-demand invocation by the model - return new AIContext { Tools = this._tools }; + return new AIContext { Tools = tools }; } try @@ -199,7 +174,7 @@ protected override async ValueTask InvokingCoreAsync(InvokingContext } // Search for relevant chat history - var contextText = await this.SearchTextAsync(requestText, cancellationToken).ConfigureAwait(false); + var contextText = await this.SearchTextAsync(requestText, searchScope, cancellationToken).ConfigureAwait(false); if (string.IsNullOrWhiteSpace(contextText)) { @@ -218,10 +193,10 @@ protected override async ValueTask InvokingCoreAsync(InvokingContext this._logger.LogError( ex, "ChatHistoryMemoryProvider: Failed to search for chat history due to error. ApplicationId: '{ApplicationId}', AgentId: '{AgentId}', SessionId: '{SessionId}', UserId: '{UserId}'.", - this._searchScope.ApplicationId, - this._searchScope.AgentId, - this._searchScope.SessionId, - this.SanitizeLogData(this._searchScope.UserId)); + searchScope.ApplicationId, + searchScope.AgentId, + searchScope.SessionId, + this.SanitizeLogData(searchScope.UserId)); } return new AIContext(); @@ -239,6 +214,9 @@ protected override async ValueTask InvokedCoreAsync(InvokedContext context, Canc return; } + var state = this.GetOrInitializeState(context.Session); + var storageScope = state?.StorageScope ?? new ChatHistoryMemoryProviderScope(); + try { // Ensure the collection is initialized @@ -253,10 +231,10 @@ protected override async ValueTask InvokedCoreAsync(InvokedContext context, Canc ["Role"] = message.Role.ToString(), ["MessageId"] = message.MessageId, ["AuthorName"] = message.AuthorName, - ["ApplicationId"] = this._storageScope?.ApplicationId, - ["AgentId"] = this._storageScope?.AgentId, - ["UserId"] = this._storageScope?.UserId, - ["SessionId"] = this._storageScope?.SessionId, + ["ApplicationId"] = storageScope.ApplicationId, + ["AgentId"] = storageScope.AgentId, + ["UserId"] = storageScope.UserId, + ["SessionId"] = storageScope.SessionId, ["Content"] = message.Text, ["CreatedAt"] = message.CreatedAt?.ToString("O") ?? DateTimeOffset.UtcNow.ToString("O"), ["ContentEmbedding"] = message.Text, @@ -275,10 +253,10 @@ protected override async ValueTask InvokedCoreAsync(InvokedContext context, Canc this._logger.LogError( ex, "ChatHistoryMemoryProvider: Failed to add messages to chat history vector store due to error. ApplicationId: '{ApplicationId}', AgentId: '{AgentId}', SessionId: '{SessionId}', UserId: '{UserId}'.", - this._searchScope.ApplicationId, - this._searchScope.AgentId, - this._searchScope.SessionId, - this.SanitizeLogData(this._searchScope.UserId)); + storageScope.ApplicationId, + storageScope.AgentId, + storageScope.SessionId, + this.SanitizeLogData(storageScope.UserId)); } } } @@ -287,16 +265,17 @@ protected override async ValueTask InvokedCoreAsync(InvokedContext context, Canc /// Function callable by the AI model (when enabled) to perform an ad-hoc chat history search. /// /// The query text. + /// The scope to filter search results with. /// Cancellation token. /// Formatted search results (may be empty). - internal async Task SearchTextAsync(string userQuestion, CancellationToken cancellationToken = default) + private async Task SearchTextAsync(string userQuestion, ChatHistoryMemoryProviderScope searchScope, CancellationToken cancellationToken = default) { if (string.IsNullOrWhiteSpace(userQuestion)) { return string.Empty; } - var results = await this.SearchChatHistoryAsync(userQuestion, this._maxResults, cancellationToken).ConfigureAwait(false); + var results = await this.SearchChatHistoryAsync(userQuestion, searchScope, this._maxResults, cancellationToken).ConfigureAwait(false); if (!results.Any()) { return string.Empty; @@ -317,10 +296,10 @@ internal async Task SearchTextAsync(string userQuestion, CancellationTok "ChatHistoryMemoryProvider: Search Results\nInput:{Input}\nOutput:{MessageText}\n ApplicationId: '{ApplicationId}', AgentId: '{AgentId}', SessionId: '{SessionId}', UserId: '{UserId}'.", this.SanitizeLogData(userQuestion), this.SanitizeLogData(formatted), - this._searchScope.ApplicationId, - this._searchScope.AgentId, - this._searchScope.SessionId, - this.SanitizeLogData(this._searchScope.UserId)); + searchScope.ApplicationId, + searchScope.AgentId, + searchScope.SessionId, + this.SanitizeLogData(searchScope.UserId)); } return formatted; @@ -330,11 +309,13 @@ internal async Task SearchTextAsync(string userQuestion, CancellationTok /// Searches for relevant chat history items based on the provided query text. /// /// The text to search for. + /// The scope to filter search results with. /// The maximum number of results to return. /// The cancellation token. /// A list of relevant chat history items. private async Task>> SearchChatHistoryAsync( string queryText, + ChatHistoryMemoryProviderScope searchScope, int top, CancellationToken cancellationToken = default) { @@ -345,10 +326,10 @@ internal async Task SearchTextAsync(string userQuestion, CancellationTok var collection = await this.EnsureCollectionExistsAsync(cancellationToken).ConfigureAwait(false); - string? applicationId = this._searchScope.ApplicationId; - string? agentId = this._searchScope.AgentId; - string? userId = this._searchScope.UserId; - string? sessionId = this._searchScope.SessionId; + string? applicationId = searchScope.ApplicationId; + string? agentId = searchScope.AgentId; + string? userId = searchScope.UserId; + string? sessionId = searchScope.SessionId; Expression, bool>>? filter = null; if (applicationId != null) @@ -401,10 +382,10 @@ internal async Task SearchTextAsync(string userQuestion, CancellationTok this._logger.LogInformation( "ChatHistoryMemoryProvider: Retrieved {Count} search results. ApplicationId: '{ApplicationId}', AgentId: '{AgentId}', SessionId: '{SessionId}', UserId: '{UserId}'.", results.Count, - this._searchScope.ApplicationId, - this._searchScope.AgentId, - this._searchScope.SessionId, - this.SanitizeLogData(this._searchScope.UserId)); + searchScope.ApplicationId, + searchScope.AgentId, + searchScope.SessionId, + this.SanitizeLogData(searchScope.UserId)); } return results; @@ -465,39 +446,32 @@ public void Dispose() GC.SuppressFinalize(this); } + private string? SanitizeLogData(string? data) => this._enableSensitiveTelemetryData ? data : ""; + /// - /// Serializes the current provider state to a including storage and search scopes. + /// Represents the state of a stored in the . /// - /// Optional serializer options. - /// Serialized provider state. - public override JsonElement Serialize(JsonSerializerOptions? jsonSerializerOptions = null) - { - var state = new ChatHistoryMemoryProviderState - { - StorageScope = this._storageScope, - SearchScope = this._searchScope, - }; - - var jso = jsonSerializerOptions ?? AgentJsonUtilities.DefaultOptions; - return JsonSerializer.SerializeToElement(state, jso.GetTypeInfo(typeof(ChatHistoryMemoryProviderState))); - } - - private static ChatHistoryMemoryProviderState? DeserializeState(JsonElement serializedState, JsonSerializerOptions? jsonSerializerOptions) + public sealed class State { - if (serializedState.ValueKind != JsonValueKind.Object) + /// + /// Initializes a new instance of the class with the specified storage and search scopes. + /// + /// The scope to use when storing chat history messages. + /// The scope to use when searching for relevant chat history messages. If null, the storage scope will be used for searching as well. + public State(ChatHistoryMemoryProviderScope storageScope, ChatHistoryMemoryProviderScope? searchScope = null) { - return null; + this.StorageScope = Throw.IfNull(storageScope); + this.SearchScope = searchScope ?? storageScope; } - var jso = jsonSerializerOptions ?? AgentJsonUtilities.DefaultOptions; - return serializedState.Deserialize(jso.GetTypeInfo(typeof(ChatHistoryMemoryProviderState))) as ChatHistoryMemoryProviderState; - } + /// + /// Gets or sets the scope used when storing chat history messages. + /// + public ChatHistoryMemoryProviderScope StorageScope { get; } - private string? SanitizeLogData(string? data) => this._enableSensitiveTelemetryData ? data : ""; - - internal sealed class ChatHistoryMemoryProviderState - { - public ChatHistoryMemoryProviderScope? StorageScope { get; set; } - public ChatHistoryMemoryProviderScope? SearchScope { get; set; } + /// + /// Gets or sets the scope used when searching chat history messages. + /// + public ChatHistoryMemoryProviderScope SearchScope { get; } } } diff --git a/dotnet/src/Microsoft.Agents.AI/Memory/ChatHistoryMemoryProviderOptions.cs b/dotnet/src/Microsoft.Agents.AI/Memory/ChatHistoryMemoryProviderOptions.cs index e09de68a59..05b03fef63 100644 --- a/dotnet/src/Microsoft.Agents.AI/Memory/ChatHistoryMemoryProviderOptions.cs +++ b/dotnet/src/Microsoft.Agents.AI/Memory/ChatHistoryMemoryProviderOptions.cs @@ -44,6 +44,15 @@ public sealed class ChatHistoryMemoryProviderOptions /// Defaults to . public bool EnableSensitiveTelemetryData { get; set; } + /// + /// Gets or sets the key used to store provider state in the . + /// + /// + /// Defaults to "ChatHistoryMemoryProvider.State". Override this if you need multiple + /// instances with separate state in the same session. + /// + public string? StateKey { get; set; } + /// /// Behavior choices for the provider. /// diff --git a/dotnet/src/Microsoft.Agents.AI/TextSearchProvider.cs b/dotnet/src/Microsoft.Agents.AI/TextSearchProvider.cs index ee87d4f00c..abb1297616 100644 --- a/dotnet/src/Microsoft.Agents.AI/TextSearchProvider.cs +++ b/dotnet/src/Microsoft.Agents.AI/TextSearchProvider.cs @@ -4,7 +4,6 @@ using System.Collections.Generic; using System.Linq; using System.Text; -using System.Text.Json; using System.Threading; using System.Threading.Tasks; using Microsoft.Extensions.AI; @@ -39,31 +38,28 @@ public sealed class TextSearchProvider : AIContextProvider private const string DefaultPluginSearchFunctionDescription = "Allows searching for additional information to help answer the user question."; private const string DefaultContextPrompt = "## Additional Context\nConsider the following information from source documents when responding to the user:"; private const string DefaultCitationsPrompt = "Include citations to the source document with document name and link if document name and link is available."; + private const string DefaultStateBagKey = "TextSearchProvider.RecentMessagesText"; private readonly Func>> _searchAsync; private readonly ILogger? _logger; private readonly AITool[] _tools; - private readonly Queue _recentMessagesText; private readonly List _recentMessageRolesIncluded; private readonly int _recentMessageMemoryLimit; private readonly TextSearchProviderOptions.TextSearchBehavior _searchTime; private readonly string _contextPrompt; private readonly string _citationsPrompt; + private readonly string _stateKey; private readonly Func, string>? _contextFormatter; /// /// Initializes a new instance of the class. /// /// Delegate that executes the search logic. Must not be . - /// A representing the serialized provider state. - /// Optional serializer options (unused - source generated context is used). /// Optional configuration options. /// Optional logger factory. /// Thrown when is . public TextSearchProvider( Func>> searchAsync, - JsonElement serializedState, - JsonSerializerOptions? jsonSerializerOptions = null, TextSearchProviderOptions? options = null, ILoggerFactory? loggerFactory = null) { @@ -75,27 +71,9 @@ public TextSearchProvider( this._searchTime = options?.SearchTime ?? TextSearchProviderOptions.TextSearchBehavior.BeforeAIInvoke; this._contextPrompt = options?.ContextPrompt ?? DefaultContextPrompt; this._citationsPrompt = options?.CitationsPrompt ?? DefaultCitationsPrompt; + this._stateKey = options?.StateKey ?? DefaultStateBagKey; this._contextFormatter = options?.ContextFormatter; - // Restore recent messages from serialized state if provided - List? restoredMessages = null; - if (serializedState.ValueKind is JsonValueKind.Null or JsonValueKind.Undefined) - { - this._recentMessagesText = new(); - } - else - { - var jso = jsonSerializerOptions ?? AgentJsonUtilities.DefaultOptions; - var state = serializedState.Deserialize(jso.GetTypeInfo(typeof(TextSearchProviderState))) as TextSearchProviderState; - if (state?.RecentMessagesText is { Count: > 0 }) - { - restoredMessages = state.RecentMessagesText; - } - - // Restore recent messages respecting the limit (may truncate if limit changed afterwards). - this._recentMessagesText = restoredMessages is null ? new() : new(restoredMessages.Take(this._recentMessageMemoryLimit)); - } - // Create the on-demand search tool (only used if behavior is OnDemandFunctionCalling) this._tools = [ @@ -115,12 +93,16 @@ protected override async ValueTask InvokingCoreAsync(InvokingContext return new AIContext { Tools = this._tools }; // No automatic message injection. } + // Retrieve recent messages from the session state bag. + var recentMessagesText = context.Session?.StateBag.GetValue(this._stateKey, AgentJsonUtilities.DefaultOptions)?.RecentMessagesText + ?? []; + // Aggregate text from memory + current request messages. var sbInput = new StringBuilder(); var requestMessagesText = context.RequestMessages .Where(m => m.GetAgentRequestMessageSource() == AgentRequestMessageSourceType.External) .Where(x => !string.IsNullOrWhiteSpace(x?.Text)).Select(x => x.Text); - foreach (var messageText in this._recentMessagesText.Concat(requestMessagesText)) + foreach (var messageText in recentMessagesText.Concat(requestMessagesText)) { if (sbInput.Length > 0) { @@ -176,12 +158,21 @@ protected override ValueTask InvokedCoreAsync(InvokedContext context, Cancellati return default; // Memory disabled. } + if (context.Session is null) + { + return default; // No session to store state in. + } + if (context.InvokeException is not null) { return default; // Do not update memory on failed invocations. } - var messagesText = context.RequestMessages + // Retrieve existing recent messages from the session state bag. + var recentMessagesText = context.Session.StateBag.GetValue(this._stateKey, AgentJsonUtilities.DefaultOptions)?.RecentMessagesText + ?? []; + + var newMessagesText = context.RequestMessages .Where(m => m.GetAgentRequestMessageSource() == AgentRequestMessageSourceType.External) .Concat(context.ResponseMessages ?? []) .Where(m => @@ -190,44 +181,23 @@ protected override ValueTask InvokedCoreAsync(InvokedContext context, Cancellati // Filter out any messages that were added by this class in InvokingAsync, since we don't want // a feedback loop where previous search results are used to find new search results. (m.AdditionalProperties == null || m.AdditionalProperties.TryGetValue("IsTextSearchProviderOutput", out bool isTextSearchProviderOutput) == false || !isTextSearchProviderOutput)) - .Select(m => m.Text) - .ToList(); - if (messagesText.Count > limit) - { - // If the current request/response exceeds the limit, only keep the most recent messages from it. - messagesText = messagesText.Skip(messagesText.Count - limit).ToList(); - } + .Select(m => m.Text); - foreach (var message in messagesText) - { - this._recentMessagesText.Enqueue(message); - } + // Combine existing messages with new messages, then take the most recent up to the limit. + var allMessages = recentMessagesText.Concat(newMessagesText).ToList(); + var updatedMessages = allMessages.Count > limit + ? allMessages.Skip(allMessages.Count - limit).ToList() + : allMessages; - while (this._recentMessagesText.Count > limit) - { - this._recentMessagesText.Dequeue(); - } + // Store updated state back to the session state bag. + context.Session.StateBag.SetValue( + this._stateKey, + new TextSearchProviderState { RecentMessagesText = updatedMessages }, + AgentJsonUtilities.DefaultOptions); return default; } - /// - /// Serializes the current provider state to a containing any overridden prompts or descriptions. - /// - /// Optional serializer options (ignored, source generated context is used). - /// A with overridden values, or default if nothing was overridden. - public override JsonElement Serialize(JsonSerializerOptions? jsonSerializerOptions = null) - { - // Only persist values that differ from defaults plus recent memory configuration & messages. - TextSearchProviderState state = new(); - if (this._recentMessageMemoryLimit > 0 && this._recentMessagesText.Count > 0) - { - state.RecentMessagesText = this._recentMessagesText.Take(this._recentMessageMemoryLimit).ToList(); - } - - return JsonSerializer.SerializeToElement(state, AgentJsonUtilities.DefaultOptions.GetTypeInfo(typeof(TextSearchProviderState))); - } - /// /// Function callable by the AI model (when enabled) to perform an ad-hoc search. /// diff --git a/dotnet/src/Microsoft.Agents.AI/TextSearchProviderOptions.cs b/dotnet/src/Microsoft.Agents.AI/TextSearchProviderOptions.cs index e90a6efa63..89d959f679 100644 --- a/dotnet/src/Microsoft.Agents.AI/TextSearchProviderOptions.cs +++ b/dotnet/src/Microsoft.Agents.AI/TextSearchProviderOptions.cs @@ -59,6 +59,15 @@ public sealed class TextSearchProviderOptions /// public int RecentMessageMemoryLimit { get; set; } + /// + /// Gets or sets the key used to store provider state in the . + /// + /// + /// Defaults to "TextSearchProvider.RecentMessagesText". Override this if you need multiple + /// instances with separate state in the same session. + /// + public string? StateKey { get; set; } + /// /// Gets or sets the list of types to filter recent messages to /// when deciding which recent messages to include when constructing the search input. diff --git a/dotnet/tests/AnthropicChatCompletion.IntegrationTests/AnthropicChatCompletionFixture.cs b/dotnet/tests/AnthropicChatCompletion.IntegrationTests/AnthropicChatCompletionFixture.cs index a0e4b64763..f36cb119d9 100644 --- a/dotnet/tests/AnthropicChatCompletion.IntegrationTests/AnthropicChatCompletionFixture.cs +++ b/dotnet/tests/AnthropicChatCompletion.IntegrationTests/AnthropicChatCompletionFixture.cs @@ -37,14 +37,14 @@ public AnthropicChatCompletionFixture(bool useReasoningChatModel, bool useBeta) public async Task> GetChatHistoryAsync(AIAgent agent, AgentSession session) { - var typedSession = (ChatClientAgentSession)session; + var chatHistoryProvider = agent.GetService(); - if (typedSession.ChatHistoryProvider is null) + if (chatHistoryProvider is null) { return []; } - return (await typedSession.ChatHistoryProvider.InvokingAsync(new(agent, session, []))).ToList(); + return (await chatHistoryProvider.InvokingAsync(new(agent, session, []))).ToList(); } public Task CreateChatClientAgentAsync( diff --git a/dotnet/tests/AzureAI.IntegrationTests/AIProjectClientFixture.cs b/dotnet/tests/AzureAI.IntegrationTests/AIProjectClientFixture.cs index c655fd7a58..dd926174c0 100644 --- a/dotnet/tests/AzureAI.IntegrationTests/AIProjectClientFixture.cs +++ b/dotnet/tests/AzureAI.IntegrationTests/AIProjectClientFixture.cs @@ -48,12 +48,14 @@ public async Task> GetChatHistoryAsync(AIAgent agent, AgentSes return await this.GetChatHistoryFromResponsesChainAsync(chatClientSession.ConversationId); } - if (chatClientSession.ChatHistoryProvider is null) + var chatHistoryProvider = agent.GetService(); + + if (chatHistoryProvider is null) { return []; } - return (await chatClientSession.ChatHistoryProvider.InvokingAsync(new(agent, session, []))).ToList(); + return (await chatHistoryProvider.InvokingAsync(new(agent, session, []))).ToList(); } private async Task> GetChatHistoryFromResponsesChainAsync(string conversationId) diff --git a/dotnet/tests/Microsoft.Agents.AI.A2A.UnitTests/A2AAgentSessionTests.cs b/dotnet/tests/Microsoft.Agents.AI.A2A.UnitTests/A2AAgentSessionTests.cs index 66018e0131..8c3e89adf6 100644 --- a/dotnet/tests/Microsoft.Agents.AI.A2A.UnitTests/A2AAgentSessionTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.A2A.UnitTests/A2AAgentSessionTests.cs @@ -21,10 +21,28 @@ public void Constructor_RoundTrip_SerializationPreservesState() // Act JsonElement serialized = originalSession.Serialize(); - A2AAgentSession deserializedSession = new(serialized); + A2AAgentSession deserializedSession = A2AAgentSession.Deserialize(serialized); // Assert Assert.Equal(originalSession.ContextId, deserializedSession.ContextId); Assert.Equal(originalSession.TaskId, deserializedSession.TaskId); } + + [Fact] + public void Constructor_RoundTrip_SerializationPreservesStateBag() + { + // Arrange + A2AAgentSession originalSession = new() { ContextId = "ctx-1", TaskId = "task-1" }; + originalSession.StateBag.SetValue("testKey", "testValue"); + + // Act + JsonElement serialized = originalSession.Serialize(); + A2AAgentSession deserializedSession = A2AAgentSession.Deserialize(serialized); + + // Assert + Assert.Equal("ctx-1", deserializedSession.ContextId); + Assert.Equal("task-1", deserializedSession.TaskId); + Assert.True(deserializedSession.StateBag.TryGetValue("testKey", out var value)); + Assert.Equal("testValue", value); + } } diff --git a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AIContextProviderTests.cs b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AIContextProviderTests.cs index 44d1be2e74..38d89827ae 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AIContextProviderTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AIContextProviderTests.cs @@ -136,19 +136,6 @@ public async Task InvokedAsync_ReturnsCompletedTaskAsync() Assert.Equal(default, task); } - [Fact] - public void Serialize_ReturnsEmptyElement() - { - // Arrange - var provider = new TestAIContextProvider(); - - // Act - var actual = provider.Serialize(); - - // Assert - Assert.Equal(default, actual); - } - [Fact] public void InvokingContext_Constructor_ThrowsForNullMessages() { diff --git a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AgentSessionStateBagTests.cs b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AgentSessionStateBagTests.cs new file mode 100644 index 0000000000..b30af7acc6 --- /dev/null +++ b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AgentSessionStateBagTests.cs @@ -0,0 +1,747 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Text.Json; +using Microsoft.Agents.AI.Abstractions.UnitTests.Models; + +namespace Microsoft.Agents.AI.Abstractions.UnitTests; + +/// +/// Contains tests for the class. +/// +public sealed class AgentSessionStateBagTests +{ + #region Constructor Tests + + [Fact] + public void Constructor_Default_CreatesEmptyStateBag() + { + // Act + var stateBag = new AgentSessionStateBag(); + + // Assert + Assert.False(stateBag.TryGetValue("nonexistent", out _)); + } + + #endregion + + #region SetValue Tests + + [Fact] + public void SetValue_WithValidKeyAndValue_StoresValue() + { + // Arrange + var stateBag = new AgentSessionStateBag(); + + // Act + stateBag.SetValue("key1", "value1"); + + // Assert + Assert.True(stateBag.TryGetValue("key1", out var result)); + Assert.Equal("value1", result); + } + + [Fact] + public void SetValue_WithNullKey_ThrowsArgumentException() + { + // Arrange + var stateBag = new AgentSessionStateBag(); + + // Act & Assert + Assert.Throws(() => stateBag.SetValue(null!, "value")); + } + + [Fact] + public void SetValue_WithEmptyKey_ThrowsArgumentException() + { + // Arrange + var stateBag = new AgentSessionStateBag(); + + // Act & Assert + Assert.Throws(() => stateBag.SetValue("", "value")); + } + + [Fact] + public void SetValue_WithWhitespaceKey_ThrowsArgumentException() + { + // Arrange + var stateBag = new AgentSessionStateBag(); + + // Act & Assert + Assert.Throws(() => stateBag.SetValue(" ", "value")); + } + + [Fact] + public void SetValue_OverwritesExistingValue() + { + // Arrange + var stateBag = new AgentSessionStateBag(); + stateBag.SetValue("key1", "originalValue"); + + // Act + stateBag.SetValue("key1", "newValue"); + + // Assert + Assert.Equal("newValue", stateBag.GetValue("key1")); + } + + #endregion + + #region GetValue Tests + + [Fact] + public void GetValue_WithExistingKey_ReturnsValue() + { + // Arrange + var stateBag = new AgentSessionStateBag(); + stateBag.SetValue("key1", "value1"); + + // Act + var result = stateBag.GetValue("key1"); + + // Assert + Assert.Equal("value1", result); + } + + [Fact] + public void GetValue_WithNonexistentKey_ReturnsNull() + { + // Arrange + var stateBag = new AgentSessionStateBag(); + + // Act + var result = stateBag.GetValue("nonexistent"); + + // Assert + Assert.Null(result); + } + + [Fact] + public void GetValue_WithNullKey_ThrowsArgumentException() + { + // Arrange + var stateBag = new AgentSessionStateBag(); + + // Act & Assert + Assert.Throws(() => stateBag.GetValue(null!)); + } + + [Fact] + public void GetValue_WithEmptyKey_ThrowsArgumentException() + { + // Arrange + var stateBag = new AgentSessionStateBag(); + + // Act & Assert + Assert.Throws(() => stateBag.GetValue("")); + } + + [Fact] + public void GetValue_CachesDeserializedValue() + { + // Arrange + var stateBag = new AgentSessionStateBag(); + stateBag.SetValue("key1", "value1"); + + // Act + var result1 = stateBag.GetValue("key1"); + var result2 = stateBag.GetValue("key1"); + + // Assert + Assert.Same(result1, result2); + } + + #endregion + + #region TryGetValue Tests + + [Fact] + public void TryGetValue_WithExistingKey_ReturnsTrueAndValue() + { + // Arrange + var stateBag = new AgentSessionStateBag(); + stateBag.SetValue("key1", "value1"); + + // Act + var found = stateBag.TryGetValue("key1", out var result); + + // Assert + Assert.True(found); + Assert.Equal("value1", result); + } + + [Fact] + public void TryGetValue_WithNonexistentKey_ReturnsFalseAndNull() + { + // Arrange + var stateBag = new AgentSessionStateBag(); + + // Act + var found = stateBag.TryGetValue("nonexistent", out var result); + + // Assert + Assert.False(found); + Assert.Null(result); + } + + [Fact] + public void TryGetValue_WithNullKey_ThrowsArgumentException() + { + // Arrange + var stateBag = new AgentSessionStateBag(); + + // Act & Assert + Assert.Throws(() => stateBag.TryGetValue(null!, out _)); + } + + [Fact] + public void TryGetValue_WithEmptyKey_ThrowsArgumentException() + { + // Arrange + var stateBag = new AgentSessionStateBag(); + + // Act & Assert + Assert.Throws(() => stateBag.TryGetValue("", out _)); + } + + #endregion + + #region Null Value Tests + + [Fact] + public void SetValue_WithNullValue_StoresNull() + { + // Arrange + var stateBag = new AgentSessionStateBag(); + + // Act + stateBag.SetValue("key1", null); + + // Assert + Assert.Equal(1, stateBag.Count); + } + + [Fact] + public void TryGetValue_WithNullValue_ReturnsTrueAndNull() + { + // Arrange + var stateBag = new AgentSessionStateBag(); + stateBag.SetValue("key1", null); + + // Act + var found = stateBag.TryGetValue("key1", out var result); + + // Assert + Assert.True(found); + Assert.Null(result); + } + + [Fact] + public void GetValue_WithNullValue_ReturnsNull() + { + // Arrange + var stateBag = new AgentSessionStateBag(); + stateBag.SetValue("key1", null); + + // Act + var result = stateBag.GetValue("key1"); + + // Assert + Assert.Null(result); + } + + [Fact] + public void SetValue_OverwriteWithNull_ReturnsNull() + { + // Arrange + var stateBag = new AgentSessionStateBag(); + stateBag.SetValue("key1", "value1"); + + // Act + stateBag.SetValue("key1", null); + + // Assert + Assert.True(stateBag.TryGetValue("key1", out var result)); + Assert.Null(result); + } + + [Fact] + public void SetValue_OverwriteNullWithValue_ReturnsValue() + { + // Arrange + var stateBag = new AgentSessionStateBag(); + stateBag.SetValue("key1", null); + + // Act + stateBag.SetValue("key1", "newValue"); + + // Assert + Assert.True(stateBag.TryGetValue("key1", out var result)); + Assert.Equal("newValue", result); + } + + [Fact] + public void SerializeDeserialize_WithNullValue_SerializesAsNull() + { + // Arrange + var stateBag = new AgentSessionStateBag(); + stateBag.SetValue("nullKey", null); + + // Act + var json = stateBag.Serialize(); + + // Assert - null values are serialized as JSON null + Assert.Equal(JsonValueKind.Object, json.ValueKind); + Assert.True(json.TryGetProperty("nullKey", out var nullElement)); + Assert.Equal(JsonValueKind.Null, nullElement.ValueKind); + } + + #endregion + + #region TryRemoveValue Tests + + [Fact] + public void TryRemoveValue_ExistingKey_ReturnsTrueAndRemoves() + { + // Arrange + var stateBag = new AgentSessionStateBag(); + stateBag.SetValue("key1", "value1"); + + // Act + var removed = stateBag.TryRemoveValue("key1"); + + // Assert + Assert.True(removed); + Assert.Equal(0, stateBag.Count); + Assert.False(stateBag.TryGetValue("key1", out _)); + } + + [Fact] + public void TryRemoveValue_NonexistentKey_ReturnsFalse() + { + // Arrange + var stateBag = new AgentSessionStateBag(); + + // Act + var removed = stateBag.TryRemoveValue("nonexistent"); + + // Assert + Assert.False(removed); + } + + [Fact] + public void TryRemoveValue_WithNullKey_ThrowsArgumentException() + { + // Arrange + var stateBag = new AgentSessionStateBag(); + + // Act & Assert + Assert.Throws(() => stateBag.TryRemoveValue(null!)); + } + + [Fact] + public void TryRemoveValue_WithEmptyKey_ThrowsArgumentException() + { + // Arrange + var stateBag = new AgentSessionStateBag(); + + // Act & Assert + Assert.Throws(() => stateBag.TryRemoveValue("")); + } + + [Fact] + public void TryRemoveValue_WithWhitespaceKey_ThrowsArgumentException() + { + // Arrange + var stateBag = new AgentSessionStateBag(); + + // Act & Assert + Assert.Throws(() => stateBag.TryRemoveValue(" ")); + } + + [Fact] + public void TryRemoveValue_DoesNotAffectOtherKeys() + { + // Arrange + var stateBag = new AgentSessionStateBag(); + stateBag.SetValue("key1", "value1"); + stateBag.SetValue("key2", "value2"); + + // Act + stateBag.TryRemoveValue("key1"); + + // Assert + Assert.Equal(1, stateBag.Count); + Assert.False(stateBag.TryGetValue("key1", out _)); + Assert.True(stateBag.TryGetValue("key2", out var value)); + Assert.Equal("value2", value); + } + + [Fact] + public void TryRemoveValue_ThenSetValue_Works() + { + // Arrange + var stateBag = new AgentSessionStateBag(); + stateBag.SetValue("key1", "original"); + + // Act + stateBag.TryRemoveValue("key1"); + stateBag.SetValue("key1", "replacement"); + + // Assert + Assert.True(stateBag.TryGetValue("key1", out var result)); + Assert.Equal("replacement", result); + } + + #endregion + + #region Serialize/Deserialize Tests + + [Fact] + public void Serialize_EmptyStateBag_ReturnsEmptyObject() + { + // Arrange + var stateBag = new AgentSessionStateBag(); + + // Act + var json = stateBag.Serialize(); + + // Assert + Assert.Equal(JsonValueKind.Object, json.ValueKind); + } + + [Fact] + public void Serialize_WithStringValue_ReturnsJsonWithValue() + { + // Arrange + var stateBag = new AgentSessionStateBag(); + stateBag.SetValue("stringKey", "stringValue"); + + // Act + var json = stateBag.Serialize(); + + // Assert + Assert.Equal(JsonValueKind.Object, json.ValueKind); + Assert.True(json.TryGetProperty("stringKey", out _)); + } + + [Fact] + public void Deserialize_FromJsonDocument_ReturnsEmptyStateBag() + { + // Arrange + var emptyJson = JsonDocument.Parse("{}").RootElement; + + // Act + var stateBag = AgentSessionStateBag.Deserialize(emptyJson); + + // Assert + Assert.False(stateBag.TryGetValue("nonexistent", out _)); + } + + [Fact] + public void Deserialize_NullElement_ReturnsEmptyStateBag() + { + // Arrange + var nullJson = default(JsonElement); + + // Act + var stateBag = AgentSessionStateBag.Deserialize(nullJson); + + // Assert + Assert.False(stateBag.TryGetValue("nonexistent", out _)); + } + + [Fact] + public void SerializeDeserialize_WithStringValue_Roundtrips() + { + // Arrange + var originalStateBag = new AgentSessionStateBag(); + originalStateBag.SetValue("stringKey", "stringValue"); + + // Act + var json = originalStateBag.Serialize(); + var restoredStateBag = AgentSessionStateBag.Deserialize(json); + + // Assert + Assert.Equal("stringValue", restoredStateBag.GetValue("stringKey")); + } + + #endregion + + #region Thread Safety Tests + + [Fact] + public async System.Threading.Tasks.Task SetValue_MultipleConcurrentWrites_DoesNotThrowAsync() + { + // Arrange + var stateBag = new AgentSessionStateBag(); + var tasks = new System.Threading.Tasks.Task[100]; + + // Act + for (int i = 0; i < 100; i++) + { + int index = i; + tasks[i] = System.Threading.Tasks.Task.Run(() => stateBag.SetValue($"key{index}", $"value{index}")); + } + + await System.Threading.Tasks.Task.WhenAll(tasks); + + // Assert + for (int i = 0; i < 100; i++) + { + Assert.True(stateBag.TryGetValue($"key{i}", out var value)); + Assert.Equal($"value{i}", value); + } + } + + [Fact] + public async System.Threading.Tasks.Task ConcurrentWritesAndSerialize_DoesNotThrowAsync() + { + // Arrange + var stateBag = new AgentSessionStateBag(); + stateBag.SetValue("shared", "initial"); + var tasks = new System.Threading.Tasks.Task[100]; + + // Act - concurrently write and serialize the same key + for (int i = 0; i < 100; i++) + { + int index = i; + tasks[i] = System.Threading.Tasks.Task.Run(() => + { + stateBag.SetValue("shared", $"value{index}"); + _ = stateBag.Serialize(); + }); + } + + await System.Threading.Tasks.Task.WhenAll(tasks); + + // Assert - should have some value and serialize without error + Assert.True(stateBag.TryGetValue("shared", out var result)); + Assert.NotNull(result); + var json = stateBag.Serialize(); + Assert.Equal(JsonValueKind.Object, json.ValueKind); + } + + [Fact] + public async System.Threading.Tasks.Task ConcurrentReadsAndWrites_DoesNotThrowAsync() + { + // Arrange + var stateBag = new AgentSessionStateBag(); + stateBag.SetValue("key", "initial"); + var tasks = new System.Threading.Tasks.Task[200]; + + // Act - half readers, half writers on the same key + for (int i = 0; i < 200; i++) + { + int index = i; + if (index % 2 == 0) + { + tasks[i] = System.Threading.Tasks.Task.Run(() => stateBag.GetValue("key")); + } + else + { + tasks[i] = System.Threading.Tasks.Task.Run(() => stateBag.SetValue("key", $"value{index}")); + } + } + + await System.Threading.Tasks.Task.WhenAll(tasks); + + // Assert - should have a consistent value + Assert.True(stateBag.TryGetValue("key", out var result)); + Assert.NotNull(result); + } + + #endregion + + #region Complex Object Tests + + [Fact] + public void SetValue_WithComplexObject_StoresValue() + { + // Arrange + var stateBag = new AgentSessionStateBag(); + var animal = new Animal { Id = 1, FullName = "Buddy", Species = Species.Bear }; + + // Act + stateBag.SetValue("animal", animal, TestJsonSerializerContext.Default.Options); + + // Assert + Animal? result = stateBag.GetValue("animal", TestJsonSerializerContext.Default.Options); + Assert.NotNull(result); + Assert.Equal(1, result.Id); + Assert.Equal("Buddy", result.FullName); + Assert.Equal(Species.Bear, result.Species); + } + + [Fact] + public void GetValue_WithComplexObject_CachesDeserializedValue() + { + // Arrange + var stateBag = new AgentSessionStateBag(); + var animal = new Animal { Id = 2, FullName = "Whiskers", Species = Species.Tiger }; + stateBag.SetValue("animal", animal, TestJsonSerializerContext.Default.Options); + + // Act + Animal? result1 = stateBag.GetValue("animal", TestJsonSerializerContext.Default.Options); + Animal? result2 = stateBag.GetValue("animal", TestJsonSerializerContext.Default.Options); + + // Assert + Assert.Same(result1, result2); + } + + [Fact] + public void TryGetValue_WithComplexObject_ReturnsTrueAndValue() + { + // Arrange + var stateBag = new AgentSessionStateBag(); + var animal = new Animal { Id = 3, FullName = "Goldie", Species = Species.Walrus }; + stateBag.SetValue("animal", animal, TestJsonSerializerContext.Default.Options); + + // Act + bool found = stateBag.TryGetValue("animal", out Animal? result, TestJsonSerializerContext.Default.Options); + + // Assert + Assert.True(found); + Assert.NotNull(result); + Assert.Equal(3, result.Id); + Assert.Equal("Goldie", result.FullName); + Assert.Equal(Species.Walrus, result.Species); + } + + [Fact] + public void SerializeDeserialize_WithComplexObject_Roundtrips() + { + // Arrange + var originalStateBag = new AgentSessionStateBag(); + var animal = new Animal { Id = 4, FullName = "Polly", Species = Species.Bear }; + originalStateBag.SetValue("animal", animal, TestJsonSerializerContext.Default.Options); + + // Act + JsonElement json = originalStateBag.Serialize(); + AgentSessionStateBag restoredStateBag = AgentSessionStateBag.Deserialize(json); + + // Assert + Animal? restoredAnimal = restoredStateBag.GetValue("animal", TestJsonSerializerContext.Default.Options); + Assert.NotNull(restoredAnimal); + Assert.Equal(4, restoredAnimal.Id); + Assert.Equal("Polly", restoredAnimal.FullName); + Assert.Equal(Species.Bear, restoredAnimal.Species); + } + + [Fact] + public void Serialize_WithComplexObject_ReturnsJsonWithProperties() + { + // Arrange + var stateBag = new AgentSessionStateBag(); + var animal = new Animal { Id = 7, FullName = "Spot", Species = Species.Walrus }; + stateBag.SetValue("animal", animal, TestJsonSerializerContext.Default.Options); + + // Act + JsonElement json = stateBag.Serialize(); + + // Assert + Assert.Equal(JsonValueKind.Object, json.ValueKind); + Assert.True(json.TryGetProperty("animal", out JsonElement animalElement)); + Assert.Equal(JsonValueKind.Object, animalElement.ValueKind); + Assert.Equal(7, animalElement.GetProperty("id").GetInt32()); + Assert.Equal("Spot", animalElement.GetProperty("fullName").GetString()); + Assert.Equal("Walrus", animalElement.GetProperty("species").GetString()); + } + + #endregion + + #region JsonSerializer Integration Tests + + [Fact] + public void JsonSerializerSerialize_EmptyStateBag_ReturnsEmptyObject() + { + // Arrange + var stateBag = new AgentSessionStateBag(); + + // Act + var json = JsonSerializer.Serialize(stateBag, AgentAbstractionsJsonUtilities.DefaultOptions); + + // Assert + Assert.Equal("{}", json); + } + + [Fact] + public void JsonSerializerSerialize_WithStringValue_ProducesSameOutputAsSerializeMethod() + { + // Arrange + var stateBag = new AgentSessionStateBag(); + stateBag.SetValue("stringKey", "stringValue"); + + // Act + var jsonFromSerializer = JsonSerializer.Serialize(stateBag, AgentAbstractionsJsonUtilities.DefaultOptions); + var jsonFromMethod = stateBag.Serialize().GetRawText(); + + // Assert + Assert.Equal(jsonFromMethod, jsonFromSerializer); + } + + [Fact] + public void JsonSerializerRoundtrip_WithStringValue_PreservesData() + { + // Arrange + var stateBag = new AgentSessionStateBag(); + stateBag.SetValue("greeting", "hello world"); + + // Act + var json = JsonSerializer.Serialize(stateBag, AgentAbstractionsJsonUtilities.DefaultOptions); + var restored = JsonSerializer.Deserialize(json, AgentAbstractionsJsonUtilities.DefaultOptions); + + // Assert + Assert.NotNull(restored); + Assert.Equal("hello world", restored!.GetValue("greeting")); + } + + [Fact] + public void JsonSerializerRoundtrip_WithComplexObject_PreservesData() + { + // Arrange + var stateBag = new AgentSessionStateBag(); + var animal = new Animal { Id = 10, FullName = "Rex", Species = Species.Tiger }; + stateBag.SetValue("animal", animal, TestJsonSerializerContext.Default.Options); + + // Act + var json = JsonSerializer.Serialize(stateBag, AgentAbstractionsJsonUtilities.DefaultOptions); + var restored = JsonSerializer.Deserialize(json, AgentAbstractionsJsonUtilities.DefaultOptions); + + // Assert + Assert.NotNull(restored); + var restoredAnimal = restored!.GetValue("animal", TestJsonSerializerContext.Default.Options); + Assert.NotNull(restoredAnimal); + Assert.Equal(10, restoredAnimal!.Id); + Assert.Equal("Rex", restoredAnimal.FullName); + Assert.Equal(Species.Tiger, restoredAnimal.Species); + } + + [Fact] + public void JsonSerializerDeserialize_NullJson_ReturnsNull() + { + // Arrange + const string Json = "null"; + + // Act + var stateBag = JsonSerializer.Deserialize(Json, AgentAbstractionsJsonUtilities.DefaultOptions); + + // Assert + Assert.Null(stateBag); + } + +#if NET10_0_OR_GREATER + [Fact] + public void JsonSerializerSerialize_WithUnknownType_Throws() + { + // Arrange + var stateBag = new AgentSessionStateBag(); + stateBag.SetValue("key", new { Name = "Test" }); // Anonymous type which cannot be deserialized + + // Act & Assert + Assert.Throws(() => JsonSerializer.Serialize(stateBag, AgentAbstractionsJsonUtilities.DefaultOptions)); + } +#endif + + #endregion +} diff --git a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AgentSessionTests.cs b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AgentSessionTests.cs index 5a776c9fb0..b80f0a4fd2 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AgentSessionTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AgentSessionTests.cs @@ -11,6 +11,21 @@ namespace Microsoft.Agents.AI.Abstractions.UnitTests; /// public class AgentSessionTests { + #region StateBag Tests + + [Fact] + public void StateBag_Values_Roundtrips() + { + // Arrange + var session = new TestAgentSession(); + + // Act & Assert + session.StateBag.SetValue("key1", "value1"); + Assert.Equal("value1", session.StateBag.GetValue("key1")); + } + + #endregion + #region GetService Method Tests /// diff --git a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/ChatHistoryProviderMessageFilterTests.cs b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/ChatHistoryProviderMessageFilterTests.cs index 5b48d025be..52a66b7643 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/ChatHistoryProviderMessageFilterTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/ChatHistoryProviderMessageFilterTests.cs @@ -3,7 +3,6 @@ using System; using System.Collections.Generic; using System.Linq; -using System.Text.Json; using System.Threading; using System.Threading.Tasks; using Microsoft.Extensions.AI; @@ -197,25 +196,4 @@ ChatHistoryProvider.InvokedContext InvokedFilter(ChatHistoryProvider.InvokedCont .Protected() .Verify("InvokedCoreAsync", Times.Once(), ItExpr.IsAny(), ItExpr.IsAny()); } - - [Fact] - public void Serialize_DelegatesToInnerProvider() - { - // Arrange - var innerProviderMock = new Mock(); - var expectedJson = JsonSerializer.SerializeToElement("data", TestJsonSerializerContext.Default.String); - - innerProviderMock - .Setup(s => s.Serialize(It.IsAny())) - .Returns(expectedJson); - - var filter = new ChatHistoryProviderMessageFilter(innerProviderMock.Object, x => x, x => x); - - // Act - var result = filter.Serialize(); - - // Assert - Assert.Equal(expectedJson.GetRawText(), result.GetRawText()); - innerProviderMock.Verify(s => s.Serialize(null), Times.Once); - } } diff --git a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/ChatHistoryProviderTests.cs b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/ChatHistoryProviderTests.cs index e158b159ca..1ede826e4e 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/ChatHistoryProviderTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/ChatHistoryProviderTests.cs @@ -3,7 +3,6 @@ using System; using System.Collections.Generic; using System.Linq; -using System.Text.Json; using System.Threading; using System.Threading.Tasks; using Microsoft.Extensions.AI; @@ -376,9 +375,6 @@ protected override ValueTask> InvokingCoreAsync(Invokin protected override ValueTask InvokedCoreAsync(InvokedContext context, CancellationToken cancellationToken = default) => default; - - public override JsonElement Serialize(JsonSerializerOptions? jsonSerializerOptions = null) - => default; } private sealed class TestChatHistoryProviderWithCustomSource : ChatHistoryProvider @@ -392,9 +388,6 @@ protected override ValueTask> InvokingCoreAsync(Invokin protected override ValueTask InvokedCoreAsync(InvokedContext context, CancellationToken cancellationToken = default) => default; - - public override JsonElement Serialize(JsonSerializerOptions? jsonSerializerOptions = null) - => default; } private sealed class TestChatHistoryProviderWithPreStampedMessages : ChatHistoryProvider @@ -412,9 +405,6 @@ protected override ValueTask> InvokingCoreAsync(Invokin protected override ValueTask InvokedCoreAsync(InvokedContext context, CancellationToken cancellationToken = default) => default; - - public override JsonElement Serialize(JsonSerializerOptions? jsonSerializerOptions = null) - => default; } private sealed class TestChatHistoryProviderWithMultipleMessages : ChatHistoryProvider @@ -428,8 +418,5 @@ protected override ValueTask> InvokingCoreAsync(Invokin protected override ValueTask InvokedCoreAsync(InvokedContext context, CancellationToken cancellationToken = default) => default; - - public override JsonElement Serialize(JsonSerializerOptions? jsonSerializerOptions = null) - => default; } } diff --git a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/InMemoryAgentSessionTests.cs b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/InMemoryAgentSessionTests.cs deleted file mode 100644 index a3a4bf7e0e..0000000000 --- a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/InMemoryAgentSessionTests.cs +++ /dev/null @@ -1,155 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. - -using System; -using System.Collections.Generic; -using System.Linq; -using System.Text.Json; -using Microsoft.Extensions.AI; - -namespace Microsoft.Agents.AI.Abstractions.UnitTests; - -/// -/// Contains tests for . -/// -public class InMemoryAgentSessionTests -{ - #region Constructor and Property Tests - - [Fact] - public void Constructor_SetsDefaultChatHistoryProvider() - { - // Arrange & Act - var session = new TestInMemoryAgentSession(); - - // Assert - Assert.NotNull(session.GetChatHistoryProvider()); - Assert.Empty(session.GetChatHistoryProvider()); - } - - [Fact] - public void Constructor_WithChatHistoryProvider_SetsProperty() - { - // Arrange - InMemoryChatHistoryProvider provider = [new(ChatRole.User, "Hello")]; - - // Act - var session = new TestInMemoryAgentSession(provider); - - // Assert - Assert.Same(provider, session.GetChatHistoryProvider()); - Assert.Single(session.GetChatHistoryProvider()); - Assert.Equal("Hello", session.GetChatHistoryProvider()[0].Text); - } - - [Fact] - public void Constructor_WithMessages_SetsProperty() - { - // Arrange - var messages = new List { new(ChatRole.User, "Hi") }; - - // Act - var session = new TestInMemoryAgentSession(messages); - - // Assert - Assert.NotNull(session.GetChatHistoryProvider()); - Assert.Single(session.GetChatHistoryProvider()); - Assert.Equal("Hi", session.GetChatHistoryProvider()[0].Text); - } - - [Fact] - public void Constructor_WithSerializedState_SetsProperty() - { - // Arrange - InMemoryChatHistoryProvider provider = [new(ChatRole.User, "TestMsg")]; - var providerState = provider.Serialize(); - var sessionStateWrapper = new InMemoryAgentSession.InMemoryAgentSessionState { ChatHistoryProviderState = providerState }; - var json = JsonSerializer.SerializeToElement(sessionStateWrapper, TestJsonSerializerContext.Default.InMemoryAgentSessionState); - - // Act - var session = new TestInMemoryAgentSession(json); - - // Assert - Assert.NotNull(session.GetChatHistoryProvider()); - Assert.Single(session.GetChatHistoryProvider()); - Assert.Equal("TestMsg", session.GetChatHistoryProvider()[0].Text); - } - - [Fact] - public void Constructor_WithInvalidJson_ThrowsArgumentException() - { - // Arrange - var invalidJson = JsonSerializer.SerializeToElement(42, TestJsonSerializerContext.Default.Int32); - - // Act & Assert - Assert.Throws(() => new TestInMemoryAgentSession(invalidJson)); - } - - #endregion - - #region SerializeAsync Tests - - [Fact] - public void Serialize_ReturnsCorrectJson_WhenMessagesExist() - { - // Arrange - var session = new TestInMemoryAgentSession([new(ChatRole.User, "TestContent")]); - - // Act - var json = session.Serialize(); - - // Assert - Assert.Equal(JsonValueKind.Object, json.ValueKind); - Assert.True(json.TryGetProperty("chatHistoryProviderState", out var providerStateProperty)); - Assert.Equal(JsonValueKind.Object, providerStateProperty.ValueKind); - Assert.True(providerStateProperty.TryGetProperty("messages", out var messagesProperty)); - Assert.Equal(JsonValueKind.Array, messagesProperty.ValueKind); - var messagesList = messagesProperty.EnumerateArray().ToList(); - Assert.Single(messagesList); - } - - [Fact] - public void Serialize_ReturnsEmptyMessages_WhenNoMessages() - { - // Arrange - var session = new TestInMemoryAgentSession(); - - // Act - var json = session.Serialize(); - - // Assert - Assert.Equal(JsonValueKind.Object, json.ValueKind); - Assert.True(json.TryGetProperty("chatHistoryProviderState", out var providerStateProperty)); - Assert.Equal(JsonValueKind.Object, providerStateProperty.ValueKind); - Assert.True(providerStateProperty.TryGetProperty("messages", out var messagesProperty)); - Assert.Equal(JsonValueKind.Array, messagesProperty.ValueKind); - Assert.Empty(messagesProperty.EnumerateArray()); - } - - #endregion - - #region GetService Tests - - [Fact] - public void GetService_RequestingChatHistoryProvider_ReturnsChatHistoryProvider() - { - // Arrange - var session = new TestInMemoryAgentSession(); - - // Act & Assert - Assert.NotNull(session.GetService(typeof(ChatHistoryProvider))); - Assert.Same(session.GetChatHistoryProvider(), session.GetService(typeof(ChatHistoryProvider))); - Assert.Same(session.GetChatHistoryProvider(), session.GetService(typeof(InMemoryChatHistoryProvider))); - } - - #endregion - - // Sealed test subclass to expose protected members for testing - private sealed class TestInMemoryAgentSession : InMemoryAgentSession - { - public TestInMemoryAgentSession() { } - public TestInMemoryAgentSession(InMemoryChatHistoryProvider? provider) : base(provider) { } - public TestInMemoryAgentSession(IEnumerable messages) : base(messages) { } - public TestInMemoryAgentSession(JsonElement serializedSessionState) : base(serializedSessionState) { } - public InMemoryChatHistoryProvider GetChatHistoryProvider() => this.ChatHistoryProvider; - } -} diff --git a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/InMemoryChatHistoryProviderTests.cs b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/InMemoryChatHistoryProviderTests.cs index 75232073a6..4ab8a02aaa 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/InMemoryChatHistoryProviderTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/InMemoryChatHistoryProviderTests.cs @@ -3,9 +3,6 @@ using System; using System.Collections.Generic; using System.Linq; -using System.Text.Encodings.Web; -using System.Text.Json; -using System.Text.Json.Serialization.Metadata; using System.Threading; using System.Threading.Tasks; using Microsoft.Extensions.AI; @@ -19,22 +16,18 @@ namespace Microsoft.Agents.AI.Abstractions.UnitTests; public class InMemoryChatHistoryProviderTests { private static readonly AIAgent s_mockAgent = new Mock().Object; - private static readonly AgentSession s_mockSession = new Mock().Object; - [Fact] - public void Constructor_Throws_ForNullReducer() => - // Arrange & Act & Assert - Assert.Throws(() => new InMemoryChatHistoryProvider(null!)); + private static AgentSession CreateMockSession() => new Mock().Object; [Fact] public void Constructor_DefaultsToBeforeMessageRetrieval_ForNotProvidedTriggerEvent() { // Arrange & Act var reducerMock = new Mock(); - var provider = new InMemoryChatHistoryProvider(reducerMock.Object); + var provider = new InMemoryChatHistoryProvider(new() { ChatReducer = reducerMock.Object }); // Assert - Assert.Equal(InMemoryChatHistoryProvider.ChatReducerTriggerEvent.BeforeMessagesRetrieval, provider.ReducerTriggerEvent); + Assert.Equal(InMemoryChatHistoryProviderOptions.ChatReducerTriggerEvent.BeforeMessagesRetrieval, provider.ReducerTriggerEvent); } [Fact] @@ -42,16 +35,19 @@ public void Constructor_Arguments_SetOnPropertiesCorrectly() { // Arrange & Act var reducerMock = new Mock(); - var provider = new InMemoryChatHistoryProvider(reducerMock.Object, InMemoryChatHistoryProvider.ChatReducerTriggerEvent.AfterMessageAdded); + var provider = new InMemoryChatHistoryProvider(new() { ChatReducer = reducerMock.Object, ReducerTriggerEvent = InMemoryChatHistoryProviderOptions.ChatReducerTriggerEvent.AfterMessageAdded }); // Assert Assert.Same(reducerMock.Object, provider.ChatReducer); - Assert.Equal(InMemoryChatHistoryProvider.ChatReducerTriggerEvent.AfterMessageAdded, provider.ReducerTriggerEvent); + Assert.Equal(InMemoryChatHistoryProviderOptions.ChatReducerTriggerEvent.AfterMessageAdded, provider.ReducerTriggerEvent); } [Fact] public async Task InvokedAsyncAddsMessagesAsync() { + var session = CreateMockSession(); + + // Arrange var requestMessages = new List { new(ChatRole.User, "Hello"), @@ -67,439 +63,157 @@ public async Task InvokedAsyncAddsMessagesAsync() }; var provider = new InMemoryChatHistoryProvider(); - provider.Add(providerMessages[0]); - var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages) + provider.SetMessages(session, [providerMessages[0]]); + var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, session, requestMessages) { ResponseMessages = responseMessages }; await provider.InvokedAsync(context, CancellationToken.None); - Assert.Equal(4, provider.Count); - Assert.Equal("original instructions", provider[0].Text); - Assert.Equal("Hello", provider[1].Text); - Assert.Equal("additional context", provider[2].Text); - Assert.Equal("Hi there!", provider[3].Text); + // Assert + var messages = provider.GetMessages(session); + Assert.Equal(4, messages.Count); + Assert.Equal("original instructions", messages[0].Text); + Assert.Equal("Hello", messages[1].Text); + Assert.Equal("additional context", messages[2].Text); + Assert.Equal("Hi there!", messages[3].Text); } [Fact] public async Task InvokedAsyncWithEmptyDoesNotFailAsync() { + var session = CreateMockSession(); + + // Arrange var provider = new InMemoryChatHistoryProvider(); - var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, []); + var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, session, []); await provider.InvokedAsync(context, CancellationToken.None); - Assert.Empty(provider); + // Assert + Assert.Empty(provider.GetMessages(session)); } [Fact] public async Task InvokingAsyncReturnsAllMessagesAsync() { - var provider = new InMemoryChatHistoryProvider - { + var session = CreateMockSession(); + + // Arrange + var provider = new InMemoryChatHistoryProvider(); + provider.SetMessages(session, + [ new ChatMessage(ChatRole.User, "Test1"), new ChatMessage(ChatRole.Assistant, "Test2") - }; + ]); - var context = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, []); + var context = new ChatHistoryProvider.InvokingContext(s_mockAgent, session, []); var result = (await provider.InvokingAsync(context, CancellationToken.None)).ToList(); + // Assert Assert.Equal(2, result.Count); Assert.Contains(result, m => m.Text == "Test1"); Assert.Contains(result, m => m.Text == "Test2"); } [Fact] - public async Task DeserializeConstructorWithEmptyElementAsync() - { - var emptyObject = JsonSerializer.Deserialize("{}", TestJsonSerializerContext.Default.JsonElement); - - var newProvider = new InMemoryChatHistoryProvider(emptyObject); - - Assert.Empty(newProvider); - } - - [Fact] - public async Task SerializeAndDeserializeConstructorRoundtripsAsync() - { - var provider = new InMemoryChatHistoryProvider - { - new ChatMessage(ChatRole.User, "A"), - new ChatMessage(ChatRole.Assistant, "B") - }; - - var jsonElement = provider.Serialize(); - var newProvider = new InMemoryChatHistoryProvider(jsonElement); - - Assert.Equal(2, newProvider.Count); - Assert.Equal("A", newProvider[0].Text); - Assert.Equal("B", newProvider[1].Text); - } - - [Fact] - public async Task SerializeAndDeserializeConstructorRoundtripsWithCustomAIContentAsync() - { - JsonSerializerOptions options = new(TestJsonSerializerContext.Default.Options) - { - TypeInfoResolver = JsonTypeInfoResolver.Combine(AgentAbstractionsJsonUtilities.DefaultOptions.TypeInfoResolver, TestJsonSerializerContext.Default), - Encoder = JavaScriptEncoder.UnsafeRelaxedJsonEscaping, - }; - options.AddAIContentType(typeDiscriminatorId: "testContent"); - - var provider = new InMemoryChatHistoryProvider - { - new ChatMessage(ChatRole.User, [new TestAIContent("foo data")]), - }; - - var jsonElement = provider.Serialize(options); - var newProvider = new InMemoryChatHistoryProvider(jsonElement, options); - - Assert.Single(newProvider); - var actualTestAIContent = Assert.IsType(newProvider[0].Contents[0]); - Assert.Equal("foo data", actualTestAIContent.TestData); - } - - [Fact] - public async Task SerializeAndDeserializeWorksWithExperimentalContentTypesAsync() - { - var provider = new InMemoryChatHistoryProvider - { - new ChatMessage(ChatRole.User, [new FunctionApprovalRequestContent("call123", new FunctionCallContent("call123", "some_func"))]), - new ChatMessage(ChatRole.Assistant, [new FunctionApprovalResponseContent("call123", true, new FunctionCallContent("call123", "some_func"))]) - }; - - var jsonElement = provider.Serialize(); - var newProvider = new InMemoryChatHistoryProvider(jsonElement); - - Assert.Equal(2, newProvider.Count); - Assert.IsType(newProvider[0].Contents[0]); - Assert.IsType(newProvider[1].Contents[0]); - } - - [Fact] - public async Task InvokedAsyncWithEmptyMessagesDoesNotChangeProviderAsync() - { - var provider = new InMemoryChatHistoryProvider(); - var messages = new List(); - - var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, messages); - await provider.InvokedAsync(context, CancellationToken.None); - - Assert.Empty(provider); - } - - [Fact] - public async Task InvokedAsync_WithNullContext_ThrowsArgumentNullExceptionAsync() + public void StateInitializer_IsInvoked_WhenSessionHasNoState() { // Arrange - var provider = new InMemoryChatHistoryProvider(); - - // Act & Assert - await Assert.ThrowsAsync(() => provider.InvokedAsync(null!, CancellationToken.None).AsTask()); - } - - [Fact] - public void DeserializeContructor_WithNullSerializedState_CreatesEmptyProvider() - { - // Act - var provider = new InMemoryChatHistoryProvider(new JsonElement()); - - // Assert - Assert.Empty(provider); - } - - [Fact] - public async Task DeserializeContructor_WithEmptyMessages_DoesNotAddMessagesAsync() - { - // Arrange - var stateWithEmptyMessages = JsonSerializer.SerializeToElement( - new Dictionary { ["messages"] = new List() }, - TestJsonSerializerContext.Default.IDictionaryStringObject); - - // Act - var provider = new InMemoryChatHistoryProvider(stateWithEmptyMessages); - - // Assert - Assert.Empty(provider); - } - - [Fact] - public async Task DeserializeConstructor_WithNullMessages_DoesNotAddMessagesAsync() - { - // Arrange - var stateWithNullMessages = JsonSerializer.SerializeToElement( - new Dictionary { ["messages"] = null! }, - TestJsonSerializerContext.Default.DictionaryStringObject); - - // Act - var provider = new InMemoryChatHistoryProvider(stateWithNullMessages); - - // Assert - Assert.Empty(provider); - } - - [Fact] - public async Task DeserializeConstructor_WithValidMessages_AddsMessagesAsync() - { - // Arrange - var messages = new List + var initialMessages = new List { - new(ChatRole.User, "User message"), - new(ChatRole.Assistant, "Assistant message") + new(ChatRole.User, "Initial message") }; - var state = new Dictionary { ["messages"] = messages }; - var serializedState = JsonSerializer.SerializeToElement( - state, - TestJsonSerializerContext.Default.DictionaryStringObject); + var provider = new InMemoryChatHistoryProvider(new() + { + StateInitializer = _ => new InMemoryChatHistoryProvider.State { Messages = initialMessages } + }); // Act - var provider = new InMemoryChatHistoryProvider(serializedState); + var messages = provider.GetMessages(CreateMockSession()); // Assert - Assert.Equal(2, provider.Count); - Assert.Equal("User message", provider[0].Text); - Assert.Equal("Assistant message", provider[1].Text); + Assert.Single(messages); + Assert.Equal("Initial message", messages[0].Text); } [Fact] - public void IndexerGet_ReturnsCorrectMessage() + public void GetMessages_ReturnsEmptyList_WhenNullSession() { // Arrange var provider = new InMemoryChatHistoryProvider(); - var message1 = new ChatMessage(ChatRole.User, "First"); - var message2 = new ChatMessage(ChatRole.Assistant, "Second"); - provider.Add(message1); - provider.Add(message2); - - // Act & Assert - Assert.Same(message1, provider[0]); - Assert.Same(message2, provider[1]); - } - - [Fact] - public void IndexerSet_UpdatesMessage() - { - // Arrange - var provider = new InMemoryChatHistoryProvider(); - var originalMessage = new ChatMessage(ChatRole.User, "Original"); - var newMessage = new ChatMessage(ChatRole.User, "Updated"); - provider.Add(originalMessage); // Act - provider[0] = newMessage; + var messages = provider.GetMessages(null); // Assert - Assert.Same(newMessage, provider[0]); - Assert.Equal("Updated", provider[0].Text); - } - - [Fact] - public void IsReadOnly_ReturnsFalse() - { - // Arrange - var provider = new InMemoryChatHistoryProvider(); - - // Act & Assert - Assert.False(provider.IsReadOnly); + Assert.Empty(messages); } [Fact] - public void IndexOf_ReturnsCorrectIndex() + public void SetMessages_ThrowsForNullMessages() { // Arrange var provider = new InMemoryChatHistoryProvider(); - var message1 = new ChatMessage(ChatRole.User, "First"); - var message2 = new ChatMessage(ChatRole.Assistant, "Second"); - var message3 = new ChatMessage(ChatRole.User, "Third"); - provider.Add(message1); - provider.Add(message2); // Act & Assert - Assert.Equal(0, provider.IndexOf(message1)); - Assert.Equal(1, provider.IndexOf(message2)); - Assert.Equal(-1, provider.IndexOf(message3)); // Not in provider + Assert.Throws(() => provider.SetMessages(CreateMockSession(), null!)); } [Fact] - public void Insert_InsertsMessageAtCorrectIndex() + public void SetMessages_UpdatesState() { - // Arrange - var provider = new InMemoryChatHistoryProvider(); - var message1 = new ChatMessage(ChatRole.User, "First"); - var message2 = new ChatMessage(ChatRole.Assistant, "Second"); - var insertMessage = new ChatMessage(ChatRole.User, "Inserted"); - provider.Add(message1); - provider.Add(message2); - - // Act - provider.Insert(1, insertMessage); - - // Assert - Assert.Equal(3, provider.Count); - Assert.Same(message1, provider[0]); - Assert.Same(insertMessage, provider[1]); - Assert.Same(message2, provider[2]); - } + var session = CreateMockSession(); - [Fact] - public void RemoveAt_RemovesMessageAtIndex() - { // Arrange var provider = new InMemoryChatHistoryProvider(); - var message1 = new ChatMessage(ChatRole.User, "First"); - var message2 = new ChatMessage(ChatRole.Assistant, "Second"); - var message3 = new ChatMessage(ChatRole.User, "Third"); - provider.Add(message1); - provider.Add(message2); - provider.Add(message3); - - // Act - provider.RemoveAt(1); - - // Assert - Assert.Equal(2, provider.Count); - Assert.Same(message1, provider[0]); - Assert.Same(message3, provider[1]); - } - - [Fact] - public void Clear_RemovesAllMessages() - { - // Arrange - var provider = new InMemoryChatHistoryProvider + var messages = new List { - new ChatMessage(ChatRole.User, "First"), - new ChatMessage(ChatRole.Assistant, "Second") + new(ChatRole.User, "Hello"), + new(ChatRole.Assistant, "World") }; // Act - provider.Clear(); + provider.SetMessages(session, messages); + var retrieved = provider.GetMessages(session); // Assert - Assert.Empty(provider); + Assert.Equal(2, retrieved.Count); + Assert.Equal("Hello", retrieved[0].Text); + Assert.Equal("World", retrieved[1].Text); } [Fact] - public void Contains_ReturnsTrueForExistingMessage() - { - // Arrange - var provider = new InMemoryChatHistoryProvider(); - var message1 = new ChatMessage(ChatRole.User, "First"); - var message2 = new ChatMessage(ChatRole.Assistant, "Second"); - provider.Add(message1); - - // Act & Assert - Assert.Contains(message1, provider); - Assert.DoesNotContain(message2, provider); - } - - [Fact] - public void CopyTo_CopiesMessagesToArray() - { - // Arrange - var provider = new InMemoryChatHistoryProvider(); - var message1 = new ChatMessage(ChatRole.User, "First"); - var message2 = new ChatMessage(ChatRole.Assistant, "Second"); - provider.Add(message1); - provider.Add(message2); - var array = new ChatMessage[4]; - - // Act - provider.CopyTo(array, 1); - - // Assert - Assert.Null(array[0]); - Assert.Same(message1, array[1]); - Assert.Same(message2, array[2]); - Assert.Null(array[3]); - } - - [Fact] - public void Remove_RemovesSpecificMessage() - { - // Arrange - var provider = new InMemoryChatHistoryProvider(); - var message1 = new ChatMessage(ChatRole.User, "First"); - var message2 = new ChatMessage(ChatRole.Assistant, "Second"); - var message3 = new ChatMessage(ChatRole.User, "Third"); - provider.Add(message1); - provider.Add(message2); - provider.Add(message3); - - // Act - var removed = provider.Remove(message2); - - // Assert - Assert.True(removed); - Assert.Equal(2, provider.Count); - Assert.Same(message1, provider[0]); - Assert.Same(message3, provider[1]); - } - - [Fact] - public void Remove_ReturnsFalseForNonExistentMessage() + public async Task InvokedAsyncWithEmptyMessagesDoesNotChangeProviderAsync() { - // Arrange - var provider = new InMemoryChatHistoryProvider(); - var message1 = new ChatMessage(ChatRole.User, "First"); - var message2 = new ChatMessage(ChatRole.Assistant, "Second"); - provider.Add(message1); - - // Act - var removed = provider.Remove(message2); - - // Assert - Assert.False(removed); - Assert.Single(provider); - } + var session = CreateMockSession(); - [Fact] - public void GetEnumerator_Generic_ReturnsAllMessages() - { // Arrange var provider = new InMemoryChatHistoryProvider(); - var message1 = new ChatMessage(ChatRole.User, "First"); - var message2 = new ChatMessage(ChatRole.Assistant, "Second"); - provider.Add(message1); - provider.Add(message2); - - // Act var messages = new List(); - messages.AddRange(provider); + + var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, session, messages); + await provider.InvokedAsync(context, CancellationToken.None); // Assert - Assert.Equal(2, messages.Count); - Assert.Same(message1, messages[0]); - Assert.Same(message2, messages[1]); + Assert.Empty(provider.GetMessages(session)); } [Fact] - public void GetEnumerator_NonGeneric_ReturnsAllMessages() + public async Task InvokedAsync_WithNullContext_ThrowsArgumentNullExceptionAsync() { // Arrange var provider = new InMemoryChatHistoryProvider(); - var message1 = new ChatMessage(ChatRole.User, "First"); - var message2 = new ChatMessage(ChatRole.Assistant, "Second"); - provider.Add(message1); - provider.Add(message2); - - // Act - var messages = new List(); - var enumerator = ((System.Collections.IEnumerable)provider).GetEnumerator(); - while (enumerator.MoveNext()) - { - messages.Add((ChatMessage)enumerator.Current); - } - // Assert - Assert.Equal(2, messages.Count); - Assert.Same(message1, messages[0]); - Assert.Same(message2, messages[1]); + // Act & Assert + await Assert.ThrowsAsync(() => provider.InvokedAsync(null!, CancellationToken.None).AsTask()); } [Fact] public async Task AddMessagesAsync_WithReducer_AfterMessageAdded_InvokesReducerAsync() { + var session = CreateMockSession(); + // Arrange var originalMessages = new List { @@ -516,21 +230,24 @@ public async Task AddMessagesAsync_WithReducer_AfterMessageAdded_InvokesReducerA .Setup(r => r.ReduceAsync(It.Is>(x => x.SequenceEqual(originalMessages)), It.IsAny())) .ReturnsAsync(reducedMessages); - var provider = new InMemoryChatHistoryProvider(reducerMock.Object, InMemoryChatHistoryProvider.ChatReducerTriggerEvent.AfterMessageAdded); + var provider = new InMemoryChatHistoryProvider(new() { ChatReducer = reducerMock.Object, ReducerTriggerEvent = InMemoryChatHistoryProviderOptions.ChatReducerTriggerEvent.AfterMessageAdded }); // Act - var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, originalMessages); + var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, session, originalMessages); await provider.InvokedAsync(context, CancellationToken.None); // Assert - Assert.Single(provider); - Assert.Equal("Reduced", provider[0].Text); + var messages = provider.GetMessages(session); + Assert.Single(messages); + Assert.Equal("Reduced", messages[0].Text); reducerMock.Verify(r => r.ReduceAsync(It.Is>(x => x.SequenceEqual(originalMessages)), It.IsAny()), Times.Once); } [Fact] public async Task GetMessagesAsync_WithReducer_BeforeMessagesRetrieval_InvokesReducerAsync() { + var session = CreateMockSession(); + // Arrange var originalMessages = new List { @@ -547,15 +264,11 @@ public async Task GetMessagesAsync_WithReducer_BeforeMessagesRetrieval_InvokesRe .Setup(r => r.ReduceAsync(It.Is>(x => x.SequenceEqual(originalMessages)), It.IsAny())) .ReturnsAsync(reducedMessages); - var provider = new InMemoryChatHistoryProvider(reducerMock.Object, InMemoryChatHistoryProvider.ChatReducerTriggerEvent.BeforeMessagesRetrieval); - // Add messages directly to the provider for this test - foreach (var msg in originalMessages) - { - provider.Add(msg); - } + var provider = new InMemoryChatHistoryProvider(new() { ChatReducer = reducerMock.Object, ReducerTriggerEvent = InMemoryChatHistoryProviderOptions.ChatReducerTriggerEvent.BeforeMessagesRetrieval }); + provider.SetMessages(session, new List(originalMessages)); // Act - var invokingContext = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, Array.Empty()); + var invokingContext = new ChatHistoryProvider.InvokingContext(s_mockAgent, session, Array.Empty()); var result = (await provider.InvokingAsync(invokingContext, CancellationToken.None)).ToList(); // Assert @@ -567,6 +280,8 @@ public async Task GetMessagesAsync_WithReducer_BeforeMessagesRetrieval_InvokesRe [Fact] public async Task AddMessagesAsync_WithReducer_ButWrongTrigger_DoesNotInvokeReducerAsync() { + var session = CreateMockSession(); + // Arrange var originalMessages = new List { @@ -575,21 +290,24 @@ public async Task AddMessagesAsync_WithReducer_ButWrongTrigger_DoesNotInvokeRedu var reducerMock = new Mock(); - var provider = new InMemoryChatHistoryProvider(reducerMock.Object, InMemoryChatHistoryProvider.ChatReducerTriggerEvent.BeforeMessagesRetrieval); + var provider = new InMemoryChatHistoryProvider(new() { ChatReducer = reducerMock.Object, ReducerTriggerEvent = InMemoryChatHistoryProviderOptions.ChatReducerTriggerEvent.BeforeMessagesRetrieval }); // Act - var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, originalMessages); + var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, session, originalMessages); await provider.InvokedAsync(context, CancellationToken.None); // Assert - Assert.Single(provider); - Assert.Equal("Hello", provider[0].Text); + var messages = provider.GetMessages(session); + Assert.Single(messages); + Assert.Equal("Hello", messages[0].Text); reducerMock.Verify(r => r.ReduceAsync(It.IsAny>(), It.IsAny()), Times.Never); } [Fact] public async Task GetMessagesAsync_WithReducer_ButWrongTrigger_DoesNotInvokeReducerAsync() { + var session = CreateMockSession(); + // Arrange var originalMessages = new List { @@ -598,13 +316,11 @@ public async Task GetMessagesAsync_WithReducer_ButWrongTrigger_DoesNotInvokeRedu var reducerMock = new Mock(); - var provider = new InMemoryChatHistoryProvider(reducerMock.Object, InMemoryChatHistoryProvider.ChatReducerTriggerEvent.AfterMessageAdded) - { - originalMessages[0] - }; + var provider = new InMemoryChatHistoryProvider(new() { ChatReducer = reducerMock.Object, ReducerTriggerEvent = InMemoryChatHistoryProviderOptions.ChatReducerTriggerEvent.AfterMessageAdded }); + provider.SetMessages(session, new List(originalMessages)); // Act - var invokingContext = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, Array.Empty()); + var invokingContext = new ChatHistoryProvider.InvokingContext(s_mockAgent, session, Array.Empty()); var result = (await provider.InvokingAsync(invokingContext, CancellationToken.None)).ToList(); // Assert @@ -616,6 +332,8 @@ public async Task GetMessagesAsync_WithReducer_ButWrongTrigger_DoesNotInvokeRedu [Fact] public async Task InvokedAsync_WithException_DoesNotAddMessagesAsync() { + var session = CreateMockSession(); + // Arrange var provider = new InMemoryChatHistoryProvider(); var requestMessages = new List @@ -626,7 +344,7 @@ public async Task InvokedAsync_WithException_DoesNotAddMessagesAsync() { new(ChatRole.Assistant, "Hi there!") }; - var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages) + var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, session, requestMessages) { ResponseMessages = responseMessages, InvokeException = new InvalidOperationException("Test exception") @@ -636,7 +354,7 @@ public async Task InvokedAsync_WithException_DoesNotAddMessagesAsync() await provider.InvokedAsync(context, CancellationToken.None); // Assert - Assert.Empty(provider); + Assert.Empty(provider.GetMessages(session)); } [Fact] diff --git a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/ServiceIdAgentSessionTests.cs b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/ServiceIdAgentSessionTests.cs deleted file mode 100644 index e4a6626f72..0000000000 --- a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/ServiceIdAgentSessionTests.cs +++ /dev/null @@ -1,119 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. - -using System; -using System.Text.Json; - -namespace Microsoft.Agents.AI.Abstractions.UnitTests; - -/// -/// Tests for . -/// -public class ServiceIdAgentSessionTests -{ - #region Constructor and Property Tests - - [Fact] - public void Constructor_SetsDefaults() - { - // Arrange & Act - var session = new TestServiceIdAgentSession(); - - // Assert - Assert.Null(session.GetServiceSessionId()); - } - - [Fact] - public void Constructor_WithServiceSessionId_SetsProperty() - { - // Arrange & Act - var session = new TestServiceIdAgentSession("service-id-123"); - - // Assert - Assert.Equal("service-id-123", session.GetServiceSessionId()); - } - - [Fact] - public void Constructor_WithSerializedId_SetsProperty() - { - // Arrange - var serviceSessionWrapper = new ServiceIdAgentSession.ServiceIdAgentSessionState { ServiceSessionId = "service-id-456" }; - var json = JsonSerializer.SerializeToElement(serviceSessionWrapper, TestJsonSerializerContext.Default.ServiceIdAgentSessionState); - - // Act - var session = new TestServiceIdAgentSession(json); - - // Assert - Assert.Equal("service-id-456", session.GetServiceSessionId()); - } - - [Fact] - public void Constructor_WithSerializedUndefinedId_SetsProperty() - { - // Arrange - var emptyObject = new EmptyObject(); - var json = JsonSerializer.SerializeToElement(emptyObject, TestJsonSerializerContext.Default.EmptyObject); - - // Act - var session = new TestServiceIdAgentSession(json); - - // Assert - Assert.Null(session.GetServiceSessionId()); - } - - [Fact] - public void Constructor_WithInvalidJson_ThrowsArgumentException() - { - // Arrange - var invalidJson = JsonSerializer.SerializeToElement(42, TestJsonSerializerContext.Default.Int32); - - // Act & Assert - Assert.Throws(() => new TestServiceIdAgentSession(invalidJson)); - } - - #endregion - - #region SerializeAsync Tests - - [Fact] - public void Serialize_ReturnsCorrectJson_WhenServiceSessionIdIsSet() - { - // Arrange - var session = new TestServiceIdAgentSession("service-id-789"); - - // Act - var json = session.Serialize(); - - // Assert - Assert.Equal(JsonValueKind.Object, json.ValueKind); - Assert.True(json.TryGetProperty("serviceSessionId", out var idProperty)); - Assert.Equal("service-id-789", idProperty.GetString()); - } - - [Fact] - public void Serialize_ReturnsUndefinedServiceSessionId_WhenNotSet() - { - // Arrange - var session = new TestServiceIdAgentSession(); - - // Act - var json = session.Serialize(); - - // Assert - Assert.Equal(JsonValueKind.Object, json.ValueKind); - Assert.False(json.TryGetProperty("serviceSessionId", out _)); - } - - #endregion - - // Sealed test subclass to expose protected members for testing - private sealed class TestServiceIdAgentSession : ServiceIdAgentSession - { - public TestServiceIdAgentSession() { } - public TestServiceIdAgentSession(string serviceSessionId) : base(serviceSessionId) { } - public TestServiceIdAgentSession(JsonElement serializedSessionState) : base(serializedSessionState) { } - public string? GetServiceSessionId() => this.ServiceSessionId; - } - - // Helper class to represent empty objects - internal sealed class EmptyObject; -} diff --git a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/TestJsonSerializerContext.cs b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/TestJsonSerializerContext.cs index 05d13c2e95..c4f3b7511a 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/TestJsonSerializerContext.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/TestJsonSerializerContext.cs @@ -19,8 +19,5 @@ namespace Microsoft.Agents.AI.Abstractions.UnitTests; [JsonSerializable(typeof(Dictionary))] [JsonSerializable(typeof(string[]))] [JsonSerializable(typeof(int))] -[JsonSerializable(typeof(InMemoryAgentSession.InMemoryAgentSessionState))] -[JsonSerializable(typeof(ServiceIdAgentSession.ServiceIdAgentSessionState))] -[JsonSerializable(typeof(ServiceIdAgentSessionTests.EmptyObject))] [JsonSerializable(typeof(InMemoryChatHistoryProviderTests.TestAIContent))] internal sealed partial class TestJsonSerializerContext : JsonSerializerContext; diff --git a/dotnet/tests/Microsoft.Agents.AI.AzureAI.UnitTests/AzureAIProjectChatClientExtensionsTests.cs b/dotnet/tests/Microsoft.Agents.AI.AzureAI.UnitTests/AzureAIProjectChatClientExtensionsTests.cs index e2cbb16b1b..6fdbc7fefb 100644 --- a/dotnet/tests/Microsoft.Agents.AI.AzureAI.UnitTests/AzureAIProjectChatClientExtensionsTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.AzureAI.UnitTests/AzureAIProjectChatClientExtensionsTests.cs @@ -2310,23 +2310,18 @@ public async Task GetAIAgentAsync_WithMatchingToolsProvided_CreatesAgentSuccessf #region CreateChatClientAgentOptions - Options Preservation Tests /// - /// Verify that CreateChatClientAgentOptions preserves AIContextProviderFactory. + /// Verify that CreateChatClientAgentOptions preserves AIContextProvider. /// [Fact] - public async Task GetAIAgentAsync_WithAIContextProviderFactory_PreservesFactoryAsync() + public async Task GetAIAgentAsync_WithAIContextProvider_PreservesProviderAsync() { // Arrange AIProjectClient client = this.CreateTestAgentClient(); - bool factoryInvoked = false; var options = new ChatClientAgentOptions { Name = "test-agent", ChatOptions = new ChatOptions { Instructions = "Test" }, - AIContextProviderFactory = (_, _) => - { - factoryInvoked = true; - return new ValueTask(new TestAIContextProvider()); - } + AIContextProvider = new TestAIContextProvider() }; // Act @@ -2334,15 +2329,13 @@ public async Task GetAIAgentAsync_WithAIContextProviderFactory_PreservesFactoryA // Assert Assert.NotNull(agent); - // Verify the factory was captured (though not necessarily invoked yet) - Assert.False(factoryInvoked); // Factory is not invoked during creation } /// - /// Verify that CreateChatClientAgentOptions preserves ChatHistoryProviderFactory. + /// Verify that CreateChatClientAgentOptions preserves ChatHistoryProvider. /// [Fact] - public async Task GetAIAgentAsync_WithChatHistoryProviderFactory_PreservesFactoryAsync() + public async Task GetAIAgentAsync_WithChatHistoryProvider_PreservesProviderAsync() { // Arrange AIProjectClient client = this.CreateTestAgentClient(); @@ -2350,7 +2343,7 @@ public async Task GetAIAgentAsync_WithChatHistoryProviderFactory_PreservesFactor { Name = "test-agent", ChatOptions = new ChatOptions { Instructions = "Test" }, - ChatHistoryProviderFactory = (_, _) => new ValueTask(new TestChatHistoryProvider()) + ChatHistoryProvider = new TestChatHistoryProvider() }; // Act @@ -3160,11 +3153,6 @@ protected override ValueTask InvokedCoreAsync(InvokedContext context, Cancellati { return default; } - - public override JsonElement Serialize(JsonSerializerOptions? jsonSerializerOptions = null) - { - return default; - } } } diff --git a/dotnet/tests/Microsoft.Agents.AI.CosmosNoSql.UnitTests/CosmosChatHistoryProviderTests.cs b/dotnet/tests/Microsoft.Agents.AI.CosmosNoSql.UnitTests/CosmosChatHistoryProviderTests.cs index e1d3c612c8..cbdbe99a21 100644 --- a/dotnet/tests/Microsoft.Agents.AI.CosmosNoSql.UnitTests/CosmosChatHistoryProviderTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.CosmosNoSql.UnitTests/CosmosChatHistoryProviderTests.cs @@ -3,8 +3,6 @@ using System; using System.Collections.Generic; using System.Linq; -using System.Text.Json; -using System.Text.Json.Serialization.Metadata; using System.Threading.Tasks; using Azure.Core; using Azure.Identity; @@ -42,7 +40,8 @@ namespace Microsoft.Agents.AI.CosmosNoSql.UnitTests; public sealed class CosmosChatHistoryProviderTests : IAsyncLifetime, IDisposable { private static readonly AIAgent s_mockAgent = new Moq.Mock().Object; - private static readonly AgentSession s_mockSession = new Moq.Mock().Object; + + private static AgentSession CreateMockSession() => new Moq.Mock().Object; // Cosmos DB Emulator connection settings private const string EmulatorEndpoint = "https://localhost:8081"; @@ -157,28 +156,11 @@ public void Constructor_WithConnectionString_ShouldCreateInstance() this.SkipIfEmulatorNotAvailable(); // Act - using var provider = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, TestContainerId, "test-conversation"); - - // Assert - Assert.NotNull(provider); - Assert.Equal("test-conversation", provider.ConversationId); - Assert.Equal(s_testDatabaseId, provider.DatabaseId); - Assert.Equal(TestContainerId, provider.ContainerId); - } - - [SkippableFact] - [Trait("Category", "CosmosDB")] - public void Constructor_WithConnectionStringNoConversationId_ShouldCreateInstance() - { - // Arrange - this.SkipIfEmulatorNotAvailable(); - - // Act - using var provider = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, TestContainerId); + using var provider = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, TestContainerId, + _ => new CosmosChatHistoryProvider.State("test-conversation")); // Assert Assert.NotNull(provider); - Assert.NotNull(provider.ConversationId); Assert.Equal(s_testDatabaseId, provider.DatabaseId); Assert.Equal(TestContainerId, provider.ContainerId); } @@ -189,18 +171,19 @@ public void Constructor_WithNullConnectionString_ShouldThrowArgumentException() { // Arrange & Act & Assert Assert.Throws(() => - new CosmosChatHistoryProvider((string)null!, s_testDatabaseId, TestContainerId, "test-conversation")); + new CosmosChatHistoryProvider((string)null!, s_testDatabaseId, TestContainerId, + _ => new CosmosChatHistoryProvider.State("test-conversation"))); } [SkippableFact] [Trait("Category", "CosmosDB")] - public void Constructor_WithEmptyConversationId_ShouldThrowArgumentException() + public void Constructor_WithNullStateInitializer_ShouldThrowArgumentNullException() { // Arrange & Act & Assert this.SkipIfEmulatorNotAvailable(); - Assert.Throws(() => - new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, TestContainerId, "")); + Assert.Throws(() => + new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, TestContainerId, null!)); } #endregion @@ -213,11 +196,13 @@ public async Task InvokedAsync_WithSingleMessage_ShouldAddMessageAsync() { // Arrange this.SkipIfEmulatorNotAvailable(); + var session = CreateMockSession(); var conversationId = Guid.NewGuid().ToString(); - using var provider = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, TestContainerId, conversationId); + using var provider = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, TestContainerId, + _ => new CosmosChatHistoryProvider.State(conversationId)); var message = new ChatMessage(ChatRole.User, "Hello, world!"); - var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, [message]) + var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, session, [message]) { ResponseMessages = [] }; @@ -229,7 +214,7 @@ public async Task InvokedAsync_WithSingleMessage_ShouldAddMessageAsync() await Task.Delay(100); // Assert - var invokingContext = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, []); + var invokingContext = new ChatHistoryProvider.InvokingContext(s_mockAgent, session, []); var messages = await provider.InvokingAsync(invokingContext); var messageList = messages.ToList(); @@ -279,8 +264,10 @@ public async Task InvokedAsync_WithMultipleMessages_ShouldAddAllMessagesAsync() { // Arrange this.SkipIfEmulatorNotAvailable(); + var session = CreateMockSession(); var conversationId = Guid.NewGuid().ToString(); - using var provider = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, TestContainerId, conversationId); + using var provider = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, TestContainerId, + _ => new CosmosChatHistoryProvider.State(conversationId)); var requestMessages = new[] { new ChatMessage(ChatRole.User, "First message"), @@ -293,7 +280,7 @@ public async Task InvokedAsync_WithMultipleMessages_ShouldAddAllMessagesAsync() new ChatMessage(ChatRole.Assistant, "Response message") }; - var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages) + var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, session, requestMessages) { ResponseMessages = responseMessages }; @@ -302,7 +289,7 @@ public async Task InvokedAsync_WithMultipleMessages_ShouldAddAllMessagesAsync() await provider.InvokedAsync(context); // Assert - var invokingContext = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, []); + var invokingContext = new ChatHistoryProvider.InvokingContext(s_mockAgent, session, []); var retrievedMessages = await provider.InvokingAsync(invokingContext); var messageList = retrievedMessages.ToList(); Assert.Equal(5, messageList.Count); @@ -323,10 +310,12 @@ public async Task InvokingAsync_WithNoMessages_ShouldReturnEmptyAsync() { // Arrange this.SkipIfEmulatorNotAvailable(); - using var provider = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, TestContainerId, Guid.NewGuid().ToString()); + var session = CreateMockSession(); + using var provider = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, TestContainerId, + _ => new CosmosChatHistoryProvider.State(Guid.NewGuid().ToString())); // Act - var invokingContext = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, []); + var invokingContext = new ChatHistoryProvider.InvokingContext(s_mockAgent, session, []); var messages = await provider.InvokingAsync(invokingContext); // Assert @@ -339,21 +328,25 @@ public async Task InvokingAsync_WithConversationIsolation_ShouldOnlyReturnMessag { // Arrange this.SkipIfEmulatorNotAvailable(); + var session = CreateMockSession(); var conversation1 = Guid.NewGuid().ToString(); var conversation2 = Guid.NewGuid().ToString(); - using var store1 = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, TestContainerId, conversation1); - using var store2 = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, TestContainerId, conversation2); + // Use different stateKey values so the providers don't overwrite each other's state in the shared session + using var store1 = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, TestContainerId, + _ => new CosmosChatHistoryProvider.State(conversation1), stateKey: "conv1"); + using var store2 = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, TestContainerId, + _ => new CosmosChatHistoryProvider.State(conversation2), stateKey: "conv2"); - var context1 = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Message for conversation 1")]); - var context2 = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Message for conversation 2")]); + var context1 = new ChatHistoryProvider.InvokedContext(s_mockAgent, session, [new ChatMessage(ChatRole.User, "Message for conversation 1")]); + var context2 = new ChatHistoryProvider.InvokedContext(s_mockAgent, session, [new ChatMessage(ChatRole.User, "Message for conversation 2")]); await store1.InvokedAsync(context1); await store2.InvokedAsync(context2); // Act - var invokingContext1 = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, []); - var invokingContext2 = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, []); + var invokingContext1 = new ChatHistoryProvider.InvokingContext(s_mockAgent, session, []); + var invokingContext2 = new ChatHistoryProvider.InvokingContext(s_mockAgent, session, []); var messages1 = await store1.InvokingAsync(invokingContext1); var messages2 = await store2.InvokingAsync(invokingContext2); @@ -377,8 +370,10 @@ public async Task FullWorkflow_AddAndGet_ShouldWorkCorrectlyAsync() { // Arrange this.SkipIfEmulatorNotAvailable(); + var session = CreateMockSession(); var conversationId = $"test-conversation-{Guid.NewGuid():N}"; // Use unique conversation ID - using var originalStore = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, TestContainerId, conversationId); + using var originalStore = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, TestContainerId, + _ => new CosmosChatHistoryProvider.State(conversationId)); var messages = new[] { @@ -390,18 +385,21 @@ public async Task FullWorkflow_AddAndGet_ShouldWorkCorrectlyAsync() }; // Act 1: Add messages - var invokedContext = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, messages); + var invokedContext = new ChatHistoryProvider.InvokedContext(s_mockAgent, session, messages); await originalStore.InvokedAsync(invokedContext); // Act 2: Verify messages were added - var invokingContext = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, []); + var invokingContext = new ChatHistoryProvider.InvokingContext(s_mockAgent, session, []); var retrievedMessages = await originalStore.InvokingAsync(invokingContext); var retrievedList = retrievedMessages.ToList(); Assert.Equal(5, retrievedList.Count); // Act 3: Create new provider instance for same conversation (test persistence) - using var newProvider = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, TestContainerId, conversationId); - var persistedMessages = await newProvider.InvokingAsync(invokingContext); + using var newProvider = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, TestContainerId, + _ => new CosmosChatHistoryProvider.State(conversationId)); + var newSession = CreateMockSession(); + var newInvokingContext = new ChatHistoryProvider.InvokingContext(s_mockAgent, newSession, []); + var persistedMessages = await newProvider.InvokingAsync(newInvokingContext); var persistedList = persistedMessages.ToList(); // Assert final state @@ -423,7 +421,8 @@ public void Dispose_AfterUse_ShouldNotThrow() { // Arrange this.SkipIfEmulatorNotAvailable(); - var provider = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, TestContainerId, Guid.NewGuid().ToString()); + var provider = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, TestContainerId, + _ => new CosmosChatHistoryProvider.State(Guid.NewGuid().ToString())); // Act & Assert provider.Dispose(); // Should not throw @@ -435,7 +434,8 @@ public void Dispose_MultipleCalls_ShouldNotThrow() { // Arrange this.SkipIfEmulatorNotAvailable(); - var provider = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, TestContainerId, Guid.NewGuid().ToString()); + var provider = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, TestContainerId, + _ => new CosmosChatHistoryProvider.State(Guid.NewGuid().ToString())); // Act & Assert provider.Dispose(); // First call @@ -454,11 +454,11 @@ public void Constructor_WithHierarchicalConnectionString_ShouldCreateInstance() this.SkipIfEmulatorNotAvailable(); // Act - using var provider = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, HierarchicalTestContainerId, "tenant-123", "user-456", "session-789"); + using var provider = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, HierarchicalTestContainerId, + _ => new CosmosChatHistoryProvider.State("session-789", "tenant-123", "user-456")); // Assert Assert.NotNull(provider); - Assert.Equal("session-789", provider.ConversationId); Assert.Equal(s_testDatabaseId, provider.DatabaseId); Assert.Equal(HierarchicalTestContainerId, provider.ContainerId); } @@ -472,11 +472,11 @@ public void Constructor_WithHierarchicalEndpoint_ShouldCreateInstance() // Act TokenCredential credential = new DefaultAzureCredential(); - using var provider = new CosmosChatHistoryProvider(EmulatorEndpoint, credential, s_testDatabaseId, HierarchicalTestContainerId, "tenant-123", "user-456", "session-789"); + using var provider = new CosmosChatHistoryProvider(EmulatorEndpoint, credential, s_testDatabaseId, HierarchicalTestContainerId, + _ => new CosmosChatHistoryProvider.State("session-789", "tenant-123", "user-456")); // Assert Assert.NotNull(provider); - Assert.Equal("session-789", provider.ConversationId); Assert.Equal(s_testDatabaseId, provider.DatabaseId); Assert.Equal(HierarchicalTestContainerId, provider.ContainerId); } @@ -489,46 +489,31 @@ public void Constructor_WithHierarchicalCosmosClient_ShouldCreateInstance() this.SkipIfEmulatorNotAvailable(); using var cosmosClient = new CosmosClient(EmulatorEndpoint, EmulatorKey); - using var provider = new CosmosChatHistoryProvider(cosmosClient, s_testDatabaseId, HierarchicalTestContainerId, "tenant-123", "user-456", "session-789"); + using var provider = new CosmosChatHistoryProvider(cosmosClient, s_testDatabaseId, HierarchicalTestContainerId, + _ => new CosmosChatHistoryProvider.State("session-789", "tenant-123", "user-456")); // Assert Assert.NotNull(provider); - Assert.Equal("session-789", provider.ConversationId); Assert.Equal(s_testDatabaseId, provider.DatabaseId); Assert.Equal(HierarchicalTestContainerId, provider.ContainerId); } [SkippableFact] [Trait("Category", "CosmosDB")] - public void Constructor_WithHierarchicalNullTenantId_ShouldThrowArgumentException() - { - // Arrange & Act & Assert - this.SkipIfEmulatorNotAvailable(); - - Assert.Throws(() => - new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, TestContainerId, null!, "user-456", "session-789")); - } - - [SkippableFact] - [Trait("Category", "CosmosDB")] - public void Constructor_WithHierarchicalEmptyUserId_ShouldThrowArgumentException() + public void State_WithEmptyConversationId_ShouldThrowArgumentException() { // Arrange & Act & Assert - this.SkipIfEmulatorNotAvailable(); - Assert.Throws(() => - new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, HierarchicalTestContainerId, "tenant-123", "", "session-789")); + new CosmosChatHistoryProvider.State("")); } [SkippableFact] [Trait("Category", "CosmosDB")] - public void Constructor_WithHierarchicalWhitespaceSessionId_ShouldThrowArgumentException() + public void State_WithWhitespaceConversationId_ShouldThrowArgumentException() { // Arrange & Act & Assert - this.SkipIfEmulatorNotAvailable(); - Assert.Throws(() => - new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, HierarchicalTestContainerId, "tenant-123", "user-456", " ")); + new CosmosChatHistoryProvider.State(" ")); } [SkippableFact] @@ -537,14 +522,16 @@ public async Task InvokedAsync_WithHierarchicalPartitioning_ShouldAddMessageWith { // Arrange this.SkipIfEmulatorNotAvailable(); + var session = CreateMockSession(); const string TenantId = "tenant-123"; const string UserId = "user-456"; const string SessionId = "session-789"; // Test hierarchical partitioning constructor with connection string - using var provider = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, HierarchicalTestContainerId, TenantId, UserId, SessionId); + using var provider = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, HierarchicalTestContainerId, + _ => new CosmosChatHistoryProvider.State(SessionId, TenantId, UserId)); var message = new ChatMessage(ChatRole.User, "Hello from hierarchical partitioning!"); - var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, [message]); + var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, session, [message]); // Act await provider.InvokedAsync(context); @@ -553,7 +540,7 @@ public async Task InvokedAsync_WithHierarchicalPartitioning_ShouldAddMessageWith await Task.Delay(100); // Assert - var invokingContext = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, []); + var invokingContext = new ChatHistoryProvider.InvokingContext(s_mockAgent, session, []); var messages = await provider.InvokingAsync(invokingContext); var messageList = messages.ToList(); @@ -589,11 +576,13 @@ public async Task InvokedAsync_WithHierarchicalMultipleMessages_ShouldAddAllMess { // Arrange this.SkipIfEmulatorNotAvailable(); + var session = CreateMockSession(); const string TenantId = "tenant-batch"; const string UserId = "user-batch"; const string SessionId = "session-batch"; // Test hierarchical partitioning constructor with connection string - using var provider = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, HierarchicalTestContainerId, TenantId, UserId, SessionId); + using var provider = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, HierarchicalTestContainerId, + _ => new CosmosChatHistoryProvider.State(SessionId, TenantId, UserId)); var messages = new[] { new ChatMessage(ChatRole.User, "First hierarchical message"), @@ -601,7 +590,7 @@ public async Task InvokedAsync_WithHierarchicalMultipleMessages_ShouldAddAllMess new ChatMessage(ChatRole.User, "Third hierarchical message") }; - var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, messages); + var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, session, messages); // Act await provider.InvokedAsync(context); @@ -610,7 +599,7 @@ public async Task InvokedAsync_WithHierarchicalMultipleMessages_ShouldAddAllMess await Task.Delay(100); // Assert - var invokingContext = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, []); + var invokingContext = new ChatHistoryProvider.InvokingContext(s_mockAgent, session, []); var retrievedMessages = await provider.InvokingAsync(invokingContext); var messageList = retrievedMessages.ToList(); @@ -626,18 +615,22 @@ public async Task InvokingAsync_WithHierarchicalPartitionIsolation_ShouldIsolate { // Arrange this.SkipIfEmulatorNotAvailable(); + var session = CreateMockSession(); const string TenantId = "tenant-isolation"; const string UserId1 = "user-1"; const string UserId2 = "user-2"; const string SessionId = "session-isolation"; // Different userIds create different hierarchical partitions, providing proper isolation - using var store1 = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, HierarchicalTestContainerId, TenantId, UserId1, SessionId); - using var store2 = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, HierarchicalTestContainerId, TenantId, UserId2, SessionId); + // Use different stateKey values so the providers don't overwrite each other's state in the shared session + using var store1 = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, HierarchicalTestContainerId, + _ => new CosmosChatHistoryProvider.State(SessionId, TenantId, UserId1), stateKey: "user1"); + using var store2 = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, HierarchicalTestContainerId, + _ => new CosmosChatHistoryProvider.State(SessionId, TenantId, UserId2), stateKey: "user2"); // Add messages to both stores - var context1 = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Message from user 1")]); - var context2 = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Message from user 2")]); + var context1 = new ChatHistoryProvider.InvokedContext(s_mockAgent, session, [new ChatMessage(ChatRole.User, "Message from user 1")]); + var context2 = new ChatHistoryProvider.InvokedContext(s_mockAgent, session, [new ChatMessage(ChatRole.User, "Message from user 2")]); await store1.InvokedAsync(context1); await store2.InvokedAsync(context2); @@ -646,8 +639,8 @@ public async Task InvokingAsync_WithHierarchicalPartitionIsolation_ShouldIsolate await Task.Delay(100); // Act & Assert - var invokingContext1 = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, []); - var invokingContext2 = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, []); + var invokingContext1 = new ChatHistoryProvider.InvokingContext(s_mockAgent, session, []); + var invokingContext2 = new ChatHistoryProvider.InvokingContext(s_mockAgent, session, []); var messages1 = await store1.InvokingAsync(invokingContext1); var messageList1 = messages1.ToList(); @@ -664,43 +657,37 @@ public async Task InvokingAsync_WithHierarchicalPartitionIsolation_ShouldIsolate [SkippableFact] [Trait("Category", "CosmosDB")] - public async Task SerializeDeserialize_WithHierarchicalPartitioning_ShouldPreserveStateAsync() + public async Task StateBag_WithHierarchicalPartitioning_ShouldPreserveStateAcrossProviderInstancesAsync() { // Arrange this.SkipIfEmulatorNotAvailable(); + var session = CreateMockSession(); const string TenantId = "tenant-serialize"; const string UserId = "user-serialize"; const string SessionId = "session-serialize"; - using var originalStore = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, HierarchicalTestContainerId, TenantId, UserId, SessionId); + using var originalStore = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, HierarchicalTestContainerId, + _ => new CosmosChatHistoryProvider.State(SessionId, TenantId, UserId)); - var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Test serialization message")]); + var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, session, [new ChatMessage(ChatRole.User, "Test serialization message")]); await originalStore.InvokedAsync(context); - // Act - Serialize the provider state - var serializedState = originalStore.Serialize(); - - // Create a new provider from the serialized state - using var cosmosClient = new CosmosClient(EmulatorEndpoint, EmulatorKey); - var serializerOptions = new JsonSerializerOptions - { - TypeInfoResolver = new DefaultJsonTypeInfoResolver() - }; - using var deserializedStore = CosmosChatHistoryProvider.CreateFromSerializedState(cosmosClient, serializedState, s_testDatabaseId, HierarchicalTestContainerId, serializerOptions); - // Wait a moment for eventual consistency await Task.Delay(100); - // Assert - The deserialized provider should have the same functionality - var invokingContext = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, []); - var messages = await deserializedStore.InvokingAsync(invokingContext); + // Act - Create a new provider that uses a different intializer, but we will use the same session. + using var newStore = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, HierarchicalTestContainerId, + _ => new CosmosChatHistoryProvider.State(Guid.NewGuid().ToString())); + + // Assert - The new provider should read the same messages from Cosmos DB + var invokingContext = new ChatHistoryProvider.InvokingContext(s_mockAgent, session, []); + var messages = await newStore.InvokingAsync(invokingContext); var messageList = messages.ToList(); Assert.Single(messageList); Assert.Equal("Test serialization message", messageList[0].Text); - Assert.Equal(SessionId, deserializedStore.ConversationId); - Assert.Equal(s_testDatabaseId, deserializedStore.DatabaseId); - Assert.Equal(HierarchicalTestContainerId, deserializedStore.ContainerId); + Assert.Equal(s_testDatabaseId, newStore.DatabaseId); + Assert.Equal(HierarchicalTestContainerId, newStore.ContainerId); } [SkippableFact] @@ -711,13 +698,17 @@ public async Task HierarchicalAndSimplePartitioning_ShouldCoexistAsync() this.SkipIfEmulatorNotAvailable(); const string SessionId = "coexist-session"; + var session = CreateMockSession(); // Create simple provider using simple partitioning container and hierarchical provider using hierarchical container - using var simpleProvider = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, TestContainerId, SessionId); - using var hierarchicalProvider = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, HierarchicalTestContainerId, "tenant-coexist", "user-coexist", SessionId); + // Use different stateKey values so the providers don't overwrite each other's state in the shared session + using var simpleProvider = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, TestContainerId, + _ => new CosmosChatHistoryProvider.State(SessionId), stateKey: "simple"); + using var hierarchicalProvider = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, HierarchicalTestContainerId, + _ => new CosmosChatHistoryProvider.State(SessionId, "tenant-coexist", "user-coexist"), stateKey: "hierarchical"); // Add messages to both - var simpleContext = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Simple partitioning message")]); - var hierarchicalContext = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Hierarchical partitioning message")]); + var simpleContext = new ChatHistoryProvider.InvokedContext(s_mockAgent, session, [new ChatMessage(ChatRole.User, "Simple partitioning message")]); + var hierarchicalContext = new ChatHistoryProvider.InvokedContext(s_mockAgent, session, [new ChatMessage(ChatRole.User, "Hierarchical partitioning message")]); await simpleProvider.InvokedAsync(simpleContext); await hierarchicalProvider.InvokedAsync(hierarchicalContext); @@ -726,7 +717,7 @@ public async Task HierarchicalAndSimplePartitioning_ShouldCoexistAsync() await Task.Delay(100); // Act & Assert - var invokingContext = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, []); + var invokingContext = new ChatHistoryProvider.InvokingContext(s_mockAgent, session, []); var simpleMessages = await simpleProvider.InvokingAsync(invokingContext); var simpleMessageList = simpleMessages.ToList(); @@ -747,9 +738,11 @@ public async Task MaxMessagesToRetrieve_ShouldLimitAndReturnMostRecentAsync() { // Arrange this.SkipIfEmulatorNotAvailable(); + var session = CreateMockSession(); const string ConversationId = "max-messages-test"; - using var provider = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, TestContainerId, ConversationId); + using var provider = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, TestContainerId, + _ => new CosmosChatHistoryProvider.State(ConversationId)); // Add 10 messages var messages = new List(); @@ -759,7 +752,7 @@ public async Task MaxMessagesToRetrieve_ShouldLimitAndReturnMostRecentAsync() await Task.Delay(10); // Small delay to ensure different timestamps } - var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, messages); + var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, session, messages); await provider.InvokedAsync(context); // Wait for eventual consistency @@ -767,7 +760,7 @@ public async Task MaxMessagesToRetrieve_ShouldLimitAndReturnMostRecentAsync() // Act - Set max to 5 and retrieve provider.MaxMessagesToRetrieve = 5; - var invokingContext = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, []); + var invokingContext = new ChatHistoryProvider.InvokingContext(s_mockAgent, session, []); var retrievedMessages = await provider.InvokingAsync(invokingContext); var messageList = retrievedMessages.ToList(); @@ -786,9 +779,11 @@ public async Task MaxMessagesToRetrieve_Null_ShouldReturnAllMessagesAsync() { // Arrange this.SkipIfEmulatorNotAvailable(); + var session = CreateMockSession(); const string ConversationId = "max-messages-null-test"; - using var provider = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, TestContainerId, ConversationId); + using var provider = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, TestContainerId, + _ => new CosmosChatHistoryProvider.State(ConversationId)); // Add 10 messages var messages = new List(); @@ -797,14 +792,14 @@ public async Task MaxMessagesToRetrieve_Null_ShouldReturnAllMessagesAsync() messages.Add(new ChatMessage(ChatRole.User, $"Message {i}")); } - var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, messages); + var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, session, messages); await provider.InvokedAsync(context); // Wait for eventual consistency await Task.Delay(100); // Act - No limit set (default null) - var invokingContext = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, []); + var invokingContext = new ChatHistoryProvider.InvokingContext(s_mockAgent, session, []); var retrievedMessages = await provider.InvokingAsync(invokingContext); var messageList = retrievedMessages.ToList(); diff --git a/dotnet/tests/Microsoft.Agents.AI.DurableTask.UnitTests/DurableAgentSessionTests.cs b/dotnet/tests/Microsoft.Agents.AI.DurableTask.UnitTests/DurableAgentSessionTests.cs index 4bf8ebc718..bc06c35ab8 100644 --- a/dotnet/tests/Microsoft.Agents.AI.DurableTask.UnitTests/DurableAgentSessionTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.DurableTask.UnitTests/DurableAgentSessionTests.cs @@ -15,7 +15,7 @@ public void BuiltInSerialization() JsonElement serializedSession = session.Serialize(); // Expected format: "{\"sessionId\":\"@dafx-test-agent@\"}" - string expectedSerializedSession = $"{{\"sessionId\":\"@dafx-{sessionId.Name}@{sessionId.Key}\"}}"; + string expectedSerializedSession = $"{{\"sessionId\":\"@dafx-{sessionId.Name}@{sessionId.Key}\",\"stateBag\":{{}}}}"; Assert.Equal(expectedSerializedSession, serializedSession.ToString()); DurableAgentSession deserializedSession = DurableAgentSession.Deserialize(serializedSession); @@ -33,11 +33,47 @@ public void STJSerialization() string serializedSession = JsonSerializer.Serialize(session, typeof(DurableAgentSession)); // Expected format: "{\"sessionId\":\"@dafx-test-agent@\"}" - string expectedSerializedSession = $"{{\"sessionId\":\"@dafx-{sessionId.Name}@{sessionId.Key}\"}}"; + string expectedSerializedSession = $"{{\"sessionId\":\"@dafx-{sessionId.Name}@{sessionId.Key}\",\"stateBag\":{{}}}}"; Assert.Equal(expectedSerializedSession, serializedSession); DurableAgentSession? deserializedSession = JsonSerializer.Deserialize(serializedSession); Assert.NotNull(deserializedSession); Assert.Equal(sessionId, deserializedSession.SessionId); } + + [Fact] + public void BuiltInSerialization_RoundTrip_PreservesStateBag() + { + // Arrange + AgentSessionId sessionId = AgentSessionId.WithRandomKey("test-agent"); + DurableAgentSession session = new(sessionId); + session.StateBag.SetValue("durableKey", "durableValue"); + + // Act + JsonElement serializedSession = session.Serialize(); + DurableAgentSession deserializedSession = DurableAgentSession.Deserialize(serializedSession); + + // Assert + Assert.Equal(sessionId, deserializedSession.SessionId); + Assert.True(deserializedSession.StateBag.TryGetValue("durableKey", out var value)); + Assert.Equal("durableValue", value); + } + + [Fact] + public void STJSerialization_RoundTrip_PreservesStateBag() + { + // Arrange + AgentSessionId sessionId = AgentSessionId.WithRandomKey("test-agent"); + DurableAgentSession session = new(sessionId); + session.StateBag.SetValue("stjKey", "stjValue"); + + // Act + string serializedSession = JsonSerializer.Serialize(session, typeof(DurableAgentSession)); + DurableAgentSession? deserializedSession = JsonSerializer.Deserialize(serializedSession); + + // Assert + Assert.NotNull(deserializedSession); + Assert.True(deserializedSession.StateBag.TryGetValue("stjKey", out var value)); + Assert.Equal("stjValue", value); + } } diff --git a/dotnet/tests/Microsoft.Agents.AI.Hosting.AGUI.AspNetCore.IntegrationTests/BasicStreamingTests.cs b/dotnet/tests/Microsoft.Agents.AI.Hosting.AGUI.AspNetCore.IntegrationTests/BasicStreamingTests.cs index a6e7aab212..f990650e7c 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Hosting.AGUI.AspNetCore.IntegrationTests/BasicStreamingTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Hosting.AGUI.AspNetCore.IntegrationTests/BasicStreamingTests.cs @@ -7,6 +7,7 @@ using System.Net.Http; using System.Runtime.CompilerServices; using System.Text.Json; +using System.Text.Json.Serialization; using System.Threading; using System.Threading.Tasks; using FluentAssertions; @@ -281,10 +282,10 @@ internal sealed class FakeChatClientAgent : AIAgent public override string? Description => "A fake agent for testing"; protected override ValueTask CreateSessionCoreAsync(CancellationToken cancellationToken = default) => - new(new FakeInMemoryAgentSession()); + new(new FakeAgentSession()); protected override ValueTask DeserializeSessionCoreAsync(JsonElement serializedState, JsonSerializerOptions? jsonSerializerOptions = null, CancellationToken cancellationToken = default) => - new(new FakeInMemoryAgentSession(serializedState, jsonSerializerOptions)); + new(serializedState.Deserialize(jsonSerializerOptions)!); protected override JsonElement SerializeSessionCore(AgentSession session, JsonSerializerOptions? jsonSerializerOptions = null) => throw new NotImplementedException(); @@ -326,15 +327,14 @@ protected override async IAsyncEnumerable RunCoreStreamingA } } - private sealed class FakeInMemoryAgentSession : InMemoryAgentSession + private sealed class FakeAgentSession : AgentSession { - public FakeInMemoryAgentSession() - : base() + public FakeAgentSession() { } - public FakeInMemoryAgentSession(JsonElement serializedSession, JsonSerializerOptions? jsonSerializerOptions = null) - : base(serializedSession, jsonSerializerOptions) + [JsonConstructor] + public FakeAgentSession(AgentSessionStateBag stateBag) : base(stateBag) { } } @@ -348,19 +348,19 @@ internal sealed class FakeMultiMessageAgent : AIAgent public override string? Description => "A fake agent that sends multiple messages for testing"; protected override ValueTask CreateSessionCoreAsync(CancellationToken cancellationToken = default) => - new(new FakeInMemoryAgentSession()); + new(new FakeAgentSession()); protected override ValueTask DeserializeSessionCoreAsync(JsonElement serializedState, JsonSerializerOptions? jsonSerializerOptions = null, CancellationToken cancellationToken = default) => - new(new FakeInMemoryAgentSession(serializedState, jsonSerializerOptions)); + new(serializedState.Deserialize(jsonSerializerOptions)!); protected override JsonElement SerializeSessionCore(AgentSession session, JsonSerializerOptions? jsonSerializerOptions = null) { - if (session is not FakeInMemoryAgentSession fakeSession) + if (session is not FakeAgentSession fakeSession) { throw new InvalidOperationException("The provided session is not compatible with the agent. Only sessions created by the agent can be serialized."); } - return fakeSession.Serialize(jsonSerializerOptions); + return JsonSerializer.SerializeToElement(fakeSession, jsonSerializerOptions); } protected override async Task RunCoreAsync( @@ -427,19 +427,16 @@ protected override async IAsyncEnumerable RunCoreStreamingA } } - private sealed class FakeInMemoryAgentSession : InMemoryAgentSession + private sealed class FakeAgentSession : AgentSession { - public FakeInMemoryAgentSession() - : base() + public FakeAgentSession() { } - public FakeInMemoryAgentSession(JsonElement serializedSession, JsonSerializerOptions? jsonSerializerOptions = null) - : base(serializedSession, jsonSerializerOptions) + [JsonConstructor] + public FakeAgentSession(AgentSessionStateBag stateBag) : base(stateBag) { } - internal new JsonElement Serialize(JsonSerializerOptions? jsonSerializerOptions = null) - => base.Serialize(jsonSerializerOptions); } public override object? GetService(Type serviceType, object? serviceKey = null) => null; diff --git a/dotnet/tests/Microsoft.Agents.AI.Hosting.AGUI.AspNetCore.IntegrationTests/ForwardedPropertiesTests.cs b/dotnet/tests/Microsoft.Agents.AI.Hosting.AGUI.AspNetCore.IntegrationTests/ForwardedPropertiesTests.cs index afd3db44b3..353641502c 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Hosting.AGUI.AspNetCore.IntegrationTests/ForwardedPropertiesTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Hosting.AGUI.AspNetCore.IntegrationTests/ForwardedPropertiesTests.cs @@ -9,6 +9,7 @@ using System.Runtime.CompilerServices; using System.Text; using System.Text.Json; +using System.Text.Json.Serialization; using System.Threading; using System.Threading.Tasks; using FluentAssertions; @@ -335,35 +336,31 @@ protected override async IAsyncEnumerable RunCoreStreamingA } protected override ValueTask CreateSessionCoreAsync(CancellationToken cancellationToken = default) => - new(new FakeInMemoryAgentSession()); + new(new FakeAgentSession()); protected override ValueTask DeserializeSessionCoreAsync(JsonElement serializedState, JsonSerializerOptions? jsonSerializerOptions = null, CancellationToken cancellationToken = default) => - new(new FakeInMemoryAgentSession(serializedState, jsonSerializerOptions)); + new(serializedState.Deserialize(jsonSerializerOptions)!); protected override JsonElement SerializeSessionCore(AgentSession session, JsonSerializerOptions? jsonSerializerOptions = null) { - if (session is not FakeInMemoryAgentSession fakeSession) + if (session is not FakeAgentSession fakeSession) { throw new InvalidOperationException("The provided session is not compatible with the agent. Only sessions created by the agent can be serialized."); } - return fakeSession.Serialize(jsonSerializerOptions); + return JsonSerializer.SerializeToElement(fakeSession, jsonSerializerOptions); } - private sealed class FakeInMemoryAgentSession : InMemoryAgentSession + private sealed class FakeAgentSession : AgentSession { - public FakeInMemoryAgentSession() - : base() + public FakeAgentSession() { } - public FakeInMemoryAgentSession(JsonElement serializedSession, JsonSerializerOptions? jsonSerializerOptions = null) - : base(serializedSession, jsonSerializerOptions) + [JsonConstructor] + public FakeAgentSession(AgentSessionStateBag stateBag) : base(stateBag) { } - - internal new JsonElement Serialize(JsonSerializerOptions? jsonSerializerOptions = null) - => base.Serialize(jsonSerializerOptions); } public override object? GetService(Type serviceType, object? serviceKey = null) => null; diff --git a/dotnet/tests/Microsoft.Agents.AI.Hosting.AGUI.AspNetCore.IntegrationTests/SharedStateTests.cs b/dotnet/tests/Microsoft.Agents.AI.Hosting.AGUI.AspNetCore.IntegrationTests/SharedStateTests.cs index b78ddd0e11..67805edd86 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Hosting.AGUI.AspNetCore.IntegrationTests/SharedStateTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Hosting.AGUI.AspNetCore.IntegrationTests/SharedStateTests.cs @@ -7,6 +7,7 @@ using System.Net.Http; using System.Runtime.CompilerServices; using System.Text.Json; +using System.Text.Json.Serialization; using System.Threading; using System.Threading.Tasks; using FluentAssertions; @@ -418,35 +419,31 @@ stateObj is JsonElement state && } protected override ValueTask CreateSessionCoreAsync(CancellationToken cancellationToken = default) => - new(new FakeInMemoryAgentSession()); + new(new FakeAgentSession()); protected override ValueTask DeserializeSessionCoreAsync(JsonElement serializedState, JsonSerializerOptions? jsonSerializerOptions = null, CancellationToken cancellationToken = default) => - new(new FakeInMemoryAgentSession(serializedState, jsonSerializerOptions)); + new(serializedState.Deserialize(jsonSerializerOptions)!); protected override JsonElement SerializeSessionCore(AgentSession session, JsonSerializerOptions? jsonSerializerOptions = null) { - if (session is not FakeInMemoryAgentSession fakeSession) + if (session is not FakeAgentSession fakeSession) { throw new InvalidOperationException("The provided session is not compatible with the agent. Only sessions created by the agent can be serialized."); } - return fakeSession.Serialize(jsonSerializerOptions); + return JsonSerializer.SerializeToElement(fakeSession, jsonSerializerOptions); } - private sealed class FakeInMemoryAgentSession : InMemoryAgentSession + private sealed class FakeAgentSession : AgentSession { - public FakeInMemoryAgentSession() - : base() + public FakeAgentSession() { } - public FakeInMemoryAgentSession(JsonElement serializedSession, JsonSerializerOptions? jsonSerializerOptions = null) - : base(serializedSession, jsonSerializerOptions) + [JsonConstructor] + public FakeAgentSession(AgentSessionStateBag stateBag) : base(stateBag) { } - - internal new JsonElement Serialize(JsonSerializerOptions? jsonSerializerOptions = null) - => base.Serialize(jsonSerializerOptions); } public override object? GetService(Type serviceType, object? serviceKey = null) => null; diff --git a/dotnet/tests/Microsoft.Agents.AI.Hosting.AGUI.AspNetCore.UnitTests/AGUIEndpointRouteBuilderExtensionsTests.cs b/dotnet/tests/Microsoft.Agents.AI.Hosting.AGUI.AspNetCore.UnitTests/AGUIEndpointRouteBuilderExtensionsTests.cs index fecb9d421b..70070eb28e 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Hosting.AGUI.AspNetCore.UnitTests/AGUIEndpointRouteBuilderExtensionsTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Hosting.AGUI.AspNetCore.UnitTests/AGUIEndpointRouteBuilderExtensionsTests.cs @@ -6,6 +6,7 @@ using System.Linq; using System.Text; using System.Text.Json; +using System.Text.Json.Serialization; using System.Threading; using System.Threading.Tasks; using Microsoft.Agents.AI.Hosting.AGUI.AspNetCore.Shared; @@ -426,19 +427,19 @@ private sealed class MultiResponseAgent : AIAgent public override string? Description => "Agent that produces multiple text chunks"; protected override ValueTask CreateSessionCoreAsync(CancellationToken cancellationToken = default) => - new(new TestInMemoryAgentSession()); + new(new TestAgentSession()); protected override ValueTask DeserializeSessionCoreAsync(JsonElement serializedState, JsonSerializerOptions? jsonSerializerOptions = null, CancellationToken cancellationToken = default) => - new(new TestInMemoryAgentSession(serializedState, jsonSerializerOptions)); + new(serializedState.Deserialize(jsonSerializerOptions)!); protected override JsonElement SerializeSessionCore(AgentSession session, JsonSerializerOptions? jsonSerializerOptions = null) { - if (session is not TestInMemoryAgentSession testSession) + if (session is not TestAgentSession testSession) { throw new InvalidOperationException("The provided session is not compatible with the agent. Only sessions created by the agent can be serialized."); } - return testSession.Serialize(jsonSerializerOptions); + return JsonSerializer.SerializeToElement(testSession, jsonSerializerOptions); } protected override Task RunCoreAsync(IEnumerable messages, AgentSession? session = null, AgentRunOptions? options = null, CancellationToken cancellationToken = default) @@ -506,20 +507,16 @@ private RequestDelegate CreateRequestDelegate( }; } - private sealed class TestInMemoryAgentSession : InMemoryAgentSession + private sealed class TestAgentSession : AgentSession { - public TestInMemoryAgentSession() - : base() + public TestAgentSession() { } - public TestInMemoryAgentSession(JsonElement serializedSessionState, JsonSerializerOptions? jsonSerializerOptions = null) - : base(serializedSessionState, jsonSerializerOptions, null) + [JsonConstructor] + public TestAgentSession(AgentSessionStateBag stateBag) : base(stateBag) { } - - internal new JsonElement Serialize(JsonSerializerOptions? jsonSerializerOptions = null) - => base.Serialize(jsonSerializerOptions); } private sealed class TestAgent : AIAgent @@ -529,19 +526,19 @@ private sealed class TestAgent : AIAgent public override string? Description => "Test agent"; protected override ValueTask CreateSessionCoreAsync(CancellationToken cancellationToken = default) => - new(new TestInMemoryAgentSession()); + new(new TestAgentSession()); protected override ValueTask DeserializeSessionCoreAsync(JsonElement serializedState, JsonSerializerOptions? jsonSerializerOptions = null, CancellationToken cancellationToken = default) => - new(new TestInMemoryAgentSession(serializedState, jsonSerializerOptions)); + new(serializedState.Deserialize(jsonSerializerOptions)!); protected override JsonElement SerializeSessionCore(AgentSession session, JsonSerializerOptions? jsonSerializerOptions = null) { - if (session is not TestInMemoryAgentSession testSession) + if (session is not TestAgentSession testSession) { throw new InvalidOperationException("The provided session is not compatible with the agent. Only sessions created by the agent can be serialized."); } - return testSession.Serialize(jsonSerializerOptions); + return JsonSerializer.SerializeToElement(testSession, jsonSerializerOptions); } protected override Task RunCoreAsync(IEnumerable messages, AgentSession? session = null, AgentRunOptions? options = null, CancellationToken cancellationToken = default) diff --git a/dotnet/tests/Microsoft.Agents.AI.Mem0.IntegrationTests/Mem0ProviderTests.cs b/dotnet/tests/Microsoft.Agents.AI.Mem0.IntegrationTests/Mem0ProviderTests.cs index a10f1246aa..df353d6cee 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Mem0.IntegrationTests/Mem0ProviderTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Mem0.IntegrationTests/Mem0ProviderTests.cs @@ -19,7 +19,6 @@ public sealed class Mem0ProviderTests : IDisposable private const string SkipReason = "Requires a Mem0 service configured"; // Set to null to enable. private static readonly AIAgent s_mockAgent = new Moq.Mock().Object; - private static readonly AgentSession s_mockSession = new Moq.Mock().Object; private readonly HttpClient _httpClient; @@ -49,17 +48,18 @@ public async Task CanAddAndRetrieveUserMemoriesAsync() var question = new ChatMessage(ChatRole.User, "What is my name?"); var input = new ChatMessage(ChatRole.User, "Hello, my name is Caoimhe."); var storageScope = new Mem0ProviderScope { ThreadId = "it-thread-1", UserId = "it-user-1" }; - var sut = new Mem0Provider(this._httpClient, storageScope); + var mockSession = new TestAgentSession(); + var sut = new Mem0Provider(this._httpClient, _ => new Mem0Provider.State(storageScope)); - await sut.ClearStoredMemoriesAsync(); - var ctxBefore = await sut.InvokingAsync(new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [question])); + await sut.ClearStoredMemoriesAsync(mockSession); + var ctxBefore = await sut.InvokingAsync(new AIContextProvider.InvokingContext(s_mockAgent, mockSession, [question])); Assert.DoesNotContain("Caoimhe", ctxBefore.Messages?[0].Text ?? string.Empty); // Act - await sut.InvokedAsync(new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, [input])); - var ctxAfterAdding = await GetContextWithRetryAsync(sut, question); - await sut.ClearStoredMemoriesAsync(); - var ctxAfterClearing = await sut.InvokingAsync(new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [question])); + await sut.InvokedAsync(new AIContextProvider.InvokedContext(s_mockAgent, mockSession, [input])); + var ctxAfterAdding = await GetContextWithRetryAsync(sut, mockSession, question); + await sut.ClearStoredMemoriesAsync(mockSession); + var ctxAfterClearing = await sut.InvokingAsync(new AIContextProvider.InvokingContext(s_mockAgent, mockSession, [question])); // Assert Assert.Contains("Caoimhe", ctxAfterAdding.Messages?[0].Text ?? string.Empty); @@ -73,17 +73,18 @@ public async Task CanAddAndRetrieveAgentMemoriesAsync() var question = new ChatMessage(ChatRole.User, "What is your name?"); var assistantIntro = new ChatMessage(ChatRole.Assistant, "Hello, I'm a friendly assistant and my name is Caoimhe."); var storageScope = new Mem0ProviderScope { AgentId = "it-agent-1" }; - var sut = new Mem0Provider(this._httpClient, storageScope); + var mockSession = new TestAgentSession(); + var sut = new Mem0Provider(this._httpClient, _ => new Mem0Provider.State(storageScope)); - await sut.ClearStoredMemoriesAsync(); - var ctxBefore = await sut.InvokingAsync(new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [question])); + await sut.ClearStoredMemoriesAsync(mockSession); + var ctxBefore = await sut.InvokingAsync(new AIContextProvider.InvokingContext(s_mockAgent, mockSession, [question])); Assert.DoesNotContain("Caoimhe", ctxBefore.Messages?[0].Text ?? string.Empty); // Act - await sut.InvokedAsync(new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, [assistantIntro])); - var ctxAfterAdding = await GetContextWithRetryAsync(sut, question); - await sut.ClearStoredMemoriesAsync(); - var ctxAfterClearing = await sut.InvokingAsync(new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [question])); + await sut.InvokedAsync(new AIContextProvider.InvokedContext(s_mockAgent, mockSession, [assistantIntro])); + var ctxAfterAdding = await GetContextWithRetryAsync(sut, mockSession, question); + await sut.ClearStoredMemoriesAsync(mockSession); + var ctxAfterClearing = await sut.InvokingAsync(new AIContextProvider.InvokingContext(s_mockAgent, mockSession, [question])); // Assert Assert.Contains("Caoimhe", ctxAfterAdding.Messages?[0].Text ?? string.Empty); @@ -96,37 +97,41 @@ public async Task DoesNotLeakMemoriesAcrossAgentScopesAsync() // Arrange var question = new ChatMessage(ChatRole.User, "What is your name?"); var assistantIntro = new ChatMessage(ChatRole.Assistant, "I'm an AI tutor and my name is Caoimhe."); - var sut1 = new Mem0Provider(this._httpClient, new Mem0ProviderScope { AgentId = "it-agent-a" }); - var sut2 = new Mem0Provider(this._httpClient, new Mem0ProviderScope { AgentId = "it-agent-b" }); - - await sut1.ClearStoredMemoriesAsync(); - await sut2.ClearStoredMemoriesAsync(); - - var ctxBefore1 = await sut1.InvokingAsync(new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [question])); - var ctxBefore2 = await sut2.InvokingAsync(new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [question])); + var storageScope1 = new Mem0ProviderScope { AgentId = "it-agent-a" }; + var storageScope2 = new Mem0ProviderScope { AgentId = "it-agent-b" }; + var mockSession1 = new TestAgentSession(); + var mockSession2 = new TestAgentSession(); + var sut1 = new Mem0Provider(this._httpClient, _ => new Mem0Provider.State(storageScope1)); + var sut2 = new Mem0Provider(this._httpClient, _ => new Mem0Provider.State(storageScope2)); + + await sut1.ClearStoredMemoriesAsync(mockSession1); + await sut2.ClearStoredMemoriesAsync(mockSession2); + + var ctxBefore1 = await sut1.InvokingAsync(new AIContextProvider.InvokingContext(s_mockAgent, mockSession1, [question])); + var ctxBefore2 = await sut2.InvokingAsync(new AIContextProvider.InvokingContext(s_mockAgent, mockSession2, [question])); Assert.DoesNotContain("Caoimhe", ctxBefore1.Messages?[0].Text ?? string.Empty); Assert.DoesNotContain("Caoimhe", ctxBefore2.Messages?[0].Text ?? string.Empty); // Act - await sut1.InvokedAsync(new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, [assistantIntro])); - var ctxAfterAdding1 = await GetContextWithRetryAsync(sut1, question); - var ctxAfterAdding2 = await GetContextWithRetryAsync(sut2, question); + await sut1.InvokedAsync(new AIContextProvider.InvokedContext(s_mockAgent, mockSession1, [assistantIntro])); + var ctxAfterAdding1 = await GetContextWithRetryAsync(sut1, mockSession1, question); + var ctxAfterAdding2 = await GetContextWithRetryAsync(sut2, mockSession2, question); // Assert Assert.Contains("Caoimhe", ctxAfterAdding1.Messages?[0].Text ?? string.Empty); Assert.DoesNotContain("Caoimhe", ctxAfterAdding2.Messages?[0].Text ?? string.Empty); // Cleanup - await sut1.ClearStoredMemoriesAsync(); - await sut2.ClearStoredMemoriesAsync(); + await sut1.ClearStoredMemoriesAsync(mockSession1); + await sut2.ClearStoredMemoriesAsync(mockSession2); } - private static async Task GetContextWithRetryAsync(Mem0Provider provider, ChatMessage question, int attempts = 5, int delayMs = 1000) + private static async Task GetContextWithRetryAsync(Mem0Provider provider, AgentSession session, ChatMessage question, int attempts = 5, int delayMs = 1000) { AIContext? ctx = null; for (int i = 0; i < attempts; i++) { - ctx = await provider.InvokingAsync(new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [question]), CancellationToken.None); + ctx = await provider.InvokingAsync(new AIContextProvider.InvokingContext(s_mockAgent, session, [question]), CancellationToken.None); var text = ctx.Messages?[0].Text; if (!string.IsNullOrEmpty(text) && text.IndexOf("Caoimhe", StringComparison.OrdinalIgnoreCase) >= 0) { @@ -141,4 +146,12 @@ public void Dispose() { this._httpClient.Dispose(); } + + private sealed class TestAgentSession : AgentSession + { + public TestAgentSession() + { + this.StateBag = new AgentSessionStateBag(); + } + } } diff --git a/dotnet/tests/Microsoft.Agents.AI.Mem0.UnitTests/Mem0ProviderTests.cs b/dotnet/tests/Microsoft.Agents.AI.Mem0.UnitTests/Mem0ProviderTests.cs index 53c87b09ba..12a94078ee 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Mem0.UnitTests/Mem0ProviderTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Mem0.UnitTests/Mem0ProviderTests.cs @@ -19,7 +19,6 @@ namespace Microsoft.Agents.AI.Mem0.UnitTests; public sealed class Mem0ProviderTests : IDisposable { private static readonly AIAgent s_mockAgent = new Mock().Object; - private static readonly AgentSession s_mockSession = new Mock().Object; private readonly Mock> _loggerMock; private readonly Mock _loggerFactoryMock; @@ -55,35 +54,16 @@ public void Constructor_Throws_WhenBaseAddressMissing() using HttpClient client = new(); // Act & Assert - var ex = Assert.Throws(() => new Mem0Provider(client, new Mem0ProviderScope() { ThreadId = "tid" })); + var ex = Assert.Throws(() => new Mem0Provider(client, _ => new Mem0Provider.State(new Mem0ProviderScope { ThreadId = "tid" }))); Assert.StartsWith("The HttpClient BaseAddress must be set for Mem0 operations.", ex.Message); } [Fact] - public void Constructor_Throws_WhenNoStorageScopeValueIsSet() + public void Constructor_Throws_WhenStateInitializerIsNull() { // Act & Assert - var ex = Assert.Throws(() => new Mem0Provider(this._httpClient, new Mem0ProviderScope())); - Assert.StartsWith("At least one of ApplicationId, AgentId, ThreadId, or UserId must be provided for the storage scope.", ex.Message); - } - - [Fact] - public void Constructor_Throws_WhenNoSearchScopeValueIsSet() - { - // Act & Assert - var ex = Assert.Throws(() => new Mem0Provider(this._httpClient, new Mem0ProviderScope() { ThreadId = "tid" }, new Mem0ProviderScope())); - Assert.StartsWith("At least one of ApplicationId, AgentId, ThreadId, or UserId must be provided for the search scope.", ex.Message); - } - - [Fact] - public void DeserializingConstructor_Throws_WithEmptyJsonElement() - { - // Arrange - var jsonElement = JsonSerializer.SerializeToElement(new object(), Mem0JsonUtilities.DefaultOptions); - - // Act & Assert - var ex = Assert.Throws(() => new Mem0Provider(this._httpClient, jsonElement)); - Assert.StartsWith("The Mem0Provider state did not contain the required scope properties.", ex.Message); + var ex = Assert.Throws(() => new Mem0Provider(this._httpClient, null!)); + Assert.Contains("stateInitializer", ex.Message); } [Fact] @@ -98,8 +78,9 @@ public async Task InvokingAsync_PerformsSearch_AndReturnsContextMessageAsync() ThreadId = "session", UserId = "user" }; - var sut = new Mem0Provider(this._httpClient, storageScope, options: new() { EnableSensitiveTelemetryData = true }, loggerFactory: this._loggerFactoryMock.Object); - var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "What is my name?")]); + var mockSession = new TestAgentSession(); + var sut = new Mem0Provider(this._httpClient, _ => new Mem0Provider.State(storageScope), options: new() { EnableSensitiveTelemetryData = true }, loggerFactory: this._loggerFactoryMock.Object); + var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, mockSession, [new ChatMessage(ChatRole.User, "What is my name?")]); // Act var aiContext = await sut.InvokingAsync(invokingContext); @@ -162,9 +143,10 @@ public async Task InvokingAsync_LogsUserIdBasedOnEnableSensitiveTelemetryDataAsy UserId = "user" }; var options = new Mem0ProviderOptions { EnableSensitiveTelemetryData = enableSensitiveTelemetryData }; + var mockSession = new TestAgentSession(); - var sut = new Mem0Provider(this._httpClient, storageScope, options: options, loggerFactory: this._loggerFactoryMock.Object); - var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Who am I?")]); + var sut = new Mem0Provider(this._httpClient, _ => new Mem0Provider.State(storageScope), options: options, loggerFactory: this._loggerFactoryMock.Object); + var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, mockSession, [new ChatMessage(ChatRole.User, "Who am I?")]); // Act await sut.InvokingAsync(invokingContext, CancellationToken.None); @@ -204,7 +186,8 @@ public async Task InvokedAsync_PersistsAllowedMessagesAsync() this._handler.EnqueueEmptyOk(); // For second CreateMemory this._handler.EnqueueEmptyOk(); // For third CreateMemory var storageScope = new Mem0ProviderScope { ApplicationId = "a", AgentId = "b", ThreadId = "c", UserId = "d" }; - var sut = new Mem0Provider(this._httpClient, storageScope); + var mockSession = new TestAgentSession(); + var sut = new Mem0Provider(this._httpClient, _ => new Mem0Provider.State(storageScope)); var requestMessages = new List { @@ -218,7 +201,7 @@ public async Task InvokedAsync_PersistsAllowedMessagesAsync() }; // Act - await sut.InvokedAsync(new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages) { ResponseMessages = responseMessages }); + await sut.InvokedAsync(new AIContextProvider.InvokedContext(s_mockAgent, mockSession, requestMessages) { ResponseMessages = responseMessages }); // Assert var memoryPosts = this._handler.Requests.Where(r => r.RequestMessage.RequestUri!.AbsolutePath == "/v1/memories/" && r.RequestMessage.Method == HttpMethod.Post).ToList(); @@ -235,7 +218,8 @@ public async Task InvokedAsync_PersistsNothingForFailedRequestAsync() { // Arrange var storageScope = new Mem0ProviderScope { ApplicationId = "a", AgentId = "b", ThreadId = "c", UserId = "d" }; - var sut = new Mem0Provider(this._httpClient, storageScope); + var mockSession = new TestAgentSession(); + var sut = new Mem0Provider(this._httpClient, _ => new Mem0Provider.State(storageScope)); var requestMessages = new List { @@ -245,7 +229,7 @@ public async Task InvokedAsync_PersistsNothingForFailedRequestAsync() }; // Act - await sut.InvokedAsync(new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages) { ResponseMessages = null, InvokeException = new InvalidOperationException("Request Failed") }); + await sut.InvokedAsync(new AIContextProvider.InvokedContext(s_mockAgent, mockSession, requestMessages) { ResponseMessages = null, InvokeException = new InvalidOperationException("Request Failed") }); // Assert Assert.Empty(this._handler.Requests); @@ -256,7 +240,8 @@ public async Task InvokedAsync_ShouldNotThrow_WhenStorageFailsAsync() { // Arrange var storageScope = new Mem0ProviderScope { ApplicationId = "a", AgentId = "b", ThreadId = "c", UserId = "d" }; - var sut = new Mem0Provider(this._httpClient, storageScope, loggerFactory: this._loggerFactoryMock.Object); + var mockSession = new TestAgentSession(); + var sut = new Mem0Provider(this._httpClient, _ => new Mem0Provider.State(storageScope), loggerFactory: this._loggerFactoryMock.Object); this._handler.EnqueueEmptyInternalServerError(); var requestMessages = new List @@ -271,7 +256,7 @@ public async Task InvokedAsync_ShouldNotThrow_WhenStorageFailsAsync() }; // Act - await sut.InvokedAsync(new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages) { ResponseMessages = responseMessages }); + await sut.InvokedAsync(new AIContextProvider.InvokedContext(s_mockAgent, mockSession, requestMessages) { ResponseMessages = responseMessages }); // Assert this._loggerMock.Verify( @@ -310,7 +295,8 @@ public async Task InvokedAsync_LogsUserIdBasedOnEnableSensitiveTelemetryDataAsyn }; var options = new Mem0ProviderOptions { EnableSensitiveTelemetryData = enableSensitiveTelemetryData }; - var sut = new Mem0Provider(this._httpClient, storageScope, options: options, loggerFactory: this._loggerFactoryMock.Object); + var mockSession = new TestAgentSession(); + var sut = new Mem0Provider(this._httpClient, _ => new Mem0Provider.State(storageScope), options: options, loggerFactory: this._loggerFactoryMock.Object); var requestMessages = new List { new(ChatRole.User, "User text") @@ -321,7 +307,7 @@ public async Task InvokedAsync_LogsUserIdBasedOnEnableSensitiveTelemetryDataAsyn }; // Act - await sut.InvokedAsync(new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages) { ResponseMessages = responseMessages }); + await sut.InvokedAsync(new AIContextProvider.InvokedContext(s_mockAgent, mockSession, requestMessages) { ResponseMessages = responseMessages }); // Assert Assert.Equal(expectedLogCount, this._loggerMock.Invocations.Count); @@ -343,11 +329,12 @@ public async Task ClearStoredMemoriesAsync_SendsDeleteWithQueryAsync() { // Arrange var storageScope = new Mem0ProviderScope { ApplicationId = "app", AgentId = "agent", ThreadId = "session", UserId = "user" }; - var sut = new Mem0Provider(this._httpClient, storageScope); + var sut = new Mem0Provider(this._httpClient, _ => new Mem0Provider.State(storageScope)); this._handler.EnqueueEmptyOk(); // for DELETE + var mockSession = new TestAgentSession(); // Act - await sut.ClearStoredMemoriesAsync(); + await sut.ClearStoredMemoriesAsync(mockSession); // Assert var delete = Assert.Single(this._handler.Requests, r => r.RequestMessage.Method == HttpMethod.Delete); @@ -355,70 +342,71 @@ public async Task ClearStoredMemoriesAsync_SendsDeleteWithQueryAsync() } [Fact] - public void Serialize_RoundTripsScopes() + public async Task InvokingAsync_ShouldNotThrow_WhenSearchFailsAsync() { // Arrange - var storageScope = new Mem0ProviderScope { ApplicationId = "app", AgentId = "agent", ThreadId = "session", UserId = "user" }; - var sut = new Mem0Provider(this._httpClient, storageScope, options: new() { ContextPrompt = "Custom:" }, loggerFactory: this._loggerFactoryMock.Object); + var storageScope = new Mem0ProviderScope { ApplicationId = "app" }; + var mockSession = new TestAgentSession(); + var provider = new Mem0Provider(this._httpClient, _ => new Mem0Provider.State(storageScope), loggerFactory: this._loggerFactoryMock.Object); + var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, mockSession, [new ChatMessage(ChatRole.User, "Q?")]); // Act - var stateElement = sut.Serialize(); - using JsonDocument doc = JsonDocument.Parse(stateElement.GetRawText()); - var storageScopeElement = doc.RootElement.GetProperty("storageScope"); - Assert.Equal("app", storageScopeElement.GetProperty("applicationId").GetString()); - Assert.Equal("agent", storageScopeElement.GetProperty("agentId").GetString()); - Assert.Equal("session", storageScopeElement.GetProperty("threadId").GetString()); - Assert.Equal("user", storageScopeElement.GetProperty("userId").GetString()); - - var sut2 = new Mem0Provider(this._httpClient, stateElement); - var stateElement2 = sut2.Serialize(); + var aiContext = await provider.InvokingAsync(invokingContext, CancellationToken.None); // Assert - using JsonDocument doc2 = JsonDocument.Parse(stateElement2.GetRawText()); - var storageScopeElement2 = doc2.RootElement.GetProperty("storageScope"); - Assert.Equal("app", storageScopeElement2.GetProperty("applicationId").GetString()); - Assert.Equal("agent", storageScopeElement2.GetProperty("agentId").GetString()); - Assert.Equal("session", storageScopeElement2.GetProperty("threadId").GetString()); - Assert.Equal("user", storageScopeElement2.GetProperty("userId").GetString()); + Assert.Null(aiContext.Messages); + Assert.Null(aiContext.Tools); + this._loggerMock.Verify( + l => l.Log( + LogLevel.Error, + It.IsAny(), + It.Is((v, t) => v.ToString()!.Contains("Mem0AIContextProvider: Failed to search Mem0 for memories due to error")), + It.IsAny(), + It.IsAny>()), + Times.Once); } [Fact] - public void Serialize_DoesNotIncludeDefaultContextPrompt() + public async Task StateInitializer_IsCalledOnceAndStoredInStateBagAsync() { // Arrange + this._handler.EnqueueJsonResponse("[]"); + this._handler.EnqueueJsonResponse("[]"); var storageScope = new Mem0ProviderScope { ApplicationId = "app" }; - var sut = new Mem0Provider(this._httpClient, storageScope); + var mockSession = new TestAgentSession(); + int initializerCallCount = 0; + var sut = new Mem0Provider(this._httpClient, _ => + { + initializerCallCount++; + return new Mem0Provider.State(storageScope); + }); + var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, mockSession, [new ChatMessage(ChatRole.User, "Q?")]); // Act - var stateElement = sut.Serialize(); + await sut.InvokingAsync(invokingContext, CancellationToken.None); + await sut.InvokingAsync(invokingContext, CancellationToken.None); // Assert - using JsonDocument doc = JsonDocument.Parse(stateElement.GetRawText()); - Assert.False(doc.RootElement.TryGetProperty("contextPrompt", out _)); + Assert.Equal(1, initializerCallCount); } [Fact] - public async Task InvokingAsync_ShouldNotThrow_WhenSearchFailsAsync() + public async Task StateKey_CanBeConfiguredViaOptionsAsync() { // Arrange + this._handler.EnqueueJsonResponse("[]"); var storageScope = new Mem0ProviderScope { ApplicationId = "app" }; - var provider = new Mem0Provider(this._httpClient, storageScope, loggerFactory: this._loggerFactoryMock.Object); - var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Q?")]); + var mockSession = new TestAgentSession(); + const string CustomKey = "MyCustomKey"; + var sut = new Mem0Provider(this._httpClient, _ => new Mem0Provider.State(storageScope), options: new() { StateKey = CustomKey }); + var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, mockSession, [new ChatMessage(ChatRole.User, "Q?")]); // Act - var aiContext = await provider.InvokingAsync(invokingContext, CancellationToken.None); + await sut.InvokingAsync(invokingContext, CancellationToken.None); // Assert - Assert.Null(aiContext.Messages); - Assert.Null(aiContext.Tools); - this._loggerMock.Verify( - l => l.Log( - LogLevel.Error, - It.IsAny(), - It.Is((v, t) => v.ToString()!.Contains("Mem0AIContextProvider: Failed to search Mem0 for memories due to error")), - It.IsAny(), - It.IsAny>()), - Times.Once); + Assert.True(mockSession.StateBag.TryGetValue(CustomKey, out var state, Mem0JsonUtilities.DefaultOptions)); + Assert.NotNull(state); } private static bool ContainsOrdinal(string source, string value) => source.IndexOf(value, StringComparison.Ordinal) >= 0; @@ -465,4 +453,12 @@ public void EnqueueJsonResponse(string json) public void EnqueueEmptyInternalServerError() => this._responses.Enqueue(new HttpResponseMessage(System.Net.HttpStatusCode.InternalServerError)); } + + private sealed class TestAgentSession : AgentSession + { + public TestAgentSession() + { + this.StateBag = new AgentSessionStateBag(); + } + } } diff --git a/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgentOptionsTests.cs b/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgentOptionsTests.cs index 8502550d2c..e9ad2b9a6b 100644 --- a/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgentOptionsTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgentOptionsTests.cs @@ -1,8 +1,6 @@ // Copyright (c) Microsoft. All rights reserved. using System.Collections.Generic; -using System.Threading; -using System.Threading.Tasks; using Microsoft.Extensions.AI; using Moq; @@ -23,8 +21,8 @@ public void DefaultConstructor_InitializesWithNullValues() Assert.Null(options.Name); Assert.Null(options.Description); Assert.Null(options.ChatOptions); - Assert.Null(options.ChatHistoryProviderFactory); - Assert.Null(options.AIContextProviderFactory); + Assert.Null(options.ChatHistoryProvider); + Assert.Null(options.AIContextProvider); } [Fact] @@ -36,8 +34,8 @@ public void Constructor_WithNullValues_SetsPropertiesCorrectly() // Assert Assert.Null(options.Name); Assert.Null(options.Description); - Assert.Null(options.AIContextProviderFactory); - Assert.Null(options.ChatHistoryProviderFactory); + Assert.Null(options.AIContextProvider); + Assert.Null(options.ChatHistoryProvider); Assert.NotNull(options.ChatOptions); Assert.Null(options.ChatOptions.Instructions); Assert.Null(options.ChatOptions.Tools); @@ -117,11 +115,8 @@ public void Clone_CreatesDeepCopyWithSameValues() const string Description = "Test description"; var tools = new List { AIFunctionFactory.Create(() => "test") }; - static ValueTask ChatHistoryProviderFactoryAsync( - ChatClientAgentOptions.ChatHistoryProviderFactoryContext ctx, CancellationToken ct) => new(new Mock().Object); - - static ValueTask AIContextProviderFactoryAsync( - ChatClientAgentOptions.AIContextProviderFactoryContext ctx, CancellationToken ct) => new(new Mock().Object); + var mockChatHistoryProvider = new Mock().Object; + var mockAIContextProvider = new Mock().Object; var original = new ChatClientAgentOptions() { @@ -129,8 +124,8 @@ static ValueTask AIContextProviderFactoryAsync( Description = Description, ChatOptions = new() { Tools = tools }, Id = "test-id", - ChatHistoryProviderFactory = ChatHistoryProviderFactoryAsync, - AIContextProviderFactory = AIContextProviderFactoryAsync + ChatHistoryProvider = mockChatHistoryProvider, + AIContextProvider = mockAIContextProvider }; // Act @@ -141,8 +136,8 @@ static ValueTask AIContextProviderFactoryAsync( Assert.Equal(original.Id, clone.Id); Assert.Equal(original.Name, clone.Name); Assert.Equal(original.Description, clone.Description); - Assert.Same(original.ChatHistoryProviderFactory, clone.ChatHistoryProviderFactory); - Assert.Same(original.AIContextProviderFactory, clone.AIContextProviderFactory); + Assert.Same(original.ChatHistoryProvider, clone.ChatHistoryProvider); + Assert.Same(original.AIContextProvider, clone.AIContextProvider); // ChatOptions should be cloned, not the same reference Assert.NotSame(original.ChatOptions, clone.ChatOptions); @@ -154,11 +149,16 @@ static ValueTask AIContextProviderFactoryAsync( public void Clone_WithoutProvidingChatOptions_ClonesCorrectly() { // Arrange + var mockChatHistoryProvider = new Mock().Object; + var mockAIContextProvider = new Mock().Object; + var original = new ChatClientAgentOptions { Id = "test-id", Name = "Test name", - Description = "Test description" + Description = "Test description", + ChatHistoryProvider = mockChatHistoryProvider, + AIContextProvider = mockAIContextProvider }; // Act @@ -170,8 +170,8 @@ public void Clone_WithoutProvidingChatOptions_ClonesCorrectly() Assert.Equal(original.Name, clone.Name); Assert.Equal(original.Description, clone.Description); Assert.Null(original.ChatOptions); - Assert.Null(clone.ChatHistoryProviderFactory); - Assert.Null(clone.AIContextProviderFactory); + Assert.Same(original.ChatHistoryProvider, clone.ChatHistoryProvider); + Assert.Same(original.AIContextProvider, clone.AIContextProvider); } private static void AssertSameTools(IList? expected, IList? actual) diff --git a/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgentSessionTests.cs b/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgentSessionTests.cs index fd311f9225..27a13c3b49 100644 --- a/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgentSessionTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgentSessionTests.cs @@ -1,12 +1,9 @@ // Copyright (c) Microsoft. All rights reserved. using System; -using System.Collections.Generic; using System.Linq; using System.Text.Json; -using System.Threading.Tasks; using Microsoft.Extensions.AI; -using Moq; #pragma warning disable CA1861 // Avoid constant arrays as arguments @@ -24,7 +21,6 @@ public void ConstructorSetsDefaults() // Assert Assert.Null(session.ConversationId); - Assert.Null(session.ChatHistoryProvider); } [Fact] @@ -39,53 +35,6 @@ public void SetConversationIdRoundtrips() // Assert Assert.Equal(ConversationId, session.ConversationId); - Assert.Null(session.ChatHistoryProvider); - } - - [Fact] - public void SetChatHistoryProviderRoundtrips() - { - // Arrange - var session = new ChatClientAgentSession(); - var chatHistoryProvider = new InMemoryChatHistoryProvider(); - - // Act - session.ChatHistoryProvider = chatHistoryProvider; - - // Assert - Assert.Same(chatHistoryProvider, session.ChatHistoryProvider); - Assert.Null(session.ConversationId); - } - - [Fact] - public void SetConversationIdThrowsWhenChatHistoryProviderIsSet() - { - // Arrange - var session = new ChatClientAgentSession - { - ChatHistoryProvider = new InMemoryChatHistoryProvider() - }; - - // Act & Assert - var exception = Assert.Throws(() => session.ConversationId = "new-session-id"); - Assert.Equal("Only the ConversationId or ChatHistoryProvider may be set, but not both and switching from one to another is not supported.", exception.Message); - Assert.NotNull(session.ChatHistoryProvider); - } - - [Fact] - public void SetChatHistoryProviderThrowsWhenConversationIdIsSet() - { - // Arrange - var session = new ChatClientAgentSession - { - ConversationId = "existing-session-id" - }; - var provider = new InMemoryChatHistoryProvider(); - - // Act & Assert - var exception = Assert.Throws(() => session.ChatHistoryProvider = provider); - Assert.Equal("Only the ConversationId or ChatHistoryProvider may be set, but not both and switching from one to another is not supported.", exception.Message); - Assert.NotNull(session.ConversationId); } #endregion Constructor and Property Tests @@ -93,29 +42,33 @@ public void SetChatHistoryProviderThrowsWhenConversationIdIsSet() #region Deserialize Tests [Fact] - public async Task VerifyDeserializeWithMessagesAsync() + public void VerifyDeserializeWithMessages() { // Arrange var json = JsonSerializer.Deserialize(""" { - "chatHistoryProviderState": { "messages": [{"authorName": "testAuthor"}] } + "stateBag": { + "InMemoryChatHistoryProvider.State": { + "messages": [{"authorName": "testAuthor"}] + } + } } """, TestJsonSerializerContext.Default.JsonElement); // Act. - var session = await ChatClientAgentSession.DeserializeAsync(json); + var session = ChatClientAgentSession.Deserialize(json, TestJsonSerializerContext.Default.Options); // Assert Assert.Null(session.ConversationId); - var chatHistoryProvider = session.ChatHistoryProvider as InMemoryChatHistoryProvider; - Assert.NotNull(chatHistoryProvider); - Assert.Single(chatHistoryProvider); - Assert.Equal("testAuthor", chatHistoryProvider[0].AuthorName); + var chatHistoryProvider = new InMemoryChatHistoryProvider(); + var messages = chatHistoryProvider.GetMessages(session); + Assert.Single(messages); + Assert.Equal("testAuthor", messages[0].AuthorName); } [Fact] - public async Task VerifyDeserializeWithIdAsync() + public void VerifyDeserializeWithId() { // Arrange var json = JsonSerializer.Deserialize(""" @@ -125,42 +78,43 @@ public async Task VerifyDeserializeWithIdAsync() """, TestJsonSerializerContext.Default.JsonElement); // Act - var session = await ChatClientAgentSession.DeserializeAsync(json); + var session = ChatClientAgentSession.Deserialize(json); // Assert Assert.Equal("TestConvId", session.ConversationId); - Assert.Null(session.ChatHistoryProvider); } [Fact] - public async Task VerifyDeserializeWithAIContextProviderAsync() + public void VerifyDeserializeWithStateBag() { // Arrange var json = JsonSerializer.Deserialize(""" { "conversationId": "TestConvId", - "aiContextProviderState": ["CP1"] + "stateBag": { + "dog": { + "name": "Fido" + } + } } """, TestJsonSerializerContext.Default.JsonElement); - Mock mockProvider = new(); - // Act - var session = await ChatClientAgentSession.DeserializeAsync(json, aiContextProviderFactory: (_, _, _) => new(mockProvider.Object)); + var session = ChatClientAgentSession.Deserialize(json); // Assert - Assert.Null(session.ChatHistoryProvider); - Assert.Same(session.AIContextProvider, mockProvider.Object); + var dog = session.StateBag.GetValue("dog", TestJsonSerializerContext.Default.Options); + Assert.NotNull(dog); + Assert.Equal("Fido", dog.Name); } [Fact] - public async Task DeserializeWithInvalidJsonThrowsAsync() + public void DeserializeWithInvalidJsonThrows() { // Arrange var invalidJson = JsonSerializer.Deserialize("[42]", TestJsonSerializerContext.Default.JsonElement); - var session = new ChatClientAgentSession(); // Act & Assert - await Assert.ThrowsAsync(() => ChatClientAgentSession.DeserializeAsync(invalidJson)); + Assert.Throws(() => ChatClientAgentSession.Deserialize(invalidJson)); } #endregion Deserialize Tests @@ -195,8 +149,9 @@ public void VerifySessionSerializationWithId() public void VerifySessionSerializationWithMessages() { // Arrange - InMemoryChatHistoryProvider provider = [new(ChatRole.User, "TestContent") { AuthorName = "TestAuthor" }]; - var session = new ChatClientAgentSession { ChatHistoryProvider = provider }; + var provider = new InMemoryChatHistoryProvider(); + var session = new ChatClientAgentSession(); + provider.SetMessages(session, [new(ChatRole.User, "TestContent") { AuthorName = "TestAuthor" }]); // Act var json = session.Serialize(); @@ -206,10 +161,12 @@ public void VerifySessionSerializationWithMessages() Assert.False(json.TryGetProperty("conversationId", out _)); - Assert.True(json.TryGetProperty("chatHistoryProviderState", out var chatHistoryProviderStateProperty)); - Assert.Equal(JsonValueKind.Object, chatHistoryProviderStateProperty.ValueKind); - - Assert.True(chatHistoryProviderStateProperty.TryGetProperty("messages", out var messagesProperty)); + // Messages should be stored in the stateBag + Assert.True(json.TryGetProperty("stateBag", out var stateBagProperty)); + Assert.Equal(JsonValueKind.Object, stateBagProperty.ValueKind); + Assert.True(stateBagProperty.TryGetProperty("InMemoryChatHistoryProvider.State", out var providerStateProperty)); + Assert.Equal(JsonValueKind.Object, providerStateProperty.ValueKind); + Assert.True(providerStateProperty.TryGetProperty("messages", out var messagesProperty)); Assert.Equal(JsonValueKind.Array, messagesProperty.ValueKind); Assert.Single(messagesProperty.EnumerateArray()); @@ -224,29 +181,23 @@ public void VerifySessionSerializationWithMessages() } [Fact] - public void VerifySessionSerializationWithWithAIContextProvider() + public void VerifySessionSerializationWithWithStateBag() { // Arrange - Mock mockProvider = new(); - mockProvider - .Setup(m => m.Serialize(It.IsAny())) - .Returns(JsonSerializer.SerializeToElement(["CP1"], TestJsonSerializerContext.Default.StringArray)); - - var session = new ChatClientAgentSession - { - AIContextProvider = mockProvider.Object - }; + var session = new ChatClientAgentSession(); + session.StateBag.SetValue("dog", new Animal { Name = "Fido" }, TestJsonSerializerContext.Default.Options); // Act var json = session.Serialize(); // Assert Assert.Equal(JsonValueKind.Object, json.ValueKind); - Assert.True(json.TryGetProperty("aiContextProviderState", out var providerStateProperty)); - Assert.Equal(JsonValueKind.Array, providerStateProperty.ValueKind); - Assert.Single(providerStateProperty.EnumerateArray()); - Assert.Equal("CP1", providerStateProperty.EnumerateArray().First().GetString()); - mockProvider.Verify(m => m.Serialize(It.IsAny()), Times.Once); + Assert.True(json.TryGetProperty("stateBag", out var stateBagProperty)); + Assert.Equal(JsonValueKind.Object, stateBagProperty.ValueKind); + Assert.True(stateBagProperty.TryGetProperty("dog", out var dogProperty)); + Assert.Equal(JsonValueKind.Object, dogProperty.ValueKind); + Assert.True(dogProperty.TryGetProperty("name", out var nameProperty)); + Assert.Equal("Fido", nameProperty.GetString()); } /// @@ -258,17 +209,7 @@ public void VerifySessionSerializationWithCustomOptions() // Arrange var session = new ChatClientAgentSession(); JsonSerializerOptions options = new() { PropertyNamingPolicy = JsonNamingPolicy.SnakeCaseLower }; - options.TypeInfoResolverChain.Add(AgentAbstractionsJsonUtilities.DefaultOptions.TypeInfoResolver!); - - var chatHistoryProviderStateElement = JsonSerializer.SerializeToElement( - new Dictionary { ["Key"] = "TestValue" }, - TestJsonSerializerContext.Default.DictionaryStringObject); - - var chatHistoryProviderMock = new Mock(); - chatHistoryProviderMock - .Setup(m => m.Serialize(options)) - .Returns(chatHistoryProviderStateElement); - session.ChatHistoryProvider = chatHistoryProviderMock.Object; + options.TypeInfoResolverChain.Add(AgentJsonUtilities.DefaultOptions.TypeInfoResolver!); // Act var json = session.Serialize(options); @@ -276,54 +217,29 @@ public void VerifySessionSerializationWithCustomOptions() // Assert Assert.Equal(JsonValueKind.Object, json.ValueKind); - Assert.False(json.TryGetProperty("conversationId", out var idProperty)); - - Assert.True(json.TryGetProperty("chatHistoryProviderState", out var chatHistoryProviderStateProperty)); - Assert.Equal(JsonValueKind.Object, chatHistoryProviderStateProperty.ValueKind); - - Assert.True(chatHistoryProviderStateProperty.TryGetProperty("Key", out var keyProperty)); - Assert.Equal("TestValue", keyProperty.GetString()); - - chatHistoryProviderMock.Verify(m => m.Serialize(options), Times.Once); + // [JsonPropertyName] takes precedence over naming policy + Assert.True(json.TryGetProperty("conversationId", out var _)); } #endregion Serialize Tests - #region GetService Tests - - [Fact] - public void GetService_RequestingAIContextProvider_ReturnsAIContextProvider() - { - // Arrange - var session = new ChatClientAgentSession(); - var mockProvider = new Mock(); - mockProvider - .Setup(m => m.GetService(It.Is(x => x == typeof(AIContextProvider)), null)) - .Returns(mockProvider.Object); - session.AIContextProvider = mockProvider.Object; - - // Act - var result = session.GetService(typeof(AIContextProvider)); - - // Assert - Assert.NotNull(result); - Assert.Same(mockProvider.Object, result); - } + #region StateBag Roundtrip Tests [Fact] - public void GetService_RequestingChatHistoryProvider_ReturnsChatHistoryProvider() + public void VerifyStateBagRoundtrips() { // Arrange var session = new ChatClientAgentSession(); - var chatHistoryProvider = new InMemoryChatHistoryProvider(); - session.ChatHistoryProvider = chatHistoryProvider; + session.StateBag.SetValue("dog", new Animal { Name = "Fido" }, TestJsonSerializerContext.Default.Options); // Act - var result = session.GetService(typeof(ChatHistoryProvider)); + var serializedSession = session.Serialize(); + var deserializedSession = ChatClientAgentSession.Deserialize(serializedSession); // Assert - Assert.NotNull(result); - Assert.Same(chatHistoryProvider, result); + var dog = deserializedSession.StateBag.GetValue("dog", TestJsonSerializerContext.Default.Options); + Assert.NotNull(dog); + Assert.Equal("Fido", dog.Name); } #endregion diff --git a/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgentTests.cs b/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgentTests.cs index 41fb29bfed..6544265711 100644 --- a/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgentTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgentTests.cs @@ -356,7 +356,7 @@ public async Task RunAsyncInvokesAIContextProviderAndUsesResultAsync() .Setup("InvokedCoreAsync", ItExpr.IsAny(), ItExpr.IsAny()) .Returns(new ValueTask()); - ChatClientAgent agent = new(mockService.Object, options: new() { AIContextProviderFactory = (_, _) => new(mockProvider.Object), ChatOptions = new() { Instructions = "base instructions", Tools = [AIFunctionFactory.Create(() => { }, "base function")] } }); + ChatClientAgent agent = new(mockService.Object, options: new() { AIContextProvider = mockProvider.Object, ChatOptions = new() { Instructions = "base instructions", Tools = [AIFunctionFactory.Create(() => { }, "base function")] } }); // Act var session = await agent.CreateSessionAsync() as ChatClientAgentSession; @@ -375,11 +375,13 @@ public async Task RunAsyncInvokesAIContextProviderAndUsesResultAsync() Assert.Contains(capturedTools, t => t.Name == "context provider function"); // Verify that the session was updated with the ai context provider, input and response messages - var chatHistoryProvider = Assert.IsType(session!.ChatHistoryProvider); - Assert.Equal(3, chatHistoryProvider.Count); - Assert.Equal("user message", chatHistoryProvider[0].Text); - Assert.Equal("context provider message", chatHistoryProvider[1].Text); - Assert.Equal("response", chatHistoryProvider[2].Text); + var chatHistoryProvider = agent.ChatHistoryProvider as InMemoryChatHistoryProvider; + Assert.NotNull(chatHistoryProvider); + var messages = chatHistoryProvider.GetMessages(session); + Assert.Equal(3, messages.Count); + Assert.Equal("user message", messages[0].Text); + Assert.Equal("context provider message", messages[1].Text); + Assert.Equal("response", messages[2].Text); mockProvider .Protected() @@ -422,7 +424,7 @@ public async Task RunAsyncInvokesAIContextProviderWhenGetResponseFailsAsync() .Setup("InvokedCoreAsync", ItExpr.IsAny(), ItExpr.IsAny()) .Returns(new ValueTask()); - ChatClientAgent agent = new(mockService.Object, options: new() { AIContextProviderFactory = (_, _) => new(mockProvider.Object), ChatOptions = new() { Instructions = "base instructions", Tools = [AIFunctionFactory.Create(() => { }, "base function")] } }); + ChatClientAgent agent = new(mockService.Object, options: new() { AIContextProvider = mockProvider.Object, ChatOptions = new() { Instructions = "base instructions", Tools = [AIFunctionFactory.Create(() => { }, "base function")] } }); // Act await Assert.ThrowsAsync(() => agent.RunAsync(requestMessages)); @@ -472,7 +474,7 @@ public async Task RunAsyncInvokesAIContextProviderAndSucceedsWithEmptyAIContextA .Setup>("InvokingCoreAsync", ItExpr.IsAny(), ItExpr.IsAny()) .ReturnsAsync(new AIContext()); - ChatClientAgent agent = new(mockService.Object, options: new() { AIContextProviderFactory = (_, _) => new(mockProvider.Object), ChatOptions = new() { Instructions = "base instructions", Tools = [AIFunctionFactory.Create(() => { }, "base function")] } }); + ChatClientAgent agent = new(mockService.Object, options: new() { AIContextProvider = mockProvider.Object, ChatOptions = new() { Instructions = "base instructions", Tools = [AIFunctionFactory.Create(() => { }, "base function")] } }); // Act await agent.RunAsync([new(ChatRole.User, "user message")]); @@ -1300,12 +1302,10 @@ public async Task RunStreamingAsyncUsesChatHistoryProviderWhenNoConversationIdRe It.IsAny>(), It.IsAny(), It.IsAny())).Returns(ToAsyncEnumerableAsync(returnUpdates)); - Mock>> mockFactory = new(); - mockFactory.Setup(f => f(It.IsAny(), It.IsAny())).ReturnsAsync(new InMemoryChatHistoryProvider()); ChatClientAgent agent = new(mockService.Object, options: new() { ChatOptions = new() { Instructions = "test instructions" }, - ChatHistoryProviderFactory = mockFactory.Object + ChatHistoryProvider = new InMemoryChatHistoryProvider() }); // Act @@ -1313,18 +1313,18 @@ public async Task RunStreamingAsyncUsesChatHistoryProviderWhenNoConversationIdRe await agent.RunStreamingAsync([new(ChatRole.User, "test")], session).ToListAsync(); // Assert - var chatHistoryProvider = Assert.IsType(session!.ChatHistoryProvider); - Assert.Equal(2, chatHistoryProvider.Count); - Assert.Equal("test", chatHistoryProvider[0].Text); - Assert.Equal("what?", chatHistoryProvider[1].Text); - mockFactory.Verify(f => f(It.IsAny(), It.IsAny()), Times.Once); + var chatHistoryProvider = Assert.IsType(agent.GetService(typeof(ChatHistoryProvider))); + var historyMessages = chatHistoryProvider.GetMessages(session); + Assert.Equal(2, historyMessages.Count); + Assert.Equal("test", historyMessages[0].Text); + Assert.Equal("what?", historyMessages[1].Text); } /// - /// Verify that RunStreamingAsync throws when a factory is provided and the chat client returns a conversation id. + /// Verify that RunStreamingAsync throws when a is provided and the chat client returns a conversation id. /// [Fact] - public async Task RunStreamingAsyncThrowsWhenChatHistoryProviderFactoryProvidedAndConversationIdReturnedByChatClientAsync() + public async Task RunStreamingAsyncThrowsWhenChatHistoryProviderProvidedAndConversationIdReturnedByChatClientAsync() { // Arrange Mock mockService = new(); @@ -1338,18 +1338,16 @@ public async Task RunStreamingAsyncThrowsWhenChatHistoryProviderFactoryProvidedA It.IsAny>(), It.IsAny(), It.IsAny())).Returns(ToAsyncEnumerableAsync(returnUpdates)); - Mock>> mockFactory = new(); - mockFactory.Setup(f => f(It.IsAny(), It.IsAny())).ReturnsAsync(new InMemoryChatHistoryProvider()); ChatClientAgent agent = new(mockService.Object, options: new() { ChatOptions = new() { Instructions = "test instructions" }, - ChatHistoryProviderFactory = mockFactory.Object + ChatHistoryProvider = new InMemoryChatHistoryProvider() }); // Act & Assert ChatClientAgentSession? session = await agent.CreateSessionAsync() as ChatClientAgentSession; var exception = await Assert.ThrowsAsync(async () => await agent.RunStreamingAsync([new(ChatRole.User, "test")], session).ToListAsync()); - Assert.Equal("Only the ConversationId or ChatHistoryProvider may be set, but not both and switching from one to another is not supported.", exception.Message); + Assert.Equal("Only ConversationId or ChatHistoryProvider may be used, but not both. The service returned a conversation id indicating server-side chat history management, but the agent has a ChatHistoryProvider configured.", exception.Message); } /// @@ -1402,7 +1400,7 @@ public async Task RunStreamingAsyncInvokesAIContextProviderAndUsesResultAsync() options: new() { ChatOptions = new() { Instructions = "base instructions", Tools = [AIFunctionFactory.Create(() => { }, "base function")] }, - AIContextProviderFactory = (_, _) => new(mockProvider.Object) + AIContextProvider = mockProvider.Object }); // Act @@ -1423,11 +1421,13 @@ public async Task RunStreamingAsyncInvokesAIContextProviderAndUsesResultAsync() Assert.Contains(capturedTools, t => t.Name == "context provider function"); // Verify that the session was updated with the input, ai context provider, and response messages - var chatHistoryProvider = Assert.IsType(session!.ChatHistoryProvider); - Assert.Equal(3, chatHistoryProvider.Count); - Assert.Equal("user message", chatHistoryProvider[0].Text); - Assert.Equal("context provider message", chatHistoryProvider[1].Text); - Assert.Equal("response", chatHistoryProvider[2].Text); + var chatHistoryProvider = agent.ChatHistoryProvider as InMemoryChatHistoryProvider; + Assert.NotNull(chatHistoryProvider); + var historyMessages2 = chatHistoryProvider.GetMessages(session); + Assert.Equal(3, historyMessages2.Count); + Assert.Equal("user message", historyMessages2[0].Text); + Assert.Equal("context provider message", historyMessages2[1].Text); + Assert.Equal("response", historyMessages2[2].Text); mockProvider .Protected() @@ -1476,7 +1476,7 @@ public async Task RunStreamingAsyncInvokesAIContextProviderWhenGetResponseFailsA options: new() { ChatOptions = new() { Instructions = "base instructions", Tools = [AIFunctionFactory.Create(() => { }, "base function")] }, - AIContextProviderFactory = (_, _) => new(mockProvider.Object) + AIContextProvider = mockProvider.Object }); // Act diff --git a/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgent_BackgroundResponsesTests.cs b/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgent_BackgroundResponsesTests.cs index 2eed890292..8f2f5bebd1 100644 --- a/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgent_BackgroundResponsesTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgent_BackgroundResponsesTests.cs @@ -365,14 +365,14 @@ public async Task RunAsync_WhenContinuationTokenProvided_SkipsSessionMessagePopu capturedMessages.AddRange(msgs)) .ReturnsAsync(new ChatResponse([new(ChatRole.Assistant, "continued response")])); - ChatClientAgent agent = new(mockChatClient.Object); - - // Create a session with both chat history provider and AI context provider - ChatClientAgentSession? session = new() + ChatClientAgent agent = new(mockChatClient.Object, options: new() { ChatHistoryProvider = mockChatHistoryProvider.Object, AIContextProvider = mockContextProvider.Object - }; + }); + + // Create a session + ChatClientAgentSession? session = new(); AgentRunOptions runOptions = new() { @@ -432,14 +432,14 @@ public async Task RunStreamingAsync_WhenContinuationTokenProvided_SkipsSessionMe capturedMessages.AddRange(msgs)) .Returns(ToAsyncEnumerableAsync([new ChatResponseUpdate(role: ChatRole.Assistant, content: "continued response")])); - ChatClientAgent agent = new(mockChatClient.Object); - - // Create a session with both chat history provider and AI context provider - ChatClientAgentSession? session = new() + ChatClientAgent agent = new(mockChatClient.Object, options: new() { ChatHistoryProvider = mockChatHistoryProvider.Object, AIContextProvider = mockContextProvider.Object - }; + }); + + // Create a session + ChatClientAgentSession? session = new(); AgentRunOptions runOptions = new() { @@ -633,8 +633,6 @@ public async Task RunStreamingAsync_WhenResumingStreaming_UsesUpdatesFromInitial It.IsAny())) .Returns(ToAsyncEnumerableAsync(returnUpdates)); - ChatClientAgent agent = new(mockChatClient.Object); - List capturedMessagesAddedToProvider = []; var mockChatHistoryProvider = new Mock(); mockChatHistoryProvider @@ -651,11 +649,13 @@ public async Task RunStreamingAsync_WhenResumingStreaming_UsesUpdatesFromInitial .Callback((context, ct) => capturedInvokedContext = context) .Returns(new ValueTask()); - ChatClientAgentSession? session = new() + ChatClientAgent agent = new(mockChatClient.Object, options: new() { ChatHistoryProvider = mockChatHistoryProvider.Object, AIContextProvider = mockContextProvider.Object - }; + }); + + ChatClientAgentSession? session = new(); AgentRunOptions runOptions = new() { @@ -695,8 +695,6 @@ public async Task RunStreamingAsync_WhenResumingStreaming_UsesInputMessagesFromI It.IsAny())) .Returns(ToAsyncEnumerableAsync(Array.Empty())); - ChatClientAgent agent = new(mockChatClient.Object); - List capturedMessagesAddedToProvider = []; var mockChatHistoryProvider = new Mock(); mockChatHistoryProvider @@ -713,11 +711,13 @@ public async Task RunStreamingAsync_WhenResumingStreaming_UsesInputMessagesFromI .Callback((context, ct) => capturedInvokedContext = context) .Returns(new ValueTask()); - ChatClientAgentSession? session = new() + ChatClientAgent agent = new(mockChatClient.Object, options: new() { ChatHistoryProvider = mockChatHistoryProvider.Object, AIContextProvider = mockContextProvider.Object - }; + }); + + ChatClientAgentSession? session = new(); AgentRunOptions runOptions = new() { diff --git a/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgent_ChatHistoryManagementTests.cs b/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgent_ChatHistoryManagementTests.cs index 4de8f01f8e..b15f944103 100644 --- a/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgent_ChatHistoryManagementTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgent_ChatHistoryManagementTests.cs @@ -163,17 +163,19 @@ public async Task RunAsync_UsesDefaultInMemoryChatHistoryProvider_WhenNoConversa await agent.RunAsync([new(ChatRole.User, "test")], session); // Assert - InMemoryChatHistoryProvider chatHistoryProvider = Assert.IsType(session!.ChatHistoryProvider); - Assert.Equal(2, chatHistoryProvider.Count); - Assert.Equal("test", chatHistoryProvider[0].Text); - Assert.Equal("response", chatHistoryProvider[1].Text); + var inMemoryProvider = agent.ChatHistoryProvider as InMemoryChatHistoryProvider; + Assert.NotNull(inMemoryProvider); + var messages = inMemoryProvider.GetMessages(session!); + Assert.Equal(2, messages.Count); + Assert.Equal("test", messages[0].Text); + Assert.Equal("response", messages[1].Text); } /// - /// Verify that RunAsync uses the ChatHistoryProvider factory when the chat client returns no conversation id. + /// Verify that RunAsync uses the ChatHistoryProvider when the chat client returns no conversation id. /// [Fact] - public async Task RunAsync_UsesChatHistoryProviderFactory_WhenProvidedAndNoConversationIdReturnedByChatClientAsync() + public async Task RunAsync_UsesChatHistoryProvider_WhenProvidedAndNoConversationIdReturnedByChatClientAsync() { // Arrange Mock mockService = new(); @@ -193,13 +195,10 @@ public async Task RunAsync_UsesChatHistoryProviderFactory_WhenProvidedAndNoConve .Setup("InvokedCoreAsync", ItExpr.IsAny(), ItExpr.IsAny()) .Returns(new ValueTask()); - Mock>> mockFactory = new(); - mockFactory.Setup(f => f(It.IsAny(), It.IsAny())).ReturnsAsync(mockChatHistoryProvider.Object); - ChatClientAgent agent = new(mockService.Object, options: new() { ChatOptions = new() { Instructions = "test instructions" }, - ChatHistoryProviderFactory = mockFactory.Object + ChatHistoryProvider = mockChatHistoryProvider.Object }); // Act @@ -207,7 +206,7 @@ public async Task RunAsync_UsesChatHistoryProviderFactory_WhenProvidedAndNoConve await agent.RunAsync([new(ChatRole.User, "test")], session); // Assert - Assert.IsType(session!.ChatHistoryProvider, exactMatch: false); + Assert.Same(mockChatHistoryProvider.Object, agent.ChatHistoryProvider); mockService.Verify( x => x.GetResponseAsync( It.Is>(msgs => msgs.Count() == 2 && msgs.Any(m => m.Text == "Existing Chat History") && msgs.Any(m => m.Text == "test")), @@ -224,7 +223,6 @@ public async Task RunAsync_UsesChatHistoryProviderFactory_WhenProvidedAndNoConve .Verify("InvokedCoreAsync", Times.Once(), ItExpr.Is(x => x.RequestMessages.Count() == 1 && x.ResponseMessages!.Count() == 1), ItExpr.IsAny()); - mockFactory.Verify(f => f(It.IsAny(), It.IsAny()), Times.Once); } /// @@ -243,13 +241,10 @@ public async Task RunAsync_NotifiesChatHistoryProvider_OnFailureAsync() Mock mockChatHistoryProvider = new(); - Mock>> mockFactory = new(); - mockFactory.Setup(f => f(It.IsAny(), It.IsAny())).ReturnsAsync(mockChatHistoryProvider.Object); - ChatClientAgent agent = new(mockService.Object, options: new() { ChatOptions = new() { Instructions = "test instructions" }, - ChatHistoryProviderFactory = mockFactory.Object + ChatHistoryProvider = mockChatHistoryProvider.Object }); // Act @@ -257,20 +252,19 @@ public async Task RunAsync_NotifiesChatHistoryProvider_OnFailureAsync() await Assert.ThrowsAsync(() => agent.RunAsync([new(ChatRole.User, "test")], session)); // Assert - Assert.IsType(session!.ChatHistoryProvider, exactMatch: false); + Assert.Same(mockChatHistoryProvider.Object, agent.ChatHistoryProvider); mockChatHistoryProvider .Protected() .Verify("InvokedCoreAsync", Times.Once(), ItExpr.Is(x => x.RequestMessages.Count() == 1 && x.ResponseMessages == null && x.InvokeException!.Message == "Test Error"), ItExpr.IsAny()); - mockFactory.Verify(f => f(It.IsAny(), It.IsAny()), Times.Once); } /// - /// Verify that RunAsync throws when a ChatHistoryProvider Factory is provided and the chat client returns a conversation id. + /// Verify that RunAsync throws when a ChatHistoryProvider is provided and the chat client returns a conversation id. /// [Fact] - public async Task RunAsync_Throws_WhenChatHistoryProviderFactoryProvidedAndConversationIdReturnedByChatClientAsync() + public async Task RunAsync_Throws_WhenChatHistoryProviderProvidedAndConversationIdReturnedByChatClientAsync() { // Arrange Mock mockService = new(); @@ -279,18 +273,16 @@ public async Task RunAsync_Throws_WhenChatHistoryProviderFactoryProvidedAndConve It.IsAny>(), It.IsAny(), It.IsAny())).ReturnsAsync(new ChatResponse([new(ChatRole.Assistant, "response")]) { ConversationId = "ConvId" }); - Mock>> mockFactory = new(); - mockFactory.Setup(f => f(It.IsAny(), It.IsAny())).ReturnsAsync(new InMemoryChatHistoryProvider()); ChatClientAgent agent = new(mockService.Object, options: new() { ChatOptions = new() { Instructions = "test instructions" }, - ChatHistoryProviderFactory = mockFactory.Object + ChatHistoryProvider = new InMemoryChatHistoryProvider() }); // Act & Assert ChatClientAgentSession? session = await agent.CreateSessionAsync() as ChatClientAgentSession; InvalidOperationException exception = await Assert.ThrowsAsync(() => agent.RunAsync([new(ChatRole.User, "test")], session)); - Assert.Equal("Only the ConversationId or ChatHistoryProvider may be set, but not both and switching from one to another is not supported.", exception.Message); + Assert.Equal("Only ConversationId or ChatHistoryProvider may be used, but not both. The service returned a conversation id indicating server-side chat history management, but the agent has a ChatHistoryProvider configured.", exception.Message); } #endregion @@ -323,25 +315,22 @@ public async Task RunAsync_UsesOverrideChatHistoryProvider_WhenProvidedViaAdditi .Setup("InvokedCoreAsync", ItExpr.IsAny(), ItExpr.IsAny()) .Returns(new ValueTask()); - // Arrange a chat history provider to provide to the agent via a factory at construction time. + // Arrange a chat history provider to provide to the agent at construction time. // This one shouldn't be used since it is being overridden. - Mock mockFactoryChatHistoryProvider = new(); - mockFactoryChatHistoryProvider + Mock mockAgentOptionsChatHistoryProvider = new(); + mockAgentOptionsChatHistoryProvider .Protected() .Setup>>("InvokingCoreAsync", ItExpr.IsAny(), ItExpr.IsAny()) .ThrowsAsync(FailException.ForFailure("Base ChatHistoryProvider shouldn't be used.")); - mockFactoryChatHistoryProvider + mockAgentOptionsChatHistoryProvider .Protected() .Setup("InvokedCoreAsync", ItExpr.IsAny(), ItExpr.IsAny()) .Throws(FailException.ForFailure("Base ChatHistoryProvider shouldn't be used.")); - Mock>> mockFactory = new(); - mockFactory.Setup(f => f(It.IsAny(), It.IsAny())).ReturnsAsync(mockFactoryChatHistoryProvider.Object); - ChatClientAgent agent = new(mockService.Object, options: new() { ChatOptions = new() { Instructions = "test instructions" }, - ChatHistoryProviderFactory = mockFactory.Object + ChatHistoryProvider = mockAgentOptionsChatHistoryProvider.Object }); // Act @@ -351,7 +340,7 @@ public async Task RunAsync_UsesOverrideChatHistoryProvider_WhenProvidedViaAdditi await agent.RunAsync([new(ChatRole.User, "test")], session, options: new AgentRunOptions { AdditionalProperties = additionalProperties }); // Assert - Assert.Same(mockFactoryChatHistoryProvider.Object, session!.ChatHistoryProvider); + Assert.Same(mockAgentOptionsChatHistoryProvider.Object, agent.ChatHistoryProvider); mockService.Verify( x => x.GetResponseAsync( It.Is>(msgs => msgs.Count() == 2 && msgs.Any(m => m.Text == "Existing Chat History") && msgs.Any(m => m.Text == "test")), @@ -369,12 +358,12 @@ public async Task RunAsync_UsesOverrideChatHistoryProvider_WhenProvidedViaAdditi ItExpr.Is(x => x.RequestMessages.Count() == 1 && x.ResponseMessages!.Count() == 1), ItExpr.IsAny()); - mockFactoryChatHistoryProvider + mockAgentOptionsChatHistoryProvider .Protected() .Verify>>("InvokingCoreAsync", Times.Never(), ItExpr.IsAny(), ItExpr.IsAny()); - mockFactoryChatHistoryProvider + mockAgentOptionsChatHistoryProvider .Protected() .Verify("InvokedCoreAsync", Times.Never(), ItExpr.IsAny(), diff --git a/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgent_CreateSessionTests.cs b/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgent_CreateSessionTests.cs index 86220a6462..68fd008e9e 100644 --- a/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgent_CreateSessionTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgent_CreateSessionTests.cs @@ -11,77 +11,6 @@ namespace Microsoft.Agents.AI.UnitTests; /// public class ChatClientAgent_CreateSessionTests { - [Fact] - public async Task CreateSession_UsesAIContextProviderFactory_IfProvidedAsync() - { - // Arrange - var mockChatClient = new Mock(); - var mockContextProvider = new Mock(); - var factoryCalled = false; - var agent = new ChatClientAgent(mockChatClient.Object, new ChatClientAgentOptions - { - ChatOptions = new() { Instructions = "Test instructions" }, - AIContextProviderFactory = (_, _) => - { - factoryCalled = true; - return new ValueTask(mockContextProvider.Object); - } - }); - - // Act - var session = await agent.CreateSessionAsync(); - - // Assert - Assert.True(factoryCalled, "AIContextProviderFactory was not called."); - Assert.IsType(session); - var typedSession = (ChatClientAgentSession)session; - Assert.Same(mockContextProvider.Object, typedSession.AIContextProvider); - } - - [Fact] - public async Task CreateSession_UsesChatHistoryProviderFactory_IfProvidedAsync() - { - // Arrange - var mockChatClient = new Mock(); - var mockChatHistoryProvider = new Mock(); - var factoryCalled = false; - var agent = new ChatClientAgent(mockChatClient.Object, new ChatClientAgentOptions - { - ChatOptions = new() { Instructions = "Test instructions" }, - ChatHistoryProviderFactory = (_, _) => - { - factoryCalled = true; - return new ValueTask(mockChatHistoryProvider.Object); - } - }); - - // Act - var session = await agent.CreateSessionAsync(); - - // Assert - Assert.True(factoryCalled, "ChatHistoryProviderFactory was not called."); - Assert.IsType(session); - var typedSession = (ChatClientAgentSession)session; - Assert.Same(mockChatHistoryProvider.Object, typedSession.ChatHistoryProvider); - } - - [Fact] - public async Task CreateSession_UsesChatHistoryProvider_FromTypedOverloadAsync() - { - // Arrange - var mockChatClient = new Mock(); - var mockChatHistoryProvider = new Mock(); - var agent = new ChatClientAgent(mockChatClient.Object); - - // Act - var session = await agent.CreateSessionAsync(mockChatHistoryProvider.Object); - - // Assert - Assert.IsType(session); - var typedSession = (ChatClientAgentSession)session; - Assert.Same(mockChatHistoryProvider.Object, typedSession.ChatHistoryProvider); - } - [Fact] public async Task CreateSession_UsesConversationId_FromTypedOverloadAsync() { diff --git a/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgent_DeserializeSessionTests.cs b/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgent_DeserializeSessionTests.cs deleted file mode 100644 index 014cb1483b..0000000000 --- a/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgent_DeserializeSessionTests.cs +++ /dev/null @@ -1,80 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. - -using System.Text.Json; -using System.Threading.Tasks; -using Microsoft.Extensions.AI; -using Moq; - -namespace Microsoft.Agents.AI.UnitTests; - -/// -/// Contains unit tests for the ChatClientAgent.DeserializeSession methods. -/// -public class ChatClientAgent_DeserializeSessionTests -{ - [Fact] - public async Task DeserializeSession_UsesAIContextProviderFactory_IfProvidedAsync() - { - // Arrange - var mockChatClient = new Mock(); - var mockContextProvider = new Mock(); - var factoryCalled = false; - var agent = new ChatClientAgent(mockChatClient.Object, new ChatClientAgentOptions - { - ChatOptions = new() { Instructions = "Test instructions" }, - AIContextProviderFactory = (_, _) => - { - factoryCalled = true; - return new ValueTask(mockContextProvider.Object); - } - }); - - var json = JsonSerializer.Deserialize(""" - { - "aiContextProviderState": ["CP1"] - } - """, TestJsonSerializerContext.Default.JsonElement); - - // Act - var session = await agent.DeserializeSessionAsync(json); - - // Assert - Assert.True(factoryCalled, "AIContextProviderFactory was not called."); - Assert.IsType(session); - var typedSession = (ChatClientAgentSession)session; - Assert.Same(mockContextProvider.Object, typedSession.AIContextProvider); - } - - [Fact] - public async Task DeserializeSession_UsesChatHistoryProviderFactory_IfProvidedAsync() - { - // Arrange - var mockChatClient = new Mock(); - var mockChatHistoryProvider = new Mock(); - var factoryCalled = false; - var agent = new ChatClientAgent(mockChatClient.Object, new ChatClientAgentOptions - { - ChatOptions = new() { Instructions = "Test instructions" }, - ChatHistoryProviderFactory = (_, _) => - { - factoryCalled = true; - return new ValueTask(mockChatHistoryProvider.Object); - } - }); - - var json = JsonSerializer.Deserialize(""" - { - "chatHistoryProviderState": { } - } - """, TestJsonSerializerContext.Default.JsonElement); - - // Act - var session = await agent.DeserializeSessionAsync(json); - - // Assert - Assert.True(factoryCalled, "ChatHistoryProviderFactory was not called."); - Assert.IsType(session); - var typedSession = (ChatClientAgentSession)session; - Assert.Same(mockChatHistoryProvider.Object, typedSession.ChatHistoryProvider); - } -} diff --git a/dotnet/tests/Microsoft.Agents.AI.UnitTests/Data/TextSearchProviderTests.cs b/dotnet/tests/Microsoft.Agents.AI.UnitTests/Data/TextSearchProviderTests.cs index ec8dda3c45..d46a05252c 100644 --- a/dotnet/tests/Microsoft.Agents.AI.UnitTests/Data/TextSearchProviderTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.UnitTests/Data/TextSearchProviderTests.cs @@ -18,7 +18,6 @@ namespace Microsoft.Agents.AI.UnitTests.Data; public sealed class TextSearchProviderTests { private static readonly AIAgent s_mockAgent = new Mock().Object; - private static readonly AgentSession s_mockSession = new Mock().Object; private readonly Mock> _loggerMock; private readonly Mock _loggerFactoryMock; @@ -64,11 +63,11 @@ public async Task InvokingAsync_ShouldInjectFormattedResultsAsync(string? overri ContextPrompt = overrideContextPrompt, CitationsPrompt = overrideCitationsPrompt }; - var provider = new TextSearchProvider(SearchDelegateAsync, default, null, options, withLogging ? this._loggerFactoryMock.Object : null); + var provider = new TextSearchProvider(SearchDelegateAsync, options, withLogging ? this._loggerFactoryMock.Object : null); var invokingContext = new AIContextProvider.InvokingContext( s_mockAgent, - s_mockSession, + new TestAgentSession(), [ new ChatMessage(ChatRole.User, "Sample user question?"), new ChatMessage(ChatRole.User, "Additional part") @@ -143,8 +142,8 @@ public async Task InvokingAsync_OnDemand_ShouldExposeSearchToolAsync(string? ove FunctionToolName = overrideName, FunctionToolDescription = overrideDescription }; - var provider = new TextSearchProvider(this.NoResultSearchAsync, default, null, options); - var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Q?")]); + var provider = new TextSearchProvider(this.NoResultSearchAsync, options); + var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, new TestAgentSession(), [new ChatMessage(ChatRole.User, "Q?")]); // Act var aiContext = await provider.InvokingAsync(invokingContext, CancellationToken.None); @@ -162,8 +161,8 @@ public async Task InvokingAsync_OnDemand_ShouldExposeSearchToolAsync(string? ove public async Task InvokingAsync_ShouldNotThrow_WhenSearchFailsAsync() { // Arrange - var provider = new TextSearchProvider(this.FailingSearchAsync, default, null, loggerFactory: this._loggerFactoryMock.Object); - var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Q?")]); + var provider = new TextSearchProvider(this.FailingSearchAsync, loggerFactory: this._loggerFactoryMock.Object); + var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, new TestAgentSession(), [new ChatMessage(ChatRole.User, "Q?")]); // Act var aiContext = await provider.InvokingAsync(invokingContext, CancellationToken.None); @@ -203,7 +202,7 @@ public async Task SearchAsync_ShouldReturnFormattedResultsAsync(string? override ContextPrompt = overrideContextPrompt, CitationsPrompt = overrideCitationsPrompt }; - var provider = new TextSearchProvider(SearchDelegateAsync, default, null, options); + var provider = new TextSearchProvider(SearchDelegateAsync, options); // Act var formatted = await provider.SearchAsync("Sample user question?", CancellationToken.None); @@ -255,8 +254,8 @@ public async Task InvokingAsync_ShouldUseContextFormatterWhenProvidedAsync() SearchTime = TextSearchProviderOptions.TextSearchBehavior.BeforeAIInvoke, ContextFormatter = r => $"Custom formatted context with {r.Count} results." }; - var provider = new TextSearchProvider(SearchDelegateAsync, default, null, options); - var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Q?")]); + var provider = new TextSearchProvider(SearchDelegateAsync, options); + var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, new TestAgentSession(), [new ChatMessage(ChatRole.User, "Q?")]); // Act var aiContext = await provider.InvokingAsync(invokingContext, CancellationToken.None); @@ -289,8 +288,8 @@ public async Task InvokingAsync_WithRawRepresentations_ContextFormatterCanAccess SearchTime = TextSearchProviderOptions.TextSearchBehavior.BeforeAIInvoke, ContextFormatter = r => string.Join(",", r.Select(x => ((RawPayload)x.RawRepresentation!).Id)) }; - var provider = new TextSearchProvider(SearchDelegateAsync, default, null, options); - var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Q?")]); + var provider = new TextSearchProvider(SearchDelegateAsync, options); + var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, new TestAgentSession(), [new ChatMessage(ChatRole.User, "Q?")]); // Act var aiContext = await provider.InvokingAsync(invokingContext, CancellationToken.None); @@ -306,8 +305,8 @@ public async Task InvokingAsync_WithNoResults_ShouldReturnEmptyContextAsync() { // Arrange var options = new TextSearchProviderOptions { SearchTime = TextSearchProviderOptions.TextSearchBehavior.BeforeAIInvoke }; - var provider = new TextSearchProvider(this.NoResultSearchAsync, default, null, options); - var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Q?")]); + var provider = new TextSearchProvider(this.NoResultSearchAsync, options); + var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, new TestAgentSession(), [new ChatMessage(ChatRole.User, "Q?")]); // Act var aiContext = await provider.InvokingAsync(invokingContext, CancellationToken.None); @@ -335,7 +334,7 @@ public async Task InvokingAsync_WithPreviousFailedRequest_ShouldNotIncludeFailed capturedInput = input; return Task.FromResult>([]); // No results needed. } - var provider = new TextSearchProvider(SearchDelegateAsync, default, null, options); + var provider = new TextSearchProvider(SearchDelegateAsync, options); // Populate memory with more messages than the limit (A,B,C,D) -> should retain B,C,D var initialMessages = new[] @@ -345,11 +344,13 @@ public async Task InvokingAsync_WithPreviousFailedRequest_ShouldNotIncludeFailed new ChatMessage(ChatRole.User, "C"), new ChatMessage(ChatRole.Assistant, "D"), }; - await provider.InvokedAsync(new(s_mockAgent, s_mockSession, initialMessages) { InvokeException = new InvalidOperationException("Request Failed") }); + + var session = new TestAgentSession(); + await provider.InvokedAsync(new(s_mockAgent, session, initialMessages) { InvokeException = new InvalidOperationException("Request Failed") }); var invokingContext = new AIContextProvider.InvokingContext( s_mockAgent, - s_mockSession, + session, [ new ChatMessage(ChatRole.User, "E") ]); @@ -377,7 +378,8 @@ public async Task InvokingAsync_WithRecentMessageMemory_ShouldIncludeStoredMessa capturedInput = input; return Task.FromResult>([]); // No results needed. } - var provider = new TextSearchProvider(SearchDelegateAsync, default, null, options); + var provider = new TextSearchProvider(SearchDelegateAsync, options); + var session = new TestAgentSession(); // Populate memory with more messages than the limit (A,B,C,D) -> should retain B,C,D var initialMessages = new[] @@ -387,11 +389,11 @@ public async Task InvokingAsync_WithRecentMessageMemory_ShouldIncludeStoredMessa new ChatMessage(ChatRole.User, "C"), new ChatMessage(ChatRole.Assistant, "D"), }; - await provider.InvokedAsync(new(s_mockAgent, s_mockSession, initialMessages)); + await provider.InvokedAsync(new(s_mockAgent, session, initialMessages)); var invokingContext = new AIContextProvider.InvokingContext( s_mockAgent, - s_mockSession, + session, [ new ChatMessage(ChatRole.User, "E") ]); @@ -419,28 +421,29 @@ public async Task InvokingAsync_WithAccumulatedMemoryAcrossInvocations_ShouldInc capturedInput = input; return Task.FromResult>([]); } - var provider = new TextSearchProvider(SearchDelegateAsync, default, null, options); + var provider = new TextSearchProvider(SearchDelegateAsync, options); + var session = new TestAgentSession(); // First memory update (A,B) await provider.InvokedAsync(new( - s_mockAgent, - s_mockSession, - [ - new ChatMessage(ChatRole.User, "A"), - new ChatMessage(ChatRole.Assistant, "B"), - ])); + s_mockAgent, + session, + [ + new ChatMessage(ChatRole.User, "A"), + new ChatMessage(ChatRole.Assistant, "B"), + ])); // Second memory update (C,D,E) await provider.InvokedAsync(new( - s_mockAgent, - s_mockSession, - [ - new ChatMessage(ChatRole.User, "C"), - new ChatMessage(ChatRole.Assistant, "D"), - new ChatMessage(ChatRole.User, "E"), - ])); + s_mockAgent, + session, + [ + new ChatMessage(ChatRole.User, "C"), + new ChatMessage(ChatRole.Assistant, "D"), + new ChatMessage(ChatRole.User, "E"), + ])); - var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "F")]); + var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, session, [new ChatMessage(ChatRole.User, "F")]); // Act await provider.InvokingAsync(invokingContext, CancellationToken.None); @@ -465,7 +468,8 @@ public async Task InvokingAsync_WithRecentMessageRolesIncluded_ShouldFilterRoles capturedInput = input; return Task.FromResult>([]); // No results needed for this test. } - var provider = new TextSearchProvider(SearchDelegateAsync, default, null, options); + var provider = new TextSearchProvider(SearchDelegateAsync, options); + var session = new TestAgentSession(); // Populate memory with mixed roles; only Assistant messages (A1,A2) should be retained. var initialMessages = new[] @@ -475,11 +479,11 @@ public async Task InvokingAsync_WithRecentMessageRolesIncluded_ShouldFilterRoles new ChatMessage(ChatRole.User, "U2"), new ChatMessage(ChatRole.Assistant, "A2"), }; - await provider.InvokedAsync(new(s_mockAgent, s_mockSession, initialMessages)); + await provider.InvokedAsync(new(s_mockAgent, session, initialMessages)); var invokingContext = new AIContextProvider.InvokingContext( s_mockAgent, - s_mockSession, + session, [ new ChatMessage(ChatRole.User, "Question?") // Current request message always appended. ]); @@ -496,26 +500,7 @@ public async Task InvokingAsync_WithRecentMessageRolesIncluded_ShouldFilterRoles #region Serialization Tests [Fact] - public void Serialize_WithNoRecentMessages_ShouldReturnEmptyState() - { - // Arrange - var options = new TextSearchProviderOptions - { - SearchTime = TextSearchProviderOptions.TextSearchBehavior.BeforeAIInvoke, - RecentMessageMemoryLimit = 3 - }; - var provider = new TextSearchProvider(this.NoResultSearchAsync, default, null, options); - - // Act - var state = provider.Serialize(); - - // Assert - Assert.Equal(JsonValueKind.Object, state.ValueKind); - Assert.False(state.TryGetProperty("recentMessagesText", out _)); - } - - [Fact] - public async Task Serialize_WithRecentMessages_ShouldPersistMessagesUpToLimitAsync() + public async Task InvokedAsync_ShouldPersistMessagesToSessionStateBagAsync() { // Arrange var options = new TextSearchProviderOptions @@ -524,7 +509,8 @@ public async Task Serialize_WithRecentMessages_ShouldPersistMessagesUpToLimitAsy RecentMessageMemoryLimit = 3, RecentMessageRolesIncluded = [ChatRole.User, ChatRole.Assistant] }; - var provider = new TextSearchProvider(this.NoResultSearchAsync, default, null, options); + var provider = new TextSearchProvider(this.NoResultSearchAsync, options); + var session = new TestAgentSession(); var messages = new[] { new ChatMessage(ChatRole.User, "M1"), @@ -533,11 +519,12 @@ public async Task Serialize_WithRecentMessages_ShouldPersistMessagesUpToLimitAsy }; // Act - await provider.InvokedAsync(new(s_mockAgent, s_mockSession, messages)); // Populate recent memory. - var state = provider.Serialize(); + await provider.InvokedAsync(new(s_mockAgent, session, messages)); // Populate recent memory. - // Assert - Assert.True(state.TryGetProperty("recentMessagesText", out var recentProperty)); + // Assert - State should be in the session's StateBag + var stateBagSerialized = session.StateBag.Serialize(); + Assert.True(stateBagSerialized.TryGetProperty("TextSearchProvider.RecentMessagesText", out var stateProperty)); + Assert.True(stateProperty.TryGetProperty("recentMessagesText", out var recentProperty)); Assert.Equal(JsonValueKind.Array, recentProperty.ValueKind); var list = recentProperty.EnumerateArray().Select(e => e.GetString()).ToList(); Assert.Equal(3, list.Count); @@ -545,7 +532,7 @@ public async Task Serialize_WithRecentMessages_ShouldPersistMessagesUpToLimitAsy } [Fact] - public async Task SerializeAndDeserialize_RoundtripRestoresMessagesAsync() + public async Task StateBag_RoundtripRestoresMessagesAsync() { // Arrange var options = new TextSearchProviderOptions @@ -554,7 +541,8 @@ public async Task SerializeAndDeserialize_RoundtripRestoresMessagesAsync() RecentMessageMemoryLimit = 4, RecentMessageRolesIncluded = [ChatRole.User, ChatRole.Assistant] }; - var provider = new TextSearchProvider(this.NoResultSearchAsync, default, null, options); + var provider = new TextSearchProvider(this.NoResultSearchAsync, options); + var session = new TestAgentSession(); var messages = new[] { new ChatMessage(ChatRole.User, "A"), @@ -562,23 +550,25 @@ public async Task SerializeAndDeserialize_RoundtripRestoresMessagesAsync() new ChatMessage(ChatRole.User, "C"), new ChatMessage(ChatRole.Assistant, "D"), }; - await provider.InvokedAsync(new(s_mockAgent, s_mockSession, messages)); + await provider.InvokedAsync(new(s_mockAgent, session, messages)); + + // Act - Serialize and deserialize the StateBag + var serializedStateBag = session.StateBag.Serialize(); + var restoredSession = new TestAgentSession(AgentSessionStateBag.Deserialize(serializedStateBag)); - // Act - var state = provider.Serialize(); string? capturedInput = null; Task> SearchDelegate2Async(string input, CancellationToken ct) { capturedInput = input; return Task.FromResult>([]); } - var roundTrippedProvider = new TextSearchProvider(SearchDelegate2Async, state, options: new TextSearchProviderOptions + var newProvider = new TextSearchProvider(SearchDelegate2Async, new TextSearchProviderOptions { SearchTime = TextSearchProviderOptions.TextSearchBehavior.BeforeAIInvoke, RecentMessageMemoryLimit = 4 }); var emptyMessages = Array.Empty(); - await roundTrippedProvider.InvokingAsync(new(s_mockAgent, s_mockSession, emptyMessages), CancellationToken.None); // Trigger search to read memory. + await newProvider.InvokingAsync(new(s_mockAgent, restoredSession, emptyMessages), CancellationToken.None); // Trigger search to read memory. // Assert Assert.NotNull(capturedInput); @@ -586,51 +576,10 @@ public async Task SerializeAndDeserialize_RoundtripRestoresMessagesAsync() } [Fact] - public async Task Deserialize_WithChangedLowerLimit_ShouldTruncateToNewLimitAsync() - { - // Arrange - var initialProvider = new TextSearchProvider(this.NoResultSearchAsync, default, null, new TextSearchProviderOptions - { - SearchTime = TextSearchProviderOptions.TextSearchBehavior.BeforeAIInvoke, - RecentMessageMemoryLimit = 5, - RecentMessageRolesIncluded = [ChatRole.User, ChatRole.Assistant] - }); - var messages = new[] - { - new ChatMessage(ChatRole.User, "L1"), - new ChatMessage(ChatRole.Assistant, "L2"), - new ChatMessage(ChatRole.User, "L3"), - new ChatMessage(ChatRole.Assistant, "L4"), - new ChatMessage(ChatRole.User, "L5"), - }; - await initialProvider.InvokedAsync(new(s_mockAgent, s_mockSession, messages)); - var state = initialProvider.Serialize(); - - string? capturedInput = null; - Task> SearchDelegate2Async(string input, CancellationToken ct) - { - capturedInput = input; - return Task.FromResult>([]); - } - - // Act - var restoredProvider = new TextSearchProvider(SearchDelegate2Async, state, options: new TextSearchProviderOptions - { - SearchTime = TextSearchProviderOptions.TextSearchBehavior.BeforeAIInvoke, - RecentMessageMemoryLimit = 3 // Lower limit - }); - await restoredProvider.InvokingAsync(new(s_mockAgent, s_mockSession, Array.Empty()), CancellationToken.None); - - // Assert - Assert.NotNull(capturedInput); - Assert.Equal("L1\nL2\nL3", capturedInput); - } - - [Fact] - public async Task Deserialize_WithEmptyState_ShouldHaveNoMessagesAsync() + public async Task InvokingAsync_WithEmptyStateBag_ShouldHaveNoMessagesAsync() { // Arrange - var emptyState = JsonSerializer.Deserialize("{}", TestJsonSerializerContext.Default.JsonElement); + var session = new TestAgentSession(); // Fresh session with empty StateBag string? capturedInput = null; Task> SearchDelegate2Async(string input, CancellationToken ct) @@ -640,17 +589,17 @@ public async Task Deserialize_WithEmptyState_ShouldHaveNoMessagesAsync() } // Act - var provider = new TextSearchProvider(SearchDelegate2Async, emptyState, options: new TextSearchProviderOptions + var provider = new TextSearchProvider(SearchDelegate2Async, new TextSearchProviderOptions { SearchTime = TextSearchProviderOptions.TextSearchBehavior.BeforeAIInvoke, RecentMessageMemoryLimit = 3 }); var emptyMessages = Array.Empty(); - await provider.InvokingAsync(new(s_mockAgent, s_mockSession, emptyMessages), CancellationToken.None); + await provider.InvokingAsync(new(s_mockAgent, session, emptyMessages), CancellationToken.None); // Assert Assert.NotNull(capturedInput); - Assert.Equal(string.Empty, capturedInput); // No recent messages serialized => empty input. + Assert.Equal(string.Empty, capturedInput); // No recent messages in StateBag => empty input. } #endregion @@ -669,4 +618,16 @@ private sealed class RawPayload { public string Id { get; set; } = string.Empty; } + + private sealed class TestAgentSession : AgentSession + { + public TestAgentSession() + { + } + + public TestAgentSession(AgentSessionStateBag stateBag) + { + this.StateBag = stateBag; + } + } } diff --git a/dotnet/tests/Microsoft.Agents.AI.UnitTests/Memory/ChatHistoryMemoryProviderTests.cs b/dotnet/tests/Microsoft.Agents.AI.UnitTests/Memory/ChatHistoryMemoryProviderTests.cs index 0a11e74528..164e64677e 100644 --- a/dotnet/tests/Microsoft.Agents.AI.UnitTests/Memory/ChatHistoryMemoryProviderTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.UnitTests/Memory/ChatHistoryMemoryProviderTests.cs @@ -3,7 +3,6 @@ using System; using System.Collections.Generic; using System.Linq; -using System.Text.Json; using System.Threading; using System.Threading.Tasks; using Microsoft.Extensions.AI; @@ -19,7 +18,6 @@ namespace Microsoft.Agents.AI.Memory.UnitTests; public class ChatHistoryMemoryProviderTests { private static readonly AIAgent s_mockAgent = new Mock().Object; - private static readonly AgentSession s_mockSession = new Mock().Object; private readonly Mock> _loggerMock; private readonly Mock _loggerFactoryMock; @@ -61,29 +59,49 @@ public ChatHistoryMemoryProviderTests() public void Constructor_Throws_ForNullVectorStore() { // Act & Assert - Assert.Throws(() => new ChatHistoryMemoryProvider(null!, "testcollection", 1, new ChatHistoryMemoryProviderScope() { UserId = "UID" })); + Assert.Throws(() => new ChatHistoryMemoryProvider( + null!, + "testcollection", + 1, + _ => new ChatHistoryMemoryProvider.State(new ChatHistoryMemoryProviderScope { UserId = "UID" }))); } [Fact] public void Constructor_Throws_ForNullCollectionName() { // Act & Assert - Assert.Throws(() => new ChatHistoryMemoryProvider(this._vectorStoreMock.Object, null!, 1, new ChatHistoryMemoryProviderScope() { UserId = "UID" })); + Assert.Throws(() => new ChatHistoryMemoryProvider( + this._vectorStoreMock.Object, + null!, + 1, + _ => new ChatHistoryMemoryProvider.State(new ChatHistoryMemoryProviderScope { UserId = "UID" }))); } [Fact] - public void Constructor_Throws_ForNullStorageScope() + public void Constructor_Throws_ForNullStateInitializer() { // Act & Assert - Assert.Throws(() => new ChatHistoryMemoryProvider(this._vectorStoreMock.Object, "testcollection", 1, null!)); + Assert.Throws(() => new ChatHistoryMemoryProvider( + this._vectorStoreMock.Object, + "testcollection", + 1, + null!)); } [Fact] public void Constructor_Throws_ForInvalidVectorDimensions() { // Act & Assert - Assert.Throws(() => new ChatHistoryMemoryProvider(this._vectorStoreMock.Object, "testcollection", 0, new ChatHistoryMemoryProviderScope() { UserId = "UID" })); - Assert.Throws(() => new ChatHistoryMemoryProvider(this._vectorStoreMock.Object, "testcollection", -5, new ChatHistoryMemoryProviderScope() { UserId = "UID" })); + Assert.Throws(() => new ChatHistoryMemoryProvider( + this._vectorStoreMock.Object, + "testcollection", + 0, + _ => new ChatHistoryMemoryProvider.State(new ChatHistoryMemoryProviderScope { UserId = "UID" }))); + Assert.Throws(() => new ChatHistoryMemoryProvider( + this._vectorStoreMock.Object, + "testcollection", + -5, + _ => new ChatHistoryMemoryProvider.State(new ChatHistoryMemoryProviderScope { UserId = "UID" }))); } #region InvokedAsync Tests @@ -113,13 +131,17 @@ public async Task InvokedAsync_UpsertsMessages_ToCollectionAsync() UserId = "user1" }; - var provider = new ChatHistoryMemoryProvider(this._vectorStoreMock.Object, TestCollectionName, 1, storeScope); + var provider = new ChatHistoryMemoryProvider( + this._vectorStoreMock.Object, + TestCollectionName, + 1, + _ => new ChatHistoryMemoryProvider.State(storeScope)); var requestMsgWithValues = new ChatMessage(ChatRole.User, "request text") { MessageId = "req-1", AuthorName = "user1", CreatedAt = new DateTimeOffset(new DateTime(2000, 1, 1), TimeSpan.Zero) }; var requestMsgWithNulls = new ChatMessage(ChatRole.User, "request text nulls"); var responseMsg = new ChatMessage(ChatRole.Assistant, "response text") { MessageId = "resp-1", AuthorName = "assistant" }; - var invokedContext = new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, [requestMsgWithValues, requestMsgWithNulls]) + var invokedContext = new AIContextProvider.InvokedContext(s_mockAgent, new TestAgentSession(), [requestMsgWithValues, requestMsgWithNulls]) { ResponseMessages = [responseMsg] }; @@ -175,9 +197,9 @@ public async Task InvokedAsync_DoesNotUpsertMessages_WhenInvokeFailedAsync() this._vectorStoreMock.Object, TestCollectionName, 1, - new ChatHistoryMemoryProviderScope() { UserId = "UID" }); + _ => new ChatHistoryMemoryProvider.State(new ChatHistoryMemoryProviderScope { UserId = "UID" })); var requestMsg = new ChatMessage(ChatRole.User, "request text") { MessageId = "req-1" }; - var invokedContext = new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, [requestMsg]) + var invokedContext = new AIContextProvider.InvokedContext(s_mockAgent, new TestAgentSession(), [requestMsg]) { InvokeException = new InvalidOperationException("Invoke failed") }; @@ -203,10 +225,10 @@ public async Task InvokedAsync_DoesNotThrow_WhenUpsertThrowsAsync() this._vectorStoreMock.Object, TestCollectionName, 1, - new ChatHistoryMemoryProviderScope() { UserId = "UID" }, + _ => new ChatHistoryMemoryProvider.State(new ChatHistoryMemoryProviderScope { UserId = "UID" }), loggerFactory: this._loggerFactoryMock.Object); var requestMsg = new ChatMessage(ChatRole.User, "request text") { MessageId = "req-1" }; - var invokedContext = new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, [requestMsg]); + var invokedContext = new AIContextProvider.InvokedContext(s_mockAgent, new TestAgentSession(), [requestMsg]); // Act await provider.InvokedAsync(invokedContext, CancellationToken.None); @@ -252,12 +274,12 @@ public async Task InvokedAsync_LogsUserIdBasedOnEnableSensitiveTelemetryDataAsyn this._vectorStoreMock.Object, TestCollectionName, 1, - new ChatHistoryMemoryProviderScope { UserId = "user1" }, + _ => new ChatHistoryMemoryProvider.State(new ChatHistoryMemoryProviderScope { UserId = "user1" }), options: options, loggerFactory: this._loggerFactoryMock.Object); var requestMsg = new ChatMessage(ChatRole.User, "request text"); - var invokedContext = new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, [requestMsg]); + var invokedContext = new AIContextProvider.InvokedContext(s_mockAgent, new TestAgentSession(), [requestMsg]); // Act await provider.InvokedAsync(invokedContext, CancellationToken.None); @@ -326,11 +348,11 @@ public async Task InvokedAsync_SearchesVectorStoreAsync() this._vectorStoreMock.Object, TestCollectionName, 1, - new ChatHistoryMemoryProviderScope() { UserId = "UID" }, + _ => new ChatHistoryMemoryProvider.State(new ChatHistoryMemoryProviderScope { UserId = "UID" }), options: providerOptions); var requestMsg = new ChatMessage(ChatRole.User, "requesting relevant history"); - var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [requestMsg]); + var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, new TestAgentSession(), [requestMsg]); // Act await provider.InvokingAsync(invokingContext, CancellationToken.None); @@ -378,10 +400,15 @@ public async Task InvokedAsync_CreatesFilter_WhenSearchScopeProvidedAsync() }) .Returns(ToAsyncEnumerableAsync(new List>>())); - var provider = new ChatHistoryMemoryProvider(this._vectorStoreMock.Object, TestCollectionName, 1, options: providerOptions, storageScope: searchScope, searchScope: searchScope); + var provider = new ChatHistoryMemoryProvider( + this._vectorStoreMock.Object, + TestCollectionName, + 1, + _ => new ChatHistoryMemoryProvider.State(searchScope, searchScope), + options: providerOptions); var requestMsg = new ChatMessage(ChatRole.User, "requesting relevant history"); - var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [requestMsg]); + var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, new TestAgentSession(), [requestMsg]); // Act await provider.InvokingAsync(invokingContext, CancellationToken.None); @@ -440,12 +467,11 @@ public async Task InvokingAsync_LogsUserIdBasedOnEnableSensitiveTelemetryDataAsy this._vectorStoreMock.Object, TestCollectionName, 1, - storageScope: scope, - searchScope: scope, + _ => new ChatHistoryMemoryProvider.State(scope, scope), options: options, loggerFactory: this._loggerFactoryMock.Object); - var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "requesting relevant history")]); + var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, new TestAgentSession(), [new ChatMessage(ChatRole.User, "requesting relevant history")]); // Act await provider.InvokingAsync(invokingContext, CancellationToken.None); @@ -479,56 +505,6 @@ public async Task InvokingAsync_LogsUserIdBasedOnEnableSensitiveTelemetryDataAsy #endregion - #region Serialization Tests - - [Fact] - public void Serialize_Deserialize_RoundtripsScopes() - { - // Arrange - var storageScope = new ChatHistoryMemoryProviderScope - { - ApplicationId = "app", - AgentId = "agent", - SessionId = "session", - UserId = "user" - }; - - var searchScope = new ChatHistoryMemoryProviderScope - { - ApplicationId = "app2", - AgentId = "agent2", - SessionId = "session2", - UserId = "user2" - }; - - var provider = new ChatHistoryMemoryProvider(this._vectorStoreMock.Object, TestCollectionName, 1, storageScope: storageScope, searchScope: searchScope); - - // Act - var stateElement = provider.Serialize(); - - using JsonDocument doc = JsonDocument.Parse(stateElement.GetRawText()); - var storage = doc.RootElement.GetProperty("storageScope"); - Assert.Equal("app", storage.GetProperty("applicationId").GetString()); - Assert.Equal("agent", storage.GetProperty("agentId").GetString()); - Assert.Equal("session", storage.GetProperty("sessionId").GetString()); - Assert.Equal("user", storage.GetProperty("userId").GetString()); - - var search = doc.RootElement.GetProperty("searchScope"); - Assert.Equal("app2", search.GetProperty("applicationId").GetString()); - Assert.Equal("agent2", search.GetProperty("agentId").GetString()); - Assert.Equal("session2", search.GetProperty("sessionId").GetString()); - Assert.Equal("user2", search.GetProperty("userId").GetString()); - - // Act - deserialize and serialize again - var provider2 = new ChatHistoryMemoryProvider(this._vectorStoreMock.Object, TestCollectionName, 1, serializedState: stateElement); - var stateElement2 = provider2.Serialize(); - - // Assert - roundtrip the state - Assert.Equal(stateElement.GetRawText(), stateElement2.GetRawText()); - } - - #endregion - private static async IAsyncEnumerable ToAsyncEnumerableAsync(IEnumerable values) { await Task.Yield(); @@ -537,4 +513,16 @@ private static async IAsyncEnumerable ToAsyncEnumerableAsync(IEnumerable))] [JsonSerializable(typeof(ChatClientAgentSessionTests.Animal))] +[JsonSerializable(typeof(ChatClientAgentSession))] internal sealed partial class TestJsonSerializerContext : JsonSerializerContext; diff --git a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/AgentWorkflowBuilderTests.cs b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/AgentWorkflowBuilderTests.cs index a21eda21c4..c0c1c92a3c 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/AgentWorkflowBuilderTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/AgentWorkflowBuilderTests.cs @@ -161,7 +161,7 @@ protected override async IAsyncEnumerable RunCoreStreamingA } } - private sealed class DoubleEchoAgentSession() : InMemoryAgentSession(); + private sealed class DoubleEchoAgentSession() : AgentSession(); [Fact] public async Task BuildConcurrent_AgentsRunInParallelAsync() diff --git a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/InProcessExecutionTests.cs b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/InProcessExecutionTests.cs index f12ffc6988..922f7a300f 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/InProcessExecutionTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/InProcessExecutionTests.cs @@ -195,5 +195,5 @@ protected override async IAsyncEnumerable RunCoreStreamingA /// /// Simple session implementation for SimpleTestAgent. /// - private sealed class SimpleTestAgentSession : InMemoryAgentSession; + private sealed class SimpleTestAgentSession : AgentSession; } diff --git a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/RoleCheckAgent.cs b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/RoleCheckAgent.cs index 48fc432eeb..b1b149de23 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/RoleCheckAgent.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/RoleCheckAgent.cs @@ -46,5 +46,5 @@ protected override async IAsyncEnumerable RunCoreStreamingA }; } - private sealed class RoleCheckAgentSession : InMemoryAgentSession; + private sealed class RoleCheckAgentSession : AgentSession; } diff --git a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/Sample/06_GroupChat_Workflow.cs b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/Sample/06_GroupChat_Workflow.cs index e4c905f814..1d9bc40516 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/Sample/06_GroupChat_Workflow.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/Sample/06_GroupChat_Workflow.cs @@ -90,4 +90,4 @@ protected override async IAsyncEnumerable RunCoreStreamingA } } -internal sealed class HelloAgentSession() : InMemoryAgentSession(); +internal sealed class HelloAgentSession() : AgentSession(); diff --git a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/TestEchoAgent.cs b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/TestEchoAgent.cs index 088c862efa..4eda20b956 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/TestEchoAgent.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/TestEchoAgent.cs @@ -5,6 +5,7 @@ using System.Linq; using System.Runtime.CompilerServices; using System.Text.Json; +using System.Text.Json.Serialization; using System.Threading; using System.Threading.Tasks; using Microsoft.Extensions.AI; @@ -16,6 +17,8 @@ internal class TestEchoAgent(string? id = null, string? name = null, string? pre protected override string? IdCore => id; public override string? Name => name ?? base.Name; + public InMemoryChatHistoryProvider ChatHistoryProvider { get; } = new(); + protected override async ValueTask DeserializeSessionCoreAsync(JsonElement serializedState, JsonSerializerOptions? jsonSerializerOptions = null, CancellationToken cancellationToken = default) { return serializedState.Deserialize(jsonSerializerOptions) ?? await this.CreateSessionAsync(cancellationToken); @@ -28,15 +31,15 @@ protected override JsonElement SerializeSessionCore(AgentSession session, JsonSe throw new InvalidOperationException("The provided session is not compatible with the agent. Only sessions created by the agent can be serialized."); } - return typedSession.Serialize(jsonSerializerOptions); + return JsonSerializer.SerializeToElement(typedSession, jsonSerializerOptions); } protected override ValueTask CreateSessionCoreAsync(CancellationToken cancellationToken = default) => new(new EchoAgentSession()); - private static ChatMessage UpdateSession(ChatMessage message, InMemoryAgentSession? session = null) + private ChatMessage UpdateSession(ChatMessage message, AgentSession? session = null) { - session?.ChatHistoryProvider.Add(message); + this.ChatHistoryProvider.GetMessages(session).Add(message); return message; } @@ -45,7 +48,7 @@ private IEnumerable EchoMessages(IEnumerable messages, { foreach (ChatMessage message in messages) { - UpdateSession(message, session as InMemoryAgentSession); + this.UpdateSession(message, session); } IEnumerable echoMessages @@ -53,14 +56,14 @@ IEnumerable echoMessages where message.Role == ChatRole.User && !string.IsNullOrEmpty(message.Text) select - UpdateSession(new ChatMessage(ChatRole.Assistant, $"{prefix}{message.Text}") + this.UpdateSession(new ChatMessage(ChatRole.Assistant, $"{prefix}{message.Text}") { AuthorName = this.Name ?? this.Id, CreatedAt = DateTimeOffset.Now, MessageId = Guid.NewGuid().ToString("N") - }, session as InMemoryAgentSession); + }, session); - return echoMessages.Concat(this.GetEpilogueMessages(options).Select(m => UpdateSession(m, session as InMemoryAgentSession))); + return echoMessages.Concat(this.GetEpilogueMessages(options).Select(m => this.UpdateSession(m, session))); } protected virtual IEnumerable GetEpilogueMessages(AgentRunOptions? options = null) @@ -99,11 +102,11 @@ protected override async IAsyncEnumerable RunCoreStreamingA } } - private sealed class EchoAgentSession : InMemoryAgentSession + private sealed class EchoAgentSession : AgentSession { - internal new JsonElement Serialize(JsonSerializerOptions? jsonSerializerOptions = null) - { - return base.Serialize(jsonSerializerOptions); - } + internal EchoAgentSession() { } + + [JsonConstructor] + internal EchoAgentSession(AgentSessionStateBag stateBag) : base(stateBag) { } } } diff --git a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/TestReplayAgent.cs b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/TestReplayAgent.cs index 032ba9001c..1721cdd4c4 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/TestReplayAgent.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/TestReplayAgent.cs @@ -104,5 +104,5 @@ protected override async IAsyncEnumerable RunCoreStreamingA return candidateMessages; } - private sealed class ReplayAgentSession() : InMemoryAgentSession(); + private sealed class ReplayAgentSession() : AgentSession(); } diff --git a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/TestRequestAgent.cs b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/TestRequestAgent.cs index 63cd8dd6f0..5533d97091 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/TestRequestAgent.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/TestRequestAgent.cs @@ -330,7 +330,7 @@ static TRequest AssertAndExtractRequestContent(ExternalRequest request } } - private sealed class TestRequestAgentSession : InMemoryAgentSession + private sealed class TestRequestAgentSession : AgentSession where TRequest : AIContent where TResponse : AIContent { @@ -343,19 +343,13 @@ public TestRequestAgentSession() public HashSet ServicedRequests { get; } = new(); public HashSet PairedRequests { get; } = new(); - private static JsonElement DeserializeAndExtractState(JsonElement serializedState, - out TestRequestAgentSessionState state, - JsonSerializerOptions? jsonSerializerOptions = null) + public TestRequestAgentSession(JsonElement element, JsonSerializerOptions? jsonSerializerOptions = null) { - state = JsonSerializer.Deserialize(serializedState, jsonSerializerOptions) + var state = JsonSerializer.Deserialize(element, jsonSerializerOptions) ?? throw new ArgumentException("Unable to deserialize session state."); - return state.SessionState; - } + this.StateBag = AgentSessionStateBag.Deserialize(state.SessionState); - public TestRequestAgentSession(JsonElement element, JsonSerializerOptions? jsonSerializerOptions = null) - : base(DeserializeAndExtractState(element, out TestRequestAgentSessionState state, jsonSerializerOptions)) - { this.UnservicedRequests = state.UnservicedRequests.ToDictionary( keySelector: item => item.Key, elementSelector: item => item.Value.As()!); @@ -364,9 +358,9 @@ public TestRequestAgentSession(JsonElement element, JsonSerializerOptions? jsonS this.PairedRequests = state.PairedRequests; } - protected override JsonElement Serialize(JsonSerializerOptions? jsonSerializerOptions = null) + internal JsonElement Serialize(JsonSerializerOptions? jsonSerializerOptions = null) { - JsonElement sessionState = base.Serialize(jsonSerializerOptions); + JsonElement sessionState = this.StateBag.Serialize(); Dictionary portableUnservicedRequests = this.UnservicedRequests.ToDictionary( diff --git a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/WorkflowHostSmokeTests.cs b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/WorkflowHostSmokeTests.cs index ac6485131d..e2f01d675f 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/WorkflowHostSmokeTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/WorkflowHostSmokeTests.cs @@ -32,18 +32,16 @@ public class WorkflowHostSmokeTests { private sealed class AlwaysFailsAIAgent(bool failByThrowing) : AIAgent { - private sealed class Session : InMemoryAgentSession + private sealed class Session : AgentSession { public Session() { } - public Session(JsonElement serializedSession, JsonSerializerOptions? jsonSerializerOptions = null) - : base(serializedSession, jsonSerializerOptions) - { } + public Session(AgentSessionStateBag stateBag) : base(stateBag) { } } protected override ValueTask DeserializeSessionCoreAsync(JsonElement serializedState, JsonSerializerOptions? jsonSerializerOptions = null, CancellationToken cancellationToken = default) { - return new(new Session(serializedState, jsonSerializerOptions)); + return new(serializedState.Deserialize(jsonSerializerOptions)!); } protected override ValueTask CreateSessionCoreAsync(CancellationToken cancellationToken = default) diff --git a/dotnet/tests/OpenAIChatCompletion.IntegrationTests/OpenAIChatCompletionFixture.cs b/dotnet/tests/OpenAIChatCompletion.IntegrationTests/OpenAIChatCompletionFixture.cs index 304df28fba..96a9a17ae8 100644 --- a/dotnet/tests/OpenAIChatCompletion.IntegrationTests/OpenAIChatCompletionFixture.cs +++ b/dotnet/tests/OpenAIChatCompletion.IntegrationTests/OpenAIChatCompletionFixture.cs @@ -30,14 +30,14 @@ public OpenAIChatCompletionFixture(bool useReasoningChatModel) public async Task> GetChatHistoryAsync(AIAgent agent, AgentSession session) { - var typedSession = (ChatClientAgentSession)session; + var chatHistoryProvider = agent.GetService(); - if (typedSession.ChatHistoryProvider is null) + if (chatHistoryProvider is null) { return []; } - return (await typedSession.ChatHistoryProvider.InvokingAsync(new(agent, session, []))).ToList(); + return (await chatHistoryProvider.InvokingAsync(new(agent, session, []))).ToList(); } public Task CreateChatClientAgentAsync( diff --git a/dotnet/tests/OpenAIResponse.IntegrationTests/OpenAIResponseFixture.cs b/dotnet/tests/OpenAIResponse.IntegrationTests/OpenAIResponseFixture.cs index 719db6a0b0..e36f8990f6 100644 --- a/dotnet/tests/OpenAIResponse.IntegrationTests/OpenAIResponseFixture.cs +++ b/dotnet/tests/OpenAIResponse.IntegrationTests/OpenAIResponseFixture.cs @@ -50,12 +50,14 @@ public async Task> GetChatHistoryAsync(AIAgent agent, AgentSes return [.. previousMessages, responseMessage]; } - if (typedSession.ChatHistoryProvider is null) + var chatHistoryProvider = agent.GetService(); + + if (chatHistoryProvider is null) { return []; } - return (await typedSession.ChatHistoryProvider.InvokingAsync(new(agent, session, []))).ToList(); + return (await chatHistoryProvider.InvokingAsync(new(agent, session, []))).ToList(); } private static ChatMessage ConvertToChatMessage(ResponseItem item)