diff --git a/README.md b/README.md index 163d57f8a..550199676 100644 --- a/README.md +++ b/README.md @@ -37,8 +37,8 @@ dotnet add package ModelContextProtocol --prerelease ## Getting Started (Client) -To get started writing a client, the `McpClientFactory.CreateAsync` method is used to instantiate and connect an `IMcpClient` -to a server. Once you have an `IMcpClient`, you can interact with it, such as to enumerate all available tools and invoke tools. +To get started writing a client, the `McpClient.CreateAsync` method is used to instantiate and connect an `McpClient` +to a server. Once you have an `McpClient`, you can interact with it, such as to enumerate all available tools and invoke tools. ```csharp var clientTransport = new StdioClientTransport(new StdioClientTransportOptions @@ -48,7 +48,7 @@ var clientTransport = new StdioClientTransport(new StdioClientTransportOptions Arguments = ["-y", "@modelcontextprotocol/server-everything"], }); -var client = await McpClientFactory.CreateAsync(clientTransport); +var client = await McpClient.CreateAsync(clientTransport); // Print the list of tools available from the server. foreach (var tool in await client.ListToolsAsync()) @@ -122,14 +122,14 @@ public static class EchoTool } ``` -Tools can have the `IMcpServer` representing the server injected via a parameter to the method, and can use that for interaction with +Tools can have the `McpServer` representing the server injected via a parameter to the method, and can use that for interaction with the connected client. Similarly, arguments may be injected via dependency injection. For example, this tool will use the supplied -`IMcpServer` to make sampling requests back to the client in order to summarize content it downloads from the specified url via +`McpServer` to make sampling requests back to the client in order to summarize content it downloads from the specified url via an `HttpClient` injected via dependency injection. ```csharp [McpServerTool(Name = "SummarizeContentFromUrl"), Description("Summarizes content downloaded from a specific URI")] public static async Task SummarizeDownloadedContent( - IMcpServer thisServer, + McpServer thisServer, HttpClient httpClient, [Description("The url from which to download the content to summarize")] string url, CancellationToken cancellationToken) @@ -224,7 +224,7 @@ McpServerOptions options = new() }, }; -await using IMcpServer server = McpServerFactory.Create(new StdioServerTransport("MyServer"), options); +await using McpServer server = McpServer.Create(new StdioServerTransport("MyServer"), options); await server.RunAsync(); ``` diff --git a/docs/concepts/elicitation/samples/client/Program.cs b/docs/concepts/elicitation/samples/client/Program.cs index 6d5178796..b56960dc7 100644 --- a/docs/concepts/elicitation/samples/client/Program.cs +++ b/docs/concepts/elicitation/samples/client/Program.cs @@ -4,7 +4,7 @@ var endpoint = Environment.GetEnvironmentVariable("ENDPOINT") ?? "http://localhost:3001"; -var clientTransport = new SseClientTransport(new() +var clientTransport = new HttpClientTransport(new() { Endpoint = new Uri(endpoint), TransportMode = HttpTransportMode.StreamableHttp, @@ -27,7 +27,7 @@ } }; -await using var mcpClient = await McpClientFactory.CreateAsync(clientTransport, options); +await using var mcpClient = await McpClient.CreateAsync(clientTransport, options); // var tools = await mcpClient.ListToolsAsync(); diff --git a/docs/concepts/elicitation/samples/server/Tools/InteractiveTools.cs b/docs/concepts/elicitation/samples/server/Tools/InteractiveTools.cs index b6a75e005..b907a805d 100644 --- a/docs/concepts/elicitation/samples/server/Tools/InteractiveTools.cs +++ b/docs/concepts/elicitation/samples/server/Tools/InteractiveTools.cs @@ -13,7 +13,7 @@ public sealed class InteractiveTools // [McpServerTool, Description("A simple game where the user has to guess a number between 1 and 10.")] public async Task GuessTheNumber( - IMcpServer server, // Get the McpServer from DI container + McpServer server, // Get the McpServer from DI container CancellationToken token ) { diff --git a/docs/concepts/logging/samples/client/Program.cs b/docs/concepts/logging/samples/client/Program.cs index b30ca0881..29a15726a 100644 --- a/docs/concepts/logging/samples/client/Program.cs +++ b/docs/concepts/logging/samples/client/Program.cs @@ -4,13 +4,13 @@ var endpoint = Environment.GetEnvironmentVariable("ENDPOINT") ?? "http://localhost:3001"; -var clientTransport = new SseClientTransport(new() +var clientTransport = new HttpClientTransport(new() { Endpoint = new Uri(endpoint), TransportMode = HttpTransportMode.StreamableHttp, }); -await using var mcpClient = await McpClientFactory.CreateAsync(clientTransport); +await using var mcpClient = await McpClient.CreateAsync(clientTransport); // // Verify that the server supports logging diff --git a/docs/concepts/progress/samples/client/Program.cs b/docs/concepts/progress/samples/client/Program.cs index 6dde5de9f..2a5f589de 100644 --- a/docs/concepts/progress/samples/client/Program.cs +++ b/docs/concepts/progress/samples/client/Program.cs @@ -5,7 +5,7 @@ var endpoint = Environment.GetEnvironmentVariable("ENDPOINT") ?? "http://localhost:3001"; -var clientTransport = new SseClientTransport(new() +var clientTransport = new HttpClientTransport(new() { Endpoint = new Uri(endpoint), TransportMode = HttpTransportMode.StreamableHttp, @@ -20,7 +20,7 @@ } }; -await using var mcpClient = await McpClientFactory.CreateAsync(clientTransport, options); +await using var mcpClient = await McpClient.CreateAsync(clientTransport, options); var tools = await mcpClient.ListToolsAsync(); foreach (var tool in tools) diff --git a/docs/concepts/progress/samples/server/Tools/LongRunningTools.cs b/docs/concepts/progress/samples/server/Tools/LongRunningTools.cs index ca2a87663..7fcd1244a 100644 --- a/docs/concepts/progress/samples/server/Tools/LongRunningTools.cs +++ b/docs/concepts/progress/samples/server/Tools/LongRunningTools.cs @@ -10,7 +10,7 @@ public class LongRunningTools { [McpServerTool, Description("Demonstrates a long running tool with progress updates")] public static async Task LongRunningTool( - IMcpServer server, + McpServer server, RequestContext context, int duration = 10, int steps = 5) diff --git a/samples/AspNetCoreMcpServer/Tools/SampleLlmTool.cs b/samples/AspNetCoreMcpServer/Tools/SampleLlmTool.cs index 3ac7f567d..e69477452 100644 --- a/samples/AspNetCoreMcpServer/Tools/SampleLlmTool.cs +++ b/samples/AspNetCoreMcpServer/Tools/SampleLlmTool.cs @@ -12,7 +12,7 @@ public sealed class SampleLlmTool { [McpServerTool(Name = "sampleLLM"), Description("Samples from an LLM using MCP's sampling feature")] public static async Task SampleLLM( - IMcpServer thisServer, + McpServer thisServer, [Description("The prompt to send to the LLM")] string prompt, [Description("Maximum number of tokens to generate")] int maxTokens, CancellationToken cancellationToken) diff --git a/samples/ChatWithTools/Program.cs b/samples/ChatWithTools/Program.cs index ba597ae8a..a84393e15 100644 --- a/samples/ChatWithTools/Program.cs +++ b/samples/ChatWithTools/Program.cs @@ -32,7 +32,7 @@ .UseOpenTelemetry(loggerFactory: loggerFactory, configure: o => o.EnableSensitiveData = true) .Build(); -var mcpClient = await McpClientFactory.CreateAsync( +var mcpClient = await McpClient.CreateAsync( new StdioClientTransport(new() { Command = "npx", diff --git a/samples/EverythingServer/LoggingUpdateMessageSender.cs b/samples/EverythingServer/LoggingUpdateMessageSender.cs index 844aa70d8..5f524ad8a 100644 --- a/samples/EverythingServer/LoggingUpdateMessageSender.cs +++ b/samples/EverythingServer/LoggingUpdateMessageSender.cs @@ -5,7 +5,7 @@ namespace EverythingServer; -public class LoggingUpdateMessageSender(IMcpServer server, Func getMinLevel) : BackgroundService +public class LoggingUpdateMessageSender(McpServer server, Func getMinLevel) : BackgroundService { readonly Dictionary _loggingLevelMap = new() { diff --git a/samples/EverythingServer/SubscriptionMessageSender.cs b/samples/EverythingServer/SubscriptionMessageSender.cs index 774d98523..b071965dc 100644 --- a/samples/EverythingServer/SubscriptionMessageSender.cs +++ b/samples/EverythingServer/SubscriptionMessageSender.cs @@ -2,7 +2,7 @@ using ModelContextProtocol; using ModelContextProtocol.Server; -internal class SubscriptionMessageSender(IMcpServer server, HashSet subscriptions) : BackgroundService +internal class SubscriptionMessageSender(McpServer server, HashSet subscriptions) : BackgroundService { protected override async Task ExecuteAsync(CancellationToken stoppingToken) { diff --git a/samples/EverythingServer/Tools/LongRunningTool.cs b/samples/EverythingServer/Tools/LongRunningTool.cs index 27f6ac20f..405b5e823 100644 --- a/samples/EverythingServer/Tools/LongRunningTool.cs +++ b/samples/EverythingServer/Tools/LongRunningTool.cs @@ -10,7 +10,7 @@ public class LongRunningTool { [McpServerTool(Name = "longRunningOperation"), Description("Demonstrates a long running operation with progress updates")] public static async Task LongRunningOperation( - IMcpServer server, + McpServer server, RequestContext context, int duration = 10, int steps = 5) diff --git a/samples/EverythingServer/Tools/SampleLlmTool.cs b/samples/EverythingServer/Tools/SampleLlmTool.cs index a58675c30..6bbe6e51d 100644 --- a/samples/EverythingServer/Tools/SampleLlmTool.cs +++ b/samples/EverythingServer/Tools/SampleLlmTool.cs @@ -9,7 +9,7 @@ public class SampleLlmTool { [McpServerTool(Name = "sampleLLM"), Description("Samples from an LLM using MCP's sampling feature")] public static async Task SampleLLM( - IMcpServer server, + McpServer server, [Description("The prompt to send to the LLM")] string prompt, [Description("Maximum number of tokens to generate")] int maxTokens, CancellationToken cancellationToken) diff --git a/samples/InMemoryTransport/Program.cs b/samples/InMemoryTransport/Program.cs index 67e2d320c..141692fe9 100644 --- a/samples/InMemoryTransport/Program.cs +++ b/samples/InMemoryTransport/Program.cs @@ -6,7 +6,7 @@ Pipe clientToServerPipe = new(), serverToClientPipe = new(); // Create a server using a stream-based transport over an in-memory pipe. -await using IMcpServer server = McpServerFactory.Create( +await using McpServer server = McpServer.Create( new StreamServerTransport(clientToServerPipe.Reader.AsStream(), serverToClientPipe.Writer.AsStream()), new McpServerOptions() { @@ -21,7 +21,7 @@ _ = server.RunAsync(); // Connect a client using a stream-based transport over the same in-memory pipe. -await using IMcpClient client = await McpClientFactory.CreateAsync( +await using McpClient client = await McpClient.CreateAsync( new StreamClientTransport(clientToServerPipe.Writer.AsStream(), serverToClientPipe.Reader.AsStream())); // List all tools. diff --git a/samples/ProtectedMcpClient/Program.cs b/samples/ProtectedMcpClient/Program.cs index 5871284a7..9dc2410ea 100644 --- a/samples/ProtectedMcpClient/Program.cs +++ b/samples/ProtectedMcpClient/Program.cs @@ -25,7 +25,7 @@ builder.AddConsole(); }); -var transport = new SseClientTransport(new() +var transport = new HttpClientTransport(new() { Endpoint = new Uri(serverUrl), Name = "Secure Weather Client", @@ -40,7 +40,7 @@ } }, httpClient, consoleLoggerFactory); -var client = await McpClientFactory.CreateAsync(transport, loggerFactory: consoleLoggerFactory); +var client = await McpClient.CreateAsync(transport, loggerFactory: consoleLoggerFactory); var tools = await client.ListToolsAsync(); if (tools.Count == 0) diff --git a/samples/QuickstartClient/Program.cs b/samples/QuickstartClient/Program.cs index d5b887ff8..cd1c4c60a 100644 --- a/samples/QuickstartClient/Program.cs +++ b/samples/QuickstartClient/Program.cs @@ -19,7 +19,7 @@ if (command == "http") { // make sure AspNetCoreMcpServer is running - clientTransport = new SseClientTransport(new() + clientTransport = new HttpClientTransport(new() { Endpoint = new Uri("http://localhost:3001") }); @@ -33,7 +33,7 @@ Arguments = arguments, }); } -await using var mcpClient = await McpClientFactory.CreateAsync(clientTransport!); +await using var mcpClient = await McpClient.CreateAsync(clientTransport!); var tools = await mcpClient.ListToolsAsync(); foreach (var tool in tools) @@ -62,7 +62,7 @@ var sb = new StringBuilder(); PromptForInput(); -while(Console.ReadLine() is string query && !"exit".Equals(query, StringComparison.OrdinalIgnoreCase)) +while (Console.ReadLine() is string query && !"exit".Equals(query, StringComparison.OrdinalIgnoreCase)) { if (string.IsNullOrWhiteSpace(query)) { diff --git a/samples/TestServerWithHosting/Tools/SampleLlmTool.cs b/samples/TestServerWithHosting/Tools/SampleLlmTool.cs index a096f9301..2c96b8c35 100644 --- a/samples/TestServerWithHosting/Tools/SampleLlmTool.cs +++ b/samples/TestServerWithHosting/Tools/SampleLlmTool.cs @@ -12,7 +12,7 @@ public sealed class SampleLlmTool { [McpServerTool(Name = "sampleLLM"), Description("Samples from an LLM using MCP's sampling feature")] public static async Task SampleLLM( - IMcpServer thisServer, + McpServer thisServer, [Description("The prompt to send to the LLM")] string prompt, [Description("Maximum number of tokens to generate")] int maxTokens, CancellationToken cancellationToken) diff --git a/src/ModelContextProtocol.AspNetCore/AuthorizationFilterSetup.cs b/src/ModelContextProtocol.AspNetCore/AuthorizationFilterSetup.cs index bd0ceabeb..2cfb74d09 100644 --- a/src/ModelContextProtocol.AspNetCore/AuthorizationFilterSetup.cs +++ b/src/ModelContextProtocol.AspNetCore/AuthorizationFilterSetup.cs @@ -292,7 +292,7 @@ private async ValueTask GetAuthorizationResultAsync( if (requestServices is null) { // The IAuthorizationPolicyProvider service must be non-null to get to this line, so it's very unexpected for RequestContext.Services to not be set. - throw new InvalidOperationException("RequestContext.Services is not set! The IMcpServer must be initialized with a non-null IServiceProvider."); + throw new InvalidOperationException("RequestContext.Services is not set! The McpServer must be initialized with a non-null IServiceProvider."); } // ASP.NET Core's AuthorizationMiddleware resolves the IAuthorizationService from scoped request services, so we do the same. diff --git a/src/ModelContextProtocol.AspNetCore/HttpServerTransportOptions.cs b/src/ModelContextProtocol.AspNetCore/HttpServerTransportOptions.cs index 94de9cb99..8d71f5166 100644 --- a/src/ModelContextProtocol.AspNetCore/HttpServerTransportOptions.cs +++ b/src/ModelContextProtocol.AspNetCore/HttpServerTransportOptions.cs @@ -20,7 +20,7 @@ public class HttpServerTransportOptions /// Gets or sets an optional asynchronous callback for running new MCP sessions manually. /// This is useful for running logic before a sessions starts and after it completes. /// - public Func? RunSessionHandler { get; set; } + public Func? RunSessionHandler { get; set; } /// /// Gets or sets whether the server should run in a stateless mode that does not require all requests for a given session diff --git a/src/ModelContextProtocol.AspNetCore/SseHandler.cs b/src/ModelContextProtocol.AspNetCore/SseHandler.cs index fffdd45e3..eefe0d29e 100644 --- a/src/ModelContextProtocol.AspNetCore/SseHandler.cs +++ b/src/ModelContextProtocol.AspNetCore/SseHandler.cs @@ -53,7 +53,7 @@ public async Task HandleSseRequestAsync(HttpContext context) try { - await using var mcpServer = McpServerFactory.Create(transport, mcpServerOptions, loggerFactory, context.RequestServices); + await using var mcpServer = McpServer.Create(transport, mcpServerOptions, loggerFactory, context.RequestServices); context.Features.Set(mcpServer); var runSessionAsync = httpMcpServerOptions.Value.RunSessionHandler ?? StreamableHttpHandler.RunSessionAsync; diff --git a/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs b/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs index a31c6fb75..14093facc 100644 --- a/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs +++ b/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs @@ -232,7 +232,7 @@ private async ValueTask CreateSessionAsync( } } - var server = McpServerFactory.Create(transport, mcpServerOptions, loggerFactory, mcpServerServices); + var server = McpServer.Create(transport, mcpServerOptions, loggerFactory, mcpServerServices); context.Features.Set(server); var userIdClaim = statelessId?.UserIdClaim ?? GetUserIdClaim(context.User); @@ -307,7 +307,7 @@ private void ScheduleStatelessSessionIdWrite(HttpContext context, StreamableHttp }; } - internal static Task RunSessionAsync(HttpContext httpContext, IMcpServer session, CancellationToken requestAborted) + internal static Task RunSessionAsync(HttpContext httpContext, McpServer session, CancellationToken requestAborted) => session.RunAsync(requestAborted); // SignalR only checks for ClaimTypes.NameIdentifier in HttpConnectionDispatcher, but AspNetCore.Antiforgery checks that plus the sub and UPN claims. diff --git a/src/ModelContextProtocol.AspNetCore/StreamableHttpSession.cs b/src/ModelContextProtocol.AspNetCore/StreamableHttpSession.cs index 7c8a31959..1e8d22dec 100644 --- a/src/ModelContextProtocol.AspNetCore/StreamableHttpSession.cs +++ b/src/ModelContextProtocol.AspNetCore/StreamableHttpSession.cs @@ -7,7 +7,7 @@ namespace ModelContextProtocol.AspNetCore; internal sealed class StreamableHttpSession( string sessionId, StreamableHttpServerTransport transport, - IMcpServer server, + McpServer server, UserIdClaim? userId, StatefulSessionManager sessionManager) : IAsyncDisposable { @@ -20,7 +20,7 @@ internal sealed class StreamableHttpSession( public string Id => sessionId; public StreamableHttpServerTransport Transport => transport; - public IMcpServer Server => server; + public McpServer Server => server; private StatefulSessionManager SessionManager => sessionManager; public CancellationToken SessionClosed => _disposeCts.Token; diff --git a/src/ModelContextProtocol.Core/AssemblyNameHelper.cs b/src/ModelContextProtocol.Core/AssemblyNameHelper.cs new file mode 100644 index 000000000..292ed2f96 --- /dev/null +++ b/src/ModelContextProtocol.Core/AssemblyNameHelper.cs @@ -0,0 +1,9 @@ +using System.Reflection; + +namespace ModelContextProtocol; + +internal static class AssemblyNameHelper +{ + /// Cached naming information used for MCP session name/version when none is specified. + public static AssemblyName DefaultAssemblyName { get; } = (Assembly.GetEntryAssembly() ?? Assembly.GetExecutingAssembly()).GetName(); +} diff --git a/src/ModelContextProtocol.Core/Client/AutoDetectingClientSessionTransport.cs b/src/ModelContextProtocol.Core/Client/AutoDetectingClientSessionTransport.cs index 06f2e0bfb..2e49babcf 100644 --- a/src/ModelContextProtocol.Core/Client/AutoDetectingClientSessionTransport.cs +++ b/src/ModelContextProtocol.Core/Client/AutoDetectingClientSessionTransport.cs @@ -12,14 +12,14 @@ namespace ModelContextProtocol.Client; /// internal sealed partial class AutoDetectingClientSessionTransport : ITransport { - private readonly SseClientTransportOptions _options; + private readonly HttpClientTransportOptions _options; private readonly McpHttpClient _httpClient; private readonly ILoggerFactory? _loggerFactory; private readonly ILogger _logger; private readonly string _name; private readonly Channel _messageChannel; - public AutoDetectingClientSessionTransport(string endpointName, SseClientTransportOptions transportOptions, McpHttpClient httpClient, ILoggerFactory? loggerFactory) + public AutoDetectingClientSessionTransport(string endpointName, HttpClientTransportOptions transportOptions, McpHttpClient httpClient, ILoggerFactory? loggerFactory) { Throw.IfNull(transportOptions); Throw.IfNull(httpClient); diff --git a/src/ModelContextProtocol.Core/Client/SseClientTransport.cs b/src/ModelContextProtocol.Core/Client/HttpClientTransport.cs similarity index 86% rename from src/ModelContextProtocol.Core/Client/SseClientTransport.cs rename to src/ModelContextProtocol.Core/Client/HttpClientTransport.cs index b31c3479b..322b9175e 100644 --- a/src/ModelContextProtocol.Core/Client/SseClientTransport.cs +++ b/src/ModelContextProtocol.Core/Client/HttpClientTransport.cs @@ -13,26 +13,26 @@ namespace ModelContextProtocol.Client; /// Unlike the , this transport connects to an existing server /// rather than launching a new process. /// -public sealed class SseClientTransport : IClientTransport, IAsyncDisposable +public sealed class HttpClientTransport : IClientTransport, IAsyncDisposable { - private readonly SseClientTransportOptions _options; + private readonly HttpClientTransportOptions _options; private readonly McpHttpClient _mcpHttpClient; private readonly ILoggerFactory? _loggerFactory; private readonly HttpClient? _ownedHttpClient; /// - /// Initializes a new instance of the class. + /// Initializes a new instance of the class. /// /// Configuration options for the transport. /// Logger factory for creating loggers used for diagnostic output during transport operations. - public SseClientTransport(SseClientTransportOptions transportOptions, ILoggerFactory? loggerFactory = null) + public HttpClientTransport(HttpClientTransportOptions transportOptions, ILoggerFactory? loggerFactory = null) : this(transportOptions, new HttpClient(), loggerFactory, ownsHttpClient: true) { } /// - /// Initializes a new instance of the class with a provided HTTP client. + /// Initializes a new instance of the class with a provided HTTP client. /// /// Configuration options for the transport. /// The HTTP client instance used for requests. @@ -41,7 +41,7 @@ public SseClientTransport(SseClientTransportOptions transportOptions, ILoggerFac /// to dispose of when the transport is disposed; /// if the caller is retaining ownership of the 's lifetime. /// - public SseClientTransport(SseClientTransportOptions transportOptions, HttpClient httpClient, ILoggerFactory? loggerFactory = null, bool ownsHttpClient = false) + public HttpClientTransport(HttpClientTransportOptions transportOptions, HttpClient httpClient, ILoggerFactory? loggerFactory = null, bool ownsHttpClient = false) { Throw.IfNull(transportOptions); Throw.IfNull(httpClient); diff --git a/src/ModelContextProtocol.Core/Client/SseClientTransportOptions.cs b/src/ModelContextProtocol.Core/Client/HttpClientTransportOptions.cs similarity index 96% rename from src/ModelContextProtocol.Core/Client/SseClientTransportOptions.cs rename to src/ModelContextProtocol.Core/Client/HttpClientTransportOptions.cs index 4097844cf..94b95eecb 100644 --- a/src/ModelContextProtocol.Core/Client/SseClientTransportOptions.cs +++ b/src/ModelContextProtocol.Core/Client/HttpClientTransportOptions.cs @@ -3,9 +3,9 @@ namespace ModelContextProtocol.Client; /// -/// Provides options for configuring instances. +/// Provides options for configuring instances. /// -public sealed class SseClientTransportOptions +public sealed class HttpClientTransportOptions { /// /// Gets or sets the base address of the server for SSE connections. diff --git a/src/ModelContextProtocol.Core/Client/IClientTransport.cs b/src/ModelContextProtocol.Core/Client/IClientTransport.cs index 525178957..2201e9b4f 100644 --- a/src/ModelContextProtocol.Core/Client/IClientTransport.cs +++ b/src/ModelContextProtocol.Core/Client/IClientTransport.cs @@ -11,7 +11,7 @@ namespace ModelContextProtocol.Client; /// and servers, allowing different transport protocols to be used interchangeably. /// /// -/// When creating an , is typically used, and is +/// When creating an , is typically used, and is /// provided with the based on expected server configuration. /// /// @@ -39,7 +39,7 @@ public interface IClientTransport /// the transport session as well. /// /// - /// This method is used by to initialize the connection. + /// This method is used by to initialize the connection. /// /// /// The transport connection could not be established. diff --git a/src/ModelContextProtocol.Core/Client/IMcpClient.cs b/src/ModelContextProtocol.Core/Client/IMcpClient.cs index 68a92a2d9..43930c030 100644 --- a/src/ModelContextProtocol.Core/Client/IMcpClient.cs +++ b/src/ModelContextProtocol.Core/Client/IMcpClient.cs @@ -1,10 +1,11 @@ -using ModelContextProtocol.Protocol; +using ModelContextProtocol.Protocol; namespace ModelContextProtocol.Client; /// /// Represents an instance of a Model Context Protocol (MCP) client that connects to and communicates with an MCP server. /// +[Obsolete($"Use {nameof(McpClient)} instead.")] // See: https://github.com/modelcontextprotocol/csharp-sdk/issues/774 public interface IMcpClient : IMcpEndpoint { /// @@ -44,4 +45,4 @@ public interface IMcpClient : IMcpEndpoint /// /// string? ServerInstructions { get; } -} \ No newline at end of file +} diff --git a/src/ModelContextProtocol.Core/Client/McpClient.Methods.cs b/src/ModelContextProtocol.Core/Client/McpClient.Methods.cs new file mode 100644 index 000000000..560ce31dc --- /dev/null +++ b/src/ModelContextProtocol.Core/Client/McpClient.Methods.cs @@ -0,0 +1,713 @@ +using Microsoft.Extensions.AI; +using Microsoft.Extensions.Logging; +using ModelContextProtocol.Protocol; +using ModelContextProtocol.Server; +using System.Runtime.CompilerServices; +using System.Text.Json; + +namespace ModelContextProtocol.Client; + +/// +/// Represents an instance of a Model Context Protocol (MCP) client session that connects to and communicates with an MCP server. +/// +#pragma warning disable CS0618 // Type or member is obsolete +public abstract partial class McpClient : McpSession, IMcpClient +#pragma warning restore CS0618 // Type or member is obsolete +{ + /// Creates an , connecting it to the specified server. + /// The transport instance used to communicate with the server. + /// + /// A client configuration object which specifies client capabilities and protocol version. + /// If , details based on the current process will be employed. + /// + /// A logger factory for creating loggers for clients. + /// The to monitor for cancellation requests. The default is . + /// An that's connected to the specified server. + /// is . + /// is . + public static async Task CreateAsync( + IClientTransport clientTransport, + McpClientOptions? clientOptions = null, + ILoggerFactory? loggerFactory = null, + CancellationToken cancellationToken = default) + { + Throw.IfNull(clientTransport); + + var transport = await clientTransport.ConnectAsync(cancellationToken).ConfigureAwait(false); + var endpointName = clientTransport.Name; + + var clientSession = new McpClientImpl(transport, endpointName, clientOptions, loggerFactory); + try + { + await clientSession.ConnectAsync(cancellationToken).ConfigureAwait(false); + } + catch + { + await clientSession.DisposeAsync().ConfigureAwait(false); + throw; + } + + return clientSession; + } + + /// + /// Sends a ping request to verify server connectivity. + /// + /// The to monitor for cancellation requests. The default is . + /// A task that completes when the ping is successful. + /// Thrown when the server cannot be reached or returns an error response. + public Task PingAsync(CancellationToken cancellationToken = default) + { + var opts = McpJsonUtilities.DefaultOptions; + opts.MakeReadOnly(); + return SendRequestAsync( + RequestMethods.Ping, + parameters: null, + serializerOptions: opts, + cancellationToken: cancellationToken).AsTask(); + } + + /// + /// Retrieves a list of available tools from the server. + /// + /// The serializer options governing tool parameter serialization. If null, the default options will be used. + /// The to monitor for cancellation requests. The default is . + /// A list of all available tools as instances. + public async ValueTask> ListToolsAsync( + JsonSerializerOptions? serializerOptions = null, + CancellationToken cancellationToken = default) + { + serializerOptions ??= McpJsonUtilities.DefaultOptions; + serializerOptions.MakeReadOnly(); + + List? tools = null; + string? cursor = null; + do + { + var toolResults = await SendRequestAsync( + RequestMethods.ToolsList, + new() { Cursor = cursor }, + McpJsonUtilities.JsonContext.Default.ListToolsRequestParams, + McpJsonUtilities.JsonContext.Default.ListToolsResult, + cancellationToken: cancellationToken).ConfigureAwait(false); + + tools ??= new List(toolResults.Tools.Count); + foreach (var tool in toolResults.Tools) + { + tools.Add(new McpClientTool(this, tool, serializerOptions)); + } + + cursor = toolResults.NextCursor; + } + while (cursor is not null); + + return tools; + } + + /// + /// Creates an enumerable for asynchronously enumerating all available tools from the server. + /// + /// The serializer options governing tool parameter serialization. If null, the default options will be used. + /// The to monitor for cancellation requests. The default is . + /// An asynchronous sequence of all available tools as instances. + public async IAsyncEnumerable EnumerateToolsAsync( + JsonSerializerOptions? serializerOptions = null, + [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + serializerOptions ??= McpJsonUtilities.DefaultOptions; + serializerOptions.MakeReadOnly(); + + string? cursor = null; + do + { + var toolResults = await SendRequestAsync( + RequestMethods.ToolsList, + new() { Cursor = cursor }, + McpJsonUtilities.JsonContext.Default.ListToolsRequestParams, + McpJsonUtilities.JsonContext.Default.ListToolsResult, + cancellationToken: cancellationToken).ConfigureAwait(false); + + foreach (var tool in toolResults.Tools) + { + yield return new McpClientTool(this, tool, serializerOptions); + } + + cursor = toolResults.NextCursor; + } + while (cursor is not null); + } + + /// + /// Retrieves a list of available prompts from the server. + /// + /// The to monitor for cancellation requests. The default is . + /// A list of all available prompts as instances. + public async ValueTask> ListPromptsAsync( + CancellationToken cancellationToken = default) + { + List? prompts = null; + string? cursor = null; + do + { + var promptResults = await SendRequestAsync( + RequestMethods.PromptsList, + new() { Cursor = cursor }, + McpJsonUtilities.JsonContext.Default.ListPromptsRequestParams, + McpJsonUtilities.JsonContext.Default.ListPromptsResult, + cancellationToken: cancellationToken).ConfigureAwait(false); + + prompts ??= new List(promptResults.Prompts.Count); + foreach (var prompt in promptResults.Prompts) + { + prompts.Add(new McpClientPrompt(this, prompt)); + } + + cursor = promptResults.NextCursor; + } + while (cursor is not null); + + return prompts; + } + + /// + /// Creates an enumerable for asynchronously enumerating all available prompts from the server. + /// + /// The to monitor for cancellation requests. The default is . + /// An asynchronous sequence of all available prompts as instances. + public async IAsyncEnumerable EnumeratePromptsAsync( + [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + string? cursor = null; + do + { + var promptResults = await SendRequestAsync( + RequestMethods.PromptsList, + new() { Cursor = cursor }, + McpJsonUtilities.JsonContext.Default.ListPromptsRequestParams, + McpJsonUtilities.JsonContext.Default.ListPromptsResult, + cancellationToken: cancellationToken).ConfigureAwait(false); + + foreach (var prompt in promptResults.Prompts) + { + yield return new(this, prompt); + } + + cursor = promptResults.NextCursor; + } + while (cursor is not null); + } + + /// + /// Retrieves a specific prompt from the MCP server. + /// + /// The name of the prompt to retrieve. + /// Optional arguments for the prompt. Keys are parameter names, and values are the argument values. + /// The serialization options governing argument serialization. + /// The to monitor for cancellation requests. The default is . + /// A task containing the prompt's result with content and messages. + public ValueTask GetPromptAsync( + string name, + IReadOnlyDictionary? arguments = null, + JsonSerializerOptions? serializerOptions = null, + CancellationToken cancellationToken = default) + { + Throw.IfNullOrWhiteSpace(name); + + serializerOptions ??= McpJsonUtilities.DefaultOptions; + serializerOptions.MakeReadOnly(); + + return SendRequestAsync( + RequestMethods.PromptsGet, + new() { Name = name, Arguments = ToArgumentsDictionary(arguments, serializerOptions) }, + McpJsonUtilities.JsonContext.Default.GetPromptRequestParams, + McpJsonUtilities.JsonContext.Default.GetPromptResult, + cancellationToken: cancellationToken); + } + + /// + /// Retrieves a list of available resource templates from the server. + /// + /// The to monitor for cancellation requests. The default is . + /// A list of all available resource templates as instances. + public async ValueTask> ListResourceTemplatesAsync( + CancellationToken cancellationToken = default) + { + List? resourceTemplates = null; + + string? cursor = null; + do + { + var templateResults = await SendRequestAsync( + RequestMethods.ResourcesTemplatesList, + new() { Cursor = cursor }, + McpJsonUtilities.JsonContext.Default.ListResourceTemplatesRequestParams, + McpJsonUtilities.JsonContext.Default.ListResourceTemplatesResult, + cancellationToken: cancellationToken).ConfigureAwait(false); + + resourceTemplates ??= new List(templateResults.ResourceTemplates.Count); + foreach (var template in templateResults.ResourceTemplates) + { + resourceTemplates.Add(new McpClientResourceTemplate(this, template)); + } + + cursor = templateResults.NextCursor; + } + while (cursor is not null); + + return resourceTemplates; + } + + /// + /// Creates an enumerable for asynchronously enumerating all available resource templates from the server. + /// + /// The to monitor for cancellation requests. The default is . + /// An asynchronous sequence of all available resource templates as instances. + public async IAsyncEnumerable EnumerateResourceTemplatesAsync( + [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + string? cursor = null; + do + { + var templateResults = await SendRequestAsync( + RequestMethods.ResourcesTemplatesList, + new() { Cursor = cursor }, + McpJsonUtilities.JsonContext.Default.ListResourceTemplatesRequestParams, + McpJsonUtilities.JsonContext.Default.ListResourceTemplatesResult, + cancellationToken: cancellationToken).ConfigureAwait(false); + + foreach (var templateResult in templateResults.ResourceTemplates) + { + yield return new McpClientResourceTemplate(this, templateResult); + } + + cursor = templateResults.NextCursor; + } + while (cursor is not null); + } + + /// + /// Retrieves a list of available resources from the server. + /// + /// The to monitor for cancellation requests. The default is . + /// A list of all available resources as instances. + public async ValueTask> ListResourcesAsync( + CancellationToken cancellationToken = default) + { + List? resources = null; + + string? cursor = null; + do + { + var resourceResults = await SendRequestAsync( + RequestMethods.ResourcesList, + new() { Cursor = cursor }, + McpJsonUtilities.JsonContext.Default.ListResourcesRequestParams, + McpJsonUtilities.JsonContext.Default.ListResourcesResult, + cancellationToken: cancellationToken).ConfigureAwait(false); + + resources ??= new List(resourceResults.Resources.Count); + foreach (var resource in resourceResults.Resources) + { + resources.Add(new McpClientResource(this, resource)); + } + + cursor = resourceResults.NextCursor; + } + while (cursor is not null); + + return resources; + } + + /// + /// Creates an enumerable for asynchronously enumerating all available resources from the server. + /// + /// The to monitor for cancellation requests. The default is . + /// An asynchronous sequence of all available resources as instances. + public async IAsyncEnumerable EnumerateResourcesAsync( + [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + string? cursor = null; + do + { + var resourceResults = await SendRequestAsync( + RequestMethods.ResourcesList, + new() { Cursor = cursor }, + McpJsonUtilities.JsonContext.Default.ListResourcesRequestParams, + McpJsonUtilities.JsonContext.Default.ListResourcesResult, + cancellationToken: cancellationToken).ConfigureAwait(false); + + foreach (var resource in resourceResults.Resources) + { + yield return new McpClientResource(this, resource); + } + + cursor = resourceResults.NextCursor; + } + while (cursor is not null); + } + + /// + /// Reads a resource from the server. + /// + /// The uri of the resource. + /// The to monitor for cancellation requests. The default is . + public ValueTask ReadResourceAsync( + string uri, CancellationToken cancellationToken = default) + { + Throw.IfNullOrWhiteSpace(uri); + + return SendRequestAsync( + RequestMethods.ResourcesRead, + new() { Uri = uri }, + McpJsonUtilities.JsonContext.Default.ReadResourceRequestParams, + McpJsonUtilities.JsonContext.Default.ReadResourceResult, + cancellationToken: cancellationToken); + } + + /// + /// Reads a resource from the server. + /// + /// The uri of the resource. + /// The to monitor for cancellation requests. The default is . + public ValueTask ReadResourceAsync( + Uri uri, CancellationToken cancellationToken = default) + { + Throw.IfNull(uri); + + return ReadResourceAsync(uri.ToString(), cancellationToken); + } + + /// + /// Reads a resource from the server. + /// + /// The uri template of the resource. + /// Arguments to use to format . + /// The to monitor for cancellation requests. The default is . + public ValueTask ReadResourceAsync( + string uriTemplate, IReadOnlyDictionary arguments, CancellationToken cancellationToken = default) + { + Throw.IfNullOrWhiteSpace(uriTemplate); + Throw.IfNull(arguments); + + return SendRequestAsync( + RequestMethods.ResourcesRead, + new() { Uri = UriTemplate.FormatUri(uriTemplate, arguments) }, + McpJsonUtilities.JsonContext.Default.ReadResourceRequestParams, + McpJsonUtilities.JsonContext.Default.ReadResourceResult, + cancellationToken: cancellationToken); + } + + /// + /// Requests completion suggestions for a prompt argument or resource reference. + /// + /// The reference object specifying the type and optional URI or name. + /// The name of the argument for which completions are requested. + /// The current value of the argument, used to filter relevant completions. + /// The to monitor for cancellation requests. The default is . + /// A containing completion suggestions. + public ValueTask CompleteAsync(Reference reference, string argumentName, string argumentValue, CancellationToken cancellationToken = default) + { + Throw.IfNull(reference); + Throw.IfNullOrWhiteSpace(argumentName); + + return SendRequestAsync( + RequestMethods.CompletionComplete, + new() + { + Ref = reference, + Argument = new Argument { Name = argumentName, Value = argumentValue } + }, + McpJsonUtilities.JsonContext.Default.CompleteRequestParams, + McpJsonUtilities.JsonContext.Default.CompleteResult, + cancellationToken: cancellationToken); + } + + /// + /// Subscribes to a resource on the server to receive notifications when it changes. + /// + /// The URI of the resource to which to subscribe. + /// The to monitor for cancellation requests. The default is . + /// A task that represents the asynchronous operation. + public Task SubscribeToResourceAsync(string uri, CancellationToken cancellationToken = default) + { + Throw.IfNullOrWhiteSpace(uri); + + return SendRequestAsync( + RequestMethods.ResourcesSubscribe, + new() { Uri = uri }, + McpJsonUtilities.JsonContext.Default.SubscribeRequestParams, + McpJsonUtilities.JsonContext.Default.EmptyResult, + cancellationToken: cancellationToken).AsTask(); + } + + /// + /// Subscribes to a resource on the server to receive notifications when it changes. + /// + /// The URI of the resource to which to subscribe. + /// The to monitor for cancellation requests. The default is . + /// A task that represents the asynchronous operation. + public Task SubscribeToResourceAsync(Uri uri, CancellationToken cancellationToken = default) + { + Throw.IfNull(uri); + + return SubscribeToResourceAsync(uri.ToString(), cancellationToken); + } + + /// + /// Unsubscribes from a resource on the server to stop receiving notifications about its changes. + /// + /// The URI of the resource to unsubscribe from. + /// The to monitor for cancellation requests. The default is . + /// A task that represents the asynchronous operation. + public Task UnsubscribeFromResourceAsync(string uri, CancellationToken cancellationToken = default) + { + Throw.IfNullOrWhiteSpace(uri); + + return SendRequestAsync( + RequestMethods.ResourcesUnsubscribe, + new() { Uri = uri }, + McpJsonUtilities.JsonContext.Default.UnsubscribeRequestParams, + McpJsonUtilities.JsonContext.Default.EmptyResult, + cancellationToken: cancellationToken).AsTask(); + } + + /// + /// Unsubscribes from a resource on the server to stop receiving notifications about its changes. + /// + /// The URI of the resource to unsubscribe from. + /// The to monitor for cancellation requests. The default is . + /// A task that represents the asynchronous operation. + public Task UnsubscribeFromResourceAsync(Uri uri, CancellationToken cancellationToken = default) + { + Throw.IfNull(uri); + + return UnsubscribeFromResourceAsync(uri.ToString(), cancellationToken); + } + + /// + /// Invokes a tool on the server. + /// + /// The name of the tool to call on the server.. + /// An optional dictionary of arguments to pass to the tool. + /// Optional progress reporter for server notifications. + /// JSON serializer options. + /// A cancellation token. + /// The from the tool execution. + public ValueTask CallToolAsync( + string toolName, + IReadOnlyDictionary? arguments = null, + IProgress? progress = null, + JsonSerializerOptions? serializerOptions = null, + CancellationToken cancellationToken = default) + { + Throw.IfNull(toolName); + serializerOptions ??= McpJsonUtilities.DefaultOptions; + serializerOptions.MakeReadOnly(); + + if (progress is not null) + { + return SendRequestWithProgressAsync(toolName, arguments, progress, serializerOptions, cancellationToken); + } + + return SendRequestAsync( + RequestMethods.ToolsCall, + new() + { + Name = toolName, + Arguments = ToArgumentsDictionary(arguments, serializerOptions), + }, + McpJsonUtilities.JsonContext.Default.CallToolRequestParams, + McpJsonUtilities.JsonContext.Default.CallToolResult, + cancellationToken: cancellationToken); + + async ValueTask SendRequestWithProgressAsync( + string toolName, + IReadOnlyDictionary? arguments, + IProgress progress, + JsonSerializerOptions serializerOptions, + CancellationToken cancellationToken) + { + ProgressToken progressToken = new(Guid.NewGuid().ToString("N")); + + await using var _ = RegisterNotificationHandler(NotificationMethods.ProgressNotification, + (notification, cancellationToken) => + { + if (JsonSerializer.Deserialize(notification.Params, McpJsonUtilities.JsonContext.Default.ProgressNotificationParams) is { } pn && + pn.ProgressToken == progressToken) + { + progress.Report(pn.Progress); + } + + return default; + }).ConfigureAwait(false); + + return await SendRequestAsync( + RequestMethods.ToolsCall, + new() + { + Name = toolName, + Arguments = ToArgumentsDictionary(arguments, serializerOptions), + ProgressToken = progressToken, + }, + McpJsonUtilities.JsonContext.Default.CallToolRequestParams, + McpJsonUtilities.JsonContext.Default.CallToolResult, + cancellationToken: cancellationToken).ConfigureAwait(false); + } + } + + /// + /// Converts the contents of a into a pair of + /// and instances to use + /// as inputs into a operation. + /// + /// + /// The created pair of messages and options. + /// is . + internal static (IList Messages, ChatOptions? Options) ToChatClientArguments( + CreateMessageRequestParams requestParams) + { + Throw.IfNull(requestParams); + + ChatOptions? options = null; + + if (requestParams.MaxTokens is int maxTokens) + { + (options ??= new()).MaxOutputTokens = maxTokens; + } + + if (requestParams.Temperature is float temperature) + { + (options ??= new()).Temperature = temperature; + } + + if (requestParams.StopSequences is { } stopSequences) + { + (options ??= new()).StopSequences = stopSequences.ToArray(); + } + + List messages = + (from sm in requestParams.Messages + let aiContent = sm.Content.ToAIContent() + where aiContent is not null + select new ChatMessage(sm.Role == Role.Assistant ? ChatRole.Assistant : ChatRole.User, [aiContent])) + .ToList(); + + return (messages, options); + } + + /// Converts the contents of a into a . + /// The whose contents should be extracted. + /// The created . + /// is . + internal static CreateMessageResult ToCreateMessageResult(ChatResponse chatResponse) + { + Throw.IfNull(chatResponse); + + // The ChatResponse can include multiple messages, of varying modalities, but CreateMessageResult supports + // only either a single blob of text or a single image. Heuristically, we'll use an image if there is one + // in any of the response messages, or we'll use all the text from them concatenated, otherwise. + + ChatMessage? lastMessage = chatResponse.Messages.LastOrDefault(); + + ContentBlock? content = null; + if (lastMessage is not null) + { + foreach (var lmc in lastMessage.Contents) + { + if (lmc is DataContent dc && (dc.HasTopLevelMediaType("image") || dc.HasTopLevelMediaType("audio"))) + { + content = dc.ToContent(); + } + } + } + + return new() + { + Content = content ?? new TextContentBlock { Text = lastMessage?.Text ?? string.Empty }, + Model = chatResponse.ModelId ?? "unknown", + Role = lastMessage?.Role == ChatRole.User ? Role.User : Role.Assistant, + StopReason = chatResponse.FinishReason == ChatFinishReason.Length ? "maxTokens" : "endTurn", + }; + } + + /// + /// Creates a sampling handler for use with that will + /// satisfy sampling requests using the specified . + /// + /// The with which to satisfy sampling requests. + /// The created handler delegate that can be assigned to . + /// is . + public static Func, CancellationToken, ValueTask> CreateSamplingHandler( + IChatClient chatClient) + { + Throw.IfNull(chatClient); + + return async (requestParams, progress, cancellationToken) => + { + Throw.IfNull(requestParams); + + var (messages, options) = ToChatClientArguments(requestParams); + var progressToken = requestParams.ProgressToken; + + List updates = []; + await foreach (var update in chatClient.GetStreamingResponseAsync(messages, options, cancellationToken).ConfigureAwait(false)) + { + updates.Add(update); + + if (progressToken is not null) + { + progress.Report(new() + { + Progress = updates.Count, + }); + } + } + + return ToCreateMessageResult(updates.ToChatResponse()); + }; + } + + /// + /// Sets the logging level for the server to control which log messages are sent to the client. + /// + /// The minimum severity level of log messages to receive from the server. + /// The to monitor for cancellation requests. The default is . + /// A task representing the asynchronous operation. + public Task SetLoggingLevel(LoggingLevel level, CancellationToken cancellationToken = default) + { + return SendRequestAsync( + RequestMethods.LoggingSetLevel, + new() { Level = level }, + McpJsonUtilities.JsonContext.Default.SetLevelRequestParams, + McpJsonUtilities.JsonContext.Default.EmptyResult, + cancellationToken: cancellationToken).AsTask(); + } + + /// + /// Sets the logging level for the server to control which log messages are sent to the client. + /// + /// The minimum severity level of log messages to receive from the server. + /// The to monitor for cancellation requests. The default is . + /// A task representing the asynchronous operation. + public Task SetLoggingLevel(LogLevel level, CancellationToken cancellationToken = default) => + SetLoggingLevel(McpServerImpl.ToLoggingLevel(level), cancellationToken); + + /// Convers a dictionary with values to a dictionary with values. + private static Dictionary? ToArgumentsDictionary( + IReadOnlyDictionary? arguments, JsonSerializerOptions options) + { + var typeInfo = options.GetTypeInfo(); + + Dictionary? result = null; + if (arguments is not null) + { + result = new(arguments.Count); + foreach (var kvp in arguments) + { + result.Add(kvp.Key, kvp.Value is JsonElement je ? je : JsonSerializer.SerializeToElement(kvp.Value, typeInfo)); + } + } + + return result; + } +} diff --git a/src/ModelContextProtocol.Core/Client/McpClient.cs b/src/ModelContextProtocol.Core/Client/McpClient.cs index dd8c7fe09..c4abe33b7 100644 --- a/src/ModelContextProtocol.Core/Client/McpClient.cs +++ b/src/ModelContextProtocol.Core/Client/McpClient.cs @@ -1,236 +1,49 @@ -using Microsoft.Extensions.Logging; -using ModelContextProtocol.Protocol; -using System.Text.Json; +using ModelContextProtocol.Protocol; namespace ModelContextProtocol.Client; -/// -internal sealed partial class McpClient : McpEndpoint, IMcpClient +/// +/// Represents an instance of a Model Context Protocol (MCP) client session that connects to and communicates with an MCP server. +/// +#pragma warning disable CS0618 // Type or member is obsolete +public abstract partial class McpClient : McpSession, IMcpClient +#pragma warning restore CS0618 // Type or member is obsolete { - private static Implementation DefaultImplementation { get; } = new() - { - Name = DefaultAssemblyName.Name ?? nameof(McpClient), - Version = DefaultAssemblyName.Version?.ToString() ?? "1.0.0", - }; - - private readonly IClientTransport _clientTransport; - private readonly McpClientOptions _options; - - private ITransport? _sessionTransport; - private CancellationTokenSource? _connectCts; - - private ServerCapabilities? _serverCapabilities; - private Implementation? _serverInfo; - private string? _serverInstructions; - /// - /// Initializes a new instance of the class. + /// Gets the capabilities supported by the connected server. /// - /// The transport to use for communication with the server. - /// Options for the client, defining protocol version and capabilities. - /// The logger factory. - public McpClient(IClientTransport clientTransport, McpClientOptions? options, ILoggerFactory? loggerFactory) - : base(loggerFactory) - { - options ??= new(); - - _clientTransport = clientTransport; - _options = options; - - EndpointName = clientTransport.Name; - - if (options.Capabilities is { } capabilities) - { - if (capabilities.NotificationHandlers is { } notificationHandlers) - { - NotificationHandlers.RegisterRange(notificationHandlers); - } - - if (capabilities.Sampling is { } samplingCapability) - { - if (samplingCapability.SamplingHandler is not { } samplingHandler) - { - throw new InvalidOperationException("Sampling capability was set but it did not provide a handler."); - } - - RequestHandlers.Set( - RequestMethods.SamplingCreateMessage, - (request, _, cancellationToken) => samplingHandler( - request, - request?.ProgressToken is { } token ? new TokenProgress(this, token) : NullProgress.Instance, - cancellationToken), - McpJsonUtilities.JsonContext.Default.CreateMessageRequestParams, - McpJsonUtilities.JsonContext.Default.CreateMessageResult); - } - - if (capabilities.Roots is { } rootsCapability) - { - if (rootsCapability.RootsHandler is not { } rootsHandler) - { - throw new InvalidOperationException("Roots capability was set but it did not provide a handler."); - } - - RequestHandlers.Set( - RequestMethods.RootsList, - (request, _, cancellationToken) => rootsHandler(request, cancellationToken), - McpJsonUtilities.JsonContext.Default.ListRootsRequestParams, - McpJsonUtilities.JsonContext.Default.ListRootsResult); - } - - if (capabilities.Elicitation is { } elicitationCapability) - { - if (elicitationCapability.ElicitationHandler is not { } elicitationHandler) - { - throw new InvalidOperationException("Elicitation capability was set but it did not provide a handler."); - } - - RequestHandlers.Set( - RequestMethods.ElicitationCreate, - (request, _, cancellationToken) => elicitationHandler(request, cancellationToken), - McpJsonUtilities.JsonContext.Default.ElicitRequestParams, - McpJsonUtilities.JsonContext.Default.ElicitResult); - } - } - } - - /// - public string? SessionId - { - get - { - if (_sessionTransport is null) - { - throw new InvalidOperationException("Must have already initialized a session when invoking this property."); - } - - return _sessionTransport.SessionId; - } - } - - /// - public ServerCapabilities ServerCapabilities => _serverCapabilities ?? throw new InvalidOperationException("The client is not connected."); - - /// - public Implementation ServerInfo => _serverInfo ?? throw new InvalidOperationException("The client is not connected."); - - /// - public string? ServerInstructions => _serverInstructions; - - /// - public override string EndpointName { get; } + /// The client is not connected. + public abstract ServerCapabilities ServerCapabilities { get; } /// - /// Asynchronously connects to an MCP server, establishes the transport connection, and completes the initialization handshake. + /// Gets the implementation information of the connected server. /// - public async Task ConnectAsync(CancellationToken cancellationToken = default) - { - _connectCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); - cancellationToken = _connectCts.Token; - - try - { - // Connect transport - _sessionTransport = await _clientTransport.ConnectAsync(cancellationToken).ConfigureAwait(false); - InitializeSession(_sessionTransport); - // We don't want the ConnectAsync token to cancel the session after we've successfully connected. - // The base class handles cleaning up the session in DisposeAsync without our help. - StartSession(_sessionTransport, fullSessionCancellationToken: CancellationToken.None); - - // Perform initialization sequence - using var initializationCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); - initializationCts.CancelAfter(_options.InitializationTimeout); - - try - { - // Send initialize request - string requestProtocol = _options.ProtocolVersion ?? McpSession.LatestProtocolVersion; - var initializeResponse = await this.SendRequestAsync( - RequestMethods.Initialize, - new InitializeRequestParams - { - ProtocolVersion = requestProtocol, - Capabilities = _options.Capabilities ?? new ClientCapabilities(), - ClientInfo = _options.ClientInfo ?? DefaultImplementation, - }, - McpJsonUtilities.JsonContext.Default.InitializeRequestParams, - McpJsonUtilities.JsonContext.Default.InitializeResult, - cancellationToken: initializationCts.Token).ConfigureAwait(false); - - // Store server information - if (_logger.IsEnabled(LogLevel.Information)) - { - LogServerCapabilitiesReceived(EndpointName, - capabilities: JsonSerializer.Serialize(initializeResponse.Capabilities, McpJsonUtilities.JsonContext.Default.ServerCapabilities), - serverInfo: JsonSerializer.Serialize(initializeResponse.ServerInfo, McpJsonUtilities.JsonContext.Default.Implementation)); - } - - _serverCapabilities = initializeResponse.Capabilities; - _serverInfo = initializeResponse.ServerInfo; - _serverInstructions = initializeResponse.Instructions; - - // Validate protocol version - bool isResponseProtocolValid = - _options.ProtocolVersion is { } optionsProtocol ? optionsProtocol == initializeResponse.ProtocolVersion : - McpSession.SupportedProtocolVersions.Contains(initializeResponse.ProtocolVersion); - if (!isResponseProtocolValid) - { - LogServerProtocolVersionMismatch(EndpointName, requestProtocol, initializeResponse.ProtocolVersion); - throw new McpException($"Server protocol version mismatch. Expected {requestProtocol}, got {initializeResponse.ProtocolVersion}"); - } - - // Send initialized notification - await this.SendNotificationAsync( - NotificationMethods.InitializedNotification, - new InitializedNotificationParams(), - McpJsonUtilities.JsonContext.Default.InitializedNotificationParams, - cancellationToken: initializationCts.Token).ConfigureAwait(false); + /// + /// + /// This property provides identification details about the connected server, including its name and version. + /// It is populated during the initialization handshake and is available after a successful connection. + /// + /// + /// This information can be useful for logging, debugging, compatibility checks, and displaying server + /// information to users. + /// + /// + /// The client is not connected. + public abstract Implementation ServerInfo { get; } - } - catch (OperationCanceledException oce) when (initializationCts.IsCancellationRequested && !cancellationToken.IsCancellationRequested) - { - LogClientInitializationTimeout(EndpointName); - throw new TimeoutException("Initialization timed out", oce); - } - } - catch (Exception e) - { - LogClientInitializationError(EndpointName, e); - await DisposeAsync().ConfigureAwait(false); - throw; - } - } - - /// - public override async ValueTask DisposeUnsynchronizedAsync() - { - try - { - if (_connectCts is not null) - { - await _connectCts.CancelAsync().ConfigureAwait(false); - _connectCts.Dispose(); - } - - await base.DisposeUnsynchronizedAsync().ConfigureAwait(false); - } - finally - { - if (_sessionTransport is not null) - { - await _sessionTransport.DisposeAsync().ConfigureAwait(false); - } - } - } - - [LoggerMessage(Level = LogLevel.Information, Message = "{EndpointName} client received server '{ServerInfo}' capabilities: '{Capabilities}'.")] - private partial void LogServerCapabilitiesReceived(string endpointName, string capabilities, string serverInfo); - - [LoggerMessage(Level = LogLevel.Error, Message = "{EndpointName} client initialization error.")] - private partial void LogClientInitializationError(string endpointName, Exception exception); - - [LoggerMessage(Level = LogLevel.Error, Message = "{EndpointName} client initialization timed out.")] - private partial void LogClientInitializationTimeout(string endpointName); - - [LoggerMessage(Level = LogLevel.Error, Message = "{EndpointName} client protocol version mismatch with server. Expected '{Expected}', received '{Received}'.")] - private partial void LogServerProtocolVersionMismatch(string endpointName, string expected, string received); -} \ No newline at end of file + /// + /// Gets any instructions describing how to use the connected server and its features. + /// + /// + /// + /// This property contains instructions provided by the server during initialization that explain + /// how to effectively use its capabilities. These instructions can include details about available + /// tools, expected input formats, limitations, or any other helpful information. + /// + /// + /// This can be used by clients to improve an LLM's understanding of available tools, prompts, and resources. + /// It can be thought of like a "hint" to the model and may be added to a system prompt. + /// + /// + public abstract string? ServerInstructions { get; } +} diff --git a/src/ModelContextProtocol.Core/Client/McpClientExtensions.cs b/src/ModelContextProtocol.Core/Client/McpClientExtensions.cs index 60a9c3a64..e987f30f6 100644 --- a/src/ModelContextProtocol.Core/Client/McpClientExtensions.cs +++ b/src/ModelContextProtocol.Core/Client/McpClientExtensions.cs @@ -1,14 +1,13 @@ using Microsoft.Extensions.AI; -using Microsoft.Extensions.Logging; using ModelContextProtocol.Protocol; -using ModelContextProtocol.Server; +using System.Diagnostics.CodeAnalysis; using System.Runtime.CompilerServices; using System.Text.Json; namespace ModelContextProtocol.Client; /// -/// Provides extension methods for interacting with an . +/// Provides extension methods for interacting with an . /// /// /// @@ -19,6 +18,53 @@ namespace ModelContextProtocol.Client; /// public static class McpClientExtensions { + /// + /// Creates a sampling handler for use with that will + /// satisfy sampling requests using the specified . + /// + /// The with which to satisfy sampling requests. + /// The created handler delegate that can be assigned to . + /// + /// + /// This method creates a function that converts MCP message requests into chat client calls, enabling + /// an MCP client to generate text or other content using an actual AI model via the provided chat client. + /// + /// + /// The handler can process text messages, image messages, and resource messages as defined in the + /// Model Context Protocol. + /// + /// + /// is . + public static Func, CancellationToken, ValueTask> CreateSamplingHandler( + this IChatClient chatClient) + { + Throw.IfNull(chatClient); + + return async (requestParams, progress, cancellationToken) => + { + Throw.IfNull(requestParams); + + var (messages, options) = requestParams.ToChatClientArguments(); + var progressToken = requestParams.ProgressToken; + + List updates = []; + await foreach (var update in chatClient.GetStreamingResponseAsync(messages, options, cancellationToken).ConfigureAwait(false)) + { + updates.Add(update); + + if (progressToken is not null) + { + progress.Report(new() + { + Progress = updates.Count, + }); + } + } + + return updates.ToChatResponse().ToCreateMessageResult(); + }; + } + /// /// Sends a ping request to verify server connectivity. /// @@ -38,17 +84,9 @@ public static class McpClientExtensions /// /// is . /// Thrown when the server cannot be reached or returns an error response. + [Obsolete($"Use {nameof(McpClient)}.{nameof(McpClient.PingAsync)} instead.")] // See: https://github.com/modelcontextprotocol/csharp-sdk/issues/774 public static Task PingAsync(this IMcpClient client, CancellationToken cancellationToken = default) - { - Throw.IfNull(client); - - return client.SendRequestAsync( - RequestMethods.Ping, - parameters: null, - McpJsonUtilities.JsonContext.Default.Object!, - McpJsonUtilities.JsonContext.Default.Object, - cancellationToken: cancellationToken).AsTask(); - } + => AsClientOrThrow(client).PingAsync(cancellationToken); /// /// Retrieves a list of available tools from the server. @@ -89,39 +127,12 @@ public static Task PingAsync(this IMcpClient client, CancellationToken cancellat /// /// /// is . - public static async ValueTask> ListToolsAsync( + [Obsolete($"Use {nameof(McpClient)}.{nameof(McpClient.ListToolsAsync)} instead.")] // See: https://github.com/modelcontextprotocol/csharp-sdk/issues/774 + public static ValueTask> ListToolsAsync( this IMcpClient client, JsonSerializerOptions? serializerOptions = null, CancellationToken cancellationToken = default) - { - Throw.IfNull(client); - - serializerOptions ??= McpJsonUtilities.DefaultOptions; - serializerOptions.MakeReadOnly(); - - List? tools = null; - string? cursor = null; - do - { - var toolResults = await client.SendRequestAsync( - RequestMethods.ToolsList, - new() { Cursor = cursor }, - McpJsonUtilities.JsonContext.Default.ListToolsRequestParams, - McpJsonUtilities.JsonContext.Default.ListToolsResult, - cancellationToken: cancellationToken).ConfigureAwait(false); - - tools ??= new List(toolResults.Tools.Count); - foreach (var tool in toolResults.Tools) - { - tools.Add(new McpClientTool(client, tool, serializerOptions)); - } - - cursor = toolResults.NextCursor; - } - while (cursor is not null); - - return tools; - } + => AsClientOrThrow(client).ListToolsAsync(serializerOptions, cancellationToken); /// /// Creates an enumerable for asynchronously enumerating all available tools from the server. @@ -155,35 +166,12 @@ public static async ValueTask> ListToolsAsync( /// /// /// is . - public static async IAsyncEnumerable EnumerateToolsAsync( + [Obsolete($"Use {nameof(McpClient)}.{nameof(McpClient.EnumerateToolsAsync)} instead.")] // See: https://github.com/modelcontextprotocol/csharp-sdk/issues/774 + public static IAsyncEnumerable EnumerateToolsAsync( this IMcpClient client, JsonSerializerOptions? serializerOptions = null, - [EnumeratorCancellation] CancellationToken cancellationToken = default) - { - Throw.IfNull(client); - - serializerOptions ??= McpJsonUtilities.DefaultOptions; - serializerOptions.MakeReadOnly(); - - string? cursor = null; - do - { - var toolResults = await client.SendRequestAsync( - RequestMethods.ToolsList, - new() { Cursor = cursor }, - McpJsonUtilities.JsonContext.Default.ListToolsRequestParams, - McpJsonUtilities.JsonContext.Default.ListToolsResult, - cancellationToken: cancellationToken).ConfigureAwait(false); - - foreach (var tool in toolResults.Tools) - { - yield return new McpClientTool(client, tool, serializerOptions); - } - - cursor = toolResults.NextCursor; - } - while (cursor is not null); - } + CancellationToken cancellationToken = default) + => AsClientOrThrow(client).EnumerateToolsAsync(serializerOptions, cancellationToken); /// /// Retrieves a list of available prompts from the server. @@ -202,34 +190,10 @@ public static async IAsyncEnumerable EnumerateToolsAsync( /// /// /// is . - public static async ValueTask> ListPromptsAsync( + [Obsolete($"Use {nameof(McpClient)}.{nameof(McpClient.ListPromptsAsync)} instead.")] // See: https://github.com/modelcontextprotocol/csharp-sdk/issues/774 + public static ValueTask> ListPromptsAsync( this IMcpClient client, CancellationToken cancellationToken = default) - { - Throw.IfNull(client); - - List? prompts = null; - string? cursor = null; - do - { - var promptResults = await client.SendRequestAsync( - RequestMethods.PromptsList, - new() { Cursor = cursor }, - McpJsonUtilities.JsonContext.Default.ListPromptsRequestParams, - McpJsonUtilities.JsonContext.Default.ListPromptsResult, - cancellationToken: cancellationToken).ConfigureAwait(false); - - prompts ??= new List(promptResults.Prompts.Count); - foreach (var prompt in promptResults.Prompts) - { - prompts.Add(new McpClientPrompt(client, prompt)); - } - - cursor = promptResults.NextCursor; - } - while (cursor is not null); - - return prompts; - } + => AsClientOrThrow(client).ListPromptsAsync(cancellationToken); /// /// Creates an enumerable for asynchronously enumerating all available prompts from the server. @@ -258,30 +222,10 @@ public static async ValueTask> ListPromptsAsync( /// /// /// is . - public static async IAsyncEnumerable EnumeratePromptsAsync( - this IMcpClient client, [EnumeratorCancellation] CancellationToken cancellationToken = default) - { - Throw.IfNull(client); - - string? cursor = null; - do - { - var promptResults = await client.SendRequestAsync( - RequestMethods.PromptsList, - new() { Cursor = cursor }, - McpJsonUtilities.JsonContext.Default.ListPromptsRequestParams, - McpJsonUtilities.JsonContext.Default.ListPromptsResult, - cancellationToken: cancellationToken).ConfigureAwait(false); - - foreach (var prompt in promptResults.Prompts) - { - yield return new(client, prompt); - } - - cursor = promptResults.NextCursor; - } - while (cursor is not null); - } + [Obsolete($"Use {nameof(McpClient)}.{nameof(McpClient.EnumeratePromptsAsync)} instead.")] // See: https://github.com/modelcontextprotocol/csharp-sdk/issues/774 + public static IAsyncEnumerable EnumeratePromptsAsync( + this IMcpClient client, CancellationToken cancellationToken = default) + => AsClientOrThrow(client).EnumeratePromptsAsync(cancellationToken); /// /// Retrieves a specific prompt from the MCP server. @@ -308,26 +252,14 @@ public static async IAsyncEnumerable EnumeratePromptsAsync( /// /// Thrown when the prompt does not exist, when required arguments are missing, or when the server encounters an error processing the prompt. /// is . + [Obsolete($"Use {nameof(McpClient)}.{nameof(McpClient.GetPromptAsync)} instead.")] // See: https://github.com/modelcontextprotocol/csharp-sdk/issues/774 public static ValueTask GetPromptAsync( this IMcpClient client, string name, IReadOnlyDictionary? arguments = null, JsonSerializerOptions? serializerOptions = null, CancellationToken cancellationToken = default) - { - Throw.IfNull(client); - Throw.IfNullOrWhiteSpace(name); - - serializerOptions ??= McpJsonUtilities.DefaultOptions; - serializerOptions.MakeReadOnly(); - - return client.SendRequestAsync( - RequestMethods.PromptsGet, - new() { Name = name, Arguments = ToArgumentsDictionary(arguments, serializerOptions) }, - McpJsonUtilities.JsonContext.Default.GetPromptRequestParams, - McpJsonUtilities.JsonContext.Default.GetPromptResult, - cancellationToken: cancellationToken); - } + => AsClientOrThrow(client).GetPromptAsync(name, arguments, serializerOptions, cancellationToken); /// /// Retrieves a list of available resource templates from the server. @@ -346,35 +278,10 @@ public static ValueTask GetPromptAsync( /// /// /// is . - public static async ValueTask> ListResourceTemplatesAsync( + [Obsolete($"Use {nameof(McpClient)}.{nameof(McpClient.ListResourceTemplatesAsync)} instead.")] // See: https://github.com/modelcontextprotocol/csharp-sdk/issues/774 + public static ValueTask> ListResourceTemplatesAsync( this IMcpClient client, CancellationToken cancellationToken = default) - { - Throw.IfNull(client); - - List? resourceTemplates = null; - - string? cursor = null; - do - { - var templateResults = await client.SendRequestAsync( - RequestMethods.ResourcesTemplatesList, - new() { Cursor = cursor }, - McpJsonUtilities.JsonContext.Default.ListResourceTemplatesRequestParams, - McpJsonUtilities.JsonContext.Default.ListResourceTemplatesResult, - cancellationToken: cancellationToken).ConfigureAwait(false); - - resourceTemplates ??= new List(templateResults.ResourceTemplates.Count); - foreach (var template in templateResults.ResourceTemplates) - { - resourceTemplates.Add(new McpClientResourceTemplate(client, template)); - } - - cursor = templateResults.NextCursor; - } - while (cursor is not null); - - return resourceTemplates; - } + => AsClientOrThrow(client).ListResourceTemplatesAsync(cancellationToken); /// /// Creates an enumerable for asynchronously enumerating all available resource templates from the server. @@ -403,30 +310,10 @@ public static async ValueTask> ListResourceTemp /// /// /// is . - public static async IAsyncEnumerable EnumerateResourceTemplatesAsync( - this IMcpClient client, [EnumeratorCancellation] CancellationToken cancellationToken = default) - { - Throw.IfNull(client); - - string? cursor = null; - do - { - var templateResults = await client.SendRequestAsync( - RequestMethods.ResourcesTemplatesList, - new() { Cursor = cursor }, - McpJsonUtilities.JsonContext.Default.ListResourceTemplatesRequestParams, - McpJsonUtilities.JsonContext.Default.ListResourceTemplatesResult, - cancellationToken: cancellationToken).ConfigureAwait(false); - - foreach (var templateResult in templateResults.ResourceTemplates) - { - yield return new McpClientResourceTemplate(client, templateResult); - } - - cursor = templateResults.NextCursor; - } - while (cursor is not null); - } + [Obsolete($"Use {nameof(McpClient)}.{nameof(McpClient.EnumerateResourceTemplatesAsync)} instead.")] // See: https://github.com/modelcontextprotocol/csharp-sdk/issues/774 + public static IAsyncEnumerable EnumerateResourceTemplatesAsync( + this IMcpClient client, CancellationToken cancellationToken = default) + => AsClientOrThrow(client).EnumerateResourceTemplatesAsync(cancellationToken); /// /// Retrieves a list of available resources from the server. @@ -457,35 +344,10 @@ public static async IAsyncEnumerable EnumerateResourc /// /// /// is . - public static async ValueTask> ListResourcesAsync( + [Obsolete($"Use {nameof(McpClient)}.{nameof(McpClient.ListResourcesAsync)} instead.")] // See: https://github.com/modelcontextprotocol/csharp-sdk/issues/774 + public static ValueTask> ListResourcesAsync( this IMcpClient client, CancellationToken cancellationToken = default) - { - Throw.IfNull(client); - - List? resources = null; - - string? cursor = null; - do - { - var resourceResults = await client.SendRequestAsync( - RequestMethods.ResourcesList, - new() { Cursor = cursor }, - McpJsonUtilities.JsonContext.Default.ListResourcesRequestParams, - McpJsonUtilities.JsonContext.Default.ListResourcesResult, - cancellationToken: cancellationToken).ConfigureAwait(false); - - resources ??= new List(resourceResults.Resources.Count); - foreach (var resource in resourceResults.Resources) - { - resources.Add(new McpClientResource(client, resource)); - } - - cursor = resourceResults.NextCursor; - } - while (cursor is not null); - - return resources; - } + => AsClientOrThrow(client).ListResourcesAsync(cancellationToken); /// /// Creates an enumerable for asynchronously enumerating all available resources from the server. @@ -514,30 +376,10 @@ public static async ValueTask> ListResourcesAsync( /// /// /// is . - public static async IAsyncEnumerable EnumerateResourcesAsync( - this IMcpClient client, [EnumeratorCancellation] CancellationToken cancellationToken = default) - { - Throw.IfNull(client); - - string? cursor = null; - do - { - var resourceResults = await client.SendRequestAsync( - RequestMethods.ResourcesList, - new() { Cursor = cursor }, - McpJsonUtilities.JsonContext.Default.ListResourcesRequestParams, - McpJsonUtilities.JsonContext.Default.ListResourcesResult, - cancellationToken: cancellationToken).ConfigureAwait(false); - - foreach (var resource in resourceResults.Resources) - { - yield return new McpClientResource(client, resource); - } - - cursor = resourceResults.NextCursor; - } - while (cursor is not null); - } + [Obsolete($"Use {nameof(McpClient)}.{nameof(McpClient.EnumerateResourcesAsync)} instead.")] // See: https://github.com/modelcontextprotocol/csharp-sdk/issues/774 + public static IAsyncEnumerable EnumerateResourcesAsync( + this IMcpClient client, CancellationToken cancellationToken = default) + => AsClientOrThrow(client).EnumerateResourcesAsync(cancellationToken); /// /// Reads a resource from the server. @@ -548,19 +390,10 @@ public static async IAsyncEnumerable EnumerateResourcesAsync( /// is . /// is . /// is empty or composed entirely of whitespace. + [Obsolete($"Use {nameof(McpClient)}.{nameof(McpClient.ReadResourceAsync)} instead.")] // See: https://github.com/modelcontextprotocol/csharp-sdk/issues/774 public static ValueTask ReadResourceAsync( this IMcpClient client, string uri, CancellationToken cancellationToken = default) - { - Throw.IfNull(client); - Throw.IfNullOrWhiteSpace(uri); - - return client.SendRequestAsync( - RequestMethods.ResourcesRead, - new() { Uri = uri }, - McpJsonUtilities.JsonContext.Default.ReadResourceRequestParams, - McpJsonUtilities.JsonContext.Default.ReadResourceResult, - cancellationToken: cancellationToken); - } + => AsClientOrThrow(client).ReadResourceAsync(uri, cancellationToken); /// /// Reads a resource from the server. @@ -570,14 +403,10 @@ public static ValueTask ReadResourceAsync( /// The to monitor for cancellation requests. The default is . /// is . /// is . + [Obsolete($"Use {nameof(McpClient)}.{nameof(McpClient.ReadResourceAsync)} instead.")] // See: https://github.com/modelcontextprotocol/csharp-sdk/issues/774 public static ValueTask ReadResourceAsync( this IMcpClient client, Uri uri, CancellationToken cancellationToken = default) - { - Throw.IfNull(client); - Throw.IfNull(uri); - - return ReadResourceAsync(client, uri.ToString(), cancellationToken); - } + => AsClientOrThrow(client).ReadResourceAsync(uri, cancellationToken); /// /// Reads a resource from the server. @@ -589,20 +418,10 @@ public static ValueTask ReadResourceAsync( /// is . /// is . /// is empty or composed entirely of whitespace. + [Obsolete($"Use {nameof(McpClient)}.{nameof(McpClient.ReadResourceAsync)} instead.")] // See: https://github.com/modelcontextprotocol/csharp-sdk/issues/774 public static ValueTask ReadResourceAsync( this IMcpClient client, string uriTemplate, IReadOnlyDictionary arguments, CancellationToken cancellationToken = default) - { - Throw.IfNull(client); - Throw.IfNullOrWhiteSpace(uriTemplate); - Throw.IfNull(arguments); - - return client.SendRequestAsync( - RequestMethods.ResourcesRead, - new() { Uri = UriTemplate.FormatUri(uriTemplate, arguments) }, - McpJsonUtilities.JsonContext.Default.ReadResourceRequestParams, - McpJsonUtilities.JsonContext.Default.ReadResourceResult, - cancellationToken: cancellationToken); - } + => AsClientOrThrow(client).ReadResourceAsync(uriTemplate, arguments, cancellationToken); /// /// Requests completion suggestions for a prompt argument or resource reference. @@ -633,23 +452,9 @@ public static ValueTask ReadResourceAsync( /// is . /// is empty or composed entirely of whitespace. /// The server returned an error response. + [Obsolete($"Use {nameof(McpClient)}.{nameof(McpClient.CompleteAsync)} instead.")] // See: https://github.com/modelcontextprotocol/csharp-sdk/issues/774 public static ValueTask CompleteAsync(this IMcpClient client, Reference reference, string argumentName, string argumentValue, CancellationToken cancellationToken = default) - { - Throw.IfNull(client); - Throw.IfNull(reference); - Throw.IfNullOrWhiteSpace(argumentName); - - return client.SendRequestAsync( - RequestMethods.CompletionComplete, - new() - { - Ref = reference, - Argument = new Argument { Name = argumentName, Value = argumentValue } - }, - McpJsonUtilities.JsonContext.Default.CompleteRequestParams, - McpJsonUtilities.JsonContext.Default.CompleteResult, - cancellationToken: cancellationToken); - } + => AsClientOrThrow(client).CompleteAsync(reference, argumentName, argumentValue, cancellationToken); /// /// Subscribes to a resource on the server to receive notifications when it changes. @@ -676,18 +481,9 @@ public static ValueTask CompleteAsync(this IMcpClient client, Re /// is . /// is . /// is empty or composed entirely of whitespace. + [Obsolete($"Use {nameof(McpClient)}.{nameof(McpClient.SubscribeToResourceAsync)} instead.")] // See: https://github.com/modelcontextprotocol/csharp-sdk/issues/774 public static Task SubscribeToResourceAsync(this IMcpClient client, string uri, CancellationToken cancellationToken = default) - { - Throw.IfNull(client); - Throw.IfNullOrWhiteSpace(uri); - - return client.SendRequestAsync( - RequestMethods.ResourcesSubscribe, - new() { Uri = uri }, - McpJsonUtilities.JsonContext.Default.SubscribeRequestParams, - McpJsonUtilities.JsonContext.Default.EmptyResult, - cancellationToken: cancellationToken).AsTask(); - } + => AsClientOrThrow(client).SubscribeToResourceAsync(uri, cancellationToken); /// /// Subscribes to a resource on the server to receive notifications when it changes. @@ -713,13 +509,9 @@ public static Task SubscribeToResourceAsync(this IMcpClient client, string uri, /// /// is . /// is . + [Obsolete($"Use {nameof(McpClient)}.{nameof(McpClient.SubscribeToResourceAsync)} instead.")] // See: https://github.com/modelcontextprotocol/csharp-sdk/issues/774 public static Task SubscribeToResourceAsync(this IMcpClient client, Uri uri, CancellationToken cancellationToken = default) - { - Throw.IfNull(client); - Throw.IfNull(uri); - - return SubscribeToResourceAsync(client, uri.ToString(), cancellationToken); - } + => AsClientOrThrow(client).SubscribeToResourceAsync(uri, cancellationToken); /// /// Unsubscribes from a resource on the server to stop receiving notifications about its changes. @@ -745,18 +537,9 @@ public static Task SubscribeToResourceAsync(this IMcpClient client, Uri uri, Can /// is . /// is . /// is empty or composed entirely of whitespace. + [Obsolete($"Use {nameof(McpClient)}.{nameof(McpClient.UnsubscribeFromResourceAsync)} instead.")] // See: https://github.com/modelcontextprotocol/csharp-sdk/issues/774 public static Task UnsubscribeFromResourceAsync(this IMcpClient client, string uri, CancellationToken cancellationToken = default) - { - Throw.IfNull(client); - Throw.IfNullOrWhiteSpace(uri); - - return client.SendRequestAsync( - RequestMethods.ResourcesUnsubscribe, - new() { Uri = uri }, - McpJsonUtilities.JsonContext.Default.UnsubscribeRequestParams, - McpJsonUtilities.JsonContext.Default.EmptyResult, - cancellationToken: cancellationToken).AsTask(); - } + => AsClientOrThrow(client).UnsubscribeFromResourceAsync(uri, cancellationToken); /// /// Unsubscribes from a resource on the server to stop receiving notifications about its changes. @@ -781,13 +564,9 @@ public static Task UnsubscribeFromResourceAsync(this IMcpClient client, string u /// /// is . /// is . + [Obsolete($"Use {nameof(McpClient)}.{nameof(McpClient.UnsubscribeFromResourceAsync)} instead.")] // See: https://github.com/modelcontextprotocol/csharp-sdk/issues/774 public static Task UnsubscribeFromResourceAsync(this IMcpClient client, Uri uri, CancellationToken cancellationToken = default) - { - Throw.IfNull(client); - Throw.IfNull(uri); - - return UnsubscribeFromResourceAsync(client, uri.ToString(), cancellationToken); - } + => AsClientOrThrow(client).UnsubscribeFromResourceAsync(uri, cancellationToken); /// /// Invokes a tool on the server. @@ -824,6 +603,7 @@ public static Task UnsubscribeFromResourceAsync(this IMcpClient client, Uri uri, /// }); /// /// + [Obsolete($"Use {nameof(McpClient)}.{nameof(McpClient.CallToolAsync)} instead.")] // See: https://github.com/modelcontextprotocol/csharp-sdk/issues/774 public static ValueTask CallToolAsync( this IMcpClient client, string toolName, @@ -831,62 +611,28 @@ public static ValueTask CallToolAsync( IProgress? progress = null, JsonSerializerOptions? serializerOptions = null, CancellationToken cancellationToken = default) - { - Throw.IfNull(client); - Throw.IfNull(toolName); - serializerOptions ??= McpJsonUtilities.DefaultOptions; - serializerOptions.MakeReadOnly(); + => AsClientOrThrow(client).CallToolAsync(toolName, arguments, progress, serializerOptions, cancellationToken); - if (progress is not null) + [MethodImpl(MethodImplOptions.AggressiveInlining)] +#pragma warning disable CS0618 // Type or member is obsolete + private static McpClient AsClientOrThrow(IMcpClient client, [CallerMemberName] string memberName = "") +#pragma warning restore CS0618 // Type or member is obsolete + { + if (client is not McpClient mcpClient) { - return SendRequestWithProgressAsync(client, toolName, arguments, progress, serializerOptions, cancellationToken); + ThrowInvalidEndpointType(memberName); } - return client.SendRequestAsync( - RequestMethods.ToolsCall, - new() - { - Name = toolName, - Arguments = ToArgumentsDictionary(arguments, serializerOptions), - }, - McpJsonUtilities.JsonContext.Default.CallToolRequestParams, - McpJsonUtilities.JsonContext.Default.CallToolResult, - cancellationToken: cancellationToken); - - static async ValueTask SendRequestWithProgressAsync( - IMcpClient client, - string toolName, - IReadOnlyDictionary? arguments, - IProgress progress, - JsonSerializerOptions serializerOptions, - CancellationToken cancellationToken) - { - ProgressToken progressToken = new(Guid.NewGuid().ToString("N")); - - await using var _ = client.RegisterNotificationHandler(NotificationMethods.ProgressNotification, - (notification, cancellationToken) => - { - if (JsonSerializer.Deserialize(notification.Params, McpJsonUtilities.JsonContext.Default.ProgressNotificationParams) is { } pn && - pn.ProgressToken == progressToken) - { - progress.Report(pn.Progress); - } - - return default; - }).ConfigureAwait(false); + return mcpClient; - return await client.SendRequestAsync( - RequestMethods.ToolsCall, - new() - { - Name = toolName, - Arguments = ToArgumentsDictionary(arguments, serializerOptions), - ProgressToken = progressToken, - }, - McpJsonUtilities.JsonContext.Default.CallToolRequestParams, - McpJsonUtilities.JsonContext.Default.CallToolResult, - cancellationToken: cancellationToken).ConfigureAwait(false); - } + [DoesNotReturn] + [MethodImpl(MethodImplOptions.NoInlining)] + static void ThrowInvalidEndpointType(string memberName) + => throw new InvalidOperationException( + $"Only arguments assignable to '{nameof(McpClient)}' are supported. " + + $"Prefer using '{nameof(McpClient)}.{memberName}' instead, as " + + $"'{nameof(McpClientExtensions)}.{memberName}' is obsolete and will be " + + $"removed in the future."); } /// @@ -963,132 +709,4 @@ internal static CreateMessageResult ToCreateMessageResult(this ChatResponse chat StopReason = chatResponse.FinishReason == ChatFinishReason.Length ? "maxTokens" : "endTurn", }; } - - /// - /// Creates a sampling handler for use with that will - /// satisfy sampling requests using the specified . - /// - /// The with which to satisfy sampling requests. - /// The created handler delegate that can be assigned to . - /// - /// - /// This method creates a function that converts MCP message requests into chat client calls, enabling - /// an MCP client to generate text or other content using an actual AI model via the provided chat client. - /// - /// - /// The handler can process text messages, image messages, and resource messages as defined in the - /// Model Context Protocol. - /// - /// - /// is . - public static Func, CancellationToken, ValueTask> CreateSamplingHandler( - this IChatClient chatClient) - { - Throw.IfNull(chatClient); - - return async (requestParams, progress, cancellationToken) => - { - Throw.IfNull(requestParams); - - var (messages, options) = requestParams.ToChatClientArguments(); - var progressToken = requestParams.ProgressToken; - - List updates = []; - await foreach (var update in chatClient.GetStreamingResponseAsync(messages, options, cancellationToken).ConfigureAwait(false)) - { - updates.Add(update); - - if (progressToken is not null) - { - progress.Report(new() - { - Progress = updates.Count, - }); - } - } - - return updates.ToChatResponse().ToCreateMessageResult(); - }; - } - - /// - /// Sets the logging level for the server to control which log messages are sent to the client. - /// - /// The client instance used to communicate with the MCP server. - /// The minimum severity level of log messages to receive from the server. - /// The to monitor for cancellation requests. The default is . - /// A task representing the asynchronous operation. - /// - /// - /// After this request is processed, the server will send log messages at or above the specified - /// logging level as notifications to the client. For example, if is set, - /// the client will receive , , - /// , , and - /// level messages. - /// - /// - /// To receive all log messages, set the level to . - /// - /// - /// Log messages are delivered as notifications to the client and can be captured by registering - /// appropriate event handlers with the client implementation, such as with . - /// - /// - /// is . - public static Task SetLoggingLevel(this IMcpClient client, LoggingLevel level, CancellationToken cancellationToken = default) - { - Throw.IfNull(client); - - return client.SendRequestAsync( - RequestMethods.LoggingSetLevel, - new() { Level = level }, - McpJsonUtilities.JsonContext.Default.SetLevelRequestParams, - McpJsonUtilities.JsonContext.Default.EmptyResult, - cancellationToken: cancellationToken).AsTask(); - } - - /// - /// Sets the logging level for the server to control which log messages are sent to the client. - /// - /// The client instance used to communicate with the MCP server. - /// The minimum severity level of log messages to receive from the server. - /// The to monitor for cancellation requests. The default is . - /// A task representing the asynchronous operation. - /// - /// - /// After this request is processed, the server will send log messages at or above the specified - /// logging level as notifications to the client. For example, if is set, - /// the client will receive , , - /// and level messages. - /// - /// - /// To receive all log messages, set the level to . - /// - /// - /// Log messages are delivered as notifications to the client and can be captured by registering - /// appropriate event handlers with the client implementation, such as with . - /// - /// - /// is . - public static Task SetLoggingLevel(this IMcpClient client, LogLevel level, CancellationToken cancellationToken = default) => - SetLoggingLevel(client, McpServer.ToLoggingLevel(level), cancellationToken); - - /// Convers a dictionary with values to a dictionary with values. - private static Dictionary? ToArgumentsDictionary( - IReadOnlyDictionary? arguments, JsonSerializerOptions options) - { - var typeInfo = options.GetTypeInfo(); - - Dictionary? result = null; - if (arguments is not null) - { - result = new(arguments.Count); - foreach (var kvp in arguments) - { - result.Add(kvp.Key, kvp.Value is JsonElement je ? je : JsonSerializer.SerializeToElement(kvp.Value, typeInfo)); - } - } - - return result; - } } \ No newline at end of file diff --git a/src/ModelContextProtocol.Core/Client/McpClientFactory.cs b/src/ModelContextProtocol.Core/Client/McpClientFactory.cs index 30b3a9476..6934eb2b5 100644 --- a/src/ModelContextProtocol.Core/Client/McpClientFactory.cs +++ b/src/ModelContextProtocol.Core/Client/McpClientFactory.cs @@ -1,4 +1,4 @@ -using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Logging; namespace ModelContextProtocol.Client; @@ -10,6 +10,7 @@ namespace ModelContextProtocol.Client; /// that connect to MCP servers. It handles the creation and connection /// of appropriate implementations through the supplied transport. /// +[Obsolete($"Use {nameof(McpClient)}.{nameof(McpClient.CreateAsync)} instead.")] // See: https://github.com/modelcontextprotocol/csharp-sdk/issues/774 public static partial class McpClientFactory { /// Creates an , connecting it to the specified server. @@ -28,27 +29,5 @@ public static async Task CreateAsync( McpClientOptions? clientOptions = null, ILoggerFactory? loggerFactory = null, CancellationToken cancellationToken = default) - { - Throw.IfNull(clientTransport); - - McpClient client = new(clientTransport, clientOptions, loggerFactory); - try - { - await client.ConnectAsync(cancellationToken).ConfigureAwait(false); - if (loggerFactory?.CreateLogger(typeof(McpClientFactory)) is ILogger logger) - { - logger.LogClientCreated(client.EndpointName); - } - } - catch - { - await client.DisposeAsync().ConfigureAwait(false); - throw; - } - - return client; - } - - [LoggerMessage(Level = LogLevel.Information, Message = "{EndpointName} client created and connected.")] - private static partial void LogClientCreated(this ILogger logger, string endpointName); + => await McpClient.CreateAsync(clientTransport, clientOptions, loggerFactory, cancellationToken); } \ No newline at end of file diff --git a/src/ModelContextProtocol.Core/Client/McpClientImpl.cs b/src/ModelContextProtocol.Core/Client/McpClientImpl.cs new file mode 100644 index 000000000..639718885 --- /dev/null +++ b/src/ModelContextProtocol.Core/Client/McpClientImpl.cs @@ -0,0 +1,246 @@ +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Logging.Abstractions; +using ModelContextProtocol.Protocol; +using System.Text.Json; + +namespace ModelContextProtocol.Client; + +/// +internal sealed partial class McpClientImpl : McpClient +{ + private static Implementation DefaultImplementation { get; } = new() + { + Name = AssemblyNameHelper.DefaultAssemblyName.Name ?? nameof(McpClient), + Version = AssemblyNameHelper.DefaultAssemblyName.Version?.ToString() ?? "1.0.0", + }; + + private readonly ILogger _logger; + private readonly ITransport _transport; + private readonly string _endpointName; + private readonly McpClientOptions _options; + private readonly McpSessionHandler _sessionHandler; + private readonly SemaphoreSlim _disposeLock = new(1, 1); + + private CancellationTokenSource? _connectCts; + + private ServerCapabilities? _serverCapabilities; + private Implementation? _serverInfo; + private string? _serverInstructions; + + private bool _disposed; + + /// + /// Initializes a new instance of the class. + /// + /// The transport to use for communication with the server. + /// The name of the endpoint for logging and debug purposes. + /// Options for the client, defining protocol version and capabilities. + /// The logger factory. + internal McpClientImpl(ITransport transport, string endpointName, McpClientOptions? options, ILoggerFactory? loggerFactory) + { + options ??= new(); + + _transport = transport; + _endpointName = $"Client ({options.ClientInfo?.Name ?? DefaultImplementation.Name} {options.ClientInfo?.Version ?? DefaultImplementation.Version})"; + _options = options; + _logger = loggerFactory?.CreateLogger() ?? NullLogger.Instance; + + var notificationHandlers = new NotificationHandlers(); + var requestHandlers = new RequestHandlers(); + + if (options.Capabilities is { } capabilities) + { + RegisterHandlers(capabilities, notificationHandlers, requestHandlers); + } + + _sessionHandler = new McpSessionHandler(isServer: false, transport, endpointName, requestHandlers, notificationHandlers, _logger); + } + + private void RegisterHandlers(ClientCapabilities capabilities, NotificationHandlers notificationHandlers, RequestHandlers requestHandlers) + { + if (capabilities.NotificationHandlers is { } notificationHandlersFromCapabilities) + { + notificationHandlers.RegisterRange(notificationHandlersFromCapabilities); + } + + if (capabilities.Sampling is { } samplingCapability) + { + if (samplingCapability.SamplingHandler is not { } samplingHandler) + { + throw new InvalidOperationException("Sampling capability was set but it did not provide a handler."); + } + + requestHandlers.Set( + RequestMethods.SamplingCreateMessage, + (request, _, cancellationToken) => samplingHandler( + request, + request?.ProgressToken is { } token ? new TokenProgress(this, token) : NullProgress.Instance, + cancellationToken), + McpJsonUtilities.JsonContext.Default.CreateMessageRequestParams, + McpJsonUtilities.JsonContext.Default.CreateMessageResult); + } + + if (capabilities.Roots is { } rootsCapability) + { + if (rootsCapability.RootsHandler is not { } rootsHandler) + { + throw new InvalidOperationException("Roots capability was set but it did not provide a handler."); + } + + requestHandlers.Set( + RequestMethods.RootsList, + (request, _, cancellationToken) => rootsHandler(request, cancellationToken), + McpJsonUtilities.JsonContext.Default.ListRootsRequestParams, + McpJsonUtilities.JsonContext.Default.ListRootsResult); + } + + if (capabilities.Elicitation is { } elicitationCapability) + { + if (elicitationCapability.ElicitationHandler is not { } elicitationHandler) + { + throw new InvalidOperationException("Elicitation capability was set but it did not provide a handler."); + } + + requestHandlers.Set( + RequestMethods.ElicitationCreate, + (request, _, cancellationToken) => elicitationHandler(request, cancellationToken), + McpJsonUtilities.JsonContext.Default.ElicitRequestParams, + McpJsonUtilities.JsonContext.Default.ElicitResult); + } + } + + /// + public override string? SessionId => _transport.SessionId; + + /// + public override ServerCapabilities ServerCapabilities => _serverCapabilities ?? throw new InvalidOperationException("The client is not connected."); + + /// + public override Implementation ServerInfo => _serverInfo ?? throw new InvalidOperationException("The client is not connected."); + + /// + public override string? ServerInstructions => _serverInstructions; + + /// + /// Asynchronously connects to an MCP server, establishes the transport connection, and completes the initialization handshake. + /// + public async Task ConnectAsync(CancellationToken cancellationToken = default) + { + _connectCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); + cancellationToken = _connectCts.Token; + + try + { + // We don't want the ConnectAsync token to cancel the message processing loop after we've successfully connected. + // The session handler handles cancelling the loop upon its disposal. + _ = _sessionHandler.ProcessMessagesAsync(CancellationToken.None); + + // Perform initialization sequence + using var initializationCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); + initializationCts.CancelAfter(_options.InitializationTimeout); + + try + { + // Send initialize request + string requestProtocol = _options.ProtocolVersion ?? McpSessionHandler.LatestProtocolVersion; + var initializeResponse = await this.SendRequestAsync( + RequestMethods.Initialize, + new InitializeRequestParams + { + ProtocolVersion = requestProtocol, + Capabilities = _options.Capabilities ?? new ClientCapabilities(), + ClientInfo = _options.ClientInfo ?? DefaultImplementation, + }, + McpJsonUtilities.JsonContext.Default.InitializeRequestParams, + McpJsonUtilities.JsonContext.Default.InitializeResult, + cancellationToken: initializationCts.Token).ConfigureAwait(false); + + // Store server information + if (_logger.IsEnabled(LogLevel.Information)) + { + LogServerCapabilitiesReceived(_endpointName, + capabilities: JsonSerializer.Serialize(initializeResponse.Capabilities, McpJsonUtilities.JsonContext.Default.ServerCapabilities), + serverInfo: JsonSerializer.Serialize(initializeResponse.ServerInfo, McpJsonUtilities.JsonContext.Default.Implementation)); + } + + _serverCapabilities = initializeResponse.Capabilities; + _serverInfo = initializeResponse.ServerInfo; + _serverInstructions = initializeResponse.Instructions; + + // Validate protocol version + bool isResponseProtocolValid = + _options.ProtocolVersion is { } optionsProtocol ? optionsProtocol == initializeResponse.ProtocolVersion : + McpSessionHandler.SupportedProtocolVersions.Contains(initializeResponse.ProtocolVersion); + if (!isResponseProtocolValid) + { + LogServerProtocolVersionMismatch(_endpointName, requestProtocol, initializeResponse.ProtocolVersion); + throw new McpException($"Server protocol version mismatch. Expected {requestProtocol}, got {initializeResponse.ProtocolVersion}"); + } + + // Send initialized notification + await this.SendNotificationAsync( + NotificationMethods.InitializedNotification, + new InitializedNotificationParams(), + McpJsonUtilities.JsonContext.Default.InitializedNotificationParams, + cancellationToken: initializationCts.Token).ConfigureAwait(false); + + } + catch (OperationCanceledException oce) when (initializationCts.IsCancellationRequested && !cancellationToken.IsCancellationRequested) + { + LogClientInitializationTimeout(_endpointName); + throw new TimeoutException("Initialization timed out", oce); + } + } + catch (Exception e) + { + LogClientInitializationError(_endpointName, e); + await DisposeAsync().ConfigureAwait(false); + throw; + } + + LogClientConnected(_endpointName); + } + + /// + public override Task SendRequestAsync(JsonRpcRequest request, CancellationToken cancellationToken = default) + => _sessionHandler.SendRequestAsync(request, cancellationToken); + + /// + public override Task SendMessageAsync(JsonRpcMessage message, CancellationToken cancellationToken = default) + => _sessionHandler.SendMessageAsync(message, cancellationToken); + + /// + public override IAsyncDisposable RegisterNotificationHandler(string method, Func handler) + => _sessionHandler.RegisterNotificationHandler(method, handler); + + /// + public override async ValueTask DisposeAsync() + { + using var _ = await _disposeLock.LockAsync().ConfigureAwait(false); + + if (_disposed) + { + return; + } + + _disposed = true; + + await _sessionHandler.DisposeAsync().ConfigureAwait(false); + await _transport.DisposeAsync().ConfigureAwait(false); + } + + [LoggerMessage(Level = LogLevel.Information, Message = "{EndpointName} client received server '{ServerInfo}' capabilities: '{Capabilities}'.")] + private partial void LogServerCapabilitiesReceived(string endpointName, string capabilities, string serverInfo); + + [LoggerMessage(Level = LogLevel.Error, Message = "{EndpointName} client initialization error.")] + private partial void LogClientInitializationError(string endpointName, Exception exception); + + [LoggerMessage(Level = LogLevel.Error, Message = "{EndpointName} client initialization timed out.")] + private partial void LogClientInitializationTimeout(string endpointName); + + [LoggerMessage(Level = LogLevel.Error, Message = "{EndpointName} client protocol version mismatch with server. Expected '{Expected}', received '{Received}'.")] + private partial void LogServerProtocolVersionMismatch(string endpointName, string expected, string received); + + [LoggerMessage(Level = LogLevel.Information, Message = "{EndpointName} client created and connected.")] + private partial void LogClientConnected(string endpointName); +} diff --git a/src/ModelContextProtocol.Core/Client/McpClientOptions.cs b/src/ModelContextProtocol.Core/Client/McpClientOptions.cs index 76099d0d9..d4ed41db4 100644 --- a/src/ModelContextProtocol.Core/Client/McpClientOptions.cs +++ b/src/ModelContextProtocol.Core/Client/McpClientOptions.cs @@ -3,10 +3,10 @@ namespace ModelContextProtocol.Client; /// -/// Provides configuration options for creating instances. +/// Provides configuration options for creating instances. /// /// -/// These options are typically passed to when creating a client. +/// These options are typically passed to when creating a client. /// They define client capabilities, protocol version, and other client-specific settings. /// public sealed class McpClientOptions diff --git a/src/ModelContextProtocol.Core/Client/McpClientPrompt.cs b/src/ModelContextProtocol.Core/Client/McpClientPrompt.cs index 43fc759a0..5a618242f 100644 --- a/src/ModelContextProtocol.Core/Client/McpClientPrompt.cs +++ b/src/ModelContextProtocol.Core/Client/McpClientPrompt.cs @@ -10,8 +10,8 @@ namespace ModelContextProtocol.Client; /// /// This class provides a client-side wrapper around a prompt defined on an MCP server. It allows /// retrieving the prompt's content by sending a request to the server with optional arguments. -/// Instances of this class are typically obtained by calling -/// or . +/// Instances of this class are typically obtained by calling +/// or . /// /// /// Each prompt has a name and optionally a description, and it can be invoked with arguments @@ -20,9 +20,9 @@ namespace ModelContextProtocol.Client; /// public sealed class McpClientPrompt { - private readonly IMcpClient _client; + private readonly McpClient _client; - internal McpClientPrompt(IMcpClient client, Prompt prompt) + internal McpClientPrompt(McpClient client, Prompt prompt) { _client = client; ProtocolPrompt = prompt; @@ -63,7 +63,7 @@ internal McpClientPrompt(IMcpClient client, Prompt prompt) /// The server will process the request and return a result containing messages or other content. /// /// - /// This is a convenience method that internally calls + /// This is a convenience method that internally calls /// with this prompt's name and arguments. /// /// diff --git a/src/ModelContextProtocol.Core/Client/McpClientResource.cs b/src/ModelContextProtocol.Core/Client/McpClientResource.cs index 06f8aff67..19f11bfdf 100644 --- a/src/ModelContextProtocol.Core/Client/McpClientResource.cs +++ b/src/ModelContextProtocol.Core/Client/McpClientResource.cs @@ -9,15 +9,15 @@ namespace ModelContextProtocol.Client; /// /// This class provides a client-side wrapper around a resource defined on an MCP server. It allows /// retrieving the resource's content by sending a request to the server with the resource's URI. -/// Instances of this class are typically obtained by calling -/// or . +/// Instances of this class are typically obtained by calling +/// or . /// /// public sealed class McpClientResource { - private readonly IMcpClient _client; + private readonly McpClient _client; - internal McpClientResource(IMcpClient client, Resource resource) + internal McpClientResource(McpClient client, Resource resource) { _client = client; ProtocolResource = resource; @@ -58,7 +58,7 @@ internal McpClientResource(IMcpClient client, Resource resource) /// A containing the resource's result with content and messages. /// /// - /// This is a convenience method that internally calls . + /// This is a convenience method that internally calls . /// /// public ValueTask ReadAsync( diff --git a/src/ModelContextProtocol.Core/Client/McpClientResourceTemplate.cs b/src/ModelContextProtocol.Core/Client/McpClientResourceTemplate.cs index 4da1bd0c3..033f7cf00 100644 --- a/src/ModelContextProtocol.Core/Client/McpClientResourceTemplate.cs +++ b/src/ModelContextProtocol.Core/Client/McpClientResourceTemplate.cs @@ -9,15 +9,15 @@ namespace ModelContextProtocol.Client; /// /// This class provides a client-side wrapper around a resource template defined on an MCP server. It allows /// retrieving the resource template's content by sending a request to the server with the resource's URI. -/// Instances of this class are typically obtained by calling -/// or . +/// Instances of this class are typically obtained by calling +/// or . /// /// public sealed class McpClientResourceTemplate { - private readonly IMcpClient _client; + private readonly McpClient _client; - internal McpClientResourceTemplate(IMcpClient client, ResourceTemplate resourceTemplate) + internal McpClientResourceTemplate(McpClient client, ResourceTemplate resourceTemplate) { _client = client; ProtocolResourceTemplate = resourceTemplate; diff --git a/src/ModelContextProtocol.Core/Client/McpClientTool.cs b/src/ModelContextProtocol.Core/Client/McpClientTool.cs index 1810e9c56..c7af513ef 100644 --- a/src/ModelContextProtocol.Core/Client/McpClientTool.cs +++ b/src/ModelContextProtocol.Core/Client/McpClientTool.cs @@ -6,11 +6,11 @@ namespace ModelContextProtocol.Client; /// -/// Provides an that calls a tool via an . +/// Provides an that calls a tool via an . /// /// /// -/// The class encapsulates an along with a description of +/// The class encapsulates an along with a description of /// a tool available via that client, allowing it to be invoked as an . This enables integration /// with AI models that support function calling capabilities. /// @@ -19,8 +19,8 @@ namespace ModelContextProtocol.Client; /// and without changing the underlying tool functionality. /// /// -/// Typically, you would get instances of this class by calling the -/// or extension methods on an instance. +/// Typically, you would get instances of this class by calling the +/// or extension methods on an instance. /// /// public sealed class McpClientTool : AIFunction @@ -32,13 +32,13 @@ public sealed class McpClientTool : AIFunction ["Strict"] = false, // some MCP schemas may not meet "strict" requirements }); - private readonly IMcpClient _client; + private readonly McpClient _client; private readonly string _name; private readonly string _description; private readonly IProgress? _progress; internal McpClientTool( - IMcpClient client, + McpClient client, Tool tool, JsonSerializerOptions serializerOptions, string? name = null, diff --git a/src/ModelContextProtocol.Core/Client/SseClientSessionTransport.cs b/src/ModelContextProtocol.Core/Client/SseClientSessionTransport.cs index 479a76279..60950dfa5 100644 --- a/src/ModelContextProtocol.Core/Client/SseClientSessionTransport.cs +++ b/src/ModelContextProtocol.Core/Client/SseClientSessionTransport.cs @@ -15,7 +15,7 @@ namespace ModelContextProtocol.Client; internal sealed partial class SseClientSessionTransport : TransportBase { private readonly McpHttpClient _httpClient; - private readonly SseClientTransportOptions _options; + private readonly HttpClientTransportOptions _options; private readonly Uri _sseEndpoint; private Uri? _messageEndpoint; private readonly CancellationTokenSource _connectionCts; @@ -29,7 +29,7 @@ internal sealed partial class SseClientSessionTransport : TransportBase /// public SseClientSessionTransport( string endpointName, - SseClientTransportOptions transportOptions, + HttpClientTransportOptions transportOptions, McpHttpClient httpClient, Channel? messageChannel, ILoggerFactory? loggerFactory) @@ -42,7 +42,7 @@ public SseClientSessionTransport( _sseEndpoint = transportOptions.Endpoint; _httpClient = httpClient; _connectionCts = new CancellationTokenSource(); - _logger = (ILogger?)loggerFactory?.CreateLogger() ?? NullLogger.Instance; + _logger = (ILogger?)loggerFactory?.CreateLogger() ?? NullLogger.Instance; _connectionEstablished = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); } diff --git a/src/ModelContextProtocol.Core/Client/StreamableHttpClientSessionTransport.cs b/src/ModelContextProtocol.Core/Client/StreamableHttpClientSessionTransport.cs index c4014ed71..f2fd55f16 100644 --- a/src/ModelContextProtocol.Core/Client/StreamableHttpClientSessionTransport.cs +++ b/src/ModelContextProtocol.Core/Client/StreamableHttpClientSessionTransport.cs @@ -17,7 +17,7 @@ internal sealed partial class StreamableHttpClientSessionTransport : TransportBa private static readonly MediaTypeWithQualityHeaderValue s_textEventStreamMediaType = new("text/event-stream"); private readonly McpHttpClient _httpClient; - private readonly SseClientTransportOptions _options; + private readonly HttpClientTransportOptions _options; private readonly CancellationTokenSource _connectionCts; private readonly ILogger _logger; @@ -29,7 +29,7 @@ internal sealed partial class StreamableHttpClientSessionTransport : TransportBa public StreamableHttpClientSessionTransport( string endpointName, - SseClientTransportOptions transportOptions, + HttpClientTransportOptions transportOptions, McpHttpClient httpClient, Channel? messageChannel, ILoggerFactory? loggerFactory) @@ -41,10 +41,10 @@ public StreamableHttpClientSessionTransport( _options = transportOptions; _httpClient = httpClient; _connectionCts = new CancellationTokenSource(); - _logger = (ILogger?)loggerFactory?.CreateLogger() ?? NullLogger.Instance; + _logger = (ILogger?)loggerFactory?.CreateLogger() ?? NullLogger.Instance; // We connect with the initialization request with the MCP transport. This means that any errors won't be observed - // until the first call to SendMessageAsync. Fortunately, that happens internally in McpClientFactory.ConnectAsync + // until the first call to SendMessageAsync. Fortunately, that happens internally in McpClient.ConnectAsync // so we still throw any connection-related Exceptions from there and never expose a pre-connected client to the user. SetConnected(); } @@ -291,7 +291,7 @@ internal static void CopyAdditionalHeaders( { if (!headers.TryAddWithoutValidation(header.Key, header.Value)) { - throw new InvalidOperationException($"Failed to add header '{header.Key}' with value '{header.Value}' from {nameof(SseClientTransportOptions.AdditionalHeaders)}."); + throw new InvalidOperationException($"Failed to add header '{header.Key}' with value '{header.Value}' from {nameof(HttpClientTransportOptions.AdditionalHeaders)}."); } } } diff --git a/src/ModelContextProtocol.Core/IMcpEndpoint.cs b/src/ModelContextProtocol.Core/IMcpEndpoint.cs index ea825e682..beb96521f 100644 --- a/src/ModelContextProtocol.Core/IMcpEndpoint.cs +++ b/src/ModelContextProtocol.Core/IMcpEndpoint.cs @@ -1,4 +1,4 @@ -using ModelContextProtocol.Client; +using ModelContextProtocol.Client; using ModelContextProtocol.Protocol; using ModelContextProtocol.Server; @@ -26,6 +26,7 @@ namespace ModelContextProtocol; /// All MCP endpoints should be properly disposed after use as they implement . /// /// +[Obsolete($"Use {nameof(McpSession)} instead.")] // See: https://github.com/modelcontextprotocol/csharp-sdk/issues/774 public interface IMcpEndpoint : IAsyncDisposable { /// Gets an identifier associated with the current MCP session. diff --git a/src/ModelContextProtocol.Core/McpEndpoint.cs b/src/ModelContextProtocol.Core/McpEndpoint.cs deleted file mode 100644 index 0d0ccbb98..000000000 --- a/src/ModelContextProtocol.Core/McpEndpoint.cs +++ /dev/null @@ -1,144 +0,0 @@ -using Microsoft.Extensions.Logging; -using Microsoft.Extensions.Logging.Abstractions; -using ModelContextProtocol.Protocol; -using ModelContextProtocol.Server; -using System.Diagnostics.CodeAnalysis; -using System.Reflection; - -namespace ModelContextProtocol; - -/// -/// Base class for an MCP JSON-RPC endpoint. This covers both MCP clients and servers. -/// It is not supported, nor necessary, to implement both client and server functionality in the same class. -/// If an application needs to act as both a client and a server, it should use separate objects for each. -/// This is especially true as a client represents a connection to one and only one server, and vice versa. -/// Any multi-client or multi-server functionality should be implemented at a higher level of abstraction. -/// -internal abstract partial class McpEndpoint : IAsyncDisposable -{ - /// Cached naming information used for name/version when none is specified. - internal static AssemblyName DefaultAssemblyName { get; } = (Assembly.GetEntryAssembly() ?? Assembly.GetExecutingAssembly()).GetName(); - - private McpSession? _session; - private CancellationTokenSource? _sessionCts; - - private readonly SemaphoreSlim _disposeLock = new(1, 1); - private bool _disposed; - - protected readonly ILogger _logger; - - /// - /// Initializes a new instance of the class. - /// - /// The logger factory. - protected McpEndpoint(ILoggerFactory? loggerFactory = null) - { - _logger = loggerFactory?.CreateLogger(GetType()) ?? NullLogger.Instance; - } - - protected RequestHandlers RequestHandlers { get; } = []; - - protected NotificationHandlers NotificationHandlers { get; } = new(); - - public Task SendRequestAsync(JsonRpcRequest request, CancellationToken cancellationToken = default) - => GetSessionOrThrow().SendRequestAsync(request, cancellationToken); - - public Task SendMessageAsync(JsonRpcMessage message, CancellationToken cancellationToken = default) - => GetSessionOrThrow().SendMessageAsync(message, cancellationToken); - - public IAsyncDisposable RegisterNotificationHandler(string method, Func handler) => - GetSessionOrThrow().RegisterNotificationHandler(method, handler); - - /// - /// Gets the name of the endpoint for logging and debug purposes. - /// - public abstract string EndpointName { get; } - - /// - /// Task that processes incoming messages from the transport. - /// - protected Task? MessageProcessingTask { get; private set; } - - protected void InitializeSession(ITransport sessionTransport) - { - _session = new McpSession(this is IMcpServer, sessionTransport, EndpointName, RequestHandlers, NotificationHandlers, _logger); - } - - [MemberNotNull(nameof(MessageProcessingTask))] - protected void StartSession(ITransport sessionTransport, CancellationToken fullSessionCancellationToken) - { - _sessionCts = CancellationTokenSource.CreateLinkedTokenSource(fullSessionCancellationToken); - MessageProcessingTask = GetSessionOrThrow().ProcessMessagesAsync(_sessionCts.Token); - } - - protected void CancelSession() => _sessionCts?.Cancel(); - - public async ValueTask DisposeAsync() - { - using var _ = await _disposeLock.LockAsync().ConfigureAwait(false); - - if (_disposed) - { - return; - } - _disposed = true; - - await DisposeUnsynchronizedAsync().ConfigureAwait(false); - } - - /// - /// Cleans up the endpoint and releases resources. - /// - /// - public virtual async ValueTask DisposeUnsynchronizedAsync() - { - LogEndpointShuttingDown(EndpointName); - - try - { - if (_sessionCts is not null) - { - await _sessionCts.CancelAsync().ConfigureAwait(false); - } - - if (MessageProcessingTask is not null) - { - try - { - await MessageProcessingTask.ConfigureAwait(false); - } - catch (OperationCanceledException) - { - // Ignore cancellation - } - } - } - finally - { - _session?.Dispose(); - _sessionCts?.Dispose(); - } - - LogEndpointShutDown(EndpointName); - } - - protected McpSession GetSessionOrThrow() - { -#if NET - ObjectDisposedException.ThrowIf(_disposed, this); -#else - if (_disposed) - { - throw new ObjectDisposedException(GetType().Name); - } -#endif - - return _session ?? throw new InvalidOperationException($"This should be unreachable from public API! Call {nameof(InitializeSession)} before sending messages."); - } - - [LoggerMessage(Level = LogLevel.Information, Message = "{EndpointName} shutting down.")] - private partial void LogEndpointShuttingDown(string endpointName); - - [LoggerMessage(Level = LogLevel.Information, Message = "{EndpointName} shut down.")] - private partial void LogEndpointShutDown(string endpointName); -} \ No newline at end of file diff --git a/src/ModelContextProtocol.Core/McpEndpointExtensions.cs b/src/ModelContextProtocol.Core/McpEndpointExtensions.cs index 4e4abe5ce..f51289ac2 100644 --- a/src/ModelContextProtocol.Core/McpEndpointExtensions.cs +++ b/src/ModelContextProtocol.Core/McpEndpointExtensions.cs @@ -1,9 +1,9 @@ -using ModelContextProtocol.Client; +using ModelContextProtocol.Client; using ModelContextProtocol.Protocol; using ModelContextProtocol.Server; +using System.Diagnostics.CodeAnalysis; +using System.Runtime.CompilerServices; using System.Text.Json; -using System.Text.Json.Nodes; -using System.Text.Json.Serialization.Metadata; namespace ModelContextProtocol; @@ -34,6 +34,7 @@ public static class McpEndpointExtensions /// The options governing request serialization. /// The to monitor for cancellation requests. The default is . /// A task that represents the asynchronous operation. The task result contains the deserialized result. + [Obsolete($"Use {nameof(McpSession)}.{nameof(McpSession.SendRequestAsync)} instead.")] // See: https://github.com/modelcontextprotocol/csharp-sdk/issues/774 public static ValueTask SendRequestAsync( this IMcpEndpoint endpoint, string method, @@ -42,53 +43,7 @@ public static ValueTask SendRequestAsync( RequestId requestId = default, CancellationToken cancellationToken = default) where TResult : notnull - { - serializerOptions ??= McpJsonUtilities.DefaultOptions; - serializerOptions.MakeReadOnly(); - - JsonTypeInfo paramsTypeInfo = serializerOptions.GetTypeInfo(); - JsonTypeInfo resultTypeInfo = serializerOptions.GetTypeInfo(); - return SendRequestAsync(endpoint, method, parameters, paramsTypeInfo, resultTypeInfo, requestId, cancellationToken); - } - - /// - /// Sends a JSON-RPC request and attempts to deserialize the result to . - /// - /// The type of the request parameters to serialize from. - /// The type of the result to deserialize to. - /// The MCP client or server instance. - /// The JSON-RPC method name to invoke. - /// Object representing the request parameters. - /// The type information for request parameter serialization. - /// The type information for request parameter deserialization. - /// The request id for the request. - /// The to monitor for cancellation requests. The default is . - /// A task that represents the asynchronous operation. The task result contains the deserialized result. - internal static async ValueTask SendRequestAsync( - this IMcpEndpoint endpoint, - string method, - TParameters parameters, - JsonTypeInfo parametersTypeInfo, - JsonTypeInfo resultTypeInfo, - RequestId requestId = default, - CancellationToken cancellationToken = default) - where TResult : notnull - { - Throw.IfNull(endpoint); - Throw.IfNullOrWhiteSpace(method); - Throw.IfNull(parametersTypeInfo); - Throw.IfNull(resultTypeInfo); - - JsonRpcRequest jsonRpcRequest = new() - { - Id = requestId, - Method = method, - Params = JsonSerializer.SerializeToNode(parameters, parametersTypeInfo), - }; - - JsonRpcResponse response = await endpoint.SendRequestAsync(jsonRpcRequest, cancellationToken).ConfigureAwait(false); - return JsonSerializer.Deserialize(response.Result, resultTypeInfo) ?? throw new JsonException("Unexpected JSON result in response."); - } + => AsSessionOrThrow(endpoint).SendRequestAsync(method, parameters, serializerOptions, requestId, cancellationToken); /// /// Sends a parameterless notification to the connected endpoint. @@ -104,12 +59,9 @@ internal static async ValueTask SendRequestAsync( /// changes in state. /// /// + [Obsolete($"Use {nameof(McpSession)}.{nameof(McpSession.SendNotificationAsync)} instead.")] // See: https://github.com/modelcontextprotocol/csharp-sdk/issues/774 public static Task SendNotificationAsync(this IMcpEndpoint client, string method, CancellationToken cancellationToken = default) - { - Throw.IfNull(client); - Throw.IfNullOrWhiteSpace(method); - return client.SendMessageAsync(new JsonRpcNotification { Method = method }, cancellationToken); - } + => AsSessionOrThrow(client).SendNotificationAsync(method, cancellationToken); /// /// Sends a notification with parameters to the connected endpoint. @@ -135,42 +87,14 @@ public static Task SendNotificationAsync(this IMcpEndpoint client, string method /// but custom methods can also be used for application-specific notifications. /// /// + [Obsolete($"Use {nameof(McpSession)}.{nameof(McpSession.SendNotificationAsync)} instead.")] // See: https://github.com/modelcontextprotocol/csharp-sdk/issues/774 public static Task SendNotificationAsync( this IMcpEndpoint endpoint, string method, TParameters parameters, JsonSerializerOptions? serializerOptions = null, CancellationToken cancellationToken = default) - { - serializerOptions ??= McpJsonUtilities.DefaultOptions; - serializerOptions.MakeReadOnly(); - - JsonTypeInfo parametersTypeInfo = serializerOptions.GetTypeInfo(); - return SendNotificationAsync(endpoint, method, parameters, parametersTypeInfo, cancellationToken); - } - - /// - /// Sends a notification to the server with parameters. - /// - /// The MCP client or server instance. - /// The JSON-RPC method name to invoke. - /// Object representing the request parameters. - /// The type information for request parameter serialization. - /// The to monitor for cancellation requests. The default is . - internal static Task SendNotificationAsync( - this IMcpEndpoint endpoint, - string method, - TParameters parameters, - JsonTypeInfo parametersTypeInfo, - CancellationToken cancellationToken = default) - { - Throw.IfNull(endpoint); - Throw.IfNullOrWhiteSpace(method); - Throw.IfNull(parametersTypeInfo); - - JsonNode? parametersJson = JsonSerializer.SerializeToNode(parameters, parametersTypeInfo); - return endpoint.SendMessageAsync(new JsonRpcNotification { Method = method, Params = parametersJson }, cancellationToken); - } + => AsSessionOrThrow(endpoint).SendNotificationAsync(method, parameters, serializerOptions, cancellationToken); /// /// Notifies the connected endpoint of progress for a long-running operation. @@ -191,22 +115,33 @@ internal static Task SendNotificationAsync( /// Progress notifications are sent asynchronously and don't block the operation from continuing. /// /// + [Obsolete($"Use {nameof(McpSession)}.{nameof(McpSession.NotifyProgressAsync)} instead.")] // See: https://github.com/modelcontextprotocol/csharp-sdk/issues/774 public static Task NotifyProgressAsync( this IMcpEndpoint endpoint, ProgressToken progressToken, - ProgressNotificationValue progress, + ProgressNotificationValue progress, CancellationToken cancellationToken = default) + => AsSessionOrThrow(endpoint).NotifyProgressAsync(progressToken, progress, cancellationToken); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] +#pragma warning disable CS0618 // Type or member is obsolete + private static McpSession AsSessionOrThrow(IMcpEndpoint endpoint, [CallerMemberName] string memberName = "") +#pragma warning restore CS0618 // Type or member is obsolete { - Throw.IfNull(endpoint); + if (endpoint is not McpSession session) + { + ThrowInvalidEndpointType(memberName); + } + + return session; - return endpoint.SendNotificationAsync( - NotificationMethods.ProgressNotification, - new ProgressNotificationParams - { - ProgressToken = progressToken, - Progress = progress, - }, - McpJsonUtilities.JsonContext.Default.ProgressNotificationParams, - cancellationToken); + [DoesNotReturn] + [MethodImpl(MethodImplOptions.NoInlining)] + static void ThrowInvalidEndpointType(string memberName) + => throw new InvalidOperationException( + $"Only arguments assignable to '{nameof(McpSession)}' are supported. " + + $"Prefer using '{nameof(McpServer)}.{memberName}' instead, as " + + $"'{nameof(McpEndpointExtensions)}.{memberName}' is obsolete and will be " + + $"removed in the future."); } } diff --git a/src/ModelContextProtocol.Core/McpSession.Methods.cs b/src/ModelContextProtocol.Core/McpSession.Methods.cs new file mode 100644 index 000000000..c537732f1 --- /dev/null +++ b/src/ModelContextProtocol.Core/McpSession.Methods.cs @@ -0,0 +1,183 @@ +using ModelContextProtocol.Protocol; +using System.Text.Json; +using System.Text.Json.Nodes; +using System.Text.Json.Serialization.Metadata; + +namespace ModelContextProtocol; + +#pragma warning disable CS0618 // Type or member is obsolete +public abstract partial class McpSession : IMcpEndpoint, IAsyncDisposable +#pragma warning restore CS0618 // Type or member is obsolete +{ + /// + /// Sends a JSON-RPC request and attempts to deserialize the result to . + /// + /// The type of the request parameters to serialize from. + /// The type of the result to deserialize to. + /// The JSON-RPC method name to invoke. + /// Object representing the request parameters. + /// The request id for the request. + /// The options governing request serialization. + /// The to monitor for cancellation requests. The default is . + /// A task that represents the asynchronous operation. The task result contains the deserialized result. + public ValueTask SendRequestAsync( + string method, + TParameters parameters, + JsonSerializerOptions? serializerOptions = null, + RequestId requestId = default, + CancellationToken cancellationToken = default) + where TResult : notnull + { + serializerOptions ??= McpJsonUtilities.DefaultOptions; + serializerOptions.MakeReadOnly(); + + JsonTypeInfo paramsTypeInfo = serializerOptions.GetTypeInfo(); + JsonTypeInfo resultTypeInfo = serializerOptions.GetTypeInfo(); + return SendRequestAsync(method, parameters, paramsTypeInfo, resultTypeInfo, requestId, cancellationToken); + } + + /// + /// Sends a JSON-RPC request and attempts to deserialize the result to . + /// + /// The type of the request parameters to serialize from. + /// The type of the result to deserialize to. + /// The JSON-RPC method name to invoke. + /// Object representing the request parameters. + /// The type information for request parameter serialization. + /// The type information for request parameter deserialization. + /// The request id for the request. + /// The to monitor for cancellation requests. The default is . + /// A task that represents the asynchronous operation. The task result contains the deserialized result. + internal async ValueTask SendRequestAsync( + string method, + TParameters parameters, + JsonTypeInfo parametersTypeInfo, + JsonTypeInfo resultTypeInfo, + RequestId requestId = default, + CancellationToken cancellationToken = default) + where TResult : notnull + { + Throw.IfNullOrWhiteSpace(method); + Throw.IfNull(parametersTypeInfo); + Throw.IfNull(resultTypeInfo); + + JsonRpcRequest jsonRpcRequest = new() + { + Id = requestId, + Method = method, + Params = JsonSerializer.SerializeToNode(parameters, parametersTypeInfo), + }; + + JsonRpcResponse response = await SendRequestAsync(jsonRpcRequest, cancellationToken).ConfigureAwait(false); + return JsonSerializer.Deserialize(response.Result, resultTypeInfo) ?? throw new JsonException("Unexpected JSON result in response."); + } + + /// + /// Sends a parameterless notification to the connected session. + /// + /// The notification method name. + /// The to monitor for cancellation requests. The default is . + /// A task that represents the asynchronous send operation. + /// + /// + /// This method sends a notification without any parameters. Notifications are one-way messages + /// that don't expect a response. They are commonly used for events, status updates, or to signal + /// changes in state. + /// + /// + public Task SendNotificationAsync(string method, CancellationToken cancellationToken = default) + { + Throw.IfNullOrWhiteSpace(method); + return SendMessageAsync(new JsonRpcNotification { Method = method }, cancellationToken); + } + + /// + /// Sends a notification with parameters to the connected session. + /// + /// The type of the notification parameters to serialize. + /// The JSON-RPC method name for the notification. + /// Object representing the notification parameters. + /// The options governing parameter serialization. If null, default options are used. + /// The to monitor for cancellation requests. The default is . + /// A task that represents the asynchronous send operation. + /// + /// + /// This method sends a notification with parameters to the connected session. Notifications are one-way + /// messages that don't expect a response, commonly used for events, status updates, or signaling changes. + /// + /// + /// The parameters object is serialized to JSON according to the provided serializer options or the default + /// options if none are specified. + /// + /// + /// The Model Context Protocol defines several standard notification methods in , + /// but custom methods can also be used for application-specific notifications. + /// + /// + public Task SendNotificationAsync( + string method, + TParameters parameters, + JsonSerializerOptions? serializerOptions = null, + CancellationToken cancellationToken = default) + { + serializerOptions ??= McpJsonUtilities.DefaultOptions; + serializerOptions.MakeReadOnly(); + + JsonTypeInfo parametersTypeInfo = serializerOptions.GetTypeInfo(); + return SendNotificationAsync(method, parameters, parametersTypeInfo, cancellationToken); + } + + /// + /// Sends a notification to the server with parameters. + /// + /// The JSON-RPC method name to invoke. + /// Object representing the request parameters. + /// The type information for request parameter serialization. + /// The to monitor for cancellation requests. The default is . + internal Task SendNotificationAsync( + string method, + TParameters parameters, + JsonTypeInfo parametersTypeInfo, + CancellationToken cancellationToken = default) + { + Throw.IfNullOrWhiteSpace(method); + Throw.IfNull(parametersTypeInfo); + + JsonNode? parametersJson = JsonSerializer.SerializeToNode(parameters, parametersTypeInfo); + return SendMessageAsync(new JsonRpcNotification { Method = method, Params = parametersJson }, cancellationToken); + } + + /// + /// Notifies the connected session of progress for a long-running operation. + /// + /// The identifying the operation for which progress is being reported. + /// The progress update to send, containing information such as percentage complete or status message. + /// The to monitor for cancellation requests. The default is . + /// A task representing the completion of the notification operation (not the operation being tracked). + /// The current session instance is . + /// + /// + /// This method sends a progress notification to the connected session using the Model Context Protocol's + /// standardized progress notification format. Progress updates are identified by a + /// that allows the recipient to correlate multiple updates with a specific long-running operation. + /// + /// + /// Progress notifications are sent asynchronously and don't block the operation from continuing. + /// + /// + public Task NotifyProgressAsync( + ProgressToken progressToken, + ProgressNotificationValue progress, + CancellationToken cancellationToken = default) + { + return SendNotificationAsync( + NotificationMethods.ProgressNotification, + new ProgressNotificationParams + { + ProgressToken = progressToken, + Progress = progress, + }, + McpJsonUtilities.JsonContext.Default.ProgressNotificationParams, + cancellationToken); + } +} diff --git a/src/ModelContextProtocol.Core/McpSession.cs b/src/ModelContextProtocol.Core/McpSession.cs index 75215fee1..241c36d4c 100644 --- a/src/ModelContextProtocol.Core/McpSession.cs +++ b/src/ModelContextProtocol.Core/McpSession.cs @@ -1,796 +1,86 @@ -using Microsoft.Extensions.Logging; -using Microsoft.Extensions.Logging.Abstractions; using ModelContextProtocol.Client; using ModelContextProtocol.Protocol; using ModelContextProtocol.Server; -using System.Collections.Concurrent; -using System.Diagnostics; -using System.Diagnostics.Metrics; -using System.Text.Json; -using System.Text.Json.Nodes; -#if !NET -using System.Threading.Channels; -#endif namespace ModelContextProtocol; /// -/// Class for managing an MCP JSON-RPC session. This covers both MCP clients and servers. +/// Represents a client or server Model Context Protocol (MCP) session. /// -internal sealed partial class McpSession : IDisposable +/// +/// +/// The MCP session provides the core communication functionality used by both clients and servers: +/// +/// Sending JSON-RPC requests and receiving responses. +/// Sending notifications to the connected session. +/// Registering handlers for receiving notifications. +/// +/// +/// +/// serves as the base interface for both and +/// interfaces, providing the common functionality needed for MCP protocol +/// communication. Most applications will use these more specific interfaces rather than working with +/// directly. +/// +/// +/// All MCP sessions should be properly disposed after use as they implement . +/// +/// +#pragma warning disable CS0618 // Type or member is obsolete +public abstract partial class McpSession : IMcpEndpoint, IAsyncDisposable +#pragma warning restore CS0618 // Type or member is obsolete { - private static readonly Histogram s_clientSessionDuration = Diagnostics.CreateDurationHistogram( - "mcp.client.session.duration", "Measures the duration of a client session.", longBuckets: true); - private static readonly Histogram s_serverSessionDuration = Diagnostics.CreateDurationHistogram( - "mcp.server.session.duration", "Measures the duration of a server session.", longBuckets: true); - private static readonly Histogram s_clientOperationDuration = Diagnostics.CreateDurationHistogram( - "mcp.client.operation.duration", "Measures the duration of outbound message.", longBuckets: false); - private static readonly Histogram s_serverOperationDuration = Diagnostics.CreateDurationHistogram( - "mcp.server.operation.duration", "Measures the duration of inbound message processing.", longBuckets: false); - - /// The latest version of the protocol supported by this implementation. - internal const string LatestProtocolVersion = "2025-06-18"; - - /// All protocol versions supported by this implementation. - internal static readonly string[] SupportedProtocolVersions = - [ - "2024-11-05", - "2025-03-26", - LatestProtocolVersion, - ]; - - private readonly bool _isServer; - private readonly string _transportKind; - private readonly ITransport _transport; - private readonly RequestHandlers _requestHandlers; - private readonly NotificationHandlers _notificationHandlers; - private readonly long _sessionStartingTimestamp = Stopwatch.GetTimestamp(); - - private readonly DistributedContextPropagator _propagator = DistributedContextPropagator.Current; - - /// Collection of requests sent on this session and waiting for responses. - private readonly ConcurrentDictionary> _pendingRequests = []; - /// - /// Collection of requests received on this session and currently being handled. The value provides a - /// that can be used to request cancellation of the in-flight handler. - /// - private readonly ConcurrentDictionary _handlingRequests = new(); - private readonly ILogger _logger; - - // This _sessionId is solely used to identify the session in telemetry and logs. - private readonly string _sessionId = Guid.NewGuid().ToString("N"); - private long _lastRequestId; - - /// - /// Initializes a new instance of the class. - /// - /// true if this is a server; false if it's a client. - /// An MCP transport implementation. - /// The name of the endpoint for logging and debug purposes. - /// A collection of request handlers. - /// A collection of notification handlers. - /// The logger. - public McpSession( - bool isServer, - ITransport transport, - string endpointName, - RequestHandlers requestHandlers, - NotificationHandlers notificationHandlers, - ILogger logger) - { - Throw.IfNull(transport); - - _transportKind = transport switch - { - StdioClientSessionTransport or StdioServerTransport => "stdio", - StreamClientSessionTransport or StreamServerTransport => "stream", - SseClientSessionTransport or SseResponseStreamTransport => "sse", - StreamableHttpClientSessionTransport or StreamableHttpServerTransport or StreamableHttpPostTransport => "http", - _ => "unknownTransport" - }; - - _isServer = isServer; - _transport = transport; - EndpointName = endpointName; - _requestHandlers = requestHandlers; - _notificationHandlers = notificationHandlers; - _logger = logger ?? NullLogger.Instance; - LogSessionCreated(EndpointName, _sessionId, _transportKind); - } - - /// - /// Gets and sets the name of the endpoint for logging and debug purposes. - /// - public string EndpointName { get; set; } + /// Gets an identifier associated with the current MCP session. + /// + /// Typically populated in transports supporting multiple sessions such as Streamable HTTP or SSE. + /// Can return if the session hasn't initialized or if the transport doesn't + /// support multiple sessions (as is the case with STDIO). + /// + public abstract string? SessionId { get; } /// - /// Starts processing messages from the transport. This method will block until the transport is disconnected. - /// This is generally started in a background task or thread from the initialization logic of the derived class. + /// Sends a JSON-RPC request to the connected session and waits for a response. /// - public async Task ProcessMessagesAsync(CancellationToken cancellationToken) - { - try - { - await foreach (var message in _transport.MessageReader.ReadAllAsync(cancellationToken).ConfigureAwait(false)) - { - LogMessageRead(EndpointName, message.GetType().Name); - - // Fire and forget the message handling to avoid blocking the transport. - if (message.Context?.ExecutionContext is null) - { - _ = ProcessMessageAsync(); - } - else - { - // Flow the execution context from the HTTP request corresponding to this message if provided. - ExecutionContext.Run(message.Context.ExecutionContext, _ => _ = ProcessMessageAsync(), null); - } - - async Task ProcessMessageAsync() - { - JsonRpcMessageWithId? messageWithId = message as JsonRpcMessageWithId; - CancellationTokenSource? combinedCts = null; - try - { - // Register before we yield, so that the tracking is guaranteed to be there - // when subsequent messages arrive, even if the asynchronous processing happens - // out of order. - if (messageWithId is not null) - { - combinedCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); - _handlingRequests[messageWithId.Id] = combinedCts; - } - - // If we await the handler without yielding first, the transport may not be able to read more messages, - // which could lead to a deadlock if the handler sends a message back. -#if NET - await Task.CompletedTask.ConfigureAwait(ConfigureAwaitOptions.ForceYielding); -#else - await default(ForceYielding); -#endif - - // Handle the message. - await HandleMessageAsync(message, combinedCts?.Token ?? cancellationToken).ConfigureAwait(false); - } - catch (Exception ex) - { - // Only send responses for request errors that aren't user-initiated cancellation. - bool isUserCancellation = - ex is OperationCanceledException && - !cancellationToken.IsCancellationRequested && - combinedCts?.IsCancellationRequested is true; - - if (!isUserCancellation && message is JsonRpcRequest request) - { - LogRequestHandlerException(EndpointName, request.Method, ex); - - JsonRpcErrorDetail detail = ex is McpException mcpe ? - new() - { - Code = (int)mcpe.ErrorCode, - Message = mcpe.Message, - } : - new() - { - Code = (int)McpErrorCode.InternalError, - Message = "An error occurred.", - }; - - var errorMessage = new JsonRpcError - { - Id = request.Id, - JsonRpc = "2.0", - Error = detail, - Context = new JsonRpcMessageContext { RelatedTransport = request.Context?.RelatedTransport }, - }; - - await SendMessageAsync(errorMessage, cancellationToken).ConfigureAwait(false); - } - else if (ex is not OperationCanceledException) - { - if (_logger.IsEnabled(LogLevel.Trace)) - { - LogMessageHandlerExceptionSensitive(EndpointName, message.GetType().Name, JsonSerializer.Serialize(message, McpJsonUtilities.JsonContext.Default.JsonRpcMessage), ex); - } - else - { - LogMessageHandlerException(EndpointName, message.GetType().Name, ex); - } - } - } - finally - { - if (messageWithId is not null) - { - _handlingRequests.TryRemove(messageWithId.Id, out _); - combinedCts!.Dispose(); - } - } - } - } - } - catch (OperationCanceledException) when (cancellationToken.IsCancellationRequested) - { - // Normal shutdown - LogEndpointMessageProcessingCanceled(EndpointName); - } - finally - { - // Fail any pending requests, as they'll never be satisfied. - foreach (var entry in _pendingRequests) - { - entry.Value.TrySetException(new IOException("The server shut down unexpectedly.")); - } - } - } - - private async Task HandleMessageAsync(JsonRpcMessage message, CancellationToken cancellationToken) - { - Histogram durationMetric = _isServer ? s_serverOperationDuration : s_clientOperationDuration; - string method = GetMethodName(message); - - long? startingTimestamp = durationMetric.Enabled ? Stopwatch.GetTimestamp() : null; - - Activity? activity = Diagnostics.ShouldInstrumentMessage(message) ? - Diagnostics.ActivitySource.StartActivity( - CreateActivityName(method), - ActivityKind.Server, - parentContext: _propagator.ExtractActivityContext(message), - links: Diagnostics.ActivityLinkFromCurrent()) : - null; - - TagList tags = default; - bool addTags = activity is { IsAllDataRequested: true } || startingTimestamp is not null; - try - { - if (addTags) - { - AddTags(ref tags, activity, message, method); - } - - switch (message) - { - case JsonRpcRequest request: - var result = await HandleRequest(request, cancellationToken).ConfigureAwait(false); - AddResponseTags(ref tags, activity, result, method); - break; - - case JsonRpcNotification notification: - await HandleNotification(notification, cancellationToken).ConfigureAwait(false); - break; - - case JsonRpcMessageWithId messageWithId: - HandleMessageWithId(message, messageWithId); - break; - - default: - LogEndpointHandlerUnexpectedMessageType(EndpointName, message.GetType().Name); - break; - } - } - catch (Exception e) when (addTags) - { - AddExceptionTags(ref tags, activity, e); - throw; - } - finally - { - FinalizeDiagnostics(activity, startingTimestamp, durationMetric, ref tags); - } - } - - private async Task HandleNotification(JsonRpcNotification notification, CancellationToken cancellationToken) - { - // Special-case cancellation to cancel a pending operation. (We'll still subsequently invoke a user-specified handler if one exists.) - if (notification.Method == NotificationMethods.CancelledNotification) - { - try - { - if (GetCancelledNotificationParams(notification.Params) is CancelledNotificationParams cn && - _handlingRequests.TryGetValue(cn.RequestId, out var cts)) - { - await cts.CancelAsync().ConfigureAwait(false); - LogRequestCanceled(EndpointName, cn.RequestId, cn.Reason); - } - } - catch - { - // "Invalid cancellation notifications SHOULD be ignored" - } - } - - // Handle user-defined notifications. - await _notificationHandlers.InvokeHandlers(notification.Method, notification, cancellationToken).ConfigureAwait(false); - } - - private void HandleMessageWithId(JsonRpcMessage message, JsonRpcMessageWithId messageWithId) - { - if (_pendingRequests.TryRemove(messageWithId.Id, out var tcs)) - { - tcs.TrySetResult(message); - } - else - { - LogNoRequestFoundForMessageWithId(EndpointName, messageWithId.Id); - } - } - - private async Task HandleRequest(JsonRpcRequest request, CancellationToken cancellationToken) - { - if (!_requestHandlers.TryGetValue(request.Method, out var handler)) - { - LogNoHandlerFoundForRequest(EndpointName, request.Method); - throw new McpException($"Method '{request.Method}' is not available.", McpErrorCode.MethodNotFound); - } - - LogRequestHandlerCalled(EndpointName, request.Method); - JsonNode? result = await handler(request, cancellationToken).ConfigureAwait(false); - LogRequestHandlerCompleted(EndpointName, request.Method); - - await SendMessageAsync(new JsonRpcResponse - { - Id = request.Id, - Result = result, - Context = request.Context, - }, cancellationToken).ConfigureAwait(false); - - return result; - } - - private CancellationTokenRegistration RegisterCancellation(CancellationToken cancellationToken, JsonRpcRequest request) - { - if (!cancellationToken.CanBeCanceled) - { - return default; - } - - return cancellationToken.Register(static objState => - { - var state = (Tuple)objState!; - _ = state.Item1.SendMessageAsync(new JsonRpcNotification - { - Method = NotificationMethods.CancelledNotification, - Params = JsonSerializer.SerializeToNode(new CancelledNotificationParams { RequestId = state.Item2.Id }, McpJsonUtilities.JsonContext.Default.CancelledNotificationParams), - Context = new JsonRpcMessageContext { RelatedTransport = state.Item2.Context?.RelatedTransport }, - }); - }, Tuple.Create(this, request)); - } - - public IAsyncDisposable RegisterNotificationHandler(string method, Func handler) - { - Throw.IfNullOrWhiteSpace(method); - Throw.IfNull(handler); - - return _notificationHandlers.Register(method, handler); - } + /// The JSON-RPC request to send. + /// The to monitor for cancellation requests. The default is . + /// A task containing the session's response. + /// The transport is not connected, or another error occurs during request processing. + /// An error occured during request processing. + /// + /// This method provides low-level access to send raw JSON-RPC requests. For most use cases, + /// consider using the strongly-typed methods that provide a more convenient API. + /// + public abstract Task SendRequestAsync(JsonRpcRequest request, CancellationToken cancellationToken = default); /// - /// Sends a JSON-RPC request to the server. - /// It is strongly recommended use the capability-specific methods instead of this one. - /// Use this method for custom requests or those not yet covered explicitly by the endpoint implementation. + /// Sends a JSON-RPC message to the connected session. /// - /// The JSON-RPC request to send. + /// + /// The JSON-RPC message to send. This can be any type that implements JsonRpcMessage, such as + /// JsonRpcRequest, JsonRpcResponse, JsonRpcNotification, or JsonRpcError. + /// /// The to monitor for cancellation requests. The default is . - /// A task containing the server's response. - public async Task SendRequestAsync(JsonRpcRequest request, CancellationToken cancellationToken) - { - cancellationToken.ThrowIfCancellationRequested(); - - Histogram durationMetric = _isServer ? s_serverOperationDuration : s_clientOperationDuration; - string method = request.Method; - - long? startingTimestamp = durationMetric.Enabled ? Stopwatch.GetTimestamp() : null; - using Activity? activity = Diagnostics.ShouldInstrumentMessage(request) ? - Diagnostics.ActivitySource.StartActivity(CreateActivityName(method), ActivityKind.Client) : - null; - - // Set request ID - if (request.Id.Id is null) - { - request = request.WithId(new RequestId(Interlocked.Increment(ref _lastRequestId))); - } - - _propagator.InjectActivityContext(activity, request); - - TagList tags = default; - bool addTags = activity is { IsAllDataRequested: true } || startingTimestamp is not null; - - var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); - _pendingRequests[request.Id] = tcs; - try - { - if (addTags) - { - AddTags(ref tags, activity, request, method); - } - - if (_logger.IsEnabled(LogLevel.Trace)) - { - LogSendingRequestSensitive(EndpointName, request.Method, JsonSerializer.Serialize(request, McpJsonUtilities.JsonContext.Default.JsonRpcMessage)); - } - else - { - LogSendingRequest(EndpointName, request.Method); - } - - await SendToRelatedTransportAsync(request, cancellationToken).ConfigureAwait(false); - - // Now that the request has been sent, register for cancellation. If we registered before, - // a cancellation request could arrive before the server knew about that request ID, in which - // case the server could ignore it. - LogRequestSentAwaitingResponse(EndpointName, request.Method, request.Id); - JsonRpcMessage? response; - using (var registration = RegisterCancellation(cancellationToken, request)) - { - response = await tcs.Task.WaitAsync(cancellationToken).ConfigureAwait(false); - } - - if (response is JsonRpcError error) - { - LogSendingRequestFailed(EndpointName, request.Method, error.Error.Message, error.Error.Code); - throw new McpException($"Request failed (remote): {error.Error.Message}", (McpErrorCode)error.Error.Code); - } - - if (response is JsonRpcResponse success) - { - if (addTags) - { - AddResponseTags(ref tags, activity, success.Result, method); - } - - if (_logger.IsEnabled(LogLevel.Trace)) - { - LogRequestResponseReceivedSensitive(EndpointName, request.Method, success.Result?.ToJsonString() ?? "null"); - } - else - { - LogRequestResponseReceived(EndpointName, request.Method); - } - - return success; - } - - // Unexpected response type - LogSendingRequestInvalidResponseType(EndpointName, request.Method); - throw new McpException("Invalid response type"); - } - catch (Exception ex) when (addTags) - { - AddExceptionTags(ref tags, activity, ex); - throw; - } - finally - { - _pendingRequests.TryRemove(request.Id, out _); - FinalizeDiagnostics(activity, startingTimestamp, durationMetric, ref tags); - } - } - - public async Task SendMessageAsync(JsonRpcMessage message, CancellationToken cancellationToken = default) - { - Throw.IfNull(message); - - cancellationToken.ThrowIfCancellationRequested(); - - Histogram durationMetric = _isServer ? s_serverOperationDuration : s_clientOperationDuration; - string method = GetMethodName(message); - - long? startingTimestamp = durationMetric.Enabled ? Stopwatch.GetTimestamp() : null; - using Activity? activity = Diagnostics.ShouldInstrumentMessage(message) ? - Diagnostics.ActivitySource.StartActivity(CreateActivityName(method), ActivityKind.Client) : - null; - - TagList tags = default; - bool addTags = activity is { IsAllDataRequested: true } || startingTimestamp is not null; - - // propagate trace context - _propagator?.InjectActivityContext(activity, message); - - try - { - if (addTags) - { - AddTags(ref tags, activity, message, method); - } - - if (_logger.IsEnabled(LogLevel.Trace)) - { - LogSendingMessageSensitive(EndpointName, JsonSerializer.Serialize(message, McpJsonUtilities.JsonContext.Default.JsonRpcMessage)); - } - else - { - LogSendingMessage(EndpointName); - } - - await SendToRelatedTransportAsync(message, cancellationToken).ConfigureAwait(false); - - // If the sent notification was a cancellation notification, cancel the pending request's await, as either the - // server won't be sending a response, or per the specification, the response should be ignored. There are inherent - // race conditions here, so it's possible and allowed for the operation to complete before we get to this point. - if (message is JsonRpcNotification { Method: NotificationMethods.CancelledNotification } notification && - GetCancelledNotificationParams(notification.Params) is CancelledNotificationParams cn && - _pendingRequests.TryRemove(cn.RequestId, out var tcs)) - { - tcs.TrySetCanceled(default); - } - } - catch (Exception ex) when (addTags) - { - AddExceptionTags(ref tags, activity, ex); - throw; - } - finally - { - FinalizeDiagnostics(activity, startingTimestamp, durationMetric, ref tags); - } - } - - // The JsonRpcMessage should be sent over the RelatedTransport if set. This is used to support the - // Streamable HTTP transport where the specification states that the server SHOULD include JSON-RPC responses in - // the HTTP response body for the POST request containing the corresponding JSON-RPC request. - private Task SendToRelatedTransportAsync(JsonRpcMessage message, CancellationToken cancellationToken) - => (message.Context?.RelatedTransport ?? _transport).SendMessageAsync(message, cancellationToken); - - private static CancelledNotificationParams? GetCancelledNotificationParams(JsonNode? notificationParams) - { - try - { - return JsonSerializer.Deserialize(notificationParams, McpJsonUtilities.JsonContext.Default.CancelledNotificationParams); - } - catch - { - return null; - } - } - - private string CreateActivityName(string method) => method; - - private static string GetMethodName(JsonRpcMessage message) => - message switch - { - JsonRpcRequest request => request.Method, - JsonRpcNotification notification => notification.Method, - _ => "unknownMethod" - }; - - private void AddTags(ref TagList tags, Activity? activity, JsonRpcMessage message, string method) - { - tags.Add("mcp.method.name", method); - tags.Add("network.transport", _transportKind); - - // TODO: When using SSE transport, add: - // - server.address and server.port on client spans and metrics - // - client.address and client.port on server spans (not metrics because of cardinality) when using SSE transport - if (activity is { IsAllDataRequested: true }) - { - // session and request id have high cardinality, so not applying to metric tags - activity.AddTag("mcp.session.id", _sessionId); - - if (message is JsonRpcMessageWithId withId) - { - activity.AddTag("mcp.request.id", withId.Id.Id?.ToString()); - } - } - - JsonObject? paramsObj = message switch - { - JsonRpcRequest request => request.Params as JsonObject, - JsonRpcNotification notification => notification.Params as JsonObject, - _ => null - }; - - if (paramsObj == null) - { - return; - } - - string? target = null; - switch (method) - { - case RequestMethods.ToolsCall: - case RequestMethods.PromptsGet: - target = GetStringProperty(paramsObj, "name"); - if (target is not null) - { - tags.Add(method == RequestMethods.ToolsCall ? "mcp.tool.name" : "mcp.prompt.name", target); - } - break; - - case RequestMethods.ResourcesRead: - case RequestMethods.ResourcesSubscribe: - case RequestMethods.ResourcesUnsubscribe: - case NotificationMethods.ResourceUpdatedNotification: - target = GetStringProperty(paramsObj, "uri"); - if (target is not null) - { - tags.Add("mcp.resource.uri", target); - } - break; - } - - if (activity is { IsAllDataRequested: true }) - { - activity.DisplayName = target == null ? method : $"{method} {target}"; - } - } - - private static void AddExceptionTags(ref TagList tags, Activity? activity, Exception e) - { - if (e is AggregateException ae && ae.InnerException is not null and not AggregateException) - { - e = ae.InnerException; - } - - int? intErrorCode = - (int?)((e as McpException)?.ErrorCode) is int errorCode ? errorCode : - e is JsonException ? (int)McpErrorCode.ParseError : - null; - - string? errorType = intErrorCode?.ToString() ?? e.GetType().FullName; - tags.Add("error.type", errorType); - if (intErrorCode is not null) - { - tags.Add("rpc.jsonrpc.error_code", errorType); - } - - if (activity is { IsAllDataRequested: true }) - { - activity.SetStatus(ActivityStatusCode.Error, e.Message); - } - } - - private static void AddResponseTags(ref TagList tags, Activity? activity, JsonNode? response, string method) - { - if (response is JsonObject jsonObject - && jsonObject.TryGetPropertyValue("isError", out var isError) - && isError?.GetValueKind() == JsonValueKind.True) - { - if (activity is { IsAllDataRequested: true }) - { - string? content = null; - if (jsonObject.TryGetPropertyValue("content", out var prop) && prop != null) - { - content = prop.ToJsonString(); - } - - activity.SetStatus(ActivityStatusCode.Error, content); - } - - tags.Add("error.type", method == RequestMethods.ToolsCall ? "tool_error" : "_OTHER"); - } - } - - private static void FinalizeDiagnostics( - Activity? activity, long? startingTimestamp, Histogram durationMetric, ref TagList tags) - { - try - { - if (startingTimestamp is not null) - { - durationMetric.Record(GetElapsed(startingTimestamp.Value).TotalSeconds, tags); - } - - if (activity is { IsAllDataRequested: true }) - { - foreach (var tag in tags) - { - activity.AddTag(tag.Key, tag.Value); - } - } - } - finally - { - activity?.Dispose(); - } - } - - public void Dispose() - { - Histogram durationMetric = _isServer ? s_serverSessionDuration : s_clientSessionDuration; - if (durationMetric.Enabled) - { - TagList tags = default; - tags.Add("network.transport", _transportKind); - - // TODO: Add server.address and server.port on client-side when using SSE transport, - // client.* attributes are not added to metrics because of cardinality - durationMetric.Record(GetElapsed(_sessionStartingTimestamp).TotalSeconds, tags); - } - - // Complete all pending requests with cancellation - foreach (var entry in _pendingRequests) - { - entry.Value.TrySetCanceled(); - } - - _pendingRequests.Clear(); - LogSessionDisposed(EndpointName, _sessionId, _transportKind); - } - -#if !NET - private static readonly double s_timestampToTicks = TimeSpan.TicksPerSecond / (double)Stopwatch.Frequency; -#endif - - private static TimeSpan GetElapsed(long startingTimestamp) => -#if NET - Stopwatch.GetElapsedTime(startingTimestamp); -#else - new((long)(s_timestampToTicks * (Stopwatch.GetTimestamp() - startingTimestamp))); -#endif - - private static string? GetStringProperty(JsonObject parameters, string propName) - { - if (parameters.TryGetPropertyValue(propName, out var prop) && prop?.GetValueKind() is JsonValueKind.String) - { - return prop.GetValue(); - } - - return null; - } - - [LoggerMessage(Level = LogLevel.Information, Message = "{EndpointName} message processing canceled.")] - private partial void LogEndpointMessageProcessingCanceled(string endpointName); - - [LoggerMessage(Level = LogLevel.Information, Message = "{EndpointName} method '{Method}' request handler called.")] - private partial void LogRequestHandlerCalled(string endpointName, string method); - - [LoggerMessage(Level = LogLevel.Information, Message = "{EndpointName} method '{Method}' request handler completed.")] - private partial void LogRequestHandlerCompleted(string endpointName, string method); - - [LoggerMessage(Level = LogLevel.Warning, Message = "{EndpointName} method '{Method}' request handler failed.")] - private partial void LogRequestHandlerException(string endpointName, string method, Exception exception); - - [LoggerMessage(Level = LogLevel.Information, Message = "{EndpointName} received request for unknown request ID '{RequestId}'.")] - private partial void LogNoRequestFoundForMessageWithId(string endpointName, RequestId requestId); - - [LoggerMessage(Level = LogLevel.Warning, Message = "{EndpointName} request failed for method '{Method}': {ErrorMessage} ({ErrorCode}).")] - private partial void LogSendingRequestFailed(string endpointName, string method, string errorMessage, int errorCode); - - [LoggerMessage(Level = LogLevel.Warning, Message = "{EndpointName} received invalid response for method '{Method}'.")] - private partial void LogSendingRequestInvalidResponseType(string endpointName, string method); - - [LoggerMessage(Level = LogLevel.Debug, Message = "{EndpointName} sending method '{Method}' request.")] - private partial void LogSendingRequest(string endpointName, string method); - - [LoggerMessage(Level = LogLevel.Trace, Message = "{EndpointName} sending method '{Method}' request. Request: '{Request}'.")] - private partial void LogSendingRequestSensitive(string endpointName, string method, string request); - - [LoggerMessage(Level = LogLevel.Information, Message = "{EndpointName} canceled request '{RequestId}' per client notification. Reason: '{Reason}'.")] - private partial void LogRequestCanceled(string endpointName, RequestId requestId, string? reason); - - [LoggerMessage(Level = LogLevel.Debug, Message = "{EndpointName} Request response received for method {method}")] - private partial void LogRequestResponseReceived(string endpointName, string method); - - [LoggerMessage(Level = LogLevel.Trace, Message = "{EndpointName} Request response received for method {method}. Response: '{Response}'.")] - private partial void LogRequestResponseReceivedSensitive(string endpointName, string method, string response); - - [LoggerMessage(Level = LogLevel.Debug, Message = "{EndpointName} read {MessageType} message from channel.")] - private partial void LogMessageRead(string endpointName, string messageType); - - [LoggerMessage(Level = LogLevel.Warning, Message = "{EndpointName} message handler {MessageType} failed.")] - private partial void LogMessageHandlerException(string endpointName, string messageType, Exception exception); - - [LoggerMessage(Level = LogLevel.Trace, Message = "{EndpointName} message handler {MessageType} failed. Message: '{Message}'.")] - private partial void LogMessageHandlerExceptionSensitive(string endpointName, string messageType, string message, Exception exception); - - [LoggerMessage(Level = LogLevel.Warning, Message = "{EndpointName} received unexpected {MessageType} message type.")] - private partial void LogEndpointHandlerUnexpectedMessageType(string endpointName, string messageType); - - [LoggerMessage(Level = LogLevel.Warning, Message = "{EndpointName} received request for method '{Method}', but no handler is available.")] - private partial void LogNoHandlerFoundForRequest(string endpointName, string method); - - [LoggerMessage(Level = LogLevel.Debug, Message = "{EndpointName} waiting for response to request '{RequestId}' for method '{Method}'.")] - private partial void LogRequestSentAwaitingResponse(string endpointName, string method, RequestId requestId); - - [LoggerMessage(Level = LogLevel.Debug, Message = "{EndpointName} sending message.")] - private partial void LogSendingMessage(string endpointName); - - [LoggerMessage(Level = LogLevel.Trace, Message = "{EndpointName} sending message. Message: '{Message}'.")] - private partial void LogSendingMessageSensitive(string endpointName, string message); - - [LoggerMessage(Level = LogLevel.Trace, Message = "{EndpointName} session {SessionId} created with transport {TransportKind}")] - private partial void LogSessionCreated(string endpointName, string sessionId, string transportKind); - - [LoggerMessage(Level = LogLevel.Trace, Message = "{EndpointName} session {SessionId} disposed with transport {TransportKind}")] - private partial void LogSessionDisposed(string endpointName, string sessionId, string transportKind); + /// A task that represents the asynchronous send operation. + /// The transport is not connected. + /// is . + /// + /// + /// This method provides low-level access to send any JSON-RPC message. For specific message types, + /// consider using the higher-level methods such as or methods + /// on this class that provide a simpler API. + /// + /// + /// The method will serialize the message and transmit it using the underlying transport mechanism. + /// + /// + public abstract Task SendMessageAsync(JsonRpcMessage message, CancellationToken cancellationToken = default); + + /// Registers a handler to be invoked when a notification for the specified method is received. + /// The notification method. + /// The handler to be invoked. + /// An that will remove the registered handler when disposed. + public abstract IAsyncDisposable RegisterNotificationHandler(string method, Func handler); + + /// + public abstract ValueTask DisposeAsync(); } diff --git a/src/ModelContextProtocol.Core/McpSessionHandler.cs b/src/ModelContextProtocol.Core/McpSessionHandler.cs new file mode 100644 index 000000000..749486e4b --- /dev/null +++ b/src/ModelContextProtocol.Core/McpSessionHandler.cs @@ -0,0 +1,831 @@ +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Logging.Abstractions; +using ModelContextProtocol.Client; +using ModelContextProtocol.Protocol; +using ModelContextProtocol.Server; +using System.Collections.Concurrent; +using System.Diagnostics; +using System.Diagnostics.Metrics; +using System.Text.Json; +using System.Text.Json.Nodes; +#if !NET +using System.Threading.Channels; +#endif + +namespace ModelContextProtocol; + +/// +/// Class for managing an MCP JSON-RPC session. This covers both MCP clients and servers. +/// +internal sealed partial class McpSessionHandler : IAsyncDisposable +{ + private static readonly Histogram s_clientSessionDuration = Diagnostics.CreateDurationHistogram( + "mcp.client.session.duration", "Measures the duration of a client session.", longBuckets: true); + private static readonly Histogram s_serverSessionDuration = Diagnostics.CreateDurationHistogram( + "mcp.server.session.duration", "Measures the duration of a server session.", longBuckets: true); + private static readonly Histogram s_clientOperationDuration = Diagnostics.CreateDurationHistogram( + "mcp.client.operation.duration", "Measures the duration of outbound message.", longBuckets: false); + private static readonly Histogram s_serverOperationDuration = Diagnostics.CreateDurationHistogram( + "mcp.server.operation.duration", "Measures the duration of inbound message processing.", longBuckets: false); + + /// The latest version of the protocol supported by this implementation. + internal const string LatestProtocolVersion = "2025-06-18"; + + /// All protocol versions supported by this implementation. + internal static readonly string[] SupportedProtocolVersions = + [ + "2024-11-05", + "2025-03-26", + LatestProtocolVersion, + ]; + + private readonly bool _isServer; + private readonly string _transportKind; + private readonly ITransport _transport; + private readonly RequestHandlers _requestHandlers; + private readonly NotificationHandlers _notificationHandlers; + private readonly long _sessionStartingTimestamp = Stopwatch.GetTimestamp(); + + private readonly DistributedContextPropagator _propagator = DistributedContextPropagator.Current; + + /// Collection of requests sent on this session and waiting for responses. + private readonly ConcurrentDictionary> _pendingRequests = []; + /// + /// Collection of requests received on this session and currently being handled. The value provides a + /// that can be used to request cancellation of the in-flight handler. + /// + private readonly ConcurrentDictionary _handlingRequests = new(); + private readonly ILogger _logger; + + // This _sessionId is solely used to identify the session in telemetry and logs. + private readonly string _sessionId = Guid.NewGuid().ToString("N"); + private long _lastRequestId; + + private CancellationTokenSource? _messageProcessingCts; + private Task? _messageProcessingTask; + + /// + /// Initializes a new instance of the class. + /// + /// true if this is a server; false if it's a client. + /// An MCP transport implementation. + /// The name of the endpoint for logging and debug purposes. + /// A collection of request handlers. + /// A collection of notification handlers. + /// The logger. + public McpSessionHandler( + bool isServer, + ITransport transport, + string endpointName, + RequestHandlers requestHandlers, + NotificationHandlers notificationHandlers, + ILogger logger) + { + Throw.IfNull(transport); + + _transportKind = transport switch + { + StdioClientSessionTransport or StdioServerTransport => "stdio", + StreamClientSessionTransport or StreamServerTransport => "stream", + SseClientSessionTransport or SseResponseStreamTransport => "sse", + StreamableHttpClientSessionTransport or StreamableHttpServerTransport or StreamableHttpPostTransport => "http", + _ => "unknownTransport" + }; + + _isServer = isServer; + _transport = transport; + EndpointName = endpointName; + _requestHandlers = requestHandlers; + _notificationHandlers = notificationHandlers; + _logger = logger ?? NullLogger.Instance; + LogSessionCreated(EndpointName, _sessionId, _transportKind); + } + + /// + /// Gets and sets the name of the endpoint for logging and debug purposes. + /// + public string EndpointName { get; set; } + + /// + /// Starts processing messages from the transport. This method will block until the transport is disconnected. + /// This is generally started in a background task or thread from the initialization logic of the derived class. + /// + public Task ProcessMessagesAsync(CancellationToken cancellationToken) + { + if (_messageProcessingTask is not null) + { + throw new InvalidOperationException("The message processing loop has already started."); + } + + Debug.Assert(_messageProcessingCts is null); + + _messageProcessingCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); + _messageProcessingTask = ProcessMessagesCoreAsync(_messageProcessingCts.Token); + return _messageProcessingTask; + } + + private async Task ProcessMessagesCoreAsync(CancellationToken cancellationToken) + { + try + { + await foreach (var message in _transport.MessageReader.ReadAllAsync(cancellationToken).ConfigureAwait(false)) + { + LogMessageRead(EndpointName, message.GetType().Name); + + // Fire and forget the message handling to avoid blocking the transport. + if (message.Context?.ExecutionContext is null) + { + _ = ProcessMessageAsync(); + } + else + { + // Flow the execution context from the HTTP request corresponding to this message if provided. + ExecutionContext.Run(message.Context.ExecutionContext, _ => _ = ProcessMessageAsync(), null); + } + + async Task ProcessMessageAsync() + { + JsonRpcMessageWithId? messageWithId = message as JsonRpcMessageWithId; + CancellationTokenSource? combinedCts = null; + try + { + // Register before we yield, so that the tracking is guaranteed to be there + // when subsequent messages arrive, even if the asynchronous processing happens + // out of order. + if (messageWithId is not null) + { + combinedCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); + _handlingRequests[messageWithId.Id] = combinedCts; + } + + // If we await the handler without yielding first, the transport may not be able to read more messages, + // which could lead to a deadlock if the handler sends a message back. +#if NET + await Task.CompletedTask.ConfigureAwait(ConfigureAwaitOptions.ForceYielding); +#else + await default(ForceYielding); +#endif + + // Handle the message. + await HandleMessageAsync(message, combinedCts?.Token ?? cancellationToken).ConfigureAwait(false); + } + catch (Exception ex) + { + // Only send responses for request errors that aren't user-initiated cancellation. + bool isUserCancellation = + ex is OperationCanceledException && + !cancellationToken.IsCancellationRequested && + combinedCts?.IsCancellationRequested is true; + + if (!isUserCancellation && message is JsonRpcRequest request) + { + LogRequestHandlerException(EndpointName, request.Method, ex); + + JsonRpcErrorDetail detail = ex is McpException mcpe ? + new() + { + Code = (int)mcpe.ErrorCode, + Message = mcpe.Message, + } : + new() + { + Code = (int)McpErrorCode.InternalError, + Message = "An error occurred.", + }; + + var errorMessage = new JsonRpcError + { + Id = request.Id, + JsonRpc = "2.0", + Error = detail, + Context = new JsonRpcMessageContext { RelatedTransport = request.Context?.RelatedTransport }, + }; + await SendMessageAsync(errorMessage, cancellationToken).ConfigureAwait(false); + } + else if (ex is not OperationCanceledException) + { + if (_logger.IsEnabled(LogLevel.Trace)) + { + LogMessageHandlerExceptionSensitive(EndpointName, message.GetType().Name, JsonSerializer.Serialize(message, McpJsonUtilities.JsonContext.Default.JsonRpcMessage), ex); + } + else + { + LogMessageHandlerException(EndpointName, message.GetType().Name, ex); + } + } + } + finally + { + if (messageWithId is not null) + { + _handlingRequests.TryRemove(messageWithId.Id, out _); + combinedCts!.Dispose(); + } + } + } + } + } + catch (OperationCanceledException) when (cancellationToken.IsCancellationRequested) + { + // Normal shutdown + LogEndpointMessageProcessingCanceled(EndpointName); + } + finally + { + // Fail any pending requests, as they'll never be satisfied. + foreach (var entry in _pendingRequests) + { + entry.Value.TrySetException(new IOException("The server shut down unexpectedly.")); + } + } + } + + private async Task HandleMessageAsync(JsonRpcMessage message, CancellationToken cancellationToken) + { + Histogram durationMetric = _isServer ? s_serverOperationDuration : s_clientOperationDuration; + string method = GetMethodName(message); + + long? startingTimestamp = durationMetric.Enabled ? Stopwatch.GetTimestamp() : null; + + Activity? activity = Diagnostics.ShouldInstrumentMessage(message) ? + Diagnostics.ActivitySource.StartActivity( + CreateActivityName(method), + ActivityKind.Server, + parentContext: _propagator.ExtractActivityContext(message), + links: Diagnostics.ActivityLinkFromCurrent()) : + null; + + TagList tags = default; + bool addTags = activity is { IsAllDataRequested: true } || startingTimestamp is not null; + try + { + if (addTags) + { + AddTags(ref tags, activity, message, method); + } + + switch (message) + { + case JsonRpcRequest request: + var result = await HandleRequest(request, cancellationToken).ConfigureAwait(false); + AddResponseTags(ref tags, activity, result, method); + break; + + case JsonRpcNotification notification: + await HandleNotification(notification, cancellationToken).ConfigureAwait(false); + break; + + case JsonRpcMessageWithId messageWithId: + HandleMessageWithId(message, messageWithId); + break; + + default: + LogEndpointHandlerUnexpectedMessageType(EndpointName, message.GetType().Name); + break; + } + } + catch (Exception e) when (addTags) + { + AddExceptionTags(ref tags, activity, e); + throw; + } + finally + { + FinalizeDiagnostics(activity, startingTimestamp, durationMetric, ref tags); + } + } + + private async Task HandleNotification(JsonRpcNotification notification, CancellationToken cancellationToken) + { + // Special-case cancellation to cancel a pending operation. (We'll still subsequently invoke a user-specified handler if one exists.) + if (notification.Method == NotificationMethods.CancelledNotification) + { + try + { + if (GetCancelledNotificationParams(notification.Params) is CancelledNotificationParams cn && + _handlingRequests.TryGetValue(cn.RequestId, out var cts)) + { + await cts.CancelAsync().ConfigureAwait(false); + LogRequestCanceled(EndpointName, cn.RequestId, cn.Reason); + } + } + catch + { + // "Invalid cancellation notifications SHOULD be ignored" + } + } + + // Handle user-defined notifications. + await _notificationHandlers.InvokeHandlers(notification.Method, notification, cancellationToken).ConfigureAwait(false); + } + + private void HandleMessageWithId(JsonRpcMessage message, JsonRpcMessageWithId messageWithId) + { + if (_pendingRequests.TryRemove(messageWithId.Id, out var tcs)) + { + tcs.TrySetResult(message); + } + else + { + LogNoRequestFoundForMessageWithId(EndpointName, messageWithId.Id); + } + } + + private async Task HandleRequest(JsonRpcRequest request, CancellationToken cancellationToken) + { + if (!_requestHandlers.TryGetValue(request.Method, out var handler)) + { + LogNoHandlerFoundForRequest(EndpointName, request.Method); + throw new McpException($"Method '{request.Method}' is not available.", McpErrorCode.MethodNotFound); + } + + LogRequestHandlerCalled(EndpointName, request.Method); + JsonNode? result = await handler(request, cancellationToken).ConfigureAwait(false); + LogRequestHandlerCompleted(EndpointName, request.Method); + + await SendMessageAsync(new JsonRpcResponse + { + Id = request.Id, + Result = result, + Context = request.Context, + }, cancellationToken).ConfigureAwait(false); + + return result; + } + + private CancellationTokenRegistration RegisterCancellation(CancellationToken cancellationToken, JsonRpcRequest request) + { + if (!cancellationToken.CanBeCanceled) + { + return default; + } + + return cancellationToken.Register(static objState => + { + var state = (Tuple)objState!; + _ = state.Item1.SendMessageAsync(new JsonRpcNotification + { + Method = NotificationMethods.CancelledNotification, + Params = JsonSerializer.SerializeToNode(new CancelledNotificationParams { RequestId = state.Item2.Id }, McpJsonUtilities.JsonContext.Default.CancelledNotificationParams), + Context = new JsonRpcMessageContext { RelatedTransport = state.Item2.Context?.RelatedTransport }, + }); + }, Tuple.Create(this, request)); + } + + public IAsyncDisposable RegisterNotificationHandler(string method, Func handler) + { + Throw.IfNullOrWhiteSpace(method); + Throw.IfNull(handler); + + return _notificationHandlers.Register(method, handler); + } + + /// + /// Sends a JSON-RPC request to the server. + /// It is strongly recommended use the capability-specific methods instead of this one. + /// Use this method for custom requests or those not yet covered explicitly by the endpoint implementation. + /// + /// The JSON-RPC request to send. + /// The to monitor for cancellation requests. The default is . + /// A task containing the server's response. + public async Task SendRequestAsync(JsonRpcRequest request, CancellationToken cancellationToken) + { + Throw.IfNull(request); + + cancellationToken.ThrowIfCancellationRequested(); + + Histogram durationMetric = _isServer ? s_serverOperationDuration : s_clientOperationDuration; + string method = request.Method; + + long? startingTimestamp = durationMetric.Enabled ? Stopwatch.GetTimestamp() : null; + using Activity? activity = Diagnostics.ShouldInstrumentMessage(request) ? + Diagnostics.ActivitySource.StartActivity(CreateActivityName(method), ActivityKind.Client) : + null; + + // Set request ID + if (request.Id.Id is null) + { + request = request.WithId(new RequestId(Interlocked.Increment(ref _lastRequestId))); + } + + _propagator.InjectActivityContext(activity, request); + + TagList tags = default; + bool addTags = activity is { IsAllDataRequested: true } || startingTimestamp is not null; + + var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + _pendingRequests[request.Id] = tcs; + try + { + if (addTags) + { + AddTags(ref tags, activity, request, method); + } + + if (_logger.IsEnabled(LogLevel.Trace)) + { + LogSendingRequestSensitive(EndpointName, request.Method, JsonSerializer.Serialize(request, McpJsonUtilities.JsonContext.Default.JsonRpcMessage)); + } + else + { + LogSendingRequest(EndpointName, request.Method); + } + + await SendToRelatedTransportAsync(request, cancellationToken).ConfigureAwait(false); + + // Now that the request has been sent, register for cancellation. If we registered before, + // a cancellation request could arrive before the server knew about that request ID, in which + // case the server could ignore it. + LogRequestSentAwaitingResponse(EndpointName, request.Method, request.Id); + JsonRpcMessage? response; + using (var registration = RegisterCancellation(cancellationToken, request)) + { + response = await tcs.Task.WaitAsync(cancellationToken).ConfigureAwait(false); + } + + if (response is JsonRpcError error) + { + LogSendingRequestFailed(EndpointName, request.Method, error.Error.Message, error.Error.Code); + throw new McpException($"Request failed (remote): {error.Error.Message}", (McpErrorCode)error.Error.Code); + } + + if (response is JsonRpcResponse success) + { + if (addTags) + { + AddResponseTags(ref tags, activity, success.Result, method); + } + + if (_logger.IsEnabled(LogLevel.Trace)) + { + LogRequestResponseReceivedSensitive(EndpointName, request.Method, success.Result?.ToJsonString() ?? "null"); + } + else + { + LogRequestResponseReceived(EndpointName, request.Method); + } + + return success; + } + + // Unexpected response type + LogSendingRequestInvalidResponseType(EndpointName, request.Method); + throw new McpException("Invalid response type"); + } + catch (Exception ex) when (addTags) + { + AddExceptionTags(ref tags, activity, ex); + throw; + } + finally + { + _pendingRequests.TryRemove(request.Id, out _); + FinalizeDiagnostics(activity, startingTimestamp, durationMetric, ref tags); + } + } + + public async Task SendMessageAsync(JsonRpcMessage message, CancellationToken cancellationToken = default) + { + Throw.IfNull(message); + + cancellationToken.ThrowIfCancellationRequested(); + + Histogram durationMetric = _isServer ? s_serverOperationDuration : s_clientOperationDuration; + string method = GetMethodName(message); + + long? startingTimestamp = durationMetric.Enabled ? Stopwatch.GetTimestamp() : null; + using Activity? activity = Diagnostics.ShouldInstrumentMessage(message) ? + Diagnostics.ActivitySource.StartActivity(CreateActivityName(method), ActivityKind.Client) : + null; + + TagList tags = default; + bool addTags = activity is { IsAllDataRequested: true } || startingTimestamp is not null; + + // propagate trace context + _propagator?.InjectActivityContext(activity, message); + + try + { + if (addTags) + { + AddTags(ref tags, activity, message, method); + } + + if (_logger.IsEnabled(LogLevel.Trace)) + { + LogSendingMessageSensitive(EndpointName, JsonSerializer.Serialize(message, McpJsonUtilities.JsonContext.Default.JsonRpcMessage)); + } + else + { + LogSendingMessage(EndpointName); + } + + await SendToRelatedTransportAsync(message, cancellationToken).ConfigureAwait(false); + + // If the sent notification was a cancellation notification, cancel the pending request's await, as either the + // server won't be sending a response, or per the specification, the response should be ignored. There are inherent + // race conditions here, so it's possible and allowed for the operation to complete before we get to this point. + if (message is JsonRpcNotification { Method: NotificationMethods.CancelledNotification } notification && + GetCancelledNotificationParams(notification.Params) is CancelledNotificationParams cn && + _pendingRequests.TryRemove(cn.RequestId, out var tcs)) + { + tcs.TrySetCanceled(default); + } + } + catch (Exception ex) when (addTags) + { + AddExceptionTags(ref tags, activity, ex); + throw; + } + finally + { + FinalizeDiagnostics(activity, startingTimestamp, durationMetric, ref tags); + } + } + + // The JsonRpcMessage should be sent over the RelatedTransport if set. This is used to support the + // Streamable HTTP transport where the specification states that the server SHOULD include JSON-RPC responses in + // the HTTP response body for the POST request containing the corresponding JSON-RPC request. + private Task SendToRelatedTransportAsync(JsonRpcMessage message, CancellationToken cancellationToken) + => (message.Context?.RelatedTransport ?? _transport).SendMessageAsync(message, cancellationToken); + + private static CancelledNotificationParams? GetCancelledNotificationParams(JsonNode? notificationParams) + { + try + { + return JsonSerializer.Deserialize(notificationParams, McpJsonUtilities.JsonContext.Default.CancelledNotificationParams); + } + catch + { + return null; + } + } + + private string CreateActivityName(string method) => method; + + private static string GetMethodName(JsonRpcMessage message) => + message switch + { + JsonRpcRequest request => request.Method, + JsonRpcNotification notification => notification.Method, + _ => "unknownMethod" + }; + + private void AddTags(ref TagList tags, Activity? activity, JsonRpcMessage message, string method) + { + tags.Add("mcp.method.name", method); + tags.Add("network.transport", _transportKind); + + // TODO: When using SSE transport, add: + // - server.address and server.port on client spans and metrics + // - client.address and client.port on server spans (not metrics because of cardinality) when using SSE transport + if (activity is { IsAllDataRequested: true }) + { + // session and request id have high cardinality, so not applying to metric tags + activity.AddTag("mcp.session.id", _sessionId); + + if (message is JsonRpcMessageWithId withId) + { + activity.AddTag("mcp.request.id", withId.Id.Id?.ToString()); + } + } + + JsonObject? paramsObj = message switch + { + JsonRpcRequest request => request.Params as JsonObject, + JsonRpcNotification notification => notification.Params as JsonObject, + _ => null + }; + + if (paramsObj == null) + { + return; + } + + string? target = null; + switch (method) + { + case RequestMethods.ToolsCall: + case RequestMethods.PromptsGet: + target = GetStringProperty(paramsObj, "name"); + if (target is not null) + { + tags.Add(method == RequestMethods.ToolsCall ? "mcp.tool.name" : "mcp.prompt.name", target); + } + break; + + case RequestMethods.ResourcesRead: + case RequestMethods.ResourcesSubscribe: + case RequestMethods.ResourcesUnsubscribe: + case NotificationMethods.ResourceUpdatedNotification: + target = GetStringProperty(paramsObj, "uri"); + if (target is not null) + { + tags.Add("mcp.resource.uri", target); + } + break; + } + + if (activity is { IsAllDataRequested: true }) + { + activity.DisplayName = target == null ? method : $"{method} {target}"; + } + } + + private static void AddExceptionTags(ref TagList tags, Activity? activity, Exception e) + { + if (e is AggregateException ae && ae.InnerException is not null and not AggregateException) + { + e = ae.InnerException; + } + + int? intErrorCode = + (int?)((e as McpException)?.ErrorCode) is int errorCode ? errorCode : + e is JsonException ? (int)McpErrorCode.ParseError : + null; + + string? errorType = intErrorCode?.ToString() ?? e.GetType().FullName; + tags.Add("error.type", errorType); + if (intErrorCode is not null) + { + tags.Add("rpc.jsonrpc.error_code", errorType); + } + + if (activity is { IsAllDataRequested: true }) + { + activity.SetStatus(ActivityStatusCode.Error, e.Message); + } + } + + private static void AddResponseTags(ref TagList tags, Activity? activity, JsonNode? response, string method) + { + if (response is JsonObject jsonObject + && jsonObject.TryGetPropertyValue("isError", out var isError) + && isError?.GetValueKind() == JsonValueKind.True) + { + if (activity is { IsAllDataRequested: true }) + { + string? content = null; + if (jsonObject.TryGetPropertyValue("content", out var prop) && prop != null) + { + content = prop.ToJsonString(); + } + + activity.SetStatus(ActivityStatusCode.Error, content); + } + + tags.Add("error.type", method == RequestMethods.ToolsCall ? "tool_error" : "_OTHER"); + } + } + + private static void FinalizeDiagnostics( + Activity? activity, long? startingTimestamp, Histogram durationMetric, ref TagList tags) + { + try + { + if (startingTimestamp is not null) + { + durationMetric.Record(GetElapsed(startingTimestamp.Value).TotalSeconds, tags); + } + + if (activity is { IsAllDataRequested: true }) + { + foreach (var tag in tags) + { + activity.AddTag(tag.Key, tag.Value); + } + } + } + finally + { + activity?.Dispose(); + } + } + + public async ValueTask DisposeAsync() + { + Histogram durationMetric = _isServer ? s_serverSessionDuration : s_clientSessionDuration; + if (durationMetric.Enabled) + { + TagList tags = default; + tags.Add("network.transport", _transportKind); + + // TODO: Add server.address and server.port on client-side when using SSE transport, + // client.* attributes are not added to metrics because of cardinality + durationMetric.Record(GetElapsed(_sessionStartingTimestamp).TotalSeconds, tags); + } + + foreach (var entry in _pendingRequests) + { + entry.Value.TrySetCanceled(); + } + + _pendingRequests.Clear(); + + if (_messageProcessingCts is not null) + { + await _messageProcessingCts.CancelAsync().ConfigureAwait(false); + } + + if (_messageProcessingTask is not null) + { + try + { + await _messageProcessingTask.ConfigureAwait(false); + } + catch (OperationCanceledException) + { + // Ignore cancellation + } + } + + LogSessionDisposed(EndpointName, _sessionId, _transportKind); + } + +#if !NET + private static readonly double s_timestampToTicks = TimeSpan.TicksPerSecond / (double)Stopwatch.Frequency; +#endif + + private static TimeSpan GetElapsed(long startingTimestamp) => +#if NET + Stopwatch.GetElapsedTime(startingTimestamp); +#else + new((long)(s_timestampToTicks * (Stopwatch.GetTimestamp() - startingTimestamp))); +#endif + + private static string? GetStringProperty(JsonObject parameters, string propName) + { + if (parameters.TryGetPropertyValue(propName, out var prop) && prop?.GetValueKind() is JsonValueKind.String) + { + return prop.GetValue(); + } + + return null; + } + + [LoggerMessage(Level = LogLevel.Information, Message = "{EndpointName} message processing canceled.")] + private partial void LogEndpointMessageProcessingCanceled(string endpointName); + + [LoggerMessage(Level = LogLevel.Information, Message = "{EndpointName} method '{Method}' request handler called.")] + private partial void LogRequestHandlerCalled(string endpointName, string method); + + [LoggerMessage(Level = LogLevel.Information, Message = "{EndpointName} method '{Method}' request handler completed.")] + private partial void LogRequestHandlerCompleted(string endpointName, string method); + + [LoggerMessage(Level = LogLevel.Warning, Message = "{EndpointName} method '{Method}' request handler failed.")] + private partial void LogRequestHandlerException(string endpointName, string method, Exception exception); + + [LoggerMessage(Level = LogLevel.Information, Message = "{EndpointName} received request for unknown request ID '{RequestId}'.")] + private partial void LogNoRequestFoundForMessageWithId(string endpointName, RequestId requestId); + + [LoggerMessage(Level = LogLevel.Warning, Message = "{EndpointName} request failed for method '{Method}': {ErrorMessage} ({ErrorCode}).")] + private partial void LogSendingRequestFailed(string endpointName, string method, string errorMessage, int errorCode); + + [LoggerMessage(Level = LogLevel.Warning, Message = "{EndpointName} received invalid response for method '{Method}'.")] + private partial void LogSendingRequestInvalidResponseType(string endpointName, string method); + + [LoggerMessage(Level = LogLevel.Debug, Message = "{EndpointName} sending method '{Method}' request.")] + private partial void LogSendingRequest(string endpointName, string method); + + [LoggerMessage(Level = LogLevel.Trace, Message = "{EndpointName} sending method '{Method}' request. Request: '{Request}'.")] + private partial void LogSendingRequestSensitive(string endpointName, string method, string request); + + [LoggerMessage(Level = LogLevel.Information, Message = "{EndpointName} canceled request '{RequestId}' per client notification. Reason: '{Reason}'.")] + private partial void LogRequestCanceled(string endpointName, RequestId requestId, string? reason); + + [LoggerMessage(Level = LogLevel.Debug, Message = "{EndpointName} Request response received for method {method}")] + private partial void LogRequestResponseReceived(string endpointName, string method); + + [LoggerMessage(Level = LogLevel.Trace, Message = "{EndpointName} Request response received for method {method}. Response: '{Response}'.")] + private partial void LogRequestResponseReceivedSensitive(string endpointName, string method, string response); + + [LoggerMessage(Level = LogLevel.Debug, Message = "{EndpointName} read {MessageType} message from channel.")] + private partial void LogMessageRead(string endpointName, string messageType); + + [LoggerMessage(Level = LogLevel.Warning, Message = "{EndpointName} message handler {MessageType} failed.")] + private partial void LogMessageHandlerException(string endpointName, string messageType, Exception exception); + + [LoggerMessage(Level = LogLevel.Trace, Message = "{EndpointName} message handler {MessageType} failed. Message: '{Message}'.")] + private partial void LogMessageHandlerExceptionSensitive(string endpointName, string messageType, string message, Exception exception); + + [LoggerMessage(Level = LogLevel.Warning, Message = "{EndpointName} received unexpected {MessageType} message type.")] + private partial void LogEndpointHandlerUnexpectedMessageType(string endpointName, string messageType); + + [LoggerMessage(Level = LogLevel.Warning, Message = "{EndpointName} received request for method '{Method}', but no handler is available.")] + private partial void LogNoHandlerFoundForRequest(string endpointName, string method); + + [LoggerMessage(Level = LogLevel.Debug, Message = "{EndpointName} waiting for response to request '{RequestId}' for method '{Method}'.")] + private partial void LogRequestSentAwaitingResponse(string endpointName, string method, RequestId requestId); + + [LoggerMessage(Level = LogLevel.Debug, Message = "{EndpointName} sending message.")] + private partial void LogSendingMessage(string endpointName); + + [LoggerMessage(Level = LogLevel.Trace, Message = "{EndpointName} sending message. Message: '{Message}'.")] + private partial void LogSendingMessageSensitive(string endpointName, string message); + + [LoggerMessage(Level = LogLevel.Trace, Message = "{EndpointName} session {SessionId} created with transport {TransportKind}")] + private partial void LogSessionCreated(string endpointName, string sessionId, string transportKind); + + [LoggerMessage(Level = LogLevel.Trace, Message = "{EndpointName} session {SessionId} disposed with transport {TransportKind}")] + private partial void LogSessionDisposed(string endpointName, string sessionId, string transportKind); +} diff --git a/src/ModelContextProtocol.Core/Protocol/ClientCapabilities.cs b/src/ModelContextProtocol.Core/Protocol/ClientCapabilities.cs index ebe698135..c065ed6cb 100644 --- a/src/ModelContextProtocol.Core/Protocol/ClientCapabilities.cs +++ b/src/ModelContextProtocol.Core/Protocol/ClientCapabilities.cs @@ -44,7 +44,7 @@ public sealed class ClientCapabilities /// server requests for listing root URIs. Root URIs serve as entry points for resource navigation in the protocol. /// /// - /// The server can use to request the list of + /// The server can use to request the list of /// available roots from the client, which will trigger the client's . /// /// @@ -78,7 +78,7 @@ public sealed class ClientCapabilities /// /// /// Handlers provided via will be registered with the client for the lifetime of the client. - /// For transient handlers, may be used to register a handler that can + /// For transient handlers, may be used to register a handler that can /// then be unregistered by disposing of the returned from the method. /// /// diff --git a/src/ModelContextProtocol.Core/Protocol/ITransport.cs b/src/ModelContextProtocol.Core/Protocol/ITransport.cs index e35b3a6fb..148472e90 100644 --- a/src/ModelContextProtocol.Core/Protocol/ITransport.cs +++ b/src/ModelContextProtocol.Core/Protocol/ITransport.cs @@ -62,8 +62,8 @@ public interface ITransport : IAsyncDisposable /// /// /// This is a core method used by higher-level abstractions in the MCP protocol implementation. - /// Most client code should use the higher-level methods provided by , - /// , , or , + /// Most client code should use the higher-level methods provided by , + /// , or , /// rather than accessing this method directly. /// /// diff --git a/src/ModelContextProtocol.Core/Protocol/JsonRpcMessageContext.cs b/src/ModelContextProtocol.Core/Protocol/JsonRpcMessageContext.cs index 30b6745a9..261796b5f 100644 --- a/src/ModelContextProtocol.Core/Protocol/JsonRpcMessageContext.cs +++ b/src/ModelContextProtocol.Core/Protocol/JsonRpcMessageContext.cs @@ -29,7 +29,7 @@ public class JsonRpcMessageContext /// /// /// This is used to support the Streamable HTTP transport in its default stateful mode. In this mode, - /// the outlives the initial HTTP request context it was created on, and new + /// the outlives the initial HTTP request context it was created on, and new /// JSON-RPC messages can originate from future HTTP requests. This allows the transport to flow the /// context with the JSON-RPC message. This is particularly useful for enabling IHttpContextAccessor /// in tool calls. diff --git a/src/ModelContextProtocol.Core/Protocol/Reference.cs b/src/ModelContextProtocol.Core/Protocol/Reference.cs index a9c87fe49..af95cf330 100644 --- a/src/ModelContextProtocol.Core/Protocol/Reference.cs +++ b/src/ModelContextProtocol.Core/Protocol/Reference.cs @@ -12,7 +12,7 @@ namespace ModelContextProtocol.Protocol; /// /// /// -/// References are commonly used with to request completion suggestions for arguments, +/// References are commonly used with to request completion suggestions for arguments, /// and with other methods that need to reference resources or prompts. /// /// diff --git a/src/ModelContextProtocol.Core/Protocol/SamplingCapability.cs b/src/ModelContextProtocol.Core/Protocol/SamplingCapability.cs index 6e0f1190a..7828ce290 100644 --- a/src/ModelContextProtocol.Core/Protocol/SamplingCapability.cs +++ b/src/ModelContextProtocol.Core/Protocol/SamplingCapability.cs @@ -34,7 +34,7 @@ public sealed class SamplingCapability /// generated content. /// /// - /// You can create a handler using the extension + /// You can create a handler using the extension /// method with any implementation of . /// /// diff --git a/src/ModelContextProtocol.Core/Protocol/ServerCapabilities.cs b/src/ModelContextProtocol.Core/Protocol/ServerCapabilities.cs index 6a4b2e62a..023a869a4 100644 --- a/src/ModelContextProtocol.Core/Protocol/ServerCapabilities.cs +++ b/src/ModelContextProtocol.Core/Protocol/ServerCapabilities.cs @@ -77,7 +77,7 @@ public sealed class ServerCapabilities /// /// /// Handlers provided via will be registered with the server for the lifetime of the server. - /// For transient handlers, may be used to register a handler that can + /// For transient handlers, may be used to register a handler that can /// then be unregistered by disposing of the returned from the method. /// /// diff --git a/src/ModelContextProtocol.Core/README.md b/src/ModelContextProtocol.Core/README.md index beb365c80..f6cffaf68 100644 --- a/src/ModelContextProtocol.Core/README.md +++ b/src/ModelContextProtocol.Core/README.md @@ -27,8 +27,8 @@ dotnet add package ModelContextProtocol.Core --prerelease ## Getting Started (Client) -To get started writing a client, the `McpClientFactory.CreateAsync` method is used to instantiate and connect an `IMcpClient` -to a server. Once you have an `IMcpClient`, you can interact with it, such as to enumerate all available tools and invoke tools. +To get started writing a client, the `McpClient.CreateAsync` method is used to instantiate and connect an `McpClient` +to a server. Once you have an `McpClient`, you can interact with it, such as to enumerate all available tools and invoke tools. ```csharp var clientTransport = new StdioClientTransport(new StdioClientTransportOptions @@ -38,7 +38,7 @@ var clientTransport = new StdioClientTransport(new StdioClientTransportOptions Arguments = ["-y", "@modelcontextprotocol/server-everything"], }); -var client = await McpClientFactory.CreateAsync(clientTransport); +var client = await McpClient.CreateAsync(clientTransport); // Print the list of tools available from the server. foreach (var tool in await client.ListToolsAsync()) diff --git a/src/ModelContextProtocol.Core/Server/DestinationBoundMcpServer.cs b/src/ModelContextProtocol.Core/Server/DestinationBoundMcpServer.cs index 78346c399..f74dc29b0 100644 --- a/src/ModelContextProtocol.Core/Server/DestinationBoundMcpServer.cs +++ b/src/ModelContextProtocol.Core/Server/DestinationBoundMcpServer.cs @@ -3,24 +3,23 @@ namespace ModelContextProtocol.Server; -internal sealed class DestinationBoundMcpServer(McpServer server, ITransport? transport) : IMcpServer +internal sealed class DestinationBoundMcpServer(McpServerImpl server, ITransport? transport) : McpServer { - public string EndpointName => server.EndpointName; - public string? SessionId => transport?.SessionId ?? server.SessionId; - public ClientCapabilities? ClientCapabilities => server.ClientCapabilities; - public Implementation? ClientInfo => server.ClientInfo; - public McpServerOptions ServerOptions => server.ServerOptions; - public IServiceProvider? Services => server.Services; - public LoggingLevel? LoggingLevel => server.LoggingLevel; + public override string? SessionId => transport?.SessionId ?? server.SessionId; + public override ClientCapabilities? ClientCapabilities => server.ClientCapabilities; + public override Implementation? ClientInfo => server.ClientInfo; + public override McpServerOptions ServerOptions => server.ServerOptions; + public override IServiceProvider? Services => server.Services; + public override LoggingLevel? LoggingLevel => server.LoggingLevel; - public ValueTask DisposeAsync() => server.DisposeAsync(); + public override ValueTask DisposeAsync() => server.DisposeAsync(); - public IAsyncDisposable RegisterNotificationHandler(string method, Func handler) => server.RegisterNotificationHandler(method, handler); + public override IAsyncDisposable RegisterNotificationHandler(string method, Func handler) => server.RegisterNotificationHandler(method, handler); // This will throw because the server must already be running for this class to be constructed, but it should give us a good Exception message. - public Task RunAsync(CancellationToken cancellationToken) => server.RunAsync(cancellationToken); + public override Task RunAsync(CancellationToken cancellationToken) => server.RunAsync(cancellationToken); - public Task SendMessageAsync(JsonRpcMessage message, CancellationToken cancellationToken = default) + public override Task SendMessageAsync(JsonRpcMessage message, CancellationToken cancellationToken = default) { if (message.Context is not null) { @@ -32,7 +31,7 @@ public Task SendMessageAsync(JsonRpcMessage message, CancellationToken cancellat return server.SendMessageAsync(message, cancellationToken); } - public Task SendRequestAsync(JsonRpcRequest request, CancellationToken cancellationToken = default) + public override Task SendRequestAsync(JsonRpcRequest request, CancellationToken cancellationToken = default) { if (request.Context is not null) { diff --git a/src/ModelContextProtocol.Core/Server/IMcpServer.cs b/src/ModelContextProtocol.Core/Server/IMcpServer.cs index ec2b87ade..016ad90b3 100644 --- a/src/ModelContextProtocol.Core/Server/IMcpServer.cs +++ b/src/ModelContextProtocol.Core/Server/IMcpServer.cs @@ -1,10 +1,11 @@ -using ModelContextProtocol.Protocol; +using ModelContextProtocol.Protocol; namespace ModelContextProtocol.Server; /// /// Represents an instance of a Model Context Protocol (MCP) server that connects to and communicates with an MCP client. /// +[Obsolete($"Use {nameof(McpServer)} instead.")] // See: https://github.com/modelcontextprotocol/csharp-sdk/issues/774 public interface IMcpServer : IMcpEndpoint { /// diff --git a/src/ModelContextProtocol.Core/Server/McpServer.Methods.cs b/src/ModelContextProtocol.Core/Server/McpServer.Methods.cs new file mode 100644 index 000000000..00fc0a7cc --- /dev/null +++ b/src/ModelContextProtocol.Core/Server/McpServer.Methods.cs @@ -0,0 +1,557 @@ +using Microsoft.Extensions.AI; +using Microsoft.Extensions.Logging; +using ModelContextProtocol.Protocol; +using System.Collections.Concurrent; +using System.Diagnostics.CodeAnalysis; +using System.Runtime.CompilerServices; +using System.Text; +using System.Text.Json; +using System.Text.Json.Nodes; +using System.Text.Json.Serialization.Metadata; + +namespace ModelContextProtocol.Server; + +/// +/// Represents an instance of a Model Context Protocol (MCP) server that connects to and communicates with an MCP client. +/// +#pragma warning disable CS0618 // Type or member is obsolete +public abstract partial class McpServer : McpSession, IMcpServer +#pragma warning restore CS0618 // Type or member is obsolete +{ + /// + /// Caches request schemas for elicitation requests based on the type and serializer options. + /// + private static readonly ConditionalWeakTable> s_elicitResultSchemaCache = new(); + + private static Dictionary>? s_elicitAllowedProperties = null; + + /// + /// Creates a new instance of an . + /// + /// Transport to use for the server representing an already-established MCP session. + /// Configuration options for this server, including capabilities. + /// Logger factory to use for logging. If null, logging will be disabled. + /// Optional service provider to create new instances of tools and other dependencies. + /// An instance that should be disposed when no longer needed. + /// is . + /// is . + public static McpServer Create( + ITransport transport, + McpServerOptions serverOptions, + ILoggerFactory? loggerFactory = null, + IServiceProvider? serviceProvider = null) + { + Throw.IfNull(transport); + Throw.IfNull(serverOptions); + + return new McpServerImpl(transport, serverOptions, loggerFactory, serviceProvider); + } + + /// + /// Requests to sample an LLM via the client using the specified request parameters. + /// + /// The parameters for the sampling request. + /// The to monitor for cancellation requests. + /// A task containing the sampling result from the client. + /// The client does not support sampling. + public ValueTask SampleAsync( + CreateMessageRequestParams request, CancellationToken cancellationToken = default) + { + ThrowIfSamplingUnsupported(); + + return SendRequestAsync( + RequestMethods.SamplingCreateMessage, + request, + McpJsonUtilities.JsonContext.Default.CreateMessageRequestParams, + McpJsonUtilities.JsonContext.Default.CreateMessageResult, + cancellationToken: cancellationToken); + } + + /// + /// Requests to sample an LLM via the client using the provided chat messages and options. + /// + /// The messages to send as part of the request. + /// The options to use for the request, including model parameters and constraints. + /// The to monitor for cancellation requests. The default is . + /// A task containing the chat response from the model. + /// is . + /// The client does not support sampling. + public async Task SampleAsync( + IEnumerable messages, ChatOptions? options = default, CancellationToken cancellationToken = default) + { + Throw.IfNull(messages); + + StringBuilder? systemPrompt = null; + + if (options?.Instructions is { } instructions) + { + (systemPrompt ??= new()).Append(instructions); + } + + List samplingMessages = []; + foreach (var message in messages) + { + if (message.Role == ChatRole.System) + { + if (systemPrompt is null) + { + systemPrompt = new(); + } + else + { + systemPrompt.AppendLine(); + } + + systemPrompt.Append(message.Text); + continue; + } + + if (message.Role == ChatRole.User || message.Role == ChatRole.Assistant) + { + Role role = message.Role == ChatRole.User ? Role.User : Role.Assistant; + + foreach (var content in message.Contents) + { + switch (content) + { + case TextContent textContent: + samplingMessages.Add(new() + { + Role = role, + Content = new TextContentBlock { Text = textContent.Text }, + }); + break; + + case DataContent dataContent when dataContent.HasTopLevelMediaType("image") || dataContent.HasTopLevelMediaType("audio"): + samplingMessages.Add(new() + { + Role = role, + Content = dataContent.HasTopLevelMediaType("image") ? + new ImageContentBlock + { + MimeType = dataContent.MediaType, + Data = dataContent.Base64Data.ToString(), + } : + new AudioContentBlock + { + MimeType = dataContent.MediaType, + Data = dataContent.Base64Data.ToString(), + }, + }); + break; + } + } + } + } + + ModelPreferences? modelPreferences = null; + if (options?.ModelId is { } modelId) + { + modelPreferences = new() { Hints = [new() { Name = modelId }] }; + } + + var result = await SampleAsync(new() + { + Messages = samplingMessages, + MaxTokens = options?.MaxOutputTokens, + StopSequences = options?.StopSequences?.ToArray(), + SystemPrompt = systemPrompt?.ToString(), + Temperature = options?.Temperature, + ModelPreferences = modelPreferences, + }, cancellationToken).ConfigureAwait(false); + + AIContent? responseContent = result.Content.ToAIContent(); + + return new(new ChatMessage(result.Role is Role.User ? ChatRole.User : ChatRole.Assistant, responseContent is not null ? [responseContent] : [])) + { + ModelId = result.Model, + FinishReason = result.StopReason switch + { + "maxTokens" => ChatFinishReason.Length, + "endTurn" or "stopSequence" or _ => ChatFinishReason.Stop, + } + }; + } + + /// + /// Creates an wrapper that can be used to send sampling requests to the client. + /// + /// The that can be used to issue sampling requests to the client. + /// The client does not support sampling. + public IChatClient AsSamplingChatClient() + { + ThrowIfSamplingUnsupported(); + return new SamplingChatClient(this); + } + + /// Gets an on which logged messages will be sent as notifications to the client. + /// An that can be used to log to the client.. + public ILoggerProvider AsClientLoggerProvider() + { + return new ClientLoggerProvider(this); + } + + /// + /// Requests the client to list the roots it exposes. + /// + /// The parameters for the list roots request. + /// The to monitor for cancellation requests. + /// A task containing the list of roots exposed by the client. + /// The client does not support roots. + public ValueTask RequestRootsAsync( + ListRootsRequestParams request, CancellationToken cancellationToken = default) + { + ThrowIfRootsUnsupported(); + + return SendRequestAsync( + RequestMethods.RootsList, + request, + McpJsonUtilities.JsonContext.Default.ListRootsRequestParams, + McpJsonUtilities.JsonContext.Default.ListRootsResult, + cancellationToken: cancellationToken); + } + + /// + /// Requests additional information from the user via the client, allowing the server to elicit structured data. + /// + /// The parameters for the elicitation request. + /// The to monitor for cancellation requests. + /// A task containing the elicitation result. + /// The client does not support elicitation. + public ValueTask ElicitAsync( + ElicitRequestParams request, CancellationToken cancellationToken = default) + { + ThrowIfElicitationUnsupported(); + + return SendRequestAsync( + RequestMethods.ElicitationCreate, + request, + McpJsonUtilities.JsonContext.Default.ElicitRequestParams, + McpJsonUtilities.JsonContext.Default.ElicitResult, + cancellationToken: cancellationToken); + } + + /// + /// Requests additional information from the user via the client, constructing a request schema from the + /// public serializable properties of and deserializing the response into . + /// + /// The type describing the expected input shape. Only primitive members are supported (string, number, boolean, enum). + /// The message to present to the user. + /// Serializer options that influence property naming and deserialization. + /// The to monitor for cancellation requests. + /// An with the user's response, if accepted. + /// + /// Elicitation uses a constrained subset of JSON Schema and only supports strings, numbers/integers, booleans and string enums. + /// Unsupported member types are ignored when constructing the schema. + /// + public async ValueTask> ElicitAsync( + string message, + JsonSerializerOptions? serializerOptions = null, + CancellationToken cancellationToken = default) + { + ThrowIfElicitationUnsupported(); + + serializerOptions ??= McpJsonUtilities.DefaultOptions; + serializerOptions.MakeReadOnly(); + + var dict = s_elicitResultSchemaCache.GetValue(serializerOptions, _ => new()); + +#if NET + var schema = dict.GetOrAdd(typeof(T), static (t, s) => BuildRequestSchema(t, s), serializerOptions); +#else + var schema = dict.GetOrAdd(typeof(T), type => BuildRequestSchema(type, serializerOptions)); +#endif + + var request = new ElicitRequestParams + { + Message = message, + RequestedSchema = schema, + }; + + var raw = await ElicitAsync(request, cancellationToken).ConfigureAwait(false); + + if (!raw.IsAccepted || raw.Content is null) + { + return new ElicitResult { Action = raw.Action, Content = default }; + } + + var obj = new JsonObject(); + foreach (var kvp in raw.Content) + { + obj[kvp.Key] = JsonNode.Parse(kvp.Value.GetRawText()); + } + + T? typed = JsonSerializer.Deserialize(obj, serializerOptions.GetTypeInfo()); + return new ElicitResult { Action = raw.Action, Content = typed }; + } + + /// + /// Builds a request schema for elicitation based on the public serializable properties of . + /// + /// The type of the schema being built. + /// The serializer options to use. + /// The built request schema. + /// + private static ElicitRequestParams.RequestSchema BuildRequestSchema(Type type, JsonSerializerOptions serializerOptions) + { + var schema = new ElicitRequestParams.RequestSchema(); + var props = schema.Properties; + + JsonTypeInfo typeInfo = serializerOptions.GetTypeInfo(type); + + if (typeInfo.Kind != JsonTypeInfoKind.Object) + { + throw new McpException($"Type '{type.FullName}' is not supported for elicitation requests."); + } + + foreach (JsonPropertyInfo pi in typeInfo.Properties) + { + var def = CreatePrimitiveSchema(pi.PropertyType, serializerOptions); + props[pi.Name] = def; + } + + return schema; + } + + /// + /// Creates a primitive schema definition for the specified type, if supported. + /// + /// The type to create the schema for. + /// The serializer options to use. + /// The created primitive schema definition. + /// Thrown when the type is not supported. + private static ElicitRequestParams.PrimitiveSchemaDefinition CreatePrimitiveSchema(Type type, JsonSerializerOptions serializerOptions) + { + if (type.IsGenericType && type.GetGenericTypeDefinition() == typeof(Nullable<>)) + { + throw new McpException($"Type '{type.FullName}' is not a supported property type for elicitation requests. Nullable types are not supported."); + } + + var typeInfo = serializerOptions.GetTypeInfo(type); + + if (typeInfo.Kind != JsonTypeInfoKind.None) + { + throw new McpException($"Type '{type.FullName}' is not a supported property type for elicitation requests."); + } + + var jsonElement = AIJsonUtilities.CreateJsonSchema(type, serializerOptions: serializerOptions); + + if (!TryValidateElicitationPrimitiveSchema(jsonElement, type, out var error)) + { + throw new McpException(error); + } + + var primitiveSchemaDefinition = + jsonElement.Deserialize(McpJsonUtilities.JsonContext.Default.PrimitiveSchemaDefinition); + + if (primitiveSchemaDefinition is null) + throw new McpException($"Type '{type.FullName}' is not a supported property type for elicitation requests."); + + return primitiveSchemaDefinition; + } + + /// + /// Validate the produced schema strictly to the subset we support. We only accept an object schema + /// with a supported primitive type keyword and no additional unsupported keywords.Reject things like + /// {}, 'true', or schemas that include unrelated keywords(e.g.items, properties, patternProperties, etc.). + /// + /// The schema to validate. + /// The type of the schema being validated, just for reporting errors. + /// The error message, if validation fails. + /// + private static bool TryValidateElicitationPrimitiveSchema(JsonElement schema, Type type, + [NotNullWhen(false)] out string? error) + { + if (schema.ValueKind is not JsonValueKind.Object) + { + error = $"Schema generated for type '{type.FullName}' is invalid: expected an object schema."; + return false; + } + + if (!schema.TryGetProperty("type", out JsonElement typeProperty) + || typeProperty.ValueKind is not JsonValueKind.String) + { + error = $"Schema generated for type '{type.FullName}' is invalid: missing or invalid 'type' keyword."; + return false; + } + + var typeKeyword = typeProperty.GetString(); + + if (string.IsNullOrEmpty(typeKeyword)) + { + error = $"Schema generated for type '{type.FullName}' is invalid: empty 'type' value."; + return false; + } + + if (typeKeyword is not ("string" or "number" or "integer" or "boolean")) + { + error = $"Schema generated for type '{type.FullName}' is invalid: unsupported primitive type '{typeKeyword}'."; + return false; + } + + s_elicitAllowedProperties ??= new() + { + ["string"] = ["type", "title", "description", "minLength", "maxLength", "format", "enum", "enumNames"], + ["number"] = ["type", "title", "description", "minimum", "maximum"], + ["integer"] = ["type", "title", "description", "minimum", "maximum"], + ["boolean"] = ["type", "title", "description", "default"] + }; + + var allowed = s_elicitAllowedProperties[typeKeyword]; + + foreach (JsonProperty prop in schema.EnumerateObject()) + { + if (!allowed.Contains(prop.Name)) + { + error = $"The property '{type.FullName}.{prop.Name}' is not supported for elicitation."; + return false; + } + } + + error = string.Empty; + return true; + } + + private void ThrowIfSamplingUnsupported() + { + if (ClientCapabilities?.Sampling is null) + { + if (ServerOptions.KnownClientInfo is not null) + { + throw new InvalidOperationException("Sampling is not supported in stateless mode."); + } + + throw new InvalidOperationException("Client does not support sampling."); + } + } + + private void ThrowIfRootsUnsupported() + { + if (ClientCapabilities?.Roots is null) + { + if (ServerOptions.KnownClientInfo is not null) + { + throw new InvalidOperationException("Roots are not supported in stateless mode."); + } + + throw new InvalidOperationException("Client does not support roots."); + } + } + + private void ThrowIfElicitationUnsupported() + { + if (ClientCapabilities?.Elicitation is null) + { + if (ServerOptions.KnownClientInfo is not null) + { + throw new InvalidOperationException("Elicitation is not supported in stateless mode."); + } + + throw new InvalidOperationException("Client does not support elicitation requests."); + } + } + + /// Provides an implementation that's implemented via client sampling. + private sealed class SamplingChatClient : IChatClient + { + private readonly McpServer _server; + + public SamplingChatClient(McpServer server) => _server = server; + + /// + public Task GetResponseAsync(IEnumerable messages, ChatOptions? options = null, CancellationToken cancellationToken = default) => + _server.SampleAsync(messages, options, cancellationToken); + + /// + async IAsyncEnumerable IChatClient.GetStreamingResponseAsync( + IEnumerable messages, ChatOptions? options, [EnumeratorCancellation] CancellationToken cancellationToken) + { + var response = await GetResponseAsync(messages, options, cancellationToken).ConfigureAwait(false); + foreach (var update in response.ToChatResponseUpdates()) + { + yield return update; + } + } + + /// + object? IChatClient.GetService(Type serviceType, object? serviceKey) + { + Throw.IfNull(serviceType); + + return + serviceKey is not null ? null : + serviceType.IsInstanceOfType(this) ? this : + serviceType.IsInstanceOfType(_server) ? _server : + null; + } + + /// + void IDisposable.Dispose() { } // nop + } + + /// + /// Provides an implementation for creating loggers + /// that send logging message notifications to the client for logged messages. + /// + private sealed class ClientLoggerProvider : ILoggerProvider + { + private readonly McpServer _server; + + public ClientLoggerProvider(McpServer server) => _server = server; + + /// + public ILogger CreateLogger(string categoryName) + { + Throw.IfNull(categoryName); + + return new ClientLogger(_server, categoryName); + } + + /// + void IDisposable.Dispose() { } + + private sealed class ClientLogger : ILogger + { + private readonly McpServer _server; + private readonly string _categoryName; + + public ClientLogger(McpServer server, string categoryName) + { + _server = server; + _categoryName = categoryName; + } + + /// + public IDisposable? BeginScope(TState state) where TState : notnull => + null; + + /// + public bool IsEnabled(LogLevel logLevel) => + _server?.LoggingLevel is { } loggingLevel && + McpServerImpl.ToLoggingLevel(logLevel) >= loggingLevel; + + /// + public void Log(LogLevel logLevel, EventId eventId, TState state, Exception? exception, Func formatter) + { + if (!IsEnabled(logLevel)) + { + return; + } + + Throw.IfNull(formatter); + + LogInternal(logLevel, formatter(state, exception)); + + void LogInternal(LogLevel level, string message) + { + _ = _server.SendNotificationAsync(NotificationMethods.LoggingMessageNotification, new LoggingMessageNotificationParams + { + Level = McpServerImpl.ToLoggingLevel(level), + Data = JsonSerializer.SerializeToElement(message, McpJsonUtilities.JsonContext.Default.String), + Logger = _categoryName, + }); + } + } + } + } +} diff --git a/src/ModelContextProtocol.Core/Server/McpServer.cs b/src/ModelContextProtocol.Core/Server/McpServer.cs index 6e15e2465..02c17de1a 100644 --- a/src/ModelContextProtocol.Core/Server/McpServer.cs +++ b/src/ModelContextProtocol.Core/Server/McpServer.cs @@ -1,714 +1,64 @@ -using Microsoft.Extensions.DependencyInjection; -using Microsoft.Extensions.Logging; using ModelContextProtocol.Protocol; -using System.Runtime.CompilerServices; -using System.Text.Json.Serialization.Metadata; namespace ModelContextProtocol.Server; -/// -internal sealed partial class McpServer : McpEndpoint, IMcpServer +/// +/// Represents an instance of a Model Context Protocol (MCP) server that connects to and communicates with an MCP client. +/// +#pragma warning disable CS0618 // Type or member is obsolete +public abstract partial class McpServer : McpSession, IMcpServer +#pragma warning restore CS0618 // Type or member is obsolete { - internal static Implementation DefaultImplementation { get; } = new() - { - Name = DefaultAssemblyName.Name ?? nameof(McpServer), - Version = DefaultAssemblyName.Version?.ToString() ?? "1.0.0", - }; - - private readonly ITransport _sessionTransport; - private readonly bool _servicesScopePerRequest; - private readonly List _disposables = []; - - private readonly string _serverOnlyEndpointName; - private string? _endpointName; - private int _started; - - /// Holds a boxed value for the server. + /// + /// Gets the capabilities supported by the client. + /// /// - /// Initialized to non-null the first time SetLevel is used. This is stored as a strong box - /// rather than a nullable to be able to manipulate it atomically. + /// + /// These capabilities are established during the initialization handshake and indicate + /// which features the client supports, such as sampling, roots, and other + /// protocol-specific functionality. + /// + /// + /// Server implementations can check these capabilities to determine which features + /// are available when interacting with the client. + /// /// - private StrongBox? _loggingLevel; + public abstract ClientCapabilities? ClientCapabilities { get; } /// - /// Creates a new instance of . + /// Gets the version and implementation information of the connected client. /// - /// Transport to use for the server representing an already-established session. - /// Configuration options for this server, including capabilities. - /// Make sure to accurately reflect exactly what capabilities the server supports and does not support. - /// Logger factory to use for logging - /// Optional service provider to use for dependency injection - /// The server was incorrectly configured. - public McpServer(ITransport transport, McpServerOptions options, ILoggerFactory? loggerFactory, IServiceProvider? serviceProvider) - : base(loggerFactory) - { - Throw.IfNull(transport); - Throw.IfNull(options); - - options ??= new(); - - _sessionTransport = transport; - ServerOptions = options; - Services = serviceProvider; - _serverOnlyEndpointName = $"Server ({options.ServerInfo?.Name ?? DefaultImplementation.Name} {options.ServerInfo?.Version ?? DefaultImplementation.Version})"; - _servicesScopePerRequest = options.ScopeRequests; - - ClientInfo = options.KnownClientInfo; - UpdateEndpointNameWithClientInfo(); - - // Configure all request handlers based on the supplied options. - ServerCapabilities = new(); - ConfigureInitialize(options); - ConfigureTools(options); - ConfigurePrompts(options); - ConfigureResources(options); - ConfigureLogging(options); - ConfigureCompletion(options); - ConfigureExperimental(options); - ConfigurePing(); - - // Register any notification handlers that were provided. - if (options.Capabilities?.NotificationHandlers is { } notificationHandlers) - { - NotificationHandlers.RegisterRange(notificationHandlers); - } - - // Now that everything has been configured, subscribe to any necessary notifications. - if (transport is not StreamableHttpServerTransport streamableHttpTransport || streamableHttpTransport.Stateless is false) - { - Register(ServerOptions.Capabilities?.Tools?.ToolCollection, NotificationMethods.ToolListChangedNotification); - Register(ServerOptions.Capabilities?.Prompts?.PromptCollection, NotificationMethods.PromptListChangedNotification); - Register(ServerOptions.Capabilities?.Resources?.ResourceCollection, NotificationMethods.ResourceListChangedNotification); - - void Register(McpServerPrimitiveCollection? collection, string notificationMethod) - where TPrimitive : IMcpServerPrimitive - { - if (collection is not null) - { - EventHandler changed = (sender, e) => _ = this.SendNotificationAsync(notificationMethod); - collection.Changed += changed; - _disposables.Add(() => collection.Changed -= changed); - } - } - } - - // And initialize the session. - InitializeSession(transport); - } - - /// - public string? SessionId => _sessionTransport.SessionId; - - /// - public ServerCapabilities ServerCapabilities { get; } = new(); - - /// - public ClientCapabilities? ClientCapabilities { get; set; } - - /// - public Implementation? ClientInfo { get; set; } - - /// - public McpServerOptions ServerOptions { get; } - - /// - public IServiceProvider? Services { get; } - - /// - public override string EndpointName => _endpointName ?? _serverOnlyEndpointName; - - /// - public LoggingLevel? LoggingLevel => _loggingLevel?.Value; - - /// - public async Task RunAsync(CancellationToken cancellationToken = default) - { - if (Interlocked.Exchange(ref _started, 1) != 0) - { - throw new InvalidOperationException($"{nameof(RunAsync)} must only be called once."); - } - - try - { - StartSession(_sessionTransport, fullSessionCancellationToken: cancellationToken); - await MessageProcessingTask.ConfigureAwait(false); - } - finally - { - await DisposeAsync().ConfigureAwait(false); - } - } - - public override async ValueTask DisposeUnsynchronizedAsync() - { - _disposables.ForEach(d => d()); - await base.DisposeUnsynchronizedAsync().ConfigureAwait(false); - } - - private void ConfigurePing() - { - SetHandler(RequestMethods.Ping, - async (request, _) => new PingResult(), - McpJsonUtilities.JsonContext.Default.JsonNode, - McpJsonUtilities.JsonContext.Default.PingResult); - } - - private void ConfigureInitialize(McpServerOptions options) - { - RequestHandlers.Set(RequestMethods.Initialize, - async (request, _, _) => - { - ClientCapabilities = request?.Capabilities ?? new(); - ClientInfo = request?.ClientInfo; - - // Use the ClientInfo to update the session EndpointName for logging. - UpdateEndpointNameWithClientInfo(); - GetSessionOrThrow().EndpointName = EndpointName; - - // Negotiate a protocol version. If the server options provide one, use that. - // Otherwise, try to use whatever the client requested as long as it's supported. - // If it's not supported, fall back to the latest supported version. - string? protocolVersion = options.ProtocolVersion; - if (protocolVersion is null) - { - protocolVersion = request?.ProtocolVersion is string clientProtocolVersion && McpSession.SupportedProtocolVersions.Contains(clientProtocolVersion) ? - clientProtocolVersion : - McpSession.LatestProtocolVersion; - } - - return new InitializeResult - { - ProtocolVersion = protocolVersion, - Instructions = options.ServerInstructions, - ServerInfo = options.ServerInfo ?? DefaultImplementation, - Capabilities = ServerCapabilities ?? new(), - }; - }, - McpJsonUtilities.JsonContext.Default.InitializeRequestParams, - McpJsonUtilities.JsonContext.Default.InitializeResult); - } - - private void ConfigureCompletion(McpServerOptions options) - { - if (options.Capabilities?.Completions is not { } completionsCapability) - { - return; - } - - var completeHandler = completionsCapability.CompleteHandler ?? (static async (_, __) => new CompleteResult()); - completeHandler = BuildFilterPipeline(completeHandler, options.Filters.CompleteFilters); - - ServerCapabilities.Completions = new() - { - CompleteHandler = completeHandler - }; - - SetHandler( - RequestMethods.CompletionComplete, - ServerCapabilities.Completions.CompleteHandler, - McpJsonUtilities.JsonContext.Default.CompleteRequestParams, - McpJsonUtilities.JsonContext.Default.CompleteResult); - } - - private void ConfigureExperimental(McpServerOptions options) - { - ServerCapabilities.Experimental = options.Capabilities?.Experimental; - } - - private void ConfigureResources(McpServerOptions options) - { - if (options.Capabilities?.Resources is not { } resourcesCapability) - { - return; - } - - ServerCapabilities.Resources = new(); - - var listResourcesHandler = resourcesCapability.ListResourcesHandler ?? (static async (_, __) => new ListResourcesResult()); - var listResourceTemplatesHandler = resourcesCapability.ListResourceTemplatesHandler ?? (static async (_, __) => new ListResourceTemplatesResult()); - var readResourceHandler = resourcesCapability.ReadResourceHandler ?? (static async (request, _) => throw new McpException($"Unknown resource URI: '{request.Params?.Uri}'", McpErrorCode.InvalidParams)); - var subscribeHandler = resourcesCapability.SubscribeToResourcesHandler ?? (static async (_, __) => new EmptyResult()); - var unsubscribeHandler = resourcesCapability.UnsubscribeFromResourcesHandler ?? (static async (_, __) => new EmptyResult()); - var resources = resourcesCapability.ResourceCollection; - var listChanged = resourcesCapability.ListChanged; - var subscribe = resourcesCapability.Subscribe; - - // Handle resources provided via DI. - if (resources is { IsEmpty: false }) - { - var originalListResourcesHandler = listResourcesHandler; - listResourcesHandler = async (request, cancellationToken) => - { - ListResourcesResult result = originalListResourcesHandler is not null ? - await originalListResourcesHandler(request, cancellationToken).ConfigureAwait(false) : - new(); - - if (request.Params?.Cursor is null) - { - foreach (var r in resources) - { - if (r.ProtocolResource is { } resource) - { - result.Resources.Add(resource); - } - } - } - - return result; - }; - - var originalListResourceTemplatesHandler = listResourceTemplatesHandler; - listResourceTemplatesHandler = async (request, cancellationToken) => - { - ListResourceTemplatesResult result = originalListResourceTemplatesHandler is not null ? - await originalListResourceTemplatesHandler(request, cancellationToken).ConfigureAwait(false) : - new(); - - if (request.Params?.Cursor is null) - { - foreach (var rt in resources) - { - if (rt.IsTemplated) - { - result.ResourceTemplates.Add(rt.ProtocolResourceTemplate); - } - } - } - - return result; - }; - - // Synthesize read resource handler, which covers both resources and resource templates. - var originalReadResourceHandler = readResourceHandler; - readResourceHandler = async (request, cancellationToken) => - { - if (request.MatchedPrimitive is McpServerResource matchedResource) - { - if (await matchedResource.ReadAsync(request, cancellationToken).ConfigureAwait(false) is { } result) - { - return result; - } - } - - return await originalReadResourceHandler(request, cancellationToken).ConfigureAwait(false); - }; - - listChanged = true; - - // TODO: Implement subscribe/unsubscribe logic for resource and resource template collections. - // subscribe = true; - } - - listResourcesHandler = BuildFilterPipeline(listResourcesHandler, options.Filters.ListResourcesFilters); - listResourceTemplatesHandler = BuildFilterPipeline(listResourceTemplatesHandler, options.Filters.ListResourceTemplatesFilters); - readResourceHandler = BuildFilterPipeline(readResourceHandler, options.Filters.ReadResourceFilters, handler => - async (request, cancellationToken) => - { - // Initial handler that sets MatchedPrimitive - if (request.Params?.Uri is { } uri && resources is not null) - { - // First try an O(1) lookup by exact match. - if (resources.TryGetPrimitive(uri, out var resource)) - { - request.MatchedPrimitive = resource; - } - else - { - // Fall back to an O(N) lookup, trying to match against each URI template. - // The number of templates is controlled by the server developer, and the number is expected to be - // not terribly large. If that changes, this can be tweaked to enable a more efficient lookup. - foreach (var resourceTemplate in resources) - { - // Check if this template would handle the request by testing if ReadAsync would succeed - if (resourceTemplate.IsTemplated) - { - // This is a simplified check - a more robust implementation would match the URI pattern - // For now, we'll let the actual handler attempt the match - request.MatchedPrimitive = resourceTemplate; - break; - } - } - } - } - - return await handler(request, cancellationToken).ConfigureAwait(false); - }); - subscribeHandler = BuildFilterPipeline(subscribeHandler, options.Filters.SubscribeToResourcesFilters); - unsubscribeHandler = BuildFilterPipeline(unsubscribeHandler, options.Filters.UnsubscribeFromResourcesFilters); - - ServerCapabilities.Resources.ListResourcesHandler = listResourcesHandler; - ServerCapabilities.Resources.ListResourceTemplatesHandler = listResourceTemplatesHandler; - ServerCapabilities.Resources.ReadResourceHandler = readResourceHandler; - ServerCapabilities.Resources.ResourceCollection = resources; - ServerCapabilities.Resources.SubscribeToResourcesHandler = subscribeHandler; - ServerCapabilities.Resources.UnsubscribeFromResourcesHandler = unsubscribeHandler; - ServerCapabilities.Resources.ListChanged = listChanged; - ServerCapabilities.Resources.Subscribe = subscribe; - - SetHandler( - RequestMethods.ResourcesList, - listResourcesHandler, - McpJsonUtilities.JsonContext.Default.ListResourcesRequestParams, - McpJsonUtilities.JsonContext.Default.ListResourcesResult); - - SetHandler( - RequestMethods.ResourcesTemplatesList, - listResourceTemplatesHandler, - McpJsonUtilities.JsonContext.Default.ListResourceTemplatesRequestParams, - McpJsonUtilities.JsonContext.Default.ListResourceTemplatesResult); - - SetHandler( - RequestMethods.ResourcesRead, - readResourceHandler, - McpJsonUtilities.JsonContext.Default.ReadResourceRequestParams, - McpJsonUtilities.JsonContext.Default.ReadResourceResult); - - SetHandler( - RequestMethods.ResourcesSubscribe, - subscribeHandler, - McpJsonUtilities.JsonContext.Default.SubscribeRequestParams, - McpJsonUtilities.JsonContext.Default.EmptyResult); - - SetHandler( - RequestMethods.ResourcesUnsubscribe, - unsubscribeHandler, - McpJsonUtilities.JsonContext.Default.UnsubscribeRequestParams, - McpJsonUtilities.JsonContext.Default.EmptyResult); - } - - private void ConfigurePrompts(McpServerOptions options) - { - if (options.Capabilities?.Prompts is not { } promptsCapability) - { - return; - } - - ServerCapabilities.Prompts = new(); - - var listPromptsHandler = promptsCapability.ListPromptsHandler ?? (static async (_, __) => new ListPromptsResult()); - var getPromptHandler = promptsCapability.GetPromptHandler ?? (static async (request, _) => throw new McpException($"Unknown prompt: '{request.Params?.Name}'", McpErrorCode.InvalidParams)); - var prompts = promptsCapability.PromptCollection; - var listChanged = promptsCapability.ListChanged; - - // Handle tools provided via DI by augmenting the handlers to incorporate them. - if (prompts is { IsEmpty: false }) - { - var originalListPromptsHandler = listPromptsHandler; - listPromptsHandler = async (request, cancellationToken) => - { - ListPromptsResult result = originalListPromptsHandler is not null ? - await originalListPromptsHandler(request, cancellationToken).ConfigureAwait(false) : - new(); - - if (request.Params?.Cursor is null) - { - foreach (var p in prompts) - { - result.Prompts.Add(p.ProtocolPrompt); - } - } - - return result; - }; - - var originalGetPromptHandler = getPromptHandler; - getPromptHandler = (request, cancellationToken) => - { - if (request.MatchedPrimitive is McpServerPrompt prompt) - { - return prompt.GetAsync(request, cancellationToken); - } - - return originalGetPromptHandler(request, cancellationToken); - }; - - listChanged = true; - } - - listPromptsHandler = BuildFilterPipeline(listPromptsHandler, options.Filters.ListPromptsFilters); - getPromptHandler = BuildFilterPipeline(getPromptHandler, options.Filters.GetPromptFilters, handler => - (request, cancellationToken) => - { - // Initial handler that sets MatchedPrimitive - if (request.Params?.Name is { } promptName && prompts is not null && - prompts.TryGetPrimitive(promptName, out var prompt)) - { - request.MatchedPrimitive = prompt; - } - - return handler(request, cancellationToken); - }); - - ServerCapabilities.Prompts.ListPromptsHandler = listPromptsHandler; - ServerCapabilities.Prompts.GetPromptHandler = getPromptHandler; - ServerCapabilities.Prompts.PromptCollection = prompts; - ServerCapabilities.Prompts.ListChanged = listChanged; - - SetHandler( - RequestMethods.PromptsList, - listPromptsHandler, - McpJsonUtilities.JsonContext.Default.ListPromptsRequestParams, - McpJsonUtilities.JsonContext.Default.ListPromptsResult); - - SetHandler( - RequestMethods.PromptsGet, - getPromptHandler, - McpJsonUtilities.JsonContext.Default.GetPromptRequestParams, - McpJsonUtilities.JsonContext.Default.GetPromptResult); - } - - private void ConfigureTools(McpServerOptions options) - { - if (options.Capabilities?.Tools is not { } toolsCapability) - { - return; - } - - ServerCapabilities.Tools = new(); - - var listToolsHandler = toolsCapability.ListToolsHandler ?? (static async (_, __) => new ListToolsResult()); - var callToolHandler = toolsCapability.CallToolHandler ?? (static async (request, _) => throw new McpException($"Unknown tool: '{request.Params?.Name}'", McpErrorCode.InvalidParams)); - var tools = toolsCapability.ToolCollection; - var listChanged = toolsCapability.ListChanged; - - // Handle tools provided via DI by augmenting the handlers to incorporate them. - if (tools is { IsEmpty: false }) - { - var originalListToolsHandler = listToolsHandler; - listToolsHandler = async (request, cancellationToken) => - { - ListToolsResult result = originalListToolsHandler is not null ? - await originalListToolsHandler(request, cancellationToken).ConfigureAwait(false) : - new(); - - if (request.Params?.Cursor is null) - { - foreach (var t in tools) - { - result.Tools.Add(t.ProtocolTool); - } - } - - return result; - }; - - var originalCallToolHandler = callToolHandler; - callToolHandler = (request, cancellationToken) => - { - if (request.MatchedPrimitive is McpServerTool tool) - { - return tool.InvokeAsync(request, cancellationToken); - } - - return originalCallToolHandler(request, cancellationToken); - }; - - listChanged = true; - } - - listToolsHandler = BuildFilterPipeline(listToolsHandler, options.Filters.ListToolsFilters); - callToolHandler = BuildFilterPipeline(callToolHandler, options.Filters.CallToolFilters, handler => - (request, cancellationToken) => - { - // Initial handler that sets MatchedPrimitive - if (request.Params?.Name is { } toolName && tools is not null && - tools.TryGetPrimitive(toolName, out var tool)) - { - request.MatchedPrimitive = tool; - } - - return handler(request, cancellationToken); - }, handler => - async (request, cancellationToken) => - { - // Final handler that provides exception handling only for tool execution - // Only wrap tool execution in try-catch, not tool resolution - if (request.MatchedPrimitive is McpServerTool) - { - try - { - return await handler(request, cancellationToken).ConfigureAwait(false); - } - catch (Exception e) when (e is not OperationCanceledException) - { - ToolCallError(request.Params?.Name ?? string.Empty, e); - - string errorMessage = e is McpException ? - $"An error occurred invoking '{request.Params?.Name}': {e.Message}" : - $"An error occurred invoking '{request.Params?.Name}'."; - - return new() - { - IsError = true, - Content = [new TextContentBlock { Text = errorMessage }], - }; - } - } - else - { - // For unmatched tools, let exceptions bubble up as protocol errors - return await handler(request, cancellationToken).ConfigureAwait(false); - } - }); - - ServerCapabilities.Tools.ListToolsHandler = listToolsHandler; - ServerCapabilities.Tools.CallToolHandler = callToolHandler; - ServerCapabilities.Tools.ToolCollection = tools; - ServerCapabilities.Tools.ListChanged = listChanged; - - SetHandler( - RequestMethods.ToolsList, - listToolsHandler, - McpJsonUtilities.JsonContext.Default.ListToolsRequestParams, - McpJsonUtilities.JsonContext.Default.ListToolsResult); - - SetHandler( - RequestMethods.ToolsCall, - callToolHandler, - McpJsonUtilities.JsonContext.Default.CallToolRequestParams, - McpJsonUtilities.JsonContext.Default.CallToolResult); - } - - private void ConfigureLogging(McpServerOptions options) - { - // We don't require that the handler be provided, as we always store the provided log level to the server. - var setLoggingLevelHandler = options.Capabilities?.Logging?.SetLoggingLevelHandler; - - // Apply filters to the handler - if (setLoggingLevelHandler is not null) - { - setLoggingLevelHandler = BuildFilterPipeline(setLoggingLevelHandler, options.Filters.SetLoggingLevelFilters); - } - - ServerCapabilities.Logging = new(); - ServerCapabilities.Logging.SetLoggingLevelHandler = setLoggingLevelHandler; - - RequestHandlers.Set( - RequestMethods.LoggingSetLevel, - (request, jsonRpcRequest, cancellationToken) => - { - // Store the provided level. - if (request is not null) - { - if (_loggingLevel is null) - { - Interlocked.CompareExchange(ref _loggingLevel, new(request.Level), null); - } - - _loggingLevel.Value = request.Level; - } - - // If a handler was provided, now delegate to it. - if (setLoggingLevelHandler is not null) - { - return InvokeHandlerAsync(setLoggingLevelHandler, request, jsonRpcRequest, cancellationToken); - } - - // Otherwise, consider it handled. - return new ValueTask(EmptyResult.Instance); - }, - McpJsonUtilities.JsonContext.Default.SetLevelRequestParams, - McpJsonUtilities.JsonContext.Default.EmptyResult); - } - - private ValueTask InvokeHandlerAsync( - McpRequestHandler handler, - TParams? args, - JsonRpcRequest jsonRpcRequest, - CancellationToken cancellationToken = default) - { - return _servicesScopePerRequest ? - InvokeScopedAsync(handler, args, jsonRpcRequest, cancellationToken) : - handler(new(new DestinationBoundMcpServer(this, jsonRpcRequest.Context?.RelatedTransport), jsonRpcRequest) { Params = args }, cancellationToken); - - async ValueTask InvokeScopedAsync( - McpRequestHandler handler, - TParams? args, - JsonRpcRequest jsonRpcRequest, - CancellationToken cancellationToken) - { - var scope = Services?.GetService()?.CreateAsyncScope(); - try - { - return await handler( - new RequestContext(new DestinationBoundMcpServer(this, jsonRpcRequest.Context?.RelatedTransport), jsonRpcRequest) - { - Services = scope?.ServiceProvider ?? Services, - Params = args - }, - cancellationToken).ConfigureAwait(false); - } - finally - { - if (scope is not null) - { - await scope.Value.DisposeAsync().ConfigureAwait(false); - } - } - } - } - - private void SetHandler( - string method, - McpRequestHandler handler, - JsonTypeInfo requestTypeInfo, - JsonTypeInfo responseTypeInfo) - { - RequestHandlers.Set(method, - (request, jsonRpcRequest, cancellationToken) => - InvokeHandlerAsync(handler, request, jsonRpcRequest, cancellationToken), - requestTypeInfo, responseTypeInfo); - } - - private static McpRequestHandler BuildFilterPipeline( - McpRequestHandler baseHandler, - List> filters, - McpRequestFilter? initialHandler = null, - McpRequestFilter? finalHandler = null) - { - var current = baseHandler; - - if (finalHandler is not null) - { - current = finalHandler(current); - } - - for (int i = filters.Count - 1; i >= 0; i--) - { - current = filters[i](current); - } - - if (initialHandler is not null) - { - current = initialHandler(current); - } - - return current; - } + /// + /// + /// This property contains identification information about the client that has connected to this server, + /// including its name and version. This information is provided by the client during initialization. + /// + /// + /// Server implementations can use this information for logging, tracking client versions, + /// or implementing client-specific behaviors. + /// + /// + public abstract Implementation? ClientInfo { get; } - private void UpdateEndpointNameWithClientInfo() - { - if (ClientInfo is null) - { - return; - } + /// + /// Gets the options used to construct this server. + /// + /// + /// These options define the server's capabilities, protocol version, and other configuration + /// settings that were used to initialize the server. + /// + public abstract McpServerOptions ServerOptions { get; } - _endpointName = $"{_serverOnlyEndpointName}, Client ({ClientInfo.Name} {ClientInfo.Version})"; - } + /// + /// Gets the service provider for the server. + /// + public abstract IServiceProvider? Services { get; } - /// Maps a to a . - internal static LoggingLevel ToLoggingLevel(LogLevel level) => - level switch - { - LogLevel.Trace => Protocol.LoggingLevel.Debug, - LogLevel.Debug => Protocol.LoggingLevel.Debug, - LogLevel.Information => Protocol.LoggingLevel.Info, - LogLevel.Warning => Protocol.LoggingLevel.Warning, - LogLevel.Error => Protocol.LoggingLevel.Error, - LogLevel.Critical => Protocol.LoggingLevel.Critical, - _ => Protocol.LoggingLevel.Emergency, - }; + /// Gets the last logging level set by the client, or if it's never been set. + public abstract LoggingLevel? LoggingLevel { get; } - [LoggerMessage(Level = LogLevel.Error, Message = "\"{ToolName}\" threw an unhandled exception.")] - private partial void ToolCallError(string toolName, Exception exception); + /// + /// Runs the server, listening for and handling client requests. + /// + public abstract Task RunAsync(CancellationToken cancellationToken = default); } diff --git a/src/ModelContextProtocol.Core/Server/McpServerExtensions.cs b/src/ModelContextProtocol.Core/Server/McpServerExtensions.cs index 97adcc307..79d545285 100644 --- a/src/ModelContextProtocol.Core/Server/McpServerExtensions.cs +++ b/src/ModelContextProtocol.Core/Server/McpServerExtensions.cs @@ -1,13 +1,9 @@ -using Microsoft.Extensions.AI; +using Microsoft.Extensions.AI; using Microsoft.Extensions.Logging; using ModelContextProtocol.Protocol; -using System.Collections.Concurrent; using System.Diagnostics.CodeAnalysis; using System.Runtime.CompilerServices; -using System.Text; using System.Text.Json; -using System.Text.Json.Nodes; -using System.Text.Json.Serialization.Metadata; namespace ModelContextProtocol.Server; @@ -16,13 +12,6 @@ namespace ModelContextProtocol.Server; /// public static class McpServerExtensions { - /// - /// Caches request schemas for elicitation requests based on the type and serializer options. - /// - private static readonly ConditionalWeakTable> s_elicitResultSchemaCache = new(); - - private static Dictionary>? s_elicitAllowedProperties = null; - /// /// Requests to sample an LLM via the client using the specified request parameters. /// @@ -37,19 +26,10 @@ public static class McpServerExtensions /// It allows detailed control over sampling parameters including messages, system prompt, temperature, /// and token limits. /// + [Obsolete($"Use {nameof(McpServer)}.{nameof(McpServer.SampleAsync)} instead.")] // See: https://github.com/modelcontextprotocol/csharp-sdk/issues/774 public static ValueTask SampleAsync( this IMcpServer server, CreateMessageRequestParams request, CancellationToken cancellationToken = default) - { - Throw.IfNull(server); - ThrowIfSamplingUnsupported(server); - - return server.SendRequestAsync( - RequestMethods.SamplingCreateMessage, - request, - McpJsonUtilities.JsonContext.Default.CreateMessageRequestParams, - McpJsonUtilities.JsonContext.Default.CreateMessageResult, - cancellationToken: cancellationToken); - } + => AsServerOrThrow(server).SampleAsync(request, cancellationToken); /// /// Requests to sample an LLM via the client using the provided chat messages and options. @@ -66,104 +46,11 @@ public static ValueTask SampleAsync( /// This method converts the provided chat messages into a format suitable for the sampling API, /// handling different content types such as text, images, and audio. /// - public static async Task SampleAsync( + [Obsolete($"Use {nameof(McpServer)}.{nameof(McpServer.SampleAsync)} instead.")] // See: https://github.com/modelcontextprotocol/csharp-sdk/issues/774] + public static Task SampleAsync( this IMcpServer server, IEnumerable messages, ChatOptions? options = default, CancellationToken cancellationToken = default) - { - Throw.IfNull(server); - Throw.IfNull(messages); - - StringBuilder? systemPrompt = null; - - if (options?.Instructions is { } instructions) - { - (systemPrompt ??= new()).Append(instructions); - } - - List samplingMessages = []; - foreach (var message in messages) - { - if (message.Role == ChatRole.System) - { - if (systemPrompt is null) - { - systemPrompt = new(); - } - else - { - systemPrompt.AppendLine(); - } - - systemPrompt.Append(message.Text); - continue; - } - - if (message.Role == ChatRole.User || message.Role == ChatRole.Assistant) - { - Role role = message.Role == ChatRole.User ? Role.User : Role.Assistant; - - foreach (var content in message.Contents) - { - switch (content) - { - case TextContent textContent: - samplingMessages.Add(new() - { - Role = role, - Content = new TextContentBlock { Text = textContent.Text }, - }); - break; - - case DataContent dataContent when dataContent.HasTopLevelMediaType("image") || dataContent.HasTopLevelMediaType("audio"): - samplingMessages.Add(new() - { - Role = role, - Content = dataContent.HasTopLevelMediaType("image") ? - new ImageContentBlock - { - MimeType = dataContent.MediaType, - Data = dataContent.Base64Data.ToString(), - } : - new AudioContentBlock - { - MimeType = dataContent.MediaType, - Data = dataContent.Base64Data.ToString(), - }, - }); - break; - } - } - } - } - - ModelPreferences? modelPreferences = null; - if (options?.ModelId is { } modelId) - { - modelPreferences = new() { Hints = [new() { Name = modelId }] }; - } - - var result = await server.SampleAsync(new() - { - Messages = samplingMessages, - MaxTokens = options?.MaxOutputTokens, - StopSequences = options?.StopSequences?.ToArray(), - SystemPrompt = systemPrompt?.ToString(), - Temperature = options?.Temperature, - ModelPreferences = modelPreferences, - }, cancellationToken).ConfigureAwait(false); - - AIContent? responseContent = result.Content.ToAIContent(); - - return new(new ChatMessage(result.Role is Role.User ? ChatRole.User : ChatRole.Assistant, responseContent is not null ? [responseContent] : [])) - { - ModelId = result.Model, - FinishReason = result.StopReason switch - { - "maxTokens" => ChatFinishReason.Length, - "endTurn" or "stopSequence" or _ => ChatFinishReason.Stop, - } - }; - } + => AsServerOrThrow(server).SampleAsync(messages, options, cancellationToken); /// /// Creates an wrapper that can be used to send sampling requests to the client. @@ -172,23 +59,16 @@ public static async Task SampleAsync( /// The that can be used to issue sampling requests to the client. /// is . /// The client does not support sampling. + [Obsolete($"Use {nameof(McpServer)}.{nameof(McpServer.AsSamplingChatClient)} instead.")] // See: https://github.com/modelcontextprotocol/csharp-sdk/issues/774 public static IChatClient AsSamplingChatClient(this IMcpServer server) - { - Throw.IfNull(server); - ThrowIfSamplingUnsupported(server); - - return new SamplingChatClient(server); - } + => AsServerOrThrow(server).AsSamplingChatClient(); /// Gets an on which logged messages will be sent as notifications to the client. /// The server to wrap as an . /// An that can be used to log to the client.. + [Obsolete($"Use {nameof(McpServer)}.{nameof(McpServer.AsSamplingChatClient)} instead.")] // See: https://github.com/modelcontextprotocol/csharp-sdk/issues/774 public static ILoggerProvider AsClientLoggerProvider(this IMcpServer server) - { - Throw.IfNull(server); - - return new ClientLoggerProvider(server); - } + => AsServerOrThrow(server).AsClientLoggerProvider(); /// /// Requests the client to list the roots it exposes. @@ -205,19 +85,10 @@ public static ILoggerProvider AsClientLoggerProvider(this IMcpServer server) /// navigated and accessed by the server. These resources might include file systems, databases, /// or other structured data sources that the client makes available through the protocol. /// + [Obsolete($"Use {nameof(McpServer)}.{nameof(McpServer.RequestRootsAsync)} instead.")] // See: https://github.com/modelcontextprotocol/csharp-sdk/issues/774 public static ValueTask RequestRootsAsync( this IMcpServer server, ListRootsRequestParams request, CancellationToken cancellationToken = default) - { - Throw.IfNull(server); - ThrowIfRootsUnsupported(server); - - return server.SendRequestAsync( - RequestMethods.RootsList, - request, - McpJsonUtilities.JsonContext.Default.ListRootsRequestParams, - McpJsonUtilities.JsonContext.Default.ListRootsResult, - cancellationToken: cancellationToken); - } + => AsServerOrThrow(server).RequestRootsAsync(request, cancellationToken); /// /// Requests additional information from the user via the client, allowing the server to elicit structured data. @@ -231,327 +102,30 @@ public static ValueTask RequestRootsAsync( /// /// This method requires the client to support the elicitation capability. /// + [Obsolete($"Use {nameof(McpServer)}.{nameof(McpServer.ElicitAsync)} instead.")] // See: https://github.com/modelcontextprotocol/csharp-sdk/issues/774 public static ValueTask ElicitAsync( this IMcpServer server, ElicitRequestParams request, CancellationToken cancellationToken = default) - { - Throw.IfNull(server); - ThrowIfElicitationUnsupported(server); - - return server.SendRequestAsync( - RequestMethods.ElicitationCreate, - request, - McpJsonUtilities.JsonContext.Default.ElicitRequestParams, - McpJsonUtilities.JsonContext.Default.ElicitResult, - cancellationToken: cancellationToken); - } + => AsServerOrThrow(server).ElicitAsync(request, cancellationToken); - /// - /// Requests additional information from the user via the client, constructing a request schema from the - /// public serializable properties of and deserializing the response into . - /// - /// The type describing the expected input shape. Only primitive members are supported (string, number, boolean, enum). - /// The server initiating the request. - /// The message to present to the user. - /// Serializer options that influence property naming and deserialization. - /// The to monitor for cancellation requests. - /// An with the user's response, if accepted. - /// - /// Elicitation uses a constrained subset of JSON Schema and only supports strings, numbers/integers, booleans and string enums. - /// Unsupported member types are ignored when constructing the schema. - /// - public static async ValueTask> ElicitAsync( - this IMcpServer server, - string message, - JsonSerializerOptions? serializerOptions = null, - CancellationToken cancellationToken = default) + [MethodImpl(MethodImplOptions.AggressiveInlining)] +#pragma warning disable CS0618 // Type or member is obsolete + private static McpServer AsServerOrThrow(IMcpServer server, [CallerMemberName] string memberName = "") +#pragma warning restore CS0618 // Type or member is obsolete { - Throw.IfNull(server); - ThrowIfElicitationUnsupported(server); - - serializerOptions ??= McpJsonUtilities.DefaultOptions; - serializerOptions.MakeReadOnly(); - - var dict = s_elicitResultSchemaCache.GetValue(serializerOptions, _ => new()); - -#if NET - var schema = dict.GetOrAdd(typeof(T), static (t, s) => BuildRequestSchema(t, s), serializerOptions); -#else - var schema = dict.GetOrAdd(typeof(T), type => BuildRequestSchema(type, serializerOptions)); -#endif - - var request = new ElicitRequestParams - { - Message = message, - RequestedSchema = schema, - }; - - var raw = await server.ElicitAsync(request, cancellationToken).ConfigureAwait(false); - - if (!raw.IsAccepted || raw.Content is null) - { - return new ElicitResult { Action = raw.Action, Content = default }; - } - - var obj = new JsonObject(); - foreach (var kvp in raw.Content) + if (server is not McpServer mcpServer) { - obj[kvp.Key] = JsonNode.Parse(kvp.Value.GetRawText()); + ThrowInvalidSessionType(memberName); } - T? typed = JsonSerializer.Deserialize(obj, serializerOptions.GetTypeInfo()); - return new ElicitResult { Action = raw.Action, Content = typed }; - } + return mcpServer; - /// - /// Builds a request schema for elicitation based on the public serializable properties of . - /// - /// The type of the schema being built. - /// The serializer options to use. - /// The built request schema. - /// - private static ElicitRequestParams.RequestSchema BuildRequestSchema(Type type, JsonSerializerOptions serializerOptions) - { - var schema = new ElicitRequestParams.RequestSchema(); - var props = schema.Properties; - - JsonTypeInfo typeInfo = serializerOptions.GetTypeInfo(type); - - if (typeInfo.Kind != JsonTypeInfoKind.Object) - { - throw new McpException($"Type '{type.FullName}' is not supported for elicitation requests."); - } - - foreach (JsonPropertyInfo pi in typeInfo.Properties) - { - var def = CreatePrimitiveSchema(pi.PropertyType, serializerOptions); - props[pi.Name] = def; - } - - return schema; - } - - /// - /// Creates a primitive schema definition for the specified type, if supported. - /// - /// The type to create the schema for. - /// The serializer options to use. - /// The created primitive schema definition. - /// Thrown when the type is not supported. - private static ElicitRequestParams.PrimitiveSchemaDefinition CreatePrimitiveSchema(Type type, JsonSerializerOptions serializerOptions) - { - if (type.IsGenericType && type.GetGenericTypeDefinition() == typeof(Nullable<>)) - { - throw new McpException($"Type '{type.FullName}' is not a supported property type for elicitation requests. Nullable types are not supported."); - } - - var typeInfo = serializerOptions.GetTypeInfo(type); - - if (typeInfo.Kind != JsonTypeInfoKind.None) - { - throw new McpException($"Type '{type.FullName}' is not a supported property type for elicitation requests."); - } - - var jsonElement = AIJsonUtilities.CreateJsonSchema(type, serializerOptions: serializerOptions); - - if (!TryValidateElicitationPrimitiveSchema(jsonElement, type, out var error)) - { - throw new McpException(error); - } - - var primitiveSchemaDefinition = - jsonElement.Deserialize(McpJsonUtilities.JsonContext.Default.PrimitiveSchemaDefinition); - - if (primitiveSchemaDefinition is null) - throw new McpException($"Type '{type.FullName}' is not a supported property type for elicitation requests."); - - return primitiveSchemaDefinition; - } - - /// - /// Validate the produced schema strictly to the subset we support. We only accept an object schema - /// with a supported primitive type keyword and no additional unsupported keywords.Reject things like - /// {}, 'true', or schemas that include unrelated keywords(e.g.items, properties, patternProperties, etc.). - /// - /// The schema to validate. - /// The type of the schema being validated, just for reporting errors. - /// The error message, if validation fails. - /// - private static bool TryValidateElicitationPrimitiveSchema(JsonElement schema, Type type, - [NotNullWhen(false)] out string? error) - { - if (schema.ValueKind is not JsonValueKind.Object) - { - error = $"Schema generated for type '{type.FullName}' is invalid: expected an object schema."; - return false; - } - - if (!schema.TryGetProperty("type", out JsonElement typeProperty) - || typeProperty.ValueKind is not JsonValueKind.String) - { - error = $"Schema generated for type '{type.FullName}' is invalid: missing or invalid 'type' keyword."; - return false; - } - - var typeKeyword = typeProperty.GetString(); - - if (string.IsNullOrEmpty(typeKeyword)) - { - error = $"Schema generated for type '{type.FullName}' is invalid: empty 'type' value."; - return false; - } - - if (typeKeyword is not ("string" or "number" or "integer" or "boolean")) - { - error = $"Schema generated for type '{type.FullName}' is invalid: unsupported primitive type '{typeKeyword}'."; - return false; - } - - s_elicitAllowedProperties ??= new() - { - ["string"] = ["type", "title", "description", "minLength", "maxLength", "format", "enum", "enumNames"], - ["number"] = ["type", "title", "description", "minimum", "maximum"], - ["integer"] = ["type", "title", "description", "minimum", "maximum"], - ["boolean"] = ["type", "title", "description", "default"] - }; - - var allowed = s_elicitAllowedProperties[typeKeyword]; - - foreach (JsonProperty prop in schema.EnumerateObject()) - { - if (!allowed.Contains(prop.Name)) - { - error = $"The property '{type.FullName}.{prop.Name}' is not supported for elicitation."; - return false; - } - } - - error = string.Empty; - return true; - } - - private static void ThrowIfSamplingUnsupported(IMcpServer server) - { - if (server.ClientCapabilities?.Sampling is null) - { - if (server.ServerOptions.KnownClientInfo is not null) - { - throw new InvalidOperationException("Sampling is not supported in stateless mode."); - } - - throw new InvalidOperationException("Client does not support sampling."); - } - } - - private static void ThrowIfRootsUnsupported(IMcpServer server) - { - if (server.ClientCapabilities?.Roots is null) - { - if (server.ServerOptions.KnownClientInfo is not null) - { - throw new InvalidOperationException("Roots are not supported in stateless mode."); - } - - throw new InvalidOperationException("Client does not support roots."); - } - } - - private static void ThrowIfElicitationUnsupported(IMcpServer server) - { - if (server.ClientCapabilities?.Elicitation is null) - { - if (server.ServerOptions.KnownClientInfo is not null) - { - throw new InvalidOperationException("Elicitation is not supported in stateless mode."); - } - - throw new InvalidOperationException("Client does not support elicitation requests."); - } - } - - /// Provides an implementation that's implemented via client sampling. - private sealed class SamplingChatClient(IMcpServer server) : IChatClient - { - /// - public Task GetResponseAsync(IEnumerable messages, ChatOptions? options = null, CancellationToken cancellationToken = default) => - server.SampleAsync(messages, options, cancellationToken); - - /// - async IAsyncEnumerable IChatClient.GetStreamingResponseAsync( - IEnumerable messages, ChatOptions? options, [EnumeratorCancellation] CancellationToken cancellationToken) - { - var response = await GetResponseAsync(messages, options, cancellationToken).ConfigureAwait(false); - foreach (var update in response.ToChatResponseUpdates()) - { - yield return update; - } - } - - /// - object? IChatClient.GetService(Type serviceType, object? serviceKey) - { - Throw.IfNull(serviceType); - - return - serviceKey is not null ? null : - serviceType.IsInstanceOfType(this) ? this : - serviceType.IsInstanceOfType(server) ? server : - null; - } - - /// - void IDisposable.Dispose() { } // nop - } - - /// - /// Provides an implementation for creating loggers - /// that send logging message notifications to the client for logged messages. - /// - private sealed class ClientLoggerProvider(IMcpServer server) : ILoggerProvider - { - /// - public ILogger CreateLogger(string categoryName) - { - Throw.IfNull(categoryName); - - return new ClientLogger(server, categoryName); - } - - /// - void IDisposable.Dispose() { } - - private sealed class ClientLogger(IMcpServer server, string categoryName) : ILogger - { - /// - public IDisposable? BeginScope(TState state) where TState : notnull => - null; - - /// - public bool IsEnabled(LogLevel logLevel) => - server?.LoggingLevel is { } loggingLevel && - McpServer.ToLoggingLevel(logLevel) >= loggingLevel; - - /// - public void Log(LogLevel logLevel, EventId eventId, TState state, Exception? exception, Func formatter) - { - if (!IsEnabled(logLevel)) - { - return; - } - - Throw.IfNull(formatter); - - Log(logLevel, formatter(state, exception)); - - void Log(LogLevel logLevel, string message) - { - _ = server.SendNotificationAsync(NotificationMethods.LoggingMessageNotification, new LoggingMessageNotificationParams - { - Level = McpServer.ToLoggingLevel(logLevel), - Data = JsonSerializer.SerializeToElement(message, McpJsonUtilities.JsonContext.Default.String), - Logger = categoryName, - }); - } - } - } + [DoesNotReturn] + [MethodImpl(MethodImplOptions.NoInlining)] + static void ThrowInvalidSessionType(string memberName) + => throw new InvalidOperationException( + $"Only arguments assignable to '{nameof(McpServer)}' are supported. " + + $"Prefer using '{nameof(McpServer)}.{memberName}' instead, as " + + $"'{nameof(McpServerExtensions)}.{memberName}' is obsolete and will be " + + $"removed in the future."); } } diff --git a/src/ModelContextProtocol.Core/Server/McpServerFactory.cs b/src/ModelContextProtocol.Core/Server/McpServerFactory.cs index 50d4188b5..00ecd8b13 100644 --- a/src/ModelContextProtocol.Core/Server/McpServerFactory.cs +++ b/src/ModelContextProtocol.Core/Server/McpServerFactory.cs @@ -1,4 +1,4 @@ -using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Logging; using ModelContextProtocol.Protocol; namespace ModelContextProtocol.Server; @@ -10,6 +10,7 @@ namespace ModelContextProtocol.Server; /// This is the recommended way to create instances. /// The factory handles proper initialization of server instances with the required dependencies. /// +[Obsolete($"Use {nameof(McpServer)}.{nameof(McpServer.Create)} instead.")] // See: https://github.com/modelcontextprotocol/csharp-sdk/issues/774 public static class McpServerFactory { /// @@ -27,10 +28,5 @@ public static IMcpServer Create( McpServerOptions serverOptions, ILoggerFactory? loggerFactory = null, IServiceProvider? serviceProvider = null) - { - Throw.IfNull(transport); - Throw.IfNull(serverOptions); - - return new McpServer(transport, serverOptions, loggerFactory, serviceProvider); - } + => McpServer.Create(transport, serverOptions, loggerFactory, serviceProvider); } diff --git a/src/ModelContextProtocol.Core/Server/McpServerImpl.cs b/src/ModelContextProtocol.Core/Server/McpServerImpl.cs new file mode 100644 index 000000000..1ece8af23 --- /dev/null +++ b/src/ModelContextProtocol.Core/Server/McpServerImpl.cs @@ -0,0 +1,750 @@ +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Logging.Abstractions; +using ModelContextProtocol.Protocol; +using System.Runtime.CompilerServices; +using System.Text.Json.Serialization.Metadata; + +namespace ModelContextProtocol.Server; + +// TODO: Fix merge conflicts in this file. + +/// +internal sealed partial class McpServerImpl : McpServer +{ + internal static Implementation DefaultImplementation { get; } = new() + { + Name = AssemblyNameHelper.DefaultAssemblyName.Name ?? nameof(McpServer), + Version = AssemblyNameHelper.DefaultAssemblyName.Version?.ToString() ?? "1.0.0", + }; + + private readonly ILogger _logger; + private readonly ITransport _sessionTransport; + private readonly bool _servicesScopePerRequest; + private readonly List _disposables = []; + private readonly NotificationHandlers _notificationHandlers; + private readonly RequestHandlers _requestHandlers; + private readonly McpSessionHandler _sessionHandler; + private readonly SemaphoreSlim _disposeLock = new(1, 1); + + private ClientCapabilities? _clientCapabilities; + private Implementation? _clientInfo; + + private readonly string _serverOnlyEndpointName; + private string _endpointName; + private int _started; + + private bool _disposed; + + /// Holds a boxed value for the server. + /// + /// Initialized to non-null the first time SetLevel is used. This is stored as a strong box + /// rather than a nullable to be able to manipulate it atomically. + /// + private StrongBox? _loggingLevel; + + /// + /// Creates a new instance of . + /// + /// Transport to use for the server representing an already-established session. + /// Configuration options for this server, including capabilities. + /// Make sure to accurately reflect exactly what capabilities the server supports and does not support. + /// Logger factory to use for logging + /// Optional service provider to use for dependency injection + /// The server was incorrectly configured. + public McpServerImpl(ITransport transport, McpServerOptions options, ILoggerFactory? loggerFactory, IServiceProvider? serviceProvider) + { + Throw.IfNull(transport); + Throw.IfNull(options); + + options ??= new(); + + _sessionTransport = transport; + ServerOptions = options; + Services = serviceProvider; + _serverOnlyEndpointName = $"Server ({options.ServerInfo?.Name ?? DefaultImplementation.Name} {options.ServerInfo?.Version ?? DefaultImplementation.Version})"; + _endpointName = _serverOnlyEndpointName; + _servicesScopePerRequest = options.ScopeRequests; + _logger = loggerFactory?.CreateLogger() ?? NullLogger.Instance; + + _clientInfo = options.KnownClientInfo; + UpdateEndpointNameWithClientInfo(); + + _notificationHandlers = new(); + _requestHandlers = []; + + // Configure all request handlers based on the supplied options. + ServerCapabilities = new(); + ConfigureInitialize(options); + ConfigureTools(options); + ConfigurePrompts(options); + ConfigureResources(options); + ConfigureLogging(options); + ConfigureCompletion(options); + ConfigureExperimental(options); + ConfigurePing(); + + // Register any notification handlers that were provided. + if (options.Capabilities?.NotificationHandlers is { } notificationHandlers) + { + _notificationHandlers.RegisterRange(notificationHandlers); + } + + // Now that everything has been configured, subscribe to any necessary notifications. + if (transport is not StreamableHttpServerTransport streamableHttpTransport || streamableHttpTransport.Stateless is false) + { + Register(ServerOptions.Capabilities?.Tools?.ToolCollection, NotificationMethods.ToolListChangedNotification); + Register(ServerOptions.Capabilities?.Prompts?.PromptCollection, NotificationMethods.PromptListChangedNotification); + Register(ServerOptions.Capabilities?.Resources?.ResourceCollection, NotificationMethods.ResourceListChangedNotification); + + void Register(McpServerPrimitiveCollection? collection, string notificationMethod) + where TPrimitive : IMcpServerPrimitive + { + if (collection is not null) + { + EventHandler changed = (sender, e) => _ = this.SendNotificationAsync(notificationMethod); + collection.Changed += changed; + _disposables.Add(() => collection.Changed -= changed); + } + } + } + + // And initialize the session. + _sessionHandler = new McpSessionHandler(isServer: true, _sessionTransport, _endpointName!, _requestHandlers, _notificationHandlers, _logger); + } + + /// + public override string? SessionId => _sessionTransport.SessionId; + + /// + public ServerCapabilities ServerCapabilities { get; } = new(); + + /// + public override ClientCapabilities? ClientCapabilities => _clientCapabilities; + + /// + public override Implementation? ClientInfo => _clientInfo; + + /// + public override McpServerOptions ServerOptions { get; } + + /// + public override IServiceProvider? Services { get; } + + /// + public override LoggingLevel? LoggingLevel => _loggingLevel?.Value; + + /// + public override async Task RunAsync(CancellationToken cancellationToken = default) + { + if (Interlocked.Exchange(ref _started, 1) != 0) + { + throw new InvalidOperationException($"{nameof(RunAsync)} must only be called once."); + } + + try + { + await _sessionHandler.ProcessMessagesAsync(cancellationToken).ConfigureAwait(false); + } + finally + { + await DisposeAsync().ConfigureAwait(false); + } + } + + + /// + public override Task SendRequestAsync(JsonRpcRequest request, CancellationToken cancellationToken = default) + => _sessionHandler.SendRequestAsync(request, cancellationToken); + + /// + public override Task SendMessageAsync(JsonRpcMessage message, CancellationToken cancellationToken = default) + => _sessionHandler.SendMessageAsync(message, cancellationToken); + + /// + public override IAsyncDisposable RegisterNotificationHandler(string method, Func handler) + => _sessionHandler.RegisterNotificationHandler(method, handler); + + /// + public override async ValueTask DisposeAsync() + { + using var _ = await _disposeLock.LockAsync().ConfigureAwait(false); + + if (_disposed) + { + return; + } + + _disposed = true; + + _disposables.ForEach(d => d()); + await _sessionHandler.DisposeAsync().ConfigureAwait(false); + } + + private void ConfigurePing() + { + SetHandler(RequestMethods.Ping, + async (request, _) => new PingResult(), + McpJsonUtilities.JsonContext.Default.JsonNode, + McpJsonUtilities.JsonContext.Default.PingResult); + } + + private void ConfigureInitialize(McpServerOptions options) + { + _requestHandlers.Set(RequestMethods.Initialize, + async (request, _, _) => + { + _clientCapabilities = request?.Capabilities ?? new(); + _clientInfo = request?.ClientInfo; + + // Use the ClientInfo to update the session EndpointName for logging. + UpdateEndpointNameWithClientInfo(); + _sessionHandler.EndpointName = _endpointName; + + // Negotiate a protocol version. If the server options provide one, use that. + // Otherwise, try to use whatever the client requested as long as it's supported. + // If it's not supported, fall back to the latest supported version. + string? protocolVersion = options.ProtocolVersion; + if (protocolVersion is null) + { + protocolVersion = request?.ProtocolVersion is string clientProtocolVersion && McpSessionHandler.SupportedProtocolVersions.Contains(clientProtocolVersion) ? + clientProtocolVersion : + McpSessionHandler.LatestProtocolVersion; + } + + return new InitializeResult + { + ProtocolVersion = protocolVersion, + Instructions = options.ServerInstructions, + ServerInfo = options.ServerInfo ?? DefaultImplementation, + Capabilities = ServerCapabilities ?? new(), + }; + }, + McpJsonUtilities.JsonContext.Default.InitializeRequestParams, + McpJsonUtilities.JsonContext.Default.InitializeResult); + } + + private void ConfigureCompletion(McpServerOptions options) + { + if (options.Capabilities?.Completions is not { } completionsCapability) + { + return; + } + + var completeHandler = completionsCapability.CompleteHandler ?? (static async (_, __) => new CompleteResult()); + completeHandler = BuildFilterPipeline(completeHandler, options.Filters.CompleteFilters); + + ServerCapabilities.Completions = new() + { + CompleteHandler = completeHandler + }; + + SetHandler( + RequestMethods.CompletionComplete, + ServerCapabilities.Completions.CompleteHandler, + McpJsonUtilities.JsonContext.Default.CompleteRequestParams, + McpJsonUtilities.JsonContext.Default.CompleteResult); + } + + private void ConfigureExperimental(McpServerOptions options) + { + ServerCapabilities.Experimental = options.Capabilities?.Experimental; + } + + private void ConfigureResources(McpServerOptions options) + { + if (options.Capabilities?.Resources is not { } resourcesCapability) + { + return; + } + + ServerCapabilities.Resources = new(); + + var listResourcesHandler = resourcesCapability.ListResourcesHandler ?? (static async (_, __) => new ListResourcesResult()); + var listResourceTemplatesHandler = resourcesCapability.ListResourceTemplatesHandler ?? (static async (_, __) => new ListResourceTemplatesResult()); + var readResourceHandler = resourcesCapability.ReadResourceHandler ?? (static async (request, _) => throw new McpException($"Unknown resource URI: '{request.Params?.Uri}'", McpErrorCode.InvalidParams)); + var subscribeHandler = resourcesCapability.SubscribeToResourcesHandler ?? (static async (_, __) => new EmptyResult()); + var unsubscribeHandler = resourcesCapability.UnsubscribeFromResourcesHandler ?? (static async (_, __) => new EmptyResult()); + var resources = resourcesCapability.ResourceCollection; + var listChanged = resourcesCapability.ListChanged; + var subscribe = resourcesCapability.Subscribe; + + // Handle resources provided via DI. + if (resources is { IsEmpty: false }) + { + var originalListResourcesHandler = listResourcesHandler; + listResourcesHandler = async (request, cancellationToken) => + { + ListResourcesResult result = originalListResourcesHandler is not null ? + await originalListResourcesHandler(request, cancellationToken).ConfigureAwait(false) : + new(); + + if (request.Params?.Cursor is null) + { + foreach (var r in resources) + { + if (r.ProtocolResource is { } resource) + { + result.Resources.Add(resource); + } + } + } + + return result; + }; + + var originalListResourceTemplatesHandler = listResourceTemplatesHandler; + listResourceTemplatesHandler = async (request, cancellationToken) => + { + ListResourceTemplatesResult result = originalListResourceTemplatesHandler is not null ? + await originalListResourceTemplatesHandler(request, cancellationToken).ConfigureAwait(false) : + new(); + + if (request.Params?.Cursor is null) + { + foreach (var rt in resources) + { + if (rt.IsTemplated) + { + result.ResourceTemplates.Add(rt.ProtocolResourceTemplate); + } + } + } + + return result; + }; + + // Synthesize read resource handler, which covers both resources and resource templates. + var originalReadResourceHandler = readResourceHandler; + readResourceHandler = async (request, cancellationToken) => + { + if (request.MatchedPrimitive is McpServerResource matchedResource) + { + if (await matchedResource.ReadAsync(request, cancellationToken).ConfigureAwait(false) is { } result) + { + return result; + } + } + + return await originalReadResourceHandler(request, cancellationToken).ConfigureAwait(false); + }; + + listChanged = true; + + // TODO: Implement subscribe/unsubscribe logic for resource and resource template collections. + // subscribe = true; + } + + listResourcesHandler = BuildFilterPipeline(listResourcesHandler, options.Filters.ListResourcesFilters); + listResourceTemplatesHandler = BuildFilterPipeline(listResourceTemplatesHandler, options.Filters.ListResourceTemplatesFilters); + readResourceHandler = BuildFilterPipeline(readResourceHandler, options.Filters.ReadResourceFilters, handler => + async (request, cancellationToken) => + { + // Initial handler that sets MatchedPrimitive + if (request.Params?.Uri is { } uri && resources is not null) + { + // First try an O(1) lookup by exact match. + if (resources.TryGetPrimitive(uri, out var resource)) + { + request.MatchedPrimitive = resource; + } + else + { + // Fall back to an O(N) lookup, trying to match against each URI template. + // The number of templates is controlled by the server developer, and the number is expected to be + // not terribly large. If that changes, this can be tweaked to enable a more efficient lookup. + foreach (var resourceTemplate in resources) + { + // Check if this template would handle the request by testing if ReadAsync would succeed + if (resourceTemplate.IsTemplated) + { + // This is a simplified check - a more robust implementation would match the URI pattern + // For now, we'll let the actual handler attempt the match + request.MatchedPrimitive = resourceTemplate; + break; + } + } + } + } + + return await handler(request, cancellationToken).ConfigureAwait(false); + }); + subscribeHandler = BuildFilterPipeline(subscribeHandler, options.Filters.SubscribeToResourcesFilters); + unsubscribeHandler = BuildFilterPipeline(unsubscribeHandler, options.Filters.UnsubscribeFromResourcesFilters); + + ServerCapabilities.Resources.ListResourcesHandler = listResourcesHandler; + ServerCapabilities.Resources.ListResourceTemplatesHandler = listResourceTemplatesHandler; + ServerCapabilities.Resources.ReadResourceHandler = readResourceHandler; + ServerCapabilities.Resources.ResourceCollection = resources; + ServerCapabilities.Resources.SubscribeToResourcesHandler = subscribeHandler; + ServerCapabilities.Resources.UnsubscribeFromResourcesHandler = unsubscribeHandler; + ServerCapabilities.Resources.ListChanged = listChanged; + ServerCapabilities.Resources.Subscribe = subscribe; + + SetHandler( + RequestMethods.ResourcesList, + listResourcesHandler, + McpJsonUtilities.JsonContext.Default.ListResourcesRequestParams, + McpJsonUtilities.JsonContext.Default.ListResourcesResult); + + SetHandler( + RequestMethods.ResourcesTemplatesList, + listResourceTemplatesHandler, + McpJsonUtilities.JsonContext.Default.ListResourceTemplatesRequestParams, + McpJsonUtilities.JsonContext.Default.ListResourceTemplatesResult); + + SetHandler( + RequestMethods.ResourcesRead, + readResourceHandler, + McpJsonUtilities.JsonContext.Default.ReadResourceRequestParams, + McpJsonUtilities.JsonContext.Default.ReadResourceResult); + + SetHandler( + RequestMethods.ResourcesSubscribe, + subscribeHandler, + McpJsonUtilities.JsonContext.Default.SubscribeRequestParams, + McpJsonUtilities.JsonContext.Default.EmptyResult); + + SetHandler( + RequestMethods.ResourcesUnsubscribe, + unsubscribeHandler, + McpJsonUtilities.JsonContext.Default.UnsubscribeRequestParams, + McpJsonUtilities.JsonContext.Default.EmptyResult); + } + + private void ConfigurePrompts(McpServerOptions options) + { + if (options.Capabilities?.Prompts is not { } promptsCapability) + { + return; + } + + ServerCapabilities.Prompts = new(); + + var listPromptsHandler = promptsCapability.ListPromptsHandler ?? (static async (_, __) => new ListPromptsResult()); + var getPromptHandler = promptsCapability.GetPromptHandler ?? (static async (request, _) => throw new McpException($"Unknown prompt: '{request.Params?.Name}'", McpErrorCode.InvalidParams)); + var prompts = promptsCapability.PromptCollection; + var listChanged = promptsCapability.ListChanged; + + // Handle tools provided via DI by augmenting the handlers to incorporate them. + if (prompts is { IsEmpty: false }) + { + var originalListPromptsHandler = listPromptsHandler; + listPromptsHandler = async (request, cancellationToken) => + { + ListPromptsResult result = originalListPromptsHandler is not null ? + await originalListPromptsHandler(request, cancellationToken).ConfigureAwait(false) : + new(); + + if (request.Params?.Cursor is null) + { + foreach (var p in prompts) + { + result.Prompts.Add(p.ProtocolPrompt); + } + } + + return result; + }; + + var originalGetPromptHandler = getPromptHandler; + getPromptHandler = (request, cancellationToken) => + { + if (request.MatchedPrimitive is McpServerPrompt prompt) + { + return prompt.GetAsync(request, cancellationToken); + } + + return originalGetPromptHandler(request, cancellationToken); + }; + + listChanged = true; + } + + listPromptsHandler = BuildFilterPipeline(listPromptsHandler, options.Filters.ListPromptsFilters); + getPromptHandler = BuildFilterPipeline(getPromptHandler, options.Filters.GetPromptFilters, handler => + (request, cancellationToken) => + { + // Initial handler that sets MatchedPrimitive + if (request.Params?.Name is { } promptName && prompts is not null && + prompts.TryGetPrimitive(promptName, out var prompt)) + { + request.MatchedPrimitive = prompt; + } + + return handler(request, cancellationToken); + }); + + ServerCapabilities.Prompts.ListPromptsHandler = listPromptsHandler; + ServerCapabilities.Prompts.GetPromptHandler = getPromptHandler; + ServerCapabilities.Prompts.PromptCollection = prompts; + ServerCapabilities.Prompts.ListChanged = listChanged; + + SetHandler( + RequestMethods.PromptsList, + listPromptsHandler, + McpJsonUtilities.JsonContext.Default.ListPromptsRequestParams, + McpJsonUtilities.JsonContext.Default.ListPromptsResult); + + SetHandler( + RequestMethods.PromptsGet, + getPromptHandler, + McpJsonUtilities.JsonContext.Default.GetPromptRequestParams, + McpJsonUtilities.JsonContext.Default.GetPromptResult); + } + + private void ConfigureTools(McpServerOptions options) + { + if (options.Capabilities?.Tools is not { } toolsCapability) + { + return; + } + + ServerCapabilities.Tools = new(); + + var listToolsHandler = toolsCapability.ListToolsHandler ?? (static async (_, __) => new ListToolsResult()); + var callToolHandler = toolsCapability.CallToolHandler ?? (static async (request, _) => throw new McpException($"Unknown tool: '{request.Params?.Name}'", McpErrorCode.InvalidParams)); + var tools = toolsCapability.ToolCollection; + var listChanged = toolsCapability.ListChanged; + + // Handle tools provided via DI by augmenting the handlers to incorporate them. + if (tools is { IsEmpty: false }) + { + var originalListToolsHandler = listToolsHandler; + listToolsHandler = async (request, cancellationToken) => + { + ListToolsResult result = originalListToolsHandler is not null ? + await originalListToolsHandler(request, cancellationToken).ConfigureAwait(false) : + new(); + + if (request.Params?.Cursor is null) + { + foreach (var t in tools) + { + result.Tools.Add(t.ProtocolTool); + } + } + + return result; + }; + + var originalCallToolHandler = callToolHandler; + callToolHandler = (request, cancellationToken) => + { + if (request.MatchedPrimitive is McpServerTool tool) + { + return tool.InvokeAsync(request, cancellationToken); + } + + return originalCallToolHandler(request, cancellationToken); + }; + + listChanged = true; + } + + listToolsHandler = BuildFilterPipeline(listToolsHandler, options.Filters.ListToolsFilters); + callToolHandler = BuildFilterPipeline(callToolHandler, options.Filters.CallToolFilters, handler => + (request, cancellationToken) => + { + // Initial handler that sets MatchedPrimitive + if (request.Params?.Name is { } toolName && tools is not null && + tools.TryGetPrimitive(toolName, out var tool)) + { + request.MatchedPrimitive = tool; + } + + return handler(request, cancellationToken); + }, handler => + async (request, cancellationToken) => + { + // Final handler that provides exception handling only for tool execution + // Only wrap tool execution in try-catch, not tool resolution + if (request.MatchedPrimitive is McpServerTool) + { + try + { + return await handler(request, cancellationToken).ConfigureAwait(false); + } + catch (Exception e) when (e is not OperationCanceledException) + { + ToolCallError(request.Params?.Name ?? string.Empty, e); + + string errorMessage = e is McpException ? + $"An error occurred invoking '{request.Params?.Name}': {e.Message}" : + $"An error occurred invoking '{request.Params?.Name}'."; + + return new() + { + IsError = true, + Content = [new TextContentBlock { Text = errorMessage }], + }; + } + } + else + { + // For unmatched tools, let exceptions bubble up as protocol errors + return await handler(request, cancellationToken).ConfigureAwait(false); + } + }); + + ServerCapabilities.Tools.ListToolsHandler = listToolsHandler; + ServerCapabilities.Tools.CallToolHandler = callToolHandler; + ServerCapabilities.Tools.ToolCollection = tools; + ServerCapabilities.Tools.ListChanged = listChanged; + + SetHandler( + RequestMethods.ToolsList, + listToolsHandler, + McpJsonUtilities.JsonContext.Default.ListToolsRequestParams, + McpJsonUtilities.JsonContext.Default.ListToolsResult); + + SetHandler( + RequestMethods.ToolsCall, + callToolHandler, + McpJsonUtilities.JsonContext.Default.CallToolRequestParams, + McpJsonUtilities.JsonContext.Default.CallToolResult); + } + + private void ConfigureLogging(McpServerOptions options) + { + // We don't require that the handler be provided, as we always store the provided log level to the server. + var setLoggingLevelHandler = options.Capabilities?.Logging?.SetLoggingLevelHandler; + + // Apply filters to the handler + if (setLoggingLevelHandler is not null) + { + setLoggingLevelHandler = BuildFilterPipeline(setLoggingLevelHandler, options.Filters.SetLoggingLevelFilters); + } + + ServerCapabilities.Logging = new(); + ServerCapabilities.Logging.SetLoggingLevelHandler = setLoggingLevelHandler; + + _requestHandlers.Set( + RequestMethods.LoggingSetLevel, + (request, jsonRpcRequest, cancellationToken) => + { + // Store the provided level. + if (request is not null) + { + if (_loggingLevel is null) + { + Interlocked.CompareExchange(ref _loggingLevel, new(request.Level), null); + } + + _loggingLevel.Value = request.Level; + } + + // If a handler was provided, now delegate to it. + if (setLoggingLevelHandler is not null) + { + return InvokeHandlerAsync(setLoggingLevelHandler, request, jsonRpcRequest, cancellationToken); + } + + // Otherwise, consider it handled. + return new ValueTask(EmptyResult.Instance); + }, + McpJsonUtilities.JsonContext.Default.SetLevelRequestParams, + McpJsonUtilities.JsonContext.Default.EmptyResult); + } + + private ValueTask InvokeHandlerAsync( + McpRequestHandler handler, + TParams? args, + JsonRpcRequest jsonRpcRequest, + CancellationToken cancellationToken = default) + { + return _servicesScopePerRequest ? + InvokeScopedAsync(handler, args, jsonRpcRequest, cancellationToken) : + handler(new(new DestinationBoundMcpServer(this, jsonRpcRequest.Context?.RelatedTransport), jsonRpcRequest) { Params = args }, cancellationToken); + + async ValueTask InvokeScopedAsync( + McpRequestHandler handler, + TParams? args, + JsonRpcRequest jsonRpcRequest, + CancellationToken cancellationToken) + { + var scope = Services?.GetService()?.CreateAsyncScope(); + try + { + return await handler( + new RequestContext(new DestinationBoundMcpServer(this, jsonRpcRequest.Context?.RelatedTransport), jsonRpcRequest) + { + Services = scope?.ServiceProvider ?? Services, + Params = args + }, + cancellationToken).ConfigureAwait(false); + } + finally + { + if (scope is not null) + { + await scope.Value.DisposeAsync().ConfigureAwait(false); + } + } + } + } + + private void SetHandler( + string method, + McpRequestHandler handler, + JsonTypeInfo requestTypeInfo, + JsonTypeInfo responseTypeInfo) + { + _requestHandlers.Set(method, + (request, jsonRpcRequest, cancellationToken) => + InvokeHandlerAsync(handler, request, jsonRpcRequest, cancellationToken), + requestTypeInfo, responseTypeInfo); + } + + private static McpRequestHandler BuildFilterPipeline( + McpRequestHandler baseHandler, + List> filters, + McpRequestFilter? initialHandler = null, + McpRequestFilter? finalHandler = null) + { + var current = baseHandler; + + if (finalHandler is not null) + { + current = finalHandler(current); + } + + for (int i = filters.Count - 1; i >= 0; i--) + { + current = filters[i](current); + } + + if (initialHandler is not null) + { + current = initialHandler(current); + } + + return current; + } + + private void UpdateEndpointNameWithClientInfo() + { + if (ClientInfo is null) + { + return; + } + + _endpointName = $"{_serverOnlyEndpointName}, Client ({ClientInfo.Name} {ClientInfo.Version})"; + } + + /// Maps a to a . + internal static LoggingLevel ToLoggingLevel(LogLevel level) => + level switch + { + LogLevel.Trace => Protocol.LoggingLevel.Debug, + LogLevel.Debug => Protocol.LoggingLevel.Debug, + LogLevel.Information => Protocol.LoggingLevel.Info, + LogLevel.Warning => Protocol.LoggingLevel.Warning, + LogLevel.Error => Protocol.LoggingLevel.Error, + LogLevel.Critical => Protocol.LoggingLevel.Critical, + _ => Protocol.LoggingLevel.Emergency, + }; + + [LoggerMessage(Level = LogLevel.Error, Message = "\"{ToolName}\" threw an unhandled exception.")] + private partial void ToolCallError(string toolName, Exception exception); +} diff --git a/src/ModelContextProtocol.Core/Server/McpServerPrompt.cs b/src/ModelContextProtocol.Core/Server/McpServerPrompt.cs index 746278791..a7fa0e242 100644 --- a/src/ModelContextProtocol.Core/Server/McpServerPrompt.cs +++ b/src/ModelContextProtocol.Core/Server/McpServerPrompt.cs @@ -15,8 +15,8 @@ namespace ModelContextProtocol.Server; /// is an abstract base class that represents an MCP prompt for use in the server (as opposed /// to , which provides the protocol representation of a prompt, and , which /// provides a client-side representation of a prompt). Instances of can be added into a -/// to be picked up automatically when is used to create -/// an , or added into a . +/// to be picked up automatically when is used to create +/// an , or added into a . /// /// /// Most commonly, instances are created using the static methods. @@ -34,7 +34,7 @@ namespace ModelContextProtocol.Server; /// /// /// parameters are automatically bound to a provided by the -/// and that respects any s sent by the client for this operation's +/// and that respects any s sent by the client for this operation's /// . /// /// @@ -45,7 +45,7 @@ namespace ModelContextProtocol.Server; /// /// /// -/// parameters are bound directly to the instance associated +/// parameters are bound directly to the instance associated /// with this request's . Such parameters may be used to understand /// what server is being used to process the request, and to interact with the client issuing the request to that server. /// @@ -210,7 +210,7 @@ public static McpServerPrompt Create( /// is . /// /// Unlike the other overloads of Create, the created by - /// does not provide all of the special parameter handling for MCP-specific concepts, like . + /// does not provide all of the special parameter handling for MCP-specific concepts, like . /// public static McpServerPrompt Create( AIFunction function, diff --git a/src/ModelContextProtocol.Core/Server/McpServerPromptAttribute.cs b/src/ModelContextProtocol.Core/Server/McpServerPromptAttribute.cs index c71e969db..ac9e247f6 100644 --- a/src/ModelContextProtocol.Core/Server/McpServerPromptAttribute.cs +++ b/src/ModelContextProtocol.Core/Server/McpServerPromptAttribute.cs @@ -25,7 +25,7 @@ namespace ModelContextProtocol.Server; /// /// /// parameters are automatically bound to a provided by the -/// and that respects any s sent by the client for this operation's +/// and that respects any s sent by the client for this operation's /// . /// /// @@ -36,7 +36,7 @@ namespace ModelContextProtocol.Server; /// /// /// -/// parameters are bound directly to the instance associated +/// parameters are bound directly to the instance associated /// with this request's . Such parameters may be used to understand /// what server is being used to process the request, and to interact with the client issuing the request to that server. /// diff --git a/src/ModelContextProtocol.Core/Server/McpServerResource.cs b/src/ModelContextProtocol.Core/Server/McpServerResource.cs index 9508cda0a..2a43e3349 100644 --- a/src/ModelContextProtocol.Core/Server/McpServerResource.cs +++ b/src/ModelContextProtocol.Core/Server/McpServerResource.cs @@ -13,7 +13,7 @@ namespace ModelContextProtocol.Server; /// is an abstract base class that represents an MCP resource for use in the server (as opposed /// to or , which provide the protocol representations of a resource). Instances of /// can be added into a to be picked up automatically when -/// is used to create an , or added into a . +/// is used to create an , or added into a . /// /// /// Most commonly, instances are created using the static methods. @@ -35,7 +35,7 @@ namespace ModelContextProtocol.Server; /// /// /// parameters are automatically bound to a provided by the -/// and that respects any s sent by the client for this operation's +/// and that respects any s sent by the client for this operation's /// . /// /// @@ -46,7 +46,7 @@ namespace ModelContextProtocol.Server; /// /// /// -/// parameters are bound directly to the instance associated +/// parameters are bound directly to the instance associated /// with this request's . Such parameters may be used to understand /// what server is being used to process the request, and to interact with the client issuing the request to that server. /// @@ -232,7 +232,7 @@ public static McpServerResource Create( /// is . /// /// Unlike the other overloads of Create, the created by - /// does not provide all of the special parameter handling for MCP-specific concepts, like . + /// does not provide all of the special parameter handling for MCP-specific concepts, like . /// public static McpServerResource Create( AIFunction function, diff --git a/src/ModelContextProtocol.Core/Server/McpServerResourceAttribute.cs b/src/ModelContextProtocol.Core/Server/McpServerResourceAttribute.cs index bc2f138f0..66c593e47 100644 --- a/src/ModelContextProtocol.Core/Server/McpServerResourceAttribute.cs +++ b/src/ModelContextProtocol.Core/Server/McpServerResourceAttribute.cs @@ -23,7 +23,7 @@ namespace ModelContextProtocol.Server; /// /// /// parameters are automatically bound to a provided by the -/// and that respects any s sent by the client for this operation's +/// and that respects any s sent by the client for this operation's /// . /// /// @@ -34,7 +34,7 @@ namespace ModelContextProtocol.Server; /// /// /// -/// parameters are bound directly to the instance associated +/// parameters are bound directly to the instance associated /// with this request's . Such parameters may be used to understand /// what server is being used to process the request, and to interact with the client issuing the request to that server. /// diff --git a/src/ModelContextProtocol.Core/Server/McpServerTool.cs b/src/ModelContextProtocol.Core/Server/McpServerTool.cs index baddf88f8..4136f5913 100644 --- a/src/ModelContextProtocol.Core/Server/McpServerTool.cs +++ b/src/ModelContextProtocol.Core/Server/McpServerTool.cs @@ -15,8 +15,8 @@ namespace ModelContextProtocol.Server; /// is an abstract base class that represents an MCP tool for use in the server (as opposed /// to , which provides the protocol representation of a tool, and , which /// provides a client-side representation of a tool). Instances of can be added into a -/// to be picked up automatically when is used to create -/// an , or added into a . +/// to be picked up automatically when is used to create +/// an , or added into a . /// /// /// Most commonly, instances are created using the static methods. @@ -35,7 +35,7 @@ namespace ModelContextProtocol.Server; /// /// /// parameters are automatically bound to a provided by the -/// and that respects any s sent by the client for this operation's +/// and that respects any s sent by the client for this operation's /// . The parameter is not included in the generated JSON schema. /// /// @@ -47,7 +47,7 @@ namespace ModelContextProtocol.Server; /// /// /// -/// parameters are not included in the JSON schema and are bound directly to the +/// parameters are not included in the JSON schema and are bound directly to the /// instance associated with this request's . Such parameters may be used to understand /// what server is being used to process the request, and to interact with the client issuing the request to that server. /// @@ -212,7 +212,7 @@ public static McpServerTool Create( /// is . /// /// Unlike the other overloads of Create, the created by - /// does not provide all of the special parameter handling for MCP-specific concepts, like . + /// does not provide all of the special parameter handling for MCP-specific concepts, like . /// public static McpServerTool Create( AIFunction function, diff --git a/src/ModelContextProtocol.Core/Server/McpServerToolAttribute.cs b/src/ModelContextProtocol.Core/Server/McpServerToolAttribute.cs index d4ea9eb75..7d5bf488b 100644 --- a/src/ModelContextProtocol.Core/Server/McpServerToolAttribute.cs +++ b/src/ModelContextProtocol.Core/Server/McpServerToolAttribute.cs @@ -26,7 +26,7 @@ namespace ModelContextProtocol.Server; /// /// /// parameters are automatically bound to a provided by the -/// and that respects any s sent by the client for this operation's +/// and that respects any s sent by the client for this operation's /// . The parameter is not included in the generated JSON schema. /// /// @@ -38,7 +38,7 @@ namespace ModelContextProtocol.Server; /// /// /// -/// parameters are not included in the JSON schema and are bound directly to the +/// parameters are not included in the JSON schema and are bound directly to the /// instance associated with this request's . Such parameters may be used to understand /// what server is being used to process the request, and to interact with the client issuing the request to that server. /// diff --git a/src/ModelContextProtocol.Core/Server/RequestContext.cs b/src/ModelContextProtocol.Core/Server/RequestContext.cs index 1141a815a..f75cea80b 100644 --- a/src/ModelContextProtocol.Core/Server/RequestContext.cs +++ b/src/ModelContextProtocol.Core/Server/RequestContext.cs @@ -15,7 +15,7 @@ namespace ModelContextProtocol.Server; public sealed class RequestContext { /// The server with which this instance is associated. - private IMcpServer _server; + private McpServer _server; private IDictionary? _items; @@ -24,7 +24,7 @@ public sealed class RequestContext /// /// The server with which this instance is associated. /// The JSON-RPC request associated with this context. - public RequestContext(IMcpServer server, JsonRpcRequest jsonRpcRequest) + public RequestContext(McpServer server, JsonRpcRequest jsonRpcRequest) { Throw.IfNull(server); Throw.IfNull(jsonRpcRequest); @@ -36,7 +36,7 @@ public RequestContext(IMcpServer server, JsonRpcRequest jsonRpcRequest) } /// Gets or sets the server with which this instance is associated. - public IMcpServer Server + public McpServer Server { get => _server; set @@ -63,10 +63,10 @@ public IMcpServer Server /// Gets or sets the services associated with this request. /// - /// This may not be the same instance stored in + /// This may not be the same instance stored in /// if was true, in which case this /// might be a scoped derived from the server's - /// . + /// . /// public IServiceProvider? Services { get; set; } diff --git a/src/ModelContextProtocol.Core/Server/RequestServiceProvider.cs b/src/ModelContextProtocol.Core/Server/RequestServiceProvider.cs index 38af614c2..9359ea157 100644 --- a/src/ModelContextProtocol.Core/Server/RequestServiceProvider.cs +++ b/src/ModelContextProtocol.Core/Server/RequestServiceProvider.cs @@ -6,8 +6,8 @@ namespace ModelContextProtocol.Server; /// Augments a service provider with additional request-related services. internal sealed class RequestServiceProvider(RequestContext request) : - IServiceProvider, IKeyedServiceProvider, IServiceProviderIsService, IServiceProviderIsKeyedService, - IDisposable, IAsyncDisposable + IServiceProvider, IKeyedServiceProvider, IServiceProviderIsService, IServiceProviderIsKeyedService, + IDisposable, IAsyncDisposable where TRequestParams : RequestParams { private readonly IServiceProvider? _innerServices = request.Services; @@ -18,14 +18,19 @@ internal sealed class RequestServiceProvider(RequestContextGets whether the specified type is in the list of additional types this service provider wraps around the one in a provided request's services. public static bool IsAugmentedWith(Type serviceType) => serviceType == typeof(RequestContext) || + serviceType == typeof(McpServer) || +#pragma warning disable CS0618 // Type or member is obsolete serviceType == typeof(IMcpServer) || +#pragma warning restore CS0618 // Type or member is obsolete serviceType == typeof(IProgress) || serviceType == typeof(ClaimsPrincipal); /// public object? GetService(Type serviceType) => serviceType == typeof(RequestContext) ? request : - serviceType == typeof(IMcpServer) ? request.Server : +#pragma warning disable CS0618 // Type or member is obsolete + serviceType == typeof(McpServer) || serviceType == typeof(IMcpServer) ? request.Server : +#pragma warning restore CS0618 // Type or member is obsolete serviceType == typeof(IProgress) ? (request.Params?.ProgressToken is { } progressToken ? new TokenProgress(request.Server, progressToken) : NullProgress.Instance) : serviceType == typeof(ClaimsPrincipal) ? request.User : diff --git a/src/ModelContextProtocol.Core/Server/StdioServerTransport.cs b/src/ModelContextProtocol.Core/Server/StdioServerTransport.cs index 556a31159..307c180a1 100644 --- a/src/ModelContextProtocol.Core/Server/StdioServerTransport.cs +++ b/src/ModelContextProtocol.Core/Server/StdioServerTransport.cs @@ -37,7 +37,7 @@ private static string GetServerName(McpServerOptions serverOptions) { Throw.IfNull(serverOptions); - return serverOptions.ServerInfo?.Name ?? McpServer.DefaultImplementation.Name; + return serverOptions.ServerInfo?.Name ?? McpServerImpl.DefaultImplementation.Name; } // Neither WindowsConsoleStream nor UnixConsoleStream respect CancellationTokens or cancel any I/O on Dispose. diff --git a/src/ModelContextProtocol.Core/TokenProgress.cs b/src/ModelContextProtocol.Core/TokenProgress.cs index f222fbf71..6b7a91e00 100644 --- a/src/ModelContextProtocol.Core/TokenProgress.cs +++ b/src/ModelContextProtocol.Core/TokenProgress.cs @@ -4,13 +4,13 @@ namespace ModelContextProtocol; /// /// Provides an tied to a specific progress token and that will issue -/// progress notifications on the supplied endpoint. +/// progress notifications on the supplied session. /// -internal sealed class TokenProgress(IMcpEndpoint endpoint, ProgressToken progressToken) : IProgress +internal sealed class TokenProgress(McpSession session, ProgressToken progressToken) : IProgress { /// public void Report(ProgressNotificationValue value) { - _ = endpoint.NotifyProgressAsync(progressToken, value, CancellationToken.None); + _ = session.NotifyProgressAsync(progressToken, value, CancellationToken.None); } } diff --git a/src/ModelContextProtocol/IMcpServerBuilder.cs b/src/ModelContextProtocol/IMcpServerBuilder.cs index 5ec37eba9..016e9eb3e 100644 --- a/src/ModelContextProtocol/IMcpServerBuilder.cs +++ b/src/ModelContextProtocol/IMcpServerBuilder.cs @@ -3,7 +3,7 @@ namespace Microsoft.Extensions.DependencyInjection; /// -/// Provides a builder for configuring instances. +/// Provides a builder for configuring instances. /// /// /// diff --git a/src/ModelContextProtocol/McpServerBuilderExtensions.cs b/src/ModelContextProtocol/McpServerBuilderExtensions.cs index 8e59d9640..d4c338262 100644 --- a/src/ModelContextProtocol/McpServerBuilderExtensions.cs +++ b/src/ModelContextProtocol/McpServerBuilderExtensions.cs @@ -827,8 +827,8 @@ public static IMcpServerBuilder WithUnsubscribeFromResourcesHandler(this IMcpSer /// and may begin sending log messages at or above the specified level to the client. /// /// - /// Regardless of whether a handler is provided, an should itself handle - /// such notifications by updating its property to return the + /// Regardless of whether a handler is provided, an should itself handle + /// such notifications by updating its property to return the /// most recently set level. /// /// @@ -1180,7 +1180,7 @@ private static void AddSingleSessionServerDependencies(IServiceCollection servic ITransport serverTransport = services.GetRequiredService(); IOptions options = services.GetRequiredService>(); ILoggerFactory? loggerFactory = services.GetService(); - return McpServerFactory.Create(serverTransport, options.Value, loggerFactory, services); + return McpServer.Create(serverTransport, options.Value, loggerFactory, services); }); } #endregion diff --git a/src/ModelContextProtocol/SingleSessionMcpServerHostedService.cs b/src/ModelContextProtocol/SingleSessionMcpServerHostedService.cs index b50e46140..80e8216a8 100644 --- a/src/ModelContextProtocol/SingleSessionMcpServerHostedService.cs +++ b/src/ModelContextProtocol/SingleSessionMcpServerHostedService.cs @@ -10,7 +10,7 @@ namespace ModelContextProtocol; /// /// The host's application lifetime. If available, it will have termination requested when the session's run completes. /// -internal sealed class SingleSessionMcpServerHostedService(IMcpServer session, IHostApplicationLifetime? lifetime = null) : BackgroundService +internal sealed class SingleSessionMcpServerHostedService(McpServer session, IHostApplicationLifetime? lifetime = null) : BackgroundService { /// protected override async Task ExecuteAsync(CancellationToken stoppingToken) diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/AuthEventTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/AuthEventTests.cs index efff68c84..9144121e8 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/AuthEventTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/AuthEventTests.cs @@ -106,7 +106,7 @@ public async Task CanAuthenticate_WithResourceMetadataFromEvent() await app.StartAsync(TestContext.Current.CancellationToken); - await using var transport = new SseClientTransport( + await using var transport = new HttpClientTransport( new() { Endpoint = new(McpServerUrl), @@ -122,7 +122,7 @@ public async Task CanAuthenticate_WithResourceMetadataFromEvent() LoggerFactory ); - await using var client = await McpClientFactory.CreateAsync( + await using var client = await McpClient.CreateAsync( transport, loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken @@ -142,7 +142,7 @@ public async Task CanAuthenticate_WithDynamicClientRegistration_FromEvent() DynamicClientRegistrationResponse? dcrResponse = null; - await using var transport = new SseClientTransport( + await using var transport = new HttpClientTransport( new() { Endpoint = new(McpServerUrl), @@ -167,7 +167,7 @@ public async Task CanAuthenticate_WithDynamicClientRegistration_FromEvent() LoggerFactory ); - await using var client = await McpClientFactory.CreateAsync( + await using var client = await McpClient.CreateAsync( transport, loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/AuthTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/AuthTests.cs index b480934ac..fff7d6d42 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/AuthTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/AuthTests.cs @@ -97,7 +97,7 @@ public async Task CanAuthenticate() await app.StartAsync(TestContext.Current.CancellationToken); - await using var transport = new SseClientTransport(new() + await using var transport = new HttpClientTransport(new() { Endpoint = new(McpServerUrl), OAuth = new() @@ -109,7 +109,7 @@ public async Task CanAuthenticate() }, }, HttpClient, LoggerFactory); - await using var client = await McpClientFactory.CreateAsync( + await using var client = await McpClient.CreateAsync( transport, loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken); } @@ -124,12 +124,12 @@ public async Task CannotAuthenticate_WithoutOAuthConfiguration() await app.StartAsync(TestContext.Current.CancellationToken); - await using var transport = new SseClientTransport(new() + await using var transport = new HttpClientTransport(new() { Endpoint = new(McpServerUrl), }, HttpClient, LoggerFactory); - var httpEx = await Assert.ThrowsAsync(async () => await McpClientFactory.CreateAsync( + var httpEx = await Assert.ThrowsAsync(async () => await McpClient.CreateAsync( transport, loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken)); Assert.Equal(HttpStatusCode.Unauthorized, httpEx.StatusCode); @@ -146,7 +146,7 @@ public async Task CannotAuthenticate_WithUnregisteredClient() await app.StartAsync(TestContext.Current.CancellationToken); - await using var transport = new SseClientTransport(new() + await using var transport = new HttpClientTransport(new() { Endpoint = new(McpServerUrl), OAuth = new() @@ -159,7 +159,7 @@ public async Task CannotAuthenticate_WithUnregisteredClient() }, HttpClient, LoggerFactory); // The EqualException is thrown by HandleAuthorizationUrlAsync when the /authorize request gets a 400 - var equalEx = await Assert.ThrowsAsync(async () => await McpClientFactory.CreateAsync( + var equalEx = await Assert.ThrowsAsync(async () => await McpClient.CreateAsync( transport, loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken)); } @@ -174,7 +174,7 @@ public async Task CanAuthenticate_WithDynamicClientRegistration() await app.StartAsync(TestContext.Current.CancellationToken); - await using var transport = new SseClientTransport(new() + await using var transport = new HttpClientTransport(new() { Endpoint = new(McpServerUrl), OAuth = new ClientOAuthOptions() @@ -190,7 +190,7 @@ public async Task CanAuthenticate_WithDynamicClientRegistration() }, }, HttpClient, LoggerFactory); - await using var client = await McpClientFactory.CreateAsync( + await using var client = await McpClient.CreateAsync( transport, loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken); } @@ -205,7 +205,7 @@ public async Task CanAuthenticate_WithTokenRefresh() await app.StartAsync(TestContext.Current.CancellationToken); - await using var transport = new SseClientTransport(new() + await using var transport = new HttpClientTransport(new() { Endpoint = new(McpServerUrl), OAuth = new() @@ -219,7 +219,7 @@ public async Task CanAuthenticate_WithTokenRefresh() // The test-refresh-client should get an expired token first, // then automatically refresh it to get a working token - await using var client = await McpClientFactory.CreateAsync( + await using var client = await McpClient.CreateAsync( transport, loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken); Assert.True(_testOAuthServer.HasIssuedRefreshToken); @@ -236,7 +236,7 @@ public async Task CanAuthenticate_WithExtraParams() await app.StartAsync(TestContext.Current.CancellationToken); - await using var transport = new SseClientTransport(new() + await using var transport = new HttpClientTransport(new() { Endpoint = new(McpServerUrl), OAuth = new() @@ -252,7 +252,7 @@ public async Task CanAuthenticate_WithExtraParams() }, }, HttpClient, LoggerFactory); - await using var client = await McpClientFactory.CreateAsync( + await using var client = await McpClient.CreateAsync( transport, loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken); Assert.NotNull(_lastAuthorizationUri?.Query); @@ -270,7 +270,7 @@ public async Task CannotOverrideExistingParameters_WithExtraParams() await app.StartAsync(TestContext.Current.CancellationToken); - await using var transport = new SseClientTransport(new() + await using var transport = new HttpClientTransport(new() { Endpoint = new(McpServerUrl), OAuth = new() @@ -286,7 +286,7 @@ public async Task CannotOverrideExistingParameters_WithExtraParams() }, }, HttpClient, LoggerFactory); - await Assert.ThrowsAsync(() => McpClientFactory.CreateAsync( + await Assert.ThrowsAsync(() => McpClient.CreateAsync( transport, loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken)); } diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/AuthorizeAttributeTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/AuthorizeAttributeTests.cs index 914284584..84d1c1a79 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/AuthorizeAttributeTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/AuthorizeAttributeTests.cs @@ -19,14 +19,14 @@ public class AuthorizeAttributeTests(ITestOutputHelper testOutputHelper) : Kestr { private readonly MockLoggerProvider _mockLoggerProvider = new(); - private async Task ConnectAsync() + private async Task ConnectAsync() { - await using var transport = new SseClientTransport(new SseClientTransportOptions + await using var transport = new HttpClientTransport(new HttpClientTransportOptions { Endpoint = new("http://localhost:5000"), }, HttpClient, LoggerFactory); - return await McpClientFactory.CreateAsync(transport, cancellationToken: TestContext.Current.CancellationToken, loggerFactory: LoggerFactory); + return await McpClient.CreateAsync(transport, cancellationToken: TestContext.Current.CancellationToken, loggerFactory: LoggerFactory); } [Fact] @@ -447,7 +447,7 @@ private async Task StartServerWithoutAuthFilters(Action new ClaimsPrincipal(new ClaimsIdentity( - [new Claim("name", name), new Claim(ClaimTypes.NameIdentifier, name), ..roles.Select(role => new Claim("role", role))], + [new Claim("name", name), new Claim(ClaimTypes.NameIdentifier, name), .. roles.Select(role => new Claim("role", role))], "TestAuthType", "name", "role")); [McpServerToolType] diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/HttpServerIntegrationTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/HttpServerIntegrationTests.cs index 9b3c91b94..f9aa5a5e9 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/HttpServerIntegrationTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/HttpServerIntegrationTests.cs @@ -21,9 +21,9 @@ public override void Dispose() base.Dispose(); } - protected abstract SseClientTransportOptions ClientTransportOptions { get; } + protected abstract HttpClientTransportOptions ClientTransportOptions { get; } - private Task GetClientAsync(McpClientOptions? options = null) + private Task GetClientAsync(McpClientOptions? options = null) { return _fixture.ConnectMcpClientAsync(options, LoggerFactory); } diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpTests.cs index 0d867c8f0..bb9746ed7 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpTests.cs @@ -23,21 +23,21 @@ protected void ConfigureStateless(HttpServerTransportOptions options) options.Stateless = Stateless; } - protected async Task ConnectAsync( + protected async Task ConnectAsync( string? path = null, - SseClientTransportOptions? transportOptions = null, + HttpClientTransportOptions? transportOptions = null, McpClientOptions? clientOptions = null) { // Default behavior when no options are provided path ??= UseStreamableHttp ? "/" : "/sse"; - await using var transport = new SseClientTransport(transportOptions ?? new SseClientTransportOptions + await using var transport = new HttpClientTransport(transportOptions ?? new HttpClientTransportOptions { Endpoint = new Uri($"http://localhost:5000{path}"), TransportMode = UseStreamableHttp ? HttpTransportMode.StreamableHttp : HttpTransportMode.Sse, }, HttpClient, LoggerFactory); - return await McpClientFactory.CreateAsync(transport, clientOptions, LoggerFactory, TestContext.Current.CancellationToken); + return await McpClient.CreateAsync(transport, clientOptions, LoggerFactory, TestContext.Current.CancellationToken); } [Fact] @@ -244,7 +244,7 @@ public string EchoClaimsPrincipal(ClaimsPrincipal? user, string message) private class SamplingRegressionTools { [McpServerTool(Name = "sampling-tool")] - public static async Task SamplingToolAsync(IMcpServer server, string prompt, CancellationToken cancellationToken) + public static async Task SamplingToolAsync(McpServer server, string prompt, CancellationToken cancellationToken) { // This tool reproduces the scenario described in https://github.com/modelcontextprotocol/csharp-sdk/issues/464 // 1. The client calls tool with request ID 2, because it's the first request after the initialize request. diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/SseIntegrationTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/SseIntegrationTests.cs index 8191f6091..ffec1a4be 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/SseIntegrationTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/SseIntegrationTests.cs @@ -15,15 +15,15 @@ namespace ModelContextProtocol.AspNetCore.Tests; public partial class SseIntegrationTests(ITestOutputHelper outputHelper) : KestrelInMemoryTest(outputHelper) { - private readonly SseClientTransportOptions DefaultTransportOptions = new() + private readonly HttpClientTransportOptions DefaultTransportOptions = new() { Endpoint = new("http://localhost:5000/sse"), Name = "In-memory SSE Client", }; - private Task ConnectMcpClientAsync(HttpClient? httpClient = null, SseClientTransportOptions? transportOptions = null) - => McpClientFactory.CreateAsync( - new SseClientTransport(transportOptions ?? DefaultTransportOptions, httpClient ?? HttpClient, LoggerFactory), + private Task ConnectMcpClientAsync(HttpClient? httpClient = null, HttpClientTransportOptions? transportOptions = null) + => McpClient.CreateAsync( + new HttpClientTransport(transportOptions ?? DefaultTransportOptions, httpClient ?? HttpClient, LoggerFactory), loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken); @@ -195,7 +195,7 @@ public async Task AdditionalHeaders_AreSent_InGetAndPostRequests() app.MapMcp(); await app.StartAsync(TestContext.Current.CancellationToken); - var sseOptions = new SseClientTransportOptions + var sseOptions = new HttpClientTransportOptions { Endpoint = new("http://localhost:5000/sse"), Name = "In-memory SSE Client", @@ -222,7 +222,7 @@ public async Task EmptyAdditionalHeadersKey_Throws_InvalidOperationException() app.MapMcp(); await app.StartAsync(TestContext.Current.CancellationToken); - var sseOptions = new SseClientTransportOptions + var sseOptions = new HttpClientTransportOptions { Endpoint = new("http://localhost:5000/sse"), Name = "In-memory SSE Client", @@ -257,7 +257,7 @@ private static void MapAbsoluteEndpointUriMcp(IEndpointRouteBuilder endpoints, b try { var transportTask = transport.RunAsync(cancellationToken: requestAborted); - await using var server = McpServerFactory.Create(transport, optionsSnapshot.Value, loggerFactory, endpoints.ServiceProvider); + await using var server = McpServer.Create(transport, optionsSnapshot.Value, loggerFactory, endpoints.ServiceProvider); try { diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/SseServerIntegrationTestFixture.cs b/tests/ModelContextProtocol.AspNetCore.Tests/SseServerIntegrationTestFixture.cs index 2aa675c84..c382c4385 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/SseServerIntegrationTestFixture.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/SseServerIntegrationTestFixture.cs @@ -18,7 +18,7 @@ public class SseServerIntegrationTestFixture : IAsyncDisposable // multiple tests, so this dispatches the output to the current test. private readonly DelegatingTestOutputHelper _delegatingTestOutputHelper = new(); - private SseClientTransportOptions DefaultTransportOptions { get; set; } = new() + private HttpClientTransportOptions DefaultTransportOptions { get; set; } = new() { Endpoint = new("http://localhost:5000/"), }; @@ -44,16 +44,16 @@ public SseServerIntegrationTestFixture() public HttpClient HttpClient { get; } - public Task ConnectMcpClientAsync(McpClientOptions? options, ILoggerFactory loggerFactory) + public Task ConnectMcpClientAsync(McpClientOptions? options, ILoggerFactory loggerFactory) { - return McpClientFactory.CreateAsync( - new SseClientTransport(DefaultTransportOptions, HttpClient, loggerFactory), + return McpClient.CreateAsync( + new HttpClientTransport(DefaultTransportOptions, HttpClient, loggerFactory), options, loggerFactory, TestContext.Current.CancellationToken); } - public void Initialize(ITestOutputHelper output, SseClientTransportOptions clientTransportOptions) + public void Initialize(ITestOutputHelper output, HttpClientTransportOptions clientTransportOptions) { _delegatingTestOutputHelper.CurrentTestOutputHelper = output; DefaultTransportOptions = clientTransportOptions; diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/SseServerIntegrationTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/SseServerIntegrationTests.cs index 2d4a78685..eb7db0110 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/SseServerIntegrationTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/SseServerIntegrationTests.cs @@ -8,7 +8,7 @@ public class SseServerIntegrationTests(SseServerIntegrationTestFixture fixture, : HttpServerIntegrationTests(fixture, testOutputHelper) { - protected override SseClientTransportOptions ClientTransportOptions => new() + protected override HttpClientTransportOptions ClientTransportOptions => new() { Endpoint = new("http://localhost:5000/sse"), Name = "In-memory SSE Client", diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/StatelessServerIntegrationTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/StatelessServerIntegrationTests.cs index d16e510cc..2ce63a1bc 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/StatelessServerIntegrationTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/StatelessServerIntegrationTests.cs @@ -5,7 +5,7 @@ namespace ModelContextProtocol.AspNetCore.Tests; public class StatelessServerIntegrationTests(SseServerIntegrationTestFixture fixture, ITestOutputHelper testOutputHelper) : StreamableHttpServerIntegrationTests(fixture, testOutputHelper) { - protected override SseClientTransportOptions ClientTransportOptions => new() + protected override HttpClientTransportOptions ClientTransportOptions => new() { Endpoint = new("http://localhost:5000/stateless"), Name = "In-memory Streamable HTTP Client", diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/StatelessServerTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/StatelessServerTests.cs index b50a43edc..3c200bb61 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/StatelessServerTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/StatelessServerTests.cs @@ -14,7 +14,7 @@ public class StatelessServerTests(ITestOutputHelper outputHelper) : KestrelInMem { private WebApplication? _app; - private readonly SseClientTransportOptions DefaultTransportOptions = new() + private readonly HttpClientTransportOptions DefaultTransportOptions = new() { Endpoint = new("http://localhost:5000/"), Name = "In-memory Streamable HTTP Client", @@ -58,9 +58,9 @@ private async Task StartAsync() HttpClient.DefaultRequestHeaders.Accept.Add(new("text/event-stream")); } - private Task ConnectMcpClientAsync(McpClientOptions? clientOptions = null) - => McpClientFactory.CreateAsync( - new SseClientTransport(DefaultTransportOptions, HttpClient, LoggerFactory), + private Task ConnectMcpClientAsync(McpClientOptions? clientOptions = null) + => McpClient.CreateAsync( + new HttpClientTransport(DefaultTransportOptions, HttpClient, LoggerFactory), clientOptions, LoggerFactory, TestContext.Current.CancellationToken); public async ValueTask DisposeAsync() @@ -194,7 +194,7 @@ public async Task ScopedServices_Resolve_FromRequestScope() } [McpServerTool(Name = "testSamplingErrors")] - public static async Task TestSamplingErrors(IMcpServer server) + public static async Task TestSamplingErrors(McpServer server) { const string expectedSamplingErrorMessage = "Sampling is not supported in stateless mode."; @@ -212,7 +212,7 @@ public static async Task TestSamplingErrors(IMcpServer server) } [McpServerTool(Name = "testRootsErrors")] - public static async Task TestRootsErrors(IMcpServer server) + public static async Task TestRootsErrors(McpServer server) { const string expectedRootsErrorMessage = "Roots are not supported in stateless mode."; @@ -227,7 +227,7 @@ public static async Task TestRootsErrors(IMcpServer server) } [McpServerTool(Name = "testElicitationErrors")] - public static async Task TestElicitationErrors(IMcpServer server) + public static async Task TestElicitationErrors(McpServer server) { const string expectedElicitationErrorMessage = "Elicitation is not supported in stateless mode."; diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpClientConformanceTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpClientConformanceTests.cs index 7ce3516ef..f1cd458f9 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpClientConformanceTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpClientConformanceTests.cs @@ -112,13 +112,13 @@ public async Task CanCallToolOnSessionlessStreamableHttpServer() { await StartAsync(); - await using var transport = new SseClientTransport(new() + await using var transport = new HttpClientTransport(new() { Endpoint = new("http://localhost:5000/mcp"), TransportMode = HttpTransportMode.StreamableHttp, }, HttpClient, LoggerFactory); - await using var client = await McpClientFactory.CreateAsync(transport, loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken); + await using var client = await McpClient.CreateAsync(transport, loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken); var tools = await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken); var echoTool = Assert.Single(tools); @@ -132,13 +132,13 @@ public async Task CanCallToolConcurrently() { await StartAsync(); - await using var transport = new SseClientTransport(new() + await using var transport = new HttpClientTransport(new() { Endpoint = new("http://localhost:5000/mcp"), TransportMode = HttpTransportMode.StreamableHttp, }, HttpClient, LoggerFactory); - await using var client = await McpClientFactory.CreateAsync(transport, loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken); + await using var client = await McpClient.CreateAsync(transport, loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken); var tools = await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken); var echoTool = Assert.Single(tools); @@ -158,13 +158,13 @@ public async Task SendsDeleteRequestOnDispose() { await StartAsync(enableDelete: true); - await using var transport = new SseClientTransport(new() + await using var transport = new HttpClientTransport(new() { Endpoint = new("http://localhost:5000/mcp"), TransportMode = HttpTransportMode.StreamableHttp, }, HttpClient, LoggerFactory); - await using var client = await McpClientFactory.CreateAsync(transport, loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken); + await using var client = await McpClient.CreateAsync(transport, loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken); // Dispose should trigger DELETE request await client.DisposeAsync(); diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpServerConformanceTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpServerConformanceTests.cs index bb184034c..7b2be8f98 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpServerConformanceTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpServerConformanceTests.cs @@ -252,7 +252,7 @@ public async Task MultipleConcurrentJsonRpcRequests_IsHandled_InParallel() [Fact] public async Task GetRequest_Receives_UnsolicitedNotifications() { - IMcpServer? server = null; + McpServer? server = null; Builder.Services.AddMcpServer() .WithHttpTransport(options => diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpServerIntegrationTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpServerIntegrationTests.cs index 3524c60a4..b2b0b5499 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpServerIntegrationTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/StreamableHttpServerIntegrationTests.cs @@ -11,7 +11,7 @@ public class StreamableHttpServerIntegrationTests(SseServerIntegrationTestFixtur {"jsonrpc":"2.0","id":"1","method":"initialize","params":{"protocolVersion":"2025-03-26","capabilities":{},"clientInfo":{"name":"IntegrationTestClient","version":"1.0.0"}}} """; - protected override SseClientTransportOptions ClientTransportOptions => new() + protected override HttpClientTransportOptions ClientTransportOptions => new() { Endpoint = new("http://localhost:5000/"), Name = "In-memory Streamable HTTP Client", diff --git a/tests/ModelContextProtocol.TestServer/Program.cs b/tests/ModelContextProtocol.TestServer/Program.cs index ddd3701fd..cbfd828c8 100644 --- a/tests/ModelContextProtocol.TestServer/Program.cs +++ b/tests/ModelContextProtocol.TestServer/Program.cs @@ -51,7 +51,7 @@ private static async Task Main(string[] args) using var loggerFactory = CreateLoggerFactory(); await using var stdioTransport = new StdioServerTransport("TestServer", loggerFactory); - await using IMcpServer server = McpServerFactory.Create(stdioTransport, options, loggerFactory); + await using McpServer server = McpServer.Create(stdioTransport, options, loggerFactory); Log.Logger.Information("Server running..."); @@ -61,7 +61,7 @@ private static async Task Main(string[] args) await server.RunAsync(); } - private static async Task RunBackgroundLoop(IMcpServer server, CancellationToken cancellationToken = default) + private static async Task RunBackgroundLoop(McpServer server, CancellationToken cancellationToken = default) { var loggingLevels = (LoggingLevel[])Enum.GetValues(typeof(LoggingLevel)); var random = new Random(); diff --git a/tests/ModelContextProtocol.Tests/Client/McpClientFactoryTests.cs b/tests/ModelContextProtocol.Tests/Client/McpClientCreationTests.cs similarity index 90% rename from tests/ModelContextProtocol.Tests/Client/McpClientFactoryTests.cs rename to tests/ModelContextProtocol.Tests/Client/McpClientCreationTests.cs index 7516a2186..15127502e 100644 --- a/tests/ModelContextProtocol.Tests/Client/McpClientFactoryTests.cs +++ b/tests/ModelContextProtocol.Tests/Client/McpClientCreationTests.cs @@ -1,26 +1,24 @@ -using Microsoft.Extensions.Logging; using ModelContextProtocol.Client; using ModelContextProtocol.Protocol; -using Moq; using System.IO.Pipelines; using System.Text.Json; using System.Threading.Channels; namespace ModelContextProtocol.Tests.Client; -public class McpClientFactoryTests +public class McpClientCreationTests { [Fact] public async Task CreateAsync_WithInvalidArgs_Throws() { - await Assert.ThrowsAsync("clientTransport", () => McpClientFactory.CreateAsync(null!, cancellationToken: TestContext.Current.CancellationToken)); + await Assert.ThrowsAsync("clientTransport", () => McpClient.CreateAsync(null!, cancellationToken: TestContext.Current.CancellationToken)); } [Fact] public async Task CreateAsync_NopTransport_ReturnsClient() { // Act - await using var client = await McpClientFactory.CreateAsync( + await using var client = await McpClient.CreateAsync( new NopTransport(), cancellationToken: TestContext.Current.CancellationToken); @@ -39,7 +37,7 @@ public async Task Cancellation_ThrowsCancellationException(bool preCanceled) cts.Cancel(); } - Task t = McpClientFactory.CreateAsync( + Task t = McpClient.CreateAsync( new StreamClientTransport(new Pipe().Writer.AsStream(), new Pipe().Reader.AsStream()), cancellationToken: cts.Token); if (!preCanceled) @@ -85,9 +83,9 @@ public async Task CreateAsync_WithCapabilitiesOptions(Type transportType) }; var clientTransport = (IClientTransport)Activator.CreateInstance(transportType)!; - IMcpClient? client = null; + McpClient? client = null; - var actionTask = McpClientFactory.CreateAsync(clientTransport, clientOptions, new Mock().Object, CancellationToken.None); + var actionTask = McpClient.CreateAsync(clientTransport, clientOptions, loggerFactory: null, CancellationToken.None); // Act if (clientTransport is FailureTransport) diff --git a/tests/ModelContextProtocol.Tests/Client/McpClientExtensionsTests.cs b/tests/ModelContextProtocol.Tests/Client/McpClientExtensionsTests.cs index e3d7ce44c..f4e6062de 100644 --- a/tests/ModelContextProtocol.Tests/Client/McpClientExtensionsTests.cs +++ b/tests/ModelContextProtocol.Tests/Client/McpClientExtensionsTests.cs @@ -1,471 +1,387 @@ -using Microsoft.Extensions.AI; -using Microsoft.Extensions.DependencyInjection; -using Microsoft.Extensions.Logging; using ModelContextProtocol.Client; using ModelContextProtocol.Protocol; -using ModelContextProtocol.Server; using Moq; using System.Text.Json; -using System.Text.Json.Serialization.Metadata; -using System.Threading.Channels; -namespace ModelContextProtocol.Tests.Client; +namespace ModelContextProtocol.Tests; -public class McpClientExtensionsTests : ClientServerTestBase +#pragma warning disable CS0618 // Type or member is obsolete + +public class McpClientExtensionsTests { - public McpClientExtensionsTests(ITestOutputHelper outputHelper) - : base(outputHelper) + [Fact] + public async Task PingAsync_Throws_When_Not_McpClient() { + var client = new Mock(MockBehavior.Strict).Object; + + var ex = await Assert.ThrowsAsync(async () => await client.PingAsync(TestContext.Current.CancellationToken)); + Assert.Contains("Prefer using 'McpClient.PingAsync' instead", ex.Message); } - protected override void ConfigureServices(ServiceCollection services, IMcpServerBuilder mcpServerBuilder) + [Fact] + public async Task GetPromptAsync_Throws_When_Not_McpClient() { - for (int f = 0; f < 10; f++) - { - string name = $"Method{f}"; - mcpServerBuilder.WithTools([McpServerTool.Create((int i) => $"{name} Result {i}", new() { Name = name })]); - } - mcpServerBuilder.WithTools([McpServerTool.Create([McpServerTool(Destructive = false, OpenWorld = true)] (string i) => $"{i} Result", new() { Name = "ValuesSetViaAttr" })]); - mcpServerBuilder.WithTools([McpServerTool.Create([McpServerTool(Destructive = false, OpenWorld = true)] (string i) => $"{i} Result", new() { Name = "ValuesSetViaOptions", Destructive = true, OpenWorld = false, ReadOnly = true })]); + var client = new Mock(MockBehavior.Strict).Object; + + var ex = await Assert.ThrowsAsync(async () => await client.GetPromptAsync( + "name", cancellationToken: TestContext.Current.CancellationToken)); + Assert.Contains("Prefer using 'McpClient.GetPromptAsync' instead", ex.Message); } - [Theory] - [InlineData(null, null)] - [InlineData(0.7f, 50)] - [InlineData(1.0f, 100)] - public async Task CreateSamplingHandler_ShouldHandleTextMessages(float? temperature, int? maxTokens) + [Fact] + public async Task CallToolAsync_Throws_When_Not_McpClient() { - // Arrange - var mockChatClient = new Mock(); - var requestParams = new CreateMessageRequestParams - { - Messages = - [ - new SamplingMessage - { - Role = Role.User, - Content = new TextContentBlock { Text = "Hello" } - } - ], - Temperature = temperature, - MaxTokens = maxTokens, - }; - - var cancellationToken = CancellationToken.None; - var expectedResponse = new[] { - new ChatResponseUpdate - { - ModelId = "test-model", - FinishReason = ChatFinishReason.Stop, - Role = ChatRole.Assistant, - Contents = - [ - new TextContent("Hello, World!") { RawRepresentation = "Hello, World!" } - ] - } - }.ToAsyncEnumerable(); - - mockChatClient - .Setup(client => client.GetStreamingResponseAsync(It.IsAny>(), It.IsAny(), cancellationToken)) - .Returns(expectedResponse); - - var handler = McpClientExtensions.CreateSamplingHandler(mockChatClient.Object); - - // Act - var result = await handler(requestParams, Mock.Of>(), cancellationToken); - - // Assert - Assert.NotNull(result); - Assert.Equal("Hello, World!", (result.Content as TextContentBlock)?.Text); - Assert.Equal("test-model", result.Model); - Assert.Equal(Role.Assistant, result.Role); - Assert.Equal("endTurn", result.StopReason); + var client = new Mock(MockBehavior.Strict).Object; + + var ex = await Assert.ThrowsAsync(async () => await client.CallToolAsync( + "tool", cancellationToken: TestContext.Current.CancellationToken)); + Assert.Contains("Prefer using 'McpClient.CallToolAsync' instead", ex.Message); } [Fact] - public async Task CreateSamplingHandler_ShouldHandleImageMessages() + public async Task ListResourcesAsync_Throws_When_Not_McpClient() { - // Arrange - var mockChatClient = new Mock(); - var requestParams = new CreateMessageRequestParams - { - Messages = - [ - new SamplingMessage - { - Role = Role.User, - Content = new ImageContentBlock - { - MimeType = "image/png", - Data = Convert.ToBase64String(new byte[] { 1, 2, 3 }) - } - } - ], - MaxTokens = 100 - }; - - const string expectedData = "SGVsbG8sIFdvcmxkIQ=="; - var cancellationToken = CancellationToken.None; - var expectedResponse = new[] { - new ChatResponseUpdate - { - ModelId = "test-model", - FinishReason = ChatFinishReason.Stop, - Role = ChatRole.Assistant, - Contents = - [ - new DataContent($"data:image/png;base64,{expectedData}") { RawRepresentation = "Hello, World!" } - ] - } - }.ToAsyncEnumerable(); - - mockChatClient - .Setup(client => client.GetStreamingResponseAsync(It.IsAny>(), It.IsAny(), cancellationToken)) - .Returns(expectedResponse); - - var handler = McpClientExtensions.CreateSamplingHandler(mockChatClient.Object); - - // Act - var result = await handler(requestParams, Mock.Of>(), cancellationToken); - - // Assert - Assert.NotNull(result); - Assert.Equal(expectedData, (result.Content as ImageContentBlock)?.Data); - Assert.Equal("test-model", result.Model); - Assert.Equal(Role.Assistant, result.Role); - Assert.Equal("endTurn", result.StopReason); + var client = new Mock(MockBehavior.Strict).Object; + + var ex = await Assert.ThrowsAsync(async () => await client.ListResourcesAsync( + cancellationToken: TestContext.Current.CancellationToken)); + Assert.Contains("Prefer using 'McpClient.ListResourcesAsync' instead", ex.Message); } [Fact] - public async Task CreateSamplingHandler_ShouldHandleResourceMessages() + public void EnumerateResourcesAsync_Throws_When_Not_McpClient() { - // Arrange - const string data = "SGVsbG8sIFdvcmxkIQ=="; - string content = $"data:application/octet-stream;base64,{data}"; - var mockChatClient = new Mock(); - var resource = new BlobResourceContents - { - Blob = data, - MimeType = "application/octet-stream", - Uri = "data:application/octet-stream" - }; - - var requestParams = new CreateMessageRequestParams - { - Messages = - [ - new SamplingMessage - { - Role = Role.User, - Content = new EmbeddedResourceBlock { Resource = resource }, - } - ], - MaxTokens = 100 - }; - - var cancellationToken = CancellationToken.None; - var expectedResponse = new[] { - new ChatResponseUpdate - { - ModelId = "test-model", - FinishReason = ChatFinishReason.Stop, - AuthorName = "bot", - Role = ChatRole.Assistant, - Contents = - [ - resource.ToAIContent() - ] - } - }.ToAsyncEnumerable(); - - mockChatClient - .Setup(client => client.GetStreamingResponseAsync(It.IsAny>(), It.IsAny(), cancellationToken)) - .Returns(expectedResponse); - - var handler = McpClientExtensions.CreateSamplingHandler(mockChatClient.Object); - - // Act - var result = await handler(requestParams, Mock.Of>(), cancellationToken); - - // Assert - Assert.NotNull(result); - Assert.Equal("test-model", result.Model); - Assert.Equal(Role.Assistant, result.Role); - Assert.Equal("endTurn", result.StopReason); + var client = new Mock(MockBehavior.Strict).Object; + + var ex = Assert.Throws(() => client.EnumerateResourcesAsync(cancellationToken: TestContext.Current.CancellationToken)); + Assert.Contains("Prefer using 'McpClient.EnumerateResourcesAsync' instead", ex.Message); } [Fact] - public async Task ListToolsAsync_AllToolsReturned() + public async Task SubscribeToResourceAsync_String_Throws_When_Not_McpClient() { - await using IMcpClient client = await CreateMcpClientForServer(); - - var tools = await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken); - Assert.Equal(12, tools.Count); - var echo = tools.Single(t => t.Name == "Method4"); - var result = await echo.InvokeAsync(new() { ["i"] = 42 }, TestContext.Current.CancellationToken); - Assert.Contains("Method4 Result 42", result?.ToString()); - - var valuesSetViaAttr = tools.Single(t => t.Name == "ValuesSetViaAttr"); - Assert.Null(valuesSetViaAttr.ProtocolTool.Annotations?.Title); - Assert.Null(valuesSetViaAttr.ProtocolTool.Annotations?.ReadOnlyHint); - Assert.Null(valuesSetViaAttr.ProtocolTool.Annotations?.IdempotentHint); - Assert.False(valuesSetViaAttr.ProtocolTool.Annotations?.DestructiveHint); - Assert.True(valuesSetViaAttr.ProtocolTool.Annotations?.OpenWorldHint); - - var valuesSetViaOptions = tools.Single(t => t.Name == "ValuesSetViaOptions"); - Assert.Null(valuesSetViaOptions.ProtocolTool.Annotations?.Title); - Assert.True(valuesSetViaOptions.ProtocolTool.Annotations?.ReadOnlyHint); - Assert.Null(valuesSetViaOptions.ProtocolTool.Annotations?.IdempotentHint); - Assert.True(valuesSetViaOptions.ProtocolTool.Annotations?.DestructiveHint); - Assert.False(valuesSetViaOptions.ProtocolTool.Annotations?.OpenWorldHint); + var client = new Mock(MockBehavior.Strict).Object; + + var ex = await Assert.ThrowsAsync(async () => await client.SubscribeToResourceAsync( + "mcp://resource/1", TestContext.Current.CancellationToken)); + Assert.Contains("Prefer using 'McpClient.SubscribeToResourceAsync' instead", ex.Message); } [Fact] - public async Task EnumerateToolsAsync_AllToolsReturned() + public async Task SubscribeToResourceAsync_Uri_Throws_When_Not_McpClient() { - await using IMcpClient client = await CreateMcpClientForServer(); + var client = new Mock(MockBehavior.Strict).Object; - await foreach (var tool in client.EnumerateToolsAsync(cancellationToken: TestContext.Current.CancellationToken)) - { - if (tool.Name == "Method4") - { - var result = await tool.InvokeAsync(new() { ["i"] = 42 }, TestContext.Current.CancellationToken); - Assert.Contains("Method4 Result 42", result?.ToString()); - return; - } - } + var ex = await Assert.ThrowsAsync(async () => await client.SubscribeToResourceAsync( + new Uri("mcp://resource/1"), TestContext.Current.CancellationToken)); + Assert.Contains("Prefer using 'McpClient.SubscribeToResourceAsync' instead", ex.Message); + } + + [Fact] + public async Task UnsubscribeFromResourceAsync_String_Throws_When_Not_McpClient() + { + var client = new Mock(MockBehavior.Strict).Object; - Assert.Fail("Couldn't find target method"); + var ex = await Assert.ThrowsAsync(async () => await client.UnsubscribeFromResourceAsync( + "mcp://resource/1", TestContext.Current.CancellationToken)); + Assert.Contains("Prefer using 'McpClient.UnsubscribeFromResourceAsync' instead", ex.Message); } [Fact] - public async Task EnumerateToolsAsync_FlowsJsonSerializerOptions() + public async Task UnsubscribeFromResourceAsync_Uri_Throws_When_Not_McpClient() { - JsonSerializerOptions options = new(JsonSerializerOptions.Default); - await using IMcpClient client = await CreateMcpClientForServer(); - bool hasTools = false; - - await foreach (var tool in client.EnumerateToolsAsync(options, TestContext.Current.CancellationToken)) - { - Assert.Same(options, tool.JsonSerializerOptions); - hasTools = true; - } - - foreach (var tool in await client.ListToolsAsync(options, TestContext.Current.CancellationToken)) - { - Assert.Same(options, tool.JsonSerializerOptions); - } - - Assert.True(hasTools); + var client = new Mock(MockBehavior.Strict).Object; + + var ex = await Assert.ThrowsAsync(async () => await client.UnsubscribeFromResourceAsync( + new Uri("mcp://resource/1"), TestContext.Current.CancellationToken)); + Assert.Contains("Prefer using 'McpClient.UnsubscribeFromResourceAsync' instead", ex.Message); } [Fact] - public async Task EnumerateToolsAsync_HonorsJsonSerializerOptions() + public async Task ReadResourceAsync_String_Throws_When_Not_McpClient() { - JsonSerializerOptions emptyOptions = new() { TypeInfoResolver = JsonTypeInfoResolver.Combine() }; - await using IMcpClient client = await CreateMcpClientForServer(); + var client = new Mock(MockBehavior.Strict).Object; - var tool = (await client.ListToolsAsync(emptyOptions, TestContext.Current.CancellationToken)).First(); - await Assert.ThrowsAsync(async () => await tool.InvokeAsync(new() { ["i"] = 42 }, TestContext.Current.CancellationToken)); + var ex = await Assert.ThrowsAsync(async () => await client.ReadResourceAsync( + "mcp://resource/1", TestContext.Current.CancellationToken)); + Assert.Contains("Prefer using 'McpClient.ReadResourceAsync' instead", ex.Message); } [Fact] - public async Task SendRequestAsync_HonorsJsonSerializerOptions() + public async Task ReadResourceAsync_Uri_Throws_When_Not_McpClient() { - JsonSerializerOptions emptyOptions = new() { TypeInfoResolver = JsonTypeInfoResolver.Combine() }; - await using IMcpClient client = await CreateMcpClientForServer(); + var client = new Mock(MockBehavior.Strict).Object; - await Assert.ThrowsAsync(async () => await client.SendRequestAsync("Method4", new() { Name = "tool" }, emptyOptions, cancellationToken: TestContext.Current.CancellationToken)); + var ex = await Assert.ThrowsAsync(async () => await client.ReadResourceAsync( + new Uri("mcp://resource/1"), TestContext.Current.CancellationToken)); + Assert.Contains("Prefer using 'McpClient.ReadResourceAsync' instead", ex.Message); } [Fact] - public async Task SendNotificationAsync_HonorsJsonSerializerOptions() + public async Task ReadResourceAsync_Template_Throws_When_Not_McpClient() { - JsonSerializerOptions emptyOptions = new() { TypeInfoResolver = JsonTypeInfoResolver.Combine() }; - await using IMcpClient client = await CreateMcpClientForServer(); + var client = new Mock(MockBehavior.Strict).Object; - await Assert.ThrowsAsync(() => client.SendNotificationAsync("Method4", new { Value = 42 }, emptyOptions, cancellationToken: TestContext.Current.CancellationToken)); + var ex = await Assert.ThrowsAsync(async () => await client.ReadResourceAsync( + "mcp://resource/{id}", new Dictionary { ["id"] = 1 }, TestContext.Current.CancellationToken)); + Assert.Contains("Prefer using 'McpClient.ReadResourceAsync' instead", ex.Message); } [Fact] - public async Task GetPromptsAsync_HonorsJsonSerializerOptions() + public async Task CompleteAsync_Throws_When_Not_McpClient() { - JsonSerializerOptions emptyOptions = new() { TypeInfoResolver = JsonTypeInfoResolver.Combine() }; - await using IMcpClient client = await CreateMcpClientForServer(); + var client = new Mock(MockBehavior.Strict).Object; + var reference = new PromptReference { Name = "prompt" }; - await Assert.ThrowsAsync(async () => await client.GetPromptAsync("Prompt", new Dictionary { ["i"] = 42 }, emptyOptions, cancellationToken: TestContext.Current.CancellationToken)); + var ex = await Assert.ThrowsAsync(async () => await client.CompleteAsync( + reference, "arg", "val", TestContext.Current.CancellationToken)); + Assert.Contains("Prefer using 'McpClient.CompleteAsync' instead", ex.Message); } [Fact] - public async Task WithName_ChangesToolName() + public async Task ListToolsAsync_Throws_When_Not_McpClient() { - JsonSerializerOptions options = new(JsonSerializerOptions.Default); - await using IMcpClient client = await CreateMcpClientForServer(); + var client = new Mock(MockBehavior.Strict).Object; - var tool = (await client.ListToolsAsync(options, TestContext.Current.CancellationToken)).First(); - var originalName = tool.Name; - var renamedTool = tool.WithName("RenamedTool"); + var ex = await Assert.ThrowsAsync(async () => await client.ListToolsAsync( + cancellationToken: TestContext.Current.CancellationToken)); + Assert.Contains("Prefer using 'McpClient.ListToolsAsync' instead", ex.Message); + } - Assert.NotNull(renamedTool); - Assert.Equal("RenamedTool", renamedTool.Name); - Assert.Equal(originalName, tool?.Name); + [Fact] + public void EnumerateToolsAsync_Throws_When_Not_McpClient() + { + var client = new Mock(MockBehavior.Strict).Object; + + var ex = Assert.Throws(() => client.EnumerateToolsAsync(cancellationToken: TestContext.Current.CancellationToken)); + Assert.Contains("Prefer using 'McpClient.EnumerateToolsAsync' instead", ex.Message); } [Fact] - public async Task WithDescription_ChangesToolDescription() + public async Task ListPromptsAsync_Throws_When_Not_McpClient() { - JsonSerializerOptions options = new(JsonSerializerOptions.Default); - await using IMcpClient client = await CreateMcpClientForServer(); - var tool = (await client.ListToolsAsync(options, TestContext.Current.CancellationToken)).FirstOrDefault(); - var originalDescription = tool?.Description; - var redescribedTool = tool?.WithDescription("ToolWithNewDescription"); - Assert.NotNull(redescribedTool); - Assert.Equal("ToolWithNewDescription", redescribedTool.Description); - Assert.Equal(originalDescription, tool?.Description); + var client = new Mock(MockBehavior.Strict).Object; + + var ex = await Assert.ThrowsAsync(async () => await client.ListPromptsAsync( + cancellationToken: TestContext.Current.CancellationToken)); + Assert.Contains("Prefer using 'McpClient.ListPromptsAsync' instead", ex.Message); } [Fact] - public async Task WithProgress_ProgressReported() + public void EnumeratePromptsAsync_Throws_When_Not_McpClient() { - const int TotalNotifications = 3; - int remainingProgress = TotalNotifications; - TaskCompletionSource allProgressReceived = new(TaskCreationOptions.RunContinuationsAsynchronously); + var client = new Mock(MockBehavior.Strict).Object; - Server.ServerOptions.Capabilities?.Tools?.ToolCollection?.Add(McpServerTool.Create(async (IProgress progress) => - { - for (int i = 0; i < TotalNotifications; i++) + var ex = Assert.Throws(() => client.EnumeratePromptsAsync(cancellationToken: TestContext.Current.CancellationToken)); + Assert.Contains("Prefer using 'McpClient.EnumeratePromptsAsync' instead", ex.Message); + } + + [Fact] + public async Task ListResourceTemplatesAsync_Throws_When_Not_McpClient() + { + var client = new Mock(MockBehavior.Strict).Object; + + var ex = await Assert.ThrowsAsync(async () => await client.ListResourceTemplatesAsync( + cancellationToken: TestContext.Current.CancellationToken)); + Assert.Contains("Prefer using 'McpClient.ListResourceTemplatesAsync' instead", ex.Message); + } + + [Fact] + public void EnumerateResourceTemplatesAsync_Throws_When_Not_McpClient() + { + var client = new Mock(MockBehavior.Strict).Object; + + var ex = Assert.Throws(() => client.EnumerateResourceTemplatesAsync(cancellationToken: TestContext.Current.CancellationToken)); + Assert.Contains("Prefer using 'McpClient.EnumerateResourceTemplatesAsync' instead", ex.Message); + } + + [Fact] + public async Task PingAsync_Forwards_To_McpClient_SendRequestAsync() + { + var mockClient = new Mock { CallBase = true }; + + mockClient + .Setup(c => c.SendRequestAsync(It.IsAny(), It.IsAny())) + .ReturnsAsync(new JsonRpcResponse { - progress.Report(new ProgressNotificationValue { Progress = i * 10, Message = "making progress" }); - await Task.Delay(1); - } + Result = JsonSerializer.SerializeToNode(new object(), McpJsonUtilities.DefaultOptions), + }); - await allProgressReceived.Task; + IMcpClient client = mockClient.Object; - return 42; - }, new() { Name = "ProgressReporter" })); + await client.PingAsync(TestContext.Current.CancellationToken); - await using IMcpClient client = await CreateMcpClientForServer(); + mockClient.Verify(c => c.SendRequestAsync(It.IsAny(), It.IsAny()), Times.Once); + } + + [Fact] + public async Task GetPromptAsync_Forwards_To_McpClient_SendRequestAsync() + { + var mockClient = new Mock { CallBase = true }; - var tool = (await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken)).First(t => t.Name == "ProgressReporter"); + var resultPayload = new GetPromptResult { Messages = [new PromptMessage { Role = Role.User, Content = new TextContentBlock { Text = "hi" } }] }; - IProgress progress = new SynchronousProgress(value => - { - Assert.True(value.Progress >= 0 && value.Progress <= 100); - Assert.Equal("making progress", value.Message); - if (Interlocked.Decrement(ref remainingProgress) == 0) + mockClient + .Setup(c => c.SendRequestAsync(It.IsAny(), It.IsAny())) + .ReturnsAsync(new JsonRpcResponse { - allProgressReceived.SetResult(true); - } - }); + Result = JsonSerializer.SerializeToNode(resultPayload, McpJsonUtilities.DefaultOptions), + }); - Assert.Throws("progress", () => tool.WithProgress(null!)); + IMcpClient client = mockClient.Object; - var result = await tool.WithProgress(progress).InvokeAsync(cancellationToken: TestContext.Current.CancellationToken); - Assert.Contains("42", result?.ToString()); + var result = await client.GetPromptAsync("name", cancellationToken: TestContext.Current.CancellationToken); + + Assert.Equal("hi", Assert.IsType(result.Messages[0].Content).Text); + mockClient.Verify(c => c.SendRequestAsync(It.IsAny(), It.IsAny()), Times.Once); } - private sealed class SynchronousProgress(Action callback) : IProgress + [Fact] + public async Task CallToolAsync_Forwards_To_McpClient_SendRequestAsync() { - public void Report(ProgressNotificationValue value) => callback(value); + var mockClient = new Mock { CallBase = true }; + + var callResult = new CallToolResult { Content = [new TextContentBlock { Text = "ok" }] }; + + mockClient + .Setup(c => c.SendRequestAsync(It.IsAny(), It.IsAny())) + .ReturnsAsync(new JsonRpcResponse + { + Result = JsonSerializer.SerializeToNode(callResult, McpJsonUtilities.DefaultOptions), + }); + + IMcpClient client = mockClient.Object; + + var result = await client.CallToolAsync("tool", cancellationToken: TestContext.Current.CancellationToken); + + Assert.Equal("ok", Assert.IsType(result.Content[0]).Text); + mockClient.Verify(c => c.SendRequestAsync(It.IsAny(), It.IsAny()), Times.Once); } [Fact] - public async Task AsClientLoggerProvider_MessagesSentToClient() + public async Task SubscribeToResourceAsync_Forwards_To_McpClient_SendRequestAsync() { - await using IMcpClient client = await CreateMcpClientForServer(); - - ILoggerProvider loggerProvider = Server.AsClientLoggerProvider(); - Assert.Throws("categoryName", () => loggerProvider.CreateLogger(null!)); - - ILogger logger = loggerProvider.CreateLogger("TestLogger"); - Assert.NotNull(logger); - - Assert.Null(logger.BeginScope("")); - - Assert.Null(Server.LoggingLevel); - Assert.False(logger.IsEnabled(LogLevel.Trace)); - Assert.False(logger.IsEnabled(LogLevel.Debug)); - Assert.False(logger.IsEnabled(LogLevel.Information)); - Assert.False(logger.IsEnabled(LogLevel.Warning)); - Assert.False(logger.IsEnabled(LogLevel.Error)); - Assert.False(logger.IsEnabled(LogLevel.Critical)); - - await client.SetLoggingLevel(LoggingLevel.Info, TestContext.Current.CancellationToken); - - DateTime start = DateTime.UtcNow; - while (Server.LoggingLevel is null) - { - await Task.Delay(1, TestContext.Current.CancellationToken); - Assert.True(DateTime.UtcNow - start < TimeSpan.FromSeconds(10), "Timed out waiting for logging level to be set"); - } - - Assert.Equal(LoggingLevel.Info, Server.LoggingLevel); - Assert.False(logger.IsEnabled(LogLevel.Trace)); - Assert.False(logger.IsEnabled(LogLevel.Debug)); - Assert.True(logger.IsEnabled(LogLevel.Information)); - Assert.True(logger.IsEnabled(LogLevel.Warning)); - Assert.True(logger.IsEnabled(LogLevel.Error)); - Assert.True(logger.IsEnabled(LogLevel.Critical)); - - List data = []; - var channel = Channel.CreateUnbounded(); - - await using (client.RegisterNotificationHandler(NotificationMethods.LoggingMessageNotification, - (notification, cancellationToken) => + var mockClient = new Mock { CallBase = true }; + + mockClient + .Setup(c => c.SendRequestAsync(It.IsAny(), It.IsAny())) + .ReturnsAsync(new JsonRpcResponse { - Assert.True(channel.Writer.TryWrite(JsonSerializer.Deserialize(notification.Params, McpJsonUtilities.DefaultOptions))); - return default; - })) - { - logger.LogTrace("Trace {Message}", "message"); - logger.LogDebug("Debug {Message}", "message"); - logger.LogInformation("Information {Message}", "message"); - logger.LogWarning("Warning {Message}", "message"); - logger.LogError("Error {Message}", "message"); - logger.LogCritical("Critical {Message}", "message"); - - for (int i = 0; i < 4; i++) + Result = JsonSerializer.SerializeToNode(new EmptyResult(), McpJsonUtilities.DefaultOptions), + }); + + IMcpClient client = mockClient.Object; + + await client.SubscribeToResourceAsync("mcp://resource/1", TestContext.Current.CancellationToken); + + mockClient.Verify(c => c.SendRequestAsync(It.IsAny(), It.IsAny()), Times.Once); + } + + [Fact] + public async Task UnsubscribeFromResourceAsync_Forwards_To_McpClient_SendRequestAsync() + { + var mockClient = new Mock { CallBase = true }; + + mockClient + .Setup(c => c.SendRequestAsync(It.IsAny(), It.IsAny())) + .ReturnsAsync(new JsonRpcResponse + { + Result = JsonSerializer.SerializeToNode(new EmptyResult(), McpJsonUtilities.DefaultOptions), + }); + + IMcpClient client = mockClient.Object; + + await client.UnsubscribeFromResourceAsync("mcp://resource/1", TestContext.Current.CancellationToken); + + mockClient.Verify(c => c.SendRequestAsync(It.IsAny(), It.IsAny()), Times.Once); + } + + [Fact] + public async Task CompleteAsync_Forwards_To_McpClient_SendRequestAsync() + { + var mockClient = new Mock { CallBase = true }; + + var completion = new Completion { Values = ["one", "two"] }; + var resultPayload = new CompleteResult { Completion = completion }; + + mockClient + .Setup(c => c.SendRequestAsync(It.IsAny(), It.IsAny())) + .ReturnsAsync(new JsonRpcResponse + { + Result = JsonSerializer.SerializeToNode(resultPayload, McpJsonUtilities.DefaultOptions), + }); + + IMcpClient client = mockClient.Object; + + var result = await client.CompleteAsync(new PromptReference { Name = "p" }, "arg", "val", TestContext.Current.CancellationToken); + + Assert.Contains("one", result.Completion.Values); + mockClient.Verify(c => c.SendRequestAsync(It.IsAny(), It.IsAny()), Times.Once); + } + + [Fact] + public async Task ReadResourceAsync_String_Forwards_To_McpClient_SendRequestAsync() + { + var mockClient = new Mock { CallBase = true }; + + var resultPayload = new ReadResourceResult { Contents = [] }; + + mockClient + .Setup(c => c.SendRequestAsync(It.IsAny(), It.IsAny())) + .ReturnsAsync(new JsonRpcResponse + { + Result = JsonSerializer.SerializeToNode(resultPayload, McpJsonUtilities.DefaultOptions), + }); + + IMcpClient client = mockClient.Object; + + var result = await client.ReadResourceAsync("mcp://resource/1", TestContext.Current.CancellationToken); + + Assert.NotNull(result); + mockClient.Verify(c => c.SendRequestAsync(It.IsAny(), It.IsAny()), Times.Once); + } + + [Fact] + public async Task ReadResourceAsync_Uri_Forwards_To_McpClient_SendRequestAsync() + { + var mockClient = new Mock { CallBase = true }; + + var resultPayload = new ReadResourceResult { Contents = [] }; + + mockClient + .Setup(c => c.SendRequestAsync(It.IsAny(), It.IsAny())) + .ReturnsAsync(new JsonRpcResponse { - var m = await channel.Reader.ReadAsync(TestContext.Current.CancellationToken); - Assert.NotNull(m); - Assert.NotNull(m.Data); - - Assert.Equal("TestLogger", m.Logger); - - string ? s = JsonSerializer.Deserialize(m.Data.Value, McpJsonUtilities.DefaultOptions); - Assert.NotNull(s); - - if (s.Contains("Information")) - { - Assert.Equal(LoggingLevel.Info, m.Level); - } - else if (s.Contains("Warning")) - { - Assert.Equal(LoggingLevel.Warning, m.Level); - } - else if (s.Contains("Error")) - { - Assert.Equal(LoggingLevel.Error, m.Level); - } - else if (s.Contains("Critical")) - { - Assert.Equal(LoggingLevel.Critical, m.Level); - } - - data.Add(s); - } - - channel.Writer.Complete(); - } - - Assert.False(await channel.Reader.WaitToReadAsync(TestContext.Current.CancellationToken)); - Assert.Equal( - [ - "Critical message", - "Error message", - "Information message", - "Warning message", - ], - data.OrderBy(s => s)); + Result = JsonSerializer.SerializeToNode(resultPayload, McpJsonUtilities.DefaultOptions), + }); + + IMcpClient client = mockClient.Object; + + var result = await client.ReadResourceAsync(new Uri("mcp://resource/1"), TestContext.Current.CancellationToken); + + Assert.NotNull(result); + mockClient.Verify(c => c.SendRequestAsync(It.IsAny(), It.IsAny()), Times.Once); + } + + [Fact] + public async Task ReadResourceAsync_Template_Forwards_To_McpClient_SendRequestAsync() + { + var mockClient = new Mock { CallBase = true }; + + var resultPayload = new ReadResourceResult { Contents = [] }; + + mockClient + .Setup(c => c.SendRequestAsync(It.IsAny(), It.IsAny())) + .ReturnsAsync(new JsonRpcResponse + { + Result = JsonSerializer.SerializeToNode(resultPayload, McpJsonUtilities.DefaultOptions), + }); + + IMcpClient client = mockClient.Object; + + var result = await client.ReadResourceAsync("mcp://resource/{id}", new Dictionary { ["id"] = 1 }, TestContext.Current.CancellationToken); + + Assert.NotNull(result); + mockClient.Verify(c => c.SendRequestAsync(It.IsAny(), It.IsAny()), Times.Once); } -} \ No newline at end of file +} diff --git a/tests/ModelContextProtocol.Tests/Client/McpClientResourceTemplateTests.cs b/tests/ModelContextProtocol.Tests/Client/McpClientResourceTemplateTests.cs index 48c3c370d..2599d7485 100644 --- a/tests/ModelContextProtocol.Tests/Client/McpClientResourceTemplateTests.cs +++ b/tests/ModelContextProtocol.Tests/Client/McpClientResourceTemplateTests.cs @@ -73,7 +73,7 @@ public static IEnumerable UriTemplate_InputsProduceExpectedOutputs_Mem public async Task UriTemplate_InputsProduceExpectedOutputs( IReadOnlyDictionary variables, string uriTemplate, object expected) { - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClient client = await CreateMcpClientForServer(); var result = await client.ReadResourceAsync(uriTemplate, variables, TestContext.Current.CancellationToken); Assert.NotNull(result); diff --git a/tests/ModelContextProtocol.Tests/Client/McpClientTests.cs b/tests/ModelContextProtocol.Tests/Client/McpClientTests.cs new file mode 100644 index 000000000..779e31e62 --- /dev/null +++ b/tests/ModelContextProtocol.Tests/Client/McpClientTests.cs @@ -0,0 +1,471 @@ +using Microsoft.Extensions.AI; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; +using ModelContextProtocol.Client; +using ModelContextProtocol.Protocol; +using ModelContextProtocol.Server; +using Moq; +using System.Text.Json; +using System.Text.Json.Serialization.Metadata; +using System.Threading.Channels; + +namespace ModelContextProtocol.Tests.Client; + +public class McpClientTests : ClientServerTestBase +{ + public McpClientTests(ITestOutputHelper outputHelper) + : base(outputHelper) + { + } + + protected override void ConfigureServices(ServiceCollection services, IMcpServerBuilder mcpServerBuilder) + { + for (int f = 0; f < 10; f++) + { + string name = $"Method{f}"; + mcpServerBuilder.WithTools([McpServerTool.Create((int i) => $"{name} Result {i}", new() { Name = name })]); + } + mcpServerBuilder.WithTools([McpServerTool.Create([McpServerTool(Destructive = false, OpenWorld = true)] (string i) => $"{i} Result", new() { Name = "ValuesSetViaAttr" })]); + mcpServerBuilder.WithTools([McpServerTool.Create([McpServerTool(Destructive = false, OpenWorld = true)] (string i) => $"{i} Result", new() { Name = "ValuesSetViaOptions", Destructive = true, OpenWorld = false, ReadOnly = true })]); + } + + [Theory] + [InlineData(null, null)] + [InlineData(0.7f, 50)] + [InlineData(1.0f, 100)] + public async Task CreateSamplingHandler_ShouldHandleTextMessages(float? temperature, int? maxTokens) + { + // Arrange + var mockChatClient = new Mock(); + var requestParams = new CreateMessageRequestParams + { + Messages = + [ + new SamplingMessage + { + Role = Role.User, + Content = new TextContentBlock { Text = "Hello" } + } + ], + Temperature = temperature, + MaxTokens = maxTokens, + }; + + var cancellationToken = CancellationToken.None; + var expectedResponse = new[] { + new ChatResponseUpdate + { + ModelId = "test-model", + FinishReason = ChatFinishReason.Stop, + Role = ChatRole.Assistant, + Contents = + [ + new TextContent("Hello, World!") { RawRepresentation = "Hello, World!" } + ] + } + }.ToAsyncEnumerable(); + + mockChatClient + .Setup(client => client.GetStreamingResponseAsync(It.IsAny>(), It.IsAny(), cancellationToken)) + .Returns(expectedResponse); + + var handler = McpClientExtensions.CreateSamplingHandler(mockChatClient.Object); + + // Act + var result = await handler(requestParams, Mock.Of>(), cancellationToken); + + // Assert + Assert.NotNull(result); + Assert.Equal("Hello, World!", (result.Content as TextContentBlock)?.Text); + Assert.Equal("test-model", result.Model); + Assert.Equal(Role.Assistant, result.Role); + Assert.Equal("endTurn", result.StopReason); + } + + [Fact] + public async Task CreateSamplingHandler_ShouldHandleImageMessages() + { + // Arrange + var mockChatClient = new Mock(); + var requestParams = new CreateMessageRequestParams + { + Messages = + [ + new SamplingMessage + { + Role = Role.User, + Content = new ImageContentBlock + { + MimeType = "image/png", + Data = Convert.ToBase64String(new byte[] { 1, 2, 3 }) + } + } + ], + MaxTokens = 100 + }; + + const string expectedData = "SGVsbG8sIFdvcmxkIQ=="; + var cancellationToken = CancellationToken.None; + var expectedResponse = new[] { + new ChatResponseUpdate + { + ModelId = "test-model", + FinishReason = ChatFinishReason.Stop, + Role = ChatRole.Assistant, + Contents = + [ + new DataContent($"data:image/png;base64,{expectedData}") { RawRepresentation = "Hello, World!" } + ] + } + }.ToAsyncEnumerable(); + + mockChatClient + .Setup(client => client.GetStreamingResponseAsync(It.IsAny>(), It.IsAny(), cancellationToken)) + .Returns(expectedResponse); + + var handler = McpClientExtensions.CreateSamplingHandler(mockChatClient.Object); + + // Act + var result = await handler(requestParams, Mock.Of>(), cancellationToken); + + // Assert + Assert.NotNull(result); + Assert.Equal(expectedData, (result.Content as ImageContentBlock)?.Data); + Assert.Equal("test-model", result.Model); + Assert.Equal(Role.Assistant, result.Role); + Assert.Equal("endTurn", result.StopReason); + } + + [Fact] + public async Task CreateSamplingHandler_ShouldHandleResourceMessages() + { + // Arrange + const string data = "SGVsbG8sIFdvcmxkIQ=="; + string content = $"data:application/octet-stream;base64,{data}"; + var mockChatClient = new Mock(); + var resource = new BlobResourceContents + { + Blob = data, + MimeType = "application/octet-stream", + Uri = "data:application/octet-stream" + }; + + var requestParams = new CreateMessageRequestParams + { + Messages = + [ + new SamplingMessage + { + Role = Role.User, + Content = new EmbeddedResourceBlock { Resource = resource }, + } + ], + MaxTokens = 100 + }; + + var cancellationToken = CancellationToken.None; + var expectedResponse = new[] { + new ChatResponseUpdate + { + ModelId = "test-model", + FinishReason = ChatFinishReason.Stop, + AuthorName = "bot", + Role = ChatRole.Assistant, + Contents = + [ + resource.ToAIContent() + ] + } + }.ToAsyncEnumerable(); + + mockChatClient + .Setup(client => client.GetStreamingResponseAsync(It.IsAny>(), It.IsAny(), cancellationToken)) + .Returns(expectedResponse); + + var handler = McpClientExtensions.CreateSamplingHandler(mockChatClient.Object); + + // Act + var result = await handler(requestParams, Mock.Of>(), cancellationToken); + + // Assert + Assert.NotNull(result); + Assert.Equal("test-model", result.Model); + Assert.Equal(Role.Assistant, result.Role); + Assert.Equal("endTurn", result.StopReason); + } + + [Fact] + public async Task ListToolsAsync_AllToolsReturned() + { + await using McpClient client = await CreateMcpClientForServer(); + + var tools = await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken); + Assert.Equal(12, tools.Count); + var echo = tools.Single(t => t.Name == "Method4"); + var result = await echo.InvokeAsync(new() { ["i"] = 42 }, TestContext.Current.CancellationToken); + Assert.Contains("Method4 Result 42", result?.ToString()); + + var valuesSetViaAttr = tools.Single(t => t.Name == "ValuesSetViaAttr"); + Assert.Null(valuesSetViaAttr.ProtocolTool.Annotations?.Title); + Assert.Null(valuesSetViaAttr.ProtocolTool.Annotations?.ReadOnlyHint); + Assert.Null(valuesSetViaAttr.ProtocolTool.Annotations?.IdempotentHint); + Assert.False(valuesSetViaAttr.ProtocolTool.Annotations?.DestructiveHint); + Assert.True(valuesSetViaAttr.ProtocolTool.Annotations?.OpenWorldHint); + + var valuesSetViaOptions = tools.Single(t => t.Name == "ValuesSetViaOptions"); + Assert.Null(valuesSetViaOptions.ProtocolTool.Annotations?.Title); + Assert.True(valuesSetViaOptions.ProtocolTool.Annotations?.ReadOnlyHint); + Assert.Null(valuesSetViaOptions.ProtocolTool.Annotations?.IdempotentHint); + Assert.True(valuesSetViaOptions.ProtocolTool.Annotations?.DestructiveHint); + Assert.False(valuesSetViaOptions.ProtocolTool.Annotations?.OpenWorldHint); + } + + [Fact] + public async Task EnumerateToolsAsync_AllToolsReturned() + { + await using McpClient client = await CreateMcpClientForServer(); + + await foreach (var tool in client.EnumerateToolsAsync(cancellationToken: TestContext.Current.CancellationToken)) + { + if (tool.Name == "Method4") + { + var result = await tool.InvokeAsync(new() { ["i"] = 42 }, TestContext.Current.CancellationToken); + Assert.Contains("Method4 Result 42", result?.ToString()); + return; + } + } + + Assert.Fail("Couldn't find target method"); + } + + [Fact] + public async Task EnumerateToolsAsync_FlowsJsonSerializerOptions() + { + JsonSerializerOptions options = new(JsonSerializerOptions.Default); + await using McpClient client = await CreateMcpClientForServer(); + bool hasTools = false; + + await foreach (var tool in client.EnumerateToolsAsync(options, TestContext.Current.CancellationToken)) + { + Assert.Same(options, tool.JsonSerializerOptions); + hasTools = true; + } + + foreach (var tool in await client.ListToolsAsync(options, TestContext.Current.CancellationToken)) + { + Assert.Same(options, tool.JsonSerializerOptions); + } + + Assert.True(hasTools); + } + + [Fact] + public async Task EnumerateToolsAsync_HonorsJsonSerializerOptions() + { + JsonSerializerOptions emptyOptions = new() { TypeInfoResolver = JsonTypeInfoResolver.Combine() }; + await using McpClient client = await CreateMcpClientForServer(); + + var tool = (await client.ListToolsAsync(emptyOptions, TestContext.Current.CancellationToken)).First(); + await Assert.ThrowsAsync(async () => await tool.InvokeAsync(new() { ["i"] = 42 }, TestContext.Current.CancellationToken)); + } + + [Fact] + public async Task SendRequestAsync_HonorsJsonSerializerOptions() + { + JsonSerializerOptions emptyOptions = new() { TypeInfoResolver = JsonTypeInfoResolver.Combine() }; + await using McpClient client = await CreateMcpClientForServer(); + + await Assert.ThrowsAsync(async () => await client.SendRequestAsync("Method4", new() { Name = "tool" }, emptyOptions, cancellationToken: TestContext.Current.CancellationToken)); + } + + [Fact] + public async Task SendNotificationAsync_HonorsJsonSerializerOptions() + { + JsonSerializerOptions emptyOptions = new() { TypeInfoResolver = JsonTypeInfoResolver.Combine() }; + await using McpClient client = await CreateMcpClientForServer(); + + await Assert.ThrowsAsync(() => client.SendNotificationAsync("Method4", new { Value = 42 }, emptyOptions, cancellationToken: TestContext.Current.CancellationToken)); + } + + [Fact] + public async Task GetPromptsAsync_HonorsJsonSerializerOptions() + { + JsonSerializerOptions emptyOptions = new() { TypeInfoResolver = JsonTypeInfoResolver.Combine() }; + await using McpClient client = await CreateMcpClientForServer(); + + await Assert.ThrowsAsync(async () => await client.GetPromptAsync("Prompt", new Dictionary { ["i"] = 42 }, emptyOptions, cancellationToken: TestContext.Current.CancellationToken)); + } + + [Fact] + public async Task WithName_ChangesToolName() + { + JsonSerializerOptions options = new(JsonSerializerOptions.Default); + await using McpClient client = await CreateMcpClientForServer(); + + var tool = (await client.ListToolsAsync(options, TestContext.Current.CancellationToken)).First(); + var originalName = tool.Name; + var renamedTool = tool.WithName("RenamedTool"); + + Assert.NotNull(renamedTool); + Assert.Equal("RenamedTool", renamedTool.Name); + Assert.Equal(originalName, tool?.Name); + } + + [Fact] + public async Task WithDescription_ChangesToolDescription() + { + JsonSerializerOptions options = new(JsonSerializerOptions.Default); + await using McpClient client = await CreateMcpClientForServer(); + var tool = (await client.ListToolsAsync(options, TestContext.Current.CancellationToken)).FirstOrDefault(); + var originalDescription = tool?.Description; + var redescribedTool = tool?.WithDescription("ToolWithNewDescription"); + Assert.NotNull(redescribedTool); + Assert.Equal("ToolWithNewDescription", redescribedTool.Description); + Assert.Equal(originalDescription, tool?.Description); + } + + [Fact] + public async Task WithProgress_ProgressReported() + { + const int TotalNotifications = 3; + int remainingProgress = TotalNotifications; + TaskCompletionSource allProgressReceived = new(TaskCreationOptions.RunContinuationsAsynchronously); + + Server.ServerOptions.Capabilities?.Tools?.ToolCollection?.Add(McpServerTool.Create(async (IProgress progress) => + { + for (int i = 0; i < TotalNotifications; i++) + { + progress.Report(new ProgressNotificationValue { Progress = i * 10, Message = "making progress" }); + await Task.Delay(1); + } + + await allProgressReceived.Task; + + return 42; + }, new() { Name = "ProgressReporter" })); + + await using McpClient client = await CreateMcpClientForServer(); + + var tool = (await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken)).First(t => t.Name == "ProgressReporter"); + + IProgress progress = new SynchronousProgress(value => + { + Assert.True(value.Progress >= 0 && value.Progress <= 100); + Assert.Equal("making progress", value.Message); + if (Interlocked.Decrement(ref remainingProgress) == 0) + { + allProgressReceived.SetResult(true); + } + }); + + Assert.Throws("progress", () => tool.WithProgress(null!)); + + var result = await tool.WithProgress(progress).InvokeAsync(cancellationToken: TestContext.Current.CancellationToken); + Assert.Contains("42", result?.ToString()); + } + + private sealed class SynchronousProgress(Action callback) : IProgress + { + public void Report(ProgressNotificationValue value) => callback(value); + } + + [Fact] + public async Task AsClientLoggerProvider_MessagesSentToClient() + { + await using McpClient client = await CreateMcpClientForServer(); + + ILoggerProvider loggerProvider = Server.AsClientLoggerProvider(); + Assert.Throws("categoryName", () => loggerProvider.CreateLogger(null!)); + + ILogger logger = loggerProvider.CreateLogger("TestLogger"); + Assert.NotNull(logger); + + Assert.Null(logger.BeginScope("")); + + Assert.Null(Server.LoggingLevel); + Assert.False(logger.IsEnabled(LogLevel.Trace)); + Assert.False(logger.IsEnabled(LogLevel.Debug)); + Assert.False(logger.IsEnabled(LogLevel.Information)); + Assert.False(logger.IsEnabled(LogLevel.Warning)); + Assert.False(logger.IsEnabled(LogLevel.Error)); + Assert.False(logger.IsEnabled(LogLevel.Critical)); + + await client.SetLoggingLevel(LoggingLevel.Info, TestContext.Current.CancellationToken); + + DateTime start = DateTime.UtcNow; + while (Server.LoggingLevel is null) + { + await Task.Delay(1, TestContext.Current.CancellationToken); + Assert.True(DateTime.UtcNow - start < TimeSpan.FromSeconds(10), "Timed out waiting for logging level to be set"); + } + + Assert.Equal(LoggingLevel.Info, Server.LoggingLevel); + Assert.False(logger.IsEnabled(LogLevel.Trace)); + Assert.False(logger.IsEnabled(LogLevel.Debug)); + Assert.True(logger.IsEnabled(LogLevel.Information)); + Assert.True(logger.IsEnabled(LogLevel.Warning)); + Assert.True(logger.IsEnabled(LogLevel.Error)); + Assert.True(logger.IsEnabled(LogLevel.Critical)); + + List data = []; + var channel = Channel.CreateUnbounded(); + + await using (client.RegisterNotificationHandler(NotificationMethods.LoggingMessageNotification, + (notification, cancellationToken) => + { + Assert.True(channel.Writer.TryWrite(JsonSerializer.Deserialize(notification.Params, McpJsonUtilities.DefaultOptions))); + return default; + })) + { + logger.LogTrace("Trace {Message}", "message"); + logger.LogDebug("Debug {Message}", "message"); + logger.LogInformation("Information {Message}", "message"); + logger.LogWarning("Warning {Message}", "message"); + logger.LogError("Error {Message}", "message"); + logger.LogCritical("Critical {Message}", "message"); + + for (int i = 0; i < 4; i++) + { + var m = await channel.Reader.ReadAsync(TestContext.Current.CancellationToken); + Assert.NotNull(m); + Assert.NotNull(m.Data); + + Assert.Equal("TestLogger", m.Logger); + + string ? s = JsonSerializer.Deserialize(m.Data.Value, McpJsonUtilities.DefaultOptions); + Assert.NotNull(s); + + if (s.Contains("Information")) + { + Assert.Equal(LoggingLevel.Info, m.Level); + } + else if (s.Contains("Warning")) + { + Assert.Equal(LoggingLevel.Warning, m.Level); + } + else if (s.Contains("Error")) + { + Assert.Equal(LoggingLevel.Error, m.Level); + } + else if (s.Contains("Critical")) + { + Assert.Equal(LoggingLevel.Critical, m.Level); + } + + data.Add(s); + } + + channel.Writer.Complete(); + } + + Assert.False(await channel.Reader.WaitToReadAsync(TestContext.Current.CancellationToken)); + Assert.Equal( + [ + "Critical message", + "Error message", + "Information message", + "Warning message", + ], + data.OrderBy(s => s)); + } +} \ No newline at end of file diff --git a/tests/ModelContextProtocol.Tests/ClientIntegrationTestFixture.cs b/tests/ModelContextProtocol.Tests/ClientIntegrationTestFixture.cs index ebc7171e2..6f625866a 100644 --- a/tests/ModelContextProtocol.Tests/ClientIntegrationTestFixture.cs +++ b/tests/ModelContextProtocol.Tests/ClientIntegrationTestFixture.cs @@ -41,8 +41,8 @@ public void Initialize(ILoggerFactory loggerFactory) _loggerFactory = loggerFactory; } - public Task CreateClientAsync(string clientId, McpClientOptions? clientOptions = null) => - McpClientFactory.CreateAsync(new StdioClientTransport(clientId switch + public Task CreateClientAsync(string clientId, McpClientOptions? clientOptions = null) => + McpClient.CreateAsync(new StdioClientTransport(clientId switch { "everything" => EverythingServerTransportOptions, "test_server" => TestServerTransportOptions, diff --git a/tests/ModelContextProtocol.Tests/ClientIntegrationTests.cs b/tests/ModelContextProtocol.Tests/ClientIntegrationTests.cs index 3e4361a57..211688419 100644 --- a/tests/ModelContextProtocol.Tests/ClientIntegrationTests.cs +++ b/tests/ModelContextProtocol.Tests/ClientIntegrationTests.cs @@ -471,7 +471,7 @@ public async Task CallTool_Stdio_MemoryServer() ClientInfo = new() { Name = "IntegrationTestClient", Version = "1.0.0" } }; - await using var client = await McpClientFactory.CreateAsync( + await using var client = await McpClient.CreateAsync( new StdioClientTransport(stdioOptions), clientOptions, loggerFactory: LoggerFactory, @@ -495,7 +495,7 @@ public async Task CallTool_Stdio_MemoryServer() public async Task ListToolsAsync_UsingEverythingServer_ToolsAreProperlyCalled() { // Get the MCP client and tools from it. - await using var client = await McpClientFactory.CreateAsync( + await using var client = await McpClient.CreateAsync( new StdioClientTransport(_fixture.EverythingServerTransportOptions), cancellationToken: TestContext.Current.CancellationToken); var mappedTools = await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken); @@ -527,7 +527,7 @@ public async Task SamplingViaChatClient_RequestResponseProperlyPropagated() var samplingHandler = new OpenAIClient(s_openAIKey).GetChatClient("gpt-4o-mini") .AsIChatClient() .CreateSamplingHandler(); - await using var client = await McpClientFactory.CreateAsync(new StdioClientTransport(_fixture.EverythingServerTransportOptions), new() + await using var client = await McpClient.CreateAsync(new StdioClientTransport(_fixture.EverythingServerTransportOptions), new() { Capabilities = new() { diff --git a/tests/ModelContextProtocol.Tests/ClientServerTestBase.cs b/tests/ModelContextProtocol.Tests/ClientServerTestBase.cs index d9b699b98..ff04c3b19 100644 --- a/tests/ModelContextProtocol.Tests/ClientServerTestBase.cs +++ b/tests/ModelContextProtocol.Tests/ClientServerTestBase.cs @@ -29,11 +29,11 @@ public ClientServerTestBase(ITestOutputHelper testOutputHelper) ServiceProvider = sc.BuildServiceProvider(validateScopes: true); _cts = CancellationTokenSource.CreateLinkedTokenSource(TestContext.Current.CancellationToken); - Server = ServiceProvider.GetRequiredService(); + Server = ServiceProvider.GetRequiredService(); _serverTask = Server.RunAsync(_cts.Token); } - protected IMcpServer Server { get; } + protected McpServer Server { get; } protected IServiceProvider ServiceProvider { get; } @@ -63,9 +63,9 @@ public async ValueTask DisposeAsync() Dispose(); } - protected async Task CreateMcpClientForServer(McpClientOptions? clientOptions = null) + protected async Task CreateMcpClientForServer(McpClientOptions? clientOptions = null) { - return await McpClientFactory.CreateAsync( + return await McpClient.CreateAsync( new StreamClientTransport( serverInput: _clientToServerPipe.Writer.AsStream(), _serverToClientPipe.Reader.AsStream(), diff --git a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsFilterTests.cs b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsFilterTests.cs index 6a7d0044d..00e67c247 100644 --- a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsFilterTests.cs +++ b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsFilterTests.cs @@ -129,7 +129,7 @@ protected override void ConfigureServices(ServiceCollection services, IMcpServer [Fact] public async Task AddListResourceTemplatesFilter_Logs_When_ListResourceTemplates_Called() { - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClient client = await CreateMcpClientForServer(); await client.ListResourceTemplatesAsync(cancellationToken: TestContext.Current.CancellationToken); @@ -141,7 +141,7 @@ public async Task AddListResourceTemplatesFilter_Logs_When_ListResourceTemplates [Fact] public async Task AddListToolsFilter_Logs_When_ListTools_Called() { - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClient client = await CreateMcpClientForServer(); await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken); @@ -153,7 +153,7 @@ public async Task AddListToolsFilter_Logs_When_ListTools_Called() [Fact] public async Task AddCallToolFilter_Logs_When_CallTool_Called() { - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClient client = await CreateMcpClientForServer(); await client.CallToolAsync("test_tool_method", cancellationToken: TestContext.Current.CancellationToken); @@ -165,7 +165,7 @@ public async Task AddCallToolFilter_Logs_When_CallTool_Called() [Fact] public async Task AddListPromptsFilter_Logs_When_ListPrompts_Called() { - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClient client = await CreateMcpClientForServer(); await client.ListPromptsAsync(cancellationToken: TestContext.Current.CancellationToken); @@ -177,7 +177,7 @@ public async Task AddListPromptsFilter_Logs_When_ListPrompts_Called() [Fact] public async Task AddGetPromptFilter_Logs_When_GetPrompt_Called() { - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClient client = await CreateMcpClientForServer(); await client.GetPromptAsync("test_prompt_method", cancellationToken: TestContext.Current.CancellationToken); @@ -189,7 +189,7 @@ public async Task AddGetPromptFilter_Logs_When_GetPrompt_Called() [Fact] public async Task AddListResourcesFilter_Logs_When_ListResources_Called() { - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClient client = await CreateMcpClientForServer(); await client.ListResourcesAsync(cancellationToken: TestContext.Current.CancellationToken); @@ -201,7 +201,7 @@ public async Task AddListResourcesFilter_Logs_When_ListResources_Called() [Fact] public async Task AddReadResourceFilter_Logs_When_ReadResource_Called() { - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClient client = await CreateMcpClientForServer(); await client.ReadResourceAsync("test://resource/123", cancellationToken: TestContext.Current.CancellationToken); @@ -213,7 +213,7 @@ public async Task AddReadResourceFilter_Logs_When_ReadResource_Called() [Fact] public async Task AddCompleteFilter_Logs_When_Complete_Called() { - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClient client = await CreateMcpClientForServer(); var reference = new PromptReference { Name = "test_prompt_method" }; await client.CompleteAsync(reference, "argument", "value", cancellationToken: TestContext.Current.CancellationToken); @@ -226,7 +226,7 @@ public async Task AddCompleteFilter_Logs_When_Complete_Called() [Fact] public async Task AddSubscribeToResourcesFilter_Logs_When_SubscribeToResources_Called() { - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClient client = await CreateMcpClientForServer(); await client.SubscribeToResourceAsync("test://resource/123", cancellationToken: TestContext.Current.CancellationToken); @@ -238,7 +238,7 @@ public async Task AddSubscribeToResourcesFilter_Logs_When_SubscribeToResources_C [Fact] public async Task AddUnsubscribeFromResourcesFilter_Logs_When_UnsubscribeFromResources_Called() { - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClient client = await CreateMcpClientForServer(); await client.UnsubscribeFromResourceAsync("test://resource/123", cancellationToken: TestContext.Current.CancellationToken); @@ -250,7 +250,7 @@ public async Task AddUnsubscribeFromResourcesFilter_Logs_When_UnsubscribeFromRes [Fact] public async Task AddSetLoggingLevelFilter_Logs_When_SetLoggingLevel_Called() { - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClient client = await CreateMcpClientForServer(); await client.SetLoggingLevel(LoggingLevel.Info, cancellationToken: TestContext.Current.CancellationToken); @@ -262,7 +262,7 @@ public async Task AddSetLoggingLevelFilter_Logs_When_SetLoggingLevel_Called() [Fact] public async Task AddListToolsFilter_Multiple_Filters_Log_In_Expected_Order() { - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClient client = await CreateMcpClientForServer(); await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken); diff --git a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsPromptsTests.cs b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsPromptsTests.cs index 38ef9ab5d..1aea56193 100644 --- a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsPromptsTests.cs +++ b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsPromptsTests.cs @@ -98,7 +98,7 @@ public void Adds_Prompts_To_Server() [Fact] public async Task Can_List_And_Call_Registered_Prompts() { - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClient client = await CreateMcpClientForServer(); var prompts = await client.ListPromptsAsync(TestContext.Current.CancellationToken); Assert.Equal(6, prompts.Count); @@ -127,7 +127,7 @@ public async Task Can_List_And_Call_Registered_Prompts() [Fact] public async Task Can_Be_Notified_Of_Prompt_Changes() { - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClient client = await CreateMcpClientForServer(); var prompts = await client.ListPromptsAsync(TestContext.Current.CancellationToken); Assert.Equal(6, prompts.Count); @@ -168,7 +168,7 @@ public async Task Can_Be_Notified_Of_Prompt_Changes() [Fact] public async Task TitleAttributeProperty_PropagatedToTitle() { - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClient client = await CreateMcpClientForServer(); var prompts = await client.ListPromptsAsync(cancellationToken: TestContext.Current.CancellationToken); Assert.NotNull(prompts); @@ -182,7 +182,7 @@ public async Task TitleAttributeProperty_PropagatedToTitle() [Fact] public async Task Throws_When_Prompt_Fails() { - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClient client = await CreateMcpClientForServer(); await Assert.ThrowsAsync(async () => await client.GetPromptAsync( nameof(SimplePrompts.ThrowsException), @@ -192,7 +192,7 @@ await Assert.ThrowsAsync(async () => await client.GetPromptAsync( [Fact] public async Task Throws_Exception_On_Unknown_Prompt() { - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClient client = await CreateMcpClientForServer(); var e = await Assert.ThrowsAsync(async () => await client.GetPromptAsync( "NotRegisteredPrompt", @@ -204,7 +204,7 @@ public async Task Throws_Exception_On_Unknown_Prompt() [Fact] public async Task Throws_Exception_Missing_Parameter() { - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClient client = await CreateMcpClientForServer(); var e = await Assert.ThrowsAsync(async () => await client.GetPromptAsync( "returns_chat_messages", @@ -238,7 +238,7 @@ public async Task WithPrompts_TargetInstance_UsesTarget() sc.AddMcpServer().WithPrompts(target); McpServerPrompt prompt = sc.BuildServiceProvider().GetServices().First(t => t.ProtocolPrompt.Name == "returns_string"); - var result = await prompt.GetAsync(new RequestContext(new Mock().Object, new JsonRpcRequest { Method = "test", Id = new RequestId("1") }) + var result = await prompt.GetAsync(new RequestContext(new Mock().Object, new JsonRpcRequest { Method = "test", Id = new RequestId("1") }) { Params = new GetPromptRequestParams { diff --git a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsResourcesTests.cs b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsResourcesTests.cs index c95fd7671..47f2b224f 100644 --- a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsResourcesTests.cs +++ b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsResourcesTests.cs @@ -126,7 +126,7 @@ public void Adds_Resources_To_Server() [Fact] public async Task Can_List_And_Call_Registered_Resources() { - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClient client = await CreateMcpClientForServer(); Assert.NotNull(client.ServerCapabilities.Resources); @@ -145,7 +145,7 @@ public async Task Can_List_And_Call_Registered_Resources() [Fact] public async Task Can_List_And_Call_Registered_ResourceTemplates() { - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClient client = await CreateMcpClientForServer(); var resources = await client.ListResourceTemplatesAsync(TestContext.Current.CancellationToken); Assert.Equal(3, resources.Count); @@ -162,7 +162,7 @@ public async Task Can_List_And_Call_Registered_ResourceTemplates() [Fact] public async Task Can_Be_Notified_Of_Resource_Changes() { - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClient client = await CreateMcpClientForServer(); var resources = await client.ListResourcesAsync(TestContext.Current.CancellationToken); Assert.Equal(5, resources.Count); @@ -203,7 +203,7 @@ public async Task Can_Be_Notified_Of_Resource_Changes() [Fact] public async Task TitleAttributeProperty_PropagatedToTitle() { - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClient client = await CreateMcpClientForServer(); var resources = await client.ListResourcesAsync(cancellationToken: TestContext.Current.CancellationToken); Assert.NotNull(resources); @@ -221,7 +221,7 @@ public async Task TitleAttributeProperty_PropagatedToTitle() [Fact] public async Task Throws_When_Resource_Fails() { - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClient client = await CreateMcpClientForServer(); await Assert.ThrowsAsync(async () => await client.ReadResourceAsync( $"resource://mcp/{nameof(SimpleResources.ThrowsException)}", @@ -231,7 +231,7 @@ await Assert.ThrowsAsync(async () => await client.ReadResourceAsyn [Fact] public async Task Throws_Exception_On_Unknown_Resource() { - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClient client = await CreateMcpClientForServer(); var e = await Assert.ThrowsAsync(async () => await client.ReadResourceAsync( "test:///NotRegisteredResource", @@ -265,7 +265,7 @@ public async Task WithResources_TargetInstance_UsesTarget() sc.AddMcpServer().WithResources(target); McpServerResource resource = sc.BuildServiceProvider().GetServices().First(t => t.ProtocolResource?.Name == "returns_string"); - var result = await resource.ReadAsync(new RequestContext(new Mock().Object, new JsonRpcRequest { Method = "test", Id = new RequestId("1") }) + var result = await resource.ReadAsync(new RequestContext(new Mock().Object, new JsonRpcRequest { Method = "test", Id = new RequestId("1") }) { Params = new() { diff --git a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs index 6313480f3..a581a81df 100644 --- a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs +++ b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs @@ -127,7 +127,7 @@ public void Adds_Tools_To_Server() [Fact] public async Task Can_List_Registered_Tools() { - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClient client = await CreateMcpClientForServer(); var tools = await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken); Assert.Equal(16, tools.Count); @@ -156,10 +156,10 @@ public async Task Can_Create_Multiple_Servers_From_Options_And_List_Registered_T var stdoutPipe = new Pipe(); await using var transport = new StreamServerTransport(stdinPipe.Reader.AsStream(), stdoutPipe.Writer.AsStream()); - await using var server = McpServerFactory.Create(transport, options, loggerFactory, ServiceProvider); + await using var server = McpServer.Create(transport, options, loggerFactory, ServiceProvider); var serverRunTask = server.RunAsync(TestContext.Current.CancellationToken); - await using (var client = await McpClientFactory.CreateAsync( + await using (var client = await McpClient.CreateAsync( new StreamClientTransport( serverInput: stdinPipe.Writer.AsStream(), serverOutput: stdoutPipe.Reader.AsStream(), @@ -191,7 +191,7 @@ public async Task Can_Create_Multiple_Servers_From_Options_And_List_Registered_T [Fact] public async Task Can_Be_Notified_Of_Tool_Changes() { - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClient client = await CreateMcpClientForServer(); var tools = await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken); Assert.Equal(16, tools.Count); @@ -232,7 +232,7 @@ public async Task Can_Be_Notified_Of_Tool_Changes() [Fact] public async Task Can_Call_Registered_Tool() { - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClient client = await CreateMcpClientForServer(); var result = await client.CallToolAsync( "echo", @@ -251,7 +251,7 @@ public async Task Can_Call_Registered_Tool() [Fact] public async Task Can_Call_Registered_Tool_With_Array_Result() { - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClient client = await CreateMcpClientForServer(); var result = await client.CallToolAsync( "echo_array", @@ -274,7 +274,7 @@ public async Task Can_Call_Registered_Tool_With_Array_Result() [Fact] public async Task Can_Call_Registered_Tool_With_Null_Result() { - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClient client = await CreateMcpClientForServer(); var result = await client.CallToolAsync( "return_null", @@ -288,7 +288,7 @@ public async Task Can_Call_Registered_Tool_With_Null_Result() [Fact] public async Task Can_Call_Registered_Tool_With_Json_Result() { - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClient client = await CreateMcpClientForServer(); var result = await client.CallToolAsync( "return_json", @@ -305,7 +305,7 @@ public async Task Can_Call_Registered_Tool_With_Json_Result() [Fact] public async Task Can_Call_Registered_Tool_With_Int_Result() { - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClient client = await CreateMcpClientForServer(); var result = await client.CallToolAsync( "return_integer", @@ -320,7 +320,7 @@ public async Task Can_Call_Registered_Tool_With_Int_Result() [Fact] public async Task Can_Call_Registered_Tool_And_Pass_ComplexType() { - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClient client = await CreateMcpClientForServer(); var result = await client.CallToolAsync( "echo_complex", @@ -337,7 +337,7 @@ public async Task Can_Call_Registered_Tool_And_Pass_ComplexType() [Fact] public async Task Can_Call_Registered_Tool_With_Instance_Method() { - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClient client = await CreateMcpClientForServer(); string[][] parts = new string[2][]; for (int i = 0; i < 2; i++) @@ -366,7 +366,7 @@ public async Task Can_Call_Registered_Tool_With_Instance_Method() [Fact] public async Task Returns_IsError_Content_And_Logs_Error_When_Tool_Fails() { - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClient client = await CreateMcpClientForServer(); var result = await client.CallToolAsync( "throw_exception", @@ -386,7 +386,7 @@ public async Task Returns_IsError_Content_And_Logs_Error_When_Tool_Fails() [Fact] public async Task Throws_Exception_On_Unknown_Tool() { - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClient client = await CreateMcpClientForServer(); var e = await Assert.ThrowsAsync(async () => await client.CallToolAsync( "NotRegisteredTool", @@ -398,7 +398,7 @@ public async Task Throws_Exception_On_Unknown_Tool() [Fact] public async Task Returns_IsError_Missing_Parameter() { - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClient client = await CreateMcpClientForServer(); var result = await client.CallToolAsync( "echo", @@ -525,7 +525,7 @@ public async Task WithTools_TargetInstance_UsesTarget() sc.AddMcpServer().WithTools(target, BuilderToolsJsonContext.Default.Options); McpServerTool tool = sc.BuildServiceProvider().GetServices().First(t => t.ProtocolTool.Name == "get_ctor_parameter"); - var result = await tool.InvokeAsync(new RequestContext(new Mock().Object, new JsonRpcRequest { Method = "test", Id = new RequestId("1") }), TestContext.Current.CancellationToken); + var result = await tool.InvokeAsync(new RequestContext(new Mock().Object, new JsonRpcRequest { Method = "test", Id = new RequestId("1") }), TestContext.Current.CancellationToken); Assert.Equal(target.GetCtorParameter(), (result.Content[0] as TextContentBlock)?.Text); } @@ -557,7 +557,7 @@ public IEnumerator GetEnumerator() [Fact] public async Task Recognizes_Parameter_Types() { - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClient client = await CreateMcpClientForServer(); var tools = await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken); @@ -632,7 +632,7 @@ public void Create_ExtractsToolAnnotations_SomeSet() [Fact] public async Task TitleAttributeProperty_PropagatedToTitle() { - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClient client = await CreateMcpClientForServer(); var tools = await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken); Assert.NotNull(tools); @@ -648,7 +648,7 @@ public async Task TitleAttributeProperty_PropagatedToTitle() [Fact] public async Task HandlesIProgressParameter() { - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClient client = await CreateMcpClientForServer(); var tools = await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken); Assert.NotNull(tools); @@ -702,7 +702,7 @@ public async Task HandlesIProgressParameter() [Fact] public async Task CancellationNotificationsPropagateToToolTokens() { - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClient client = await CreateMcpClientForServer(); var tools = await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken); Assert.NotNull(tools); diff --git a/tests/ModelContextProtocol.Tests/Configuration/McpServerScopedTests.cs b/tests/ModelContextProtocol.Tests/Configuration/McpServerScopedTests.cs index b940c1c7c..5ddc3c54a 100644 --- a/tests/ModelContextProtocol.Tests/Configuration/McpServerScopedTests.cs +++ b/tests/ModelContextProtocol.Tests/Configuration/McpServerScopedTests.cs @@ -22,7 +22,7 @@ protected override void ConfigureServices(ServiceCollection services, IMcpServer [Fact] public async Task InjectScopedServiceAsArgument() { - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClient client = await CreateMcpClientForServer(); var tools = await client.ListToolsAsync(McpServerScopedTestsJsonContext.Default.Options, TestContext.Current.CancellationToken); var tool = tools.First(t => t.Name == "echo_complex"); diff --git a/tests/ModelContextProtocol.Tests/DiagnosticTests.cs b/tests/ModelContextProtocol.Tests/DiagnosticTests.cs index 116c62a15..5ad30d282 100644 --- a/tests/ModelContextProtocol.Tests/DiagnosticTests.cs +++ b/tests/ModelContextProtocol.Tests/DiagnosticTests.cs @@ -128,7 +128,7 @@ await RunConnected(async (client, server) => Assert.Equal("-32602", doesNotExistToolClient.Tags.Single(t => t.Key == "rpc.jsonrpc.error_code").Value); } - private static async Task RunConnected(Func action, List clientToServerLog) + private static async Task RunConnected(Func action, List clientToServerLog) { Pipe clientToServerPipe = new(), serverToClientPipe = new(); StreamServerTransport serverTransport = new(clientToServerPipe.Reader.AsStream(), serverToClientPipe.Writer.AsStream()); @@ -137,7 +137,7 @@ private static async Task RunConnected(Func action Task serverTask; - await using (IMcpServer server = McpServerFactory.Create(serverTransport, new() + await using (McpServer server = McpServer.Create(serverTransport, new() { Capabilities = new() { @@ -153,7 +153,7 @@ private static async Task RunConnected(Func action { serverTask = server.RunAsync(TestContext.Current.CancellationToken); - await using (IMcpClient client = await McpClientFactory.CreateAsync( + await using (McpClient client = await McpClient.CreateAsync( clientTransport, cancellationToken: TestContext.Current.CancellationToken)) { diff --git a/tests/ModelContextProtocol.Tests/DockerEverythingServerTests.cs b/tests/ModelContextProtocol.Tests/DockerEverythingServerTests.cs index ffd95076f..842371f88 100644 --- a/tests/ModelContextProtocol.Tests/DockerEverythingServerTests.cs +++ b/tests/ModelContextProtocol.Tests/DockerEverythingServerTests.cs @@ -36,15 +36,15 @@ public async Task ConnectAndReceiveMessage_EverythingServerWithSse() ClientInfo = new() { Name = "IntegrationTestClient", Version = "1.0.0" } }; - var defaultConfig = new SseClientTransportOptions + var defaultConfig = new HttpClientTransportOptions { Endpoint = new Uri($"http://localhost:{port}/sse"), Name = "Everything", }; // Create client and run tests - await using var client = await McpClientFactory.CreateAsync( - new SseClientTransport(defaultConfig), + await using var client = await McpClient.CreateAsync( + new HttpClientTransport(defaultConfig), defaultOptions, loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken); @@ -63,7 +63,7 @@ public async Task Sampling_Sse_EverythingServer() await using var fixture = new EverythingSseServerFixture(port); await fixture.StartAsync(); - var defaultConfig = new SseClientTransportOptions + var defaultConfig = new HttpClientTransportOptions { Endpoint = new Uri($"http://localhost:{port}/sse"), Name = "Everything", @@ -90,8 +90,8 @@ public async Task Sampling_Sse_EverythingServer() }, }; - await using var client = await McpClientFactory.CreateAsync( - new SseClientTransport(defaultConfig), + await using var client = await McpClient.CreateAsync( + new HttpClientTransport(defaultConfig), defaultOptions, loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken); diff --git a/tests/ModelContextProtocol.Tests/McpEndpointExtensionsTests.cs b/tests/ModelContextProtocol.Tests/McpEndpointExtensionsTests.cs new file mode 100644 index 000000000..613c703c3 --- /dev/null +++ b/tests/ModelContextProtocol.Tests/McpEndpointExtensionsTests.cs @@ -0,0 +1,118 @@ +using ModelContextProtocol.Protocol; +using Moq; +using System.Text.Json; + +namespace ModelContextProtocol.Tests; + +#pragma warning disable CS0618 // Type or member is obsolete + +public class McpEndpointExtensionsTests +{ + [Fact] + public async Task SendRequestAsync_Generic_Throws_When_Not_McpSession() + { + var endpoint = new Mock(MockBehavior.Strict).Object; + + var ex = await Assert.ThrowsAsync(async () => await McpEndpointExtensions.SendRequestAsync( + endpoint, "method", "param", cancellationToken: TestContext.Current.CancellationToken)); + Assert.Contains("Prefer using 'McpServer.SendRequestAsync' instead", ex.Message); + } + + [Fact] + public async Task SendNotificationAsync_Parameterless_Throws_When_Not_McpSession() + { + var endpoint = new Mock(MockBehavior.Strict).Object; + + var ex = await Assert.ThrowsAsync(async () => await McpEndpointExtensions.SendNotificationAsync( + endpoint, "notify", cancellationToken: TestContext.Current.CancellationToken)); + Assert.Contains("Prefer using 'McpServer.SendNotificationAsync' instead", ex.Message); + } + + [Fact] + public async Task SendNotificationAsync_Generic_Throws_When_Not_McpSession() + { + var endpoint = new Mock(MockBehavior.Strict).Object; + + var ex = await Assert.ThrowsAsync(async () => await McpEndpointExtensions.SendNotificationAsync( + endpoint, "notify", "payload", cancellationToken: TestContext.Current.CancellationToken)); + Assert.Contains("Prefer using 'McpServer.SendNotificationAsync' instead", ex.Message); + } + + [Fact] + public async Task NotifyProgressAsync_Throws_When_Not_McpSession() + { + var endpoint = new Mock(MockBehavior.Strict).Object; + + var ex = await Assert.ThrowsAsync(async () => await McpEndpointExtensions.NotifyProgressAsync( + endpoint, new ProgressToken("t1"), new ProgressNotificationValue { Progress = 0.5f }, cancellationToken: TestContext.Current.CancellationToken)); + Assert.Contains("Prefer using 'McpServer.NotifyProgressAsync' instead", ex.Message); + } + + [Fact] + public async Task SendRequestAsync_Generic_Forwards_To_McpSession_SendRequestAsync() + { + var mockSession = new Mock { CallBase = true }; + + mockSession + .Setup(s => s.SendRequestAsync(It.IsAny(), It.IsAny())) + .ReturnsAsync(new JsonRpcResponse + { + Result = JsonSerializer.SerializeToNode(42, McpJsonUtilities.DefaultOptions), + }); + + IMcpEndpoint endpoint = mockSession.Object; + + var result = await endpoint.SendRequestAsync("method", "param", cancellationToken: TestContext.Current.CancellationToken); + + Assert.Equal(42, result); + mockSession.Verify(s => s.SendRequestAsync(It.IsAny(), It.IsAny()), Times.Once); + } + + [Fact] + public async Task SendNotificationAsync_Parameterless_Forwards_To_McpSession_SendMessageAsync() + { + var mockSession = new Mock { CallBase = true }; + + mockSession + .Setup(s => s.SendMessageAsync(It.IsAny(), It.IsAny())) + .Returns(Task.CompletedTask); + + IMcpEndpoint endpoint = mockSession.Object; + + await endpoint.SendNotificationAsync("notify", cancellationToken: TestContext.Current.CancellationToken); + + mockSession.Verify(s => s.SendMessageAsync(It.IsAny(), It.IsAny()), Times.Once); + } + + [Fact] + public async Task SendNotificationAsync_Generic_Forwards_To_McpSession_SendMessageAsync() + { + var mockSession = new Mock { CallBase = true }; + + mockSession + .Setup(s => s.SendMessageAsync(It.IsAny(), It.IsAny())) + .Returns(Task.CompletedTask); + + IMcpEndpoint endpoint = mockSession.Object; + + await endpoint.SendNotificationAsync("notify", "payload", cancellationToken: TestContext.Current.CancellationToken); + + mockSession.Verify(s => s.SendMessageAsync(It.IsAny(), It.IsAny()), Times.Once); + } + + [Fact] + public async Task NotifyProgressAsync_Forwards_To_McpSession_SendMessageAsync() + { + var mockSession = new Mock { CallBase = true }; + + mockSession + .Setup(s => s.SendMessageAsync(It.IsAny(), It.IsAny())) + .Returns(Task.CompletedTask); + + IMcpEndpoint endpoint = mockSession.Object; + + await endpoint.NotifyProgressAsync(new ProgressToken("progress-token"), new ProgressNotificationValue { Progress = 1 }, cancellationToken: TestContext.Current.CancellationToken); + + mockSession.Verify(s => s.SendMessageAsync(It.IsAny(), It.IsAny()), Times.Once); + } +} \ No newline at end of file diff --git a/tests/ModelContextProtocol.Tests/Protocol/ElicitationTests.cs b/tests/ModelContextProtocol.Tests/Protocol/ElicitationTests.cs index f44743916..22fd69c17 100644 --- a/tests/ModelContextProtocol.Tests/Protocol/ElicitationTests.cs +++ b/tests/ModelContextProtocol.Tests/Protocol/ElicitationTests.cs @@ -67,7 +67,7 @@ protected override void ConfigureServices(ServiceCollection services, IMcpServer [Fact] public async Task Can_Elicit_Information() { - await using IMcpClient client = await CreateMcpClientForServer(new McpClientOptions + await using McpClient client = await CreateMcpClientForServer(new McpClientOptions { Capabilities = new() { diff --git a/tests/ModelContextProtocol.Tests/Protocol/ElicitationTypedTests.cs b/tests/ModelContextProtocol.Tests/Protocol/ElicitationTypedTests.cs index 11c7995cc..7ac39d591 100644 --- a/tests/ModelContextProtocol.Tests/Protocol/ElicitationTypedTests.cs +++ b/tests/ModelContextProtocol.Tests/Protocol/ElicitationTypedTests.cs @@ -1,7 +1,6 @@ using Microsoft.Extensions.DependencyInjection; using ModelContextProtocol.Client; using ModelContextProtocol.Protocol; -using ModelContextProtocol.Server; using System.Text.Json; using System.Text.Json.Serialization; @@ -102,7 +101,7 @@ await request.Server.ElicitAsync( [Fact] public async Task Can_Elicit_Typed_Information() { - await using IMcpClient client = await CreateMcpClientForServer(new McpClientOptions + await using McpClient client = await CreateMcpClientForServer(new McpClientOptions { Capabilities = new() { @@ -198,7 +197,7 @@ public async Task Can_Elicit_Typed_Information() [Fact] public async Task Elicit_Typed_Respects_NamingPolicy() { - await using IMcpClient client = await CreateMcpClientForServer(new McpClientOptions + await using McpClient client = await CreateMcpClientForServer(new McpClientOptions { Capabilities = new() { @@ -242,7 +241,7 @@ public async Task Elicit_Typed_Respects_NamingPolicy() [Fact] public async Task Elicit_Typed_With_Unsupported_Property_Type_Throws() { - await using IMcpClient client = await CreateMcpClientForServer(new McpClientOptions + await using McpClient client = await CreateMcpClientForServer(new McpClientOptions { Capabilities = new() { @@ -267,7 +266,7 @@ public async Task Elicit_Typed_With_Unsupported_Property_Type_Throws() [Fact] public async Task Elicit_Typed_With_Nullable_Property_Type_Throws() { - await using IMcpClient client = await CreateMcpClientForServer(new McpClientOptions + await using McpClient client = await CreateMcpClientForServer(new McpClientOptions { Capabilities = new() { @@ -290,7 +289,7 @@ public async Task Elicit_Typed_With_Nullable_Property_Type_Throws() [Fact] public async Task Elicit_Typed_With_NonObject_Generic_Type_Throws() { - await using IMcpClient client = await CreateMcpClientForServer(new McpClientOptions + await using McpClient client = await CreateMcpClientForServer(new McpClientOptions { Capabilities = new() { diff --git a/tests/ModelContextProtocol.Tests/Protocol/NotificationHandlerTests.cs b/tests/ModelContextProtocol.Tests/Protocol/NotificationHandlerTests.cs index 0d18667e9..25470650e 100644 --- a/tests/ModelContextProtocol.Tests/Protocol/NotificationHandlerTests.cs +++ b/tests/ModelContextProtocol.Tests/Protocol/NotificationHandlerTests.cs @@ -13,7 +13,7 @@ public NotificationHandlerTests(ITestOutputHelper testOutputHelper) public async Task RegistrationsAreRemovedWhenDisposed() { const string NotificationName = "somethingsomething"; - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClient client = await CreateMcpClientForServer(); const int Iterations = 10; @@ -40,7 +40,7 @@ public async Task RegistrationsAreRemovedWhenDisposed() public async Task MultipleRegistrationsResultInMultipleCallbacks() { const string NotificationName = "somethingsomething"; - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClient client = await CreateMcpClientForServer(); const int RegistrationCount = 10; @@ -80,7 +80,7 @@ public async Task MultipleRegistrationsResultInMultipleCallbacks() public async Task MultipleHandlersRunEvenIfOneThrows() { const string NotificationName = "somethingsomething"; - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClient client = await CreateMcpClientForServer(); const int RegistrationCount = 10; @@ -122,7 +122,7 @@ public async Task MultipleHandlersRunEvenIfOneThrows() public async Task DisposeAsyncDoesNotCompleteWhileNotificationHandlerRuns(int numberOfDisposals) { const string NotificationName = "somethingsomething"; - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClient client = await CreateMcpClientForServer(); var handlerRunning = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); var releaseHandler = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); @@ -163,7 +163,7 @@ public async Task DisposeAsyncDoesNotCompleteWhileNotificationHandlerRuns(int nu public async Task DisposeAsyncCompletesImmediatelyWhenInvokedFromHandler(int numberOfDisposals) { const string NotificationName = "somethingsomething"; - await using IMcpClient client = await CreateMcpClientForServer(); + await using McpClient client = await CreateMcpClientForServer(); var handlerRunning = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); var releaseHandler = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); diff --git a/tests/ModelContextProtocol.Tests/Server/McpServerExtensionsTests.cs b/tests/ModelContextProtocol.Tests/Server/McpServerExtensionsTests.cs new file mode 100644 index 000000000..5569f993c --- /dev/null +++ b/tests/ModelContextProtocol.Tests/Server/McpServerExtensionsTests.cs @@ -0,0 +1,195 @@ +using Microsoft.Extensions.AI; +using ModelContextProtocol.Protocol; +using ModelContextProtocol.Server; +using Moq; +using System.Text.Json; + +namespace ModelContextProtocol.Tests.Server; + +#pragma warning disable CS0618 // Type or member is obsolete + +public class McpServerExtensionsTests +{ + [Fact] + public async Task SampleAsync_Request_Throws_When_Not_McpServer() + { + var server = new Mock(MockBehavior.Strict).Object; + + var ex = await Assert.ThrowsAsync(async () => await server.SampleAsync( + new CreateMessageRequestParams { Messages = [new SamplingMessage { Role = Role.User, Content = new TextContentBlock { Text = "hi" } }] }, + TestContext.Current.CancellationToken)); + Assert.Contains("Prefer using 'McpServer.SampleAsync' instead", ex.Message); + } + + [Fact] + public async Task SampleAsync_Messages_Throws_When_Not_McpServer() + { + var server = new Mock(MockBehavior.Strict).Object; + + var ex = await Assert.ThrowsAsync(async () => await server.SampleAsync( + [new ChatMessage(ChatRole.User, "hi")], cancellationToken: TestContext.Current.CancellationToken)); + Assert.Contains("Prefer using 'McpServer.SampleAsync' instead", ex.Message); + } + + [Fact] + public void AsSamplingChatClient_Throws_When_Not_McpServer() + { + var server = new Mock(MockBehavior.Strict).Object; + + var ex = Assert.Throws(server.AsSamplingChatClient); + Assert.Contains("Prefer using 'McpServer.AsSamplingChatClient' instead", ex.Message); + } + + [Fact] + public void AsClientLoggerProvider_Throws_When_Not_McpServer() + { + var server = new Mock(MockBehavior.Strict).Object; + + var ex = Assert.Throws(server.AsClientLoggerProvider); + Assert.Contains("Prefer using 'McpServer.AsClientLoggerProvider' instead", ex.Message); + } + + [Fact] + public async Task RequestRootsAsync_Throws_When_Not_McpServer() + { + var server = new Mock(MockBehavior.Strict).Object; + + var ex = await Assert.ThrowsAsync(async () => await server.RequestRootsAsync( + new ListRootsRequestParams(), TestContext.Current.CancellationToken)); + Assert.Contains("Prefer using 'McpServer.RequestRootsAsync' instead", ex.Message); + } + + [Fact] + public async Task ElicitAsync_Throws_When_Not_McpServer() + { + var server = new Mock(MockBehavior.Strict).Object; + + var ex = await Assert.ThrowsAsync(async () => await server.ElicitAsync( + new ElicitRequestParams { Message = "hello" }, TestContext.Current.CancellationToken)); + Assert.Contains("Prefer using 'McpServer.ElicitAsync' instead", ex.Message); + } + + [Fact] + public async Task SampleAsync_Request_Forwards_To_McpServer_SendRequestAsync() + { + var mockServer = new Mock { CallBase = true }; + + var resultPayload = new CreateMessageResult + { + Content = new TextContentBlock { Text = "resp" }, + Model = "test-model", + Role = Role.Assistant, + StopReason = "endTurn", + }; + + mockServer + .Setup(s => s.ClientCapabilities) + .Returns(new ClientCapabilities() { Sampling = new() }); + + mockServer + .Setup(s => s.SendRequestAsync(It.IsAny(), It.IsAny())) + .ReturnsAsync(new JsonRpcResponse + { + Result = JsonSerializer.SerializeToNode(resultPayload, McpJsonUtilities.DefaultOptions), + }); + + IMcpServer server = mockServer.Object; + + var result = await server.SampleAsync(new CreateMessageRequestParams + { + Messages = [new SamplingMessage { Role = Role.User, Content = new TextContentBlock { Text = "hi" } }] + }, TestContext.Current.CancellationToken); + + Assert.Equal("test-model", result.Model); + Assert.Equal(Role.Assistant, result.Role); + Assert.Equal("resp", Assert.IsType(result.Content).Text); + mockServer.Verify(s => s.SendRequestAsync(It.IsAny(), It.IsAny()), Times.Once); + } + + [Fact] + public async Task SampleAsync_Messages_Forwards_To_McpServer_SendRequestAsync() + { + var mockServer = new Mock { CallBase = true }; + + var resultPayload = new CreateMessageResult + { + Content = new TextContentBlock { Text = "resp" }, + Model = "test-model", + Role = Role.Assistant, + StopReason = "endTurn", + }; + + mockServer + .Setup(s => s.ClientCapabilities) + .Returns(new ClientCapabilities() { Sampling = new() }); + + mockServer + .Setup(s => s.SendRequestAsync(It.IsAny(), It.IsAny())) + .ReturnsAsync(new JsonRpcResponse + { + Result = JsonSerializer.SerializeToNode(resultPayload, McpJsonUtilities.DefaultOptions), + }); + + IMcpServer server = mockServer.Object; + + var chatResponse = await server.SampleAsync([new ChatMessage(ChatRole.User, "hi")], cancellationToken: TestContext.Current.CancellationToken); + + Assert.Equal("test-model", chatResponse.ModelId); + var last = chatResponse.Messages.Last(); + Assert.Equal(ChatRole.Assistant, last.Role); + Assert.Equal("resp", last.Text); + mockServer.Verify(s => s.SendRequestAsync(It.IsAny(), It.IsAny()), Times.Once); + } + + [Fact] + public async Task RequestRootsAsync_Forwards_To_McpServer_SendRequestAsync() + { + var mockServer = new Mock { CallBase = true }; + + var resultPayload = new ListRootsResult { Roots = [new Root { Uri = "root://a" }] }; + + mockServer + .Setup(s => s.ClientCapabilities) + .Returns(new ClientCapabilities() { Roots = new() }); + + mockServer + .Setup(s => s.SendRequestAsync(It.IsAny(), It.IsAny())) + .ReturnsAsync(new JsonRpcResponse + { + Result = JsonSerializer.SerializeToNode(resultPayload, McpJsonUtilities.DefaultOptions), + }); + + IMcpServer server = mockServer.Object; + + var result = await server.RequestRootsAsync(new ListRootsRequestParams(), TestContext.Current.CancellationToken); + + Assert.Equal("root://a", result.Roots[0].Uri); + mockServer.Verify(s => s.SendRequestAsync(It.IsAny(), It.IsAny()), Times.Once); + } + + [Fact] + public async Task ElicitAsync_Forwards_To_McpServer_SendRequestAsync() + { + var mockServer = new Mock { CallBase = true }; + + var resultPayload = new ElicitResult { Action = "accept" }; + + mockServer + .Setup(s => s.ClientCapabilities) + .Returns(new ClientCapabilities() { Elicitation = new() }); + + mockServer + .Setup(s => s.SendRequestAsync(It.IsAny(), It.IsAny())) + .ReturnsAsync(new JsonRpcResponse + { + Result = JsonSerializer.SerializeToNode(resultPayload, McpJsonUtilities.DefaultOptions), + }); + + IMcpServer server = mockServer.Object; + + var result = await server.ElicitAsync(new ElicitRequestParams { Message = "hi" }, TestContext.Current.CancellationToken); + + Assert.Equal("accept", result.Action); + mockServer.Verify(s => s.SendRequestAsync(It.IsAny(), It.IsAny()), Times.Once); + } +} diff --git a/tests/ModelContextProtocol.Tests/Server/McpServerFactoryTests.cs b/tests/ModelContextProtocol.Tests/Server/McpServerFactoryTests.cs deleted file mode 100644 index 034a30bd7..000000000 --- a/tests/ModelContextProtocol.Tests/Server/McpServerFactoryTests.cs +++ /dev/null @@ -1,45 +0,0 @@ -using ModelContextProtocol.Server; -using ModelContextProtocol.Tests.Utils; - -namespace ModelContextProtocol.Tests.Server; - -public class McpServerFactoryTests : LoggedTest -{ - private readonly McpServerOptions _options; - - public McpServerFactoryTests(ITestOutputHelper testOutputHelper) - : base(testOutputHelper) - { - _options = new McpServerOptions - { - ProtocolVersion = "1.0", - InitializationTimeout = TimeSpan.FromSeconds(30) - }; - } - - [Fact] - public async Task Create_Should_Initialize_With_Valid_Parameters() - { - // Arrange & Act - await using var transport = new TestServerTransport(); - await using IMcpServer server = McpServerFactory.Create(transport, _options, LoggerFactory); - - // Assert - Assert.NotNull(server); - } - - [Fact] - public void Create_Throws_For_Null_ServerTransport() - { - // Arrange, Act & Assert - Assert.Throws("transport", () => McpServerFactory.Create(null!, _options, LoggerFactory)); - } - - [Fact] - public async Task Create_Throws_For_Null_Options() - { - // Arrange, Act & Assert - await using var transport = new TestServerTransport(); - Assert.Throws("serverOptions", () => McpServerFactory.Create(transport, null!, LoggerFactory)); - } -} diff --git a/tests/ModelContextProtocol.Tests/Server/McpServerLoggingLevelTests.cs b/tests/ModelContextProtocol.Tests/Server/McpServerLoggingLevelTests.cs index b2e748730..be271a686 100644 --- a/tests/ModelContextProtocol.Tests/Server/McpServerLoggingLevelTests.cs +++ b/tests/ModelContextProtocol.Tests/Server/McpServerLoggingLevelTests.cs @@ -25,7 +25,7 @@ public void CanCreateServerWithLoggingLevelHandler() var provider = services.BuildServiceProvider(); - provider.GetRequiredService(); + provider.GetRequiredService(); } [Fact] @@ -39,7 +39,7 @@ public void AddingLoggingLevelHandlerSetsLoggingCapability() var provider = services.BuildServiceProvider(); - var server = provider.GetRequiredService(); + var server = provider.GetRequiredService(); Assert.NotNull(server.ServerOptions.Capabilities?.Logging); Assert.NotNull(server.ServerOptions.Capabilities.Logging.SetLoggingLevelHandler); @@ -52,7 +52,7 @@ public void ServerWithoutCallingLoggingLevelHandlerDoesNotSetLoggingCapability() services.AddMcpServer() .WithStdioServerTransport(); var provider = services.BuildServiceProvider(); - var server = provider.GetRequiredService(); + var server = provider.GetRequiredService(); Assert.Null(server.ServerOptions.Capabilities?.Logging); } } diff --git a/tests/ModelContextProtocol.Tests/Server/McpServerPromptTests.cs b/tests/ModelContextProtocol.Tests/Server/McpServerPromptTests.cs index 307e086a3..41c26f405 100644 --- a/tests/ModelContextProtocol.Tests/Server/McpServerPromptTests.cs +++ b/tests/ModelContextProtocol.Tests/Server/McpServerPromptTests.cs @@ -1,11 +1,9 @@ using Microsoft.Extensions.AI; using Microsoft.Extensions.DependencyInjection; -using Microsoft.Extensions.Primitives; using ModelContextProtocol.Protocol; using ModelContextProtocol.Server; using Moq; using System.ComponentModel; -using System.Diagnostics; using System.Reflection; using System.Runtime.InteropServices; using System.Text.Json; @@ -43,11 +41,11 @@ public void Create_InvalidArgs_Throws() } [Fact] - public async Task SupportsIMcpServer() + public async Task SupportsMcpServer() { - Mock mockServer = new(); + Mock mockServer = new(); - McpServerPrompt prompt = McpServerPrompt.Create((IMcpServer server) => + McpServerPrompt prompt = McpServerPrompt.Create((McpServer server) => { Assert.Same(mockServer.Object, server); return new ChatMessage(ChatRole.User, "Hello"); @@ -73,7 +71,7 @@ public async Task SupportsCtorInjection() sc.AddSingleton(expectedMyService); IServiceProvider services = sc.BuildServiceProvider(); - Mock mockServer = new(); + Mock mockServer = new(); mockServer.SetupGet(s => s.Services).Returns(services); MethodInfo? testMethod = typeof(HasCtorWithSpecialParameters).GetMethod(nameof(HasCtorWithSpecialParameters.TestPrompt)); @@ -96,11 +94,11 @@ public async Task SupportsCtorInjection() private sealed class HasCtorWithSpecialParameters { private readonly MyService _ms; - private readonly IMcpServer _server; + private readonly McpServer _server; private readonly RequestContext _request; private readonly IProgress _progress; - public HasCtorWithSpecialParameters(MyService ms, IMcpServer server, RequestContext request, IProgress progress) + public HasCtorWithSpecialParameters(MyService ms, McpServer server, RequestContext request, IProgress progress) { Assert.NotNull(ms); Assert.NotNull(server); @@ -135,11 +133,11 @@ public async Task SupportsServiceFromDI() Assert.DoesNotContain("actualMyService", prompt.ProtocolPrompt.Arguments?.Select(a => a.Name) ?? []); await Assert.ThrowsAnyAsync(async () => await prompt.GetAsync( - new RequestContext(new Mock().Object, CreateTestJsonRpcRequest()), + new RequestContext(new Mock().Object, CreateTestJsonRpcRequest()), TestContext.Current.CancellationToken)); var result = await prompt.GetAsync( - new RequestContext(new Mock().Object, CreateTestJsonRpcRequest()) { Services = services }, + new RequestContext(new Mock().Object, CreateTestJsonRpcRequest()) { Services = services }, TestContext.Current.CancellationToken); Assert.Equal("Hello", Assert.IsType(result.Messages[0].Content).Text); } @@ -160,7 +158,7 @@ public async Task SupportsOptionalServiceFromDI() }, new() { Services = services }); var result = await prompt.GetAsync( - new RequestContext(new Mock().Object, CreateTestJsonRpcRequest()), + new RequestContext(new Mock().Object, CreateTestJsonRpcRequest()), TestContext.Current.CancellationToken); Assert.Equal("Hello", Assert.IsType(result.Messages[0].Content).Text); } @@ -173,7 +171,7 @@ public async Task SupportsDisposingInstantiatedDisposableTargets() _ => new DisposablePromptType()); var result = await prompt1.GetAsync( - new RequestContext(new Mock().Object, CreateTestJsonRpcRequest()), + new RequestContext(new Mock().Object, CreateTestJsonRpcRequest()), TestContext.Current.CancellationToken); Assert.Equal("disposals:1", Assert.IsType(result.Messages[0].Content).Text); } @@ -186,7 +184,7 @@ public async Task SupportsAsyncDisposingInstantiatedAsyncDisposableTargets() _ => new AsyncDisposablePromptType()); var result = await prompt1.GetAsync( - new RequestContext(new Mock().Object, CreateTestJsonRpcRequest()), + new RequestContext(new Mock().Object, CreateTestJsonRpcRequest()), TestContext.Current.CancellationToken); Assert.Equal("asyncDisposals:1", Assert.IsType(result.Messages[0].Content).Text); } @@ -199,7 +197,7 @@ public async Task SupportsAsyncDisposingInstantiatedAsyncDisposableAndDisposable _ => new AsyncDisposableAndDisposablePromptType()); var result = await prompt1.GetAsync( - new RequestContext(new Mock().Object, CreateTestJsonRpcRequest()), + new RequestContext(new Mock().Object, CreateTestJsonRpcRequest()), TestContext.Current.CancellationToken); Assert.Equal("disposals:0, asyncDisposals:1", Assert.IsType(result.Messages[0].Content).Text); } @@ -215,7 +213,7 @@ public async Task CanReturnGetPromptResult() }); var actual = await prompt.GetAsync( - new RequestContext(new Mock().Object, CreateTestJsonRpcRequest()), + new RequestContext(new Mock().Object, CreateTestJsonRpcRequest()), TestContext.Current.CancellationToken); Assert.Same(expected, actual); @@ -232,7 +230,7 @@ public async Task CanReturnText() }); var actual = await prompt.GetAsync( - new RequestContext(new Mock().Object, CreateTestJsonRpcRequest()), + new RequestContext(new Mock().Object, CreateTestJsonRpcRequest()), TestContext.Current.CancellationToken); Assert.NotNull(actual); @@ -258,7 +256,7 @@ public async Task CanReturnPromptMessage() }); var actual = await prompt.GetAsync( - new RequestContext(new Mock().Object, CreateTestJsonRpcRequest()), + new RequestContext(new Mock().Object, CreateTestJsonRpcRequest()), TestContext.Current.CancellationToken); Assert.NotNull(actual); @@ -290,7 +288,7 @@ public async Task CanReturnPromptMessages() }); var actual = await prompt.GetAsync( - new RequestContext(new Mock().Object, CreateTestJsonRpcRequest()), + new RequestContext(new Mock().Object, CreateTestJsonRpcRequest()), TestContext.Current.CancellationToken); Assert.NotNull(actual); @@ -317,7 +315,7 @@ public async Task CanReturnChatMessage() }); var actual = await prompt.GetAsync( - new RequestContext(new Mock().Object, CreateTestJsonRpcRequest()), + new RequestContext(new Mock().Object, CreateTestJsonRpcRequest()), TestContext.Current.CancellationToken); Assert.NotNull(actual); @@ -349,7 +347,7 @@ public async Task CanReturnChatMessages() }); var actual = await prompt.GetAsync( - new RequestContext(new Mock().Object, CreateTestJsonRpcRequest()), + new RequestContext(new Mock().Object, CreateTestJsonRpcRequest()), TestContext.Current.CancellationToken); Assert.NotNull(actual); @@ -370,7 +368,7 @@ public async Task ThrowsForNullReturn() }); await Assert.ThrowsAsync(async () => await prompt.GetAsync( - new RequestContext(new Mock().Object, CreateTestJsonRpcRequest()), + new RequestContext(new Mock().Object, CreateTestJsonRpcRequest()), TestContext.Current.CancellationToken)); } @@ -383,7 +381,7 @@ public async Task ThrowsForUnexpectedTypeReturn() }); await Assert.ThrowsAsync(async () => await prompt.GetAsync( - new RequestContext(new Mock().Object, CreateTestJsonRpcRequest()), + new RequestContext(new Mock().Object, CreateTestJsonRpcRequest()), TestContext.Current.CancellationToken)); } diff --git a/tests/ModelContextProtocol.Tests/Server/McpServerResourceTests.cs b/tests/ModelContextProtocol.Tests/Server/McpServerResourceTests.cs index df0b65372..f7f2a7742 100644 --- a/tests/ModelContextProtocol.Tests/Server/McpServerResourceTests.cs +++ b/tests/ModelContextProtocol.Tests/Server/McpServerResourceTests.cs @@ -60,7 +60,7 @@ public void CanCreateServerWithResource() var provider = services.BuildServiceProvider(); - provider.GetRequiredService(); + provider.GetRequiredService(); } @@ -96,7 +96,7 @@ public void CanCreateServerWithResourceTemplates() var provider = services.BuildServiceProvider(); - provider.GetRequiredService(); + provider.GetRequiredService(); } [Fact] @@ -119,7 +119,7 @@ public void CreatingReadHandlerWithNoListHandlerSucceeds() }); var sp = services.BuildServiceProvider(); - sp.GetRequiredService(); + sp.GetRequiredService(); } [Fact] @@ -143,7 +143,7 @@ public async Task UriTemplate_CreatedFromParameters_LotsOfTypesSupported() McpServerResource t; ReadResourceResult? result; - IMcpServer server = new Mock().Object; + McpServer server = new Mock().Object; t = McpServerResource.Create(() => "42", new() { Name = Name }); Assert.Equal("resource://mcp/Hello", t.ProtocolResourceTemplate.UriTemplate); @@ -153,7 +153,7 @@ public async Task UriTemplate_CreatedFromParameters_LotsOfTypesSupported() Assert.NotNull(result); Assert.Equal("42", ((TextResourceContents)result.Contents[0]).Text); - t = McpServerResource.Create((IMcpServer server) => "42", new() { Name = Name }); + t = McpServerResource.Create((McpServer server) => "42", new() { Name = Name }); Assert.Equal("resource://mcp/Hello", t.ProtocolResourceTemplate.UriTemplate); result = await t.ReadAsync( new RequestContext(server, CreateTestJsonRpcRequest()) { Params = new() { Uri = "resource://mcp/Hello" } }, @@ -287,7 +287,7 @@ public async Task UriTemplate_NonMatchingUri_ReturnsNull(string uri) McpServerResource t = McpServerResource.Create((string arg1) => arg1, new() { Name = "Hello" }); Assert.Equal("resource://mcp/Hello{?arg1}", t.ProtocolResourceTemplate.UriTemplate); Assert.Null(await t.ReadAsync( - new RequestContext(new Mock().Object, CreateTestJsonRpcRequest()) { Params = new() { Uri = uri } }, + new RequestContext(new Mock().Object, CreateTestJsonRpcRequest()) { Params = new() { Uri = uri } }, TestContext.Current.CancellationToken)); } @@ -298,7 +298,7 @@ public async Task UriTemplate_IsHostCaseInsensitive(string actualUri, string que { McpServerResource t = McpServerResource.Create(() => "resource", new() { UriTemplate = actualUri }); Assert.NotNull(await t.ReadAsync( - new RequestContext(new Mock().Object, CreateTestJsonRpcRequest()) { Params = new() { Uri = queriedUri } }, + new RequestContext(new Mock().Object, CreateTestJsonRpcRequest()) { Params = new() { Uri = queriedUri } }, TestContext.Current.CancellationToken)); } @@ -327,7 +327,7 @@ public async Task UriTemplate_MissingParameter_Throws(string uri) McpServerResource t = McpServerResource.Create((string arg1, int arg2) => arg1, new() { Name = "Hello" }); Assert.Equal("resource://mcp/Hello{?arg1,arg2}", t.ProtocolResourceTemplate.UriTemplate); await Assert.ThrowsAsync(async () => await t.ReadAsync( - new RequestContext(new Mock().Object, CreateTestJsonRpcRequest()) { Params = new() { Uri = uri } }, + new RequestContext(new Mock().Object, CreateTestJsonRpcRequest()) { Params = new() { Uri = uri } }, TestContext.Current.CancellationToken)); } @@ -340,36 +340,36 @@ public async Task UriTemplate_MissingOptionalParameter_Succeeds() ReadResourceResult? result; result = await t.ReadAsync( - new RequestContext(new Mock().Object, CreateTestJsonRpcRequest()) { Params = new() { Uri = "resource://mcp/Hello" } }, + new RequestContext(new Mock().Object, CreateTestJsonRpcRequest()) { Params = new() { Uri = "resource://mcp/Hello" } }, TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.Equal("", ((TextResourceContents)result.Contents[0]).Text); result = await t.ReadAsync( - new RequestContext(new Mock().Object, CreateTestJsonRpcRequest()) { Params = new() { Uri = "resource://mcp/Hello?arg1=first" } }, + new RequestContext(new Mock().Object, CreateTestJsonRpcRequest()) { Params = new() { Uri = "resource://mcp/Hello?arg1=first" } }, TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.Equal("first", ((TextResourceContents)result.Contents[0]).Text); result = await t.ReadAsync( - new RequestContext(new Mock().Object, CreateTestJsonRpcRequest()) { Params = new() { Uri = "resource://mcp/Hello?arg2=42" } }, + new RequestContext(new Mock().Object, CreateTestJsonRpcRequest()) { Params = new() { Uri = "resource://mcp/Hello?arg2=42" } }, TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.Equal("42", ((TextResourceContents)result.Contents[0]).Text); result = await t.ReadAsync( - new RequestContext(new Mock().Object, CreateTestJsonRpcRequest()) { Params = new() { Uri = "resource://mcp/Hello?arg1=first&arg2=42" } }, + new RequestContext(new Mock().Object, CreateTestJsonRpcRequest()) { Params = new() { Uri = "resource://mcp/Hello?arg1=first&arg2=42" } }, TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.Equal("first42", ((TextResourceContents)result.Contents[0]).Text); } [Fact] - public async Task SupportsIMcpServer() + public async Task SupportsMcpServer() { - Mock mockServer = new(); + Mock mockServer = new(); - McpServerResource resource = McpServerResource.Create((IMcpServer server) => + McpServerResource resource = McpServerResource.Create((McpServer server) => { Assert.Same(mockServer.Object, server); return "42"; @@ -391,7 +391,7 @@ public async Task SupportsCtorInjection() sc.AddSingleton(expectedMyService); IServiceProvider services = sc.BuildServiceProvider(); - Mock mockServer = new(); + Mock mockServer = new(); mockServer.SetupGet(s => s.Services).Returns(services); MethodInfo? testMethod = typeof(HasCtorWithSpecialParameters).GetMethod(nameof(HasCtorWithSpecialParameters.TestResource)); @@ -414,11 +414,11 @@ public async Task SupportsCtorInjection() private sealed class HasCtorWithSpecialParameters { private readonly MyService _ms; - private readonly IMcpServer _server; + private readonly McpServer _server; private readonly RequestContext _request; private readonly IProgress _progress; - public HasCtorWithSpecialParameters(MyService ms, IMcpServer server, RequestContext request, IProgress progress) + public HasCtorWithSpecialParameters(MyService ms, McpServer server, RequestContext request, IProgress progress) { Assert.NotNull(ms); Assert.NotNull(server); @@ -477,7 +477,7 @@ public async Task SupportsServiceFromDI(ServiceLifetime injectedArgumentLifetime McpServerResource resource = services.GetRequiredService(); - Mock mockServer = new(); + Mock mockServer = new(); await Assert.ThrowsAnyAsync(async () => await resource.ReadAsync( new RequestContext(mockServer.Object, CreateTestJsonRpcRequest()) { Params = new() { Uri = "resource://mcp/Test" } }, @@ -506,7 +506,7 @@ public async Task SupportsOptionalServiceFromDI() }, new() { Services = services, Name = "Test" }); var result = await resource.ReadAsync( - new RequestContext(new Mock().Object, CreateTestJsonRpcRequest()) { Params = new() { Uri = "resource://mcp/Test" } }, + new RequestContext(new Mock().Object, CreateTestJsonRpcRequest()) { Params = new() { Uri = "resource://mcp/Test" } }, TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.Equal("42", ((TextResourceContents)result.Contents[0]).Text); @@ -522,7 +522,7 @@ public async Task SupportsDisposingInstantiatedDisposableTargets() _ => new DisposableResourceType()); var result = await resource1.ReadAsync( - new RequestContext(new Mock().Object, CreateTestJsonRpcRequest()) { Params = new() { Uri = "test://static/resource/instanceMethod" } }, + new RequestContext(new Mock().Object, CreateTestJsonRpcRequest()) { Params = new() { Uri = "test://static/resource/instanceMethod" } }, TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.Equal("0", ((TextResourceContents)result.Contents[0]).Text); @@ -533,8 +533,8 @@ public async Task SupportsDisposingInstantiatedDisposableTargets() [Fact] public async Task CanReturnReadResult() { - Mock mockServer = new(); - McpServerResource resource = McpServerResource.Create((IMcpServer server) => + Mock mockServer = new(); + McpServerResource resource = McpServerResource.Create((McpServer server) => { Assert.Same(mockServer.Object, server); return new ReadResourceResult { Contents = new List { new TextResourceContents { Text = "hello" } } }; @@ -550,8 +550,8 @@ public async Task CanReturnReadResult() [Fact] public async Task CanReturnResourceContents() { - Mock mockServer = new(); - McpServerResource resource = McpServerResource.Create((IMcpServer server) => + Mock mockServer = new(); + McpServerResource resource = McpServerResource.Create((McpServer server) => { Assert.Same(mockServer.Object, server); return new TextResourceContents { Text = "hello" }; @@ -567,8 +567,8 @@ public async Task CanReturnResourceContents() [Fact] public async Task CanReturnCollectionOfResourceContents() { - Mock mockServer = new(); - McpServerResource resource = McpServerResource.Create((IMcpServer server) => + Mock mockServer = new(); + McpServerResource resource = McpServerResource.Create((McpServer server) => { Assert.Same(mockServer.Object, server); return (IList) @@ -589,8 +589,8 @@ public async Task CanReturnCollectionOfResourceContents() [Fact] public async Task CanReturnString() { - Mock mockServer = new(); - McpServerResource resource = McpServerResource.Create((IMcpServer server) => + Mock mockServer = new(); + McpServerResource resource = McpServerResource.Create((McpServer server) => { Assert.Same(mockServer.Object, server); return "42"; @@ -606,8 +606,8 @@ public async Task CanReturnString() [Fact] public async Task CanReturnCollectionOfStrings() { - Mock mockServer = new(); - McpServerResource resource = McpServerResource.Create((IMcpServer server) => + Mock mockServer = new(); + McpServerResource resource = McpServerResource.Create((McpServer server) => { Assert.Same(mockServer.Object, server); return new List { "42", "43" }; @@ -624,8 +624,8 @@ public async Task CanReturnCollectionOfStrings() [Fact] public async Task CanReturnDataContent() { - Mock mockServer = new(); - McpServerResource resource = McpServerResource.Create((IMcpServer server) => + Mock mockServer = new(); + McpServerResource resource = McpServerResource.Create((McpServer server) => { Assert.Same(mockServer.Object, server); return new DataContent(new byte[] { 0, 1, 2 }, "application/octet-stream"); @@ -642,8 +642,8 @@ public async Task CanReturnDataContent() [Fact] public async Task CanReturnCollectionOfAIContent() { - Mock mockServer = new(); - McpServerResource resource = McpServerResource.Create((IMcpServer server) => + Mock mockServer = new(); + McpServerResource resource = McpServerResource.Create((McpServer server) => { Assert.Same(mockServer.Object, server); return new List diff --git a/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs b/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs index 6750b2cad..61cda7015 100644 --- a/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs +++ b/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs @@ -32,12 +32,38 @@ private static McpServerOptions CreateOptions(ServerCapabilities? capabilities = }; } + [Fact] + public async Task Create_Should_Initialize_With_Valid_Parameters() + { + // Arrange & Act + await using var transport = new TestServerTransport(); + await using McpServer server = McpServer.Create(transport, _options, LoggerFactory); + + // Assert + Assert.NotNull(server); + } + + [Fact] + public void Create_Throws_For_Null_ServerTransport() + { + // Arrange, Act & Assert + Assert.Throws("transport", () => McpServer.Create(null!, _options, LoggerFactory)); + } + + [Fact] + public async Task Create_Throws_For_Null_Options() + { + // Arrange, Act & Assert + await using var transport = new TestServerTransport(); + Assert.Throws("serverOptions", () => McpServer.Create(transport, null!, LoggerFactory)); + } + [Fact] public async Task Constructor_Should_Initialize_With_Valid_Parameters() { // Arrange & Act await using var transport = new TestServerTransport(); - await using var server = McpServerFactory.Create(transport, _options, LoggerFactory); + await using var server = McpServer.Create(transport, _options, LoggerFactory); // Assert Assert.NotNull(server); @@ -47,7 +73,7 @@ public async Task Constructor_Should_Initialize_With_Valid_Parameters() public void Constructor_Throws_For_Null_Transport() { // Arrange, Act & Assert - Assert.Throws(() => McpServerFactory.Create(null!, _options, LoggerFactory)); + Assert.Throws(() => McpServer.Create(null!, _options, LoggerFactory)); } [Fact] @@ -55,7 +81,7 @@ public async Task Constructor_Throws_For_Null_Options() { // Arrange, Act & Assert await using var transport = new TestServerTransport(); - Assert.Throws(() => McpServerFactory.Create(transport, null!, LoggerFactory)); + Assert.Throws(() => McpServer.Create(transport, null!, LoggerFactory)); } [Fact] @@ -63,7 +89,7 @@ public async Task Constructor_Does_Not_Throw_For_Null_Logger() { // Arrange & Act await using var transport = new TestServerTransport(); - await using var server = McpServerFactory.Create(transport, _options, null); + await using var server = McpServer.Create(transport, _options, null); // Assert Assert.NotNull(server); @@ -74,7 +100,7 @@ public async Task Constructor_Does_Not_Throw_For_Null_ServiceProvider() { // Arrange & Act await using var transport = new TestServerTransport(); - await using var server = McpServerFactory.Create(transport, _options, LoggerFactory, null); + await using var server = McpServer.Create(transport, _options, LoggerFactory, null); // Assert Assert.NotNull(server); @@ -85,7 +111,7 @@ public async Task RunAsync_Should_Throw_InvalidOperationException_If_Already_Run { // Arrange await using var transport = new TestServerTransport(); - await using var server = McpServerFactory.Create(transport, _options, LoggerFactory); + await using var server = McpServer.Create(transport, _options, LoggerFactory); var runTask = server.RunAsync(TestContext.Current.CancellationToken); // Act & Assert @@ -100,7 +126,7 @@ public async Task SampleAsync_Should_Throw_Exception_If_Client_Does_Not_Support_ { // Arrange await using var transport = new TestServerTransport(); - await using var server = McpServerFactory.Create(transport, _options, LoggerFactory); + await using var server = McpServer.Create(transport, _options, LoggerFactory); SetClientCapabilities(server, new ClientCapabilities()); var action = async () => await server.SampleAsync(new CreateMessageRequestParams { Messages = [] }, CancellationToken.None); @@ -114,7 +140,7 @@ public async Task SampleAsync_Should_SendRequest() { // Arrange await using var transport = new TestServerTransport(); - await using var server = McpServerFactory.Create(transport, _options, LoggerFactory); + await using var server = McpServer.Create(transport, _options, LoggerFactory); SetClientCapabilities(server, new ClientCapabilities { Sampling = new SamplingCapability() }); var runTask = server.RunAsync(TestContext.Current.CancellationToken); @@ -136,7 +162,7 @@ public async Task RequestRootsAsync_Should_Throw_Exception_If_Client_Does_Not_Su { // Arrange await using var transport = new TestServerTransport(); - await using var server = McpServerFactory.Create(transport, _options, LoggerFactory); + await using var server = McpServer.Create(transport, _options, LoggerFactory); SetClientCapabilities(server, new ClientCapabilities()); // Act & Assert @@ -148,7 +174,7 @@ public async Task RequestRootsAsync_Should_SendRequest() { // Arrange await using var transport = new TestServerTransport(); - await using var server = McpServerFactory.Create(transport, _options, LoggerFactory); + await using var server = McpServer.Create(transport, _options, LoggerFactory); SetClientCapabilities(server, new ClientCapabilities { Roots = new RootsCapability() }); var runTask = server.RunAsync(TestContext.Current.CancellationToken); @@ -170,7 +196,7 @@ public async Task ElicitAsync_Should_Throw_Exception_If_Client_Does_Not_Support_ { // Arrange await using var transport = new TestServerTransport(); - await using var server = McpServerFactory.Create(transport, _options, LoggerFactory); + await using var server = McpServer.Create(transport, _options, LoggerFactory); SetClientCapabilities(server, new ClientCapabilities()); // Act & Assert @@ -182,7 +208,7 @@ public async Task ElicitAsync_Should_SendRequest() { // Arrange await using var transport = new TestServerTransport(); - await using var server = McpServerFactory.Create(transport, _options, LoggerFactory); + await using var server = McpServer.Create(transport, _options, LoggerFactory); SetClientCapabilities(server, new ClientCapabilities { Elicitation = new ElicitationCapability() }); var runTask = server.RunAsync(TestContext.Current.CancellationToken); @@ -216,7 +242,7 @@ await Can_Handle_Requests( [Fact] public async Task Can_Handle_Initialize_Requests() { - AssemblyName expectedAssemblyName = (Assembly.GetEntryAssembly() ?? typeof(IMcpServer).Assembly).GetName(); + AssemblyName expectedAssemblyName = (Assembly.GetEntryAssembly() ?? typeof(McpServer).Assembly).GetName(); await Can_Handle_Requests( serverCapabilities: null, method: RequestMethods.Initialize, @@ -510,7 +536,7 @@ private async Task Can_Handle_Requests(ServerCapabilities? serverCapabilities, s var options = CreateOptions(serverCapabilities); configureOptions?.Invoke(options); - await using var server = McpServerFactory.Create(transport, options, LoggerFactory); + await using var server = McpServer.Create(transport, options, LoggerFactory); var runTask = server.RunAsync(TestContext.Current.CancellationToken); @@ -544,7 +570,7 @@ private async Task Succeeds_Even_If_No_Handler_Assigned(ServerCapabilities serve await using var transport = new TestServerTransport(); var options = CreateOptions(serverCapabilities); - var server = McpServerFactory.Create(transport, options, LoggerFactory); + var server = McpServer.Create(transport, options, LoggerFactory); await server.DisposeAsync(); } @@ -589,7 +615,7 @@ public async Task AsSamplingChatClient_HandlesRequestResponse() public async Task Can_SendMessage_Before_RunAsync() { await using var transport = new TestServerTransport(); - await using var server = McpServerFactory.Create(transport, _options, LoggerFactory); + await using var server = McpServer.Create(transport, _options, LoggerFactory); var logNotification = new JsonRpcNotification { @@ -605,22 +631,22 @@ public async Task Can_SendMessage_Before_RunAsync() Assert.Same(logNotification, transport.SentMessages[0]); } - private static void SetClientCapabilities(IMcpServer server, ClientCapabilities capabilities) + private static void SetClientCapabilities(McpServer server, ClientCapabilities capabilities) { - PropertyInfo? property = server.GetType().GetProperty("ClientCapabilities", BindingFlags.Public | BindingFlags.Instance); - Assert.NotNull(property); - property.SetValue(server, capabilities); + FieldInfo? field = server.GetType().GetField("_clientCapabilities", BindingFlags.NonPublic | BindingFlags.Instance); + Assert.NotNull(field); + field.SetValue(server, capabilities); } - private sealed class TestServerForIChatClient(bool supportsSampling) : IMcpServer + private sealed class TestServerForIChatClient(bool supportsSampling) : McpServer { - public ClientCapabilities? ClientCapabilities => + public override ClientCapabilities? ClientCapabilities => supportsSampling ? new ClientCapabilities { Sampling = new SamplingCapability() } : null; - public McpServerOptions ServerOptions => new(); + public override McpServerOptions ServerOptions => new(); - public Task SendRequestAsync(JsonRpcRequest request, CancellationToken cancellationToken) + public override Task SendRequestAsync(JsonRpcRequest request, CancellationToken cancellationToken) { CreateMessageRequestParams? rp = JsonSerializer.Deserialize(request.Params, McpJsonUtilities.DefaultOptions); @@ -653,17 +679,17 @@ public Task SendRequestAsync(JsonRpcRequest request, Cancellati }); } - public ValueTask DisposeAsync() => default; + public override ValueTask DisposeAsync() => default; - public string? SessionId => throw new NotImplementedException(); - public Implementation? ClientInfo => throw new NotImplementedException(); - public IServiceProvider? Services => throw new NotImplementedException(); - public LoggingLevel? LoggingLevel => throw new NotImplementedException(); - public Task SendMessageAsync(JsonRpcMessage message, CancellationToken cancellationToken = default) => + public override string? SessionId => throw new NotImplementedException(); + public override Implementation? ClientInfo => throw new NotImplementedException(); + public override IServiceProvider? Services => throw new NotImplementedException(); + public override LoggingLevel? LoggingLevel => throw new NotImplementedException(); + public override Task SendMessageAsync(JsonRpcMessage message, CancellationToken cancellationToken = default) => throw new NotImplementedException(); - public Task RunAsync(CancellationToken cancellationToken = default) => + public override Task RunAsync(CancellationToken cancellationToken = default) => throw new NotImplementedException(); - public IAsyncDisposable RegisterNotificationHandler(string method, Func handler) => + public override IAsyncDisposable RegisterNotificationHandler(string method, Func handler) => throw new NotImplementedException(); } @@ -683,7 +709,7 @@ public async Task NotifyProgress_Should_Be_Handled() })], }; - var server = McpServerFactory.Create(transport, options, LoggerFactory); + var server = McpServer.Create(transport, options, LoggerFactory); Task serverTask = server.RunAsync(TestContext.Current.CancellationToken); diff --git a/tests/ModelContextProtocol.Tests/Server/McpServerToolTests.cs b/tests/ModelContextProtocol.Tests/Server/McpServerToolTests.cs index ca2ab7835..b9463e18f 100644 --- a/tests/ModelContextProtocol.Tests/Server/McpServerToolTests.cs +++ b/tests/ModelContextProtocol.Tests/Server/McpServerToolTests.cs @@ -48,11 +48,11 @@ public void Create_InvalidArgs_Throws() } [Fact] - public async Task SupportsIMcpServer() + public async Task SupportsMcpServer() { - Mock mockServer = new(); + Mock mockServer = new(); - McpServerTool tool = McpServerTool.Create((IMcpServer server) => + McpServerTool tool = McpServerTool.Create((McpServer server) => { Assert.Same(mockServer.Object, server); return "42"; @@ -75,7 +75,7 @@ public async Task SupportsCtorInjection() sc.AddSingleton(expectedMyService); IServiceProvider services = sc.BuildServiceProvider(); - Mock mockServer = new(); + Mock mockServer = new(); mockServer.SetupGet(s => s.Services).Returns(services); MethodInfo? testMethod = typeof(HasCtorWithSpecialParameters).GetMethod(nameof(HasCtorWithSpecialParameters.TestTool)); @@ -98,11 +98,11 @@ public async Task SupportsCtorInjection() private sealed class HasCtorWithSpecialParameters { private readonly MyService _ms; - private readonly IMcpServer _server; + private readonly McpServer _server; private readonly RequestContext _request; private readonly IProgress _progress; - public HasCtorWithSpecialParameters(MyService ms, IMcpServer server, RequestContext request, IProgress progress) + public HasCtorWithSpecialParameters(MyService ms, McpServer server, RequestContext request, IProgress progress) { Assert.NotNull(ms); Assert.NotNull(server); @@ -162,7 +162,7 @@ public async Task SupportsServiceFromDI(ServiceLifetime injectedArgumentLifetime Assert.DoesNotContain("actualMyService", JsonSerializer.Serialize(tool.ProtocolTool.InputSchema, McpJsonUtilities.DefaultOptions)); - Mock mockServer = new(); + Mock mockServer = new(); var ex = await Assert.ThrowsAsync(async () => await tool.InvokeAsync( new RequestContext(mockServer.Object, CreateTestJsonRpcRequest()), @@ -192,7 +192,7 @@ public async Task SupportsOptionalServiceFromDI() }, new() { Services = services }); var result = await tool.InvokeAsync( - new RequestContext(new Mock().Object, CreateTestJsonRpcRequest()), + new RequestContext(new Mock().Object, CreateTestJsonRpcRequest()), TestContext.Current.CancellationToken); Assert.Equal("42", (result.Content[0] as TextContentBlock)?.Text); } @@ -207,7 +207,7 @@ public async Task SupportsDisposingInstantiatedDisposableTargets() options); var result = await tool1.InvokeAsync( - new RequestContext(new Mock().Object, CreateTestJsonRpcRequest()), + new RequestContext(new Mock().Object, CreateTestJsonRpcRequest()), TestContext.Current.CancellationToken); Assert.Equal("""{"disposals":1}""", (result.Content[0] as TextContentBlock)?.Text); } @@ -222,7 +222,7 @@ public async Task SupportsAsyncDisposingInstantiatedAsyncDisposableTargets() options); var result = await tool1.InvokeAsync( - new RequestContext(new Mock().Object, CreateTestJsonRpcRequest()), + new RequestContext(new Mock().Object, CreateTestJsonRpcRequest()), TestContext.Current.CancellationToken); Assert.Equal("""{"asyncDisposals":1}""", (result.Content[0] as TextContentBlock)?.Text); } @@ -241,7 +241,7 @@ public async Task SupportsAsyncDisposingInstantiatedAsyncDisposableAndDisposable options); var result = await tool1.InvokeAsync( - new RequestContext(new Mock().Object, CreateTestJsonRpcRequest()) { Services = services }, + new RequestContext(new Mock().Object, CreateTestJsonRpcRequest()) { Services = services }, TestContext.Current.CancellationToken); Assert.Equal("""{"asyncDisposals":1,"disposals":0}""", (result.Content[0] as TextContentBlock)?.Text); } @@ -250,8 +250,8 @@ public async Task SupportsAsyncDisposingInstantiatedAsyncDisposableAndDisposable [Fact] public async Task CanReturnCollectionOfAIContent() { - Mock mockServer = new(); - McpServerTool tool = McpServerTool.Create((IMcpServer server) => + Mock mockServer = new(); + McpServerTool tool = McpServerTool.Create((McpServer server) => { Assert.Same(mockServer.Object, server); return new List { @@ -282,8 +282,8 @@ public async Task CanReturnCollectionOfAIContent() [InlineData("data:audio/wav;base64,1234", "audio")] public async Task CanReturnSingleAIContent(string data, string type) { - Mock mockServer = new(); - McpServerTool tool = McpServerTool.Create((IMcpServer server) => + Mock mockServer = new(); + McpServerTool tool = McpServerTool.Create((McpServer server) => { Assert.Same(mockServer.Object, server); return type switch @@ -325,8 +325,8 @@ public async Task CanReturnSingleAIContent(string data, string type) [Fact] public async Task CanReturnNullAIContent() { - Mock mockServer = new(); - McpServerTool tool = McpServerTool.Create((IMcpServer server) => + Mock mockServer = new(); + McpServerTool tool = McpServerTool.Create((McpServer server) => { Assert.Same(mockServer.Object, server); return (string?)null; @@ -340,8 +340,8 @@ public async Task CanReturnNullAIContent() [Fact] public async Task CanReturnString() { - Mock mockServer = new(); - McpServerTool tool = McpServerTool.Create((IMcpServer server) => + Mock mockServer = new(); + McpServerTool tool = McpServerTool.Create((McpServer server) => { Assert.Same(mockServer.Object, server); return "42"; @@ -356,8 +356,8 @@ public async Task CanReturnString() [Fact] public async Task CanReturnCollectionOfStrings() { - Mock mockServer = new(); - McpServerTool tool = McpServerTool.Create((IMcpServer server) => + Mock mockServer = new(); + McpServerTool tool = McpServerTool.Create((McpServer server) => { Assert.Same(mockServer.Object, server); return new List { "42", "43" }; @@ -372,8 +372,8 @@ public async Task CanReturnCollectionOfStrings() [Fact] public async Task CanReturnMcpContent() { - Mock mockServer = new(); - McpServerTool tool = McpServerTool.Create((IMcpServer server) => + Mock mockServer = new(); + McpServerTool tool = McpServerTool.Create((McpServer server) => { Assert.Same(mockServer.Object, server); return new TextContentBlock { Text = "42" }; @@ -389,8 +389,8 @@ public async Task CanReturnMcpContent() [Fact] public async Task CanReturnCollectionOfMcpContent() { - Mock mockServer = new(); - McpServerTool tool = McpServerTool.Create((IMcpServer server) => + Mock mockServer = new(); + McpServerTool tool = McpServerTool.Create((McpServer server) => { Assert.Same(mockServer.Object, server); return (IList) @@ -416,8 +416,8 @@ public async Task CanReturnCallToolResult() Content = new List { new TextContentBlock { Text = "text" }, new ImageContentBlock { Data = "1234", MimeType = "image/png" } } }; - Mock mockServer = new(); - McpServerTool tool = McpServerTool.Create((IMcpServer server) => + Mock mockServer = new(); + McpServerTool tool = McpServerTool.Create((McpServer server) => { Assert.Same(mockServer.Object, server); return response; @@ -436,7 +436,7 @@ public async Task CanReturnCallToolResult() [Fact] public async Task SupportsSchemaCreateOptions() { - AIJsonSchemaCreateOptions schemaCreateOptions = new () + AIJsonSchemaCreateOptions schemaCreateOptions = new() { TransformSchemaNode = (context, node) => { @@ -462,7 +462,7 @@ public async Task StructuredOutput_Enabled_ReturnsExpectedSchema(T value) { JsonSerializerOptions options = new() { TypeInfoResolver = new DefaultJsonTypeInfoResolver() }; McpServerTool tool = McpServerTool.Create(() => value, new() { Name = "tool", UseStructuredContent = true, SerializerOptions = options }); - var mockServer = new Mock(); + var mockServer = new Mock(); var request = new RequestContext(mockServer.Object, CreateTestJsonRpcRequest()) { Params = new CallToolRequestParams { Name = "tool" }, @@ -480,7 +480,7 @@ public async Task StructuredOutput_Enabled_ReturnsExpectedSchema(T value) public async Task StructuredOutput_Enabled_VoidReturningTools_ReturnsExpectedSchema() { McpServerTool tool = McpServerTool.Create(() => { }); - var mockServer = new Mock(); + var mockServer = new Mock(); var request = new RequestContext(mockServer.Object, CreateTestJsonRpcRequest()) { Params = new CallToolRequestParams { Name = "tool" }, @@ -520,7 +520,7 @@ public async Task StructuredOutput_Disabled_ReturnsExpectedSchema(T value) { JsonSerializerOptions options = new() { TypeInfoResolver = new DefaultJsonTypeInfoResolver() }; McpServerTool tool = McpServerTool.Create(() => value, new() { UseStructuredContent = false, SerializerOptions = options }); - var mockServer = new Mock(); + var mockServer = new Mock(); var request = new RequestContext(mockServer.Object, CreateTestJsonRpcRequest()) { Params = new CallToolRequestParams { Name = "tool" }, diff --git a/tests/ModelContextProtocol.Tests/StdioServerIntegrationTests.cs b/tests/ModelContextProtocol.Tests/StdioServerIntegrationTests.cs index f3927be62..d14c376c1 100644 --- a/tests/ModelContextProtocol.Tests/StdioServerIntegrationTests.cs +++ b/tests/ModelContextProtocol.Tests/StdioServerIntegrationTests.cs @@ -35,7 +35,7 @@ public async Task SigInt_DisposesTestServerWithHosting_Gracefully() process.StandardInput.BaseStream, serverName: "TestServerWithHosting"); - await using var client = await McpClientFactory.CreateAsync( + await using var client = await McpClient.CreateAsync( new TestClientTransport(streamServerTransport), loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken); diff --git a/tests/ModelContextProtocol.Tests/Transport/SseClientTransportAutoDetectTests.cs b/tests/ModelContextProtocol.Tests/Transport/HttpClientTransportAutoDetectTests.cs similarity index 89% rename from tests/ModelContextProtocol.Tests/Transport/SseClientTransportAutoDetectTests.cs rename to tests/ModelContextProtocol.Tests/Transport/HttpClientTransportAutoDetectTests.cs index 8f6fbff2c..768ebf7ea 100644 --- a/tests/ModelContextProtocol.Tests/Transport/SseClientTransportAutoDetectTests.cs +++ b/tests/ModelContextProtocol.Tests/Transport/HttpClientTransportAutoDetectTests.cs @@ -4,12 +4,12 @@ namespace ModelContextProtocol.Tests.Transport; -public class SseClientTransportAutoDetectTests(ITestOutputHelper testOutputHelper) : LoggedTest(testOutputHelper) +public class HttpClientTransportAutoDetectTests(ITestOutputHelper testOutputHelper) : LoggedTest(testOutputHelper) { [Fact] public async Task AutoDetectMode_UsesStreamableHttp_WhenServerSupportsIt() { - var options = new SseClientTransportOptions + var options = new HttpClientTransportOptions { Endpoint = new Uri("http://localhost"), TransportMode = HttpTransportMode.AutoDetect, @@ -18,7 +18,7 @@ public async Task AutoDetectMode_UsesStreamableHttp_WhenServerSupportsIt() using var mockHttpHandler = new MockHttpHandler(); using var httpClient = new HttpClient(mockHttpHandler); - await using var transport = new SseClientTransport(options, httpClient, LoggerFactory); + await using var transport = new HttpClientTransport(options, httpClient, LoggerFactory); // Simulate successful Streamable HTTP response for initialize mockHttpHandler.RequestHandler = (request) => @@ -50,7 +50,7 @@ public async Task AutoDetectMode_UsesStreamableHttp_WhenServerSupportsIt() [Fact] public async Task AutoDetectMode_FallsBackToSse_WhenStreamableHttpFails() { - var options = new SseClientTransportOptions + var options = new HttpClientTransportOptions { Endpoint = new Uri("http://localhost"), TransportMode = HttpTransportMode.AutoDetect, @@ -59,7 +59,7 @@ public async Task AutoDetectMode_FallsBackToSse_WhenStreamableHttpFails() using var mockHttpHandler = new MockHttpHandler(); using var httpClient = new HttpClient(mockHttpHandler); - await using var transport = new SseClientTransport(options, httpClient, LoggerFactory); + await using var transport = new HttpClientTransport(options, httpClient, LoggerFactory); var requestCount = 0; diff --git a/tests/ModelContextProtocol.Tests/Transport/SseClientTransportTests.cs b/tests/ModelContextProtocol.Tests/Transport/HttpClientTransportTests.cs similarity index 86% rename from tests/ModelContextProtocol.Tests/Transport/SseClientTransportTests.cs rename to tests/ModelContextProtocol.Tests/Transport/HttpClientTransportTests.cs index 3ff504304..fc1ac2d88 100644 --- a/tests/ModelContextProtocol.Tests/Transport/SseClientTransportTests.cs +++ b/tests/ModelContextProtocol.Tests/Transport/HttpClientTransportTests.cs @@ -5,14 +5,14 @@ namespace ModelContextProtocol.Tests.Transport; -public class SseClientTransportTests : LoggedTest +public class HttpClientTransportTests : LoggedTest { - private readonly SseClientTransportOptions _transportOptions; + private readonly HttpClientTransportOptions _transportOptions; - public SseClientTransportTests(ITestOutputHelper testOutputHelper) + public HttpClientTransportTests(ITestOutputHelper testOutputHelper) : base(testOutputHelper) { - _transportOptions = new SseClientTransportOptions + _transportOptions = new HttpClientTransportOptions { Endpoint = new Uri("http://localhost:8080"), ConnectionTimeout = TimeSpan.FromSeconds(2), @@ -28,14 +28,14 @@ public SseClientTransportTests(ITestOutputHelper testOutputHelper) [Fact] public void Constructor_Throws_For_Null_Options() { - var exception = Assert.Throws(() => new SseClientTransport(null!, LoggerFactory)); + var exception = Assert.Throws(() => new HttpClientTransport(null!, LoggerFactory)); Assert.Equal("transportOptions", exception.ParamName); } [Fact] public void Constructor_Throws_For_Null_HttpClient() { - var exception = Assert.Throws(() => new SseClientTransport(_transportOptions, httpClient: null!, LoggerFactory)); + var exception = Assert.Throws(() => new HttpClientTransport(_transportOptions, httpClient: null!, LoggerFactory)); Assert.Equal("httpClient", exception.ParamName); } @@ -44,7 +44,7 @@ public async Task ConnectAsync_Should_Connect_Successfully() { using var mockHttpHandler = new MockHttpHandler(); using var httpClient = new HttpClient(mockHttpHandler); - await using var transport = new SseClientTransport(_transportOptions, httpClient, LoggerFactory); + await using var transport = new HttpClientTransport(_transportOptions, httpClient, LoggerFactory); bool firstCall = true; @@ -68,7 +68,7 @@ public async Task ConnectAsync_Throws_Exception_On_Failure() { using var mockHttpHandler = new MockHttpHandler(); using var httpClient = new HttpClient(mockHttpHandler); - await using var transport = new SseClientTransport(_transportOptions, httpClient, LoggerFactory); + await using var transport = new HttpClientTransport(_transportOptions, httpClient, LoggerFactory); var retries = 0; mockHttpHandler.RequestHandler = (request) => @@ -87,7 +87,7 @@ public async Task SendMessageAsync_Handles_Accepted_Response() { using var mockHttpHandler = new MockHttpHandler(); using var httpClient = new HttpClient(mockHttpHandler); - await using var transport = new SseClientTransport(_transportOptions, httpClient, LoggerFactory); + await using var transport = new HttpClientTransport(_transportOptions, httpClient, LoggerFactory); var firstCall = true; mockHttpHandler.RequestHandler = (request) => @@ -125,7 +125,7 @@ public async Task ReceiveMessagesAsync_Handles_Messages() { using var mockHttpHandler = new MockHttpHandler(); using var httpClient = new HttpClient(mockHttpHandler); - await using var transport = new SseClientTransport(_transportOptions, httpClient, LoggerFactory); + await using var transport = new HttpClientTransport(_transportOptions, httpClient, LoggerFactory); var callIndex = 0; mockHttpHandler.RequestHandler = (request) => @@ -165,7 +165,7 @@ public async Task DisposeAsync_Should_Dispose_Resources() }); }; - await using var transport = new SseClientTransport(_transportOptions, httpClient, LoggerFactory); + await using var transport = new HttpClientTransport(_transportOptions, httpClient, LoggerFactory); await using var session = await transport.ConnectAsync(TestContext.Current.CancellationToken); await session.DisposeAsync(); diff --git a/tests/ModelContextProtocol.Tests/Transport/StdioClientTransportTests.cs b/tests/ModelContextProtocol.Tests/Transport/StdioClientTransportTests.cs index dfde342af..48c2b9533 100644 --- a/tests/ModelContextProtocol.Tests/Transport/StdioClientTransportTests.cs +++ b/tests/ModelContextProtocol.Tests/Transport/StdioClientTransportTests.cs @@ -8,7 +8,7 @@ namespace ModelContextProtocol.Tests.Transport; public class StdioClientTransportTests(ITestOutputHelper testOutputHelper) : LoggedTest(testOutputHelper) { public static bool IsStdErrCallbackSupported => !PlatformDetection.IsMonoRuntime; - + [Fact] public async Task CreateAsync_ValidProcessInvalidServer_Throws() { @@ -18,13 +18,13 @@ public async Task CreateAsync_ValidProcessInvalidServer_Throws() new(new() { Command = "cmd", Arguments = ["/C", $"echo \"{id}\" >&2"] }, LoggerFactory) : new(new() { Command = "ls", Arguments = [id] }, LoggerFactory); - IOException e = await Assert.ThrowsAsync(() => McpClientFactory.CreateAsync(transport, loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken)); + IOException e = await Assert.ThrowsAsync(() => McpClient.CreateAsync(transport, loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken)); if (!RuntimeInformation.IsOSPlatform(OSPlatform.OSX)) { Assert.Contains(id, e.ToString()); } } - + [Fact(Skip = "Platform not supported by this test.", SkipUnless = nameof(IsStdErrCallbackSupported))] public async Task CreateAsync_ValidProcessInvalidServer_StdErrCallbackInvoked() { @@ -46,7 +46,7 @@ public async Task CreateAsync_ValidProcessInvalidServer_StdErrCallbackInvoked() new(new() { Command = "cmd", Arguments = ["/C", $"echo \"{id}\" >&2"], StandardErrorLines = stdErrCallback }, LoggerFactory) : new(new() { Command = "ls", Arguments = [id], StandardErrorLines = stdErrCallback }, LoggerFactory); - await Assert.ThrowsAsync(() => McpClientFactory.CreateAsync(transport, loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken)); + await Assert.ThrowsAsync(() => McpClient.CreateAsync(transport, loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken)); Assert.InRange(count, 1, int.MaxValue); Assert.Contains(id, sb.ToString());