diff --git a/dotnet/samples/GettingStarted/AgentProviders/Agent_With_CustomImplementation/Program.cs b/dotnet/samples/GettingStarted/AgentProviders/Agent_With_CustomImplementation/Program.cs index 5d4e77474a..82f76e7599 100644 --- a/dotnet/samples/GettingStarted/AgentProviders/Agent_With_CustomImplementation/Program.cs +++ b/dotnet/samples/GettingStarted/AgentProviders/Agent_With_CustomImplementation/Program.cs @@ -55,14 +55,14 @@ protected override async Task RunCoreAsync(IEnumerable responseMessages = CloneAndToUpperCase(messages, this.Name).ToList(); // Notify the session of the input and output messages. - var invokedContext = new ChatHistoryProvider.InvokedContext(messages, storeMessages) + var invokedContext = new ChatHistoryProvider.InvokedContext(this, session, messages, storeMessages) { ResponseMessages = responseMessages }; @@ -87,14 +87,14 @@ protected override async IAsyncEnumerable RunCoreStreamingA } // Get existing messages from the store - var invokingContext = new ChatHistoryProvider.InvokingContext(messages); + var invokingContext = new ChatHistoryProvider.InvokingContext(this, session, messages); var storeMessages = await typedSession.ChatHistoryProvider.InvokingAsync(invokingContext, cancellationToken); // Clone the input messages and turn them into response messages with upper case text. List responseMessages = CloneAndToUpperCase(messages, this.Name).ToList(); // Notify the session of the input and output messages. - var invokedContext = new ChatHistoryProvider.InvokedContext(messages, storeMessages) + var invokedContext = new ChatHistoryProvider.InvokedContext(this, session, messages, storeMessages) { ResponseMessages = responseMessages }; diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/AIContextProvider.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/AIContextProvider.cs index f104f12890..f79b0a851d 100644 --- a/dotnet/src/Microsoft.Agents.AI.Abstractions/AIContextProvider.cs +++ b/dotnet/src/Microsoft.Agents.AI.Abstractions/AIContextProvider.cs @@ -129,13 +129,30 @@ public sealed class InvokingContext /// /// Initializes a new instance of the class with the specified request messages. /// + /// The agent being invoked. + /// The session associated with the agent invocation. /// The messages to be used by the agent for this invocation. /// is . - public InvokingContext(IEnumerable requestMessages) + public InvokingContext( + AIAgent agent, + AgentSession? session, + IEnumerable requestMessages) { - this.RequestMessages = requestMessages ?? throw new ArgumentNullException(nameof(requestMessages)); + this.Agent = Throw.IfNull(agent); + this.Session = session; + this.RequestMessages = Throw.IfNull(requestMessages); } + /// + /// Gets the agent that is being invoked. + /// + public AIAgent Agent { get; } + + /// + /// Gets the agent session associated with the agent invocation. + /// + public AgentSession? Session { get; } + /// /// Gets the caller provided messages that will be used by the agent for this invocation. /// @@ -158,15 +175,33 @@ public sealed class InvokedContext /// /// Initializes a new instance of the class with the specified request messages. /// + /// The agent being invoked. + /// The session associated with the agent invocation. /// The caller provided messages that were used by the agent for this invocation. /// The messages provided by the for this invocation, if any. /// is . - public InvokedContext(IEnumerable requestMessages, IEnumerable? aiContextProviderMessages) + public InvokedContext( + AIAgent agent, + AgentSession? session, + IEnumerable requestMessages, + IEnumerable? aiContextProviderMessages) { - this.RequestMessages = requestMessages ?? throw new ArgumentNullException(nameof(requestMessages)); + this.Agent = Throw.IfNull(agent); + this.Session = session; + this.RequestMessages = Throw.IfNull(requestMessages); this.AIContextProviderMessages = aiContextProviderMessages; } + /// + /// Gets the agent that is being invoked. + /// + public AIAgent Agent { get; } + + /// + /// Gets the agent session associated with the agent invocation. + /// + public AgentSession? Session { get; } + /// /// Gets the caller provided messages that were used by the agent for this invocation. /// diff --git a/dotnet/src/Microsoft.Agents.AI.Abstractions/ChatHistoryProvider.cs b/dotnet/src/Microsoft.Agents.AI.Abstractions/ChatHistoryProvider.cs index d809582ea4..352bae3355 100644 --- a/dotnet/src/Microsoft.Agents.AI.Abstractions/ChatHistoryProvider.cs +++ b/dotnet/src/Microsoft.Agents.AI.Abstractions/ChatHistoryProvider.cs @@ -143,13 +143,30 @@ public sealed class InvokingContext /// /// Initializes a new instance of the class with the specified request messages. /// + /// The agent being invoked. + /// The session associated with the agent invocation. /// The new messages to be used by the agent for this invocation. /// is . - public InvokingContext(IEnumerable requestMessages) + public InvokingContext( + AIAgent agent, + AgentSession? session, + IEnumerable requestMessages) { - this.RequestMessages = requestMessages ?? throw new ArgumentNullException(nameof(requestMessages)); + this.Agent = Throw.IfNull(agent); + this.Session = session; + this.RequestMessages = Throw.IfNull(requestMessages); } + /// + /// Gets the agent that is being invoked. + /// + public AIAgent Agent { get; } + + /// + /// Gets the agent session associated with the agent invocation. + /// + public AgentSession? Session { get; } + /// /// Gets the caller provided messages that will be used by the agent for this invocation. /// @@ -172,15 +189,33 @@ public sealed class InvokedContext /// /// Initializes a new instance of the class with the specified request messages. /// + /// The agent being invoked. + /// The session associated with the agent invocation. /// The caller provided messages that were used by the agent for this invocation. /// The messages retrieved from the for this invocation. /// is . - public InvokedContext(IEnumerable requestMessages, IEnumerable? chatHistoryProviderMessages) + public InvokedContext( + AIAgent agent, + AgentSession? session, + IEnumerable requestMessages, + IEnumerable? chatHistoryProviderMessages) { + this.Agent = Throw.IfNull(agent); + this.Session = session; this.RequestMessages = Throw.IfNull(requestMessages); this.ChatHistoryProviderMessages = chatHistoryProviderMessages; } + /// + /// Gets the agent that is being invoked. + /// + public AIAgent Agent { get; } + + /// + /// Gets the agent session associated with the agent invocation. + /// + public AgentSession? Session { get; } + /// /// Gets the caller provided messages that were used by the agent for this invocation. /// diff --git a/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgent.cs b/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgent.cs index ee6db4830d..23e120f14f 100644 --- a/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgent.cs +++ b/dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgent.cs @@ -231,8 +231,8 @@ protected override async IAsyncEnumerable RunCoreStreamingA } catch (Exception ex) { - await NotifyChatHistoryProviderOfFailureAsync(safeSession, ex, GetInputMessages(inputMessages, continuationToken), chatHistoryProviderMessages, aiContextProviderMessages, chatOptions, cancellationToken).ConfigureAwait(false); - await NotifyAIContextProviderOfFailureAsync(safeSession, ex, GetInputMessages(inputMessages, continuationToken), aiContextProviderMessages, cancellationToken).ConfigureAwait(false); + await this.NotifyChatHistoryProviderOfFailureAsync(safeSession, ex, GetInputMessages(inputMessages, continuationToken), chatHistoryProviderMessages, aiContextProviderMessages, chatOptions, cancellationToken).ConfigureAwait(false); + await this.NotifyAIContextProviderOfFailureAsync(safeSession, ex, GetInputMessages(inputMessages, continuationToken), aiContextProviderMessages, cancellationToken).ConfigureAwait(false); throw; } @@ -246,8 +246,8 @@ protected override async IAsyncEnumerable RunCoreStreamingA } catch (Exception ex) { - await NotifyChatHistoryProviderOfFailureAsync(safeSession, ex, GetInputMessages(inputMessages, continuationToken), chatHistoryProviderMessages, aiContextProviderMessages, chatOptions, cancellationToken).ConfigureAwait(false); - await NotifyAIContextProviderOfFailureAsync(safeSession, ex, GetInputMessages(inputMessages, continuationToken), aiContextProviderMessages, cancellationToken).ConfigureAwait(false); + await this.NotifyChatHistoryProviderOfFailureAsync(safeSession, ex, GetInputMessages(inputMessages, continuationToken), chatHistoryProviderMessages, aiContextProviderMessages, chatOptions, cancellationToken).ConfigureAwait(false); + await this.NotifyAIContextProviderOfFailureAsync(safeSession, ex, GetInputMessages(inputMessages, continuationToken), aiContextProviderMessages, cancellationToken).ConfigureAwait(false); throw; } @@ -273,8 +273,8 @@ protected override async IAsyncEnumerable RunCoreStreamingA } catch (Exception ex) { - await NotifyChatHistoryProviderOfFailureAsync(safeSession, ex, GetInputMessages(inputMessages, continuationToken), chatHistoryProviderMessages, aiContextProviderMessages, chatOptions, cancellationToken).ConfigureAwait(false); - await NotifyAIContextProviderOfFailureAsync(safeSession, ex, GetInputMessages(inputMessages, continuationToken), aiContextProviderMessages, cancellationToken).ConfigureAwait(false); + await this.NotifyChatHistoryProviderOfFailureAsync(safeSession, ex, GetInputMessages(inputMessages, continuationToken), chatHistoryProviderMessages, aiContextProviderMessages, chatOptions, cancellationToken).ConfigureAwait(false); + await this.NotifyAIContextProviderOfFailureAsync(safeSession, ex, GetInputMessages(inputMessages, continuationToken), aiContextProviderMessages, cancellationToken).ConfigureAwait(false); throw; } } @@ -286,10 +286,10 @@ protected override async IAsyncEnumerable RunCoreStreamingA await this.UpdateSessionWithTypeAndConversationIdAsync(safeSession, chatResponse.ConversationId, cancellationToken).ConfigureAwait(false); // To avoid inconsistent state we only notify the session of the input messages if no error occurs after the initial request. - await NotifyChatHistoryProviderOfNewMessagesAsync(safeSession, GetInputMessages(inputMessages, continuationToken), chatHistoryProviderMessages, aiContextProviderMessages, chatResponse.Messages, chatOptions, cancellationToken).ConfigureAwait(false); + await this.NotifyChatHistoryProviderOfNewMessagesAsync(safeSession, GetInputMessages(inputMessages, continuationToken), chatHistoryProviderMessages, aiContextProviderMessages, chatResponse.Messages, chatOptions, cancellationToken).ConfigureAwait(false); // Notify the AIContextProvider of all new messages. - await NotifyAIContextProviderOfSuccessAsync(safeSession, GetInputMessages(inputMessages, continuationToken), aiContextProviderMessages, chatResponse.Messages, cancellationToken).ConfigureAwait(false); + await this.NotifyAIContextProviderOfSuccessAsync(safeSession, GetInputMessages(inputMessages, continuationToken), aiContextProviderMessages, chatResponse.Messages, cancellationToken).ConfigureAwait(false); } /// @@ -455,8 +455,8 @@ private async Task RunCoreAsync RunCoreAsync RunCoreAsync /// Notify the when an agent run succeeded, if there is an . /// - private static async Task NotifyAIContextProviderOfSuccessAsync( + private async Task NotifyAIContextProviderOfSuccessAsync( ChatClientAgentSession session, IEnumerable inputMessages, IList? aiContextProviderMessages, @@ -497,7 +497,7 @@ private static async Task NotifyAIContextProviderOfSuccessAsync( { if (session.AIContextProvider is not null) { - await session.AIContextProvider.InvokedAsync(new(inputMessages, aiContextProviderMessages) { ResponseMessages = responseMessages }, + await session.AIContextProvider.InvokedAsync(new(this, session, inputMessages, aiContextProviderMessages) { ResponseMessages = responseMessages }, cancellationToken).ConfigureAwait(false); } } @@ -505,7 +505,7 @@ await session.AIContextProvider.InvokedAsync(new(inputMessages, aiContextProvide /// /// Notify the of any failure during an agent run, if there is an . /// - private static async Task NotifyAIContextProviderOfFailureAsync( + private async Task NotifyAIContextProviderOfFailureAsync( ChatClientAgentSession session, Exception ex, IEnumerable inputMessages, @@ -514,7 +514,7 @@ private static async Task NotifyAIContextProviderOfFailureAsync( { if (session.AIContextProvider is not null) { - await session.AIContextProvider.InvokedAsync(new(inputMessages, aiContextProviderMessages) { InvokeException = ex }, + await session.AIContextProvider.InvokedAsync(new(this, session, inputMessages, aiContextProviderMessages) { InvokeException = ex }, cancellationToken).ConfigureAwait(false); } } @@ -726,7 +726,7 @@ private async Task // Add any existing messages from the session to the messages to be sent to the chat client. if (chatHistoryProvider is not null) { - var invokingContext = new ChatHistoryProvider.InvokingContext(inputMessages); + var invokingContext = new ChatHistoryProvider.InvokingContext(this, typedSession, inputMessages); var providerMessages = await chatHistoryProvider.InvokingAsync(invokingContext, cancellationToken).ConfigureAwait(false); inputMessagesForChatClient.AddRange(providerMessages); chatHistoryProviderMessages = providerMessages as IList ?? providerMessages.ToList(); @@ -739,7 +739,7 @@ private async Task // messages and options with the additional context. if (typedSession.AIContextProvider is not null) { - var invokingContext = new AIContextProvider.InvokingContext(inputMessages); + var invokingContext = new AIContextProvider.InvokingContext(this, typedSession, inputMessages); var aiContext = await typedSession.AIContextProvider.InvokingAsync(invokingContext, cancellationToken).ConfigureAwait(false); if (aiContext.Messages is { Count: > 0 }) { @@ -812,7 +812,7 @@ private async Task UpdateSessionWithTypeAndConversationIdAsync(ChatClientAgentSe } } - private static Task NotifyChatHistoryProviderOfFailureAsync( + private Task NotifyChatHistoryProviderOfFailureAsync( ChatClientAgentSession session, Exception ex, IEnumerable requestMessages, @@ -827,7 +827,7 @@ private static Task NotifyChatHistoryProviderOfFailureAsync( // If we don't have one, it means that the chat history is service managed and the underlying service is responsible for storing messages. if (provider is not null) { - var invokedContext = new ChatHistoryProvider.InvokedContext(requestMessages, chatHistoryProviderMessages!) + var invokedContext = new ChatHistoryProvider.InvokedContext(this, session, requestMessages, chatHistoryProviderMessages!) { AIContextProviderMessages = aiContextProviderMessages, InvokeException = ex @@ -839,7 +839,7 @@ private static Task NotifyChatHistoryProviderOfFailureAsync( return Task.CompletedTask; } - private static Task NotifyChatHistoryProviderOfNewMessagesAsync( + private Task NotifyChatHistoryProviderOfNewMessagesAsync( ChatClientAgentSession session, IEnumerable requestMessages, IEnumerable? chatHistoryProviderMessages, @@ -854,7 +854,7 @@ private static Task NotifyChatHistoryProviderOfNewMessagesAsync( // If we don't have one, it means that the chat history is service managed and the underlying service is responsible for storing messages. if (provider is not null) { - var invokedContext = new ChatHistoryProvider.InvokedContext(requestMessages, chatHistoryProviderMessages!) + var invokedContext = new ChatHistoryProvider.InvokedContext(this, session, requestMessages, chatHistoryProviderMessages!) { AIContextProviderMessages = aiContextProviderMessages, ResponseMessages = responseMessages diff --git a/dotnet/tests/AgentConformance.IntegrationTests/IAgentFixture.cs b/dotnet/tests/AgentConformance.IntegrationTests/IAgentFixture.cs index 96b40d561b..5548c5aaf9 100644 --- a/dotnet/tests/AgentConformance.IntegrationTests/IAgentFixture.cs +++ b/dotnet/tests/AgentConformance.IntegrationTests/IAgentFixture.cs @@ -15,7 +15,7 @@ public interface IAgentFixture : IAsyncLifetime { AIAgent Agent { get; } - Task> GetChatHistoryAsync(AgentSession session); + Task> GetChatHistoryAsync(AIAgent agent, AgentSession session); Task DeleteSessionAsync(AgentSession session); } diff --git a/dotnet/tests/AgentConformance.IntegrationTests/RunStreamingTests.cs b/dotnet/tests/AgentConformance.IntegrationTests/RunStreamingTests.cs index f9cc732175..18982baaad 100644 --- a/dotnet/tests/AgentConformance.IntegrationTests/RunStreamingTests.cs +++ b/dotnet/tests/AgentConformance.IntegrationTests/RunStreamingTests.cs @@ -106,7 +106,7 @@ public virtual async Task SessionMaintainsHistoryAsync() Assert.Contains("Paris", response1Text); Assert.Contains("Vienna", response2Text); - var chatHistory = await this.Fixture.GetChatHistoryAsync(session); + var chatHistory = await this.Fixture.GetChatHistoryAsync(agent, session); Assert.Equal(4, chatHistory.Count); Assert.Equal(2, chatHistory.Count(x => x.Role == ChatRole.User)); Assert.Equal(2, chatHistory.Count(x => x.Role == ChatRole.Assistant)); diff --git a/dotnet/tests/AgentConformance.IntegrationTests/RunTests.cs b/dotnet/tests/AgentConformance.IntegrationTests/RunTests.cs index 302784a1a8..da1cebaf52 100644 --- a/dotnet/tests/AgentConformance.IntegrationTests/RunTests.cs +++ b/dotnet/tests/AgentConformance.IntegrationTests/RunTests.cs @@ -111,7 +111,7 @@ public virtual async Task SessionMaintainsHistoryAsync() Assert.Contains("Paris", result1.Text); Assert.Contains("Vienna", result2.Text); - var chatHistory = await this.Fixture.GetChatHistoryAsync(session); + var chatHistory = await this.Fixture.GetChatHistoryAsync(agent, session); Assert.Equal(4, chatHistory.Count); Assert.Equal(2, chatHistory.Count(x => x.Role == ChatRole.User)); Assert.Equal(2, chatHistory.Count(x => x.Role == ChatRole.Assistant)); diff --git a/dotnet/tests/AnthropicChatCompletion.IntegrationTests/AnthropicChatCompletionFixture.cs b/dotnet/tests/AnthropicChatCompletion.IntegrationTests/AnthropicChatCompletionFixture.cs index 5f0fcbca2c..a0e4b64763 100644 --- a/dotnet/tests/AnthropicChatCompletion.IntegrationTests/AnthropicChatCompletionFixture.cs +++ b/dotnet/tests/AnthropicChatCompletion.IntegrationTests/AnthropicChatCompletionFixture.cs @@ -35,7 +35,7 @@ public AnthropicChatCompletionFixture(bool useReasoningChatModel, bool useBeta) public IChatClient ChatClient => this._agent.ChatClient; - public async Task> GetChatHistoryAsync(AgentSession session) + public async Task> GetChatHistoryAsync(AIAgent agent, AgentSession session) { var typedSession = (ChatClientAgentSession)session; @@ -44,7 +44,7 @@ public async Task> GetChatHistoryAsync(AgentSession session) return []; } - return (await typedSession.ChatHistoryProvider.InvokingAsync(new([]))).ToList(); + return (await typedSession.ChatHistoryProvider.InvokingAsync(new(agent, session, []))).ToList(); } public Task CreateChatClientAgentAsync( diff --git a/dotnet/tests/AzureAI.IntegrationTests/AIProjectClientFixture.cs b/dotnet/tests/AzureAI.IntegrationTests/AIProjectClientFixture.cs index 4b78d30f1c..c655fd7a58 100644 --- a/dotnet/tests/AzureAI.IntegrationTests/AIProjectClientFixture.cs +++ b/dotnet/tests/AzureAI.IntegrationTests/AIProjectClientFixture.cs @@ -33,7 +33,7 @@ public async Task CreateConversationAsync() return response.Value.Id; } - public async Task> GetChatHistoryAsync(AgentSession session) + public async Task> GetChatHistoryAsync(AIAgent agent, AgentSession session) { var chatClientSession = (ChatClientAgentSession)session; @@ -53,7 +53,7 @@ public async Task> GetChatHistoryAsync(AgentSession session) return []; } - return (await chatClientSession.ChatHistoryProvider.InvokingAsync(new([]))).ToList(); + return (await chatClientSession.ChatHistoryProvider.InvokingAsync(new(agent, session, []))).ToList(); } private async Task> GetChatHistoryFromResponsesChainAsync(string conversationId) diff --git a/dotnet/tests/AzureAIAgentsPersistent.IntegrationTests/AzureAIAgentsPersistentFixture.cs b/dotnet/tests/AzureAIAgentsPersistent.IntegrationTests/AzureAIAgentsPersistentFixture.cs index 2f59630c38..3e3272d951 100644 --- a/dotnet/tests/AzureAIAgentsPersistent.IntegrationTests/AzureAIAgentsPersistentFixture.cs +++ b/dotnet/tests/AzureAIAgentsPersistent.IntegrationTests/AzureAIAgentsPersistentFixture.cs @@ -24,7 +24,7 @@ public class AzureAIAgentsPersistentFixture : IChatClientAgentFixture public AIAgent Agent => this._agent; - public async Task> GetChatHistoryAsync(AgentSession session) + public async Task> GetChatHistoryAsync(AIAgent agent, AgentSession session) { List messages = []; var typedSession = (ChatClientAgentSession)session; diff --git a/dotnet/tests/CopilotStudio.IntegrationTests/CopilotStudioFixture.cs b/dotnet/tests/CopilotStudio.IntegrationTests/CopilotStudioFixture.cs index 3b3ac7ff7e..dd5fe46ecc 100644 --- a/dotnet/tests/CopilotStudio.IntegrationTests/CopilotStudioFixture.cs +++ b/dotnet/tests/CopilotStudio.IntegrationTests/CopilotStudioFixture.cs @@ -20,7 +20,7 @@ public class CopilotStudioFixture : IAgentFixture { public AIAgent Agent { get; private set; } = null!; - public Task> GetChatHistoryAsync(AgentSession session) => + public Task> GetChatHistoryAsync(AIAgent agent, AgentSession session) => throw new NotSupportedException("CopilotStudio doesn't allow retrieval of chat history."); public Task DeleteSessionAsync(AgentSession session) => diff --git a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AIContextProviderTests.cs b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AIContextProviderTests.cs index b6aabd081e..94aa73858a 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AIContextProviderTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/AIContextProviderTests.cs @@ -6,17 +6,21 @@ using System.Threading; using System.Threading.Tasks; using Microsoft.Extensions.AI; +using Moq; namespace Microsoft.Agents.AI.Abstractions.UnitTests; public class AIContextProviderTests { + private static readonly AIAgent s_mockAgent = new Mock().Object; + private static readonly AgentSession s_mockSession = new Mock().Object; + [Fact] public async Task InvokedAsync_ReturnsCompletedTaskAsync() { var provider = new TestAIContextProvider(); var messages = new ReadOnlyCollection([]); - var task = provider.InvokedAsync(new(messages, aiContextProviderMessages: null)); + var task = provider.InvokedAsync(new(s_mockAgent, s_mockSession, messages, aiContextProviderMessages: null)); Assert.Equal(default, task); } @@ -31,13 +35,13 @@ public void Serialize_ReturnsEmptyElement() [Fact] public void InvokingContext_Constructor_ThrowsForNullMessages() { - Assert.Throws(() => new AIContextProvider.InvokingContext(null!)); + Assert.Throws(() => new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, null!)); } [Fact] public void InvokedContext_Constructor_ThrowsForNullMessages() { - Assert.Throws(() => new AIContextProvider.InvokedContext(null!, aiContextProviderMessages: null)); + Assert.Throws(() => new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, null!, aiContextProviderMessages: null)); } #region GetService Method Tests @@ -163,7 +167,7 @@ public void InvokingContext_RequestMessages_SetterThrowsForNull() { // Arrange var messages = new ReadOnlyCollection([new(ChatRole.User, "Hello")]); - var context = new AIContextProvider.InvokingContext(messages); + var context = new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, messages); // Act & Assert Assert.Throws(() => context.RequestMessages = null!); @@ -175,7 +179,7 @@ public void InvokingContext_RequestMessages_SetterRoundtrips() // Arrange var initialMessages = new ReadOnlyCollection([new(ChatRole.User, "Hello")]); var newMessages = new List { new(ChatRole.User, "New message") }; - var context = new AIContextProvider.InvokingContext(initialMessages); + var context = new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, initialMessages); // Act context.RequestMessages = newMessages; @@ -184,6 +188,55 @@ public void InvokingContext_RequestMessages_SetterRoundtrips() Assert.Same(newMessages, context.RequestMessages); } + [Fact] + public void InvokingContext_Agent_ReturnsConstructorValue() + { + // Arrange + var messages = new ReadOnlyCollection([new(ChatRole.User, "Hello")]); + + // Act + var context = new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, messages); + + // Assert + Assert.Same(s_mockAgent, context.Agent); + } + + [Fact] + public void InvokingContext_Session_ReturnsConstructorValue() + { + // Arrange + var messages = new ReadOnlyCollection([new(ChatRole.User, "Hello")]); + + // Act + var context = new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, messages); + + // Assert + Assert.Same(s_mockSession, context.Session); + } + + [Fact] + public void InvokingContext_Session_CanBeNull() + { + // Arrange + var messages = new ReadOnlyCollection([new(ChatRole.User, "Hello")]); + + // Act + var context = new AIContextProvider.InvokingContext(s_mockAgent, null, messages); + + // Assert + Assert.Null(context.Session); + } + + [Fact] + public void InvokingContext_Constructor_ThrowsForNullAgent() + { + // Arrange + var messages = new ReadOnlyCollection([new(ChatRole.User, "Hello")]); + + // Act & Assert + Assert.Throws(() => new AIContextProvider.InvokingContext(null!, s_mockSession, messages)); + } + #endregion #region InvokedContext Tests @@ -193,7 +246,7 @@ public void InvokedContext_RequestMessages_SetterThrowsForNull() { // Arrange var messages = new ReadOnlyCollection([new(ChatRole.User, "Hello")]); - var context = new AIContextProvider.InvokedContext(messages, aiContextProviderMessages: null); + var context = new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, messages, aiContextProviderMessages: null); // Act & Assert Assert.Throws(() => context.RequestMessages = null!); @@ -205,7 +258,7 @@ public void InvokedContext_RequestMessages_SetterRoundtrips() // Arrange var initialMessages = new ReadOnlyCollection([new(ChatRole.User, "Hello")]); var newMessages = new List { new(ChatRole.User, "New message") }; - var context = new AIContextProvider.InvokedContext(initialMessages, aiContextProviderMessages: null); + var context = new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, initialMessages, aiContextProviderMessages: null); // Act context.RequestMessages = newMessages; @@ -220,7 +273,7 @@ public void InvokedContext_AIContextProviderMessages_Roundtrips() // Arrange var requestMessages = new ReadOnlyCollection([new(ChatRole.User, "Hello")]); var aiContextMessages = new List { new(ChatRole.System, "AI context message") }; - var context = new AIContextProvider.InvokedContext(requestMessages, aiContextProviderMessages: null); + var context = new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages, aiContextProviderMessages: null); // Act context.AIContextProviderMessages = aiContextMessages; @@ -235,7 +288,7 @@ public void InvokedContext_ResponseMessages_Roundtrips() // Arrange var requestMessages = new ReadOnlyCollection([new(ChatRole.User, "Hello")]); var responseMessages = new List { new(ChatRole.Assistant, "Response message") }; - var context = new AIContextProvider.InvokedContext(requestMessages, aiContextProviderMessages: null); + var context = new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages, aiContextProviderMessages: null); // Act context.ResponseMessages = responseMessages; @@ -250,7 +303,7 @@ public void InvokedContext_InvokeException_Roundtrips() // Arrange var requestMessages = new ReadOnlyCollection([new(ChatRole.User, "Hello")]); var exception = new InvalidOperationException("Test exception"); - var context = new AIContextProvider.InvokedContext(requestMessages, aiContextProviderMessages: null); + var context = new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages, aiContextProviderMessages: null); // Act context.InvokeException = exception; @@ -259,6 +312,55 @@ public void InvokedContext_InvokeException_Roundtrips() Assert.Same(exception, context.InvokeException); } + [Fact] + public void InvokedContext_Agent_ReturnsConstructorValue() + { + // Arrange + var requestMessages = new ReadOnlyCollection([new(ChatRole.User, "Hello")]); + + // Act + var context = new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages, aiContextProviderMessages: null); + + // Assert + Assert.Same(s_mockAgent, context.Agent); + } + + [Fact] + public void InvokedContext_Session_ReturnsConstructorValue() + { + // Arrange + var requestMessages = new ReadOnlyCollection([new(ChatRole.User, "Hello")]); + + // Act + var context = new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages, aiContextProviderMessages: null); + + // Assert + Assert.Same(s_mockSession, context.Session); + } + + [Fact] + public void InvokedContext_Session_CanBeNull() + { + // Arrange + var requestMessages = new ReadOnlyCollection([new(ChatRole.User, "Hello")]); + + // Act + var context = new AIContextProvider.InvokedContext(s_mockAgent, null, requestMessages, aiContextProviderMessages: null); + + // Assert + Assert.Null(context.Session); + } + + [Fact] + public void InvokedContext_Constructor_ThrowsForNullAgent() + { + // Arrange + var requestMessages = new ReadOnlyCollection([new(ChatRole.User, "Hello")]); + + // Act & Assert + Assert.Throws(() => new AIContextProvider.InvokedContext(null!, s_mockSession, requestMessages, aiContextProviderMessages: null)); + } + #endregion private sealed class TestAIContextProvider : AIContextProvider diff --git a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/ChatHistoryProviderExtensionsTests.cs b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/ChatHistoryProviderExtensionsTests.cs index 84a0242320..a74906c801 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/ChatHistoryProviderExtensionsTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/ChatHistoryProviderExtensionsTests.cs @@ -14,6 +14,9 @@ namespace Microsoft.Agents.AI.Abstractions.UnitTests; /// public sealed class ChatHistoryProviderExtensionsTests { + private static readonly AIAgent s_mockAgent = new Mock().Object; + private static readonly AgentSession s_mockSession = new Mock().Object; + [Fact] public void WithMessageFilters_ReturnsChatHistoryProviderMessageFilter() { @@ -35,7 +38,7 @@ public async Task WithMessageFilters_InvokingFilter_IsAppliedAsync() // Arrange Mock providerMock = new(); List innerMessages = [new(ChatRole.User, "Hello"), new(ChatRole.Assistant, "Hi")]; - ChatHistoryProvider.InvokingContext context = new([new ChatMessage(ChatRole.User, "Test")]); + ChatHistoryProvider.InvokingContext context = new(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Test")]); providerMock .Setup(p => p.InvokingAsync(context, It.IsAny())) @@ -59,7 +62,7 @@ public async Task WithMessageFilters_InvokedFilter_IsAppliedAsync() Mock providerMock = new(); List requestMessages = [new(ChatRole.User, "Hello")]; List chatHistoryProviderMessages = [new(ChatRole.System, "System")]; - ChatHistoryProvider.InvokedContext context = new(requestMessages, chatHistoryProviderMessages) + ChatHistoryProvider.InvokedContext context = new(s_mockAgent, s_mockSession, requestMessages, chatHistoryProviderMessages) { ResponseMessages = [new ChatMessage(ChatRole.Assistant, "Response")] }; @@ -106,7 +109,7 @@ public async Task WithAIContextProviderMessageRemoval_RemovesAIContextProviderMe List requestMessages = [new(ChatRole.User, "Hello")]; List chatHistoryProviderMessages = [new(ChatRole.System, "System")]; List aiContextProviderMessages = [new(ChatRole.System, "Context")]; - ChatHistoryProvider.InvokedContext context = new(requestMessages, chatHistoryProviderMessages) + ChatHistoryProvider.InvokedContext context = new(s_mockAgent, s_mockSession, requestMessages, chatHistoryProviderMessages) { AIContextProviderMessages = aiContextProviderMessages }; diff --git a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/ChatHistoryProviderMessageFilterTests.cs b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/ChatHistoryProviderMessageFilterTests.cs index 43a3e78f10..4b955a43c0 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/ChatHistoryProviderMessageFilterTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/ChatHistoryProviderMessageFilterTests.cs @@ -16,6 +16,9 @@ namespace Microsoft.Agents.AI.Abstractions.UnitTests; /// public sealed class ChatHistoryProviderMessageFilterTests { + private static readonly AIAgent s_mockAgent = new Mock().Object; + private static readonly AgentSession s_mockSession = new Mock().Object; + [Fact] public void Constructor_WithNullInnerProvider_ThrowsArgumentNullException() { @@ -59,7 +62,7 @@ public async Task InvokingAsync_WithNoOpFilters_ReturnsInnerProviderMessagesAsyn new(ChatRole.User, "Hello"), new(ChatRole.Assistant, "Hi there!") }; - var context = new ChatHistoryProvider.InvokingContext([new ChatMessage(ChatRole.User, "Test")]); + var context = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Test")]); innerProviderMock .Setup(s => s.InvokingAsync(context, It.IsAny())) @@ -88,7 +91,7 @@ public async Task InvokingAsync_WithInvokingFilter_AppliesFilterAsync() new(ChatRole.Assistant, "Hi there!"), new(ChatRole.User, "How are you?") }; - var context = new ChatHistoryProvider.InvokingContext([new ChatMessage(ChatRole.User, "Test")]); + var context = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Test")]); innerProviderMock .Setup(s => s.InvokingAsync(context, It.IsAny())) @@ -118,7 +121,7 @@ public async Task InvokingAsync_WithInvokingFilter_CanModifyMessagesAsync() new(ChatRole.User, "Hello"), new(ChatRole.Assistant, "Hi there!") }; - var context = new ChatHistoryProvider.InvokingContext([new ChatMessage(ChatRole.User, "Test")]); + var context = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Test")]); innerProviderMock .Setup(s => s.InvokingAsync(context, It.IsAny())) @@ -147,7 +150,7 @@ public async Task InvokedAsync_WithInvokedFilter_AppliesFilterAsync() var requestMessages = new List { new(ChatRole.User, "Hello") }; var chatHistoryProviderMessages = new List { new(ChatRole.System, "System") }; var responseMessages = new List { new(ChatRole.Assistant, "Response") }; - var context = new ChatHistoryProvider.InvokedContext(requestMessages, chatHistoryProviderMessages) + var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages, chatHistoryProviderMessages) { ResponseMessages = responseMessages }; @@ -162,7 +165,7 @@ public async Task InvokedAsync_WithInvokedFilter_AppliesFilterAsync() ChatHistoryProvider.InvokedContext InvokedFilter(ChatHistoryProvider.InvokedContext ctx) { var modifiedRequestMessages = ctx.RequestMessages.Select(m => new ChatMessage(m.Role, $"[FILTERED] {m.Text}")).ToList(); - return new ChatHistoryProvider.InvokedContext(modifiedRequestMessages, ctx.ChatHistoryProviderMessages) + return new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, modifiedRequestMessages, ctx.ChatHistoryProviderMessages) { ResponseMessages = ctx.ResponseMessages, AIContextProviderMessages = ctx.AIContextProviderMessages, diff --git a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/ChatHistoryProviderTests.cs b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/ChatHistoryProviderTests.cs index a26ef199d9..5e0fbe9817 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/ChatHistoryProviderTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/ChatHistoryProviderTests.cs @@ -6,6 +6,7 @@ using System.Threading; using System.Threading.Tasks; using Microsoft.Extensions.AI; +using Moq; namespace Microsoft.Agents.AI.Abstractions.UnitTests; @@ -14,6 +15,9 @@ namespace Microsoft.Agents.AI.Abstractions.UnitTests; /// public class ChatHistoryProviderTests { + private static readonly AIAgent s_mockAgent = new Mock().Object; + private static readonly AgentSession s_mockSession = new Mock().Object; + #region GetService Method Tests [Fact] @@ -82,7 +86,7 @@ public void GetService_Generic_ReturnsNullForUnrelatedType() public void InvokingContext_Constructor_ThrowsForNullMessages() { // Arrange & Act & Assert - Assert.Throws(() => new ChatHistoryProvider.InvokingContext(null!)); + Assert.Throws(() => new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, null!)); } [Fact] @@ -90,7 +94,7 @@ public void InvokingContext_RequestMessages_SetterThrowsForNull() { // Arrange var messages = new List { new(ChatRole.User, "Hello") }; - var context = new ChatHistoryProvider.InvokingContext(messages); + var context = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, messages); // Act & Assert Assert.Throws(() => context.RequestMessages = null!); @@ -102,7 +106,7 @@ public void InvokingContext_RequestMessages_SetterRoundtrips() // Arrange var initialMessages = new List { new(ChatRole.User, "Hello") }; var newMessages = new List { new(ChatRole.User, "New message") }; - var context = new ChatHistoryProvider.InvokingContext(initialMessages); + var context = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, initialMessages); // Act context.RequestMessages = newMessages; @@ -111,6 +115,55 @@ public void InvokingContext_RequestMessages_SetterRoundtrips() Assert.Same(newMessages, context.RequestMessages); } + [Fact] + public void InvokingContext_Agent_ReturnsConstructorValue() + { + // Arrange + var messages = new List { new(ChatRole.User, "Hello") }; + + // Act + var context = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, messages); + + // Assert + Assert.Same(s_mockAgent, context.Agent); + } + + [Fact] + public void InvokingContext_Session_ReturnsConstructorValue() + { + // Arrange + var messages = new List { new(ChatRole.User, "Hello") }; + + // Act + var context = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, messages); + + // Assert + Assert.Same(s_mockSession, context.Session); + } + + [Fact] + public void InvokingContext_Session_CanBeNull() + { + // Arrange + var messages = new List { new(ChatRole.User, "Hello") }; + + // Act + var context = new ChatHistoryProvider.InvokingContext(s_mockAgent, null, messages); + + // Assert + Assert.Null(context.Session); + } + + [Fact] + public void InvokingContext_Constructor_ThrowsForNullAgent() + { + // Arrange + var messages = new List { new(ChatRole.User, "Hello") }; + + // Act & Assert + Assert.Throws(() => new ChatHistoryProvider.InvokingContext(null!, s_mockSession, messages)); + } + #endregion #region InvokedContext Tests @@ -119,7 +172,7 @@ public void InvokingContext_RequestMessages_SetterRoundtrips() public void InvokedContext_Constructor_ThrowsForNullRequestMessages() { // Arrange & Act & Assert - Assert.Throws(() => new ChatHistoryProvider.InvokedContext(null!, [])); + Assert.Throws(() => new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, null!, [])); } [Fact] @@ -127,7 +180,7 @@ public void InvokedContext_RequestMessages_SetterThrowsForNull() { // Arrange var requestMessages = new List { new(ChatRole.User, "Hello") }; - var context = new ChatHistoryProvider.InvokedContext(requestMessages, []); + var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages, []); // Act & Assert Assert.Throws(() => context.RequestMessages = null!); @@ -139,7 +192,7 @@ public void InvokedContext_RequestMessages_SetterRoundtrips() // Arrange var initialMessages = new List { new(ChatRole.User, "Hello") }; var newMessages = new List { new(ChatRole.User, "New message") }; - var context = new ChatHistoryProvider.InvokedContext(initialMessages, []); + var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, initialMessages, []); // Act context.RequestMessages = newMessages; @@ -154,7 +207,7 @@ public void InvokedContext_ChatHistoryProviderMessages_SetterRoundtrips() // Arrange var requestMessages = new List { new(ChatRole.User, "Hello") }; var newProviderMessages = new List { new(ChatRole.System, "System message") }; - var context = new ChatHistoryProvider.InvokedContext(requestMessages, []); + var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages, []); // Act context.ChatHistoryProviderMessages = newProviderMessages; @@ -169,7 +222,7 @@ public void InvokedContext_AIContextProviderMessages_Roundtrips() // Arrange var requestMessages = new List { new(ChatRole.User, "Hello") }; var aiContextMessages = new List { new(ChatRole.System, "AI context message") }; - var context = new ChatHistoryProvider.InvokedContext(requestMessages, []); + var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages, []); // Act context.AIContextProviderMessages = aiContextMessages; @@ -184,7 +237,7 @@ public void InvokedContext_ResponseMessages_Roundtrips() // Arrange var requestMessages = new List { new(ChatRole.User, "Hello") }; var responseMessages = new List { new(ChatRole.Assistant, "Response message") }; - var context = new ChatHistoryProvider.InvokedContext(requestMessages, []); + var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages, []); // Act context.ResponseMessages = responseMessages; @@ -199,7 +252,7 @@ public void InvokedContext_InvokeException_Roundtrips() // Arrange var requestMessages = new List { new(ChatRole.User, "Hello") }; var exception = new InvalidOperationException("Test exception"); - var context = new ChatHistoryProvider.InvokedContext(requestMessages, []); + var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages, []); // Act context.InvokeException = exception; @@ -208,6 +261,55 @@ public void InvokedContext_InvokeException_Roundtrips() Assert.Same(exception, context.InvokeException); } + [Fact] + public void InvokedContext_Agent_ReturnsConstructorValue() + { + // Arrange + var requestMessages = new List { new(ChatRole.User, "Hello") }; + + // Act + var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages, []); + + // Assert + Assert.Same(s_mockAgent, context.Agent); + } + + [Fact] + public void InvokedContext_Session_ReturnsConstructorValue() + { + // Arrange + var requestMessages = new List { new(ChatRole.User, "Hello") }; + + // Act + var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages, []); + + // Assert + Assert.Same(s_mockSession, context.Session); + } + + [Fact] + public void InvokedContext_Session_CanBeNull() + { + // Arrange + var requestMessages = new List { new(ChatRole.User, "Hello") }; + + // Act + var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, null, requestMessages, []); + + // Assert + Assert.Null(context.Session); + } + + [Fact] + public void InvokedContext_Constructor_ThrowsForNullAgent() + { + // Arrange + var requestMessages = new List { new(ChatRole.User, "Hello") }; + + // Act & Assert + Assert.Throws(() => new ChatHistoryProvider.InvokedContext(null!, s_mockSession, requestMessages, [])); + } + #endregion private sealed class TestChatHistoryProvider : ChatHistoryProvider diff --git a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/InMemoryChatHistoryProviderTests.cs b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/InMemoryChatHistoryProviderTests.cs index ff31d0afc9..bf8ff998b9 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/InMemoryChatHistoryProviderTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Abstractions.UnitTests/InMemoryChatHistoryProviderTests.cs @@ -18,6 +18,9 @@ namespace Microsoft.Agents.AI.Abstractions.UnitTests; /// public class InMemoryChatHistoryProviderTests { + private static readonly AIAgent s_mockAgent = new Mock().Object; + private static readonly AgentSession s_mockSession = new Mock().Object; + [Fact] public void Constructor_Throws_ForNullReducer() => // Arrange & Act & Assert @@ -68,7 +71,7 @@ public async Task InvokedAsyncAddsMessagesAsync() var provider = new InMemoryChatHistoryProvider(); provider.Add(providerMessages[0]); - var context = new ChatHistoryProvider.InvokedContext(requestMessages, providerMessages) + var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages, providerMessages) { AIContextProviderMessages = aiContextProviderMessages, ResponseMessages = responseMessages @@ -87,7 +90,7 @@ public async Task InvokedAsyncWithEmptyDoesNotFailAsync() { var provider = new InMemoryChatHistoryProvider(); - var context = new ChatHistoryProvider.InvokedContext([], []); + var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, [], []); await provider.InvokedAsync(context, CancellationToken.None); Assert.Empty(provider); @@ -102,7 +105,7 @@ public async Task InvokingAsyncReturnsAllMessagesAsync() new ChatMessage(ChatRole.Assistant, "Test2") }; - var context = new ChatHistoryProvider.InvokingContext([]); + var context = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, []); var result = (await provider.InvokingAsync(context, CancellationToken.None)).ToList(); Assert.Equal(2, result.Count); @@ -183,7 +186,7 @@ public async Task InvokedAsyncWithEmptyMessagesDoesNotChangeProviderAsync() var provider = new InMemoryChatHistoryProvider(); var messages = new List(); - var context = new ChatHistoryProvider.InvokedContext(messages, []); + var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, messages, []); await provider.InvokedAsync(context, CancellationToken.None); Assert.Empty(provider); @@ -520,7 +523,7 @@ public async Task AddMessagesAsync_WithReducer_AfterMessageAdded_InvokesReducerA var provider = new InMemoryChatHistoryProvider(reducerMock.Object, InMemoryChatHistoryProvider.ChatReducerTriggerEvent.AfterMessageAdded); // Act - var context = new ChatHistoryProvider.InvokedContext(originalMessages, []); + var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, originalMessages, []); await provider.InvokedAsync(context, CancellationToken.None); // Assert @@ -556,7 +559,7 @@ public async Task GetMessagesAsync_WithReducer_BeforeMessagesRetrieval_InvokesRe } // Act - var invokingContext = new ChatHistoryProvider.InvokingContext(Array.Empty()); + var invokingContext = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, Array.Empty()); var result = (await provider.InvokingAsync(invokingContext, CancellationToken.None)).ToList(); // Assert @@ -579,7 +582,7 @@ public async Task AddMessagesAsync_WithReducer_ButWrongTrigger_DoesNotInvokeRedu var provider = new InMemoryChatHistoryProvider(reducerMock.Object, InMemoryChatHistoryProvider.ChatReducerTriggerEvent.BeforeMessagesRetrieval); // Act - var context = new ChatHistoryProvider.InvokedContext(originalMessages, []); + var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, originalMessages, []); await provider.InvokedAsync(context, CancellationToken.None); // Assert @@ -605,7 +608,7 @@ public async Task GetMessagesAsync_WithReducer_ButWrongTrigger_DoesNotInvokeRedu }; // Act - var invokingContext = new ChatHistoryProvider.InvokingContext(Array.Empty()); + var invokingContext = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, Array.Empty()); var result = (await provider.InvokingAsync(invokingContext, CancellationToken.None)).ToList(); // Assert @@ -627,7 +630,7 @@ public async Task InvokedAsync_WithException_DoesNotAddMessagesAsync() { new(ChatRole.Assistant, "Hi there!") }; - var context = new ChatHistoryProvider.InvokedContext(requestMessages, []) + var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages, []) { ResponseMessages = responseMessages, InvokeException = new InvalidOperationException("Test exception") diff --git a/dotnet/tests/Microsoft.Agents.AI.CosmosNoSql.UnitTests/CosmosChatHistoryProviderTests.cs b/dotnet/tests/Microsoft.Agents.AI.CosmosNoSql.UnitTests/CosmosChatHistoryProviderTests.cs index ab2f58dfd5..f6589ff9e3 100644 --- a/dotnet/tests/Microsoft.Agents.AI.CosmosNoSql.UnitTests/CosmosChatHistoryProviderTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.CosmosNoSql.UnitTests/CosmosChatHistoryProviderTests.cs @@ -41,6 +41,9 @@ namespace Microsoft.Agents.AI.CosmosNoSql.UnitTests; [Collection("CosmosDB")] public sealed class CosmosChatHistoryProviderTests : IAsyncLifetime, IDisposable { + private static readonly AIAgent s_mockAgent = new Moq.Mock().Object; + private static readonly AgentSession s_mockSession = new Moq.Mock().Object; + // Cosmos DB Emulator connection settings private const string EmulatorEndpoint = "https://localhost:8081"; private const string EmulatorKey = "C2y6yDjf5/R+ob0N8A7Cgv30VRDJIWEHLM+4QDU5DE2nQ9nDuVTqobD4b8mGGyPMbIZnqyMsEcaGQy67XIw/Jw=="; @@ -214,7 +217,7 @@ public async Task InvokedAsync_WithSingleMessage_ShouldAddMessageAsync() using var provider = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, TestContainerId, conversationId); var message = new ChatMessage(ChatRole.User, "Hello, world!"); - var context = new ChatHistoryProvider.InvokedContext([message], []) + var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, [message], []) { ResponseMessages = [] }; @@ -226,7 +229,7 @@ public async Task InvokedAsync_WithSingleMessage_ShouldAddMessageAsync() await Task.Delay(100); // Assert - var invokingContext = new ChatHistoryProvider.InvokingContext([]); + var invokingContext = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, []); var messages = await provider.InvokingAsync(invokingContext); var messageList = messages.ToList(); @@ -293,7 +296,7 @@ public async Task InvokedAsync_WithMultipleMessages_ShouldAddAllMessagesAsync() new ChatMessage(ChatRole.Assistant, "Response message") }; - var context = new ChatHistoryProvider.InvokedContext(requestMessages, []) + var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages, []) { AIContextProviderMessages = aiContextProviderMessages, ResponseMessages = responseMessages @@ -303,7 +306,7 @@ public async Task InvokedAsync_WithMultipleMessages_ShouldAddAllMessagesAsync() await provider.InvokedAsync(context); // Assert - var invokingContext = new ChatHistoryProvider.InvokingContext([]); + var invokingContext = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, []); var retrievedMessages = await provider.InvokingAsync(invokingContext); var messageList = retrievedMessages.ToList(); Assert.Equal(5, messageList.Count); @@ -327,7 +330,7 @@ public async Task InvokingAsync_WithNoMessages_ShouldReturnEmptyAsync() using var provider = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, TestContainerId, Guid.NewGuid().ToString()); // Act - var invokingContext = new ChatHistoryProvider.InvokingContext([]); + var invokingContext = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, []); var messages = await provider.InvokingAsync(invokingContext); // Assert @@ -346,15 +349,15 @@ public async Task InvokingAsync_WithConversationIsolation_ShouldOnlyReturnMessag using var store1 = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, TestContainerId, conversation1); using var store2 = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, TestContainerId, conversation2); - var context1 = new ChatHistoryProvider.InvokedContext([new ChatMessage(ChatRole.User, "Message for conversation 1")], []); - var context2 = new ChatHistoryProvider.InvokedContext([new ChatMessage(ChatRole.User, "Message for conversation 2")], []); + var context1 = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Message for conversation 1")], []); + var context2 = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Message for conversation 2")], []); await store1.InvokedAsync(context1); await store2.InvokedAsync(context2); // Act - var invokingContext1 = new ChatHistoryProvider.InvokingContext([]); - var invokingContext2 = new ChatHistoryProvider.InvokingContext([]); + var invokingContext1 = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, []); + var invokingContext2 = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, []); var messages1 = await store1.InvokingAsync(invokingContext1); var messages2 = await store2.InvokingAsync(invokingContext2); @@ -391,11 +394,11 @@ public async Task FullWorkflow_AddAndGet_ShouldWorkCorrectlyAsync() }; // Act 1: Add messages - var invokedContext = new ChatHistoryProvider.InvokedContext(messages, []); + var invokedContext = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, messages, []); await originalStore.InvokedAsync(invokedContext); // Act 2: Verify messages were added - var invokingContext = new ChatHistoryProvider.InvokingContext([]); + var invokingContext = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, []); var retrievedMessages = await originalStore.InvokingAsync(invokingContext); var retrievedList = retrievedMessages.ToList(); Assert.Equal(5, retrievedList.Count); @@ -545,7 +548,7 @@ public async Task InvokedAsync_WithHierarchicalPartitioning_ShouldAddMessageWith using var provider = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, HierarchicalTestContainerId, TenantId, UserId, SessionId); var message = new ChatMessage(ChatRole.User, "Hello from hierarchical partitioning!"); - var context = new ChatHistoryProvider.InvokedContext([message], []); + var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, [message], []); // Act await provider.InvokedAsync(context); @@ -554,7 +557,7 @@ public async Task InvokedAsync_WithHierarchicalPartitioning_ShouldAddMessageWith await Task.Delay(100); // Assert - var invokingContext = new ChatHistoryProvider.InvokingContext([]); + var invokingContext = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, []); var messages = await provider.InvokingAsync(invokingContext); var messageList = messages.ToList(); @@ -602,7 +605,7 @@ public async Task InvokedAsync_WithHierarchicalMultipleMessages_ShouldAddAllMess new ChatMessage(ChatRole.User, "Third hierarchical message") }; - var context = new ChatHistoryProvider.InvokedContext(messages, []); + var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, messages, []); // Act await provider.InvokedAsync(context); @@ -611,7 +614,7 @@ public async Task InvokedAsync_WithHierarchicalMultipleMessages_ShouldAddAllMess await Task.Delay(100); // Assert - var invokingContext = new ChatHistoryProvider.InvokingContext([]); + var invokingContext = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, []); var retrievedMessages = await provider.InvokingAsync(invokingContext); var messageList = retrievedMessages.ToList(); @@ -637,8 +640,8 @@ public async Task InvokingAsync_WithHierarchicalPartitionIsolation_ShouldIsolate using var store2 = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, HierarchicalTestContainerId, TenantId, UserId2, SessionId); // Add messages to both stores - var context1 = new ChatHistoryProvider.InvokedContext([new ChatMessage(ChatRole.User, "Message from user 1")], []); - var context2 = new ChatHistoryProvider.InvokedContext([new ChatMessage(ChatRole.User, "Message from user 2")], []); + var context1 = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Message from user 1")], []); + var context2 = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Message from user 2")], []); await store1.InvokedAsync(context1); await store2.InvokedAsync(context2); @@ -647,8 +650,8 @@ public async Task InvokingAsync_WithHierarchicalPartitionIsolation_ShouldIsolate await Task.Delay(100); // Act & Assert - var invokingContext1 = new ChatHistoryProvider.InvokingContext([]); - var invokingContext2 = new ChatHistoryProvider.InvokingContext([]); + var invokingContext1 = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, []); + var invokingContext2 = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, []); var messages1 = await store1.InvokingAsync(invokingContext1); var messageList1 = messages1.ToList(); @@ -675,7 +678,7 @@ public async Task SerializeDeserialize_WithHierarchicalPartitioning_ShouldPreser using var originalStore = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, HierarchicalTestContainerId, TenantId, UserId, SessionId); - var context = new ChatHistoryProvider.InvokedContext([new ChatMessage(ChatRole.User, "Test serialization message")], []); + var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Test serialization message")], []); await originalStore.InvokedAsync(context); // Act - Serialize the provider state @@ -693,7 +696,7 @@ public async Task SerializeDeserialize_WithHierarchicalPartitioning_ShouldPreser await Task.Delay(100); // Assert - The deserialized provider should have the same functionality - var invokingContext = new ChatHistoryProvider.InvokingContext([]); + var invokingContext = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, []); var messages = await deserializedStore.InvokingAsync(invokingContext); var messageList = messages.ToList(); @@ -717,8 +720,8 @@ public async Task HierarchicalAndSimplePartitioning_ShouldCoexistAsync() using var hierarchicalProvider = new CosmosChatHistoryProvider(this._connectionString, s_testDatabaseId, HierarchicalTestContainerId, "tenant-coexist", "user-coexist", SessionId); // Add messages to both - var simpleContext = new ChatHistoryProvider.InvokedContext([new ChatMessage(ChatRole.User, "Simple partitioning message")], []); - var hierarchicalContext = new ChatHistoryProvider.InvokedContext([new ChatMessage(ChatRole.User, "Hierarchical partitioning message")], []); + var simpleContext = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Simple partitioning message")], []); + var hierarchicalContext = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Hierarchical partitioning message")], []); await simpleProvider.InvokedAsync(simpleContext); await hierarchicalProvider.InvokedAsync(hierarchicalContext); @@ -727,7 +730,7 @@ public async Task HierarchicalAndSimplePartitioning_ShouldCoexistAsync() await Task.Delay(100); // Act & Assert - var invokingContext = new ChatHistoryProvider.InvokingContext([]); + var invokingContext = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, []); var simpleMessages = await simpleProvider.InvokingAsync(invokingContext); var simpleMessageList = simpleMessages.ToList(); @@ -760,7 +763,7 @@ public async Task MaxMessagesToRetrieve_ShouldLimitAndReturnMostRecentAsync() await Task.Delay(10); // Small delay to ensure different timestamps } - var context = new ChatHistoryProvider.InvokedContext(messages, []); + var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, messages, []); await provider.InvokedAsync(context); // Wait for eventual consistency @@ -768,7 +771,7 @@ public async Task MaxMessagesToRetrieve_ShouldLimitAndReturnMostRecentAsync() // Act - Set max to 5 and retrieve provider.MaxMessagesToRetrieve = 5; - var invokingContext = new ChatHistoryProvider.InvokingContext([]); + var invokingContext = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, []); var retrievedMessages = await provider.InvokingAsync(invokingContext); var messageList = retrievedMessages.ToList(); @@ -798,14 +801,14 @@ public async Task MaxMessagesToRetrieve_Null_ShouldReturnAllMessagesAsync() messages.Add(new ChatMessage(ChatRole.User, $"Message {i}")); } - var context = new ChatHistoryProvider.InvokedContext(messages, []); + var context = new ChatHistoryProvider.InvokedContext(s_mockAgent, s_mockSession, messages, []); await provider.InvokedAsync(context); // Wait for eventual consistency await Task.Delay(100); // Act - No limit set (default null) - var invokingContext = new ChatHistoryProvider.InvokingContext([]); + var invokingContext = new ChatHistoryProvider.InvokingContext(s_mockAgent, s_mockSession, []); var retrievedMessages = await provider.InvokingAsync(invokingContext); var messageList = retrievedMessages.ToList(); diff --git a/dotnet/tests/Microsoft.Agents.AI.Mem0.IntegrationTests/Mem0ProviderTests.cs b/dotnet/tests/Microsoft.Agents.AI.Mem0.IntegrationTests/Mem0ProviderTests.cs index bacc59833a..81ca4eb588 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Mem0.IntegrationTests/Mem0ProviderTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Mem0.IntegrationTests/Mem0ProviderTests.cs @@ -18,6 +18,9 @@ public sealed class Mem0ProviderTests : IDisposable { private const string SkipReason = "Requires a Mem0 service configured"; // Set to null to enable. + private static readonly AIAgent s_mockAgent = new Moq.Mock().Object; + private static readonly AgentSession s_mockSession = new Moq.Mock().Object; + private readonly HttpClient _httpClient; public Mem0ProviderTests() @@ -49,14 +52,14 @@ public async Task CanAddAndRetrieveUserMemoriesAsync() var sut = new Mem0Provider(this._httpClient, storageScope); await sut.ClearStoredMemoriesAsync(); - var ctxBefore = await sut.InvokingAsync(new AIContextProvider.InvokingContext([question])); + var ctxBefore = await sut.InvokingAsync(new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [question])); Assert.DoesNotContain("Caoimhe", ctxBefore.Messages?[0].Text ?? string.Empty); // Act - await sut.InvokedAsync(new AIContextProvider.InvokedContext([input], aiContextProviderMessages: null)); + await sut.InvokedAsync(new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, [input], aiContextProviderMessages: null)); var ctxAfterAdding = await GetContextWithRetryAsync(sut, question); await sut.ClearStoredMemoriesAsync(); - var ctxAfterClearing = await sut.InvokingAsync(new AIContextProvider.InvokingContext([question])); + var ctxAfterClearing = await sut.InvokingAsync(new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [question])); // Assert Assert.Contains("Caoimhe", ctxAfterAdding.Messages?[0].Text ?? string.Empty); @@ -73,14 +76,14 @@ public async Task CanAddAndRetrieveAgentMemoriesAsync() var sut = new Mem0Provider(this._httpClient, storageScope); await sut.ClearStoredMemoriesAsync(); - var ctxBefore = await sut.InvokingAsync(new AIContextProvider.InvokingContext([question])); + var ctxBefore = await sut.InvokingAsync(new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [question])); Assert.DoesNotContain("Caoimhe", ctxBefore.Messages?[0].Text ?? string.Empty); // Act - await sut.InvokedAsync(new AIContextProvider.InvokedContext([assistantIntro], aiContextProviderMessages: null)); + await sut.InvokedAsync(new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, [assistantIntro], aiContextProviderMessages: null)); var ctxAfterAdding = await GetContextWithRetryAsync(sut, question); await sut.ClearStoredMemoriesAsync(); - var ctxAfterClearing = await sut.InvokingAsync(new AIContextProvider.InvokingContext([question])); + var ctxAfterClearing = await sut.InvokingAsync(new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [question])); // Assert Assert.Contains("Caoimhe", ctxAfterAdding.Messages?[0].Text ?? string.Empty); @@ -99,13 +102,13 @@ public async Task DoesNotLeakMemoriesAcrossAgentScopesAsync() await sut1.ClearStoredMemoriesAsync(); await sut2.ClearStoredMemoriesAsync(); - var ctxBefore1 = await sut1.InvokingAsync(new AIContextProvider.InvokingContext([question])); - var ctxBefore2 = await sut2.InvokingAsync(new AIContextProvider.InvokingContext([question])); + var ctxBefore1 = await sut1.InvokingAsync(new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [question])); + var ctxBefore2 = await sut2.InvokingAsync(new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [question])); Assert.DoesNotContain("Caoimhe", ctxBefore1.Messages?[0].Text ?? string.Empty); Assert.DoesNotContain("Caoimhe", ctxBefore2.Messages?[0].Text ?? string.Empty); // Act - await sut1.InvokedAsync(new AIContextProvider.InvokedContext([assistantIntro], aiContextProviderMessages: null)); + await sut1.InvokedAsync(new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, [assistantIntro], aiContextProviderMessages: null)); var ctxAfterAdding1 = await GetContextWithRetryAsync(sut1, question); var ctxAfterAdding2 = await GetContextWithRetryAsync(sut2, question); @@ -123,7 +126,7 @@ private static async Task GetContextWithRetryAsync(Mem0Provider provi AIContext? ctx = null; for (int i = 0; i < attempts; i++) { - ctx = await provider.InvokingAsync(new AIContextProvider.InvokingContext([question]), CancellationToken.None); + ctx = await provider.InvokingAsync(new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [question]), CancellationToken.None); var text = ctx.Messages?[0].Text; if (!string.IsNullOrEmpty(text) && text.IndexOf("Caoimhe", StringComparison.OrdinalIgnoreCase) >= 0) { diff --git a/dotnet/tests/Microsoft.Agents.AI.Mem0.UnitTests/Mem0ProviderTests.cs b/dotnet/tests/Microsoft.Agents.AI.Mem0.UnitTests/Mem0ProviderTests.cs index 832881857d..b886784af9 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Mem0.UnitTests/Mem0ProviderTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Mem0.UnitTests/Mem0ProviderTests.cs @@ -18,6 +18,9 @@ namespace Microsoft.Agents.AI.Mem0.UnitTests; /// public sealed class Mem0ProviderTests : IDisposable { + private static readonly AIAgent s_mockAgent = new Mock().Object; + private static readonly AgentSession s_mockSession = new Mock().Object; + private readonly Mock> _loggerMock; private readonly Mock _loggerFactoryMock; private readonly RecordingHandler _handler = new(); @@ -96,7 +99,7 @@ public async Task InvokingAsync_PerformsSearch_AndReturnsContextMessageAsync() UserId = "user" }; var sut = new Mem0Provider(this._httpClient, storageScope, options: new() { EnableSensitiveTelemetryData = true }, loggerFactory: this._loggerFactoryMock.Object); - var invokingContext = new AIContextProvider.InvokingContext([new ChatMessage(ChatRole.User, "What is my name?")]); + var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "What is my name?")]); // Act var aiContext = await sut.InvokingAsync(invokingContext); @@ -161,7 +164,7 @@ public async Task InvokingAsync_LogsUserIdBasedOnEnableSensitiveTelemetryDataAsy var options = new Mem0ProviderOptions { EnableSensitiveTelemetryData = enableSensitiveTelemetryData }; var sut = new Mem0Provider(this._httpClient, storageScope, options: options, loggerFactory: this._loggerFactoryMock.Object); - var invokingContext = new AIContextProvider.InvokingContext([new ChatMessage(ChatRole.User, "Who am I?")]); + var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Who am I?")]); // Act await sut.InvokingAsync(invokingContext, CancellationToken.None); @@ -215,7 +218,7 @@ public async Task InvokedAsync_PersistsAllowedMessagesAsync() }; // Act - await sut.InvokedAsync(new AIContextProvider.InvokedContext(requestMessages, aiContextProviderMessages: null) { ResponseMessages = responseMessages }); + await sut.InvokedAsync(new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages, aiContextProviderMessages: null) { ResponseMessages = responseMessages }); // Assert var memoryPosts = this._handler.Requests.Where(r => r.RequestMessage.RequestUri!.AbsolutePath == "/v1/memories/" && r.RequestMessage.Method == HttpMethod.Post).ToList(); @@ -242,7 +245,7 @@ public async Task InvokedAsync_PersistsNothingForFailedRequestAsync() }; // Act - await sut.InvokedAsync(new AIContextProvider.InvokedContext(requestMessages, aiContextProviderMessages: null) { ResponseMessages = null, InvokeException = new InvalidOperationException("Request Failed") }); + await sut.InvokedAsync(new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages, aiContextProviderMessages: null) { ResponseMessages = null, InvokeException = new InvalidOperationException("Request Failed") }); // Assert Assert.Empty(this._handler.Requests); @@ -268,7 +271,7 @@ public async Task InvokedAsync_ShouldNotThrow_WhenStorageFailsAsync() }; // Act - await sut.InvokedAsync(new AIContextProvider.InvokedContext(requestMessages, aiContextProviderMessages: null) { ResponseMessages = responseMessages }); + await sut.InvokedAsync(new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages, aiContextProviderMessages: null) { ResponseMessages = responseMessages }); // Assert this._loggerMock.Verify( @@ -318,7 +321,7 @@ public async Task InvokedAsync_LogsUserIdBasedOnEnableSensitiveTelemetryDataAsyn }; // Act - await sut.InvokedAsync(new AIContextProvider.InvokedContext(requestMessages, aiContextProviderMessages: null) { ResponseMessages = responseMessages }); + await sut.InvokedAsync(new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, requestMessages, aiContextProviderMessages: null) { ResponseMessages = responseMessages }); // Assert Assert.Equal(expectedLogCount, this._loggerMock.Invocations.Count); @@ -400,7 +403,7 @@ public async Task InvokingAsync_ShouldNotThrow_WhenSearchFailsAsync() // Arrange var storageScope = new Mem0ProviderScope { ApplicationId = "app" }; var provider = new Mem0Provider(this._httpClient, storageScope, loggerFactory: this._loggerFactoryMock.Object); - var invokingContext = new AIContextProvider.InvokingContext([new ChatMessage(ChatRole.User, "Q?")]); + var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Q?")]); // Act var aiContext = await provider.InvokingAsync(invokingContext, CancellationToken.None); diff --git a/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgentSessionTests.cs b/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgentSessionTests.cs index 4001b59090..fd311f9225 100644 --- a/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgentSessionTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.UnitTests/ChatClient/ChatClientAgentSessionTests.cs @@ -327,4 +327,9 @@ public void GetService_RequestingChatHistoryProvider_ReturnsChatHistoryProvider( } #endregion + + internal sealed class Animal + { + public string Name { get; set; } = string.Empty; + } } diff --git a/dotnet/tests/Microsoft.Agents.AI.UnitTests/Data/TextSearchProviderTests.cs b/dotnet/tests/Microsoft.Agents.AI.UnitTests/Data/TextSearchProviderTests.cs index 3698ee7065..360c3071ae 100644 --- a/dotnet/tests/Microsoft.Agents.AI.UnitTests/Data/TextSearchProviderTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.UnitTests/Data/TextSearchProviderTests.cs @@ -17,6 +17,9 @@ namespace Microsoft.Agents.AI.UnitTests.Data; /// public sealed class TextSearchProviderTests { + private static readonly AIAgent s_mockAgent = new Mock().Object; + private static readonly AgentSession s_mockSession = new Mock().Object; + private readonly Mock> _loggerMock; private readonly Mock _loggerFactoryMock; @@ -64,10 +67,12 @@ public async Task InvokingAsync_ShouldInjectFormattedResultsAsync(string? overri var provider = new TextSearchProvider(SearchDelegateAsync, default, null, options, withLogging ? this._loggerFactoryMock.Object : null); var invokingContext = new AIContextProvider.InvokingContext( - [ - new ChatMessage(ChatRole.User, "Sample user question?"), - new ChatMessage(ChatRole.User, "Additional part") - ]); + s_mockAgent, + s_mockSession, + [ + new ChatMessage(ChatRole.User, "Sample user question?"), + new ChatMessage(ChatRole.User, "Additional part") + ]); // Act var aiContext = await provider.InvokingAsync(invokingContext, CancellationToken.None); @@ -139,7 +144,7 @@ public async Task InvokingAsync_OnDemand_ShouldExposeSearchToolAsync(string? ove FunctionToolDescription = overrideDescription }; var provider = new TextSearchProvider(this.NoResultSearchAsync, default, null, options); - var invokingContext = new AIContextProvider.InvokingContext([new ChatMessage(ChatRole.User, "Q?")]); + var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Q?")]); // Act var aiContext = await provider.InvokingAsync(invokingContext, CancellationToken.None); @@ -158,7 +163,7 @@ public async Task InvokingAsync_ShouldNotThrow_WhenSearchFailsAsync() { // Arrange var provider = new TextSearchProvider(this.FailingSearchAsync, default, null, loggerFactory: this._loggerFactoryMock.Object); - var invokingContext = new AIContextProvider.InvokingContext([new ChatMessage(ChatRole.User, "Q?")]); + var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Q?")]); // Act var aiContext = await provider.InvokingAsync(invokingContext, CancellationToken.None); @@ -251,7 +256,7 @@ public async Task InvokingAsync_ShouldUseContextFormatterWhenProvidedAsync() ContextFormatter = r => $"Custom formatted context with {r.Count} results." }; var provider = new TextSearchProvider(SearchDelegateAsync, default, null, options); - var invokingContext = new AIContextProvider.InvokingContext([new ChatMessage(ChatRole.User, "Q?")]); + var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Q?")]); // Act var aiContext = await provider.InvokingAsync(invokingContext, CancellationToken.None); @@ -285,7 +290,7 @@ public async Task InvokingAsync_WithRawRepresentations_ContextFormatterCanAccess ContextFormatter = r => string.Join(",", r.Select(x => ((RawPayload)x.RawRepresentation!).Id)) }; var provider = new TextSearchProvider(SearchDelegateAsync, default, null, options); - var invokingContext = new AIContextProvider.InvokingContext([new ChatMessage(ChatRole.User, "Q?")]); + var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Q?")]); // Act var aiContext = await provider.InvokingAsync(invokingContext, CancellationToken.None); @@ -302,7 +307,7 @@ public async Task InvokingAsync_WithNoResults_ShouldReturnEmptyContextAsync() // Arrange var options = new TextSearchProviderOptions { SearchTime = TextSearchProviderOptions.TextSearchBehavior.BeforeAIInvoke }; var provider = new TextSearchProvider(this.NoResultSearchAsync, default, null, options); - var invokingContext = new AIContextProvider.InvokingContext([new ChatMessage(ChatRole.User, "Q?")]); + var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "Q?")]); // Act var aiContext = await provider.InvokingAsync(invokingContext, CancellationToken.None); @@ -340,12 +345,14 @@ public async Task InvokingAsync_WithPreviousFailedRequest_ShouldNotIncludeFailed new ChatMessage(ChatRole.User, "C"), new ChatMessage(ChatRole.Assistant, "D"), }; - await provider.InvokedAsync(new(initialMessages, aiContextProviderMessages: null) { InvokeException = new InvalidOperationException("Request Failed") }); + await provider.InvokedAsync(new(s_mockAgent, s_mockSession, initialMessages, aiContextProviderMessages: null) { InvokeException = new InvalidOperationException("Request Failed") }); var invokingContext = new AIContextProvider.InvokingContext( - [ - new ChatMessage(ChatRole.User, "E") - ]); + s_mockAgent, + s_mockSession, + [ + new ChatMessage(ChatRole.User, "E") + ]); // Act await provider.InvokingAsync(invokingContext, CancellationToken.None); @@ -380,12 +387,14 @@ public async Task InvokingAsync_WithRecentMessageMemory_ShouldIncludeStoredMessa new ChatMessage(ChatRole.User, "C"), new ChatMessage(ChatRole.Assistant, "D"), }; - await provider.InvokedAsync(new(initialMessages, aiContextProviderMessages: null)); + await provider.InvokedAsync(new(s_mockAgent, s_mockSession, initialMessages, aiContextProviderMessages: null)); var invokingContext = new AIContextProvider.InvokingContext( - [ - new ChatMessage(ChatRole.User, "E") - ]); + s_mockAgent, + s_mockSession, + [ + new ChatMessage(ChatRole.User, "E") + ]); // Act await provider.InvokingAsync(invokingContext, CancellationToken.None); @@ -414,20 +423,24 @@ public async Task InvokingAsync_WithAccumulatedMemoryAcrossInvocations_ShouldInc // First memory update (A,B) await provider.InvokedAsync(new( - [ - new ChatMessage(ChatRole.User, "A"), - new ChatMessage(ChatRole.Assistant, "B"), - ], aiContextProviderMessages: null)); + s_mockAgent, + s_mockSession, + [ + new ChatMessage(ChatRole.User, "A"), + new ChatMessage(ChatRole.Assistant, "B"), + ], aiContextProviderMessages: null)); // Second memory update (C,D,E) await provider.InvokedAsync(new( - [ - new ChatMessage(ChatRole.User, "C"), - new ChatMessage(ChatRole.Assistant, "D"), - new ChatMessage(ChatRole.User, "E"), - ], aiContextProviderMessages: null)); + s_mockAgent, + s_mockSession, + [ + new ChatMessage(ChatRole.User, "C"), + new ChatMessage(ChatRole.Assistant, "D"), + new ChatMessage(ChatRole.User, "E"), + ], aiContextProviderMessages: null)); - var invokingContext = new AIContextProvider.InvokingContext([new ChatMessage(ChatRole.User, "F")]); + var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "F")]); // Act await provider.InvokingAsync(invokingContext, CancellationToken.None); @@ -462,12 +475,14 @@ public async Task InvokingAsync_WithRecentMessageRolesIncluded_ShouldFilterRoles new ChatMessage(ChatRole.User, "U2"), new ChatMessage(ChatRole.Assistant, "A2"), }; - await provider.InvokedAsync(new(initialMessages, null)); + await provider.InvokedAsync(new(s_mockAgent, s_mockSession, initialMessages, null)); var invokingContext = new AIContextProvider.InvokingContext( - [ - new ChatMessage(ChatRole.User, "Question?") // Current request message always appended. - ]); + s_mockAgent, + s_mockSession, + [ + new ChatMessage(ChatRole.User, "Question?") // Current request message always appended. + ]); // Act await provider.InvokingAsync(invokingContext, CancellationToken.None); @@ -518,7 +533,7 @@ public async Task Serialize_WithRecentMessages_ShouldPersistMessagesUpToLimitAsy }; // Act - await provider.InvokedAsync(new(messages, aiContextProviderMessages: null)); // Populate recent memory. + await provider.InvokedAsync(new(s_mockAgent, s_mockSession, messages, aiContextProviderMessages: null)); // Populate recent memory. var state = provider.Serialize(); // Assert @@ -547,7 +562,7 @@ public async Task SerializeAndDeserialize_RoundtripRestoresMessagesAsync() new ChatMessage(ChatRole.User, "C"), new ChatMessage(ChatRole.Assistant, "D"), }; - await provider.InvokedAsync(new(messages, aiContextProviderMessages: null)); + await provider.InvokedAsync(new(s_mockAgent, s_mockSession, messages, aiContextProviderMessages: null)); // Act var state = provider.Serialize(); @@ -563,7 +578,7 @@ public async Task SerializeAndDeserialize_RoundtripRestoresMessagesAsync() RecentMessageMemoryLimit = 4 }); var emptyMessages = Array.Empty(); - await roundTrippedProvider.InvokingAsync(new(emptyMessages), CancellationToken.None); // Trigger search to read memory. + await roundTrippedProvider.InvokingAsync(new(s_mockAgent, s_mockSession, emptyMessages), CancellationToken.None); // Trigger search to read memory. // Assert Assert.NotNull(capturedInput); @@ -588,7 +603,7 @@ public async Task Deserialize_WithChangedLowerLimit_ShouldTruncateToNewLimitAsyn new ChatMessage(ChatRole.Assistant, "L4"), new ChatMessage(ChatRole.User, "L5"), }; - await initialProvider.InvokedAsync(new(messages, aiContextProviderMessages: null)); + await initialProvider.InvokedAsync(new(s_mockAgent, s_mockSession, messages, aiContextProviderMessages: null)); var state = initialProvider.Serialize(); string? capturedInput = null; @@ -604,7 +619,7 @@ public async Task Deserialize_WithChangedLowerLimit_ShouldTruncateToNewLimitAsyn SearchTime = TextSearchProviderOptions.TextSearchBehavior.BeforeAIInvoke, RecentMessageMemoryLimit = 3 // Lower limit }); - await restoredProvider.InvokingAsync(new(Array.Empty()), CancellationToken.None); + await restoredProvider.InvokingAsync(new(s_mockAgent, s_mockSession, Array.Empty()), CancellationToken.None); // Assert Assert.NotNull(capturedInput); @@ -631,7 +646,7 @@ public async Task Deserialize_WithEmptyState_ShouldHaveNoMessagesAsync() RecentMessageMemoryLimit = 3 }); var emptyMessages = Array.Empty(); - await provider.InvokingAsync(new(emptyMessages), CancellationToken.None); + await provider.InvokingAsync(new(s_mockAgent, s_mockSession, emptyMessages), CancellationToken.None); // Assert Assert.NotNull(capturedInput); diff --git a/dotnet/tests/Microsoft.Agents.AI.UnitTests/Memory/ChatHistoryMemoryProviderTests.cs b/dotnet/tests/Microsoft.Agents.AI.UnitTests/Memory/ChatHistoryMemoryProviderTests.cs index f46538c8e4..8d3cad85ae 100644 --- a/dotnet/tests/Microsoft.Agents.AI.UnitTests/Memory/ChatHistoryMemoryProviderTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.UnitTests/Memory/ChatHistoryMemoryProviderTests.cs @@ -18,6 +18,9 @@ namespace Microsoft.Agents.AI.Memory.UnitTests; /// public class ChatHistoryMemoryProviderTests { + private static readonly AIAgent s_mockAgent = new Mock().Object; + private static readonly AgentSession s_mockSession = new Mock().Object; + private readonly Mock> _loggerMock; private readonly Mock _loggerFactoryMock; @@ -116,7 +119,7 @@ public async Task InvokedAsync_UpsertsMessages_ToCollectionAsync() var requestMsgWithNulls = new ChatMessage(ChatRole.User, "request text nulls"); var responseMsg = new ChatMessage(ChatRole.Assistant, "response text") { MessageId = "resp-1", AuthorName = "assistant" }; - var invokedContext = new AIContextProvider.InvokedContext([requestMsgWithValues, requestMsgWithNulls], aiContextProviderMessages: null) + var invokedContext = new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, [requestMsgWithValues, requestMsgWithNulls], aiContextProviderMessages: null) { ResponseMessages = [responseMsg] }; @@ -174,7 +177,7 @@ public async Task InvokedAsync_DoesNotUpsertMessages_WhenInvokeFailedAsync() 1, new ChatHistoryMemoryProviderScope() { UserId = "UID" }); var requestMsg = new ChatMessage(ChatRole.User, "request text") { MessageId = "req-1" }; - var invokedContext = new AIContextProvider.InvokedContext([requestMsg], aiContextProviderMessages: null) + var invokedContext = new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, [requestMsg], aiContextProviderMessages: null) { InvokeException = new InvalidOperationException("Invoke failed") }; @@ -203,7 +206,7 @@ public async Task InvokedAsync_DoesNotThrow_WhenUpsertThrowsAsync() new ChatHistoryMemoryProviderScope() { UserId = "UID" }, loggerFactory: this._loggerFactoryMock.Object); var requestMsg = new ChatMessage(ChatRole.User, "request text") { MessageId = "req-1" }; - var invokedContext = new AIContextProvider.InvokedContext([requestMsg], aiContextProviderMessages: null); + var invokedContext = new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, [requestMsg], aiContextProviderMessages: null); // Act await provider.InvokedAsync(invokedContext, CancellationToken.None); @@ -254,7 +257,7 @@ public async Task InvokedAsync_LogsUserIdBasedOnEnableSensitiveTelemetryDataAsyn loggerFactory: this._loggerFactoryMock.Object); var requestMsg = new ChatMessage(ChatRole.User, "request text"); - var invokedContext = new AIContextProvider.InvokedContext([requestMsg], aiContextProviderMessages: null); + var invokedContext = new AIContextProvider.InvokedContext(s_mockAgent, s_mockSession, [requestMsg], aiContextProviderMessages: null); // Act await provider.InvokedAsync(invokedContext, CancellationToken.None); @@ -327,7 +330,7 @@ public async Task InvokedAsync_SearchesVectorStoreAsync() options: providerOptions); var requestMsg = new ChatMessage(ChatRole.User, "requesting relevant history"); - var invokingContext = new AIContextProvider.InvokingContext([requestMsg]); + var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [requestMsg]); // Act await provider.InvokingAsync(invokingContext, CancellationToken.None); @@ -378,7 +381,7 @@ public async Task InvokedAsync_CreatesFilter_WhenSearchScopeProvidedAsync() var provider = new ChatHistoryMemoryProvider(this._vectorStoreMock.Object, TestCollectionName, 1, options: providerOptions, storageScope: searchScope, searchScope: searchScope); var requestMsg = new ChatMessage(ChatRole.User, "requesting relevant history"); - var invokingContext = new AIContextProvider.InvokingContext([requestMsg]); + var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [requestMsg]); // Act await provider.InvokingAsync(invokingContext, CancellationToken.None); @@ -442,7 +445,7 @@ public async Task InvokingAsync_LogsUserIdBasedOnEnableSensitiveTelemetryDataAsy options: options, loggerFactory: this._loggerFactoryMock.Object); - var invokingContext = new AIContextProvider.InvokingContext([new ChatMessage(ChatRole.User, "requesting relevant history")]); + var invokingContext = new AIContextProvider.InvokingContext(s_mockAgent, s_mockSession, [new ChatMessage(ChatRole.User, "requesting relevant history")]); // Act await provider.InvokingAsync(invokingContext, CancellationToken.None); diff --git a/dotnet/tests/Microsoft.Agents.AI.UnitTests/TestJsonSerializerContext.cs b/dotnet/tests/Microsoft.Agents.AI.UnitTests/TestJsonSerializerContext.cs index b145991439..0ac3ab9fbf 100644 --- a/dotnet/tests/Microsoft.Agents.AI.UnitTests/TestJsonSerializerContext.cs +++ b/dotnet/tests/Microsoft.Agents.AI.UnitTests/TestJsonSerializerContext.cs @@ -14,4 +14,5 @@ namespace Microsoft.Agents.AI.UnitTests; [JsonSerializable(typeof(string))] [JsonSerializable(typeof(string[]))] [JsonSerializable(typeof(Dictionary))] +[JsonSerializable(typeof(ChatClientAgentSessionTests.Animal))] internal sealed partial class TestJsonSerializerContext : JsonSerializerContext; diff --git a/dotnet/tests/OpenAIAssistant.IntegrationTests/OpenAIAssistantFixture.cs b/dotnet/tests/OpenAIAssistant.IntegrationTests/OpenAIAssistantFixture.cs index a90f49f428..bb58f09fb4 100644 --- a/dotnet/tests/OpenAIAssistant.IntegrationTests/OpenAIAssistantFixture.cs +++ b/dotnet/tests/OpenAIAssistant.IntegrationTests/OpenAIAssistantFixture.cs @@ -23,7 +23,7 @@ public class OpenAIAssistantFixture : IChatClientAgentFixture public IChatClient ChatClient => this._agent.ChatClient; - public async Task> GetChatHistoryAsync(AgentSession session) + public async Task> GetChatHistoryAsync(AIAgent agent, AgentSession session) { var typedSession = (ChatClientAgentSession)session; List messages = []; diff --git a/dotnet/tests/OpenAIChatCompletion.IntegrationTests/OpenAIChatCompletionFixture.cs b/dotnet/tests/OpenAIChatCompletion.IntegrationTests/OpenAIChatCompletionFixture.cs index eeb60620c0..304df28fba 100644 --- a/dotnet/tests/OpenAIChatCompletion.IntegrationTests/OpenAIChatCompletionFixture.cs +++ b/dotnet/tests/OpenAIChatCompletion.IntegrationTests/OpenAIChatCompletionFixture.cs @@ -28,7 +28,7 @@ public OpenAIChatCompletionFixture(bool useReasoningChatModel) public IChatClient ChatClient => this._agent.ChatClient; - public async Task> GetChatHistoryAsync(AgentSession session) + public async Task> GetChatHistoryAsync(AIAgent agent, AgentSession session) { var typedSession = (ChatClientAgentSession)session; @@ -37,7 +37,7 @@ public async Task> GetChatHistoryAsync(AgentSession session) return []; } - return (await typedSession.ChatHistoryProvider.InvokingAsync(new([]))).ToList(); + return (await typedSession.ChatHistoryProvider.InvokingAsync(new(agent, session, []))).ToList(); } public Task CreateChatClientAgentAsync( diff --git a/dotnet/tests/OpenAIResponse.IntegrationTests/OpenAIResponseFixture.cs b/dotnet/tests/OpenAIResponse.IntegrationTests/OpenAIResponseFixture.cs index 2006404239..719db6a0b0 100644 --- a/dotnet/tests/OpenAIResponse.IntegrationTests/OpenAIResponseFixture.cs +++ b/dotnet/tests/OpenAIResponse.IntegrationTests/OpenAIResponseFixture.cs @@ -25,7 +25,7 @@ public class OpenAIResponseFixture(bool store) : IChatClientAgentFixture public IChatClient ChatClient => this._agent.ChatClient; - public async Task> GetChatHistoryAsync(AgentSession session) + public async Task> GetChatHistoryAsync(AIAgent agent, AgentSession session) { var typedSession = (ChatClientAgentSession)session; @@ -55,7 +55,7 @@ public async Task> GetChatHistoryAsync(AgentSession session) return []; } - return (await typedSession.ChatHistoryProvider.InvokingAsync(new([]))).ToList(); + return (await typedSession.ChatHistoryProvider.InvokingAsync(new(agent, session, []))).ToList(); } private static ChatMessage ConvertToChatMessage(ResponseItem item)