diff --git a/dotnet/samples/GettingStarted/AgentProviders/Agent_With_CustomImplementation/Program.cs b/dotnet/samples/GettingStarted/AgentProviders/Agent_With_CustomImplementation/Program.cs index 1c9c9a3964..786fbea36b 100644 --- a/dotnet/samples/GettingStarted/AgentProviders/Agent_With_CustomImplementation/Program.cs +++ b/dotnet/samples/GettingStarted/AgentProviders/Agent_With_CustomImplementation/Program.cs @@ -6,6 +6,7 @@ using System.Runtime.CompilerServices; using System.Text.Json; +using System.Text.Json.Serialization; using Microsoft.Agents.AI; using Microsoft.Extensions.AI; using SampleApp; @@ -28,6 +29,8 @@ internal sealed class UpperCaseParrotAgent : AIAgent { public override string? Name => "UpperCaseParrotAgent"; + public readonly ChatHistoryProvider ChatHistoryProvider = new InMemoryChatHistoryProvider(); + protected override ValueTask CreateSessionCoreAsync(CancellationToken cancellationToken = default) => new(new CustomAgentSession()); @@ -38,11 +41,11 @@ protected override ValueTask SerializeSessionCoreAsync(AgentSession throw new ArgumentException($"The provided session is not of type {nameof(CustomAgentSession)}.", nameof(session)); } - return new(typedSession.Serialize(jsonSerializerOptions)); + return new(JsonSerializer.SerializeToElement(typedSession, jsonSerializerOptions)); } protected override ValueTask DeserializeSessionCoreAsync(JsonElement serializedState, JsonSerializerOptions? jsonSerializerOptions = null, CancellationToken cancellationToken = default) - => new(new CustomAgentSession(serializedState, jsonSerializerOptions)); + => new(serializedState.Deserialize(jsonSerializerOptions)!); protected override async Task RunCoreAsync(IEnumerable messages, AgentSession? session = null, AgentRunOptions? options = null, CancellationToken cancellationToken = default) { @@ -56,17 +59,14 @@ protected override async Task RunCoreAsync(IEnumerable responseMessages = CloneAndToUpperCase(messages, this.Name).ToList(); // Notify the session of the input and output messages. - var invokedContext = new ChatHistoryProvider.InvokedContext(this, session, messages) - { - ResponseMessages = responseMessages - }; - await typedSession.ChatHistoryProvider.InvokedAsync(invokedContext, cancellationToken); + var invokedContext = new ChatHistoryProvider.InvokedContext(this, session, userAndChatHistoryMessages, responseMessages); + await this.ChatHistoryProvider.InvokedAsync(invokedContext, cancellationToken); return new AgentResponse { @@ -88,17 +88,14 @@ protected override async IAsyncEnumerable RunCoreStreamingA // Get existing messages from the store var invokingContext = new ChatHistoryProvider.InvokingContext(this, session, messages); - var storeMessages = await typedSession.ChatHistoryProvider.InvokingAsync(invokingContext, cancellationToken); + var userAndChatHistoryMessages = await this.ChatHistoryProvider.InvokingAsync(invokingContext, cancellationToken); // Clone the input messages and turn them into response messages with upper case text. List responseMessages = CloneAndToUpperCase(messages, this.Name).ToList(); // Notify the session of the input and output messages. - var invokedContext = new ChatHistoryProvider.InvokedContext(this, session, messages) - { - ResponseMessages = responseMessages - }; - await typedSession.ChatHistoryProvider.InvokedAsync(invokedContext, cancellationToken); + var invokedContext = new ChatHistoryProvider.InvokedContext(this, session, userAndChatHistoryMessages, responseMessages); + await this.ChatHistoryProvider.InvokedAsync(invokedContext, cancellationToken); foreach (var message in responseMessages) { @@ -140,15 +137,16 @@ private static IEnumerable CloneAndToUpperCase(IEnumerable /// A session type for our custom agent that only supports in memory storage of messages. /// - internal sealed class CustomAgentSession : InMemoryAgentSession + internal sealed class CustomAgentSession : AgentSession { - internal CustomAgentSession() { } - - internal CustomAgentSession(JsonElement serializedSessionState, JsonSerializerOptions? jsonSerializerOptions = null) - : base(serializedSessionState, jsonSerializerOptions) { } + internal CustomAgentSession() + { + } - internal new JsonElement Serialize(JsonSerializerOptions? jsonSerializerOptions = null) - => base.Serialize(jsonSerializerOptions); + [JsonConstructor] + internal CustomAgentSession(AgentSessionStateBag stateBag) : base(stateBag) + { + } } } } diff --git a/dotnet/samples/GettingStarted/AgentWithMemory/AgentWithMemory_Step01_ChatHistoryMemory/Program.cs b/dotnet/samples/GettingStarted/AgentWithMemory/AgentWithMemory_Step01_ChatHistoryMemory/Program.cs index 4e2065e0eb..ff4628ef7a 100644 --- a/dotnet/samples/GettingStarted/AgentWithMemory/AgentWithMemory_Step01_ChatHistoryMemory/Program.cs +++ b/dotnet/samples/GettingStarted/AgentWithMemory/AgentWithMemory_Step01_ChatHistoryMemory/Program.cs @@ -37,16 +37,21 @@ { ChatOptions = new() { Instructions = "You are good at telling jokes." }, Name = "Joker", - AIContextProviderFactory = (ctx, ct) => new ValueTask(new ChatHistoryMemoryProvider( + AIContextProviders = [new ChatHistoryMemoryProvider( vectorStore, collectionName: "chathistory", vectorDimensions: 3072, - // Configure the scope values under which chat messages will be stored. - // In this case, we are using a fixed user ID and a unique session ID for each new session. - storageScope: new() { UserId = "UID1", SessionId = Guid.NewGuid().ToString() }, - // Configure the scope which would be used to search for relevant prior messages. - // In this case, we are searching for any messages for the user across all sessions. - searchScope: new() { UserId = "UID1" })) + // Callback to configure the initial state of the ChatHistoryMemoryProvider. + // The ChatHistoryMemoryProvider stores its state in the AgentSession and this callback + // will be called whenever the ChatHistoryMemoryProvider cannot find existing state in the session, + // typically the first time it is used with a new session. + session => new ChatHistoryMemoryProvider.State( + // Configure the scope values under which chat messages will be stored. + // In this case, we are using a fixed user ID and a unique session ID for each new session. + storageScope: new() { UserId = "UID1", SessionId = Guid.NewGuid().ToString() }, + // Configure the scope which would be used to search for relevant prior messages. + // In this case, we are searching for any messages for the user across all sessions. + searchScope: new() { UserId = "UID1" }))] }); // Start a new session for the agent conversation. diff --git a/dotnet/samples/GettingStarted/AgentWithMemory/AgentWithMemory_Step02_MemoryUsingMem0/Program.cs b/dotnet/samples/GettingStarted/AgentWithMemory/AgentWithMemory_Step02_MemoryUsingMem0/Program.cs index a81c496b5e..d1d6a60f9d 100644 --- a/dotnet/samples/GettingStarted/AgentWithMemory/AgentWithMemory_Step02_MemoryUsingMem0/Program.cs +++ b/dotnet/samples/GettingStarted/AgentWithMemory/AgentWithMemory_Step02_MemoryUsingMem0/Program.cs @@ -34,20 +34,21 @@ .AsAIAgent(new ChatClientAgentOptions() { ChatOptions = new() { Instructions = "You are a friendly travel assistant. Use known memories about the user when responding, and do not invent details." }, - AIContextProviderFactory = (ctx, ct) => new ValueTask(ctx.SerializedState.ValueKind is not JsonValueKind.Null and not JsonValueKind.Undefined - // If each session should have its own Mem0 scope, you can create a new id per session here: - // ? new Mem0Provider(mem0HttpClient, new Mem0ProviderScope() { ThreadId = Guid.NewGuid().ToString() }) - // In this case we are storing memories scoped by application and user instead so that memories are retained across threads. - ? new Mem0Provider(mem0HttpClient, new Mem0ProviderScope() { ApplicationId = "getting-started-agents", UserId = "sample-user" }) - // For cases where we are restoring from serialized state: - : new Mem0Provider(mem0HttpClient, ctx.SerializedState, ctx.JsonSerializerOptions)) + // The stateInitializer can be used to customize the Mem0 scope per session and it will be called each time a session + // is encountered by the Mem0Provider that does not already have Mem0Provider state stored on the session. + // If each session should have its own Mem0 scope, you can create a new id per session via the stateInitializer, e.g.: + // new Mem0Provider(mem0HttpClient, stateInitializer: _ => new(new Mem0ProviderScope() { ThreadId = Guid.NewGuid().ToString() })) + // In our case we are storing memories scoped by application and user instead so that memories are retained across threads. + AIContextProviders = [new Mem0Provider(mem0HttpClient, stateInitializer: _ => new(new Mem0ProviderScope() { ApplicationId = "getting-started-agents", UserId = "sample-user" }))] }); AgentSession session = await agent.CreateSessionAsync(); // Clear any existing memories for this scope to demonstrate fresh behavior. -Mem0Provider mem0Provider = session.GetService()!; -await mem0Provider.ClearStoredMemoriesAsync(); +// Note that the ClearStoredMemoriesAsync method will clear memories +// using the scope stored in the session, or provided via the stateInitializer. +Mem0Provider mem0Provider = agent.GetService()!; +await mem0Provider.ClearStoredMemoriesAsync(session); Console.WriteLine(await agent.RunAsync("Hi there! My name is Taylor and I'm planning a hiking trip to Patagonia in November.", session)); Console.WriteLine(await agent.RunAsync("I'm travelling with my sister and we love finding scenic viewpoints.", session)); diff --git a/dotnet/samples/GettingStarted/AgentWithMemory/AgentWithMemory_Step03_CustomMemory/Program.cs b/dotnet/samples/GettingStarted/AgentWithMemory/AgentWithMemory_Step03_CustomMemory/Program.cs index 4a736674fc..bac72d9b31 100644 --- a/dotnet/samples/GettingStarted/AgentWithMemory/AgentWithMemory_Step03_CustomMemory/Program.cs +++ b/dotnet/samples/GettingStarted/AgentWithMemory/AgentWithMemory_Step03_CustomMemory/Program.cs @@ -36,7 +36,7 @@ AIAgent agent = chatClient.AsAIAgent(new ChatClientAgentOptions() { ChatOptions = new() { Instructions = "You are a friendly assistant. Always address the user by their name." }, - AIContextProviderFactory = (ctx, ct) => new ValueTask(new UserInfoMemory(chatClient.AsIChatClient(), ctx.SerializedState, ctx.JsonSerializerOptions)) + AIContextProviders = [new UserInfoMemory(chatClient.AsIChatClient())] }); // Create a new session for the conversation. @@ -58,10 +58,10 @@ var deserializedSession = await agent.DeserializeSessionAsync(sesionElement); Console.WriteLine(await agent.RunAsync("What is my name and age?", deserializedSession)); -Console.WriteLine("\n>> Read memories from memory component\n"); +Console.WriteLine("\n>> Read memories using memory component\n"); -// It's possible to access the memory component via the session's GetService method. -var userInfo = deserializedSession.GetService()?.UserInfo; +// It's possible to access the memory component via the agent's GetService method. +var userInfo = agent.GetService()?.GetUserInfo(deserializedSession); // Output the user info that was captured by the memory component. Console.WriteLine($"MEMORY - User Name: {userInfo?.UserName}"); @@ -69,12 +69,12 @@ Console.WriteLine("\n>> Use new session with previously created memories\n"); -// It is also possible to set the memories in a memory component on an individual session. +// It is also possible to set the memories using a memory component on an individual session. // This is useful if we want to start a new session, but have it share the same memories as a previous session. var newSession = await agent.CreateSessionAsync(); -if (userInfo is not null && newSession.GetService() is UserInfoMemory newSessionMemory) +if (userInfo is not null && agent.GetService() is UserInfoMemory newSessionMemory) { - newSessionMemory.UserInfo = userInfo; + newSessionMemory.SetUserInfo(newSession, userInfo); } // Invoke the agent and output the text result. @@ -89,28 +89,27 @@ namespace SampleApp internal sealed class UserInfoMemory : AIContextProvider { private readonly IChatClient _chatClient; + private readonly Func _stateInitializer; - public UserInfoMemory(IChatClient chatClient, UserInfo? userInfo = null) + public UserInfoMemory(IChatClient chatClient, Func? stateInitializer = null) { this._chatClient = chatClient; - this.UserInfo = userInfo ?? new UserInfo(); + this._stateInitializer = stateInitializer ?? (_ => new UserInfo()); } - public UserInfoMemory(IChatClient chatClient, JsonElement serializedState, JsonSerializerOptions? jsonSerializerOptions = null) - { - this._chatClient = chatClient; - - this.UserInfo = serializedState.ValueKind == JsonValueKind.Object ? - serializedState.Deserialize(jsonSerializerOptions)! : - new UserInfo(); - } + public UserInfo GetUserInfo(AgentSession session) + => session.StateBag.GetValue(nameof(UserInfoMemory)) ?? new UserInfo(); - public UserInfo UserInfo { get; set; } + public void SetUserInfo(AgentSession session, UserInfo userInfo) + => session.StateBag.SetValue(nameof(UserInfoMemory), userInfo); protected override async ValueTask InvokedCoreAsync(InvokedContext context, CancellationToken cancellationToken = default) { + var userInfo = context.Session?.StateBag.GetValue(nameof(UserInfoMemory)) + ?? this._stateInitializer.Invoke(context.Session); + // Try and extract the user name and age from the message if we don't have it already and it's a user message. - if ((this.UserInfo.UserName is null || this.UserInfo.UserAge is null) && context.RequestMessages.Any(x => x.Role == ChatRole.User)) + if ((userInfo.UserName is null || userInfo.UserAge is null) && context.RequestMessages.Any(x => x.Role == ChatRole.User)) { var result = await this._chatClient.GetResponseAsync( context.RequestMessages, @@ -120,36 +119,43 @@ protected override async ValueTask InvokedCoreAsync(InvokedContext context, Canc }, cancellationToken: cancellationToken); - this.UserInfo.UserName ??= result.Result.UserName; - this.UserInfo.UserAge ??= result.Result.UserAge; + userInfo.UserName ??= result.Result.UserName; + userInfo.UserAge ??= result.Result.UserAge; } + + context.Session?.StateBag.SetValue(nameof(UserInfoMemory), userInfo); } protected override ValueTask InvokingCoreAsync(InvokingContext context, CancellationToken cancellationToken = default) { + var inputContext = context.AIContext; + var userInfo = context.Session?.StateBag.GetValue(nameof(UserInfoMemory)) + ?? this._stateInitializer.Invoke(context.Session); + StringBuilder instructions = new(); + if (!string.IsNullOrEmpty(inputContext.Instructions)) + { + instructions.AppendLine(inputContext.Instructions); + } // If we don't already know the user's name and age, add instructions to ask for them, otherwise just provide what we have to the context. instructions .AppendLine( - this.UserInfo.UserName is null ? + userInfo.UserName is null ? "Ask the user for their name and politely decline to answer any questions until they provide it." : - $"The user's name is {this.UserInfo.UserName}.") + $"The user's name is {userInfo.UserName}.") .AppendLine( - this.UserInfo.UserAge is null ? + userInfo.UserAge is null ? "Ask the user for their age and politely decline to answer any questions until they provide it." : - $"The user's age is {this.UserInfo.UserAge}."); + $"The user's age is {userInfo.UserAge}."); return new ValueTask(new AIContext { - Instructions = instructions.ToString() + Instructions = instructions.ToString(), + Messages = inputContext.Messages, + Tools = inputContext.Tools }); } - - public override JsonElement Serialize(JsonSerializerOptions? jsonSerializerOptions = null) - { - return JsonSerializer.SerializeToElement(this.UserInfo, jsonSerializerOptions); - } } internal sealed class UserInfo diff --git a/dotnet/samples/GettingStarted/AgentWithRAG/AgentWithRAG_Step01_BasicTextRAG/Program.cs b/dotnet/samples/GettingStarted/AgentWithRAG/AgentWithRAG_Step01_BasicTextRAG/Program.cs index 516585f7dc..c04601d940 100644 --- a/dotnet/samples/GettingStarted/AgentWithRAG/AgentWithRAG_Step01_BasicTextRAG/Program.cs +++ b/dotnet/samples/GettingStarted/AgentWithRAG/AgentWithRAG_Step01_BasicTextRAG/Program.cs @@ -65,12 +65,16 @@ .AsAIAgent(new ChatClientAgentOptions { ChatOptions = new() { Instructions = "You are a helpful support specialist for Contoso Outdoors. Answer questions using the provided context and cite the source document when available." }, - AIContextProviderFactory = (ctx, ct) => new ValueTask(new TextSearchProvider(SearchAdapter, ctx.SerializedState, ctx.JsonSerializerOptions, textSearchOptions)), - // Since we are using ChatCompletion which stores chat history locally, we can also add a message removal policy + AIContextProviders = [new TextSearchProvider(SearchAdapter, textSearchOptions)], + // Since we are using ChatCompletion which stores chat history locally, we can also add a message filter // that removes messages produced by the TextSearchProvider before they are added to the chat history, so that // we don't bloat chat history with all the search result messages. - ChatHistoryProviderFactory = (ctx, ct) => new ValueTask(new InMemoryChatHistoryProvider(ctx.SerializedState, ctx.JsonSerializerOptions) - .WithAIContextProviderMessageRemoval()), + // By default the chat history provider will store all messages, except for those that came from chat history in the first place. + // We also want to maintain that exclusion here. + ChatHistoryProvider = new InMemoryChatHistoryProvider(new InMemoryChatHistoryProviderOptions + { + StorageInputMessageFilter = messages => messages.Where(m => m.GetAgentRequestMessageSourceType() != AgentRequestMessageSourceType.AIContextProvider && m.GetAgentRequestMessageSourceType() != AgentRequestMessageSourceType.ChatHistory) + }), }); AgentSession session = await agent.CreateSessionAsync(); diff --git a/dotnet/samples/GettingStarted/AgentWithRAG/AgentWithRAG_Step02_CustomVectorStoreRAG/Program.cs b/dotnet/samples/GettingStarted/AgentWithRAG/AgentWithRAG_Step02_CustomVectorStoreRAG/Program.cs index 4120f2d604..f06fda6a5b 100644 --- a/dotnet/samples/GettingStarted/AgentWithRAG/AgentWithRAG_Step02_CustomVectorStoreRAG/Program.cs +++ b/dotnet/samples/GettingStarted/AgentWithRAG/AgentWithRAG_Step02_CustomVectorStoreRAG/Program.cs @@ -74,7 +74,7 @@ .AsAIAgent(new ChatClientAgentOptions { ChatOptions = new() { Instructions = "You are a helpful support specialist for the Microsoft Agent Framework. Answer questions using the provided context and cite the source document when available. Keep responses brief." }, - AIContextProviderFactory = (ctx, ct) => new ValueTask(new TextSearchProvider(SearchAdapter, ctx.SerializedState, ctx.JsonSerializerOptions, textSearchOptions)) + AIContextProviders = [new TextSearchProvider(SearchAdapter, textSearchOptions)] }); AgentSession session = await agent.CreateSessionAsync(); diff --git a/dotnet/samples/GettingStarted/AgentWithRAG/AgentWithRAG_Step03_CustomRAGDataSource/Program.cs b/dotnet/samples/GettingStarted/AgentWithRAG/AgentWithRAG_Step03_CustomRAGDataSource/Program.cs index 06da840df4..d4e3a40756 100644 --- a/dotnet/samples/GettingStarted/AgentWithRAG/AgentWithRAG_Step03_CustomRAGDataSource/Program.cs +++ b/dotnet/samples/GettingStarted/AgentWithRAG/AgentWithRAG_Step03_CustomRAGDataSource/Program.cs @@ -32,7 +32,7 @@ .AsAIAgent(new ChatClientAgentOptions { ChatOptions = new() { Instructions = "You are a helpful support specialist for Contoso Outdoors. Answer questions using the provided context and cite the source document when available." }, - AIContextProviderFactory = (ctx, ct) => new ValueTask(new TextSearchProvider(MockSearchAsync, ctx.SerializedState, ctx.JsonSerializerOptions, textSearchOptions)) + AIContextProviders = [new TextSearchProvider(MockSearchAsync, textSearchOptions)] }); AgentSession session = await agent.CreateSessionAsync(); diff --git a/dotnet/samples/GettingStarted/Agents/Agent_Step07_3rdPartyChatHistoryStorage/Program.cs b/dotnet/samples/GettingStarted/Agents/Agent_Step07_3rdPartyChatHistoryStorage/Program.cs index 0eaf3d8bc5..cc4ca1a6ed 100644 --- a/dotnet/samples/GettingStarted/Agents/Agent_Step07_3rdPartyChatHistoryStorage/Program.cs +++ b/dotnet/samples/GettingStarted/Agents/Agent_Step07_3rdPartyChatHistoryStorage/Program.cs @@ -3,7 +3,7 @@ #pragma warning disable CA1869 // Cache and reuse 'JsonSerializerOptions' instances // This sample shows how to create and use a simple AI agent with custom ChatHistoryProvider that stores chat history in a custom storage location. -// The state of the custom ChatHistoryProvider (SessionDbKey) is stored with the agent session, so that when the session is resumed later, +// The state of the custom ChatHistoryProvider (SessionDbKey) is stored in the AgentSession's StateBag, so that when the session is resumed later, // the chat history can be retrieved from the custom storage location. using System.Text.Json; @@ -36,11 +36,8 @@ { ChatOptions = new() { Instructions = "You are good at telling jokes." }, Name = "Joker", - ChatHistoryProviderFactory = (ctx, ct) => new ValueTask( - // Create a new ChatHistoryProvider for this agent that stores chat history in a vector store. - // Each session must get its own copy of the VectorChatHistoryProvider, since the provider - // also contains the id that the chat history is stored under. - new VectorChatHistoryProvider(vectorStore, ctx.SerializedState, ctx.JsonSerializerOptions)) + // Create a new ChatHistoryProvider for this agent that stores chat history in a vector store. + ChatHistoryProvider = new VectorChatHistoryProvider(vectorStore) }); // Start a new session for the agent conversation. @@ -66,48 +63,75 @@ // Run the agent with the session that stores chat history in the vector store a second time. Console.WriteLine(await agent.RunAsync("Now tell the same joke in the voice of a pirate, and add some emojis to the joke.", resumedSession)); -// We can access the VectorChatHistoryProvider via the session's GetService method if we need to read the key under which chat history is stored. -var chatHistoryProvider = resumedSession.GetService()!; -Console.WriteLine($"\nSession is stored in vector store under key: {chatHistoryProvider.SessionDbKey}"); +// We can access the VectorChatHistoryProvider via the agent's GetService method +// if we need to read the key under which chat history is stored. The key is stored +// in the session state, and therefore we need to provide the session when reading it. +var chatHistoryProvider = agent.GetService()!; +Console.WriteLine($"\nSession is stored in vector store under key: {chatHistoryProvider.GetSessionDbKey(resumedSession)}"); namespace SampleApp { /// /// A sample implementation of that stores chat history in a vector store. + /// State (the session DB key) is stored in the so it roundtrips + /// automatically with session serialization. /// internal sealed class VectorChatHistoryProvider : ChatHistoryProvider { private readonly VectorStore _vectorStore; + private readonly Func _stateInitializer; + private readonly string _stateKey; - public VectorChatHistoryProvider(VectorStore vectorStore, JsonElement serializedState, JsonSerializerOptions? jsonSerializerOptions = null) + /// + public override string StateKey => this._stateKey; + + public VectorChatHistoryProvider( + VectorStore vectorStore, + Func? stateInitializer = null, + string? stateKey = null) { this._vectorStore = vectorStore ?? throw new ArgumentNullException(nameof(vectorStore)); + this._stateInitializer = stateInitializer ?? (_ => new State(Guid.NewGuid().ToString("N"))); + this._stateKey = stateKey ?? base.StateKey; + } + + public string GetSessionDbKey(AgentSession session) + => this.GetOrInitializeState(session).SessionDbKey; + + private State GetOrInitializeState(AgentSession? session) + { + if (session?.StateBag.TryGetValue(this._stateKey, out var state) is true && state is not null) + { + return state; + } - if (serializedState.ValueKind is JsonValueKind.String) + state = this._stateInitializer(session); + if (session is not null) { - // Here we can deserialize the session id so that we can access the same messages as before the suspension. - this.SessionDbKey = serializedState.Deserialize(); + session.StateBag.SetValue(this._stateKey, state); } - } - public string? SessionDbKey { get; private set; } + return state; + } protected override async ValueTask> InvokingCoreAsync(InvokingContext context, CancellationToken cancellationToken = default) { + var state = this.GetOrInitializeState(context.Session); var collection = this._vectorStore.GetCollection("ChatHistory"); await collection.EnsureCollectionExistsAsync(cancellationToken); var records = await collection .GetAsync( - x => x.SessionId == this.SessionDbKey, 10, + x => x.SessionId == state.SessionDbKey, 10, new() { OrderBy = x => x.Descending(y => y.Timestamp) }, cancellationToken) .ToListAsync(cancellationToken); - var messages = records.ConvertAll(x => JsonSerializer.Deserialize(x.SerializedMessage!)!) -; + var messages = records.ConvertAll(x => JsonSerializer.Deserialize(x.SerializedMessage!)!); messages.Reverse(); - return messages; + return messages + .Select(message => message.WithAgentRequestMessageSource(AgentRequestMessageSourceType.ChatHistory, this.GetType().FullName!)) + .Concat(context.RequestMessages); } protected override async ValueTask InvokedCoreAsync(InvokedContext context, CancellationToken cancellationToken = default) @@ -118,28 +142,39 @@ protected override async ValueTask InvokedCoreAsync(InvokedContext context, Canc return; } - this.SessionDbKey ??= Guid.NewGuid().ToString("N"); + var state = this.GetOrInitializeState(context.Session); var collection = this._vectorStore.GetCollection("ChatHistory"); await collection.EnsureCollectionExistsAsync(cancellationToken); - // Add both request and response messages to the store + // Add both request and response messages to the store, excluding messages that came from chat history. // Optionally messages produced by the AIContextProvider can also be persisted (not shown). - var allNewMessages = context.RequestMessages.Concat(context.ResponseMessages ?? []); + var allNewMessages = context.RequestMessages + .Where(m => m.GetAgentRequestMessageSourceType() != AgentRequestMessageSourceType.ChatHistory) + .Concat(context.ResponseMessages ?? []); await collection.UpsertAsync(allNewMessages.Select(x => new ChatHistoryItem() { - Key = this.SessionDbKey + x.MessageId, + Key = state.SessionDbKey + x.MessageId, Timestamp = DateTimeOffset.UtcNow, - SessionId = this.SessionDbKey, + SessionId = state.SessionDbKey, SerializedMessage = JsonSerializer.Serialize(x), MessageText = x.Text }), cancellationToken); } - public override JsonElement Serialize(JsonSerializerOptions? jsonSerializerOptions = null) => - // We have to serialize the session id, so that on deserialization we can retrieve the messages using the same session id. - JsonSerializer.SerializeToElement(this.SessionDbKey); + /// + /// Represents the per-session state stored in the . + /// + public sealed class State + { + public State(string sessionDbKey) + { + this.SessionDbKey = sessionDbKey ?? throw new ArgumentNullException(nameof(sessionDbKey)); + } + + public string SessionDbKey { get; } + } /// /// The data structure used to store chat history items in the vector store. diff --git a/dotnet/samples/GettingStarted/Agents/Agent_Step16_ChatReduction/Program.cs b/dotnet/samples/GettingStarted/Agents/Agent_Step16_ChatReduction/Program.cs index 77abb7898a..875f3375a4 100644 --- a/dotnet/samples/GettingStarted/Agents/Agent_Step16_ChatReduction/Program.cs +++ b/dotnet/samples/GettingStarted/Agents/Agent_Step16_ChatReduction/Program.cs @@ -27,7 +27,7 @@ { ChatOptions = new() { Instructions = "You are good at telling jokes." }, Name = "Joker", - ChatHistoryProviderFactory = (ctx, ct) => new ValueTask(new InMemoryChatHistoryProvider(new MessageCountingChatReducer(2), ctx.SerializedState, ctx.JsonSerializerOptions)) + ChatHistoryProvider = new InMemoryChatHistoryProvider(new() { ChatReducer = new MessageCountingChatReducer(2) }) }); AgentSession session = await agent.CreateSessionAsync(); @@ -36,7 +36,10 @@ Console.WriteLine(await agent.RunAsync("Tell me a joke about a pirate.", session)); // Get the chat history to see how many messages are stored. -IList? chatHistory = session.GetService>(); +// We can use the ChatHistoryProvider, that is also used by the agent, to read the +// chat history from the session state, and see how the reducer is affecting the stored messages. +var provider = agent.GetService(); +List? chatHistory = provider?.GetMessages(session); Console.WriteLine($"\nChat history has {chatHistory?.Count} messages.\n"); // Invoke the agent a few more times. diff --git a/dotnet/samples/GettingStarted/Agents/Agent_Step20_AdditionalAIContext/Program.cs b/dotnet/samples/GettingStarted/Agents/Agent_Step20_AdditionalAIContext/Program.cs index b04cf836bb..ba56d94e93 100644 --- a/dotnet/samples/GettingStarted/Agents/Agent_Step20_AdditionalAIContext/Program.cs +++ b/dotnet/samples/GettingStarted/Agents/Agent_Step20_AdditionalAIContext/Program.cs @@ -1,13 +1,12 @@ // Copyright (c) Microsoft. All rights reserved. -// This sample shows how to inject additional AI context into a ChatClientAgent using a custom AIContextProvider component that is attached to the agent. -// The sample also shows how to combine the results from multiple providers into a single class, in order to attach multiple of these to an agent. +// This sample shows how to inject additional AI context into a ChatClientAgent using custom AIContextProvider components that are attached to the agent. +// Multiple providers can be attached to an agent, and they will be called in sequence, each receiving the accumulated context from the previous one. // This mechanism can be used for various purposes, such as injecting RAG search results or memories into the agent's context. // Also note that Agent Framework already provides built-in AIContextProviders for many of these scenarios. #pragma warning disable CA1869 // Cache and reuse 'JsonSerializerOptions' instances -using System.ComponentModel; using System.Text; using System.Text.Json; using Azure.AI.OpenAI; @@ -48,16 +47,20 @@ You are a helpful personal assistant. You manage a TODO list for the user. When the user has completed one of the tasks it can be removed from the TODO list. Only provide the list of TODO items if asked. You remind users of upcoming calendar events when the user interacts with you. """ }, - ChatHistoryProviderFactory = (ctx, ct) => new ValueTask(new InMemoryChatHistoryProvider() - // Use WithAIContextProviderMessageRemoval, so that we don't store the messages from the AI context provider in the chat history. + ChatHistoryProvider = new InMemoryChatHistoryProvider(new InMemoryChatHistoryProviderOptions + { + // Use StorageInputMessageFilter to provide a custom filter for messages stored in chat history. + // By default the chat history provider will store all messages, except for those that came from chat history in the first place. + // In this case, we want to also exclude messages that came from AI context providers. // You may want to store these messages, depending on their content and your requirements. - .WithAIContextProviderMessageRemoval()), - // Add an AI context provider that maintains a todo list for the agent and one that provides upcoming calendar entries. - // Wrap these in an AI context provider that aggregates the other two. - AIContextProviderFactory = (ctx, ct) => new ValueTask(new AggregatingAIContextProvider([ - AggregatingAIContextProvider.CreateFactory((jsonElement, jsonSerializerOptions) => new TodoListAIContextProvider(jsonElement, jsonSerializerOptions)), - AggregatingAIContextProvider.CreateFactory((_, _) => new CalendarSearchAIContextProvider(loadNextThreeCalendarEvents)) - ], ctx.SerializedState, ctx.JsonSerializerOptions)), + StorageInputMessageFilter = messages => messages.Where(m => m.GetAgentRequestMessageSourceType() != AgentRequestMessageSourceType.AIContextProvider && m.GetAgentRequestMessageSourceType() != AgentRequestMessageSourceType.ChatHistory) + }), + // Add multiple AI context providers: one that maintains a todo list and one that provides upcoming calendar entries. + // The agent will call each provider in sequence, accumulating context from each. + AIContextProviders = [ + new TodoListAIContextProvider(), + new CalendarSearchAIContextProvider(loadNextThreeCalendarEvents) + ], }); // Invoke the agent and output the text result. @@ -83,51 +86,67 @@ namespace SampleApp /// internal sealed class TodoListAIContextProvider : AIContextProvider { - private readonly List _todoItems = new(); + private static List GetTodoItems(AgentSession? session) + => session?.StateBag.GetValue>(nameof(TodoListAIContextProvider)) ?? new List(); - public TodoListAIContextProvider(JsonElement jsonElement, JsonSerializerOptions? jsonSerializerOptions = null) - { - // Only try and restore the state if we got an array, since any other json would be invalid or undefined/null meaning - // it's the first time we are running. - if (jsonElement.ValueKind == JsonValueKind.Array) - { - this._todoItems = JsonSerializer.Deserialize>(jsonElement.GetRawText(), jsonSerializerOptions) ?? new List(); - } - } + private static void SetTodoItems(AgentSession? session, List items) + => session?.StateBag.SetValue(nameof(TodoListAIContextProvider), items); protected override ValueTask InvokingCoreAsync(InvokingContext context, CancellationToken cancellationToken = default) { + var inputContext = context.AIContext; + var todoItems = GetTodoItems(context.Session); + StringBuilder outputMessageBuilder = new(); outputMessageBuilder.AppendLine("Your todo list contains the following items:"); - if (this._todoItems.Count == 0) + if (todoItems.Count == 0) { outputMessageBuilder.AppendLine(" (no items)"); } else { - for (int i = 0; i < this._todoItems.Count; i++) + for (int i = 0; i < todoItems.Count; i++) { - outputMessageBuilder.AppendLine($"{i}. {this._todoItems[i]}"); + outputMessageBuilder.AppendLine($"{i}. {todoItems[i]}"); } } return new ValueTask(new AIContext { - Tools = [AIFunctionFactory.Create(this.AddTodoItem), AIFunctionFactory.Create(this.RemoveTodoItem)], - Messages = [new MEAI.ChatMessage(ChatRole.User, outputMessageBuilder.ToString())] + Instructions = inputContext.Instructions, + Tools = (inputContext.Tools ?? []).Concat(new AITool[] + { + AIFunctionFactory.Create((string item) => AddTodoItem(context.Session, item), "AddTodoItem", "Adds an item to the todo list."), + AIFunctionFactory.Create((int index) => RemoveTodoItem(context.Session, index), "RemoveTodoItem", "Removes an item from the todo list. Index is zero based.") + }), + Messages = + (inputContext.Messages ?? []) + .Concat( + [ + new MEAI.ChatMessage(ChatRole.User, outputMessageBuilder.ToString()).WithAgentRequestMessageSource(AgentRequestMessageSourceType.AIContextProvider, this.GetType().FullName!) + ]) }); } - [Description("Adds an item to the todo list. Index is zero based.")] - private void RemoveTodoItem(int index) => - this._todoItems.RemoveAt(index); + private static void RemoveTodoItem(AgentSession? session, int index) + { + var items = GetTodoItems(session); + items.RemoveAt(index); + SetTodoItems(session, items); + } - private void AddTodoItem(string item) => - this._todoItems.Add(string.IsNullOrWhiteSpace(item) ? throw new ArgumentException("Item must have a value") : item); + private static void AddTodoItem(AgentSession? session, string item) + { + if (string.IsNullOrWhiteSpace(item)) + { + throw new ArgumentException("Item must have a value"); + } - public override JsonElement Serialize(JsonSerializerOptions? jsonSerializerOptions = null) => - JsonSerializer.SerializeToElement(this._todoItems, jsonSerializerOptions); + var items = GetTodoItems(session); + items.Add(item); + SetTodoItems(session, items); + } } /// @@ -137,6 +156,7 @@ internal sealed class CalendarSearchAIContextProvider(Func> loadN { protected override async ValueTask InvokingCoreAsync(InvokingContext context, CancellationToken cancellationToken = default) { + var inputContext = context.AIContext; var events = await loadNextThreeCalendarEvents(); StringBuilder outputMessageBuilder = new(); @@ -148,84 +168,16 @@ protected override async ValueTask InvokingCoreAsync(InvokingContext return new() { + Instructions = inputContext.Instructions, Messages = - [ - new MEAI.ChatMessage(ChatRole.User, outputMessageBuilder.ToString()), - ] - }; - } - } - - /// - /// An which aggregates multiple AI context providers into one. - /// Serialized state for the different providers are stored under their type name. - /// Tools and messages from all providers are combined, and instructions are concatenated. - /// - internal sealed class AggregatingAIContextProvider : AIContextProvider - { - private readonly List _providers = new(); - - public AggregatingAIContextProvider(ProviderFactory[] providerFactories, JsonElement jsonElement, JsonSerializerOptions? jsonSerializerOptions) - { - // We received a json object, so let's check if it has some previously serialized state that we can use. - if (jsonElement.ValueKind == JsonValueKind.Object) - { - this._providers = providerFactories - .Select(factory => factory.FactoryMethod(jsonElement.TryGetProperty(factory.ProviderType.Name, out var prop) ? prop : default, jsonSerializerOptions)) - .ToList(); - return; - } - - // We didn't receive any valid json, so we can just construct fresh providers. - this._providers = providerFactories - .Select(factory => factory.FactoryMethod(default, jsonSerializerOptions)) - .ToList(); - } - - protected override async ValueTask InvokingCoreAsync(InvokingContext context, CancellationToken cancellationToken = default) - { - // Invoke all the sub providers. - var tasks = this._providers.Select(provider => provider.InvokingAsync(context, cancellationToken).AsTask()); - var results = await Task.WhenAll(tasks); - - // Combine the results from each sub provider. - return new AIContext - { - Tools = results.SelectMany(r => r.Tools ?? []).ToList(), - Messages = results.SelectMany(r => r.Messages ?? []).ToList(), - Instructions = string.Join("\n", results.Select(r => r.Instructions).Where(s => !string.IsNullOrEmpty(s))) - }; - } - - public override JsonElement Serialize(JsonSerializerOptions? jsonSerializerOptions = null) - { - Dictionary elements = new(); - foreach (var provider in this._providers) - { - JsonElement element = provider.Serialize(jsonSerializerOptions); - - // Don't try to store state for any providers that aren't producing any. - if (element.ValueKind != JsonValueKind.Undefined && element.ValueKind != JsonValueKind.Null) - { - elements[provider.GetType().Name] = element; - } - } - - return JsonSerializer.SerializeToElement(elements, jsonSerializerOptions); - } - - public static ProviderFactory CreateFactory(Func factoryMethod) - where TProviderType : AIContextProvider => new() - { - FactoryMethod = (jsonElement, jsonSerializerOptions) => factoryMethod(jsonElement, jsonSerializerOptions), - ProviderType = typeof(TProviderType) + (inputContext.Messages ?? []) + .Concat( + [ + new MEAI.ChatMessage(ChatRole.User, outputMessageBuilder.ToString()).WithAgentRequestMessageSource(AgentRequestMessageSourceType.AIContextProvider, this.GetType().FullName!) + ]) + .ToList(), + Tools = inputContext.Tools }; - - public readonly struct ProviderFactory - { - public Func FactoryMethod { get; init; } - - public Type ProviderType { get; init; } } } } diff --git a/dotnet/src/Microsoft.Agents.AI.A2A/A2AAgent.cs b/dotnet/src/Microsoft.Agents.AI.A2A/A2AAgent.cs index aea99b4e3d..5601653e8f 100644 --- a/dotnet/src/Microsoft.Agents.AI.A2A/A2AAgent.cs +++ b/dotnet/src/Microsoft.Agents.AI.A2A/A2AAgent.cs @@ -80,7 +80,7 @@ protected override ValueTask SerializeSessionCoreAsync(AgentSession /// protected override ValueTask DeserializeSessionCoreAsync(JsonElement serializedState, JsonSerializerOptions? jsonSerializerOptions = null, CancellationToken cancellationToken = default) - => new(new A2AAgentSession(serializedState, jsonSerializerOptions)); + => new(A2AAgentSession.Deserialize(serializedState, jsonSerializerOptions)); /// protected override async Task RunCoreAsync(IEnumerable messages, AgentSession? session = null, AgentRunOptions? options = null, CancellationToken cancellationToken = default) diff --git a/dotnet/src/Microsoft.Agents.AI.A2A/A2AAgentSession.cs b/dotnet/src/Microsoft.Agents.AI.A2A/A2AAgentSession.cs index cac9b43a30..045abc736a 100644 --- a/dotnet/src/Microsoft.Agents.AI.A2A/A2AAgentSession.cs +++ b/dotnet/src/Microsoft.Agents.AI.A2A/A2AAgentSession.cs @@ -1,66 +1,61 @@ // Copyright (c) Microsoft. All rights reserved. using System; +using System.Diagnostics; using System.Text.Json; +using System.Text.Json.Serialization; namespace Microsoft.Agents.AI.A2A; /// /// Session for A2A based agents. /// +[DebuggerDisplay("{DebuggerDisplay,nq}")] public sealed class A2AAgentSession : AgentSession { internal A2AAgentSession() { } - internal A2AAgentSession(JsonElement serializedSessionState, JsonSerializerOptions? jsonSerializerOptions = null) + [JsonConstructor] + internal A2AAgentSession(string? contextId, string? taskId, AgentSessionStateBag? stateBag) : base(stateBag ?? new()) { - if (serializedSessionState.ValueKind != JsonValueKind.Object) - { - throw new ArgumentException("The serialized session state must be a JSON object.", nameof(serializedSessionState)); - } - - var state = serializedSessionState.Deserialize( - A2AJsonUtilities.DefaultOptions.GetTypeInfo(typeof(A2AAgentSessionState))) as A2AAgentSessionState; - - if (state?.ContextId is string contextId) - { - this.ContextId = contextId; - } - - if (state?.TaskId is string taskId) - { - this.TaskId = taskId; - } + this.ContextId = contextId; + this.TaskId = taskId; } /// /// Gets the ID for the current conversation with the A2A agent. /// + [JsonPropertyName("contextId")] public string? ContextId { get; internal set; } /// /// Gets the ID for the task the agent is currently working on. /// + [JsonPropertyName("taskId")] public string? TaskId { get; internal set; } /// internal JsonElement Serialize(JsonSerializerOptions? jsonSerializerOptions = null) { - var state = new A2AAgentSessionState - { - ContextId = this.ContextId, - TaskId = this.TaskId - }; - - return JsonSerializer.SerializeToElement(state, A2AJsonUtilities.DefaultOptions.GetTypeInfo(typeof(A2AAgentSessionState))); + var jso = jsonSerializerOptions ?? A2AJsonUtilities.DefaultOptions; + return JsonSerializer.SerializeToElement(this, jso.GetTypeInfo(typeof(A2AAgentSession))); } - internal sealed class A2AAgentSessionState + internal static A2AAgentSession Deserialize(JsonElement serializedState, JsonSerializerOptions? jsonSerializerOptions = null) { - public string? ContextId { get; set; } + if (serializedState.ValueKind != JsonValueKind.Object) + { + throw new ArgumentException("The serialized session state must be a JSON object.", nameof(serializedState)); + } - public string? TaskId { get; set; } + var jso = jsonSerializerOptions ?? A2AJsonUtilities.DefaultOptions; + return serializedState.Deserialize(jso.GetTypeInfo(typeof(A2AAgentSession))) as A2AAgentSession + ?? new A2AAgentSession(); } + + [DebuggerBrowsable(DebuggerBrowsableState.Never)] + private string DebuggerDisplay => + $"ContextId = {this.ContextId}, TaskId = {this.TaskId}, StateBag Count = {this.StateBag.Count}"; } diff --git a/dotnet/src/Microsoft.Agents.AI.A2A/A2AJsonUtilities.cs b/dotnet/src/Microsoft.Agents.AI.A2A/A2AJsonUtilities.cs index 5f079d9573..3c25e350ae 100644 --- a/dotnet/src/Microsoft.Agents.AI.A2A/A2AJsonUtilities.cs +++ b/dotnet/src/Microsoft.Agents.AI.A2A/A2AJsonUtilities.cs @@ -74,7 +74,7 @@ private static JsonSerializerOptions CreateDefaultOptions() NumberHandling = JsonNumberHandling.AllowReadingFromString)] // A2A agent types - [JsonSerializable(typeof(A2AAgentSession.A2AAgentSessionState))] + [JsonSerializable(typeof(A2AAgentSession))] [ExcludeFromCodeCoverage] private sealed partial class JsonContext : JsonSerializerContext; } diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/AIContext.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/AIContext.cs index b05992d93e..9ccfc3e905 100644 --- a/dotnet/src/Microsoft.Agents.AI.Abstractions/AIContext.cs +++ b/dotnet/src/Microsoft.Agents.AI.Abstractions/AIContext.cs @@ -56,41 +56,44 @@ public sealed class AIContext public string? Instructions { get; set; } /// - /// Gets or sets a collection of messages to add to the conversation history. + /// Gets or sets the sequence of messages to use for the current invocation. /// /// - /// A list of instances to be permanently added to the conversation history, - /// or if no messages should be added. + /// A sequence of instances to be used for the current invocation, + /// or if no messages should be used. /// /// /// - /// Unlike and , messages added through this property become - /// permanent additions to the conversation history. They will persist beyond the current invocation and - /// will be available in future interactions within the same conversation thread. + /// Unlike and , messages added through this property may become + /// permanent additions to the conversation history. + /// If chat history is managed by the underlying AI service, these messages will become part of chat history. + /// If chat history is managed using a , these messages will be passed to the + /// method, + /// and the provider can choose which of these messages to permanently add to the conversation history. /// /// /// This property is useful for: /// - /// Injecting relevant historical context or background information + /// Injecting relevant historical context e.g. memories + /// Injecting relevant background information e.g. via Retrieval Augmented Generation /// Adding system messages that provide ongoing context - /// Including retrieved information that should be part of the conversation record - /// Inserting contextual exchanges that inform the current conversation /// /// /// - public IList? Messages { get; set; } + public IEnumerable? Messages { get; set; } /// - /// Gets or sets a collection of tools or functions to make available to the AI model for the current invocation. + /// Gets or sets a sequence of tools or functions to make available to the AI model for the current invocation. /// /// - /// A list of instances that will be available to the AI model during the current invocation, + /// A sequence of instances that will be available to the AI model during the current invocation, /// or if no additional tools should be provided. /// /// /// - /// These tools are transient and apply only to the current AI model invocation. They are combined with any - /// tools already configured for the agent to provide an expanded set of capabilities for the specific interaction. + /// These tools are transient and apply only to the current AI model invocation. Any existing tools + /// are provided as input to the instances, so context providers can choose to modify or replace the existing tools + /// as needed based on the current context. The resulting set of tools is then passed to the underlying AI model, which may choose to utilize them when generating responses. /// /// /// Context-specific tools enable: @@ -102,5 +105,5 @@ public sealed class AIContext /// /// /// - public IList? Tools { get; set; } + public IEnumerable? Tools { get; set; } } diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/AIContextProvider.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/AIContextProvider.cs index 76ff8752b9..b201e9f8e1 100644 --- a/dotnet/src/Microsoft.Agents.AI.Abstractions/AIContextProvider.cs +++ b/dotnet/src/Microsoft.Agents.AI.Abstractions/AIContextProvider.cs @@ -2,8 +2,6 @@ using System; using System.Collections.Generic; -using System.Linq; -using System.Text.Json; using System.Threading; using System.Threading.Tasks; using Microsoft.Extensions.AI; @@ -32,24 +30,15 @@ namespace Microsoft.Agents.AI; /// public abstract class AIContextProvider { - private readonly string _sourceId; - - /// - /// Initializes a new instance of the class. - /// - protected AIContextProvider() - { - this._sourceId = this.GetType().FullName!; - } - /// - /// Initializes a new instance of the class with the specified source id. + /// Gets the key used to store the provider state in the . /// - /// The source id to stamp on for each messages produced by the . - protected AIContextProvider(string sourceId) - { - this._sourceId = sourceId; - } + /// + /// The default value is the name of the concrete type (e.g. "TextSearchProvider"). + /// Implementations may override this to provide a custom key, for example when multiple + /// instances of the same provider type are used in the same session. + /// + public virtual string StateKey => this.GetType().Name; /// /// Called at the start of agent invocation to provide additional context. @@ -68,20 +57,8 @@ protected AIContextProvider(string sourceId) /// /// /// - public async ValueTask InvokingAsync(InvokingContext context, CancellationToken cancellationToken = default) - { - var aiContext = await this.InvokingCoreAsync(context, cancellationToken).ConfigureAwait(false); - if (aiContext.Messages is null) - { - return aiContext; - } - - aiContext.Messages = aiContext.Messages - .Select(message => message.AsAgentRequestMessageSourcedMessage(AgentRequestMessageSourceType.AIContextProvider, this._sourceId)) - .ToList(); - - return aiContext; - } + public ValueTask InvokingAsync(InvokingContext context, CancellationToken cancellationToken = default) + => this.InvokingCoreAsync(context, cancellationToken); /// /// Called at the start of agent invocation to provide additional context. @@ -112,13 +89,18 @@ public async ValueTask InvokingAsync(InvokingContext context, Cancell /// /// Implementers can use the request and response messages in the provided to: /// - /// Update internal state based on conversation outcomes + /// Update state based on conversation outcomes /// Extract and store memories or preferences from user messages /// Log or audit conversation details /// Perform cleanup or finalization tasks /// /// /// + /// The is passed a reference to the via and + /// allowing it to store state in the . Since an is used with many different sessions, it should + /// not store any session-specific information within its own instance fields. Instead, any session-specific state should be stored in the associated . + /// + /// /// This method is called regardless of whether the invocation succeeded or failed. /// To check if the invocation was successful, inspect the property. /// @@ -150,18 +132,6 @@ public ValueTask InvokedAsync(InvokedContext context, CancellationToken cancella protected virtual ValueTask InvokedCoreAsync(InvokedContext context, CancellationToken cancellationToken = default) => default; - /// - /// Serializes the current object's state to a using the specified serialization options. - /// - /// The JSON serialization options to use for the serialization process. - /// A representation of the object's state, or a default if the provider has no serializable state. - /// - /// The default implementation returns a default . Override this method if the provider - /// maintains state that should be preserved across sessions or distributed scenarios. - /// - public virtual JsonElement Serialize(JsonSerializerOptions? jsonSerializerOptions = null) - => default; - /// Asks the for an object of the specified type . /// The type of object being requested. /// An optional key that can be used to help identify the target service. @@ -203,20 +173,20 @@ public virtual JsonElement Serialize(JsonSerializerOptions? jsonSerializerOption public sealed class InvokingContext { /// - /// Initializes a new instance of the class with the specified request messages. + /// Initializes a new instance of the class. /// /// The agent being invoked. /// The session associated with the agent invocation. - /// The messages to be used by the agent for this invocation. - /// is . + /// The AI context to be used by the agent for this invocation. + /// or is . public InvokingContext( AIAgent agent, AgentSession? session, - IEnumerable requestMessages) + AIContext aiContext) { this.Agent = Throw.IfNull(agent); this.Session = session; - this.RequestMessages = Throw.IfNull(requestMessages); + this.AIContext = Throw.IfNull(aiContext); } /// @@ -230,39 +200,75 @@ public InvokingContext( public AgentSession? Session { get; } /// - /// Gets the caller provided messages that will be used by the agent for this invocation. + /// Gets the being built for the current invocation. Context providers can modify + /// and return or return a new instance to provide additional context for the invocation. /// - /// - /// A collection of instances representing new messages that were provided by the caller. - /// - public IEnumerable RequestMessages { get; set { field = Throw.IfNull(value); } } + /// + /// + /// If multiple instances are used in the same invocation, each + /// will receive the context returned by the previous allowing them to build on top of each other's context. + /// + /// + /// The first in the invocation pipeline will receive an instance + /// that already contains the caller provided messages that will be used by the agent for this invocation. + /// + /// + /// It may also contain messages from chat history, if a is being used. + /// + /// + public AIContext AIContext { get; } } /// /// Contains the context information provided to . /// /// - /// This class provides context about a completed agent invocation, including both the - /// request messages that were used and the response messages that were generated. It also indicates - /// whether the invocation succeeded or failed. + /// This class provides context about a completed agent invocation, including the accumulated + /// request messages (user input, chat history and any others provided by AI context providers) that were used + /// and the response messages that were generated. It also indicates whether the invocation succeeded or failed. /// public sealed class InvokedContext { /// - /// Initializes a new instance of the class with the specified request messages. + /// Initializes a new instance of the class for a successful invocation. /// - /// The agent being invoked. + /// The agent that was invoked. + /// The session associated with the agent invocation. + /// The accumulated request messages (user input, chat history and any others provided by AI context providers) + /// that were used by the agent for this invocation. + /// The response messages generated during this invocation. + /// , , or is . + public InvokedContext( + AIAgent agent, + AgentSession? session, + IEnumerable requestMessages, + IEnumerable responseMessages) + { + this.Agent = Throw.IfNull(agent); + this.Session = session; + this.RequestMessages = Throw.IfNull(requestMessages); + this.ResponseMessages = Throw.IfNull(responseMessages); + } + + /// + /// Initializes a new instance of the class for a failed invocation. + /// + /// The agent that was invoked. /// The session associated with the agent invocation. - /// The caller provided messages that were used by the agent for this invocation. - /// is . + /// The accumulated request messages (user input, chat history and any others provided by AI context providers) + /// that were used by the agent for this invocation. + /// The exception that caused the invocation to fail. + /// , , or is . public InvokedContext( AIAgent agent, AgentSession? session, - IEnumerable requestMessages) + IEnumerable requestMessages, + Exception invokeException) { this.Agent = Throw.IfNull(agent); this.Session = session; this.RequestMessages = Throw.IfNull(requestMessages); + this.InvokeException = Throw.IfNull(invokeException); } /// @@ -276,22 +282,22 @@ public InvokedContext( public AgentSession? Session { get; } /// - /// Gets the caller provided messages that were used by the agent for this invocation. + /// Gets the accumulated request messages (user input, chat history and any others provided by AI context providers) + /// that were used by the agent for this invocation. /// /// - /// A collection of instances representing new messages that were provided by the caller. - /// This does not include any supplied messages. + /// A collection of instances representing all messages that were used by the agent for this invocation. /// - public IEnumerable RequestMessages { get; set { field = Throw.IfNull(value); } } + public IEnumerable RequestMessages { get; } /// /// Gets the collection of response messages generated during this invocation if the invocation succeeded. /// /// /// A collection of instances representing the response, - /// or if the invocation failed or did not produce response messages. + /// or if the invocation failed. /// - public IEnumerable? ResponseMessages { get; set; } + public IEnumerable? ResponseMessages { get; } /// /// Gets the that was thrown during the invocation, if the invocation failed. @@ -299,6 +305,6 @@ public InvokedContext( /// /// The exception that caused the invocation to fail, or if the invocation succeeded. /// - public Exception? InvokeException { get; set; } + public Exception? InvokeException { get; } } } diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentAbstractionsJsonUtilities.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentAbstractionsJsonUtilities.cs index 17fbb9e4c6..f8c8aa9b98 100644 --- a/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentAbstractionsJsonUtilities.cs +++ b/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentAbstractionsJsonUtilities.cs @@ -1,5 +1,6 @@ // Copyright (c) Microsoft. All rights reserved. +using System.Collections.Concurrent; using System.Diagnostics.CodeAnalysis; using System.Text.Encodings.Web; using System.Text.Json; @@ -80,9 +81,9 @@ private static JsonSerializerOptions CreateDefaultOptions() [JsonSerializable(typeof(AgentResponse[]))] [JsonSerializable(typeof(AgentResponseUpdate))] [JsonSerializable(typeof(AgentResponseUpdate[]))] - [JsonSerializable(typeof(ServiceIdAgentSession.ServiceIdAgentSessionState))] - [JsonSerializable(typeof(InMemoryAgentSession.InMemoryAgentSessionState))] [JsonSerializable(typeof(InMemoryChatHistoryProvider.State))] + [JsonSerializable(typeof(AgentSessionStateBag))] + [JsonSerializable(typeof(ConcurrentDictionary))] [ExcludeFromCodeCoverage] private sealed partial class JsonContext : JsonSerializerContext; diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentRequestMessageSourceAttribution.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentRequestMessageSourceAttribution.cs index 2c606814ce..1515adec9a 100644 --- a/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentRequestMessageSourceAttribution.cs +++ b/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentRequestMessageSourceAttribution.cs @@ -63,6 +63,17 @@ public override bool Equals(object? obj) return obj is AgentRequestMessageSourceAttribution other && this.Equals(other); } + /// + /// Returns a string representation of the current instance. + /// + /// A string containing the source type and source identifier. + public override string ToString() + { + return this.SourceId is null + ? $"{this.SourceType}" + : $"{this.SourceType}:{this.SourceId}"; + } + /// /// Returns a hash code for the current instance. /// diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentRequestMessageSourceType.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentRequestMessageSourceType.cs index 14bbbe3388..744f87bed6 100644 --- a/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentRequestMessageSourceType.cs +++ b/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentRequestMessageSourceType.cs @@ -58,6 +58,12 @@ public bool Equals(AgentRequestMessageSourceType other) /// if is a and its value is the same as this instance; otherwise, . public override bool Equals(object? obj) => obj is AgentRequestMessageSourceType other && this.Equals(other); + /// + /// Returns the string representation of this instance. + /// + /// The string value representing the source of the agent request message. + public override string ToString() => this.Value; + /// /// Returns the hash code for this instance. /// diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentSession.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentSession.cs index 4c62eccc99..a154b0a9f5 100644 --- a/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentSession.cs +++ b/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentSession.cs @@ -1,7 +1,9 @@ // Copyright (c) Microsoft. All rights reserved. using System; +using System.Diagnostics; using System.Text.Json; +using System.Text.Json.Serialization; using Microsoft.Shared.Diagnostics; namespace Microsoft.Agents.AI; @@ -44,6 +46,7 @@ namespace Microsoft.Agents.AI; /// /// /// +[DebuggerDisplay("{DebuggerDisplay,nq}")] public abstract class AgentSession { /// @@ -53,6 +56,20 @@ protected AgentSession() { } + /// + /// Initializes a new instance of the class. + /// + protected AgentSession(AgentSessionStateBag stateBag) + { + this.StateBag = Throw.IfNull(stateBag); + } + + /// + /// Gets any arbitrary state associated with this session. + /// + [JsonPropertyName("stateBag")] + public AgentSessionStateBag StateBag { get; protected set; } = new(); + /// Asks the for an object of the specified type . /// The type of object being requested. /// An optional key that can be used to help identify the target service. @@ -82,4 +99,7 @@ protected AgentSession() /// public TService? GetService(object? serviceKey = null) => this.GetService(typeof(TService), serviceKey) is TService service ? service : default; + + [DebuggerBrowsable(DebuggerBrowsableState.Never)] + private string DebuggerDisplay => $"StateBag Count = {this.StateBag.Count}"; } diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentSessionStateBag.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentSessionStateBag.cs new file mode 100644 index 0000000000..d78a866b2c --- /dev/null +++ b/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentSessionStateBag.cs @@ -0,0 +1,145 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Concurrent; +using System.Text.Json; +using System.Text.Json.Serialization; +using Microsoft.Shared.Diagnostics; + +namespace Microsoft.Agents.AI; + +/// +/// Provides a thread-safe key-value store for managing session-scoped state with support for type-safe access and JSON +/// serialization options. +/// +/// +/// SessionState enables storing and retrieving objects associated with a session using string keys. +/// Values can be accessed in a type-safe manner and are serialized or deserialized using configurable JSON serializer +/// options. This class is designed for concurrent access and is safe to use across multiple threads. +/// +[JsonConverter(typeof(AgentSessionStateBagJsonConverter))] +public class AgentSessionStateBag +{ + private readonly ConcurrentDictionary _state; + + /// + /// Initializes a new instance of the class. + /// + public AgentSessionStateBag() + { + this._state = new ConcurrentDictionary(); + } + + /// + /// Initializes a new instance of the class. + /// + /// The initial state dictionary. + internal AgentSessionStateBag(ConcurrentDictionary? state) + { + this._state = state ?? new ConcurrentDictionary(); + } + + /// + /// Gets the number of key-value pairs contained in the session state. + /// + public int Count => this._state.Count; + + /// + /// Tries to get a value from the session state. + /// + /// The type of the value to retrieve. + /// The key from which to retrieve the value. + /// The value if found and convertible to the required type; otherwise, null. + /// The JSON serializer options to use for serializing/deserializing the value. + /// if the value was successfully retrieved, otherwise. + public bool TryGetValue(string key, out T? value, JsonSerializerOptions? jsonSerializerOptions = null) + where T : class + { + _ = Throw.IfNullOrWhitespace(key); + var jso = jsonSerializerOptions ?? AgentAbstractionsJsonUtilities.DefaultOptions; + + if (this._state.TryGetValue(key, out var stateValue)) + { + return stateValue.TryReadDeserializedValue(out value, jso); + } + + value = null; + return false; + } + + /// + /// Gets a value from the session state. + /// + /// The type of value to get. + /// The key from which to retrieve the value. + /// The JSON serializer options to use for serializing/deserialing the value. + /// The retrieved value or null if not found. + /// The value could not be deserialized into the required type. + public T? GetValue(string key, JsonSerializerOptions? jsonSerializerOptions = null) + where T : class + { + _ = Throw.IfNullOrWhitespace(key); + var jso = jsonSerializerOptions ?? AgentAbstractionsJsonUtilities.DefaultOptions; + + if (this._state.TryGetValue(key, out var stateValue)) + { + return stateValue.ReadDeserializedValue(jso); + } + + return null; + } + + /// + /// Sets a value in the session state. + /// + /// The type of the value to set. + /// The key to store the value under. + /// The value to set. + /// The JSON serializer options to use for serializing the value. + public void SetValue(string key, T? value, JsonSerializerOptions? jsonSerializerOptions = null) + where T : class + { + _ = Throw.IfNullOrWhitespace(key); + var jso = jsonSerializerOptions ?? AgentAbstractionsJsonUtilities.DefaultOptions; + + var stateValue = this._state.GetOrAdd(key, _ => + new AgentSessionStateBagValue(value, typeof(T), jso)); + + stateValue.SetDeserialized(value, typeof(T), jso); + } + + /// + /// Tries to remove a value from the session state. + /// + /// The key of the value to remove. + /// if the value was successfully removed; otherwise, . + public bool TryRemoveValue(string key) + => this._state.TryRemove(Throw.IfNullOrWhitespace(key), out _); + + /// + /// Serializes all session state values to a JSON object. + /// + /// A representing the serialized session state. + /// Thrown when a session state value is not properly initialized. + public JsonElement Serialize() + { + return JsonSerializer.SerializeToElement(this._state, AgentAbstractionsJsonUtilities.DefaultOptions.GetTypeInfo(typeof(ConcurrentDictionary))); + } + + /// + /// Deserializes a JSON object into an instance. + /// + /// The element to deserialize. + /// The deserialized . + public static AgentSessionStateBag Deserialize(JsonElement jsonElement) + { + if (jsonElement.ValueKind is JsonValueKind.Undefined or JsonValueKind.Null) + { + return new AgentSessionStateBag(); + } + + return new AgentSessionStateBag( + jsonElement.Deserialize(AgentAbstractionsJsonUtilities.DefaultOptions.GetTypeInfo(typeof(ConcurrentDictionary))) as ConcurrentDictionary + ?? new ConcurrentDictionary()); + } +} diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentSessionStateBagJsonConverter.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentSessionStateBagJsonConverter.cs new file mode 100644 index 0000000000..bfb6904320 --- /dev/null +++ b/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentSessionStateBagJsonConverter.cs @@ -0,0 +1,28 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Text.Json; +using System.Text.Json.Serialization; + +namespace Microsoft.Agents.AI; + +/// +/// Custom JSON converter for that serializes and deserializes +/// the internal dictionary contents rather than the container object's public properties. +/// +public sealed class AgentSessionStateBagJsonConverter : JsonConverter +{ + /// + public override AgentSessionStateBag Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) + { + var element = JsonElement.ParseValue(ref reader); + return AgentSessionStateBag.Deserialize(element); + } + + /// + public override void Write(Utf8JsonWriter writer, AgentSessionStateBag value, JsonSerializerOptions options) + { + var element = value.Serialize(); + element.WriteTo(writer); + } +} diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentSessionStateBagValue.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentSessionStateBagValue.cs new file mode 100644 index 0000000000..0b4849aa1b --- /dev/null +++ b/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentSessionStateBagValue.cs @@ -0,0 +1,182 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Text.Json; +using System.Text.Json.Serialization; + +namespace Microsoft.Agents.AI; + +/// +/// Used to store a value in session state. +/// +[JsonConverter(typeof(AgentSessionStateBagValueJsonConverter))] +internal class AgentSessionStateBagValue +{ + private readonly object _lock = new(); + private DeserializedCache? _cache; + private JsonElement _jsonValue; + + /// + /// Initializes a new instance of the SessionStateValue class with the specified value. + /// + /// The serialized value to associate with the session state. + public AgentSessionStateBagValue(JsonElement jsonValue) + { + this.JsonValue = jsonValue; + } + + /// + /// Initializes a new instance of the SessionStateValue class with the specified value. + /// + /// The value to associate with the session state. Can be any object, including null. + /// The type of the value. + /// The JSON serializer options to use for serializing the value. + public AgentSessionStateBagValue(object? deserializedValue, Type valueType, JsonSerializerOptions jsonSerializerOptions) + { + this._cache = new DeserializedCache(deserializedValue, valueType, jsonSerializerOptions); + } + + /// + /// Gets or sets the value associated with this instance. + /// + public JsonElement JsonValue + { + get + { + lock (this._lock) + { + // We are assuming here that JsonValue will only be read when the object is being serialized, + // which means that we will only call SerializeToElement when serializing and therefore it's + // OK to serialize on each read if the cache is set. + if (this._cache is { } cache) + { + this._jsonValue = JsonSerializer.SerializeToElement(cache.Value, cache.Options.GetTypeInfo(cache.ValueType)); + } + + return this._jsonValue; + } + } + set + { + lock (this._lock) + { + this._jsonValue = value; + this._cache = null; + } + } + } + + /// + /// Tries to read the deserialized value of this session state value. + /// Returns false if the value could not be deserialized into the required type, or if the value is undefined. + /// Returns true and sets the out parameter to null if the value is null. + /// + public bool TryReadDeserializedValue(out T? value, JsonSerializerOptions? jsonSerializerOptions = null) + where T : class + { + var jso = jsonSerializerOptions ?? AgentAbstractionsJsonUtilities.DefaultOptions; + + lock (this._lock) + { + switch (this._cache) + { + case DeserializedCache { Value: null, ValueType: Type cacheValueType } when cacheValueType == typeof(T): + value = null; + return true; + case DeserializedCache { Value: T cacheValue, ValueType: Type cacheValueType } when cacheValueType == typeof(T): + value = cacheValue; + return true; + case DeserializedCache { ValueType: Type cacheValueType } when cacheValueType != typeof(T): + value = null; + return false; + } + + switch (this._jsonValue) + { + case JsonElement jsonElement when jsonElement.ValueKind == JsonValueKind.Undefined: + value = null; + return false; + case JsonElement jsonElement when jsonElement.ValueKind == JsonValueKind.Null: + value = null; + return true; + default: + T? result = this._jsonValue.Deserialize(jso.GetTypeInfo(typeof(T))) as T; + if (result is null) + { + value = null; + return false; + } + + this._cache = new DeserializedCache(result, typeof(T), jso); + + value = result; + return true; + } + } + } + + /// + /// Reads the deserialized value of this session state value, throwing an exception if the value could not be deserialized into the required type or is undefined. + /// + public T? ReadDeserializedValue(JsonSerializerOptions? jsonSerializerOptions = null) + where T : class + { + var jso = jsonSerializerOptions ?? AgentAbstractionsJsonUtilities.DefaultOptions; + + lock (this._lock) + { + switch (this._cache) + { + case DeserializedCache { Value: null, ValueType: Type cacheValueType } when cacheValueType == typeof(T): + return null; + case DeserializedCache { Value: T cacheValue, ValueType: Type cacheValueType } when cacheValueType == typeof(T): + return cacheValue; + case DeserializedCache { ValueType: Type cacheValueType } when cacheValueType != typeof(T): + throw new InvalidOperationException($"The type of the cached value is {cacheValueType.FullName}, but the requested type is {typeof(T).FullName}."); + } + + switch (this._jsonValue) + { + case JsonElement jsonElement when jsonElement.ValueKind == JsonValueKind.Null || jsonElement.ValueKind == JsonValueKind.Undefined: + return null; + default: + T? result = this._jsonValue.Deserialize(jso.GetTypeInfo(typeof(T))) as T; + if (result is null) + { + throw new InvalidOperationException($"Failed to deserialize session state value to type {typeof(T).FullName}."); + } + + this._cache = new DeserializedCache(result, typeof(T), jso); + return result; + } + } + } + + /// + /// Sets the deserialized value of this session state value, updating the cache accordingly. + /// This does not update the JsonValue directly; the JsonValue will be updated on the next read or when the object is serialized. + /// + public void SetDeserialized(T? deserializedValue, Type valueType, JsonSerializerOptions jsonSerializerOptions) + { + lock (this._lock) + { + this._cache = new DeserializedCache(deserializedValue, valueType, jsonSerializerOptions); + } + } + + private readonly struct DeserializedCache + { + public DeserializedCache(object? value, Type valueType, JsonSerializerOptions options) + { + this.Value = value; + this.ValueType = valueType; + this.Options = options; + } + + public object? Value { get; } + + public Type ValueType { get; } + + public JsonSerializerOptions Options { get; } + } +} diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentSessionStateBagValueJsonConverter.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentSessionStateBagValueJsonConverter.cs new file mode 100644 index 0000000000..27c9dc08a8 --- /dev/null +++ b/dotnet/src/Microsoft.Agents.AI.Abstractions/AgentSessionStateBagValueJsonConverter.cs @@ -0,0 +1,27 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Text.Json; +using System.Text.Json.Serialization; + +namespace Microsoft.Agents.AI; + +/// +/// Custom JSON converter for that serializes and deserializes +/// the directly rather than wrapping it in a container object. +/// +internal sealed class AgentSessionStateBagValueJsonConverter : JsonConverter +{ + /// + public override AgentSessionStateBagValue Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) + { + var element = JsonElement.ParseValue(ref reader); + return new AgentSessionStateBagValue(element); + } + + /// + public override void Write(Utf8JsonWriter writer, AgentSessionStateBagValue value, JsonSerializerOptions options) + { + value.JsonValue.WriteTo(writer); + } +} diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/ChatHistoryProvider.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/ChatHistoryProvider.cs index 086ad02061..d16ca69528 100644 --- a/dotnet/src/Microsoft.Agents.AI.Abstractions/ChatHistoryProvider.cs +++ b/dotnet/src/Microsoft.Agents.AI.Abstractions/ChatHistoryProvider.cs @@ -2,8 +2,6 @@ using System; using System.Collections.Generic; -using System.Linq; -using System.Text.Json; using System.Threading; using System.Threading.Tasks; using Microsoft.Extensions.AI; @@ -27,34 +25,29 @@ namespace Microsoft.Agents.AI; /// Storing chat messages with proper ordering and metadata preservation /// Retrieving messages in chronological order for agent context /// Managing storage limits through truncation, summarization, or other strategies -/// Supporting serialization for thread persistence and migration /// /// /// +/// The is passed a reference to the via and +/// allowing it to store state in the . Since a is used with many different sessions, it should +/// not store any session-specific information within its own instance fields. Instead, any session-specific state should be stored in the associated . +/// +/// /// A is only relevant for scenarios where the underlying AI service that the agent is using /// does not use in-service chat history storage. /// /// public abstract class ChatHistoryProvider { - private readonly string _sourceId; - - /// - /// Initializes a new instance of the class. - /// - protected ChatHistoryProvider() - { - this._sourceId = this.GetType().FullName!; - } - /// - /// Initializes a new instance of the class with the specified source id. + /// Gets the key used to store the provider state in the . /// - /// The source id to stamp on for each messages produced by the . - protected ChatHistoryProvider(string sourceId) - { - this._sourceId = sourceId; - } + /// + /// The default value is the name of the concrete type (e.g. "InMemoryChatHistoryProvider"). + /// Implementations may override this to provide a custom key, for example when multiple + /// instances of the same provider type are used in the same session. + /// + public virtual string StateKey => this.GetType().Name; /// /// Called at the start of agent invocation to provide messages from the chat history as context for the next agent invocation. @@ -80,17 +73,9 @@ protected ChatHistoryProvider(string sourceId) /// Archiving old messages while keeping active conversation context /// /// - /// - /// Each instance should be associated with a single to ensure proper message isolation - /// and context management. - /// /// - public async ValueTask> InvokingAsync(InvokingContext context, CancellationToken cancellationToken = default) - { - var messages = await this.InvokingCoreAsync(context, cancellationToken).ConfigureAwait(false); - - return messages.Select(message => message.AsAgentRequestMessageSourcedMessage(AgentRequestMessageSourceType.ChatHistory, this._sourceId)); - } + public ValueTask> InvokingAsync(InvokingContext context, CancellationToken cancellationToken = default) + => this.InvokingCoreAsync(context, cancellationToken); /// /// Called at the start of agent invocation to provide messages from the chat history as context for the next agent invocation. @@ -178,13 +163,6 @@ public ValueTask InvokedAsync(InvokedContext context, CancellationToken cancella /// protected abstract ValueTask InvokedCoreAsync(InvokedContext context, CancellationToken cancellationToken = default); - /// - /// Serializes the current object's state to a using the specified serialization options. - /// - /// The JSON serialization options to use. - /// A representation of the object's state. - public abstract JsonElement Serialize(JsonSerializerOptions? jsonSerializerOptions = null); - /// Asks the for an object of the specified type . /// The type of object being requested. /// An optional key that can be used to help identify the target service. @@ -229,7 +207,7 @@ public sealed class InvokingContext /// /// The agent being invoked. /// The session associated with the agent invocation. - /// The new messages to be used by the agent for this invocation. + /// The messages to be used by the agent for this invocation. /// is . public InvokingContext( AIAgent agent, @@ -252,11 +230,22 @@ public InvokingContext( public AgentSession? Session { get; } /// - /// Gets the caller provided messages that will be used by the agent for this invocation. + /// Gets the messages that will be used by the agent for this invocation. instances can modify + /// and return or return a new message list to add additional messages for the invocation. /// /// - /// A collection of instances representing new messages that were provided by the caller. + /// A collection of instances representing the messages that will be used by the agent for this invocation. /// + /// + /// + /// If multiple instances are used in the same invocation, each + /// will receive the messages returned by the previous allowing them to build on top of each other's context. + /// + /// + /// The first in the invocation pipeline will receive the + /// caller provided messages. + /// + /// public IEnumerable RequestMessages { get; set { field = Throw.IfNull(value); } } } @@ -264,27 +253,52 @@ public InvokingContext( /// Contains the context information provided to . /// /// - /// This class provides context about a completed agent invocation, including both the - /// request messages that were used and the response messages that were generated. It also indicates - /// whether the invocation succeeded or failed. + /// This class provides context about a completed agent invocation, including the accumulated + /// request messages (user input, chat history and any others provided by AI context providers) that were used + /// and the response messages that were generated. It also indicates whether the invocation succeeded or failed. /// public sealed class InvokedContext { /// - /// Initializes a new instance of the class with the specified request messages. + /// Initializes a new instance of the class for a successful invocation. /// - /// The agent being invoked. + /// The agent that was invoked. /// The session associated with the agent invocation. - /// The caller provided messages that were used by the agent for this invocation. - /// is . + /// The accumulated request messages (user input, chat history and any others provided by AI context providers) + /// that were used by the agent for this invocation. + /// The response messages generated during this invocation. + /// , , or is . public InvokedContext( AIAgent agent, AgentSession? session, - IEnumerable requestMessages) + IEnumerable requestMessages, + IEnumerable responseMessages) { this.Agent = Throw.IfNull(agent); this.Session = session; this.RequestMessages = Throw.IfNull(requestMessages); + this.ResponseMessages = Throw.IfNull(responseMessages); + } + + /// + /// Initializes a new instance of the class for a failed invocation. + /// + /// The agent that was invoked. + /// The session associated with the agent invocation. + /// The accumulated request messages (user input, chat history and any others provided by AI context providers) + /// that were used by the agent for this invocation. + /// The exception that caused the invocation to fail. + /// , , or is . + public InvokedContext( + AIAgent agent, + AgentSession? session, + IEnumerable requestMessages, + Exception invokeException) + { + this.Agent = Throw.IfNull(agent); + this.Session = session; + this.RequestMessages = Throw.IfNull(requestMessages); + this.InvokeException = Throw.IfNull(invokeException); } /// @@ -298,22 +312,23 @@ public InvokedContext( public AgentSession? Session { get; } /// - /// Gets the caller provided messages that were used by the agent for this invocation. + /// Gets the accumulated request messages (user input, chat history and any others provided by AI context providers) + /// that were used by the agent for this invocation. /// /// /// A collection of instances representing new messages that were provided by the caller. /// This does not include any supplied messages. /// - public IEnumerable RequestMessages { get; set { field = Throw.IfNull(value); } } + public IEnumerable RequestMessages { get; } /// /// Gets the collection of response messages generated during this invocation if the invocation succeeded. /// /// /// A collection of instances representing the response, - /// or if the invocation failed or did not produce response messages. + /// or if the invocation failed. /// - public IEnumerable? ResponseMessages { get; set; } + public IEnumerable? ResponseMessages { get; } /// /// Gets the that was thrown during the invocation, if the invocation failed. @@ -321,6 +336,6 @@ public InvokedContext( /// /// The exception that caused the invocation to fail, or if the invocation succeeded. /// - public Exception? InvokeException { get; set; } + public Exception? InvokeException { get; } } } diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/ChatHistoryProviderExtensions.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/ChatHistoryProviderExtensions.cs deleted file mode 100644 index 4c8ef2489f..0000000000 --- a/dotnet/src/Microsoft.Agents.AI.Abstractions/ChatHistoryProviderExtensions.cs +++ /dev/null @@ -1,52 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. - -using System; -using System.Collections.Generic; -using System.Linq; -using Microsoft.Extensions.AI; - -namespace Microsoft.Agents.AI; - -/// -/// Contains extension methods for the class. -/// -public static class ChatHistoryProviderExtensions -{ - /// - /// Adds message filtering to an existing , so that messages passed to the and messages - /// provided by the can be filtered, updated or replaced. - /// - /// The to add the message filter to. - /// An optional filter function to apply to messages produced by the . If null, no filter is applied at this - /// stage. - /// An optional filter function to apply to the invoked context messages before they are passed to the . If null, no - /// filter is applied at this stage. - /// The with filtering applied. - public static ChatHistoryProvider WithMessageFilters( - this ChatHistoryProvider provider, - Func, IEnumerable>? invokingMessagesFilter = null, - Func? invokedMessagesFilter = null) - { - return new ChatHistoryProviderMessageFilter( - innerProvider: provider, - invokingMessagesFilter: invokingMessagesFilter, - invokedMessagesFilter: invokedMessagesFilter); - } - - /// - /// Decorates the provided so that it does not add - /// messages with to chat history. - /// - /// The to add the message filter to. - /// A new instance that filters out messages so they do not get added. - public static ChatHistoryProvider WithAIContextProviderMessageRemoval(this ChatHistoryProvider provider) - { - return new ChatHistoryProviderMessageFilter( - innerProvider: provider, - invokedMessagesFilter: (ctx) => - { - ctx.RequestMessages = ctx.RequestMessages.Where(x => x.GetAgentRequestMessageSourceType() != AgentRequestMessageSourceType.AIContextProvider); - return ctx; - }); - } -} diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/ChatHistoryProviderMessageFilter.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/ChatHistoryProviderMessageFilter.cs deleted file mode 100644 index 6cee80986b..0000000000 --- a/dotnet/src/Microsoft.Agents.AI.Abstractions/ChatHistoryProviderMessageFilter.cs +++ /dev/null @@ -1,74 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. - -using System; -using System.Collections.Generic; -using System.Text.Json; -using System.Threading; -using System.Threading.Tasks; -using Microsoft.Extensions.AI; -using Microsoft.Shared.Diagnostics; - -namespace Microsoft.Agents.AI; - -/// -/// A decorator that allows filtering the messages -/// passed into and out of an inner . -/// -public sealed class ChatHistoryProviderMessageFilter : ChatHistoryProvider -{ - private readonly ChatHistoryProvider _innerProvider; - private readonly Func, IEnumerable>? _invokingMessagesFilter; - private readonly Func? _invokedMessagesFilter; - - /// - /// Initializes a new instance of the class. - /// - /// Use this constructor to customize how messages are filtered before and after invocation by - /// providing appropriate filter functions. If no filters are provided, the operates without - /// additional filtering. - /// The underlying to be wrapped. Cannot be null. - /// An optional filter function to apply to messages provided by the - /// before they are used by the agent. If null, no filter is applied at this stage. - /// An optional filter function to apply to the invocation context after messages have been produced. If null, no - /// filter is applied at this stage. - /// Thrown if is null. - public ChatHistoryProviderMessageFilter( - ChatHistoryProvider innerProvider, - Func, IEnumerable>? invokingMessagesFilter = null, - Func? invokedMessagesFilter = null) - { - this._innerProvider = Throw.IfNull(innerProvider); - - if (invokingMessagesFilter == null && invokedMessagesFilter == null) - { - throw new ArgumentException("At least one filter function, invokingMessagesFilter or invokedMessagesFilter, must be provided."); - } - - this._invokingMessagesFilter = invokingMessagesFilter; - this._invokedMessagesFilter = invokedMessagesFilter; - } - - /// - protected override async ValueTask> InvokingCoreAsync(InvokingContext context, CancellationToken cancellationToken = default) - { - var messages = await this._innerProvider.InvokingAsync(context, cancellationToken).ConfigureAwait(false); - return this._invokingMessagesFilter != null ? this._invokingMessagesFilter(messages) : messages; - } - - /// - protected override ValueTask InvokedCoreAsync(InvokedContext context, CancellationToken cancellationToken = default) - { - if (this._invokedMessagesFilter != null) - { - context = this._invokedMessagesFilter(context); - } - - return this._innerProvider.InvokedAsync(context, cancellationToken); - } - - /// - public override JsonElement Serialize(JsonSerializerOptions? jsonSerializerOptions = null) - { - return this._innerProvider.Serialize(jsonSerializerOptions); - } -} diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/ChatMessageExtensions.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/ChatMessageExtensions.cs index 052e48ce56..0ff4874732 100644 --- a/dotnet/src/Microsoft.Agents.AI.Abstractions/ChatMessageExtensions.cs +++ b/dotnet/src/Microsoft.Agents.AI.Abstractions/ChatMessageExtensions.cs @@ -54,7 +54,7 @@ public static AgentRequestMessageSourceType GetAgentRequestMessageSourceType(thi /// If the message is already tagged with the provided source type and source id, it is returned as is. /// Otherwise, a cloned message is returned with the appropriate tagging in the AdditionalProperties. /// - public static ChatMessage AsAgentRequestMessageSourcedMessage(this ChatMessage message, AgentRequestMessageSourceType sourceType, string? sourceId = null) + public static ChatMessage WithAgentRequestMessageSource(this ChatMessage message, AgentRequestMessageSourceType sourceType, string? sourceId = null) { if (message.AdditionalProperties != null // Check if the message was already tagged with the required source type and source id diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/InMemoryAgentSession.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/InMemoryAgentSession.cs deleted file mode 100644 index 05ffafaeb9..0000000000 --- a/dotnet/src/Microsoft.Agents.AI.Abstractions/InMemoryAgentSession.cs +++ /dev/null @@ -1,124 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. - -using System; -using System.Collections.Generic; -using System.Diagnostics; -using System.Text.Json; -using Microsoft.Extensions.AI; - -namespace Microsoft.Agents.AI; - -/// -/// Provides an abstract base class for an that maintain all chat history in local memory. -/// -/// -/// -/// is designed for scenarios where chat history should be stored locally -/// rather than in external services or databases. This approach provides high performance and simplicity while -/// maintaining full control over the conversation data. -/// -/// -/// In-memory threads do not persist conversation data across application restarts -/// unless explicitly serialized and restored. -/// -/// -[DebuggerDisplay("{DebuggerDisplay,nq}")] -public abstract class InMemoryAgentSession : AgentSession -{ - /// - /// Initializes a new instance of the class. - /// - /// - /// An optional instance to use for storing chat messages. - /// If , a new empty will be created. - /// - /// - /// This constructor allows sharing of between sessions or providing pre-configured - /// with specific reduction or processing logic. - /// - protected InMemoryAgentSession(InMemoryChatHistoryProvider? chatHistoryProvider = null) - { - this.ChatHistoryProvider = chatHistoryProvider ?? []; - } - - /// - /// Initializes a new instance of the class. - /// - /// The initial messages to populate the conversation history. - /// is . - /// - /// This constructor is useful for initializing sessions with existing conversation history or - /// for migrating conversations from other storage systems. - /// - protected InMemoryAgentSession(IEnumerable messages) - { - this.ChatHistoryProvider = [.. messages]; - } - - /// - /// Initializes a new instance of the class from previously serialized state. - /// - /// A representing the serialized state of the session. - /// Optional settings for customizing the JSON deserialization process. - /// - /// Optional factory function to create the from its serialized state. - /// If not provided, a default factory will be used that creates a basic . - /// - /// The is not a JSON object. - /// The is invalid or cannot be deserialized to the expected type. - /// - /// This constructor enables restoration of in-memory threads from previously saved state, allowing - /// conversations to be resumed across application restarts or migrated between different instances. - /// - protected InMemoryAgentSession( - JsonElement serializedState, - JsonSerializerOptions? jsonSerializerOptions = null, - Func? chatHistoryProviderFactory = null) - { - if (serializedState.ValueKind != JsonValueKind.Object) - { - throw new ArgumentException("The serialized session state must be a JSON object.", nameof(serializedState)); - } - - var state = serializedState.Deserialize( - AgentAbstractionsJsonUtilities.DefaultOptions.GetTypeInfo(typeof(InMemoryAgentSessionState))) as InMemoryAgentSessionState; - - this.ChatHistoryProvider = - chatHistoryProviderFactory?.Invoke(state?.ChatHistoryProviderState ?? default, jsonSerializerOptions) ?? - new InMemoryChatHistoryProvider(state?.ChatHistoryProviderState ?? default, jsonSerializerOptions); - } - - /// - /// Gets or sets the used by this thread. - /// - public InMemoryChatHistoryProvider ChatHistoryProvider { get; } - - /// - /// Serializes the current object's state to a using the specified serialization options. - /// - /// The JSON serialization options to use. - /// A representation of the object's state. - protected internal virtual JsonElement Serialize(JsonSerializerOptions? jsonSerializerOptions = null) - { - var chatHistoryProviderState = this.ChatHistoryProvider.Serialize(jsonSerializerOptions); - - var state = new InMemoryAgentSessionState - { - ChatHistoryProviderState = chatHistoryProviderState, - }; - - return JsonSerializer.SerializeToElement(state, AgentAbstractionsJsonUtilities.DefaultOptions.GetTypeInfo(typeof(InMemoryAgentSessionState))); - } - - /// - public override object? GetService(Type serviceType, object? serviceKey = null) => - base.GetService(serviceType, serviceKey) ?? this.ChatHistoryProvider?.GetService(serviceType, serviceKey); - - [DebuggerBrowsable(DebuggerBrowsableState.Never)] - private string DebuggerDisplay => $"Count = {this.ChatHistoryProvider.Count}"; - - internal sealed class InMemoryAgentSessionState - { - public JsonElement? ChatHistoryProviderState { get; set; } - } -} diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/InMemoryChatHistoryProvider.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/InMemoryChatHistoryProvider.cs index 001b3a3bcc..9c535923f4 100644 --- a/dotnet/src/Microsoft.Agents.AI.Abstractions/InMemoryChatHistoryProvider.cs +++ b/dotnet/src/Microsoft.Agents.AI.Abstractions/InMemoryChatHistoryProvider.cs @@ -1,11 +1,10 @@ // Copyright (c) Microsoft. All rights reserved. using System; -using System.Collections; using System.Collections.Generic; -using System.Diagnostics; using System.Linq; using System.Text.Json; +using System.Text.Json.Serialization; using System.Threading; using System.Threading.Tasks; using Microsoft.Extensions.AI; @@ -14,122 +13,101 @@ namespace Microsoft.Agents.AI; /// -/// Provides an in-memory implementation of with support for message reduction and collection semantics. +/// Provides an in-memory implementation of with support for message reduction. /// /// /// -/// stores chat messages entirely in local memory, providing fast access and manipulation -/// capabilities. It implements both for agent integration and -/// for direct collection manipulation. +/// stores chat messages in the , +/// providing fast access and manipulation capabilities integrated with session state management. /// /// /// This maintains all messages in memory. For long-running conversations or high-volume scenarios, consider using /// message reduction strategies or alternative storage implementations. /// /// -[DebuggerDisplay("Count = {Count}")] -[DebuggerTypeProxy(typeof(DebugView))] -public sealed class InMemoryChatHistoryProvider : ChatHistoryProvider, IList, IReadOnlyList +public sealed class InMemoryChatHistoryProvider : ChatHistoryProvider { - private List _messages; + private static IEnumerable DefaultExcludeChatHistoryFilter(IEnumerable messages) + => messages.Where(m => m.GetAgentRequestMessageSourceType() != AgentRequestMessageSourceType.ChatHistory); + + private readonly string _stateKey; + private readonly Func _stateInitializer; + private readonly JsonSerializerOptions _jsonSerializerOptions; + private readonly Func, IEnumerable> _storageInputMessageFilter; + private readonly Func, IEnumerable>? _retrievalOutputMessageFilter; /// /// Initializes a new instance of the class. /// - /// - /// This constructor creates a basic in-memory without message reduction capabilities. - /// Messages will be stored exactly as added without any automatic processing or reduction. - /// - public InMemoryChatHistoryProvider() + /// + /// Optional configuration options that control the provider's behavior, including state initialization, + /// message reduction, and serialization settings. If , default settings will be used. + /// + public InMemoryChatHistoryProvider(InMemoryChatHistoryProviderOptions? options = null) { - this._messages = []; + this._stateInitializer = options?.StateInitializer ?? (_ => new State()); + this.ChatReducer = options?.ChatReducer; + this.ReducerTriggerEvent = options?.ReducerTriggerEvent ?? InMemoryChatHistoryProviderOptions.ChatReducerTriggerEvent.BeforeMessagesRetrieval; + this._stateKey = options?.StateKey ?? base.StateKey; + this._jsonSerializerOptions = options?.JsonSerializerOptions ?? AgentAbstractionsJsonUtilities.DefaultOptions; + this._storageInputMessageFilter = options?.StorageInputMessageFilter ?? DefaultExcludeChatHistoryFilter; + this._retrievalOutputMessageFilter = options?.RetrievalOutputMessageFilter; } + /// + public override string StateKey => this._stateKey; + /// - /// Initializes a new instance of the class from previously serialized state. + /// Gets the chat reducer used to process or reduce chat messages. If null, no reduction logic will be applied. /// - /// A representing the serialized state of the provider. - /// Optional settings for customizing the JSON deserialization process. - /// The is not a valid JSON object or cannot be deserialized. - /// - /// This constructor enables restoration of messages from previously saved state, allowing - /// conversation history to be preserved across application restarts or migrated between instances. - /// - public InMemoryChatHistoryProvider(JsonElement serializedState, JsonSerializerOptions? jsonSerializerOptions = null) - : this(null, serializedState, jsonSerializerOptions, ChatReducerTriggerEvent.BeforeMessagesRetrieval) - { - } + public IChatReducer? ChatReducer { get; } /// - /// Initializes a new instance of the class. + /// Gets the event that triggers the reducer invocation in this provider. /// - /// - /// A instance used to process, reduce, or optimize chat messages. - /// This can be used to implement strategies like message summarization, truncation, or cleanup. - /// - /// - /// Specifies when the message reducer should be invoked. The default is , - /// which applies reduction logic when messages are retrieved for agent consumption. - /// - /// is . - /// - /// Message reducers enable automatic management of message storage by implementing strategies to - /// keep memory usage under control while preserving important conversation context. - /// - public InMemoryChatHistoryProvider(IChatReducer chatReducer, ChatReducerTriggerEvent reducerTriggerEvent = ChatReducerTriggerEvent.BeforeMessagesRetrieval) - : this(chatReducer, default, null, reducerTriggerEvent) - { - Throw.IfNull(chatReducer); - } + public InMemoryChatHistoryProviderOptions.ChatReducerTriggerEvent ReducerTriggerEvent { get; } /// - /// Initializes a new instance of the class, with an existing state from a serialized JSON element. + /// Gets the chat messages stored for the specified session. /// - /// An optional instance used to process or reduce chat messages. If null, no reduction logic will be applied. - /// A representing the serialized state of the provider. - /// Optional settings for customizing the JSON deserialization process. - /// The event that should trigger the reducer invocation. - public InMemoryChatHistoryProvider(IChatReducer? chatReducer, JsonElement serializedState, JsonSerializerOptions? jsonSerializerOptions = null, ChatReducerTriggerEvent reducerTriggerEvent = ChatReducerTriggerEvent.BeforeMessagesRetrieval) - { - this.ChatReducer = chatReducer; - this.ReducerTriggerEvent = reducerTriggerEvent; - - if (serializedState.ValueKind is JsonValueKind.Object) - { - var jso = jsonSerializerOptions ?? AgentAbstractionsJsonUtilities.DefaultOptions; - var state = serializedState.Deserialize( - jso.GetTypeInfo(typeof(State))) as State; - if (state?.Messages is { } messages) - { - this._messages = messages; - return; - } - } - - this._messages = []; - } + /// The agent session containing the state. + /// A list of chat messages, or an empty list if no state is found. + public List GetMessages(AgentSession? session) + => this.GetOrInitializeState(session).Messages; /// - /// Gets the chat reducer used to process or reduce chat messages. If null, no reduction logic will be applied. + /// Sets the chat messages for the specified session. /// - public IChatReducer? ChatReducer { get; } + /// The agent session containing the state. + /// The messages to store. + /// is . + public void SetMessages(AgentSession? session, List messages) + { + _ = Throw.IfNull(messages); + + var state = this.GetOrInitializeState(session); + state.Messages = messages; + } /// - /// Gets the event that triggers the reducer invocation in this provider. + /// Gets the state from the session's StateBag, or initializes it using the state initializer if not present. /// - public ChatReducerTriggerEvent ReducerTriggerEvent { get; } - - /// - public int Count => this._messages.Count; + /// The agent session containing the StateBag. + /// The provider state, or null if no session is available. + private State GetOrInitializeState(AgentSession? session) + { + if (session?.StateBag.TryGetValue(this._stateKey, out var state, this._jsonSerializerOptions) is true && state is not null) + { + return state; + } - /// - public bool IsReadOnly => ((IList)this._messages).IsReadOnly; + state = this._stateInitializer(session); + if (session is not null) + { + session.StateBag.SetValue(this._stateKey, state, this._jsonSerializerOptions); + } - /// - public ChatMessage this[int index] - { - get => this._messages[index]; - set => this._messages[index] = value; + return state; } /// @@ -137,12 +115,21 @@ protected override async ValueTask> InvokingCoreAsync(I { _ = Throw.IfNull(context); - if (this.ReducerTriggerEvent is ChatReducerTriggerEvent.BeforeMessagesRetrieval && this.ChatReducer is not null) + var state = this.GetOrInitializeState(context.Session); + + if (this.ReducerTriggerEvent is InMemoryChatHistoryProviderOptions.ChatReducerTriggerEvent.BeforeMessagesRetrieval && this.ChatReducer is not null) { - this._messages = (await this.ChatReducer.ReduceAsync(this._messages, cancellationToken).ConfigureAwait(false)).ToList(); + state.Messages = (await this.ChatReducer.ReduceAsync(state.Messages, cancellationToken).ConfigureAwait(false)).ToList(); } - return this._messages; + IEnumerable output = state.Messages; + if (this._retrievalOutputMessageFilter is not null) + { + output = this._retrievalOutputMessageFilter(output); + } + return output + .Select(message => message.WithAgentRequestMessageSource(AgentRequestMessageSourceType.ChatHistory, this.GetType().FullName!)) + .Concat(context.RequestMessages); } /// @@ -155,94 +142,27 @@ protected override async ValueTask InvokedCoreAsync(InvokedContext context, Canc return; } + var state = this.GetOrInitializeState(context.Session); + // Add request and response messages to the provider - var allNewMessages = context.RequestMessages.Concat(context.ResponseMessages ?? []); - this._messages.AddRange(allNewMessages); + var allNewMessages = this._storageInputMessageFilter(context.RequestMessages).Concat(context.ResponseMessages ?? []); + state.Messages.AddRange(allNewMessages); - if (this.ReducerTriggerEvent is ChatReducerTriggerEvent.AfterMessageAdded && this.ChatReducer is not null) + if (this.ReducerTriggerEvent is InMemoryChatHistoryProviderOptions.ChatReducerTriggerEvent.AfterMessageAdded && this.ChatReducer is not null) { - this._messages = (await this.ChatReducer.ReduceAsync(this._messages, cancellationToken).ConfigureAwait(false)).ToList(); + state.Messages = (await this.ChatReducer.ReduceAsync(state.Messages, cancellationToken).ConfigureAwait(false)).ToList(); } } - /// - public override JsonElement Serialize(JsonSerializerOptions? jsonSerializerOptions = null) - { - State state = new() - { - Messages = this._messages, - }; - - var jso = jsonSerializerOptions ?? AgentAbstractionsJsonUtilities.DefaultOptions; - return JsonSerializer.SerializeToElement(state, jso.GetTypeInfo(typeof(State))); - } - - /// - public int IndexOf(ChatMessage item) - => this._messages.IndexOf(item); - - /// - public void Insert(int index, ChatMessage item) - => this._messages.Insert(index, item); - - /// - public void RemoveAt(int index) - => this._messages.RemoveAt(index); - - /// - public void Add(ChatMessage item) - => this._messages.Add(item); - - /// - public void Clear() - => this._messages.Clear(); - - /// - public bool Contains(ChatMessage item) - => this._messages.Contains(item); - - /// - public void CopyTo(ChatMessage[] array, int arrayIndex) - => this._messages.CopyTo(array, arrayIndex); - - /// - public bool Remove(ChatMessage item) - => this._messages.Remove(item); - - /// - public IEnumerator GetEnumerator() - => this._messages.GetEnumerator(); - - /// - IEnumerator IEnumerable.GetEnumerator() - => this.GetEnumerator(); - - internal sealed class State - { - public List Messages { get; set; } = []; - } - /// - /// Defines the events that can trigger a reducer in the . + /// Represents the state of a stored in the . /// - public enum ChatReducerTriggerEvent + public sealed class State { /// - /// Trigger the reducer when a new message is added. - /// will only complete when reducer processing is done. - /// - AfterMessageAdded, - - /// - /// Trigger the reducer before messages are retrieved from the provider. - /// The reducer will process the messages before they are returned to the caller. + /// Gets or sets the list of chat messages. /// - BeforeMessagesRetrieval - } - - private sealed class DebugView(InMemoryChatHistoryProvider provider) - { - [DebuggerBrowsable(DebuggerBrowsableState.RootHidden)] - public ChatMessage[] Items => provider._messages.ToArray(); + [JsonPropertyName("messages")] + public List Messages { get; set; } = []; } } diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/InMemoryChatHistoryProviderOptions.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/InMemoryChatHistoryProviderOptions.cs new file mode 100644 index 0000000000..41ab46321f --- /dev/null +++ b/dotnet/src/Microsoft.Agents.AI.Abstractions/InMemoryChatHistoryProviderOptions.cs @@ -0,0 +1,93 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Text.Json; +using Microsoft.Extensions.AI; + +namespace Microsoft.Agents.AI; + +/// +/// Represents configuration options for . +/// +public sealed class InMemoryChatHistoryProviderOptions +{ + /// + /// Gets or sets an optional delegate that initializes the provider state on the first invocation. + /// If , a default initializer that creates an empty state will be used. + /// + public Func? StateInitializer { get; set; } + + /// + /// Gets or sets an optional instance used to process, reduce, or optimize chat messages. + /// This can be used to implement strategies like message summarization, truncation, or cleanup. + /// + public IChatReducer? ChatReducer { get; set; } + + /// + /// Gets or sets when the message reducer should be invoked. + /// The default is , + /// which applies reduction logic when messages are retrieved for agent consumption. + /// + /// + /// Message reducers enable automatic management of message storage by implementing strategies to + /// keep memory usage under control while preserving important conversation context. + /// + public ChatReducerTriggerEvent ReducerTriggerEvent { get; set; } = ChatReducerTriggerEvent.BeforeMessagesRetrieval; + + /// + /// Gets or sets an optional key to use for storing the state in the . + /// If , a default key will be used. + /// + public string? StateKey { get; set; } + + /// + /// Gets or sets optional JSON serializer options for serializing the state of this provider. + /// This is valuable for cases like when the chat history contains custom types + /// and source generated serializers are required, or Native AOT / Trimming is required. + /// + public JsonSerializerOptions? JsonSerializerOptions { get; set; } + + /// + /// Gets or sets an optional filter function applied to request messages before they are added to storage + /// during . + /// + /// + /// When , the provider defaults to excluding messages with + /// source type to avoid + /// storing messages that came from chat history in the first place. + /// Depending on your requirements, you could provide a different filter, that also excludes + /// messages from e.g. AI context providers. + /// + public Func, IEnumerable>? StorageInputMessageFilter { get; set; } + + /// + /// Gets or sets an optional filter function applied to messages produced by this provider + /// during . + /// + /// + /// This filter is only applied to the messages that the provider itself produces (from its internal storage). + /// + /// + /// When , no filtering is applied to the output messages. + /// + public Func, IEnumerable>? RetrievalOutputMessageFilter { get; set; } + + /// + /// Defines the events that can trigger a reducer in the . + /// + public enum ChatReducerTriggerEvent + { + /// + /// Trigger the reducer when a new message is added. + /// will only complete when reducer processing is done. + /// + AfterMessageAdded, + + /// + /// Trigger the reducer before messages are retrieved from the provider. + /// The reducer will process the messages before they are returned to the caller. + /// + BeforeMessagesRetrieval + } +} diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/ServiceIdAgentSession.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/ServiceIdAgentSession.cs deleted file mode 100644 index cf00635984..0000000000 --- a/dotnet/src/Microsoft.Agents.AI.Abstractions/ServiceIdAgentSession.cs +++ /dev/null @@ -1,104 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. - -using System; -using System.Diagnostics; -using System.Text.Json; -using Microsoft.Shared.Diagnostics; - -namespace Microsoft.Agents.AI; - -/// -/// Provides a base class for agent sessions that store conversation state remotely in a service and maintain only an identifier reference locally. -/// -/// -/// This class is designed for scenarios where conversation state is managed by an external service (such as a cloud-based AI service) -/// rather than being stored locally. The session maintains only the service identifier needed to reference the remote conversation state. -/// -[DebuggerDisplay("ServiceSessionId = {ServiceSessionId}")] -public abstract class ServiceIdAgentSession : AgentSession -{ - /// - /// Initializes a new instance of the class without a service session identifier. - /// - /// - /// When using this constructor, the will be initially - /// and should be set by derived classes when the remote conversation is created. - /// - protected ServiceIdAgentSession() - { - } - - /// - /// Initializes a new instance of the class with the specified service session identifier. - /// - /// The unique identifier that references the conversation state stored in the remote service. - /// is . - /// is empty or contains only whitespace. - protected ServiceIdAgentSession(string serviceSessionId) - { - this.ServiceSessionId = Throw.IfNullOrEmpty(serviceSessionId); - } - - /// - /// Initializes a new instance of the class from previously serialized state. - /// - /// A representing the serialized state of the session. - /// Optional settings for customizing the JSON deserialization process. - /// The is not a JSON object. - /// The is invalid or cannot be deserialized to the expected type. - /// - /// This constructor enables restoration of a service-backed session from serialized state, typically used - /// when deserializing session information that was previously saved or transmitted across application boundaries. - /// - protected ServiceIdAgentSession( - JsonElement serializedState, - JsonSerializerOptions? jsonSerializerOptions = null) - { - if (serializedState.ValueKind != JsonValueKind.Object) - { - throw new ArgumentException("The serialized session state must be a JSON object.", nameof(serializedState)); - } - - var state = serializedState.Deserialize( - AgentAbstractionsJsonUtilities.DefaultOptions.GetTypeInfo(typeof(ServiceIdAgentSessionState))) as ServiceIdAgentSessionState; - - if (state?.ServiceSessionId is string serviceSessionId) - { - this.ServiceSessionId = serviceSessionId; - } - } - - /// - /// Gets or sets the unique identifier that references the conversation state stored in the remote service. - /// - /// - /// A string identifier that uniquely identifies the conversation within the remote service, - /// or if no remote conversation has been established yet. - /// - /// - /// This identifier is used by derived classes to reference the remote conversation state when making - /// API calls to the backing service. The exact format and meaning of this identifier depends on the - /// specific service implementation. - /// - protected string? ServiceSessionId { get; set; } - - /// - /// Serializes the current object's state to a using the specified serialization options. - /// - /// The JSON serialization options to use. - /// A representation of the object's state. - protected internal virtual JsonElement Serialize(JsonSerializerOptions? jsonSerializerOptions = null) - { - var state = new ServiceIdAgentSessionState - { - ServiceSessionId = this.ServiceSessionId, - }; - - return JsonSerializer.SerializeToElement(state, AgentAbstractionsJsonUtilities.DefaultOptions.GetTypeInfo(typeof(ServiceIdAgentSessionState))); - } - - internal sealed class ServiceIdAgentSessionState - { - public string? ServiceSessionId { get; set; } - } -} diff --git a/dotnet/src/Microsoft.Agents.AI.AzureAI.Persistent/PersistentAgentsClientExtensions.cs b/dotnet/src/Microsoft.Agents.AI.AzureAI.Persistent/PersistentAgentsClientExtensions.cs index 2058e6760b..660e874711 100644 --- a/dotnet/src/Microsoft.Agents.AI.AzureAI.Persistent/PersistentAgentsClientExtensions.cs +++ b/dotnet/src/Microsoft.Agents.AI.AzureAI.Persistent/PersistentAgentsClientExtensions.cs @@ -191,8 +191,8 @@ public static ChatClientAgent AsAIAgent( Name = options.Name ?? persistentAgentMetadata.Name, Description = options.Description ?? persistentAgentMetadata.Description, ChatOptions = options.ChatOptions, - AIContextProviderFactory = options.AIContextProviderFactory, - ChatHistoryProviderFactory = options.ChatHistoryProviderFactory, + AIContextProviders = options.AIContextProviders, + ChatHistoryProvider = options.ChatHistoryProvider, UseProvidedChatClientAsIs = options.UseProvidedChatClientAsIs }; diff --git a/dotnet/src/Microsoft.Agents.AI.AzureAI/AzureAIProjectChatClientExtensions.cs b/dotnet/src/Microsoft.Agents.AI.AzureAI/AzureAIProjectChatClientExtensions.cs index 8433ff3b6f..c35f49b088 100644 --- a/dotnet/src/Microsoft.Agents.AI.AzureAI/AzureAIProjectChatClientExtensions.cs +++ b/dotnet/src/Microsoft.Agents.AI.AzureAI/AzureAIProjectChatClientExtensions.cs @@ -594,8 +594,8 @@ private static ChatClientAgentOptions CreateChatClientAgentOptions(AgentVersion var agentOptions = CreateChatClientAgentOptions(agentVersion, options?.ChatOptions, requireInvocableTools); if (options is not null) { - agentOptions.AIContextProviderFactory = options.AIContextProviderFactory; - agentOptions.ChatHistoryProviderFactory = options.ChatHistoryProviderFactory; + agentOptions.AIContextProviders = options.AIContextProviders; + agentOptions.ChatHistoryProvider = options.ChatHistoryProvider; agentOptions.UseProvidedChatClientAsIs = options.UseProvidedChatClientAsIs; } diff --git a/dotnet/src/Microsoft.Agents.AI.CopilotStudio/CopilotStudioAgent.cs b/dotnet/src/Microsoft.Agents.AI.CopilotStudio/CopilotStudioAgent.cs index c631f90f17..bb53f94b50 100644 --- a/dotnet/src/Microsoft.Agents.AI.CopilotStudio/CopilotStudioAgent.cs +++ b/dotnet/src/Microsoft.Agents.AI.CopilotStudio/CopilotStudioAgent.cs @@ -68,7 +68,7 @@ protected override ValueTask SerializeSessionCoreAsync(AgentSession /// protected override ValueTask DeserializeSessionCoreAsync(JsonElement serializedState, JsonSerializerOptions? jsonSerializerOptions = null, CancellationToken cancellationToken = default) - => new(new CopilotStudioAgentSession(serializedState, jsonSerializerOptions)); + => new(CopilotStudioAgentSession.Deserialize(serializedState, jsonSerializerOptions)); /// protected override async Task RunCoreAsync( diff --git a/dotnet/src/Microsoft.Agents.AI.CopilotStudio/CopilotStudioAgentSession.cs b/dotnet/src/Microsoft.Agents.AI.CopilotStudio/CopilotStudioAgentSession.cs index ec3e21ca91..ba101df082 100644 --- a/dotnet/src/Microsoft.Agents.AI.CopilotStudio/CopilotStudioAgentSession.cs +++ b/dotnet/src/Microsoft.Agents.AI.CopilotStudio/CopilotStudioAgentSession.cs @@ -1,36 +1,58 @@ // Copyright (c) Microsoft. All rights reserved. +using System; +using System.Diagnostics; using System.Text.Json; +using System.Text.Json.Serialization; namespace Microsoft.Agents.AI.CopilotStudio; /// /// Session for CopilotStudio based agents. /// -public sealed class CopilotStudioAgentSession : ServiceIdAgentSession +[DebuggerDisplay("{DebuggerDisplay,nq}")] +public sealed class CopilotStudioAgentSession : AgentSession { internal CopilotStudioAgentSession() { } - internal CopilotStudioAgentSession(JsonElement serializedSessionState, JsonSerializerOptions? jsonSerializerOptions = null) : base(serializedSessionState, jsonSerializerOptions) + [JsonConstructor] + internal CopilotStudioAgentSession(string? conversationId, AgentSessionStateBag? stateBag) : base(stateBag ?? new()) { + this.ConversationId = conversationId; } /// /// Gets the ID for the current conversation with the Copilot Studio agent. /// - public string? ConversationId - { - get { return this.ServiceSessionId; } - internal set { this.ServiceSessionId = value; } - } + [JsonPropertyName("serviceSessionId")] + public string? ConversationId { get; internal set; } /// /// Serializes the current object's state to a using the specified serialization options. /// /// The JSON serialization options to use. /// A representation of the object's state. - internal new JsonElement Serialize(JsonSerializerOptions? jsonSerializerOptions = null) - => base.Serialize(jsonSerializerOptions); + internal JsonElement Serialize(JsonSerializerOptions? jsonSerializerOptions = null) + { + var jso = jsonSerializerOptions ?? CopilotStudioJsonUtilities.DefaultOptions; + return JsonSerializer.SerializeToElement(this, jso.GetTypeInfo(typeof(CopilotStudioAgentSession))); + } + + internal static CopilotStudioAgentSession Deserialize(JsonElement serializedState, JsonSerializerOptions? jsonSerializerOptions = null) + { + if (serializedState.ValueKind != JsonValueKind.Object) + { + throw new ArgumentException("The serialized session state must be a JSON object.", nameof(serializedState)); + } + + var jso = jsonSerializerOptions ?? CopilotStudioJsonUtilities.DefaultOptions; + return serializedState.Deserialize(jso.GetTypeInfo(typeof(CopilotStudioAgentSession))) as CopilotStudioAgentSession + ?? new CopilotStudioAgentSession(); + } + + [DebuggerBrowsable(DebuggerBrowsableState.Never)] + private string DebuggerDisplay => + $"ConversationId = {this.ConversationId}, StateBag Count = {this.StateBag.Count}"; } diff --git a/dotnet/src/Microsoft.Agents.AI.CopilotStudio/CopilotStudioJsonUtilities.cs b/dotnet/src/Microsoft.Agents.AI.CopilotStudio/CopilotStudioJsonUtilities.cs new file mode 100644 index 0000000000..44177b0708 --- /dev/null +++ b/dotnet/src/Microsoft.Agents.AI.CopilotStudio/CopilotStudioJsonUtilities.cs @@ -0,0 +1,48 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Diagnostics.CodeAnalysis; +using System.Text.Encodings.Web; +using System.Text.Json; +using System.Text.Json.Serialization; + +namespace Microsoft.Agents.AI.CopilotStudio; + +/// +/// Provides utility methods and configurations for JSON serialization operations within the Copilot Studio agent implementation. +/// +internal static partial class CopilotStudioJsonUtilities +{ + /// + /// Gets the default instance used for JSON serialization operations. + /// + public static JsonSerializerOptions DefaultOptions { get; } = CreateDefaultOptions(); + + /// + /// Creates and configures the default JSON serialization options. + /// + /// The configured options. + private static JsonSerializerOptions CreateDefaultOptions() + { + // Copy the configuration from the source generated context. + JsonSerializerOptions options = new(JsonContext.Default.Options) + { + Encoder = JavaScriptEncoder.UnsafeRelaxedJsonEscaping, + }; + + // Chain in the resolvers from both AgentAbstractionsJsonUtilities and our source generated context. + options.TypeInfoResolverChain.Clear(); + options.TypeInfoResolverChain.Add(AgentAbstractionsJsonUtilities.DefaultOptions.TypeInfoResolver!); + options.TypeInfoResolverChain.Add(JsonContext.Default.Options.TypeInfoResolver!); + + options.MakeReadOnly(); + return options; + } + + [JsonSourceGenerationOptions(JsonSerializerDefaults.Web, + UseStringEnumConverter = true, + DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull, + NumberHandling = JsonNumberHandling.AllowReadingFromString)] + [JsonSerializable(typeof(CopilotStudioAgentSession))] + [ExcludeFromCodeCoverage] + private sealed partial class JsonContext : JsonSerializerContext; +} diff --git a/dotnet/src/Microsoft.Agents.AI.CosmosNoSql/CosmosChatHistoryProvider.cs b/dotnet/src/Microsoft.Agents.AI.CosmosNoSql/CosmosChatHistoryProvider.cs index 85c5865f07..265f3a3675 100644 --- a/dotnet/src/Microsoft.Agents.AI.CosmosNoSql/CosmosChatHistoryProvider.cs +++ b/dotnet/src/Microsoft.Agents.AI.CosmosNoSql/CosmosChatHistoryProvider.cs @@ -21,17 +21,16 @@ namespace Microsoft.Agents.AI; [RequiresDynamicCode("The CosmosChatHistoryProvider uses JSON serialization which is incompatible with NativeAOT.")] public sealed class CosmosChatHistoryProvider : ChatHistoryProvider, IDisposable { + private static IEnumerable DefaultExcludeChatHistoryFilter(IEnumerable messages) + => messages.Where(m => m.GetAgentRequestMessageSourceType() != AgentRequestMessageSourceType.ChatHistory); + private readonly CosmosClient _cosmosClient; private readonly Container _container; private readonly bool _ownsClient; + private readonly string _stateKey; + private readonly Func _stateInitializer; private bool _disposed; - // Hierarchical partition key support - private readonly string? _tenantId; - private readonly string? _userId; - private readonly PartitionKey _partitionKey; - private readonly bool _useHierarchicalPartitioning; - /// /// Cached JSON serializer options for .NET 9.0 compatibility. /// @@ -47,6 +46,9 @@ private static JsonSerializerOptions CreateDefaultJsonOptions() return options; } + /// + public override string StateKey => this._stateKey; + /// /// Gets or sets the maximum number of messages to return in a single query batch. /// Default is 100 for optimal performance. @@ -72,11 +74,6 @@ private static JsonSerializerOptions CreateDefaultJsonOptions() /// public int? MessageTtlSeconds { get; set; } = 86400; - /// - /// Gets the conversation ID associated with this provider. - /// - public string ConversationId { get; init; } - /// /// Gets the database ID associated with this provider. /// @@ -88,49 +85,50 @@ private static JsonSerializerOptions CreateDefaultJsonOptions() public string ContainerId { get; init; } /// - /// Internal primary constructor used by all public constructors. + /// A filter function applied to request messages before they are stored + /// during . The default filter excludes messages with the + /// source type. /// - /// The instance to use for Cosmos DB operations. - /// The identifier of the Cosmos DB database. - /// The identifier of the Cosmos DB container. - /// The unique identifier for this conversation thread. - /// Whether this instance owns the CosmosClient and should dispose it. - /// Optional tenant identifier for hierarchical partitioning. - /// Optional user identifier for hierarchical partitioning. - internal CosmosChatHistoryProvider(CosmosClient cosmosClient, string databaseId, string containerId, string conversationId, bool ownsClient, string? tenantId = null, string? userId = null) - { - this._cosmosClient = Throw.IfNull(cosmosClient); - this._container = this._cosmosClient.GetContainer(Throw.IfNullOrWhitespace(databaseId), Throw.IfNullOrWhitespace(containerId)); - this.ConversationId = Throw.IfNullOrWhitespace(conversationId); - this.DatabaseId = databaseId; - this.ContainerId = containerId; - this._ownsClient = ownsClient; + public Func, IEnumerable> StorageInputMessageFilter { get; set { field = Throw.IfNull(value); } } = DefaultExcludeChatHistoryFilter; - // Initialize partitioning mode - this._tenantId = tenantId; - this._userId = userId; - this._useHierarchicalPartitioning = tenantId != null && userId != null; - - this._partitionKey = this._useHierarchicalPartitioning - ? new PartitionKeyBuilder() - .Add(tenantId!) - .Add(userId!) - .Add(conversationId) - .Build() - : new PartitionKey(conversationId); - } + /// + /// Gets or sets an optional filter function applied to messages produced by this provider + /// during . + /// + /// + /// This filter is only applied to the messages that the provider itself produces (from its internal storage). + /// + /// + /// When , no filtering is applied to the output messages. + /// + public Func, IEnumerable>? RetrievalOutputMessageFilter { get; set; } /// - /// Initializes a new instance of the class using a connection string. + /// Initializes a new instance of the class. /// - /// The Cosmos DB connection string. + /// The instance to use for Cosmos DB operations. /// The identifier of the Cosmos DB database. /// The identifier of the Cosmos DB container. - /// Thrown when any required parameter is null. + /// A delegate that initializes the provider state on the first invocation, providing the conversation routing info (conversationId, tenantId, userId). + /// Whether this instance owns the CosmosClient and should dispose it. + /// An optional key to use for storing the state in the . + /// Thrown when or is . /// Thrown when any string parameter is null or whitespace. - public CosmosChatHistoryProvider(string connectionString, string databaseId, string containerId) - : this(connectionString, databaseId, containerId, Guid.NewGuid().ToString("N")) + public CosmosChatHistoryProvider( + CosmosClient cosmosClient, + string databaseId, + string containerId, + Func stateInitializer, + bool ownsClient = false, + string? stateKey = null) { + this._cosmosClient = Throw.IfNull(cosmosClient); + this.DatabaseId = Throw.IfNullOrWhitespace(databaseId); + this.ContainerId = Throw.IfNullOrWhitespace(containerId); + this._container = this._cosmosClient.GetContainer(databaseId, containerId); + this._stateInitializer = Throw.IfNull(stateInitializer); + this._ownsClient = ownsClient; + this._stateKey = stateKey ?? base.StateKey; } /// @@ -139,11 +137,17 @@ public CosmosChatHistoryProvider(string connectionString, string databaseId, str /// The Cosmos DB connection string. /// The identifier of the Cosmos DB database. /// The identifier of the Cosmos DB container. - /// The unique identifier for this conversation thread. + /// A delegate that initializes the provider state on the first invocation. + /// An optional key to use for storing the state in the . /// Thrown when any required parameter is null. /// Thrown when any string parameter is null or whitespace. - public CosmosChatHistoryProvider(string connectionString, string databaseId, string containerId, string conversationId) - : this(new CosmosClient(Throw.IfNullOrWhitespace(connectionString)), databaseId, containerId, conversationId, ownsClient: true) + public CosmosChatHistoryProvider( + string connectionString, + string databaseId, + string containerId, + Func stateInitializer, + string? stateKey = null) + : this(new CosmosClient(Throw.IfNullOrWhitespace(connectionString)), databaseId, containerId, stateInitializer, ownsClient: true, stateKey) { } @@ -154,136 +158,63 @@ public CosmosChatHistoryProvider(string connectionString, string databaseId, str /// The TokenCredential to use for authentication (e.g., DefaultAzureCredential, ManagedIdentityCredential). /// The identifier of the Cosmos DB database. /// The identifier of the Cosmos DB container. + /// A delegate that initializes the provider state on the first invocation. + /// An optional key to use for storing the state in the . /// Thrown when any required parameter is null. /// Thrown when any string parameter is null or whitespace. - public CosmosChatHistoryProvider(string accountEndpoint, TokenCredential tokenCredential, string databaseId, string containerId) - : this(accountEndpoint, tokenCredential, databaseId, containerId, Guid.NewGuid().ToString("N")) + public CosmosChatHistoryProvider( + string accountEndpoint, + TokenCredential tokenCredential, + string databaseId, + string containerId, + Func stateInitializer, + string? stateKey = null) + : this(new CosmosClient(Throw.IfNullOrWhitespace(accountEndpoint), Throw.IfNull(tokenCredential)), databaseId, containerId, stateInitializer, ownsClient: true, stateKey) { } /// - /// Initializes a new instance of the class using a TokenCredential for authentication. + /// Gets the state from the session's StateBag, or initializes it using the state initializer if not present. /// - /// The Cosmos DB account endpoint URI. - /// The TokenCredential to use for authentication (e.g., DefaultAzureCredential, ManagedIdentityCredential). - /// The identifier of the Cosmos DB database. - /// The identifier of the Cosmos DB container. - /// The unique identifier for this conversation thread. - /// Thrown when any required parameter is null. - /// Thrown when any string parameter is null or whitespace. - public CosmosChatHistoryProvider(string accountEndpoint, TokenCredential tokenCredential, string databaseId, string containerId, string conversationId) - : this(new CosmosClient(Throw.IfNullOrWhitespace(accountEndpoint), Throw.IfNull(tokenCredential)), databaseId, containerId, conversationId, ownsClient: true) + /// The agent session containing the StateBag. + /// The provider state, or null if no session is available. + private State GetOrInitializeState(AgentSession? session) { - } - - /// - /// Initializes a new instance of the class using an existing . - /// - /// The instance to use for Cosmos DB operations. - /// The identifier of the Cosmos DB database. - /// The identifier of the Cosmos DB container. - /// Thrown when is null. - /// Thrown when any string parameter is null or whitespace. - public CosmosChatHistoryProvider(CosmosClient cosmosClient, string databaseId, string containerId) - : this(cosmosClient, databaseId, containerId, Guid.NewGuid().ToString("N")) - { - } - - /// - /// Initializes a new instance of the class using an existing . - /// - /// The instance to use for Cosmos DB operations. - /// The identifier of the Cosmos DB database. - /// The identifier of the Cosmos DB container. - /// The unique identifier for this conversation thread. - /// Thrown when is null. - /// Thrown when any string parameter is null or whitespace. - public CosmosChatHistoryProvider(CosmosClient cosmosClient, string databaseId, string containerId, string conversationId) - : this(cosmosClient, databaseId, containerId, conversationId, ownsClient: false) - { - } + if (session?.StateBag.TryGetValue(this._stateKey, out var state, AgentAbstractionsJsonUtilities.DefaultOptions) is true && state is not null) + { + return state; + } - /// - /// Initializes a new instance of the class using a connection string with hierarchical partition keys. - /// - /// The Cosmos DB connection string. - /// The identifier of the Cosmos DB database. - /// The identifier of the Cosmos DB container. - /// The tenant identifier for hierarchical partitioning. - /// The user identifier for hierarchical partitioning. - /// The session identifier for hierarchical partitioning. - /// Thrown when any required parameter is null. - /// Thrown when any string parameter is null or whitespace. - public CosmosChatHistoryProvider(string connectionString, string databaseId, string containerId, string tenantId, string userId, string sessionId) - : this(new CosmosClient(Throw.IfNullOrWhitespace(connectionString)), databaseId, containerId, Throw.IfNullOrWhitespace(sessionId), ownsClient: true, Throw.IfNullOrWhitespace(tenantId), Throw.IfNullOrWhitespace(userId)) - { - } + state = this._stateInitializer(session); + if (session is not null) + { + session.StateBag.SetValue(this._stateKey, state, AgentAbstractionsJsonUtilities.DefaultOptions); + } - /// - /// Initializes a new instance of the class using a TokenCredential for authentication with hierarchical partition keys. - /// - /// The Cosmos DB account endpoint URI. - /// The TokenCredential to use for authentication (e.g., DefaultAzureCredential, ManagedIdentityCredential). - /// The identifier of the Cosmos DB database. - /// The identifier of the Cosmos DB container. - /// The tenant identifier for hierarchical partitioning. - /// The user identifier for hierarchical partitioning. - /// The session identifier for hierarchical partitioning. - /// Thrown when any required parameter is null. - /// Thrown when any string parameter is null or whitespace. - public CosmosChatHistoryProvider(string accountEndpoint, TokenCredential tokenCredential, string databaseId, string containerId, string tenantId, string userId, string sessionId) - : this(new CosmosClient(Throw.IfNullOrWhitespace(accountEndpoint), Throw.IfNull(tokenCredential)), databaseId, containerId, Throw.IfNullOrWhitespace(sessionId), ownsClient: true, Throw.IfNullOrWhitespace(tenantId), Throw.IfNullOrWhitespace(userId)) - { + return state; } /// - /// Initializes a new instance of the class using an existing with hierarchical partition keys. + /// Determines whether hierarchical partitioning should be used based on the state. /// - /// The instance to use for Cosmos DB operations. - /// The identifier of the Cosmos DB database. - /// The identifier of the Cosmos DB container. - /// The tenant identifier for hierarchical partitioning. - /// The user identifier for hierarchical partitioning. - /// The session identifier for hierarchical partitioning. - /// Thrown when is null. - /// Thrown when any string parameter is null or whitespace. - public CosmosChatHistoryProvider(CosmosClient cosmosClient, string databaseId, string containerId, string tenantId, string userId, string sessionId) - : this(cosmosClient, databaseId, containerId, Throw.IfNullOrWhitespace(sessionId), ownsClient: false, Throw.IfNullOrWhitespace(tenantId), Throw.IfNullOrWhitespace(userId)) - { - } + private static bool UseHierarchicalPartitioning(State state) => + state.TenantId is not null && state.UserId is not null; /// - /// Creates a new instance of the class from previously serialized state. + /// Builds the partition key from the state. /// - /// The instance to use for Cosmos DB operations. - /// A representing the serialized state of the provider. - /// The identifier of the Cosmos DB database. - /// The identifier of the Cosmos DB container. - /// Optional settings for customizing the JSON deserialization process. - /// A new instance of initialized from the serialized state. - /// Thrown when is null. - /// Thrown when the serialized state cannot be deserialized. - public static CosmosChatHistoryProvider CreateFromSerializedState(CosmosClient cosmosClient, JsonElement serializedState, string databaseId, string containerId, JsonSerializerOptions? jsonSerializerOptions = null) + private static PartitionKey BuildPartitionKey(State state) { - Throw.IfNull(cosmosClient); - Throw.IfNullOrWhitespace(databaseId); - Throw.IfNullOrWhitespace(containerId); - - if (serializedState.ValueKind is not JsonValueKind.Object) - { - throw new ArgumentException("Invalid serialized state", nameof(serializedState)); - } - - var state = serializedState.Deserialize(jsonSerializerOptions); - if (state?.ConversationIdentifier is not { } conversationId) + if (UseHierarchicalPartitioning(state)) { - throw new ArgumentException("Invalid serialized state", nameof(serializedState)); + return new PartitionKeyBuilder() + .Add(state.TenantId) + .Add(state.UserId) + .Add(state.ConversationId) + .Build(); } - // Use the internal constructor with all parameters to ensure partition key logic is centralized - return state.UseHierarchicalPartitioning && state.TenantId != null && state.UserId != null - ? new CosmosChatHistoryProvider(cosmosClient, databaseId, containerId, conversationId, ownsClient: false, state.TenantId, state.UserId) - : new CosmosChatHistoryProvider(cosmosClient, databaseId, containerId, conversationId, ownsClient: false); + return new PartitionKey(state.ConversationId); } /// @@ -296,15 +227,20 @@ protected override async ValueTask> InvokingCoreAsync(I } #pragma warning restore CA1513 + _ = Throw.IfNull(context); + + var state = this.GetOrInitializeState(context.Session); + var partitionKey = BuildPartitionKey(state); + // Fetch most recent messages in descending order when limit is set, then reverse to ascending var orderDirection = this.MaxMessagesToRetrieve.HasValue ? "DESC" : "ASC"; var query = new QueryDefinition($"SELECT * FROM c WHERE c.conversationId = @conversationId AND c.type = @type ORDER BY c.timestamp {orderDirection}") - .WithParameter("@conversationId", this.ConversationId) + .WithParameter("@conversationId", state.ConversationId) .WithParameter("@type", "ChatMessage"); var iterator = this._container.GetItemQueryIterator(query, requestOptions: new QueryRequestOptions { - PartitionKey = this._partitionKey, + PartitionKey = partitionKey, MaxItemCount = this.MaxItemCount // Configurable query performance }); @@ -343,7 +279,9 @@ protected override async ValueTask> InvokingCoreAsync(I messages.Reverse(); } - return messages; + return (this.RetrievalOutputMessageFilter is not null ? this.RetrievalOutputMessageFilter(messages) : messages) + .Select(message => message.WithAgentRequestMessageSource(AgentRequestMessageSourceType.ChatHistory, this.GetType().FullName!)) + .Concat(context.RequestMessages); } /// @@ -364,27 +302,30 @@ protected override async ValueTask InvokedCoreAsync(InvokedContext context, Canc } #pragma warning restore CA1513 - var messageList = context.RequestMessages.Concat(context.ResponseMessages ?? []).ToList(); + var state = this.GetOrInitializeState(context.Session); + var messageList = this.StorageInputMessageFilter(context.RequestMessages).Concat(context.ResponseMessages ?? []).ToList(); if (messageList.Count == 0) { return; } + var partitionKey = BuildPartitionKey(state); + // Use transactional batch for atomic operations if (messageList.Count > 1) { - await this.AddMessagesInBatchAsync(messageList, cancellationToken).ConfigureAwait(false); + await this.AddMessagesInBatchAsync(partitionKey, state, messageList, cancellationToken).ConfigureAwait(false); } else { - await this.AddSingleMessageAsync(messageList.First(), cancellationToken).ConfigureAwait(false); + await this.AddSingleMessageAsync(partitionKey, state, messageList.First(), cancellationToken).ConfigureAwait(false); } } /// /// Adds multiple messages using transactional batch operations for atomicity. /// - private async Task AddMessagesInBatchAsync(List messages, CancellationToken cancellationToken) + private async Task AddMessagesInBatchAsync(PartitionKey partitionKey, State state, List messages, CancellationToken cancellationToken) { var currentTimestamp = DateTimeOffset.UtcNow.ToUnixTimeSeconds(); @@ -392,7 +333,7 @@ private async Task AddMessagesInBatchAsync(List messages, Cancellat for (int i = 0; i < messages.Count; i += this.MaxBatchSize) { var batchMessages = messages.Skip(i).Take(this.MaxBatchSize).ToList(); - await this.ExecuteBatchOperationAsync(batchMessages, currentTimestamp, cancellationToken).ConfigureAwait(false); + await this.ExecuteBatchOperationAsync(partitionKey, state, batchMessages, currentTimestamp, cancellationToken).ConfigureAwait(false); } } @@ -400,13 +341,13 @@ private async Task AddMessagesInBatchAsync(List messages, Cancellat /// Executes a single batch operation with enhanced error handling. /// Cosmos SDK handles throttling (429) retries automatically. /// - private async Task ExecuteBatchOperationAsync(List messages, long timestamp, CancellationToken cancellationToken) + private async Task ExecuteBatchOperationAsync(PartitionKey partitionKey, State state, List messages, long timestamp, CancellationToken cancellationToken) { // Create all documents upfront for validation and batch operation var documents = new List(messages.Count); foreach (var message in messages) { - documents.Add(this.CreateMessageDocument(message, timestamp)); + documents.Add(this.CreateMessageDocument(state, message, timestamp)); } // Defensive check: Verify all messages share the same partition key values @@ -414,7 +355,7 @@ private async Task ExecuteBatchOperationAsync(List messages, long t // In simple partitioning, this means same conversationId if (documents.Count > 0) { - if (this._useHierarchicalPartitioning) + if (UseHierarchicalPartitioning(state)) { // Verify all documents have matching hierarchical partition key components var firstDoc = documents[0]; @@ -436,7 +377,7 @@ private async Task ExecuteBatchOperationAsync(List messages, long t // All messages in this store share the same partition key by design // Transactional batches require all items to share the same partition key - var batch = this._container.CreateTransactionalBatch(this._partitionKey); + var batch = this._container.CreateTransactionalBatch(partitionKey); foreach (var document in documents) { @@ -457,7 +398,7 @@ private async Task ExecuteBatchOperationAsync(List messages, long t if (messages.Count == 1) { // Can't split further, use single operation - await this.AddSingleMessageAsync(messages[0], cancellationToken).ConfigureAwait(false); + await this.AddSingleMessageAsync(partitionKey, state, messages[0], cancellationToken).ConfigureAwait(false); return; } @@ -466,21 +407,21 @@ private async Task ExecuteBatchOperationAsync(List messages, long t var firstHalf = messages.Take(midpoint).ToList(); var secondHalf = messages.Skip(midpoint).ToList(); - await this.ExecuteBatchOperationAsync(firstHalf, timestamp, cancellationToken).ConfigureAwait(false); - await this.ExecuteBatchOperationAsync(secondHalf, timestamp, cancellationToken).ConfigureAwait(false); + await this.ExecuteBatchOperationAsync(partitionKey, state, firstHalf, timestamp, cancellationToken).ConfigureAwait(false); + await this.ExecuteBatchOperationAsync(partitionKey, state, secondHalf, timestamp, cancellationToken).ConfigureAwait(false); } } /// /// Adds a single message to the store. /// - private async Task AddSingleMessageAsync(ChatMessage message, CancellationToken cancellationToken) + private async Task AddSingleMessageAsync(PartitionKey partitionKey, State state, ChatMessage message, CancellationToken cancellationToken) { - var document = this.CreateMessageDocument(message, DateTimeOffset.UtcNow.ToUnixTimeSeconds()); + var document = this.CreateMessageDocument(state, message, DateTimeOffset.UtcNow.ToUnixTimeSeconds()); try { - await this._container.CreateItemAsync(document, this._partitionKey, cancellationToken: cancellationToken).ConfigureAwait(false); + await this._container.CreateItemAsync(document, partitionKey, cancellationToken: cancellationToken).ConfigureAwait(false); } catch (CosmosException ex) when (ex.StatusCode == System.Net.HttpStatusCode.RequestEntityTooLarge) { @@ -495,12 +436,14 @@ private async Task AddSingleMessageAsync(ChatMessage message, CancellationToken /// /// Creates a message document with enhanced metadata. /// - private CosmosMessageDocument CreateMessageDocument(ChatMessage message, long timestamp) + private CosmosMessageDocument CreateMessageDocument(State state, ChatMessage message, long timestamp) { + var useHierarchical = UseHierarchicalPartitioning(state); + return new CosmosMessageDocument { Id = Guid.NewGuid().ToString(), - ConversationId = this.ConversationId, + ConversationId = state.ConversationId, Timestamp = timestamp, MessageId = message.MessageId, Role = message.Role.Value, @@ -508,41 +451,20 @@ private CosmosMessageDocument CreateMessageDocument(ChatMessage message, long ti Type = "ChatMessage", // Type discriminator Ttl = this.MessageTtlSeconds, // Configurable TTL // Include hierarchical metadata when using hierarchical partitioning - TenantId = this._useHierarchicalPartitioning ? this._tenantId : null, - UserId = this._useHierarchicalPartitioning ? this._userId : null, - SessionId = this._useHierarchicalPartitioning ? this.ConversationId : null - }; - } - - /// - public override JsonElement Serialize(JsonSerializerOptions? jsonSerializerOptions = null) - { -#pragma warning disable CA1513 // Use ObjectDisposedException.ThrowIf - not available on all target frameworks - if (this._disposed) - { - throw new ObjectDisposedException(this.GetType().FullName); - } -#pragma warning restore CA1513 - - var state = new State - { - ConversationIdentifier = this.ConversationId, - TenantId = this._tenantId, - UserId = this._userId, - UseHierarchicalPartitioning = this._useHierarchicalPartitioning + TenantId = useHierarchical ? state.TenantId : null, + UserId = useHierarchical ? state.UserId : null, + SessionId = useHierarchical ? state.ConversationId : null }; - - var options = jsonSerializerOptions ?? s_defaultJsonOptions; - return JsonSerializer.SerializeToElement(state, options); } /// /// Gets the count of messages in this conversation. /// This is an additional utility method beyond the base contract. /// + /// The agent session to get state from. /// The cancellation token. /// The number of messages in the conversation. - public async Task GetMessageCountAsync(CancellationToken cancellationToken = default) + public async Task GetMessageCountAsync(AgentSession? session, CancellationToken cancellationToken = default) { #pragma warning disable CA1513 // Use ObjectDisposedException.ThrowIf - not available on all target frameworks if (this._disposed) @@ -551,14 +473,17 @@ public async Task GetMessageCountAsync(CancellationToken cancellationToken } #pragma warning restore CA1513 + var state = this.GetOrInitializeState(session); + var partitionKey = BuildPartitionKey(state); + // Efficient count query var query = new QueryDefinition("SELECT VALUE COUNT(1) FROM c WHERE c.conversationId = @conversationId AND c.Type = @type") - .WithParameter("@conversationId", this.ConversationId) + .WithParameter("@conversationId", state.ConversationId) .WithParameter("@type", "ChatMessage"); var iterator = this._container.GetItemQueryIterator(query, requestOptions: new QueryRequestOptions { - PartitionKey = this._partitionKey + PartitionKey = partitionKey }); // COUNT queries always return a result @@ -570,9 +495,10 @@ public async Task GetMessageCountAsync(CancellationToken cancellationToken /// Deletes all messages in this conversation. /// This is an additional utility method beyond the base contract. /// + /// The agent session to get state from. /// The cancellation token. /// The number of messages deleted. - public async Task ClearMessagesAsync(CancellationToken cancellationToken = default) + public async Task ClearMessagesAsync(AgentSession? session, CancellationToken cancellationToken = default) { #pragma warning disable CA1513 // Use ObjectDisposedException.ThrowIf - not available on all target frameworks if (this._disposed) @@ -581,14 +507,17 @@ public async Task ClearMessagesAsync(CancellationToken cancellationToken = } #pragma warning restore CA1513 + var state = this.GetOrInitializeState(session); + var partitionKey = BuildPartitionKey(state); + // Batch delete for efficiency var query = new QueryDefinition("SELECT VALUE c.id FROM c WHERE c.conversationId = @conversationId AND c.Type = @type") - .WithParameter("@conversationId", this.ConversationId) + .WithParameter("@conversationId", state.ConversationId) .WithParameter("@type", "ChatMessage"); var iterator = this._container.GetItemQueryIterator(query, requestOptions: new QueryRequestOptions { - PartitionKey = this._partitionKey, + PartitionKey = partitionKey, MaxItemCount = this.MaxItemCount }); @@ -597,7 +526,7 @@ public async Task ClearMessagesAsync(CancellationToken cancellationToken = while (iterator.HasMoreResults) { var response = await iterator.ReadNextAsync(cancellationToken).ConfigureAwait(false); - var batch = this._container.CreateTransactionalBatch(this._partitionKey); + var batch = this._container.CreateTransactionalBatch(partitionKey); var batchItemCount = 0; foreach (var itemId in response) @@ -632,12 +561,38 @@ public void Dispose() } } - private sealed class State + /// + /// Represents the per-session state of a stored in the . + /// + public sealed class State { - public string ConversationIdentifier { get; set; } = string.Empty; - public string? TenantId { get; set; } - public string? UserId { get; set; } - public bool UseHierarchicalPartitioning { get; set; } + /// + /// Initializes a new instance of the class. + /// + /// The unique identifier for this conversation thread. + /// Optional tenant identifier for hierarchical partitioning. + /// Optional user identifier for hierarchical partitioning. + public State(string conversationId, string? tenantId = null, string? userId = null) + { + this.ConversationId = Throw.IfNullOrWhitespace(conversationId); + this.TenantId = tenantId; + this.UserId = userId; + } + + /// + /// Gets the conversation ID associated with this state. + /// + public string ConversationId { get; } + + /// + /// Gets the tenant identifier for hierarchical partitioning, if any. + /// + public string? TenantId { get; } + + /// + /// Gets the user identifier for hierarchical partitioning, if any. + /// + public string? UserId { get; } } /// diff --git a/dotnet/src/Microsoft.Agents.AI.CosmosNoSql/CosmosDBChatExtensions.cs b/dotnet/src/Microsoft.Agents.AI.CosmosNoSql/CosmosDBChatExtensions.cs index 3d93e9dd6a..76b865e4c8 100644 --- a/dotnet/src/Microsoft.Agents.AI.CosmosNoSql/CosmosDBChatExtensions.cs +++ b/dotnet/src/Microsoft.Agents.AI.CosmosNoSql/CosmosDBChatExtensions.cs @@ -2,7 +2,6 @@ using System; using System.Diagnostics.CodeAnalysis; -using System.Threading.Tasks; using Azure.Core; using Microsoft.Azure.Cosmos; @@ -13,6 +12,9 @@ namespace Microsoft.Agents.AI; /// public static class CosmosDBChatExtensions { + private static readonly Func s_defaultStateInitializer = + _ => new CosmosChatHistoryProvider.State(Guid.NewGuid().ToString("N")); + /// /// Configures the agent to use Cosmos DB for message storage with connection string authentication. /// @@ -20,6 +22,7 @@ public static class CosmosDBChatExtensions /// The Cosmos DB connection string. /// The identifier of the Cosmos DB database. /// The identifier of the Cosmos DB container. + /// An optional delegate that initializes the provider state on the first invocation, providing the conversation routing info (conversationId, tenantId, userId). When not provided, a new conversation ID is generated automatically. /// The configured . /// Thrown when is null. /// Thrown when any string parameter is null or whitespace. @@ -29,14 +32,16 @@ public static ChatClientAgentOptions WithCosmosDBChatHistoryProvider( this ChatClientAgentOptions options, string connectionString, string databaseId, - string containerId) + string containerId, + Func? stateInitializer = null) { if (options is null) { throw new ArgumentNullException(nameof(options)); } - options.ChatHistoryProviderFactory = (context, ct) => new ValueTask(new CosmosChatHistoryProvider(connectionString, databaseId, containerId)); + options.ChatHistoryProvider = + new CosmosChatHistoryProvider(connectionString, databaseId, containerId, stateInitializer ?? s_defaultStateInitializer); return options; } @@ -48,6 +53,7 @@ public static ChatClientAgentOptions WithCosmosDBChatHistoryProvider( /// The identifier of the Cosmos DB database. /// The identifier of the Cosmos DB container. /// The TokenCredential to use for authentication (e.g., DefaultAzureCredential, ManagedIdentityCredential). + /// An optional delegate that initializes the provider state on the first invocation, providing the conversation routing info (conversationId, tenantId, userId). When not provided, a new conversation ID is generated automatically. /// The configured . /// Thrown when or is null. /// Thrown when any string parameter is null or whitespace. @@ -58,7 +64,8 @@ public static ChatClientAgentOptions WithCosmosDBChatHistoryProviderUsingManaged string accountEndpoint, string databaseId, string containerId, - TokenCredential tokenCredential) + TokenCredential tokenCredential, + Func? stateInitializer = null) { if (options is null) { @@ -70,7 +77,8 @@ public static ChatClientAgentOptions WithCosmosDBChatHistoryProviderUsingManaged throw new ArgumentNullException(nameof(tokenCredential)); } - options.ChatHistoryProviderFactory = (context, ct) => new ValueTask(new CosmosChatHistoryProvider(accountEndpoint, tokenCredential, databaseId, containerId)); + options.ChatHistoryProvider = + new CosmosChatHistoryProvider(accountEndpoint, tokenCredential, databaseId, containerId, stateInitializer ?? s_defaultStateInitializer); return options; } @@ -81,6 +89,7 @@ public static ChatClientAgentOptions WithCosmosDBChatHistoryProviderUsingManaged /// The instance to use for Cosmos DB operations. /// The identifier of the Cosmos DB database. /// The identifier of the Cosmos DB container. + /// An optional delegate that initializes the provider state on the first invocation, providing the conversation routing info (conversationId, tenantId, userId). When not provided, a new conversation ID is generated automatically. /// The configured . /// Thrown when any required parameter is null. /// Thrown when any string parameter is null or whitespace. @@ -90,14 +99,16 @@ public static ChatClientAgentOptions WithCosmosDBChatHistoryProvider( this ChatClientAgentOptions options, CosmosClient cosmosClient, string databaseId, - string containerId) + string containerId, + Func? stateInitializer = null) { if (options is null) { throw new ArgumentNullException(nameof(options)); } - options.ChatHistoryProviderFactory = (context, ct) => new ValueTask(new CosmosChatHistoryProvider(cosmosClient, databaseId, containerId)); + options.ChatHistoryProvider = + new CosmosChatHistoryProvider(cosmosClient, databaseId, containerId, stateInitializer ?? s_defaultStateInitializer); return options; } } diff --git a/dotnet/src/Microsoft.Agents.AI.DurableTask/CHANGELOG.md b/dotnet/src/Microsoft.Agents.AI.DurableTask/CHANGELOG.md index d8260fcb84..ff886e2ebe 100644 --- a/dotnet/src/Microsoft.Agents.AI.DurableTask/CHANGELOG.md +++ b/dotnet/src/Microsoft.Agents.AI.DurableTask/CHANGELOG.md @@ -12,6 +12,7 @@ - Renamed serializedSession parameter to serializedState on DeserializeSessionAsync for consistency ([#3681](https://github.com/microsoft/agent-framework/pull/3681)) - Introduce Core method pattern for Session management methods on AIAgent ([#3699](https://github.com/microsoft/agent-framework/pull/3699)) - Changed AIAgent.SerializeSession to AIAgent.SerializeSessionAsync ([#3879](https://github.com/microsoft/agent-framework/pull/3879)) +- Changed ChatHistory and AIContext Providers to have pipeline semantics ([#3806](https://github.com/microsoft/agent-framework/pull/3806)) ## v1.0.0-preview.251204.1 diff --git a/dotnet/src/Microsoft.Agents.AI.DurableTask/DurableAgentSession.cs b/dotnet/src/Microsoft.Agents.AI.DurableTask/DurableAgentSession.cs index b9d9807728..ba33c15d32 100644 --- a/dotnet/src/Microsoft.Agents.AI.DurableTask/DurableAgentSession.cs +++ b/dotnet/src/Microsoft.Agents.AI.DurableTask/DurableAgentSession.cs @@ -7,17 +7,22 @@ namespace Microsoft.Agents.AI.DurableTask; /// -/// An agent thread implementation for durable agents. +/// An implementation for durable agents. /// -[DebuggerDisplay("{SessionId}")] +[DebuggerDisplay("{DebuggerDisplay,nq}")] public sealed class DurableAgentSession : AgentSession { - [JsonConstructor] internal DurableAgentSession(AgentSessionId sessionId) { this.SessionId = sessionId; } + [JsonConstructor] + internal DurableAgentSession(AgentSessionId sessionId, AgentSessionStateBag stateBag) : base(stateBag) + { + this.SessionId = sessionId; + } + /// /// Gets the agent session ID. /// @@ -28,9 +33,8 @@ internal DurableAgentSession(AgentSessionId sessionId) /// internal JsonElement Serialize(JsonSerializerOptions? jsonSerializerOptions = null) { - return JsonSerializer.SerializeToElement( - this, - DurableAgentJsonUtilities.DefaultOptions.GetTypeInfo(typeof(DurableAgentSession))); + var jso = jsonSerializerOptions ?? DurableAgentJsonUtilities.DefaultOptions; + return JsonSerializer.SerializeToElement(this, jso.GetTypeInfo(typeof(DurableAgentSession))); } /// @@ -49,7 +53,11 @@ internal static DurableAgentSession Deserialize(JsonElement serializedSession, J string sessionIdString = sessionIdElement.GetString() ?? throw new JsonException("sessionId property is null."); AgentSessionId sessionId = AgentSessionId.Parse(sessionIdString); - return new DurableAgentSession(sessionId); + AgentSessionStateBag stateBag = serializedSession.TryGetProperty("stateBag", out JsonElement stateBagElement) + ? AgentSessionStateBag.Deserialize(stateBagElement) + : new AgentSessionStateBag(); + + return new DurableAgentSession(sessionId, stateBag); } /// @@ -68,4 +76,8 @@ public override string ToString() { return this.SessionId.ToString(); } + + [DebuggerBrowsable(DebuggerBrowsableState.Never)] + private string DebuggerDisplay => + $"SessionId = {this.SessionId}, StateBag Count = {this.StateBag.Count}"; } diff --git a/dotnet/src/Microsoft.Agents.AI.GitHub.Copilot/GitHubCopilotAgent.cs b/dotnet/src/Microsoft.Agents.AI.GitHub.Copilot/GitHubCopilotAgent.cs index 8319832613..f0533c8461 100644 --- a/dotnet/src/Microsoft.Agents.AI.GitHub.Copilot/GitHubCopilotAgent.cs +++ b/dotnet/src/Microsoft.Agents.AI.GitHub.Copilot/GitHubCopilotAgent.cs @@ -115,7 +115,7 @@ protected override ValueTask DeserializeSessionCoreAsync( JsonElement serializedState, JsonSerializerOptions? jsonSerializerOptions = null, CancellationToken cancellationToken = default) - => new(new GitHubCopilotAgentSession(serializedState, jsonSerializerOptions)); + => new(GitHubCopilotAgentSession.Deserialize(serializedState, jsonSerializerOptions)); /// protected override Task RunCoreAsync( diff --git a/dotnet/src/Microsoft.Agents.AI.GitHub.Copilot/GitHubCopilotAgentSession.cs b/dotnet/src/Microsoft.Agents.AI.GitHub.Copilot/GitHubCopilotAgentSession.cs index f514eeb71b..70fe43425e 100644 --- a/dotnet/src/Microsoft.Agents.AI.GitHub.Copilot/GitHubCopilotAgentSession.cs +++ b/dotnet/src/Microsoft.Agents.AI.GitHub.Copilot/GitHubCopilotAgentSession.cs @@ -1,17 +1,22 @@ // Copyright (c) Microsoft. All rights reserved. +using System; +using System.Diagnostics; using System.Text.Json; +using System.Text.Json.Serialization; namespace Microsoft.Agents.AI.GitHub.Copilot; /// /// Represents a session for a GitHub Copilot agent conversation. /// +[DebuggerDisplay("{DebuggerDisplay,nq}")] public sealed class GitHubCopilotAgentSession : AgentSession { /// /// Gets or sets the session ID for the GitHub Copilot conversation. /// + [JsonPropertyName("sessionId")] public string? SessionId { get; internal set; } /// @@ -21,35 +26,32 @@ internal GitHubCopilotAgentSession() { } - /// - /// Initializes a new instance of the class from serialized data. - /// - /// The serialized thread data. - /// Optional JSON serialization options. - internal GitHubCopilotAgentSession(JsonElement serializedThread, JsonSerializerOptions? jsonSerializerOptions = null) + [JsonConstructor] + internal GitHubCopilotAgentSession(string? sessionId, AgentSessionStateBag? stateBag) : base(stateBag ?? new()) { - // The JSON serialization uses camelCase - if (serializedThread.TryGetProperty("sessionId", out JsonElement sessionIdElement)) - { - this.SessionId = sessionIdElement.GetString(); - } + this.SessionId = sessionId; } /// internal JsonElement Serialize(JsonSerializerOptions? jsonSerializerOptions = null) { - State state = new() - { - SessionId = this.SessionId - }; - - return JsonSerializer.SerializeToElement( - state, - GitHubCopilotJsonUtilities.DefaultOptions.GetTypeInfo(typeof(State))); + var jso = jsonSerializerOptions ?? GitHubCopilotJsonUtilities.DefaultOptions; + return JsonSerializer.SerializeToElement(this, jso.GetTypeInfo(typeof(GitHubCopilotAgentSession))); } - internal sealed class State + internal static GitHubCopilotAgentSession Deserialize(JsonElement serializedState, JsonSerializerOptions? jsonSerializerOptions = null) { - public string? SessionId { get; set; } + if (serializedState.ValueKind != JsonValueKind.Object) + { + throw new ArgumentException("The serialized session state must be a JSON object.", nameof(serializedState)); + } + + var jso = jsonSerializerOptions ?? GitHubCopilotJsonUtilities.DefaultOptions; + return serializedState.Deserialize(jso.GetTypeInfo(typeof(GitHubCopilotAgentSession))) as GitHubCopilotAgentSession + ?? new GitHubCopilotAgentSession(); } + + [DebuggerBrowsable(DebuggerBrowsableState.Never)] + private string DebuggerDisplay => + $"SessionId = {this.SessionId}, StateBag Count = {this.StateBag.Count}"; } diff --git a/dotnet/src/Microsoft.Agents.AI.GitHub.Copilot/GitHubCopilotJsonUtilities.cs b/dotnet/src/Microsoft.Agents.AI.GitHub.Copilot/GitHubCopilotJsonUtilities.cs index c5254efd6b..9e97c0585b 100644 --- a/dotnet/src/Microsoft.Agents.AI.GitHub.Copilot/GitHubCopilotJsonUtilities.cs +++ b/dotnet/src/Microsoft.Agents.AI.GitHub.Copilot/GitHubCopilotJsonUtilities.cs @@ -42,7 +42,7 @@ private static JsonSerializerOptions CreateDefaultOptions() UseStringEnumConverter = true, DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull, NumberHandling = JsonNumberHandling.AllowReadingFromString)] - [JsonSerializable(typeof(GitHubCopilotAgentSession.State))] + [JsonSerializable(typeof(GitHubCopilotAgentSession))] [ExcludeFromCodeCoverage] private sealed partial class JsonContext : JsonSerializerContext; } diff --git a/dotnet/src/Microsoft.Agents.AI.Mem0/Mem0JsonUtilities.cs b/dotnet/src/Microsoft.Agents.AI.Mem0/Mem0JsonUtilities.cs index d139cb0f76..33f92f3ac2 100644 --- a/dotnet/src/Microsoft.Agents.AI.Mem0/Mem0JsonUtilities.cs +++ b/dotnet/src/Microsoft.Agents.AI.Mem0/Mem0JsonUtilities.cs @@ -65,7 +65,7 @@ private static JsonSerializerOptions CreateDefaultOptions() NumberHandling = JsonNumberHandling.AllowReadingFromString)] // Agent abstraction types - [JsonSerializable(typeof(Mem0Provider.Mem0State))] + [JsonSerializable(typeof(Mem0Provider.State))] [ExcludeFromCodeCoverage] internal sealed partial class JsonContext : JsonSerializerContext; diff --git a/dotnet/src/Microsoft.Agents.AI.Mem0/Mem0Provider.cs b/dotnet/src/Microsoft.Agents.AI.Mem0/Mem0Provider.cs index a230eb0e4e..aaf2333553 100644 --- a/dotnet/src/Microsoft.Agents.AI.Mem0/Mem0Provider.cs +++ b/dotnet/src/Microsoft.Agents.AI.Mem0/Mem0Provider.cs @@ -4,7 +4,6 @@ using System.Collections.Generic; using System.Linq; using System.Net.Http; -using System.Text.Json; using System.Text.Json.Serialization; using System.Threading; using System.Threading.Tasks; @@ -27,23 +26,27 @@ public sealed class Mem0Provider : AIContextProvider { private const string DefaultContextPrompt = "## Memories\nConsider the following memories when answering user questions:"; + private static IEnumerable DefaultExternalOnlyFilter(IEnumerable messages) + => messages.Where(m => m.GetAgentRequestMessageSourceType() == AgentRequestMessageSourceType.External); + private readonly string _contextPrompt; private readonly bool _enableSensitiveTelemetryData; + private readonly string _stateKey; + private readonly Func _stateInitializer; + private readonly Func, IEnumerable> _searchInputMessageFilter; + private readonly Func, IEnumerable> _storageInputMessageFilter; private readonly Mem0Client _client; private readonly ILogger? _logger; - private readonly Mem0ProviderScope _storageScope; - private readonly Mem0ProviderScope _searchScope; - /// /// Initializes a new instance of the class. /// /// Configured (base address + auth). - /// Optional values to scope the memory storage with. - /// Optional values to scope the memory search with. Defaults to if not provided. + /// A delegate that initializes the provider state on the first invocation, providing the storage and search scopes. /// Provider options. /// Optional logger factory. + /// Thrown when or is . /// /// The base address of the required mem0 service, and any authentication headers, should be set on the /// already, when passed as a parameter here. E.g.: @@ -51,83 +54,60 @@ public sealed class Mem0Provider : AIContextProvider /// using var httpClient = new HttpClient(); /// httpClient.BaseAddress = new Uri("https://api.mem0.ai"); /// httpClient.DefaultRequestHeaders.Authorization = new AuthenticationHeaderValue("Token", "<Your APIKey>"); - /// new Mem0AIContextProvider(httpClient); + /// new Mem0Provider(httpClient); /// /// - public Mem0Provider(HttpClient httpClient, Mem0ProviderScope storageScope, Mem0ProviderScope? searchScope = null, Mem0ProviderOptions? options = null, ILoggerFactory? loggerFactory = null) + public Mem0Provider(HttpClient httpClient, Func stateInitializer, Mem0ProviderOptions? options = null, ILoggerFactory? loggerFactory = null) { + Throw.IfNull(httpClient); if (string.IsNullOrWhiteSpace(httpClient.BaseAddress?.AbsoluteUri)) { throw new ArgumentException("The HttpClient BaseAddress must be set for Mem0 operations.", nameof(httpClient)); } + this._stateInitializer = Throw.IfNull(stateInitializer); this._logger = loggerFactory?.CreateLogger(); this._client = new Mem0Client(httpClient); this._contextPrompt = options?.ContextPrompt ?? DefaultContextPrompt; this._enableSensitiveTelemetryData = options?.EnableSensitiveTelemetryData ?? false; - this._storageScope = new Mem0ProviderScope(Throw.IfNull(storageScope)); - this._searchScope = searchScope ?? storageScope; - - if (string.IsNullOrWhiteSpace(this._storageScope.ApplicationId) - && string.IsNullOrWhiteSpace(this._storageScope.AgentId) - && string.IsNullOrWhiteSpace(this._storageScope.ThreadId) - && string.IsNullOrWhiteSpace(this._storageScope.UserId)) - { - throw new ArgumentException("At least one of ApplicationId, AgentId, ThreadId, or UserId must be provided for the storage scope."); - } - - if (string.IsNullOrWhiteSpace(this._searchScope.ApplicationId) - && string.IsNullOrWhiteSpace(this._searchScope.AgentId) - && string.IsNullOrWhiteSpace(this._searchScope.ThreadId) - && string.IsNullOrWhiteSpace(this._searchScope.UserId)) - { - throw new ArgumentException("At least one of ApplicationId, AgentId, ThreadId, or UserId must be provided for the search scope."); - } + this._stateKey = options?.StateKey ?? base.StateKey; + this._searchInputMessageFilter = options?.SearchInputMessageFilter ?? DefaultExternalOnlyFilter; + this._storageInputMessageFilter = options?.StorageInputMessageFilter ?? DefaultExternalOnlyFilter; } + /// + public override string StateKey => this._stateKey; + /// - /// Initializes a new instance of the class, with existing state from a serialized JSON element. + /// Gets the state from the session's StateBag, or initializes it using the StateInitializer if not present. /// - /// Configured (base address + auth). - /// A representing the serialized state of the store. - /// Optional settings for customizing the JSON deserialization process. - /// Provider options. - /// Optional logger factory. - /// - /// - /// The base address of the required mem0 service, and any authentication headers, should be set on the - /// already, when passed as a parameter here. E.g.: - /// - /// using var httpClient = new HttpClient(); - /// httpClient.BaseAddress = new Uri("https://api.mem0.ai"); - /// httpClient.DefaultRequestHeaders.Authorization = new AuthenticationHeaderValue("Token", "<Your APIKey>"); - /// new Mem0AIContextProvider(httpClient, state); - /// - /// - public Mem0Provider(HttpClient httpClient, JsonElement serializedState, JsonSerializerOptions? jsonSerializerOptions = null, Mem0ProviderOptions? options = null, ILoggerFactory? loggerFactory = null) + /// The agent session containing the StateBag. + /// The provider state, or null if no session is available. + private State? GetOrInitializeState(AgentSession? session) { - if (string.IsNullOrWhiteSpace(httpClient.BaseAddress?.AbsoluteUri)) + if (session?.StateBag.TryGetValue(this._stateKey, out var state, Mem0JsonUtilities.DefaultOptions) is true && state is not null) { - throw new ArgumentException("The HttpClient BaseAddress must be set for Mem0 operations.", nameof(httpClient)); + return state; } - this._logger = loggerFactory?.CreateLogger(); - this._client = new Mem0Client(httpClient); + state = this._stateInitializer(session); - this._contextPrompt = options?.ContextPrompt ?? DefaultContextPrompt; - this._enableSensitiveTelemetryData = options?.EnableSensitiveTelemetryData ?? false; - - var jso = jsonSerializerOptions ?? Mem0JsonUtilities.DefaultOptions; - var state = serializedState.Deserialize(jso.GetTypeInfo(typeof(Mem0State))) as Mem0State; + if (state is null + || state.StorageScope is null + || (state.StorageScope.AgentId is null && state.StorageScope.ThreadId is null && state.StorageScope.UserId is null && state.StorageScope.ApplicationId is null) + || state.SearchScope is null + || (state.SearchScope.AgentId is null && state.SearchScope.ThreadId is null && state.SearchScope.UserId is null && state.SearchScope.ApplicationId is null)) + { + throw new InvalidOperationException("State initializer must return a non-null state with valid storage and search scopes, where at lest one scoping parameter is set for each."); + } - if (state == null || state.StorageScope == null || state.SearchScope == null) + if (session is not null) { - throw new InvalidOperationException("The Mem0Provider state did not contain the required scope properties."); + session.StateBag.SetValue(this._stateKey, state, Mem0JsonUtilities.DefaultOptions); } - this._storageScope = state.StorageScope; - this._searchScope = state.SearchScope; + return state; } /// @@ -135,36 +115,42 @@ protected override async ValueTask InvokingCoreAsync(InvokingContext { Throw.IfNull(context); + var inputContext = context.AIContext; + var state = this.GetOrInitializeState(context.Session); + var searchScope = state?.SearchScope ?? new Mem0ProviderScope(); + string queryText = string.Join( Environment.NewLine, - context.RequestMessages - .Where(m => m.GetAgentRequestMessageSourceType() == AgentRequestMessageSourceType.External) + this._searchInputMessageFilter(inputContext.Messages ?? []) .Where(m => !string.IsNullOrWhiteSpace(m.Text)) .Select(m => m.Text)); try { var memories = (await this._client.SearchAsync( - this._searchScope.ApplicationId, - this._searchScope.AgentId, - this._searchScope.ThreadId, - this._searchScope.UserId, + searchScope.ApplicationId, + searchScope.AgentId, + searchScope.ThreadId, + searchScope.UserId, queryText, cancellationToken).ConfigureAwait(false)).ToList(); var outputMessageText = memories.Count == 0 ? null : $"{this._contextPrompt}\n{string.Join(Environment.NewLine, memories)}"; + var outputMessage = memories.Count == 0 + ? null + : new ChatMessage(ChatRole.User, outputMessageText!).WithAgentRequestMessageSource(AgentRequestMessageSourceType.AIContextProvider, this.GetType().FullName!); if (this._logger?.IsEnabled(LogLevel.Information) is true) { this._logger.LogInformation( "Mem0AIContextProvider: Retrieved {Count} memories. ApplicationId: '{ApplicationId}', AgentId: '{AgentId}', ThreadId: '{ThreadId}', UserId: '{UserId}'.", memories.Count, - this._searchScope.ApplicationId, - this._searchScope.AgentId, - this._searchScope.ThreadId, - this.SanitizeLogData(this._searchScope.UserId)); + searchScope.ApplicationId, + searchScope.AgentId, + searchScope.ThreadId, + this.SanitizeLogData(searchScope.UserId)); if (outputMessageText is not null && this._logger.IsEnabled(LogLevel.Trace)) { @@ -172,16 +158,20 @@ protected override async ValueTask InvokingCoreAsync(InvokingContext "Mem0AIContextProvider: Search Results\nInput:{Input}\nOutput:{MessageText}\nApplicationId: '{ApplicationId}', AgentId: '{AgentId}', ThreadId: '{ThreadId}', UserId: '{UserId}'.", this.SanitizeLogData(queryText), this.SanitizeLogData(outputMessageText), - this._searchScope.ApplicationId, - this._searchScope.AgentId, - this._searchScope.ThreadId, - this.SanitizeLogData(this._searchScope.UserId)); + searchScope.ApplicationId, + searchScope.AgentId, + searchScope.ThreadId, + this.SanitizeLogData(searchScope.UserId)); } } return new AIContext { - Messages = [new ChatMessage(ChatRole.User, outputMessageText)] + Instructions = inputContext.Instructions, + Messages = + (inputContext.Messages ?? []) + .Concat(outputMessage is not null ? [outputMessage] : []), + Tools = inputContext.Tools }; } catch (ArgumentException) @@ -195,12 +185,12 @@ protected override async ValueTask InvokingCoreAsync(InvokingContext this._logger.LogError( ex, "Mem0AIContextProvider: Failed to search Mem0 for memories due to error. ApplicationId: '{ApplicationId}', AgentId: '{AgentId}', ThreadId: '{ThreadId}', UserId: '{UserId}'.", - this._searchScope.ApplicationId, - this._searchScope.AgentId, - this._searchScope.ThreadId, - this.SanitizeLogData(this._searchScope.UserId)); + searchScope.ApplicationId, + searchScope.AgentId, + searchScope.ThreadId, + this.SanitizeLogData(searchScope.UserId)); } - return new AIContext(); + return inputContext; } } @@ -212,12 +202,15 @@ protected override async ValueTask InvokedCoreAsync(InvokedContext context, Canc return; // Do not update memory on failed invocations. } + var state = this.GetOrInitializeState(context.Session); + var storageScope = state?.StorageScope ?? new Mem0ProviderScope(); + try { // Persist request and response messages after invocation. await this.PersistMessagesAsync( - context.RequestMessages - .Where(m => m.GetAgentRequestMessageSourceType() == AgentRequestMessageSourceType.External) + storageScope, + this._storageInputMessageFilter(context.RequestMessages) .Concat(context.ResponseMessages ?? []), cancellationToken).ConfigureAwait(false); } @@ -228,36 +221,39 @@ await this.PersistMessagesAsync( this._logger.LogError( ex, "Mem0AIContextProvider: Failed to send messages to Mem0 due to error. ApplicationId: '{ApplicationId}', AgentId: '{AgentId}', ThreadId: '{ThreadId}', UserId: '{UserId}'.", - this._storageScope.ApplicationId, - this._storageScope.AgentId, - this._storageScope.ThreadId, - this.SanitizeLogData(this._storageScope.UserId)); + storageScope.ApplicationId, + storageScope.AgentId, + storageScope.ThreadId, + this.SanitizeLogData(storageScope.UserId)); } } } /// - /// Clears stored memories for the configured scopes. + /// Clears stored memories for the specified scope. /// + /// The session containing the scope state to clear memories for. /// Cancellation token. - public Task ClearStoredMemoriesAsync(CancellationToken cancellationToken = default) => - this._client.ClearMemoryAsync( - this._storageScope.ApplicationId, - this._storageScope.AgentId, - this._storageScope.ThreadId, - this._storageScope.UserId, - cancellationToken); - - /// - public override JsonElement Serialize(JsonSerializerOptions? jsonSerializerOptions = null) + public Task ClearStoredMemoriesAsync(AgentSession session, CancellationToken cancellationToken = default) { - var state = new Mem0State(this._storageScope, this._searchScope); + Throw.IfNull(session); + var state = this.GetOrInitializeState(session); + var storageScope = state?.StorageScope; - var jso = jsonSerializerOptions ?? Mem0JsonUtilities.DefaultOptions; - return JsonSerializer.SerializeToElement(state, jso.GetTypeInfo(typeof(Mem0State))); + if (storageScope is null) + { + return Task.CompletedTask; // Nothing to clear if there is no state. + } + + return this._client.ClearMemoryAsync( + storageScope.ApplicationId, + storageScope.AgentId, + storageScope.ThreadId, + storageScope.UserId, + cancellationToken); } - private async Task PersistMessagesAsync(IEnumerable messages, CancellationToken cancellationToken) + private async Task PersistMessagesAsync(Mem0ProviderScope storageScope, IEnumerable messages, CancellationToken cancellationToken) { foreach (var message in messages) { @@ -277,27 +273,42 @@ private async Task PersistMessagesAsync(IEnumerable messages, Cance } await this._client.CreateMemoryAsync( - this._storageScope.ApplicationId, - this._storageScope.AgentId, - this._storageScope.ThreadId, - this._storageScope.UserId, + storageScope.ApplicationId, + storageScope.AgentId, + storageScope.ThreadId, + storageScope.UserId, message.Text, message.Role.Value, cancellationToken).ConfigureAwait(false); } } - internal sealed class Mem0State + /// + /// Represents the state of a stored in the . + /// + public sealed class State { + /// + /// Initializes a new instance of the class with the specified storage and search scopes. + /// + /// The scope to use when storing memories. + /// The scope to use when searching for memories. If null, the storage scope will be used for searching as well. [JsonConstructor] - public Mem0State(Mem0ProviderScope storageScope, Mem0ProviderScope searchScope) + public State(Mem0ProviderScope storageScope, Mem0ProviderScope? searchScope = null) { - this.StorageScope = storageScope; - this.SearchScope = searchScope; + this.StorageScope = Throw.IfNull(storageScope); + this.SearchScope = searchScope ?? storageScope; } - public Mem0ProviderScope StorageScope { get; set; } - public Mem0ProviderScope SearchScope { get; set; } + /// + /// Gets the scope used when storing memories. + /// + public Mem0ProviderScope StorageScope { get; } + + /// + /// Gets the scope used when searching memories. + /// + public Mem0ProviderScope SearchScope { get; } } private string? SanitizeLogData(string? data) => this._enableSensitiveTelemetryData ? data : ""; diff --git a/dotnet/src/Microsoft.Agents.AI.Mem0/Mem0ProviderOptions.cs b/dotnet/src/Microsoft.Agents.AI.Mem0/Mem0ProviderOptions.cs index f2d3d89e16..f7d14028d9 100644 --- a/dotnet/src/Microsoft.Agents.AI.Mem0/Mem0ProviderOptions.cs +++ b/dotnet/src/Microsoft.Agents.AI.Mem0/Mem0ProviderOptions.cs @@ -1,5 +1,9 @@ // Copyright (c) Microsoft. All rights reserved. +using System; +using System.Collections.Generic; +using Microsoft.Extensions.AI; + namespace Microsoft.Agents.AI.Mem0; /// @@ -18,4 +22,30 @@ public sealed class Mem0ProviderOptions /// /// Defaults to . public bool EnableSensitiveTelemetryData { get; set; } + + /// + /// Gets or sets the key used to store the provider state in the session's . + /// + /// Defaults to the provider's type name. + public string? StateKey { get; set; } + + /// + /// Gets or sets an optional filter function applied to request messages when building the search text to use when + /// searching for relevant memories during . + /// + /// + /// When , the provider defaults to including only + /// messages. + /// + public Func, IEnumerable>? SearchInputMessageFilter { get; set; } + + /// + /// Gets or sets an optional filter function applied to request messages when determining which messages to + /// extract memories from during . + /// + /// + /// When , the provider defaults to including only + /// messages. + /// + public Func, IEnumerable>? StorageInputMessageFilter { get; set; } } diff --git a/dotnet/src/Microsoft.Agents.AI.OpenAI/Extensions/OpenAIAssistantClientExtensions.cs b/dotnet/src/Microsoft.Agents.AI.OpenAI/Extensions/OpenAIAssistantClientExtensions.cs index d167d1f0b4..c56a63c76e 100644 --- a/dotnet/src/Microsoft.Agents.AI.OpenAI/Extensions/OpenAIAssistantClientExtensions.cs +++ b/dotnet/src/Microsoft.Agents.AI.OpenAI/Extensions/OpenAIAssistantClientExtensions.cs @@ -204,8 +204,8 @@ public static ChatClientAgent AsAIAgent( Name = options.Name ?? assistantMetadata.Name, Description = options.Description ?? assistantMetadata.Description, ChatOptions = options.ChatOptions, - AIContextProviderFactory = options.AIContextProviderFactory, - ChatHistoryProviderFactory = options.ChatHistoryProviderFactory, + AIContextProviders = options.AIContextProviders, + ChatHistoryProvider = options.ChatHistoryProvider, UseProvidedChatClientAsIs = options.UseProvidedChatClientAsIs }; diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows/WorkflowChatHistoryProvider.cs b/dotnet/src/Microsoft.Agents.AI.Workflows/WorkflowChatHistoryProvider.cs index f631de8e8a..6672b9e2a3 100644 --- a/dotnet/src/Microsoft.Agents.AI.Workflows/WorkflowChatHistoryProvider.cs +++ b/dotnet/src/Microsoft.Agents.AI.Workflows/WorkflowChatHistoryProvider.cs @@ -6,48 +6,56 @@ using System.Threading; using System.Threading.Tasks; using Microsoft.Extensions.AI; -using Microsoft.Shared.Diagnostics; namespace Microsoft.Agents.AI.Workflows; internal sealed class WorkflowChatHistoryProvider : ChatHistoryProvider { - private int _bookmark; - private readonly List _chatMessages = []; - - public WorkflowChatHistoryProvider() + private readonly JsonSerializerOptions _jsonSerializerOptions; + + /// + /// Initializes a new instance of the class. + /// + /// + /// Optional JSON serializer options for serializing the state of this provider. + /// This is valuable for cases like when the chat history contains custom types + /// and source generated serializers are required, or Native AOT / Trimming is required. + /// + public WorkflowChatHistoryProvider(JsonSerializerOptions? jsonSerializerOptions = null) { + this._jsonSerializerOptions = jsonSerializerOptions ?? AgentAbstractionsJsonUtilities.DefaultOptions; } - public WorkflowChatHistoryProvider(StoreState state) + internal sealed class StoreState { - this.ImportStoreState(Throw.IfNull(state)); + public int Bookmark { get; set; } + public List Messages { get; set; } = []; } - private void ImportStoreState(StoreState state, bool clearMessages = false) + private StoreState GetOrInitializeState(AgentSession? session) { - if (clearMessages) + if (session?.StateBag.TryGetValue(this.StateKey, out var state, this._jsonSerializerOptions) is true && state is not null) { - this._chatMessages.Clear(); + return state; } - if (state?.Messages is not null) + state = new(); + if (session is not null) { - this._chatMessages.AddRange(state.Messages); + session.StateBag.SetValue(this.StateKey, state, this._jsonSerializerOptions); } - this._bookmark = state?.Bookmark ?? 0; - } - internal sealed class StoreState - { - public int Bookmark { get; set; } - public IList Messages { get; set; } = []; + return state; } - internal void AddMessages(params IEnumerable messages) => this._chatMessages.AddRange(messages); + internal void AddMessages(AgentSession session, params IEnumerable messages) + => this.GetOrInitializeState(session).Messages.AddRange(messages); protected override ValueTask> InvokingCoreAsync(InvokingContext context, CancellationToken cancellationToken = default) - => new(this._chatMessages.AsReadOnly()); + => new(this.GetOrInitializeState(context.Session) + .Messages + .Select(message => message.WithAgentRequestMessageSource(AgentRequestMessageSourceType.ChatHistory, this.GetType().FullName!)) + .Concat(context.RequestMessages)); protected override ValueTask InvokedCoreAsync(InvokedContext context, CancellationToken cancellationToken = default) { @@ -56,29 +64,27 @@ protected override ValueTask InvokedCoreAsync(InvokedContext context, Cancellati return default; } - var allNewMessages = context.RequestMessages.Concat(context.ResponseMessages ?? []); - this._chatMessages.AddRange(allNewMessages); + var allNewMessages = context.RequestMessages + .Where(m => m.GetAgentRequestMessageSourceType() != AgentRequestMessageSourceType.ChatHistory) + .Concat(context.ResponseMessages ?? []); + this.GetOrInitializeState(context.Session).Messages.AddRange(allNewMessages); return default; } - public IEnumerable GetFromBookmark() + public IEnumerable GetFromBookmark(AgentSession session) { - for (int i = this._bookmark; i < this._chatMessages.Count; i++) + var state = this.GetOrInitializeState(session); + + for (int i = state.Bookmark; i < state.Messages.Count; i++) { - yield return this._chatMessages[i]; + yield return state.Messages[i]; } } - public void UpdateBookmark() => this._bookmark = this._chatMessages.Count; - - public override JsonElement Serialize(JsonSerializerOptions? jsonSerializerOptions = null) + public void UpdateBookmark(AgentSession session) { - StoreState state = this.ExportStoreState(); - - return JsonSerializer.SerializeToElement(state, - WorkflowsJsonUtilities.DefaultOptions.GetTypeInfo(typeof(StoreState))); + var state = this.GetOrInitializeState(session); + state.Bookmark = state.Messages.Count; } - - internal StoreState ExportStoreState() => new() { Bookmark = this._bookmark, Messages = this._chatMessages }; } diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows/WorkflowHostAgent.cs b/dotnet/src/Microsoft.Agents.AI.Workflows/WorkflowHostAgent.cs index c08ba5c3f4..7438ce1a34 100644 --- a/dotnet/src/Microsoft.Agents.AI.Workflows/WorkflowHostAgent.cs +++ b/dotnet/src/Microsoft.Agents.AI.Workflows/WorkflowHostAgent.cs @@ -94,7 +94,7 @@ private async ValueTask UpdateSessionAsync(IEnumerable InvokeStageAsync( try { this.LastResponseId = Guid.NewGuid().ToString("N"); - List messages = this.ChatHistoryProvider.GetFromBookmark().ToList(); + List messages = this.ChatHistoryProvider.GetFromBookmark(this).ToList(); #pragma warning disable CA2007 // Analyzer misfiring and not seeing .ConfigureAwait(false) below. await using Checkpointed checkpointed = @@ -240,7 +241,7 @@ IAsyncEnumerable InvokeStageAsync( finally { // Do we want to try to undo the step, and not update the bookmark? - this.ChatHistoryProvider.UpdateBookmark(); + this.ChatHistoryProvider.UpdateBookmark(this); } } @@ -254,12 +255,12 @@ IAsyncEnumerable InvokeStageAsync( internal sealed class SessionState( string runId, CheckpointInfo? lastCheckpoint, - WorkflowChatHistoryProvider.StoreState chatHistoryProviderState, - InMemoryCheckpointManager? checkpointManager = null) + InMemoryCheckpointManager? checkpointManager = null, + AgentSessionStateBag? stateBag = null) { public string RunId { get; } = runId; public CheckpointInfo? LastCheckpoint { get; } = lastCheckpoint; - public WorkflowChatHistoryProvider.StoreState ChatHistoryProviderState { get; } = chatHistoryProviderState; public InMemoryCheckpointManager? CheckpointManager { get; } = checkpointManager; + public AgentSessionStateBag StateBag { get; } = stateBag ?? new(); } } diff --git a/dotnet/src/Microsoft.Agents.AI/AgentJsonUtilities.cs b/dotnet/src/Microsoft.Agents.AI/AgentJsonUtilities.cs index 5c915b6b01..96ec6dbecb 100644 --- a/dotnet/src/Microsoft.Agents.AI/AgentJsonUtilities.cs +++ b/dotnet/src/Microsoft.Agents.AI/AgentJsonUtilities.cs @@ -65,9 +65,9 @@ private static JsonSerializerOptions CreateDefaultOptions() NumberHandling = JsonNumberHandling.AllowReadingFromString)] // Agent abstraction types - [JsonSerializable(typeof(ChatClientAgentSession.SessionState))] + [JsonSerializable(typeof(ChatClientAgentSession))] [JsonSerializable(typeof(TextSearchProvider.TextSearchProviderState))] - [JsonSerializable(typeof(ChatHistoryMemoryProvider.ChatHistoryMemoryProviderState))] + [JsonSerializable(typeof(ChatHistoryMemoryProvider.State))] [ExcludeFromCodeCoverage] internal sealed partial class JsonContext : JsonSerializerContext; diff --git a/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgent.cs b/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgent.cs index dc462cf501..6dc6d175a7 100644 --- a/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgent.cs +++ b/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgent.cs @@ -20,6 +20,7 @@ namespace Microsoft.Agents.AI; public sealed partial class ChatClientAgent : AIAgent { private readonly ChatClientAgentOptions? _agentOptions; + private readonly HashSet _aiContextProviderStateKeys; private readonly AIAgentMetadata _agentMetadata; private readonly ILogger _logger; private readonly Type _chatClientType; @@ -105,6 +106,15 @@ public ChatClientAgent(IChatClient chatClient, ChatClientAgentOptions? options, // If the user has not opted out of using our default decorators, we wrap the chat client. this.ChatClient = options?.UseProvidedChatClientAsIs is true ? chatClient : chatClient.WithDefaultAgentMiddleware(options, services); + // Use the ChatHistoryProvider from options if provided. + // If one was not provided, and we later find out that the underlying service does not manage chat history server-side, + // we will use the default InMemoryChatHistoryProvider at that time. + this.ChatHistoryProvider = options?.ChatHistoryProvider; + this.AIContextProviders = this._agentOptions?.AIContextProviders as IReadOnlyList ?? this._agentOptions?.AIContextProviders?.ToList(); + + // Validate that no two providers share the same StateKey, since they would overwrite each other's state in the session. + this._aiContextProviderStateKeys = ValidateAndCollectStateKeys(this._agentOptions?.AIContextProviders, this.ChatHistoryProvider); + this._logger = (loggerFactory ?? chatClient.GetService() ?? NullLoggerFactory.Instance).CreateLogger(); } @@ -120,6 +130,22 @@ public ChatClientAgent(IChatClient chatClient, ChatClientAgentOptions? options, /// public IChatClient ChatClient { get; } + /// + /// Gets the used by this agent, to support cases where the chat history is not stored by the agent service. + /// + /// + /// This property may be null in case the agent stores messages in the underlying agent service. + /// + public ChatHistoryProvider? ChatHistoryProvider { get; private set; } + + /// + /// Gets the list of instances used by this agent, to support cases where additional context is needed for each agent run. + /// + /// + /// This property may be null in case no additional context providers were configured. + /// + public IReadOnlyList? AIContextProviders { get; } + /// protected override string? IdCore => this._agentOptions?.Id; @@ -206,7 +232,6 @@ protected override async IAsyncEnumerable RunCoreStreamingA (ChatClientAgentSession safeSession, ChatOptions? chatOptions, - List inputMessagesForProviders, List inputMessagesForChatClient, ChatClientAgentContinuationToken? continuationToken) = await this.PrepareSessionAndMessagesAsync(session, inputMessages, options, cancellationToken).ConfigureAwait(false); @@ -230,8 +255,8 @@ protected override async IAsyncEnumerable RunCoreStreamingA } catch (Exception ex) { - await this.NotifyChatHistoryProviderOfFailureAsync(safeSession, ex, GetInputMessages(inputMessagesForProviders, continuationToken), chatOptions, cancellationToken).ConfigureAwait(false); - await this.NotifyAIContextProviderOfFailureAsync(safeSession, ex, GetInputMessages(inputMessagesForProviders, continuationToken), cancellationToken).ConfigureAwait(false); + await this.NotifyChatHistoryProviderOfFailureAsync(safeSession, ex, GetInputMessages(inputMessagesForChatClient, continuationToken), chatOptions, cancellationToken).ConfigureAwait(false); + await this.NotifyAIContextProviderOfFailureAsync(safeSession, ex, GetInputMessages(inputMessagesForChatClient, continuationToken), cancellationToken).ConfigureAwait(false); throw; } @@ -245,8 +270,8 @@ protected override async IAsyncEnumerable RunCoreStreamingA } catch (Exception ex) { - await this.NotifyChatHistoryProviderOfFailureAsync(safeSession, ex, GetInputMessages(inputMessagesForProviders, continuationToken), chatOptions, cancellationToken).ConfigureAwait(false); - await this.NotifyAIContextProviderOfFailureAsync(safeSession, ex, GetInputMessages(inputMessagesForProviders, continuationToken), cancellationToken).ConfigureAwait(false); + await this.NotifyChatHistoryProviderOfFailureAsync(safeSession, ex, GetInputMessages(inputMessagesForChatClient, continuationToken), chatOptions, cancellationToken).ConfigureAwait(false); + await this.NotifyAIContextProviderOfFailureAsync(safeSession, ex, GetInputMessages(inputMessagesForChatClient, continuationToken), cancellationToken).ConfigureAwait(false); throw; } @@ -272,8 +297,8 @@ protected override async IAsyncEnumerable RunCoreStreamingA } catch (Exception ex) { - await this.NotifyChatHistoryProviderOfFailureAsync(safeSession, ex, GetInputMessages(inputMessagesForProviders, continuationToken), chatOptions, cancellationToken).ConfigureAwait(false); - await this.NotifyAIContextProviderOfFailureAsync(safeSession, ex, GetInputMessages(inputMessagesForProviders, continuationToken), cancellationToken).ConfigureAwait(false); + await this.NotifyChatHistoryProviderOfFailureAsync(safeSession, ex, GetInputMessages(inputMessagesForChatClient, continuationToken), chatOptions, cancellationToken).ConfigureAwait(false); + await this.NotifyAIContextProviderOfFailureAsync(safeSession, ex, GetInputMessages(inputMessagesForChatClient, continuationToken), cancellationToken).ConfigureAwait(false); throw; } } @@ -282,13 +307,13 @@ protected override async IAsyncEnumerable RunCoreStreamingA // We can derive the type of supported session from whether we have a conversation id, // so let's update it and set the conversation id for the service session case. - await this.UpdateSessionWithTypeAndConversationIdAsync(safeSession, chatResponse.ConversationId, cancellationToken).ConfigureAwait(false); + this.UpdateSessionConversationId(safeSession, chatResponse.ConversationId, cancellationToken); // To avoid inconsistent state we only notify the session of the input messages if no error occurs after the initial request. - await this.NotifyChatHistoryProviderOfNewMessagesAsync(safeSession, GetInputMessages(inputMessagesForProviders, continuationToken), chatResponse.Messages, chatOptions, cancellationToken).ConfigureAwait(false); + await this.NotifyChatHistoryProviderOfNewMessagesAsync(safeSession, GetInputMessages(inputMessagesForChatClient, continuationToken), chatResponse.Messages, chatOptions, cancellationToken).ConfigureAwait(false); // Notify the AIContextProvider of all new messages. - await this.NotifyAIContextProviderOfSuccessAsync(safeSession, GetInputMessages(inputMessagesForProviders, continuationToken), chatResponse.Messages, cancellationToken).ConfigureAwait(false); + await this.NotifyAIContextProviderOfSuccessAsync(safeSession, GetInputMessages(inputMessagesForChatClient, continuationToken), chatResponse.Messages, cancellationToken).ConfigureAwait(false); } /// @@ -298,24 +323,14 @@ protected override async IAsyncEnumerable RunCoreStreamingA : serviceType == typeof(IChatClient) ? this.ChatClient : serviceType == typeof(ChatOptions) ? this._agentOptions?.ChatOptions : serviceType == typeof(ChatClientAgentOptions) ? this._agentOptions - : this.ChatClient.GetService(serviceType, serviceKey)); + : this.AIContextProviders?.Select(provider => provider.GetService(serviceType, serviceKey)).FirstOrDefault(s => s is not null) + ?? this.ChatHistoryProvider?.GetService(serviceType, serviceKey) + ?? this.ChatClient.GetService(serviceType, serviceKey)); /// - protected override async ValueTask CreateSessionCoreAsync(CancellationToken cancellationToken = default) + protected override ValueTask CreateSessionCoreAsync(CancellationToken cancellationToken = default) { - ChatHistoryProvider? chatHistoryProvider = this._agentOptions?.ChatHistoryProviderFactory is not null - ? await this._agentOptions.ChatHistoryProviderFactory.Invoke(new() { SerializedState = default, JsonSerializerOptions = null }, cancellationToken).ConfigureAwait(false) - : null; - - AIContextProvider? contextProvider = this._agentOptions?.AIContextProviderFactory is not null - ? await this._agentOptions.AIContextProviderFactory.Invoke(new() { SerializedState = default, JsonSerializerOptions = null }, cancellationToken).ConfigureAwait(false) - : null; - - return new ChatClientAgentSession - { - ChatHistoryProvider = chatHistoryProvider, - AIContextProvider = contextProvider - }; + return new(new ChatClientAgentSession()); } /// @@ -336,52 +351,12 @@ protected override async ValueTask CreateSessionCoreAsync(Cancella /// instances that support server-side conversation storage through their underlying . /// /// - public async ValueTask CreateSessionAsync(string conversationId, CancellationToken cancellationToken = default) + public ValueTask CreateSessionAsync(string conversationId, CancellationToken cancellationToken = default) { - AIContextProvider? contextProvider = this._agentOptions?.AIContextProviderFactory is not null - ? await this._agentOptions.AIContextProviderFactory.Invoke(new() { SerializedState = default, JsonSerializerOptions = null }, cancellationToken).ConfigureAwait(false) - : null; - - return new ChatClientAgentSession() + return new(new ChatClientAgentSession() { ConversationId = conversationId, - AIContextProvider = contextProvider - }; - } - - /// - /// Creates a new agent session instance using an existing to continue a conversation. - /// - /// The instance to use for managing the conversation's message history. - /// The to monitor for cancellation requests. - /// - /// A value task representing the asynchronous operation. The task result contains a new instance configured to work with the provided . - /// - /// - /// - /// This method creates threads that do not support server-side conversation storage. - /// Some AI services require server-side conversation storage to function properly, and creating a session - /// with a may not be compatible with these services. - /// - /// - /// Where a service requires server-side conversation storage, use . - /// - /// - /// If the agent detects, during the first run, that the underlying AI service requires server-side conversation storage, - /// the session will throw an exception to indicate that it cannot continue using the provided . - /// - /// - public async ValueTask CreateSessionAsync(ChatHistoryProvider chatHistoryProvider, CancellationToken cancellationToken = default) - { - AIContextProvider? contextProvider = this._agentOptions?.AIContextProviderFactory is not null - ? await this._agentOptions.AIContextProviderFactory.Invoke(new() { SerializedState = default, JsonSerializerOptions = null }, cancellationToken).ConfigureAwait(false) - : null; - - return new ChatClientAgentSession() - { - ChatHistoryProvider = Throw.IfNull(chatHistoryProvider), - AIContextProvider = contextProvider - }; + }); } /// @@ -398,22 +373,9 @@ protected override ValueTask SerializeSessionCoreAsync(AgentSession } /// - protected override async ValueTask DeserializeSessionCoreAsync(JsonElement serializedState, JsonSerializerOptions? jsonSerializerOptions = null, CancellationToken cancellationToken = default) + protected override ValueTask DeserializeSessionCoreAsync(JsonElement serializedState, JsonSerializerOptions? jsonSerializerOptions = null, CancellationToken cancellationToken = default) { - Func>? chatHistoryProviderFactory = this._agentOptions?.ChatHistoryProviderFactory is null ? - null : - (jse, jso, ct) => this._agentOptions.ChatHistoryProviderFactory.Invoke(new() { SerializedState = jse, JsonSerializerOptions = jso }, ct); - - Func>? aiContextProviderFactory = this._agentOptions?.AIContextProviderFactory is null ? - null : - (jse, jso, ct) => this._agentOptions.AIContextProviderFactory.Invoke(new() { SerializedState = jse, JsonSerializerOptions = jso }, ct); - - return await ChatClientAgentSession.DeserializeAsync( - serializedState, - jsonSerializerOptions, - chatHistoryProviderFactory, - aiContextProviderFactory, - cancellationToken).ConfigureAwait(false); + return new(ChatClientAgentSession.Deserialize(serializedState, jsonSerializerOptions)); } #region Private @@ -432,7 +394,6 @@ private async Task RunCoreAsync inputMessagesForProviders, List inputMessagesForChatClient, ChatClientAgentContinuationToken? _) = await this.PrepareSessionAndMessagesAsync(session, inputMessages, options, cancellationToken).ConfigureAwait(false); @@ -453,8 +414,8 @@ private async Task RunCoreAsync RunCoreAsync RunCoreAsync responseMessages, CancellationToken cancellationToken) { - if (session.AIContextProvider is not null) + if (this.AIContextProviders is { Count: > 0 } contextProviders) { - await session.AIContextProvider.InvokedAsync(new(this, session, inputMessages) { ResponseMessages = responseMessages }, - cancellationToken).ConfigureAwait(false); + AIContextProvider.InvokedContext invokedContext = new(this, session, inputMessages, responseMessages); + + foreach (var contextProvider in contextProviders) + { + await contextProvider.InvokedAsync(invokedContext, cancellationToken).ConfigureAwait(false); + } } } @@ -508,10 +473,14 @@ private async Task NotifyAIContextProviderOfFailureAsync( IEnumerable inputMessages, CancellationToken cancellationToken) { - if (session.AIContextProvider is not null) + if (this.AIContextProviders is { Count: > 0 } contextProviders) { - await session.AIContextProvider.InvokedAsync(new(this, session, inputMessages) { InvokeException = ex }, - cancellationToken).ConfigureAwait(false); + AIContextProvider.InvokedContext invokedContext = new(this, session, inputMessages, ex); + + foreach (var contextProvider in contextProviders) + { + await contextProvider.InvokedAsync(invokedContext, cancellationToken).ConfigureAwait(false); + } } } @@ -679,7 +648,6 @@ private async Task <( ChatClientAgentSession AgentSession, ChatOptions? ChatOptions, - List inputMessagesForProviders, List InputMessagesForChatClient, ChatClientAgentContinuationToken? ContinuationToken )> PrepareSessionAndMessagesAsync( @@ -709,52 +677,53 @@ private async Task throw new InvalidOperationException("Input messages are not allowed when continuing a background response using a continuation token."); } - List inputMessagesForProviders = []; - List inputMessagesForChatClient = []; + IEnumerable inputMessagesForChatClient = inputMessages; // Populate the session messages only if we are not continuing an existing response as it's not allowed if (chatOptions?.ContinuationToken is null) { - ChatHistoryProvider? chatHistoryProvider = ResolveChatHistoryProvider(typedSession, chatOptions); + ChatHistoryProvider? chatHistoryProvider = this.ResolveChatHistoryProvider(chatOptions, typedSession); // Add any existing messages from the session to the messages to be sent to the chat client. + // The ChatHistoryProvider returns the merged result (history + input messages). if (chatHistoryProvider is not null) { - var invokingContext = new ChatHistoryProvider.InvokingContext(this, typedSession, inputMessages); - var providerMessages = await chatHistoryProvider.InvokingAsync(invokingContext, cancellationToken).ConfigureAwait(false); - inputMessagesForChatClient.AddRange(providerMessages); + var invokingContext = new ChatHistoryProvider.InvokingContext(this, typedSession, inputMessagesForChatClient); + inputMessagesForChatClient = await chatHistoryProvider.InvokingAsync(invokingContext, cancellationToken).ConfigureAwait(false); } - // Add the input messages before getting context from AIContextProvider. - inputMessagesForProviders.AddRange(inputMessages); - inputMessagesForChatClient.AddRange(inputMessages); - // If we have an AIContextProvider, we should get context from it, and update our // messages and options with the additional context. - if (typedSession.AIContextProvider is not null) + // The AIContextProvider returns the accumulated AIContext (original + new contributions). + if (this.AIContextProviders is { Count: > 0 } aiContextProviders) { - var invokingContext = new AIContextProvider.InvokingContext(this, typedSession, inputMessages); - var aiContext = await typedSession.AIContextProvider.InvokingAsync(invokingContext, cancellationToken).ConfigureAwait(false); - if (aiContext.Messages is { Count: > 0 }) + var aiContext = new AIContext + { + Instructions = chatOptions?.Instructions, + Messages = inputMessagesForChatClient, + Tools = chatOptions?.Tools + }; + + foreach (var aiContextProvider in aiContextProviders) { - inputMessagesForProviders.AddRange(aiContext.Messages); - inputMessagesForChatClient.AddRange(aiContext.Messages); + var invokingContext = new AIContextProvider.InvokingContext(this, typedSession, aiContext); + aiContext = await aiContextProvider.InvokingAsync(invokingContext, cancellationToken).ConfigureAwait(false); } - if (aiContext.Tools is { Count: > 0 }) + // Materialize the accumulated messages and tools once at the end of the provider pipeline. + inputMessagesForChatClient = aiContext.Messages ?? []; + + var tools = aiContext.Tools as IList ?? aiContext.Tools?.ToList(); + if (chatOptions?.Tools is { Count: > 0 } || tools is { Count: > 0 }) { chatOptions ??= new(); - chatOptions.Tools ??= []; - foreach (AITool tool in aiContext.Tools) - { - chatOptions.Tools.Add(tool); - } + chatOptions.Tools = tools; } - if (aiContext.Instructions is not null) + if (chatOptions?.Instructions is not null || aiContext.Instructions is not null) { chatOptions ??= new(); - chatOptions.Instructions = string.IsNullOrWhiteSpace(chatOptions.Instructions) ? aiContext.Instructions : $"{chatOptions.Instructions}\n{aiContext.Instructions}"; + chatOptions.Instructions = aiContext.Instructions; } } } @@ -777,10 +746,13 @@ private async Task chatOptions.ConversationId = typedSession.ConversationId; } - return (typedSession, chatOptions, inputMessagesForProviders, inputMessagesForChatClient, continuationToken); + // Materialize the accumulated messages once at the end of the provider pipeline, reusing the existing list if possible. + List messagesList = inputMessagesForChatClient as List ?? inputMessagesForChatClient.ToList(); + + return (typedSession, chatOptions, messagesList, continuationToken); } - private async Task UpdateSessionWithTypeAndConversationIdAsync(ChatClientAgentSession session, string? responseConversationId, CancellationToken cancellationToken) + private void UpdateSessionConversationId(ChatClientAgentSession session, string? responseConversationId, CancellationToken cancellationToken) { if (string.IsNullOrWhiteSpace(responseConversationId) && !string.IsNullOrWhiteSpace(session.ConversationId)) { @@ -791,6 +763,14 @@ private async Task UpdateSessionWithTypeAndConversationIdAsync(ChatClientAgentSe if (!string.IsNullOrWhiteSpace(responseConversationId)) { + if (this.ChatHistoryProvider is not null) + { + // The agent has a ChatHistoryProvider configured, but the service returned a conversation id, + // meaning the service manages chat history server-side. Both cannot be used simultaneously. + throw new InvalidOperationException( + $"Only {nameof(ChatClientAgentSession.ConversationId)} or {nameof(this.ChatHistoryProvider)} may be used, but not both. The service returned a conversation id indicating server-side chat history management, but the agent has a {nameof(this.ChatHistoryProvider)} configured."); + } + // If we got a conversation id back from the chat client, it means that the service supports server side session storage // so we should update the session with the new id. session.ConversationId = responseConversationId; @@ -798,11 +778,9 @@ private async Task UpdateSessionWithTypeAndConversationIdAsync(ChatClientAgentSe else { // If the service doesn't use service side chat history storage (i.e. we got no id back from invocation), and - // the session has no ChatHistoryProvider yet, we should update the session with the custom ChatHistoryProvider or - // default InMemoryChatHistoryProvider so that it has somewhere to store the chat history. - session.ChatHistoryProvider ??= this._agentOptions?.ChatHistoryProviderFactory is not null - ? await this._agentOptions.ChatHistoryProviderFactory.Invoke(new() { SerializedState = default, JsonSerializerOptions = null }, cancellationToken).ConfigureAwait(false) - : new InMemoryChatHistoryProvider(); + // the agent has no ChatHistoryProvider yet, we should use the default InMemoryChatHistoryProvider so that + // we have somewhere to store the chat history. + this.ChatHistoryProvider ??= new InMemoryChatHistoryProvider(); } } @@ -813,16 +791,13 @@ private Task NotifyChatHistoryProviderOfFailureAsync( ChatOptions? chatOptions, CancellationToken cancellationToken) { - ChatHistoryProvider? provider = ResolveChatHistoryProvider(session, chatOptions); + ChatHistoryProvider? provider = this.ResolveChatHistoryProvider(chatOptions, session); // Only notify the provider if we have one. // If we don't have one, it means that the chat history is service managed and the underlying service is responsible for storing messages. if (provider is not null) { - var invokedContext = new ChatHistoryProvider.InvokedContext(this, session, requestMessages) - { - InvokeException = ex - }; + var invokedContext = new ChatHistoryProvider.InvokedContext(this, session, requestMessages, ex); return provider.InvokedAsync(invokedContext, cancellationToken).AsTask(); } @@ -837,29 +812,45 @@ private Task NotifyChatHistoryProviderOfNewMessagesAsync( ChatOptions? chatOptions, CancellationToken cancellationToken) { - ChatHistoryProvider? provider = ResolveChatHistoryProvider(session, chatOptions); + ChatHistoryProvider? provider = this.ResolveChatHistoryProvider(chatOptions, session); // Only notify the provider if we have one. // If we don't have one, it means that the chat history is service managed and the underlying service is responsible for storing messages. if (provider is not null) { - var invokedContext = new ChatHistoryProvider.InvokedContext(this, session, requestMessages) - { - ResponseMessages = responseMessages - }; + var invokedContext = new ChatHistoryProvider.InvokedContext(this, session, requestMessages, responseMessages); return provider.InvokedAsync(invokedContext, cancellationToken).AsTask(); } return Task.CompletedTask; } - private static ChatHistoryProvider? ResolveChatHistoryProvider(ChatClientAgentSession session, ChatOptions? chatOptions) + private ChatHistoryProvider? ResolveChatHistoryProvider(ChatOptions? chatOptions, ChatClientAgentSession session) { - ChatHistoryProvider? provider = session.ChatHistoryProvider; + ChatHistoryProvider? provider = this.ChatHistoryProvider; + + if (session.ConversationId is not null && provider is not null) + { + throw new InvalidOperationException( + $"Only {nameof(ChatClientAgentSession.ConversationId)} or {nameof(this.ChatHistoryProvider)} may be used, but not both. The current {nameof(ChatClientAgentSession)} has a {nameof(ChatClientAgentSession.ConversationId)} indicating server-side chat history management, but the agent has a {nameof(this.ChatHistoryProvider)} configured."); + } - // If someone provided an override ChatHistoryProvider via AdditionalProperties, we should use that instead of the one on the session. + // If someone provided an override ChatHistoryProvider via AdditionalProperties, we should use that instead. if (chatOptions?.AdditionalProperties?.TryGetValue(out ChatHistoryProvider? overrideProvider) is true) { + if (session.ConversationId is not null && overrideProvider is not null) + { + throw new InvalidOperationException( + $"Only {nameof(ChatClientAgentSession.ConversationId)} or {nameof(this.ChatHistoryProvider)} may be used, but not both. The current {nameof(ChatClientAgentSession)} has a {nameof(ChatClientAgentSession.ConversationId)} indicating server-side chat history management, but an override {nameof(this.ChatHistoryProvider)} was provided via {nameof(AgentRunOptions.AdditionalProperties)}."); + } + + // Validate that the override provider's StateKey does not clash with any AIContextProvider's StateKey. + if (overrideProvider is not null && this._aiContextProviderStateKeys.Contains(overrideProvider.StateKey)) + { + throw new InvalidOperationException( + $"The ChatHistoryProvider '{overrideProvider.GetType().Name}' uses the state key '{overrideProvider.StateKey}' which is already used by one of the configured AIContextProviders. Each provider must use a unique state key to avoid overwriting each other's state."); + } + provider = overrideProvider; } @@ -906,5 +897,43 @@ private static List GetResponseUpdates(ChatClientAgentContin } private string GetLoggingAgentName() => this.Name ?? "UnnamedAgent"; + + /// + /// Validates that all configured providers have unique values + /// and returns a of the AIContextProvider state keys. + /// + private static HashSet ValidateAndCollectStateKeys(IEnumerable? aiContextProviders, ChatHistoryProvider? chatHistoryProvider) + { + HashSet stateKeys = new(StringComparer.Ordinal); + + if (aiContextProviders is not null) + { + foreach (var provider in aiContextProviders) + { + if (!stateKeys.Add(provider.StateKey)) + { + throw new InvalidOperationException( + $"Multiple providers use the same state key '{provider.StateKey}'. Each provider must use a unique state key to avoid overwriting each other's state."); + } + } + } + + if (chatHistoryProvider is null + && stateKeys.Contains(nameof(InMemoryChatHistoryProvider))) + { + throw new InvalidOperationException( + $"The default {nameof(InMemoryChatHistoryProvider)} uses the state key '{nameof(InMemoryChatHistoryProvider)}', which is already used by one of the configured AIContextProviders. Each provider must use a unique state key to avoid overwriting each other's state. To resolve this, either configure a different state key for the AIContextProvider that is using '{nameof(InMemoryChatHistoryProvider)}' as its state key, or provide a custom ChatHistoryProvider with a unique state key."); + } + + if (chatHistoryProvider is not null + && stateKeys.Contains(chatHistoryProvider.StateKey)) + { + throw new InvalidOperationException( + $"The ChatHistoryProvider '{chatHistoryProvider.GetType().Name}' uses the state key '{chatHistoryProvider.StateKey}' which is already used by one of the configured AIContextProviders. Each provider must use a unique state key to avoid overwriting each other's state. To resolve this, either configure a different state key for the AIContextProvider that is using '{chatHistoryProvider.StateKey}' as its state key, or reconfigure the custom ChatHistoryProvider with a unique state key."); + } + + return stateKeys; + } + #endregion } diff --git a/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgentOptions.cs b/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgentOptions.cs index 6f8451e2b8..ddca9197ab 100644 --- a/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgentOptions.cs +++ b/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgentOptions.cs @@ -1,9 +1,6 @@ // Copyright (c) Microsoft. All rights reserved. -using System; -using System.Text.Json; -using System.Threading; -using System.Threading.Tasks; +using System.Collections.Generic; using Microsoft.Extensions.AI; namespace Microsoft.Agents.AI; @@ -39,17 +36,14 @@ public sealed class ChatClientAgentOptions public ChatOptions? ChatOptions { get; set; } /// - /// Gets or sets a factory function to create an instance of - /// which will be used to provide chat history for this agent. + /// Gets or sets the instance to use for providing chat history for this agent. /// - public Func>? ChatHistoryProviderFactory { get; set; } + public ChatHistoryProvider? ChatHistoryProvider { get; set; } /// - /// Gets or sets a factory function to create an instance of - /// which will be used to create a context provider for each new thread, and can then - /// provide additional context for each agent run. + /// Gets or sets the list of instances to use for providing additional context for each agent run. /// - public Func>? AIContextProviderFactory { get; set; } + public IEnumerable? AIContextProviders { get; set; } /// /// Gets or sets a value indicating whether to use the provided instance as is, @@ -75,41 +69,7 @@ public ChatClientAgentOptions Clone() Name = this.Name, Description = this.Description, ChatOptions = this.ChatOptions?.Clone(), - ChatHistoryProviderFactory = this.ChatHistoryProviderFactory, - AIContextProviderFactory = this.AIContextProviderFactory, + ChatHistoryProvider = this.ChatHistoryProvider, + AIContextProviders = this.AIContextProviders is null ? null : new List(this.AIContextProviders), }; - - /// - /// Context object passed to the to create a new instance of . - /// - public sealed class AIContextProviderFactoryContext - { - /// - /// Gets or sets the serialized state of the , if any. - /// - /// if there is no state, e.g. when the is first created. - public JsonElement SerializedState { get; set; } - - /// - /// Gets or sets the JSON serialization options to use when deserializing the . - /// - public JsonSerializerOptions? JsonSerializerOptions { get; set; } - } - - /// - /// Context object passed to the to create a new instance of . - /// - public sealed class ChatHistoryProviderFactoryContext - { - /// - /// Gets or sets the serialized state of the , if any. - /// - /// if there is no state, e.g. when the is first created. - public JsonElement SerializedState { get; set; } - - /// - /// Gets or sets the JSON serialization options to use when deserializing the . - /// - public JsonSerializerOptions? JsonSerializerOptions { get; set; } - } } diff --git a/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgentSession.cs b/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgentSession.cs index 1a79ae64d1..400bfbcaf6 100644 --- a/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgentSession.cs +++ b/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgentSession.cs @@ -3,8 +3,7 @@ using System; using System.Diagnostics; using System.Text.Json; -using System.Threading; -using System.Threading.Tasks; +using System.Text.Json.Serialization; using Microsoft.Shared.Diagnostics; namespace Microsoft.Agents.AI; @@ -15,8 +14,6 @@ namespace Microsoft.Agents.AI; [DebuggerDisplay("{DebuggerDisplay,nq}")] public sealed class ChatClientAgentSession : AgentSession { - private ChatHistoryProvider? _chatHistoryProvider; - /// /// Initializes a new instance of the class. /// @@ -24,29 +21,30 @@ internal ChatClientAgentSession() { } + [JsonConstructor] + internal ChatClientAgentSession(string? conversationId, AgentSessionStateBag? stateBag) : base(stateBag ?? new()) + { + this.ConversationId = conversationId; + } + /// - /// Gets or sets the ID of the underlying service thread to support cases where the chat history is stored by the agent service. + /// Gets or sets the ID of the underlying service chat history to support cases where the chat history is stored by the agent service. /// /// /// - /// Note that either or may be set, but not both. - /// If is not null, setting will throw an - /// exception. - /// - /// /// This property may be null in the following cases: /// - /// The thread stores messages via the and not in the agent service. - /// This thread object is new and a server managed thread has not yet been created in the agent service. + /// The agent stores messages via a and not in the agent service. + /// This session object is new and server managed chat history has not yet been created in the agent service. /// /// /// - /// The id may also change over time where the id is pointing at a - /// agent service managed thread, and the default behavior of a service is - /// to fork the thread with each iteration. + /// The id may also change over time where the id is pointing at + /// agent service managed chat history, and the default behavior of a service is + /// to fork the chat history with each iteration. /// /// - /// Attempted to set a conversation ID but a is already set. + [JsonPropertyName("conversationId")] public string? ConversationId { get; @@ -57,149 +55,37 @@ internal set return; } - if (this._chatHistoryProvider is not null) - { - // If we have a ChatHistoryProvider already, we shouldn't switch the session to use a conversation id - // since it means that the session contents will essentially be deleted, and the session will not work - // with the original agent anymore. - throw new InvalidOperationException("Only the ConversationId or ChatHistoryProvider may be set, but not both and switching from one to another is not supported."); - } - field = Throw.IfNullOrWhitespace(value); } } - /// - /// Gets or sets the used by this thread, for cases where messages should be stored in a custom location. - /// - /// - /// - /// Note that either or may be set, but not both. - /// If is not null, and is set, - /// will be reverted to null, and vice versa. - /// - /// - /// This property may be null in the following cases: - /// - /// The thread stores messages in the agent service and just has an id to the remove thread, instead of in an . - /// This thread object is new it is not yet clear whether it will be backed by a server managed thread or an . - /// - /// - /// - public ChatHistoryProvider? ChatHistoryProvider - { - get => this._chatHistoryProvider; - internal set - { - if (this._chatHistoryProvider is null && value is null) - { - return; - } - - if (!string.IsNullOrWhiteSpace(this.ConversationId)) - { - // If we have a conversation id already, we shouldn't switch the session to use a ChatHistoryProvider - // since it means that the session will not work with the original agent anymore. - throw new InvalidOperationException("Only the ConversationId or ChatHistoryProvider may be set, but not both and switching from one to another is not supported."); - } - - this._chatHistoryProvider = Throw.IfNull(value); - } - } - - /// - /// Gets or sets the used by this thread to provide additional context to the AI model before each invocation. - /// - public AIContextProvider? AIContextProvider { get; internal set; } - /// /// Creates a new instance of the class from previously serialized state. /// /// A representing the serialized state of the session. - /// Optional settings for customizing the JSON deserialization process. - /// - /// An optional factory function to create a custom from its serialized state. - /// If not provided, the default will be used. - /// - /// - /// An optional factory function to create a custom from its serialized state. - /// If not provided, no context provider will be configured. - /// - /// The to monitor for cancellation requests. - /// A task representing the asynchronous operation. The task result contains the deserialized . - internal static async Task DeserializeAsync( - JsonElement serializedState, - JsonSerializerOptions? jsonSerializerOptions = null, - Func>? chatHistoryProviderFactory = null, - Func>? aiContextProviderFactory = null, - CancellationToken cancellationToken = default) + /// Optional JSON serialization options to use instead of the default options. + /// The deserialized . + internal static ChatClientAgentSession Deserialize(JsonElement serializedState, JsonSerializerOptions? jsonSerializerOptions = null) { if (serializedState.ValueKind != JsonValueKind.Object) { throw new ArgumentException("The serialized session state must be a JSON object.", nameof(serializedState)); } - var state = serializedState.Deserialize( - AgentJsonUtilities.DefaultOptions.GetTypeInfo(typeof(SessionState))) as SessionState; - - var session = new ChatClientAgentSession(); - - session.AIContextProvider = aiContextProviderFactory is not null - ? await aiContextProviderFactory.Invoke(state?.AIContextProviderState ?? default, jsonSerializerOptions, cancellationToken).ConfigureAwait(false) - : null; - - if (state?.ConversationId is string sessionId) - { - session.ConversationId = sessionId; - - // Since we have an ID, we should not have a ChatHistoryProvider and we can return here. - return session; - } - - session._chatHistoryProvider = - chatHistoryProviderFactory is not null - ? await chatHistoryProviderFactory.Invoke(state?.ChatHistoryProviderState ?? default, jsonSerializerOptions, cancellationToken).ConfigureAwait(false) - : new InMemoryChatHistoryProvider(state?.ChatHistoryProviderState ?? default, jsonSerializerOptions); // default to an in-memory ChatHistoryProvider - - return session; + var jso = jsonSerializerOptions ?? AgentJsonUtilities.DefaultOptions; + return serializedState.Deserialize(jso.GetTypeInfo(typeof(ChatClientAgentSession))) as ChatClientAgentSession + ?? new ChatClientAgentSession(); } /// internal JsonElement Serialize(JsonSerializerOptions? jsonSerializerOptions = null) { - JsonElement? chatHistoryProviderState = this._chatHistoryProvider?.Serialize(jsonSerializerOptions); - - JsonElement? aiContextProviderState = this.AIContextProvider?.Serialize(jsonSerializerOptions); - - var state = new SessionState - { - ConversationId = this.ConversationId, - ChatHistoryProviderState = chatHistoryProviderState is { ValueKind: not JsonValueKind.Undefined } ? chatHistoryProviderState : null, - AIContextProviderState = aiContextProviderState is { ValueKind: not JsonValueKind.Undefined } ? aiContextProviderState : null, - }; - - return JsonSerializer.SerializeToElement(state, AgentJsonUtilities.DefaultOptions.GetTypeInfo(typeof(SessionState))); + var jso = jsonSerializerOptions ?? AgentJsonUtilities.DefaultOptions; + return JsonSerializer.SerializeToElement(this, jso.GetTypeInfo(typeof(ChatClientAgentSession))); } - /// - public override object? GetService(Type serviceType, object? serviceKey = null) => - base.GetService(serviceType, serviceKey) - ?? this.AIContextProvider?.GetService(serviceType, serviceKey) - ?? this.ChatHistoryProvider?.GetService(serviceType, serviceKey); - [DebuggerBrowsable(DebuggerBrowsableState.Never)] private string DebuggerDisplay => - this.ConversationId is { } conversationId ? $"ConversationId = {conversationId}" : - this._chatHistoryProvider is InMemoryChatHistoryProvider inMemoryChatHistoryProvider ? $"Count = {inMemoryChatHistoryProvider.Count}" : - this._chatHistoryProvider is { } chatHistoryProvider ? $"ChatHistoryProvider = {chatHistoryProvider.GetType().Name}" : - "Count = 0"; - - internal sealed class SessionState - { - public string? ConversationId { get; set; } - - public JsonElement? ChatHistoryProviderState { get; set; } - - public JsonElement? AIContextProviderState { get; set; } - } + this.ConversationId is { } conversationId ? $"ConversationId = {conversationId}, StateBag Count = {this.StateBag.Count}" : + $"StateBag Count = {this.StateBag.Count}"; } diff --git a/dotnet/src/Microsoft.Agents.AI/Memory/ChatHistoryMemoryProvider.cs b/dotnet/src/Microsoft.Agents.AI/Memory/ChatHistoryMemoryProvider.cs index b8e495152c..9d163f79cf 100644 --- a/dotnet/src/Microsoft.Agents.AI/Memory/ChatHistoryMemoryProvider.cs +++ b/dotnet/src/Microsoft.Agents.AI/Memory/ChatHistoryMemoryProvider.cs @@ -4,7 +4,6 @@ using System.Collections.Generic; using System.Linq; using System.Linq.Expressions; -using System.Text.Json; using System.Threading; using System.Threading.Tasks; using Microsoft.Extensions.AI; @@ -42,17 +41,24 @@ public sealed class ChatHistoryMemoryProvider : AIContextProvider, IDisposable private const string DefaultFunctionToolName = "Search"; private const string DefaultFunctionToolDescription = "Allows searching for related previous chat history to help answer the user question."; + private static IEnumerable DefaultExternalOnlyFilter(IEnumerable messages) + => messages.Where(m => m.GetAgentRequestMessageSourceType() == AgentRequestMessageSourceType.External); + +#pragma warning disable CA2213 // VectorStore is not owned by this class - caller is responsible for disposal private readonly VectorStore _vectorStore; +#pragma warning restore CA2213 private readonly VectorStoreCollection> _collection; private readonly int _maxResults; private readonly string _contextPrompt; private readonly bool _enableSensitiveTelemetryData; private readonly ChatHistoryMemoryProviderOptions.SearchBehavior _searchTime; - private readonly AITool[] _tools; + private readonly string _toolName; + private readonly string _toolDescription; private readonly ILogger? _logger; - - private readonly ChatHistoryMemoryProviderScope _storageScope; - private readonly ChatHistoryMemoryProviderScope _searchScope; + private readonly string _stateKey; + private readonly Func _stateInitializer; + private readonly Func, IEnumerable> _searchInputMessageFilter; + private readonly Func, IEnumerable> _storageInputMessageFilter; private bool _collectionInitialized; private readonly SemaphoreSlim _initializationLock = new(1, 1); @@ -64,93 +70,32 @@ public sealed class ChatHistoryMemoryProvider : AIContextProvider, IDisposable /// The vector store to use for storing and retrieving chat history. /// The name of the collection for storing chat history in the vector store. /// The number of dimensions to use for the chat history vector store embeddings. - /// Optional values to scope the chat history storage with. - /// Optional values to scope the chat history search with. Where values are null, no filtering is done using those values. Defaults to if not provided. + /// A delegate that initializes the provider state on the first invocation, providing the storage and search scopes. /// Optional configuration options. /// Optional logger factory. - /// Thrown when is . + /// Thrown when or is . public ChatHistoryMemoryProvider( VectorStore vectorStore, string collectionName, int vectorDimensions, - ChatHistoryMemoryProviderScope storageScope, - ChatHistoryMemoryProviderScope? searchScope = null, + Func stateInitializer, ChatHistoryMemoryProviderOptions? options = null, ILoggerFactory? loggerFactory = null) - : this( - vectorStore, - collectionName, - vectorDimensions, - new ChatHistoryMemoryProviderState - { - StorageScope = new(Throw.IfNull(storageScope)), - SearchScope = searchScope ?? new(storageScope), - }, - options, - loggerFactory) { - } + this._vectorStore = Throw.IfNull(vectorStore); + this._stateInitializer = Throw.IfNull(stateInitializer); - /// - /// Initializes a new instance of the class from previously serialized state. - /// - /// The vector store to use for storing and retrieving chat history. - /// The name of the collection for storing chat history in the vector store. - /// The number of dimensions to use for the chat history vector store embeddings. - /// A representing the serialized state of the provider. - /// Optional settings for customizing the JSON deserialization process. - /// Optional configuration options. - /// Optional logger factory. - public ChatHistoryMemoryProvider( - VectorStore vectorStore, - string collectionName, - int vectorDimensions, - JsonElement serializedState, - JsonSerializerOptions? jsonSerializerOptions = null, - ChatHistoryMemoryProviderOptions? options = null, - ILoggerFactory? loggerFactory = null) - : this( - vectorStore, - collectionName, - vectorDimensions, - DeserializeState(serializedState, jsonSerializerOptions), - options, - loggerFactory) - { - } - - private ChatHistoryMemoryProvider( - VectorStore vectorStore, - string collectionName, - int vectorDimensions, - ChatHistoryMemoryProviderState? state = null, - ChatHistoryMemoryProviderOptions? options = null, - ILoggerFactory? loggerFactory = null) - { - this._vectorStore = vectorStore ?? throw new ArgumentNullException(nameof(vectorStore)); options ??= new ChatHistoryMemoryProviderOptions(); this._maxResults = options.MaxResults.HasValue ? Throw.IfLessThanOrEqual(options.MaxResults.Value, 0) : DefaultMaxResults; this._contextPrompt = options.ContextPrompt ?? DefaultContextPrompt; this._enableSensitiveTelemetryData = options.EnableSensitiveTelemetryData; this._searchTime = options.SearchTime; + this._stateKey = options.StateKey ?? base.StateKey; this._logger = loggerFactory?.CreateLogger(); - - if (state == null || state.StorageScope == null || state.SearchScope == null) - { - throw new InvalidOperationException($"The {nameof(ChatHistoryMemoryProvider)} state did not contain the required scope properties."); - } - - this._storageScope = state.StorageScope; - this._searchScope = state.SearchScope; - - // Create on-demand search tool (only used when behavior is OnDemandFunctionCalling) - this._tools = - [ - AIFunctionFactory.Create( - (Func>)this.SearchTextAsync, - name: options.FunctionToolName ?? DefaultFunctionToolName, - description: options.FunctionToolDescription ?? DefaultFunctionToolDescription) - ]; + this._toolName = options.FunctionToolName ?? DefaultFunctionToolName; + this._toolDescription = options.FunctionToolDescription ?? DefaultFunctionToolDescription; + this._searchInputMessageFilter = options.SearchInputMessageFilter ?? DefaultExternalOnlyFilter; + this._storageInputMessageFilter = options.StorageInputMessageFilter ?? DefaultExternalOnlyFilter; // Create a definition so that we can use the dimensions provided at runtime. var definition = new VectorStoreCollectionDefinition @@ -174,41 +119,93 @@ private ChatHistoryMemoryProvider( this._collection = this._vectorStore.GetDynamicCollection(Throw.IfNullOrWhitespace(collectionName), definition); } + /// + public override string StateKey => this._stateKey; + + /// + /// Gets the state from the session's StateBag, or initializes it using the StateInitializer if not present. + /// + /// The agent session containing the StateBag. + /// The provider state, or null if no session is available. + private State? GetOrInitializeState(AgentSession? session) + { + if (session?.StateBag.TryGetValue(this._stateKey, out var state, AgentJsonUtilities.DefaultOptions) is true && state is not null) + { + return state; + } + + state = this._stateInitializer(session); + if (state is not null && session is not null) + { + session.StateBag.SetValue(this._stateKey, state, AgentJsonUtilities.DefaultOptions); + } + + return state; + } + /// protected override async ValueTask InvokingCoreAsync(InvokingContext context, CancellationToken cancellationToken = default) { _ = Throw.IfNull(context); + var inputContext = context.AIContext; + var state = this.GetOrInitializeState(context.Session); + var searchScope = state?.SearchScope ?? new ChatHistoryMemoryProviderScope(); + if (this._searchTime == ChatHistoryMemoryProviderOptions.SearchBehavior.OnDemandFunctionCalling) { - // Expose search tool for on-demand invocation by the model - return new AIContext { Tools = this._tools }; + Task InlineSearchAsync(string userQuestion, CancellationToken ct) + => this.SearchTextAsync(userQuestion, searchScope, ct); + + // Create on-demand search tool (only used when behavior is OnDemandFunctionCalling) + AITool[] tools = + [ + AIFunctionFactory.Create( + InlineSearchAsync, + name: this._toolName, + description: this._toolDescription) + ]; + + // Expose search tool for on-demand invocation by the model, accumulated with the input context + return new AIContext + { + Instructions = inputContext.Instructions, + Messages = inputContext.Messages, + Tools = (inputContext.Tools ?? []).Concat(tools) + }; } try { // Get the text from the current request messages - var requestText = string.Join("\n", context.RequestMessages - .Where(m => m.GetAgentRequestMessageSourceType() == AgentRequestMessageSourceType.External) + var requestText = string.Join("\n", + this._searchInputMessageFilter(inputContext.Messages ?? []) .Where(m => m != null && !string.IsNullOrWhiteSpace(m.Text)) .Select(m => m.Text)); if (string.IsNullOrWhiteSpace(requestText)) { - return new AIContext(); + return inputContext; } // Search for relevant chat history - var contextText = await this.SearchTextAsync(requestText, cancellationToken).ConfigureAwait(false); + var contextText = await this.SearchTextAsync(requestText, searchScope, cancellationToken).ConfigureAwait(false); if (string.IsNullOrWhiteSpace(contextText)) { - return new AIContext(); + return inputContext; } return new AIContext { - Messages = [new ChatMessage(ChatRole.User, contextText)] + Instructions = inputContext.Instructions, + Messages = + (inputContext.Messages ?? []) + .Concat( + [ + new ChatMessage(ChatRole.User, contextText).WithAgentRequestMessageSource(AgentRequestMessageSourceType.AIContextProvider, this.GetType().FullName!) + ]), + Tools = inputContext.Tools }; } catch (Exception ex) @@ -218,13 +215,13 @@ protected override async ValueTask InvokingCoreAsync(InvokingContext this._logger.LogError( ex, "ChatHistoryMemoryProvider: Failed to search for chat history due to error. ApplicationId: '{ApplicationId}', AgentId: '{AgentId}', SessionId: '{SessionId}', UserId: '{UserId}'.", - this._searchScope.ApplicationId, - this._searchScope.AgentId, - this._searchScope.SessionId, - this.SanitizeLogData(this._searchScope.UserId)); + searchScope.ApplicationId, + searchScope.AgentId, + searchScope.SessionId, + this.SanitizeLogData(searchScope.UserId)); } - return new AIContext(); + return inputContext; } } @@ -239,13 +236,15 @@ protected override async ValueTask InvokedCoreAsync(InvokedContext context, Canc return; } + var state = this.GetOrInitializeState(context.Session); + var storageScope = state?.StorageScope ?? new ChatHistoryMemoryProviderScope(); + try { // Ensure the collection is initialized var collection = await this.EnsureCollectionExistsAsync(cancellationToken).ConfigureAwait(false); - List> itemsToStore = context.RequestMessages - .Where(m => m.GetAgentRequestMessageSourceType() == AgentRequestMessageSourceType.External) + List> itemsToStore = this._storageInputMessageFilter(context.RequestMessages) .Concat(context.ResponseMessages ?? []) .Select(message => new Dictionary { @@ -253,10 +252,10 @@ protected override async ValueTask InvokedCoreAsync(InvokedContext context, Canc ["Role"] = message.Role.ToString(), ["MessageId"] = message.MessageId, ["AuthorName"] = message.AuthorName, - ["ApplicationId"] = this._storageScope?.ApplicationId, - ["AgentId"] = this._storageScope?.AgentId, - ["UserId"] = this._storageScope?.UserId, - ["SessionId"] = this._storageScope?.SessionId, + ["ApplicationId"] = storageScope.ApplicationId, + ["AgentId"] = storageScope.AgentId, + ["UserId"] = storageScope.UserId, + ["SessionId"] = storageScope.SessionId, ["Content"] = message.Text, ["CreatedAt"] = message.CreatedAt?.ToString("O") ?? DateTimeOffset.UtcNow.ToString("O"), ["ContentEmbedding"] = message.Text, @@ -275,10 +274,10 @@ protected override async ValueTask InvokedCoreAsync(InvokedContext context, Canc this._logger.LogError( ex, "ChatHistoryMemoryProvider: Failed to add messages to chat history vector store due to error. ApplicationId: '{ApplicationId}', AgentId: '{AgentId}', SessionId: '{SessionId}', UserId: '{UserId}'.", - this._searchScope.ApplicationId, - this._searchScope.AgentId, - this._searchScope.SessionId, - this.SanitizeLogData(this._searchScope.UserId)); + storageScope.ApplicationId, + storageScope.AgentId, + storageScope.SessionId, + this.SanitizeLogData(storageScope.UserId)); } } } @@ -287,16 +286,17 @@ protected override async ValueTask InvokedCoreAsync(InvokedContext context, Canc /// Function callable by the AI model (when enabled) to perform an ad-hoc chat history search. /// /// The query text. + /// The scope to filter search results with. /// Cancellation token. /// Formatted search results (may be empty). - internal async Task SearchTextAsync(string userQuestion, CancellationToken cancellationToken = default) + private async Task SearchTextAsync(string userQuestion, ChatHistoryMemoryProviderScope searchScope, CancellationToken cancellationToken = default) { if (string.IsNullOrWhiteSpace(userQuestion)) { return string.Empty; } - var results = await this.SearchChatHistoryAsync(userQuestion, this._maxResults, cancellationToken).ConfigureAwait(false); + var results = await this.SearchChatHistoryAsync(userQuestion, searchScope, this._maxResults, cancellationToken).ConfigureAwait(false); if (!results.Any()) { return string.Empty; @@ -317,10 +317,10 @@ internal async Task SearchTextAsync(string userQuestion, CancellationTok "ChatHistoryMemoryProvider: Search Results\nInput:{Input}\nOutput:{MessageText}\n ApplicationId: '{ApplicationId}', AgentId: '{AgentId}', SessionId: '{SessionId}', UserId: '{UserId}'.", this.SanitizeLogData(userQuestion), this.SanitizeLogData(formatted), - this._searchScope.ApplicationId, - this._searchScope.AgentId, - this._searchScope.SessionId, - this.SanitizeLogData(this._searchScope.UserId)); + searchScope.ApplicationId, + searchScope.AgentId, + searchScope.SessionId, + this.SanitizeLogData(searchScope.UserId)); } return formatted; @@ -330,11 +330,13 @@ internal async Task SearchTextAsync(string userQuestion, CancellationTok /// Searches for relevant chat history items based on the provided query text. /// /// The text to search for. + /// The scope to filter search results with. /// The maximum number of results to return. /// The cancellation token. /// A list of relevant chat history items. private async Task>> SearchChatHistoryAsync( string queryText, + ChatHistoryMemoryProviderScope searchScope, int top, CancellationToken cancellationToken = default) { @@ -345,10 +347,10 @@ internal async Task SearchTextAsync(string userQuestion, CancellationTok var collection = await this.EnsureCollectionExistsAsync(cancellationToken).ConfigureAwait(false); - string? applicationId = this._searchScope.ApplicationId; - string? agentId = this._searchScope.AgentId; - string? userId = this._searchScope.UserId; - string? sessionId = this._searchScope.SessionId; + string? applicationId = searchScope.ApplicationId; + string? agentId = searchScope.AgentId; + string? userId = searchScope.UserId; + string? sessionId = searchScope.SessionId; Expression, bool>>? filter = null; if (applicationId != null) @@ -401,10 +403,10 @@ internal async Task SearchTextAsync(string userQuestion, CancellationTok this._logger.LogInformation( "ChatHistoryMemoryProvider: Retrieved {Count} search results. ApplicationId: '{ApplicationId}', AgentId: '{AgentId}', SessionId: '{SessionId}', UserId: '{UserId}'.", results.Count, - this._searchScope.ApplicationId, - this._searchScope.AgentId, - this._searchScope.SessionId, - this.SanitizeLogData(this._searchScope.UserId)); + searchScope.ApplicationId, + searchScope.AgentId, + searchScope.SessionId, + this.SanitizeLogData(searchScope.UserId)); } return results; @@ -465,39 +467,32 @@ public void Dispose() GC.SuppressFinalize(this); } + private string? SanitizeLogData(string? data) => this._enableSensitiveTelemetryData ? data : ""; + /// - /// Serializes the current provider state to a including storage and search scopes. + /// Represents the state of a stored in the . /// - /// Optional serializer options. - /// Serialized provider state. - public override JsonElement Serialize(JsonSerializerOptions? jsonSerializerOptions = null) - { - var state = new ChatHistoryMemoryProviderState - { - StorageScope = this._storageScope, - SearchScope = this._searchScope, - }; - - var jso = jsonSerializerOptions ?? AgentJsonUtilities.DefaultOptions; - return JsonSerializer.SerializeToElement(state, jso.GetTypeInfo(typeof(ChatHistoryMemoryProviderState))); - } - - private static ChatHistoryMemoryProviderState? DeserializeState(JsonElement serializedState, JsonSerializerOptions? jsonSerializerOptions) + public sealed class State { - if (serializedState.ValueKind != JsonValueKind.Object) + /// + /// Initializes a new instance of the class with the specified storage and search scopes. + /// + /// The scope to use when storing chat history messages. + /// The scope to use when searching for relevant chat history messages. If null, the storage scope will be used for searching as well. + public State(ChatHistoryMemoryProviderScope storageScope, ChatHistoryMemoryProviderScope? searchScope = null) { - return null; + this.StorageScope = Throw.IfNull(storageScope); + this.SearchScope = searchScope ?? storageScope; } - var jso = jsonSerializerOptions ?? AgentJsonUtilities.DefaultOptions; - return serializedState.Deserialize(jso.GetTypeInfo(typeof(ChatHistoryMemoryProviderState))) as ChatHistoryMemoryProviderState; - } - - private string? SanitizeLogData(string? data) => this._enableSensitiveTelemetryData ? data : ""; + /// + /// Gets or sets the scope used when storing chat history messages. + /// + public ChatHistoryMemoryProviderScope StorageScope { get; } - internal sealed class ChatHistoryMemoryProviderState - { - public ChatHistoryMemoryProviderScope? StorageScope { get; set; } - public ChatHistoryMemoryProviderScope? SearchScope { get; set; } + /// + /// Gets or sets the scope used when searching chat history messages. + /// + public ChatHistoryMemoryProviderScope SearchScope { get; } } } diff --git a/dotnet/src/Microsoft.Agents.AI/Memory/ChatHistoryMemoryProviderOptions.cs b/dotnet/src/Microsoft.Agents.AI/Memory/ChatHistoryMemoryProviderOptions.cs index e09de68a59..6c92a426f3 100644 --- a/dotnet/src/Microsoft.Agents.AI/Memory/ChatHistoryMemoryProviderOptions.cs +++ b/dotnet/src/Microsoft.Agents.AI/Memory/ChatHistoryMemoryProviderOptions.cs @@ -1,5 +1,9 @@ // Copyright (c) Microsoft. All rights reserved. +using System; +using System.Collections.Generic; +using Microsoft.Extensions.AI; + namespace Microsoft.Agents.AI; /// @@ -44,6 +48,35 @@ public sealed class ChatHistoryMemoryProviderOptions /// Defaults to . public bool EnableSensitiveTelemetryData { get; set; } + /// + /// Gets or sets the key used to store provider state in the . + /// + /// + /// Defaults to the provider's type name. Override this if you need multiple + /// instances with separate state in the same session. + /// + public string? StateKey { get; set; } + + /// + /// Gets or sets an optional filter function applied to request messages when constructing the search text to use + /// to search for relevant chat history during . + /// + /// + /// When , the provider defaults to including only + /// messages. + /// + public Func, IEnumerable>? SearchInputMessageFilter { get; set; } + + /// + /// Gets or sets an optional filter function applied to request messages when storing recent chat history + /// during . + /// + /// + /// When , the provider defaults to including only + /// messages. + /// + public Func, IEnumerable>? StorageInputMessageFilter { get; set; } + /// /// Behavior choices for the provider. /// diff --git a/dotnet/src/Microsoft.Agents.AI/TextSearchProvider.cs b/dotnet/src/Microsoft.Agents.AI/TextSearchProvider.cs index f29aadf808..f038fa3c38 100644 --- a/dotnet/src/Microsoft.Agents.AI/TextSearchProvider.cs +++ b/dotnet/src/Microsoft.Agents.AI/TextSearchProvider.cs @@ -4,7 +4,6 @@ using System.Collections.Generic; using System.Linq; using System.Text; -using System.Text.Json; using System.Threading; using System.Threading.Tasks; using Microsoft.Extensions.AI; @@ -40,30 +39,31 @@ public sealed class TextSearchProvider : AIContextProvider private const string DefaultContextPrompt = "## Additional Context\nConsider the following information from source documents when responding to the user:"; private const string DefaultCitationsPrompt = "Include citations to the source document with document name and link if document name and link is available."; + private static IEnumerable DefaultExternalOnlyFilter(IEnumerable messages) + => messages.Where(m => m.GetAgentRequestMessageSourceType() == AgentRequestMessageSourceType.External); + private readonly Func>> _searchAsync; private readonly ILogger? _logger; private readonly AITool[] _tools; - private readonly Queue _recentMessagesText; private readonly List _recentMessageRolesIncluded; private readonly int _recentMessageMemoryLimit; private readonly TextSearchProviderOptions.TextSearchBehavior _searchTime; private readonly string _contextPrompt; private readonly string _citationsPrompt; + private readonly string _stateKey; private readonly Func, string>? _contextFormatter; + private readonly Func, IEnumerable> _searchInputMessageFilter; + private readonly Func, IEnumerable> _storageInputMessageFilter; /// /// Initializes a new instance of the class. /// /// Delegate that executes the search logic. Must not be . - /// A representing the serialized provider state. - /// Optional serializer options (unused - source generated context is used). /// Optional configuration options. /// Optional logger factory. /// Thrown when is . public TextSearchProvider( Func>> searchAsync, - JsonElement serializedState, - JsonSerializerOptions? jsonSerializerOptions = null, TextSearchProviderOptions? options = null, ILoggerFactory? loggerFactory = null) { @@ -75,26 +75,10 @@ public TextSearchProvider( this._searchTime = options?.SearchTime ?? TextSearchProviderOptions.TextSearchBehavior.BeforeAIInvoke; this._contextPrompt = options?.ContextPrompt ?? DefaultContextPrompt; this._citationsPrompt = options?.CitationsPrompt ?? DefaultCitationsPrompt; + this._stateKey = options?.StateKey ?? base.StateKey; this._contextFormatter = options?.ContextFormatter; - - // Restore recent messages from serialized state if provided - List? restoredMessages = null; - if (serializedState.ValueKind is JsonValueKind.Null or JsonValueKind.Undefined) - { - this._recentMessagesText = new(); - } - else - { - var jso = jsonSerializerOptions ?? AgentJsonUtilities.DefaultOptions; - var state = serializedState.Deserialize(jso.GetTypeInfo(typeof(TextSearchProviderState))) as TextSearchProviderState; - if (state?.RecentMessagesText is { Count: > 0 }) - { - restoredMessages = state.RecentMessagesText; - } - - // Restore recent messages respecting the limit (may truncate if limit changed afterwards). - this._recentMessagesText = restoredMessages is null ? new() : new(restoredMessages.Take(this._recentMessageMemoryLimit)); - } + this._searchInputMessageFilter = options?.SearchInputMessageFilter ?? DefaultExternalOnlyFilter; + this._storageInputMessageFilter = options?.StorageInputMessageFilter ?? DefaultExternalOnlyFilter; // Create the on-demand search tool (only used if behavior is OnDemandFunctionCalling) this._tools = @@ -106,21 +90,35 @@ public TextSearchProvider( ]; } + /// + public override string StateKey => this._stateKey; + /// protected override async ValueTask InvokingCoreAsync(InvokingContext context, CancellationToken cancellationToken = default) { + var inputContext = context.AIContext; + if (this._searchTime != TextSearchProviderOptions.TextSearchBehavior.BeforeAIInvoke) { - // Expose the search tool for on-demand invocation. - return new AIContext { Tools = this._tools }; // No automatic message injection. + // Expose the search tool for on-demand invocation, accumulated with the input context. + return new AIContext + { + Instructions = inputContext.Instructions, + Messages = inputContext.Messages, + Tools = (inputContext.Tools ?? []).Concat(this._tools) + }; } + // Retrieve recent messages from the session state bag. + var recentMessagesText = context.Session?.StateBag.GetValue(this._stateKey, AgentJsonUtilities.DefaultOptions)?.RecentMessagesText + ?? []; + // Aggregate text from memory + current request messages. var sbInput = new StringBuilder(); - var requestMessagesText = context.RequestMessages - .Where(m => m.GetAgentRequestMessageSourceType() == AgentRequestMessageSourceType.External) + var requestMessagesText = + this._searchInputMessageFilter(inputContext.Messages ?? []) .Where(x => !string.IsNullOrWhiteSpace(x?.Text)).Select(x => x.Text); - foreach (var messageText in this._recentMessagesText.Concat(requestMessagesText)) + foreach (var messageText in recentMessagesText.Concat(requestMessagesText)) { if (sbInput.Length > 0) { @@ -144,7 +142,7 @@ protected override async ValueTask InvokingCoreAsync(InvokingContext if (materialized.Count == 0) { - return new AIContext(); + return inputContext; } // Format search results @@ -157,13 +155,20 @@ protected override async ValueTask InvokingCoreAsync(InvokingContext return new AIContext { - Messages = [new ChatMessage(ChatRole.User, formatted) { AdditionalProperties = new AdditionalPropertiesDictionary() { ["IsTextSearchProviderOutput"] = true } }] + Instructions = inputContext.Instructions, + Messages = + (inputContext.Messages ?? []) + .Concat( + [ + new ChatMessage(ChatRole.User, formatted).WithAgentRequestMessageSource(AgentRequestMessageSourceType.AIContextProvider, this.GetType().FullName!) + ]), + Tools = inputContext.Tools }; } catch (Exception ex) { this._logger?.LogError(ex, "TextSearchProvider: Failed to search for data due to error"); - return new AIContext(); + return inputContext; } } @@ -176,58 +181,42 @@ protected override ValueTask InvokedCoreAsync(InvokedContext context, Cancellati return default; // Memory disabled. } + if (context.Session is null) + { + return default; // No session to store state in. + } + if (context.InvokeException is not null) { return default; // Do not update memory on failed invocations. } - var messagesText = context.RequestMessages - .Where(m => m.GetAgentRequestMessageSourceType() == AgentRequestMessageSourceType.External) + // Retrieve existing recent messages from the session state bag. + var recentMessagesText = context.Session.StateBag.GetValue(this._stateKey, AgentJsonUtilities.DefaultOptions)?.RecentMessagesText + ?? []; + + var newMessagesText = this._storageInputMessageFilter(context.RequestMessages) .Concat(context.ResponseMessages ?? []) .Where(m => this._recentMessageRolesIncluded.Contains(m.Role) && - !string.IsNullOrWhiteSpace(m.Text) && - // Filter out any messages that were added by this class in InvokingAsync, since we don't want - // a feedback loop where previous search results are used to find new search results. - (m.AdditionalProperties == null || m.AdditionalProperties.TryGetValue("IsTextSearchProviderOutput", out bool isTextSearchProviderOutput) == false || !isTextSearchProviderOutput)) - .Select(m => m.Text) - .ToList(); - if (messagesText.Count > limit) - { - // If the current request/response exceeds the limit, only keep the most recent messages from it. - messagesText = messagesText.Skip(messagesText.Count - limit).ToList(); - } + !string.IsNullOrWhiteSpace(m.Text)) + .Select(m => m.Text); - foreach (var message in messagesText) - { - this._recentMessagesText.Enqueue(message); - } + // Combine existing messages with new messages, then take the most recent up to the limit. + var allMessages = recentMessagesText.Concat(newMessagesText).ToList(); + var updatedMessages = allMessages.Count > limit + ? allMessages.Skip(allMessages.Count - limit).ToList() + : allMessages; - while (this._recentMessagesText.Count > limit) - { - this._recentMessagesText.Dequeue(); - } + // Store updated state back to the session state bag. + context.Session.StateBag.SetValue( + this._stateKey, + new TextSearchProviderState { RecentMessagesText = updatedMessages }, + AgentJsonUtilities.DefaultOptions); return default; } - /// - /// Serializes the current provider state to a containing any overridden prompts or descriptions. - /// - /// Optional serializer options (ignored, source generated context is used). - /// A with overridden values, or default if nothing was overridden. - public override JsonElement Serialize(JsonSerializerOptions? jsonSerializerOptions = null) - { - // Only persist values that differ from defaults plus recent memory configuration & messages. - TextSearchProviderState state = new(); - if (this._recentMessageMemoryLimit > 0 && this._recentMessagesText.Count > 0) - { - state.RecentMessagesText = this._recentMessagesText.Take(this._recentMessageMemoryLimit).ToList(); - } - - return JsonSerializer.SerializeToElement(state, AgentJsonUtilities.DefaultOptions.GetTypeInfo(typeof(TextSearchProviderState))); - } - /// /// Function callable by the AI model (when enabled) to perform an ad-hoc search. /// diff --git a/dotnet/src/Microsoft.Agents.AI/TextSearchProviderOptions.cs b/dotnet/src/Microsoft.Agents.AI/TextSearchProviderOptions.cs index e90a6efa63..837470b776 100644 --- a/dotnet/src/Microsoft.Agents.AI/TextSearchProviderOptions.cs +++ b/dotnet/src/Microsoft.Agents.AI/TextSearchProviderOptions.cs @@ -59,6 +59,35 @@ public sealed class TextSearchProviderOptions /// public int RecentMessageMemoryLimit { get; set; } + /// + /// Gets or sets the key used to store provider state in the . + /// + /// + /// Defaults to the provider's type name. Override this if you need multiple + /// instances with separate state in the same session. + /// + public string? StateKey { get; set; } + + /// + /// Gets or sets an optional filter function applied to request messages when constructing the search input + /// text during . + /// + /// + /// When , the provider defaults to including only + /// messages. + /// + public Func, IEnumerable>? SearchInputMessageFilter { get; set; } + + /// + /// Gets or sets an optional filter function applied to request messages when updating the recent message + /// memory during . + /// + /// + /// When , the provider defaults to including only + /// messages. + /// + public Func, IEnumerable>? StorageInputMessageFilter { get; set; } + /// /// Gets or sets the list of types to filter recent messages to /// when deciding which recent messages to include when constructing the search input. diff --git a/dotnet/tests/AnthropicChatCompletion.IntegrationTests/AnthropicChatCompletionFixture.cs b/dotnet/tests/AnthropicChatCompletion.IntegrationTests/AnthropicChatCompletionFixture.cs index a0e4b64763..f36cb119d9 100644 --- a/dotnet/tests/AnthropicChatCompletion.IntegrationTests/AnthropicChatCompletionFixture.cs +++ b/dotnet/tests/AnthropicChatCompletion.IntegrationTests/AnthropicChatCompletionFixture.cs @@ -37,14 +37,14 @@ public AnthropicChatCompletionFixture(bool useReasoningChatModel, bool useBeta) public async Task> GetChatHistoryAsync(AIAgent agent, AgentSession session) { - var typedSession = (ChatClientAgentSession)session; + var chatHistoryProvider = agent.GetService(); - if (typedSession.ChatHistoryProvider is null) + if (chatHistoryProvider is null) { return []; } - return (await typedSession.ChatHistoryProvider.InvokingAsync(new(agent, session, []))).ToList(); + return (await chatHistoryProvider.InvokingAsync(new(agent, session, []))).ToList(); } public Task CreateChatClientAgentAsync( diff --git a/dotnet/tests/AzureAI.IntegrationTests/AIProjectClientFixture.cs b/dotnet/tests/AzureAI.IntegrationTests/AIProjectClientFixture.cs index c655fd7a58..dd926174c0 100644 --- a/dotnet/tests/AzureAI.IntegrationTests/AIProjectClientFixture.cs +++ b/dotnet/tests/AzureAI.IntegrationTests/AIProjectClientFixture.cs @@ -48,12 +48,14 @@ public async Task> GetChatHistoryAsync(AIAgent agent, AgentSes return await this.GetChatHistoryFromResponsesChainAsync(chatClientSession.ConversationId); } - if (chatClientSession.ChatHistoryProvider is null) + var chatHistoryProvider = agent.GetService(); + + if (chatHistoryProvider is null) { return []; } - return (await chatClientSession.ChatHistoryProvider.InvokingAsync(new(agent, session, []))).ToList(); + return (await chatHistoryProvider.InvokingAsync(new(agent, session, []))).ToList(); } private async Task> GetChatHistoryFromResponsesChainAsync(string conversationId) diff --git a/dotnet/tests/Microsoft.Agents.AI.A2A.UnitTests/A2AAgentSessionTests.cs b/dotnet/tests/Microsoft.Agents.AI.A2A.UnitTests/A2AAgentSessionTests.cs index 66018e0131..8c3e89adf6 100644 --- a/dotnet/tests/Microsoft.Agents.AI.A2A.UnitTests/A2AAgentSessionTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.A2A.UnitTests/A2AAgentSessionTests.cs @@ -21,10 +21,28 @@ public void Constructor_RoundTrip_SerializationPreservesState() // Act JsonElement serialized = originalSession.Serialize(); - A2AAgentSession deserializedSession = new(serialized); + A2AAgentSession deserializedSession = A2AAgentSession.Deserialize(serialized); // Assert Assert.Equal(originalSession.ContextId, deserializedSession.ContextId); Assert.Equal(originalSession.TaskId, deserializedSession.TaskId); } + + [Fact] + public void Constructor_RoundTrip_SerializationPreservesStateBag() + { + // Arrange + A2AAgentSession originalSession = new() { ContextId = "ctx-1", TaskId = "task-1" }; + originalSession.StateBag.SetValue("testKey", "testValue"); + + // Act + JsonElement serialized = originalSession.Serialize(); + A2AAgentSession deserializedSession = A2AAgentSession.Deserialize(serialized); + + // Assert + Assert.Equal("ctx-1", deserializedSession.ContextId); + Assert.Equal("task-1", deserializedSession.TaskId); + Assert.True(deserializedSession.StateBag.TryGetValue("testKey", out var value)); + Assert.Equal("testValue", value); + } } diff --git a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AIContextProviderTests.cs b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AIContextProviderTests.cs index aa41a03efc..14a0f81e08 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AIContextProviderTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AIContextProviderTests.cs @@ -3,7 +3,6 @@ using System; using System.Collections.Generic; using System.Collections.ObjectModel; -using System.Linq; using System.Threading; using System.Threading.Tasks; using Microsoft.Extensions.AI; @@ -16,110 +15,6 @@ public class AIContextProviderTests private static readonly AIAgent s_mockAgent = new Mock().Object; private static readonly AgentSession s_mockSession = new Mock().Object; - #region InvokingAsync Message Stamping Tests - - [Fact] - public async Task InvokingAsync_StampsMessagesWithSourceTypeAndSourceIdAsync() - { - // Arrange - var provider = new TestAIContextProviderWithMessages(); - var context = new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Request")]); - - // Act - AIContext aiContext = await provider.InvokingAsync(context); - - // Assert - Assert.NotNull(aiContext.Messages); - ChatMessage message = aiContext.Messages.Single(); - Assert.NotNull(message.AdditionalProperties); - Assert.True(message.AdditionalProperties.TryGetValue(AgentRequestMessageSourceAttribution.AdditionalPropertiesKey, out object? attribution)); - var typedAttribution = Assert.IsType(attribution); - Assert.Equal(AgentRequestMessageSourceType.AIContextProvider, typedAttribution.SourceType); - Assert.Equal(typeof(TestAIContextProviderWithMessages).FullName, typedAttribution.SourceId); - } - - [Fact] - public async Task InvokingAsync_WithCustomSourceId_StampsMessagesWithCustomSourceIdAsync() - { - // Arrange - const string CustomSourceId = "CustomContextSource"; - var provider = new TestAIContextProviderWithCustomSource(CustomSourceId); - var context = new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Request")]); - - // Act - AIContext aiContext = await provider.InvokingAsync(context); - - // Assert - Assert.NotNull(aiContext.Messages); - ChatMessage message = aiContext.Messages.Single(); - Assert.NotNull(message.AdditionalProperties); - Assert.True(message.AdditionalProperties.TryGetValue(AgentRequestMessageSourceAttribution.AdditionalPropertiesKey, out object? attribution)); - var typedAttribution = Assert.IsType(attribution); - Assert.Equal(AgentRequestMessageSourceType.AIContextProvider, typedAttribution.SourceType); - Assert.Equal(CustomSourceId, typedAttribution.SourceId); - } - - [Fact] - public async Task InvokingAsync_DoesNotReStampAlreadyStampedMessagesAsync() - { - // Arrange - var provider = new TestAIContextProviderWithPreStampedMessages(); - var context = new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Request")]); - - // Act - AIContext aiContext = await provider.InvokingAsync(context); - - // Assert - Assert.NotNull(aiContext.Messages); - ChatMessage message = aiContext.Messages.Single(); - Assert.NotNull(message.AdditionalProperties); - Assert.True(message.AdditionalProperties.TryGetValue(AgentRequestMessageSourceAttribution.AdditionalPropertiesKey, out object? attribution)); - var typedAttribution = Assert.IsType(attribution); - Assert.Equal(AgentRequestMessageSourceType.AIContextProvider, typedAttribution.SourceType); - Assert.Equal(typeof(TestAIContextProviderWithPreStampedMessages).FullName, typedAttribution.SourceId); - } - - [Fact] - public async Task InvokingAsync_StampsMultipleMessagesAsync() - { - // Arrange - var provider = new TestAIContextProviderWithMultipleMessages(); - var context = new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Request")]); - - // Act - AIContext aiContext = await provider.InvokingAsync(context); - - // Assert - Assert.NotNull(aiContext.Messages); - List messageList = aiContext.Messages.ToList(); - Assert.Equal(3, messageList.Count); - - foreach (ChatMessage message in messageList) - { - Assert.NotNull(message.AdditionalProperties); - Assert.True(message.AdditionalProperties.TryGetValue(AgentRequestMessageSourceAttribution.AdditionalPropertiesKey, out object? attribution)); - var typedAttribution = Assert.IsType(attribution); - Assert.Equal(AgentRequestMessageSourceType.AIContextProvider, typedAttribution.SourceType); - Assert.Equal(typeof(TestAIContextProviderWithMultipleMessages).FullName, typedAttribution.SourceId); - } - } - - [Fact] - public async Task InvokingAsync_WithNullMessages_ReturnsContextWithoutStampingAsync() - { - // Arrange - var provider = new TestAIContextProvider(); - var context = new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Request")]); - - // Act - AIContext aiContext = await provider.InvokingAsync(context); - - // Assert - Assert.Null(aiContext.Messages); - } - - #endregion - #region Basic Tests [Fact] @@ -130,25 +25,12 @@ public async Task InvokedAsync_ReturnsCompletedTaskAsync() var messages = new ReadOnlyCollection([]); // Act - ValueTask task = provider.InvokedAsync(new(s_mockAgent, s_mockSession, messages)); + ValueTask task = provider.InvokedAsync(new(s_mockAgent, s_mockSession, messages, [])); // Assert Assert.Equal(default, task); } - [Fact] - public void Serialize_ReturnsEmptyElement() - { - // Arrange - var provider = new TestAIContextProvider(); - - // Act - var actual = provider.Serialize(); - - // Assert - Assert.Equal(default, actual); - } - [Fact] public void InvokingContext_Constructor_ThrowsForNullMessages() { @@ -160,7 +42,7 @@ public void InvokingContext_Constructor_ThrowsForNullMessages() public void InvokedContext_Constructor_ThrowsForNullMessages() { // Act & Assert - Assert.Throws(() => new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, null!)); + Assert.Throws(() => new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, null!, [])); } #endregion @@ -284,39 +166,33 @@ public void GetService_Generic_ReturnsNullForUnrelatedType() #region InvokingContext Tests [Fact] - public void InvokingContext_RequestMessages_SetterThrowsForNull() + public void InvokingContext_Constructor_ThrowsForNullAIContext() { - // Arrange - var messages = new ReadOnlyCollection([new(ChatRole.User, "Hello")]); - var context = new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, messages); - // Act & Assert - Assert.Throws(() => context.RequestMessages = null!); + Assert.Throws(() => new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, null!)); } [Fact] - public void InvokingContext_RequestMessages_SetterRoundtrips() + public void InvokingContext_AIContext_ConstructorValueRoundtrips() { // Arrange - var initialMessages = new ReadOnlyCollection([new(ChatRole.User, "Hello")]); - var newMessages = new List { new(ChatRole.User, "New message") }; - var context = new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, initialMessages); + var aiContext = new AIContext { Messages = [new ChatMessage(ChatRole.User, "Hello")] }; // Act - context.RequestMessages = newMessages; + var context = new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, aiContext); // Assert - Assert.Same(newMessages, context.RequestMessages); + Assert.Same(aiContext, context.AIContext); } [Fact] public void InvokingContext_Agent_ReturnsConstructorValue() { // Arrange - var messages = new ReadOnlyCollection([new(ChatRole.User, "Hello")]); + var aiContext = new AIContext { Messages = [new ChatMessage(ChatRole.User, "Hello")] }; // Act - var context = new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, messages); + var context = new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, aiContext); // Assert Assert.Same(s_mockAgent, context.Agent); @@ -326,10 +202,10 @@ public void InvokingContext_Agent_ReturnsConstructorValue() public void InvokingContext_Session_ReturnsConstructorValue() { // Arrange - var messages = new ReadOnlyCollection([new(ChatRole.User, "Hello")]); + var aiContext = new AIContext { Messages = [new ChatMessage(ChatRole.User, "Hello")] }; // Act - var context = new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, messages); + var context = new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, aiContext); // Assert Assert.Same(s_mockSession, context.Session); @@ -339,10 +215,10 @@ public void InvokingContext_Session_ReturnsConstructorValue() public void InvokingContext_Session_CanBeNull() { // Arrange - var messages = new ReadOnlyCollection([new(ChatRole.User, "Hello")]); + var aiContext = new AIContext { Messages = [new ChatMessage(ChatRole.User, "Hello")] }; // Act - var context = new AIContextProvider.InvokingContext(s_mockAgent, null, messages); + var context = new AIContextProvider.InvokingContext(s_mockAgent, null, aiContext); // Assert Assert.Null(context.Session); @@ -352,52 +228,25 @@ public void InvokingContext_Session_CanBeNull() public void InvokingContext_Constructor_ThrowsForNullAgent() { // Arrange - var messages = new ReadOnlyCollection([new(ChatRole.User, "Hello")]); + var aiContext = new AIContext { Messages = [new ChatMessage(ChatRole.User, "Hello")] }; // Act & Assert - Assert.Throws(() => new AIContextProvider.InvokingContext(null!, s_mockSession, messages)); + Assert.Throws(() => new AIContextProvider.InvokingContext(null!, s_mockSession, aiContext)); } #endregion #region InvokedContext Tests - [Fact] - public void InvokedContext_RequestMessages_SetterThrowsForNull() - { - // Arrange - var messages = new ReadOnlyCollection([new(ChatRole.User, "Hello")]); - var context = new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, messages); - - // Act & Assert - Assert.Throws(() => context.RequestMessages = null!); - } - - [Fact] - public void InvokedContext_RequestMessages_SetterRoundtrips() - { - // Arrange - var initialMessages = new ReadOnlyCollection([new(ChatRole.User, "Hello")]); - var newMessages = new List { new(ChatRole.User, "New message") }; - var context = new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, initialMessages); - - // Act - context.RequestMessages = newMessages; - - // Assert - Assert.Same(newMessages, context.RequestMessages); - } - [Fact] public void InvokedContext_ResponseMessages_Roundtrips() { // Arrange var requestMessages = new ReadOnlyCollection([new(ChatRole.User, "Hello")]); var responseMessages = new List { new(ChatRole.Assistant, "Response message") }; - var context = new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages); // Act - context.ResponseMessages = responseMessages; + var context = new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages, responseMessages); // Assert Assert.Same(responseMessages, context.ResponseMessages); @@ -409,10 +258,9 @@ public void InvokedContext_InvokeException_Roundtrips() // Arrange var requestMessages = new ReadOnlyCollection([new(ChatRole.User, "Hello")]); var exception = new InvalidOperationException("Test exception"); - var context = new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages); // Act - context.InvokeException = exception; + var context = new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages, exception); // Assert Assert.Same(exception, context.InvokeException); @@ -425,7 +273,7 @@ public void InvokedContext_Agent_ReturnsConstructorValue() var requestMessages = new ReadOnlyCollection([new(ChatRole.User, "Hello")]); // Act - var context = new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages); + var context = new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages, []); // Assert Assert.Same(s_mockAgent, context.Agent); @@ -438,7 +286,7 @@ public void InvokedContext_Session_ReturnsConstructorValue() var requestMessages = new ReadOnlyCollection([new(ChatRole.User, "Hello")]); // Act - var context = new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages); + var context = new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages, []); // Assert Assert.Same(s_mockSession, context.Session); @@ -451,7 +299,7 @@ public void InvokedContext_Session_CanBeNull() var requestMessages = new ReadOnlyCollection([new(ChatRole.User, "Hello")]); // Act - var context = new AIContextProvider.InvokedContext(s_mockAgent, null, requestMessages); + var context = new AIContextProvider.InvokedContext(s_mockAgent, null, requestMessages, []); // Assert Assert.Null(context.Session); @@ -464,65 +312,34 @@ public void InvokedContext_Constructor_ThrowsForNullAgent() var requestMessages = new ReadOnlyCollection([new(ChatRole.User, "Hello")]); // Act & Assert - Assert.Throws(() => new AIContextProvider.InvokedContext(null!, s_mockSession, requestMessages)); + Assert.Throws(() => new AIContextProvider.InvokedContext(null!, s_mockSession, requestMessages, [])); } - #endregion - - private sealed class TestAIContextProvider : AIContextProvider + [Fact] + public void InvokedContext_SuccessConstructor_ThrowsForNullResponseMessages() { - protected override ValueTask InvokingCoreAsync(InvokingContext context, CancellationToken cancellationToken = default) - => new(new AIContext()); - } + // Arrange + var requestMessages = new ReadOnlyCollection([new(ChatRole.User, "Hello")]); - private sealed class TestAIContextProviderWithMessages : AIContextProvider - { - protected override ValueTask InvokingCoreAsync(InvokingContext context, CancellationToken cancellationToken = default) - => new(new AIContext - { - Messages = [new ChatMessage(ChatRole.System, "Context Message")] - }); + // Act & Assert + Assert.Throws(() => new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages, (IEnumerable)null!)); } - private sealed class TestAIContextProviderWithCustomSource : AIContextProvider + [Fact] + public void InvokedContext_FailureConstructor_ThrowsForNullException() { - public TestAIContextProviderWithCustomSource(string sourceId) : base(sourceId) - { - } + // Arrange + var requestMessages = new ReadOnlyCollection([new(ChatRole.User, "Hello")]); - protected override ValueTask InvokingCoreAsync(InvokingContext context, CancellationToken cancellationToken = default) - => new(new AIContext - { - Messages = [new ChatMessage(ChatRole.System, "Context Message")] - }); + // Act & Assert + Assert.Throws(() => new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages, (Exception)null!)); } - private sealed class TestAIContextProviderWithPreStampedMessages : AIContextProvider - { - protected override ValueTask InvokingCoreAsync(InvokingContext context, CancellationToken cancellationToken = default) - { - var message = new ChatMessage(ChatRole.System, "Pre-stamped Message"); - message.AdditionalProperties = new AdditionalPropertiesDictionary - { - [AgentRequestMessageSourceAttribution.AdditionalPropertiesKey] = new AgentRequestMessageSourceAttribution(AgentRequestMessageSourceType.AIContextProvider, this.GetType().FullName!) - }; - return new(new AIContext - { - Messages = [message] - }); - } - } + #endregion - private sealed class TestAIContextProviderWithMultipleMessages : AIContextProvider + private sealed class TestAIContextProvider : AIContextProvider { protected override ValueTask InvokingCoreAsync(InvokingContext context, CancellationToken cancellationToken = default) - => new(new AIContext - { - Messages = [ - new ChatMessage(ChatRole.System, "Message 1"), - new ChatMessage(ChatRole.User, "Message 2"), - new ChatMessage(ChatRole.Assistant, "Message 3") - ] - }); + => new(new AIContext()); } } diff --git a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AIContextTests.cs b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AIContextTests.cs index b1ba6060ea..c925f098b3 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AIContextTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AIContextTests.cs @@ -1,5 +1,6 @@ // Copyright (c) Microsoft. All rights reserved. +using System.Linq; using Microsoft.Extensions.AI; namespace Microsoft.Agents.AI.Abstractions.UnitTests; @@ -33,9 +34,10 @@ public void SetMessagesRoundtrips() }; Assert.NotNull(context.Messages); - Assert.Equal(2, context.Messages.Count); - Assert.Equal("Hello", context.Messages[0].Text); - Assert.Equal("Hi there!", context.Messages[1].Text); + var messages = context.Messages.ToList(); + Assert.Equal(2, messages.Count); + Assert.Equal("Hello", messages[0].Text); + Assert.Equal("Hi there!", messages[1].Text); } [Fact] @@ -51,8 +53,9 @@ public void SetAIFunctionsRoundtrips() }; Assert.NotNull(context.Tools); - Assert.Equal(2, context.Tools.Count); - Assert.Equal("Function1", context.Tools[0].Name); - Assert.Equal("Function2", context.Tools[1].Name); + var tools = context.Tools.ToList(); + Assert.Equal(2, tools.Count); + Assert.Equal("Function1", tools[0].Name); + Assert.Equal("Function2", tools[1].Name); } } diff --git a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AgentRequestMessageSourceAttributionTests.cs b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AgentRequestMessageSourceAttributionTests.cs index 5aee121097..70c7a2e06f 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AgentRequestMessageSourceAttributionTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AgentRequestMessageSourceAttributionTests.cs @@ -390,6 +390,49 @@ public void EqualityOperator_WithDifferentSourceIdOnly_ReturnsFalse() #endregion + #region ToString Tests + + [Fact] + public void ToString_WithSourceId_ReturnsTypeColonId() + { + // Arrange + AgentRequestMessageSourceAttribution attribution = new(AgentRequestMessageSourceType.AIContextProvider, "MyProvider"); + + // Act + string result = attribution.ToString(); + + // Assert + Assert.Equal("AIContextProvider:MyProvider", result); + } + + [Fact] + public void ToString_WithNullSourceId_ReturnsTypeOnly() + { + // Arrange + AgentRequestMessageSourceAttribution attribution = new(AgentRequestMessageSourceType.ChatHistory, null); + + // Act + string result = attribution.ToString(); + + // Assert + Assert.Equal("ChatHistory", result); + } + + [Fact] + public void ToString_Default_ReturnsExternalOnly() + { + // Arrange + AgentRequestMessageSourceAttribution attribution = default; + + // Act + string result = attribution.ToString(); + + // Assert + Assert.Equal("External", result); + } + + #endregion + #region Inequality Operator Tests [Fact] diff --git a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AgentRequestMessageSourceTypeTests.cs b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AgentRequestMessageSourceTypeTests.cs index 000505fe32..973228828b 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AgentRequestMessageSourceTypeTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AgentRequestMessageSourceTypeTests.cs @@ -414,6 +414,46 @@ public void InequalityOperator_DifferentStaticInstances_ReturnsTrue() #endregion + #region ToString Tests + + [Fact] + public void ToString_ReturnsValue() + { + // Arrange + AgentRequestMessageSourceType source = new("CustomSource"); + + // Act + string result = source.ToString(); + + // Assert + Assert.Equal("CustomSource", result); + } + + [Fact] + public void ToString_StaticExternal_ReturnsExternal() + { + // Arrange & Act + string result = AgentRequestMessageSourceType.External.ToString(); + + // Assert + Assert.Equal("External", result); + } + + [Fact] + public void ToString_Default_ReturnsExternal() + { + // Arrange + AgentRequestMessageSourceType source = default; + + // Act + string result = source.ToString(); + + // Assert + Assert.Equal("External", result); + } + + #endregion + #region IEquatable Tests [Fact] diff --git a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AgentSessionStateBagTests.cs b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AgentSessionStateBagTests.cs new file mode 100644 index 0000000000..a51f6dcb86 --- /dev/null +++ b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AgentSessionStateBagTests.cs @@ -0,0 +1,840 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Text.Json; +using Microsoft.Agents.AI.Abstractions.UnitTests.Models; + +namespace Microsoft.Agents.AI.Abstractions.UnitTests; + +/// +/// Contains tests for the class. +/// +public sealed class AgentSessionStateBagTests +{ + #region Constructor Tests + + [Fact] + public void Constructor_Default_CreatesEmptyStateBag() + { + // Act + var stateBag = new AgentSessionStateBag(); + + // Assert + Assert.False(stateBag.TryGetValue("nonexistent", out _)); + } + + #endregion + + #region SetValue Tests + + [Fact] + public void SetValue_WithValidKeyAndValue_StoresValue() + { + // Arrange + var stateBag = new AgentSessionStateBag(); + + // Act + stateBag.SetValue("key1", "value1"); + + // Assert + Assert.True(stateBag.TryGetValue("key1", out var result)); + Assert.Equal("value1", result); + } + + [Fact] + public void SetValue_WithNullKey_ThrowsArgumentException() + { + // Arrange + var stateBag = new AgentSessionStateBag(); + + // Act & Assert + Assert.Throws(() => stateBag.SetValue(null!, "value")); + } + + [Fact] + public void SetValue_WithEmptyKey_ThrowsArgumentException() + { + // Arrange + var stateBag = new AgentSessionStateBag(); + + // Act & Assert + Assert.Throws(() => stateBag.SetValue("", "value")); + } + + [Fact] + public void SetValue_WithWhitespaceKey_ThrowsArgumentException() + { + // Arrange + var stateBag = new AgentSessionStateBag(); + + // Act & Assert + Assert.Throws(() => stateBag.SetValue(" ", "value")); + } + + [Fact] + public void SetValue_OverwritesExistingValue() + { + // Arrange + var stateBag = new AgentSessionStateBag(); + stateBag.SetValue("key1", "originalValue"); + + // Act + stateBag.SetValue("key1", "newValue"); + + // Assert + Assert.Equal("newValue", stateBag.GetValue("key1")); + } + + #endregion + + #region GetValue Tests + + [Fact] + public void GetValue_WithExistingKey_ReturnsValue() + { + // Arrange + var stateBag = new AgentSessionStateBag(); + stateBag.SetValue("key1", "value1"); + + // Act + var result = stateBag.GetValue("key1"); + + // Assert + Assert.Equal("value1", result); + } + + [Fact] + public void GetValue_WithNonexistentKey_ReturnsNull() + { + // Arrange + var stateBag = new AgentSessionStateBag(); + + // Act + var result = stateBag.GetValue("nonexistent"); + + // Assert + Assert.Null(result); + } + + [Fact] + public void GetValue_WithNullKey_ThrowsArgumentException() + { + // Arrange + var stateBag = new AgentSessionStateBag(); + + // Act & Assert + Assert.Throws(() => stateBag.GetValue(null!)); + } + + [Fact] + public void GetValue_WithEmptyKey_ThrowsArgumentException() + { + // Arrange + var stateBag = new AgentSessionStateBag(); + + // Act & Assert + Assert.Throws(() => stateBag.GetValue("")); + } + + [Fact] + public void GetValue_CachesDeserializedValue() + { + // Arrange + var stateBag = new AgentSessionStateBag(); + stateBag.SetValue("key1", "value1"); + + // Act + var result1 = stateBag.GetValue("key1"); + var result2 = stateBag.GetValue("key1"); + + // Assert + Assert.Same(result1, result2); + } + + #endregion + + #region TryGetValue Tests + + [Fact] + public void TryGetValue_WithExistingKey_ReturnsTrueAndValue() + { + // Arrange + var stateBag = new AgentSessionStateBag(); + stateBag.SetValue("key1", "value1"); + + // Act + var found = stateBag.TryGetValue("key1", out var result); + + // Assert + Assert.True(found); + Assert.Equal("value1", result); + } + + [Fact] + public void TryGetValue_WithNonexistentKey_ReturnsFalseAndNull() + { + // Arrange + var stateBag = new AgentSessionStateBag(); + + // Act + var found = stateBag.TryGetValue("nonexistent", out var result); + + // Assert + Assert.False(found); + Assert.Null(result); + } + + [Fact] + public void TryGetValue_WithNullKey_ThrowsArgumentException() + { + // Arrange + var stateBag = new AgentSessionStateBag(); + + // Act & Assert + Assert.Throws(() => stateBag.TryGetValue(null!, out _)); + } + + [Fact] + public void TryGetValue_WithEmptyKey_ThrowsArgumentException() + { + // Arrange + var stateBag = new AgentSessionStateBag(); + + // Act & Assert + Assert.Throws(() => stateBag.TryGetValue("", out _)); + } + + #endregion + + #region Null Value Tests + + [Fact] + public void SetValue_WithNullValue_StoresNull() + { + // Arrange + var stateBag = new AgentSessionStateBag(); + + // Act + stateBag.SetValue("key1", null); + + // Assert + Assert.Equal(1, stateBag.Count); + } + + [Fact] + public void TryGetValue_WithNullValue_ReturnsTrueAndNull() + { + // Arrange + var stateBag = new AgentSessionStateBag(); + stateBag.SetValue("key1", null); + + // Act + var found = stateBag.TryGetValue("key1", out var result); + + // Assert + Assert.True(found); + Assert.Null(result); + } + + [Fact] + public void GetValue_WithNullValue_ReturnsNull() + { + // Arrange + var stateBag = new AgentSessionStateBag(); + stateBag.SetValue("key1", null); + + // Act + var result = stateBag.GetValue("key1"); + + // Assert + Assert.Null(result); + } + + [Fact] + public void SetValue_OverwriteWithNull_ReturnsNull() + { + // Arrange + var stateBag = new AgentSessionStateBag(); + stateBag.SetValue("key1", "value1"); + + // Act + stateBag.SetValue("key1", null); + + // Assert + Assert.True(stateBag.TryGetValue("key1", out var result)); + Assert.Null(result); + } + + [Fact] + public void SetValue_OverwriteNullWithValue_ReturnsValue() + { + // Arrange + var stateBag = new AgentSessionStateBag(); + stateBag.SetValue("key1", null); + + // Act + stateBag.SetValue("key1", "newValue"); + + // Assert + Assert.True(stateBag.TryGetValue("key1", out var result)); + Assert.Equal("newValue", result); + } + + [Fact] + public void SerializeDeserialize_WithNullValue_SerializesAsNull() + { + // Arrange + var stateBag = new AgentSessionStateBag(); + stateBag.SetValue("nullKey", null); + + // Act + var json = stateBag.Serialize(); + + // Assert - null values are serialized as JSON null + Assert.Equal(JsonValueKind.Object, json.ValueKind); + Assert.True(json.TryGetProperty("nullKey", out var nullElement)); + Assert.Equal(JsonValueKind.Null, nullElement.ValueKind); + } + + #endregion + + #region TryRemoveValue Tests + + [Fact] + public void TryRemoveValue_ExistingKey_ReturnsTrueAndRemoves() + { + // Arrange + var stateBag = new AgentSessionStateBag(); + stateBag.SetValue("key1", "value1"); + + // Act + var removed = stateBag.TryRemoveValue("key1"); + + // Assert + Assert.True(removed); + Assert.Equal(0, stateBag.Count); + Assert.False(stateBag.TryGetValue("key1", out _)); + } + + [Fact] + public void TryRemoveValue_NonexistentKey_ReturnsFalse() + { + // Arrange + var stateBag = new AgentSessionStateBag(); + + // Act + var removed = stateBag.TryRemoveValue("nonexistent"); + + // Assert + Assert.False(removed); + } + + [Fact] + public void TryRemoveValue_WithNullKey_ThrowsArgumentException() + { + // Arrange + var stateBag = new AgentSessionStateBag(); + + // Act & Assert + Assert.Throws(() => stateBag.TryRemoveValue(null!)); + } + + [Fact] + public void TryRemoveValue_WithEmptyKey_ThrowsArgumentException() + { + // Arrange + var stateBag = new AgentSessionStateBag(); + + // Act & Assert + Assert.Throws(() => stateBag.TryRemoveValue("")); + } + + [Fact] + public void TryRemoveValue_WithWhitespaceKey_ThrowsArgumentException() + { + // Arrange + var stateBag = new AgentSessionStateBag(); + + // Act & Assert + Assert.Throws(() => stateBag.TryRemoveValue(" ")); + } + + [Fact] + public void TryRemoveValue_DoesNotAffectOtherKeys() + { + // Arrange + var stateBag = new AgentSessionStateBag(); + stateBag.SetValue("key1", "value1"); + stateBag.SetValue("key2", "value2"); + + // Act + stateBag.TryRemoveValue("key1"); + + // Assert + Assert.Equal(1, stateBag.Count); + Assert.False(stateBag.TryGetValue("key1", out _)); + Assert.True(stateBag.TryGetValue("key2", out var value)); + Assert.Equal("value2", value); + } + + [Fact] + public void TryRemoveValue_ThenSetValue_Works() + { + // Arrange + var stateBag = new AgentSessionStateBag(); + stateBag.SetValue("key1", "original"); + + // Act + stateBag.TryRemoveValue("key1"); + stateBag.SetValue("key1", "replacement"); + + // Assert + Assert.True(stateBag.TryGetValue("key1", out var result)); + Assert.Equal("replacement", result); + } + + #endregion + + #region Serialize/Deserialize Tests + + [Fact] + public void Serialize_EmptyStateBag_ReturnsEmptyObject() + { + // Arrange + var stateBag = new AgentSessionStateBag(); + + // Act + var json = stateBag.Serialize(); + + // Assert + Assert.Equal(JsonValueKind.Object, json.ValueKind); + } + + [Fact] + public void Serialize_WithStringValue_ReturnsJsonWithValue() + { + // Arrange + var stateBag = new AgentSessionStateBag(); + stateBag.SetValue("stringKey", "stringValue"); + + // Act + var json = stateBag.Serialize(); + + // Assert + Assert.Equal(JsonValueKind.Object, json.ValueKind); + Assert.True(json.TryGetProperty("stringKey", out _)); + } + + [Fact] + public void Deserialize_FromJsonDocument_ReturnsEmptyStateBag() + { + // Arrange + var emptyJson = JsonDocument.Parse("{}").RootElement; + + // Act + var stateBag = AgentSessionStateBag.Deserialize(emptyJson); + + // Assert + Assert.False(stateBag.TryGetValue("nonexistent", out _)); + } + + [Fact] + public void Deserialize_NullElement_ReturnsEmptyStateBag() + { + // Arrange + var nullJson = default(JsonElement); + + // Act + var stateBag = AgentSessionStateBag.Deserialize(nullJson); + + // Assert + Assert.False(stateBag.TryGetValue("nonexistent", out _)); + } + + [Fact] + public void SerializeDeserialize_WithStringValue_Roundtrips() + { + // Arrange + var originalStateBag = new AgentSessionStateBag(); + originalStateBag.SetValue("stringKey", "stringValue"); + + // Act + var json = originalStateBag.Serialize(); + var restoredStateBag = AgentSessionStateBag.Deserialize(json); + + // Assert + Assert.Equal("stringValue", restoredStateBag.GetValue("stringKey")); + } + + #endregion + + #region Thread Safety Tests + + [Fact] + public async System.Threading.Tasks.Task SetValue_MultipleConcurrentWrites_DoesNotThrowAsync() + { + // Arrange + var stateBag = new AgentSessionStateBag(); + var tasks = new System.Threading.Tasks.Task[100]; + + // Act + for (int i = 0; i < 100; i++) + { + int index = i; + tasks[i] = System.Threading.Tasks.Task.Run(() => stateBag.SetValue($"key{index}", $"value{index}")); + } + + await System.Threading.Tasks.Task.WhenAll(tasks); + + // Assert + for (int i = 0; i < 100; i++) + { + Assert.True(stateBag.TryGetValue($"key{i}", out var value)); + Assert.Equal($"value{i}", value); + } + } + + [Fact] + public async System.Threading.Tasks.Task ConcurrentWritesAndSerialize_DoesNotThrowAsync() + { + // Arrange + var stateBag = new AgentSessionStateBag(); + stateBag.SetValue("shared", "initial"); + var tasks = new System.Threading.Tasks.Task[100]; + + // Act - concurrently write and serialize the same key + for (int i = 0; i < 100; i++) + { + int index = i; + tasks[i] = System.Threading.Tasks.Task.Run(() => + { + stateBag.SetValue("shared", $"value{index}"); + _ = stateBag.Serialize(); + }); + } + + await System.Threading.Tasks.Task.WhenAll(tasks); + + // Assert - should have some value and serialize without error + Assert.True(stateBag.TryGetValue("shared", out var result)); + Assert.NotNull(result); + var json = stateBag.Serialize(); + Assert.Equal(JsonValueKind.Object, json.ValueKind); + } + + [Fact] + public async System.Threading.Tasks.Task ConcurrentReadsAndWrites_DoesNotThrowAsync() + { + // Arrange + var stateBag = new AgentSessionStateBag(); + stateBag.SetValue("key", "initial"); + var tasks = new System.Threading.Tasks.Task[200]; + + // Act - half readers, half writers on the same key + for (int i = 0; i < 200; i++) + { + int index = i; + tasks[i] = (index % 2 == 0) + ? System.Threading.Tasks.Task.Run(() => stateBag.GetValue("key")) + : System.Threading.Tasks.Task.Run(() => stateBag.SetValue("key", $"value{index}")); + } + + await System.Threading.Tasks.Task.WhenAll(tasks); + + // Assert - should have a consistent value + Assert.True(stateBag.TryGetValue("key", out var result)); + Assert.NotNull(result); + } + + #endregion + + #region Complex Object Tests + + [Fact] + public void SetValue_WithComplexObject_StoresValue() + { + // Arrange + var stateBag = new AgentSessionStateBag(); + var animal = new Animal { Id = 1, FullName = "Buddy", Species = Species.Bear }; + + // Act + stateBag.SetValue("animal", animal, TestJsonSerializerContext.Default.Options); + + // Assert + Animal? result = stateBag.GetValue("animal", TestJsonSerializerContext.Default.Options); + Assert.NotNull(result); + Assert.Equal(1, result.Id); + Assert.Equal("Buddy", result.FullName); + Assert.Equal(Species.Bear, result.Species); + } + + [Fact] + public void GetValue_WithComplexObject_CachesDeserializedValue() + { + // Arrange + var stateBag = new AgentSessionStateBag(); + var animal = new Animal { Id = 2, FullName = "Whiskers", Species = Species.Tiger }; + stateBag.SetValue("animal", animal, TestJsonSerializerContext.Default.Options); + + // Act + Animal? result1 = stateBag.GetValue("animal", TestJsonSerializerContext.Default.Options); + Animal? result2 = stateBag.GetValue("animal", TestJsonSerializerContext.Default.Options); + + // Assert + Assert.Same(result1, result2); + } + + [Fact] + public void TryGetValue_WithComplexObject_ReturnsTrueAndValue() + { + // Arrange + var stateBag = new AgentSessionStateBag(); + var animal = new Animal { Id = 3, FullName = "Goldie", Species = Species.Walrus }; + stateBag.SetValue("animal", animal, TestJsonSerializerContext.Default.Options); + + // Act + bool found = stateBag.TryGetValue("animal", out Animal? result, TestJsonSerializerContext.Default.Options); + + // Assert + Assert.True(found); + Assert.NotNull(result); + Assert.Equal(3, result.Id); + Assert.Equal("Goldie", result.FullName); + Assert.Equal(Species.Walrus, result.Species); + } + + [Fact] + public void SerializeDeserialize_WithComplexObject_Roundtrips() + { + // Arrange + var originalStateBag = new AgentSessionStateBag(); + var animal = new Animal { Id = 4, FullName = "Polly", Species = Species.Bear }; + originalStateBag.SetValue("animal", animal, TestJsonSerializerContext.Default.Options); + + // Act + JsonElement json = originalStateBag.Serialize(); + AgentSessionStateBag restoredStateBag = AgentSessionStateBag.Deserialize(json); + + // Assert + Animal? restoredAnimal = restoredStateBag.GetValue("animal", TestJsonSerializerContext.Default.Options); + Assert.NotNull(restoredAnimal); + Assert.Equal(4, restoredAnimal.Id); + Assert.Equal("Polly", restoredAnimal.FullName); + Assert.Equal(Species.Bear, restoredAnimal.Species); + } + + [Fact] + public void Serialize_WithComplexObject_ReturnsJsonWithProperties() + { + // Arrange + var stateBag = new AgentSessionStateBag(); + var animal = new Animal { Id = 7, FullName = "Spot", Species = Species.Walrus }; + stateBag.SetValue("animal", animal, TestJsonSerializerContext.Default.Options); + + // Act + JsonElement json = stateBag.Serialize(); + + // Assert + Assert.Equal(JsonValueKind.Object, json.ValueKind); + Assert.True(json.TryGetProperty("animal", out JsonElement animalElement)); + Assert.Equal(JsonValueKind.Object, animalElement.ValueKind); + Assert.Equal(7, animalElement.GetProperty("id").GetInt32()); + Assert.Equal("Spot", animalElement.GetProperty("fullName").GetString()); + Assert.Equal("Walrus", animalElement.GetProperty("species").GetString()); + } + + #endregion + + #region Type Mismatch Tests + + [Fact] + public void TryGetValue_WithDifferentTypeAfterSet_ReturnsFalse() + { + // Arrange + var stateBag = new AgentSessionStateBag(); + stateBag.SetValue("key1", "hello"); + + // Act + var found = stateBag.TryGetValue("key1", out var result, TestJsonSerializerContext.Default.Options); + + // Assert + Assert.False(found); + Assert.Null(result); + } + + [Fact] + public void GetValue_WithDifferentTypeAfterSet_ThrowsInvalidOperationException() + { + // Arrange + var stateBag = new AgentSessionStateBag(); + stateBag.SetValue("key1", "hello"); + + // Act & Assert + Assert.Throws(() => stateBag.GetValue("key1", TestJsonSerializerContext.Default.Options)); + } + + [Fact] + public void TryGetValue_WithDifferentTypeAfterDeserializedRead_ReturnsFalse() + { + // Arrange + var stateBag = new AgentSessionStateBag(); + stateBag.SetValue("key1", "hello"); + + // First read caches the value as string + var cachedValue = stateBag.GetValue("key1"); + Assert.Equal("hello", cachedValue); + + // Act - request as a different type + var found = stateBag.TryGetValue("key1", out var result, TestJsonSerializerContext.Default.Options); + + // Assert + Assert.False(found); + Assert.Null(result); + } + + [Fact] + public void GetValue_WithDifferentTypeAfterDeserializedRoundtrip_ThrowsInvalidOperationException() + { + // Arrange + var originalStateBag = new AgentSessionStateBag(); + originalStateBag.SetValue("key1", "hello"); + + // Round-trip through serialization + var json = originalStateBag.Serialize(); + var restoredStateBag = AgentSessionStateBag.Deserialize(json); + + // First read caches the value as string + var cachedValue = restoredStateBag.GetValue("key1"); + Assert.Equal("hello", cachedValue); + + // Act & Assert - request as a different type + Assert.Throws(() => restoredStateBag.GetValue("key1", TestJsonSerializerContext.Default.Options)); + } + + [Fact] + public void TryGetValue_ComplexTypeAfterSetString_ReturnsFalse() + { + // Arrange + var stateBag = new AgentSessionStateBag(); + stateBag.SetValue("animal", "not an animal"); + + // Act + var found = stateBag.TryGetValue("animal", out var result, TestJsonSerializerContext.Default.Options); + + // Assert + Assert.False(found); + Assert.Null(result); + } + + [Fact] + public void GetValue_TypeMismatch_ExceptionMessageContainsBothTypeNames() + { + // Arrange + var stateBag = new AgentSessionStateBag(); + stateBag.SetValue("key1", "hello"); + + // Act + var exception = Assert.Throws(() => stateBag.GetValue("key1", TestJsonSerializerContext.Default.Options)); + + // Assert + Assert.Contains(typeof(string).FullName!, exception.Message); + Assert.Contains(typeof(Animal).FullName!, exception.Message); + } + + #endregion + + #region JsonSerializer Integration Tests + + [Fact] + public void JsonSerializerSerialize_EmptyStateBag_ReturnsEmptyObject() + { + // Arrange + var stateBag = new AgentSessionStateBag(); + + // Act + var json = JsonSerializer.Serialize(stateBag, AgentAbstractionsJsonUtilities.DefaultOptions); + + // Assert + Assert.Equal("{}", json); + } + + [Fact] + public void JsonSerializerSerialize_WithStringValue_ProducesSameOutputAsSerializeMethod() + { + // Arrange + var stateBag = new AgentSessionStateBag(); + stateBag.SetValue("stringKey", "stringValue"); + + // Act + var jsonFromSerializer = JsonSerializer.Serialize(stateBag, AgentAbstractionsJsonUtilities.DefaultOptions); + var jsonFromMethod = stateBag.Serialize().GetRawText(); + + // Assert + Assert.Equal(jsonFromMethod, jsonFromSerializer); + } + + [Fact] + public void JsonSerializerRoundtrip_WithStringValue_PreservesData() + { + // Arrange + var stateBag = new AgentSessionStateBag(); + stateBag.SetValue("greeting", "hello world"); + + // Act + var json = JsonSerializer.Serialize(stateBag, AgentAbstractionsJsonUtilities.DefaultOptions); + var restored = JsonSerializer.Deserialize(json, AgentAbstractionsJsonUtilities.DefaultOptions); + + // Assert + Assert.NotNull(restored); + Assert.Equal("hello world", restored!.GetValue("greeting")); + } + + [Fact] + public void JsonSerializerRoundtrip_WithComplexObject_PreservesData() + { + // Arrange + var stateBag = new AgentSessionStateBag(); + var animal = new Animal { Id = 10, FullName = "Rex", Species = Species.Tiger }; + stateBag.SetValue("animal", animal, TestJsonSerializerContext.Default.Options); + + // Act + var json = JsonSerializer.Serialize(stateBag, AgentAbstractionsJsonUtilities.DefaultOptions); + var restored = JsonSerializer.Deserialize(json, AgentAbstractionsJsonUtilities.DefaultOptions); + + // Assert + Assert.NotNull(restored); + var restoredAnimal = restored!.GetValue("animal", TestJsonSerializerContext.Default.Options); + Assert.NotNull(restoredAnimal); + Assert.Equal(10, restoredAnimal!.Id); + Assert.Equal("Rex", restoredAnimal.FullName); + Assert.Equal(Species.Tiger, restoredAnimal.Species); + } + + [Fact] + public void JsonSerializerDeserialize_NullJson_ReturnsNull() + { + // Arrange + const string Json = "null"; + + // Act + var stateBag = JsonSerializer.Deserialize(Json, AgentAbstractionsJsonUtilities.DefaultOptions); + + // Assert + Assert.Null(stateBag); + } + +#if NET10_0_OR_GREATER + [Fact] + public void JsonSerializerSerialize_WithUnknownType_Throws() + { + // Arrange + var stateBag = new AgentSessionStateBag(); + stateBag.SetValue("key", new { Name = "Test" }); // Anonymous type which cannot be deserialized + + // Act & Assert + Assert.Throws(() => JsonSerializer.Serialize(stateBag, AgentAbstractionsJsonUtilities.DefaultOptions)); + } +#endif + + #endregion +} diff --git a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AgentSessionTests.cs b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AgentSessionTests.cs index 5a776c9fb0..b80f0a4fd2 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AgentSessionTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AgentSessionTests.cs @@ -11,6 +11,21 @@ namespace Microsoft.Agents.AI.Abstractions.UnitTests; /// public class AgentSessionTests { + #region StateBag Tests + + [Fact] + public void StateBag_Values_Roundtrips() + { + // Arrange + var session = new TestAgentSession(); + + // Act & Assert + session.StateBag.SetValue("key1", "value1"); + Assert.Equal("value1", session.StateBag.GetValue("key1")); + } + + #endregion + #region GetService Method Tests /// diff --git a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/ChatHistoryProviderExtensionsTests.cs b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/ChatHistoryProviderExtensionsTests.cs deleted file mode 100644 index 1fcbe37e25..0000000000 --- a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/ChatHistoryProviderExtensionsTests.cs +++ /dev/null @@ -1,141 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. - -using System.Collections.Generic; -using System.Linq; -using System.Threading; -using System.Threading.Tasks; -using Microsoft.Extensions.AI; -using Moq; -using Moq.Protected; - -namespace Microsoft.Agents.AI.Abstractions.UnitTests; - -/// -/// Contains tests for the class. -/// -public sealed class ChatHistoryProviderExtensionsTests -{ - private static readonly AIAgent s_mockAgent = new Mock().Object; - private static readonly AgentSession s_mockSession = new Mock().Object; - - [Fact] - public void WithMessageFilters_ReturnsChatHistoryProviderMessageFilter() - { - // Arrange - Mock providerMock = new(); - - // Act - ChatHistoryProvider result = providerMock.Object.WithMessageFilters( - invokingMessagesFilter: msgs => msgs, - invokedMessagesFilter: ctx => ctx); - - // Assert - Assert.IsType(result); - } - - [Fact] - public async Task WithMessageFilters_InvokingFilter_IsAppliedAsync() - { - // Arrange - Mock providerMock = new(); - List innerMessages = [new(ChatRole.User, "Hello"), new(ChatRole.Assistant, "Hi")]; - ChatHistoryProvider.InvokingContext context = new(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Test")]); - - providerMock - .Protected() - .Setup>>("InvokingCoreAsync", ItExpr.IsAny(), ItExpr.IsAny()) - .ReturnsAsync(innerMessages); - - ChatHistoryProvider filtered = providerMock.Object.WithMessageFilters( - invokingMessagesFilter: msgs => msgs.Where(m => m.Role == ChatRole.User)); - - // Act - List result = (await filtered.InvokingAsync(context, CancellationToken.None)).ToList(); - - // Assert - Assert.Single(result); - Assert.Equal(ChatRole.User, result[0].Role); - } - - [Fact] - public async Task WithMessageFilters_InvokedFilter_IsAppliedAsync() - { - // Arrange - Mock providerMock = new(); - List requestMessages = - [ - new(ChatRole.System, "System") { AdditionalProperties = new() { { AgentRequestMessageSourceAttribution.AdditionalPropertiesKey, new AgentRequestMessageSourceAttribution(AgentRequestMessageSourceType.ChatHistory, "TestSource") } } }, - new(ChatRole.User, "Hello") - ]; - ChatHistoryProvider.InvokedContext context = new(s_mockAgent, s_mockSession, requestMessages) - { - ResponseMessages = [new ChatMessage(ChatRole.Assistant, "Response")] - }; - - ChatHistoryProvider.InvokedContext? capturedContext = null; - providerMock - .Protected() - .Setup("InvokedCoreAsync", ItExpr.IsAny(), ItExpr.IsAny()) - .Callback((ctx, _) => capturedContext = ctx) - .Returns(default(ValueTask)); - - ChatHistoryProvider filtered = providerMock.Object.WithMessageFilters( - invokedMessagesFilter: ctx => - { - ctx.ResponseMessages = null; - return ctx; - }); - - // Act - await filtered.InvokedAsync(context, CancellationToken.None); - - // Assert - Assert.NotNull(capturedContext); - Assert.Null(capturedContext.ResponseMessages); - } - - [Fact] - public void WithAIContextProviderMessageRemoval_ReturnsChatHistoryProviderMessageFilter() - { - // Arrange - Mock providerMock = new(); - - // Act - ChatHistoryProvider result = providerMock.Object.WithAIContextProviderMessageRemoval(); - - // Assert - Assert.IsType(result); - } - - [Fact] - public async Task WithAIContextProviderMessageRemoval_RemovesAIContextProviderMessagesAsync() - { - // Arrange - Mock providerMock = new(); - List requestMessages = - [ - new(ChatRole.System, "System") { AdditionalProperties = new() { { AgentRequestMessageSourceAttribution.AdditionalPropertiesKey, new AgentRequestMessageSourceAttribution(AgentRequestMessageSourceType.ChatHistory, "TestSource") } } }, - new(ChatRole.User, "Hello"), - new(ChatRole.System, "Context") { AdditionalProperties = new() { { AgentRequestMessageSourceAttribution.AdditionalPropertiesKey, new AgentRequestMessageSourceAttribution(AgentRequestMessageSourceType.AIContextProvider, "TestContextSource") } } } - ]; - ChatHistoryProvider.InvokedContext context = new(s_mockAgent, s_mockSession, requestMessages); - - ChatHistoryProvider.InvokedContext? capturedContext = null; - providerMock - .Protected() - .Setup("InvokedCoreAsync", ItExpr.IsAny(), ItExpr.IsAny()) - .Callback((ctx, _) => capturedContext = ctx) - .Returns(default(ValueTask)); - - ChatHistoryProvider filtered = providerMock.Object.WithAIContextProviderMessageRemoval(); - - // Act - await filtered.InvokedAsync(context, CancellationToken.None); - - // Assert - Assert.NotNull(capturedContext); - Assert.Equal(2, capturedContext.RequestMessages.Count()); - Assert.Contains("System", capturedContext.RequestMessages.Select(x => x.Text)); - Assert.Contains("Hello", capturedContext.RequestMessages.Select(x => x.Text)); - } -} diff --git a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/ChatHistoryProviderMessageFilterTests.cs b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/ChatHistoryProviderMessageFilterTests.cs deleted file mode 100644 index 75d8f554c5..0000000000 --- a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/ChatHistoryProviderMessageFilterTests.cs +++ /dev/null @@ -1,221 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. - -using System; -using System.Collections.Generic; -using System.Linq; -using System.Text.Json; -using System.Threading; -using System.Threading.Tasks; -using Microsoft.Extensions.AI; -using Moq; -using Moq.Protected; - -namespace Microsoft.Agents.AI.Abstractions.UnitTests; - -/// -/// Contains tests for the class. -/// -public sealed class ChatHistoryProviderMessageFilterTests -{ - private static readonly AIAgent s_mockAgent = new Mock().Object; - private static readonly AgentSession s_mockSession = new Mock().Object; - - [Fact] - public void Constructor_WithNullInnerProvider_ThrowsArgumentNullException() - { - // Arrange, Act & Assert - Assert.Throws(() => new ChatHistoryProviderMessageFilter(null!)); - } - - [Fact] - public void Constructor_WithOnlyInnerProvider_Throws() - { - // Arrange - var innerProviderMock = new Mock(); - - // Act & Assert - Assert.Throws(() => new ChatHistoryProviderMessageFilter(innerProviderMock.Object)); - } - - [Fact] - public void Constructor_WithAllParameters_CreatesInstance() - { - // Arrange - var innerProviderMock = new Mock(); - - IEnumerable InvokingFilter(IEnumerable msgs) => msgs; - ChatHistoryProvider.InvokedContext InvokedFilter(ChatHistoryProvider.InvokedContext ctx) => ctx; - - // Act - var filter = new ChatHistoryProviderMessageFilter(innerProviderMock.Object, InvokingFilter, InvokedFilter); - - // Assert - Assert.NotNull(filter); - } - - [Fact] - public async Task InvokingAsync_WithNoOpFilters_ReturnsInnerProviderMessagesAsync() - { - // Arrange - var innerProviderMock = new Mock(); - var expectedMessages = new List - { - new(ChatRole.User, "Hello"), - new(ChatRole.Assistant, "Hi there!") - }; - var context = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Test")]); - - innerProviderMock - .Protected() - .Setup>>("InvokingCoreAsync", ItExpr.IsAny(), ItExpr.IsAny()) - .ReturnsAsync(expectedMessages); - - var filter = new ChatHistoryProviderMessageFilter(innerProviderMock.Object, x => x, x => x); - - // Act - var result = (await filter.InvokingAsync(context, CancellationToken.None)).ToList(); - - // Assert - Assert.Equal(2, result.Count); - Assert.Equal("Hello", result[0].Text); - Assert.Equal("Hi there!", result[1].Text); - innerProviderMock - .Protected() - .Verify>>("InvokingCoreAsync", Times.Once(), ItExpr.IsAny(), ItExpr.IsAny()); - } - - [Fact] - public async Task InvokingAsync_WithInvokingFilter_AppliesFilterAsync() - { - // Arrange - var innerProviderMock = new Mock(); - var innerMessages = new List - { - new(ChatRole.User, "Hello"), - new(ChatRole.Assistant, "Hi there!"), - new(ChatRole.User, "How are you?") - }; - var context = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Test")]); - - innerProviderMock - .Protected() - .Setup>>("InvokingCoreAsync", ItExpr.IsAny(), ItExpr.IsAny()) - .ReturnsAsync(innerMessages); - - // Filter to only user messages - IEnumerable InvokingFilter(IEnumerable msgs) => msgs.Where(m => m.Role == ChatRole.User); - - var filter = new ChatHistoryProviderMessageFilter(innerProviderMock.Object, InvokingFilter); - - // Act - var result = (await filter.InvokingAsync(context, CancellationToken.None)).ToList(); - - // Assert - Assert.Equal(2, result.Count); - Assert.All(result, msg => Assert.Equal(ChatRole.User, msg.Role)); - innerProviderMock - .Protected() - .Verify>>("InvokingCoreAsync", Times.Once(), ItExpr.IsAny(), ItExpr.IsAny()); - } - - [Fact] - public async Task InvokingAsync_WithInvokingFilter_CanModifyMessagesAsync() - { - // Arrange - var innerProviderMock = new Mock(); - var innerMessages = new List - { - new(ChatRole.User, "Hello"), - new(ChatRole.Assistant, "Hi there!") - }; - var context = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Test")]); - - innerProviderMock - .Protected() - .Setup>>("InvokingCoreAsync", ItExpr.IsAny(), ItExpr.IsAny()) - .ReturnsAsync(innerMessages); - - // Filter that transforms messages - IEnumerable InvokingFilter(IEnumerable msgs) => - msgs.Select(m => new ChatMessage(m.Role, $"[FILTERED] {m.Text}")); - - var filter = new ChatHistoryProviderMessageFilter(innerProviderMock.Object, InvokingFilter); - - // Act - var result = (await filter.InvokingAsync(context, CancellationToken.None)).ToList(); - - // Assert - Assert.Equal(2, result.Count); - Assert.Equal("[FILTERED] Hello", result[0].Text); - Assert.Equal("[FILTERED] Hi there!", result[1].Text); - } - - [Fact] - public async Task InvokedAsync_WithInvokedFilter_AppliesFilterAsync() - { - // Arrange - var innerProviderMock = new Mock(); - List requestMessages = - [ - new(ChatRole.System, "System") { AdditionalProperties = new() { { AgentRequestMessageSourceAttribution.AdditionalPropertiesKey, new AgentRequestMessageSourceAttribution(AgentRequestMessageSourceType.ChatHistory, "TestSource") } } }, - new(ChatRole.User, "Hello"), - ]; - var responseMessages = new List { new(ChatRole.Assistant, "Response") }; - var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages) - { - ResponseMessages = responseMessages - }; - - ChatHistoryProvider.InvokedContext? capturedContext = null; - innerProviderMock - .Protected() - .Setup("InvokedCoreAsync", ItExpr.IsAny(), ItExpr.IsAny()) - .Callback((ctx, ct) => capturedContext = ctx) - .Returns(default(ValueTask)); - - // Filter that modifies the context - ChatHistoryProvider.InvokedContext InvokedFilter(ChatHistoryProvider.InvokedContext ctx) - { - var modifiedRequestMessages = ctx.RequestMessages.Where(x => x.GetAgentRequestMessageSourceType() == AgentRequestMessageSourceType.External).Select(m => new ChatMessage(m.Role, $"[FILTERED] {m.Text}")).ToList(); - return new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, modifiedRequestMessages) - { - ResponseMessages = ctx.ResponseMessages, - InvokeException = ctx.InvokeException - }; - } - - var filter = new ChatHistoryProviderMessageFilter(innerProviderMock.Object, invokedMessagesFilter: InvokedFilter); - - // Act - await filter.InvokedAsync(context, CancellationToken.None); - - // Assert - Assert.NotNull(capturedContext); - Assert.Single(capturedContext.RequestMessages); - Assert.Equal("[FILTERED] Hello", capturedContext.RequestMessages.First().Text); - innerProviderMock - .Protected() - .Verify("InvokedCoreAsync", Times.Once(), ItExpr.IsAny(), ItExpr.IsAny()); - } - - [Fact] - public void Serialize_DelegatesToInnerProvider() - { - // Arrange - var innerProviderMock = new Mock(); - var expectedJson = JsonSerializer.SerializeToElement("data", TestJsonSerializerContext.Default.String); - - innerProviderMock - .Setup(s => s.Serialize(It.IsAny())) - .Returns(expectedJson); - - var filter = new ChatHistoryProviderMessageFilter(innerProviderMock.Object, x => x, x => x); - - // Act - var result = filter.Serialize(); - - // Assert - Assert.Equal(expectedJson.GetRawText(), result.GetRawText()); - innerProviderMock.Verify(s => s.Serialize(null), Times.Once); - } -} diff --git a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/ChatHistoryProviderTests.cs b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/ChatHistoryProviderTests.cs index 8b07366b03..7fccca8d1b 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/ChatHistoryProviderTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/ChatHistoryProviderTests.cs @@ -3,7 +3,6 @@ using System; using System.Collections.Generic; using System.Linq; -using System.Text.Json; using System.Threading; using System.Threading.Tasks; using Microsoft.Extensions.AI; @@ -19,92 +18,6 @@ public class ChatHistoryProviderTests private static readonly AIAgent s_mockAgent = new Mock().Object; private static readonly AgentSession s_mockSession = new Mock().Object; - #region InvokingAsync Message Stamping Tests - - [Fact] - public async Task InvokingAsync_StampsMessagesWithSourceTypeAndSourceIdAsync() - { - // Arrange - var provider = new TestChatHistoryProvider(); - var context = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Request")]); - - // Act - IEnumerable messages = await provider.InvokingAsync(context); - - // Assert - ChatMessage message = messages.Single(); - Assert.NotNull(message.AdditionalProperties); - Assert.True(message.AdditionalProperties.TryGetValue(AgentRequestMessageSourceAttribution.AdditionalPropertiesKey, out object? attribution)); - var typedAttribution = Assert.IsType(attribution); - Assert.Equal(AgentRequestMessageSourceType.ChatHistory, typedAttribution.SourceType); - Assert.Equal(typeof(TestChatHistoryProvider).FullName, typedAttribution.SourceId); - } - - [Fact] - public async Task InvokingAsync_WithCustomSourceId_StampsMessagesWithCustomSourceIdAsync() - { - // Arrange - const string CustomSourceId = "CustomHistorySource"; - var provider = new TestChatHistoryProviderWithCustomSource(CustomSourceId); - var context = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Request")]); - - // Act - IEnumerable messages = await provider.InvokingAsync(context); - - // Assert - ChatMessage message = messages.Single(); - Assert.NotNull(message.AdditionalProperties); - Assert.True(message.AdditionalProperties.TryGetValue(AgentRequestMessageSourceAttribution.AdditionalPropertiesKey, out object? attribution)); - var typedAttribution = Assert.IsType(attribution); - Assert.Equal(AgentRequestMessageSourceType.ChatHistory, typedAttribution.SourceType); - Assert.Equal(CustomSourceId, typedAttribution.SourceId); - } - - [Fact] - public async Task InvokingAsync_DoesNotReStampAlreadyStampedMessagesAsync() - { - // Arrange - var provider = new TestChatHistoryProviderWithPreStampedMessages(); - var context = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Request")]); - - // Act - IEnumerable messages = await provider.InvokingAsync(context); - - // Assert - ChatMessage message = messages.Single(); - Assert.NotNull(message.AdditionalProperties); - Assert.True(message.AdditionalProperties.TryGetValue(AgentRequestMessageSourceAttribution.AdditionalPropertiesKey, out object? attribution)); - var typedAttribution = Assert.IsType(attribution); - Assert.Equal(AgentRequestMessageSourceType.ChatHistory, typedAttribution.SourceType); - Assert.Equal(typeof(TestChatHistoryProviderWithPreStampedMessages).FullName, typedAttribution.SourceId); - } - - [Fact] - public async Task InvokingAsync_StampsMultipleMessagesAsync() - { - // Arrange - var provider = new TestChatHistoryProviderWithMultipleMessages(); - var context = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Request")]); - - // Act - IEnumerable messages = await provider.InvokingAsync(context); - - // Assert - List messageList = messages.ToList(); - Assert.Equal(3, messageList.Count); - - foreach (ChatMessage message in messageList) - { - Assert.NotNull(message.AdditionalProperties); - Assert.True(message.AdditionalProperties.TryGetValue(AgentRequestMessageSourceAttribution.AdditionalPropertiesKey, out object? attribution)); - var typedAttribution = Assert.IsType(attribution); - Assert.Equal(AgentRequestMessageSourceType.ChatHistory, typedAttribution.SourceType); - Assert.Equal(typeof(TestChatHistoryProviderWithMultipleMessages).FullName, typedAttribution.SourceId); - } - } - - #endregion - #region GetService Method Tests [Fact] @@ -259,33 +172,7 @@ public void InvokingContext_Constructor_ThrowsForNullAgent() public void InvokedContext_Constructor_ThrowsForNullRequestMessages() { // Arrange & Act & Assert - Assert.Throws(() => new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, null!)); - } - - [Fact] - public void InvokedContext_RequestMessages_SetterThrowsForNull() - { - // Arrange - var requestMessages = new List { new(ChatRole.User, "Hello") }; - var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages); - - // Act & Assert - Assert.Throws(() => context.RequestMessages = null!); - } - - [Fact] - public void InvokedContext_RequestMessages_SetterRoundtrips() - { - // Arrange - var initialMessages = new List { new(ChatRole.User, "Hello") }; - var newMessages = new List { new(ChatRole.User, "New message") }; - var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, initialMessages); - - // Act - context.RequestMessages = newMessages; - - // Assert - Assert.Same(newMessages, context.RequestMessages); + Assert.Throws(() => new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, null!, [])); } [Fact] @@ -294,10 +181,9 @@ public void InvokedContext_ResponseMessages_Roundtrips() // Arrange var requestMessages = new List { new(ChatRole.User, "Hello") }; var responseMessages = new List { new(ChatRole.Assistant, "Response message") }; - var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages); // Act - context.ResponseMessages = responseMessages; + var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages, responseMessages); // Assert Assert.Same(responseMessages, context.ResponseMessages); @@ -309,10 +195,9 @@ public void InvokedContext_InvokeException_Roundtrips() // Arrange var requestMessages = new List { new(ChatRole.User, "Hello") }; var exception = new InvalidOperationException("Test exception"); - var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages); // Act - context.InvokeException = exception; + var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages, exception); // Assert Assert.Same(exception, context.InvokeException); @@ -325,7 +210,7 @@ public void InvokedContext_Agent_ReturnsConstructorValue() var requestMessages = new List { new(ChatRole.User, "Hello") }; // Act - var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages); + var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages, []); // Assert Assert.Same(s_mockAgent, context.Agent); @@ -338,7 +223,7 @@ public void InvokedContext_Session_ReturnsConstructorValue() var requestMessages = new List { new(ChatRole.User, "Hello") }; // Act - var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages); + var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages, []); // Assert Assert.Same(s_mockSession, context.Session); @@ -351,7 +236,7 @@ public void InvokedContext_Session_CanBeNull() var requestMessages = new List { new(ChatRole.User, "Hello") }; // Act - var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, null, requestMessages); + var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, null, requestMessages, []); // Assert Assert.Null(context.Session); @@ -364,71 +249,37 @@ public void InvokedContext_Constructor_ThrowsForNullAgent() var requestMessages = new List { new(ChatRole.User, "Hello") }; // Act & Assert - Assert.Throws(() => new ChatHistoryProvider.InvokedContext(null!, s_mockSession, requestMessages)); + Assert.Throws(() => new ChatHistoryProvider.InvokedContext(null!, s_mockSession, requestMessages, [])); } - #endregion - - private sealed class TestChatHistoryProvider : ChatHistoryProvider + [Fact] + public void InvokedContext_SuccessConstructor_ThrowsForNullResponseMessages() { - protected override ValueTask> InvokingCoreAsync(InvokingContext context, CancellationToken cancellationToken = default) - => new([new ChatMessage(ChatRole.User, "Test Message")]); - - protected override ValueTask InvokedCoreAsync(InvokedContext context, CancellationToken cancellationToken = default) - => default; + // Arrange + var requestMessages = new List { new(ChatRole.User, "Hello") }; - public override JsonElement Serialize(JsonSerializerOptions? jsonSerializerOptions = null) - => default; + // Act & Assert + Assert.Throws(() => new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages, (IEnumerable)null!)); } - private sealed class TestChatHistoryProviderWithCustomSource : ChatHistoryProvider + [Fact] + public void InvokedContext_FailureConstructor_ThrowsForNullException() { - public TestChatHistoryProviderWithCustomSource(string sourceId) : base(sourceId) - { - } - - protected override ValueTask> InvokingCoreAsync(InvokingContext context, CancellationToken cancellationToken = default) - => new([new ChatMessage(ChatRole.User, "Test Message")]); - - protected override ValueTask InvokedCoreAsync(InvokedContext context, CancellationToken cancellationToken = default) - => default; + // Arrange + var requestMessages = new List { new(ChatRole.User, "Hello") }; - public override JsonElement Serialize(JsonSerializerOptions? jsonSerializerOptions = null) - => default; + // Act & Assert + Assert.Throws(() => new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages, (Exception)null!)); } - private sealed class TestChatHistoryProviderWithPreStampedMessages : ChatHistoryProvider - { - protected override ValueTask> InvokingCoreAsync(InvokingContext context, CancellationToken cancellationToken = default) - { - var message = new ChatMessage(ChatRole.User, "Pre-stamped Message"); - message.AdditionalProperties = new AdditionalPropertiesDictionary - { - [AgentRequestMessageSourceAttribution.AdditionalPropertiesKey] = new AgentRequestMessageSourceAttribution(AgentRequestMessageSourceType.ChatHistory, this.GetType().FullName!) - }; - return new([message]); - } - - protected override ValueTask InvokedCoreAsync(InvokedContext context, CancellationToken cancellationToken = default) - => default; - - public override JsonElement Serialize(JsonSerializerOptions? jsonSerializerOptions = null) - => default; - } + #endregion - private sealed class TestChatHistoryProviderWithMultipleMessages : ChatHistoryProvider + private sealed class TestChatHistoryProvider : ChatHistoryProvider { protected override ValueTask> InvokingCoreAsync(InvokingContext context, CancellationToken cancellationToken = default) - => new([ - new ChatMessage(ChatRole.User, "Message 1"), - new ChatMessage(ChatRole.Assistant, "Message 2"), - new ChatMessage(ChatRole.User, "Message 3") - ]); + => new(new ChatMessage[] { new(ChatRole.User, "Test Message") }.Concat(context.RequestMessages)); protected override ValueTask InvokedCoreAsync(InvokedContext context, CancellationToken cancellationToken = default) => default; - - public override JsonElement Serialize(JsonSerializerOptions? jsonSerializerOptions = null) - => default; } } diff --git a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/ChatMessageExtensionsTests.cs b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/ChatMessageExtensionsTests.cs index 97050c3071..05fd576798 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/ChatMessageExtensionsTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/ChatMessageExtensionsTests.cs @@ -350,7 +350,7 @@ public void AsAgentRequestMessageSourcedMessage_WithNoAdditionalProperties_Retur ChatMessage message = new(ChatRole.User, "Hello"); // Act - ChatMessage result = message.AsAgentRequestMessageSourcedMessage(AgentRequestMessageSourceType.External, "TestSourceId"); + ChatMessage result = message.WithAgentRequestMessageSource(AgentRequestMessageSourceType.External, "TestSourceId"); // Assert Assert.NotSame(message, result); @@ -368,7 +368,7 @@ public void AsAgentRequestMessageSourcedMessage_WithNullAdditionalProperties_Ret }; // Act - ChatMessage result = message.AsAgentRequestMessageSourcedMessage(AgentRequestMessageSourceType.AIContextProvider, "ProviderSourceId"); + ChatMessage result = message.WithAgentRequestMessageSource(AgentRequestMessageSourceType.AIContextProvider, "ProviderSourceId"); // Assert Assert.NotSame(message, result); @@ -389,7 +389,7 @@ public void AsAgentRequestMessageSourcedMessage_WithMatchingSourceTypeAndSourceI }; // Act - ChatMessage result = message.AsAgentRequestMessageSourcedMessage(AgentRequestMessageSourceType.ChatHistory, "HistoryId"); + ChatMessage result = message.WithAgentRequestMessageSource(AgentRequestMessageSourceType.ChatHistory, "HistoryId"); // Assert Assert.Same(message, result); @@ -408,7 +408,7 @@ public void AsAgentRequestMessageSourcedMessage_WithDifferentSourceType_ReturnsC }; // Act - ChatMessage result = message.AsAgentRequestMessageSourcedMessage(AgentRequestMessageSourceType.AIContextProvider, "SourceId"); + ChatMessage result = message.WithAgentRequestMessageSource(AgentRequestMessageSourceType.AIContextProvider, "SourceId"); // Assert Assert.NotSame(message, result); @@ -429,7 +429,7 @@ public void AsAgentRequestMessageSourcedMessage_WithDifferentSourceId_ReturnsClo }; // Act - ChatMessage result = message.AsAgentRequestMessageSourcedMessage(AgentRequestMessageSourceType.External, "NewId"); + ChatMessage result = message.WithAgentRequestMessageSource(AgentRequestMessageSourceType.External, "NewId"); // Assert Assert.NotSame(message, result); @@ -444,7 +444,7 @@ public void AsAgentRequestMessageSourcedMessage_WithDefaultNullSourceId_ReturnsC ChatMessage message = new(ChatRole.User, "Hello"); // Act - ChatMessage result = message.AsAgentRequestMessageSourcedMessage(AgentRequestMessageSourceType.ChatHistory); + ChatMessage result = message.WithAgentRequestMessageSource(AgentRequestMessageSourceType.ChatHistory); // Assert Assert.NotSame(message, result); @@ -465,7 +465,7 @@ public void AsAgentRequestMessageSourcedMessage_WithMatchingSourceTypeAndNullSou }; // Act - ChatMessage result = message.AsAgentRequestMessageSourcedMessage(AgentRequestMessageSourceType.External); + ChatMessage result = message.WithAgentRequestMessageSource(AgentRequestMessageSourceType.External); // Assert Assert.Same(message, result); @@ -478,7 +478,7 @@ public void AsAgentRequestMessageSourcedMessage_DoesNotModifyOriginalMessage() ChatMessage message = new(ChatRole.User, "Hello"); // Act - ChatMessage result = message.AsAgentRequestMessageSourcedMessage(AgentRequestMessageSourceType.AIContextProvider, "ProviderId"); + ChatMessage result = message.WithAgentRequestMessageSource(AgentRequestMessageSourceType.AIContextProvider, "ProviderId"); // Assert Assert.Null(message.AdditionalProperties); @@ -499,7 +499,7 @@ public void AsAgentRequestMessageSourcedMessage_WithWrongAttributionType_Returns }; // Act - ChatMessage result = message.AsAgentRequestMessageSourcedMessage(AgentRequestMessageSourceType.External, "SourceId"); + ChatMessage result = message.WithAgentRequestMessageSource(AgentRequestMessageSourceType.External, "SourceId"); // Assert Assert.NotSame(message, result); @@ -514,7 +514,7 @@ public void AsAgentRequestMessageSourcedMessage_PreservesMessageContent() ChatMessage message = new(ChatRole.Assistant, "Test content"); // Act - ChatMessage result = message.AsAgentRequestMessageSourcedMessage(AgentRequestMessageSourceType.ChatHistory, "HistoryId"); + ChatMessage result = message.WithAgentRequestMessageSource(AgentRequestMessageSourceType.ChatHistory, "HistoryId"); // Assert Assert.Equal(ChatRole.Assistant, result.Role); diff --git a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/InMemoryAgentSessionTests.cs b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/InMemoryAgentSessionTests.cs deleted file mode 100644 index a3a4bf7e0e..0000000000 --- a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/InMemoryAgentSessionTests.cs +++ /dev/null @@ -1,155 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. - -using System; -using System.Collections.Generic; -using System.Linq; -using System.Text.Json; -using Microsoft.Extensions.AI; - -namespace Microsoft.Agents.AI.Abstractions.UnitTests; - -/// -/// Contains tests for . -/// -public class InMemoryAgentSessionTests -{ - #region Constructor and Property Tests - - [Fact] - public void Constructor_SetsDefaultChatHistoryProvider() - { - // Arrange & Act - var session = new TestInMemoryAgentSession(); - - // Assert - Assert.NotNull(session.GetChatHistoryProvider()); - Assert.Empty(session.GetChatHistoryProvider()); - } - - [Fact] - public void Constructor_WithChatHistoryProvider_SetsProperty() - { - // Arrange - InMemoryChatHistoryProvider provider = [new(ChatRole.User, "Hello")]; - - // Act - var session = new TestInMemoryAgentSession(provider); - - // Assert - Assert.Same(provider, session.GetChatHistoryProvider()); - Assert.Single(session.GetChatHistoryProvider()); - Assert.Equal("Hello", session.GetChatHistoryProvider()[0].Text); - } - - [Fact] - public void Constructor_WithMessages_SetsProperty() - { - // Arrange - var messages = new List { new(ChatRole.User, "Hi") }; - - // Act - var session = new TestInMemoryAgentSession(messages); - - // Assert - Assert.NotNull(session.GetChatHistoryProvider()); - Assert.Single(session.GetChatHistoryProvider()); - Assert.Equal("Hi", session.GetChatHistoryProvider()[0].Text); - } - - [Fact] - public void Constructor_WithSerializedState_SetsProperty() - { - // Arrange - InMemoryChatHistoryProvider provider = [new(ChatRole.User, "TestMsg")]; - var providerState = provider.Serialize(); - var sessionStateWrapper = new InMemoryAgentSession.InMemoryAgentSessionState { ChatHistoryProviderState = providerState }; - var json = JsonSerializer.SerializeToElement(sessionStateWrapper, TestJsonSerializerContext.Default.InMemoryAgentSessionState); - - // Act - var session = new TestInMemoryAgentSession(json); - - // Assert - Assert.NotNull(session.GetChatHistoryProvider()); - Assert.Single(session.GetChatHistoryProvider()); - Assert.Equal("TestMsg", session.GetChatHistoryProvider()[0].Text); - } - - [Fact] - public void Constructor_WithInvalidJson_ThrowsArgumentException() - { - // Arrange - var invalidJson = JsonSerializer.SerializeToElement(42, TestJsonSerializerContext.Default.Int32); - - // Act & Assert - Assert.Throws(() => new TestInMemoryAgentSession(invalidJson)); - } - - #endregion - - #region SerializeAsync Tests - - [Fact] - public void Serialize_ReturnsCorrectJson_WhenMessagesExist() - { - // Arrange - var session = new TestInMemoryAgentSession([new(ChatRole.User, "TestContent")]); - - // Act - var json = session.Serialize(); - - // Assert - Assert.Equal(JsonValueKind.Object, json.ValueKind); - Assert.True(json.TryGetProperty("chatHistoryProviderState", out var providerStateProperty)); - Assert.Equal(JsonValueKind.Object, providerStateProperty.ValueKind); - Assert.True(providerStateProperty.TryGetProperty("messages", out var messagesProperty)); - Assert.Equal(JsonValueKind.Array, messagesProperty.ValueKind); - var messagesList = messagesProperty.EnumerateArray().ToList(); - Assert.Single(messagesList); - } - - [Fact] - public void Serialize_ReturnsEmptyMessages_WhenNoMessages() - { - // Arrange - var session = new TestInMemoryAgentSession(); - - // Act - var json = session.Serialize(); - - // Assert - Assert.Equal(JsonValueKind.Object, json.ValueKind); - Assert.True(json.TryGetProperty("chatHistoryProviderState", out var providerStateProperty)); - Assert.Equal(JsonValueKind.Object, providerStateProperty.ValueKind); - Assert.True(providerStateProperty.TryGetProperty("messages", out var messagesProperty)); - Assert.Equal(JsonValueKind.Array, messagesProperty.ValueKind); - Assert.Empty(messagesProperty.EnumerateArray()); - } - - #endregion - - #region GetService Tests - - [Fact] - public void GetService_RequestingChatHistoryProvider_ReturnsChatHistoryProvider() - { - // Arrange - var session = new TestInMemoryAgentSession(); - - // Act & Assert - Assert.NotNull(session.GetService(typeof(ChatHistoryProvider))); - Assert.Same(session.GetChatHistoryProvider(), session.GetService(typeof(ChatHistoryProvider))); - Assert.Same(session.GetChatHistoryProvider(), session.GetService(typeof(InMemoryChatHistoryProvider))); - } - - #endregion - - // Sealed test subclass to expose protected members for testing - private sealed class TestInMemoryAgentSession : InMemoryAgentSession - { - public TestInMemoryAgentSession() { } - public TestInMemoryAgentSession(InMemoryChatHistoryProvider? provider) : base(provider) { } - public TestInMemoryAgentSession(IEnumerable messages) : base(messages) { } - public TestInMemoryAgentSession(JsonElement serializedSessionState) : base(serializedSessionState) { } - public InMemoryChatHistoryProvider GetChatHistoryProvider() => this.ChatHistoryProvider; - } -} diff --git a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/InMemoryChatHistoryProviderTests.cs b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/InMemoryChatHistoryProviderTests.cs index cecce5bea1..b907529241 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/InMemoryChatHistoryProviderTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/InMemoryChatHistoryProviderTests.cs @@ -3,9 +3,6 @@ using System; using System.Collections.Generic; using System.Linq; -using System.Text.Encodings.Web; -using System.Text.Json; -using System.Text.Json.Serialization.Metadata; using System.Threading; using System.Threading.Tasks; using Microsoft.Extensions.AI; @@ -19,22 +16,18 @@ namespace Microsoft.Agents.AI.Abstractions.UnitTests; public class InMemoryChatHistoryProviderTests { private static readonly AIAgent s_mockAgent = new Mock().Object; - private static readonly AgentSession s_mockSession = new Mock().Object; - [Fact] - public void Constructor_Throws_ForNullReducer() => - // Arrange & Act & Assert - Assert.Throws(() => new InMemoryChatHistoryProvider(null!)); + private static AgentSession CreateMockSession() => new Mock().Object; [Fact] public void Constructor_DefaultsToBeforeMessageRetrieval_ForNotProvidedTriggerEvent() { // Arrange & Act var reducerMock = new Mock(); - var provider = new InMemoryChatHistoryProvider(reducerMock.Object); + var provider = new InMemoryChatHistoryProvider(new() { ChatReducer = reducerMock.Object }); // Assert - Assert.Equal(InMemoryChatHistoryProvider.ChatReducerTriggerEvent.BeforeMessagesRetrieval, provider.ReducerTriggerEvent); + Assert.Equal(InMemoryChatHistoryProviderOptions.ChatReducerTriggerEvent.BeforeMessagesRetrieval, provider.ReducerTriggerEvent); } [Fact] @@ -42,464 +35,211 @@ public void Constructor_Arguments_SetOnPropertiesCorrectly() { // Arrange & Act var reducerMock = new Mock(); - var provider = new InMemoryChatHistoryProvider(reducerMock.Object, InMemoryChatHistoryProvider.ChatReducerTriggerEvent.AfterMessageAdded); + var provider = new InMemoryChatHistoryProvider(new() { ChatReducer = reducerMock.Object, ReducerTriggerEvent = InMemoryChatHistoryProviderOptions.ChatReducerTriggerEvent.AfterMessageAdded }); // Assert Assert.Same(reducerMock.Object, provider.ChatReducer); - Assert.Equal(InMemoryChatHistoryProvider.ChatReducerTriggerEvent.AfterMessageAdded, provider.ReducerTriggerEvent); - } - - [Fact] - public async Task InvokedAsyncAddsMessagesAsync() - { - var requestMessages = new List - { - new(ChatRole.User, "Hello"), - new(ChatRole.System, "additional context") { AdditionalProperties = new() { { AgentRequestMessageSourceAttribution.AdditionalPropertiesKey, new AgentRequestMessageSourceAttribution(AgentRequestMessageSourceType.ChatHistory, "TestSource") } } }, - }; - var responseMessages = new List - { - new(ChatRole.Assistant, "Hi there!") - }; - var providerMessages = new List() - { - new(ChatRole.System, "original instructions") - }; - - var provider = new InMemoryChatHistoryProvider(); - provider.Add(providerMessages[0]); - var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages) - { - ResponseMessages = responseMessages - }; - await provider.InvokedAsync(context, CancellationToken.None); - - Assert.Equal(4, provider.Count); - Assert.Equal("original instructions", provider[0].Text); - Assert.Equal("Hello", provider[1].Text); - Assert.Equal("additional context", provider[2].Text); - Assert.Equal("Hi there!", provider[3].Text); + Assert.Equal(InMemoryChatHistoryProviderOptions.ChatReducerTriggerEvent.AfterMessageAdded, provider.ReducerTriggerEvent); } [Fact] - public async Task InvokedAsyncWithEmptyDoesNotFailAsync() + public void StateKey_ReturnsDefaultKey_WhenNoOptionsProvided() { + // Arrange & Act var provider = new InMemoryChatHistoryProvider(); - var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, []); - await provider.InvokedAsync(context, CancellationToken.None); - - Assert.Empty(provider); - } - - [Fact] - public async Task InvokingAsyncReturnsAllMessagesAsync() - { - var provider = new InMemoryChatHistoryProvider - { - new ChatMessage(ChatRole.User, "Test1"), - new ChatMessage(ChatRole.Assistant, "Test2") - }; - - var context = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, []); - var result = (await provider.InvokingAsync(context, CancellationToken.None)).ToList(); - - Assert.Equal(2, result.Count); - Assert.Contains(result, m => m.Text == "Test1"); - Assert.Contains(result, m => m.Text == "Test2"); + // Assert + Assert.Equal("InMemoryChatHistoryProvider", provider.StateKey); } [Fact] - public async Task DeserializeConstructorWithEmptyElementAsync() + public void StateKey_ReturnsCustomKey_WhenSetViaOptions() { - var emptyObject = JsonSerializer.Deserialize("{}", TestJsonSerializerContext.Default.JsonElement); - - var newProvider = new InMemoryChatHistoryProvider(emptyObject); + // Arrange & Act + var provider = new InMemoryChatHistoryProvider(new() { StateKey = "custom-key" }); - Assert.Empty(newProvider); + // Assert + Assert.Equal("custom-key", provider.StateKey); } [Fact] - public async Task SerializeAndDeserializeConstructorRoundtripsAsync() + public async Task InvokedAsyncAddsMessagesAsync() { - var provider = new InMemoryChatHistoryProvider - { - new ChatMessage(ChatRole.User, "A"), - new ChatMessage(ChatRole.Assistant, "B") - }; - - var jsonElement = provider.Serialize(); - var newProvider = new InMemoryChatHistoryProvider(jsonElement); - - Assert.Equal(2, newProvider.Count); - Assert.Equal("A", newProvider[0].Text); - Assert.Equal("B", newProvider[1].Text); - } + var session = CreateMockSession(); - [Fact] - public async Task SerializeAndDeserializeConstructorRoundtripsWithCustomAIContentAsync() - { - JsonSerializerOptions options = new(TestJsonSerializerContext.Default.Options) + // Arrange + var requestMessages = new List { - TypeInfoResolver = JsonTypeInfoResolver.Combine(AgentAbstractionsJsonUtilities.DefaultOptions.TypeInfoResolver, TestJsonSerializerContext.Default), - Encoder = JavaScriptEncoder.UnsafeRelaxedJsonEscaping, + new(ChatRole.User, "Hello"), + new(ChatRole.System, "additional context") { AdditionalProperties = new() { { AgentRequestMessageSourceAttribution.AdditionalPropertiesKey, new AgentRequestMessageSourceAttribution(AgentRequestMessageSourceType.AIContextProvider, "TestSource") } } }, }; - options.AddAIContentType(typeDiscriminatorId: "testContent"); - - var provider = new InMemoryChatHistoryProvider + var responseMessages = new List { - new ChatMessage(ChatRole.User, [new TestAIContent("foo data")]), + new(ChatRole.Assistant, "Hi there!") }; - - var jsonElement = provider.Serialize(options); - var newProvider = new InMemoryChatHistoryProvider(jsonElement, options); - - Assert.Single(newProvider); - var actualTestAIContent = Assert.IsType(newProvider[0].Contents[0]); - Assert.Equal("foo data", actualTestAIContent.TestData); - } - - [Fact] - public async Task SerializeAndDeserializeWorksWithExperimentalContentTypesAsync() - { - var provider = new InMemoryChatHistoryProvider + var providerMessages = new List() { - new ChatMessage(ChatRole.User, [new FunctionApprovalRequestContent("call123", new FunctionCallContent("call123", "some_func"))]), - new ChatMessage(ChatRole.Assistant, [new FunctionApprovalResponseContent("call123", true, new FunctionCallContent("call123", "some_func"))]) + new(ChatRole.System, "original instructions") }; - var jsonElement = provider.Serialize(); - var newProvider = new InMemoryChatHistoryProvider(jsonElement); - - Assert.Equal(2, newProvider.Count); - Assert.IsType(newProvider[0].Contents[0]); - Assert.IsType(newProvider[1].Contents[0]); - } - - [Fact] - public async Task InvokedAsyncWithEmptyMessagesDoesNotChangeProviderAsync() - { var provider = new InMemoryChatHistoryProvider(); - var messages = new List(); - - var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, messages); + provider.SetMessages(session, providerMessages); + var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, session, requestMessages, responseMessages); await provider.InvokedAsync(context, CancellationToken.None); - Assert.Empty(provider); - } - - [Fact] - public async Task InvokedAsync_WithNullContext_ThrowsArgumentNullExceptionAsync() - { - // Arrange - var provider = new InMemoryChatHistoryProvider(); - - // Act & Assert - await Assert.ThrowsAsync(() => provider.InvokedAsync(null!, CancellationToken.None).AsTask()); - } - - [Fact] - public void DeserializeContructor_WithNullSerializedState_CreatesEmptyProvider() - { - // Act - var provider = new InMemoryChatHistoryProvider(new JsonElement()); - // Assert - Assert.Empty(provider); + var messages = provider.GetMessages(session); + Assert.Equal(4, messages.Count); + Assert.Equal("original instructions", messages[0].Text); + Assert.Equal("Hello", messages[1].Text); + Assert.Equal("additional context", messages[2].Text); + Assert.Equal("Hi there!", messages[3].Text); } [Fact] - public async Task DeserializeContructor_WithEmptyMessages_DoesNotAddMessagesAsync() + public async Task InvokedAsyncWithEmptyDoesNotFailAsync() { - // Arrange - var stateWithEmptyMessages = JsonSerializer.SerializeToElement( - new Dictionary { ["messages"] = new List() }, - TestJsonSerializerContext.Default.IDictionaryStringObject); - - // Act - var provider = new InMemoryChatHistoryProvider(stateWithEmptyMessages); + var session = CreateMockSession(); - // Assert - Assert.Empty(provider); - } - - [Fact] - public async Task DeserializeConstructor_WithNullMessages_DoesNotAddMessagesAsync() - { // Arrange - var stateWithNullMessages = JsonSerializer.SerializeToElement( - new Dictionary { ["messages"] = null! }, - TestJsonSerializerContext.Default.DictionaryStringObject); - - // Act - var provider = new InMemoryChatHistoryProvider(stateWithNullMessages); + var provider = new InMemoryChatHistoryProvider(); + var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, session, [], []); + await provider.InvokedAsync(context, CancellationToken.None); // Assert - Assert.Empty(provider); + Assert.Empty(provider.GetMessages(session)); } [Fact] - public async Task DeserializeConstructor_WithValidMessages_AddsMessagesAsync() + public async Task InvokingAsyncReturnsAllMessagesAsync() { + var session = CreateMockSession(); + // Arrange - var messages = new List + var requestMessages = new List { - new(ChatRole.User, "User message"), - new(ChatRole.Assistant, "Assistant message") + new(ChatRole.User, "Hello"), }; - var state = new Dictionary { ["messages"] = messages }; - var serializedState = JsonSerializer.SerializeToElement( - state, - TestJsonSerializerContext.Default.DictionaryStringObject); - - // Act - var provider = new InMemoryChatHistoryProvider(serializedState); - // Assert - Assert.Equal(2, provider.Count); - Assert.Equal("User message", provider[0].Text); - Assert.Equal("Assistant message", provider[1].Text); - } - - [Fact] - public void IndexerGet_ReturnsCorrectMessage() - { - // Arrange var provider = new InMemoryChatHistoryProvider(); - var message1 = new ChatMessage(ChatRole.User, "First"); - var message2 = new ChatMessage(ChatRole.Assistant, "Second"); - provider.Add(message1); - provider.Add(message2); - - // Act & Assert - Assert.Same(message1, provider[0]); - Assert.Same(message2, provider[1]); - } - - [Fact] - public void IndexerSet_UpdatesMessage() - { - // Arrange - var provider = new InMemoryChatHistoryProvider(); - var originalMessage = new ChatMessage(ChatRole.User, "Original"); - var newMessage = new ChatMessage(ChatRole.User, "Updated"); - provider.Add(originalMessage); + provider.SetMessages(session, + [ + new ChatMessage(ChatRole.User, "Test1"), + new ChatMessage(ChatRole.Assistant, "Test2") + ]); - // Act - provider[0] = newMessage; + var context = new ChatHistoryProvider.InvokingContext(s_mockAgent, session, requestMessages); + var result = (await provider.InvokingAsync(context, CancellationToken.None)).ToList(); // Assert - Assert.Same(newMessage, provider[0]); - Assert.Equal("Updated", provider[0].Text); - } - - [Fact] - public void IsReadOnly_ReturnsFalse() - { - // Arrange - var provider = new InMemoryChatHistoryProvider(); - - // Act & Assert - Assert.False(provider.IsReadOnly); - } - - [Fact] - public void IndexOf_ReturnsCorrectIndex() - { - // Arrange - var provider = new InMemoryChatHistoryProvider(); - var message1 = new ChatMessage(ChatRole.User, "First"); - var message2 = new ChatMessage(ChatRole.Assistant, "Second"); - var message3 = new ChatMessage(ChatRole.User, "Third"); - provider.Add(message1); - provider.Add(message2); + Assert.Equal(3, result.Count); + Assert.Contains(result, m => m.Text == "Test1"); + Assert.Contains(result, m => m.Text == "Test2"); + Assert.Contains(result, m => m.Text == "Hello"); - // Act & Assert - Assert.Equal(0, provider.IndexOf(message1)); - Assert.Equal(1, provider.IndexOf(message2)); - Assert.Equal(-1, provider.IndexOf(message3)); // Not in provider + Assert.Equal(AgentRequestMessageSourceType.ChatHistory, result[0].GetAgentRequestMessageSourceType()); + Assert.Equal(AgentRequestMessageSourceType.ChatHistory, result[1].GetAgentRequestMessageSourceType()); + Assert.Equal(AgentRequestMessageSourceType.External, result[2].GetAgentRequestMessageSourceType()); } [Fact] - public void Insert_InsertsMessageAtCorrectIndex() + public void StateInitializer_IsInvoked_WhenSessionHasNoState() { // Arrange - var provider = new InMemoryChatHistoryProvider(); - var message1 = new ChatMessage(ChatRole.User, "First"); - var message2 = new ChatMessage(ChatRole.Assistant, "Second"); - var insertMessage = new ChatMessage(ChatRole.User, "Inserted"); - provider.Add(message1); - provider.Add(message2); + var initialMessages = new List + { + new(ChatRole.User, "Initial message") + }; + var provider = new InMemoryChatHistoryProvider(new() + { + StateInitializer = _ => new InMemoryChatHistoryProvider.State { Messages = initialMessages } + }); // Act - provider.Insert(1, insertMessage); + var messages = provider.GetMessages(CreateMockSession()); // Assert - Assert.Equal(3, provider.Count); - Assert.Same(message1, provider[0]); - Assert.Same(insertMessage, provider[1]); - Assert.Same(message2, provider[2]); + Assert.Single(messages); + Assert.Equal("Initial message", messages[0].Text); } [Fact] - public void RemoveAt_RemovesMessageAtIndex() + public void GetMessages_ReturnsEmptyList_WhenNullSession() { // Arrange var provider = new InMemoryChatHistoryProvider(); - var message1 = new ChatMessage(ChatRole.User, "First"); - var message2 = new ChatMessage(ChatRole.Assistant, "Second"); - var message3 = new ChatMessage(ChatRole.User, "Third"); - provider.Add(message1); - provider.Add(message2); - provider.Add(message3); // Act - provider.RemoveAt(1); + var messages = provider.GetMessages(null); // Assert - Assert.Equal(2, provider.Count); - Assert.Same(message1, provider[0]); - Assert.Same(message3, provider[1]); + Assert.Empty(messages); } [Fact] - public void Clear_RemovesAllMessages() - { - // Arrange - var provider = new InMemoryChatHistoryProvider - { - new ChatMessage(ChatRole.User, "First"), - new ChatMessage(ChatRole.Assistant, "Second") - }; - - // Act - provider.Clear(); - - // Assert - Assert.Empty(provider); - } - - [Fact] - public void Contains_ReturnsTrueForExistingMessage() + public void SetMessages_ThrowsForNullMessages() { // Arrange var provider = new InMemoryChatHistoryProvider(); - var message1 = new ChatMessage(ChatRole.User, "First"); - var message2 = new ChatMessage(ChatRole.Assistant, "Second"); - provider.Add(message1); // Act & Assert - Assert.Contains(message1, provider); - Assert.DoesNotContain(message2, provider); + Assert.Throws(() => provider.SetMessages(CreateMockSession(), null!)); } [Fact] - public void CopyTo_CopiesMessagesToArray() + public void SetMessages_UpdatesState() { - // Arrange - var provider = new InMemoryChatHistoryProvider(); - var message1 = new ChatMessage(ChatRole.User, "First"); - var message2 = new ChatMessage(ChatRole.Assistant, "Second"); - provider.Add(message1); - provider.Add(message2); - var array = new ChatMessage[4]; - - // Act - provider.CopyTo(array, 1); - - // Assert - Assert.Null(array[0]); - Assert.Same(message1, array[1]); - Assert.Same(message2, array[2]); - Assert.Null(array[3]); - } + var session = CreateMockSession(); - [Fact] - public void Remove_RemovesSpecificMessage() - { // Arrange var provider = new InMemoryChatHistoryProvider(); - var message1 = new ChatMessage(ChatRole.User, "First"); - var message2 = new ChatMessage(ChatRole.Assistant, "Second"); - var message3 = new ChatMessage(ChatRole.User, "Third"); - provider.Add(message1); - provider.Add(message2); - provider.Add(message3); + var messages = new List + { + new(ChatRole.User, "Hello"), + new(ChatRole.Assistant, "World") + }; // Act - var removed = provider.Remove(message2); + provider.SetMessages(session, messages); + var retrieved = provider.GetMessages(session); // Assert - Assert.True(removed); - Assert.Equal(2, provider.Count); - Assert.Same(message1, provider[0]); - Assert.Same(message3, provider[1]); + Assert.Equal(2, retrieved.Count); + Assert.Equal("Hello", retrieved[0].Text); + Assert.Equal("World", retrieved[1].Text); } [Fact] - public void Remove_ReturnsFalseForNonExistentMessage() + public async Task InvokedAsyncWithEmptyMessagesDoesNotChangeProviderAsync() { - // Arrange - var provider = new InMemoryChatHistoryProvider(); - var message1 = new ChatMessage(ChatRole.User, "First"); - var message2 = new ChatMessage(ChatRole.Assistant, "Second"); - provider.Add(message1); - - // Act - var removed = provider.Remove(message2); - - // Assert - Assert.False(removed); - Assert.Single(provider); - } + var session = CreateMockSession(); - [Fact] - public void GetEnumerator_Generic_ReturnsAllMessages() - { // Arrange var provider = new InMemoryChatHistoryProvider(); - var message1 = new ChatMessage(ChatRole.User, "First"); - var message2 = new ChatMessage(ChatRole.Assistant, "Second"); - provider.Add(message1); - provider.Add(message2); - - // Act var messages = new List(); - messages.AddRange(provider); + + var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, session, messages, []); + await provider.InvokedAsync(context, CancellationToken.None); // Assert - Assert.Equal(2, messages.Count); - Assert.Same(message1, messages[0]); - Assert.Same(message2, messages[1]); + Assert.Empty(provider.GetMessages(session)); } [Fact] - public void GetEnumerator_NonGeneric_ReturnsAllMessages() + public async Task InvokedAsync_WithNullContext_ThrowsArgumentNullExceptionAsync() { // Arrange var provider = new InMemoryChatHistoryProvider(); - var message1 = new ChatMessage(ChatRole.User, "First"); - var message2 = new ChatMessage(ChatRole.Assistant, "Second"); - provider.Add(message1); - provider.Add(message2); - - // Act - var messages = new List(); - var enumerator = ((System.Collections.IEnumerable)provider).GetEnumerator(); - while (enumerator.MoveNext()) - { - messages.Add((ChatMessage)enumerator.Current); - } - // Assert - Assert.Equal(2, messages.Count); - Assert.Same(message1, messages[0]); - Assert.Same(message2, messages[1]); + // Act & Assert + await Assert.ThrowsAsync(() => provider.InvokedAsync(null!, CancellationToken.None).AsTask()); } [Fact] public async Task AddMessagesAsync_WithReducer_AfterMessageAdded_InvokesReducerAsync() { + var session = CreateMockSession(); + // Arrange var originalMessages = new List { @@ -516,21 +256,24 @@ public async Task AddMessagesAsync_WithReducer_AfterMessageAdded_InvokesReducerA .Setup(r => r.ReduceAsync(It.Is>(x => x.SequenceEqual(originalMessages)), It.IsAny())) .ReturnsAsync(reducedMessages); - var provider = new InMemoryChatHistoryProvider(reducerMock.Object, InMemoryChatHistoryProvider.ChatReducerTriggerEvent.AfterMessageAdded); + var provider = new InMemoryChatHistoryProvider(new() { ChatReducer = reducerMock.Object, ReducerTriggerEvent = InMemoryChatHistoryProviderOptions.ChatReducerTriggerEvent.AfterMessageAdded }); // Act - var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, originalMessages); + var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, session, originalMessages, []); await provider.InvokedAsync(context, CancellationToken.None); // Assert - Assert.Single(provider); - Assert.Equal("Reduced", provider[0].Text); + var messages = provider.GetMessages(session); + Assert.Single(messages); + Assert.Equal("Reduced", messages[0].Text); reducerMock.Verify(r => r.ReduceAsync(It.Is>(x => x.SequenceEqual(originalMessages)), It.IsAny()), Times.Once); } [Fact] public async Task GetMessagesAsync_WithReducer_BeforeMessagesRetrieval_InvokesReducerAsync() { + var session = CreateMockSession(); + // Arrange var originalMessages = new List { @@ -547,15 +290,11 @@ public async Task GetMessagesAsync_WithReducer_BeforeMessagesRetrieval_InvokesRe .Setup(r => r.ReduceAsync(It.Is>(x => x.SequenceEqual(originalMessages)), It.IsAny())) .ReturnsAsync(reducedMessages); - var provider = new InMemoryChatHistoryProvider(reducerMock.Object, InMemoryChatHistoryProvider.ChatReducerTriggerEvent.BeforeMessagesRetrieval); - // Add messages directly to the provider for this test - foreach (var msg in originalMessages) - { - provider.Add(msg); - } + var provider = new InMemoryChatHistoryProvider(new() { ChatReducer = reducerMock.Object, ReducerTriggerEvent = InMemoryChatHistoryProviderOptions.ChatReducerTriggerEvent.BeforeMessagesRetrieval }); + provider.SetMessages(session, new List(originalMessages)); // Act - var invokingContext = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, Array.Empty()); + var invokingContext = new ChatHistoryProvider.InvokingContext(s_mockAgent, session, Array.Empty()); var result = (await provider.InvokingAsync(invokingContext, CancellationToken.None)).ToList(); // Assert @@ -567,6 +306,8 @@ public async Task GetMessagesAsync_WithReducer_BeforeMessagesRetrieval_InvokesRe [Fact] public async Task AddMessagesAsync_WithReducer_ButWrongTrigger_DoesNotInvokeReducerAsync() { + var session = CreateMockSession(); + // Arrange var originalMessages = new List { @@ -575,21 +316,24 @@ public async Task AddMessagesAsync_WithReducer_ButWrongTrigger_DoesNotInvokeRedu var reducerMock = new Mock(); - var provider = new InMemoryChatHistoryProvider(reducerMock.Object, InMemoryChatHistoryProvider.ChatReducerTriggerEvent.BeforeMessagesRetrieval); + var provider = new InMemoryChatHistoryProvider(new() { ChatReducer = reducerMock.Object, ReducerTriggerEvent = InMemoryChatHistoryProviderOptions.ChatReducerTriggerEvent.BeforeMessagesRetrieval }); // Act - var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, originalMessages); + var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, session, originalMessages, []); await provider.InvokedAsync(context, CancellationToken.None); // Assert - Assert.Single(provider); - Assert.Equal("Hello", provider[0].Text); + var messages = provider.GetMessages(session); + Assert.Single(messages); + Assert.Equal("Hello", messages[0].Text); reducerMock.Verify(r => r.ReduceAsync(It.IsAny>(), It.IsAny()), Times.Never); } [Fact] public async Task GetMessagesAsync_WithReducer_ButWrongTrigger_DoesNotInvokeReducerAsync() { + var session = CreateMockSession(); + // Arrange var originalMessages = new List { @@ -598,13 +342,11 @@ public async Task GetMessagesAsync_WithReducer_ButWrongTrigger_DoesNotInvokeRedu var reducerMock = new Mock(); - var provider = new InMemoryChatHistoryProvider(reducerMock.Object, InMemoryChatHistoryProvider.ChatReducerTriggerEvent.AfterMessageAdded) - { - originalMessages[0] - }; + var provider = new InMemoryChatHistoryProvider(new() { ChatReducer = reducerMock.Object, ReducerTriggerEvent = InMemoryChatHistoryProviderOptions.ChatReducerTriggerEvent.AfterMessageAdded }); + provider.SetMessages(session, new List(originalMessages)); // Act - var invokingContext = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, Array.Empty()); + var invokingContext = new ChatHistoryProvider.InvokingContext(s_mockAgent, session, Array.Empty()); var result = (await provider.InvokingAsync(invokingContext, CancellationToken.None)).ToList(); // Assert @@ -616,27 +358,21 @@ public async Task GetMessagesAsync_WithReducer_ButWrongTrigger_DoesNotInvokeRedu [Fact] public async Task InvokedAsync_WithException_DoesNotAddMessagesAsync() { + var session = CreateMockSession(); + // Arrange var provider = new InMemoryChatHistoryProvider(); var requestMessages = new List { new(ChatRole.User, "Hello") }; - var responseMessages = new List - { - new(ChatRole.Assistant, "Hi there!") - }; - var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages) - { - ResponseMessages = responseMessages, - InvokeException = new InvalidOperationException("Test exception") - }; + var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, session, requestMessages, new InvalidOperationException("Test exception")); // Act await provider.InvokedAsync(context, CancellationToken.None); // Assert - Assert.Empty(provider); + Assert.Empty(provider.GetMessages(session)); } [Fact] @@ -649,6 +385,85 @@ public async Task InvokingAsync_WithNullContext_ThrowsArgumentNullExceptionAsync await Assert.ThrowsAsync(() => provider.InvokingAsync(null!, CancellationToken.None).AsTask()); } + [Fact] + public async Task InvokedAsync_DefaultFilter_ExcludesChatHistoryMessagesAsync() + { + // Arrange + var session = CreateMockSession(); + var provider = new InMemoryChatHistoryProvider(); + var requestMessages = new List + { + new(ChatRole.User, "External message"), + new(ChatRole.System, "From history") { AdditionalProperties = new() { { AgentRequestMessageSourceAttribution.AdditionalPropertiesKey, new AgentRequestMessageSourceAttribution(AgentRequestMessageSourceType.ChatHistory, "HistorySource") } } }, + new(ChatRole.System, "From context provider") { AdditionalProperties = new() { { AgentRequestMessageSourceAttribution.AdditionalPropertiesKey, new AgentRequestMessageSourceAttribution(AgentRequestMessageSourceType.AIContextProvider, "ContextSource") } } }, + }; + + var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, session, requestMessages, [new ChatMessage(ChatRole.Assistant, "Response")]); + + // Act + await provider.InvokedAsync(context, CancellationToken.None); + + // Assert - ChatHistory message excluded, AIContextProvider message included + var messages = provider.GetMessages(session); + Assert.Equal(3, messages.Count); + Assert.Equal("External message", messages[0].Text); + Assert.Equal("From context provider", messages[1].Text); + Assert.Equal("Response", messages[2].Text); + } + + [Fact] + public async Task InvokedAsync_CustomFilter_OverridesDefaultAsync() + { + // Arrange + var session = CreateMockSession(); + var provider = new InMemoryChatHistoryProvider(new InMemoryChatHistoryProviderOptions + { + StorageInputMessageFilter = messages => messages.Where(m => m.GetAgentRequestMessageSourceType() == AgentRequestMessageSourceType.External) + }); + var requestMessages = new List + { + new(ChatRole.User, "External message"), + new(ChatRole.System, "From history") { AdditionalProperties = new() { { AgentRequestMessageSourceAttribution.AdditionalPropertiesKey, new AgentRequestMessageSourceAttribution(AgentRequestMessageSourceType.ChatHistory, "HistorySource") } } }, + new(ChatRole.System, "From context provider") { AdditionalProperties = new() { { AgentRequestMessageSourceAttribution.AdditionalPropertiesKey, new AgentRequestMessageSourceAttribution(AgentRequestMessageSourceType.AIContextProvider, "ContextSource") } } }, + }; + + var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, session, requestMessages, [new ChatMessage(ChatRole.Assistant, "Response")]); + + // Act + await provider.InvokedAsync(context, CancellationToken.None); + + // Assert - Custom filter keeps only External messages (both ChatHistory and AIContextProvider excluded) + var messages = provider.GetMessages(session); + Assert.Equal(2, messages.Count); + Assert.Equal("External message", messages[0].Text); + Assert.Equal("Response", messages[1].Text); + } + + [Fact] + public async Task InvokingAsync_OutputFilter_FiltersOutputMessagesAsync() + { + // Arrange + var session = CreateMockSession(); + var provider = new InMemoryChatHistoryProvider(new InMemoryChatHistoryProviderOptions + { + RetrievalOutputMessageFilter = messages => messages.Where(m => m.Role == ChatRole.User) + }); + provider.SetMessages(session, + [ + new ChatMessage(ChatRole.User, "User message"), + new ChatMessage(ChatRole.Assistant, "Assistant message"), + new ChatMessage(ChatRole.System, "System message") + ]); + + // Act + var context = new ChatHistoryProvider.InvokingContext(s_mockAgent, session, []); + var result = (await provider.InvokingAsync(context, CancellationToken.None)).ToList(); + + // Assert - Only user messages pass through the output filter + Assert.Single(result); + Assert.Equal("User message", result[0].Text); + } + public class TestAIContent(string testData) : AIContent { public string TestData => testData; diff --git a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/ServiceIdAgentSessionTests.cs b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/ServiceIdAgentSessionTests.cs deleted file mode 100644 index e4a6626f72..0000000000 --- a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/ServiceIdAgentSessionTests.cs +++ /dev/null @@ -1,119 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. - -using System; -using System.Text.Json; - -namespace Microsoft.Agents.AI.Abstractions.UnitTests; - -/// -/// Tests for . -/// -public class ServiceIdAgentSessionTests -{ - #region Constructor and Property Tests - - [Fact] - public void Constructor_SetsDefaults() - { - // Arrange & Act - var session = new TestServiceIdAgentSession(); - - // Assert - Assert.Null(session.GetServiceSessionId()); - } - - [Fact] - public void Constructor_WithServiceSessionId_SetsProperty() - { - // Arrange & Act - var session = new TestServiceIdAgentSession("service-id-123"); - - // Assert - Assert.Equal("service-id-123", session.GetServiceSessionId()); - } - - [Fact] - public void Constructor_WithSerializedId_SetsProperty() - { - // Arrange - var serviceSessionWrapper = new ServiceIdAgentSession.ServiceIdAgentSessionState { ServiceSessionId = "service-id-456" }; - var json = JsonSerializer.SerializeToElement(serviceSessionWrapper, TestJsonSerializerContext.Default.ServiceIdAgentSessionState); - - // Act - var session = new TestServiceIdAgentSession(json); - - // Assert - Assert.Equal("service-id-456", session.GetServiceSessionId()); - } - - [Fact] - public void Constructor_WithSerializedUndefinedId_SetsProperty() - { - // Arrange - var emptyObject = new EmptyObject(); - var json = JsonSerializer.SerializeToElement(emptyObject, TestJsonSerializerContext.Default.EmptyObject); - - // Act - var session = new TestServiceIdAgentSession(json); - - // Assert - Assert.Null(session.GetServiceSessionId()); - } - - [Fact] - public void Constructor_WithInvalidJson_ThrowsArgumentException() - { - // Arrange - var invalidJson = JsonSerializer.SerializeToElement(42, TestJsonSerializerContext.Default.Int32); - - // Act & Assert - Assert.Throws(() => new TestServiceIdAgentSession(invalidJson)); - } - - #endregion - - #region SerializeAsync Tests - - [Fact] - public void Serialize_ReturnsCorrectJson_WhenServiceSessionIdIsSet() - { - // Arrange - var session = new TestServiceIdAgentSession("service-id-789"); - - // Act - var json = session.Serialize(); - - // Assert - Assert.Equal(JsonValueKind.Object, json.ValueKind); - Assert.True(json.TryGetProperty("serviceSessionId", out var idProperty)); - Assert.Equal("service-id-789", idProperty.GetString()); - } - - [Fact] - public void Serialize_ReturnsUndefinedServiceSessionId_WhenNotSet() - { - // Arrange - var session = new TestServiceIdAgentSession(); - - // Act - var json = session.Serialize(); - - // Assert - Assert.Equal(JsonValueKind.Object, json.ValueKind); - Assert.False(json.TryGetProperty("serviceSessionId", out _)); - } - - #endregion - - // Sealed test subclass to expose protected members for testing - private sealed class TestServiceIdAgentSession : ServiceIdAgentSession - { - public TestServiceIdAgentSession() { } - public TestServiceIdAgentSession(string serviceSessionId) : base(serviceSessionId) { } - public TestServiceIdAgentSession(JsonElement serializedSessionState) : base(serializedSessionState) { } - public string? GetServiceSessionId() => this.ServiceSessionId; - } - - // Helper class to represent empty objects - internal sealed class EmptyObject; -} diff --git a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/TestJsonSerializerContext.cs b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/TestJsonSerializerContext.cs index 05d13c2e95..c4f3b7511a 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/TestJsonSerializerContext.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/TestJsonSerializerContext.cs @@ -19,8 +19,5 @@ namespace Microsoft.Agents.AI.Abstractions.UnitTests; [JsonSerializable(typeof(Dictionary))] [JsonSerializable(typeof(string[]))] [JsonSerializable(typeof(int))] -[JsonSerializable(typeof(InMemoryAgentSession.InMemoryAgentSessionState))] -[JsonSerializable(typeof(ServiceIdAgentSession.ServiceIdAgentSessionState))] -[JsonSerializable(typeof(ServiceIdAgentSessionTests.EmptyObject))] [JsonSerializable(typeof(InMemoryChatHistoryProviderTests.TestAIContent))] internal sealed partial class TestJsonSerializerContext : JsonSerializerContext; diff --git a/dotnet/tests/Microsoft.Agents.AI.AzureAI.UnitTests/AzureAIProjectChatClientExtensionsTests.cs b/dotnet/tests/Microsoft.Agents.AI.AzureAI.UnitTests/AzureAIProjectChatClientExtensionsTests.cs index e2cbb16b1b..2f2e276ae9 100644 --- a/dotnet/tests/Microsoft.Agents.AI.AzureAI.UnitTests/AzureAIProjectChatClientExtensionsTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.AzureAI.UnitTests/AzureAIProjectChatClientExtensionsTests.cs @@ -2310,23 +2310,18 @@ public async Task GetAIAgentAsync_WithMatchingToolsProvided_CreatesAgentSuccessf #region CreateChatClientAgentOptions - Options Preservation Tests /// - /// Verify that CreateChatClientAgentOptions preserves AIContextProviderFactory. + /// Verify that CreateChatClientAgentOptions preserves AIContextProviders. /// [Fact] - public async Task GetAIAgentAsync_WithAIContextProviderFactory_PreservesFactoryAsync() + public async Task GetAIAgentAsync_WithAIContextProviders_PreservesProviderAsync() { // Arrange AIProjectClient client = this.CreateTestAgentClient(); - bool factoryInvoked = false; var options = new ChatClientAgentOptions { Name = "test-agent", ChatOptions = new ChatOptions { Instructions = "Test" }, - AIContextProviderFactory = (_, _) => - { - factoryInvoked = true; - return new ValueTask(new TestAIContextProvider()); - } + AIContextProviders = [new TestAIContextProvider()] }; // Act @@ -2334,15 +2329,13 @@ public async Task GetAIAgentAsync_WithAIContextProviderFactory_PreservesFactoryA // Assert Assert.NotNull(agent); - // Verify the factory was captured (though not necessarily invoked yet) - Assert.False(factoryInvoked); // Factory is not invoked during creation } /// - /// Verify that CreateChatClientAgentOptions preserves ChatHistoryProviderFactory. + /// Verify that CreateChatClientAgentOptions preserves ChatHistoryProvider. /// [Fact] - public async Task GetAIAgentAsync_WithChatHistoryProviderFactory_PreservesFactoryAsync() + public async Task GetAIAgentAsync_WithChatHistoryProvider_PreservesProviderAsync() { // Arrange AIProjectClient client = this.CreateTestAgentClient(); @@ -2350,7 +2343,7 @@ public async Task GetAIAgentAsync_WithChatHistoryProviderFactory_PreservesFactor { Name = "test-agent", ChatOptions = new ChatOptions { Instructions = "Test" }, - ChatHistoryProviderFactory = (_, _) => new ValueTask(new TestChatHistoryProvider()) + ChatHistoryProvider = new TestChatHistoryProvider() }; // Act @@ -3142,7 +3135,7 @@ private sealed class TestAIContextProvider : AIContextProvider { protected override ValueTask InvokingCoreAsync(InvokingContext context, CancellationToken cancellationToken = default) { - return new ValueTask(new AIContext()); + return new ValueTask(context.AIContext); } } @@ -3153,18 +3146,13 @@ private sealed class TestChatHistoryProvider : ChatHistoryProvider { protected override ValueTask> InvokingCoreAsync(InvokingContext context, CancellationToken cancellationToken = default) { - return new ValueTask>(Array.Empty()); + return new ValueTask>(context.RequestMessages); } protected override ValueTask InvokedCoreAsync(InvokedContext context, CancellationToken cancellationToken = default) { return default; } - - public override JsonElement Serialize(JsonSerializerOptions? jsonSerializerOptions = null) - { - return default; - } } } diff --git a/dotnet/tests/Microsoft.Agents.AI.CosmosNoSql.UnitTests/CosmosChatHistoryProviderTests.cs b/dotnet/tests/Microsoft.Agents.AI.CosmosNoSql.UnitTests/CosmosChatHistoryProviderTests.cs index e7276f70b1..12bd467e71 100644 --- a/dotnet/tests/Microsoft.Agents.AI.CosmosNoSql.UnitTests/CosmosChatHistoryProviderTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.CosmosNoSql.UnitTests/CosmosChatHistoryProviderTests.cs @@ -3,8 +3,6 @@ using System; using System.Collections.Generic; using System.Linq; -using System.Text.Json; -using System.Text.Json.Serialization.Metadata; using System.Threading.Tasks; using Azure.Core; using Azure.Identity; @@ -42,7 +40,8 @@ namespace Microsoft.Agents.AI.CosmosNoSql.UnitTests; public sealed class CosmosChatHistoryProviderTests : IAsyncLifetime, IDisposable { private static readonly AIAgent s_mockAgent = new Moq.Mock().Object; - private static readonly AgentSession s_mockSession = new Moq.Mock().Object; + + private static AgentSession CreateMockSession() => new Moq.Mock().Object; // Cosmos DB Emulator connection settings private const string EmulatorEndpoint = "https://localhost:8081"; @@ -151,34 +150,46 @@ private void SkipIfEmulatorNotAvailable() [SkippableFact] [Trait("Category", "CosmosDB")] - public void Constructor_WithConnectionString_ShouldCreateInstance() + public void StateKey_ReturnsDefaultKey_WhenNoStateKeyProvided() { // Arrange & Act this.SkipIfEmulatorNotAvailable(); - // Act - using var provider = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, TestContainerId, "test-conversation"); + using var provider = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, TestContainerId, + _ => new CosmosChatHistoryProvider.State("test-conversation")); // Assert - Assert.NotNull(provider); - Assert.Equal("test-conversation", provider.ConversationId); - Assert.Equal(s_testDatabaseId, provider.DatabaseId); - Assert.Equal(TestContainerId, provider.ContainerId); + Assert.Equal("CosmosChatHistoryProvider", provider.StateKey); } [SkippableFact] [Trait("Category", "CosmosDB")] - public void Constructor_WithConnectionStringNoConversationId_ShouldCreateInstance() + public void StateKey_ReturnsCustomKey_WhenSetViaConstructor() { - // Arrange + // Arrange & Act + this.SkipIfEmulatorNotAvailable(); + + using var provider = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, TestContainerId, + _ => new CosmosChatHistoryProvider.State("test-conversation"), + stateKey: "custom-key"); + + // Assert + Assert.Equal("custom-key", provider.StateKey); + } + + [SkippableFact] + [Trait("Category", "CosmosDB")] + public void Constructor_WithConnectionString_ShouldCreateInstance() + { + // Arrange & Act this.SkipIfEmulatorNotAvailable(); // Act - using var provider = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, TestContainerId); + using var provider = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, TestContainerId, + _ => new CosmosChatHistoryProvider.State("test-conversation")); // Assert Assert.NotNull(provider); - Assert.NotNull(provider.ConversationId); Assert.Equal(s_testDatabaseId, provider.DatabaseId); Assert.Equal(TestContainerId, provider.ContainerId); } @@ -189,18 +200,19 @@ public void Constructor_WithNullConnectionString_ShouldThrowArgumentException() { // Arrange & Act & Assert Assert.Throws(() => - new CosmosChatHistoryProvider((string)null!, s_testDatabaseId, TestContainerId, "test-conversation")); + new CosmosChatHistoryProvider((string)null!, s_testDatabaseId, TestContainerId, + _ => new CosmosChatHistoryProvider.State("test-conversation"))); } [SkippableFact] [Trait("Category", "CosmosDB")] - public void Constructor_WithEmptyConversationId_ShouldThrowArgumentException() + public void Constructor_WithNullStateInitializer_ShouldThrowArgumentNullException() { // Arrange & Act & Assert this.SkipIfEmulatorNotAvailable(); - Assert.Throws(() => - new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, TestContainerId, "")); + Assert.Throws(() => + new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, TestContainerId, null!)); } #endregion @@ -213,14 +225,13 @@ public async Task InvokedAsync_WithSingleMessage_ShouldAddMessageAsync() { // Arrange this.SkipIfEmulatorNotAvailable(); + var session = CreateMockSession(); var conversationId = Guid.NewGuid().ToString(); - using var provider = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, TestContainerId, conversationId); + using var provider = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, TestContainerId, + _ => new CosmosChatHistoryProvider.State(conversationId)); var message = new ChatMessage(ChatRole.User, "Hello, world!"); - var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, [message]) - { - ResponseMessages = [] - }; + var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, session, [message], []); // Act await provider.InvokedAsync(context); @@ -229,7 +240,7 @@ public async Task InvokedAsync_WithSingleMessage_ShouldAddMessageAsync() await Task.Delay(100); // Assert - var invokingContext = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, []); + var invokingContext = new ChatHistoryProvider.InvokingContext(s_mockAgent, session, []); var messages = await provider.InvokingAsync(invokingContext); var messageList = messages.ToList(); @@ -279,8 +290,10 @@ public async Task InvokedAsync_WithMultipleMessages_ShouldAddAllMessagesAsync() { // Arrange this.SkipIfEmulatorNotAvailable(); + var session = CreateMockSession(); var conversationId = Guid.NewGuid().ToString(); - using var provider = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, TestContainerId, conversationId); + using var provider = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, TestContainerId, + _ => new CosmosChatHistoryProvider.State(conversationId)); var requestMessages = new[] { new ChatMessage(ChatRole.User, "First message"), @@ -293,16 +306,13 @@ public async Task InvokedAsync_WithMultipleMessages_ShouldAddAllMessagesAsync() new ChatMessage(ChatRole.Assistant, "Response message") }; - var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages) - { - ResponseMessages = responseMessages - }; + var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, session, requestMessages, responseMessages); // Act await provider.InvokedAsync(context); // Assert - var invokingContext = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, []); + var invokingContext = new ChatHistoryProvider.InvokingContext(s_mockAgent, session, []); var retrievedMessages = await provider.InvokingAsync(invokingContext); var messageList = retrievedMessages.ToList(); Assert.Equal(5, messageList.Count); @@ -323,10 +333,12 @@ public async Task InvokingAsync_WithNoMessages_ShouldReturnEmptyAsync() { // Arrange this.SkipIfEmulatorNotAvailable(); - using var provider = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, TestContainerId, Guid.NewGuid().ToString()); + var session = CreateMockSession(); + using var provider = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, TestContainerId, + _ => new CosmosChatHistoryProvider.State(Guid.NewGuid().ToString())); // Act - var invokingContext = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, []); + var invokingContext = new ChatHistoryProvider.InvokingContext(s_mockAgent, session, []); var messages = await provider.InvokingAsync(invokingContext); // Assert @@ -339,21 +351,25 @@ public async Task InvokingAsync_WithConversationIsolation_ShouldOnlyReturnMessag { // Arrange this.SkipIfEmulatorNotAvailable(); + var session = CreateMockSession(); var conversation1 = Guid.NewGuid().ToString(); var conversation2 = Guid.NewGuid().ToString(); - using var store1 = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, TestContainerId, conversation1); - using var store2 = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, TestContainerId, conversation2); + // Use different stateKey values so the providers don't overwrite each other's state in the shared session + using var store1 = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, TestContainerId, + _ => new CosmosChatHistoryProvider.State(conversation1), stateKey: "conv1"); + using var store2 = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, TestContainerId, + _ => new CosmosChatHistoryProvider.State(conversation2), stateKey: "conv2"); - var context1 = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Message for conversation 1")]); - var context2 = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Message for conversation 2")]); + var context1 = new ChatHistoryProvider.InvokedContext(s_mockAgent, session, [new ChatMessage(ChatRole.User, "Message for conversation 1")], []); + var context2 = new ChatHistoryProvider.InvokedContext(s_mockAgent, session, [new ChatMessage(ChatRole.User, "Message for conversation 2")], []); await store1.InvokedAsync(context1); await store2.InvokedAsync(context2); // Act - var invokingContext1 = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, []); - var invokingContext2 = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, []); + var invokingContext1 = new ChatHistoryProvider.InvokingContext(s_mockAgent, session, []); + var invokingContext2 = new ChatHistoryProvider.InvokingContext(s_mockAgent, session, []); var messages1 = await store1.InvokingAsync(invokingContext1); var messages2 = await store2.InvokingAsync(invokingContext2); @@ -365,6 +381,8 @@ public async Task InvokingAsync_WithConversationIsolation_ShouldOnlyReturnMessag Assert.Single(messageList2); Assert.Equal("Message for conversation 1", messageList1[0].Text); Assert.Equal("Message for conversation 2", messageList2[0].Text); + Assert.Equal(AgentRequestMessageSourceType.ChatHistory, messageList1[0].GetAgentRequestMessageSourceType()); + Assert.Equal(AgentRequestMessageSourceType.ChatHistory, messageList2[0].GetAgentRequestMessageSourceType()); } #endregion @@ -377,8 +395,10 @@ public async Task FullWorkflow_AddAndGet_ShouldWorkCorrectlyAsync() { // Arrange this.SkipIfEmulatorNotAvailable(); + var session = CreateMockSession(); var conversationId = $"test-conversation-{Guid.NewGuid():N}"; // Use unique conversation ID - using var originalStore = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, TestContainerId, conversationId); + using var originalStore = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, TestContainerId, + _ => new CosmosChatHistoryProvider.State(conversationId)); var messages = new[] { @@ -390,18 +410,21 @@ public async Task FullWorkflow_AddAndGet_ShouldWorkCorrectlyAsync() }; // Act 1: Add messages - var invokedContext = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, messages); + var invokedContext = new ChatHistoryProvider.InvokedContext(s_mockAgent, session, messages, []); await originalStore.InvokedAsync(invokedContext); // Act 2: Verify messages were added - var invokingContext = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, []); + var invokingContext = new ChatHistoryProvider.InvokingContext(s_mockAgent, session, []); var retrievedMessages = await originalStore.InvokingAsync(invokingContext); var retrievedList = retrievedMessages.ToList(); Assert.Equal(5, retrievedList.Count); // Act 3: Create new provider instance for same conversation (test persistence) - using var newProvider = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, TestContainerId, conversationId); - var persistedMessages = await newProvider.InvokingAsync(invokingContext); + using var newProvider = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, TestContainerId, + _ => new CosmosChatHistoryProvider.State(conversationId)); + var newSession = CreateMockSession(); + var newInvokingContext = new ChatHistoryProvider.InvokingContext(s_mockAgent, newSession, []); + var persistedMessages = await newProvider.InvokingAsync(newInvokingContext); var persistedList = persistedMessages.ToList(); // Assert final state @@ -423,7 +446,8 @@ public void Dispose_AfterUse_ShouldNotThrow() { // Arrange this.SkipIfEmulatorNotAvailable(); - var provider = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, TestContainerId, Guid.NewGuid().ToString()); + var provider = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, TestContainerId, + _ => new CosmosChatHistoryProvider.State(Guid.NewGuid().ToString())); // Act & Assert provider.Dispose(); // Should not throw @@ -435,7 +459,8 @@ public void Dispose_MultipleCalls_ShouldNotThrow() { // Arrange this.SkipIfEmulatorNotAvailable(); - var provider = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, TestContainerId, Guid.NewGuid().ToString()); + var provider = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, TestContainerId, + _ => new CosmosChatHistoryProvider.State(Guid.NewGuid().ToString())); // Act & Assert provider.Dispose(); // First call @@ -454,11 +479,11 @@ public void Constructor_WithHierarchicalConnectionString_ShouldCreateInstance() this.SkipIfEmulatorNotAvailable(); // Act - using var provider = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, HierarchicalTestContainerId, "tenant-123", "user-456", "session-789"); + using var provider = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, HierarchicalTestContainerId, + _ => new CosmosChatHistoryProvider.State("session-789", "tenant-123", "user-456")); // Assert Assert.NotNull(provider); - Assert.Equal("session-789", provider.ConversationId); Assert.Equal(s_testDatabaseId, provider.DatabaseId); Assert.Equal(HierarchicalTestContainerId, provider.ContainerId); } @@ -472,11 +497,11 @@ public void Constructor_WithHierarchicalEndpoint_ShouldCreateInstance() // Act TokenCredential credential = new DefaultAzureCredential(); - using var provider = new CosmosChatHistoryProvider(EmulatorEndpoint, credential, s_testDatabaseId, HierarchicalTestContainerId, "tenant-123", "user-456", "session-789"); + using var provider = new CosmosChatHistoryProvider(EmulatorEndpoint, credential, s_testDatabaseId, HierarchicalTestContainerId, + _ => new CosmosChatHistoryProvider.State("session-789", "tenant-123", "user-456")); // Assert Assert.NotNull(provider); - Assert.Equal("session-789", provider.ConversationId); Assert.Equal(s_testDatabaseId, provider.DatabaseId); Assert.Equal(HierarchicalTestContainerId, provider.ContainerId); } @@ -489,46 +514,31 @@ public void Constructor_WithHierarchicalCosmosClient_ShouldCreateInstance() this.SkipIfEmulatorNotAvailable(); using var cosmosClient = new CosmosClient(EmulatorEndpoint, EmulatorKey); - using var provider = new CosmosChatHistoryProvider(cosmosClient, s_testDatabaseId, HierarchicalTestContainerId, "tenant-123", "user-456", "session-789"); + using var provider = new CosmosChatHistoryProvider(cosmosClient, s_testDatabaseId, HierarchicalTestContainerId, + _ => new CosmosChatHistoryProvider.State("session-789", "tenant-123", "user-456")); // Assert Assert.NotNull(provider); - Assert.Equal("session-789", provider.ConversationId); Assert.Equal(s_testDatabaseId, provider.DatabaseId); Assert.Equal(HierarchicalTestContainerId, provider.ContainerId); } [SkippableFact] [Trait("Category", "CosmosDB")] - public void Constructor_WithHierarchicalNullTenantId_ShouldThrowArgumentException() - { - // Arrange & Act & Assert - this.SkipIfEmulatorNotAvailable(); - - Assert.Throws(() => - new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, TestContainerId, null!, "user-456", "session-789")); - } - - [SkippableFact] - [Trait("Category", "CosmosDB")] - public void Constructor_WithHierarchicalEmptyUserId_ShouldThrowArgumentException() + public void State_WithEmptyConversationId_ShouldThrowArgumentException() { // Arrange & Act & Assert - this.SkipIfEmulatorNotAvailable(); - Assert.Throws(() => - new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, HierarchicalTestContainerId, "tenant-123", "", "session-789")); + new CosmosChatHistoryProvider.State("")); } [SkippableFact] [Trait("Category", "CosmosDB")] - public void Constructor_WithHierarchicalWhitespaceSessionId_ShouldThrowArgumentException() + public void State_WithWhitespaceConversationId_ShouldThrowArgumentException() { // Arrange & Act & Assert - this.SkipIfEmulatorNotAvailable(); - Assert.Throws(() => - new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, HierarchicalTestContainerId, "tenant-123", "user-456", " ")); + new CosmosChatHistoryProvider.State(" ")); } [SkippableFact] @@ -537,14 +547,16 @@ public async Task InvokedAsync_WithHierarchicalPartitioning_ShouldAddMessageWith { // Arrange this.SkipIfEmulatorNotAvailable(); + var session = CreateMockSession(); const string TenantId = "tenant-123"; const string UserId = "user-456"; const string SessionId = "session-789"; // Test hierarchical partitioning constructor with connection string - using var provider = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, HierarchicalTestContainerId, TenantId, UserId, SessionId); + using var provider = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, HierarchicalTestContainerId, + _ => new CosmosChatHistoryProvider.State(SessionId, TenantId, UserId)); var message = new ChatMessage(ChatRole.User, "Hello from hierarchical partitioning!"); - var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, [message]); + var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, session, [message], []); // Act await provider.InvokedAsync(context); @@ -553,7 +565,7 @@ public async Task InvokedAsync_WithHierarchicalPartitioning_ShouldAddMessageWith await Task.Delay(100); // Assert - var invokingContext = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, []); + var invokingContext = new ChatHistoryProvider.InvokingContext(s_mockAgent, session, []); var messages = await provider.InvokingAsync(invokingContext); var messageList = messages.ToList(); @@ -589,11 +601,13 @@ public async Task InvokedAsync_WithHierarchicalMultipleMessages_ShouldAddAllMess { // Arrange this.SkipIfEmulatorNotAvailable(); + var session = CreateMockSession(); const string TenantId = "tenant-batch"; const string UserId = "user-batch"; const string SessionId = "session-batch"; // Test hierarchical partitioning constructor with connection string - using var provider = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, HierarchicalTestContainerId, TenantId, UserId, SessionId); + using var provider = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, HierarchicalTestContainerId, + _ => new CosmosChatHistoryProvider.State(SessionId, TenantId, UserId)); var messages = new[] { new ChatMessage(ChatRole.User, "First hierarchical message"), @@ -601,7 +615,7 @@ public async Task InvokedAsync_WithHierarchicalMultipleMessages_ShouldAddAllMess new ChatMessage(ChatRole.User, "Third hierarchical message") }; - var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, messages); + var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, session, messages, []); // Act await provider.InvokedAsync(context); @@ -610,7 +624,7 @@ public async Task InvokedAsync_WithHierarchicalMultipleMessages_ShouldAddAllMess await Task.Delay(100); // Assert - var invokingContext = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, []); + var invokingContext = new ChatHistoryProvider.InvokingContext(s_mockAgent, session, []); var retrievedMessages = await provider.InvokingAsync(invokingContext); var messageList = retrievedMessages.ToList(); @@ -626,18 +640,22 @@ public async Task InvokingAsync_WithHierarchicalPartitionIsolation_ShouldIsolate { // Arrange this.SkipIfEmulatorNotAvailable(); + var session = CreateMockSession(); const string TenantId = "tenant-isolation"; const string UserId1 = "user-1"; const string UserId2 = "user-2"; const string SessionId = "session-isolation"; // Different userIds create different hierarchical partitions, providing proper isolation - using var store1 = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, HierarchicalTestContainerId, TenantId, UserId1, SessionId); - using var store2 = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, HierarchicalTestContainerId, TenantId, UserId2, SessionId); + // Use different stateKey values so the providers don't overwrite each other's state in the shared session + using var store1 = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, HierarchicalTestContainerId, + _ => new CosmosChatHistoryProvider.State(SessionId, TenantId, UserId1), stateKey: "user1"); + using var store2 = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, HierarchicalTestContainerId, + _ => new CosmosChatHistoryProvider.State(SessionId, TenantId, UserId2), stateKey: "user2"); // Add messages to both stores - var context1 = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Message from user 1")]); - var context2 = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Message from user 2")]); + var context1 = new ChatHistoryProvider.InvokedContext(s_mockAgent, session, [new ChatMessage(ChatRole.User, "Message from user 1")], []); + var context2 = new ChatHistoryProvider.InvokedContext(s_mockAgent, session, [new ChatMessage(ChatRole.User, "Message from user 2")], []); await store1.InvokedAsync(context1); await store2.InvokedAsync(context2); @@ -646,8 +664,8 @@ public async Task InvokingAsync_WithHierarchicalPartitionIsolation_ShouldIsolate await Task.Delay(100); // Act & Assert - var invokingContext1 = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, []); - var invokingContext2 = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, []); + var invokingContext1 = new ChatHistoryProvider.InvokingContext(s_mockAgent, session, []); + var invokingContext2 = new ChatHistoryProvider.InvokingContext(s_mockAgent, session, []); var messages1 = await store1.InvokingAsync(invokingContext1); var messageList1 = messages1.ToList(); @@ -664,43 +682,37 @@ public async Task InvokingAsync_WithHierarchicalPartitionIsolation_ShouldIsolate [SkippableFact] [Trait("Category", "CosmosDB")] - public async Task SerializeDeserialize_WithHierarchicalPartitioning_ShouldPreserveStateAsync() + public async Task StateBag_WithHierarchicalPartitioning_ShouldPreserveStateAcrossProviderInstancesAsync() { // Arrange this.SkipIfEmulatorNotAvailable(); + var session = CreateMockSession(); const string TenantId = "tenant-serialize"; const string UserId = "user-serialize"; const string SessionId = "session-serialize"; - using var originalStore = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, HierarchicalTestContainerId, TenantId, UserId, SessionId); + using var originalStore = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, HierarchicalTestContainerId, + _ => new CosmosChatHistoryProvider.State(SessionId, TenantId, UserId)); - var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Test serialization message")]); + var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, session, [new ChatMessage(ChatRole.User, "Test serialization message")], []); await originalStore.InvokedAsync(context); - // Act - Serialize the provider state - var serializedState = originalStore.Serialize(); - - // Create a new provider from the serialized state - using var cosmosClient = new CosmosClient(EmulatorEndpoint, EmulatorKey); - var serializerOptions = new JsonSerializerOptions - { - TypeInfoResolver = new DefaultJsonTypeInfoResolver() - }; - using var deserializedStore = CosmosChatHistoryProvider.CreateFromSerializedState(cosmosClient, serializedState, s_testDatabaseId, HierarchicalTestContainerId, serializerOptions); - // Wait a moment for eventual consistency await Task.Delay(100); - // Assert - The deserialized provider should have the same functionality - var invokingContext = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, []); - var messages = await deserializedStore.InvokingAsync(invokingContext); + // Act - Create a new provider that uses a different intializer, but we will use the same session. + using var newStore = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, HierarchicalTestContainerId, + _ => new CosmosChatHistoryProvider.State(Guid.NewGuid().ToString())); + + // Assert - The new provider should read the same messages from Cosmos DB + var invokingContext = new ChatHistoryProvider.InvokingContext(s_mockAgent, session, []); + var messages = await newStore.InvokingAsync(invokingContext); var messageList = messages.ToList(); Assert.Single(messageList); Assert.Equal("Test serialization message", messageList[0].Text); - Assert.Equal(SessionId, deserializedStore.ConversationId); - Assert.Equal(s_testDatabaseId, deserializedStore.DatabaseId); - Assert.Equal(HierarchicalTestContainerId, deserializedStore.ContainerId); + Assert.Equal(s_testDatabaseId, newStore.DatabaseId); + Assert.Equal(HierarchicalTestContainerId, newStore.ContainerId); } [SkippableFact] @@ -711,13 +723,17 @@ public async Task HierarchicalAndSimplePartitioning_ShouldCoexistAsync() this.SkipIfEmulatorNotAvailable(); const string SessionId = "coexist-session"; + var session = CreateMockSession(); // Create simple provider using simple partitioning container and hierarchical provider using hierarchical container - using var simpleProvider = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, TestContainerId, SessionId); - using var hierarchicalProvider = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, HierarchicalTestContainerId, "tenant-coexist", "user-coexist", SessionId); + // Use different stateKey values so the providers don't overwrite each other's state in the shared session + using var simpleProvider = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, TestContainerId, + _ => new CosmosChatHistoryProvider.State(SessionId), stateKey: "simple"); + using var hierarchicalProvider = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, HierarchicalTestContainerId, + _ => new CosmosChatHistoryProvider.State(SessionId, "tenant-coexist", "user-coexist"), stateKey: "hierarchical"); // Add messages to both - var simpleContext = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Simple partitioning message")]); - var hierarchicalContext = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Hierarchical partitioning message")]); + var simpleContext = new ChatHistoryProvider.InvokedContext(s_mockAgent, session, [new ChatMessage(ChatRole.User, "Simple partitioning message")], []); + var hierarchicalContext = new ChatHistoryProvider.InvokedContext(s_mockAgent, session, [new ChatMessage(ChatRole.User, "Hierarchical partitioning message")], []); await simpleProvider.InvokedAsync(simpleContext); await hierarchicalProvider.InvokedAsync(hierarchicalContext); @@ -726,7 +742,7 @@ public async Task HierarchicalAndSimplePartitioning_ShouldCoexistAsync() await Task.Delay(100); // Act & Assert - var invokingContext = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, []); + var invokingContext = new ChatHistoryProvider.InvokingContext(s_mockAgent, session, []); var simpleMessages = await simpleProvider.InvokingAsync(invokingContext); var simpleMessageList = simpleMessages.ToList(); @@ -747,9 +763,11 @@ public async Task MaxMessagesToRetrieve_ShouldLimitAndReturnMostRecentAsync() { // Arrange this.SkipIfEmulatorNotAvailable(); + var session = CreateMockSession(); const string ConversationId = "max-messages-test"; - using var provider = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, TestContainerId, ConversationId); + using var provider = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, TestContainerId, + _ => new CosmosChatHistoryProvider.State(ConversationId)); // Add 10 messages var messages = new List(); @@ -759,7 +777,7 @@ public async Task MaxMessagesToRetrieve_ShouldLimitAndReturnMostRecentAsync() await Task.Delay(10); // Small delay to ensure different timestamps } - var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, messages); + var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, session, messages, []); await provider.InvokedAsync(context); // Wait for eventual consistency @@ -767,7 +785,7 @@ public async Task MaxMessagesToRetrieve_ShouldLimitAndReturnMostRecentAsync() // Act - Set max to 5 and retrieve provider.MaxMessagesToRetrieve = 5; - var invokingContext = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, []); + var invokingContext = new ChatHistoryProvider.InvokingContext(s_mockAgent, session, []); var retrievedMessages = await provider.InvokingAsync(invokingContext); var messageList = retrievedMessages.ToList(); @@ -786,9 +804,11 @@ public async Task MaxMessagesToRetrieve_Null_ShouldReturnAllMessagesAsync() { // Arrange this.SkipIfEmulatorNotAvailable(); + var session = CreateMockSession(); const string ConversationId = "max-messages-null-test"; - using var provider = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, TestContainerId, ConversationId); + using var provider = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, TestContainerId, + _ => new CosmosChatHistoryProvider.State(ConversationId)); // Add 10 messages var messages = new List(); @@ -797,14 +817,14 @@ public async Task MaxMessagesToRetrieve_Null_ShouldReturnAllMessagesAsync() messages.Add(new ChatMessage(ChatRole.User, $"Message {i}")); } - var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, messages); + var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, session, messages, []); await provider.InvokedAsync(context); // Wait for eventual consistency await Task.Delay(100); // Act - No limit set (default null) - var invokingContext = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, []); + var invokingContext = new ChatHistoryProvider.InvokingContext(s_mockAgent, session, []); var retrievedMessages = await provider.InvokingAsync(invokingContext); var messageList = retrievedMessages.ToList(); @@ -815,4 +835,119 @@ public async Task MaxMessagesToRetrieve_Null_ShouldReturnAllMessagesAsync() } #endregion + + #region Message Filter Tests + + [SkippableFact] + [Trait("Category", "CosmosDB")] + public async Task InvokedAsync_DefaultFilter_ExcludesChatHistoryMessagesFromStorageAsync() + { + // Arrange + this.SkipIfEmulatorNotAvailable(); + var session = CreateMockSession(); + var conversationId = Guid.NewGuid().ToString(); + using var provider = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, TestContainerId, + _ => new CosmosChatHistoryProvider.State(conversationId)); + + var requestMessages = new[] + { + new ChatMessage(ChatRole.User, "External message"), + new ChatMessage(ChatRole.System, "From history") { AdditionalProperties = new() { { AgentRequestMessageSourceAttribution.AdditionalPropertiesKey, new AgentRequestMessageSourceAttribution(AgentRequestMessageSourceType.ChatHistory, "HistorySource") } } }, + new ChatMessage(ChatRole.System, "From context provider") { AdditionalProperties = new() { { AgentRequestMessageSourceAttribution.AdditionalPropertiesKey, new AgentRequestMessageSourceAttribution(AgentRequestMessageSourceType.AIContextProvider, "ContextSource") } } }, + }; + + var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, session, requestMessages, [new ChatMessage(ChatRole.Assistant, "Response")]); + + // Act + await provider.InvokedAsync(context); + + // Wait for eventual consistency + await Task.Delay(100); + + // Assert - ChatHistory message excluded, External + AIContextProvider + Response stored + var invokingContext = new ChatHistoryProvider.InvokingContext(s_mockAgent, session, []); + var messages = (await provider.InvokingAsync(invokingContext)).ToList(); + Assert.Equal(3, messages.Count); + Assert.Equal("External message", messages[0].Text); + Assert.Equal("From context provider", messages[1].Text); + Assert.Equal("Response", messages[2].Text); + } + + [SkippableFact] + [Trait("Category", "CosmosDB")] + public async Task InvokedAsync_CustomStorageInputFilter_OverridesDefaultAsync() + { + // Arrange + this.SkipIfEmulatorNotAvailable(); + var session = CreateMockSession(); + var conversationId = Guid.NewGuid().ToString(); + using var provider = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, TestContainerId, + _ => new CosmosChatHistoryProvider.State(conversationId)) + { + // Custom filter: only store External messages (also exclude AIContextProvider) + StorageInputMessageFilter = messages => messages.Where(m => m.GetAgentRequestMessageSourceType() == AgentRequestMessageSourceType.External) + }; + + var requestMessages = new[] + { + new ChatMessage(ChatRole.User, "External message"), + new ChatMessage(ChatRole.System, "From history") { AdditionalProperties = new() { { AgentRequestMessageSourceAttribution.AdditionalPropertiesKey, new AgentRequestMessageSourceAttribution(AgentRequestMessageSourceType.ChatHistory, "HistorySource") } } }, + new ChatMessage(ChatRole.System, "From context provider") { AdditionalProperties = new() { { AgentRequestMessageSourceAttribution.AdditionalPropertiesKey, new AgentRequestMessageSourceAttribution(AgentRequestMessageSourceType.AIContextProvider, "ContextSource") } } }, + }; + + var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, session, requestMessages, [new ChatMessage(ChatRole.Assistant, "Response")]); + + // Act + await provider.InvokedAsync(context); + + // Wait for eventual consistency + await Task.Delay(100); + + // Assert - Custom filter: only External + Response stored (both ChatHistory and AIContextProvider excluded) + var invokingContext = new ChatHistoryProvider.InvokingContext(s_mockAgent, session, []); + var messages = (await provider.InvokingAsync(invokingContext)).ToList(); + Assert.Equal(2, messages.Count); + Assert.Equal("External message", messages[0].Text); + Assert.Equal("Response", messages[1].Text); + } + + [SkippableFact] + [Trait("Category", "CosmosDB")] + public async Task InvokingAsync_RetrievalOutputFilter_FiltersRetrievedMessagesAsync() + { + // Arrange + this.SkipIfEmulatorNotAvailable(); + var session = CreateMockSession(); + var conversationId = Guid.NewGuid().ToString(); + using var provider = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, TestContainerId, + _ => new CosmosChatHistoryProvider.State(conversationId)) + { + // Only return User messages when retrieving + RetrievalOutputMessageFilter = messages => messages.Where(m => m.Role == ChatRole.User) + }; + + var requestMessages = new[] + { + new ChatMessage(ChatRole.User, "User message"), + new ChatMessage(ChatRole.System, "System message"), + }; + + var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, session, requestMessages, [new ChatMessage(ChatRole.Assistant, "Assistant response")]); + + await provider.InvokedAsync(context); + + // Wait for eventual consistency + await Task.Delay(100); + + // Act + var invokingContext = new ChatHistoryProvider.InvokingContext(s_mockAgent, session, []); + var messages = (await provider.InvokingAsync(invokingContext)).ToList(); + + // Assert - Only User messages returned (System and Assistant filtered by RetrievalOutputMessageFilter) + Assert.Single(messages); + Assert.Equal("User message", messages[0].Text); + Assert.Equal(ChatRole.User, messages[0].Role); + } + + #endregion } diff --git a/dotnet/tests/Microsoft.Agents.AI.DurableTask.UnitTests/DurableAgentSessionTests.cs b/dotnet/tests/Microsoft.Agents.AI.DurableTask.UnitTests/DurableAgentSessionTests.cs index 4bf8ebc718..bc06c35ab8 100644 --- a/dotnet/tests/Microsoft.Agents.AI.DurableTask.UnitTests/DurableAgentSessionTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.DurableTask.UnitTests/DurableAgentSessionTests.cs @@ -15,7 +15,7 @@ public void BuiltInSerialization() JsonElement serializedSession = session.Serialize(); // Expected format: "{\"sessionId\":\"@dafx-test-agent@\"}" - string expectedSerializedSession = $"{{\"sessionId\":\"@dafx-{sessionId.Name}@{sessionId.Key}\"}}"; + string expectedSerializedSession = $"{{\"sessionId\":\"@dafx-{sessionId.Name}@{sessionId.Key}\",\"stateBag\":{{}}}}"; Assert.Equal(expectedSerializedSession, serializedSession.ToString()); DurableAgentSession deserializedSession = DurableAgentSession.Deserialize(serializedSession); @@ -33,11 +33,47 @@ public void STJSerialization() string serializedSession = JsonSerializer.Serialize(session, typeof(DurableAgentSession)); // Expected format: "{\"sessionId\":\"@dafx-test-agent@\"}" - string expectedSerializedSession = $"{{\"sessionId\":\"@dafx-{sessionId.Name}@{sessionId.Key}\"}}"; + string expectedSerializedSession = $"{{\"sessionId\":\"@dafx-{sessionId.Name}@{sessionId.Key}\",\"stateBag\":{{}}}}"; Assert.Equal(expectedSerializedSession, serializedSession); DurableAgentSession? deserializedSession = JsonSerializer.Deserialize(serializedSession); Assert.NotNull(deserializedSession); Assert.Equal(sessionId, deserializedSession.SessionId); } + + [Fact] + public void BuiltInSerialization_RoundTrip_PreservesStateBag() + { + // Arrange + AgentSessionId sessionId = AgentSessionId.WithRandomKey("test-agent"); + DurableAgentSession session = new(sessionId); + session.StateBag.SetValue("durableKey", "durableValue"); + + // Act + JsonElement serializedSession = session.Serialize(); + DurableAgentSession deserializedSession = DurableAgentSession.Deserialize(serializedSession); + + // Assert + Assert.Equal(sessionId, deserializedSession.SessionId); + Assert.True(deserializedSession.StateBag.TryGetValue("durableKey", out var value)); + Assert.Equal("durableValue", value); + } + + [Fact] + public void STJSerialization_RoundTrip_PreservesStateBag() + { + // Arrange + AgentSessionId sessionId = AgentSessionId.WithRandomKey("test-agent"); + DurableAgentSession session = new(sessionId); + session.StateBag.SetValue("stjKey", "stjValue"); + + // Act + string serializedSession = JsonSerializer.Serialize(session, typeof(DurableAgentSession)); + DurableAgentSession? deserializedSession = JsonSerializer.Deserialize(serializedSession); + + // Assert + Assert.NotNull(deserializedSession); + Assert.True(deserializedSession.StateBag.TryGetValue("stjKey", out var value)); + Assert.Equal("stjValue", value); + } } diff --git a/dotnet/tests/Microsoft.Agents.AI.Hosting.AGUI.AspNetCore.IntegrationTests/BasicStreamingTests.cs b/dotnet/tests/Microsoft.Agents.AI.Hosting.AGUI.AspNetCore.IntegrationTests/BasicStreamingTests.cs index 03dfe63d99..1be1eddef1 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Hosting.AGUI.AspNetCore.IntegrationTests/BasicStreamingTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Hosting.AGUI.AspNetCore.IntegrationTests/BasicStreamingTests.cs @@ -7,6 +7,7 @@ using System.Net.Http; using System.Runtime.CompilerServices; using System.Text.Json; +using System.Text.Json.Serialization; using System.Threading; using System.Threading.Tasks; using FluentAssertions; @@ -281,10 +282,10 @@ internal sealed class FakeChatClientAgent : AIAgent public override string? Description => "A fake agent for testing"; protected override ValueTask CreateSessionCoreAsync(CancellationToken cancellationToken = default) => - new(new FakeInMemoryAgentSession()); + new(new FakeAgentSession()); protected override ValueTask DeserializeSessionCoreAsync(JsonElement serializedState, JsonSerializerOptions? jsonSerializerOptions = null, CancellationToken cancellationToken = default) => - new(new FakeInMemoryAgentSession(serializedState, jsonSerializerOptions)); + new(serializedState.Deserialize(jsonSerializerOptions)!); protected override ValueTask SerializeSessionCoreAsync(AgentSession session, JsonSerializerOptions? jsonSerializerOptions = null, CancellationToken cancellationToken = default) => throw new NotImplementedException(); @@ -326,15 +327,14 @@ protected override async IAsyncEnumerable RunCoreStreamingA } } - private sealed class FakeInMemoryAgentSession : InMemoryAgentSession + private sealed class FakeAgentSession : AgentSession { - public FakeInMemoryAgentSession() - : base() + public FakeAgentSession() { } - public FakeInMemoryAgentSession(JsonElement serializedSession, JsonSerializerOptions? jsonSerializerOptions = null) - : base(serializedSession, jsonSerializerOptions) + [JsonConstructor] + public FakeAgentSession(AgentSessionStateBag stateBag) : base(stateBag) { } } @@ -348,19 +348,19 @@ internal sealed class FakeMultiMessageAgent : AIAgent public override string? Description => "A fake agent that sends multiple messages for testing"; protected override ValueTask CreateSessionCoreAsync(CancellationToken cancellationToken = default) => - new(new FakeInMemoryAgentSession()); + new(new FakeAgentSession()); protected override ValueTask DeserializeSessionCoreAsync(JsonElement serializedState, JsonSerializerOptions? jsonSerializerOptions = null, CancellationToken cancellationToken = default) => - new(new FakeInMemoryAgentSession(serializedState, jsonSerializerOptions)); + new(serializedState.Deserialize(jsonSerializerOptions)!); protected override ValueTask SerializeSessionCoreAsync(AgentSession session, JsonSerializerOptions? jsonSerializerOptions = null, CancellationToken cancellationToken = default) { - if (session is not FakeInMemoryAgentSession fakeSession) + if (session is not FakeAgentSession fakeSession) { throw new InvalidOperationException("The provided session is not compatible with the agent. Only sessions created by the agent can be serialized."); } - return new(fakeSession.Serialize(jsonSerializerOptions)); + return new(JsonSerializer.SerializeToElement(fakeSession, jsonSerializerOptions)); } protected override async Task RunCoreAsync( @@ -427,19 +427,16 @@ protected override async IAsyncEnumerable RunCoreStreamingA } } - private sealed class FakeInMemoryAgentSession : InMemoryAgentSession + private sealed class FakeAgentSession : AgentSession { - public FakeInMemoryAgentSession() - : base() + public FakeAgentSession() { } - public FakeInMemoryAgentSession(JsonElement serializedSession, JsonSerializerOptions? jsonSerializerOptions = null) - : base(serializedSession, jsonSerializerOptions) + [JsonConstructor] + public FakeAgentSession(AgentSessionStateBag stateBag) : base(stateBag) { } - internal new JsonElement Serialize(JsonSerializerOptions? jsonSerializerOptions = null) - => base.Serialize(jsonSerializerOptions); } public override object? GetService(Type serviceType, object? serviceKey = null) => null; diff --git a/dotnet/tests/Microsoft.Agents.AI.Hosting.AGUI.AspNetCore.IntegrationTests/ForwardedPropertiesTests.cs b/dotnet/tests/Microsoft.Agents.AI.Hosting.AGUI.AspNetCore.IntegrationTests/ForwardedPropertiesTests.cs index 67108676ac..844900372b 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Hosting.AGUI.AspNetCore.IntegrationTests/ForwardedPropertiesTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Hosting.AGUI.AspNetCore.IntegrationTests/ForwardedPropertiesTests.cs @@ -9,6 +9,7 @@ using System.Runtime.CompilerServices; using System.Text; using System.Text.Json; +using System.Text.Json.Serialization; using System.Threading; using System.Threading.Tasks; using FluentAssertions; @@ -335,35 +336,31 @@ protected override async IAsyncEnumerable RunCoreStreamingA } protected override ValueTask CreateSessionCoreAsync(CancellationToken cancellationToken = default) => - new(new FakeInMemoryAgentSession()); + new(new FakeAgentSession()); protected override ValueTask DeserializeSessionCoreAsync(JsonElement serializedState, JsonSerializerOptions? jsonSerializerOptions = null, CancellationToken cancellationToken = default) => - new(new FakeInMemoryAgentSession(serializedState, jsonSerializerOptions)); + new(serializedState.Deserialize(jsonSerializerOptions)!); protected override ValueTask SerializeSessionCoreAsync(AgentSession session, JsonSerializerOptions? jsonSerializerOptions = null, CancellationToken cancellationToken = default) { - if (session is not FakeInMemoryAgentSession fakeSession) + if (session is not FakeAgentSession fakeSession) { throw new InvalidOperationException("The provided session is not compatible with the agent. Only sessions created by the agent can be serialized."); } - return new(fakeSession.Serialize(jsonSerializerOptions)); + return new(JsonSerializer.SerializeToElement(fakeSession, jsonSerializerOptions)); } - private sealed class FakeInMemoryAgentSession : InMemoryAgentSession + private sealed class FakeAgentSession : AgentSession { - public FakeInMemoryAgentSession() - : base() + public FakeAgentSession() { } - public FakeInMemoryAgentSession(JsonElement serializedSession, JsonSerializerOptions? jsonSerializerOptions = null) - : base(serializedSession, jsonSerializerOptions) + [JsonConstructor] + public FakeAgentSession(AgentSessionStateBag stateBag) : base(stateBag) { } - - internal new JsonElement Serialize(JsonSerializerOptions? jsonSerializerOptions = null) - => base.Serialize(jsonSerializerOptions); } public override object? GetService(Type serviceType, object? serviceKey = null) => null; diff --git a/dotnet/tests/Microsoft.Agents.AI.Hosting.AGUI.AspNetCore.IntegrationTests/SharedStateTests.cs b/dotnet/tests/Microsoft.Agents.AI.Hosting.AGUI.AspNetCore.IntegrationTests/SharedStateTests.cs index 9ff3dde1a4..78441254d9 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Hosting.AGUI.AspNetCore.IntegrationTests/SharedStateTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Hosting.AGUI.AspNetCore.IntegrationTests/SharedStateTests.cs @@ -7,6 +7,7 @@ using System.Net.Http; using System.Runtime.CompilerServices; using System.Text.Json; +using System.Text.Json.Serialization; using System.Threading; using System.Threading.Tasks; using FluentAssertions; @@ -418,35 +419,31 @@ stateObj is JsonElement state && } protected override ValueTask CreateSessionCoreAsync(CancellationToken cancellationToken = default) => - new(new FakeInMemoryAgentSession()); + new(new FakeAgentSession()); protected override ValueTask DeserializeSessionCoreAsync(JsonElement serializedState, JsonSerializerOptions? jsonSerializerOptions = null, CancellationToken cancellationToken = default) => - new(new FakeInMemoryAgentSession(serializedState, jsonSerializerOptions)); + new(serializedState.Deserialize(jsonSerializerOptions)!); protected override ValueTask SerializeSessionCoreAsync(AgentSession session, JsonSerializerOptions? jsonSerializerOptions = null, CancellationToken cancellationToken = default) { - if (session is not FakeInMemoryAgentSession fakeSession) + if (session is not FakeAgentSession fakeSession) { throw new InvalidOperationException("The provided session is not compatible with the agent. Only sessions created by the agent can be serialized."); } - return new(fakeSession.Serialize(jsonSerializerOptions)); + return new(JsonSerializer.SerializeToElement(fakeSession, jsonSerializerOptions)); } - private sealed class FakeInMemoryAgentSession : InMemoryAgentSession + private sealed class FakeAgentSession : AgentSession { - public FakeInMemoryAgentSession() - : base() + public FakeAgentSession() { } - public FakeInMemoryAgentSession(JsonElement serializedSession, JsonSerializerOptions? jsonSerializerOptions = null) - : base(serializedSession, jsonSerializerOptions) + [JsonConstructor] + public FakeAgentSession(AgentSessionStateBag stateBag) : base(stateBag) { } - - internal new JsonElement Serialize(JsonSerializerOptions? jsonSerializerOptions = null) - => base.Serialize(jsonSerializerOptions); } public override object? GetService(Type serviceType, object? serviceKey = null) => null; diff --git a/dotnet/tests/Microsoft.Agents.AI.Hosting.AGUI.AspNetCore.UnitTests/AGUIEndpointRouteBuilderExtensionsTests.cs b/dotnet/tests/Microsoft.Agents.AI.Hosting.AGUI.AspNetCore.UnitTests/AGUIEndpointRouteBuilderExtensionsTests.cs index 16efbd0e6e..383a7d1062 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Hosting.AGUI.AspNetCore.UnitTests/AGUIEndpointRouteBuilderExtensionsTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Hosting.AGUI.AspNetCore.UnitTests/AGUIEndpointRouteBuilderExtensionsTests.cs @@ -6,6 +6,7 @@ using System.Linq; using System.Text; using System.Text.Json; +using System.Text.Json.Serialization; using System.Threading; using System.Threading.Tasks; using Microsoft.Agents.AI.Hosting.AGUI.AspNetCore.Shared; @@ -426,19 +427,19 @@ private sealed class MultiResponseAgent : AIAgent public override string? Description => "Agent that produces multiple text chunks"; protected override ValueTask CreateSessionCoreAsync(CancellationToken cancellationToken = default) => - new(new TestInMemoryAgentSession()); + new(new TestAgentSession()); protected override ValueTask DeserializeSessionCoreAsync(JsonElement serializedState, JsonSerializerOptions? jsonSerializerOptions = null, CancellationToken cancellationToken = default) => - new(new TestInMemoryAgentSession(serializedState, jsonSerializerOptions)); + new(serializedState.Deserialize(jsonSerializerOptions)!); protected override ValueTask SerializeSessionCoreAsync(AgentSession session, JsonSerializerOptions? jsonSerializerOptions = null, CancellationToken cancellationToken = default) { - if (session is not TestInMemoryAgentSession testSession) + if (session is not TestAgentSession testSession) { throw new InvalidOperationException("The provided session is not compatible with the agent. Only sessions created by the agent can be serialized."); } - return new(testSession.Serialize(jsonSerializerOptions)); + return new(JsonSerializer.SerializeToElement(testSession, jsonSerializerOptions)); } protected override Task RunCoreAsync(IEnumerable messages, AgentSession? session = null, AgentRunOptions? options = null, CancellationToken cancellationToken = default) @@ -506,20 +507,16 @@ private RequestDelegate CreateRequestDelegate( }; } - private sealed class TestInMemoryAgentSession : InMemoryAgentSession + private sealed class TestAgentSession : AgentSession { - public TestInMemoryAgentSession() - : base() + public TestAgentSession() { } - public TestInMemoryAgentSession(JsonElement serializedSessionState, JsonSerializerOptions? jsonSerializerOptions = null) - : base(serializedSessionState, jsonSerializerOptions, null) + [JsonConstructor] + public TestAgentSession(AgentSessionStateBag stateBag) : base(stateBag) { } - - internal new JsonElement Serialize(JsonSerializerOptions? jsonSerializerOptions = null) - => base.Serialize(jsonSerializerOptions); } private sealed class TestAgent : AIAgent @@ -529,19 +526,19 @@ private sealed class TestAgent : AIAgent public override string? Description => "Test agent"; protected override ValueTask CreateSessionCoreAsync(CancellationToken cancellationToken = default) => - new(new TestInMemoryAgentSession()); + new(new TestAgentSession()); protected override ValueTask DeserializeSessionCoreAsync(JsonElement serializedState, JsonSerializerOptions? jsonSerializerOptions = null, CancellationToken cancellationToken = default) => - new(new TestInMemoryAgentSession(serializedState, jsonSerializerOptions)); + new(serializedState.Deserialize(jsonSerializerOptions)!); protected override ValueTask SerializeSessionCoreAsync(AgentSession session, JsonSerializerOptions? jsonSerializerOptions = null, CancellationToken cancellationToken = default) { - if (session is not TestInMemoryAgentSession testSession) + if (session is not TestAgentSession testSession) { throw new InvalidOperationException("The provided session is not compatible with the agent. Only sessions created by the agent can be serialized."); } - return new(testSession.Serialize(jsonSerializerOptions)); + return new(JsonSerializer.SerializeToElement(testSession, jsonSerializerOptions)); } protected override Task RunCoreAsync(IEnumerable messages, AgentSession? session = null, AgentRunOptions? options = null, CancellationToken cancellationToken = default) diff --git a/dotnet/tests/Microsoft.Agents.AI.Mem0.IntegrationTests/Mem0ProviderTests.cs b/dotnet/tests/Microsoft.Agents.AI.Mem0.IntegrationTests/Mem0ProviderTests.cs index a10f1246aa..72e9f6bdff 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Mem0.IntegrationTests/Mem0ProviderTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Mem0.IntegrationTests/Mem0ProviderTests.cs @@ -1,6 +1,8 @@ // Copyright (c) Microsoft. All rights reserved. using System; +using System.Collections.Generic; +using System.Linq; using System.Net.Http; using System.Net.Http.Headers; using System.Threading; @@ -19,7 +21,6 @@ public sealed class Mem0ProviderTests : IDisposable private const string SkipReason = "Requires a Mem0 service configured"; // Set to null to enable. private static readonly AIAgent s_mockAgent = new Moq.Mock().Object; - private static readonly AgentSession s_mockSession = new Moq.Mock().Object; private readonly HttpClient _httpClient; @@ -49,21 +50,22 @@ public async Task CanAddAndRetrieveUserMemoriesAsync() var question = new ChatMessage(ChatRole.User, "What is my name?"); var input = new ChatMessage(ChatRole.User, "Hello, my name is Caoimhe."); var storageScope = new Mem0ProviderScope { ThreadId = "it-thread-1", UserId = "it-user-1" }; - var sut = new Mem0Provider(this._httpClient, storageScope); + var mockSession = new TestAgentSession(); + var sut = new Mem0Provider(this._httpClient, _ => new Mem0Provider.State(storageScope)); - await sut.ClearStoredMemoriesAsync(); - var ctxBefore = await sut.InvokingAsync(new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [question])); - Assert.DoesNotContain("Caoimhe", ctxBefore.Messages?[0].Text ?? string.Empty); + await sut.ClearStoredMemoriesAsync(mockSession); + var ctxBefore = await sut.InvokingAsync(new AIContextProvider.InvokingContext(s_mockAgent, mockSession, new AIContext { Messages = new List { question } })); + Assert.DoesNotContain("Caoimhe", ctxBefore.Messages?.LastOrDefault()?.Text ?? string.Empty); // Act - await sut.InvokedAsync(new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, [input])); - var ctxAfterAdding = await GetContextWithRetryAsync(sut, question); - await sut.ClearStoredMemoriesAsync(); - var ctxAfterClearing = await sut.InvokingAsync(new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [question])); + await sut.InvokedAsync(new AIContextProvider.InvokedContext(s_mockAgent, mockSession, [input], [])); + var ctxAfterAdding = await GetContextWithRetryAsync(sut, mockSession, question); + await sut.ClearStoredMemoriesAsync(mockSession); + var ctxAfterClearing = await sut.InvokingAsync(new AIContextProvider.InvokingContext(s_mockAgent, mockSession, new AIContext { Messages = new List { question } })); // Assert - Assert.Contains("Caoimhe", ctxAfterAdding.Messages?[0].Text ?? string.Empty); - Assert.DoesNotContain("Caoimhe", ctxAfterClearing.Messages?[0].Text ?? string.Empty); + Assert.Contains("Caoimhe", ctxAfterAdding.Messages?.LastOrDefault()?.Text ?? string.Empty); + Assert.DoesNotContain("Caoimhe", ctxAfterClearing.Messages?.LastOrDefault()?.Text ?? string.Empty); } [Fact(Skip = SkipReason)] @@ -73,21 +75,22 @@ public async Task CanAddAndRetrieveAgentMemoriesAsync() var question = new ChatMessage(ChatRole.User, "What is your name?"); var assistantIntro = new ChatMessage(ChatRole.Assistant, "Hello, I'm a friendly assistant and my name is Caoimhe."); var storageScope = new Mem0ProviderScope { AgentId = "it-agent-1" }; - var sut = new Mem0Provider(this._httpClient, storageScope); + var mockSession = new TestAgentSession(); + var sut = new Mem0Provider(this._httpClient, _ => new Mem0Provider.State(storageScope)); - await sut.ClearStoredMemoriesAsync(); - var ctxBefore = await sut.InvokingAsync(new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [question])); - Assert.DoesNotContain("Caoimhe", ctxBefore.Messages?[0].Text ?? string.Empty); + await sut.ClearStoredMemoriesAsync(mockSession); + var ctxBefore = await sut.InvokingAsync(new AIContextProvider.InvokingContext(s_mockAgent, mockSession, new AIContext { Messages = new List { question } })); + Assert.DoesNotContain("Caoimhe", ctxBefore.Messages?.LastOrDefault()?.Text ?? string.Empty); // Act - await sut.InvokedAsync(new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, [assistantIntro])); - var ctxAfterAdding = await GetContextWithRetryAsync(sut, question); - await sut.ClearStoredMemoriesAsync(); - var ctxAfterClearing = await sut.InvokingAsync(new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [question])); + await sut.InvokedAsync(new AIContextProvider.InvokedContext(s_mockAgent, mockSession, [assistantIntro], [])); + var ctxAfterAdding = await GetContextWithRetryAsync(sut, mockSession, question); + await sut.ClearStoredMemoriesAsync(mockSession); + var ctxAfterClearing = await sut.InvokingAsync(new AIContextProvider.InvokingContext(s_mockAgent, mockSession, new AIContext { Messages = new List { question } })); // Assert - Assert.Contains("Caoimhe", ctxAfterAdding.Messages?[0].Text ?? string.Empty); - Assert.DoesNotContain("Caoimhe", ctxAfterClearing.Messages?[0].Text ?? string.Empty); + Assert.Contains("Caoimhe", ctxAfterAdding.Messages?.LastOrDefault()?.Text ?? string.Empty); + Assert.DoesNotContain("Caoimhe", ctxAfterClearing.Messages?.LastOrDefault()?.Text ?? string.Empty); } [Fact(Skip = SkipReason)] @@ -96,38 +99,42 @@ public async Task DoesNotLeakMemoriesAcrossAgentScopesAsync() // Arrange var question = new ChatMessage(ChatRole.User, "What is your name?"); var assistantIntro = new ChatMessage(ChatRole.Assistant, "I'm an AI tutor and my name is Caoimhe."); - var sut1 = new Mem0Provider(this._httpClient, new Mem0ProviderScope { AgentId = "it-agent-a" }); - var sut2 = new Mem0Provider(this._httpClient, new Mem0ProviderScope { AgentId = "it-agent-b" }); + var storageScope1 = new Mem0ProviderScope { AgentId = "it-agent-a" }; + var storageScope2 = new Mem0ProviderScope { AgentId = "it-agent-b" }; + var mockSession1 = new TestAgentSession(); + var mockSession2 = new TestAgentSession(); + var sut1 = new Mem0Provider(this._httpClient, _ => new Mem0Provider.State(storageScope1)); + var sut2 = new Mem0Provider(this._httpClient, _ => new Mem0Provider.State(storageScope2)); - await sut1.ClearStoredMemoriesAsync(); - await sut2.ClearStoredMemoriesAsync(); + await sut1.ClearStoredMemoriesAsync(mockSession1); + await sut2.ClearStoredMemoriesAsync(mockSession2); - var ctxBefore1 = await sut1.InvokingAsync(new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [question])); - var ctxBefore2 = await sut2.InvokingAsync(new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [question])); - Assert.DoesNotContain("Caoimhe", ctxBefore1.Messages?[0].Text ?? string.Empty); - Assert.DoesNotContain("Caoimhe", ctxBefore2.Messages?[0].Text ?? string.Empty); + var ctxBefore1 = await sut1.InvokingAsync(new AIContextProvider.InvokingContext(s_mockAgent, mockSession1, new AIContext { Messages = new List { question } })); + var ctxBefore2 = await sut2.InvokingAsync(new AIContextProvider.InvokingContext(s_mockAgent, mockSession2, new AIContext { Messages = new List { question } })); + Assert.DoesNotContain("Caoimhe", ctxBefore1.Messages?.LastOrDefault()?.Text ?? string.Empty); + Assert.DoesNotContain("Caoimhe", ctxBefore2.Messages?.LastOrDefault()?.Text ?? string.Empty); // Act - await sut1.InvokedAsync(new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, [assistantIntro])); - var ctxAfterAdding1 = await GetContextWithRetryAsync(sut1, question); - var ctxAfterAdding2 = await GetContextWithRetryAsync(sut2, question); + await sut1.InvokedAsync(new AIContextProvider.InvokedContext(s_mockAgent, mockSession1, [assistantIntro], [])); + var ctxAfterAdding1 = await GetContextWithRetryAsync(sut1, mockSession1, question); + var ctxAfterAdding2 = await GetContextWithRetryAsync(sut2, mockSession2, question); // Assert - Assert.Contains("Caoimhe", ctxAfterAdding1.Messages?[0].Text ?? string.Empty); - Assert.DoesNotContain("Caoimhe", ctxAfterAdding2.Messages?[0].Text ?? string.Empty); + Assert.Contains("Caoimhe", ctxAfterAdding1.Messages?.LastOrDefault()?.Text ?? string.Empty); + Assert.DoesNotContain("Caoimhe", ctxAfterAdding2.Messages?.LastOrDefault()?.Text ?? string.Empty); // Cleanup - await sut1.ClearStoredMemoriesAsync(); - await sut2.ClearStoredMemoriesAsync(); + await sut1.ClearStoredMemoriesAsync(mockSession1); + await sut2.ClearStoredMemoriesAsync(mockSession2); } - private static async Task GetContextWithRetryAsync(Mem0Provider provider, ChatMessage question, int attempts = 5, int delayMs = 1000) + private static async Task GetContextWithRetryAsync(Mem0Provider provider, AgentSession session, ChatMessage question, int attempts = 5, int delayMs = 1000) { AIContext? ctx = null; for (int i = 0; i < attempts; i++) { - ctx = await provider.InvokingAsync(new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [question]), CancellationToken.None); - var text = ctx.Messages?[0].Text; + ctx = await provider.InvokingAsync(new AIContextProvider.InvokingContext(s_mockAgent, session, new AIContext { Messages = new List { question } }), CancellationToken.None); + var text = ctx.Messages?.LastOrDefault()?.Text; if (!string.IsNullOrEmpty(text) && text.IndexOf("Caoimhe", StringComparison.OrdinalIgnoreCase) >= 0) { break; @@ -141,4 +148,12 @@ public void Dispose() { this._httpClient.Dispose(); } + + private sealed class TestAgentSession : AgentSession + { + public TestAgentSession() + { + this.StateBag = new AgentSessionStateBag(); + } + } } diff --git a/dotnet/tests/Microsoft.Agents.AI.Mem0.UnitTests/Mem0ProviderTests.cs b/dotnet/tests/Microsoft.Agents.AI.Mem0.UnitTests/Mem0ProviderTests.cs index 53c87b09ba..7ad77b7df5 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Mem0.UnitTests/Mem0ProviderTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Mem0.UnitTests/Mem0ProviderTests.cs @@ -19,7 +19,6 @@ namespace Microsoft.Agents.AI.Mem0.UnitTests; public sealed class Mem0ProviderTests : IDisposable { private static readonly AIAgent s_mockAgent = new Mock().Object; - private static readonly AgentSession s_mockSession = new Mock().Object; private readonly Mock> _loggerMock; private readonly Mock _loggerFactoryMock; @@ -55,35 +54,39 @@ public void Constructor_Throws_WhenBaseAddressMissing() using HttpClient client = new(); // Act & Assert - var ex = Assert.Throws(() => new Mem0Provider(client, new Mem0ProviderScope() { ThreadId = "tid" })); + var ex = Assert.Throws(() => new Mem0Provider(client, _ => new Mem0Provider.State(new Mem0ProviderScope { ThreadId = "tid" }))); Assert.StartsWith("The HttpClient BaseAddress must be set for Mem0 operations.", ex.Message); } [Fact] - public void Constructor_Throws_WhenNoStorageScopeValueIsSet() + public void Constructor_Throws_WhenStateInitializerIsNull() { // Act & Assert - var ex = Assert.Throws(() => new Mem0Provider(this._httpClient, new Mem0ProviderScope())); - Assert.StartsWith("At least one of ApplicationId, AgentId, ThreadId, or UserId must be provided for the storage scope.", ex.Message); + var ex = Assert.Throws(() => new Mem0Provider(this._httpClient, null!)); + Assert.Contains("stateInitializer", ex.Message); } [Fact] - public void Constructor_Throws_WhenNoSearchScopeValueIsSet() + public void StateKey_ReturnsDefaultKey_WhenNoOptionsProvided() { - // Act & Assert - var ex = Assert.Throws(() => new Mem0Provider(this._httpClient, new Mem0ProviderScope() { ThreadId = "tid" }, new Mem0ProviderScope())); - Assert.StartsWith("At least one of ApplicationId, AgentId, ThreadId, or UserId must be provided for the search scope.", ex.Message); + // Arrange & Act + var provider = new Mem0Provider(this._httpClient, _ => new Mem0Provider.State(new Mem0ProviderScope { ThreadId = "tid" })); + + // Assert + Assert.Equal("Mem0Provider", provider.StateKey); } [Fact] - public void DeserializingConstructor_Throws_WithEmptyJsonElement() + public void StateKey_ReturnsCustomKey_WhenSetViaOptions() { - // Arrange - var jsonElement = JsonSerializer.SerializeToElement(new object(), Mem0JsonUtilities.DefaultOptions); + // Arrange & Act + var provider = new Mem0Provider( + this._httpClient, + _ => new Mem0Provider.State(new Mem0ProviderScope { ThreadId = "tid" }), + new Mem0ProviderOptions { StateKey = "custom-key" }); - // Act & Assert - var ex = Assert.Throws(() => new Mem0Provider(this._httpClient, jsonElement)); - Assert.StartsWith("The Mem0Provider state did not contain the required scope properties.", ex.Message); + // Assert + Assert.Equal("custom-key", provider.StateKey); } [Fact] @@ -98,8 +101,9 @@ public async Task InvokingAsync_PerformsSearch_AndReturnsContextMessageAsync() ThreadId = "session", UserId = "user" }; - var sut = new Mem0Provider(this._httpClient, storageScope, options: new() { EnableSensitiveTelemetryData = true }, loggerFactory: this._loggerFactoryMock.Object); - var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "What is my name?")]); + var mockSession = new TestAgentSession(); + var sut = new Mem0Provider(this._httpClient, _ => new Mem0Provider.State(storageScope), options: new() { EnableSensitiveTelemetryData = true }, loggerFactory: this._loggerFactoryMock.Object); + var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, mockSession, new AIContext { Messages = new List { new(ChatRole.User, "What is my name?") } }); // Act var aiContext = await sut.InvokingAsync(invokingContext); @@ -114,9 +118,13 @@ public async Task InvokingAsync_PerformsSearch_AndReturnsContextMessageAsync() Assert.Equal("What is my name?", doc.RootElement.GetProperty("query").GetString()); Assert.NotNull(aiContext.Messages); - var contextMessage = Assert.Single(aiContext.Messages); + var messages = aiContext.Messages.ToList(); + Assert.Equal(2, messages.Count); + Assert.Equal(AgentRequestMessageSourceType.External, messages[0].GetAgentRequestMessageSourceType()); + var contextMessage = messages[1]; Assert.Equal(ChatRole.User, contextMessage.Role); Assert.Contains("Name is Caoimhe", contextMessage.Text); + Assert.Equal(AgentRequestMessageSourceType.AIContextProvider, contextMessage.GetAgentRequestMessageSourceType()); this._loggerMock.Verify( l => l.Log( @@ -162,9 +170,10 @@ public async Task InvokingAsync_LogsUserIdBasedOnEnableSensitiveTelemetryDataAsy UserId = "user" }; var options = new Mem0ProviderOptions { EnableSensitiveTelemetryData = enableSensitiveTelemetryData }; + var mockSession = new TestAgentSession(); - var sut = new Mem0Provider(this._httpClient, storageScope, options: options, loggerFactory: this._loggerFactoryMock.Object); - var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Who am I?")]); + var sut = new Mem0Provider(this._httpClient, _ => new Mem0Provider.State(storageScope), options: options, loggerFactory: this._loggerFactoryMock.Object); + var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, mockSession, new AIContext { Messages = new List { new(ChatRole.User, "Who am I?") } }); // Act await sut.InvokingAsync(invokingContext, CancellationToken.None); @@ -204,7 +213,8 @@ public async Task InvokedAsync_PersistsAllowedMessagesAsync() this._handler.EnqueueEmptyOk(); // For second CreateMemory this._handler.EnqueueEmptyOk(); // For third CreateMemory var storageScope = new Mem0ProviderScope { ApplicationId = "a", AgentId = "b", ThreadId = "c", UserId = "d" }; - var sut = new Mem0Provider(this._httpClient, storageScope); + var mockSession = new TestAgentSession(); + var sut = new Mem0Provider(this._httpClient, _ => new Mem0Provider.State(storageScope)); var requestMessages = new List { @@ -218,7 +228,7 @@ public async Task InvokedAsync_PersistsAllowedMessagesAsync() }; // Act - await sut.InvokedAsync(new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages) { ResponseMessages = responseMessages }); + await sut.InvokedAsync(new AIContextProvider.InvokedContext(s_mockAgent, mockSession, requestMessages, responseMessages)); // Assert var memoryPosts = this._handler.Requests.Where(r => r.RequestMessage.RequestUri!.AbsolutePath == "/v1/memories/" && r.RequestMessage.Method == HttpMethod.Post).ToList(); @@ -235,7 +245,8 @@ public async Task InvokedAsync_PersistsNothingForFailedRequestAsync() { // Arrange var storageScope = new Mem0ProviderScope { ApplicationId = "a", AgentId = "b", ThreadId = "c", UserId = "d" }; - var sut = new Mem0Provider(this._httpClient, storageScope); + var mockSession = new TestAgentSession(); + var sut = new Mem0Provider(this._httpClient, _ => new Mem0Provider.State(storageScope)); var requestMessages = new List { @@ -245,7 +256,7 @@ public async Task InvokedAsync_PersistsNothingForFailedRequestAsync() }; // Act - await sut.InvokedAsync(new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages) { ResponseMessages = null, InvokeException = new InvalidOperationException("Request Failed") }); + await sut.InvokedAsync(new AIContextProvider.InvokedContext(s_mockAgent, mockSession, requestMessages, new InvalidOperationException("Request Failed"))); // Assert Assert.Empty(this._handler.Requests); @@ -256,7 +267,8 @@ public async Task InvokedAsync_ShouldNotThrow_WhenStorageFailsAsync() { // Arrange var storageScope = new Mem0ProviderScope { ApplicationId = "a", AgentId = "b", ThreadId = "c", UserId = "d" }; - var sut = new Mem0Provider(this._httpClient, storageScope, loggerFactory: this._loggerFactoryMock.Object); + var mockSession = new TestAgentSession(); + var sut = new Mem0Provider(this._httpClient, _ => new Mem0Provider.State(storageScope), loggerFactory: this._loggerFactoryMock.Object); this._handler.EnqueueEmptyInternalServerError(); var requestMessages = new List @@ -271,7 +283,7 @@ public async Task InvokedAsync_ShouldNotThrow_WhenStorageFailsAsync() }; // Act - await sut.InvokedAsync(new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages) { ResponseMessages = responseMessages }); + await sut.InvokedAsync(new AIContextProvider.InvokedContext(s_mockAgent, mockSession, requestMessages, responseMessages)); // Assert this._loggerMock.Verify( @@ -310,7 +322,8 @@ public async Task InvokedAsync_LogsUserIdBasedOnEnableSensitiveTelemetryDataAsyn }; var options = new Mem0ProviderOptions { EnableSensitiveTelemetryData = enableSensitiveTelemetryData }; - var sut = new Mem0Provider(this._httpClient, storageScope, options: options, loggerFactory: this._loggerFactoryMock.Object); + var mockSession = new TestAgentSession(); + var sut = new Mem0Provider(this._httpClient, _ => new Mem0Provider.State(storageScope), options: options, loggerFactory: this._loggerFactoryMock.Object); var requestMessages = new List { new(ChatRole.User, "User text") @@ -321,7 +334,7 @@ public async Task InvokedAsync_LogsUserIdBasedOnEnableSensitiveTelemetryDataAsyn }; // Act - await sut.InvokedAsync(new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages) { ResponseMessages = responseMessages }); + await sut.InvokedAsync(new AIContextProvider.InvokedContext(s_mockAgent, mockSession, requestMessages, responseMessages)); // Assert Assert.Equal(expectedLogCount, this._loggerMock.Invocations.Count); @@ -343,11 +356,12 @@ public async Task ClearStoredMemoriesAsync_SendsDeleteWithQueryAsync() { // Arrange var storageScope = new Mem0ProviderScope { ApplicationId = "app", AgentId = "agent", ThreadId = "session", UserId = "user" }; - var sut = new Mem0Provider(this._httpClient, storageScope); + var sut = new Mem0Provider(this._httpClient, _ => new Mem0Provider.State(storageScope)); this._handler.EnqueueEmptyOk(); // for DELETE + var mockSession = new TestAgentSession(); // Act - await sut.ClearStoredMemoriesAsync(); + await sut.ClearStoredMemoriesAsync(mockSession); // Assert var delete = Assert.Single(this._handler.Requests, r => r.RequestMessage.Method == HttpMethod.Delete); @@ -355,70 +369,182 @@ public async Task ClearStoredMemoriesAsync_SendsDeleteWithQueryAsync() } [Fact] - public void Serialize_RoundTripsScopes() + public async Task InvokingAsync_ShouldNotThrow_WhenSearchFailsAsync() { // Arrange - var storageScope = new Mem0ProviderScope { ApplicationId = "app", AgentId = "agent", ThreadId = "session", UserId = "user" }; - var sut = new Mem0Provider(this._httpClient, storageScope, options: new() { ContextPrompt = "Custom:" }, loggerFactory: this._loggerFactoryMock.Object); + var storageScope = new Mem0ProviderScope { ApplicationId = "app" }; + var mockSession = new TestAgentSession(); + var provider = new Mem0Provider(this._httpClient, _ => new Mem0Provider.State(storageScope), loggerFactory: this._loggerFactoryMock.Object); + var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, mockSession, new AIContext { Messages = new List { new(ChatRole.User, "Q?") } }); // Act - var stateElement = sut.Serialize(); - using JsonDocument doc = JsonDocument.Parse(stateElement.GetRawText()); - var storageScopeElement = doc.RootElement.GetProperty("storageScope"); - Assert.Equal("app", storageScopeElement.GetProperty("applicationId").GetString()); - Assert.Equal("agent", storageScopeElement.GetProperty("agentId").GetString()); - Assert.Equal("session", storageScopeElement.GetProperty("threadId").GetString()); - Assert.Equal("user", storageScopeElement.GetProperty("userId").GetString()); - - var sut2 = new Mem0Provider(this._httpClient, stateElement); - var stateElement2 = sut2.Serialize(); + var aiContext = await provider.InvokingAsync(invokingContext, CancellationToken.None); // Assert - using JsonDocument doc2 = JsonDocument.Parse(stateElement2.GetRawText()); - var storageScopeElement2 = doc2.RootElement.GetProperty("storageScope"); - Assert.Equal("app", storageScopeElement2.GetProperty("applicationId").GetString()); - Assert.Equal("agent", storageScopeElement2.GetProperty("agentId").GetString()); - Assert.Equal("session", storageScopeElement2.GetProperty("threadId").GetString()); - Assert.Equal("user", storageScopeElement2.GetProperty("userId").GetString()); + Assert.NotNull(aiContext.Messages); + Assert.Single(aiContext.Messages); + Assert.Null(aiContext.Tools); + this._loggerMock.Verify( + l => l.Log( + LogLevel.Error, + It.IsAny(), + It.Is((v, t) => v.ToString()!.Contains("Mem0AIContextProvider: Failed to search Mem0 for memories due to error")), + It.IsAny(), + It.IsAny>()), + Times.Once); } [Fact] - public void Serialize_DoesNotIncludeDefaultContextPrompt() + public async Task StateInitializer_IsCalledOnceAndStoredInStateBagAsync() { // Arrange + this._handler.EnqueueJsonResponse("[]"); + this._handler.EnqueueJsonResponse("[]"); var storageScope = new Mem0ProviderScope { ApplicationId = "app" }; - var sut = new Mem0Provider(this._httpClient, storageScope); + var mockSession = new TestAgentSession(); + int initializerCallCount = 0; + var sut = new Mem0Provider(this._httpClient, _ => + { + initializerCallCount++; + return new Mem0Provider.State(storageScope); + }); + var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, mockSession, new AIContext { Messages = new List { new(ChatRole.User, "Q?") } }); // Act - var stateElement = sut.Serialize(); + await sut.InvokingAsync(invokingContext, CancellationToken.None); + await sut.InvokingAsync(invokingContext, CancellationToken.None); // Assert - using JsonDocument doc = JsonDocument.Parse(stateElement.GetRawText()); - Assert.False(doc.RootElement.TryGetProperty("contextPrompt", out _)); + Assert.Equal(1, initializerCallCount); } [Fact] - public async Task InvokingAsync_ShouldNotThrow_WhenSearchFailsAsync() + public async Task StateKey_CanBeConfiguredViaOptionsAsync() { // Arrange + this._handler.EnqueueJsonResponse("[]"); var storageScope = new Mem0ProviderScope { ApplicationId = "app" }; - var provider = new Mem0Provider(this._httpClient, storageScope, loggerFactory: this._loggerFactoryMock.Object); - var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Q?")]); + var mockSession = new TestAgentSession(); + const string CustomKey = "MyCustomKey"; + var sut = new Mem0Provider(this._httpClient, _ => new Mem0Provider.State(storageScope), options: new() { StateKey = CustomKey }); + var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, mockSession, new AIContext { Messages = new List { new(ChatRole.User, "Q?") } }); // Act - var aiContext = await provider.InvokingAsync(invokingContext, CancellationToken.None); + await sut.InvokingAsync(invokingContext, CancellationToken.None); // Assert - Assert.Null(aiContext.Messages); - Assert.Null(aiContext.Tools); - this._loggerMock.Verify( - l => l.Log( - LogLevel.Error, - It.IsAny(), - It.Is((v, t) => v.ToString()!.Contains("Mem0AIContextProvider: Failed to search Mem0 for memories due to error")), - It.IsAny(), - It.IsAny>()), - Times.Once); + Assert.True(mockSession.StateBag.TryGetValue(CustomKey, out var state, Mem0JsonUtilities.DefaultOptions)); + Assert.NotNull(state); + } + + [Fact] + public async Task InvokingAsync_DefaultFilter_ExcludesNonExternalMessagesFromSearchAsync() + { + // Arrange + this._handler.EnqueueJsonResponse("[]"); // Empty search results + var storageScope = new Mem0ProviderScope { ApplicationId = "app", AgentId = "agent", ThreadId = "session", UserId = "user" }; + var mockSession = new TestAgentSession(); + var sut = new Mem0Provider(this._httpClient, _ => new Mem0Provider.State(storageScope)); + + var requestMessages = new List + { + new(ChatRole.User, "External message"), + new(ChatRole.System, "From history") { AdditionalProperties = new() { { AgentRequestMessageSourceAttribution.AdditionalPropertiesKey, new AgentRequestMessageSourceAttribution(AgentRequestMessageSourceType.ChatHistory, "HistorySource") } } }, + new(ChatRole.System, "From context provider") { AdditionalProperties = new() { { AgentRequestMessageSourceAttribution.AdditionalPropertiesKey, new AgentRequestMessageSourceAttribution(AgentRequestMessageSourceType.AIContextProvider, "ContextSource") } } }, + }; + + var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, mockSession, new AIContext { Messages = requestMessages }); + + // Act + await sut.InvokingAsync(invokingContext, CancellationToken.None); + + // Assert - Search query should only contain the External message + var searchRequest = Assert.Single(this._handler.Requests, r => r.RequestMessage.Method == HttpMethod.Post); + using JsonDocument doc = JsonDocument.Parse(searchRequest.RequestBody); + Assert.Equal("External message", doc.RootElement.GetProperty("query").GetString()); + } + + [Fact] + public async Task InvokingAsync_CustomSearchInputFilter_OverridesDefaultAsync() + { + // Arrange + this._handler.EnqueueJsonResponse("[]"); // Empty search results + var storageScope = new Mem0ProviderScope { ApplicationId = "app", AgentId = "agent", ThreadId = "session", UserId = "user" }; + var mockSession = new TestAgentSession(); + var sut = new Mem0Provider(this._httpClient, _ => new Mem0Provider.State(storageScope), options: new Mem0ProviderOptions + { + SearchInputMessageFilter = messages => messages // No filtering + }); + + var requestMessages = new List + { + new(ChatRole.User, "External message"), + new(ChatRole.System, "From history") { AdditionalProperties = new() { { AgentRequestMessageSourceAttribution.AdditionalPropertiesKey, new AgentRequestMessageSourceAttribution(AgentRequestMessageSourceType.ChatHistory, "HistorySource") } } }, + }; + + var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, mockSession, new AIContext { Messages = requestMessages }); + + // Act + await sut.InvokingAsync(invokingContext, CancellationToken.None); + + // Assert - Search query should contain all messages (custom identity filter) + var searchRequest = Assert.Single(this._handler.Requests, r => r.RequestMessage.Method == HttpMethod.Post); + using JsonDocument doc = JsonDocument.Parse(searchRequest.RequestBody); + var queryText = doc.RootElement.GetProperty("query").GetString(); + Assert.Contains("External message", queryText); + Assert.Contains("From history", queryText); + } + + [Fact] + public async Task InvokedAsync_DefaultFilter_ExcludesNonExternalMessagesFromStorageAsync() + { + // Arrange + this._handler.EnqueueEmptyOk(); // For the one message that should be stored + var storageScope = new Mem0ProviderScope { ApplicationId = "a", AgentId = "b", ThreadId = "c", UserId = "d" }; + var mockSession = new TestAgentSession(); + var sut = new Mem0Provider(this._httpClient, _ => new Mem0Provider.State(storageScope)); + + var requestMessages = new List + { + new(ChatRole.User, "External message"), + new(ChatRole.System, "From history") { AdditionalProperties = new() { { AgentRequestMessageSourceAttribution.AdditionalPropertiesKey, new AgentRequestMessageSourceAttribution(AgentRequestMessageSourceType.ChatHistory, "HistorySource") } } }, + }; + + // Act + await sut.InvokedAsync(new AIContextProvider.InvokedContext(s_mockAgent, mockSession, requestMessages, [])); + + // Assert - Only the External message should be persisted + var memoryPosts = this._handler.Requests.Where(r => r.RequestMessage.RequestUri!.AbsolutePath == "/v1/memories/" && r.RequestMessage.Method == HttpMethod.Post).ToList(); + Assert.Single(memoryPosts); + Assert.Contains("External message", memoryPosts[0].RequestBody); + Assert.DoesNotContain(memoryPosts, r => ContainsOrdinal(r.RequestBody, "From history")); + } + + [Fact] + public async Task InvokedAsync_CustomStorageInputFilter_OverridesDefaultAsync() + { + // Arrange + this._handler.EnqueueEmptyOk(); // For first CreateMemory + this._handler.EnqueueEmptyOk(); // For second CreateMemory + var storageScope = new Mem0ProviderScope { ApplicationId = "a", AgentId = "b", ThreadId = "c", UserId = "d" }; + var mockSession = new TestAgentSession(); + var sut = new Mem0Provider(this._httpClient, _ => new Mem0Provider.State(storageScope), options: new Mem0ProviderOptions + { + StorageInputMessageFilter = messages => messages // No filtering - store everything + }); + + var requestMessages = new List + { + new(ChatRole.User, "External message"), + new(ChatRole.System, "From history") { AdditionalProperties = new() { { AgentRequestMessageSourceAttribution.AdditionalPropertiesKey, new AgentRequestMessageSourceAttribution(AgentRequestMessageSourceType.ChatHistory, "HistorySource") } } }, + }; + + // Act + await sut.InvokedAsync(new AIContextProvider.InvokedContext(s_mockAgent, mockSession, requestMessages, [])); + + // Assert - Both messages should be persisted (identity filter overrides default) + var memoryPosts = this._handler.Requests.Where(r => r.RequestMessage.RequestUri!.AbsolutePath == "/v1/memories/" && r.RequestMessage.Method == HttpMethod.Post).ToList(); + Assert.Equal(2, memoryPosts.Count); } private static bool ContainsOrdinal(string source, string value) => source.IndexOf(value, StringComparison.Ordinal) >= 0; @@ -465,4 +591,12 @@ public void EnqueueJsonResponse(string json) public void EnqueueEmptyInternalServerError() => this._responses.Enqueue(new HttpResponseMessage(System.Net.HttpStatusCode.InternalServerError)); } + + private sealed class TestAgentSession : AgentSession + { + public TestAgentSession() + { + this.StateBag = new AgentSessionStateBag(); + } + } } diff --git a/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgentOptionsTests.cs b/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgentOptionsTests.cs index 8502550d2c..f69fb3d636 100644 --- a/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgentOptionsTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgentOptionsTests.cs @@ -1,8 +1,6 @@ // Copyright (c) Microsoft. All rights reserved. using System.Collections.Generic; -using System.Threading; -using System.Threading.Tasks; using Microsoft.Extensions.AI; using Moq; @@ -23,8 +21,8 @@ public void DefaultConstructor_InitializesWithNullValues() Assert.Null(options.Name); Assert.Null(options.Description); Assert.Null(options.ChatOptions); - Assert.Null(options.ChatHistoryProviderFactory); - Assert.Null(options.AIContextProviderFactory); + Assert.Null(options.ChatHistoryProvider); + Assert.Null(options.AIContextProviders); } [Fact] @@ -36,8 +34,8 @@ public void Constructor_WithNullValues_SetsPropertiesCorrectly() // Assert Assert.Null(options.Name); Assert.Null(options.Description); - Assert.Null(options.AIContextProviderFactory); - Assert.Null(options.ChatHistoryProviderFactory); + Assert.Null(options.AIContextProviders); + Assert.Null(options.ChatHistoryProvider); Assert.NotNull(options.ChatOptions); Assert.Null(options.ChatOptions.Instructions); Assert.Null(options.ChatOptions.Tools); @@ -117,11 +115,8 @@ public void Clone_CreatesDeepCopyWithSameValues() const string Description = "Test description"; var tools = new List { AIFunctionFactory.Create(() => "test") }; - static ValueTask ChatHistoryProviderFactoryAsync( - ChatClientAgentOptions.ChatHistoryProviderFactoryContext ctx, CancellationToken ct) => new(new Mock().Object); - - static ValueTask AIContextProviderFactoryAsync( - ChatClientAgentOptions.AIContextProviderFactoryContext ctx, CancellationToken ct) => new(new Mock().Object); + var mockChatHistoryProvider = new Mock().Object; + var mockAIContextProvider = new Mock().Object; var original = new ChatClientAgentOptions() { @@ -129,8 +124,8 @@ static ValueTask AIContextProviderFactoryAsync( Description = Description, ChatOptions = new() { Tools = tools }, Id = "test-id", - ChatHistoryProviderFactory = ChatHistoryProviderFactoryAsync, - AIContextProviderFactory = AIContextProviderFactoryAsync + ChatHistoryProvider = mockChatHistoryProvider, + AIContextProviders = [mockAIContextProvider] }; // Act @@ -141,8 +136,8 @@ static ValueTask AIContextProviderFactoryAsync( Assert.Equal(original.Id, clone.Id); Assert.Equal(original.Name, clone.Name); Assert.Equal(original.Description, clone.Description); - Assert.Same(original.ChatHistoryProviderFactory, clone.ChatHistoryProviderFactory); - Assert.Same(original.AIContextProviderFactory, clone.AIContextProviderFactory); + Assert.Same(original.ChatHistoryProvider, clone.ChatHistoryProvider); + Assert.Equal(original.AIContextProviders, clone.AIContextProviders); // ChatOptions should be cloned, not the same reference Assert.NotSame(original.ChatOptions, clone.ChatOptions); @@ -154,11 +149,16 @@ static ValueTask AIContextProviderFactoryAsync( public void Clone_WithoutProvidingChatOptions_ClonesCorrectly() { // Arrange + var mockChatHistoryProvider = new Mock().Object; + var mockAIContextProvider = new Mock().Object; + var original = new ChatClientAgentOptions { Id = "test-id", Name = "Test name", - Description = "Test description" + Description = "Test description", + ChatHistoryProvider = mockChatHistoryProvider, + AIContextProviders = [mockAIContextProvider] }; // Act @@ -170,8 +170,8 @@ public void Clone_WithoutProvidingChatOptions_ClonesCorrectly() Assert.Equal(original.Name, clone.Name); Assert.Equal(original.Description, clone.Description); Assert.Null(original.ChatOptions); - Assert.Null(clone.ChatHistoryProviderFactory); - Assert.Null(clone.AIContextProviderFactory); + Assert.Same(original.ChatHistoryProvider, clone.ChatHistoryProvider); + Assert.Equal(original.AIContextProviders, clone.AIContextProviders); } private static void AssertSameTools(IList? expected, IList? actual) diff --git a/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgentSessionTests.cs b/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgentSessionTests.cs index fd311f9225..1f5e5aa9cd 100644 --- a/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgentSessionTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgentSessionTests.cs @@ -1,12 +1,9 @@ // Copyright (c) Microsoft. All rights reserved. using System; -using System.Collections.Generic; using System.Linq; using System.Text.Json; -using System.Threading.Tasks; using Microsoft.Extensions.AI; -using Moq; #pragma warning disable CA1861 // Avoid constant arrays as arguments @@ -24,7 +21,6 @@ public void ConstructorSetsDefaults() // Assert Assert.Null(session.ConversationId); - Assert.Null(session.ChatHistoryProvider); } [Fact] @@ -39,53 +35,6 @@ public void SetConversationIdRoundtrips() // Assert Assert.Equal(ConversationId, session.ConversationId); - Assert.Null(session.ChatHistoryProvider); - } - - [Fact] - public void SetChatHistoryProviderRoundtrips() - { - // Arrange - var session = new ChatClientAgentSession(); - var chatHistoryProvider = new InMemoryChatHistoryProvider(); - - // Act - session.ChatHistoryProvider = chatHistoryProvider; - - // Assert - Assert.Same(chatHistoryProvider, session.ChatHistoryProvider); - Assert.Null(session.ConversationId); - } - - [Fact] - public void SetConversationIdThrowsWhenChatHistoryProviderIsSet() - { - // Arrange - var session = new ChatClientAgentSession - { - ChatHistoryProvider = new InMemoryChatHistoryProvider() - }; - - // Act & Assert - var exception = Assert.Throws(() => session.ConversationId = "new-session-id"); - Assert.Equal("Only the ConversationId or ChatHistoryProvider may be set, but not both and switching from one to another is not supported.", exception.Message); - Assert.NotNull(session.ChatHistoryProvider); - } - - [Fact] - public void SetChatHistoryProviderThrowsWhenConversationIdIsSet() - { - // Arrange - var session = new ChatClientAgentSession - { - ConversationId = "existing-session-id" - }; - var provider = new InMemoryChatHistoryProvider(); - - // Act & Assert - var exception = Assert.Throws(() => session.ChatHistoryProvider = provider); - Assert.Equal("Only the ConversationId or ChatHistoryProvider may be set, but not both and switching from one to another is not supported.", exception.Message); - Assert.NotNull(session.ConversationId); } #endregion Constructor and Property Tests @@ -93,29 +42,33 @@ public void SetChatHistoryProviderThrowsWhenConversationIdIsSet() #region Deserialize Tests [Fact] - public async Task VerifyDeserializeWithMessagesAsync() + public void VerifyDeserializeWithMessages() { // Arrange var json = JsonSerializer.Deserialize(""" { - "chatHistoryProviderState": { "messages": [{"authorName": "testAuthor"}] } + "stateBag": { + "InMemoryChatHistoryProvider": { + "messages": [{"authorName": "testAuthor"}] + } + } } """, TestJsonSerializerContext.Default.JsonElement); // Act. - var session = await ChatClientAgentSession.DeserializeAsync(json); + var session = ChatClientAgentSession.Deserialize(json, TestJsonSerializerContext.Default.Options); // Assert Assert.Null(session.ConversationId); - var chatHistoryProvider = session.ChatHistoryProvider as InMemoryChatHistoryProvider; - Assert.NotNull(chatHistoryProvider); - Assert.Single(chatHistoryProvider); - Assert.Equal("testAuthor", chatHistoryProvider[0].AuthorName); + var chatHistoryProvider = new InMemoryChatHistoryProvider(); + var messages = chatHistoryProvider.GetMessages(session); + Assert.Single(messages); + Assert.Equal("testAuthor", messages[0].AuthorName); } [Fact] - public async Task VerifyDeserializeWithIdAsync() + public void VerifyDeserializeWithId() { // Arrange var json = JsonSerializer.Deserialize(""" @@ -125,42 +78,43 @@ public async Task VerifyDeserializeWithIdAsync() """, TestJsonSerializerContext.Default.JsonElement); // Act - var session = await ChatClientAgentSession.DeserializeAsync(json); + var session = ChatClientAgentSession.Deserialize(json); // Assert Assert.Equal("TestConvId", session.ConversationId); - Assert.Null(session.ChatHistoryProvider); } [Fact] - public async Task VerifyDeserializeWithAIContextProviderAsync() + public void VerifyDeserializeWithStateBag() { // Arrange var json = JsonSerializer.Deserialize(""" { "conversationId": "TestConvId", - "aiContextProviderState": ["CP1"] + "stateBag": { + "dog": { + "name": "Fido" + } + } } """, TestJsonSerializerContext.Default.JsonElement); - Mock mockProvider = new(); - // Act - var session = await ChatClientAgentSession.DeserializeAsync(json, aiContextProviderFactory: (_, _, _) => new(mockProvider.Object)); + var session = ChatClientAgentSession.Deserialize(json); // Assert - Assert.Null(session.ChatHistoryProvider); - Assert.Same(session.AIContextProvider, mockProvider.Object); + var dog = session.StateBag.GetValue("dog", TestJsonSerializerContext.Default.Options); + Assert.NotNull(dog); + Assert.Equal("Fido", dog.Name); } [Fact] - public async Task DeserializeWithInvalidJsonThrowsAsync() + public void DeserializeWithInvalidJsonThrows() { // Arrange var invalidJson = JsonSerializer.Deserialize("[42]", TestJsonSerializerContext.Default.JsonElement); - var session = new ChatClientAgentSession(); // Act & Assert - await Assert.ThrowsAsync(() => ChatClientAgentSession.DeserializeAsync(invalidJson)); + Assert.Throws(() => ChatClientAgentSession.Deserialize(invalidJson)); } #endregion Deserialize Tests @@ -195,8 +149,9 @@ public void VerifySessionSerializationWithId() public void VerifySessionSerializationWithMessages() { // Arrange - InMemoryChatHistoryProvider provider = [new(ChatRole.User, "TestContent") { AuthorName = "TestAuthor" }]; - var session = new ChatClientAgentSession { ChatHistoryProvider = provider }; + var provider = new InMemoryChatHistoryProvider(); + var session = new ChatClientAgentSession(); + provider.SetMessages(session, [new(ChatRole.User, "TestContent") { AuthorName = "TestAuthor" }]); // Act var json = session.Serialize(); @@ -206,10 +161,12 @@ public void VerifySessionSerializationWithMessages() Assert.False(json.TryGetProperty("conversationId", out _)); - Assert.True(json.TryGetProperty("chatHistoryProviderState", out var chatHistoryProviderStateProperty)); - Assert.Equal(JsonValueKind.Object, chatHistoryProviderStateProperty.ValueKind); - - Assert.True(chatHistoryProviderStateProperty.TryGetProperty("messages", out var messagesProperty)); + // Messages should be stored in the stateBag + Assert.True(json.TryGetProperty("stateBag", out var stateBagProperty)); + Assert.Equal(JsonValueKind.Object, stateBagProperty.ValueKind); + Assert.True(stateBagProperty.TryGetProperty("InMemoryChatHistoryProvider", out var providerStateProperty)); + Assert.Equal(JsonValueKind.Object, providerStateProperty.ValueKind); + Assert.True(providerStateProperty.TryGetProperty("messages", out var messagesProperty)); Assert.Equal(JsonValueKind.Array, messagesProperty.ValueKind); Assert.Single(messagesProperty.EnumerateArray()); @@ -224,29 +181,23 @@ public void VerifySessionSerializationWithMessages() } [Fact] - public void VerifySessionSerializationWithWithAIContextProvider() + public void VerifySessionSerializationWithWithStateBag() { // Arrange - Mock mockProvider = new(); - mockProvider - .Setup(m => m.Serialize(It.IsAny())) - .Returns(JsonSerializer.SerializeToElement(["CP1"], TestJsonSerializerContext.Default.StringArray)); - - var session = new ChatClientAgentSession - { - AIContextProvider = mockProvider.Object - }; + var session = new ChatClientAgentSession(); + session.StateBag.SetValue("dog", new Animal { Name = "Fido" }, TestJsonSerializerContext.Default.Options); // Act var json = session.Serialize(); // Assert Assert.Equal(JsonValueKind.Object, json.ValueKind); - Assert.True(json.TryGetProperty("aiContextProviderState", out var providerStateProperty)); - Assert.Equal(JsonValueKind.Array, providerStateProperty.ValueKind); - Assert.Single(providerStateProperty.EnumerateArray()); - Assert.Equal("CP1", providerStateProperty.EnumerateArray().First().GetString()); - mockProvider.Verify(m => m.Serialize(It.IsAny()), Times.Once); + Assert.True(json.TryGetProperty("stateBag", out var stateBagProperty)); + Assert.Equal(JsonValueKind.Object, stateBagProperty.ValueKind); + Assert.True(stateBagProperty.TryGetProperty("dog", out var dogProperty)); + Assert.Equal(JsonValueKind.Object, dogProperty.ValueKind); + Assert.True(dogProperty.TryGetProperty("name", out var nameProperty)); + Assert.Equal("Fido", nameProperty.GetString()); } /// @@ -258,17 +209,7 @@ public void VerifySessionSerializationWithCustomOptions() // Arrange var session = new ChatClientAgentSession(); JsonSerializerOptions options = new() { PropertyNamingPolicy = JsonNamingPolicy.SnakeCaseLower }; - options.TypeInfoResolverChain.Add(AgentAbstractionsJsonUtilities.DefaultOptions.TypeInfoResolver!); - - var chatHistoryProviderStateElement = JsonSerializer.SerializeToElement( - new Dictionary { ["Key"] = "TestValue" }, - TestJsonSerializerContext.Default.DictionaryStringObject); - - var chatHistoryProviderMock = new Mock(); - chatHistoryProviderMock - .Setup(m => m.Serialize(options)) - .Returns(chatHistoryProviderStateElement); - session.ChatHistoryProvider = chatHistoryProviderMock.Object; + options.TypeInfoResolverChain.Add(AgentJsonUtilities.DefaultOptions.TypeInfoResolver!); // Act var json = session.Serialize(options); @@ -276,54 +217,29 @@ public void VerifySessionSerializationWithCustomOptions() // Assert Assert.Equal(JsonValueKind.Object, json.ValueKind); - Assert.False(json.TryGetProperty("conversationId", out var idProperty)); - - Assert.True(json.TryGetProperty("chatHistoryProviderState", out var chatHistoryProviderStateProperty)); - Assert.Equal(JsonValueKind.Object, chatHistoryProviderStateProperty.ValueKind); - - Assert.True(chatHistoryProviderStateProperty.TryGetProperty("Key", out var keyProperty)); - Assert.Equal("TestValue", keyProperty.GetString()); - - chatHistoryProviderMock.Verify(m => m.Serialize(options), Times.Once); + // [JsonPropertyName] takes precedence over naming policy + Assert.True(json.TryGetProperty("conversationId", out var _)); } #endregion Serialize Tests - #region GetService Tests - - [Fact] - public void GetService_RequestingAIContextProvider_ReturnsAIContextProvider() - { - // Arrange - var session = new ChatClientAgentSession(); - var mockProvider = new Mock(); - mockProvider - .Setup(m => m.GetService(It.Is(x => x == typeof(AIContextProvider)), null)) - .Returns(mockProvider.Object); - session.AIContextProvider = mockProvider.Object; - - // Act - var result = session.GetService(typeof(AIContextProvider)); - - // Assert - Assert.NotNull(result); - Assert.Same(mockProvider.Object, result); - } + #region StateBag Roundtrip Tests [Fact] - public void GetService_RequestingChatHistoryProvider_ReturnsChatHistoryProvider() + public void VerifyStateBagRoundtrips() { // Arrange var session = new ChatClientAgentSession(); - var chatHistoryProvider = new InMemoryChatHistoryProvider(); - session.ChatHistoryProvider = chatHistoryProvider; + session.StateBag.SetValue("dog", new Animal { Name = "Fido" }, TestJsonSerializerContext.Default.Options); // Act - var result = session.GetService(typeof(ChatHistoryProvider)); + var serializedSession = session.Serialize(); + var deserializedSession = ChatClientAgentSession.Deserialize(serializedSession); // Assert - Assert.NotNull(result); - Assert.Same(chatHistoryProvider, result); + var dog = deserializedSession.StateBag.GetValue("dog", TestJsonSerializerContext.Default.Options); + Assert.NotNull(dog); + Assert.Equal("Fido", dog.Name); } #endregion diff --git a/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgentTests.cs b/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgentTests.cs index c23e8cffaf..7b33cbbd5f 100644 --- a/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgentTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgentTests.cs @@ -45,6 +45,154 @@ public void VerifyChatClientAgentDefinition() Assert.Equal("FunctionInvokingChatClient", agent.ChatClient.GetType().Name); } + /// + /// Verify that the constructor throws when two AIContextProviders use the same StateKey. + /// + [Fact] + public void Constructor_ThrowsWhenDuplicateAIContextProviderStateKeys() + { + // Arrange + var chatClient = new Mock().Object; + var provider1 = new TestAIContextProvider("SharedKey"); + var provider2 = new TestAIContextProvider("SharedKey"); + + // Act & Assert + var ex = Assert.Throws(() => + new ChatClientAgent(chatClient, options: new() + { + AIContextProviders = [provider1, provider2] + })); + + Assert.Contains("SharedKey", ex.Message); + } + + /// + /// Verify that the constructor throws when an AIContextProvider uses the same StateKey as the default InMemoryChatHistoryProvider + /// and no explicit ChatHistoryProvider is configured. + /// + [Fact] + public void Constructor_ThrowsWhenAIContextProviderStateKeyClashesWithDefaultInMemoryChatHistoryProvider() + { + // Arrange + var chatClient = new Mock().Object; + var contextProvider = new TestAIContextProvider(nameof(InMemoryChatHistoryProvider)); + + // Act & Assert + var ex = Assert.Throws(() => + new ChatClientAgent(chatClient, options: new() + { + AIContextProviders = [contextProvider] + })); + + Assert.Contains(nameof(InMemoryChatHistoryProvider), ex.Message); + } + + /// + /// Verify that the constructor throws when a ChatHistoryProvider uses the same StateKey as an AIContextProvider. + /// + [Fact] + public void Constructor_ThrowsWhenChatHistoryProviderStateKeyClashesWithAIContextProvider() + { + // Arrange + var chatClient = new Mock().Object; + var contextProvider = new TestAIContextProvider("SharedKey"); + var historyProvider = new TestChatHistoryProvider("SharedKey"); + + // Act & Assert + var ex = Assert.Throws(() => + new ChatClientAgent(chatClient, options: new() + { + AIContextProviders = [contextProvider], + ChatHistoryProvider = historyProvider + })); + + Assert.Contains("SharedKey", ex.Message); + Assert.Contains(nameof(ChatHistoryProvider), ex.Message); + } + + /// + /// Verify that the constructor succeeds when all providers use unique StateKeys. + /// + [Fact] + public void Constructor_SucceedsWithUniqueProviderStateKeys() + { + // Arrange + var chatClient = new Mock().Object; + var contextProvider1 = new TestAIContextProvider("Key1"); + var contextProvider2 = new TestAIContextProvider("Key2"); + var historyProvider = new TestChatHistoryProvider("Key3"); + + // Act & Assert - should not throw + _ = new ChatClientAgent(chatClient, options: new() + { + AIContextProviders = [contextProvider1, contextProvider2], + ChatHistoryProvider = historyProvider + }); + } + + /// + /// Verify that RunAsync throws when an override ChatHistoryProvider's StateKey clashes with an AIContextProvider. + /// + [Fact] + public async Task RunAsync_ThrowsWhenOverrideChatHistoryProviderStateKeyClashesWithAIContextProviderAsync() + { + // Arrange + Mock mockService = new(); + mockService.Setup( + s => s.GetResponseAsync( + It.IsAny>(), + It.IsAny(), + It.IsAny())).ReturnsAsync(new ChatResponse([new(ChatRole.Assistant, "response")])); + + var contextProvider = new TestAIContextProvider("SharedKey"); + var overrideHistoryProvider = new TestChatHistoryProvider("SharedKey"); + + ChatClientAgent agent = new(mockService.Object, options: new() + { + AIContextProviders = [contextProvider] + }); + + // Act & Assert + ChatClientAgentSession? session = await agent.CreateSessionAsync() as ChatClientAgentSession; + AdditionalPropertiesDictionary additionalProperties = new(); + additionalProperties.Add(overrideHistoryProvider); + + var ex = await Assert.ThrowsAsync(() => + agent.RunAsync([new(ChatRole.User, "test")], session, options: new AgentRunOptions { AdditionalProperties = additionalProperties })); + + Assert.Contains("SharedKey", ex.Message); + } + + /// + /// Verify that RunAsync succeeds when an override ChatHistoryProvider uses the same StateKey as the default ChatHistoryProvider. + /// + [Fact] + public async Task RunAsync_SucceedsWhenOverrideChatHistoryProviderSharesKeyWithDefaultAsync() + { + // Arrange + Mock mockService = new(); + mockService.Setup( + s => s.GetResponseAsync( + It.IsAny>(), + It.IsAny(), + It.IsAny())).ReturnsAsync(new ChatResponse([new(ChatRole.Assistant, "response")])); + + var defaultHistoryProvider = new TestChatHistoryProvider("SameKey"); + var overrideHistoryProvider = new TestChatHistoryProvider("SameKey"); + + ChatClientAgent agent = new(mockService.Object, options: new() + { + ChatHistoryProvider = defaultHistoryProvider + }); + + // Act & Assert - should not throw + ChatClientAgentSession? session = await agent.CreateSessionAsync() as ChatClientAgentSession; + AdditionalPropertiesDictionary additionalProperties = new(); + additionalProperties.Add(overrideHistoryProvider); + + await agent.RunAsync([new(ChatRole.User, "test")], session, options: new AgentRunOptions { AdditionalProperties = additionalProperties }); + } + #endregion #region RunAsync Tests @@ -345,18 +493,19 @@ public async Task RunAsyncInvokesAIContextProviderAndUsesResultAsync() mockProvider .Protected() .Setup>("InvokingCoreAsync", ItExpr.IsAny(), ItExpr.IsAny()) - .ReturnsAsync(new AIContext - { - Messages = aiContextProviderMessages, - Instructions = "context provider instructions", - Tools = [AIFunctionFactory.Create(() => { }, "context provider function")] - }); + .Returns((AIContextProvider.InvokingContext ctx, CancellationToken _) => + new ValueTask(new AIContext + { + Messages = (ctx.AIContext.Messages ?? []).Concat(aiContextProviderMessages), + Instructions = ctx.AIContext.Instructions + "\ncontext provider instructions", + Tools = (ctx.AIContext.Tools ?? []).Concat(new[] { AIFunctionFactory.Create(() => { }, "context provider function") }) + })); mockProvider .Protected() .Setup("InvokedCoreAsync", ItExpr.IsAny(), ItExpr.IsAny()) .Returns(new ValueTask()); - ChatClientAgent agent = new(mockService.Object, options: new() { AIContextProviderFactory = (_, _) => new(mockProvider.Object), ChatOptions = new() { Instructions = "base instructions", Tools = [AIFunctionFactory.Create(() => { }, "base function")] } }); + ChatClientAgent agent = new(mockService.Object, options: new() { AIContextProviders = [mockProvider.Object], ChatOptions = new() { Instructions = "base instructions", Tools = [AIFunctionFactory.Create(() => { }, "base function")] } }); // Act var session = await agent.CreateSessionAsync() as ChatClientAgentSession; @@ -375,11 +524,13 @@ public async Task RunAsyncInvokesAIContextProviderAndUsesResultAsync() Assert.Contains(capturedTools, t => t.Name == "context provider function"); // Verify that the session was updated with the ai context provider, input and response messages - var chatHistoryProvider = Assert.IsType(session!.ChatHistoryProvider); - Assert.Equal(3, chatHistoryProvider.Count); - Assert.Equal("user message", chatHistoryProvider[0].Text); - Assert.Equal("context provider message", chatHistoryProvider[1].Text); - Assert.Equal("response", chatHistoryProvider[2].Text); + var chatHistoryProvider = agent.ChatHistoryProvider as InMemoryChatHistoryProvider; + Assert.NotNull(chatHistoryProvider); + var messages = chatHistoryProvider.GetMessages(session); + Assert.Equal(3, messages.Count); + Assert.Equal("user message", messages[0].Text); + Assert.Equal("context provider message", messages[1].Text); + Assert.Equal("response", messages[2].Text); mockProvider .Protected() @@ -413,16 +564,17 @@ public async Task RunAsyncInvokesAIContextProviderWhenGetResponseFailsAsync() mockProvider .Protected() .Setup>("InvokingCoreAsync", ItExpr.IsAny(), ItExpr.IsAny()) - .ReturnsAsync(new AIContext - { - Messages = aiContextProviderMessages, - }); + .Returns((AIContextProvider.InvokingContext ctx, CancellationToken _) => + new ValueTask(new AIContext + { + Messages = (ctx.AIContext.Messages ?? []).Concat(aiContextProviderMessages), + })); mockProvider .Protected() .Setup("InvokedCoreAsync", ItExpr.IsAny(), ItExpr.IsAny()) .Returns(new ValueTask()); - ChatClientAgent agent = new(mockService.Object, options: new() { AIContextProviderFactory = (_, _) => new(mockProvider.Object), ChatOptions = new() { Instructions = "base instructions", Tools = [AIFunctionFactory.Create(() => { }, "base function")] } }); + ChatClientAgent agent = new(mockService.Object, options: new() { AIContextProviders = [mockProvider.Object], ChatOptions = new() { Instructions = "base instructions", Tools = [AIFunctionFactory.Create(() => { }, "base function")] } }); // Act await Assert.ThrowsAsync(() => agent.RunAsync(requestMessages)); @@ -470,9 +622,15 @@ public async Task RunAsyncInvokesAIContextProviderAndSucceedsWithEmptyAIContextA mockProvider .Protected() .Setup>("InvokingCoreAsync", ItExpr.IsAny(), ItExpr.IsAny()) - .ReturnsAsync(new AIContext()); + .Returns((AIContextProvider.InvokingContext ctx, CancellationToken _) => + new ValueTask(new AIContext + { + Instructions = ctx.AIContext.Instructions, + Messages = ctx.AIContext.Messages, + Tools = ctx.AIContext.Tools + })); - ChatClientAgent agent = new(mockService.Object, options: new() { AIContextProviderFactory = (_, _) => new(mockProvider.Object), ChatOptions = new() { Instructions = "base instructions", Tools = [AIFunctionFactory.Create(() => { }, "base function")] } }); + ChatClientAgent agent = new(mockService.Object, options: new() { AIContextProviders = [mockProvider.Object], ChatOptions = new() { Instructions = "base instructions", Tools = [AIFunctionFactory.Create(() => { }, "base function")] } }); // Act await agent.RunAsync([new(ChatRole.User, "user message")]); @@ -490,6 +648,299 @@ public async Task RunAsyncInvokesAIContextProviderAndSucceedsWithEmptyAIContextA .Verify>("InvokingCoreAsync", Times.Once(), ItExpr.IsAny(), ItExpr.IsAny()); } + /// + /// Verify that RunAsync invokes multiple AIContextProviders in sequence, each receiving the accumulated context. + /// + [Fact] + public async Task RunAsyncInvokesMultipleAIContextProvidersInOrderAsync() + { + // Arrange + ChatMessage[] requestMessages = [new(ChatRole.User, "user message")]; + ChatMessage[] responseMessages = [new(ChatRole.Assistant, "response")]; + Mock mockService = new(); + List capturedMessages = []; + string capturedInstructions = string.Empty; + List capturedTools = []; + mockService + .Setup(s => s.GetResponseAsync( + It.IsAny>(), + It.IsAny(), + It.IsAny())) + .Callback, ChatOptions, CancellationToken>((msgs, opts, ct) => + { + capturedMessages.AddRange(msgs); + capturedInstructions = opts.Instructions ?? string.Empty; + if (opts.Tools is not null) + { + capturedTools.AddRange(opts.Tools); + } + }) + .ReturnsAsync(new ChatResponse(responseMessages)); + + // Provider 1: adds a system message and a tool + var mockProvider1 = new Mock(); + mockProvider1.SetupGet(p => p.StateKey).Returns("Provider1"); + mockProvider1 + .Protected() + .Setup>("InvokingCoreAsync", ItExpr.IsAny(), ItExpr.IsAny()) + .Returns((AIContextProvider.InvokingContext ctx, CancellationToken _) => + new ValueTask(new AIContext + { + Messages = (ctx.AIContext.Messages ?? []).Concat([new ChatMessage(ChatRole.System, "provider1 context")]).ToList(), + Instructions = ctx.AIContext.Instructions + "\nprovider1 instructions", + Tools = (ctx.AIContext.Tools ?? []).Concat([AIFunctionFactory.Create(() => { }, "provider1 function")]).ToList() + })); + mockProvider1 + .Protected() + .Setup("InvokedCoreAsync", ItExpr.IsAny(), ItExpr.IsAny()) + .Returns(new ValueTask()); + + // Provider 2: adds another system message and verifies it receives accumulated context from provider 1 + AIContext? provider2ReceivedContext = null; + var mockProvider2 = new Mock(); + mockProvider2.SetupGet(p => p.StateKey).Returns("Provider2"); + mockProvider2 + .Protected() + .Setup>("InvokingCoreAsync", ItExpr.IsAny(), ItExpr.IsAny()) + .Returns((AIContextProvider.InvokingContext ctx, CancellationToken _) => + { + provider2ReceivedContext = ctx.AIContext; + return new ValueTask(new AIContext + { + Messages = (ctx.AIContext.Messages ?? []).Concat([new ChatMessage(ChatRole.System, "provider2 context")]).ToList(), + Instructions = ctx.AIContext.Instructions + "\nprovider2 instructions", + Tools = (ctx.AIContext.Tools ?? []).Concat([AIFunctionFactory.Create(() => { }, "provider2 function")]).ToList() + }); + }); + mockProvider2 + .Protected() + .Setup("InvokedCoreAsync", ItExpr.IsAny(), ItExpr.IsAny()) + .Returns(new ValueTask()); + + ChatClientAgent agent = new(mockService.Object, options: new() + { + AIContextProviders = [mockProvider1.Object, mockProvider2.Object], + ChatOptions = new() { Instructions = "base instructions", Tools = [AIFunctionFactory.Create(() => { }, "base function")] } + }); + + // Act + var session = await agent.CreateSessionAsync() as ChatClientAgentSession; + await agent.RunAsync(requestMessages, session); + + // Assert + // Provider 2 should have received accumulated context from provider 1 + Assert.NotNull(provider2ReceivedContext); + Assert.Contains(provider2ReceivedContext.Messages!, m => m.Text == "provider1 context"); + Assert.Contains("provider1 instructions", provider2ReceivedContext.Instructions); + + // Final captured messages should contain user message + both provider contexts + Assert.Equal(3, capturedMessages.Count); + Assert.Equal("user message", capturedMessages[0].Text); + Assert.Equal("provider1 context", capturedMessages[1].Text); + Assert.Equal("provider2 context", capturedMessages[2].Text); + + // Instructions should be accumulated + Assert.Equal("base instructions\nprovider1 instructions\nprovider2 instructions", capturedInstructions); + + // Tools should contain base + both provider tools + Assert.Equal(3, capturedTools.Count); + Assert.Contains(capturedTools, t => t.Name == "base function"); + Assert.Contains(capturedTools, t => t.Name == "provider1 function"); + Assert.Contains(capturedTools, t => t.Name == "provider2 function"); + + // Both providers should have been invoked + mockProvider1 + .Protected() + .Verify>("InvokingCoreAsync", Times.Once(), ItExpr.IsAny(), ItExpr.IsAny()); + mockProvider2 + .Protected() + .Verify>("InvokingCoreAsync", Times.Once(), ItExpr.IsAny(), ItExpr.IsAny()); + + // Both providers should have been notified of success + mockProvider1 + .Protected() + .Verify("InvokedCoreAsync", Times.Once(), ItExpr.Is(x => + x.ResponseMessages == responseMessages && + x.InvokeException == null), ItExpr.IsAny()); + mockProvider2 + .Protected() + .Verify("InvokedCoreAsync", Times.Once(), ItExpr.Is(x => + x.ResponseMessages == responseMessages && + x.InvokeException == null), ItExpr.IsAny()); + } + + /// + /// Verify that RunAsync invokes InvokedCoreAsync on all AIContextProviders when the downstream GetResponse call fails. + /// + [Fact] + public async Task RunAsyncInvokesMultipleAIContextProvidersOnFailureAsync() + { + // Arrange + ChatMessage[] requestMessages = [new(ChatRole.User, "user message")]; + Mock mockService = new(); + mockService + .Setup(s => s.GetResponseAsync( + It.IsAny>(), + It.IsAny(), + It.IsAny())) + .ThrowsAsync(new InvalidOperationException("downstream failure")); + + var mockProvider1 = new Mock(); + mockProvider1.SetupGet(p => p.StateKey).Returns("Provider1"); + mockProvider1 + .Protected() + .Setup>("InvokingCoreAsync", ItExpr.IsAny(), ItExpr.IsAny()) + .Returns((AIContextProvider.InvokingContext ctx, CancellationToken _) => + new ValueTask(new AIContext + { + Messages = ctx.AIContext.Messages?.ToList(), + Instructions = ctx.AIContext.Instructions, + Tools = ctx.AIContext.Tools + })); + mockProvider1 + .Protected() + .Setup("InvokedCoreAsync", ItExpr.IsAny(), ItExpr.IsAny()) + .Returns(new ValueTask()); + + var mockProvider2 = new Mock(); + mockProvider2.SetupGet(p => p.StateKey).Returns("Provider2"); + mockProvider2 + .Protected() + .Setup>("InvokingCoreAsync", ItExpr.IsAny(), ItExpr.IsAny()) + .Returns((AIContextProvider.InvokingContext ctx, CancellationToken _) => + new ValueTask(new AIContext + { + Messages = ctx.AIContext.Messages?.ToList(), + Instructions = ctx.AIContext.Instructions, + Tools = ctx.AIContext.Tools + })); + mockProvider2 + .Protected() + .Setup("InvokedCoreAsync", ItExpr.IsAny(), ItExpr.IsAny()) + .Returns(new ValueTask()); + + ChatClientAgent agent = new(mockService.Object, options: new() + { + AIContextProviders = [mockProvider1.Object, mockProvider2.Object], + ChatOptions = new() { Instructions = "base instructions" } + }); + + // Act + await Assert.ThrowsAsync(() => agent.RunAsync(requestMessages)); + + // Assert - both providers should have been notified of the failure + mockProvider1 + .Protected() + .Verify>("InvokingCoreAsync", Times.Once(), ItExpr.IsAny(), ItExpr.IsAny()); + mockProvider2 + .Protected() + .Verify>("InvokingCoreAsync", Times.Once(), ItExpr.IsAny(), ItExpr.IsAny()); + + mockProvider1 + .Protected() + .Verify("InvokedCoreAsync", Times.Once(), ItExpr.Is(x => + x.InvokeException is InvalidOperationException), ItExpr.IsAny()); + mockProvider2 + .Protected() + .Verify("InvokedCoreAsync", Times.Once(), ItExpr.Is(x => + x.InvokeException is InvalidOperationException), ItExpr.IsAny()); + } + + /// + /// Verify that RunStreamingAsync invokes multiple AIContextProviders in sequence. + /// + [Fact] + public async Task RunStreamingAsyncInvokesMultipleAIContextProvidersAsync() + { + // Arrange + ChatMessage[] requestMessages = [new(ChatRole.User, "user message")]; + ChatResponseUpdate[] responseUpdates = [new(ChatRole.Assistant, "response")]; + Mock mockService = new(); + List capturedMessages = []; + string capturedInstructions = string.Empty; + mockService + .Setup(s => s.GetStreamingResponseAsync( + It.IsAny>(), + It.IsAny(), + It.IsAny())) + .Callback, ChatOptions, CancellationToken>((msgs, opts, ct) => + { + capturedMessages.AddRange(msgs); + capturedInstructions = opts.Instructions ?? string.Empty; + }) + .Returns(ToAsyncEnumerableAsync(responseUpdates)); + + var mockProvider1 = new Mock(); + mockProvider1.SetupGet(p => p.StateKey).Returns("Provider1"); + mockProvider1 + .Protected() + .Setup>("InvokingCoreAsync", ItExpr.IsAny(), ItExpr.IsAny()) + .Returns((AIContextProvider.InvokingContext ctx, CancellationToken _) => + new ValueTask(new AIContext + { + Messages = (ctx.AIContext.Messages ?? []).Concat([new ChatMessage(ChatRole.System, "provider1 context")]).ToList(), + Instructions = ctx.AIContext.Instructions + "\nprovider1 instructions", + Tools = ctx.AIContext.Tools + })); + mockProvider1 + .Protected() + .Setup("InvokedCoreAsync", ItExpr.IsAny(), ItExpr.IsAny()) + .Returns(new ValueTask()); + + var mockProvider2 = new Mock(); + mockProvider2.SetupGet(p => p.StateKey).Returns("Provider2"); + mockProvider2 + .Protected() + .Setup>("InvokingCoreAsync", ItExpr.IsAny(), ItExpr.IsAny()) + .Returns((AIContextProvider.InvokingContext ctx, CancellationToken _) => + new ValueTask(new AIContext + { + Messages = (ctx.AIContext.Messages ?? []).Concat([new ChatMessage(ChatRole.System, "provider2 context")]).ToList(), + Instructions = ctx.AIContext.Instructions + "\nprovider2 instructions", + Tools = ctx.AIContext.Tools + })); + mockProvider2 + .Protected() + .Setup("InvokedCoreAsync", ItExpr.IsAny(), ItExpr.IsAny()) + .Returns(new ValueTask()); + + ChatClientAgent agent = new( + mockService.Object, + options: new() + { + ChatOptions = new() { Instructions = "base instructions" }, + AIContextProviders = [mockProvider1.Object, mockProvider2.Object] + }); + + // Act + var session = await agent.CreateSessionAsync() as ChatClientAgentSession; + var updates = agent.RunStreamingAsync(requestMessages, session); + _ = await updates.ToAgentResponseAsync(); + + // Assert + Assert.Equal(3, capturedMessages.Count); + Assert.Equal("user message", capturedMessages[0].Text); + Assert.Equal("provider1 context", capturedMessages[1].Text); + Assert.Equal("provider2 context", capturedMessages[2].Text); + Assert.Equal("base instructions\nprovider1 instructions\nprovider2 instructions", capturedInstructions); + + // Both providers should have been invoked and notified + mockProvider1 + .Protected() + .Verify>("InvokingCoreAsync", Times.Once(), ItExpr.IsAny(), ItExpr.IsAny()); + mockProvider2 + .Protected() + .Verify>("InvokingCoreAsync", Times.Once(), ItExpr.IsAny(), ItExpr.IsAny()); + mockProvider1 + .Protected() + .Verify("InvokedCoreAsync", Times.Once(), ItExpr.Is(x => + x.InvokeException == null), ItExpr.IsAny()); + mockProvider2 + .Protected() + .Verify("InvokedCoreAsync", Times.Once(), ItExpr.Is(x => + x.InvokeException == null), ItExpr.IsAny()); + } + #endregion #region RunAsync Structured Output Tests @@ -1300,12 +1751,10 @@ public async Task RunStreamingAsyncUsesChatHistoryProviderWhenNoConversationIdRe It.IsAny>(), It.IsAny(), It.IsAny())).Returns(ToAsyncEnumerableAsync(returnUpdates)); - Mock>> mockFactory = new(); - mockFactory.Setup(f => f(It.IsAny(), It.IsAny())).ReturnsAsync(new InMemoryChatHistoryProvider()); ChatClientAgent agent = new(mockService.Object, options: new() { ChatOptions = new() { Instructions = "test instructions" }, - ChatHistoryProviderFactory = mockFactory.Object + ChatHistoryProvider = new InMemoryChatHistoryProvider() }); // Act @@ -1313,11 +1762,11 @@ public async Task RunStreamingAsyncUsesChatHistoryProviderWhenNoConversationIdRe await agent.RunStreamingAsync([new(ChatRole.User, "test")], session).ToListAsync(); // Assert - var chatHistoryProvider = Assert.IsType(session!.ChatHistoryProvider); - Assert.Equal(2, chatHistoryProvider.Count); - Assert.Equal("test", chatHistoryProvider[0].Text); - Assert.Equal("what?", chatHistoryProvider[1].Text); - mockFactory.Verify(f => f(It.IsAny(), It.IsAny()), Times.Once); + var chatHistoryProvider = Assert.IsType(agent.GetService(typeof(ChatHistoryProvider))); + var historyMessages = chatHistoryProvider.GetMessages(session); + Assert.Equal(2, historyMessages.Count); + Assert.Equal("test", historyMessages[0].Text); + Assert.Equal("what?", historyMessages[1].Text); } /// @@ -1360,10 +1809,10 @@ public async Task RunStreamingAsyncIncludesChatHistoryInMessagesToChatClientAsyn } /// - /// Verify that RunStreamingAsync throws when a factory is provided and the chat client returns a conversation id. + /// Verify that RunStreamingAsync throws when a is provided and the chat client returns a conversation id. /// [Fact] - public async Task RunStreamingAsyncThrowsWhenChatHistoryProviderFactoryProvidedAndConversationIdReturnedByChatClientAsync() + public async Task RunStreamingAsyncThrowsWhenChatHistoryProviderProvidedAndConversationIdReturnedByChatClientAsync() { // Arrange Mock mockService = new(); @@ -1377,18 +1826,16 @@ public async Task RunStreamingAsyncThrowsWhenChatHistoryProviderFactoryProvidedA It.IsAny>(), It.IsAny(), It.IsAny())).Returns(ToAsyncEnumerableAsync(returnUpdates)); - Mock>> mockFactory = new(); - mockFactory.Setup(f => f(It.IsAny(), It.IsAny())).ReturnsAsync(new InMemoryChatHistoryProvider()); ChatClientAgent agent = new(mockService.Object, options: new() { ChatOptions = new() { Instructions = "test instructions" }, - ChatHistoryProviderFactory = mockFactory.Object + ChatHistoryProvider = new InMemoryChatHistoryProvider() }); // Act & Assert ChatClientAgentSession? session = await agent.CreateSessionAsync() as ChatClientAgentSession; var exception = await Assert.ThrowsAsync(async () => await agent.RunStreamingAsync([new(ChatRole.User, "test")], session).ToListAsync()); - Assert.Equal("Only the ConversationId or ChatHistoryProvider may be set, but not both and switching from one to another is not supported.", exception.Message); + Assert.Equal("Only ConversationId or ChatHistoryProvider may be used, but not both. The service returned a conversation id indicating server-side chat history management, but the agent has a ChatHistoryProvider configured.", exception.Message); } /// @@ -1425,12 +1872,13 @@ public async Task RunStreamingAsyncInvokesAIContextProviderAndUsesResultAsync() mockProvider .Protected() .Setup>("InvokingCoreAsync", ItExpr.IsAny(), ItExpr.IsAny()) - .ReturnsAsync(new AIContext - { - Messages = aiContextProviderMessages, - Instructions = "context provider instructions", - Tools = [AIFunctionFactory.Create(() => { }, "context provider function")] - }); + .Returns((AIContextProvider.InvokingContext ctx, CancellationToken _) => + new ValueTask(new AIContext + { + Messages = (ctx.AIContext.Messages ?? []).Concat(aiContextProviderMessages), + Instructions = ctx.AIContext.Instructions + "\ncontext provider instructions", + Tools = (ctx.AIContext.Tools ?? []).Concat(new[] { AIFunctionFactory.Create(() => { }, "context provider function") }) + })); mockProvider .Protected() .Setup("InvokedCoreAsync", ItExpr.IsAny(), ItExpr.IsAny()) @@ -1441,7 +1889,7 @@ public async Task RunStreamingAsyncInvokesAIContextProviderAndUsesResultAsync() options: new() { ChatOptions = new() { Instructions = "base instructions", Tools = [AIFunctionFactory.Create(() => { }, "base function")] }, - AIContextProviderFactory = (_, _) => new(mockProvider.Object) + AIContextProviders = [mockProvider.Object] }); // Act @@ -1462,11 +1910,13 @@ public async Task RunStreamingAsyncInvokesAIContextProviderAndUsesResultAsync() Assert.Contains(capturedTools, t => t.Name == "context provider function"); // Verify that the session was updated with the input, ai context provider, and response messages - var chatHistoryProvider = Assert.IsType(session!.ChatHistoryProvider); - Assert.Equal(3, chatHistoryProvider.Count); - Assert.Equal("user message", chatHistoryProvider[0].Text); - Assert.Equal("context provider message", chatHistoryProvider[1].Text); - Assert.Equal("response", chatHistoryProvider[2].Text); + var chatHistoryProvider = agent.ChatHistoryProvider as InMemoryChatHistoryProvider; + Assert.NotNull(chatHistoryProvider); + var historyMessages2 = chatHistoryProvider.GetMessages(session); + Assert.Equal(3, historyMessages2.Count); + Assert.Equal("user message", historyMessages2[0].Text); + Assert.Equal("context provider message", historyMessages2[1].Text); + Assert.Equal("response", historyMessages2[2].Text); mockProvider .Protected() @@ -1501,10 +1951,11 @@ public async Task RunStreamingAsyncInvokesAIContextProviderWhenGetResponseFailsA mockProvider .Protected() .Setup>("InvokingCoreAsync", ItExpr.IsAny(), ItExpr.IsAny()) - .ReturnsAsync(new AIContext - { - Messages = aiContextProviderMessages, - }); + .Returns((AIContextProvider.InvokingContext ctx, CancellationToken _) => + new ValueTask(new AIContext + { + Messages = (ctx.AIContext.Messages ?? []).Concat(aiContextProviderMessages), + })); mockProvider .Protected() .Setup("InvokedCoreAsync", ItExpr.IsAny(), ItExpr.IsAny()) @@ -1515,7 +1966,7 @@ public async Task RunStreamingAsyncInvokesAIContextProviderWhenGetResponseFailsA options: new() { ChatOptions = new() { Instructions = "base instructions", Tools = [AIFunctionFactory.Create(() => { }, "base function")] }, - AIContextProviderFactory = (_, _) => new(mockProvider.Object) + AIContextProviders = [mockProvider.Object] }); // Act @@ -1565,4 +2016,23 @@ private enum Species [JsonSourceGenerationOptions(UseStringEnumConverter = true, PropertyNamingPolicy = JsonKnownNamingPolicy.CamelCase)] [JsonSerializable(typeof(Animal))] private sealed partial class JsonContext2 : JsonSerializerContext; + + private sealed class TestAIContextProvider(string stateKey) : AIContextProvider + { + public override string StateKey => stateKey; + + protected override ValueTask InvokingCoreAsync(InvokingContext context, CancellationToken cancellationToken = default) + => new(context.AIContext); + } + + private sealed class TestChatHistoryProvider(string stateKey) : ChatHistoryProvider + { + public override string StateKey => stateKey; + + protected override ValueTask> InvokingCoreAsync(InvokingContext context, CancellationToken cancellationToken = default) + => new(context.RequestMessages); + + protected override ValueTask InvokedCoreAsync(InvokedContext context, CancellationToken cancellationToken = default) + => default; + } } diff --git a/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgent_BackgroundResponsesTests.cs b/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgent_BackgroundResponsesTests.cs index 2eed890292..84285ff9c4 100644 --- a/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgent_BackgroundResponsesTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgent_BackgroundResponsesTests.cs @@ -339,6 +339,7 @@ public async Task RunAsync_WhenContinuationTokenProvided_SkipsSessionMessagePopu // Create a mock chat history provider that would normally provide messages var mockChatHistoryProvider = new Mock(); + mockChatHistoryProvider.SetupGet(p => p.StateKey).Returns("ChatHistoryProvider"); mockChatHistoryProvider .Protected() .Setup>>("InvokingCoreAsync", ItExpr.IsAny(), ItExpr.IsAny()) @@ -346,6 +347,7 @@ public async Task RunAsync_WhenContinuationTokenProvided_SkipsSessionMessagePopu // Create a mock AI context provider that would normally provide context var mockContextProvider = new Mock(); + mockContextProvider.SetupGet(p => p.StateKey).Returns("Provider1"); mockContextProvider .Protected() .Setup>("InvokingCoreAsync", ItExpr.IsAny(), ItExpr.IsAny()) @@ -365,14 +367,14 @@ public async Task RunAsync_WhenContinuationTokenProvided_SkipsSessionMessagePopu capturedMessages.AddRange(msgs)) .ReturnsAsync(new ChatResponse([new(ChatRole.Assistant, "continued response")])); - ChatClientAgent agent = new(mockChatClient.Object); - - // Create a session with both chat history provider and AI context provider - ChatClientAgentSession? session = new() + ChatClientAgent agent = new(mockChatClient.Object, options: new() { ChatHistoryProvider = mockChatHistoryProvider.Object, - AIContextProvider = mockContextProvider.Object - }; + AIContextProviders = [mockContextProvider.Object] + }); + + // Create a session + ChatClientAgentSession? session = new(); AgentRunOptions runOptions = new() { @@ -406,6 +408,7 @@ public async Task RunStreamingAsync_WhenContinuationTokenProvided_SkipsSessionMe // Create a mock chat history provider that would normally provide messages var mockChatHistoryProvider = new Mock(); + mockChatHistoryProvider.SetupGet(p => p.StateKey).Returns("ChatHistoryProvider"); mockChatHistoryProvider .Protected() .Setup>>("InvokingCoreAsync", ItExpr.IsAny(), ItExpr.IsAny()) @@ -413,6 +416,7 @@ public async Task RunStreamingAsync_WhenContinuationTokenProvided_SkipsSessionMe // Create a mock AI context provider that would normally provide context var mockContextProvider = new Mock(); + mockContextProvider.SetupGet(p => p.StateKey).Returns("Provider1"); mockContextProvider .Protected() .Setup>("InvokingCoreAsync", ItExpr.IsAny(), ItExpr.IsAny()) @@ -432,14 +436,14 @@ public async Task RunStreamingAsync_WhenContinuationTokenProvided_SkipsSessionMe capturedMessages.AddRange(msgs)) .Returns(ToAsyncEnumerableAsync([new ChatResponseUpdate(role: ChatRole.Assistant, content: "continued response")])); - ChatClientAgent agent = new(mockChatClient.Object); - - // Create a session with both chat history provider and AI context provider - ChatClientAgentSession? session = new() + ChatClientAgent agent = new(mockChatClient.Object, options: new() { ChatHistoryProvider = mockChatHistoryProvider.Object, - AIContextProvider = mockContextProvider.Object - }; + AIContextProviders = [mockContextProvider.Object] + }); + + // Create a session + ChatClientAgentSession? session = new(); AgentRunOptions runOptions = new() { @@ -633,10 +637,9 @@ public async Task RunStreamingAsync_WhenResumingStreaming_UsesUpdatesFromInitial It.IsAny())) .Returns(ToAsyncEnumerableAsync(returnUpdates)); - ChatClientAgent agent = new(mockChatClient.Object); - List capturedMessagesAddedToProvider = []; var mockChatHistoryProvider = new Mock(); + mockChatHistoryProvider.SetupGet(p => p.StateKey).Returns("ChatHistoryProvider"); mockChatHistoryProvider .Protected() .Setup("InvokedCoreAsync", ItExpr.IsAny(), ItExpr.IsAny()) @@ -645,17 +648,20 @@ public async Task RunStreamingAsync_WhenResumingStreaming_UsesUpdatesFromInitial AIContextProvider.InvokedContext? capturedInvokedContext = null; var mockContextProvider = new Mock(); + mockContextProvider.SetupGet(p => p.StateKey).Returns("Provider1"); mockContextProvider .Protected() .Setup("InvokedCoreAsync", ItExpr.IsAny(), ItExpr.IsAny()) .Callback((context, ct) => capturedInvokedContext = context) .Returns(new ValueTask()); - ChatClientAgentSession? session = new() + ChatClientAgent agent = new(mockChatClient.Object, options: new() { ChatHistoryProvider = mockChatHistoryProvider.Object, - AIContextProvider = mockContextProvider.Object - }; + AIContextProviders = [mockContextProvider.Object] + }); + + ChatClientAgentSession? session = new(); AgentRunOptions runOptions = new() { @@ -695,10 +701,9 @@ public async Task RunStreamingAsync_WhenResumingStreaming_UsesInputMessagesFromI It.IsAny())) .Returns(ToAsyncEnumerableAsync(Array.Empty())); - ChatClientAgent agent = new(mockChatClient.Object); - List capturedMessagesAddedToProvider = []; var mockChatHistoryProvider = new Mock(); + mockChatHistoryProvider.SetupGet(p => p.StateKey).Returns("ChatHistoryProvider"); mockChatHistoryProvider .Protected() .Setup("InvokedCoreAsync", ItExpr.IsAny(), ItExpr.IsAny()) @@ -707,17 +712,20 @@ public async Task RunStreamingAsync_WhenResumingStreaming_UsesInputMessagesFromI AIContextProvider.InvokedContext? capturedInvokedContext = null; var mockContextProvider = new Mock(); + mockContextProvider.SetupGet(p => p.StateKey).Returns("Provider1"); mockContextProvider .Protected() .Setup("InvokedCoreAsync", ItExpr.IsAny(), ItExpr.IsAny()) .Callback((context, ct) => capturedInvokedContext = context) .Returns(new ValueTask()); - ChatClientAgentSession? session = new() + ChatClientAgent agent = new(mockChatClient.Object, options: new() { ChatHistoryProvider = mockChatHistoryProvider.Object, - AIContextProvider = mockContextProvider.Object - }; + AIContextProviders = [mockContextProvider.Object] + }); + + ChatClientAgentSession? session = new(); AgentRunOptions runOptions = new() { diff --git a/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgent_ChatHistoryManagementTests.cs b/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgent_ChatHistoryManagementTests.cs index 4de8f01f8e..4731805af7 100644 --- a/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgent_ChatHistoryManagementTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgent_ChatHistoryManagementTests.cs @@ -163,17 +163,19 @@ public async Task RunAsync_UsesDefaultInMemoryChatHistoryProvider_WhenNoConversa await agent.RunAsync([new(ChatRole.User, "test")], session); // Assert - InMemoryChatHistoryProvider chatHistoryProvider = Assert.IsType(session!.ChatHistoryProvider); - Assert.Equal(2, chatHistoryProvider.Count); - Assert.Equal("test", chatHistoryProvider[0].Text); - Assert.Equal("response", chatHistoryProvider[1].Text); + var inMemoryProvider = agent.ChatHistoryProvider as InMemoryChatHistoryProvider; + Assert.NotNull(inMemoryProvider); + var messages = inMemoryProvider.GetMessages(session!); + Assert.Equal(2, messages.Count); + Assert.Equal("test", messages[0].Text); + Assert.Equal("response", messages[1].Text); } /// - /// Verify that RunAsync uses the ChatHistoryProvider factory when the chat client returns no conversation id. + /// Verify that RunAsync uses the ChatHistoryProvider when the chat client returns no conversation id. /// [Fact] - public async Task RunAsync_UsesChatHistoryProviderFactory_WhenProvidedAndNoConversationIdReturnedByChatClientAsync() + public async Task RunAsync_UsesChatHistoryProvider_WhenProvidedAndNoConversationIdReturnedByChatClientAsync() { // Arrange Mock mockService = new(); @@ -187,19 +189,17 @@ public async Task RunAsync_UsesChatHistoryProviderFactory_WhenProvidedAndNoConve mockChatHistoryProvider .Protected() .Setup>>("InvokingCoreAsync", ItExpr.IsAny(), ItExpr.IsAny()) - .ReturnsAsync([new ChatMessage(ChatRole.User, "Existing Chat History")]); + .Returns((ChatHistoryProvider.InvokingContext ctx, CancellationToken _) => + new ValueTask>(new List { new(ChatRole.User, "Existing Chat History") }.Concat(ctx.RequestMessages).ToList())); mockChatHistoryProvider .Protected() .Setup("InvokedCoreAsync", ItExpr.IsAny(), ItExpr.IsAny()) .Returns(new ValueTask()); - Mock>> mockFactory = new(); - mockFactory.Setup(f => f(It.IsAny(), It.IsAny())).ReturnsAsync(mockChatHistoryProvider.Object); - ChatClientAgent agent = new(mockService.Object, options: new() { ChatOptions = new() { Instructions = "test instructions" }, - ChatHistoryProviderFactory = mockFactory.Object + ChatHistoryProvider = mockChatHistoryProvider.Object }); // Act @@ -207,7 +207,7 @@ public async Task RunAsync_UsesChatHistoryProviderFactory_WhenProvidedAndNoConve await agent.RunAsync([new(ChatRole.User, "test")], session); // Assert - Assert.IsType(session!.ChatHistoryProvider, exactMatch: false); + Assert.Same(mockChatHistoryProvider.Object, agent.ChatHistoryProvider); mockService.Verify( x => x.GetResponseAsync( It.Is>(msgs => msgs.Count() == 2 && msgs.Any(m => m.Text == "Existing Chat History") && msgs.Any(m => m.Text == "test")), @@ -222,9 +222,8 @@ public async Task RunAsync_UsesChatHistoryProviderFactory_WhenProvidedAndNoConve mockChatHistoryProvider .Protected() .Verify("InvokedCoreAsync", Times.Once(), - ItExpr.Is(x => x.RequestMessages.Count() == 1 && x.ResponseMessages!.Count() == 1), + ItExpr.Is(x => x.RequestMessages.Count() == 2 && x.ResponseMessages!.Count() == 1), ItExpr.IsAny()); - mockFactory.Verify(f => f(It.IsAny(), It.IsAny()), Times.Once); } /// @@ -242,14 +241,16 @@ public async Task RunAsync_NotifiesChatHistoryProvider_OnFailureAsync() It.IsAny())).Throws(new InvalidOperationException("Test Error")); Mock mockChatHistoryProvider = new(); - - Mock>> mockFactory = new(); - mockFactory.Setup(f => f(It.IsAny(), It.IsAny())).ReturnsAsync(mockChatHistoryProvider.Object); + mockChatHistoryProvider + .Protected() + .Setup>>("InvokingCoreAsync", ItExpr.IsAny(), ItExpr.IsAny()) + .Returns((ChatHistoryProvider.InvokingContext ctx, CancellationToken _) => + new ValueTask>(ctx.RequestMessages.ToList())); ChatClientAgent agent = new(mockService.Object, options: new() { ChatOptions = new() { Instructions = "test instructions" }, - ChatHistoryProviderFactory = mockFactory.Object + ChatHistoryProvider = mockChatHistoryProvider.Object }); // Act @@ -257,20 +258,19 @@ public async Task RunAsync_NotifiesChatHistoryProvider_OnFailureAsync() await Assert.ThrowsAsync(() => agent.RunAsync([new(ChatRole.User, "test")], session)); // Assert - Assert.IsType(session!.ChatHistoryProvider, exactMatch: false); + Assert.Same(mockChatHistoryProvider.Object, agent.ChatHistoryProvider); mockChatHistoryProvider .Protected() .Verify("InvokedCoreAsync", Times.Once(), ItExpr.Is(x => x.RequestMessages.Count() == 1 && x.ResponseMessages == null && x.InvokeException!.Message == "Test Error"), ItExpr.IsAny()); - mockFactory.Verify(f => f(It.IsAny(), It.IsAny()), Times.Once); } /// - /// Verify that RunAsync throws when a ChatHistoryProvider Factory is provided and the chat client returns a conversation id. + /// Verify that RunAsync throws when a ChatHistoryProvider is provided and the chat client returns a conversation id. /// [Fact] - public async Task RunAsync_Throws_WhenChatHistoryProviderFactoryProvidedAndConversationIdReturnedByChatClientAsync() + public async Task RunAsync_Throws_WhenChatHistoryProviderProvidedAndConversationIdReturnedByChatClientAsync() { // Arrange Mock mockService = new(); @@ -279,18 +279,16 @@ public async Task RunAsync_Throws_WhenChatHistoryProviderFactoryProvidedAndConve It.IsAny>(), It.IsAny(), It.IsAny())).ReturnsAsync(new ChatResponse([new(ChatRole.Assistant, "response")]) { ConversationId = "ConvId" }); - Mock>> mockFactory = new(); - mockFactory.Setup(f => f(It.IsAny(), It.IsAny())).ReturnsAsync(new InMemoryChatHistoryProvider()); ChatClientAgent agent = new(mockService.Object, options: new() { ChatOptions = new() { Instructions = "test instructions" }, - ChatHistoryProviderFactory = mockFactory.Object + ChatHistoryProvider = new InMemoryChatHistoryProvider() }); // Act & Assert ChatClientAgentSession? session = await agent.CreateSessionAsync() as ChatClientAgentSession; InvalidOperationException exception = await Assert.ThrowsAsync(() => agent.RunAsync([new(ChatRole.User, "test")], session)); - Assert.Equal("Only the ConversationId or ChatHistoryProvider may be set, but not both and switching from one to another is not supported.", exception.Message); + Assert.Equal("Only ConversationId or ChatHistoryProvider may be used, but not both. The service returned a conversation id indicating server-side chat history management, but the agent has a ChatHistoryProvider configured.", exception.Message); } #endregion @@ -317,31 +315,29 @@ public async Task RunAsync_UsesOverrideChatHistoryProvider_WhenProvidedViaAdditi mockOverrideChatHistoryProvider .Protected() .Setup>>("InvokingCoreAsync", ItExpr.IsAny(), ItExpr.IsAny()) - .ReturnsAsync([new ChatMessage(ChatRole.User, "Existing Chat History")]); + .Returns((ChatHistoryProvider.InvokingContext ctx, CancellationToken _) => + new ValueTask>(new List { new(ChatRole.User, "Existing Chat History") }.Concat(ctx.RequestMessages).ToList())); mockOverrideChatHistoryProvider .Protected() .Setup("InvokedCoreAsync", ItExpr.IsAny(), ItExpr.IsAny()) .Returns(new ValueTask()); - // Arrange a chat history provider to provide to the agent via a factory at construction time. + // Arrange a chat history provider to provide to the agent at construction time. // This one shouldn't be used since it is being overridden. - Mock mockFactoryChatHistoryProvider = new(); - mockFactoryChatHistoryProvider + Mock mockAgentOptionsChatHistoryProvider = new(); + mockAgentOptionsChatHistoryProvider .Protected() .Setup>>("InvokingCoreAsync", ItExpr.IsAny(), ItExpr.IsAny()) .ThrowsAsync(FailException.ForFailure("Base ChatHistoryProvider shouldn't be used.")); - mockFactoryChatHistoryProvider + mockAgentOptionsChatHistoryProvider .Protected() .Setup("InvokedCoreAsync", ItExpr.IsAny(), ItExpr.IsAny()) .Throws(FailException.ForFailure("Base ChatHistoryProvider shouldn't be used.")); - Mock>> mockFactory = new(); - mockFactory.Setup(f => f(It.IsAny(), It.IsAny())).ReturnsAsync(mockFactoryChatHistoryProvider.Object); - ChatClientAgent agent = new(mockService.Object, options: new() { ChatOptions = new() { Instructions = "test instructions" }, - ChatHistoryProviderFactory = mockFactory.Object + ChatHistoryProvider = mockAgentOptionsChatHistoryProvider.Object }); // Act @@ -351,7 +347,7 @@ public async Task RunAsync_UsesOverrideChatHistoryProvider_WhenProvidedViaAdditi await agent.RunAsync([new(ChatRole.User, "test")], session, options: new AgentRunOptions { AdditionalProperties = additionalProperties }); // Assert - Assert.Same(mockFactoryChatHistoryProvider.Object, session!.ChatHistoryProvider); + Assert.Same(mockAgentOptionsChatHistoryProvider.Object, agent.ChatHistoryProvider); mockService.Verify( x => x.GetResponseAsync( It.Is>(msgs => msgs.Count() == 2 && msgs.Any(m => m.Text == "Existing Chat History") && msgs.Any(m => m.Text == "test")), @@ -366,15 +362,15 @@ public async Task RunAsync_UsesOverrideChatHistoryProvider_WhenProvidedViaAdditi mockOverrideChatHistoryProvider .Protected() .Verify("InvokedCoreAsync", Times.Once(), - ItExpr.Is(x => x.RequestMessages.Count() == 1 && x.ResponseMessages!.Count() == 1), + ItExpr.Is(x => x.RequestMessages.Count() == 2 && x.ResponseMessages!.Count() == 1), ItExpr.IsAny()); - mockFactoryChatHistoryProvider + mockAgentOptionsChatHistoryProvider .Protected() .Verify>>("InvokingCoreAsync", Times.Never(), ItExpr.IsAny(), ItExpr.IsAny()); - mockFactoryChatHistoryProvider + mockAgentOptionsChatHistoryProvider .Protected() .Verify("InvokedCoreAsync", Times.Never(), ItExpr.IsAny(), diff --git a/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgent_CreateSessionTests.cs b/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgent_CreateSessionTests.cs index 86220a6462..68fd008e9e 100644 --- a/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgent_CreateSessionTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgent_CreateSessionTests.cs @@ -11,77 +11,6 @@ namespace Microsoft.Agents.AI.UnitTests; /// public class ChatClientAgent_CreateSessionTests { - [Fact] - public async Task CreateSession_UsesAIContextProviderFactory_IfProvidedAsync() - { - // Arrange - var mockChatClient = new Mock(); - var mockContextProvider = new Mock(); - var factoryCalled = false; - var agent = new ChatClientAgent(mockChatClient.Object, new ChatClientAgentOptions - { - ChatOptions = new() { Instructions = "Test instructions" }, - AIContextProviderFactory = (_, _) => - { - factoryCalled = true; - return new ValueTask(mockContextProvider.Object); - } - }); - - // Act - var session = await agent.CreateSessionAsync(); - - // Assert - Assert.True(factoryCalled, "AIContextProviderFactory was not called."); - Assert.IsType(session); - var typedSession = (ChatClientAgentSession)session; - Assert.Same(mockContextProvider.Object, typedSession.AIContextProvider); - } - - [Fact] - public async Task CreateSession_UsesChatHistoryProviderFactory_IfProvidedAsync() - { - // Arrange - var mockChatClient = new Mock(); - var mockChatHistoryProvider = new Mock(); - var factoryCalled = false; - var agent = new ChatClientAgent(mockChatClient.Object, new ChatClientAgentOptions - { - ChatOptions = new() { Instructions = "Test instructions" }, - ChatHistoryProviderFactory = (_, _) => - { - factoryCalled = true; - return new ValueTask(mockChatHistoryProvider.Object); - } - }); - - // Act - var session = await agent.CreateSessionAsync(); - - // Assert - Assert.True(factoryCalled, "ChatHistoryProviderFactory was not called."); - Assert.IsType(session); - var typedSession = (ChatClientAgentSession)session; - Assert.Same(mockChatHistoryProvider.Object, typedSession.ChatHistoryProvider); - } - - [Fact] - public async Task CreateSession_UsesChatHistoryProvider_FromTypedOverloadAsync() - { - // Arrange - var mockChatClient = new Mock(); - var mockChatHistoryProvider = new Mock(); - var agent = new ChatClientAgent(mockChatClient.Object); - - // Act - var session = await agent.CreateSessionAsync(mockChatHistoryProvider.Object); - - // Assert - Assert.IsType(session); - var typedSession = (ChatClientAgentSession)session; - Assert.Same(mockChatHistoryProvider.Object, typedSession.ChatHistoryProvider); - } - [Fact] public async Task CreateSession_UsesConversationId_FromTypedOverloadAsync() { diff --git a/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgent_DeserializeSessionTests.cs b/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgent_DeserializeSessionTests.cs deleted file mode 100644 index 014cb1483b..0000000000 --- a/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgent_DeserializeSessionTests.cs +++ /dev/null @@ -1,80 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. - -using System.Text.Json; -using System.Threading.Tasks; -using Microsoft.Extensions.AI; -using Moq; - -namespace Microsoft.Agents.AI.UnitTests; - -/// -/// Contains unit tests for the ChatClientAgent.DeserializeSession methods. -/// -public class ChatClientAgent_DeserializeSessionTests -{ - [Fact] - public async Task DeserializeSession_UsesAIContextProviderFactory_IfProvidedAsync() - { - // Arrange - var mockChatClient = new Mock(); - var mockContextProvider = new Mock(); - var factoryCalled = false; - var agent = new ChatClientAgent(mockChatClient.Object, new ChatClientAgentOptions - { - ChatOptions = new() { Instructions = "Test instructions" }, - AIContextProviderFactory = (_, _) => - { - factoryCalled = true; - return new ValueTask(mockContextProvider.Object); - } - }); - - var json = JsonSerializer.Deserialize(""" - { - "aiContextProviderState": ["CP1"] - } - """, TestJsonSerializerContext.Default.JsonElement); - - // Act - var session = await agent.DeserializeSessionAsync(json); - - // Assert - Assert.True(factoryCalled, "AIContextProviderFactory was not called."); - Assert.IsType(session); - var typedSession = (ChatClientAgentSession)session; - Assert.Same(mockContextProvider.Object, typedSession.AIContextProvider); - } - - [Fact] - public async Task DeserializeSession_UsesChatHistoryProviderFactory_IfProvidedAsync() - { - // Arrange - var mockChatClient = new Mock(); - var mockChatHistoryProvider = new Mock(); - var factoryCalled = false; - var agent = new ChatClientAgent(mockChatClient.Object, new ChatClientAgentOptions - { - ChatOptions = new() { Instructions = "Test instructions" }, - ChatHistoryProviderFactory = (_, _) => - { - factoryCalled = true; - return new ValueTask(mockChatHistoryProvider.Object); - } - }); - - var json = JsonSerializer.Deserialize(""" - { - "chatHistoryProviderState": { } - } - """, TestJsonSerializerContext.Default.JsonElement); - - // Act - var session = await agent.DeserializeSessionAsync(json); - - // Assert - Assert.True(factoryCalled, "ChatHistoryProviderFactory was not called."); - Assert.IsType(session); - var typedSession = (ChatClientAgentSession)session; - Assert.Same(mockChatHistoryProvider.Object, typedSession.ChatHistoryProvider); - } -} diff --git a/dotnet/tests/Microsoft.Agents.AI.UnitTests/Data/TextSearchProviderTests.cs b/dotnet/tests/Microsoft.Agents.AI.UnitTests/Data/TextSearchProviderTests.cs index ec8dda3c45..602fe40e08 100644 --- a/dotnet/tests/Microsoft.Agents.AI.UnitTests/Data/TextSearchProviderTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.UnitTests/Data/TextSearchProviderTests.cs @@ -18,7 +18,6 @@ namespace Microsoft.Agents.AI.UnitTests.Data; public sealed class TextSearchProviderTests { private static readonly AIAgent s_mockAgent = new Mock().Object; - private static readonly AgentSession s_mockSession = new Mock().Object; private readonly Mock> _loggerMock; private readonly Mock _loggerFactoryMock; @@ -39,6 +38,28 @@ public TextSearchProviderTests() .Returns(true); } + [Fact] + public void StateKey_ReturnsDefaultKey_WhenNoOptionsProvided() + { + // Arrange & Act + var provider = new TextSearchProvider((_, _) => Task.FromResult>([])); + + // Assert + Assert.Equal("TextSearchProvider", provider.StateKey); + } + + [Fact] + public void StateKey_ReturnsCustomKey_WhenSetViaOptions() + { + // Arrange & Act + var provider = new TextSearchProvider( + (_, _) => Task.FromResult>([]), + new TextSearchProviderOptions { StateKey = "custom-key" }); + + // Assert + Assert.Equal("custom-key", provider.StateKey); + } + [Theory] [InlineData(null, null, true)] [InlineData("Custom context prompt", "Custom citations prompt", false)] @@ -64,15 +85,19 @@ public async Task InvokingAsync_ShouldInjectFormattedResultsAsync(string? overri ContextPrompt = overrideContextPrompt, CitationsPrompt = overrideCitationsPrompt }; - var provider = new TextSearchProvider(SearchDelegateAsync, default, null, options, withLogging ? this._loggerFactoryMock.Object : null); + var provider = new TextSearchProvider(SearchDelegateAsync, options, withLogging ? this._loggerFactoryMock.Object : null); var invokingContext = new AIContextProvider.InvokingContext( s_mockAgent, - s_mockSession, - [ - new ChatMessage(ChatRole.User, "Sample user question?"), - new ChatMessage(ChatRole.User, "Additional part") - ]); + new TestAgentSession(), + new AIContext + { + Messages = new List + { + new(ChatRole.User, "Sample user question?"), + new(ChatRole.User, "Additional part") + } + }); // Act var aiContext = await provider.InvokingAsync(invokingContext, CancellationToken.None); @@ -81,9 +106,15 @@ public async Task InvokingAsync_ShouldInjectFormattedResultsAsync(string? overri Assert.Equal("Sample user question?\nAdditional part", capturedInput); Assert.Null(aiContext.Instructions); // TextSearchProvider uses a user message for context injection. Assert.NotNull(aiContext.Messages); - Assert.Single(aiContext.Messages!); - var message = aiContext.Messages!.Single(); + var messages = aiContext.Messages!.ToList(); + Assert.Equal(3, messages.Count); // 2 input messages + 1 search result message + Assert.Equal("Sample user question?", messages[0].Text); + Assert.Equal("Additional part", messages[1].Text); + Assert.Equal(AgentRequestMessageSourceType.External, messages[0].GetAgentRequestMessageSourceType()); + Assert.Equal(AgentRequestMessageSourceType.External, messages[1].GetAgentRequestMessageSourceType()); + var message = messages.Last(); Assert.Equal(ChatRole.User, message.Role); + Assert.Equal(AgentRequestMessageSourceType.AIContextProvider, message.GetAgentRequestMessageSourceType()); string text = message.Text!; if (overrideContextPrompt is null) @@ -143,17 +174,21 @@ public async Task InvokingAsync_OnDemand_ShouldExposeSearchToolAsync(string? ove FunctionToolName = overrideName, FunctionToolDescription = overrideDescription }; - var provider = new TextSearchProvider(this.NoResultSearchAsync, default, null, options); - var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Q?")]); + var provider = new TextSearchProvider(this.NoResultSearchAsync, options); + var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, new TestAgentSession(), new AIContext { Messages = new List { new(ChatRole.User, "Q?") } }); // Act var aiContext = await provider.InvokingAsync(invokingContext, CancellationToken.None); // Assert - Assert.Null(aiContext.Messages); // No automatic injection. + Assert.NotNull(aiContext.Messages); // Input messages are preserved. + var messages = aiContext.Messages!.ToList(); + Assert.Single(messages); + Assert.Equal("Q?", messages[0].Text); Assert.NotNull(aiContext.Tools); - Assert.Single(aiContext.Tools); - var tool = aiContext.Tools.Single(); + var tools = aiContext.Tools!.ToList(); + Assert.Single(tools); + var tool = tools[0]; Assert.Equal(expectedName, tool.Name); Assert.Equal(expectedDescription, tool.Description); } @@ -162,14 +197,17 @@ public async Task InvokingAsync_OnDemand_ShouldExposeSearchToolAsync(string? ove public async Task InvokingAsync_ShouldNotThrow_WhenSearchFailsAsync() { // Arrange - var provider = new TextSearchProvider(this.FailingSearchAsync, default, null, loggerFactory: this._loggerFactoryMock.Object); - var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Q?")]); + var provider = new TextSearchProvider(this.FailingSearchAsync, loggerFactory: this._loggerFactoryMock.Object); + var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, new TestAgentSession(), new AIContext { Messages = new List { new(ChatRole.User, "Q?") } }); // Act var aiContext = await provider.InvokingAsync(invokingContext, CancellationToken.None); // Assert - Assert.Null(aiContext.Messages); + Assert.NotNull(aiContext.Messages); // Input messages are preserved on error. + var messages = aiContext.Messages!.ToList(); + Assert.Single(messages); + Assert.Equal("Q?", messages[0].Text); Assert.Null(aiContext.Tools); this._loggerMock.Verify( l => l.Log( @@ -203,7 +241,7 @@ public async Task SearchAsync_ShouldReturnFormattedResultsAsync(string? override ContextPrompt = overrideContextPrompt, CitationsPrompt = overrideCitationsPrompt }; - var provider = new TextSearchProvider(SearchDelegateAsync, default, null, options); + var provider = new TextSearchProvider(SearchDelegateAsync, options); // Act var formatted = await provider.SearchAsync("Sample user question?", CancellationToken.None); @@ -255,16 +293,18 @@ public async Task InvokingAsync_ShouldUseContextFormatterWhenProvidedAsync() SearchTime = TextSearchProviderOptions.TextSearchBehavior.BeforeAIInvoke, ContextFormatter = r => $"Custom formatted context with {r.Count} results." }; - var provider = new TextSearchProvider(SearchDelegateAsync, default, null, options); - var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Q?")]); + var provider = new TextSearchProvider(SearchDelegateAsync, options); + var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, new TestAgentSession(), new AIContext { Messages = new List { new(ChatRole.User, "Q?") } }); // Act var aiContext = await provider.InvokingAsync(invokingContext, CancellationToken.None); // Assert Assert.NotNull(aiContext.Messages); - Assert.Single(aiContext.Messages!); - Assert.Equal("Custom formatted context with 2 results.", aiContext.Messages![0].Text); + var messages = aiContext.Messages!.ToList(); + Assert.Equal(2, messages.Count); // 1 input message + 1 formatted result message + Assert.Equal("Q?", messages[0].Text); + Assert.Equal("Custom formatted context with 2 results.", messages[1].Text); } [Fact] @@ -289,16 +329,18 @@ public async Task InvokingAsync_WithRawRepresentations_ContextFormatterCanAccess SearchTime = TextSearchProviderOptions.TextSearchBehavior.BeforeAIInvoke, ContextFormatter = r => string.Join(",", r.Select(x => ((RawPayload)x.RawRepresentation!).Id)) }; - var provider = new TextSearchProvider(SearchDelegateAsync, default, null, options); - var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Q?")]); + var provider = new TextSearchProvider(SearchDelegateAsync, options); + var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, new TestAgentSession(), new AIContext { Messages = new List { new(ChatRole.User, "Q?") } }); // Act var aiContext = await provider.InvokingAsync(invokingContext, CancellationToken.None); // Assert Assert.NotNull(aiContext.Messages); - Assert.Single(aiContext.Messages!); - Assert.Equal("R1,R2", aiContext.Messages![0].Text); + var messages = aiContext.Messages!.ToList(); + Assert.Equal(2, messages.Count); // 1 input message + 1 formatted result message + Assert.Equal("Q?", messages[0].Text); + Assert.Equal("R1,R2", messages[1].Text); } [Fact] @@ -306,18 +348,155 @@ public async Task InvokingAsync_WithNoResults_ShouldReturnEmptyContextAsync() { // Arrange var options = new TextSearchProviderOptions { SearchTime = TextSearchProviderOptions.TextSearchBehavior.BeforeAIInvoke }; - var provider = new TextSearchProvider(this.NoResultSearchAsync, default, null, options); - var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Q?")]); + var provider = new TextSearchProvider(this.NoResultSearchAsync, options); + var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, new TestAgentSession(), new AIContext { Messages = new List { new(ChatRole.User, "Q?") } }); // Act var aiContext = await provider.InvokingAsync(invokingContext, CancellationToken.None); // Assert - Assert.Null(aiContext.Messages); + Assert.NotNull(aiContext.Messages); // Input messages are preserved when no results found. + var messages = aiContext.Messages!.ToList(); + Assert.Single(messages); + Assert.Equal("Q?", messages[0].Text); Assert.Null(aiContext.Instructions); Assert.Null(aiContext.Tools); } + #region Message Filter Tests + + [Fact] + public async Task InvokingAsync_DefaultFilter_ExcludesNonExternalMessagesFromSearchInputAsync() + { + // Arrange + string? capturedInput = null; + Task> SearchDelegateAsync(string input, CancellationToken ct) + { + capturedInput = input; + return Task.FromResult>([]); + } + + var provider = new TextSearchProvider(SearchDelegateAsync); + var requestMessages = new List + { + new(ChatRole.User, "External message"), + new(ChatRole.System, "From history") { AdditionalProperties = new() { { AgentRequestMessageSourceAttribution.AdditionalPropertiesKey, new AgentRequestMessageSourceAttribution(AgentRequestMessageSourceType.ChatHistory, "HistorySource") } } }, + new(ChatRole.System, "From context provider") { AdditionalProperties = new() { { AgentRequestMessageSourceAttribution.AdditionalPropertiesKey, new AgentRequestMessageSourceAttribution(AgentRequestMessageSourceType.AIContextProvider, "ContextSource") } } }, + }; + + var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, new TestAgentSession(), new AIContext { Messages = requestMessages }); + + // Act + await provider.InvokingAsync(invokingContext, CancellationToken.None); + + // Assert - Only external messages should be used for search input + Assert.Equal("External message", capturedInput); + } + + [Fact] + public async Task InvokingAsync_CustomSearchInputFilter_OverridesDefaultAsync() + { + // Arrange + string? capturedInput = null; + Task> SearchDelegateAsync(string input, CancellationToken ct) + { + capturedInput = input; + return Task.FromResult>([]); + } + + var provider = new TextSearchProvider(SearchDelegateAsync, new TextSearchProviderOptions + { + SearchInputMessageFilter = messages => messages.Where(m => m.Role == ChatRole.System) + }); + var requestMessages = new List + { + new(ChatRole.User, "User message"), + new(ChatRole.System, "System message"), + }; + + var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, new TestAgentSession(), new AIContext { Messages = requestMessages }); + + // Act + await provider.InvokingAsync(invokingContext, CancellationToken.None); + + // Assert - Custom filter keeps only System messages + Assert.Equal("System message", capturedInput); + } + + [Fact] + public async Task InvokedAsync_DefaultFilter_ExcludesNonExternalMessagesFromStorageAsync() + { + // Arrange + var options = new TextSearchProviderOptions + { + RecentMessageMemoryLimit = 10, + RecentMessageRolesIncluded = [ChatRole.User, ChatRole.System] + }; + string? capturedInput = null; + Task> SearchDelegateAsync(string input, CancellationToken ct) + { + capturedInput = input; + return Task.FromResult>([]); + } + var provider = new TextSearchProvider(SearchDelegateAsync, options); + var session = new TestAgentSession(); + + var requestMessages = new List + { + new(ChatRole.User, "External message"), + new(ChatRole.System, "From history") { AdditionalProperties = new() { { AgentRequestMessageSourceAttribution.AdditionalPropertiesKey, new AgentRequestMessageSourceAttribution(AgentRequestMessageSourceType.ChatHistory, "HistorySource") } } }, + new(ChatRole.System, "From context provider") { AdditionalProperties = new() { { AgentRequestMessageSourceAttribution.AdditionalPropertiesKey, new AgentRequestMessageSourceAttribution(AgentRequestMessageSourceType.AIContextProvider, "ContextSource") } } }, + }; + + // Store messages via InvokedAsync + await provider.InvokedAsync(new(s_mockAgent, session, requestMessages, [])); + + // Now invoke to read stored memory + var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, session, new AIContext { Messages = [new ChatMessage(ChatRole.User, "Next")] }); + await provider.InvokingAsync(invokingContext, CancellationToken.None); + + // Assert - Only "External message" was stored in memory, so search input = "External message" + "Next" + Assert.Equal("External message\nNext", capturedInput); + } + + [Fact] + public async Task InvokedAsync_CustomStorageInputFilter_OverridesDefaultAsync() + { + // Arrange + var options = new TextSearchProviderOptions + { + RecentMessageMemoryLimit = 10, + RecentMessageRolesIncluded = [ChatRole.User, ChatRole.System], + StorageInputMessageFilter = messages => messages // No filtering - store everything + }; + string? capturedInput = null; + Task> SearchDelegateAsync(string input, CancellationToken ct) + { + capturedInput = input; + return Task.FromResult>([]); + } + var provider = new TextSearchProvider(SearchDelegateAsync, options); + var session = new TestAgentSession(); + + var requestMessages = new List + { + new(ChatRole.User, "External message"), + new(ChatRole.System, "From history") { AdditionalProperties = new() { { AgentRequestMessageSourceAttribution.AdditionalPropertiesKey, new AgentRequestMessageSourceAttribution(AgentRequestMessageSourceType.ChatHistory, "HistorySource") } } }, + }; + + // Store messages via InvokedAsync + await provider.InvokedAsync(new(s_mockAgent, session, requestMessages, [])); + + // Now invoke to read stored memory + var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, session, new AIContext { Messages = [new ChatMessage(ChatRole.User, "Next")] }); + await provider.InvokingAsync(invokingContext, CancellationToken.None); + + // Assert - Both messages stored (identity filter), so search input includes all + current + Assert.Equal("External message\nFrom history\nNext", capturedInput); + } + + #endregion + #region Recent Message Memory Tests [Fact] @@ -335,7 +514,7 @@ public async Task InvokingAsync_WithPreviousFailedRequest_ShouldNotIncludeFailed capturedInput = input; return Task.FromResult>([]); // No results needed. } - var provider = new TextSearchProvider(SearchDelegateAsync, default, null, options); + var provider = new TextSearchProvider(SearchDelegateAsync, options); // Populate memory with more messages than the limit (A,B,C,D) -> should retain B,C,D var initialMessages = new[] @@ -345,14 +524,14 @@ public async Task InvokingAsync_WithPreviousFailedRequest_ShouldNotIncludeFailed new ChatMessage(ChatRole.User, "C"), new ChatMessage(ChatRole.Assistant, "D"), }; - await provider.InvokedAsync(new(s_mockAgent, s_mockSession, initialMessages) { InvokeException = new InvalidOperationException("Request Failed") }); + + var session = new TestAgentSession(); + await provider.InvokedAsync(new(s_mockAgent, session, initialMessages, new InvalidOperationException("Request Failed"))); var invokingContext = new AIContextProvider.InvokingContext( s_mockAgent, - s_mockSession, - [ - new ChatMessage(ChatRole.User, "E") - ]); + session, + new AIContext { Messages = new List { new(ChatRole.User, "E") } }); // Act await provider.InvokingAsync(invokingContext, CancellationToken.None); @@ -377,7 +556,8 @@ public async Task InvokingAsync_WithRecentMessageMemory_ShouldIncludeStoredMessa capturedInput = input; return Task.FromResult>([]); // No results needed. } - var provider = new TextSearchProvider(SearchDelegateAsync, default, null, options); + var provider = new TextSearchProvider(SearchDelegateAsync, options); + var session = new TestAgentSession(); // Populate memory with more messages than the limit (A,B,C,D) -> should retain B,C,D var initialMessages = new[] @@ -387,14 +567,12 @@ public async Task InvokingAsync_WithRecentMessageMemory_ShouldIncludeStoredMessa new ChatMessage(ChatRole.User, "C"), new ChatMessage(ChatRole.Assistant, "D"), }; - await provider.InvokedAsync(new(s_mockAgent, s_mockSession, initialMessages)); + await provider.InvokedAsync(new(s_mockAgent, session, initialMessages, [])); var invokingContext = new AIContextProvider.InvokingContext( s_mockAgent, - s_mockSession, - [ - new ChatMessage(ChatRole.User, "E") - ]); + session, + new AIContext { Messages = new List { new(ChatRole.User, "E") } }); // Act await provider.InvokingAsync(invokingContext, CancellationToken.None); @@ -419,28 +597,31 @@ public async Task InvokingAsync_WithAccumulatedMemoryAcrossInvocations_ShouldInc capturedInput = input; return Task.FromResult>([]); } - var provider = new TextSearchProvider(SearchDelegateAsync, default, null, options); + var provider = new TextSearchProvider(SearchDelegateAsync, options); + var session = new TestAgentSession(); // First memory update (A,B) await provider.InvokedAsync(new( - s_mockAgent, - s_mockSession, - [ - new ChatMessage(ChatRole.User, "A"), - new ChatMessage(ChatRole.Assistant, "B"), - ])); + s_mockAgent, + session, + [ + new ChatMessage(ChatRole.User, "A"), + new ChatMessage(ChatRole.Assistant, "B"), + ], + [])); // Second memory update (C,D,E) await provider.InvokedAsync(new( - s_mockAgent, - s_mockSession, - [ - new ChatMessage(ChatRole.User, "C"), - new ChatMessage(ChatRole.Assistant, "D"), - new ChatMessage(ChatRole.User, "E"), - ])); + s_mockAgent, + session, + [ + new ChatMessage(ChatRole.User, "C"), + new ChatMessage(ChatRole.Assistant, "D"), + new ChatMessage(ChatRole.User, "E"), + ], + [])); - var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "F")]); + var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, session, new AIContext { Messages = new List { new(ChatRole.User, "F") } }); // Act await provider.InvokingAsync(invokingContext, CancellationToken.None); @@ -465,7 +646,8 @@ public async Task InvokingAsync_WithRecentMessageRolesIncluded_ShouldFilterRoles capturedInput = input; return Task.FromResult>([]); // No results needed for this test. } - var provider = new TextSearchProvider(SearchDelegateAsync, default, null, options); + var provider = new TextSearchProvider(SearchDelegateAsync, options); + var session = new TestAgentSession(); // Populate memory with mixed roles; only Assistant messages (A1,A2) should be retained. var initialMessages = new[] @@ -475,14 +657,12 @@ public async Task InvokingAsync_WithRecentMessageRolesIncluded_ShouldFilterRoles new ChatMessage(ChatRole.User, "U2"), new ChatMessage(ChatRole.Assistant, "A2"), }; - await provider.InvokedAsync(new(s_mockAgent, s_mockSession, initialMessages)); + await provider.InvokedAsync(new(s_mockAgent, session, initialMessages, [])); var invokingContext = new AIContextProvider.InvokingContext( s_mockAgent, - s_mockSession, - [ - new ChatMessage(ChatRole.User, "Question?") // Current request message always appended. - ]); + session, + new AIContext { Messages = new List { new(ChatRole.User, "Question?") } }); // Current request message always appended. // Act await provider.InvokingAsync(invokingContext, CancellationToken.None); @@ -496,26 +676,7 @@ public async Task InvokingAsync_WithRecentMessageRolesIncluded_ShouldFilterRoles #region Serialization Tests [Fact] - public void Serialize_WithNoRecentMessages_ShouldReturnEmptyState() - { - // Arrange - var options = new TextSearchProviderOptions - { - SearchTime = TextSearchProviderOptions.TextSearchBehavior.BeforeAIInvoke, - RecentMessageMemoryLimit = 3 - }; - var provider = new TextSearchProvider(this.NoResultSearchAsync, default, null, options); - - // Act - var state = provider.Serialize(); - - // Assert - Assert.Equal(JsonValueKind.Object, state.ValueKind); - Assert.False(state.TryGetProperty("recentMessagesText", out _)); - } - - [Fact] - public async Task Serialize_WithRecentMessages_ShouldPersistMessagesUpToLimitAsync() + public async Task InvokedAsync_ShouldPersistMessagesToSessionStateBagAsync() { // Arrange var options = new TextSearchProviderOptions @@ -524,7 +685,8 @@ public async Task Serialize_WithRecentMessages_ShouldPersistMessagesUpToLimitAsy RecentMessageMemoryLimit = 3, RecentMessageRolesIncluded = [ChatRole.User, ChatRole.Assistant] }; - var provider = new TextSearchProvider(this.NoResultSearchAsync, default, null, options); + var provider = new TextSearchProvider(this.NoResultSearchAsync, options); + var session = new TestAgentSession(); var messages = new[] { new ChatMessage(ChatRole.User, "M1"), @@ -533,11 +695,12 @@ public async Task Serialize_WithRecentMessages_ShouldPersistMessagesUpToLimitAsy }; // Act - await provider.InvokedAsync(new(s_mockAgent, s_mockSession, messages)); // Populate recent memory. - var state = provider.Serialize(); + await provider.InvokedAsync(new(s_mockAgent, session, messages, [])); // Populate recent memory. - // Assert - Assert.True(state.TryGetProperty("recentMessagesText", out var recentProperty)); + // Assert - State should be in the session's StateBag + var stateBagSerialized = session.StateBag.Serialize(); + Assert.True(stateBagSerialized.TryGetProperty("TextSearchProvider", out var stateProperty)); + Assert.True(stateProperty.TryGetProperty("recentMessagesText", out var recentProperty)); Assert.Equal(JsonValueKind.Array, recentProperty.ValueKind); var list = recentProperty.EnumerateArray().Select(e => e.GetString()).ToList(); Assert.Equal(3, list.Count); @@ -545,7 +708,7 @@ public async Task Serialize_WithRecentMessages_ShouldPersistMessagesUpToLimitAsy } [Fact] - public async Task SerializeAndDeserialize_RoundtripRestoresMessagesAsync() + public async Task StateBag_RoundtripRestoresMessagesAsync() { // Arrange var options = new TextSearchProviderOptions @@ -554,7 +717,8 @@ public async Task SerializeAndDeserialize_RoundtripRestoresMessagesAsync() RecentMessageMemoryLimit = 4, RecentMessageRolesIncluded = [ChatRole.User, ChatRole.Assistant] }; - var provider = new TextSearchProvider(this.NoResultSearchAsync, default, null, options); + var provider = new TextSearchProvider(this.NoResultSearchAsync, options); + var session = new TestAgentSession(); var messages = new[] { new ChatMessage(ChatRole.User, "A"), @@ -562,23 +726,24 @@ public async Task SerializeAndDeserialize_RoundtripRestoresMessagesAsync() new ChatMessage(ChatRole.User, "C"), new ChatMessage(ChatRole.Assistant, "D"), }; - await provider.InvokedAsync(new(s_mockAgent, s_mockSession, messages)); + await provider.InvokedAsync(new(s_mockAgent, session, messages, [])); + + // Act - Serialize and deserialize the StateBag + var serializedStateBag = session.StateBag.Serialize(); + var restoredSession = new TestAgentSession(AgentSessionStateBag.Deserialize(serializedStateBag)); - // Act - var state = provider.Serialize(); string? capturedInput = null; Task> SearchDelegate2Async(string input, CancellationToken ct) { capturedInput = input; return Task.FromResult>([]); } - var roundTrippedProvider = new TextSearchProvider(SearchDelegate2Async, state, options: new TextSearchProviderOptions + var newProvider = new TextSearchProvider(SearchDelegate2Async, new TextSearchProviderOptions { SearchTime = TextSearchProviderOptions.TextSearchBehavior.BeforeAIInvoke, RecentMessageMemoryLimit = 4 }); - var emptyMessages = Array.Empty(); - await roundTrippedProvider.InvokingAsync(new(s_mockAgent, s_mockSession, emptyMessages), CancellationToken.None); // Trigger search to read memory. + await newProvider.InvokingAsync(new(s_mockAgent, restoredSession, new AIContext()), CancellationToken.None); // Trigger search to read memory. // Assert Assert.NotNull(capturedInput); @@ -586,25 +751,10 @@ public async Task SerializeAndDeserialize_RoundtripRestoresMessagesAsync() } [Fact] - public async Task Deserialize_WithChangedLowerLimit_ShouldTruncateToNewLimitAsync() + public async Task InvokingAsync_WithEmptyStateBag_ShouldHaveNoMessagesAsync() { // Arrange - var initialProvider = new TextSearchProvider(this.NoResultSearchAsync, default, null, new TextSearchProviderOptions - { - SearchTime = TextSearchProviderOptions.TextSearchBehavior.BeforeAIInvoke, - RecentMessageMemoryLimit = 5, - RecentMessageRolesIncluded = [ChatRole.User, ChatRole.Assistant] - }); - var messages = new[] - { - new ChatMessage(ChatRole.User, "L1"), - new ChatMessage(ChatRole.Assistant, "L2"), - new ChatMessage(ChatRole.User, "L3"), - new ChatMessage(ChatRole.Assistant, "L4"), - new ChatMessage(ChatRole.User, "L5"), - }; - await initialProvider.InvokedAsync(new(s_mockAgent, s_mockSession, messages)); - var state = initialProvider.Serialize(); + var session = new TestAgentSession(); // Fresh session with empty StateBag string? capturedInput = null; Task> SearchDelegate2Async(string input, CancellationToken ct) @@ -614,43 +764,16 @@ public async Task Deserialize_WithChangedLowerLimit_ShouldTruncateToNewLimitAsyn } // Act - var restoredProvider = new TextSearchProvider(SearchDelegate2Async, state, options: new TextSearchProviderOptions - { - SearchTime = TextSearchProviderOptions.TextSearchBehavior.BeforeAIInvoke, - RecentMessageMemoryLimit = 3 // Lower limit - }); - await restoredProvider.InvokingAsync(new(s_mockAgent, s_mockSession, Array.Empty()), CancellationToken.None); - - // Assert - Assert.NotNull(capturedInput); - Assert.Equal("L1\nL2\nL3", capturedInput); - } - - [Fact] - public async Task Deserialize_WithEmptyState_ShouldHaveNoMessagesAsync() - { - // Arrange - var emptyState = JsonSerializer.Deserialize("{}", TestJsonSerializerContext.Default.JsonElement); - - string? capturedInput = null; - Task> SearchDelegate2Async(string input, CancellationToken ct) - { - capturedInput = input; - return Task.FromResult>([]); - } - - // Act - var provider = new TextSearchProvider(SearchDelegate2Async, emptyState, options: new TextSearchProviderOptions + var provider = new TextSearchProvider(SearchDelegate2Async, new TextSearchProviderOptions { SearchTime = TextSearchProviderOptions.TextSearchBehavior.BeforeAIInvoke, RecentMessageMemoryLimit = 3 }); - var emptyMessages = Array.Empty(); - await provider.InvokingAsync(new(s_mockAgent, s_mockSession, emptyMessages), CancellationToken.None); + await provider.InvokingAsync(new(s_mockAgent, session, new AIContext()), CancellationToken.None); // Assert Assert.NotNull(capturedInput); - Assert.Equal(string.Empty, capturedInput); // No recent messages serialized => empty input. + Assert.Equal(string.Empty, capturedInput); // No recent messages in StateBag => empty input. } #endregion @@ -669,4 +792,16 @@ private sealed class RawPayload { public string Id { get; set; } = string.Empty; } + + private sealed class TestAgentSession : AgentSession + { + public TestAgentSession() + { + } + + public TestAgentSession(AgentSessionStateBag stateBag) + { + this.StateBag = stateBag; + } + } } diff --git a/dotnet/tests/Microsoft.Agents.AI.UnitTests/Memory/ChatHistoryMemoryProviderTests.cs b/dotnet/tests/Microsoft.Agents.AI.UnitTests/Memory/ChatHistoryMemoryProviderTests.cs index 0a11e74528..be73260b15 100644 --- a/dotnet/tests/Microsoft.Agents.AI.UnitTests/Memory/ChatHistoryMemoryProviderTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.UnitTests/Memory/ChatHistoryMemoryProviderTests.cs @@ -3,7 +3,6 @@ using System; using System.Collections.Generic; using System.Linq; -using System.Text.Json; using System.Threading; using System.Threading.Tasks; using Microsoft.Extensions.AI; @@ -19,7 +18,6 @@ namespace Microsoft.Agents.AI.Memory.UnitTests; public class ChatHistoryMemoryProviderTests { private static readonly AIAgent s_mockAgent = new Mock().Object; - private static readonly AgentSession s_mockSession = new Mock().Object; private readonly Mock> _loggerMock; private readonly Mock _loggerFactoryMock; @@ -57,33 +55,82 @@ public ChatHistoryMemoryProviderTests() .Returns(this._vectorStoreCollectionMock.Object); } + [Fact] + public void StateKey_ReturnsDefaultKey_WhenNoOptionsProvided() + { + // Arrange & Act + var provider = new ChatHistoryMemoryProvider( + this._vectorStoreMock.Object, + TestCollectionName, + 1, + _ => new ChatHistoryMemoryProvider.State(new ChatHistoryMemoryProviderScope { UserId = "UID" })); + + // Assert + Assert.Equal("ChatHistoryMemoryProvider", provider.StateKey); + } + + [Fact] + public void StateKey_ReturnsCustomKey_WhenSetViaOptions() + { + // Arrange & Act + var provider = new ChatHistoryMemoryProvider( + this._vectorStoreMock.Object, + TestCollectionName, + 1, + _ => new ChatHistoryMemoryProvider.State(new ChatHistoryMemoryProviderScope { UserId = "UID" }), + new ChatHistoryMemoryProviderOptions { StateKey = "custom-key" }); + + // Assert + Assert.Equal("custom-key", provider.StateKey); + } + [Fact] public void Constructor_Throws_ForNullVectorStore() { // Act & Assert - Assert.Throws(() => new ChatHistoryMemoryProvider(null!, "testcollection", 1, new ChatHistoryMemoryProviderScope() { UserId = "UID" })); + Assert.Throws(() => new ChatHistoryMemoryProvider( + null!, + "testcollection", + 1, + _ => new ChatHistoryMemoryProvider.State(new ChatHistoryMemoryProviderScope { UserId = "UID" }))); } [Fact] public void Constructor_Throws_ForNullCollectionName() { // Act & Assert - Assert.Throws(() => new ChatHistoryMemoryProvider(this._vectorStoreMock.Object, null!, 1, new ChatHistoryMemoryProviderScope() { UserId = "UID" })); + Assert.Throws(() => new ChatHistoryMemoryProvider( + this._vectorStoreMock.Object, + null!, + 1, + _ => new ChatHistoryMemoryProvider.State(new ChatHistoryMemoryProviderScope { UserId = "UID" }))); } [Fact] - public void Constructor_Throws_ForNullStorageScope() + public void Constructor_Throws_ForNullStateInitializer() { // Act & Assert - Assert.Throws(() => new ChatHistoryMemoryProvider(this._vectorStoreMock.Object, "testcollection", 1, null!)); + Assert.Throws(() => new ChatHistoryMemoryProvider( + this._vectorStoreMock.Object, + "testcollection", + 1, + null!)); } [Fact] public void Constructor_Throws_ForInvalidVectorDimensions() { // Act & Assert - Assert.Throws(() => new ChatHistoryMemoryProvider(this._vectorStoreMock.Object, "testcollection", 0, new ChatHistoryMemoryProviderScope() { UserId = "UID" })); - Assert.Throws(() => new ChatHistoryMemoryProvider(this._vectorStoreMock.Object, "testcollection", -5, new ChatHistoryMemoryProviderScope() { UserId = "UID" })); + Assert.Throws(() => new ChatHistoryMemoryProvider( + this._vectorStoreMock.Object, + "testcollection", + 0, + _ => new ChatHistoryMemoryProvider.State(new ChatHistoryMemoryProviderScope { UserId = "UID" }))); + Assert.Throws(() => new ChatHistoryMemoryProvider( + this._vectorStoreMock.Object, + "testcollection", + -5, + _ => new ChatHistoryMemoryProvider.State(new ChatHistoryMemoryProviderScope { UserId = "UID" }))); } #region InvokedAsync Tests @@ -113,16 +160,17 @@ public async Task InvokedAsync_UpsertsMessages_ToCollectionAsync() UserId = "user1" }; - var provider = new ChatHistoryMemoryProvider(this._vectorStoreMock.Object, TestCollectionName, 1, storeScope); + var provider = new ChatHistoryMemoryProvider( + this._vectorStoreMock.Object, + TestCollectionName, + 1, + _ => new ChatHistoryMemoryProvider.State(storeScope)); var requestMsgWithValues = new ChatMessage(ChatRole.User, "request text") { MessageId = "req-1", AuthorName = "user1", CreatedAt = new DateTimeOffset(new DateTime(2000, 1, 1), TimeSpan.Zero) }; var requestMsgWithNulls = new ChatMessage(ChatRole.User, "request text nulls"); var responseMsg = new ChatMessage(ChatRole.Assistant, "response text") { MessageId = "resp-1", AuthorName = "assistant" }; - var invokedContext = new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, [requestMsgWithValues, requestMsgWithNulls]) - { - ResponseMessages = [responseMsg] - }; + var invokedContext = new AIContextProvider.InvokedContext(s_mockAgent, new TestAgentSession(), [requestMsgWithValues, requestMsgWithNulls], [responseMsg]); // Act await provider.InvokedAsync(invokedContext, CancellationToken.None); @@ -175,12 +223,9 @@ public async Task InvokedAsync_DoesNotUpsertMessages_WhenInvokeFailedAsync() this._vectorStoreMock.Object, TestCollectionName, 1, - new ChatHistoryMemoryProviderScope() { UserId = "UID" }); + _ => new ChatHistoryMemoryProvider.State(new ChatHistoryMemoryProviderScope { UserId = "UID" })); var requestMsg = new ChatMessage(ChatRole.User, "request text") { MessageId = "req-1" }; - var invokedContext = new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, [requestMsg]) - { - InvokeException = new InvalidOperationException("Invoke failed") - }; + var invokedContext = new AIContextProvider.InvokedContext(s_mockAgent, new TestAgentSession(), [requestMsg], new InvalidOperationException("Invoke failed")); // Act await provider.InvokedAsync(invokedContext, CancellationToken.None); @@ -203,10 +248,10 @@ public async Task InvokedAsync_DoesNotThrow_WhenUpsertThrowsAsync() this._vectorStoreMock.Object, TestCollectionName, 1, - new ChatHistoryMemoryProviderScope() { UserId = "UID" }, + _ => new ChatHistoryMemoryProvider.State(new ChatHistoryMemoryProviderScope { UserId = "UID" }), loggerFactory: this._loggerFactoryMock.Object); var requestMsg = new ChatMessage(ChatRole.User, "request text") { MessageId = "req-1" }; - var invokedContext = new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, [requestMsg]); + var invokedContext = new AIContextProvider.InvokedContext(s_mockAgent, new TestAgentSession(), [requestMsg], []); // Act await provider.InvokedAsync(invokedContext, CancellationToken.None); @@ -252,12 +297,12 @@ public async Task InvokedAsync_LogsUserIdBasedOnEnableSensitiveTelemetryDataAsyn this._vectorStoreMock.Object, TestCollectionName, 1, - new ChatHistoryMemoryProviderScope { UserId = "user1" }, + _ => new ChatHistoryMemoryProvider.State(new ChatHistoryMemoryProviderScope { UserId = "user1" }), options: options, loggerFactory: this._loggerFactoryMock.Object); var requestMsg = new ChatMessage(ChatRole.User, "request text"); - var invokedContext = new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, [requestMsg]); + var invokedContext = new AIContextProvider.InvokedContext(s_mockAgent, new TestAgentSession(), [requestMsg], []); // Act await provider.InvokedAsync(invokedContext, CancellationToken.None); @@ -326,14 +371,14 @@ public async Task InvokedAsync_SearchesVectorStoreAsync() this._vectorStoreMock.Object, TestCollectionName, 1, - new ChatHistoryMemoryProviderScope() { UserId = "UID" }, + _ => new ChatHistoryMemoryProvider.State(new ChatHistoryMemoryProviderScope { UserId = "UID" }), options: providerOptions); var requestMsg = new ChatMessage(ChatRole.User, "requesting relevant history"); - var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [requestMsg]); + var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, new TestAgentSession(), new AIContext { Messages = new List { requestMsg } }); // Act - await provider.InvokingAsync(invokingContext, CancellationToken.None); + var aiContext = await provider.InvokingAsync(invokingContext, CancellationToken.None); // Assert this._vectorStoreCollectionMock.Verify( @@ -343,6 +388,12 @@ public async Task InvokedAsync_SearchesVectorStoreAsync() It.IsAny>>(), It.IsAny()), Times.Once); + + Assert.NotNull(aiContext.Messages); + var messages = aiContext.Messages.ToList(); + Assert.Equal(2, messages.Count); + Assert.Equal(AgentRequestMessageSourceType.External, messages[0].GetAgentRequestMessageSourceType()); + Assert.Equal(AgentRequestMessageSourceType.AIContextProvider, messages[1].GetAgentRequestMessageSourceType()); } [Fact] @@ -378,10 +429,15 @@ public async Task InvokedAsync_CreatesFilter_WhenSearchScopeProvidedAsync() }) .Returns(ToAsyncEnumerableAsync(new List>>())); - var provider = new ChatHistoryMemoryProvider(this._vectorStoreMock.Object, TestCollectionName, 1, options: providerOptions, storageScope: searchScope, searchScope: searchScope); + var provider = new ChatHistoryMemoryProvider( + this._vectorStoreMock.Object, + TestCollectionName, + 1, + _ => new ChatHistoryMemoryProvider.State(searchScope, searchScope), + options: providerOptions); var requestMsg = new ChatMessage(ChatRole.User, "requesting relevant history"); - var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [requestMsg]); + var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, new TestAgentSession(), new AIContext { Messages = new List { requestMsg } }); // Act await provider.InvokingAsync(invokingContext, CancellationToken.None); @@ -440,12 +496,11 @@ public async Task InvokingAsync_LogsUserIdBasedOnEnableSensitiveTelemetryDataAsy this._vectorStoreMock.Object, TestCollectionName, 1, - storageScope: scope, - searchScope: scope, + _ => new ChatHistoryMemoryProvider.State(scope, scope), options: options, loggerFactory: this._loggerFactoryMock.Object); - var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "requesting relevant history")]); + var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, new TestAgentSession(), new AIContext { Messages = new List { new(ChatRole.User, "requesting relevant history") } }); // Act await provider.InvokingAsync(invokingContext, CancellationToken.None); @@ -479,52 +534,178 @@ public async Task InvokingAsync_LogsUserIdBasedOnEnableSensitiveTelemetryDataAsy #endregion - #region Serialization Tests + #region Message Filter Tests [Fact] - public void Serialize_Deserialize_RoundtripsScopes() + public async Task InvokingAsync_DefaultFilter_ExcludesNonExternalMessagesFromSearchAsync() { // Arrange - var storageScope = new ChatHistoryMemoryProviderScope + var providerOptions = new ChatHistoryMemoryProviderOptions { - ApplicationId = "app", - AgentId = "agent", - SessionId = "session", - UserId = "user" + SearchTime = ChatHistoryMemoryProviderOptions.SearchBehavior.BeforeAIInvoke, }; - var searchScope = new ChatHistoryMemoryProviderScope + string? capturedQuery = null; + this._vectorStoreCollectionMock + .Setup(c => c.SearchAsync( + It.IsAny(), + It.IsAny(), + It.IsAny>>(), + It.IsAny())) + .Callback>, CancellationToken>((query, _, _, _) => capturedQuery = query) + .Returns(ToAsyncEnumerableAsync(new List>>())); + + var provider = new ChatHistoryMemoryProvider( + this._vectorStoreMock.Object, + TestCollectionName, + 1, + _ => new ChatHistoryMemoryProvider.State(new ChatHistoryMemoryProviderScope { UserId = "UID" }), + options: providerOptions); + + var requestMessages = new List + { + new(ChatRole.User, "External message"), + new(ChatRole.System, "From history") { AdditionalProperties = new() { { AgentRequestMessageSourceAttribution.AdditionalPropertiesKey, new AgentRequestMessageSourceAttribution(AgentRequestMessageSourceType.ChatHistory, "HistorySource") } } }, + new(ChatRole.System, "From context provider") { AdditionalProperties = new() { { AgentRequestMessageSourceAttribution.AdditionalPropertiesKey, new AgentRequestMessageSourceAttribution(AgentRequestMessageSourceType.AIContextProvider, "ContextSource") } } }, + }; + + var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, new TestAgentSession(), new AIContext { Messages = requestMessages }); + + // Act + await provider.InvokingAsync(invokingContext, CancellationToken.None); + + // Assert - Only External message used for search query + Assert.Equal("External message", capturedQuery); + } + + [Fact] + public async Task InvokingAsync_CustomSearchInputFilter_OverridesDefaultAsync() + { + // Arrange + var providerOptions = new ChatHistoryMemoryProviderOptions + { + SearchTime = ChatHistoryMemoryProviderOptions.SearchBehavior.BeforeAIInvoke, + SearchInputMessageFilter = messages => messages // No filtering + }; + + string? capturedQuery = null; + this._vectorStoreCollectionMock + .Setup(c => c.SearchAsync( + It.IsAny(), + It.IsAny(), + It.IsAny>>(), + It.IsAny())) + .Callback>, CancellationToken>((query, _, _, _) => capturedQuery = query) + .Returns(ToAsyncEnumerableAsync(new List>>())); + + var provider = new ChatHistoryMemoryProvider( + this._vectorStoreMock.Object, + TestCollectionName, + 1, + _ => new ChatHistoryMemoryProvider.State(new ChatHistoryMemoryProviderScope { UserId = "UID" }), + options: providerOptions); + + var requestMessages = new List { - ApplicationId = "app2", - AgentId = "agent2", - SessionId = "session2", - UserId = "user2" + new(ChatRole.User, "External message"), + new(ChatRole.System, "From history") { AdditionalProperties = new() { { AgentRequestMessageSourceAttribution.AdditionalPropertiesKey, new AgentRequestMessageSourceAttribution(AgentRequestMessageSourceType.ChatHistory, "HistorySource") } } }, }; - var provider = new ChatHistoryMemoryProvider(this._vectorStoreMock.Object, TestCollectionName, 1, storageScope: storageScope, searchScope: searchScope); + var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, new TestAgentSession(), new AIContext { Messages = requestMessages }); // Act - var stateElement = provider.Serialize(); - - using JsonDocument doc = JsonDocument.Parse(stateElement.GetRawText()); - var storage = doc.RootElement.GetProperty("storageScope"); - Assert.Equal("app", storage.GetProperty("applicationId").GetString()); - Assert.Equal("agent", storage.GetProperty("agentId").GetString()); - Assert.Equal("session", storage.GetProperty("sessionId").GetString()); - Assert.Equal("user", storage.GetProperty("userId").GetString()); - - var search = doc.RootElement.GetProperty("searchScope"); - Assert.Equal("app2", search.GetProperty("applicationId").GetString()); - Assert.Equal("agent2", search.GetProperty("agentId").GetString()); - Assert.Equal("session2", search.GetProperty("sessionId").GetString()); - Assert.Equal("user2", search.GetProperty("userId").GetString()); - - // Act - deserialize and serialize again - var provider2 = new ChatHistoryMemoryProvider(this._vectorStoreMock.Object, TestCollectionName, 1, serializedState: stateElement); - var stateElement2 = provider2.Serialize(); - - // Assert - roundtrip the state - Assert.Equal(stateElement.GetRawText(), stateElement2.GetRawText()); + await provider.InvokingAsync(invokingContext, CancellationToken.None); + + // Assert - Both messages should be included in search query (identity filter) + Assert.NotNull(capturedQuery); + Assert.Contains("External message", capturedQuery); + Assert.Contains("From history", capturedQuery); + } + + [Fact] + public async Task InvokedAsync_DefaultFilter_ExcludesNonExternalMessagesFromStorageAsync() + { + // Arrange + var stored = new List>(); + + this._vectorStoreCollectionMock + .Setup(c => c.UpsertAsync(It.IsAny>>(), It.IsAny())) + .Callback>, CancellationToken>((items, ct) => + { + if (items != null) + { + stored.AddRange(items); + } + }) + .Returns(Task.CompletedTask); + + var provider = new ChatHistoryMemoryProvider( + this._vectorStoreMock.Object, + TestCollectionName, + 1, + _ => new ChatHistoryMemoryProvider.State(new ChatHistoryMemoryProviderScope { UserId = "UID" })); + + var requestMessages = new List + { + new(ChatRole.User, "External message"), + new(ChatRole.System, "From history") { AdditionalProperties = new() { { AgentRequestMessageSourceAttribution.AdditionalPropertiesKey, new AgentRequestMessageSourceAttribution(AgentRequestMessageSourceType.ChatHistory, "HistorySource") } } }, + new(ChatRole.System, "From context provider") { AdditionalProperties = new() { { AgentRequestMessageSourceAttribution.AdditionalPropertiesKey, new AgentRequestMessageSourceAttribution(AgentRequestMessageSourceType.AIContextProvider, "ContextSource") } } }, + }; + + var invokedContext = new AIContextProvider.InvokedContext(s_mockAgent, new TestAgentSession(), requestMessages, [new ChatMessage(ChatRole.Assistant, "Response")]); + + // Act + await provider.InvokedAsync(invokedContext, CancellationToken.None); + + // Assert - Only External message + response stored (ChatHistory and AIContextProvider excluded by default) + Assert.Equal(2, stored.Count); + Assert.Equal("External message", stored[0]["Content"]); + Assert.Equal("Response", stored[1]["Content"]); + } + + [Fact] + public async Task InvokedAsync_CustomStorageInputFilter_OverridesDefaultAsync() + { + // Arrange + var stored = new List>(); + + this._vectorStoreCollectionMock + .Setup(c => c.UpsertAsync(It.IsAny>>(), It.IsAny())) + .Callback>, CancellationToken>((items, ct) => + { + if (items != null) + { + stored.AddRange(items); + } + }) + .Returns(Task.CompletedTask); + + var provider = new ChatHistoryMemoryProvider( + this._vectorStoreMock.Object, + TestCollectionName, + 1, + _ => new ChatHistoryMemoryProvider.State(new ChatHistoryMemoryProviderScope { UserId = "UID" }), + options: new ChatHistoryMemoryProviderOptions + { + StorageInputMessageFilter = messages => messages // No filtering - store everything + }); + + var requestMessages = new List + { + new(ChatRole.User, "External message"), + new(ChatRole.System, "From history") { AdditionalProperties = new() { { AgentRequestMessageSourceAttribution.AdditionalPropertiesKey, new AgentRequestMessageSourceAttribution(AgentRequestMessageSourceType.ChatHistory, "HistorySource") } } }, + }; + + var invokedContext = new AIContextProvider.InvokedContext(s_mockAgent, new TestAgentSession(), requestMessages, [new ChatMessage(ChatRole.Assistant, "Response")]); + + // Act + await provider.InvokedAsync(invokedContext, CancellationToken.None); + + // Assert - All messages stored (identity filter overrides default) + Assert.Equal(3, stored.Count); + Assert.Equal("External message", stored[0]["Content"]); + Assert.Equal("From history", stored[1]["Content"]); + Assert.Equal("Response", stored[2]["Content"]); } #endregion @@ -537,4 +718,16 @@ private static async IAsyncEnumerable ToAsyncEnumerableAsync(IEnumerable))] [JsonSerializable(typeof(ChatClientAgentSessionTests.Animal))] +[JsonSerializable(typeof(ChatClientAgentSession))] internal sealed partial class TestJsonSerializerContext : JsonSerializerContext; diff --git a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/AgentWorkflowBuilderTests.cs b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/AgentWorkflowBuilderTests.cs index 3dfe605f2e..eb017f07a3 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/AgentWorkflowBuilderTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/AgentWorkflowBuilderTests.cs @@ -161,7 +161,7 @@ protected override async IAsyncEnumerable RunCoreStreamingA } } - private sealed class DoubleEchoAgentSession() : InMemoryAgentSession(); + private sealed class DoubleEchoAgentSession() : AgentSession(); [Fact] public async Task BuildConcurrent_AgentsRunInParallelAsync() diff --git a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/InProcessExecutionTests.cs b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/InProcessExecutionTests.cs index dc51338aa3..e45cc39d2f 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/InProcessExecutionTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/InProcessExecutionTests.cs @@ -195,5 +195,5 @@ protected override async IAsyncEnumerable RunCoreStreamingA /// /// Simple session implementation for SimpleTestAgent. /// - private sealed class SimpleTestAgentSession : InMemoryAgentSession; + private sealed class SimpleTestAgentSession : AgentSession; } diff --git a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/RoleCheckAgent.cs b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/RoleCheckAgent.cs index dde1d1feed..385c427b37 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/RoleCheckAgent.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/RoleCheckAgent.cs @@ -46,5 +46,5 @@ protected override async IAsyncEnumerable RunCoreStreamingA }; } - private sealed class RoleCheckAgentSession : InMemoryAgentSession; + private sealed class RoleCheckAgentSession : AgentSession; } diff --git a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/Sample/06_GroupChat_Workflow.cs b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/Sample/06_GroupChat_Workflow.cs index afade362a1..1e00e2e649 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/Sample/06_GroupChat_Workflow.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/Sample/06_GroupChat_Workflow.cs @@ -90,4 +90,4 @@ protected override async IAsyncEnumerable RunCoreStreamingA } } -internal sealed class HelloAgentSession() : InMemoryAgentSession(); +internal sealed class HelloAgentSession() : AgentSession(); diff --git a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/TestEchoAgent.cs b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/TestEchoAgent.cs index 9d5eca42bf..b1c7c8cbe3 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/TestEchoAgent.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/TestEchoAgent.cs @@ -5,6 +5,7 @@ using System.Linq; using System.Runtime.CompilerServices; using System.Text.Json; +using System.Text.Json.Serialization; using System.Threading; using System.Threading.Tasks; using Microsoft.Extensions.AI; @@ -16,6 +17,8 @@ internal class TestEchoAgent(string? id = null, string? name = null, string? pre protected override string? IdCore => id; public override string? Name => name ?? base.Name; + public InMemoryChatHistoryProvider ChatHistoryProvider { get; } = new(); + protected override async ValueTask DeserializeSessionCoreAsync(JsonElement serializedState, JsonSerializerOptions? jsonSerializerOptions = null, CancellationToken cancellationToken = default) { return serializedState.Deserialize(jsonSerializerOptions) ?? await this.CreateSessionAsync(cancellationToken); @@ -28,15 +31,15 @@ protected override ValueTask SerializeSessionCoreAsync(AgentSession throw new InvalidOperationException("The provided session is not compatible with the agent. Only sessions created by the agent can be serialized."); } - return new(typedSession.Serialize(jsonSerializerOptions)); + return new(JsonSerializer.SerializeToElement(typedSession, jsonSerializerOptions)); } protected override ValueTask CreateSessionCoreAsync(CancellationToken cancellationToken = default) => new(new EchoAgentSession()); - private static ChatMessage UpdateSession(ChatMessage message, InMemoryAgentSession? session = null) + private ChatMessage UpdateSession(ChatMessage message, AgentSession? session = null) { - session?.ChatHistoryProvider.Add(message); + this.ChatHistoryProvider.GetMessages(session).Add(message); return message; } @@ -45,7 +48,7 @@ private IEnumerable EchoMessages(IEnumerable messages, { foreach (ChatMessage message in messages) { - UpdateSession(message, session as InMemoryAgentSession); + this.UpdateSession(message, session); } IEnumerable echoMessages @@ -53,14 +56,14 @@ IEnumerable echoMessages where message.Role == ChatRole.User && !string.IsNullOrEmpty(message.Text) select - UpdateSession(new ChatMessage(ChatRole.Assistant, $"{prefix}{message.Text}") + this.UpdateSession(new ChatMessage(ChatRole.Assistant, $"{prefix}{message.Text}") { AuthorName = this.Name ?? this.Id, CreatedAt = DateTimeOffset.Now, MessageId = Guid.NewGuid().ToString("N") - }, session as InMemoryAgentSession); + }, session); - return echoMessages.Concat(this.GetEpilogueMessages(options).Select(m => UpdateSession(m, session as InMemoryAgentSession))); + return echoMessages.Concat(this.GetEpilogueMessages(options).Select(m => this.UpdateSession(m, session))); } protected virtual IEnumerable GetEpilogueMessages(AgentRunOptions? options = null) @@ -99,11 +102,11 @@ protected override async IAsyncEnumerable RunCoreStreamingA } } - private sealed class EchoAgentSession : InMemoryAgentSession + private sealed class EchoAgentSession : AgentSession { - internal new JsonElement Serialize(JsonSerializerOptions? jsonSerializerOptions = null) - { - return base.Serialize(jsonSerializerOptions); - } + internal EchoAgentSession() { } + + [JsonConstructor] + internal EchoAgentSession(AgentSessionStateBag stateBag) : base(stateBag) { } } } diff --git a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/TestReplayAgent.cs b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/TestReplayAgent.cs index 2dd33a67e4..3a117b9492 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/TestReplayAgent.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/TestReplayAgent.cs @@ -104,5 +104,5 @@ protected override async IAsyncEnumerable RunCoreStreamingA return candidateMessages; } - private sealed class ReplayAgentSession() : InMemoryAgentSession(); + private sealed class ReplayAgentSession() : AgentSession(); } diff --git a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/TestRequestAgent.cs b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/TestRequestAgent.cs index 65a49add96..0b49f4de05 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/TestRequestAgent.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/TestRequestAgent.cs @@ -330,7 +330,7 @@ static TRequest AssertAndExtractRequestContent(ExternalRequest request } } - private sealed class TestRequestAgentSession : InMemoryAgentSession + private sealed class TestRequestAgentSession : AgentSession where TRequest : AIContent where TResponse : AIContent { @@ -343,19 +343,13 @@ public TestRequestAgentSession() public HashSet ServicedRequests { get; } = new(); public HashSet PairedRequests { get; } = new(); - private static JsonElement DeserializeAndExtractState(JsonElement serializedState, - out TestRequestAgentSessionState state, - JsonSerializerOptions? jsonSerializerOptions = null) + public TestRequestAgentSession(JsonElement element, JsonSerializerOptions? jsonSerializerOptions = null) { - state = JsonSerializer.Deserialize(serializedState, jsonSerializerOptions) + var state = JsonSerializer.Deserialize(element, jsonSerializerOptions) ?? throw new ArgumentException("Unable to deserialize session state."); - return state.SessionState; - } + this.StateBag = AgentSessionStateBag.Deserialize(state.SessionState); - public TestRequestAgentSession(JsonElement element, JsonSerializerOptions? jsonSerializerOptions = null) - : base(DeserializeAndExtractState(element, out TestRequestAgentSessionState state, jsonSerializerOptions)) - { this.UnservicedRequests = state.UnservicedRequests.ToDictionary( keySelector: item => item.Key, elementSelector: item => item.Value.As()!); @@ -364,9 +358,9 @@ public TestRequestAgentSession(JsonElement element, JsonSerializerOptions? jsonS this.PairedRequests = state.PairedRequests; } - protected override JsonElement Serialize(JsonSerializerOptions? jsonSerializerOptions = null) + internal JsonElement Serialize(JsonSerializerOptions? jsonSerializerOptions = null) { - JsonElement sessionState = base.Serialize(jsonSerializerOptions); + JsonElement sessionState = this.StateBag.Serialize(); Dictionary portableUnservicedRequests = this.UnservicedRequests.ToDictionary( diff --git a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/WorkflowHostSmokeTests.cs b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/WorkflowHostSmokeTests.cs index 5a041699d1..a68a5bac75 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/WorkflowHostSmokeTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/WorkflowHostSmokeTests.cs @@ -32,18 +32,16 @@ public class WorkflowHostSmokeTests { private sealed class AlwaysFailsAIAgent(bool failByThrowing) : AIAgent { - private sealed class Session : InMemoryAgentSession + private sealed class Session : AgentSession { public Session() { } - public Session(JsonElement serializedSession, JsonSerializerOptions? jsonSerializerOptions = null) - : base(serializedSession, jsonSerializerOptions) - { } + public Session(AgentSessionStateBag stateBag) : base(stateBag) { } } protected override ValueTask DeserializeSessionCoreAsync(JsonElement serializedState, JsonSerializerOptions? jsonSerializerOptions = null, CancellationToken cancellationToken = default) { - return new(new Session(serializedState, jsonSerializerOptions)); + return new(serializedState.Deserialize(jsonSerializerOptions)!); } protected override ValueTask CreateSessionCoreAsync(CancellationToken cancellationToken = default) diff --git a/dotnet/tests/OpenAIChatCompletion.IntegrationTests/OpenAIChatCompletionFixture.cs b/dotnet/tests/OpenAIChatCompletion.IntegrationTests/OpenAIChatCompletionFixture.cs index 304df28fba..96a9a17ae8 100644 --- a/dotnet/tests/OpenAIChatCompletion.IntegrationTests/OpenAIChatCompletionFixture.cs +++ b/dotnet/tests/OpenAIChatCompletion.IntegrationTests/OpenAIChatCompletionFixture.cs @@ -30,14 +30,14 @@ public OpenAIChatCompletionFixture(bool useReasoningChatModel) public async Task> GetChatHistoryAsync(AIAgent agent, AgentSession session) { - var typedSession = (ChatClientAgentSession)session; + var chatHistoryProvider = agent.GetService(); - if (typedSession.ChatHistoryProvider is null) + if (chatHistoryProvider is null) { return []; } - return (await typedSession.ChatHistoryProvider.InvokingAsync(new(agent, session, []))).ToList(); + return (await chatHistoryProvider.InvokingAsync(new(agent, session, []))).ToList(); } public Task CreateChatClientAgentAsync( diff --git a/dotnet/tests/OpenAIResponse.IntegrationTests/OpenAIResponseFixture.cs b/dotnet/tests/OpenAIResponse.IntegrationTests/OpenAIResponseFixture.cs index 719db6a0b0..e36f8990f6 100644 --- a/dotnet/tests/OpenAIResponse.IntegrationTests/OpenAIResponseFixture.cs +++ b/dotnet/tests/OpenAIResponse.IntegrationTests/OpenAIResponseFixture.cs @@ -50,12 +50,14 @@ public async Task> GetChatHistoryAsync(AIAgent agent, AgentSes return [.. previousMessages, responseMessage]; } - if (typedSession.ChatHistoryProvider is null) + var chatHistoryProvider = agent.GetService(); + + if (chatHistoryProvider is null) { return []; } - return (await typedSession.ChatHistoryProvider.InvokingAsync(new(agent, session, []))).ToList(); + return (await chatHistoryProvider.InvokingAsync(new(agent, session, []))).ToList(); } private static ChatMessage ConvertToChatMessage(ResponseItem item)