diff --git a/dotnet/src/Microsoft.Agents.AI.Hosting/AgentHostingServiceCollectionExtensions.cs b/dotnet/src/Microsoft.Agents.AI.Hosting/AgentHostingServiceCollectionExtensions.cs index e12d017343..733a7af9a7 100644 --- a/dotnet/src/Microsoft.Agents.AI.Hosting/AgentHostingServiceCollectionExtensions.cs +++ b/dotnet/src/Microsoft.Agents.AI.Hosting/AgentHostingServiceCollectionExtensions.cs @@ -1,8 +1,7 @@ // Copyright (c) Microsoft. All rights reserved. using System; -using System.Collections.Generic; -using Microsoft.Agents.AI.Hosting.Local; +using System.Linq; using Microsoft.Extensions.AI; using Microsoft.Extensions.DependencyInjection; using Microsoft.Shared.Diagnostics; @@ -29,7 +28,7 @@ public static IHostedAgentBuilder AddAIAgent(this IServiceCollection services, s return services.AddAIAgent(name, (sp, key) => { var chatClient = sp.GetRequiredService(); - var tools = GetRegisteredToolsForAgent(sp, name); + var tools = sp.GetKeyedServices(name).ToList(); return new ChatClientAgent(chatClient, instructions, key, tools: tools); }); } @@ -49,7 +48,7 @@ public static IHostedAgentBuilder AddAIAgent(this IServiceCollection services, s Throw.IfNullOrEmpty(name); return services.AddAIAgent(name, (sp, key) => { - var tools = GetRegisteredToolsForAgent(sp, name); + var tools = sp.GetKeyedServices(name).ToList(); return new ChatClientAgent(chatClient, instructions, key, tools: tools); }); } @@ -70,7 +69,7 @@ public static IHostedAgentBuilder AddAIAgent(this IServiceCollection services, s return services.AddAIAgent(name, (sp, key) => { var chatClient = chatClientServiceKey is null ? sp.GetRequiredService() : sp.GetRequiredKeyedService(chatClientServiceKey); - var tools = GetRegisteredToolsForAgent(sp, name); + var tools = sp.GetKeyedServices(name).ToList(); return new ChatClientAgent(chatClient, instructions, key, tools: tools); }); } @@ -92,7 +91,7 @@ public static IHostedAgentBuilder AddAIAgent(this IServiceCollection services, s return services.AddAIAgent(name, (sp, key) => { var chatClient = chatClientServiceKey is null ? sp.GetRequiredService() : sp.GetRequiredKeyedService(chatClientServiceKey); - var tools = GetRegisteredToolsForAgent(sp, name); + var tools = sp.GetKeyedServices(name).ToList(); return new ChatClientAgent(chatClient, instructions: instructions, name: key, description: description, tools: tools); }); } @@ -127,10 +126,4 @@ public static IHostedAgentBuilder AddAIAgent(this IServiceCollection services, s return new HostedAgentBuilder(name, services); } - - private static IList GetRegisteredToolsForAgent(IServiceProvider serviceProvider, string agentName) - { - var registry = serviceProvider.GetService(); - return registry?.GetTools(agentName) ?? []; - } } diff --git a/dotnet/src/Microsoft.Agents.AI.Hosting/HostedAgentBuilderExtensions.cs b/dotnet/src/Microsoft.Agents.AI.Hosting/HostedAgentBuilderExtensions.cs index d3a437663a..e2c52ff9e0 100644 --- a/dotnet/src/Microsoft.Agents.AI.Hosting/HostedAgentBuilderExtensions.cs +++ b/dotnet/src/Microsoft.Agents.AI.Hosting/HostedAgentBuilderExtensions.cs @@ -1,8 +1,6 @@ // Copyright (c) Microsoft. All rights reserved. using System; -using System.Linq; -using Microsoft.Agents.AI.Hosting.Local; using Microsoft.Extensions.AI; using Microsoft.Extensions.DependencyInjection; using Microsoft.Shared.Diagnostics; @@ -70,18 +68,7 @@ public static IHostedAgentBuilder WithAITool(this IHostedAgentBuilder builder, A Throw.IfNull(builder); Throw.IfNull(tool); - var agentName = builder.Name; - var services = builder.ServiceCollection; - - // Get or create the agent tool registry - var descriptor = services.FirstOrDefault(sd => !sd.IsKeyedService && sd.ServiceType.Equals(typeof(LocalAgentToolRegistry))); - if (descriptor?.ImplementationInstance is not LocalAgentToolRegistry toolRegistry) - { - toolRegistry = new(); - services.Add(ServiceDescriptor.Singleton(toolRegistry)); - } - - toolRegistry.AddTool(agentName, tool); + builder.ServiceCollection.AddKeyedSingleton(builder.Name, tool); return builder; } @@ -105,4 +92,19 @@ public static IHostedAgentBuilder WithAITools(this IHostedAgentBuilder builder, return builder; } + + /// + /// Adds AI tool to an agent being configured with the service collection. + /// + /// The hosted agent builder. + /// A factory function that creates a AI tool using the provided service provider. + public static IHostedAgentBuilder WithAITool(this IHostedAgentBuilder builder, Func factory) + { + Throw.IfNull(builder); + Throw.IfNull(factory); + + builder.ServiceCollection.AddKeyedSingleton(builder.Name, (sp, name) => factory(sp)); + + return builder; + } } diff --git a/dotnet/src/Microsoft.Agents.AI.Hosting/Local/LocalAgentToolRegistry.cs b/dotnet/src/Microsoft.Agents.AI.Hosting/Local/LocalAgentToolRegistry.cs deleted file mode 100644 index 8c87803db3..0000000000 --- a/dotnet/src/Microsoft.Agents.AI.Hosting/Local/LocalAgentToolRegistry.cs +++ /dev/null @@ -1,27 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. - -using System.Collections.Generic; -using Microsoft.Extensions.AI; - -namespace Microsoft.Agents.AI.Hosting.Local; - -internal sealed class LocalAgentToolRegistry -{ - private readonly Dictionary> _toolsByAgentName = []; - - public void AddTool(string agentName, AITool tool) - { - if (!this._toolsByAgentName.TryGetValue(agentName, out var tools)) - { - tools = []; - this._toolsByAgentName[agentName] = tools; - } - - tools.Add(tool); - } - - public IList GetTools(string agentName) - { - return this._toolsByAgentName.TryGetValue(agentName, out var tools) ? tools : []; - } -} diff --git a/dotnet/tests/Microsoft.Agents.AI.Hosting.UnitTests/HostedAgentBuilderToolsExtensionsTests.cs b/dotnet/tests/Microsoft.Agents.AI.Hosting.UnitTests/HostedAgentBuilderToolsExtensionsTests.cs index a229c7e1f8..28b621714f 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Hosting.UnitTests/HostedAgentBuilderToolsExtensionsTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Hosting.UnitTests/HostedAgentBuilderToolsExtensionsTests.cs @@ -2,6 +2,7 @@ using System; using System.Collections.Generic; +using System.Linq; using System.Threading; using System.Threading.Tasks; using Microsoft.Extensions.AI; @@ -17,49 +18,40 @@ public sealed class HostedAgentBuilderToolsExtensionsTests [Fact] public void WithAITool_ThrowsWhenBuilderIsNull() { - // Arrange var tool = new DummyAITool(); - // Act & Assert Assert.Throws(() => HostedAgentBuilderExtensions.WithAITool(null!, tool)); } [Fact] public void WithAITool_ThrowsWhenToolIsNull() { - // Arrange var services = new ServiceCollection(); var builder = services.AddAIAgent("test-agent", "Test instructions"); - // Act & Assert - Assert.Throws(() => builder.WithAITool(null!)); + Assert.Throws(() => builder.WithAITool(tool: null!)); } [Fact] public void WithAITools_ThrowsWhenBuilderIsNull() { - // Arrange var tools = new[] { new DummyAITool() }; - // Act & Assert Assert.Throws(() => HostedAgentBuilderExtensions.WithAITools(null!, tools)); } [Fact] public void WithAITools_ThrowsWhenToolsArrayIsNull() { - // Arrange var services = new ServiceCollection(); var builder = services.AddAIAgent("test-agent", "Test instructions"); - // Act & Assert Assert.Throws(() => builder.WithAITools(null!)); } [Fact] public void RegisteredTools_ResolvesAllToolsForAgent() { - // Arrange var services = new ServiceCollection(); services.AddSingleton(new MockChatClient()); @@ -73,9 +65,13 @@ public void RegisteredTools_ResolvesAllToolsForAgent() var serviceProvider = services.BuildServiceProvider(); - var agent1Tools = ResolveAgentTools(serviceProvider, "test-agent"); + var agent1Tools = ResolveToolsFromAgent(serviceProvider, "test-agent"); Assert.Contains(tool1, agent1Tools); Assert.Contains(tool2, agent1Tools); + + var agent1ToolsDI = ResolveToolsFromDI(serviceProvider, "test-agent"); + Assert.Contains(tool1, agent1ToolsDI); + Assert.Contains(tool2, agent1ToolsDI); } [Fact] @@ -100,21 +96,160 @@ public void RegisteredTools_IsolatedPerAgent() var serviceProvider = services.BuildServiceProvider(); - var agent1Tools = ResolveAgentTools(serviceProvider, "agent1"); - var agent2Tools = ResolveAgentTools(serviceProvider, "agent2"); + var agent1Tools = ResolveToolsFromAgent(serviceProvider, "agent1"); + var agent2Tools = ResolveToolsFromAgent(serviceProvider, "agent2"); + + var agent1ToolsDI = ResolveToolsFromDI(serviceProvider, "agent1"); + var agent2ToolsDI = ResolveToolsFromDI(serviceProvider, "agent2"); Assert.Contains(tool1, agent1Tools); Assert.Contains(tool2, agent1Tools); + Assert.Contains(tool1, agent1ToolsDI); + Assert.Contains(tool2, agent1ToolsDI); + Assert.Contains(tool3, agent2Tools); + Assert.Contains(tool3, agent2ToolsDI); } - private static IList ResolveAgentTools(IServiceProvider serviceProvider, string name) + private static IList ResolveToolsFromAgent(IServiceProvider serviceProvider, string name) { var agent = serviceProvider.GetRequiredKeyedService(name) as ChatClientAgent; Assert.NotNull(agent?.ChatOptions?.Tools); return agent.ChatOptions.Tools; } + private static List ResolveToolsFromDI(IServiceProvider serviceProvider, string name) + { + var tools = serviceProvider.GetKeyedServices(name); + Assert.NotNull(tools); + return tools.ToList(); + } + + [Fact] + public void WithAIToolFactory_ThrowsWhenBuilderIsNull() + { + Assert.Throws(() => HostedAgentBuilderExtensions.WithAITool(null!, CreateTool)); + + static AITool CreateTool(IServiceProvider _) => new DummyAITool(); + } + + [Fact] + public void WithAIToolFactory_ThrowsWhenFactoryIsNull() + { + var services = new ServiceCollection(); + var builder = services.AddAIAgent("test-agent", "Test instructions"); + + Assert.Throws(() => builder.WithAITool(factory: null!)); + } + + [Fact] + public void WithAIToolFactory_RegistersToolFromFactory() + { + var services = new ServiceCollection(); + services.AddSingleton(new MockChatClient()); + + DummyAITool? createdTool = null; + var builder = services.AddAIAgent("test-agent", "Test instructions"); + builder.WithAITool(sp => + { + createdTool = new DummyAITool(); + return createdTool; + }); + + var serviceProvider = services.BuildServiceProvider(); + var tools = ResolveToolsFromDI(serviceProvider, "test-agent"); + + Assert.Single(tools); + Assert.Same(createdTool, tools[0]); + } + + [Fact] + public void WithAIToolFactory_CanAccessServicesFromFactory() + { + var services = new ServiceCollection(); + var mockChatClient = new MockChatClient(); + services.AddSingleton(mockChatClient); + + IChatClient? resolvedChatClient = null; + var builder = services.AddAIAgent("test-agent", "Test instructions"); + builder.WithAITool(sp => + { + resolvedChatClient = sp.GetService(); + return new DummyAITool(); + }); + + var serviceProvider = services.BuildServiceProvider(); + _ = ResolveToolsFromDI(serviceProvider, "test-agent"); + + Assert.Same(mockChatClient, resolvedChatClient); + } + + [Fact] + public void WithAIToolFactory_ToolsAreIsolatedPerAgent() + { + var services = new ServiceCollection(); + services.AddSingleton(new MockChatClient()); + + var tool1 = new DummyAITool(); + var tool2 = new DummyAITool(); + + var builder1 = services.AddAIAgent("agent1", "Agent 1 instructions"); + var builder2 = services.AddAIAgent("agent2", "Agent 2 instructions"); + + builder1.WithAITool(_ => tool1); + builder2.WithAITool(_ => tool2); + + var serviceProvider = services.BuildServiceProvider(); + var agent1Tools = ResolveToolsFromDI(serviceProvider, "agent1"); + var agent2Tools = ResolveToolsFromDI(serviceProvider, "agent2"); + + Assert.Single(agent1Tools); + Assert.Contains(tool1, agent1Tools); + Assert.DoesNotContain(tool2, agent1Tools); + + Assert.Single(agent2Tools); + Assert.Contains(tool2, agent2Tools); + Assert.DoesNotContain(tool1, agent2Tools); + } + + [Fact] + public void WithAIToolFactory_CanCombineWithDirectToolRegistration() + { + var services = new ServiceCollection(); + services.AddSingleton(new MockChatClient()); + + var directTool = new DummyAITool(); + var factoryTool = new DummyAITool(); + + var builder = services.AddAIAgent("test-agent", "Test instructions"); + builder + .WithAITool(directTool) + .WithAITool(_ => factoryTool); + + var serviceProvider = services.BuildServiceProvider(); + var tools = ResolveToolsFromDI(serviceProvider, "test-agent"); + + Assert.Equal(2, tools.Count); + Assert.Contains(directTool, tools); + Assert.Contains(factoryTool, tools); + } + + [Fact] + public void WithAIToolFactory_ToolsAvailableOnAgent() + { + var services = new ServiceCollection(); + services.AddSingleton(new MockChatClient()); + + var factoryTool = new DummyAITool(); + var builder = services.AddAIAgent("test-agent", "Test instructions"); + builder.WithAITool(_ => factoryTool); + + var serviceProvider = services.BuildServiceProvider(); + var agentTools = ResolveToolsFromAgent(serviceProvider, "test-agent"); + + Assert.Contains(factoryTool, agentTools); + } + /// /// Dummy AITool implementation for testing. ///