diff --git a/src/DefaultBuilder/src/WebApplicationBuilder.cs b/src/DefaultBuilder/src/WebApplicationBuilder.cs index dde62750af8d..a6753d867034 100644 --- a/src/DefaultBuilder/src/WebApplicationBuilder.cs +++ b/src/DefaultBuilder/src/WebApplicationBuilder.cs @@ -18,6 +18,7 @@ namespace Microsoft.AspNetCore.Builder public sealed class WebApplicationBuilder { private const string EndpointRouteBuilderKey = "__EndpointRouteBuilder"; + private const string WebApplicationBuilderKey = "__WebApplicationBuilder"; private readonly HostBuilder _hostBuilder = new(); private readonly BootstrapHostBuilder _bootstrapHostBuilder; @@ -170,6 +171,8 @@ public WebApplication Build() ((IConfigurationBuilder)Configuration).Sources.Clear(); Configuration.AddConfiguration(_builtApplication.Configuration); + _builtApplication.Properties[WebApplicationBuilderKey] = true; + // Mark the service collection as read-only to prevent future modifications _services.IsReadOnly = true; @@ -206,6 +209,10 @@ private void ConfigureApplication(WebHostBuilderContext context, IApplicationBui // An implicitly created IEndpointRouteBuilder was addeded to app.Properties by the UseRouting() call above. targetRouteBuilder = GetEndpointRouteBuilder(app)!; + + // Copy the endpoint route builder to the built application + _builtApplication.Properties[EndpointRouteBuilderKey] = targetRouteBuilder; + implicitRouting = true; } diff --git a/src/Http/Routing/src/Builder/EndpointRoutingApplicationBuilderExtensions.cs b/src/Http/Routing/src/Builder/EndpointRoutingApplicationBuilderExtensions.cs index 07466e2d16a2..8d823da32dc3 100644 --- a/src/Http/Routing/src/Builder/EndpointRoutingApplicationBuilderExtensions.cs +++ b/src/Http/Routing/src/Builder/EndpointRoutingApplicationBuilderExtensions.cs @@ -14,7 +14,7 @@ namespace Microsoft.AspNetCore.Builder /// public static class EndpointRoutingApplicationBuilderExtensions { - private const string EndpointRouteBuilder = "__EndpointRouteBuilder"; + internal const string EndpointRouteBuilder = "__EndpointRouteBuilder"; /// /// Adds a middleware to the specified . @@ -42,10 +42,48 @@ public static IApplicationBuilder UseRouting(this IApplicationBuilder builder) throw new ArgumentNullException(nameof(builder)); } + return UseRouting(builder, true); + } + + /// + /// Adds a middleware to the specified . + /// + /// The to add the middleware to. + /// Whether a new should be created. + /// A reference to this instance after the operation has completed. + /// + /// + /// A call to must be followed by a call to + /// for the same + /// instance. + /// + /// + /// The defines a point in the middleware pipeline where routing decisions are + /// made, and an is associated with the . The + /// defines a point in the middleware pipeline where the current is executed. Middleware between + /// the and may observe or change the + /// associated with the . + /// + /// + public static IApplicationBuilder UseRouting(this IApplicationBuilder builder, bool overrideEndpointRouteBuilder) + { + if (builder == null) + { + throw new ArgumentNullException(nameof(builder)); + } + VerifyRoutingServicesAreRegistered(builder); - var endpointRouteBuilder = new DefaultEndpointRouteBuilder(builder); - builder.Properties[EndpointRouteBuilder] = endpointRouteBuilder; + IEndpointRouteBuilder endpointRouteBuilder; + if (overrideEndpointRouteBuilder || !builder.Properties.TryGetValue(EndpointRouteBuilder, out var routeBuilder)) + { + endpointRouteBuilder = new DefaultEndpointRouteBuilder(builder); + builder.Properties[EndpointRouteBuilder] = endpointRouteBuilder; + } + else + { + endpointRouteBuilder = (IEndpointRouteBuilder)routeBuilder!; + } return builder.UseMiddleware(endpointRouteBuilder); } diff --git a/src/Http/Routing/src/PublicAPI.Unshipped.txt b/src/Http/Routing/src/PublicAPI.Unshipped.txt index 99913da6a3da..f397faa061c6 100644 --- a/src/Http/Routing/src/PublicAPI.Unshipped.txt +++ b/src/Http/Routing/src/PublicAPI.Unshipped.txt @@ -14,6 +14,7 @@ Microsoft.AspNetCore.Routing.IDataTokensMetadata.DataTokens.get -> System.Collec Microsoft.AspNetCore.Routing.IRouteNameMetadata.RouteName.get -> string? Microsoft.AspNetCore.Routing.RouteNameMetadata.RouteName.get -> string? Microsoft.AspNetCore.Routing.RouteNameMetadata.RouteNameMetadata(string? routeName) -> void +static Microsoft.AspNetCore.Builder.EndpointRoutingApplicationBuilderExtensions.UseRouting(this Microsoft.AspNetCore.Builder.IApplicationBuilder! builder, bool overrideEndpointRouteBuilder) -> Microsoft.AspNetCore.Builder.IApplicationBuilder! static Microsoft.AspNetCore.Builder.MinimalActionEndpointRouteBuilderExtensions.Map(this Microsoft.AspNetCore.Routing.IEndpointRouteBuilder! endpoints, Microsoft.AspNetCore.Routing.Patterns.RoutePattern! pattern, System.Delegate! action) -> Microsoft.AspNetCore.Builder.MinimalActionEndpointConventionBuilder! static Microsoft.AspNetCore.Builder.MinimalActionEndpointRouteBuilderExtensions.Map(this Microsoft.AspNetCore.Routing.IEndpointRouteBuilder! endpoints, string! pattern, System.Delegate! action) -> Microsoft.AspNetCore.Builder.MinimalActionEndpointConventionBuilder! static Microsoft.AspNetCore.Builder.MinimalActionEndpointRouteBuilderExtensions.MapDelete(this Microsoft.AspNetCore.Routing.IEndpointRouteBuilder! endpoints, string! pattern, System.Delegate! action) -> Microsoft.AspNetCore.Builder.MinimalActionEndpointConventionBuilder! diff --git a/src/Http/Routing/test/UnitTests/Builder/EndpointRoutingApplicationBuilderExtensionsTest.cs b/src/Http/Routing/test/UnitTests/Builder/EndpointRoutingApplicationBuilderExtensionsTest.cs index 4b9d875d0a1c..28248c95fab3 100644 --- a/src/Http/Routing/test/UnitTests/Builder/EndpointRoutingApplicationBuilderExtensionsTest.cs +++ b/src/Http/Routing/test/UnitTests/Builder/EndpointRoutingApplicationBuilderExtensionsTest.cs @@ -107,6 +107,72 @@ public async Task UseRouting_ServicesRegistered_Match_DoesNotSetsFeature() Assert.Same(endpoint, httpContext.GetEndpoint()); } + [Fact] + public void UseRouting_Default_CreatesEndpointRouteBuilder() + { + // Arrange + var services = CreateServices(); + var app = new ApplicationBuilder(services); + + // Assert + Assert.False(app.Properties.ContainsKey(EndpointRoutingApplicationBuilderExtensions.EndpointRouteBuilder)); + + // Act + app.UseRouting(); + + // Assert + Assert.NotNull(app.Properties[EndpointRoutingApplicationBuilderExtensions.EndpointRouteBuilder]); + } + + [Fact] + public void UseRouting_Default_OverridesEndpointRouteBuilder() + { + // Arrange + var services = CreateServices(); + var app = new ApplicationBuilder(services); + var endpointRouteBuilder = new DefaultEndpointRouteBuilder(app); + app.Properties[EndpointRoutingApplicationBuilderExtensions.EndpointRouteBuilder] = endpointRouteBuilder; + + // Act + app.UseRouting(); + + // Assert + Assert.NotSame(endpointRouteBuilder, app.Properties[EndpointRoutingApplicationBuilderExtensions.EndpointRouteBuilder]); + } + + [Fact] + public void UseRouting_OverrideEndpointRouteBuilderFalse_CreatesEndpointRouteBuilderIfNotFound() + { + // Arrange + var services = CreateServices(); + var app = new ApplicationBuilder(services); + + // Assert + Assert.False(app.Properties.ContainsKey(EndpointRoutingApplicationBuilderExtensions.EndpointRouteBuilder)); + + // Act + app.UseRouting(overrideEndpointRouteBuilder: false); + + // Assert + Assert.NotNull(app.Properties[EndpointRoutingApplicationBuilderExtensions.EndpointRouteBuilder]); + } + + [Fact] + public void UseRouting_OverrideEndpointRouteBuilderFalse_UsesExistingEndpointRouteBuilderIfFound() + { + // Arrange + var services = CreateServices(); + var app = new ApplicationBuilder(services); + var endpointRouteBuilder = new DefaultEndpointRouteBuilder(app); + app.Properties[EndpointRoutingApplicationBuilderExtensions.EndpointRouteBuilder] = endpointRouteBuilder; + + // Act + app.UseRouting(overrideEndpointRouteBuilder: false); + + // Assert + Assert.Same(endpointRouteBuilder, app.Properties[EndpointRoutingApplicationBuilderExtensions.EndpointRouteBuilder]); + } + [Fact] public void UseEndpoint_WithoutEndpointRoutingMiddleware_Throws() { diff --git a/src/Middleware/Diagnostics/src/ExceptionHandler/ExceptionHandlerExtensions.cs b/src/Middleware/Diagnostics/src/ExceptionHandler/ExceptionHandlerExtensions.cs index 996cd4703ef9..7dc1a787aa69 100644 --- a/src/Middleware/Diagnostics/src/ExceptionHandler/ExceptionHandlerExtensions.cs +++ b/src/Middleware/Diagnostics/src/ExceptionHandler/ExceptionHandlerExtensions.cs @@ -2,8 +2,12 @@ // The .NET Foundation licenses this file to you under the MIT license. using System; +using System.Diagnostics; using Microsoft.AspNetCore.Diagnostics; using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Routing; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; using Microsoft.Extensions.Options; namespace Microsoft.AspNetCore.Builder @@ -17,6 +21,9 @@ public static class ExceptionHandlerExtensions /// Adds a middleware to the pipeline that will catch exceptions, log them, and re-execute the request in an alternate pipeline. /// The request will not be re-executed if the response has already started. /// + /// + /// This overload requires you to configure options with . + /// /// /// public static IApplicationBuilder UseExceptionHandler(this IApplicationBuilder app) @@ -95,7 +102,33 @@ public static IApplicationBuilder UseExceptionHandler(this IApplicationBuilder a throw new ArgumentNullException(nameof(options)); } - return app.UseMiddleware(Options.Create(options)); + // UseRouting called before this middleware or Minimal + if (app.Properties.ContainsKey("__EndpointRouteBuilder") || app.Properties.ContainsKey("__WebApplicationBuilder")) + { + return app.Use(next => + { + var loggerFactory = app.ApplicationServices.GetRequiredService(); + var diagnosticListener = app.ApplicationServices.GetRequiredService(); + + if (!string.IsNullOrEmpty(options.ExceptionHandlingPath) && options.ExceptionHandler is null) + { + // start a new middleware pipeline + var builder = app.New(); + // use the old routing pipeline if it exists so we preserve all the routes and matching logic + builder.UseRouting(overrideEndpointRouteBuilder: false); + // apply the next middleware + builder.Run(next); + // store the pipeline for the error case + options.ExceptionHandler = builder.Build(); + } + + return new ExceptionHandlerMiddleware(next, loggerFactory, Options.Create(options), diagnosticListener).Invoke; + }); + } + else + { + return app.UseMiddleware(Options.Create(options)); + } } } } diff --git a/src/Middleware/Diagnostics/test/UnitTests/ExceptionHandlerTest.cs b/src/Middleware/Diagnostics/test/UnitTests/ExceptionHandlerTest.cs index 3e65cf7a4a39..12e5ec09908d 100644 --- a/src/Middleware/Diagnostics/test/UnitTests/ExceptionHandlerTest.cs +++ b/src/Middleware/Diagnostics/test/UnitTests/ExceptionHandlerTest.cs @@ -657,5 +657,45 @@ public async Task ExceptionHandler_CanReturn404Responses_WhenAllowed() && w.EventId == 4 && w.Message == "No exception handler was found, rethrowing original exception."); } + + [Fact] + public async Task ExceptionHandler_RerunsRoutingOnError_WhenConfiguredWithExceptionHandlingPath() + { + using var host = new HostBuilder() + .ConfigureWebHost(webHostBuilder => + { + webHostBuilder + .UseTestServer() + .ConfigureServices(services => + { + services.AddRouting(); + }) + .Configure(app => + { + app.UseRouting(); + + app.UseExceptionHandler("/error"); + + app.UseEndpoints(endpoints => + { + endpoints.MapGet("/error", context => + { + return context.Response.WriteAsync("Exception handled"); + }); + endpoints.MapGet("/throw", context => throw new InvalidOperationException("Something bad happened.")); + }); + }); + }).Build(); + + await host.StartAsync(); + + using (var server = host.GetTestServer()) + { + var client = server.CreateClient(); + var response = await client.GetAsync("throw"); + Assert.Equal(HttpStatusCode.InternalServerError, response.StatusCode); + Assert.Equal("Exception handled", await response.Content.ReadAsStringAsync()); + } + } } } diff --git a/src/Middleware/Rewrite/src/Microsoft.AspNetCore.Rewrite.csproj b/src/Middleware/Rewrite/src/Microsoft.AspNetCore.Rewrite.csproj index 7ea6d0556c2c..f2d2f2233038 100644 --- a/src/Middleware/Rewrite/src/Microsoft.AspNetCore.Rewrite.csproj +++ b/src/Middleware/Rewrite/src/Microsoft.AspNetCore.Rewrite.csproj @@ -20,6 +20,7 @@ + diff --git a/src/Middleware/Rewrite/src/RewriteBuilderExtensions.cs b/src/Middleware/Rewrite/src/RewriteBuilderExtensions.cs index 874fc0f68531..877dce552d42 100644 --- a/src/Middleware/Rewrite/src/RewriteBuilderExtensions.cs +++ b/src/Middleware/Rewrite/src/RewriteBuilderExtensions.cs @@ -45,6 +45,25 @@ public static IApplicationBuilder UseRewriter(this IApplicationBuilder app, Rewr throw new ArgumentNullException(nameof(options)); } + // UseRouting called before this middleware or Minimal + if (app.Properties.ContainsKey("__EndpointRouteBuilder") || app.Properties.ContainsKey("__WebApplicationBuilder")) + { + return app.Use(next => + { + // start a new middleware pipeline + var sub = app.New(); + // insert the rewrite middleware before routing so any path changes will be matched correctly + sub.UseMiddleware(Options.Create(options)); + // use the old routing pipeline if it exists so we preserve all the routes and matching logic + sub.UseRouting(overrideEndpointRouteBuilder: false); + // apply the next middleware + sub.Run(next); + // return the modified middleware + var nextWithRewriteAndRouting = sub.Build(); + return nextWithRewriteAndRouting.Invoke; + }); + } + // put middleware in pipeline return app.UseMiddleware(Options.Create(options)); } diff --git a/src/Middleware/Rewrite/src/RewriteMiddleware.cs b/src/Middleware/Rewrite/src/RewriteMiddleware.cs index bfb28fc0fe06..ca303c81da92 100644 --- a/src/Middleware/Rewrite/src/RewriteMiddleware.cs +++ b/src/Middleware/Rewrite/src/RewriteMiddleware.cs @@ -73,6 +73,12 @@ public Task Invoke(HttpContext context) Result = RuleResult.ContinueRules }; + // TODO: only do this if a rule applies + if (context.GetEndpoint() is not null) + { + context.SetEndpoint(null); + } + foreach (var rule in _options.Rules) { rule.ApplyRule(rewriteContext); diff --git a/src/Middleware/Rewrite/test/MiddlewareTests.cs b/src/Middleware/Rewrite/test/MiddlewareTests.cs index cd38a0a00963..85fa1935b769 100644 --- a/src/Middleware/Rewrite/test/MiddlewareTests.cs +++ b/src/Middleware/Rewrite/test/MiddlewareTests.cs @@ -680,5 +680,46 @@ public async Task CheckRedirectToWwwWithStatusCodeInWhitelistedDomains(int statu Assert.Equal(statusCode, (int)response.StatusCode); } + [Fact] + public async Task Rewrite_RerunsRouting_WhenConfiguredAfterRouting() + { + var options = new RewriteOptions().AddRewrite("(.*)", "http://example.com/g", skipRemainingRules: false); + using var host = new HostBuilder() + .ConfigureWebHost(webHostBuilder => + { + webHostBuilder + .UseTestServer() + .ConfigureServices(services => + { + services.AddRouting(); + }) + .Configure(app => + { + app.UseRouting(); + app.UseRewriter(options); + + app.UseEndpoints(endpoints => + { + endpoints.MapGet("/foo", context => context.Response.WriteAsync( + "bad")); + + endpoints.MapGet("/g", context => context.Response.WriteAsync( + context.Request.Scheme + + "://" + + context.Request.Host + + context.Request.Path + + context.Request.QueryString)); + }); + }); + }).Build(); + + await host.StartAsync(); + + var server = host.GetTestServer(); + + var response = await server.CreateClient().GetStringAsync("foo"); + + Assert.Equal("http://example.com/g", response); + } } }