Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion dotnet/samples/01-get-started/04_memory/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ namespace SampleApp
internal sealed class UserInfoMemory : AIContextProvider
{
private readonly ProviderSessionState<UserInfo> _sessionState;
private IReadOnlyList<string>? _stateKeys;
private readonly IChatClient _chatClient;

public UserInfoMemory(IChatClient chatClient, Func<AgentSession?, UserInfo>? stateInitializer = null)
Expand All @@ -99,7 +100,7 @@ public UserInfoMemory(IChatClient chatClient, Func<AgentSession?, UserInfo>? sta
this._chatClient = chatClient;
}

public override string StateKey => this._sessionState.StateKey;
public override IReadOnlyList<string> StateKeys => this._stateKeys ??= [this._sessionState.StateKey];

public UserInfo GetUserInfo(AgentSession session)
=> this._sessionState.GetOrInitializeState(session);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ namespace SampleApp
internal sealed class VectorChatHistoryProvider : ChatHistoryProvider
{
private readonly ProviderSessionState<State> _sessionState;
private IReadOnlyList<string>? _stateKeys;
private readonly VectorStore _vectorStore;

public VectorChatHistoryProvider(
Expand All @@ -92,7 +93,7 @@ public VectorChatHistoryProvider(
this._vectorStore = vectorStore ?? throw new ArgumentNullException(nameof(vectorStore));
}

public override string StateKey => this._sessionState.StateKey;
public override IReadOnlyList<string> StateKeys => this._stateKeys ??= [this._sessionState.StateKey];

public string GetSessionDbKey(AgentSession session)
=> this._sessionState.GetOrInitializeState(session).SessionDbKey;
Expand Down
13 changes: 8 additions & 5 deletions dotnet/src/Microsoft.Agents.AI.Abstractions/AIContextProvider.cs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ private static IEnumerable<ChatMessage> DefaultExternalOnlyFilter(IEnumerable<Ch
private static IEnumerable<ChatMessage> DefaultNoopFilter(IEnumerable<ChatMessage> messages)
=> messages;

private IReadOnlyList<string>? _stateKeys;

/// <summary>
/// Initializes a new instance of the <see cref="AIContextProvider"/> class.
/// </summary>
Expand Down Expand Up @@ -68,14 +70,15 @@ protected AIContextProvider(
protected Func<IEnumerable<ChatMessage>, IEnumerable<ChatMessage>> StoreInputResponseMessageFilter { get; }

/// <summary>
/// Gets the key used to store the provider state in the <see cref="AgentSession.StateBag"/>.
/// Gets the set of keys used to store the provider state in the <see cref="AgentSession.StateBag"/>.
/// </summary>
/// <remarks>
/// The default value is the name of the concrete type (e.g. <c>"TextSearchProvider"</c>).
/// Implementations may override this to provide a custom key, for example when multiple
/// instances of the same provider type are used in the same session.
/// The default value is a single-element set containing the name of the concrete type (e.g. <c>"TextSearchProvider"</c>).
/// Implementations may override this to provide custom keys, for example when multiple
/// instances of the same provider type are used in the same session, or when a provider
/// stores state under more than one key.
/// </remarks>
public virtual string StateKey => this.GetType().Name;
public virtual IReadOnlyList<string> StateKeys => this._stateKeys ??= [this.GetType().Name];

/// <summary>
/// Called at the start of agent invocation to provide additional context.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ private static IEnumerable<ChatMessage> DefaultExcludeChatHistoryFilter(IEnumera
private static IEnumerable<ChatMessage> DefaultNoopFilter(IEnumerable<ChatMessage> messages)
=> messages;

private IReadOnlyList<string>? _stateKeys;
private readonly Func<IEnumerable<ChatMessage>, IEnumerable<ChatMessage>>? _provideOutputMessageFilter;
private readonly Func<IEnumerable<ChatMessage>, IEnumerable<ChatMessage>> _storeInputRequestMessageFilter;
private readonly Func<IEnumerable<ChatMessage>, IEnumerable<ChatMessage>> _storeInputResponseMessageFilter;
Expand All @@ -66,14 +67,15 @@ protected ChatHistoryProvider(
}

/// <summary>
/// Gets the key used to store the provider state in the <see cref="AgentSession.StateBag"/>.
/// Gets the set of keys used to store the provider state in the <see cref="AgentSession.StateBag"/>.
/// </summary>
/// <remarks>
/// The default value is the name of the concrete type (e.g. <c>"InMemoryChatHistoryProvider"</c>).
/// Implementations may override this to provide a custom key, for example when multiple
/// instances of the same provider type are used in the same session.
/// The default value is a single-element set containing the name of the concrete type (e.g. <c>"InMemoryChatHistoryProvider"</c>).
/// Implementations may override this to provide custom keys, for example when multiple
/// instances of the same provider type are used in the same session, or when a provider
/// stores state under more than one key.
/// </remarks>
public virtual string StateKey => this.GetType().Name;
public virtual IReadOnlyList<string> StateKeys => this._stateKeys ??= [this.GetType().Name];

/// <summary>
/// Called at the start of agent invocation to provide messages for the next agent invocation.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ namespace Microsoft.Agents.AI;
public sealed class InMemoryChatHistoryProvider : ChatHistoryProvider
{
private readonly ProviderSessionState<State> _sessionState;
private IReadOnlyList<string>? _stateKeys;

/// <summary>
/// Initializes a new instance of the <see cref="InMemoryChatHistoryProvider"/> class.
Expand All @@ -50,7 +51,7 @@ public InMemoryChatHistoryProvider(InMemoryChatHistoryProviderOptions? options =
}

/// <inheritdoc />
public override string StateKey => this._sessionState.StateKey;
public override IReadOnlyList<string> StateKeys => this._stateKeys ??= [this._sessionState.StateKey];

/// <summary>
/// Gets the chat reducer used to process or reduce chat messages. If null, no reduction logic will be applied.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ namespace Microsoft.Agents.AI;
public sealed class CosmosChatHistoryProvider : ChatHistoryProvider, IDisposable
{
private readonly ProviderSessionState<State> _sessionState;
private IReadOnlyList<string>? _stateKeys;
private readonly CosmosClient _cosmosClient;
private readonly Container _container;
private readonly bool _ownsClient;
Expand Down Expand Up @@ -114,7 +115,7 @@ public CosmosChatHistoryProvider(
}

/// <inheritdoc />
public override string StateKey => this._sessionState.StateKey;
public override IReadOnlyList<string> StateKeys => this._stateKeys ??= [this._sessionState.StateKey];

/// <summary>
/// Initializes a new instance of the <see cref="CosmosChatHistoryProvider"/> class using a connection string.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ public sealed class FoundryMemoryProvider : AIContextProvider
private const string DefaultContextPrompt = "## Memories\nConsider the following memories when answering user questions:";

private readonly ProviderSessionState<State> _sessionState;
private IReadOnlyList<string>? _stateKeys;
private readonly string _contextPrompt;
private readonly string _memoryStoreName;
private readonly int _maxMemories;
Expand Down Expand Up @@ -82,7 +83,7 @@ public FoundryMemoryProvider(
}

/// <inheritdoc />
public override string StateKey => this._sessionState.StateKey;
public override IReadOnlyList<string> StateKeys => this._stateKeys ??= [this._sessionState.StateKey];

private static Func<AgentSession?, State> ValidateStateInitializer(Func<AgentSession?, State> stateInitializer) =>
session =>
Expand Down
3 changes: 2 additions & 1 deletion dotnet/src/Microsoft.Agents.AI.Mem0/Mem0Provider.cs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ public sealed class Mem0Provider : MessageAIContextProvider
private const string DefaultContextPrompt = "## Memories\nConsider the following memories when answering user questions:";

private readonly ProviderSessionState<State> _sessionState;
private IReadOnlyList<string>? _stateKeys;
private readonly string _contextPrompt;
private readonly bool _enableSensitiveTelemetryData;

Expand Down Expand Up @@ -72,7 +73,7 @@ public Mem0Provider(HttpClient httpClient, Func<AgentSession?, State> stateIniti
}

/// <inheritdoc />
public override string StateKey => this._sessionState.StateKey;
public override IReadOnlyList<string> StateKeys => this._stateKeys ??= [this._sessionState.StateKey];

private static Func<AgentSession?, State> ValidateStateInitializer(Func<AgentSession?, State> stateInitializer) =>
session =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ namespace Microsoft.Agents.AI.Workflows;
internal sealed class WorkflowChatHistoryProvider : ChatHistoryProvider
{
private readonly ProviderSessionState<StoreState> _sessionState;
private IReadOnlyList<string>? _stateKeys;

/// <summary>
/// Initializes a new instance of the <see cref="WorkflowChatHistoryProvider"/> class.
Expand All @@ -30,7 +31,7 @@ public WorkflowChatHistoryProvider(JsonSerializerOptions? jsonSerializerOptions
}

/// <inheritdoc />
public override string StateKey => this._sessionState.StateKey;
public override IReadOnlyList<string> StateKeys => this._stateKeys ??= [this._sessionState.StateKey];

internal sealed class StoreState
{
Expand Down
40 changes: 27 additions & 13 deletions dotnet/src/Microsoft.Agents.AI/ChatClient/ChatClientAgent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ public ChatClientAgent(IChatClient chatClient, ChatClientAgentOptions? options,
this.ChatHistoryProvider = options?.ChatHistoryProvider ?? new InMemoryChatHistoryProvider();
this.AIContextProviders = this._agentOptions?.AIContextProviders as IReadOnlyList<AIContextProvider> ?? this._agentOptions?.AIContextProviders?.ToList();

// Validate that no two providers share the same StateKey, since they would overwrite each other's state in the session.
// Validate that no two providers share any StateKeys, since they would overwrite each other's state in the session.
this._aiContextProviderStateKeys = ValidateAndCollectStateKeys(this._agentOptions?.AIContextProviders, this.ChatHistoryProvider);

this._logger = (loggerFactory ?? chatClient.GetService<ILoggerFactory>() ?? NullLoggerFactory.Instance).CreateLogger<ChatClientAgent>();
Expand Down Expand Up @@ -824,11 +824,17 @@ private Task NotifyChatHistoryProviderOfNewMessagesAsync(
$"Only {nameof(ChatClientAgentSession.ConversationId)} or {nameof(this.ChatHistoryProvider)} may be used, but not both. The current {nameof(ChatClientAgentSession)} has a {nameof(ChatClientAgentSession.ConversationId)} indicating server-side chat history management, but an override {nameof(this.ChatHistoryProvider)} was provided via {nameof(AgentRunOptions.AdditionalProperties)}.");
}

// Validate that the override provider's StateKey does not clash with any AIContextProvider's StateKey.
if (overrideProvider is not null && this._aiContextProviderStateKeys.Contains(overrideProvider.StateKey))
// Validate that the override provider's StateKeys do not clash with any AIContextProvider's StateKeys.
if (overrideProvider is not null)
{
throw new InvalidOperationException(
$"The ChatHistoryProvider '{overrideProvider.GetType().Name}' uses the state key '{overrideProvider.StateKey}' which is already used by one of the configured AIContextProviders. Each provider must use a unique state key to avoid overwriting each other's state.");
foreach (var key in overrideProvider.StateKeys)
{
if (this._aiContextProviderStateKeys.Contains(key))
{
throw new InvalidOperationException(
$"The ChatHistoryProvider '{overrideProvider.GetType().Name}' uses state key '{key}' which is already used by one of the configured AIContextProviders. Each provider must use unique state keys to avoid overwriting each other's state.");
}
}
}

provider = overrideProvider;
Expand Down Expand Up @@ -879,7 +885,7 @@ private static List<ChatResponseUpdate> GetResponseUpdates(ChatClientAgentContin
private string GetLoggingAgentName() => this.Name ?? "UnnamedAgent";

/// <summary>
/// Validates that all configured providers have unique <see cref="AIContextProvider.StateKey"/> values
/// Validates that all configured providers have unique <see cref="AIContextProvider.StateKeys"/> values
/// and returns a <see cref="HashSet{T}"/> of the AIContextProvider state keys.
/// </summary>
private static HashSet<string> ValidateAndCollectStateKeys(IEnumerable<AIContextProvider>? aiContextProviders, ChatHistoryProvider? chatHistoryProvider)
Expand All @@ -890,10 +896,13 @@ private static HashSet<string> ValidateAndCollectStateKeys(IEnumerable<AIContext
{
foreach (var provider in aiContextProviders)
{
if (!stateKeys.Add(provider.StateKey))
foreach (var key in provider.StateKeys)
{
throw new InvalidOperationException(
$"Multiple providers use the same state key '{provider.StateKey}'. Each provider must use a unique state key to avoid overwriting each other's state.");
if (!stateKeys.Add(key))
{
throw new InvalidOperationException(
$"Multiple providers use the same state key '{key}'. Each provider must use a unique state key to avoid overwriting each other's state.");
}
}
Comment thread
westey-m marked this conversation as resolved.
}
}
Expand All @@ -905,11 +914,16 @@ private static HashSet<string> ValidateAndCollectStateKeys(IEnumerable<AIContext
$"The default {nameof(InMemoryChatHistoryProvider)} uses the state key '{nameof(InMemoryChatHistoryProvider)}', which is already used by one of the configured AIContextProviders. Each provider must use a unique state key to avoid overwriting each other's state. To resolve this, either configure a different state key for the AIContextProvider that is using '{nameof(InMemoryChatHistoryProvider)}' as its state key, or provide a custom ChatHistoryProvider with a unique state key.");
}

if (chatHistoryProvider is not null
&& stateKeys.Contains(chatHistoryProvider.StateKey))
if (chatHistoryProvider is not null)
{
throw new InvalidOperationException(
$"The ChatHistoryProvider '{chatHistoryProvider.GetType().Name}' uses the state key '{chatHistoryProvider.StateKey}' which is already used by one of the configured AIContextProviders. Each provider must use a unique state key to avoid overwriting each other's state. To resolve this, either configure a different state key for the AIContextProvider that is using '{chatHistoryProvider.StateKey}' as its state key, or reconfigure the custom ChatHistoryProvider with a unique state key.");
foreach (var key in chatHistoryProvider.StateKeys)
{
if (stateKeys.Contains(key))
{
throw new InvalidOperationException(
$"The ChatHistoryProvider '{chatHistoryProvider.GetType().Name}' uses state key '{key}' which is already used by one of the configured AIContextProviders. Each provider must use unique state keys to avoid overwriting each other's state. To resolve this, either configure different state keys for the AIContextProvider that shares keys with the ChatHistoryProvider, or reconfigure the custom ChatHistoryProvider with unique state keys.");
}
}
}

return stateKeys;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ public sealed class ChatHistoryMemoryProvider : MessageAIContextProvider, IDispo
private const string ContentEmbeddingField = "ContentEmbedding";

private readonly ProviderSessionState<State> _sessionState;
private IReadOnlyList<string>? _stateKeys;

#pragma warning disable CA2213 // VectorStore is not owned by this class - caller is responsible for disposal
private readonly VectorStore _vectorStore;
Expand Down Expand Up @@ -128,7 +129,7 @@ public ChatHistoryMemoryProvider(
}

/// <inheritdoc />
public override string StateKey => this._sessionState.StateKey;
public override IReadOnlyList<string> StateKeys => this._stateKeys ??= [this._sessionState.StateKey];

/// <inheritdoc />
protected override async ValueTask<AIContext> ProvideAIContextAsync(AIContextProvider.InvokingContext context, CancellationToken cancellationToken = default)
Expand Down
3 changes: 2 additions & 1 deletion dotnet/src/Microsoft.Agents.AI/TextSearchProvider.cs
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ public sealed class TextSearchProvider : MessageAIContextProvider
private const string DefaultCitationsPrompt = "Include citations to the source document with document name and link if document name and link is available.";

private readonly ProviderSessionState<TextSearchProviderState> _sessionState;
private IReadOnlyList<string>? _stateKeys;
private readonly Func<string, CancellationToken, Task<IEnumerable<TextSearchResult>>> _searchAsync;
private readonly ILogger<TextSearchProvider>? _logger;
private readonly AITool[] _tools;
Expand Down Expand Up @@ -88,7 +89,7 @@ public TextSearchProvider(
}

/// <inheritdoc />
public override string StateKey => this._sessionState.StateKey;
public override IReadOnlyList<string> StateKeys => this._stateKeys ??= [this._sessionState.StateKey];

/// <inheritdoc />
protected override async ValueTask<AIContext> ProvideAIContextAsync(AIContextProvider.InvokingContext context, CancellationToken cancellationToken = default)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,23 +43,25 @@ public void Constructor_Arguments_SetOnPropertiesCorrectly()
}

[Fact]
public void StateKey_ReturnsDefaultKey_WhenNoOptionsProvided()
public void StateKeys_ReturnsDefaultKey_WhenNoOptionsProvided()
{
// Arrange & Act
var provider = new InMemoryChatHistoryProvider();

// Assert
Assert.Equal("InMemoryChatHistoryProvider", provider.StateKey);
Assert.Single(provider.StateKeys);
Assert.Contains("InMemoryChatHistoryProvider", provider.StateKeys);
}

[Fact]
public void StateKey_ReturnsCustomKey_WhenSetViaOptions()
public void StateKeys_ReturnsCustomKey_WhenSetViaOptions()
{
// Arrange & Act
var provider = new InMemoryChatHistoryProvider(new() { StateKey = "custom-key" });

// Assert
Assert.Equal("custom-key", provider.StateKey);
Assert.Single(provider.StateKeys);
Assert.Contains("custom-key", provider.StateKeys);
}

[Fact]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ private void SkipIfEmulatorNotAvailable()

[SkippableFact]
[Trait("Category", "CosmosDB")]
public void StateKey_ReturnsDefaultKey_WhenNoStateKeyProvided()
public void StateKeys_ReturnsDefaultKey_WhenNoStateKeyProvided()
{
// Arrange & Act
this.SkipIfEmulatorNotAvailable();
Expand All @@ -159,12 +159,13 @@ public void StateKey_ReturnsDefaultKey_WhenNoStateKeyProvided()
_ => new CosmosChatHistoryProvider.State("test-conversation"));

// Assert
Assert.Equal("CosmosChatHistoryProvider", provider.StateKey);
Assert.Single(provider.StateKeys);
Assert.Contains("CosmosChatHistoryProvider", provider.StateKeys);
}

[SkippableFact]
[Trait("Category", "CosmosDB")]
public void StateKey_ReturnsCustomKey_WhenSetViaConstructor()
public void StateKeys_ReturnsCustomKey_WhenSetViaConstructor()
{
// Arrange & Act
this.SkipIfEmulatorNotAvailable();
Expand All @@ -174,7 +175,8 @@ public void StateKey_ReturnsCustomKey_WhenSetViaConstructor()
stateKey: "custom-key");

// Assert
Assert.Equal("custom-key", provider.StateKey);
Assert.Single(provider.StateKeys);
Assert.Contains("custom-key", provider.StateKeys);
}

[SkippableFact]
Expand Down
Loading
Loading