diff --git a/src/Http/Http.Abstractions/src/IRouteHandlerFilter.cs b/src/Http/Http.Abstractions/src/IRouteHandlerFilter.cs index 4d3e583eaa5d..854947c4b351 100644 --- a/src/Http/Http.Abstractions/src/IRouteHandlerFilter.cs +++ b/src/Http/Http.Abstractions/src/IRouteHandlerFilter.cs @@ -9,12 +9,12 @@ namespace Microsoft.AspNetCore.Http; public interface IRouteHandlerFilter { /// - /// Implements the core logic associated with the filter given a + /// Implements the core logic associated with the filter given a /// and the next filter to call in the pipeline. /// - /// The associated with the current request/response. + /// The associated with the current request/response. /// The next filter in the pipeline. /// An awaitable result of calling the handler and apply /// any modifications made by filters in the pipeline. - ValueTask InvokeAsync(RouteHandlerFilterContext context, Func> next); + ValueTask InvokeAsync(RouteHandlerInvocationContext context, RouteHandlerFilterDelegate next); } diff --git a/src/Http/Http.Abstractions/src/PublicAPI.Unshipped.txt b/src/Http/Http.Abstractions/src/PublicAPI.Unshipped.txt index 35605ea00afc..dc7700f1d4a5 100644 --- a/src/Http/Http.Abstractions/src/PublicAPI.Unshipped.txt +++ b/src/Http/Http.Abstractions/src/PublicAPI.Unshipped.txt @@ -1,18 +1,23 @@ #nullable enable *REMOVED*abstract Microsoft.AspNetCore.Http.HttpResponse.ContentType.get -> string! Microsoft.AspNetCore.Http.EndpointMetadataCollection.GetRequiredMetadata() -> T! -Microsoft.AspNetCore.Http.RouteHandlerFilterContext.RouteHandlerFilterContext(Microsoft.AspNetCore.Http.HttpContext! httpContext, params object![]! parameters) -> void -Microsoft.AspNetCore.Http.IRouteHandlerFilter.InvokeAsync(Microsoft.AspNetCore.Http.RouteHandlerFilterContext! context, System.Func>! next) -> System.Threading.Tasks.ValueTask +Microsoft.AspNetCore.Http.IRouteHandlerFilter.InvokeAsync(Microsoft.AspNetCore.Http.RouteHandlerInvocationContext! context, Microsoft.AspNetCore.Http.RouteHandlerFilterDelegate! next) -> System.Threading.Tasks.ValueTask Microsoft.AspNetCore.Http.Metadata.IFromFormMetadata Microsoft.AspNetCore.Http.Metadata.IFromFormMetadata.Name.get -> string? +Microsoft.AspNetCore.Http.RouteHandlerContext +Microsoft.AspNetCore.Http.RouteHandlerContext.EndpointMetadata.get -> Microsoft.AspNetCore.Http.EndpointMetadataCollection! +Microsoft.AspNetCore.Http.RouteHandlerContext.MethodInfo.get -> System.Reflection.MethodInfo! +Microsoft.AspNetCore.Http.RouteHandlerContext.RouteHandlerContext(System.Reflection.MethodInfo! methodInfo, Microsoft.AspNetCore.Http.EndpointMetadataCollection! endpointMetadata) -> void +Microsoft.AspNetCore.Http.RouteHandlerFilterDelegate +Microsoft.AspNetCore.Http.RouteHandlerInvocationContext +Microsoft.AspNetCore.Http.RouteHandlerInvocationContext.HttpContext.get -> Microsoft.AspNetCore.Http.HttpContext! +Microsoft.AspNetCore.Http.RouteHandlerInvocationContext.Parameters.get -> System.Collections.Generic.IList! +Microsoft.AspNetCore.Http.RouteHandlerInvocationContext.RouteHandlerInvocationContext(Microsoft.AspNetCore.Http.HttpContext! httpContext, params object![]! parameters) -> void Microsoft.AspNetCore.Routing.RouteValueDictionary.RouteValueDictionary(Microsoft.AspNetCore.Routing.RouteValueDictionary? dictionary) -> void Microsoft.AspNetCore.Routing.RouteValueDictionary.RouteValueDictionary(System.Collections.Generic.IEnumerable>? values) -> void Microsoft.AspNetCore.Routing.RouteValueDictionary.RouteValueDictionary(System.Collections.Generic.IEnumerable>? values) -> void abstract Microsoft.AspNetCore.Http.HttpResponse.ContentType.get -> string? Microsoft.AspNetCore.Http.Metadata.ISkipStatusCodePagesMetadata -Microsoft.AspNetCore.Http.RouteHandlerFilterContext -Microsoft.AspNetCore.Http.RouteHandlerFilterContext.HttpContext.get -> Microsoft.AspNetCore.Http.HttpContext! -Microsoft.AspNetCore.Http.RouteHandlerFilterContext.Parameters.get -> System.Collections.Generic.IList! Microsoft.AspNetCore.Http.IRouteHandlerFilter Microsoft.AspNetCore.Http.Metadata.IEndpointDescriptionMetadata Microsoft.AspNetCore.Http.Metadata.IEndpointDescriptionMetadata.Description.get -> string! diff --git a/src/Http/Http.Abstractions/src/RouteHandlerContext.cs b/src/Http/Http.Abstractions/src/RouteHandlerContext.cs new file mode 100644 index 000000000000..fdf54a3b9bb5 --- /dev/null +++ b/src/Http/Http.Abstractions/src/RouteHandlerContext.cs @@ -0,0 +1,34 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Reflection; + +namespace Microsoft.AspNetCore.Http; + +/// +/// Represents the information accessible via the route handler filter +/// API when the user is constructing a new route handler. +/// +public sealed class RouteHandlerContext +{ + /// + /// Creates a new instance of the . + /// + /// The associated with the route handler of the current request. + /// The associated with the endpoint the filter is targeting. + public RouteHandlerContext(MethodInfo methodInfo, EndpointMetadataCollection endpointMetadata) + { + MethodInfo = methodInfo; + EndpointMetadata = endpointMetadata; + } + + /// + /// The associated with the current route handler. + /// + public MethodInfo MethodInfo { get; } + + /// + /// The associated with the current endpoint. + /// + public EndpointMetadataCollection EndpointMetadata { get; } +} diff --git a/src/Http/Http.Abstractions/src/RouteHandlerFilterDelegate.cs b/src/Http/Http.Abstractions/src/RouteHandlerFilterDelegate.cs new file mode 100644 index 000000000000..afc443a33a7d --- /dev/null +++ b/src/Http/Http.Abstractions/src/RouteHandlerFilterDelegate.cs @@ -0,0 +1,13 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace Microsoft.AspNetCore.Http; + +/// +/// A delegate that is applied as a filter on a route handler. +/// +/// The associated with the current request. +/// +/// A result of calling the handler and applying any modifications made by filters in the pipeline. +/// +public delegate ValueTask RouteHandlerFilterDelegate(RouteHandlerInvocationContext context); diff --git a/src/Http/Http.Abstractions/src/RouteHandlerFilterContext.cs b/src/Http/Http.Abstractions/src/RouteHandlerInvocationContext.cs similarity index 77% rename from src/Http/Http.Abstractions/src/RouteHandlerFilterContext.cs rename to src/Http/Http.Abstractions/src/RouteHandlerInvocationContext.cs index 558d97cbd06b..d7cfe600a760 100644 --- a/src/Http/Http.Abstractions/src/RouteHandlerFilterContext.cs +++ b/src/Http/Http.Abstractions/src/RouteHandlerInvocationContext.cs @@ -7,14 +7,14 @@ namespace Microsoft.AspNetCore.Http; /// Provides an abstraction for wrapping the and parameters /// provided to a route handler. /// -public class RouteHandlerFilterContext +public sealed class RouteHandlerInvocationContext { /// - /// Creates a new instance of the for a given request. + /// Creates a new instance of the for a given request. /// /// The associated with the current request. /// A list of parameters provided in the current request. - public RouteHandlerFilterContext(HttpContext httpContext, params object[] parameters) + public RouteHandlerInvocationContext(HttpContext httpContext, params object[] parameters) { HttpContext = httpContext; Parameters = parameters; @@ -28,7 +28,7 @@ public RouteHandlerFilterContext(HttpContext httpContext, params object[] parame /// /// A list of parameters provided in the current request to the filter. /// - /// This list is not read-only to premit modifying of existing parameters by filters. + /// This list is not read-only to permit modifying of existing parameters by filters. /// /// public IList Parameters { get; } diff --git a/src/Http/Http.Extensions/src/PublicAPI.Unshipped.txt b/src/Http/Http.Extensions/src/PublicAPI.Unshipped.txt index c24492007419..f5825c4e8476 100644 --- a/src/Http/Http.Extensions/src/PublicAPI.Unshipped.txt +++ b/src/Http/Http.Extensions/src/PublicAPI.Unshipped.txt @@ -1,8 +1,8 @@ #nullable enable +Microsoft.AspNetCore.Http.RequestDelegateFactoryOptions.RouteHandlerFilterFactories.get -> System.Collections.Generic.IReadOnlyList!>? +Microsoft.AspNetCore.Http.RequestDelegateFactoryOptions.RouteHandlerFilterFactories.init -> void Microsoft.Extensions.DependencyInjection.RouteHandlerJsonServiceExtensions static Microsoft.Extensions.DependencyInjection.RouteHandlerJsonServiceExtensions.ConfigureRouteHandlerJsonOptions(this Microsoft.Extensions.DependencyInjection.IServiceCollection! services, System.Action! configureOptions) -> Microsoft.Extensions.DependencyInjection.IServiceCollection! -Microsoft.AspNetCore.Http.RequestDelegateFactoryOptions.RouteHandlerFilters.get -> System.Collections.Generic.IReadOnlyList? -Microsoft.AspNetCore.Http.RequestDelegateFactoryOptions.RouteHandlerFilters.init -> void Microsoft.AspNetCore.Http.EndpointDescriptionAttribute Microsoft.AspNetCore.Http.EndpointDescriptionAttribute.EndpointDescriptionAttribute(string! description) -> void Microsoft.AspNetCore.Http.EndpointDescriptionAttribute.Description.get -> string! diff --git a/src/Http/Http.Extensions/src/RequestDelegateFactory.cs b/src/Http/Http.Extensions/src/RequestDelegateFactory.cs index 3807ea2c342d..c32a6a790147 100644 --- a/src/Http/Http.Extensions/src/RequestDelegateFactory.cs +++ b/src/Http/Http.Extensions/src/RequestDelegateFactory.cs @@ -79,13 +79,13 @@ public static partial class RequestDelegateFactory private static readonly BinaryExpression TempSourceStringNullExpr = Expression.Equal(TempSourceStringExpr, Expression.Constant(null)); private static readonly UnaryExpression TempSourceStringIsNotNullOrEmptyExpr = Expression.Not(Expression.Call(StringIsNullOrEmptyMethod, TempSourceStringExpr)); - private static readonly ConstructorInfo RouteHandlerFilterContextConstructor = typeof(RouteHandlerFilterContext).GetConstructor(new[] { typeof(HttpContext), typeof(object[]) })!; - private static readonly ParameterExpression FilterContextExpr = Expression.Parameter(typeof(RouteHandlerFilterContext), "context"); - private static readonly MemberExpression FilterContextParametersExpr = Expression.Property(FilterContextExpr, typeof(RouteHandlerFilterContext).GetProperty(nameof(RouteHandlerFilterContext.Parameters))!); - private static readonly MemberExpression FilterContextHttpContextExpr = Expression.Property(FilterContextExpr, typeof(RouteHandlerFilterContext).GetProperty(nameof(RouteHandlerFilterContext.HttpContext))!); + private static readonly ConstructorInfo RouteHandlerInvocationContextConstructor = typeof(RouteHandlerInvocationContext).GetConstructor(new[] { typeof(HttpContext), typeof(object[]) })!; + private static readonly ParameterExpression FilterContextExpr = Expression.Parameter(typeof(RouteHandlerInvocationContext), "context"); + private static readonly MemberExpression FilterContextParametersExpr = Expression.Property(FilterContextExpr, typeof(RouteHandlerInvocationContext).GetProperty(nameof(RouteHandlerInvocationContext.Parameters))!); + private static readonly MemberExpression FilterContextHttpContextExpr = Expression.Property(FilterContextExpr, typeof(RouteHandlerInvocationContext).GetProperty(nameof(RouteHandlerInvocationContext.HttpContext))!); private static readonly MemberExpression FilterContextHttpContextResponseExpr = Expression.Property(FilterContextHttpContextExpr, typeof(HttpContext).GetProperty(nameof(HttpContext.Response))!); private static readonly MemberExpression FilterContextHttpContextStatusCodeExpr = Expression.Property(FilterContextHttpContextResponseExpr, typeof(HttpResponse).GetProperty(nameof(HttpResponse.StatusCode))!); - private static readonly ParameterExpression InvokedFilterContextExpr = Expression.Parameter(typeof(RouteHandlerFilterContext), "filterContext"); + private static readonly ParameterExpression InvokedFilterContextExpr = Expression.Parameter(typeof(RouteHandlerInvocationContext), "filterContext"); private static readonly string[] DefaultAcceptsContentType = new[] { "application/json" }; private static readonly string[] FormFileContentType = new[] { "multipart/form-data" }; @@ -166,7 +166,7 @@ private static FactoryContext CreateFactoryContext(RequestDelegateFactoryOptions RouteParameters = options?.RouteParameterNames?.ToList(), ThrowOnBadRequest = options?.ThrowOnBadRequest ?? false, DisableInferredFromBody = options?.DisableInferBodyFromParameters ?? false, - Filters = options?.RouteHandlerFilters?.ToList() + Filters = options?.RouteHandlerFilterFactories?.ToList() }; private static Func CreateTargetableRequestDelegate(MethodInfo methodInfo, Expression? targetExpression, FactoryContext factoryContext) @@ -196,15 +196,15 @@ private static FactoryContext CreateFactoryContext(RequestDelegateFactoryOptions if (factoryContext.Filters is { Count: > 0 }) { var filterPipeline = CreateFilterPipeline(methodInfo, targetExpression, factoryContext); - Expression>> invokePipeline = (context) => filterPipeline(context); + Expression>> invokePipeline = (context) => filterPipeline(context); returnType = typeof(ValueTask); - // var filterContext = new RouteHandlerFilterContext(httpContext, new[] { (object)name_local, (object)int_local }); + // var filterContext = new RouteHandlerInvocationContext(httpContext, new[] { (object)name_local, (object)int_local }); // invokePipeline.Invoke(filterContext); factoryContext.MethodCall = Expression.Block( new[] { InvokedFilterContextExpr }, Expression.Assign( InvokedFilterContextExpr, - Expression.New(RouteHandlerFilterContextConstructor, + Expression.New(RouteHandlerInvocationContextConstructor, new Expression[] { HttpContextExpr, Expression.NewArrayInit(typeof(object), factoryContext.BoxedArgs) })), Expression.Invoke(invokePipeline, InvokedFilterContextExpr) ); @@ -222,13 +222,13 @@ private static FactoryContext CreateFactoryContext(RequestDelegateFactoryOptions return HandleRequestBodyAndCompileRequestDelegate(responseWritingMethodCall, factoryContext); } - private static Func> CreateFilterPipeline(MethodInfo methodInfo, Expression? target, FactoryContext factoryContext) + private static RouteHandlerFilterDelegate CreateFilterPipeline(MethodInfo methodInfo, Expression? target, FactoryContext factoryContext) { Debug.Assert(factoryContext.Filters is not null); // httpContext.Response.StatusCode >= 400 // ? Task.CompletedTask // : handler((string)context.Parameters[0], (int)context.Parameters[1]) - var filteredInvocation = Expression.Lambda>>( + var filteredInvocation = Expression.Lambda( Expression.Condition( Expression.GreaterThanOrEqual(FilterContextHttpContextStatusCodeExpr, Expression.Constant(400)), CompletedValueTaskExpr, @@ -240,12 +240,16 @@ target is null : Expression.Call(target, methodInfo, factoryContext.ContextArgAccess)) )), FilterContextExpr).Compile(); + var routeHandlerContext = new RouteHandlerContext( + methodInfo, + new EndpointMetadataCollection(factoryContext.Metadata)); for (var i = factoryContext.Filters.Count - 1; i >= 0; i--) { - var currentFilter = factoryContext.Filters![i]; + var currentFilterFactory = factoryContext.Filters[i]; var nextFilter = filteredInvocation; - filteredInvocation = (RouteHandlerFilterContext context) => currentFilter.InvokeAsync(context, nextFilter); + var currentFilter = currentFilterFactory(routeHandlerContext, nextFilter); + filteredInvocation = (RouteHandlerInvocationContext context) => currentFilter(context); } return filteredInvocation; @@ -264,7 +268,7 @@ private static Expression[] CreateArguments(ParameterInfo[]? parameters, Factory { args[i] = CreateArgument(parameters[i], factoryContext); // Register expressions containing the boxed and unboxed variants - // of the route handler's arguments for use in RouteHandlerFilterContext + // of the route handler's arguments for use in RouteHandlerInvocationContext // construction and route handler invocation. // (string)context.Parameters[0]; factoryContext.ContextArgAccess.Add( @@ -1693,7 +1697,7 @@ private class FactoryContext public List ContextArgAccess { get; } = new(); public Expression? MethodCall { get; set; } public List BoxedArgs { get; } = new(); - public List? Filters { get; init; } + public List>? Filters { get; init; } } private static class RequestDelegateFactoryConstants diff --git a/src/Http/Http.Extensions/src/RequestDelegateFactoryOptions.cs b/src/Http/Http.Extensions/src/RequestDelegateFactoryOptions.cs index 870c2a06158e..70207f9c63d8 100644 --- a/src/Http/Http.Extensions/src/RequestDelegateFactoryOptions.cs +++ b/src/Http/Http.Extensions/src/RequestDelegateFactoryOptions.cs @@ -35,5 +35,5 @@ public sealed class RequestDelegateFactoryOptions /// /// The list of filters that must run in the pipeline for a given route handler. /// - public IReadOnlyList? RouteHandlerFilters { get; init; } + public IReadOnlyList>? RouteHandlerFilterFactories { get; init; } } diff --git a/src/Http/Http.Extensions/test/RequestDelegateFactoryTests.cs b/src/Http/Http.Extensions/test/RequestDelegateFactoryTests.cs index 1af56e02660e..f597d15edc15 100644 --- a/src/Http/Http.Extensions/test/RequestDelegateFactoryTests.cs +++ b/src/Http/Http.Extensions/test/RequestDelegateFactoryTests.cs @@ -4216,7 +4216,14 @@ string HelloName(string name) // Act var factoryResult = RequestDelegateFactory.Create(HelloName, new RequestDelegateFactoryOptions() { - RouteHandlerFilters = new List() { new ModifyStringArgumentFilter() } + RouteHandlerFilterFactories = new List>() + { + (routeHandlerContext, next) => async (context) => + { + context.Parameters[0] = context.Parameters[0] != null ? $"{((string)context.Parameters[0]!)}Prefix" : "NULL"; + return await next(context); + } + } }); var requestDelegate = factoryResult.RequestDelegate; await requestDelegate(httpContext); @@ -4243,7 +4250,16 @@ string HelloName(string name) // Act var factoryResult = RequestDelegateFactory.Create(HelloName, new RequestDelegateFactoryOptions() { - RouteHandlerFilters = new List() { new ProvideCustomErrorMessageFilter() } + RouteHandlerFilterFactories = new List>() { + (routeHandlerContext, next) => async (context) => + { + if (context.HttpContext.Response.StatusCode == 400) + { + return Results.Problem("New response", statusCode: 400); + } + return await next(context); + } + } }); var requestDelegate = factoryResult.RequestDelegate; await requestDelegate(httpContext); @@ -4280,7 +4296,22 @@ string HelloName(string name, int age) // Act var factoryResult = RequestDelegateFactory.Create(HelloName, new RequestDelegateFactoryOptions() { - RouteHandlerFilters = new List() { new ModifyIntArgumentFilter(), new LogArgumentsFilter(Log) } + RouteHandlerFilterFactories = new List>() + { + (routeHandlerContext, next) => async (context) => + { + context.Parameters[1] = ((int)context.Parameters[1]!) + 2; + return await next(context); + }, + (routeHandlerContext, next) => async (context) => + { + foreach (var parameter in context.Parameters) + { + Log(parameter!.ToString() ?? "no arg"); + } + return await next(context); + } + } }); var requestDelegate = factoryResult.RequestDelegate; await requestDelegate(httpContext); @@ -4291,6 +4322,109 @@ string HelloName(string name, int age) Assert.Equal(2, loggerInvoked); } + [Fact] + public async Task RequestDelegateFactory_CanInvokeEndpointFilter_ThatUsesMethodInfo() + { + // Arrange + string HelloName(string name) + { + return $"Hello, {name}!."; + }; + + var httpContext = CreateHttpContext(); + + var responseBodyStream = new MemoryStream(); + httpContext.Response.Body = responseBodyStream; + + httpContext.Request.Query = new QueryCollection(new Dictionary + { + ["name"] = "TestName" + }); + + // Act + var factoryResult = RequestDelegateFactory.Create(HelloName, new RequestDelegateFactoryOptions() + { + RouteHandlerFilterFactories = new List>() + { + (routeHandlerContext, next) => + { + var parameters = routeHandlerContext.MethodInfo.GetParameters(); + var isInt = parameters.Length == 2 && parameters[1].ParameterType == typeof(int); + return async (context) => + { + if (isInt) + { + context.Parameters[1] = ((int)context.Parameters[1]!) + 2; + return await next(context); + } + return "Is not an int."; + }; + }, + } + }); + var requestDelegate = factoryResult.RequestDelegate; + await requestDelegate(httpContext); + + // Assert + var responseBody = Encoding.UTF8.GetString(responseBodyStream.ToArray()); + Assert.Equal("Is not an int.", responseBody); + } + + [Fact] + public async Task RequestDelegateFactory_CanInvokeEndpointFilter_ThatUsesEndpointMetadata() + { + // Arrange + string HelloName(IFormFileCollection formFiles) + { + return $"Got {formFiles.Count} files."; + }; + + var fileContent = new StringContent("hello", Encoding.UTF8, "application/octet-stream"); + var form = new MultipartFormDataContent("some-boundary"); + form.Add(fileContent, "file", "file.txt"); + + var stream = new MemoryStream(); + await form.CopyToAsync(stream); + + stream.Seek(0, SeekOrigin.Begin); + + var httpContext = CreateHttpContext(); + httpContext.Request.Body = stream; + httpContext.Request.Headers["Content-Type"] = "multipart/form-data;boundary=some-boundary"; + httpContext.Features.Set(new RequestBodyDetectionFeature(true)); + + var responseBodyStream = new MemoryStream(); + httpContext.Response.Body = responseBodyStream; + + // Act + var factoryResult = RequestDelegateFactory.Create(HelloName, new RequestDelegateFactoryOptions() + { + RouteHandlerFilterFactories = new List>() + { + (routeHandlerContext, next) => + { + var acceptsMetadata = routeHandlerContext.EndpointMetadata.OfType(); + var contentType = acceptsMetadata.SingleOrDefault()?.ContentTypes.SingleOrDefault(); + + return async (context) => + { + if (contentType == "multipart/form-data") + { + return "I see you expect a form."; + } + return await next(context); + }; + }, + } + }); + var requestDelegate = factoryResult.RequestDelegate; + await requestDelegate(httpContext); + + // Assert + var responseBody = Encoding.UTF8.GetString(responseBodyStream.ToArray()); + Assert.Equal("I see you expect a form.", responseBody); + } + [Fact] public async Task RequestDelegateFactory_CanInvokeSingleEndpointFilter_ThatModifiesBodyParameter() { @@ -4316,7 +4450,16 @@ string PrintTodo(Todo todo) // Act var factoryResult = RequestDelegateFactory.Create(PrintTodo, new RequestDelegateFactoryOptions() { - RouteHandlerFilters = new List() { new ModifyTodoArgumentFilter() } + RouteHandlerFilterFactories = new List>() + { + (routeHandlerContext, next) => async (context) => + { + Todo originalTodo = (Todo)context.Parameters[0]!; + originalTodo!.IsComplete = !originalTodo.IsComplete; + context.Parameters[0] = originalTodo; + return await next(context); + } + } }); var requestDelegate = factoryResult.RequestDelegate; await requestDelegate(httpContext); @@ -4348,7 +4491,18 @@ string HelloName(string name) // Act var factoryResult = RequestDelegateFactory.Create(HelloName, new RequestDelegateFactoryOptions() { - RouteHandlerFilters = new List() { new ModifyStringResultFilter() } + RouteHandlerFilterFactories = new List>() + { + (routeHandlerContext, next) => async (context) => + { + var previousResult = await next(context); + if (previousResult is string stringResult) + { + return stringResult.ToUpperInvariant(); + } + return previousResult; + } + } }); var requestDelegate = factoryResult.RequestDelegate; await requestDelegate(httpContext); @@ -4380,7 +4534,23 @@ string HelloName(string name) // Act var factoryResult = RequestDelegateFactory.Create(HelloName, new RequestDelegateFactoryOptions() { - RouteHandlerFilters = new List() { new ModifyStringResultFilter(), new ModifyStringArgumentFilter() } + RouteHandlerFilterFactories = new List>() + { + (routeHandlerContext, next) => async (context) => + { + var previousResult = await next(context); + if (previousResult is string stringResult) + { + return stringResult.ToUpperInvariant(); + } + return previousResult; + }, + (RouteHandlerContext, next) => async (context) => + { + context.Parameters[0] = context.Parameters[0] != null ? $"{((string)context.Parameters[0]!)}Prefix" : "NULL"; + return await next(context); + } + } }); var requestDelegate = factoryResult.RequestDelegate; await requestDelegate(httpContext); @@ -4749,78 +4919,6 @@ public TlsConnectionFeature(X509Certificate2 clientCertificate) throw new NotImplementedException(); } } - - private class ModifyStringArgumentFilter : IRouteHandlerFilter - { - public async ValueTask InvokeAsync(RouteHandlerFilterContext context, Func> next) - { - context.Parameters[0] = context.Parameters[0] != null ? $"{((string)context.Parameters[0]!)}Prefix" : "NULL"; - return await next(context); - } - } - - private class ModifyIntArgumentFilter : IRouteHandlerFilter - { - public async ValueTask InvokeAsync(RouteHandlerFilterContext context, Func> next) - { - context.Parameters[1] = ((int)context.Parameters[1]!) + 2; - return await next(context); - } - } - - private class ModifyTodoArgumentFilter : IRouteHandlerFilter - { - public async ValueTask InvokeAsync(RouteHandlerFilterContext context, Func> next) - { - Todo originalTodo = (Todo)context.Parameters[0]!; - originalTodo!.IsComplete = !originalTodo.IsComplete; - context.Parameters[0] = originalTodo; - return await next(context); - } - } - - private class ProvideCustomErrorMessageFilter : IRouteHandlerFilter - { - public async ValueTask InvokeAsync(RouteHandlerFilterContext context, Func> next) - { - if (context.HttpContext.Response.StatusCode == 400) - { - return Results.Problem("New response", statusCode: 400); - } - return await next(context); - } - } - - private class LogArgumentsFilter : IRouteHandlerFilter - { - private Action _logger; - - public LogArgumentsFilter(Action logger) - { - _logger = logger; - } - public async ValueTask InvokeAsync(RouteHandlerFilterContext context, Func> next) - { - foreach (var parameter in context.Parameters) - { - _logger(parameter!.ToString() ?? "no arg"); - } - return await next(context); - } - } - - private class ModifyStringResultFilter : IRouteHandlerFilter - { - public async ValueTask InvokeAsync(RouteHandlerFilterContext context, Func> next) - { - var previousResult = await next(context); - if (previousResult is string stringResult) - { - return stringResult.ToUpperInvariant(); - } - return previousResult; - } - } } internal static class TestExtensionResults diff --git a/src/Http/Routing/src/Builder/DelegateRouteHandlerFilter.cs b/src/Http/Routing/src/Builder/DelegateRouteHandlerFilter.cs deleted file mode 100644 index 155a9b5e4b40..000000000000 --- a/src/Http/Routing/src/Builder/DelegateRouteHandlerFilter.cs +++ /dev/null @@ -1,19 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. - -namespace Microsoft.AspNetCore.Http; - -internal sealed class DelegateRouteHandlerFilter : IRouteHandlerFilter -{ - private readonly Func>, ValueTask> _routeHandlerFilter; - - internal DelegateRouteHandlerFilter(Func>, ValueTask> routeHandlerFilter) - { - _routeHandlerFilter = routeHandlerFilter; - } - - public ValueTask InvokeAsync(RouteHandlerFilterContext context, Func> next) - { - return _routeHandlerFilter(context, next); - } -} diff --git a/src/Http/Routing/src/Builder/EndpointRouteBuilderExtensions.cs b/src/Http/Routing/src/Builder/EndpointRouteBuilderExtensions.cs index 6ce2d6c2c7ea..d8990104cb12 100644 --- a/src/Http/Routing/src/Builder/EndpointRouteBuilderExtensions.cs +++ b/src/Http/Routing/src/Builder/EndpointRouteBuilderExtensions.cs @@ -523,7 +523,7 @@ private static RouteHandlerBuilder Map( RouteParameterNames = routeParams, ThrowOnBadRequest = routeHandlerOptions?.Value.ThrowOnBadRequest ?? false, DisableInferBodyFromParameters = disableInferBodyFromParameters, - RouteHandlerFilters = routeHandlerBuilder.RouteHandlerFilters + RouteHandlerFilterFactories = routeHandlerBuilder.RouteHandlerFilterFactories }; var filteredRequestDelegateResult = RequestDelegateFactory.Create(handler, options); // Add request delegate metadata diff --git a/src/Http/Routing/src/Builder/RouteHandlerBuilder.cs b/src/Http/Routing/src/Builder/RouteHandlerBuilder.cs index b42e22cc3d8d..735178ff4214 100644 --- a/src/Http/Routing/src/Builder/RouteHandlerBuilder.cs +++ b/src/Http/Routing/src/Builder/RouteHandlerBuilder.cs @@ -13,7 +13,7 @@ public sealed class RouteHandlerBuilder : IEndpointConventionBuilder private readonly IEnumerable? _endpointConventionBuilders; private readonly IEndpointConventionBuilder? _endpointConventionBuilder; - internal List RouteHandlerFilters { get; } = new(); + internal List> RouteHandlerFilterFactories { get; } = new(); /// /// Instantiates a new given a single diff --git a/src/Http/Routing/src/Builder/RouteHandlerFilterExtensions.cs b/src/Http/Routing/src/Builder/RouteHandlerFilterExtensions.cs index ffec088f3e73..fae8885ad24a 100644 --- a/src/Http/Routing/src/Builder/RouteHandlerFilterExtensions.cs +++ b/src/Http/Routing/src/Builder/RouteHandlerFilterExtensions.cs @@ -3,6 +3,7 @@ using System.Diagnostics.CodeAnalysis; using Microsoft.AspNetCore.Builder; +using Microsoft.Extensions.DependencyInjection; namespace Microsoft.AspNetCore.Http; @@ -19,7 +20,7 @@ public static class RouteHandlerFilterExtensions /// A that can be used to further customize the route handler. public static RouteHandlerBuilder AddFilter(this RouteHandlerBuilder builder, IRouteHandlerFilter filter) { - builder.RouteHandlerFilters.Add(filter); + builder.RouteHandlerFilterFactories.Add((routeHandlerContext, next) => (context) => filter.InvokeAsync(context, next)); return builder; } @@ -29,9 +30,14 @@ public static RouteHandlerBuilder AddFilter(this RouteHandlerBuilder builder, IR /// The type of the to register. /// The . /// A that can be used to further customize the route handler. - public static RouteHandlerBuilder AddFilter<[DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors)] TFilterType>(this RouteHandlerBuilder builder) where TFilterType : IRouteHandlerFilter, new() + public static RouteHandlerBuilder AddFilter<[DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors)] TFilterType>(this RouteHandlerBuilder builder) where TFilterType : IRouteHandlerFilter { - builder.RouteHandlerFilters.Add(new TFilterType()); + var filterFactory = ActivatorUtilities.CreateFactory(typeof(TFilterType), Type.EmptyTypes); + builder.RouteHandlerFilterFactories.Add((routeHandlerContext, next) => (context) => + { + var filter = (IRouteHandlerFilter)filterFactory.Invoke(context.HttpContext.RequestServices, Array.Empty()); + return filter.InvokeAsync(context, next); + }); return builder; } @@ -41,9 +47,21 @@ public static RouteHandlerBuilder AddFilter(this RouteHandlerBuilder builder, IR /// The . /// A representing the core logic of the filter. /// A that can be used to further customize the route handler. - public static RouteHandlerBuilder AddFilter(this RouteHandlerBuilder builder, Func>, ValueTask> routeHandlerFilter) + public static RouteHandlerBuilder AddFilter(this RouteHandlerBuilder builder, Func> routeHandlerFilter) { - builder.RouteHandlerFilters.Add(new DelegateRouteHandlerFilter(routeHandlerFilter)); + builder.RouteHandlerFilterFactories.Add((routeHandlerContext, next) => (context) => routeHandlerFilter(context, next)); + return builder; + } + + /// + /// Register a filter given a delegate representing the filter factory. + /// + /// The . + /// A representing the logic for constructing the filter. + /// A that can be used to further customize the route handler. + public static RouteHandlerBuilder AddFilter(this RouteHandlerBuilder builder, Func filterFactory) + { + builder.RouteHandlerFilterFactories.Add(filterFactory); return builder; } } diff --git a/src/Http/Routing/src/PublicAPI.Unshipped.txt b/src/Http/Routing/src/PublicAPI.Unshipped.txt index 4c2bc898a0b6..432fe7abd574 100644 --- a/src/Http/Routing/src/PublicAPI.Unshipped.txt +++ b/src/Http/Routing/src/PublicAPI.Unshipped.txt @@ -8,7 +8,8 @@ override Microsoft.AspNetCore.Routing.RouteValuesAddress.ToString() -> string? *REMOVED*~Microsoft.AspNetCore.Routing.DefaultInlineConstraintResolver.DefaultInlineConstraintResolver(Microsoft.Extensions.Options.IOptions! routeOptions, System.IServiceProvider! serviceProvider) -> void Microsoft.AspNetCore.Routing.DefaultInlineConstraintResolver.DefaultInlineConstraintResolver(Microsoft.Extensions.Options.IOptions! routeOptions, System.IServiceProvider! serviceProvider) -> void static Microsoft.AspNetCore.Http.RouteHandlerFilterExtensions.AddFilter(this Microsoft.AspNetCore.Builder.RouteHandlerBuilder! builder, Microsoft.AspNetCore.Http.IRouteHandlerFilter! filter) -> Microsoft.AspNetCore.Builder.RouteHandlerBuilder! -static Microsoft.AspNetCore.Http.RouteHandlerFilterExtensions.AddFilter(this Microsoft.AspNetCore.Builder.RouteHandlerBuilder! builder, System.Func>!, System.Threading.Tasks.ValueTask>! routeHandlerFilter) -> Microsoft.AspNetCore.Builder.RouteHandlerBuilder! +static Microsoft.AspNetCore.Http.RouteHandlerFilterExtensions.AddFilter(this Microsoft.AspNetCore.Builder.RouteHandlerBuilder! builder, System.Func! filterFactory) -> Microsoft.AspNetCore.Builder.RouteHandlerBuilder! +static Microsoft.AspNetCore.Http.RouteHandlerFilterExtensions.AddFilter(this Microsoft.AspNetCore.Builder.RouteHandlerBuilder! builder, System.Func>! routeHandlerFilter) -> Microsoft.AspNetCore.Builder.RouteHandlerBuilder! static Microsoft.AspNetCore.Http.RouteHandlerFilterExtensions.AddFilter(this Microsoft.AspNetCore.Builder.RouteHandlerBuilder! builder) -> Microsoft.AspNetCore.Builder.RouteHandlerBuilder! static Microsoft.AspNetCore.Http.OpenApiRouteHandlerBuilderExtensions.WithDescription(this Microsoft.AspNetCore.Builder.RouteHandlerBuilder! builder, string! description) -> Microsoft.AspNetCore.Builder.RouteHandlerBuilder! static Microsoft.AspNetCore.Http.OpenApiRouteHandlerBuilderExtensions.WithSummary(this Microsoft.AspNetCore.Builder.RouteHandlerBuilder! builder, string! summary) -> Microsoft.AspNetCore.Builder.RouteHandlerBuilder! diff --git a/src/Http/Routing/test/UnitTests/Builder/RouteHandlerEndpointRouteBuilderExtensionsTest.cs b/src/Http/Routing/test/UnitTests/Builder/RouteHandlerEndpointRouteBuilderExtensionsTest.cs index 1c3450159801..b9d4e586076f 100644 --- a/src/Http/Routing/test/UnitTests/Builder/RouteHandlerEndpointRouteBuilderExtensionsTest.cs +++ b/src/Http/Routing/test/UnitTests/Builder/RouteHandlerEndpointRouteBuilderExtensionsTest.cs @@ -6,13 +6,15 @@ using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Http.Metadata; using Microsoft.AspNetCore.Routing; +using Microsoft.AspNetCore.Testing; using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; using Microsoft.Extensions.Options; using Microsoft.Extensions.Primitives; namespace Microsoft.AspNetCore.Builder; -public class RouteHandlerEndpointRouteBuilderExtensionsTest +public class RouteHandlerEndpointRouteBuilderExtensionsTest : LoggedTest { private ModelEndpointDataSource GetBuilderEndpointDataSource(IEndpointRouteBuilder endpointRouteBuilder) { @@ -847,6 +849,154 @@ public async Task MapMethod_DefaultsToNotThrowOnBadHttpRequestIfItCannotResolveR Assert.Equal(400, httpContext.Response.StatusCode); } + public static object[][] AddFiltersByClassData = +{ + new object[] { (Action)((RouteHandlerBuilder builder) => builder.AddFilter(new IncrementArgFilter())) }, + new object[] { (Action)((RouteHandlerBuilder builder) => builder.AddFilter()) } + }; + + [Theory] + [MemberData(nameof(AddFiltersByClassData))] + public async Task AddFilterMethods_CanRegisterFilterWithClassImplementation(Action addFilter) + { + var builder = new DefaultEndpointRouteBuilder(new ApplicationBuilder(new ServiceCollection().BuildServiceProvider())); + + string PrintId(int id) => $"ID: {id}"; + var routeHandlerBuilder = builder.Map("/{id}", PrintId); + addFilter(routeHandlerBuilder); + + var dataSource = GetBuilderEndpointDataSource(builder); + // Trigger Endpoint build by calling getter. + var endpoint = Assert.Single(dataSource.Endpoints); + + var httpContext = new DefaultHttpContext(); + httpContext.Request.RouteValues["id"] = "2"; + var outStream = new MemoryStream(); + httpContext.Response.Body = outStream; + + await endpoint.RequestDelegate!(httpContext); + + // Assert; + var httpResponse = httpContext.Response; + httpResponse.Body.Seek(0, SeekOrigin.Begin); + var streamReader = new StreamReader(httpResponse.Body); + var body = streamReader.ReadToEndAsync().Result; + Assert.Equal(200, httpContext.Response.StatusCode); + Assert.Equal("ID: 3", body); + } + + public static object[][] AddFiltersByDelegateData + { + get + { + void WithFilter(RouteHandlerBuilder builder) => + builder.AddFilter(async (context, next) => + { + context.Parameters[0] = ((int)context.Parameters[0]!) + 1; + return await next(context); + }); + + void WithFilterFactory(RouteHandlerBuilder builder) => + builder.AddFilter((routeHandlerContext, next) => async (context) => + { + Assert.NotNull(routeHandlerContext.MethodInfo); + Assert.NotNull(routeHandlerContext.MethodInfo.DeclaringType); + Assert.Equal("RouteHandlerEndpointRouteBuilderExtensionsTest", routeHandlerContext.MethodInfo.DeclaringType?.Name); + context.Parameters[0] = ((int)context.Parameters[0]!) + 1; + return await next(context); + }); + + return new object[][] { + new object[] { (Action)WithFilter }, + new object[] { (Action)WithFilterFactory } + }; + } + } + + [Theory] + [MemberData(nameof(AddFiltersByDelegateData))] + public async Task AddFilterMethods_CanRegisterFilterWithDelegateImplementation(Action addFilter) + { + var builder = new DefaultEndpointRouteBuilder(new ApplicationBuilder(new ServiceCollection().BuildServiceProvider())); + + string PrintId(int id) => $"ID: {id}"; + var routeHandlerBuilder = builder.Map("/{id}", PrintId); + addFilter(routeHandlerBuilder); + + var dataSource = GetBuilderEndpointDataSource(builder); + // Trigger Endpoint build by calling getter. + var endpoint = Assert.Single(dataSource.Endpoints); + + var httpContext = new DefaultHttpContext(); + httpContext.Request.RouteValues["id"] = "2"; + var outStream = new MemoryStream(); + httpContext.Response.Body = outStream; + + await endpoint.RequestDelegate!(httpContext); + + // Assert; + var httpResponse = httpContext.Response; + httpResponse.Body.Seek(0, SeekOrigin.Begin); + var streamReader = new StreamReader(httpResponse.Body); + var body = streamReader.ReadToEndAsync().Result; + Assert.Equal(200, httpContext.Response.StatusCode); + Assert.Equal("ID: 3", body); + } + + [Fact] + public async Task RequestDelegateFactory_CanInvokeEndpointFilter_ThatAccessesServices() + { + var builder = new DefaultEndpointRouteBuilder(new ApplicationBuilder(new ServiceCollection().BuildServiceProvider())); + + string? PrintLogger(HttpContext context) => $"loggerErrorIsEnabled: {context.Items["loggerErrorIsEnabled"]}"; + var routeHandlerBuilder = builder.Map("/", PrintLogger); + routeHandlerBuilder.AddFilter(); + + var dataSource = GetBuilderEndpointDataSource(builder); + // Trigger Endpoint build by calling getter. + var endpoint = Assert.Single(dataSource.Endpoints); + + var httpContext = new DefaultHttpContext(); + var serviceCollection = new ServiceCollection(); + serviceCollection.AddSingleton(LoggerFactory); + httpContext.RequestServices = serviceCollection.BuildServiceProvider(); + var outStream = new MemoryStream(); + httpContext.Response.Body = outStream; + await endpoint.RequestDelegate!(httpContext); + + Assert.Equal(200, httpContext.Response.StatusCode); + var httpResponse = httpContext.Response; + httpResponse.Body.Seek(0, SeekOrigin.Begin); + var streamReader = new StreamReader(httpResponse.Body); + var body = streamReader.ReadToEndAsync().Result; + Assert.Equal("loggerErrorIsEnabled: True", body); + } + + class ServiceAccessingRouteHandlerFilter : IRouteHandlerFilter + { + private ILogger _logger; + + public ServiceAccessingRouteHandlerFilter(ILoggerFactory loggerFactory) + { + _logger = loggerFactory.CreateLogger(); + } + + public async ValueTask InvokeAsync(RouteHandlerInvocationContext context, RouteHandlerFilterDelegate next) + { + context.HttpContext.Items["loggerErrorIsEnabled"] = _logger.IsEnabled(LogLevel.Error); + return await next(context); + } + } + + class IncrementArgFilter : IRouteHandlerFilter + { + public async ValueTask InvokeAsync(RouteHandlerInvocationContext context, RouteHandlerFilterDelegate next) + { + context.Parameters[0] = ((int)context.Parameters[0]!) + 1; + return await next(context); + } + } + class FromRoute : Attribute, IFromRouteMetadata { public string? Name { get; set; }