From 58dc73a221221833fd8ba6278eb3b0edfe9b082a Mon Sep 17 00:00:00 2001 From: Stephen Halter Date: Tue, 26 Aug 2025 08:17:43 -0700 Subject: [PATCH 1/7] Add middleware and authz support for server-side handlers --- docs/concepts/filters.md | 262 ++++++++++++ docs/concepts/toc.yml | 4 + .../Properties/launchSettings.json | 4 +- .../AuthorizationFilterSetup.cs | 222 +++++++++++ .../HttpMcpServerBuilderExtensions.cs | 4 + .../SseHandler.cs | 3 +- .../StreamableHttpHandler.cs | 38 +- .../StreamableHttpSession.cs | 20 +- src/ModelContextProtocol.Core/McpSession.cs | 18 +- .../Protocol/JsonRpcMessage.cs | 26 +- .../Protocol/JsonRpcMessageContext.cs | 61 +++ .../Protocol/JsonRpcRequest.cs | 4 +- .../Protocol/Prompt.cs | 7 + .../Protocol/Resource.cs | 7 + .../Protocol/ResourceTemplate.cs | 8 + .../Protocol/Tool.cs | 15 +- .../RequestHandlers.cs | 8 +- .../Server/AIFunctionMcpServerPrompt.cs | 15 +- .../Server/AIFunctionMcpServerResource.cs | 19 +- .../Server/AIFunctionMcpServerTool.cs | 74 ++-- .../Server/DestinationBoundMcpServer.cs | 18 +- .../Server/IMcpServerPrimitive.cs | 9 + .../Server/McpServer.cs | 180 +++++++-- .../Server/McpServerFilters.cs | 161 ++++++++ .../Server/McpServerOptions.cs | 10 + .../Server/McpServerPrompt.cs | 21 +- .../Server/McpServerPromptCreateOptions.cs | 10 + .../Server/McpServerResource.cs | 21 +- .../Server/McpServerResourceCreateOptions.cs | 10 + .../Server/McpServerTool.cs | 27 +- .../Server/McpServerToolCreateOptions.cs | 12 +- .../Server/RequestContext.cs | 30 +- ...eProvider.cs => RequestServiceProvider.cs} | 25 +- .../Server/SseResponseStreamTransport.cs | 16 +- .../Server/SseWriter.cs | 2 + .../Server/StreamableHttpPostTransport.cs | 69 ++-- .../Server/StreamableHttpServerTransport.cs | 30 +- .../McpServerBuilderExtensions.cs | 272 +++++++++++++ .../AuthorizeAttributeTests.cs | 374 ++++++++++++++++++ .../MapMcpSseTests.cs | 2 +- .../MapMcpTests.cs | 40 ++ .../ClientServerTestBase.cs | 3 +- .../McpServerBuilderExtensionsFilterTests.cs | 314 +++++++++++++++ .../McpServerBuilderExtensionsToolsTests.cs | 19 +- .../Server/McpServerPromptTests.cs | 46 ++- .../Server/McpServerResourceTests.cs | 84 ++-- .../Server/McpServerToolTests.cs | 110 ++---- 47 files changed, 2374 insertions(+), 360 deletions(-) create mode 100644 docs/concepts/filters.md create mode 100644 src/ModelContextProtocol.AspNetCore/AuthorizationFilterSetup.cs create mode 100644 src/ModelContextProtocol.Core/Protocol/JsonRpcMessageContext.cs create mode 100644 src/ModelContextProtocol.Core/Server/McpServerFilters.cs rename src/ModelContextProtocol.Core/Server/{AugmentedServiceProvider.cs => RequestServiceProvider.cs} (69%) create mode 100644 tests/ModelContextProtocol.AspNetCore.Tests/AuthorizeAttributeTests.cs create mode 100644 tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsFilterTests.cs diff --git a/docs/concepts/filters.md b/docs/concepts/filters.md new file mode 100644 index 000000000..9a8a43265 --- /dev/null +++ b/docs/concepts/filters.md @@ -0,0 +1,262 @@ +--- +title: Filters +author: halter73 +description: MCP Server Handler Filters +uid: filters +--- + +# MCP Server Handler Filters + +This document describes the filter functionality in the MCP Server, which allows you to add middleware-style filters to handler pipelines. + +## Overview + +For each handler type in the MCP Server, there are corresponding `AddXXXFilter` methods in `McpServerBuilderExtensions.cs` that allow you to add filters to the handler pipeline. The filters are stored in `McpServerOptions.Filters` and applied during server configuration. + +## Available Filter Methods + +The following filter methods are available: + +- `AddListResourceTemplatesFilter` - Filter for list resource templates handlers +- `AddListToolsFilter` - Filter for list tools handlers +- `AddCallToolFilter` - Filter for call tool handlers +- `AddListPromptsFilter` - Filter for list prompts handlers +- `AddGetPromptFilter` - Filter for get prompt handlers +- `AddListResourcesFilter` - Filter for list resources handlers +- `AddReadResourceFilter` - Filter for read resource handlers +- `AddCompleteFilter` - Filter for completion handlers +- `AddSubscribeToResourcesFilter` - Filter for resource subscription handlers +- `AddUnsubscribeFromResourcesFilter` - Filter for resource unsubscription handlers +- `AddSetLoggingLevelFilter` - Filter for logging level handlers + +## Usage + +Filters are functions that take a handler and return a new handler, allowing you to wrap the original handler with additional functionality: + +```csharp +services.AddMcpServer() + .WithListToolsHandler(async (context, cancellationToken) => + { + // Your base handler logic + return new ListToolsResult { Tools = GetTools() }; + }) + .AddListToolsFilter(next => async (context, cancellationToken) => + { + // Pre-processing logic + Console.WriteLine("Before handler execution"); + + var result = await next(context, cancellationToken); + + // Post-processing logic + Console.WriteLine("After handler execution"); + return result; + }); +``` + +## Filter Execution Order + +```csharp +services.AddMcpServer() + .WithListToolsHandler(baseHandler) + .AddListToolsFilter(filter1) // Executes first (outermost) + .AddListToolsFilter(filter2) // Executes second + .AddListToolsFilter(filter3); // Executes third (closest to handler) +``` + +Execution flow: `filter1 -> filter2 -> filter3 -> baseHandler -> filter3 -> filter2 -> filter1` + +## Common Use Cases + +### Logging +```csharp +.AddListToolsFilter(next => async (context, cancellationToken) => +{ + Console.WriteLine($"Processing request from {context.Meta.ProgressToken}"); + var result = await next(context, cancellationToken); + Console.WriteLine($"Returning {result.Tools?.Count ?? 0} tools"); + return result; +}); +``` + +### Error Handling +```csharp +.AddCallToolFilter(next => async (context, cancellationToken) => +{ + try + { + return await next(context, cancellationToken); + } + catch (Exception ex) + { + return new CallToolResult + { + Content = new[] { new TextContent { Type = "text", Text = $"Error: {ex.Message}" } }, + IsError = true + }; + } +}); +``` + +### Performance Monitoring +```csharp +.AddListToolsFilter(next => async (context, cancellationToken) => +{ + var stopwatch = Stopwatch.StartNew(); + var result = await next(context, cancellationToken); + stopwatch.Stop(); + Console.WriteLine($"Handler took {stopwatch.ElapsedMilliseconds}ms"); + return result; +}); +``` + +### Caching +```csharp +.AddListResourcesFilter(next => async (context, cancellationToken) => +{ + var cacheKey = $"resources:{context.Params.Cursor}"; + if (cache.TryGetValue(cacheKey, out var cached)) + return cached; + + var result = await next(context, cancellationToken); + cache.Set(cacheKey, result, TimeSpan.FromMinutes(5)); + return result; +}); +``` + +## Built-in Authorization Filters + +When using the ASP.NET Core integration (`ModelContextProtocol.AspNetCore`), authorization filters are automatically configured to support `[Authorize]` and `[AllowAnonymous]` attributes on MCP server tools, prompts, and resources. + +### Authorization Attributes Support + +The MCP server automatically respects the following authorization attributes: + +- **`[Authorize]`** - Requires authentication for access +- **`[Authorize(Roles = "RoleName")]`** - Requires specific roles +- **`[Authorize(Policy = "PolicyName")]`** - Requires specific authorization policies +- **`[AllowAnonymous]`** - Explicitly allows anonymous access (overrides `[Authorize]`) + +### Tool Authorization + +Tools can be decorated with authorization attributes to control access: + +```csharp +[McpServerToolType] +public class WeatherTools +{ + [McpServerTool, Description("Gets public weather data")] + public static string GetWeather(string location) + { + return $"Weather for {location}: Sunny, 25°C"; + } + + [McpServerTool, Description("Gets detailed weather forecast")] + [Authorize] // Requires authentication + public static string GetDetailedForecast(string location) + { + return $"Detailed forecast for {location}: ..."; + } + + [McpServerTool, Description("Manages weather alerts")] + [Authorize(Roles = "Admin")] // Requires Admin role + public static string ManageWeatherAlerts(string alertType) + { + return $"Managing alert: {alertType}"; + } +} +``` + +### Class-Level Authorization + +You can apply authorization at the class level, which affects all tools in the class: + +```csharp +[McpServerToolType] +[Authorize] // All tools require authentication +public class AdminTools +{ + [McpServerTool, Description("Admin-only tool")] + public static string AdminOperation() + { + return "Admin operation completed"; + } + + [McpServerTool, Description("Public tool accessible to anonymous users")] + [AllowAnonymous] // Overrides class-level [Authorize] + public static string PublicOperation() + { + return "Public operation completed"; + } +} +``` + +### How Authorization Filters Work + +The authorization filters work differently for list operations versus individual operations: + +#### List Operations (ListTools, ListPrompts, ListResources) +For list operations, the filters automatically remove unauthorized items from the results. Users only see tools, prompts, or resources they have permission to access. + +#### Individual Operations (CallTool, GetPrompt, ReadResource) +For individual operations, the filters return authorization errors when access is denied: + +- **Tools**: Returns a `CallToolResult` with `IsError = true` and an error message +- **Prompts**: Throws an `McpException` with "Access forbidden" message +- **Resources**: Throws an `McpException` with "Access forbidden" message + +### Setup Requirements + +To use authorization features, you must configure authentication and authorization in your ASP.NET Core application: + +```csharp +var builder = WebApplication.CreateBuilder(args); + +// Add authentication +builder.Services.AddAuthentication("Bearer") + .AddJwtBearer("Bearer", options => { /* JWT configuration */ }); + +// Add authorization (required for [Authorize] attributes to work) +builder.Services.AddAuthorization(); + +// Add MCP server +builder.Services.AddMcpServer() + .WithTools(); + +var app = builder.Build(); + +// Use authentication and authorization middleware +app.UseAuthentication(); +app.UseAuthorization(); + +app.MapMcp(); +app.Run(); +``` + +### Custom Authorization Filters + +You can also create custom authorization filters using the filter methods: + +```csharp +.AddCallToolFilter(next => async (context, cancellationToken) => +{ + // Custom authorization logic + if (context.User?.Identity?.IsAuthenticated != true) + { + return new CallToolResult + { + Content = [new TextContent { Text = "Custom: Authentication required" }], + IsError = true + }; + } + + return await next(context, cancellationToken); +}); +``` + +### RequestContext + +Within filters, you have access to: + +- `context.User` - The current user's `ClaimsPrincipal` +- `context.Services` - The request's service provider for resolving authorization services +- `context.MatchedPrimitive` - The matched tool/prompt/resource with its metadata including authorization attributes via `context.MatchedPrimitive.Metadata` diff --git a/docs/concepts/toc.yml b/docs/concepts/toc.yml index 939f21fc0..2f7c930fa 100644 --- a/docs/concepts/toc.yml +++ b/docs/concepts/toc.yml @@ -13,3 +13,7 @@ items: items: - name: Logging uid: logging +- name: Server Features + items: + - name: Filters + uid: filters \ No newline at end of file diff --git a/samples/AspNetCoreMcpServer/Properties/launchSettings.json b/samples/AspNetCoreMcpServer/Properties/launchSettings.json index a5b8a22f6..6670029e1 100644 --- a/samples/AspNetCoreMcpServer/Properties/launchSettings.json +++ b/samples/AspNetCoreMcpServer/Properties/launchSettings.json @@ -7,7 +7,7 @@ "applicationUrl": "http://localhost:3001", "environmentVariables": { "ASPNETCORE_ENVIRONMENT": "Development", - "OTEL_SERVICE_NAME": "aspnetcore-mcp-server", + "OTEL_SERVICE_NAME": "aspnetcore-mcp-server" } }, "https": { @@ -16,7 +16,7 @@ "applicationUrl": "https://localhost:7133;http://localhost:3001", "environmentVariables": { "ASPNETCORE_ENVIRONMENT": "Development", - "OTEL_SERVICE_NAME": "aspnetcore-mcp-server", + "OTEL_SERVICE_NAME": "aspnetcore-mcp-server" } } } diff --git a/src/ModelContextProtocol.AspNetCore/AuthorizationFilterSetup.cs b/src/ModelContextProtocol.AspNetCore/AuthorizationFilterSetup.cs new file mode 100644 index 000000000..7d2c30f28 --- /dev/null +++ b/src/ModelContextProtocol.AspNetCore/AuthorizationFilterSetup.cs @@ -0,0 +1,222 @@ +using System.Security.Claims; +using Microsoft.AspNetCore.Authorization; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Options; +using ModelContextProtocol.Protocol; +using ModelContextProtocol.Server; + +namespace ModelContextProtocol.AspNetCore; + +/// +/// Evaluates authorization policies from endpoint metadata. +/// +internal sealed class AuthorizationFilterSetup(IAuthorizationPolicyProvider? policyProvider = null) : IConfigureOptions +{ + public void Configure(McpServerOptions options) + { + ConfigureListToolsFilter(options); + ConfigureCallToolFilter(options); + + ConfigureListResourcesFilter(options); + ConfigureListResourceTemplatesFilter(options); + ConfigureReadResourceFilter(options); + + ConfigureListPromptsFilter(options); + ConfigureGetPromptFilter(options); + } + + private void ConfigureListToolsFilter(McpServerOptions options) + { + options.Filters.ListToolsFilters.Add(next => async (context, cancellationToken) => + { + var result = await next(context, cancellationToken); + await FilterAuthorizedItemsAsync( + result.Tools, static tool => tool.McpServerTool, + context.User, context.Services, context); + return result; + }); + } + + private void ConfigureCallToolFilter(McpServerOptions options) + { + options.Filters.CallToolFilters.Add(next => async (context, cancellationToken) => + { + var authResult = await GetAuthorizationResultAsync(context.User, context.MatchedPrimitive, context.Services, context); + if (!authResult.Succeeded) + { + return new CallToolResult + { + Content = [new TextContentBlock { Text = "Access forbidden: This tool requires authorization." }], + IsError = true + }; + } + + return await next(context, cancellationToken); + }); + } + + private void ConfigureListResourcesFilter(McpServerOptions options) + { + options.Filters.ListResourcesFilters.Add(next => async (context, cancellationToken) => + { + var result = await next(context, cancellationToken); + await FilterAuthorizedItemsAsync( + result.Resources, static resource => resource.McpServerResource, + context.User, context.Services, context); + return result; + }); + } + + private void ConfigureListResourceTemplatesFilter(McpServerOptions options) + { + options.Filters.ListResourceTemplatesFilters.Add(next => async (context, cancellationToken) => + { + var result = await next(context, cancellationToken); + await FilterAuthorizedItemsAsync( + result.ResourceTemplates, static resourceTemplate => resourceTemplate.McpServerResource, + context.User, context.Services, context); + return result; + }); + } + + private void ConfigureReadResourceFilter(McpServerOptions options) + { + options.Filters.ReadResourceFilters.Add(next => async (context, cancellationToken) => + { + var authResult = await GetAuthorizationResultAsync(context.User, context.MatchedPrimitive, context.Services, context); + if (!authResult.Succeeded) + { + throw new McpException("Access forbidden: This resource requires authorization.", McpErrorCode.InvalidRequest); + } + + return await next(context, cancellationToken); + }); + } + + private void ConfigureListPromptsFilter(McpServerOptions options) + { + options.Filters.ListPromptsFilters.Add(next => async (context, cancellationToken) => + { + var result = await next(context, cancellationToken); + await FilterAuthorizedItemsAsync( + result.Prompts, static prompt => prompt.McpServerPrompt, + context.User, context.Services, context); + return result; + }); + } + + private void ConfigureGetPromptFilter(McpServerOptions options) + { + options.Filters.GetPromptFilters.Add(next => async (context, cancellationToken) => + { + var authResult = await GetAuthorizationResultAsync(context.User, context.MatchedPrimitive, context.Services, context); + if (!authResult.Succeeded) + { + throw new McpException("Access forbidden: This prompt requires authorization.", McpErrorCode.InvalidRequest); + } + + return await next(context, cancellationToken); + }); + } + + /// + /// Filters a collection of items based on authorization policies in their metadata. + /// For list operations where we need to filter results by authorization. + /// + private async ValueTask FilterAuthorizedItemsAsync(IList items, Func primitiveSelector, + ClaimsPrincipal? user, IServiceProvider? requestServices, object context) + { + for (int i = items.Count - 1; i >= 0; i--) + { + var authorizationResult = await GetAuthorizationResultAsync( + user, primitiveSelector(items[i]), requestServices, context); + + if (!authorizationResult.Succeeded) + { + items.RemoveAt(i); + } + } + } + + private async ValueTask GetAuthorizationResultAsync( + ClaimsPrincipal? user, IMcpServerPrimitive? primitive, IServiceProvider? requestServices, object context) + { + // If no primitive was found for this request or there is IAllowAnonymous metadata anywhere on the class or method, + // the request should go through as normal. + if (primitive is null || primitive.Metadata.Any(static m => m is IAllowAnonymous)) + { + return AuthorizationResult.Success(); + } + + // There are no [Authorize] style attributes applied to the method or containing class. Any fallback policies + // have already been enforced at the HTTP request level by the ASP.NET Core authorization middleware. + if (!primitive.Metadata.Any(static m => m is IAuthorizeData or AuthorizationPolicy or IAuthorizationRequirementData)) + { + return AuthorizationResult.Success(); + } + + if (policyProvider is null) + { + throw new InvalidOperationException($"You must call AddAuthorization() because an authorization related attribute was found on {primitive.Id}"); + } + + // TODO: Cache policy lookup. We would probably use a singleton (not-static) ConditionalWeakTable. + var policy = await CombineAsync(policyProvider, primitive.Metadata); + if (policy is null) + { + return AuthorizationResult.Success(); + } + + if (requestServices is null) + { + // The IAuthorizationPolicyProvider service must be non-null to get to this line, so it's very unexpected for RequestContext.Services to not be set. + throw new InvalidOperationException("RequestContext.Services is not set! The IMcpServer must be initialized with a non-null IServiceProvider."); + } + + // ASP.NET Core's AuthorizationMiddleware resolves the IAuthorizationService from scoped request services, so we do the same. + var authService = requestServices.GetRequiredService(); + return await authService.AuthorizeAsync(user ?? new ClaimsPrincipal(new ClaimsIdentity()), context, policy); + } + + /// + /// Combines authorization policies and requirements from endpoint metadata without considering . + /// + /// The authorization policy provider. + /// The endpoint metadata collection. + /// The combined authorization policy, or null if no authorization is required. + private static async ValueTask CombineAsync(IAuthorizationPolicyProvider policyProvider, IReadOnlyList endpointMetadata) + { + // https://github.com/dotnet/aspnetcore/issues/63365 tracks adding this as public API to AuthorizationPolicy itself. + // Copied from https://github.com/dotnet/aspnetcore/blob/9f2977bf9cfb539820983bda3bedf81c8cda9f20/src/Security/Authorization/Policy/src/AuthorizationMiddleware.cs#L116-L138 + var authorizeData = endpointMetadata.OfType(); + var policies = endpointMetadata.OfType(); + + var policy = await AuthorizationPolicy.CombineAsync(policyProvider, authorizeData, policies); + + AuthorizationPolicyBuilder? reqPolicyBuilder = null; + + foreach (var m in endpointMetadata) + { + if (m is not IAuthorizationRequirementData requirementData) + { + continue; + } + + reqPolicyBuilder ??= new AuthorizationPolicyBuilder(); + foreach (var requirement in requirementData.GetRequirements()) + { + reqPolicyBuilder.AddRequirements(requirement); + } + } + + if (reqPolicyBuilder is null) + { + return policy; + } + + // Combine policy with requirements or just use requirements if no policy + return (policy is null) + ? reqPolicyBuilder.Build() + : AuthorizationPolicy.Combine(policy, reqPolicyBuilder.Build()); + } +} \ No newline at end of file diff --git a/src/ModelContextProtocol.AspNetCore/HttpMcpServerBuilderExtensions.cs b/src/ModelContextProtocol.AspNetCore/HttpMcpServerBuilderExtensions.cs index 2d6b29fd9..70835a83d 100644 --- a/src/ModelContextProtocol.AspNetCore/HttpMcpServerBuilderExtensions.cs +++ b/src/ModelContextProtocol.AspNetCore/HttpMcpServerBuilderExtensions.cs @@ -1,4 +1,5 @@ using Microsoft.Extensions.DependencyInjection.Extensions; +using Microsoft.Extensions.Options; using ModelContextProtocol.AspNetCore; using ModelContextProtocol.Server; @@ -29,6 +30,9 @@ public static IMcpServerBuilder WithHttpTransport(this IMcpServerBuilder builder builder.Services.AddHostedService(); builder.Services.AddDataProtection(); + // Register authorization filter setup for automatic filter configuration + builder.Services.TryAddEnumerable(ServiceDescriptor.Singleton, AuthorizationFilterSetup>()); + if (configureOptions is not null) { builder.Services.Configure(configureOptions); diff --git a/src/ModelContextProtocol.AspNetCore/SseHandler.cs b/src/ModelContextProtocol.AspNetCore/SseHandler.cs index 6ed72fb64..fffdd45e3 100644 --- a/src/ModelContextProtocol.AspNetCore/SseHandler.cs +++ b/src/ModelContextProtocol.AspNetCore/SseHandler.cs @@ -2,7 +2,6 @@ using Microsoft.Extensions.Hosting; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Options; -using ModelContextProtocol.Protocol; using ModelContextProtocol.Server; using System.Collections.Concurrent; using System.Diagnostics; @@ -97,7 +96,7 @@ public async Task HandleMessageRequestAsync(HttpContext context) return; } - var message = (JsonRpcMessage?)await context.Request.ReadFromJsonAsync(McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(JsonRpcMessage)), context.RequestAborted); + var message = await StreamableHttpHandler.ReadJsonRpcMessageAsync(context); if (message is null) { await Results.BadRequest("No message in request body.").ExecuteAsync(context); diff --git a/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs b/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs index bfbd805de..8e3a72eb0 100644 --- a/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs +++ b/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs @@ -8,7 +8,6 @@ using ModelContextProtocol.AspNetCore.Stateless; using ModelContextProtocol.Protocol; using ModelContextProtocol.Server; -using System.IO.Pipelines; using System.Security.Claims; using System.Security.Cryptography; using System.Text.Json; @@ -26,6 +25,8 @@ internal sealed class StreamableHttpHandler( IServiceProvider applicationServices) { private const string McpSessionIdHeaderName = "Mcp-Session-Id"; + + private static readonly JsonTypeInfo s_messageTypeInfo = GetRequiredJsonTypeInfo(); private static readonly JsonTypeInfo s_errorTypeInfo = GetRequiredJsonTypeInfo(); public HttpServerTransportOptions HttpServerTransportOptions => httpServerTransportOptions.Value; @@ -55,8 +56,17 @@ await WriteJsonRpcErrorAsync(context, await using var _ = await session.AcquireReferenceAsync(context.RequestAborted); + var message = await ReadJsonRpcMessageAsync(context); + if (message is null) + { + await WriteJsonRpcErrorAsync(context, + "Bad Request: The POST body did not contain a valid JSON-RPC message.", + StatusCodes.Status400BadRequest); + return; + } + InitializeSseResponse(context); - var wroteResponse = await session.Transport.HandlePostRequest(new HttpDuplexPipe(context), context.RequestAborted); + var wroteResponse = await session.Transport.HandlePostRequest(message, context.Response.Body, context.RequestAborted); if (!wroteResponse) { // We wound up writing nothing, so there should be no Content-Type response header. @@ -264,6 +274,22 @@ internal static string MakeNewSessionId() return WebEncoders.Base64UrlEncode(buffer); } + internal static async Task ReadJsonRpcMessageAsync(HttpContext context) + { + // Implementation for reading a JSON-RPC message from the request body + var message = await context.Request.ReadFromJsonAsync(s_messageTypeInfo, context.RequestAborted); + + if (context.User?.Identity?.IsAuthenticated ?? false) + { + message?.Context = new() + { + User = context.User, + }; + } + + return message; + } + private void ScheduleStatelessSessionIdWrite(HttpContext context, StreamableHttpServerTransport transport) { transport.OnInitRequestReceived = initRequestParams => @@ -304,17 +330,11 @@ internal static Task RunSessionAsync(HttpContext httpContext, IMcpServer session return null; } - private static JsonTypeInfo GetRequiredJsonTypeInfo() => (JsonTypeInfo)McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(T)); + internal static JsonTypeInfo GetRequiredJsonTypeInfo() => (JsonTypeInfo)McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(T)); private static bool MatchesApplicationJsonMediaType(MediaTypeHeaderValue acceptHeaderValue) => acceptHeaderValue.MatchesMediaType("application/json"); private static bool MatchesTextEventStreamMediaType(MediaTypeHeaderValue acceptHeaderValue) => acceptHeaderValue.MatchesMediaType("text/event-stream"); - - private sealed class HttpDuplexPipe(HttpContext context) : IDuplexPipe - { - public PipeReader Input => context.Request.BodyReader; - public PipeWriter Output => context.Response.BodyWriter; - } } diff --git a/src/ModelContextProtocol.AspNetCore/StreamableHttpSession.cs b/src/ModelContextProtocol.AspNetCore/StreamableHttpSession.cs index ffeafada7..7c8a31959 100644 --- a/src/ModelContextProtocol.AspNetCore/StreamableHttpSession.cs +++ b/src/ModelContextProtocol.AspNetCore/StreamableHttpSession.cs @@ -100,15 +100,17 @@ public async ValueTask DisposeAsync() try { - await _disposeCts.CancelAsync(); - try { + // Dispose transport first to complete the incoming MessageReader gracefully and avoid a potentially unnecessary OCE. + await transport.DisposeAsync(); + await _disposeCts.CancelAsync(); + await ServerRunTask; } finally { - await DisposeServerThenTransportAsync(); + await server.DisposeAsync(); } } catch (OperationCanceledException) @@ -124,18 +126,6 @@ public async ValueTask DisposeAsync() } } - private async ValueTask DisposeServerThenTransportAsync() - { - try - { - await server.DisposeAsync(); - } - finally - { - await transport.DisposeAsync(); - } - } - private sealed class UnreferenceDisposable(StreamableHttpSession session) : IAsyncDisposable { public ValueTask DisposeAsync() diff --git a/src/ModelContextProtocol.Core/McpSession.cs b/src/ModelContextProtocol.Core/McpSession.cs index da9542055..75215fee1 100644 --- a/src/ModelContextProtocol.Core/McpSession.cs +++ b/src/ModelContextProtocol.Core/McpSession.cs @@ -116,14 +116,14 @@ public async Task ProcessMessagesAsync(CancellationToken cancellationToken) LogMessageRead(EndpointName, message.GetType().Name); // Fire and forget the message handling to avoid blocking the transport. - if (message.ExecutionContext is null) + if (message.Context?.ExecutionContext is null) { _ = ProcessMessageAsync(); } else { // Flow the execution context from the HTTP request corresponding to this message if provided. - ExecutionContext.Run(message.ExecutionContext, _ => _ = ProcessMessageAsync(), null); + ExecutionContext.Run(message.Context.ExecutionContext, _ => _ = ProcessMessageAsync(), null); } async Task ProcessMessageAsync() @@ -176,13 +176,15 @@ ex is OperationCanceledException && Message = "An error occurred.", }; - await SendMessageAsync(new JsonRpcError + var errorMessage = new JsonRpcError { Id = request.Id, JsonRpc = "2.0", Error = detail, - RelatedTransport = request.RelatedTransport, - }, cancellationToken).ConfigureAwait(false); + Context = new JsonRpcMessageContext { RelatedTransport = request.Context?.RelatedTransport }, + }; + + await SendMessageAsync(errorMessage, cancellationToken).ConfigureAwait(false); } else if (ex is not OperationCanceledException) { @@ -329,7 +331,7 @@ await SendMessageAsync(new JsonRpcResponse { Id = request.Id, Result = result, - RelatedTransport = request.RelatedTransport, + Context = request.Context, }, cancellationToken).ConfigureAwait(false); return result; @@ -349,7 +351,7 @@ private CancellationTokenRegistration RegisterCancellation(CancellationToken can { Method = NotificationMethods.CancelledNotification, Params = JsonSerializer.SerializeToNode(new CancelledNotificationParams { RequestId = state.Item2.Id }, McpJsonUtilities.JsonContext.Default.CancelledNotificationParams), - RelatedTransport = state.Item2.RelatedTransport, + Context = new JsonRpcMessageContext { RelatedTransport = state.Item2.Context?.RelatedTransport }, }); }, Tuple.Create(this, request)); } @@ -527,7 +529,7 @@ public async Task SendMessageAsync(JsonRpcMessage message, CancellationToken can // Streamable HTTP transport where the specification states that the server SHOULD include JSON-RPC responses in // the HTTP response body for the POST request containing the corresponding JSON-RPC request. private Task SendToRelatedTransportAsync(JsonRpcMessage message, CancellationToken cancellationToken) - => (message.RelatedTransport ?? _transport).SendMessageAsync(message, cancellationToken); + => (message.Context?.RelatedTransport ?? _transport).SendMessageAsync(message, cancellationToken); private static CancelledNotificationParams? GetCancelledNotificationParams(JsonNode? notificationParams) { diff --git a/src/ModelContextProtocol.Core/Protocol/JsonRpcMessage.cs b/src/ModelContextProtocol.Core/Protocol/JsonRpcMessage.cs index b3176937c..ae15453db 100644 --- a/src/ModelContextProtocol.Core/Protocol/JsonRpcMessage.cs +++ b/src/ModelContextProtocol.Core/Protocol/JsonRpcMessage.cs @@ -1,5 +1,6 @@ using ModelContextProtocol.Server; using System.ComponentModel; +using System.Security.Claims; using System.Text.Json; using System.Text.Json.Serialization; @@ -29,28 +30,21 @@ private protected JsonRpcMessage() public string JsonRpc { get; init; } = "2.0"; /// - /// Gets or sets the transport the was received on or should be sent over. + /// Gets or sets the contextual information for this JSON-RPC message. /// /// - /// This is used to support the Streamable HTTP transport where the specification states that the server - /// SHOULD include JSON-RPC responses in the HTTP response body for the POST request containing - /// the corresponding JSON-RPC request. It may be for other transports. + /// This property contains transport-specific and runtime context information that accompanies + /// JSON-RPC messages but is not serialized as part of the JSON-RPC payload. This includes + /// transport references, execution context, and authenticated user information. /// - [JsonIgnore] - public ITransport? RelatedTransport { get; set; } - - /// - /// Gets or sets the that should be used to run any handlers - /// /// - /// This is used to support the Streamable HTTP transport in its default stateful mode. In this mode, - /// the outlives the initial HTTP request context it was created on, and new - /// JSON-RPC messages can originate from future HTTP requests. This allows the transport to flow the - /// context with the JSON-RPC message. This is particularly useful for enabling IHttpContextAccessor - /// in tool calls. + /// This property should only be set when implementing a custom + /// that needs to pass additional per-message context or to pass a + /// to + /// or . /// [JsonIgnore] - public ExecutionContext? ExecutionContext { get; set; } + public JsonRpcMessageContext? Context { get; set; } /// /// Provides a for messages, diff --git a/src/ModelContextProtocol.Core/Protocol/JsonRpcMessageContext.cs b/src/ModelContextProtocol.Core/Protocol/JsonRpcMessageContext.cs new file mode 100644 index 000000000..30b6745a9 --- /dev/null +++ b/src/ModelContextProtocol.Core/Protocol/JsonRpcMessageContext.cs @@ -0,0 +1,61 @@ +using ModelContextProtocol.Server; +using System.Security.Claims; +using System.Text.Json.Serialization; + +namespace ModelContextProtocol.Protocol; + +/// +/// Contains contextual information for JSON-RPC messages that is not part of the JSON-RPC protocol specification. +/// +/// +/// This class holds transport-specific and runtime context information that accompanies JSON-RPC messages +/// but is not serialized as part of the JSON-RPC payload. This includes transport references, execution context, +/// and authenticated user information. +/// +public class JsonRpcMessageContext +{ + /// + /// Gets or sets the transport the was received on or should be sent over. + /// + /// + /// This is used to support the Streamable HTTP transport where the specification states that the server + /// SHOULD include JSON-RPC responses in the HTTP response body for the POST request containing + /// the corresponding JSON-RPC request. It may be for other transports. + /// + public ITransport? RelatedTransport { get; set; } + + /// + /// Gets or sets the that should be used to run any handlers + /// + /// + /// This is used to support the Streamable HTTP transport in its default stateful mode. In this mode, + /// the outlives the initial HTTP request context it was created on, and new + /// JSON-RPC messages can originate from future HTTP requests. This allows the transport to flow the + /// context with the JSON-RPC message. This is particularly useful for enabling IHttpContextAccessor + /// in tool calls. + /// + public ExecutionContext? ExecutionContext { get; set; } + + /// + /// Gets or sets the authenticated user associated with this JSON-RPC message. + /// + /// + /// + /// This property contains the representing the authenticated user + /// who initiated this JSON-RPC message. This enables request handlers to access user identity + /// and authorization information without requiring dependency on HTTP context accessors + /// or other HTTP-specific abstractions. + /// + /// + /// The user information is automatically populated by the transport layer when processing + /// incoming HTTP requests in ASP.NET Core scenarios. For other transport types or scenarios + /// where user authentication is not applicable, this property may be . + /// + /// + /// This property is particularly useful in the Streamable HTTP transport where JSON-RPC messages + /// may outlive the original HTTP request context, allowing user identity to be preserved + /// throughout the message processing pipeline. + /// + /// + public ClaimsPrincipal? User { get; set; } +} diff --git a/src/ModelContextProtocol.Core/Protocol/JsonRpcRequest.cs b/src/ModelContextProtocol.Core/Protocol/JsonRpcRequest.cs index ed6c8982a..e80b25f47 100644 --- a/src/ModelContextProtocol.Core/Protocol/JsonRpcRequest.cs +++ b/src/ModelContextProtocol.Core/Protocol/JsonRpcRequest.cs @@ -9,7 +9,7 @@ namespace ModelContextProtocol.Protocol; /// /// Requests are messages that require a response from the receiver. Each request includes a unique ID /// that will be included in the corresponding response message (either a success response or an error). -/// +/// /// The receiver of a request message is expected to execute the specified method with the provided parameters /// and return either a with the result, or a /// if the method execution fails. @@ -36,7 +36,7 @@ internal JsonRpcRequest WithId(RequestId id) Id = id, Method = Method, Params = Params, - RelatedTransport = RelatedTransport, + Context = Context, }; } } diff --git a/src/ModelContextProtocol.Core/Protocol/Prompt.cs b/src/ModelContextProtocol.Core/Protocol/Prompt.cs index 1a5004065..fcd3053f5 100644 --- a/src/ModelContextProtocol.Core/Protocol/Prompt.cs +++ b/src/ModelContextProtocol.Core/Protocol/Prompt.cs @@ -1,5 +1,6 @@ using System.Text.Json.Nodes; using System.Text.Json.Serialization; +using ModelContextProtocol.Server; namespace ModelContextProtocol.Protocol; @@ -59,4 +60,10 @@ public sealed class Prompt : IBaseMetadata /// [JsonPropertyName("_meta")] public JsonObject? Meta { get; set; } + + /// + /// Gets or sets the callable server prompt corresponding to this metadata if any. + /// + [JsonIgnore] + public McpServerPrompt? McpServerPrompt { get; set; } } diff --git a/src/ModelContextProtocol.Core/Protocol/Resource.cs b/src/ModelContextProtocol.Core/Protocol/Resource.cs index 63dce7fdc..1b8a0e9cd 100644 --- a/src/ModelContextProtocol.Core/Protocol/Resource.cs +++ b/src/ModelContextProtocol.Core/Protocol/Resource.cs @@ -1,5 +1,6 @@ using System.Text.Json.Nodes; using System.Text.Json.Serialization; +using ModelContextProtocol.Server; namespace ModelContextProtocol.Protocol; @@ -87,4 +88,10 @@ public sealed class Resource : IBaseMetadata /// [JsonPropertyName("_meta")] public JsonObject? Meta { get; init; } + + /// + /// Gets or sets the callable server resource corresponding to this metadata if any. + /// + [JsonIgnore] + public McpServerResource? McpServerResource { get; set; } } diff --git a/src/ModelContextProtocol.Core/Protocol/ResourceTemplate.cs b/src/ModelContextProtocol.Core/Protocol/ResourceTemplate.cs index d2959d182..f0f294985 100644 --- a/src/ModelContextProtocol.Core/Protocol/ResourceTemplate.cs +++ b/src/ModelContextProtocol.Core/Protocol/ResourceTemplate.cs @@ -1,5 +1,6 @@ using System.Text.Json.Nodes; using System.Text.Json.Serialization; +using ModelContextProtocol.Server; namespace ModelContextProtocol.Protocol; @@ -84,6 +85,12 @@ public sealed class ResourceTemplate : IBaseMetadata [JsonIgnore] public bool IsTemplated => UriTemplate.Contains('{'); + /// + /// Gets or sets the callable server resource corresponding to this metadata if any. + /// + [JsonIgnore] + public McpServerResource? McpServerResource { get; set; } + /// Converts the into a . /// A if is ; otherwise, . public Resource? AsResource() @@ -102,6 +109,7 @@ public sealed class ResourceTemplate : IBaseMetadata MimeType = MimeType, Annotations = Annotations, Meta = Meta, + McpServerResource = McpServerResource, }; } } \ No newline at end of file diff --git a/src/ModelContextProtocol.Core/Protocol/Tool.cs b/src/ModelContextProtocol.Core/Protocol/Tool.cs index c09598ca7..1c4716691 100644 --- a/src/ModelContextProtocol.Core/Protocol/Tool.cs +++ b/src/ModelContextProtocol.Core/Protocol/Tool.cs @@ -1,6 +1,7 @@ using System.Text.Json; using System.Text.Json.Nodes; using System.Text.Json.Serialization; +using ModelContextProtocol.Server; namespace ModelContextProtocol.Protocol; @@ -43,7 +44,7 @@ public sealed class Tool : IBaseMetadata /// if an invalid schema is provided. /// /// - /// The schema typically defines the properties (parameters) that the tool accepts, + /// The schema typically defines the properties (parameters) that the tool accepts, /// their types, and which ones are required. This helps AI models understand /// how to structure their calls to the tool. /// @@ -52,9 +53,9 @@ public sealed class Tool : IBaseMetadata /// /// [JsonPropertyName("inputSchema")] - public JsonElement InputSchema - { - get => field; + public JsonElement InputSchema + { + get => field; set { if (!McpJsonUtilities.IsValidMcpToolSchema(value)) @@ -114,4 +115,10 @@ public JsonElement? OutputSchema /// [JsonPropertyName("_meta")] public JsonObject? Meta { get; set; } + + /// + /// Gets or sets the callable server tool corresponding to this metadata if any. + /// + [JsonIgnore] + public McpServerTool? McpServerTool { get; set; } } diff --git a/src/ModelContextProtocol.Core/RequestHandlers.cs b/src/ModelContextProtocol.Core/RequestHandlers.cs index 854a4bddf..0c2b54fa5 100644 --- a/src/ModelContextProtocol.Core/RequestHandlers.cs +++ b/src/ModelContextProtocol.Core/RequestHandlers.cs @@ -23,13 +23,13 @@ internal sealed class RequestHandlers : Dictionary /// - /// The handler function receives the deserialized request object and a cancellation token, and should return - /// a response object that will be serialized back to the client. + /// The handler function receives the deserialized request object, the full JSON-RPC request, and a cancellation token, + /// and should return a response object that will be serialized back to the client. /// /// public void Set( string method, - Func> handler, + Func> handler, JsonTypeInfo requestTypeInfo, JsonTypeInfo responseTypeInfo) { @@ -41,7 +41,7 @@ public void Set( this[method] = async (request, cancellationToken) => { TRequest? typedRequest = JsonSerializer.Deserialize(request.Params, requestTypeInfo); - object? result = await handler(typedRequest, request.RelatedTransport, cancellationToken).ConfigureAwait(false); + object? result = await handler(typedRequest, request, cancellationToken).ConfigureAwait(false); return JsonSerializer.SerializeToNode(result, responseTypeInfo); }; } diff --git a/src/ModelContextProtocol.Core/Server/AIFunctionMcpServerPrompt.cs b/src/ModelContextProtocol.Core/Server/AIFunctionMcpServerPrompt.cs index d651d7ee3..ef068c551 100644 --- a/src/ModelContextProtocol.Core/Server/AIFunctionMcpServerPrompt.cs +++ b/src/ModelContextProtocol.Core/Server/AIFunctionMcpServerPrompt.cs @@ -11,6 +11,7 @@ namespace ModelContextProtocol.Server; /// Provides an that's implemented via an . internal sealed class AIFunctionMcpServerPrompt : McpServerPrompt { + private readonly IReadOnlyList _metadata; /// /// Creates an instance for a method, specified via a instance. /// @@ -136,7 +137,7 @@ private static AIFunctionFactoryOptions CreateAIFunctionFactoryOptions( Arguments = args, }; - return new AIFunctionMcpServerPrompt(function, prompt); + return new AIFunctionMcpServerPrompt(function, prompt, options?.Metadata ?? []); } private static McpServerPromptCreateOptions DeriveOptions(MethodInfo method, McpServerPromptCreateOptions? options) @@ -154,6 +155,9 @@ private static McpServerPromptCreateOptions DeriveOptions(MethodInfo method, Mcp newOptions.Description ??= descAttr.Description; } + // Set metadata if not already provided + newOptions.Metadata ??= AIFunctionMcpServerTool.CreateMetadata(method); + return newOptions; } @@ -161,15 +165,20 @@ private static McpServerPromptCreateOptions DeriveOptions(MethodInfo method, Mcp internal AIFunction AIFunction { get; } /// Initializes a new instance of the class. - private AIFunctionMcpServerPrompt(AIFunction function, Prompt prompt) + private AIFunctionMcpServerPrompt(AIFunction function, Prompt prompt, IReadOnlyList metadata) { AIFunction = function; ProtocolPrompt = prompt; + ProtocolPrompt.McpServerPrompt = this; + _metadata = metadata; } /// public override Prompt ProtocolPrompt { get; } + /// + public override IReadOnlyList Metadata => _metadata; + /// public override async ValueTask GetAsync( RequestContext request, CancellationToken cancellationToken = default) @@ -177,7 +186,7 @@ public override async ValueTask GetAsync( Throw.IfNull(request); cancellationToken.ThrowIfCancellationRequested(); - request.Services = new RequestServiceProvider(request, request.Services); + request.Services = new RequestServiceProvider(request); AIFunctionArguments arguments = new() { Services = request.Services }; if (request.Params?.Arguments is { } argDict) diff --git a/src/ModelContextProtocol.Core/Server/AIFunctionMcpServerResource.cs b/src/ModelContextProtocol.Core/Server/AIFunctionMcpServerResource.cs index a8b0d2486..69b8deb8d 100644 --- a/src/ModelContextProtocol.Core/Server/AIFunctionMcpServerResource.cs +++ b/src/ModelContextProtocol.Core/Server/AIFunctionMcpServerResource.cs @@ -6,6 +6,7 @@ using System.Diagnostics; using System.Globalization; using System.Reflection; +using System.Runtime.CompilerServices; using System.Text; using System.Text.Json; using System.Text.RegularExpressions; @@ -17,6 +18,7 @@ internal sealed class AIFunctionMcpServerResource : McpServerResource { private readonly Regex? _uriParser; private readonly string[] _templateVariableNames = []; + private readonly IReadOnlyList _metadata; /// /// Creates an instance for a method, specified via a instance. @@ -218,7 +220,7 @@ private static AIFunctionFactoryOptions CreateAIFunctionFactoryOptions( MimeType = options?.MimeType ?? "application/octet-stream", }; - return new AIFunctionMcpServerResource(function, resource); + return new AIFunctionMcpServerResource(function, resource, options?.Metadata ?? []); } private static McpServerResourceCreateOptions DeriveOptions(MemberInfo member, McpServerResourceCreateOptions? options) @@ -238,6 +240,12 @@ private static McpServerResourceCreateOptions DeriveOptions(MemberInfo member, M newOptions.Description ??= descAttr.Description; } + // Set metadata if not already provided and the member is a MethodInfo + if (member is MethodInfo method) + { + newOptions.Metadata ??= AIFunctionMcpServerTool.CreateMetadata(method); + } + return newOptions; } @@ -270,11 +278,13 @@ private static string DeriveUriTemplate(string name, AIFunction function) internal AIFunction AIFunction { get; } /// Initializes a new instance of the class. - private AIFunctionMcpServerResource(AIFunction function, ResourceTemplate resourceTemplate) + private AIFunctionMcpServerResource(AIFunction function, ResourceTemplate resourceTemplate, IReadOnlyList metadata) { AIFunction = function; ProtocolResourceTemplate = resourceTemplate; + ProtocolResourceTemplate.McpServerResource = this; ProtocolResource = resourceTemplate.AsResource(); + _metadata = metadata; if (ProtocolResource is null) { @@ -289,6 +299,9 @@ private AIFunctionMcpServerResource(AIFunction function, ResourceTemplate resour /// public override Resource? ProtocolResource { get; } + /// + public override IReadOnlyList Metadata => _metadata; + /// public override async ValueTask ReadAsync( RequestContext request, CancellationToken cancellationToken = default) @@ -316,7 +329,7 @@ private AIFunctionMcpServerResource(AIFunction function, ResourceTemplate resour } // Build up the arguments for the AIFunction call, including all of the name/value pairs from the URI. - request.Services = new RequestServiceProvider(request, request.Services); + request.Services = new RequestServiceProvider(request); AIFunctionArguments arguments = new() { Services = request.Services }; // For templates, populate the arguments from the URI template. diff --git a/src/ModelContextProtocol.Core/Server/AIFunctionMcpServerTool.cs b/src/ModelContextProtocol.Core/Server/AIFunctionMcpServerTool.cs index 664ede5ab..cb4758486 100644 --- a/src/ModelContextProtocol.Core/Server/AIFunctionMcpServerTool.cs +++ b/src/ModelContextProtocol.Core/Server/AIFunctionMcpServerTool.cs @@ -1,7 +1,5 @@ using Microsoft.Extensions.AI; using Microsoft.Extensions.DependencyInjection; -using Microsoft.Extensions.Logging; -using Microsoft.Extensions.Logging.Abstractions; using ModelContextProtocol.Protocol; using System.ComponentModel; using System.Diagnostics; @@ -15,8 +13,8 @@ namespace ModelContextProtocol.Server; /// Provides an that's implemented via an . internal sealed partial class AIFunctionMcpServerTool : McpServerTool { - private readonly ILogger _logger; private readonly bool _structuredOutputRequiresWrapping; + private readonly IReadOnlyList _metadata; /// /// Creates an instance for a method, specified via a instance. @@ -26,7 +24,7 @@ internal sealed partial class AIFunctionMcpServerTool : McpServerTool McpServerToolCreateOptions? options) { Throw.IfNull(method); - + options = DeriveOptions(method.Method, options); return Create(method.Method, method.Target, options); @@ -146,7 +144,7 @@ options.OpenWorld is not null || } } - return new AIFunctionMcpServerTool(function, tool, options?.Services, structuredOutputRequiresWrapping); + return new AIFunctionMcpServerTool(function, tool, options?.Services, structuredOutputRequiresWrapping, options?.Metadata ?? []); } private static McpServerToolCreateOptions DeriveOptions(MethodInfo method, McpServerToolCreateOptions? options) @@ -186,6 +184,9 @@ private static McpServerToolCreateOptions DeriveOptions(MethodInfo method, McpSe newOptions.Description ??= descAttr.Description; } + // Set metadata if not already provided + newOptions.Metadata ??= CreateMetadata(method); + return newOptions; } @@ -193,17 +194,22 @@ private static McpServerToolCreateOptions DeriveOptions(MethodInfo method, McpSe internal AIFunction AIFunction { get; } /// Initializes a new instance of the class. - private AIFunctionMcpServerTool(AIFunction function, Tool tool, IServiceProvider? serviceProvider, bool structuredOutputRequiresWrapping) + private AIFunctionMcpServerTool(AIFunction function, Tool tool, IServiceProvider? serviceProvider, bool structuredOutputRequiresWrapping, IReadOnlyList metadata) { AIFunction = function; ProtocolTool = tool; - _logger = serviceProvider?.GetService()?.CreateLogger() ?? (ILogger)NullLogger.Instance; + ProtocolTool.McpServerTool = this; + _structuredOutputRequiresWrapping = structuredOutputRequiresWrapping; + _metadata = metadata; } /// public override Tool ProtocolTool { get; } + /// + public override IReadOnlyList Metadata => _metadata; + /// public override async ValueTask InvokeAsync( RequestContext request, CancellationToken cancellationToken = default) @@ -211,7 +217,7 @@ public override async ValueTask InvokeAsync( Throw.IfNull(request); cancellationToken.ThrowIfCancellationRequested(); - request.Services = new RequestServiceProvider(request, request.Services); + request.Services = new RequestServiceProvider(request); AIFunctionArguments arguments = new() { Services = request.Services }; if (request.Params?.Arguments is { } argDict) @@ -223,24 +229,7 @@ public override async ValueTask InvokeAsync( } object? result; - try - { - result = await AIFunction.InvokeAsync(arguments, cancellationToken).ConfigureAwait(false); - } - catch (Exception e) when (e is not OperationCanceledException) - { - ToolCallError(request.Params?.Name ?? string.Empty, e); - - string errorMessage = e is McpException ? - $"An error occurred invoking '{request.Params?.Name}': {e.Message}" : - $"An error occurred invoking '{request.Params?.Name}'."; - - return new() - { - IsError = true, - Content = [new TextContentBlock { Text = errorMessage }], - }; - } + result = await AIFunction.InvokeAsync(arguments, cancellationToken).ConfigureAwait(false); JsonNode? structuredContent = CreateStructuredResponse(result); return result switch @@ -257,27 +246,27 @@ public override async ValueTask InvokeAsync( Content = [], StructuredContent = structuredContent, }, - + string text => new() { Content = [new TextContentBlock { Text = text }], StructuredContent = structuredContent, }, - + ContentBlock content => new() { Content = [content], StructuredContent = structuredContent, }, - + IEnumerable contentItems => ConvertAIContentEnumerableToCallToolResult(contentItems, structuredContent), - + IEnumerable contents => new() { Content = [.. contents], StructuredContent = structuredContent, }, - + CallToolResult callToolResponse => callToolResponse, _ => new() @@ -336,6 +325,26 @@ static bool IsAsyncMethod(MethodInfo method) } } + /// Creates metadata from attributes on the specified method and its declaring class, with the MethodInfo as the first item. + internal static IReadOnlyList CreateMetadata(MethodInfo method) + { + // Add the MethodInfo to the start of the metadata similar to what RouteEndpointDataSource does for minimal endpoints. + List metadata = [method]; + + // Add class-level attributes first, since those are less specific. + if (method.DeclaringType is not null) + { + metadata.AddRange(method.DeclaringType.GetCustomAttributes()); + } + + // Add method-level attributes second, since those are more specific. + // When metadata conflicts, later metadata usually takes precedence with exceptions for metadata like + // IAllowAnonymous which always take precedence over IAuthorizeData no matter the order. + metadata.AddRange(method.GetCustomAttributes()); + + return metadata.AsReadOnly(); + } + /// Regex that flags runs of characters other than ASCII digits or letters. #if NET [GeneratedRegex("[^0-9A-Za-z]+")] @@ -446,7 +455,4 @@ private static CallToolResult ConvertAIContentEnumerableToCallToolResult(IEnumer IsError = allErrorContent && hasAny }; } - - [LoggerMessage(Level = LogLevel.Error, Message = "\"{ToolName}\" threw an unhandled exception.")] - private partial void ToolCallError(string toolName, Exception exception); } \ No newline at end of file diff --git a/src/ModelContextProtocol.Core/Server/DestinationBoundMcpServer.cs b/src/ModelContextProtocol.Core/Server/DestinationBoundMcpServer.cs index d286d1ef4..78346c399 100644 --- a/src/ModelContextProtocol.Core/Server/DestinationBoundMcpServer.cs +++ b/src/ModelContextProtocol.Core/Server/DestinationBoundMcpServer.cs @@ -22,15 +22,25 @@ internal sealed class DestinationBoundMcpServer(McpServer server, ITransport? tr public Task SendMessageAsync(JsonRpcMessage message, CancellationToken cancellationToken = default) { - Debug.Assert(message.RelatedTransport is null); - message.RelatedTransport = transport; + if (message.Context is not null) + { + throw new ArgumentException("Only transports can provide a JsonRpcMessageContext."); + } + + message.Context = new JsonRpcMessageContext(); + message.Context.RelatedTransport = transport; return server.SendMessageAsync(message, cancellationToken); } public Task SendRequestAsync(JsonRpcRequest request, CancellationToken cancellationToken = default) { - Debug.Assert(request.RelatedTransport is null); - request.RelatedTransport = transport; + if (request.Context is not null) + { + throw new ArgumentException("Only transports can provide a JsonRpcMessageContext."); + } + + request.Context = new JsonRpcMessageContext(); + request.Context.RelatedTransport = transport; return server.SendRequestAsync(request, cancellationToken); } } diff --git a/src/ModelContextProtocol.Core/Server/IMcpServerPrimitive.cs b/src/ModelContextProtocol.Core/Server/IMcpServerPrimitive.cs index 597fdec97..f3ec62219 100644 --- a/src/ModelContextProtocol.Core/Server/IMcpServerPrimitive.cs +++ b/src/ModelContextProtocol.Core/Server/IMcpServerPrimitive.cs @@ -7,4 +7,13 @@ public interface IMcpServerPrimitive { /// Gets the unique identifier of the primitive. string Id { get; } + + /// + /// Gets the metadata for this primitive instance. + /// + /// + /// Contains attributes from the associated MethodInfo and declaring class (if any), + /// with class-level attributes appearing before method-level attributes. + /// + IReadOnlyList Metadata { get; } } diff --git a/src/ModelContextProtocol.Core/Server/McpServer.cs b/src/ModelContextProtocol.Core/Server/McpServer.cs index 6c5858f91..0056b1ae0 100644 --- a/src/ModelContextProtocol.Core/Server/McpServer.cs +++ b/src/ModelContextProtocol.Core/Server/McpServer.cs @@ -7,7 +7,7 @@ namespace ModelContextProtocol.Server; /// -internal sealed class McpServer : McpEndpoint, IMcpServer +internal sealed partial class McpServer : McpEndpoint, IMcpServer { internal static Implementation DefaultImplementation { get; } = new() { @@ -195,9 +195,12 @@ private void ConfigureCompletion(McpServerOptions options) return; } + var completeHandler = completionsCapability.CompleteHandler ?? (static async (_, __) => new CompleteResult()); + completeHandler = BuildFilterPipeline(completeHandler, options.Filters.CompleteFilters); + ServerCapabilities.Completions = new() { - CompleteHandler = completionsCapability.CompleteHandler ?? (static async (_, __) => new CompleteResult()) + CompleteHandler = completeHandler }; SetHandler( @@ -279,30 +282,14 @@ await originalListResourceTemplatesHandler(request, cancellationToken).Configure var originalReadResourceHandler = readResourceHandler; readResourceHandler = async (request, cancellationToken) => { - if (request.Params?.Uri is string uri) + if (request.MatchedPrimitive is McpServerResource matchedResource) { - // First try an O(1) lookup by exact match. - if (resources.TryGetPrimitive(uri, out var resource)) + if (await matchedResource.ReadAsync(request, cancellationToken).ConfigureAwait(false) is { } result) { - if (await resource.ReadAsync(request, cancellationToken).ConfigureAwait(false) is { } result) - { - return result; - } - } - - // Fall back to an O(N) lookup, trying to match against each URI template. - // The number of templates is controlled by the server developer, and the number is expected to be - // not terribly large. If that changes, this can be tweaked to enable a more efficient lookup. - foreach (var resourceTemplate in resources) - { - if (await resourceTemplate.ReadAsync(request, cancellationToken).ConfigureAwait(false) is { } result) - { - return result; - } + return result; } } - // Finally fall back to the handler. return await originalReadResourceHandler(request, cancellationToken).ConfigureAwait(false); }; @@ -312,6 +299,43 @@ await originalListResourceTemplatesHandler(request, cancellationToken).Configure // subscribe = true; } + listResourcesHandler = BuildFilterPipeline(listResourcesHandler, options.Filters.ListResourcesFilters); + listResourceTemplatesHandler = BuildFilterPipeline(listResourceTemplatesHandler, options.Filters.ListResourceTemplatesFilters); + readResourceHandler = BuildFilterPipeline(readResourceHandler, options.Filters.ReadResourceFilters, handler => + async (request, cancellationToken) => + { + // Initial handler that sets MatchedPrimitive + if (request.Params?.Uri is { } uri && resources is not null) + { + // First try an O(1) lookup by exact match. + if (resources.TryGetPrimitive(uri, out var resource)) + { + request.MatchedPrimitive = resource; + } + else + { + // Fall back to an O(N) lookup, trying to match against each URI template. + // The number of templates is controlled by the server developer, and the number is expected to be + // not terribly large. If that changes, this can be tweaked to enable a more efficient lookup. + foreach (var resourceTemplate in resources) + { + // Check if this template would handle the request by testing if ReadAsync would succeed + if (resourceTemplate.IsTemplated) + { + // This is a simplified check - a more robust implementation would match the URI pattern + // For now, we'll let the actual handler attempt the match + request.MatchedPrimitive = resourceTemplate; + break; + } + } + } + } + + return await handler(request, cancellationToken).ConfigureAwait(false); + }); + subscribeHandler = BuildFilterPipeline(subscribeHandler, options.Filters.SubscribeToResourcesFilters); + unsubscribeHandler = BuildFilterPipeline(unsubscribeHandler, options.Filters.UnsubscribeFromResourcesFilters); + ServerCapabilities.Resources.ListResourcesHandler = listResourcesHandler; ServerCapabilities.Resources.ListResourceTemplatesHandler = listResourceTemplatesHandler; ServerCapabilities.Resources.ReadResourceHandler = readResourceHandler; @@ -390,8 +414,7 @@ await originalListPromptsHandler(request, cancellationToken).ConfigureAwait(fals var originalGetPromptHandler = getPromptHandler; getPromptHandler = (request, cancellationToken) => { - if (request.Params is not null && - prompts.TryGetPrimitive(request.Params.Name, out var prompt)) + if (request.MatchedPrimitive is McpServerPrompt prompt) { return prompt.GetAsync(request, cancellationToken); } @@ -402,6 +425,20 @@ await originalListPromptsHandler(request, cancellationToken).ConfigureAwait(fals listChanged = true; } + listPromptsHandler = BuildFilterPipeline(listPromptsHandler, options.Filters.ListPromptsFilters); + getPromptHandler = BuildFilterPipeline(getPromptHandler, options.Filters.GetPromptFilters, handler => + (request, cancellationToken) => + { + // Initial handler that sets MatchedPrimitive + if (request.Params?.Name is { } promptName && prompts is not null && + prompts.TryGetPrimitive(promptName, out var prompt)) + { + request.MatchedPrimitive = prompt; + } + + return handler(request, cancellationToken); + }); + ServerCapabilities.Prompts.ListPromptsHandler = listPromptsHandler; ServerCapabilities.Prompts.GetPromptHandler = getPromptHandler; ServerCapabilities.Prompts.PromptCollection = prompts; @@ -458,8 +495,7 @@ await originalListToolsHandler(request, cancellationToken).ConfigureAwait(false) var originalCallToolHandler = callToolHandler; callToolHandler = (request, cancellationToken) => { - if (request.Params is not null && - tools.TryGetPrimitive(request.Params.Name, out var tool)) + if (request.MatchedPrimitive is McpServerTool tool) { return tool.InvokeAsync(request, cancellationToken); } @@ -470,6 +506,51 @@ await originalListToolsHandler(request, cancellationToken).ConfigureAwait(false) listChanged = true; } + listToolsHandler = BuildFilterPipeline(listToolsHandler, options.Filters.ListToolsFilters); + callToolHandler = BuildFilterPipeline(callToolHandler, options.Filters.CallToolFilters, handler => + (request, cancellationToken) => + { + // Initial handler that sets MatchedPrimitive + if (request.Params?.Name is { } toolName && tools is not null && + tools.TryGetPrimitive(toolName, out var tool)) + { + request.MatchedPrimitive = tool; + } + + return handler(request, cancellationToken); + }, handler => + async (request, cancellationToken) => + { + // Final handler that provides exception handling only for tool execution + // Only wrap tool execution in try-catch, not tool resolution + if (request.MatchedPrimitive is McpServerTool) + { + try + { + return await handler(request, cancellationToken).ConfigureAwait(false); + } + catch (Exception e) when (e is not OperationCanceledException) + { + ToolCallError(request.Params?.Name ?? string.Empty, e); + + string errorMessage = e is McpException ? + $"An error occurred invoking '{request.Params?.Name}': {e.Message}" : + $"An error occurred invoking '{request.Params?.Name}'."; + + return new() + { + IsError = true, + Content = [new TextContentBlock { Text = errorMessage }], + }; + } + } + else + { + // For unmatched tools, let exceptions bubble up as protocol errors + return await handler(request, cancellationToken).ConfigureAwait(false); + } + }); + ServerCapabilities.Tools.ListToolsHandler = listToolsHandler; ServerCapabilities.Tools.CallToolHandler = callToolHandler; ServerCapabilities.Tools.ToolCollection = tools; @@ -493,12 +574,18 @@ private void ConfigureLogging(McpServerOptions options) // We don't require that the handler be provided, as we always store the provided log level to the server. var setLoggingLevelHandler = options.Capabilities?.Logging?.SetLoggingLevelHandler; + // Apply filters to the handler + if (setLoggingLevelHandler is not null) + { + setLoggingLevelHandler = BuildFilterPipeline(setLoggingLevelHandler, options.Filters.SetLoggingLevelFilters); + } + ServerCapabilities.Logging = new(); ServerCapabilities.Logging.SetLoggingLevelHandler = setLoggingLevelHandler; RequestHandlers.Set( RequestMethods.LoggingSetLevel, - (request, destinationTransport, cancellationToken) => + (request, jsonRpcRequest, cancellationToken) => { // Store the provided level. if (request is not null) @@ -514,7 +601,7 @@ private void ConfigureLogging(McpServerOptions options) // If a handler was provided, now delegate to it. if (setLoggingLevelHandler is not null) { - return InvokeHandlerAsync(setLoggingLevelHandler, request, destinationTransport, cancellationToken); + return InvokeHandlerAsync(setLoggingLevelHandler, request, jsonRpcRequest, cancellationToken); } // Otherwise, consider it handled. @@ -527,23 +614,24 @@ private void ConfigureLogging(McpServerOptions options) private ValueTask InvokeHandlerAsync( Func, CancellationToken, ValueTask> handler, TParams? args, - ITransport? destinationTransport = null, + JsonRpcRequest jsonRpcRequest, CancellationToken cancellationToken = default) { return _servicesScopePerRequest ? - InvokeScopedAsync(handler, args, cancellationToken) : - handler(new(new DestinationBoundMcpServer(this, destinationTransport)) { Params = args }, cancellationToken); + InvokeScopedAsync(handler, args, jsonRpcRequest, cancellationToken) : + handler(new(new DestinationBoundMcpServer(this, jsonRpcRequest.Context?.RelatedTransport), jsonRpcRequest) { Params = args }, cancellationToken); async ValueTask InvokeScopedAsync( Func, CancellationToken, ValueTask> handler, TParams? args, + JsonRpcRequest jsonRpcRequest, CancellationToken cancellationToken) { var scope = Services?.GetService()?.CreateAsyncScope(); try { return await handler( - new RequestContext(new DestinationBoundMcpServer(this, destinationTransport)) + new RequestContext(new DestinationBoundMcpServer(this, jsonRpcRequest.Context?.RelatedTransport), jsonRpcRequest) { Services = scope?.ServiceProvider ?? Services, Params = args @@ -566,12 +654,33 @@ private void SetHandler( JsonTypeInfo requestTypeInfo, JsonTypeInfo responseTypeInfo) { - RequestHandlers.Set(method, - (request, destinationTransport, cancellationToken) => - InvokeHandlerAsync(handler, request, destinationTransport, cancellationToken), + RequestHandlers.Set(method, + (request, jsonRpcRequest, cancellationToken) => + InvokeHandlerAsync(handler, request, jsonRpcRequest, cancellationToken), requestTypeInfo, responseTypeInfo); } + private static THandler BuildFilterPipeline( + THandler baseHandler, List> filters, + Func? initialHandler = null, + Func? finalHandler = null) + { + THandler current = baseHandler; + if (finalHandler is not null) + { + current = finalHandler(current); + } + for (int i = filters.Count - 1; i >= 0; i--) + { + current = filters[i](current); + } + if (initialHandler is not null) + { + current = initialHandler(current); + } + return current; + } + private void UpdateEndpointNameWithClientInfo() { if (ClientInfo is null) @@ -594,4 +703,7 @@ internal static LoggingLevel ToLoggingLevel(LogLevel level) => LogLevel.Critical => Protocol.LoggingLevel.Critical, _ => Protocol.LoggingLevel.Emergency, }; + + [LoggerMessage(Level = LogLevel.Error, Message = "\"{ToolName}\" threw an unhandled exception.")] + private partial void ToolCallError(string toolName, Exception exception); } diff --git a/src/ModelContextProtocol.Core/Server/McpServerFilters.cs b/src/ModelContextProtocol.Core/Server/McpServerFilters.cs new file mode 100644 index 000000000..d15154dd0 --- /dev/null +++ b/src/ModelContextProtocol.Core/Server/McpServerFilters.cs @@ -0,0 +1,161 @@ +using ModelContextProtocol.Protocol; + +namespace ModelContextProtocol.Server; + +/// +/// Provides filter collections for MCP server handlers. +/// +/// +/// This class contains collections of filters that can be applied to various MCP server handlers. +/// This allows for middleware-style composition where filters can perform actions before and after the inner handler. +/// +public sealed class McpServerFilters +{ + /// + /// Gets the filters for the list tools handler pipeline. + /// + /// + /// + /// These filters wrap handlers that return a list of available tools when requested by a client. + /// The filters can modify, log, or perform additional operations on requests and responses for + /// requests. It supports pagination through the cursor mechanism, + /// where the client can make repeated calls with the cursor returned by the previous call to retrieve more tools. + /// + /// + /// These filters work alongside any tools defined in the collection. + /// Tools from both sources will be combined when returning results to clients. + /// + /// + public List, CancellationToken, ValueTask>, Func, CancellationToken, ValueTask>>> ListToolsFilters { get; } = new(); + + /// + /// Gets the filters for the call tool handler pipeline. + /// + /// + /// These filters wrap handlers that are invoked when a client makes a call to a tool that isn't found in the collection. + /// The filters can modify, log, or perform additional operations on requests and responses for + /// requests. The handler should implement logic to execute the requested tool and return appropriate results. + /// + public List, CancellationToken, ValueTask>, Func, CancellationToken, ValueTask>>> CallToolFilters { get; } = new(); + + /// + /// Gets the filters for the list prompts handler pipeline. + /// + /// + /// + /// These filters wrap handlers that return a list of available prompts when requested by a client. + /// The filters can modify, log, or perform additional operations on requests and responses for + /// requests. It supports pagination through the cursor mechanism, + /// where the client can make repeated calls with the cursor returned by the previous call to retrieve more prompts. + /// + /// + /// These filters work alongside any prompts defined in the collection. + /// Prompts from both sources will be combined when returning results to clients. + /// + /// + public List, CancellationToken, ValueTask>, Func, CancellationToken, ValueTask>>> ListPromptsFilters { get; } = new(); + + /// + /// Gets the filters for the get prompt handler pipeline. + /// + /// + /// These filters wrap handlers that are invoked when a client requests details for a specific prompt that isn't found in the collection. + /// The filters can modify, log, or perform additional operations on requests and responses for + /// requests. The handler should implement logic to fetch or generate the requested prompt and return appropriate results. + /// + public List, CancellationToken, ValueTask>, Func, CancellationToken, ValueTask>>> GetPromptFilters { get; } = new(); + + /// + /// Gets the filters for the list resource templates handler pipeline. + /// + /// + /// These filters wrap handlers that return a list of available resource templates when requested by a client. + /// The filters can modify, log, or perform additional operations on requests and responses for + /// requests. It supports pagination through the cursor mechanism, + /// where the client can make repeated calls with the cursor returned by the previous call to retrieve more resource templates. + /// + public List, CancellationToken, ValueTask>, Func, CancellationToken, ValueTask>>> ListResourceTemplatesFilters { get; } = new(); + + /// + /// Gets the filters for the list resources handler pipeline. + /// + /// + /// These filters wrap handlers that return a list of available resources when requested by a client. + /// The filters can modify, log, or perform additional operations on requests and responses for + /// requests. It supports pagination through the cursor mechanism, + /// where the client can make repeated calls with the cursor returned by the previous call to retrieve more resources. + /// + public List, CancellationToken, ValueTask>, Func, CancellationToken, ValueTask>>> ListResourcesFilters { get; } = new(); + + /// + /// Gets the filters for the read resource handler pipeline. + /// + /// + /// These filters wrap handlers that are invoked when a client requests the content of a specific resource identified by its URI. + /// The filters can modify, log, or perform additional operations on requests and responses for + /// requests. The handler should implement logic to locate and retrieve the requested resource. + /// + public List, CancellationToken, ValueTask>, Func, CancellationToken, ValueTask>>> ReadResourceFilters { get; } = new(); + + /// + /// Gets the filters for the complete handler pipeline. + /// + /// + /// These filters wrap handlers that provide auto-completion suggestions for prompt arguments or resource references in the Model Context Protocol. + /// The filters can modify, log, or perform additional operations on requests and responses for + /// requests. The handler processes auto-completion requests, returning a list of suggestions based on the + /// reference type and current argument value. + /// + public List, CancellationToken, ValueTask>, Func, CancellationToken, ValueTask>>> CompleteFilters { get; } = new(); + + /// + /// Gets the filters for the subscribe to resources handler pipeline. + /// + /// + /// + /// These filters wrap handlers that are invoked when a client wants to receive notifications about changes to specific resources or resource patterns. + /// The filters can modify, log, or perform additional operations on requests and responses for + /// requests. The handler should implement logic to register the client's interest in the specified resources + /// and set up the necessary infrastructure to send notifications when those resources change. + /// + /// + /// After a successful subscription, the server should send resource change notifications to the client + /// whenever a relevant resource is created, updated, or deleted. + /// + /// + public List, CancellationToken, ValueTask>, Func, CancellationToken, ValueTask>>> SubscribeToResourcesFilters { get; } = new(); + + /// + /// Gets the filters for the unsubscribe from resources handler pipeline. + /// + /// + /// + /// These filters wrap handlers that are invoked when a client wants to stop receiving notifications about previously subscribed resources. + /// The filters can modify, log, or perform additional operations on requests and responses for + /// requests. The handler should implement logic to remove the client's subscriptions to the specified resources + /// and clean up any associated resources. + /// + /// + /// After a successful unsubscription, the server should no longer send resource change notifications + /// to the client for the specified resources. + /// + /// + public List, CancellationToken, ValueTask>, Func, CancellationToken, ValueTask>>> UnsubscribeFromResourcesFilters { get; } = new(); + + /// + /// Gets the filters for the set logging level handler pipeline. + /// + /// + /// + /// These filters wrap handlers that process requests from clients. When set, it enables + /// clients to control which log messages they receive by specifying a minimum severity threshold. + /// The filters can modify, log, or perform additional operations on requests and responses for + /// requests. + /// + /// + /// After handling a level change request, the server typically begins sending log messages + /// at or above the specified level to the client as notifications/message notifications. + /// + /// + public List, CancellationToken, ValueTask>, Func, CancellationToken, ValueTask>>> SetLoggingLevelFilters { get; } = new(); +} diff --git a/src/ModelContextProtocol.Core/Server/McpServerOptions.cs b/src/ModelContextProtocol.Core/Server/McpServerOptions.cs index 8c50a9b55..1c981b77f 100644 --- a/src/ModelContextProtocol.Core/Server/McpServerOptions.cs +++ b/src/ModelContextProtocol.Core/Server/McpServerOptions.cs @@ -79,4 +79,14 @@ public sealed class McpServerOptions /// /// public Implementation? KnownClientInfo { get; set; } + + /// + /// Gets the filter collections for MCP server handlers. + /// + /// + /// This property provides access to filter collections that can be used to modify the behavior + /// of various MCP server handlers. Filters are applied in reverse order, so the last filter + /// added will be the outermost (first to execute). + /// + public McpServerFilters Filters { get; } = new(); } diff --git a/src/ModelContextProtocol.Core/Server/McpServerPrompt.cs b/src/ModelContextProtocol.Core/Server/McpServerPrompt.cs index 68874df3e..746278791 100644 --- a/src/ModelContextProtocol.Core/Server/McpServerPrompt.cs +++ b/src/ModelContextProtocol.Core/Server/McpServerPrompt.cs @@ -20,7 +20,7 @@ namespace ModelContextProtocol.Server; /// /// /// Most commonly, instances are created using the static methods. -/// These methods enable creating an for a method, specified via a or +/// These methods enable creating an for a method, specified via a or /// , and are what are used implicitly by WithPromptsFromAssembly and WithPrompts. The methods /// create instances capable of working with a large variety of .NET method signatures, automatically handling /// how parameters are marshaled into the method from the JSON received from the MCP client, and how the return value is marshaled back @@ -61,15 +61,15 @@ namespace ModelContextProtocol.Server; /// /// /// -/// When the is constructed, it may be passed an via +/// When the is constructed, it may be passed an via /// . Any parameter that can be satisfied by that -/// according to will be resolved from the provided to +/// according to will be resolved from the provided to /// rather than from the argument collection. /// /// /// /// -/// Any parameter attributed with will similarly be resolved from the +/// Any parameter attributed with will similarly be resolved from the /// provided to rather than from the argument collection. /// /// @@ -80,7 +80,7 @@ namespace ModelContextProtocol.Server; /// /// /// In general, the data supplied via the 's dictionary is passed along from the caller and -/// should thus be considered unvalidated and untrusted. To provide validated and trusted data to the invocation of the prompt, consider having +/// should thus be considered unvalidated and untrusted. To provide validated and trusted data to the invocation of the prompt, consider having /// the prompt be an instance method, referring to data stored in the instance, or using an instance or parameters resolved from the /// to provide data to the method. /// @@ -128,6 +128,15 @@ protected McpServerPrompt() /// public abstract Prompt ProtocolPrompt { get; } + /// + /// Gets the metadata for this prompt instance. + /// + /// + /// Contains attributes from the associated MethodInfo and declaring class (if any), + /// with class-level attributes appearing before method-level attributes. + /// + public abstract IReadOnlyList Metadata { get; } + /// /// Gets the prompt, rendering it with the provided request parameters and returning the prompt result. /// @@ -170,7 +179,7 @@ public static McpServerPrompt Create( /// is . /// is an instance method but is . public static McpServerPrompt Create( - MethodInfo method, + MethodInfo method, object? target = null, McpServerPromptCreateOptions? options = null) => AIFunctionMcpServerPrompt.Create(method, target, options); diff --git a/src/ModelContextProtocol.Core/Server/McpServerPromptCreateOptions.cs b/src/ModelContextProtocol.Core/Server/McpServerPromptCreateOptions.cs index 95d712ffd..1853b0f1a 100644 --- a/src/ModelContextProtocol.Core/Server/McpServerPromptCreateOptions.cs +++ b/src/ModelContextProtocol.Core/Server/McpServerPromptCreateOptions.cs @@ -68,6 +68,15 @@ public sealed class McpServerPromptCreateOptions /// public AIJsonSchemaCreateOptions? SchemaCreateOptions { get; set; } + /// + /// Gets or sets the metadata associated with the prompt. + /// + /// + /// Metadata includes information such as the attributes extracted from the method and its declaring class. + /// If not provided, metadata will be automatically generated for methods created via reflection. + /// + public IReadOnlyList? Metadata { get; set; } + /// /// Creates a shallow clone of the current instance. /// @@ -80,5 +89,6 @@ internal McpServerPromptCreateOptions Clone() => Description = Description, SerializerOptions = SerializerOptions, SchemaCreateOptions = SchemaCreateOptions, + Metadata = Metadata, }; } diff --git a/src/ModelContextProtocol.Core/Server/McpServerResource.cs b/src/ModelContextProtocol.Core/Server/McpServerResource.cs index 8e42d3e1c..9508cda0a 100644 --- a/src/ModelContextProtocol.Core/Server/McpServerResource.cs +++ b/src/ModelContextProtocol.Core/Server/McpServerResource.cs @@ -11,13 +11,13 @@ namespace ModelContextProtocol.Server; /// /// /// is an abstract base class that represents an MCP resource for use in the server (as opposed -/// to or , which provide the protocol representations of a resource). Instances of +/// to or , which provide the protocol representations of a resource). Instances of /// can be added into a to be picked up automatically when /// is used to create an , or added into a . /// /// /// Most commonly, instances are created using the static methods. -/// These methods enable creating an for a method, specified via a or +/// These methods enable creating an for a method, specified via a or /// , and are what are used implicitly by WithResourcesFromAssembly and /// . The methods /// create instances capable of working with a large variety of .NET method signatures, automatically handling @@ -62,15 +62,15 @@ namespace ModelContextProtocol.Server; /// /// /// -/// When the is constructed, it may be passed an via +/// When the is constructed, it may be passed an via /// . Any parameter that can be satisfied by that -/// according to will be resolved from the provided to the +/// according to will be resolved from the provided to the /// resource invocation rather than from the argument collection. /// /// /// /// -/// Any parameter attributed with will similarly be resolved from the +/// Any parameter attributed with will similarly be resolved from the /// provided to the resource invocation rather than from the argument collection. /// /// @@ -149,6 +149,15 @@ protected McpServerResource() /// public virtual Resource? ProtocolResource => ProtocolResourceTemplate.AsResource(); + /// + /// Gets the metadata for this resource instance. + /// + /// + /// Contains attributes from the associated MethodInfo and declaring class (if any), + /// with class-level attributes appearing before method-level attributes. + /// + public abstract IReadOnlyList Metadata { get; } + /// /// Gets the resource, rendering it with the provided request parameters and returning the resource result. /// @@ -192,7 +201,7 @@ public static McpServerResource Create( /// is . /// is an instance method but is . public static McpServerResource Create( - MethodInfo method, + MethodInfo method, object? target = null, McpServerResourceCreateOptions? options = null) => AIFunctionMcpServerResource.Create(method, target, options); diff --git a/src/ModelContextProtocol.Core/Server/McpServerResourceCreateOptions.cs b/src/ModelContextProtocol.Core/Server/McpServerResourceCreateOptions.cs index 24051a7ff..2d6b66b32 100644 --- a/src/ModelContextProtocol.Core/Server/McpServerResourceCreateOptions.cs +++ b/src/ModelContextProtocol.Core/Server/McpServerResourceCreateOptions.cs @@ -83,6 +83,15 @@ public sealed class McpServerResourceCreateOptions /// public AIJsonSchemaCreateOptions? SchemaCreateOptions { get; set; } + /// + /// Gets or sets the metadata associated with the resource. + /// + /// + /// Metadata includes information such as attributes extracted from the method and its declaring class. + /// If not provided, metadata will be automatically generated for methods created via reflection. + /// + public IReadOnlyList? Metadata { get; set; } + /// /// Creates a shallow clone of the current instance. /// @@ -97,5 +106,6 @@ internal McpServerResourceCreateOptions Clone() => MimeType = MimeType, SerializerOptions = SerializerOptions, SchemaCreateOptions = SchemaCreateOptions, + Metadata = Metadata, }; } diff --git a/src/ModelContextProtocol.Core/Server/McpServerTool.cs b/src/ModelContextProtocol.Core/Server/McpServerTool.cs index e3958271b..baddf88f8 100644 --- a/src/ModelContextProtocol.Core/Server/McpServerTool.cs +++ b/src/ModelContextProtocol.Core/Server/McpServerTool.cs @@ -20,7 +20,7 @@ namespace ModelContextProtocol.Server; /// /// /// Most commonly, instances are created using the static methods. -/// These methods enable creating an for a method, specified via a or +/// These methods enable creating an for a method, specified via a or /// , and are what are used implicitly by WithToolsFromAssembly and WithTools. The methods /// create instances capable of working with a large variety of .NET method signatures, automatically handling /// how parameters are marshaled into the method from the JSON received from the MCP client, and how the return value is marshaled back @@ -56,22 +56,22 @@ namespace ModelContextProtocol.Server; /// /// parameters accepting values /// are not included in the JSON schema and are bound to an instance manufactured -/// to forward progress notifications from the tool to the client. If the client included a in their request, +/// to forward progress notifications from the tool to the client. If the client included a in their request, /// progress reports issued to this instance will propagate to the client as notifications with /// that token. If the client did not include a , the instance will ignore any progress reports issued to it. /// /// /// /// -/// When the is constructed, it may be passed an via +/// When the is constructed, it may be passed an via /// . Any parameter that can be satisfied by that -/// according to will not be included in the generated JSON schema and will be resolved +/// according to will not be included in the generated JSON schema and will be resolved /// from the provided to rather than from the argument collection. /// /// /// /// -/// Any parameter attributed with will similarly be resolved from the +/// Any parameter attributed with will similarly be resolved from the /// provided to rather than from the argument /// collection, and will not be included in the generated JSON schema. /// @@ -79,13 +79,13 @@ namespace ModelContextProtocol.Server; /// /// /// -/// All other parameters are deserialized from the s in the dictionary, -/// using the supplied in , or if none was provided, +/// All other parameters are deserialized from the s in the dictionary, +/// using the supplied in , or if none was provided, /// using . /// /// /// In general, the data supplied via the 's dictionary is passed along from the caller and -/// should thus be considered unvalidated and untrusted. To provide validated and trusted data to the invocation of the tool, consider having +/// should thus be considered unvalidated and untrusted. To provide validated and trusted data to the invocation of the tool, consider having /// the tool be an instance method, referring to data stored in the instance, or using an instance or parameters resolved from the /// to provide data to the method. /// @@ -141,6 +141,15 @@ protected McpServerTool() /// Gets the protocol type for this instance. public abstract Tool ProtocolTool { get; } + /// + /// Gets the metadata for this tool instance. + /// + /// + /// Contains attributes from the associated MethodInfo and declaring class (if any), + /// with class-level attributes appearing before method-level attributes. + /// + public abstract IReadOnlyList Metadata { get; } + /// Invokes the . /// The request information resulting in the invocation of this tool. /// The to monitor for cancellation requests. The default is . @@ -172,7 +181,7 @@ public static McpServerTool Create( /// is . /// is an instance method but is . public static McpServerTool Create( - MethodInfo method, + MethodInfo method, object? target = null, McpServerToolCreateOptions? options = null) => AIFunctionMcpServerTool.Create(method, target, options); diff --git a/src/ModelContextProtocol.Core/Server/McpServerToolCreateOptions.cs b/src/ModelContextProtocol.Core/Server/McpServerToolCreateOptions.cs index bdb4ecb8d..d18af8c02 100644 --- a/src/ModelContextProtocol.Core/Server/McpServerToolCreateOptions.cs +++ b/src/ModelContextProtocol.Core/Server/McpServerToolCreateOptions.cs @@ -80,7 +80,7 @@ public sealed class McpServerToolCreateOptions public bool? Destructive { get; set; } /// - /// Gets or sets whether calling the tool repeatedly with the same arguments + /// Gets or sets whether calling the tool repeatedly with the same arguments /// will have no additional effect on its environment. /// /// @@ -155,6 +155,15 @@ public sealed class McpServerToolCreateOptions /// public AIJsonSchemaCreateOptions? SchemaCreateOptions { get; set; } + /// + /// Gets or sets the metadata associated with the tool. + /// + /// + /// Metadata includes information such as attributes extracted from the method and its declaring class. + /// If not provided, metadata will be automatically generated for methods created via reflection. + /// + public IReadOnlyList? Metadata { get; set; } + /// /// Creates a shallow clone of the current instance. /// @@ -172,5 +181,6 @@ internal McpServerToolCreateOptions Clone() => UseStructuredContent = UseStructuredContent, SerializerOptions = SerializerOptions, SchemaCreateOptions = SchemaCreateOptions, + Metadata = Metadata, }; } diff --git a/src/ModelContextProtocol.Core/Server/RequestContext.cs b/src/ModelContextProtocol.Core/Server/RequestContext.cs index b0ea9d993..8af9f666f 100644 --- a/src/ModelContextProtocol.Core/Server/RequestContext.cs +++ b/src/ModelContextProtocol.Core/Server/RequestContext.cs @@ -1,3 +1,6 @@ +using System.Security.Claims; +using ModelContextProtocol.Protocol; + namespace ModelContextProtocol.Server; /// @@ -15,19 +18,23 @@ public sealed class RequestContext private IMcpServer _server; /// - /// Initializes a new instance of the class with the specified server. + /// Initializes a new instance of the class with the specified server and JSON-RPC request. /// /// The server with which this instance is associated. - public RequestContext(IMcpServer server) + /// The JSON-RPC request associated with this context. + public RequestContext(IMcpServer server, JsonRpcRequest jsonRpcRequest) { Throw.IfNull(server); + Throw.IfNull(jsonRpcRequest); _server = server; + JsonRpcRequest = jsonRpcRequest; Services = server.Services; + User = jsonRpcRequest.Context?.User; } /// Gets or sets the server with which this instance is associated. - public IMcpServer Server + public IMcpServer Server { get => _server; set @@ -46,6 +53,23 @@ public IMcpServer Server /// public IServiceProvider? Services { get; set; } + /// Gets or sets the user associated with this request. + public ClaimsPrincipal? User { get; set; } + /// Gets or sets the parameters associated with this request. public TParams? Params { get; set; } + + /// + /// Gets or sets the primitive that matched the request. + /// + public IMcpServerPrimitive? MatchedPrimitive { get; set; } + + /// + /// Gets the JSON-RPC request associated with this context. + /// + /// + /// This property provides access to the complete JSON-RPC request that initiated this handler invocation, + /// including the method name, parameters, request ID, and associated transport and user information. + /// + public JsonRpcRequest JsonRpcRequest { get; } } \ No newline at end of file diff --git a/src/ModelContextProtocol.Core/Server/AugmentedServiceProvider.cs b/src/ModelContextProtocol.Core/Server/RequestServiceProvider.cs similarity index 69% rename from src/ModelContextProtocol.Core/Server/AugmentedServiceProvider.cs rename to src/ModelContextProtocol.Core/Server/RequestServiceProvider.cs index 3372072fe..38af614c2 100644 --- a/src/ModelContextProtocol.Core/Server/AugmentedServiceProvider.cs +++ b/src/ModelContextProtocol.Core/Server/RequestServiceProvider.cs @@ -1,16 +1,17 @@ using Microsoft.Extensions.DependencyInjection; using ModelContextProtocol.Protocol; +using System.Security.Claims; namespace ModelContextProtocol.Server; /// Augments a service provider with additional request-related services. -internal sealed class RequestServiceProvider( - RequestContext request, IServiceProvider? innerServices) : - IServiceProvider, IKeyedServiceProvider, - IServiceProviderIsService, IServiceProviderIsKeyedService, +internal sealed class RequestServiceProvider(RequestContext request) : + IServiceProvider, IKeyedServiceProvider, IServiceProviderIsService, IServiceProviderIsKeyedService, IDisposable, IAsyncDisposable where TRequestParams : RequestParams { + private readonly IServiceProvider? _innerServices = request.Services; + /// Gets the request associated with this instance. public RequestContext Request => request; @@ -18,7 +19,8 @@ internal sealed class RequestServiceProvider( public static bool IsAugmentedWith(Type serviceType) => serviceType == typeof(RequestContext) || serviceType == typeof(IMcpServer) || - serviceType == typeof(IProgress); + serviceType == typeof(IProgress) || + serviceType == typeof(ClaimsPrincipal); /// public object? GetService(Type serviceType) => @@ -26,22 +28,23 @@ public static bool IsAugmentedWith(Type serviceType) => serviceType == typeof(IMcpServer) ? request.Server : serviceType == typeof(IProgress) ? (request.Params?.ProgressToken is { } progressToken ? new TokenProgress(request.Server, progressToken) : NullProgress.Instance) : - innerServices?.GetService(serviceType); + serviceType == typeof(ClaimsPrincipal) ? request.User : + _innerServices?.GetService(serviceType); /// public bool IsService(Type serviceType) => IsAugmentedWith(serviceType) || - (innerServices as IServiceProviderIsService)?.IsService(serviceType) is true; + (_innerServices as IServiceProviderIsService)?.IsService(serviceType) is true; /// public bool IsKeyedService(Type serviceType, object? serviceKey) => (serviceKey is null && IsService(serviceType)) || - (innerServices as IServiceProviderIsKeyedService)?.IsKeyedService(serviceType, serviceKey) is true; + (_innerServices as IServiceProviderIsKeyedService)?.IsKeyedService(serviceType, serviceKey) is true; /// public object? GetKeyedService(Type serviceType, object? serviceKey) => serviceKey is null ? GetService(serviceType) : - (innerServices as IKeyedServiceProvider)?.GetKeyedService(serviceType, serviceKey); + (_innerServices as IKeyedServiceProvider)?.GetKeyedService(serviceType, serviceKey); /// public object GetRequiredKeyedService(Type serviceType, object? serviceKey) => @@ -50,9 +53,9 @@ public object GetRequiredKeyedService(Type serviceType, object? serviceKey) => /// public void Dispose() => - (innerServices as IDisposable)?.Dispose(); + (_innerServices as IDisposable)?.Dispose(); /// public ValueTask DisposeAsync() => - innerServices is IAsyncDisposable asyncDisposable ? asyncDisposable.DisposeAsync() : default; + _innerServices is IAsyncDisposable asyncDisposable ? asyncDisposable.DisposeAsync() : default; } \ No newline at end of file diff --git a/src/ModelContextProtocol.Core/Server/SseResponseStreamTransport.cs b/src/ModelContextProtocol.Core/Server/SseResponseStreamTransport.cs index 438421f28..8941e4ed6 100644 --- a/src/ModelContextProtocol.Core/Server/SseResponseStreamTransport.cs +++ b/src/ModelContextProtocol.Core/Server/SseResponseStreamTransport.cs @@ -1,4 +1,5 @@ using ModelContextProtocol.Protocol; +using System.Security.Claims; using System.Threading.Channels; namespace ModelContextProtocol.Server; @@ -9,7 +10,7 @@ namespace ModelContextProtocol.Server; /// /// /// This transport provides one-way communication from server to client using the SSE protocol over HTTP, -/// while receiving client messages through a separate mechanism. It writes messages as +/// while receiving client messages through a separate mechanism. It writes messages as /// SSE events to a response stream, typically associated with an HTTP response. /// /// @@ -41,7 +42,7 @@ public sealed class SseResponseStreamTransport(Stream sseResponseStream, string? /// /// The to monitor for cancellation requests. The default is . /// A task representing the send loop that writes JSON-RPC messages to the SSE response stream. - public async Task RunAsync(CancellationToken cancellationToken) + public async Task RunAsync(CancellationToken cancellationToken = default) { _isConnected = true; await _sseWriter.WriteAllAsync(sseResponseStream, cancellationToken).ConfigureAwait(false); @@ -64,6 +65,7 @@ public async ValueTask DisposeAsync() /// public async Task SendMessageAsync(JsonRpcMessage message, CancellationToken cancellationToken = default) { + Throw.IfNull(message); await _sseWriter.SendMessageAsync(message, cancellationToken).ConfigureAwait(false); } @@ -76,8 +78,8 @@ public async Task SendMessageAsync(JsonRpcMessage message, CancellationToken can /// Thrown when there is an attempt to process a message before calling . /// /// - /// This method is the entry point for processing client-to-server communication in the SSE transport model. - /// While the SSE protocol itself is unidirectional (server to client), this method allows bidirectional + /// This method is the entry point for processing client-to-server communication in the SSE transport model. + /// While the SSE protocol itself is unidirectional (server to client), this method allows bidirectional /// communication by handling HTTP POST requests sent to the message endpoint. /// /// @@ -85,11 +87,11 @@ public async Task SendMessageAsync(JsonRpcMessage message, CancellationToken can /// process the message and make it available to the MCP server via the channel. /// /// - /// This method validates that the transport is connected before processing the message, ensuring proper - /// sequencing of operations in the transport lifecycle. + /// If an authenticated sent the message, that can be included in the . + /// No other part of the context should be set. /// /// - public async Task OnMessageReceivedAsync(JsonRpcMessage message, CancellationToken cancellationToken) + public async Task OnMessageReceivedAsync(JsonRpcMessage message, CancellationToken cancellationToken = default) { Throw.IfNull(message); diff --git a/src/ModelContextProtocol.Core/Server/SseWriter.cs b/src/ModelContextProtocol.Core/Server/SseWriter.cs index 18571e2c9..4fb7feafe 100644 --- a/src/ModelContextProtocol.Core/Server/SseWriter.cs +++ b/src/ModelContextProtocol.Core/Server/SseWriter.cs @@ -26,6 +26,8 @@ internal sealed class SseWriter(string? messageEndpoint = null, BoundedChannelOp public Task WriteAllAsync(Stream sseResponseStream, CancellationToken cancellationToken) { + Throw.IfNull(sseResponseStream); + // When messageEndpoint is set, the very first SSE event isn't really an IJsonRpcMessage, but there's no API to write a single // item of a different type, so we fib and special-case the "endpoint" event type in the formatter. if (messageEndpoint is not null && !_messages.Writer.TryWrite(new SseItem(null, "endpoint"))) diff --git a/src/ModelContextProtocol.Core/Server/StreamableHttpPostTransport.cs b/src/ModelContextProtocol.Core/Server/StreamableHttpPostTransport.cs index 9d225caa8..1992939de 100644 --- a/src/ModelContextProtocol.Core/Server/StreamableHttpPostTransport.cs +++ b/src/ModelContextProtocol.Core/Server/StreamableHttpPostTransport.cs @@ -1,7 +1,9 @@ using ModelContextProtocol.Protocol; +using System.Diagnostics; using System.IO.Pipelines; using System.Net.ServerSentEvents; using System.Runtime.CompilerServices; +using System.Security.Claims; using System.Text.Json; using System.Threading.Channels; @@ -9,14 +11,14 @@ namespace ModelContextProtocol.Server; /// /// Handles processing the request/response body pairs for the Streamable HTTP transport. -/// This is typically used via . +/// This is typically used via . /// -internal sealed class StreamableHttpPostTransport(StreamableHttpServerTransport parentTransport, IDuplexPipe httpBodies) : ITransport +internal sealed class StreamableHttpPostTransport(StreamableHttpServerTransport parentTransport, Stream responseStream) : ITransport { private readonly SseWriter _sseWriter = new(); private RequestId _pendingRequest; - public ChannelReader MessageReader => throw new NotSupportedException("JsonRpcMessage.RelatedTransport should only be used for sending messages."); + public ChannelReader MessageReader => throw new NotSupportedException("JsonRpcMessage.Context.RelatedTransport should only be used for sending messages."); string? ITransport.SessionId => parentTransport.SessionId; @@ -25,11 +27,31 @@ internal sealed class StreamableHttpPostTransport(StreamableHttpServerTransport /// False, if nothing was written because the request body did not contain any messages to respond to. /// The HTTP application should typically respond with an empty "202 Accepted" response in this scenario. /// - public async ValueTask RunAsync(CancellationToken cancellationToken) + public async ValueTask HandlePostAsync(JsonRpcMessage message, CancellationToken cancellationToken) { - var message = await JsonSerializer.DeserializeAsync(httpBodies.Input.AsStream(), - McpJsonUtilities.JsonContext.Default.JsonRpcMessage, cancellationToken).ConfigureAwait(false); - await OnMessageReceivedAsync(message, cancellationToken).ConfigureAwait(false); + Debug.Assert(_pendingRequest.Id is null); + + if (message is JsonRpcRequest request) + { + _pendingRequest = request.Id; + + // Invoke the initialize request callback if applicable. + if (parentTransport.OnInitRequestReceived is { } onInitRequest && request.Method == RequestMethods.Initialize) + { + var initializeRequest = JsonSerializer.Deserialize(request.Params, McpJsonUtilities.JsonContext.Default.InitializeRequestParams); + await onInitRequest(initializeRequest).ConfigureAwait(false); + } + } + + message.Context ??= new JsonRpcMessageContext(); + message.Context.RelatedTransport = this; + + if (parentTransport.FlowExecutionContextFromRequests) + { + message.Context.ExecutionContext = ExecutionContext.Capture(); + } + + await parentTransport.MessageWriter.WriteAsync(message, cancellationToken).ConfigureAwait(false); if (_pendingRequest.Id is null) { @@ -37,12 +59,14 @@ public async ValueTask RunAsync(CancellationToken cancellationToken) } _sseWriter.MessageFilter = StopOnFinalResponseFilter; - await _sseWriter.WriteAllAsync(httpBodies.Output.AsStream(), cancellationToken).ConfigureAwait(false); + await _sseWriter.WriteAllAsync(responseStream, cancellationToken).ConfigureAwait(false); return true; } public async Task SendMessageAsync(JsonRpcMessage message, CancellationToken cancellationToken = default) { + Throw.IfNull(message); + if (parentTransport.Stateless && message is JsonRpcRequest) { throw new InvalidOperationException("Server to client requests are not supported in stateless mode."); @@ -69,33 +93,4 @@ public async ValueTask DisposeAsync() } } } - - private async ValueTask OnMessageReceivedAsync(JsonRpcMessage? message, CancellationToken cancellationToken) - { - if (message is null) - { - throw new InvalidOperationException("Received invalid null message."); - } - - if (message is JsonRpcRequest request) - { - _pendingRequest = request.Id; - - // Invoke the initialize request callback if applicable. - if (parentTransport.OnInitRequestReceived is { } onInitRequest && request.Method == RequestMethods.Initialize) - { - var initializeRequest = JsonSerializer.Deserialize(request.Params, McpJsonUtilities.JsonContext.Default.InitializeRequestParams); - await onInitRequest(initializeRequest).ConfigureAwait(false); - } - } - - message.RelatedTransport = this; - - if (parentTransport.FlowExecutionContextFromRequests) - { - message.ExecutionContext = ExecutionContext.Capture(); - } - - await parentTransport.MessageWriter.WriteAsync(message, cancellationToken).ConfigureAwait(false); - } } diff --git a/src/ModelContextProtocol.Core/Server/StreamableHttpServerTransport.cs b/src/ModelContextProtocol.Core/Server/StreamableHttpServerTransport.cs index b63c8a651..57283e9a2 100644 --- a/src/ModelContextProtocol.Core/Server/StreamableHttpServerTransport.cs +++ b/src/ModelContextProtocol.Core/Server/StreamableHttpServerTransport.cs @@ -1,5 +1,6 @@ using ModelContextProtocol.Protocol; using System.IO.Pipelines; +using System.Security.Claims; using System.Threading.Channels; namespace ModelContextProtocol.Server; @@ -49,8 +50,8 @@ public sealed class StreamableHttpServerTransport : ITransport public bool Stateless { get; init; } /// - /// Gets a value indicating whether the execution context should flow from the calls to - /// to the corresponding emitted by the . + /// Gets a value indicating whether the execution context should flow from the calls to + /// to the corresponding property contained in the instances returned by the . /// /// /// Defaults to . @@ -75,8 +76,10 @@ public sealed class StreamableHttpServerTransport : ITransport /// The response stream to write MCP JSON-RPC messages as SSE events to. /// The to monitor for cancellation requests. The default is . /// A task representing the send loop that writes JSON-RPC messages to the SSE response stream. - public async Task HandleGetRequest(Stream sseResponseStream, CancellationToken cancellationToken) + public async Task HandleGetRequest(Stream sseResponseStream, CancellationToken cancellationToken = default) { + Throw.IfNull(sseResponseStream); + if (Stateless) { throw new InvalidOperationException("GET requests are not supported in stateless mode."); @@ -96,23 +99,33 @@ public async Task HandleGetRequest(Stream sseResponseStream, CancellationToken c /// and other correlated messages are sent back to the client directly in response /// to the that initiated the message. /// - /// The duplex pipe facilitates the reading and writing of HTTP request and response data. - /// This token allows for the operation to be canceled if needed. + /// The JSON-RPC message received from the client via the POST request body. + /// This token allows for the operation to be canceled if needed. The default is . + /// The POST response body to write MCP JSON-RPC messages to. /// /// True, if data was written to the response body. /// False, if nothing was written because the request body did not contain any messages to respond to. /// The HTTP application should typically respond with an empty "202 Accepted" response in this scenario. /// - public async Task HandlePostRequest(IDuplexPipe httpBodies, CancellationToken cancellationToken) + /// + /// If 's an authenticated sent the message, that can be included in the . + /// No other part of the context should be set. + /// + public async Task HandlePostRequest(JsonRpcMessage message, Stream responseStream, CancellationToken cancellationToken = default) { + Throw.IfNull(message); + Throw.IfNull(responseStream); + using var postCts = CancellationTokenSource.CreateLinkedTokenSource(_disposeCts.Token, cancellationToken); - await using var postTransport = new StreamableHttpPostTransport(this, httpBodies); - return await postTransport.RunAsync(postCts.Token).ConfigureAwait(false); + await using var postTransport = new StreamableHttpPostTransport(this, responseStream); + return await postTransport.HandlePostAsync(message, postCts.Token).ConfigureAwait(false); } /// public async Task SendMessageAsync(JsonRpcMessage message, CancellationToken cancellationToken = default) { + Throw.IfNull(message); + if (Stateless) { throw new InvalidOperationException("Unsolicited server to client messages are not supported in stateless mode."); @@ -126,6 +139,7 @@ public async ValueTask DisposeAsync() { try { + _incomingChannel.Writer.TryComplete(); await _disposeCts.CancelAsync(); } finally diff --git a/src/ModelContextProtocol/McpServerBuilderExtensions.cs b/src/ModelContextProtocol/McpServerBuilderExtensions.cs index d925b24f6..db1b029de 100644 --- a/src/ModelContextProtocol/McpServerBuilderExtensions.cs +++ b/src/ModelContextProtocol/McpServerBuilderExtensions.cs @@ -707,6 +707,278 @@ public static IMcpServerBuilder WithSetLoggingLevelHandler(this IMcpServerBuilde } #endregion + #region Filters + /// + /// Adds a filter to the list resource templates handler pipeline. + /// + /// The builder instance. + /// The filter function that wraps the handler. + /// The builder provided in . + /// is . + /// + /// + /// This filter wraps handlers that return a list of available resource templates when requested by a client. + /// The filter can modify, log, or perform additional operations on requests and responses for + /// requests. It supports pagination through the cursor mechanism, + /// where the client can make repeated calls with the cursor returned by the previous call to retrieve more resource templates. + /// + /// + public static IMcpServerBuilder AddListResourceTemplatesFilter(this IMcpServerBuilder builder, Func, CancellationToken, ValueTask>, Func, CancellationToken, ValueTask>> filter) + { + Throw.IfNull(builder); + + builder.Services.Configure(options => options.Filters.ListResourceTemplatesFilters.Add(filter)); + return builder; + } + + /// + /// Adds a filter to the list tools handler pipeline. + /// + /// The builder instance. + /// The filter function that wraps the handler. + /// The builder provided in . + /// is . + /// + /// + /// This filter wraps handlers that return a list of available tools when requested by a client. + /// The filter can modify, log, or perform additional operations on requests and responses for + /// requests. It supports pagination through the cursor mechanism, + /// where the client can make repeated calls with the cursor returned by the previous call to retrieve more tools. + /// + /// + /// This filter works alongside any tools defined in the collection. + /// Tools from both sources will be combined when returning results to clients. + /// + /// + public static IMcpServerBuilder AddListToolsFilter(this IMcpServerBuilder builder, Func, CancellationToken, ValueTask>, Func, CancellationToken, ValueTask>> filter) + { + Throw.IfNull(builder); + + builder.Services.Configure(options => options.Filters.ListToolsFilters.Add(filter)); + return builder; + } + + /// + /// Adds a filter to the call tool handler pipeline. + /// + /// The builder instance. + /// The filter function that wraps the handler. + /// The builder provided in . + /// is . + /// + /// + /// This filter wraps handlers that are invoked when a client makes a call to a tool that isn't found in the collection. + /// The filter can modify, log, or perform additional operations on requests and responses for + /// requests. The handler should implement logic to execute the requested tool and return appropriate results. + /// + /// + public static IMcpServerBuilder AddCallToolFilter(this IMcpServerBuilder builder, Func, CancellationToken, ValueTask>, Func, CancellationToken, ValueTask>> filter) + { + Throw.IfNull(builder); + + builder.Services.Configure(options => options.Filters.CallToolFilters.Add(filter)); + return builder; + } + + /// + /// Adds a filter to the list prompts handler pipeline. + /// + /// The builder instance. + /// The filter function that wraps the handler. + /// The builder provided in . + /// is . + /// + /// + /// This filter wraps handlers that return a list of available prompts when requested by a client. + /// The filter can modify, log, or perform additional operations on requests and responses for + /// requests. It supports pagination through the cursor mechanism, + /// where the client can make repeated calls with the cursor returned by the previous call to retrieve more prompts. + /// + /// + /// This filter works alongside any prompts defined in the collection. + /// Prompts from both sources will be combined when returning results to clients. + /// + /// + public static IMcpServerBuilder AddListPromptsFilter(this IMcpServerBuilder builder, Func, CancellationToken, ValueTask>, Func, CancellationToken, ValueTask>> filter) + { + Throw.IfNull(builder); + + builder.Services.Configure(options => options.Filters.ListPromptsFilters.Add(filter)); + return builder; + } + + /// + /// Adds a filter to the get prompt handler pipeline. + /// + /// The builder instance. + /// The filter function that wraps the handler. + /// The builder provided in . + /// is . + /// + /// + /// This filter wraps handlers that are invoked when a client requests details for a specific prompt that isn't found in the collection. + /// The filter can modify, log, or perform additional operations on requests and responses for + /// requests. The handler should implement logic to fetch or generate the requested prompt and return appropriate results. + /// + /// + public static IMcpServerBuilder AddGetPromptFilter(this IMcpServerBuilder builder, Func, CancellationToken, ValueTask>, Func, CancellationToken, ValueTask>> filter) + { + Throw.IfNull(builder); + + builder.Services.Configure(options => options.Filters.GetPromptFilters.Add(filter)); + return builder; + } + + /// + /// Adds a filter to the list resources handler pipeline. + /// + /// The builder instance. + /// The filter function that wraps the handler. + /// The builder provided in . + /// is . + /// + /// + /// This filter wraps handlers that return a list of available resources when requested by a client. + /// The filter can modify, log, or perform additional operations on requests and responses for + /// requests. It supports pagination through the cursor mechanism, + /// where the client can make repeated calls with the cursor returned by the previous call to retrieve more resources. + /// + /// + public static IMcpServerBuilder AddListResourcesFilter(this IMcpServerBuilder builder, Func, CancellationToken, ValueTask>, Func, CancellationToken, ValueTask>> filter) + { + Throw.IfNull(builder); + + builder.Services.Configure(options => options.Filters.ListResourcesFilters.Add(filter)); + return builder; + } + + /// + /// Adds a filter to the read resource handler pipeline. + /// + /// The builder instance. + /// The filter function that wraps the handler. + /// The builder provided in . + /// is . + /// + /// + /// This filter wraps handlers that are invoked when a client requests the content of a specific resource identified by its URI. + /// The filter can modify, log, or perform additional operations on requests and responses for + /// requests. The handler should implement logic to locate and retrieve the requested resource. + /// + /// + public static IMcpServerBuilder AddReadResourceFilter(this IMcpServerBuilder builder, Func, CancellationToken, ValueTask>, Func, CancellationToken, ValueTask>> filter) + { + Throw.IfNull(builder); + + builder.Services.Configure(options => options.Filters.ReadResourceFilters.Add(filter)); + return builder; + } + + /// + /// Adds a filter to the complete handler pipeline. + /// + /// The builder instance. + /// The filter function that wraps the handler. + /// The builder provided in . + /// is . + /// + /// + /// This filter wraps handlers that provide auto-completion suggestions for prompt arguments or resource references in the Model Context Protocol. + /// The filter can modify, log, or perform additional operations on requests and responses for + /// requests. The handler processes auto-completion requests, returning a list of suggestions based on the + /// reference type and current argument value. + /// + /// + public static IMcpServerBuilder AddCompleteFilter(this IMcpServerBuilder builder, Func, CancellationToken, ValueTask>, Func, CancellationToken, ValueTask>> filter) + { + Throw.IfNull(builder); + + builder.Services.Configure(options => options.Filters.CompleteFilters.Add(filter)); + return builder; + } + + /// + /// Adds a filter to the subscribe to resources handler pipeline. + /// + /// The builder instance. + /// The filter function that wraps the handler. + /// The builder provided in . + /// is . + /// + /// + /// This filter wraps handlers that are invoked when a client wants to receive notifications about changes to specific resources or resource patterns. + /// The filter can modify, log, or perform additional operations on requests and responses for + /// requests. The handler should implement logic to register the client's interest in the specified resources + /// and set up the necessary infrastructure to send notifications when those resources change. + /// + /// + /// After a successful subscription, the server should send resource change notifications to the client + /// whenever a relevant resource is created, updated, or deleted. + /// + /// + public static IMcpServerBuilder AddSubscribeToResourcesFilter(this IMcpServerBuilder builder, Func, CancellationToken, ValueTask>, Func, CancellationToken, ValueTask>> filter) + { + Throw.IfNull(builder); + + builder.Services.Configure(options => options.Filters.SubscribeToResourcesFilters.Add(filter)); + return builder; + } + + /// + /// Adds a filter to the unsubscribe from resources handler pipeline. + /// + /// The builder instance. + /// The filter function that wraps the handler. + /// The builder provided in . + /// is . + /// + /// + /// This filter wraps handlers that are invoked when a client wants to stop receiving notifications about previously subscribed resources. + /// The filter can modify, log, or perform additional operations on requests and responses for + /// requests. The handler should implement logic to remove the client's subscriptions to the specified resources + /// and clean up any associated resources. + /// + /// + /// After a successful unsubscription, the server should no longer send resource change notifications + /// to the client for the specified resources. + /// + /// + public static IMcpServerBuilder AddUnsubscribeFromResourcesFilter(this IMcpServerBuilder builder, Func, CancellationToken, ValueTask>, Func, CancellationToken, ValueTask>> filter) + { + Throw.IfNull(builder); + + builder.Services.Configure(options => options.Filters.UnsubscribeFromResourcesFilters.Add(filter)); + return builder; + } + + /// + /// Adds a filter to the set logging level handler pipeline. + /// + /// The builder instance. + /// The filter function that wraps the handler. + /// The builder provided in . + /// is . + /// + /// + /// This filter wraps handlers that process requests from clients. When set, it enables + /// clients to control which log messages they receive by specifying a minimum severity threshold. + /// The filter can modify, log, or perform additional operations on requests and responses for + /// requests. + /// + /// + /// After handling a level change request, the server typically begins sending log messages + /// at or above the specified level to the client as notifications/message notifications. + /// + /// + public static IMcpServerBuilder AddSetLoggingLevelFilter(this IMcpServerBuilder builder, Func, CancellationToken, ValueTask>, Func, CancellationToken, ValueTask>> filter) + { + Throw.IfNull(builder); + + builder.Services.Configure(options => options.Filters.SetLoggingLevelFilters.Add(filter)); + return builder; + } + #endregion + #region Transports /// /// Adds a server transport that uses standard input (stdin) and standard output (stdout) for communication. diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/AuthorizeAttributeTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/AuthorizeAttributeTests.cs new file mode 100644 index 000000000..8c173d890 --- /dev/null +++ b/tests/ModelContextProtocol.AspNetCore.Tests/AuthorizeAttributeTests.cs @@ -0,0 +1,374 @@ +using Microsoft.AspNetCore.Authorization; +using Microsoft.AspNetCore.Builder; +using Microsoft.Extensions.DependencyInjection; +using ModelContextProtocol.AspNetCore.Tests.Utils; +using ModelContextProtocol.Client; +using ModelContextProtocol.Protocol; +using ModelContextProtocol.Server; +using System.ComponentModel; +using System.Security.Claims; + +namespace ModelContextProtocol.AspNetCore.Tests; + +/// +/// Tests for MCP authorization functionality with [Authorize], [AllowAnonymous] and role-based authorization. +/// +public class AuthorizeAttributeTests(ITestOutputHelper testOutputHelper) : KestrelInMemoryTest(testOutputHelper) +{ + private async Task ConnectAsync() + { + await using var transport = new SseClientTransport(new SseClientTransportOptions + { + Endpoint = new("http://localhost:5000"), + }, HttpClient, LoggerFactory); + + return await McpClientFactory.CreateAsync(transport, cancellationToken: TestContext.Current.CancellationToken, loggerFactory: LoggerFactory); + } + + [Fact] + public async Task Authorize_Tool_RequiresAuthentication() + { + await using var app = await StartServerWithAuth(builder => builder.WithTools()); + + var client = await ConnectAsync(); + var result = await client.CallToolAsync( + "authorized_tool", + new Dictionary { ["message"] = "test" }, + cancellationToken: TestContext.Current.CancellationToken); + + // Should return error because tool requires authorization but user is anonymous + Assert.True(result.IsError ?? false); + var content = Assert.Single(result.Content.OfType()); + Assert.Equal("Access forbidden: This tool requires authorization.", content.Text); + } + + [Fact] + public async Task ClassLevelAuthorize_Tool_RequiresAuthentication() + { + await using var app = await StartServerWithAuth(builder => builder.WithTools()); + + var client = await ConnectAsync(); + var result = await client.CallToolAsync( + "anonymous_tool", + new Dictionary { ["message"] = "test" }, + cancellationToken: TestContext.Current.CancellationToken); + + Assert.False(result.IsError ?? false); + var content = Assert.Single(result.Content.OfType()); + Assert.Equal("Anonymous: test", content.Text); + } + + [Fact] + public async Task AllowAnonymous_Tool_AllowsAnonymousAccess() + { + await using var app = await StartServerWithAuth(builder => builder.WithTools()); + + var client = await ConnectAsync(); + var result = await client.CallToolAsync( + "anonymous_tool", + new Dictionary { ["message"] = "test" }, + cancellationToken: TestContext.Current.CancellationToken); + + Assert.False(result.IsError ?? false); + var content = Assert.Single(result.Content.OfType()); + Assert.Equal("Anonymous: test", content.Text); + } + + [Fact] + public async Task Authorize_Tool_AllowsAuthenticatedUser() + { + await using var app = await StartServerWithAuth(builder => builder.WithTools(), "TestUser"); + + var client = await ConnectAsync(); + var result = await client.CallToolAsync( + "authorized_tool", + new Dictionary { ["message"] = "test" }, + cancellationToken: TestContext.Current.CancellationToken); + + Assert.False(result.IsError ?? false); + var content = Assert.Single(result.Content.OfType()); + Assert.Equal("Authorized: test", content.Text); + } + + [Fact] + public async Task AuthorizeWithRoles_Tool_RequiresAdminRole() + { + await using var app = await StartServerWithAuth(builder => builder.WithTools(), "TestUser", "User"); + + var client = await ConnectAsync(); + var result = await client.CallToolAsync( + "admin_tool", + new Dictionary { ["message"] = "test" }, + cancellationToken: TestContext.Current.CancellationToken); + + // Should return error because tool requires Admin role but user only has User role + Assert.True(result.IsError ?? false); + var content = Assert.Single(result.Content.OfType()); + Assert.Equal("Access forbidden: This tool requires authorization.", content.Text); + } + + [Fact] + public async Task AuthorizeWithRoles_Tool_AllowsAdminUser() + { + await using var app = await StartServerWithAuth(builder => builder.WithTools(), "AdminUser", "Admin"); + + var client = await ConnectAsync(); + var result = await client.CallToolAsync( + "admin_tool", + new Dictionary { ["message"] = "test" }, + cancellationToken: TestContext.Current.CancellationToken); + + Assert.False(result.IsError ?? false); + var content = Assert.Single(result.Content.OfType()); + Assert.Equal("Admin: test", content.Text); + } + + [Fact] + public async Task ListTools_Anonymous_OnlyReturnsAnonymousTools() + { + await using var app = await StartServerWithAuth(builder => builder.WithTools()); + + var client = await ConnectAsync(); + var tools = await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken); + + Assert.Single(tools); + Assert.Equal("anonymous_tool", tools[0].Name); + } + + [Fact] + public async Task ListTools_AuthenticatedUser_ReturnsAuthorizedTools() + { + await using var app = await StartServerWithAuth(builder => builder.WithTools(), "TestUser"); + + var client = await ConnectAsync(); + var tools = await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken); + + // Authenticated user should see anonymous and basic authorized tools, but not admin-only tools + Assert.Equal(2, tools.Count); + var toolNames = tools.Select(t => t.Name).OrderBy(n => n).ToList(); + Assert.Equal(["anonymous_tool", "authorized_tool"], toolNames); + } + + [Fact] + public async Task ListTools_AdminUser_ReturnsAllTools() + { + await using var app = await StartServerWithAuth(builder => builder.WithTools(), "AdminUser", "Admin"); + + var client = await ConnectAsync(); + var tools = await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken); + + // Admin user should see all tools + Assert.Equal(3, tools.Count); + var toolNames = tools.Select(t => t.Name).OrderBy(n => n).ToList(); + Assert.Equal(["admin_tool", "anonymous_tool", "authorized_tool"], toolNames); + } + + [Fact] + public async Task ListTools_UserRole_DoesNotReturnAdminTools() + { + await using var app = await StartServerWithAuth(builder => builder.WithTools(), "TestUser", "User"); + + var client = await ConnectAsync(); + var tools = await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken); + + // User with User role should not see admin-only tools + Assert.Equal(2, tools.Count); + var toolNames = tools.Select(t => t.Name).OrderBy(n => n).ToList(); + Assert.Equal(["anonymous_tool", "authorized_tool"], toolNames); + } + + [Fact] + public async Task Authorize_Prompt_RequiresAuthentication() + { + await using var app = await StartServerWithAuth(builder => builder.WithPrompts()); + + var client = await ConnectAsync(); + + var exception = await Assert.ThrowsAsync(async () => + await client.GetPromptAsync( + "authorized_prompt", + new Dictionary { ["message"] = "test" }, + cancellationToken: TestContext.Current.CancellationToken)); + + Assert.Equal("Request failed (remote): Access forbidden: This prompt requires authorization.", exception.Message); + Assert.Equal(McpErrorCode.InvalidRequest, exception.ErrorCode); + } + + [Fact] + public async Task Authorize_Prompt_AllowsAuthenticatedUser() + { + await using var app = await StartServerWithAuth(builder => builder.WithPrompts(), "TestUser"); + + var client = await ConnectAsync(); + var result = await client.GetPromptAsync( + "authorized_prompt", + new Dictionary { ["message"] = "test" }, + cancellationToken: TestContext.Current.CancellationToken); + + var message = Assert.Single(result.Messages); + Assert.Equal(Role.User, message.Role); + var content = Assert.IsType(message.Content); + Assert.Equal("Authorized prompt: test", content.Text); + } + + [Fact] + public async Task ListPrompts_Anonymous_OnlyReturnsAnonymousPrompts() + { + await using var app = await StartServerWithAuth(builder => builder.WithPrompts()); + + var client = await ConnectAsync(); + var prompts = await client.ListPromptsAsync(cancellationToken: TestContext.Current.CancellationToken); + + // Anonymous user should only see prompts marked with [AllowAnonymous] + Assert.Single(prompts); + Assert.Equal("anonymous_prompt", prompts[0].Name); + } + + [Fact] + public async Task Authorize_Resource_RequiresAuthentication() + { + await using var app = await StartServerWithAuth(builder => builder.WithResources()); + + var client = await ConnectAsync(); + + var exception = await Assert.ThrowsAsync(async () => + await client.ReadResourceAsync( + "resource://authorized", + cancellationToken: TestContext.Current.CancellationToken)); + + Assert.Equal("Request failed (remote): Access forbidden: This resource requires authorization.", exception.Message); + Assert.Equal(McpErrorCode.InvalidRequest, exception.ErrorCode); + } + + [Fact] + public async Task Authorize_Resource_AllowsAuthenticatedUser() + { + await using var app = await StartServerWithAuth(builder => builder.WithResources(), "TestUser"); + + var client = await ConnectAsync(); + var result = await client.ReadResourceAsync( + "resource://authorized", + cancellationToken: TestContext.Current.CancellationToken); + + var content = Assert.Single(result.Contents.OfType()); + Assert.Equal("Authorized resource content", content.Text); + } + + [Fact] + public async Task ListResources_Anonymous_OnlyReturnsAnonymousResources() + { + await using var app = await StartServerWithAuth(builder => builder.WithResources()); + + var client = await ConnectAsync(); + var resources = await client.ListResourcesAsync(cancellationToken: TestContext.Current.CancellationToken); + + Assert.Single(resources); + Assert.Equal("resource://anonymous", resources[0].Uri); + } + + private async Task StartServerWithAuth(Action configure, string? userName = null, params string[] roles) + { + var builder = Builder.Services.AddMcpServer().WithHttpTransport(); + configure(builder); + Builder.Services.AddAuthorization(); + + var app = Builder.Build(); + + if (userName is not null) + { + app.Use(next => + { + return async context => + { + context.User = CreateUser(userName, roles); + await next(context); + }; + }); + } + + app.MapMcp(); + await app.StartAsync(TestContext.Current.CancellationToken); + return app; + } + + private ClaimsPrincipal CreateUser(string name, params string[] roles) + => new ClaimsPrincipal(new ClaimsIdentity( + [new Claim("name", name), new Claim(ClaimTypes.NameIdentifier, name), ..roles.Select(role => new Claim("role", role))], + "TestAuthType", "name", "role")); + + [McpServerToolType] + private class AuthorizationTestTools + { + [McpServerTool, Description("A tool that allows anonymous access.")] + public static string AnonymousTool(string message) + { + return $"Anonymous: {message}"; + } + + [McpServerTool, Description("A tool that requires authorization.")] + [Authorize] + public static string AuthorizedTool(string message) + { + return $"Authorized: {message}"; + } + + [McpServerTool, Description("A tool that requires Admin role.")] + [Authorize(Roles = "Admin")] + public static string AdminTool(string message) + { + return $"Admin: {message}"; + } + } + + [McpServerToolType] + [Authorize] + private class AllowAnonymousTestTools + { + [McpServerTool, Description("A tool that allows anonymous access.")] + [AllowAnonymous] + public static string AnonymousTool(string message) + { + return $"Anonymous: {message}"; + } + + [McpServerTool, Description("A tool that requires authorization.")] + public static string AuthorizedTool(string message) + { + return $"Authorized: {message}"; + } + } + + [McpServerPromptType] + private class AuthorizationTestPrompts + { + [McpServerPrompt, Description("A prompt that allows anonymous access.")] + public static string AnonymousPrompt(string message) + { + return $"Anonymous prompt: {message}"; + } + + [McpServerPrompt, Description("A prompt that requires authorization.")] + [Authorize] + public static string AuthorizedPrompt(string message) + { + return $"Authorized prompt: {message}"; + } + } + + [McpServerResourceType] + private class AuthorizationTestResources + { + [McpServerResource(UriTemplate = "resource://anonymous"), Description("A resource that allows anonymous access.")] + public static string AnonymousResource() + { + return "Anonymous resource content"; + } + + [McpServerResource(UriTemplate = "resource://authorized"), Description("A resource that requires authorization.")] + [Authorize] + public static string AuthorizedResource() + { + return "Authorized resource content"; + } + } +} diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpSseTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpSseTests.cs index f31621307..728304070 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpSseTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpSseTests.cs @@ -13,7 +13,7 @@ public class MapMcpSseTests(ITestOutputHelper outputHelper) : MapMcpTests(output [InlineData("/mcp/secondary")] public async Task Allows_Customizing_Route(string pattern) { - Builder.Services.AddMcpServer().WithHttpTransport(ConfigureStateless); + Builder.Services.AddMcpServer().WithHttpTransport(); await using var app = Builder.Build(); app.MapMcp(pattern); diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpTests.cs index 4d0d73562..0d867c8f0 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/MapMcpTests.cs @@ -111,6 +111,35 @@ public async Task Messages_FromNewUser_AreRejected() Assert.Equal(HttpStatusCode.Forbidden, httpRequestException.StatusCode); } + [Fact] + public async Task ClaimsPrincipal_CanBeInjectedIntoToolMethod() + { + Builder.Services.AddMcpServer().WithHttpTransport(ConfigureStateless).WithTools(); + Builder.Services.AddHttpContextAccessor(); + + await using var app = Builder.Build(); + + app.Use(next => async context => + { + context.User = CreateUser("TestUser"); + await next(context); + }); + + app.MapMcp(); + + await app.StartAsync(TestContext.Current.CancellationToken); + + await using var client = await ConnectAsync(); + + var response = await client.CallToolAsync( + "echo_claims_principal", + new Dictionary() { ["message"] = "Hello world!" }, + cancellationToken: TestContext.Current.CancellationToken); + + var content = Assert.Single(response.Content.OfType()); + Assert.Equal("TestUser: Hello world!", content.Text); + } + [Fact] public async Task Sampling_DoesNotCloseStream_Prematurely() { @@ -200,6 +229,17 @@ public string EchoWithUserName(string message) } } + [McpServerToolType] + protected class ClaimsPrincipalTools + { + [McpServerTool, Description("Echoes the input back to the client with the user name from ClaimsPrincipal.")] + public string EchoClaimsPrincipal(ClaimsPrincipal? user, string message) + { + var userName = user?.Identity?.Name ?? "anonymous"; + return $"{userName}: {message}"; + } + } + [McpServerToolType] private class SamplingRegressionTools { diff --git a/tests/ModelContextProtocol.Tests/ClientServerTestBase.cs b/tests/ModelContextProtocol.Tests/ClientServerTestBase.cs index ec1c85107..d9b699b98 100644 --- a/tests/ModelContextProtocol.Tests/ClientServerTestBase.cs +++ b/tests/ModelContextProtocol.Tests/ClientServerTestBase.cs @@ -20,7 +20,8 @@ public ClientServerTestBase(ITestOutputHelper testOutputHelper) : base(testOutputHelper) { ServiceCollection sc = new(); - sc.AddSingleton(LoggerFactory); + sc.AddLogging(); + sc.AddSingleton(XunitLoggerProvider); _builder = sc .AddMcpServer() .WithStreamServerTransport(_clientToServerPipe.Reader.AsStream(), _serverToClientPipe.Writer.AsStream()); diff --git a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsFilterTests.cs b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsFilterTests.cs new file mode 100644 index 000000000..6a7d0044d --- /dev/null +++ b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsFilterTests.cs @@ -0,0 +1,314 @@ +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; +using ModelContextProtocol.Client; +using ModelContextProtocol.Protocol; +using ModelContextProtocol.Server; +using ModelContextProtocol.Tests.Utils; + +namespace ModelContextProtocol.Tests.Configuration; + +public class McpServerBuilderExtensionsFilterTests : ClientServerTestBase +{ + public McpServerBuilderExtensionsFilterTests(ITestOutputHelper testOutputHelper) + : base(testOutputHelper) + { + } + + private MockLoggerProvider _mockLoggerProvider = new(); + + private static ILogger GetLogger(IServiceProvider? services, string categoryName) + { + var loggerFactory = services?.GetRequiredService() ?? throw new InvalidOperationException("LoggerFactory not available"); + return loggerFactory.CreateLogger(categoryName); + } + + protected override void ConfigureServices(ServiceCollection services, IMcpServerBuilder mcpServerBuilder) + { + mcpServerBuilder + .AddListResourceTemplatesFilter((next) => async (request, cancellationToken) => + { + var logger = GetLogger(request.Services, "ListResourceTemplatesFilter"); + logger.LogInformation("ListResourceTemplatesFilter executed"); + return await next(request, cancellationToken); + }) + .AddListToolsFilter((next) => async (request, cancellationToken) => + { + var logger = GetLogger(request.Services, "ListToolsFilter"); + logger.LogInformation("ListToolsFilter executed"); + return await next(request, cancellationToken); + }) + .AddListToolsFilter((next) => async (request, cancellationToken) => + { + var logger = GetLogger(request.Services, "ListToolsOrder1"); + logger.LogInformation("ListToolsOrder1 before"); + var result = await next(request, cancellationToken); + logger.LogInformation("ListToolsOrder1 after"); + return result; + }) + .AddListToolsFilter((next) => async (request, cancellationToken) => + { + var logger = GetLogger(request.Services, "ListToolsOrder2"); + logger.LogInformation("ListToolsOrder2 before"); + var result = await next(request, cancellationToken); + logger.LogInformation("ListToolsOrder2 after"); + return result; + }) + .AddCallToolFilter((next) => async (request, cancellationToken) => + { + var logger = GetLogger(request.Services, "CallToolFilter"); + var primitiveId = request.MatchedPrimitive?.Id ?? "unknown"; + logger.LogInformation($"CallToolFilter executed for tool: {primitiveId}"); + return await next(request, cancellationToken); + }) + .AddListPromptsFilter((next) => async (request, cancellationToken) => + { + var logger = GetLogger(request.Services, "ListPromptsFilter"); + logger.LogInformation("ListPromptsFilter executed"); + return await next(request, cancellationToken); + }) + .AddGetPromptFilter((next) => async (request, cancellationToken) => + { + var logger = GetLogger(request.Services, "GetPromptFilter"); + var primitiveId = request.MatchedPrimitive?.Id ?? "unknown"; + logger.LogInformation($"GetPromptFilter executed for prompt: {primitiveId}"); + return await next(request, cancellationToken); + }) + .AddListResourcesFilter((next) => async (request, cancellationToken) => + { + var logger = GetLogger(request.Services, "ListResourcesFilter"); + logger.LogInformation("ListResourcesFilter executed"); + return await next(request, cancellationToken); + }) + .AddReadResourceFilter((next) => async (request, cancellationToken) => + { + var logger = GetLogger(request.Services, "ReadResourceFilter"); + var primitiveId = request.MatchedPrimitive?.Id ?? "unknown"; + logger.LogInformation($"ReadResourceFilter executed for resource: {primitiveId}"); + return await next(request, cancellationToken); + }) + .AddCompleteFilter((next) => async (request, cancellationToken) => + { + var logger = GetLogger(request.Services, "CompleteFilter"); + logger.LogInformation("CompleteFilter executed"); + return await next(request, cancellationToken); + }) + .AddSubscribeToResourcesFilter((next) => async (request, cancellationToken) => + { + var logger = GetLogger(request.Services, "SubscribeToResourcesFilter"); + logger.LogInformation("SubscribeToResourcesFilter executed"); + return await next(request, cancellationToken); + }) + .AddUnsubscribeFromResourcesFilter((next) => async (request, cancellationToken) => + { + var logger = GetLogger(request.Services, "UnsubscribeFromResourcesFilter"); + logger.LogInformation("UnsubscribeFromResourcesFilter executed"); + return await next(request, cancellationToken); + }) + .AddSetLoggingLevelFilter((next) => async (request, cancellationToken) => + { + var logger = GetLogger(request.Services, "SetLoggingLevelFilter"); + logger.LogInformation("SetLoggingLevelFilter executed"); + return await next(request, cancellationToken); + }) + .WithTools() + .WithPrompts() + .WithResources() + .WithSetLoggingLevelHandler(async (request, cancellationToken) => new EmptyResult()) + .WithListResourceTemplatesHandler(async (request, cancellationToken) => new ListResourceTemplatesResult + { + ResourceTemplates = [new() { Name = "test", UriTemplate = "test://resource/{id}" }] + }) + .WithCompleteHandler(async (request, cancellationToken) => new CompleteResult + { + Completion = new() { Values = ["test"] } + }); + + services.AddSingleton(_mockLoggerProvider); + } + + [Fact] + public async Task AddListResourceTemplatesFilter_Logs_When_ListResourceTemplates_Called() + { + await using IMcpClient client = await CreateMcpClientForServer(); + + await client.ListResourceTemplatesAsync(cancellationToken: TestContext.Current.CancellationToken); + + var logMessage = Assert.Single(_mockLoggerProvider.LogMessages, m => m.Message == "ListResourceTemplatesFilter executed"); + Assert.Equal(LogLevel.Information, logMessage.LogLevel); + Assert.Equal("ListResourceTemplatesFilter", logMessage.Category); + } + + [Fact] + public async Task AddListToolsFilter_Logs_When_ListTools_Called() + { + await using IMcpClient client = await CreateMcpClientForServer(); + + await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken); + + var logMessage = Assert.Single(_mockLoggerProvider.LogMessages, m => m.Message == "ListToolsFilter executed"); + Assert.Equal(LogLevel.Information, logMessage.LogLevel); + Assert.Equal("ListToolsFilter", logMessage.Category); + } + + [Fact] + public async Task AddCallToolFilter_Logs_When_CallTool_Called() + { + await using IMcpClient client = await CreateMcpClientForServer(); + + await client.CallToolAsync("test_tool_method", cancellationToken: TestContext.Current.CancellationToken); + + var logMessage = Assert.Single(_mockLoggerProvider.LogMessages, m => m.Message == "CallToolFilter executed for tool: test_tool_method"); + Assert.Equal(LogLevel.Information, logMessage.LogLevel); + Assert.Equal("CallToolFilter", logMessage.Category); + } + + [Fact] + public async Task AddListPromptsFilter_Logs_When_ListPrompts_Called() + { + await using IMcpClient client = await CreateMcpClientForServer(); + + await client.ListPromptsAsync(cancellationToken: TestContext.Current.CancellationToken); + + var logMessage = Assert.Single(_mockLoggerProvider.LogMessages, m => m.Message == "ListPromptsFilter executed"); + Assert.Equal(LogLevel.Information, logMessage.LogLevel); + Assert.Equal("ListPromptsFilter", logMessage.Category); + } + + [Fact] + public async Task AddGetPromptFilter_Logs_When_GetPrompt_Called() + { + await using IMcpClient client = await CreateMcpClientForServer(); + + await client.GetPromptAsync("test_prompt_method", cancellationToken: TestContext.Current.CancellationToken); + + var logMessage = Assert.Single(_mockLoggerProvider.LogMessages, m => m.Message == "GetPromptFilter executed for prompt: test_prompt_method"); + Assert.Equal(LogLevel.Information, logMessage.LogLevel); + Assert.Equal("GetPromptFilter", logMessage.Category); + } + + [Fact] + public async Task AddListResourcesFilter_Logs_When_ListResources_Called() + { + await using IMcpClient client = await CreateMcpClientForServer(); + + await client.ListResourcesAsync(cancellationToken: TestContext.Current.CancellationToken); + + var logMessage = Assert.Single(_mockLoggerProvider.LogMessages, m => m.Message == "ListResourcesFilter executed"); + Assert.Equal(LogLevel.Information, logMessage.LogLevel); + Assert.Equal("ListResourcesFilter", logMessage.Category); + } + + [Fact] + public async Task AddReadResourceFilter_Logs_When_ReadResource_Called() + { + await using IMcpClient client = await CreateMcpClientForServer(); + + await client.ReadResourceAsync("test://resource/123", cancellationToken: TestContext.Current.CancellationToken); + + var logMessage = Assert.Single(_mockLoggerProvider.LogMessages, m => m.Message == "ReadResourceFilter executed for resource: test://resource/{id}"); + Assert.Equal(LogLevel.Information, logMessage.LogLevel); + Assert.Equal("ReadResourceFilter", logMessage.Category); + } + + [Fact] + public async Task AddCompleteFilter_Logs_When_Complete_Called() + { + await using IMcpClient client = await CreateMcpClientForServer(); + + var reference = new PromptReference { Name = "test_prompt_method" }; + await client.CompleteAsync(reference, "argument", "value", cancellationToken: TestContext.Current.CancellationToken); + + var logMessage = Assert.Single(_mockLoggerProvider.LogMessages, m => m.Message == "CompleteFilter executed"); + Assert.Equal(LogLevel.Information, logMessage.LogLevel); + Assert.Equal("CompleteFilter", logMessage.Category); + } + + [Fact] + public async Task AddSubscribeToResourcesFilter_Logs_When_SubscribeToResources_Called() + { + await using IMcpClient client = await CreateMcpClientForServer(); + + await client.SubscribeToResourceAsync("test://resource/123", cancellationToken: TestContext.Current.CancellationToken); + + var logMessage = Assert.Single(_mockLoggerProvider.LogMessages, m => m.Message == "SubscribeToResourcesFilter executed"); + Assert.Equal(LogLevel.Information, logMessage.LogLevel); + Assert.Equal("SubscribeToResourcesFilter", logMessage.Category); + } + + [Fact] + public async Task AddUnsubscribeFromResourcesFilter_Logs_When_UnsubscribeFromResources_Called() + { + await using IMcpClient client = await CreateMcpClientForServer(); + + await client.UnsubscribeFromResourceAsync("test://resource/123", cancellationToken: TestContext.Current.CancellationToken); + + var logMessage = Assert.Single(_mockLoggerProvider.LogMessages, m => m.Message == "UnsubscribeFromResourcesFilter executed"); + Assert.Equal(LogLevel.Information, logMessage.LogLevel); + Assert.Equal("UnsubscribeFromResourcesFilter", logMessage.Category); + } + + [Fact] + public async Task AddSetLoggingLevelFilter_Logs_When_SetLoggingLevel_Called() + { + await using IMcpClient client = await CreateMcpClientForServer(); + + await client.SetLoggingLevel(LoggingLevel.Info, cancellationToken: TestContext.Current.CancellationToken); + + var logMessage = Assert.Single(_mockLoggerProvider.LogMessages, m => m.Message == "SetLoggingLevelFilter executed"); + Assert.Equal(LogLevel.Information, logMessage.LogLevel); + Assert.Equal("SetLoggingLevelFilter", logMessage.Category); + } + + [Fact] + public async Task AddListToolsFilter_Multiple_Filters_Log_In_Expected_Order() + { + await using IMcpClient client = await CreateMcpClientForServer(); + + await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken); + + var logMessages = _mockLoggerProvider.LogMessages + .Where(m => m.Category.StartsWith("ListToolsOrder")) + .Select(m => m.Message); + + Assert.Collection(logMessages, + m => Assert.Equal("ListToolsOrder1 before", m), + m => Assert.Equal("ListToolsOrder2 before", m), + m => Assert.Equal("ListToolsOrder2 after", m), + m => Assert.Equal("ListToolsOrder1 after", m) + ); + } + + [McpServerToolType] + public sealed class TestTool + { + [McpServerTool] + public static string TestToolMethod() + { + return "test result"; + } + } + + [McpServerPromptType] + public sealed class TestPrompt + { + [McpServerPrompt] + public static Task TestPromptMethod() + { + return Task.FromResult(new GetPromptResult + { + Description = "Test prompt", + Messages = [new() { Role = Role.User, Content = new TextContentBlock { Text = "Test" } }] + }); + } + } + + [McpServerResourceType] + public sealed class TestResource + { + [McpServerResource(UriTemplate = "test://resource/{id}")] + public static string TestResourceMethod(string id) + { + return $"Test resource for ID: {id}"; + } + } +} diff --git a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs index 35f833d50..82a2b6b6b 100644 --- a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs +++ b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs @@ -5,6 +5,7 @@ using ModelContextProtocol.Client; using ModelContextProtocol.Protocol; using ModelContextProtocol.Server; +using ModelContextProtocol.Tests.Utils; using System.Collections.Concurrent; using System.ComponentModel; using System.IO.Pipelines; @@ -22,6 +23,8 @@ public McpServerBuilderExtensionsToolsTests(ITestOutputHelper testOutputHelper) { } + private MockLoggerProvider _mockLoggerProvider = new(); + protected override void ConfigureServices(ServiceCollection services, IMcpServerBuilder mcpServerBuilder) { mcpServerBuilder @@ -107,6 +110,7 @@ protected override void ConfigureServices(ServiceCollection services, IMcpServer .WithTools(serializerOptions: BuilderToolsJsonContext.Default.Options); services.AddSingleton(new ObjectWithId()); + services.AddSingleton(_mockLoggerProvider); } [Fact] @@ -155,8 +159,8 @@ public async Task Can_Create_Multiple_Servers_From_Options_And_List_Registered_T await using (var client = await McpClientFactory.CreateAsync( new StreamClientTransport( - serverInput: stdinPipe.Writer.AsStream(), - serverOutput: stdoutPipe.Reader.AsStream(), + serverInput: stdinPipe.Writer.AsStream(), + serverOutput: stdoutPipe.Reader.AsStream(), LoggerFactory), loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken)) @@ -230,7 +234,7 @@ public async Task Can_Call_Registered_Tool() var result = await client.CallToolAsync( "echo", - new Dictionary() { ["message"] = "Peter" }, + new Dictionary() { ["message"] = "Peter" }, cancellationToken: TestContext.Current.CancellationToken); Assert.NotNull(result); @@ -351,14 +355,14 @@ public async Task Can_Call_Registered_Tool_With_Instance_Method() string random1 = parts[0][0]; string random2 = parts[1][0]; Assert.NotEqual(random1, random2); - + string id1 = parts[0][1]; string id2 = parts[1][1]; Assert.Equal(id1, id2); } [Fact] - public async Task Returns_IsError_Content_When_Tool_Fails() + public async Task Returns_IsError_Content_And_Logs_Error_When_Tool_Fails() { await using IMcpClient client = await CreateMcpClientForServer(); @@ -370,6 +374,11 @@ public async Task Returns_IsError_Content_When_Tool_Fails() Assert.NotNull(result.Content); Assert.NotEmpty(result.Content); Assert.Contains("An error occurred", (result.Content[0] as TextContentBlock)?.Text); + + var errorLog = Assert.Single(_mockLoggerProvider.LogMessages, m => m.LogLevel == LogLevel.Error); + Assert.Equal($"\"throw_exception\" threw an unhandled exception.", errorLog.Message); + Assert.IsType(errorLog.Exception); + Assert.Equal("Test error", errorLog.Exception.Message); } [Fact] diff --git a/tests/ModelContextProtocol.Tests/Server/McpServerPromptTests.cs b/tests/ModelContextProtocol.Tests/Server/McpServerPromptTests.cs index 39e9b72ff..307e086a3 100644 --- a/tests/ModelContextProtocol.Tests/Server/McpServerPromptTests.cs +++ b/tests/ModelContextProtocol.Tests/Server/McpServerPromptTests.cs @@ -1,4 +1,4 @@ -using Microsoft.Extensions.AI; +using Microsoft.Extensions.AI; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Primitives; using ModelContextProtocol.Protocol; @@ -15,6 +15,16 @@ namespace ModelContextProtocol.Tests.Server; public class McpServerPromptTests { + private static JsonRpcRequest CreateTestJsonRpcRequest() + { + return new JsonRpcRequest + { + Id = new RequestId("test-id"), + Method = "test/method", + Params = null + }; + } + public McpServerPromptTests() { #if !NET @@ -46,7 +56,7 @@ public async Task SupportsIMcpServer() Assert.DoesNotContain("server", prompt.ProtocolPrompt.Arguments?.Select(a => a.Name) ?? []); var result = await prompt.GetAsync( - new RequestContext(mockServer.Object), + new RequestContext(mockServer.Object, CreateTestJsonRpcRequest()), TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.NotNull(result.Messages); @@ -75,7 +85,7 @@ public async Task SupportsCtorInjection() }, new() { Services = services }); var result = await prompt.GetAsync( - new RequestContext(mockServer.Object), + new RequestContext(mockServer.Object, CreateTestJsonRpcRequest()), TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.NotNull(result.Messages); @@ -125,11 +135,11 @@ public async Task SupportsServiceFromDI() Assert.DoesNotContain("actualMyService", prompt.ProtocolPrompt.Arguments?.Select(a => a.Name) ?? []); await Assert.ThrowsAnyAsync(async () => await prompt.GetAsync( - new RequestContext(new Mock().Object), + new RequestContext(new Mock().Object, CreateTestJsonRpcRequest()), TestContext.Current.CancellationToken)); var result = await prompt.GetAsync( - new RequestContext(new Mock().Object) { Services = services }, + new RequestContext(new Mock().Object, CreateTestJsonRpcRequest()) { Services = services }, TestContext.Current.CancellationToken); Assert.Equal("Hello", Assert.IsType(result.Messages[0].Content).Text); } @@ -150,7 +160,7 @@ public async Task SupportsOptionalServiceFromDI() }, new() { Services = services }); var result = await prompt.GetAsync( - new RequestContext(new Mock().Object), + new RequestContext(new Mock().Object, CreateTestJsonRpcRequest()), TestContext.Current.CancellationToken); Assert.Equal("Hello", Assert.IsType(result.Messages[0].Content).Text); } @@ -163,7 +173,7 @@ public async Task SupportsDisposingInstantiatedDisposableTargets() _ => new DisposablePromptType()); var result = await prompt1.GetAsync( - new RequestContext(new Mock().Object), + new RequestContext(new Mock().Object, CreateTestJsonRpcRequest()), TestContext.Current.CancellationToken); Assert.Equal("disposals:1", Assert.IsType(result.Messages[0].Content).Text); } @@ -176,7 +186,7 @@ public async Task SupportsAsyncDisposingInstantiatedAsyncDisposableTargets() _ => new AsyncDisposablePromptType()); var result = await prompt1.GetAsync( - new RequestContext(new Mock().Object), + new RequestContext(new Mock().Object, CreateTestJsonRpcRequest()), TestContext.Current.CancellationToken); Assert.Equal("asyncDisposals:1", Assert.IsType(result.Messages[0].Content).Text); } @@ -189,7 +199,7 @@ public async Task SupportsAsyncDisposingInstantiatedAsyncDisposableAndDisposable _ => new AsyncDisposableAndDisposablePromptType()); var result = await prompt1.GetAsync( - new RequestContext(new Mock().Object), + new RequestContext(new Mock().Object, CreateTestJsonRpcRequest()), TestContext.Current.CancellationToken); Assert.Equal("disposals:0, asyncDisposals:1", Assert.IsType(result.Messages[0].Content).Text); } @@ -205,7 +215,7 @@ public async Task CanReturnGetPromptResult() }); var actual = await prompt.GetAsync( - new RequestContext(new Mock().Object), + new RequestContext(new Mock().Object, CreateTestJsonRpcRequest()), TestContext.Current.CancellationToken); Assert.Same(expected, actual); @@ -222,7 +232,7 @@ public async Task CanReturnText() }); var actual = await prompt.GetAsync( - new RequestContext(new Mock().Object), + new RequestContext(new Mock().Object, CreateTestJsonRpcRequest()), TestContext.Current.CancellationToken); Assert.NotNull(actual); @@ -248,7 +258,7 @@ public async Task CanReturnPromptMessage() }); var actual = await prompt.GetAsync( - new RequestContext(new Mock().Object), + new RequestContext(new Mock().Object, CreateTestJsonRpcRequest()), TestContext.Current.CancellationToken); Assert.NotNull(actual); @@ -260,7 +270,7 @@ public async Task CanReturnPromptMessage() [Fact] public async Task CanReturnPromptMessages() { - IList expected = + IList expected = [ new() { @@ -280,7 +290,7 @@ public async Task CanReturnPromptMessages() }); var actual = await prompt.GetAsync( - new RequestContext(new Mock().Object), + new RequestContext(new Mock().Object, CreateTestJsonRpcRequest()), TestContext.Current.CancellationToken); Assert.NotNull(actual); @@ -307,7 +317,7 @@ public async Task CanReturnChatMessage() }); var actual = await prompt.GetAsync( - new RequestContext(new Mock().Object), + new RequestContext(new Mock().Object, CreateTestJsonRpcRequest()), TestContext.Current.CancellationToken); Assert.NotNull(actual); @@ -339,7 +349,7 @@ public async Task CanReturnChatMessages() }); var actual = await prompt.GetAsync( - new RequestContext(new Mock().Object), + new RequestContext(new Mock().Object, CreateTestJsonRpcRequest()), TestContext.Current.CancellationToken); Assert.NotNull(actual); @@ -360,7 +370,7 @@ public async Task ThrowsForNullReturn() }); await Assert.ThrowsAsync(async () => await prompt.GetAsync( - new RequestContext(new Mock().Object), + new RequestContext(new Mock().Object, CreateTestJsonRpcRequest()), TestContext.Current.CancellationToken)); } @@ -373,7 +383,7 @@ public async Task ThrowsForUnexpectedTypeReturn() }); await Assert.ThrowsAsync(async () => await prompt.GetAsync( - new RequestContext(new Mock().Object), + new RequestContext(new Mock().Object, CreateTestJsonRpcRequest()), TestContext.Current.CancellationToken)); } diff --git a/tests/ModelContextProtocol.Tests/Server/McpServerResourceTests.cs b/tests/ModelContextProtocol.Tests/Server/McpServerResourceTests.cs index 011c4f2b6..df0b65372 100644 --- a/tests/ModelContextProtocol.Tests/Server/McpServerResourceTests.cs +++ b/tests/ModelContextProtocol.Tests/Server/McpServerResourceTests.cs @@ -1,4 +1,4 @@ -using Microsoft.Extensions.AI; +using Microsoft.Extensions.AI; using Microsoft.Extensions.DependencyInjection; using ModelContextProtocol.Protocol; using ModelContextProtocol.Server; @@ -11,6 +11,16 @@ namespace ModelContextProtocol.Tests.Server; public partial class McpServerResourceTests { + private static JsonRpcRequest CreateTestJsonRpcRequest() + { + return new JsonRpcRequest + { + Id = new RequestId("test-id"), + Method = "test/method", + Params = null + }; + } + public McpServerResourceTests() { #if !NET @@ -138,7 +148,7 @@ public async Task UriTemplate_CreatedFromParameters_LotsOfTypesSupported() t = McpServerResource.Create(() => "42", new() { Name = Name }); Assert.Equal("resource://mcp/Hello", t.ProtocolResourceTemplate.UriTemplate); result = await t.ReadAsync( - new RequestContext(server) { Params = new() { Uri = "resource://mcp/Hello" } }, + new RequestContext(server, CreateTestJsonRpcRequest()) { Params = new() { Uri = "resource://mcp/Hello" } }, TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.Equal("42", ((TextResourceContents)result.Contents[0]).Text); @@ -146,7 +156,7 @@ public async Task UriTemplate_CreatedFromParameters_LotsOfTypesSupported() t = McpServerResource.Create((IMcpServer server) => "42", new() { Name = Name }); Assert.Equal("resource://mcp/Hello", t.ProtocolResourceTemplate.UriTemplate); result = await t.ReadAsync( - new RequestContext(server) { Params = new() { Uri = "resource://mcp/Hello" } }, + new RequestContext(server, CreateTestJsonRpcRequest()) { Params = new() { Uri = "resource://mcp/Hello" } }, TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.Equal("42", ((TextResourceContents)result.Contents[0]).Text); @@ -154,7 +164,7 @@ public async Task UriTemplate_CreatedFromParameters_LotsOfTypesSupported() t = McpServerResource.Create((string arg1) => arg1, new() { Name = Name }); Assert.Equal($"resource://mcp/Hello{{?arg1}}", t.ProtocolResourceTemplate.UriTemplate); result = await t.ReadAsync( - new RequestContext(server) { Params = new() { Uri = $"resource://mcp/Hello?arg1=wOrLd" } }, + new RequestContext(server, CreateTestJsonRpcRequest()) { Params = new() { Uri = $"resource://mcp/Hello?arg1=wOrLd" } }, TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.Equal("wOrLd", ((TextResourceContents)result.Contents[0]).Text); @@ -162,7 +172,7 @@ public async Task UriTemplate_CreatedFromParameters_LotsOfTypesSupported() t = McpServerResource.Create((string arg1, string? arg2 = null) => arg1 + arg2, new() { Name = Name }); Assert.Equal($"resource://mcp/Hello{{?arg1,arg2}}", t.ProtocolResourceTemplate.UriTemplate); result = await t.ReadAsync( - new RequestContext(server) { Params = new() { Uri = $"resource://mcp/Hello?arg1=wo&arg2=rld" } }, + new RequestContext(server, CreateTestJsonRpcRequest()) { Params = new() { Uri = $"resource://mcp/Hello?arg1=wo&arg2=rld" } }, TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.Equal("world", ((TextResourceContents)result.Contents[0]).Text); @@ -170,7 +180,7 @@ public async Task UriTemplate_CreatedFromParameters_LotsOfTypesSupported() t = McpServerResource.Create((object a1, bool a2, char a3, byte a4, sbyte a5) => a1.ToString() + a2 + a3 + a4 + a5, new() { Name = Name }); Assert.Equal($"resource://mcp/Hello{{?a1,a2,a3,a4,a5}}", t.ProtocolResourceTemplate.UriTemplate); result = await t.ReadAsync( - new RequestContext(server) { Params = new() { Uri = $"resource://mcp/Hello?a1=hi&a2=true&a3=s&a4=12&a5=34" } }, + new RequestContext(server, CreateTestJsonRpcRequest()) { Params = new() { Uri = $"resource://mcp/Hello?a1=hi&a2=true&a3=s&a4=12&a5=34" } }, TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.Equal("hiTrues1234", ((TextResourceContents)result.Contents[0]).Text); @@ -178,7 +188,7 @@ public async Task UriTemplate_CreatedFromParameters_LotsOfTypesSupported() t = McpServerResource.Create((ushort a1, short a2, uint a3, int a4, ulong a5) => (a1 + a2 + a3 + a4 + (long)a5).ToString(), new() { Name = Name }); Assert.Equal($"resource://mcp/Hello{{?a1,a2,a3,a4,a5}}", t.ProtocolResourceTemplate.UriTemplate); result = await t.ReadAsync( - new RequestContext(server) { Params = new() { Uri = $"resource://mcp/Hello?a1=10&a2=20&a3=30&a4=40&a5=50" } }, + new RequestContext(server, CreateTestJsonRpcRequest()) { Params = new() { Uri = $"resource://mcp/Hello?a1=10&a2=20&a3=30&a4=40&a5=50" } }, TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.Equal("150", ((TextResourceContents)result.Contents[0]).Text); @@ -186,7 +196,7 @@ public async Task UriTemplate_CreatedFromParameters_LotsOfTypesSupported() t = McpServerResource.Create((long a1, float a2, double a3, decimal a4, TimeSpan a5) => a5.ToString(), new() { Name = Name }); Assert.Equal($"resource://mcp/Hello{{?a1,a2,a3,a4,a5}}", t.ProtocolResourceTemplate.UriTemplate); result = await t.ReadAsync( - new RequestContext(server) { Params = new() { Uri = $"resource://mcp/Hello?a1=1&a2=2&a3=3&a4=4&a5=5" } }, + new RequestContext(server, CreateTestJsonRpcRequest()) { Params = new() { Uri = $"resource://mcp/Hello?a1=1&a2=2&a3=3&a4=4&a5=5" } }, TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.Equal("5.00:00:00", ((TextResourceContents)result.Contents[0]).Text); @@ -194,7 +204,7 @@ public async Task UriTemplate_CreatedFromParameters_LotsOfTypesSupported() t = McpServerResource.Create((DateTime a1, DateTimeOffset a2, Uri a3, Guid a4, Version a5) => a4.ToString("N") + a5, new() { Name = Name }); Assert.Equal($"resource://mcp/Hello{{?a1,a2,a3,a4,a5}}", t.ProtocolResourceTemplate.UriTemplate); result = await t.ReadAsync( - new RequestContext(server) { Params = new() { Uri = $"resource://mcp/Hello?a1={DateTime.UtcNow:r}&a2={DateTimeOffset.UtcNow:r}&a3=http%3A%2F%2Ftest&a4=14e5f43d-0d41-47d6-8207-8249cf669e41&a5=1.2.3.4" } }, + new RequestContext(server, CreateTestJsonRpcRequest()) { Params = new() { Uri = $"resource://mcp/Hello?a1={DateTime.UtcNow:r}&a2={DateTimeOffset.UtcNow:r}&a3=http%3A%2F%2Ftest&a4=14e5f43d-0d41-47d6-8207-8249cf669e41&a5=1.2.3.4" } }, TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.Equal("14e5f43d0d4147d682078249cf669e411.2.3.4", ((TextResourceContents)result.Contents[0]).Text); @@ -203,7 +213,7 @@ public async Task UriTemplate_CreatedFromParameters_LotsOfTypesSupported() t = McpServerResource.Create((Half a2, Int128 a3, UInt128 a4, IntPtr a5) => (a3 + (Int128)a4 + a5).ToString(), new() { Name = Name }); Assert.Equal($"resource://mcp/Hello{{?a2,a3,a4,a5}}", t.ProtocolResourceTemplate.UriTemplate); result = await t.ReadAsync( - new RequestContext(server) { Params = new() { Uri = $"resource://mcp/Hello?a2=1.0&a3=3&a4=4&a5=5" } }, + new RequestContext(server, CreateTestJsonRpcRequest()) { Params = new() { Uri = $"resource://mcp/Hello?a2=1.0&a3=3&a4=4&a5=5" } }, TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.Equal("12", ((TextResourceContents)result.Contents[0]).Text); @@ -211,7 +221,7 @@ public async Task UriTemplate_CreatedFromParameters_LotsOfTypesSupported() t = McpServerResource.Create((UIntPtr a1, DateOnly a2, TimeOnly a3) => a1.ToString(), new() { Name = Name }); Assert.Equal($"resource://mcp/Hello{{?a1,a2,a3}}", t.ProtocolResourceTemplate.UriTemplate); result = await t.ReadAsync( - new RequestContext(server) { Params = new() { Uri = $"resource://mcp/Hello?a1=123&a2=0001-02-03&a3=01%3A02%3A03" } }, + new RequestContext(server, CreateTestJsonRpcRequest()) { Params = new() { Uri = $"resource://mcp/Hello?a1=123&a2=0001-02-03&a3=01%3A02%3A03" } }, TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.Equal("123", ((TextResourceContents)result.Contents[0]).Text); @@ -220,7 +230,7 @@ public async Task UriTemplate_CreatedFromParameters_LotsOfTypesSupported() t = McpServerResource.Create((bool? a2, char? a3, byte? a4, sbyte? a5) => a2?.ToString() + a3 + a4 + a5, new() { Name = Name }); Assert.Equal($"resource://mcp/Hello{{?a2,a3,a4,a5}}", t.ProtocolResourceTemplate.UriTemplate); result = await t.ReadAsync( - new RequestContext(server) { Params = new() { Uri = $"resource://mcp/Hello?a2=true&a3=s&a4=12&a5=34" } }, + new RequestContext(server, CreateTestJsonRpcRequest()) { Params = new() { Uri = $"resource://mcp/Hello?a2=true&a3=s&a4=12&a5=34" } }, TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.Equal("Trues1234", ((TextResourceContents)result.Contents[0]).Text); @@ -228,7 +238,7 @@ public async Task UriTemplate_CreatedFromParameters_LotsOfTypesSupported() t = McpServerResource.Create((ushort? a1, short? a2, uint? a3, int? a4, ulong? a5) => (a1 + a2 + a3 + a4 + (long?)a5).ToString(), new() { Name = Name }); Assert.Equal($"resource://mcp/Hello{{?a1,a2,a3,a4,a5}}", t.ProtocolResourceTemplate.UriTemplate); result = await t.ReadAsync( - new RequestContext(server) { Params = new() { Uri = $"resource://mcp/Hello?a1=10&a2=20&a3=30&a4=40&a5=50" } }, + new RequestContext(server, CreateTestJsonRpcRequest()) { Params = new() { Uri = $"resource://mcp/Hello?a1=10&a2=20&a3=30&a4=40&a5=50" } }, TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.Equal("150", ((TextResourceContents)result.Contents[0]).Text); @@ -236,7 +246,7 @@ public async Task UriTemplate_CreatedFromParameters_LotsOfTypesSupported() t = McpServerResource.Create((long? a1, float? a2, double? a3, decimal? a4, TimeSpan? a5) => a5?.ToString(), new() { Name = Name }); Assert.Equal($"resource://mcp/Hello{{?a1,a2,a3,a4,a5}}", t.ProtocolResourceTemplate.UriTemplate); result = await t.ReadAsync( - new RequestContext(server) { Params = new() { Uri = $"resource://mcp/Hello?a1=1&a2=2&a3=3&a4=4&a5=5" } }, + new RequestContext(server, CreateTestJsonRpcRequest()) { Params = new() { Uri = $"resource://mcp/Hello?a1=1&a2=2&a3=3&a4=4&a5=5" } }, TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.Equal("5.00:00:00", ((TextResourceContents)result.Contents[0]).Text); @@ -244,7 +254,7 @@ public async Task UriTemplate_CreatedFromParameters_LotsOfTypesSupported() t = McpServerResource.Create((DateTime? a1, DateTimeOffset? a2, Guid? a4) => a4?.ToString("N"), new() { Name = Name }); Assert.Equal($"resource://mcp/Hello{{?a1,a2,a4}}", t.ProtocolResourceTemplate.UriTemplate); result = await t.ReadAsync( - new RequestContext(server) { Params = new() { Uri = $"resource://mcp/Hello?a1={DateTime.UtcNow:r}&a2={DateTimeOffset.UtcNow:r}&a4=14e5f43d-0d41-47d6-8207-8249cf669e41" } }, + new RequestContext(server, CreateTestJsonRpcRequest()) { Params = new() { Uri = $"resource://mcp/Hello?a1={DateTime.UtcNow:r}&a2={DateTimeOffset.UtcNow:r}&a4=14e5f43d-0d41-47d6-8207-8249cf669e41" } }, TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.Equal("14e5f43d0d4147d682078249cf669e41", ((TextResourceContents)result.Contents[0]).Text); @@ -253,7 +263,7 @@ public async Task UriTemplate_CreatedFromParameters_LotsOfTypesSupported() t = McpServerResource.Create((Half? a2, Int128? a3, UInt128? a4, IntPtr? a5) => (a3 + (Int128?)a4 + a5).ToString(), new() { Name = Name }); Assert.Equal($"resource://mcp/Hello{{?a2,a3,a4,a5}}", t.ProtocolResourceTemplate.UriTemplate); result = await t.ReadAsync( - new RequestContext(server) { Params = new() { Uri = $"resource://mcp/Hello?a2=1.0&a3=3&a4=4&a5=5" } }, + new RequestContext(server, CreateTestJsonRpcRequest()) { Params = new() { Uri = $"resource://mcp/Hello?a2=1.0&a3=3&a4=4&a5=5" } }, TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.Equal("12", ((TextResourceContents)result.Contents[0]).Text); @@ -261,7 +271,7 @@ public async Task UriTemplate_CreatedFromParameters_LotsOfTypesSupported() t = McpServerResource.Create((UIntPtr? a1, DateOnly? a2, TimeOnly? a3) => a1?.ToString(), new() { Name = Name }); Assert.Equal($"resource://mcp/Hello{{?a1,a2,a3}}", t.ProtocolResourceTemplate.UriTemplate); result = await t.ReadAsync( - new RequestContext(server) { Params = new() { Uri = $"resource://mcp/Hello?a1=123&a2=0001-02-03&a3=01%3A02%3A03" } }, + new RequestContext(server, CreateTestJsonRpcRequest()) { Params = new() { Uri = $"resource://mcp/Hello?a1=123&a2=0001-02-03&a3=01%3A02%3A03" } }, TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.Equal("123", ((TextResourceContents)result.Contents[0]).Text); @@ -277,7 +287,7 @@ public async Task UriTemplate_NonMatchingUri_ReturnsNull(string uri) McpServerResource t = McpServerResource.Create((string arg1) => arg1, new() { Name = "Hello" }); Assert.Equal("resource://mcp/Hello{?arg1}", t.ProtocolResourceTemplate.UriTemplate); Assert.Null(await t.ReadAsync( - new RequestContext(new Mock().Object) { Params = new() { Uri = uri } }, + new RequestContext(new Mock().Object, CreateTestJsonRpcRequest()) { Params = new() { Uri = uri } }, TestContext.Current.CancellationToken)); } @@ -288,7 +298,7 @@ public async Task UriTemplate_IsHostCaseInsensitive(string actualUri, string que { McpServerResource t = McpServerResource.Create(() => "resource", new() { UriTemplate = actualUri }); Assert.NotNull(await t.ReadAsync( - new RequestContext(new Mock().Object) { Params = new() { Uri = queriedUri } }, + new RequestContext(new Mock().Object, CreateTestJsonRpcRequest()) { Params = new() { Uri = queriedUri } }, TestContext.Current.CancellationToken)); } @@ -317,7 +327,7 @@ public async Task UriTemplate_MissingParameter_Throws(string uri) McpServerResource t = McpServerResource.Create((string arg1, int arg2) => arg1, new() { Name = "Hello" }); Assert.Equal("resource://mcp/Hello{?arg1,arg2}", t.ProtocolResourceTemplate.UriTemplate); await Assert.ThrowsAsync(async () => await t.ReadAsync( - new RequestContext(new Mock().Object) { Params = new() { Uri = uri } }, + new RequestContext(new Mock().Object, CreateTestJsonRpcRequest()) { Params = new() { Uri = uri } }, TestContext.Current.CancellationToken)); } @@ -330,25 +340,25 @@ public async Task UriTemplate_MissingOptionalParameter_Succeeds() ReadResourceResult? result; result = await t.ReadAsync( - new RequestContext(new Mock().Object) { Params = new() { Uri = "resource://mcp/Hello" } }, + new RequestContext(new Mock().Object, CreateTestJsonRpcRequest()) { Params = new() { Uri = "resource://mcp/Hello" } }, TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.Equal("", ((TextResourceContents)result.Contents[0]).Text); result = await t.ReadAsync( - new RequestContext(new Mock().Object) { Params = new() { Uri = "resource://mcp/Hello?arg1=first" } }, + new RequestContext(new Mock().Object, CreateTestJsonRpcRequest()) { Params = new() { Uri = "resource://mcp/Hello?arg1=first" } }, TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.Equal("first", ((TextResourceContents)result.Contents[0]).Text); result = await t.ReadAsync( - new RequestContext(new Mock().Object) { Params = new() { Uri = "resource://mcp/Hello?arg2=42" } }, + new RequestContext(new Mock().Object, CreateTestJsonRpcRequest()) { Params = new() { Uri = "resource://mcp/Hello?arg2=42" } }, TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.Equal("42", ((TextResourceContents)result.Contents[0]).Text); result = await t.ReadAsync( - new RequestContext(new Mock().Object) { Params = new() { Uri = "resource://mcp/Hello?arg1=first&arg2=42" } }, + new RequestContext(new Mock().Object, CreateTestJsonRpcRequest()) { Params = new() { Uri = "resource://mcp/Hello?arg1=first&arg2=42" } }, TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.Equal("first42", ((TextResourceContents)result.Contents[0]).Text); @@ -366,7 +376,7 @@ public async Task SupportsIMcpServer() }, new() { Name = "Test" }); var result = await resource.ReadAsync( - new RequestContext(mockServer.Object) { Params = new() { Uri = "resource://mcp/Test" } }, + new RequestContext(mockServer.Object, CreateTestJsonRpcRequest()) { Params = new() { Uri = "resource://mcp/Test" } }, TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.Equal("42", ((TextResourceContents)result.Contents[0]).Text); @@ -393,7 +403,7 @@ public async Task SupportsCtorInjection() }, new() { Services = services }); var result = await tool.ReadAsync( - new RequestContext(mockServer.Object) { Params = new() { Uri = "https://something" } }, + new RequestContext(mockServer.Object, CreateTestJsonRpcRequest()) { Params = new() { Uri = "https://something" } }, TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.NotNull(result.Contents); @@ -470,11 +480,11 @@ public async Task SupportsServiceFromDI(ServiceLifetime injectedArgumentLifetime Mock mockServer = new(); await Assert.ThrowsAnyAsync(async () => await resource.ReadAsync( - new RequestContext(mockServer.Object) { Params = new() { Uri = "resource://mcp/Test" } }, + new RequestContext(mockServer.Object, CreateTestJsonRpcRequest()) { Params = new() { Uri = "resource://mcp/Test" } }, TestContext.Current.CancellationToken)); var result = await resource.ReadAsync( - new RequestContext(mockServer.Object) { Services = services, Params = new() { Uri = "resource://mcp/Test" } }, + new RequestContext(mockServer.Object, CreateTestJsonRpcRequest()) { Services = services, Params = new() { Uri = "resource://mcp/Test" } }, TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.Equal("42", ((TextResourceContents)result.Contents[0]).Text); @@ -496,7 +506,7 @@ public async Task SupportsOptionalServiceFromDI() }, new() { Services = services, Name = "Test" }); var result = await resource.ReadAsync( - new RequestContext(new Mock().Object) { Params = new() { Uri = "resource://mcp/Test" } }, + new RequestContext(new Mock().Object, CreateTestJsonRpcRequest()) { Params = new() { Uri = "resource://mcp/Test" } }, TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.Equal("42", ((TextResourceContents)result.Contents[0]).Text); @@ -512,7 +522,7 @@ public async Task SupportsDisposingInstantiatedDisposableTargets() _ => new DisposableResourceType()); var result = await resource1.ReadAsync( - new RequestContext(new Mock().Object) { Params = new() { Uri = "test://static/resource/instanceMethod" } }, + new RequestContext(new Mock().Object, CreateTestJsonRpcRequest()) { Params = new() { Uri = "test://static/resource/instanceMethod" } }, TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.Equal("0", ((TextResourceContents)result.Contents[0]).Text); @@ -530,7 +540,7 @@ public async Task CanReturnReadResult() return new ReadResourceResult { Contents = new List { new TextResourceContents { Text = "hello" } } }; }, new() { Name = "Test" }); var result = await resource.ReadAsync( - new RequestContext(mockServer.Object) { Params = new() { Uri = "resource://mcp/Test" } }, + new RequestContext(mockServer.Object, CreateTestJsonRpcRequest()) { Params = new() { Uri = "resource://mcp/Test" } }, TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.Single(result.Contents); @@ -547,7 +557,7 @@ public async Task CanReturnResourceContents() return new TextResourceContents { Text = "hello" }; }, new() { Name = "Test", SerializerOptions = JsonContext6.Default.Options }); var result = await resource.ReadAsync( - new RequestContext(mockServer.Object) { Params = new() { Uri = "resource://mcp/Test" } }, + new RequestContext(mockServer.Object, CreateTestJsonRpcRequest()) { Params = new() { Uri = "resource://mcp/Test" } }, TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.Single(result.Contents); @@ -568,7 +578,7 @@ public async Task CanReturnCollectionOfResourceContents() ]; }, new() { Name = "Test" }); var result = await resource.ReadAsync( - new RequestContext(mockServer.Object) { Params = new() { Uri = "resource://mcp/Test" } }, + new RequestContext(mockServer.Object, CreateTestJsonRpcRequest()) { Params = new() { Uri = "resource://mcp/Test" } }, TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.Equal(2, result.Contents.Count); @@ -586,7 +596,7 @@ public async Task CanReturnString() return "42"; }, new() { Name = "Test" }); var result = await resource.ReadAsync( - new RequestContext(mockServer.Object) { Params = new() { Uri = "resource://mcp/Test" } }, + new RequestContext(mockServer.Object, CreateTestJsonRpcRequest()) { Params = new() { Uri = "resource://mcp/Test" } }, TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.Single(result.Contents); @@ -603,7 +613,7 @@ public async Task CanReturnCollectionOfStrings() return new List { "42", "43" }; }, new() { Name = "Test", SerializerOptions = JsonContext6.Default.Options }); var result = await resource.ReadAsync( - new RequestContext(mockServer.Object) { Params = new() { Uri = "resource://mcp/Test" } }, + new RequestContext(mockServer.Object, CreateTestJsonRpcRequest()) { Params = new() { Uri = "resource://mcp/Test" } }, TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.Equal(2, result.Contents.Count); @@ -621,7 +631,7 @@ public async Task CanReturnDataContent() return new DataContent(new byte[] { 0, 1, 2 }, "application/octet-stream"); }, new() { Name = "Test" }); var result = await resource.ReadAsync( - new RequestContext(mockServer.Object) { Params = new() { Uri = "resource://mcp/Test" } }, + new RequestContext(mockServer.Object, CreateTestJsonRpcRequest()) { Params = new() { Uri = "resource://mcp/Test" } }, TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.Single(result.Contents); @@ -643,7 +653,7 @@ public async Task CanReturnCollectionOfAIContent() }; }, new() { Name = "Test", SerializerOptions = JsonContext6.Default.Options }); var result = await resource.ReadAsync( - new RequestContext(mockServer.Object) { Params = new() { Uri = "resource://mcp/Test" } }, + new RequestContext(mockServer.Object, CreateTestJsonRpcRequest()) { Params = new() { Uri = "resource://mcp/Test" } }, TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.Equal(2, result.Contents.Count); diff --git a/tests/ModelContextProtocol.Tests/Server/McpServerToolTests.cs b/tests/ModelContextProtocol.Tests/Server/McpServerToolTests.cs index f961eef34..ca2ab7835 100644 --- a/tests/ModelContextProtocol.Tests/Server/McpServerToolTests.cs +++ b/tests/ModelContextProtocol.Tests/Server/McpServerToolTests.cs @@ -1,10 +1,8 @@ -using Json.Schema; +using Json.Schema; using Microsoft.Extensions.AI; using Microsoft.Extensions.DependencyInjection; -using Microsoft.Extensions.Logging; using ModelContextProtocol.Protocol; using ModelContextProtocol.Server; -using ModelContextProtocol.Tests.Utils; using Moq; using System.Reflection; using System.Runtime.InteropServices; @@ -18,6 +16,16 @@ namespace ModelContextProtocol.Tests.Server; public partial class McpServerToolTests { + private static JsonRpcRequest CreateTestJsonRpcRequest() + { + return new JsonRpcRequest + { + Id = new RequestId("test-id"), + Method = "test/method", + Params = null + }; + } + public McpServerToolTests() { #if !NET @@ -53,7 +61,7 @@ public async Task SupportsIMcpServer() Assert.DoesNotContain("server", JsonSerializer.Serialize(tool.ProtocolTool.InputSchema, McpJsonUtilities.DefaultOptions)); var result = await tool.InvokeAsync( - new RequestContext(mockServer.Object), + new RequestContext(mockServer.Object, CreateTestJsonRpcRequest()), TestContext.Current.CancellationToken); Assert.Equal("42", (result.Content[0] as TextContentBlock)?.Text); } @@ -79,7 +87,7 @@ public async Task SupportsCtorInjection() }, new() { Services = services }); var result = await tool.InvokeAsync( - new RequestContext(mockServer.Object), + new RequestContext(mockServer.Object, CreateTestJsonRpcRequest()), TestContext.Current.CancellationToken); Assert.NotNull(result); Assert.NotNull(result.Content); @@ -156,13 +164,14 @@ public async Task SupportsServiceFromDI(ServiceLifetime injectedArgumentLifetime Mock mockServer = new(); - var result = await tool.InvokeAsync( - new RequestContext(mockServer.Object), - TestContext.Current.CancellationToken); - Assert.True(result.IsError); + var ex = await Assert.ThrowsAsync(async () => await tool.InvokeAsync( + new RequestContext(mockServer.Object, CreateTestJsonRpcRequest()), + TestContext.Current.CancellationToken)); - result = await tool.InvokeAsync( - new RequestContext(mockServer.Object) { Services = services }, + mockServer.SetupGet(s => s.Services).Returns(services); + + var result = await tool.InvokeAsync( + new RequestContext(mockServer.Object, CreateTestJsonRpcRequest()) { Services = services }, TestContext.Current.CancellationToken); Assert.Equal("42", (result.Content[0] as TextContentBlock)?.Text); } @@ -183,7 +192,7 @@ public async Task SupportsOptionalServiceFromDI() }, new() { Services = services }); var result = await tool.InvokeAsync( - new RequestContext(new Mock().Object), + new RequestContext(new Mock().Object, CreateTestJsonRpcRequest()), TestContext.Current.CancellationToken); Assert.Equal("42", (result.Content[0] as TextContentBlock)?.Text); } @@ -198,7 +207,7 @@ public async Task SupportsDisposingInstantiatedDisposableTargets() options); var result = await tool1.InvokeAsync( - new RequestContext(new Mock().Object), + new RequestContext(new Mock().Object, CreateTestJsonRpcRequest()), TestContext.Current.CancellationToken); Assert.Equal("""{"disposals":1}""", (result.Content[0] as TextContentBlock)?.Text); } @@ -213,7 +222,7 @@ public async Task SupportsAsyncDisposingInstantiatedAsyncDisposableTargets() options); var result = await tool1.InvokeAsync( - new RequestContext(new Mock().Object), + new RequestContext(new Mock().Object, CreateTestJsonRpcRequest()), TestContext.Current.CancellationToken); Assert.Equal("""{"asyncDisposals":1}""", (result.Content[0] as TextContentBlock)?.Text); } @@ -232,7 +241,7 @@ public async Task SupportsAsyncDisposingInstantiatedAsyncDisposableAndDisposable options); var result = await tool1.InvokeAsync( - new RequestContext(new Mock().Object) { Services = services }, + new RequestContext(new Mock().Object, CreateTestJsonRpcRequest()) { Services = services }, TestContext.Current.CancellationToken); Assert.Equal("""{"asyncDisposals":1,"disposals":0}""", (result.Content[0] as TextContentBlock)?.Text); } @@ -253,7 +262,7 @@ public async Task CanReturnCollectionOfAIContent() }, new() { SerializerOptions = JsonContext2.Default.Options }); var result = await tool.InvokeAsync( - new RequestContext(mockServer.Object), + new RequestContext(mockServer.Object, CreateTestJsonRpcRequest()), TestContext.Current.CancellationToken); Assert.Equal(3, result.Content.Count); @@ -287,7 +296,7 @@ public async Task CanReturnSingleAIContent(string data, string type) }); var result = await tool.InvokeAsync( - new RequestContext(mockServer.Object), + new RequestContext(mockServer.Object, CreateTestJsonRpcRequest()), TestContext.Current.CancellationToken); Assert.Single(result.Content); @@ -323,7 +332,7 @@ public async Task CanReturnNullAIContent() return (string?)null; }); var result = await tool.InvokeAsync( - new RequestContext(mockServer.Object), + new RequestContext(mockServer.Object, CreateTestJsonRpcRequest()), TestContext.Current.CancellationToken); Assert.Empty(result.Content); } @@ -338,7 +347,7 @@ public async Task CanReturnString() return "42"; }); var result = await tool.InvokeAsync( - new RequestContext(mockServer.Object), + new RequestContext(mockServer.Object, CreateTestJsonRpcRequest()), TestContext.Current.CancellationToken); Assert.Single(result.Content); Assert.Equal("42", Assert.IsType(result.Content[0]).Text); @@ -354,7 +363,7 @@ public async Task CanReturnCollectionOfStrings() return new List { "42", "43" }; }, new() { SerializerOptions = JsonContext2.Default.Options }); var result = await tool.InvokeAsync( - new RequestContext(mockServer.Object), + new RequestContext(mockServer.Object, CreateTestJsonRpcRequest()), TestContext.Current.CancellationToken); Assert.Single(result.Content); Assert.Equal("""["42","43"]""", Assert.IsType(result.Content[0]).Text); @@ -370,7 +379,7 @@ public async Task CanReturnMcpContent() return new TextContentBlock { Text = "42" }; }); var result = await tool.InvokeAsync( - new RequestContext(mockServer.Object), + new RequestContext(mockServer.Object, CreateTestJsonRpcRequest()), TestContext.Current.CancellationToken); Assert.Single(result.Content); Assert.Equal("42", Assert.IsType(result.Content[0]).Text); @@ -386,12 +395,12 @@ public async Task CanReturnCollectionOfMcpContent() Assert.Same(mockServer.Object, server); return (IList) [ - new TextContentBlock { Text = "42" }, - new ImageContentBlock { Data = "1234", MimeType = "image/png" } + new TextContentBlock { Text = "42" }, + new ImageContentBlock { Data = "1234", MimeType = "image/png" } ]; }); var result = await tool.InvokeAsync( - new RequestContext(mockServer.Object), + new RequestContext(mockServer.Object, CreateTestJsonRpcRequest()), TestContext.Current.CancellationToken); Assert.Equal(2, result.Content.Count); Assert.Equal("42", Assert.IsType(result.Content[0]).Text); @@ -414,7 +423,7 @@ public async Task CanReturnCallToolResult() return response; }); var result = await tool.InvokeAsync( - new RequestContext(mockServer.Object), + new RequestContext(mockServer.Object, CreateTestJsonRpcRequest()), TestContext.Current.CancellationToken); Assert.Same(response, result); @@ -447,45 +456,6 @@ public async Task SupportsSchemaCreateOptions() ); } - [Fact] - public async Task ToolCallError_LogsErrorMessage() - { - // Arrange - var mockLoggerProvider = new MockLoggerProvider(); - var loggerFactory = new LoggerFactory(new[] { mockLoggerProvider }); - var services = new ServiceCollection(); - services.AddSingleton(loggerFactory); - var serviceProvider = services.BuildServiceProvider(); - - var toolName = "tool-that-throws"; - var exceptionMessage = "Test exception message"; - - McpServerTool tool = McpServerTool.Create(() => - { - throw new InvalidOperationException(exceptionMessage); - }, new() { Name = toolName, Services = serviceProvider }); - - var mockServer = new Mock(); - var request = new RequestContext(mockServer.Object) - { - Params = new CallToolRequestParams { Name = toolName }, - Services = serviceProvider - }; - - // Act - var result = await tool.InvokeAsync(request, TestContext.Current.CancellationToken); - - // Assert - Assert.True(result.IsError); - Assert.Single(result.Content); - Assert.Equal($"An error occurred invoking '{toolName}'.", Assert.IsType(result.Content[0]).Text); - - var errorLog = Assert.Single(mockLoggerProvider.LogMessages, m => m.LogLevel == LogLevel.Error); - Assert.Equal($"\"{toolName}\" threw an unhandled exception.", errorLog.Message); - Assert.IsType(errorLog.Exception); - Assert.Equal(exceptionMessage, errorLog.Exception.Message); - } - [Theory] [MemberData(nameof(StructuredOutput_ReturnsExpectedSchema_Inputs))] public async Task StructuredOutput_Enabled_ReturnsExpectedSchema(T value) @@ -493,7 +463,7 @@ public async Task StructuredOutput_Enabled_ReturnsExpectedSchema(T value) JsonSerializerOptions options = new() { TypeInfoResolver = new DefaultJsonTypeInfoResolver() }; McpServerTool tool = McpServerTool.Create(() => value, new() { Name = "tool", UseStructuredContent = true, SerializerOptions = options }); var mockServer = new Mock(); - var request = new RequestContext(mockServer.Object) + var request = new RequestContext(mockServer.Object, CreateTestJsonRpcRequest()) { Params = new CallToolRequestParams { Name = "tool" }, }; @@ -511,7 +481,7 @@ public async Task StructuredOutput_Enabled_VoidReturningTools_ReturnsExpectedSch { McpServerTool tool = McpServerTool.Create(() => { }); var mockServer = new Mock(); - var request = new RequestContext(mockServer.Object) + var request = new RequestContext(mockServer.Object, CreateTestJsonRpcRequest()) { Params = new CallToolRequestParams { Name = "tool" }, }; @@ -522,7 +492,7 @@ public async Task StructuredOutput_Enabled_VoidReturningTools_ReturnsExpectedSch Assert.Null(result.StructuredContent); tool = McpServerTool.Create(() => Task.CompletedTask); - request = new RequestContext(mockServer.Object) + request = new RequestContext(mockServer.Object, CreateTestJsonRpcRequest()) { Params = new CallToolRequestParams { Name = "tool" }, }; @@ -533,7 +503,7 @@ public async Task StructuredOutput_Enabled_VoidReturningTools_ReturnsExpectedSch Assert.Null(result.StructuredContent); tool = McpServerTool.Create(() => default(ValueTask)); - request = new RequestContext(mockServer.Object) + request = new RequestContext(mockServer.Object, CreateTestJsonRpcRequest()) { Params = new CallToolRequestParams { Name = "tool" }, }; @@ -551,7 +521,7 @@ public async Task StructuredOutput_Disabled_ReturnsExpectedSchema(T value) JsonSerializerOptions options = new() { TypeInfoResolver = new DefaultJsonTypeInfoResolver() }; McpServerTool tool = McpServerTool.Create(() => value, new() { UseStructuredContent = false, SerializerOptions = options }); var mockServer = new Mock(); - var request = new RequestContext(mockServer.Object) + var request = new RequestContext(mockServer.Object, CreateTestJsonRpcRequest()) { Params = new CallToolRequestParams { Name = "tool" }, }; @@ -592,7 +562,7 @@ public static IEnumerable StructuredOutput_ReturnsExpectedSchema_Input yield return new object[] { new() }; yield return new object[] { new List { "item1", "item2" } }; yield return new object[] { new Dictionary { ["key1"] = 1, ["key2"] = 2 } }; - yield return new object[] { new Person("John", 27) }; + yield return new object[] { new Person("John", 27) }; } private sealed class MyService; From 76f72970e4cea3990ed9f6aeb38fd2dd9ef5a9c5 Mon Sep 17 00:00:00 2001 From: Stephen Halter Date: Tue, 26 Aug 2025 08:17:43 -0700 Subject: [PATCH 2/7] Fix failing Windows build - filters.md cleanup --- docs/concepts/filters.md | 31 ++++++++----------- .../StreamableHttpHandler.cs | 6 ++-- 2 files changed, 17 insertions(+), 20 deletions(-) diff --git a/docs/concepts/filters.md b/docs/concepts/filters.md index 9a8a43265..27462e176 100644 --- a/docs/concepts/filters.md +++ b/docs/concepts/filters.md @@ -7,10 +7,6 @@ uid: filters # MCP Server Handler Filters -This document describes the filter functionality in the MCP Server, which allows you to add middleware-style filters to handler pipelines. - -## Overview - For each handler type in the MCP Server, there are corresponding `AddXXXFilter` methods in `McpServerBuilderExtensions.cs` that allow you to add filters to the handler pipeline. The filters are stored in `McpServerOptions.Filters` and applied during server configuration. ## Available Filter Methods @@ -173,12 +169,12 @@ You can apply authorization at the class level, which affects all tools in the c ```csharp [McpServerToolType] [Authorize] // All tools require authentication -public class AdminTools +public class RestrictedTools { - [McpServerTool, Description("Admin-only tool")] - public static string AdminOperation() + [McpServerTool, Description("Restricted tool accessible to authenticated users")] + public static string RestrictedOperation() { - return "Admin operation completed"; + return "Restricted operation completed"; } [McpServerTool, Description("Public tool accessible to anonymous users")] @@ -211,23 +207,22 @@ To use authorization features, you must configure authentication and authorizati ```csharp var builder = WebApplication.CreateBuilder(args); -// Add authentication builder.Services.AddAuthentication("Bearer") - .AddJwtBearer("Bearer", options => { /* JWT configuration */ }); - -// Add authorization (required for [Authorize] attributes to work) + .AddJwtBearer(options => { /* JWT configuration */ }) + .AddMcp(options => { /* Resource metadata configuration */ }); builder.Services.AddAuthorization(); -// Add MCP server builder.Services.AddMcpServer() - .WithTools(); + .WithHttpTransport() + .WithTools() + .AddCallToolFilter(next => async (context, cancellationToken) => + { + // Custom call tool logic + return await next(context, cancellationToken); + }); var app = builder.Build(); -// Use authentication and authorization middleware -app.UseAuthentication(); -app.UseAuthorization(); - app.MapMcp(); app.Run(); ``` diff --git a/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs b/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs index 8e3a72eb0..d3db7e964 100644 --- a/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs +++ b/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs @@ -279,9 +279,11 @@ internal static string MakeNewSessionId() // Implementation for reading a JSON-RPC message from the request body var message = await context.Request.ReadFromJsonAsync(s_messageTypeInfo, context.RequestAborted); - if (context.User?.Identity?.IsAuthenticated ?? false) + if (context.User?.Identity?.IsAuthenticated ?? false && message is not null) { - message?.Context = new() + // We get weird CS0131 errors only on the Windows build GitHub Action if we use "message?.Context = ..." + // https://productionresultssa0.blob.core.windows.net/actions-results/f2218319-0fdd-473b-891d-06e5a4a0f826/workflow-job-run-98901492-cf7c-5406-85d9-0f7057e0516f/logs/job/job-logs.txt?rsct=text%2Fplain&se=2025-08-26T16%3A06%3A31Z&sig=RvEQo6DgrpDUW9mnbgDvf6FVDAAoHKzk9rsDdcPxOhw%3D&ske=2025-08-27T03%3A39%3A43Z&skoid=ca7593d4-ee42-46cd-af88-8b886a2f84eb&sks=b&skt=2025-08-26T15%3A39%3A43Z&sktid=398a6654-997b-47e9-b12b-9515b896b4de&skv=2025-05-05&sp=r&spr=https&sr=b&st=2025-08-26T15%3A56%3A26Z&sv=2025-05-05 + message!.Context = new() { User = context.User, }; From 9b5786debb39055f5d29977bac6a75f351486b2e Mon Sep 17 00:00:00 2001 From: Stephen Halter Date: Mon, 8 Sep 2025 09:18:20 -0700 Subject: [PATCH 3/7] Use a delegate type for handlers and filters --- .../Protocol/CompletionsCapability.cs | 4 +- .../Protocol/LoggingCapability.cs | 2 +- .../Protocol/PromptsCapability.cs | 16 +++---- .../Protocol/ResourcesCapability.cs | 14 +++--- .../Protocol/ToolsCapability.cs | 14 +++--- .../RequestHandlers.cs | 14 +++--- .../Server/McpRequestFilter.cs | 11 +++++ .../Server/McpRequestHandler.cs | 13 ++++++ .../Server/McpServer.cs | 23 +++++----- .../Server/McpServerFilters.cs | 22 ++++----- .../McpServerBuilderExtensions.cs | 46 +++++++++---------- src/ModelContextProtocol/McpServerHandlers.cs | 24 +++++----- .../Program.cs | 2 +- .../McpServerBuilderExtensionsHandlerTests.cs | 20 ++++---- .../McpServerBuilderExtensionsPromptsTests.cs | 2 +- ...cpServerBuilderExtensionsResourcesTests.cs | 2 +- .../McpServerBuilderExtensionsToolsTests.cs | 2 +- 17 files changed, 128 insertions(+), 103 deletions(-) create mode 100644 src/ModelContextProtocol.Core/Server/McpRequestFilter.cs create mode 100644 src/ModelContextProtocol.Core/Server/McpRequestHandler.cs diff --git a/src/ModelContextProtocol.Core/Protocol/CompletionsCapability.cs b/src/ModelContextProtocol.Core/Protocol/CompletionsCapability.cs index f411c2975..86363351a 100644 --- a/src/ModelContextProtocol.Core/Protocol/CompletionsCapability.cs +++ b/src/ModelContextProtocol.Core/Protocol/CompletionsCapability.cs @@ -9,7 +9,7 @@ namespace ModelContextProtocol.Protocol; /// /// /// -/// When enabled, this capability allows a Model Context Protocol server to provide +/// When enabled, this capability allows a Model Context Protocol server to provide /// auto-completion suggestions. This capability is advertised to clients during the initialize handshake. /// /// @@ -33,5 +33,5 @@ public sealed class CompletionsCapability /// and should return appropriate completion suggestions. /// [JsonIgnore] - public Func, CancellationToken, ValueTask>? CompleteHandler { get; set; } + public McpRequestHandler? CompleteHandler { get; set; } } \ No newline at end of file diff --git a/src/ModelContextProtocol.Core/Protocol/LoggingCapability.cs b/src/ModelContextProtocol.Core/Protocol/LoggingCapability.cs index ab43fb066..07803c1ac 100644 --- a/src/ModelContextProtocol.Core/Protocol/LoggingCapability.cs +++ b/src/ModelContextProtocol.Core/Protocol/LoggingCapability.cs @@ -18,5 +18,5 @@ public sealed class LoggingCapability /// Gets or sets the handler for set logging level requests from clients. /// [JsonIgnore] - public Func, CancellationToken, ValueTask>? SetLoggingLevelHandler { get; set; } + public McpRequestHandler? SetLoggingLevelHandler { get; set; } } \ No newline at end of file diff --git a/src/ModelContextProtocol.Core/Protocol/PromptsCapability.cs b/src/ModelContextProtocol.Core/Protocol/PromptsCapability.cs index 8fad1c0e0..fdfa3d43c 100644 --- a/src/ModelContextProtocol.Core/Protocol/PromptsCapability.cs +++ b/src/ModelContextProtocol.Core/Protocol/PromptsCapability.cs @@ -22,10 +22,10 @@ public sealed class PromptsCapability /// Gets or sets whether this server supports notifications for changes to the prompt list. /// /// - /// When set to , the server will send notifications using - /// when prompts are added, + /// When set to , the server will send notifications using + /// when prompts are added, /// removed, or modified. Clients can register handlers for these notifications to - /// refresh their prompt cache. This capability enables clients to stay synchronized with server-side changes + /// refresh their prompt cache. This capability enables clients to stay synchronized with server-side changes /// to available prompts. /// [JsonPropertyName("listChanged")] @@ -40,15 +40,15 @@ public sealed class PromptsCapability /// along with any prompts defined in . /// [JsonIgnore] - public Func, CancellationToken, ValueTask>? ListPromptsHandler { get; set; } + public McpRequestHandler? ListPromptsHandler { get; set; } /// /// Gets or sets the handler for requests. /// /// /// - /// This handler is invoked when a client requests details for a specific prompt by name and provides arguments - /// for the prompt if needed. The handler receives the request context containing the prompt name and any arguments, + /// This handler is invoked when a client requests details for a specific prompt by name and provides arguments + /// for the prompt if needed. The handler receives the request context containing the prompt name and any arguments, /// and should return a with the prompt messages and other details. /// /// @@ -57,7 +57,7 @@ public sealed class PromptsCapability /// /// [JsonIgnore] - public Func, CancellationToken, ValueTask>? GetPromptHandler { get; set; } + public McpRequestHandler? GetPromptHandler { get; set; } /// /// Gets or sets a collection of prompts that will be served by the server. @@ -69,7 +69,7 @@ public sealed class PromptsCapability /// when those are provided: /// /// - /// - For requests: The server returns all prompts from this collection + /// - For requests: The server returns all prompts from this collection /// plus any additional prompts provided by the if it's set. /// /// diff --git a/src/ModelContextProtocol.Core/Protocol/ResourcesCapability.cs b/src/ModelContextProtocol.Core/Protocol/ResourcesCapability.cs index f6486488b..b5336b207 100644 --- a/src/ModelContextProtocol.Core/Protocol/ResourcesCapability.cs +++ b/src/ModelContextProtocol.Core/Protocol/ResourcesCapability.cs @@ -21,8 +21,8 @@ public sealed class ResourcesCapability /// Gets or sets whether this server supports notifications for changes to the resource list. /// /// - /// When set to , the server will send notifications using - /// when resources are added, + /// When set to , the server will send notifications using + /// when resources are added, /// removed, or modified. Clients can register handlers for these notifications to /// refresh their resource cache. /// @@ -39,7 +39,7 @@ public sealed class ResourcesCapability /// allowing clients to discover available resource types and their access patterns. /// [JsonIgnore] - public Func, CancellationToken, ValueTask>? ListResourceTemplatesHandler { get; set; } + public McpRequestHandler? ListResourceTemplatesHandler { get; set; } /// /// Gets or sets the handler for requests. @@ -49,7 +49,7 @@ public sealed class ResourcesCapability /// The implementation should return a with the matching resources. /// [JsonIgnore] - public Func, CancellationToken, ValueTask>? ListResourcesHandler { get; set; } + public McpRequestHandler? ListResourcesHandler { get; set; } /// /// Gets or sets the handler for requests. @@ -61,7 +61,7 @@ public sealed class ResourcesCapability /// its contents in a ReadResourceResult object. /// [JsonIgnore] - public Func, CancellationToken, ValueTask>? ReadResourceHandler { get; set; } + public McpRequestHandler? ReadResourceHandler { get; set; } /// /// Gets or sets the handler for requests. @@ -74,7 +74,7 @@ public sealed class ResourcesCapability /// requiring polling. /// [JsonIgnore] - public Func, CancellationToken, ValueTask>? SubscribeToResourcesHandler { get; set; } + public McpRequestHandler? SubscribeToResourcesHandler { get; set; } /// /// Gets or sets the handler for requests. @@ -85,7 +85,7 @@ public sealed class ResourcesCapability /// about the specified resource. /// [JsonIgnore] - public Func, CancellationToken, ValueTask>? UnsubscribeFromResourcesHandler { get; set; } + public McpRequestHandler? UnsubscribeFromResourcesHandler { get; set; } /// /// Gets or sets a collection of resources served by the server. diff --git a/src/ModelContextProtocol.Core/Protocol/ToolsCapability.cs b/src/ModelContextProtocol.Core/Protocol/ToolsCapability.cs index 5a3bec5ca..0ea955314 100644 --- a/src/ModelContextProtocol.Core/Protocol/ToolsCapability.cs +++ b/src/ModelContextProtocol.Core/Protocol/ToolsCapability.cs @@ -13,10 +13,10 @@ public sealed class ToolsCapability /// Gets or sets whether this server supports notifications for changes to the tool list. /// /// - /// When set to , the server will send notifications using - /// when tools are added, + /// When set to , the server will send notifications using + /// when tools are added, /// removed, or modified. Clients can register handlers for these notifications to - /// refresh their tool cache. This capability enables clients to stay synchronized with server-side + /// refresh their tool cache. This capability enables clients to stay synchronized with server-side /// changes to available tools. /// [JsonPropertyName("listChanged")] @@ -33,19 +33,19 @@ public sealed class ToolsCapability /// and the tools from the collection will be combined to form the complete list of available tools. /// [JsonIgnore] - public Func, CancellationToken, ValueTask>? ListToolsHandler { get; set; } + public McpRequestHandler? ListToolsHandler { get; set; } /// /// Gets or sets the handler for requests. /// /// /// This handler is invoked when a client makes a call to a tool that isn't found in the . - /// The handler should implement logic to execute the requested tool and return appropriate results. - /// It receives a containing information about the tool + /// The handler should implement logic to execute the requested tool and return appropriate results. + /// It receives a containing information about the tool /// being called and its arguments, and should return a with the execution results. /// [JsonIgnore] - public Func, CancellationToken, ValueTask>? CallToolHandler { get; set; } + public McpRequestHandler? CallToolHandler { get; set; } /// /// Gets or sets a collection of tools served by the server. diff --git a/src/ModelContextProtocol.Core/RequestHandlers.cs b/src/ModelContextProtocol.Core/RequestHandlers.cs index 0c2b54fa5..fd95751d9 100644 --- a/src/ModelContextProtocol.Core/RequestHandlers.cs +++ b/src/ModelContextProtocol.Core/RequestHandlers.cs @@ -10,8 +10,8 @@ internal sealed class RequestHandlers : Dictionary /// Registers a handler for incoming requests of a specific method in the MCP protocol. /// - /// Type of request payload that will be deserialized from incoming JSON - /// Type of response payload that will be serialized to JSON (not full RPC response) + /// Type of request payload that will be deserialized from incoming JSON + /// Type of response payload that will be serialized to JSON (not full RPC response) /// Method identifier to register for (e.g., "tools/list", "logging/setLevel") /// Handler function to be called when a request with the specified method identifier is received /// The JSON contract governing request parameter deserialization @@ -27,11 +27,11 @@ internal sealed class RequestHandlers : Dictionary /// - public void Set( + public void Set( string method, - Func> handler, - JsonTypeInfo requestTypeInfo, - JsonTypeInfo responseTypeInfo) + Func> handler, + JsonTypeInfo requestTypeInfo, + JsonTypeInfo responseTypeInfo) { Throw.IfNull(method); Throw.IfNull(handler); @@ -40,7 +40,7 @@ public void Set( this[method] = async (request, cancellationToken) => { - TRequest? typedRequest = JsonSerializer.Deserialize(request.Params, requestTypeInfo); + TParams? typedRequest = JsonSerializer.Deserialize(request.Params, requestTypeInfo); object? result = await handler(typedRequest, request, cancellationToken).ConfigureAwait(false); return JsonSerializer.SerializeToNode(result, responseTypeInfo); }; diff --git a/src/ModelContextProtocol.Core/Server/McpRequestFilter.cs b/src/ModelContextProtocol.Core/Server/McpRequestFilter.cs new file mode 100644 index 000000000..bc1cabc45 --- /dev/null +++ b/src/ModelContextProtocol.Core/Server/McpRequestFilter.cs @@ -0,0 +1,11 @@ +namespace ModelContextProtocol.Server; + +/// +/// Delegate type for applying filters to incoming MCP requests with specific parameter and result types. +/// +/// The type of the parameters sent with the request. +/// The type of the response returned by the handler. +/// The next request handler in the pipeline. +/// The next request handler wrapped with the filter. +public delegate McpRequestHandler McpRequestFilter( + McpRequestHandler next); \ No newline at end of file diff --git a/src/ModelContextProtocol.Core/Server/McpRequestHandler.cs b/src/ModelContextProtocol.Core/Server/McpRequestHandler.cs new file mode 100644 index 000000000..651e070e5 --- /dev/null +++ b/src/ModelContextProtocol.Core/Server/McpRequestHandler.cs @@ -0,0 +1,13 @@ +namespace ModelContextProtocol.Server; + +/// +/// Delegate type for handling incoming MCP requests with specific parameter and result types. +/// +/// The type of the parameters sent with the request. +/// The type of the response returned by the handler. +/// The request context containing the parameters and other metadata. +/// A cancellation token to cancel the operation. +/// A task representing the asynchronous operation, with the result of the handler. +public delegate ValueTask McpRequestHandler( + RequestContext request, + CancellationToken cancellationToken); \ No newline at end of file diff --git a/src/ModelContextProtocol.Core/Server/McpServer.cs b/src/ModelContextProtocol.Core/Server/McpServer.cs index 0056b1ae0..8785302a3 100644 --- a/src/ModelContextProtocol.Core/Server/McpServer.cs +++ b/src/ModelContextProtocol.Core/Server/McpServer.cs @@ -612,7 +612,7 @@ private void ConfigureLogging(McpServerOptions options) } private ValueTask InvokeHandlerAsync( - Func, CancellationToken, ValueTask> handler, + McpRequestHandler handler, TParams? args, JsonRpcRequest jsonRpcRequest, CancellationToken cancellationToken = default) @@ -622,7 +622,7 @@ private ValueTask InvokeHandlerAsync( handler(new(new DestinationBoundMcpServer(this, jsonRpcRequest.Context?.RelatedTransport), jsonRpcRequest) { Params = args }, cancellationToken); async ValueTask InvokeScopedAsync( - Func, CancellationToken, ValueTask> handler, + McpRequestHandler handler, TParams? args, JsonRpcRequest jsonRpcRequest, CancellationToken cancellationToken) @@ -648,11 +648,11 @@ async ValueTask InvokeScopedAsync( } } - private void SetHandler( + private void SetHandler( string method, - Func, CancellationToken, ValueTask> handler, - JsonTypeInfo requestTypeInfo, - JsonTypeInfo responseTypeInfo) + McpRequestHandler handler, + JsonTypeInfo requestTypeInfo, + JsonTypeInfo responseTypeInfo) { RequestHandlers.Set(method, (request, jsonRpcRequest, cancellationToken) => @@ -660,12 +660,13 @@ private void SetHandler( requestTypeInfo, responseTypeInfo); } - private static THandler BuildFilterPipeline( - THandler baseHandler, List> filters, - Func? initialHandler = null, - Func? finalHandler = null) + private static McpRequestHandler BuildFilterPipeline( + McpRequestHandler baseHandler, + List> filters, + McpRequestFilter? initialHandler = null, + McpRequestFilter? finalHandler = null) { - THandler current = baseHandler; + var current = baseHandler; if (finalHandler is not null) { current = finalHandler(current); diff --git a/src/ModelContextProtocol.Core/Server/McpServerFilters.cs b/src/ModelContextProtocol.Core/Server/McpServerFilters.cs index d15154dd0..e38421bc1 100644 --- a/src/ModelContextProtocol.Core/Server/McpServerFilters.cs +++ b/src/ModelContextProtocol.Core/Server/McpServerFilters.cs @@ -26,7 +26,7 @@ public sealed class McpServerFilters /// Tools from both sources will be combined when returning results to clients. /// /// - public List, CancellationToken, ValueTask>, Func, CancellationToken, ValueTask>>> ListToolsFilters { get; } = new(); + public List> ListToolsFilters { get; } = new(); /// /// Gets the filters for the call tool handler pipeline. @@ -36,7 +36,7 @@ public sealed class McpServerFilters /// The filters can modify, log, or perform additional operations on requests and responses for /// requests. The handler should implement logic to execute the requested tool and return appropriate results. /// - public List, CancellationToken, ValueTask>, Func, CancellationToken, ValueTask>>> CallToolFilters { get; } = new(); + public List> CallToolFilters { get; } = new(); /// /// Gets the filters for the list prompts handler pipeline. @@ -53,7 +53,7 @@ public sealed class McpServerFilters /// Prompts from both sources will be combined when returning results to clients. /// /// - public List, CancellationToken, ValueTask>, Func, CancellationToken, ValueTask>>> ListPromptsFilters { get; } = new(); + public List> ListPromptsFilters { get; } = new(); /// /// Gets the filters for the get prompt handler pipeline. @@ -63,7 +63,7 @@ public sealed class McpServerFilters /// The filters can modify, log, or perform additional operations on requests and responses for /// requests. The handler should implement logic to fetch or generate the requested prompt and return appropriate results. /// - public List, CancellationToken, ValueTask>, Func, CancellationToken, ValueTask>>> GetPromptFilters { get; } = new(); + public List> GetPromptFilters { get; } = new(); /// /// Gets the filters for the list resource templates handler pipeline. @@ -74,7 +74,7 @@ public sealed class McpServerFilters /// requests. It supports pagination through the cursor mechanism, /// where the client can make repeated calls with the cursor returned by the previous call to retrieve more resource templates. /// - public List, CancellationToken, ValueTask>, Func, CancellationToken, ValueTask>>> ListResourceTemplatesFilters { get; } = new(); + public List> ListResourceTemplatesFilters { get; } = new(); /// /// Gets the filters for the list resources handler pipeline. @@ -85,7 +85,7 @@ public sealed class McpServerFilters /// requests. It supports pagination through the cursor mechanism, /// where the client can make repeated calls with the cursor returned by the previous call to retrieve more resources. /// - public List, CancellationToken, ValueTask>, Func, CancellationToken, ValueTask>>> ListResourcesFilters { get; } = new(); + public List> ListResourcesFilters { get; } = new(); /// /// Gets the filters for the read resource handler pipeline. @@ -95,7 +95,7 @@ public sealed class McpServerFilters /// The filters can modify, log, or perform additional operations on requests and responses for /// requests. The handler should implement logic to locate and retrieve the requested resource. /// - public List, CancellationToken, ValueTask>, Func, CancellationToken, ValueTask>>> ReadResourceFilters { get; } = new(); + public List> ReadResourceFilters { get; } = new(); /// /// Gets the filters for the complete handler pipeline. @@ -106,7 +106,7 @@ public sealed class McpServerFilters /// requests. The handler processes auto-completion requests, returning a list of suggestions based on the /// reference type and current argument value. /// - public List, CancellationToken, ValueTask>, Func, CancellationToken, ValueTask>>> CompleteFilters { get; } = new(); + public List> CompleteFilters { get; } = new(); /// /// Gets the filters for the subscribe to resources handler pipeline. @@ -123,7 +123,7 @@ public sealed class McpServerFilters /// whenever a relevant resource is created, updated, or deleted. /// /// - public List, CancellationToken, ValueTask>, Func, CancellationToken, ValueTask>>> SubscribeToResourcesFilters { get; } = new(); + public List> SubscribeToResourcesFilters { get; } = new(); /// /// Gets the filters for the unsubscribe from resources handler pipeline. @@ -140,7 +140,7 @@ public sealed class McpServerFilters /// to the client for the specified resources. /// /// - public List, CancellationToken, ValueTask>, Func, CancellationToken, ValueTask>>> UnsubscribeFromResourcesFilters { get; } = new(); + public List> UnsubscribeFromResourcesFilters { get; } = new(); /// /// Gets the filters for the set logging level handler pipeline. @@ -157,5 +157,5 @@ public sealed class McpServerFilters /// at or above the specified level to the client as notifications/message notifications. /// /// - public List, CancellationToken, ValueTask>, Func, CancellationToken, ValueTask>>> SetLoggingLevelFilters { get; } = new(); + public List> SetLoggingLevelFilters { get; } = new(); } diff --git a/src/ModelContextProtocol/McpServerBuilderExtensions.cs b/src/ModelContextProtocol/McpServerBuilderExtensions.cs index 90b1cb0b5..8e59d9640 100644 --- a/src/ModelContextProtocol/McpServerBuilderExtensions.cs +++ b/src/ModelContextProtocol/McpServerBuilderExtensions.cs @@ -91,7 +91,7 @@ public static partial class McpServerBuilderExtensions if (toolMethod.GetCustomAttribute() is not null) { builder.Services.AddSingleton(services => McpServerTool.Create( - toolMethod, + toolMethod, toolMethod.IsStatic ? null : target, new() { Services = services, SerializerOptions = serializerOptions })); } @@ -585,7 +585,7 @@ where t.GetCustomAttribute() is not null /// resource system where templates define the URI patterns and the read handler provides the actual content. /// /// - public static IMcpServerBuilder WithListResourceTemplatesHandler(this IMcpServerBuilder builder, Func, CancellationToken, ValueTask> handler) + public static IMcpServerBuilder WithListResourceTemplatesHandler(this IMcpServerBuilder builder, McpRequestHandler handler) { Throw.IfNull(builder); @@ -618,7 +618,7 @@ public static IMcpServerBuilder WithListResourceTemplatesHandler(this IMcpServer /// executes them when invoked by clients. /// /// - public static IMcpServerBuilder WithListToolsHandler(this IMcpServerBuilder builder, Func, CancellationToken, ValueTask> handler) + public static IMcpServerBuilder WithListToolsHandler(this IMcpServerBuilder builder, McpRequestHandler handler) { Throw.IfNull(builder); @@ -638,7 +638,7 @@ public static IMcpServerBuilder WithListToolsHandler(this IMcpServerBuilder buil /// This method is typically paired with to provide a complete tools implementation, /// where advertises available tools and this handler executes them. /// - public static IMcpServerBuilder WithCallToolHandler(this IMcpServerBuilder builder, Func, CancellationToken, ValueTask> handler) + public static IMcpServerBuilder WithCallToolHandler(this IMcpServerBuilder builder, McpRequestHandler handler) { Throw.IfNull(builder); @@ -671,7 +671,7 @@ public static IMcpServerBuilder WithCallToolHandler(this IMcpServerBuilder build /// produces them when invoked by clients. /// /// - public static IMcpServerBuilder WithListPromptsHandler(this IMcpServerBuilder builder, Func, CancellationToken, ValueTask> handler) + public static IMcpServerBuilder WithListPromptsHandler(this IMcpServerBuilder builder, McpRequestHandler handler) { Throw.IfNull(builder); @@ -686,7 +686,7 @@ public static IMcpServerBuilder WithListPromptsHandler(this IMcpServerBuilder bu /// The handler function that processes prompt requests. /// The builder provided in . /// is . - public static IMcpServerBuilder WithGetPromptHandler(this IMcpServerBuilder builder, Func, CancellationToken, ValueTask> handler) + public static IMcpServerBuilder WithGetPromptHandler(this IMcpServerBuilder builder, McpRequestHandler handler) { Throw.IfNull(builder); @@ -707,7 +707,7 @@ public static IMcpServerBuilder WithGetPromptHandler(this IMcpServerBuilder buil /// where this handler advertises available resources and the read handler provides their content when requested. /// /// - public static IMcpServerBuilder WithListResourcesHandler(this IMcpServerBuilder builder, Func, CancellationToken, ValueTask> handler) + public static IMcpServerBuilder WithListResourcesHandler(this IMcpServerBuilder builder, McpRequestHandler handler) { Throw.IfNull(builder); @@ -726,7 +726,7 @@ public static IMcpServerBuilder WithListResourcesHandler(this IMcpServerBuilder /// This handler is typically paired with to provide a complete resources implementation, /// where the list handler advertises available resources and the read handler provides their content when requested. /// - public static IMcpServerBuilder WithReadResourceHandler(this IMcpServerBuilder builder, Func, CancellationToken, ValueTask> handler) + public static IMcpServerBuilder WithReadResourceHandler(this IMcpServerBuilder builder, McpRequestHandler handler) { Throw.IfNull(builder); @@ -745,7 +745,7 @@ public static IMcpServerBuilder WithReadResourceHandler(this IMcpServerBuilder b /// The completion handler is invoked when clients request suggestions for argument values. /// This enables auto-complete functionality for both prompt arguments and resource references. /// - public static IMcpServerBuilder WithCompleteHandler(this IMcpServerBuilder builder, Func, CancellationToken, ValueTask> handler) + public static IMcpServerBuilder WithCompleteHandler(this IMcpServerBuilder builder, McpRequestHandler handler) { Throw.IfNull(builder); @@ -775,7 +775,7 @@ public static IMcpServerBuilder WithCompleteHandler(this IMcpServerBuilder build /// resources and to send appropriate notifications through the connection when resources change. /// /// - public static IMcpServerBuilder WithSubscribeToResourcesHandler(this IMcpServerBuilder builder, Func, CancellationToken, ValueTask> handler) + public static IMcpServerBuilder WithSubscribeToResourcesHandler(this IMcpServerBuilder builder, McpRequestHandler handler) { Throw.IfNull(builder); @@ -805,7 +805,7 @@ public static IMcpServerBuilder WithSubscribeToResourcesHandler(this IMcpServerB /// to the specified resource. /// /// - public static IMcpServerBuilder WithUnsubscribeFromResourcesHandler(this IMcpServerBuilder builder, Func, CancellationToken, ValueTask> handler) + public static IMcpServerBuilder WithUnsubscribeFromResourcesHandler(this IMcpServerBuilder builder, McpRequestHandler handler) { Throw.IfNull(builder); @@ -832,7 +832,7 @@ public static IMcpServerBuilder WithUnsubscribeFromResourcesHandler(this IMcpSer /// most recently set level. /// /// - public static IMcpServerBuilder WithSetLoggingLevelHandler(this IMcpServerBuilder builder, Func, CancellationToken, ValueTask> handler) + public static IMcpServerBuilder WithSetLoggingLevelHandler(this IMcpServerBuilder builder, McpRequestHandler handler) { Throw.IfNull(builder); @@ -857,7 +857,7 @@ public static IMcpServerBuilder WithSetLoggingLevelHandler(this IMcpServerBuilde /// where the client can make repeated calls with the cursor returned by the previous call to retrieve more resource templates. /// /// - public static IMcpServerBuilder AddListResourceTemplatesFilter(this IMcpServerBuilder builder, Func, CancellationToken, ValueTask>, Func, CancellationToken, ValueTask>> filter) + public static IMcpServerBuilder AddListResourceTemplatesFilter(this IMcpServerBuilder builder, McpRequestFilter filter) { Throw.IfNull(builder); @@ -884,7 +884,7 @@ public static IMcpServerBuilder AddListResourceTemplatesFilter(this IMcpServerBu /// Tools from both sources will be combined when returning results to clients. /// /// - public static IMcpServerBuilder AddListToolsFilter(this IMcpServerBuilder builder, Func, CancellationToken, ValueTask>, Func, CancellationToken, ValueTask>> filter) + public static IMcpServerBuilder AddListToolsFilter(this IMcpServerBuilder builder, McpRequestFilter filter) { Throw.IfNull(builder); @@ -906,7 +906,7 @@ public static IMcpServerBuilder AddListToolsFilter(this IMcpServerBuilder builde /// requests. The handler should implement logic to execute the requested tool and return appropriate results. /// /// - public static IMcpServerBuilder AddCallToolFilter(this IMcpServerBuilder builder, Func, CancellationToken, ValueTask>, Func, CancellationToken, ValueTask>> filter) + public static IMcpServerBuilder AddCallToolFilter(this IMcpServerBuilder builder, McpRequestFilter filter) { Throw.IfNull(builder); @@ -933,7 +933,7 @@ public static IMcpServerBuilder AddCallToolFilter(this IMcpServerBuilder builder /// Prompts from both sources will be combined when returning results to clients. /// /// - public static IMcpServerBuilder AddListPromptsFilter(this IMcpServerBuilder builder, Func, CancellationToken, ValueTask>, Func, CancellationToken, ValueTask>> filter) + public static IMcpServerBuilder AddListPromptsFilter(this IMcpServerBuilder builder, McpRequestFilter filter) { Throw.IfNull(builder); @@ -955,7 +955,7 @@ public static IMcpServerBuilder AddListPromptsFilter(this IMcpServerBuilder buil /// requests. The handler should implement logic to fetch or generate the requested prompt and return appropriate results. /// /// - public static IMcpServerBuilder AddGetPromptFilter(this IMcpServerBuilder builder, Func, CancellationToken, ValueTask>, Func, CancellationToken, ValueTask>> filter) + public static IMcpServerBuilder AddGetPromptFilter(this IMcpServerBuilder builder, McpRequestFilter filter) { Throw.IfNull(builder); @@ -978,7 +978,7 @@ public static IMcpServerBuilder AddGetPromptFilter(this IMcpServerBuilder builde /// where the client can make repeated calls with the cursor returned by the previous call to retrieve more resources. /// /// - public static IMcpServerBuilder AddListResourcesFilter(this IMcpServerBuilder builder, Func, CancellationToken, ValueTask>, Func, CancellationToken, ValueTask>> filter) + public static IMcpServerBuilder AddListResourcesFilter(this IMcpServerBuilder builder, McpRequestFilter filter) { Throw.IfNull(builder); @@ -1000,7 +1000,7 @@ public static IMcpServerBuilder AddListResourcesFilter(this IMcpServerBuilder bu /// requests. The handler should implement logic to locate and retrieve the requested resource. /// /// - public static IMcpServerBuilder AddReadResourceFilter(this IMcpServerBuilder builder, Func, CancellationToken, ValueTask>, Func, CancellationToken, ValueTask>> filter) + public static IMcpServerBuilder AddReadResourceFilter(this IMcpServerBuilder builder, McpRequestFilter filter) { Throw.IfNull(builder); @@ -1023,7 +1023,7 @@ public static IMcpServerBuilder AddReadResourceFilter(this IMcpServerBuilder bui /// reference type and current argument value. /// /// - public static IMcpServerBuilder AddCompleteFilter(this IMcpServerBuilder builder, Func, CancellationToken, ValueTask>, Func, CancellationToken, ValueTask>> filter) + public static IMcpServerBuilder AddCompleteFilter(this IMcpServerBuilder builder, McpRequestFilter filter) { Throw.IfNull(builder); @@ -1050,7 +1050,7 @@ public static IMcpServerBuilder AddCompleteFilter(this IMcpServerBuilder builder /// whenever a relevant resource is created, updated, or deleted. /// /// - public static IMcpServerBuilder AddSubscribeToResourcesFilter(this IMcpServerBuilder builder, Func, CancellationToken, ValueTask>, Func, CancellationToken, ValueTask>> filter) + public static IMcpServerBuilder AddSubscribeToResourcesFilter(this IMcpServerBuilder builder, McpRequestFilter filter) { Throw.IfNull(builder); @@ -1077,7 +1077,7 @@ public static IMcpServerBuilder AddSubscribeToResourcesFilter(this IMcpServerBui /// to the client for the specified resources. /// /// - public static IMcpServerBuilder AddUnsubscribeFromResourcesFilter(this IMcpServerBuilder builder, Func, CancellationToken, ValueTask>, Func, CancellationToken, ValueTask>> filter) + public static IMcpServerBuilder AddUnsubscribeFromResourcesFilter(this IMcpServerBuilder builder, McpRequestFilter filter) { Throw.IfNull(builder); @@ -1104,7 +1104,7 @@ public static IMcpServerBuilder AddUnsubscribeFromResourcesFilter(this IMcpServe /// at or above the specified level to the client as notifications/message notifications. /// /// - public static IMcpServerBuilder AddSetLoggingLevelFilter(this IMcpServerBuilder builder, Func, CancellationToken, ValueTask>, Func, CancellationToken, ValueTask>> filter) + public static IMcpServerBuilder AddSetLoggingLevelFilter(this IMcpServerBuilder builder, McpRequestFilter filter) { Throw.IfNull(builder); diff --git a/src/ModelContextProtocol/McpServerHandlers.cs b/src/ModelContextProtocol/McpServerHandlers.cs index a07c81b54..34504e928 100644 --- a/src/ModelContextProtocol/McpServerHandlers.cs +++ b/src/ModelContextProtocol/McpServerHandlers.cs @@ -40,7 +40,7 @@ public sealed class McpServerHandlers /// Tools from both sources will be combined when returning results to clients. /// /// - public Func, CancellationToken, ValueTask>? ListToolsHandler { get; set; } + public McpRequestHandler? ListToolsHandler { get; set; } /// /// Gets or sets the handler for requests. @@ -49,7 +49,7 @@ public sealed class McpServerHandlers /// This handler is invoked when a client makes a call to a tool that isn't found in the collection. /// The handler should implement logic to execute the requested tool and return appropriate results. /// - public Func, CancellationToken, ValueTask>? CallToolHandler { get; set; } + public McpRequestHandler? CallToolHandler { get; set; } /// /// Gets or sets the handler for requests. @@ -65,7 +65,7 @@ public sealed class McpServerHandlers /// Prompts from both sources will be combined when returning results to clients. /// /// - public Func, CancellationToken, ValueTask>? ListPromptsHandler { get; set; } + public McpRequestHandler? ListPromptsHandler { get; set; } /// /// Gets or sets the handler for requests. @@ -74,7 +74,7 @@ public sealed class McpServerHandlers /// This handler is invoked when a client requests details for a specific prompt that isn't found in the collection. /// The handler should implement logic to fetch or generate the requested prompt and return appropriate results. /// - public Func, CancellationToken, ValueTask>? GetPromptHandler { get; set; } + public McpRequestHandler? GetPromptHandler { get; set; } /// /// Gets or sets the handler for requests. @@ -84,7 +84,7 @@ public sealed class McpServerHandlers /// It supports pagination through the cursor mechanism, where the client can make /// repeated calls with the cursor returned by the previous call to retrieve more resource templates. /// - public Func, CancellationToken, ValueTask>? ListResourceTemplatesHandler { get; set; } + public McpRequestHandler? ListResourceTemplatesHandler { get; set; } /// /// Gets or sets the handler for requests. @@ -94,7 +94,7 @@ public sealed class McpServerHandlers /// It supports pagination through the cursor mechanism, where the client can make /// repeated calls with the cursor returned by the previous call to retrieve more resources. /// - public Func, CancellationToken, ValueTask>? ListResourcesHandler { get; set; } + public McpRequestHandler? ListResourcesHandler { get; set; } /// /// Gets or sets the handler for requests. @@ -103,17 +103,17 @@ public sealed class McpServerHandlers /// This handler is invoked when a client requests the content of a specific resource identified by its URI. /// The handler should implement logic to locate and retrieve the requested resource. /// - public Func, CancellationToken, ValueTask>? ReadResourceHandler { get; set; } + public McpRequestHandler? ReadResourceHandler { get; set; } /// /// Gets or sets the handler for requests. /// /// /// This handler provides auto-completion suggestions for prompt arguments or resource references in the Model Context Protocol. - /// The handler processes auto-completion requests, returning a list of suggestions based on the + /// The handler processes auto-completion requests, returning a list of suggestions based on the /// reference type and current argument value. /// - public Func, CancellationToken, ValueTask>? CompleteHandler { get; set; } + public McpRequestHandler? CompleteHandler { get; set; } /// /// Gets or sets the handler for requests. @@ -129,7 +129,7 @@ public sealed class McpServerHandlers /// whenever a relevant resource is created, updated, or deleted. /// /// - public Func, CancellationToken, ValueTask>? SubscribeToResourcesHandler { get; set; } + public McpRequestHandler? SubscribeToResourcesHandler { get; set; } /// /// Gets or sets the handler for requests. @@ -145,7 +145,7 @@ public sealed class McpServerHandlers /// to the client for the specified resources. /// /// - public Func, CancellationToken, ValueTask>? UnsubscribeFromResourcesHandler { get; set; } + public McpRequestHandler? UnsubscribeFromResourcesHandler { get; set; } /// /// Gets or sets the handler for requests. @@ -160,7 +160,7 @@ public sealed class McpServerHandlers /// at or above the specified level to the client as notifications/message notifications. /// /// - public Func, CancellationToken, ValueTask>? SetLoggingLevelHandler { get; set; } + public McpRequestHandler? SetLoggingLevelHandler { get; set; } /// /// Overwrite any handlers in McpServerOptions with non-null handlers from this instance. diff --git a/tests/ModelContextProtocol.TestServer/Program.cs b/tests/ModelContextProtocol.TestServer/Program.cs index 0bc4134fa..ddd3701fd 100644 --- a/tests/ModelContextProtocol.TestServer/Program.cs +++ b/tests/ModelContextProtocol.TestServer/Program.cs @@ -491,7 +491,7 @@ private static CompletionsCapability ConfigureCompletions() {"temperature", ["0", "0.5", "0.7", "1.0"]}, }; - Func, CancellationToken, ValueTask> handler = async (request, cancellationToken) => + McpRequestHandler handler = async (request, cancellationToken) => { string[]? values; switch (request.Params?.Ref) diff --git a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsHandlerTests.cs b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsHandlerTests.cs index c446eb5da..adae22f24 100644 --- a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsHandlerTests.cs +++ b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsHandlerTests.cs @@ -21,7 +21,7 @@ public McpServerBuilderExtensionsHandlerTests() [Fact] public void WithListToolsHandler_Sets_Handler() { - Func, CancellationToken, ValueTask> handler = async (context, token) => new ListToolsResult(); + McpRequestHandler handler = async (context, token) => new ListToolsResult(); _builder.Object.WithListToolsHandler(handler); @@ -34,7 +34,7 @@ public void WithListToolsHandler_Sets_Handler() [Fact] public void WithCallToolHandler_Sets_Handler() { - Func, CancellationToken, ValueTask> handler = async (context, token) => new CallToolResult(); + McpRequestHandler handler = async (context, token) => new CallToolResult(); _builder.Object.WithCallToolHandler(handler); @@ -47,7 +47,7 @@ public void WithCallToolHandler_Sets_Handler() [Fact] public void WithListPromptsHandler_Sets_Handler() { - Func, CancellationToken, ValueTask> handler = async (context, token) => new ListPromptsResult(); + McpRequestHandler handler = async (context, token) => new ListPromptsResult(); _builder.Object.WithListPromptsHandler(handler); @@ -60,7 +60,7 @@ public void WithListPromptsHandler_Sets_Handler() [Fact] public void WithGetPromptHandler_Sets_Handler() { - Func, CancellationToken, ValueTask> handler = async (context, token) => new GetPromptResult(); + McpRequestHandler handler = async (context, token) => new GetPromptResult(); _builder.Object.WithGetPromptHandler(handler); @@ -73,7 +73,7 @@ public void WithGetPromptHandler_Sets_Handler() [Fact] public void WithListResourceTemplatesHandler_Sets_Handler() { - Func, CancellationToken, ValueTask> handler = async (context, token) => new ListResourceTemplatesResult(); + McpRequestHandler handler = async (context, token) => new ListResourceTemplatesResult(); _builder.Object.WithListResourceTemplatesHandler(handler); @@ -86,7 +86,7 @@ public void WithListResourceTemplatesHandler_Sets_Handler() [Fact] public void WithListResourcesHandler_Sets_Handler() { - Func, CancellationToken, ValueTask> handler = async (context, token) => new ListResourcesResult(); + McpRequestHandler handler = async (context, token) => new ListResourcesResult(); _builder.Object.WithListResourcesHandler(handler); @@ -99,7 +99,7 @@ public void WithListResourcesHandler_Sets_Handler() [Fact] public void WithReadResourceHandler_Sets_Handler() { - Func, CancellationToken, ValueTask> handler = async (context, token) => new ReadResourceResult(); + McpRequestHandler handler = async (context, token) => new ReadResourceResult(); _builder.Object.WithReadResourceHandler(handler); @@ -112,7 +112,7 @@ public void WithReadResourceHandler_Sets_Handler() [Fact] public void WithCompleteHandler_Sets_Handler() { - Func, CancellationToken, ValueTask> handler = async (context, token) => new CompleteResult(); + McpRequestHandler handler = async (context, token) => new CompleteResult(); _builder.Object.WithCompleteHandler(handler); @@ -125,7 +125,7 @@ public void WithCompleteHandler_Sets_Handler() [Fact] public void WithSubscribeToResourcesHandler_Sets_Handler() { - Func, CancellationToken, ValueTask> handler = async (context, token) => new EmptyResult(); + McpRequestHandler handler = async (context, token) => new EmptyResult(); _builder.Object.WithSubscribeToResourcesHandler(handler); @@ -138,7 +138,7 @@ public void WithSubscribeToResourcesHandler_Sets_Handler() [Fact] public void WithUnsubscribeFromResourcesHandler_Sets_Handler() { - Func, CancellationToken, ValueTask> handler = async (context, token) => new EmptyResult(); + McpRequestHandler handler = async (context, token) => new EmptyResult(); _builder.Object.WithUnsubscribeFromResourcesHandler(handler); diff --git a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsPromptsTests.cs b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsPromptsTests.cs index d697b9791..38ef9ab5d 100644 --- a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsPromptsTests.cs +++ b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsPromptsTests.cs @@ -238,7 +238,7 @@ public async Task WithPrompts_TargetInstance_UsesTarget() sc.AddMcpServer().WithPrompts(target); McpServerPrompt prompt = sc.BuildServiceProvider().GetServices().First(t => t.ProtocolPrompt.Name == "returns_string"); - var result = await prompt.GetAsync(new RequestContext(new Mock().Object) + var result = await prompt.GetAsync(new RequestContext(new Mock().Object, new JsonRpcRequest { Method = "test", Id = new RequestId("1") }) { Params = new GetPromptRequestParams { diff --git a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsResourcesTests.cs b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsResourcesTests.cs index e6b177f5a..c95fd7671 100644 --- a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsResourcesTests.cs +++ b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsResourcesTests.cs @@ -265,7 +265,7 @@ public async Task WithResources_TargetInstance_UsesTarget() sc.AddMcpServer().WithResources(target); McpServerResource resource = sc.BuildServiceProvider().GetServices().First(t => t.ProtocolResource?.Name == "returns_string"); - var result = await resource.ReadAsync(new RequestContext(new Mock().Object) + var result = await resource.ReadAsync(new RequestContext(new Mock().Object, new JsonRpcRequest { Method = "test", Id = new RequestId("1") }) { Params = new() { diff --git a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs index 70dde7980..6313480f3 100644 --- a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs +++ b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs @@ -525,7 +525,7 @@ public async Task WithTools_TargetInstance_UsesTarget() sc.AddMcpServer().WithTools(target, BuilderToolsJsonContext.Default.Options); McpServerTool tool = sc.BuildServiceProvider().GetServices().First(t => t.ProtocolTool.Name == "get_ctor_parameter"); - var result = await tool.InvokeAsync(new RequestContext(new Mock().Object), TestContext.Current.CancellationToken); + var result = await tool.InvokeAsync(new RequestContext(new Mock().Object, new JsonRpcRequest { Method = "test", Id = new RequestId("1") }), TestContext.Current.CancellationToken); Assert.Equal(target.GetCtorParameter(), (result.Content[0] as TextContentBlock)?.Text); } From d29040fbff6a4a3e564211c6ed313f6b147e5f31 Mon Sep 17 00:00:00 2001 From: Stephen Halter Date: Tue, 9 Sep 2025 07:18:37 -0700 Subject: [PATCH 4/7] Add AddAuthorizationFilters --- docs/concepts/filters.md | 54 +++++- .../AuthorizationFilterSetup.cs | 164 +++++++++++++++-- .../HttpMcpServerBuilderExtensions.cs | 27 ++- .../Server/RequestContext.cs | 17 ++ .../AuthorizeAttributeTests.cs | 165 +++++++++++++++++- 5 files changed, 410 insertions(+), 17 deletions(-) diff --git a/docs/concepts/filters.md b/docs/concepts/filters.md index 27462e176..38c095647 100644 --- a/docs/concepts/filters.md +++ b/docs/concepts/filters.md @@ -121,7 +121,20 @@ Execution flow: `filter1 -> filter2 -> filter3 -> baseHandler -> filter3 -> filt ## Built-in Authorization Filters -When using the ASP.NET Core integration (`ModelContextProtocol.AspNetCore`), authorization filters are automatically configured to support `[Authorize]` and `[AllowAnonymous]` attributes on MCP server tools, prompts, and resources. +When using the ASP.NET Core integration (`ModelContextProtocol.AspNetCore`), you can add authorization filters to support `[Authorize]` and `[AllowAnonymous]` attributes on MCP server tools, prompts, and resources by calling `AddAuthorizationFilters()` on your MCP server builder. + +### Enabling Authorization Filters + +To enable authorization support, call `AddAuthorizationFilters()` when configuring your MCP server: + +```csharp +services.AddMcpServer() + .WithHttpTransport() + .AddAuthorizationFilters() // Enable authorization filter support + .WithTools(); +``` + +**Important**: You should always call `AddAuthorizationFilters()` when using ASP.NET Core integration if you want to use authorization attributes like `[Authorize]` on your MCP server tools, prompts, or resources. ### Authorization Attributes Support @@ -200,9 +213,45 @@ For individual operations, the filters return authorization errors when access i - **Prompts**: Throws an `McpException` with "Access forbidden" message - **Resources**: Throws an `McpException` with "Access forbidden" message +### Filter Execution Order and Authorization + +Authorization filters are applied automatically when you call `AddAuthorizationFilters()`. These filters run at a specific point in the filter pipeline, which means: + +**Filters added before authorization filters** can see: +- Unauthorized requests for operations before they are rejected by the authorization filters +- Complete listings for unauthorized primitives before they are filtered out by the authorization filters + +**Filters added after authorization filters** will only see: +- Authorized requests that passed authorization checks +- Filtered listings containing only authorized primitives + +This allows you to implement logging, metrics, or other cross-cutting concerns that need to see all requests, while still maintaining proper authorization: + +```csharp +services.AddMcpServer() + .WithHttpTransport() + .AddListToolsFilter(next => async (context, cancellationToken) => + { + // This filter runs BEFORE authorization - sees all tools + Console.WriteLine("Request for tools list - will see all tools"); + var result = await next(context, cancellationToken); + Console.WriteLine($"Returning {result.Tools?.Count ?? 0} tools after authorization"); + return result; + }) + .AddAuthorizationFilters() // Authorization filtering happens here + .AddListToolsFilter(next => async (context, cancellationToken) => + { + // This filter runs AFTER authorization - only sees authorized tools + var result = await next(context, cancellationToken); + Console.WriteLine($"Post-auth filter sees {result.Tools?.Count ?? 0} authorized tools"); + return result; + }) + .WithTools(); +``` + ### Setup Requirements -To use authorization features, you must configure authentication and authorization in your ASP.NET Core application: +To use authorization features, you must configure authentication and authorization in your ASP.NET Core application and call `AddAuthorizationFilters()`: ```csharp var builder = WebApplication.CreateBuilder(args); @@ -214,6 +263,7 @@ builder.Services.AddAuthorization(); builder.Services.AddMcpServer() .WithHttpTransport() + .AddAuthorizationFilters() // Required for authorization support .WithTools() .AddCallToolFilter(next => async (context, cancellationToken) => { diff --git a/src/ModelContextProtocol.AspNetCore/AuthorizationFilterSetup.cs b/src/ModelContextProtocol.AspNetCore/AuthorizationFilterSetup.cs index 7d2c30f28..0eacfdd16 100644 --- a/src/ModelContextProtocol.AspNetCore/AuthorizationFilterSetup.cs +++ b/src/ModelContextProtocol.AspNetCore/AuthorizationFilterSetup.cs @@ -1,3 +1,4 @@ +using System.Diagnostics.CodeAnalysis; using System.Security.Claims; using Microsoft.AspNetCore.Authorization; using Microsoft.Extensions.DependencyInjection; @@ -10,8 +11,10 @@ namespace ModelContextProtocol.AspNetCore; /// /// Evaluates authorization policies from endpoint metadata. /// -internal sealed class AuthorizationFilterSetup(IAuthorizationPolicyProvider? policyProvider = null) : IConfigureOptions +internal sealed class AuthorizationFilterSetup(IAuthorizationPolicyProvider? policyProvider = null) : IConfigureOptions, IPostConfigureOptions { + private static readonly string AuthorizationFilterInvokedKey = "ModelContextProtocol.AspNetCore.AuthorizationFilter.Invoked"; + public void Configure(McpServerOptions options) { ConfigureListToolsFilter(options); @@ -25,10 +28,25 @@ public void Configure(McpServerOptions options) ConfigureGetPromptFilter(options); } + public void PostConfigure(string? name, McpServerOptions options) + { + CheckListToolsFilter(options); + CheckCallToolFilter(options); + + CheckListResourcesFilter(options); + CheckListResourceTemplatesFilter(options); + CheckReadResourceFilter(options); + + CheckListPromptsFilter(options); + CheckGetPromptFilter(options); + } + private void ConfigureListToolsFilter(McpServerOptions options) { options.Filters.ListToolsFilters.Add(next => async (context, cancellationToken) => { + context.Items[AuthorizationFilterInvokedKey] = true; + var result = await next(context, cancellationToken); await FilterAuthorizedItemsAsync( result.Tools, static tool => tool.McpServerTool, @@ -37,6 +55,22 @@ await FilterAuthorizedItemsAsync( }); } + private void CheckListToolsFilter(McpServerOptions options) + { + options.Filters.ListToolsFilters.Add(next => async (context, cancellationToken) => + { + var result = await next(context, cancellationToken); + + if (HasAuthorizationMetadata(result.Tools.Select(static tool => tool.McpServerTool)) + && !context.Items.ContainsKey(AuthorizationFilterInvokedKey)) + { + throw new InvalidOperationException("Authorization filter was not invoked for tools/list operation, but authorization metadata was found on the tools. Ensure that AddAuthorizationFilters() is called on the IMcpServerBuilder to configure authorization filters."); + } + + return result; + }); + } + private void ConfigureCallToolFilter(McpServerOptions options) { options.Filters.CallToolFilters.Add(next => async (context, cancellationToken) => @@ -51,6 +85,22 @@ private void ConfigureCallToolFilter(McpServerOptions options) }; } + context.Items[AuthorizationFilterInvokedKey] = true; + + return await next(context, cancellationToken); + }); + } + + private void CheckCallToolFilter(McpServerOptions options) + { + options.Filters.CallToolFilters.Add(next => async (context, cancellationToken) => + { + if (HasAuthorizationMetadata(context.MatchedPrimitive) + && !context.Items.ContainsKey(AuthorizationFilterInvokedKey)) + { + throw new InvalidOperationException("Authorization filter was not invoked for tools/call operation, but authorization metadata was found on the tool. Ensure that AddAuthorizationFilters() is called on the IMcpServerBuilder to configure authorization filters."); + } + return await next(context, cancellationToken); }); } @@ -59,6 +109,8 @@ private void ConfigureListResourcesFilter(McpServerOptions options) { options.Filters.ListResourcesFilters.Add(next => async (context, cancellationToken) => { + context.Items[AuthorizationFilterInvokedKey] = true; + var result = await next(context, cancellationToken); await FilterAuthorizedItemsAsync( result.Resources, static resource => resource.McpServerResource, @@ -67,10 +119,28 @@ await FilterAuthorizedItemsAsync( }); } + private void CheckListResourcesFilter(McpServerOptions options) + { + options.Filters.ListResourcesFilters.Add(next => async (context, cancellationToken) => + { + var result = await next(context, cancellationToken); + + if (HasAuthorizationMetadata(result.Resources.Select(static resource => resource.McpServerResource)) + && !context.Items.ContainsKey(AuthorizationFilterInvokedKey)) + { + throw new InvalidOperationException("Authorization filter was not invoked for resources/list operation, but authorization metadata was found on the resources. Ensure that AddAuthorizationFilters() is called on the IMcpServerBuilder to configure authorization filters."); + } + + return result; + }); + } + private void ConfigureListResourceTemplatesFilter(McpServerOptions options) { options.Filters.ListResourceTemplatesFilters.Add(next => async (context, cancellationToken) => { + context.Items[AuthorizationFilterInvokedKey] = true; + var result = await next(context, cancellationToken); await FilterAuthorizedItemsAsync( result.ResourceTemplates, static resourceTemplate => resourceTemplate.McpServerResource, @@ -79,10 +149,28 @@ await FilterAuthorizedItemsAsync( }); } + private void CheckListResourceTemplatesFilter(McpServerOptions options) + { + options.Filters.ListResourceTemplatesFilters.Add(next => async (context, cancellationToken) => + { + var result = await next(context, cancellationToken); + + if (HasAuthorizationMetadata(result.ResourceTemplates.Select(static resourceTemplate => resourceTemplate.McpServerResource)) + && !context.Items.ContainsKey(AuthorizationFilterInvokedKey)) + { + throw new InvalidOperationException("Authorization filter was not invoked for resources/templates/list operation, but authorization metadata was found on the resource templates. Ensure that AddAuthorizationFilters() is called on the IMcpServerBuilder to configure authorization filters."); + } + + return result; + }); + } + private void ConfigureReadResourceFilter(McpServerOptions options) { options.Filters.ReadResourceFilters.Add(next => async (context, cancellationToken) => { + context.Items[AuthorizationFilterInvokedKey] = true; + var authResult = await GetAuthorizationResultAsync(context.User, context.MatchedPrimitive, context.Services, context); if (!authResult.Succeeded) { @@ -93,10 +181,26 @@ private void ConfigureReadResourceFilter(McpServerOptions options) }); } + private void CheckReadResourceFilter(McpServerOptions options) + { + options.Filters.ReadResourceFilters.Add(next => async (context, cancellationToken) => + { + if (HasAuthorizationMetadata(context.MatchedPrimitive) + && !context.Items.ContainsKey(AuthorizationFilterInvokedKey)) + { + throw new InvalidOperationException("Authorization filter was not invoked for resources/read operation, but authorization metadata was found on the resource. Ensure that AddAuthorizationFilters() is called on the IMcpServerBuilder to configure authorization filters."); + } + + return await next(context, cancellationToken); + }); + } + private void ConfigureListPromptsFilter(McpServerOptions options) { options.Filters.ListPromptsFilters.Add(next => async (context, cancellationToken) => { + context.Items[AuthorizationFilterInvokedKey] = true; + var result = await next(context, cancellationToken); await FilterAuthorizedItemsAsync( result.Prompts, static prompt => prompt.McpServerPrompt, @@ -105,10 +209,28 @@ await FilterAuthorizedItemsAsync( }); } + private void CheckListPromptsFilter(McpServerOptions options) + { + options.Filters.ListPromptsFilters.Add(next => async (context, cancellationToken) => + { + var result = await next(context, cancellationToken); + + if (HasAuthorizationMetadata(result.Prompts.Select(static prompt => prompt.McpServerPrompt)) + && !context.Items.ContainsKey(AuthorizationFilterInvokedKey)) + { + throw new InvalidOperationException("Authorization filter was not invoked for prompts/list operation, but authorization metadata was found on the prompts. Ensure that AddAuthorizationFilters() is called on the IMcpServerBuilder to configure authorization filters."); + } + + return result; + }); + } + private void ConfigureGetPromptFilter(McpServerOptions options) { options.Filters.GetPromptFilters.Add(next => async (context, cancellationToken) => { + context.Items[AuthorizationFilterInvokedKey] = true; + var authResult = await GetAuthorizationResultAsync(context.User, context.MatchedPrimitive, context.Services, context); if (!authResult.Succeeded) { @@ -119,6 +241,20 @@ private void ConfigureGetPromptFilter(McpServerOptions options) }); } + private void CheckGetPromptFilter(McpServerOptions options) + { + options.Filters.GetPromptFilters.Add(next => async (context, cancellationToken) => + { + if (HasAuthorizationMetadata(context.MatchedPrimitive) + && !context.Items.ContainsKey(AuthorizationFilterInvokedKey)) + { + throw new InvalidOperationException("Authorization filter was not invoked for prompts/get operation, but authorization metadata was found on the prompt. Ensure that AddAuthorizationFilters() is called on the IMcpServerBuilder to configure authorization filters."); + } + + return await next(context, cancellationToken); + }); + } + /// /// Filters a collection of items based on authorization policies in their metadata. /// For list operations where we need to filter results by authorization. @@ -141,16 +277,7 @@ private async ValueTask FilterAuthorizedItemsAsync(IList items, Func GetAuthorizationResultAsync( ClaimsPrincipal? user, IMcpServerPrimitive? primitive, IServiceProvider? requestServices, object context) { - // If no primitive was found for this request or there is IAllowAnonymous metadata anywhere on the class or method, - // the request should go through as normal. - if (primitive is null || primitive.Metadata.Any(static m => m is IAllowAnonymous)) - { - return AuthorizationResult.Success(); - } - - // There are no [Authorize] style attributes applied to the method or containing class. Any fallback policies - // have already been enforced at the HTTP request level by the ASP.NET Core authorization middleware. - if (!primitive.Metadata.Any(static m => m is IAuthorizeData or AuthorizationPolicy or IAuthorizationRequirementData)) + if (!HasAuthorizationMetadata(primitive)) { return AuthorizationResult.Success(); } @@ -219,4 +346,19 @@ private async ValueTask GetAuthorizationResultAsync( ? reqPolicyBuilder.Build() : AuthorizationPolicy.Combine(policy, reqPolicyBuilder.Build()); } + + private static bool HasAuthorizationMetadata([NotNullWhen(true)] IMcpServerPrimitive? primitive) + { + // If no primitive was found for this request or there is IAllowAnonymous metadata anywhere on the class or method, + // the request should go through as normal. + if (primitive is null || primitive.Metadata.Any(static m => m is IAllowAnonymous)) + { + return false; + } + + return primitive.Metadata.Any(static m => m is IAuthorizeData or AuthorizationPolicy or IAuthorizationRequirementData); + } + + private static bool HasAuthorizationMetadata(IEnumerable primitives) + => primitives.Any(HasAuthorizationMetadata); } \ No newline at end of file diff --git a/src/ModelContextProtocol.AspNetCore/HttpMcpServerBuilderExtensions.cs b/src/ModelContextProtocol.AspNetCore/HttpMcpServerBuilderExtensions.cs index 70835a83d..fbceab4b1 100644 --- a/src/ModelContextProtocol.AspNetCore/HttpMcpServerBuilderExtensions.cs +++ b/src/ModelContextProtocol.AspNetCore/HttpMcpServerBuilderExtensions.cs @@ -1,3 +1,4 @@ +using Microsoft.AspNetCore.Authorization; using Microsoft.Extensions.DependencyInjection.Extensions; using Microsoft.Extensions.Options; using ModelContextProtocol.AspNetCore; @@ -30,8 +31,7 @@ public static IMcpServerBuilder WithHttpTransport(this IMcpServerBuilder builder builder.Services.AddHostedService(); builder.Services.AddDataProtection(); - // Register authorization filter setup for automatic filter configuration - builder.Services.TryAddEnumerable(ServiceDescriptor.Singleton, AuthorizationFilterSetup>()); + builder.Services.TryAddEnumerable(ServiceDescriptor.Transient, AuthorizationFilterSetup>()); if (configureOptions is not null) { @@ -40,4 +40,27 @@ public static IMcpServerBuilder WithHttpTransport(this IMcpServerBuilder builder return builder; } + + /// + /// Adds authorization filters to support + /// on MCP server tools, prompts, and resources. This method should always be called when using + /// ASP.NET Core integration to ensure proper authorization support. + /// + /// The builder instance. + /// The builder provided in . + /// is . + /// + /// This method automatically configures authorization filters for all MCP server handlers. These filters respect + /// authorization attributes such as + /// and . + /// + public static IMcpServerBuilder AddAuthorizationFilters(this IMcpServerBuilder builder) + { + ArgumentNullException.ThrowIfNull(builder); + + // Allow the authorization filters to get added multiple times in case other middleware changes the matched primitive. + builder.Services.AddTransient, AuthorizationFilterSetup>(); + + return builder; + } } diff --git a/src/ModelContextProtocol.Core/Server/RequestContext.cs b/src/ModelContextProtocol.Core/Server/RequestContext.cs index 8af9f666f..1141a815a 100644 --- a/src/ModelContextProtocol.Core/Server/RequestContext.cs +++ b/src/ModelContextProtocol.Core/Server/RequestContext.cs @@ -17,6 +17,8 @@ public sealed class RequestContext /// The server with which this instance is associated. private IMcpServer _server; + private IDictionary? _items; + /// /// Initializes a new instance of the class with the specified server and JSON-RPC request. /// @@ -44,6 +46,21 @@ public IMcpServer Server } } + /// + /// Gets or sets a key/value collection that can be used to share data within the scope of this request. + /// + public IDictionary Items + { + get + { + return _items ??= new Dictionary(); + } + set + { + _items = value; + } + } + /// Gets or sets the services associated with this request. /// /// This may not be the same instance stored in diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/AuthorizeAttributeTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/AuthorizeAttributeTests.cs index 8c173d890..1e1b60df8 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/AuthorizeAttributeTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/AuthorizeAttributeTests.cs @@ -1,10 +1,12 @@ using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Builder; using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; using ModelContextProtocol.AspNetCore.Tests.Utils; using ModelContextProtocol.Client; using ModelContextProtocol.Protocol; using ModelContextProtocol.Server; +using ModelContextProtocol.Tests.Utils; using System.ComponentModel; using System.Security.Claims; @@ -15,6 +17,8 @@ namespace ModelContextProtocol.AspNetCore.Tests; /// public class AuthorizeAttributeTests(ITestOutputHelper testOutputHelper) : KestrelInMemoryTest(testOutputHelper) { + private readonly MockLoggerProvider _mockLoggerProvider = new(); + private async Task ConnectAsync() { await using var transport = new SseClientTransport(new SseClientTransportOptions @@ -266,11 +270,147 @@ public async Task ListResources_Anonymous_OnlyReturnsAnonymousResources() Assert.Equal("resource://anonymous", resources[0].Uri); } + [Fact] + public async Task ListTools_WithoutAuthFilters_ThrowsInvalidOperationException() + { + _mockLoggerProvider.LogMessages.Clear(); + await using var app = await StartServerWithoutAuthFilters(builder => builder.WithTools()); + var client = await ConnectAsync(); + + var exception = await Assert.ThrowsAsync(async () => + await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken)); + + Assert.Equal("Request failed (remote): An error occurred.", exception.Message); + Assert.Contains(_mockLoggerProvider.LogMessages, log => + log.LogLevel == LogLevel.Warning && + log.Exception is InvalidOperationException && + log.Exception.Message.Contains("Authorization filter was not invoked for tools/list operation") && + log.Exception.Message.Contains("Ensure that AddAuthorizationFilters() is called")); + } + + [Fact] + public async Task CallTool_WithoutAuthFilters_ThrowsInvalidOperationException() + { + _mockLoggerProvider.LogMessages.Clear(); + await using var app = await StartServerWithoutAuthFilters(builder => builder.WithTools()); + var client = await ConnectAsync(); + + var exception = await Assert.ThrowsAsync(async () => + await client.CallToolAsync( + "authorized_tool", + new Dictionary { ["message"] = "test" }, + cancellationToken: TestContext.Current.CancellationToken)); + + Assert.Equal("Request failed (remote): An error occurred.", exception.Message); + Assert.Contains(_mockLoggerProvider.LogMessages, log => + log.LogLevel == LogLevel.Warning && + log.Exception is InvalidOperationException && + log.Exception.Message.Contains("Authorization filter was not invoked for tools/call operation") && + log.Exception.Message.Contains("Ensure that AddAuthorizationFilters() is called")); + } + + [Fact] + public async Task ListPrompts_WithoutAuthFilters_ThrowsInvalidOperationException() + { + _mockLoggerProvider.LogMessages.Clear(); + await using var app = await StartServerWithoutAuthFilters(builder => builder.WithPrompts()); + var client = await ConnectAsync(); + + var exception = await Assert.ThrowsAsync(async () => + await client.ListPromptsAsync(cancellationToken: TestContext.Current.CancellationToken)); + + Assert.Equal("Request failed (remote): An error occurred.", exception.Message); + Assert.Contains(_mockLoggerProvider.LogMessages, log => + log.LogLevel == LogLevel.Warning && + log.Exception is InvalidOperationException && + log.Exception.Message.Contains("Authorization filter was not invoked for prompts/list operation") && + log.Exception.Message.Contains("Ensure that AddAuthorizationFilters() is called")); + } + + [Fact] + public async Task GetPrompt_WithoutAuthFilters_ThrowsInvalidOperationException() + { + _mockLoggerProvider.LogMessages.Clear(); + await using var app = await StartServerWithoutAuthFilters(builder => builder.WithPrompts()); + var client = await ConnectAsync(); + + var exception = await Assert.ThrowsAsync(async () => + await client.GetPromptAsync( + "authorized_prompt", + new Dictionary { ["message"] = "test" }, + cancellationToken: TestContext.Current.CancellationToken)); + + Assert.Equal("Request failed (remote): An error occurred.", exception.Message); + Assert.Contains(_mockLoggerProvider.LogMessages, log => + log.LogLevel == LogLevel.Warning && + log.Exception is InvalidOperationException && + log.Exception.Message.Contains("Authorization filter was not invoked for prompts/get operation") && + log.Exception.Message.Contains("Ensure that AddAuthorizationFilters() is called")); + } + + [Fact] + public async Task ListResources_WithoutAuthFilters_ThrowsInvalidOperationException() + { + _mockLoggerProvider.LogMessages.Clear(); + await using var app = await StartServerWithoutAuthFilters(builder => builder.WithResources()); + var client = await ConnectAsync(); + + var exception = await Assert.ThrowsAsync(async () => + await client.ListResourcesAsync(cancellationToken: TestContext.Current.CancellationToken)); + + Assert.Equal("Request failed (remote): An error occurred.", exception.Message); + Assert.Contains(_mockLoggerProvider.LogMessages, log => + log.LogLevel == LogLevel.Warning && + log.Exception is InvalidOperationException && + log.Exception.Message.Contains("Authorization filter was not invoked for resources/list operation") && + log.Exception.Message.Contains("Ensure that AddAuthorizationFilters() is called")); + } + + [Fact] + public async Task ReadResource_WithoutAuthFilters_ThrowsInvalidOperationException() + { + _mockLoggerProvider.LogMessages.Clear(); + await using var app = await StartServerWithoutAuthFilters(builder => builder.WithResources()); + var client = await ConnectAsync(); + + var exception = await Assert.ThrowsAsync(async () => + await client.ReadResourceAsync( + "resource://authorized", + cancellationToken: TestContext.Current.CancellationToken)); + + Assert.Equal("Request failed (remote): An error occurred.", exception.Message); + Assert.Contains(_mockLoggerProvider.LogMessages, log => + log.LogLevel == LogLevel.Warning && + log.Exception is InvalidOperationException && + log.Exception.Message.Contains("Authorization filter was not invoked for resources/read operation") && + log.Exception.Message.Contains("Ensure that AddAuthorizationFilters() is called")); + } + + [Fact] + public async Task ListResourceTemplates_WithoutAuthFilters_ThrowsInvalidOperationException() + { + _mockLoggerProvider.LogMessages.Clear(); + await using var app = await StartServerWithoutAuthFilters(builder => builder.WithResources()); + var client = await ConnectAsync(); + + var exception = await Assert.ThrowsAsync(async () => + await client.ListResourceTemplatesAsync(cancellationToken: TestContext.Current.CancellationToken)); + + Assert.Equal("Request failed (remote): An error occurred.", exception.Message); + Assert.Contains(_mockLoggerProvider.LogMessages, log => + log.LogLevel == LogLevel.Warning && + log.Exception is InvalidOperationException && + log.Exception.Message.Contains("Authorization filter was not invoked for resources/templates/list operation") && + log.Exception.Message.Contains("Ensure that AddAuthorizationFilters() is called")); + } + private async Task StartServerWithAuth(Action configure, string? userName = null, params string[] roles) { - var builder = Builder.Services.AddMcpServer().WithHttpTransport(); - configure(builder); + var mcpServerBuilder = Builder.Services.AddMcpServer().WithHttpTransport().AddAuthorizationFilters(); + configure(mcpServerBuilder); + Builder.Services.AddAuthorization(); + Builder.Services.AddSingleton(_mockLoggerProvider); var app = Builder.Build(); @@ -291,6 +431,20 @@ private async Task StartServerWithAuth(Action return app; } + private async Task StartServerWithoutAuthFilters(Action configure) + { + var mcpServerBuilder = Builder.Services.AddMcpServer().WithHttpTransport(); // No AddAuthorizationFilters() call + configure(mcpServerBuilder); + + Builder.Services.AddAuthorization(); + Builder.Services.AddSingleton(_mockLoggerProvider); + + var app = Builder.Build(); + app.MapMcp(); + await app.StartAsync(TestContext.Current.CancellationToken); + return app; + } + private ClaimsPrincipal CreateUser(string name, params string[] roles) => new ClaimsPrincipal(new ClaimsIdentity( [new Claim("name", name), new Claim(ClaimTypes.NameIdentifier, name), ..roles.Select(role => new Claim("role", role))], @@ -370,5 +524,12 @@ public static string AuthorizedResource() { return "Authorized resource content"; } + + [McpServerResource(UriTemplate = "resource://authorized/{id}"), Description("A resource template that requires authorization.")] + [Authorize] + public static string AuthorizedResourceWithTemplate(string id) + { + return "Authorized resource content"; + } } } From f9e88ee098cd6379b1685d57f449266f69233e29 Mon Sep 17 00:00:00 2001 From: Stephen Halter Date: Tue, 9 Sep 2025 08:26:49 -0700 Subject: [PATCH 5/7] Update tool call authorization failures to throw an McpException --- .../AuthorizationFilterSetup.cs | 6 +--- .../AuthorizeAttributeTests.cs | 32 +++++++++---------- 2 files changed, 17 insertions(+), 21 deletions(-) diff --git a/src/ModelContextProtocol.AspNetCore/AuthorizationFilterSetup.cs b/src/ModelContextProtocol.AspNetCore/AuthorizationFilterSetup.cs index 0eacfdd16..f901b88a9 100644 --- a/src/ModelContextProtocol.AspNetCore/AuthorizationFilterSetup.cs +++ b/src/ModelContextProtocol.AspNetCore/AuthorizationFilterSetup.cs @@ -78,11 +78,7 @@ private void ConfigureCallToolFilter(McpServerOptions options) var authResult = await GetAuthorizationResultAsync(context.User, context.MatchedPrimitive, context.Services, context); if (!authResult.Succeeded) { - return new CallToolResult - { - Content = [new TextContentBlock { Text = "Access forbidden: This tool requires authorization." }], - IsError = true - }; + throw new McpException("Access forbidden: This tool requires authorization.", McpErrorCode.InvalidRequest); } context.Items[AuthorizationFilterInvokedKey] = true; diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/AuthorizeAttributeTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/AuthorizeAttributeTests.cs index 1e1b60df8..914284584 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/AuthorizeAttributeTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/AuthorizeAttributeTests.cs @@ -35,15 +35,15 @@ public async Task Authorize_Tool_RequiresAuthentication() await using var app = await StartServerWithAuth(builder => builder.WithTools()); var client = await ConnectAsync(); - var result = await client.CallToolAsync( - "authorized_tool", - new Dictionary { ["message"] = "test" }, - cancellationToken: TestContext.Current.CancellationToken); - // Should return error because tool requires authorization but user is anonymous - Assert.True(result.IsError ?? false); - var content = Assert.Single(result.Content.OfType()); - Assert.Equal("Access forbidden: This tool requires authorization.", content.Text); + var exception = await Assert.ThrowsAsync(async () => + await client.CallToolAsync( + "authorized_tool", + new Dictionary { ["message"] = "test" }, + cancellationToken: TestContext.Current.CancellationToken)); + + Assert.Equal("Request failed (remote): Access forbidden: This tool requires authorization.", exception.Message); + Assert.Equal(McpErrorCode.InvalidRequest, exception.ErrorCode); } [Fact] @@ -100,15 +100,15 @@ public async Task AuthorizeWithRoles_Tool_RequiresAdminRole() await using var app = await StartServerWithAuth(builder => builder.WithTools(), "TestUser", "User"); var client = await ConnectAsync(); - var result = await client.CallToolAsync( - "admin_tool", - new Dictionary { ["message"] = "test" }, - cancellationToken: TestContext.Current.CancellationToken); - // Should return error because tool requires Admin role but user only has User role - Assert.True(result.IsError ?? false); - var content = Assert.Single(result.Content.OfType()); - Assert.Equal("Access forbidden: This tool requires authorization.", content.Text); + var exception = await Assert.ThrowsAsync(async () => + await client.CallToolAsync( + "admin_tool", + new Dictionary { ["message"] = "test" }, + cancellationToken: TestContext.Current.CancellationToken)); + + Assert.Equal("Request failed (remote): Access forbidden: This tool requires authorization.", exception.Message); + Assert.Equal(McpErrorCode.InvalidRequest, exception.ErrorCode); } [Fact] From 263313e10bba34326a975c076582682dbf5104a3 Mon Sep 17 00:00:00 2001 From: Stephen Halter Date: Tue, 9 Sep 2025 08:54:12 -0700 Subject: [PATCH 6/7] Remove invalid link from comment --- src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs | 1 - 1 file changed, 1 deletion(-) diff --git a/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs b/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs index d3db7e964..58d3ace02 100644 --- a/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs +++ b/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs @@ -282,7 +282,6 @@ internal static string MakeNewSessionId() if (context.User?.Identity?.IsAuthenticated ?? false && message is not null) { // We get weird CS0131 errors only on the Windows build GitHub Action if we use "message?.Context = ..." - // https://productionresultssa0.blob.core.windows.net/actions-results/f2218319-0fdd-473b-891d-06e5a4a0f826/workflow-job-run-98901492-cf7c-5406-85d9-0f7057e0516f/logs/job/job-logs.txt?rsct=text%2Fplain&se=2025-08-26T16%3A06%3A31Z&sig=RvEQo6DgrpDUW9mnbgDvf6FVDAAoHKzk9rsDdcPxOhw%3D&ske=2025-08-27T03%3A39%3A43Z&skoid=ca7593d4-ee42-46cd-af88-8b886a2f84eb&sks=b&skt=2025-08-26T15%3A39%3A43Z&sktid=398a6654-997b-47e9-b12b-9515b896b4de&skv=2025-05-05&sp=r&spr=https&sr=b&st=2025-08-26T15%3A56%3A26Z&sv=2025-05-05 message!.Context = new() { User = context.User, From d3186fd91751c55cf7e176e079071830170cbfaf Mon Sep 17 00:00:00 2001 From: Stephen Halter Date: Wed, 10 Sep 2025 09:46:45 -0700 Subject: [PATCH 7/7] Address PR feedback - Update filters.md to use DI and logging - Update filters.md Mention that uncaught McpExceptions get turned into JSON-RPC errors - Added newlines to McpServer between blocks - Remove TODO from AuthorizationFilterSetup now that an issue has been filed --- docs/concepts/filters.md | 38 ++++++++++++------- .../AuthorizationFilterSetup.cs | 1 - .../StreamableHttpHandler.cs | 5 +-- .../Server/McpServer.cs | 4 ++ 4 files changed, 30 insertions(+), 18 deletions(-) diff --git a/docs/concepts/filters.md b/docs/concepts/filters.md index 38c095647..3e081d123 100644 --- a/docs/concepts/filters.md +++ b/docs/concepts/filters.md @@ -38,13 +38,15 @@ services.AddMcpServer() }) .AddListToolsFilter(next => async (context, cancellationToken) => { + var logger = context.Services?.GetService>(); + // Pre-processing logic - Console.WriteLine("Before handler execution"); + logger?.LogInformation("Before handler execution"); var result = await next(context, cancellationToken); // Post-processing logic - Console.WriteLine("After handler execution"); + logger?.LogInformation("After handler execution"); return result; }); ``` @@ -67,9 +69,11 @@ Execution flow: `filter1 -> filter2 -> filter3 -> baseHandler -> filter3 -> filt ```csharp .AddListToolsFilter(next => async (context, cancellationToken) => { - Console.WriteLine($"Processing request from {context.Meta.ProgressToken}"); + var logger = context.Services?.GetService>(); + + logger?.LogInformation($"Processing request from {context.Meta.ProgressToken}"); var result = await next(context, cancellationToken); - Console.WriteLine($"Returning {result.Tools?.Count ?? 0} tools"); + logger?.LogInformation($"Returning {result.Tools?.Count ?? 0} tools"); return result; }); ``` @@ -97,10 +101,12 @@ Execution flow: `filter1 -> filter2 -> filter3 -> baseHandler -> filter3 -> filt ```csharp .AddListToolsFilter(next => async (context, cancellationToken) => { + var logger = context.Services?.GetService>(); + var stopwatch = Stopwatch.StartNew(); var result = await next(context, cancellationToken); stopwatch.Stop(); - Console.WriteLine($"Handler took {stopwatch.ElapsedMilliseconds}ms"); + logger?.LogInformation($"Handler took {stopwatch.ElapsedMilliseconds}ms"); return result; }); ``` @@ -109,9 +115,13 @@ Execution flow: `filter1 -> filter2 -> filter3 -> baseHandler -> filter3 -> filt ```csharp .AddListResourcesFilter(next => async (context, cancellationToken) => { + var cache = context.Services!.GetRequiredService(); + var cacheKey = $"resources:{context.Params.Cursor}"; if (cache.TryGetValue(cacheKey, out var cached)) - return cached; + { + return (ListResourcesResult)cached; + } var result = await next(context, cancellationToken); cache.Set(cacheKey, result, TimeSpan.FromMinutes(5)); @@ -207,11 +217,7 @@ The authorization filters work differently for list operations versus individual For list operations, the filters automatically remove unauthorized items from the results. Users only see tools, prompts, or resources they have permission to access. #### Individual Operations (CallTool, GetPrompt, ReadResource) -For individual operations, the filters return authorization errors when access is denied: - -- **Tools**: Returns a `CallToolResult` with `IsError = true` and an error message -- **Prompts**: Throws an `McpException` with "Access forbidden" message -- **Resources**: Throws an `McpException` with "Access forbidden" message +For individual operations, the filters throw an `McpException` with "Access forbidden" message. These get turned into JSON-RPC errors if uncaught by middleware. ### Filter Execution Order and Authorization @@ -232,18 +238,22 @@ services.AddMcpServer() .WithHttpTransport() .AddListToolsFilter(next => async (context, cancellationToken) => { + var logger = context.Services?.GetService>(); + // This filter runs BEFORE authorization - sees all tools - Console.WriteLine("Request for tools list - will see all tools"); + logger?.LogInformation("Request for tools list - will see all tools"); var result = await next(context, cancellationToken); - Console.WriteLine($"Returning {result.Tools?.Count ?? 0} tools after authorization"); + logger?.LogInformation($"Returning {result.Tools?.Count ?? 0} tools after authorization"); return result; }) .AddAuthorizationFilters() // Authorization filtering happens here .AddListToolsFilter(next => async (context, cancellationToken) => { + var logger = context.Services?.GetService>(); + // This filter runs AFTER authorization - only sees authorized tools var result = await next(context, cancellationToken); - Console.WriteLine($"Post-auth filter sees {result.Tools?.Count ?? 0} authorized tools"); + logger?.LogInformation($"Post-auth filter sees {result.Tools?.Count ?? 0} authorized tools"); return result; }) .WithTools(); diff --git a/src/ModelContextProtocol.AspNetCore/AuthorizationFilterSetup.cs b/src/ModelContextProtocol.AspNetCore/AuthorizationFilterSetup.cs index f901b88a9..bd0ceabeb 100644 --- a/src/ModelContextProtocol.AspNetCore/AuthorizationFilterSetup.cs +++ b/src/ModelContextProtocol.AspNetCore/AuthorizationFilterSetup.cs @@ -283,7 +283,6 @@ private async ValueTask GetAuthorizationResultAsync( throw new InvalidOperationException($"You must call AddAuthorization() because an authorization related attribute was found on {primitive.Id}"); } - // TODO: Cache policy lookup. We would probably use a singleton (not-static) ConditionalWeakTable. var policy = await CombineAsync(policyProvider, primitive.Metadata); if (policy is null) { diff --git a/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs b/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs index 58d3ace02..a31c6fb75 100644 --- a/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs +++ b/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs @@ -279,10 +279,9 @@ internal static string MakeNewSessionId() // Implementation for reading a JSON-RPC message from the request body var message = await context.Request.ReadFromJsonAsync(s_messageTypeInfo, context.RequestAborted); - if (context.User?.Identity?.IsAuthenticated ?? false && message is not null) + if (context.User?.Identity?.IsAuthenticated == true && message is not null) { - // We get weird CS0131 errors only on the Windows build GitHub Action if we use "message?.Context = ..." - message!.Context = new() + message.Context = new() { User = context.User, }; diff --git a/src/ModelContextProtocol.Core/Server/McpServer.cs b/src/ModelContextProtocol.Core/Server/McpServer.cs index 8785302a3..6e15e2465 100644 --- a/src/ModelContextProtocol.Core/Server/McpServer.cs +++ b/src/ModelContextProtocol.Core/Server/McpServer.cs @@ -667,18 +667,22 @@ private static McpRequestHandler BuildFilterPipeline? finalHandler = null) { var current = baseHandler; + if (finalHandler is not null) { current = finalHandler(current); } + for (int i = filters.Count - 1; i >= 0; i--) { current = filters[i](current); } + if (initialHandler is not null) { current = initialHandler(current); } + return current; }