diff --git a/src/Middleware/RateLimiting/src/Features/IRateLimiterContextFeature.cs b/src/Middleware/RateLimiting/src/Features/IRateLimiterContextFeature.cs new file mode 100644 index 000000000000..2e40f8957c61 --- /dev/null +++ b/src/Middleware/RateLimiting/src/Features/IRateLimiterContextFeature.cs @@ -0,0 +1,12 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using Microsoft.AspNetCore.Connections.Features; + +namespace Microsoft.AspNetCore.RateLimiting.Features; + +public interface IRateLimiterContextFeature +{ + RateLimiterContext Context { get; } +} + diff --git a/src/Middleware/RateLimiting/src/Features/RateLimiterContext.cs b/src/Middleware/RateLimiting/src/Features/RateLimiterContext.cs new file mode 100644 index 000000000000..c2e4f0c19935 --- /dev/null +++ b/src/Middleware/RateLimiting/src/Features/RateLimiterContext.cs @@ -0,0 +1,18 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Threading.RateLimiting; +using Microsoft.AspNetCore.Http; + +namespace Microsoft.AspNetCore.RateLimiting.Features; + +public class RateLimiterContext +{ + public required HttpContext HttpContext { get; set; } + + public required RateLimitLease Lease { get; set; } + + public required PartitionedRateLimiter? GlobalLimiter { get; set; } + + public required PartitionedRateLimiter EndpointLimiter { get; set; } +} diff --git a/src/Middleware/RateLimiting/src/Features/RateLimiterContextFeature.cs b/src/Middleware/RateLimiting/src/Features/RateLimiterContextFeature.cs new file mode 100644 index 000000000000..9bbafbd7556a --- /dev/null +++ b/src/Middleware/RateLimiting/src/Features/RateLimiterContextFeature.cs @@ -0,0 +1,22 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; + +namespace Microsoft.AspNetCore.RateLimiting.Features; + +public class RateLimiterContextFeature : IRateLimiterContextFeature +{ + private readonly RateLimiterContext _context; + + public RateLimiterContextFeature(RateLimiterContext context) + { + _context = context; + } + + public RateLimiterContext Context => _context; +} diff --git a/src/Middleware/RateLimiting/src/PublicAPI.Unshipped.txt b/src/Middleware/RateLimiting/src/PublicAPI.Unshipped.txt index 9e18f7083b21..b55b41a92b91 100644 --- a/src/Middleware/RateLimiting/src/PublicAPI.Unshipped.txt +++ b/src/Middleware/RateLimiting/src/PublicAPI.Unshipped.txt @@ -6,6 +6,17 @@ Microsoft.AspNetCore.RateLimiting.DisableRateLimitingAttribute.DisableRateLimiti Microsoft.AspNetCore.RateLimiting.EnableRateLimitingAttribute Microsoft.AspNetCore.RateLimiting.EnableRateLimitingAttribute.EnableRateLimitingAttribute(string! policyName) -> void Microsoft.AspNetCore.RateLimiting.EnableRateLimitingAttribute.PolicyName.get -> string? +Microsoft.AspNetCore.RateLimiting.Features.IRateLimiterContextFeature +Microsoft.AspNetCore.RateLimiting.Features.IRateLimiterContextFeature.Context.get -> Microsoft.AspNetCore.RateLimiting.Features.RateLimiterContext! +Microsoft.AspNetCore.RateLimiting.Features.RateLimiterContext +Microsoft.AspNetCore.RateLimiting.Features.RateLimiterContext.EndpointLimiter.get -> System.Threading.RateLimiting.PartitionedRateLimiter! +Microsoft.AspNetCore.RateLimiting.Features.RateLimiterContext.GlobalLimiter.get -> System.Threading.RateLimiting.PartitionedRateLimiter? +Microsoft.AspNetCore.RateLimiting.Features.RateLimiterContext.HttpContext.get -> Microsoft.AspNetCore.Http.HttpContext! +Microsoft.AspNetCore.RateLimiting.Features.RateLimiterContext.HttpContext.set -> void +Microsoft.AspNetCore.RateLimiting.Features.RateLimiterContext.Lease.get -> System.Threading.RateLimiting.RateLimitLease! +Microsoft.AspNetCore.RateLimiting.Features.RateLimiterContextFeature +Microsoft.AspNetCore.RateLimiting.Features.RateLimiterContextFeature.Context.get -> Microsoft.AspNetCore.RateLimiting.Features.RateLimiterContext! +Microsoft.AspNetCore.RateLimiting.Features.RateLimiterContextFeature.RateLimiterContextFeature(Microsoft.AspNetCore.RateLimiting.Features.RateLimiterContext! context) -> void Microsoft.AspNetCore.RateLimiting.IRateLimiterPolicy Microsoft.AspNetCore.RateLimiting.IRateLimiterPolicy.GetPartition(Microsoft.AspNetCore.Http.HttpContext! httpContext) -> System.Threading.RateLimiting.RateLimitPartition Microsoft.AspNetCore.RateLimiting.IRateLimiterPolicy.OnRejected.get -> System.Func? diff --git a/src/Middleware/RateLimiting/src/RateLimitingMiddleware.cs b/src/Middleware/RateLimiting/src/RateLimitingMiddleware.cs index d743f77feea6..f45b74ebdbaf 100644 --- a/src/Middleware/RateLimiting/src/RateLimitingMiddleware.cs +++ b/src/Middleware/RateLimiting/src/RateLimitingMiddleware.cs @@ -3,6 +3,7 @@ using System.Threading.RateLimiting; using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.RateLimiting.Features; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Options; @@ -75,9 +76,25 @@ public Task Invoke(HttpContext context) return InvokeInternal(context, enableRateLimitingAttribute); } + private void AddRateLimiterHttpContextFeature(HttpContext context, LeaseContext lease) + { + var rlContext = new RateLimiterContext() + { + EndpointLimiter = _endpointLimiter, + GlobalLimiter = _globalLimiter, + HttpContext = context, + Lease = lease.Lease + }; + + context.Features.Set(new RateLimiterContextFeature(rlContext)); + } + private async Task InvokeInternal(HttpContext context, EnableRateLimitingAttribute? enableRateLimitingAttribute) { using var leaseContext = await TryAcquireAsync(context); + + AddRateLimiterHttpContextFeature(context, leaseContext); + if (leaseContext.Lease?.IsAcquired == true) { await _next(context); @@ -253,4 +270,4 @@ private static partial class RateLimiterLog [LoggerMessage(3, LogLevel.Debug, "The request was canceled.", EventName = "RequestCanceled")] internal static partial void RequestCanceled(ILogger logger); } -} \ No newline at end of file +} diff --git a/src/Middleware/RateLimiting/test/RateLimitingMiddlewareTests.cs b/src/Middleware/RateLimiting/test/RateLimitingMiddlewareTests.cs index 24c6ccd93990..b56530442a74 100644 --- a/src/Middleware/RateLimiting/test/RateLimitingMiddlewareTests.cs +++ b/src/Middleware/RateLimiting/test/RateLimitingMiddlewareTests.cs @@ -4,6 +4,7 @@ using System.Threading.RateLimiting; using Microsoft.AspNetCore.Builder; using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.RateLimiting.Features; using Microsoft.AspNetCore.Testing; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging.Abstractions; @@ -606,6 +607,112 @@ public async Task MultipleEndpointPolicies_LastOneWins() Assert.Equal(StatusCodes.Status403Forbidden, context.Response.StatusCode); } + [Fact] + public async Task RateLimiterContextFeature_PropertiesInitialized() + { + // Arrange + const string policy = "test_policy"; + var options = CreateOptionsAccessor(); + + options.Value.GlobalLimiter = PartitionedRateLimiter.Create(context => + { + return RateLimitPartition.GetConcurrencyLimiter("globalLimiter", key => new ConcurrencyLimiterOptions + { + PermitLimit = 10, + QueueProcessingOrder = QueueProcessingOrder.NewestFirst, + QueueLimit = 5 + }); + }); + options.Value.AddConcurrencyLimiter(policy, configure => + { + configure.QueueLimit = 10; + configure.PermitLimit = 10; + }); + + var middleware = CreateTestRateLimitingMiddleware(options); + var context = new DefaultHttpContext(); + var endpoint = CreateEndpointWithRateLimitPolicy(policy); + context.SetEndpoint(endpoint); + + //Act + await middleware.Invoke(context).DefaultTimeout(); + + // Assert + var rlContext = context.Features.Get(); + + Assert.NotNull(rlContext); + Assert.NotNull(rlContext.Context); + Assert.NotNull(rlContext.Context.GlobalLimiter); + Assert.NotNull(rlContext.Context.EndpointLimiter); + Assert.NotNull(rlContext.Context.HttpContext); + } + + [Fact] + public async Task RateLimiterContextFeature_GetsCorrectGlobalRateLimiterInfo_WithSingleRequest() + { + // Arrange + var options = CreateOptionsAccessor(); + + options.Value.GlobalLimiter = PartitionedRateLimiter.Create(context => + { + return RateLimitPartition.GetConcurrencyLimiter("globalLimiter", key => new ConcurrencyLimiterOptions + { + PermitLimit = 10, + QueueProcessingOrder = QueueProcessingOrder.NewestFirst, + QueueLimit = 5 + }); + }); + + var middleware = CreateTestRateLimitingMiddleware(options); + var context = new DefaultHttpContext(); + + //Act + await middleware.Invoke(context).DefaultTimeout(); + + // Assert + var rlContext = context.Features.Get(); + var statistics = rlContext.Context.GlobalLimiter.GetStatistics(context); + + Assert.NotNull(statistics); + Assert.True(rlContext.Context.Lease.IsAcquired); + Assert.Equal(1, statistics.TotalSuccessfulLeases); + Assert.Equal(0, statistics.TotalFailedLeases); + Assert.Equal(10, statistics.CurrentAvailablePermits); + } + + [Fact] + public async Task GetsCorrectEndpointRateLimiterInfo_With3Requests() + { + // Arrange + const string policy = "test_policy"; + var options = CreateOptionsAccessor(); + + options.Value.AddConcurrencyLimiter(policy, configure => + { + configure.QueueLimit = 10; + configure.PermitLimit = 10; + }); + + var middleware = CreateTestRateLimitingMiddleware(options); + + var context = new DefaultHttpContext(); + var endpoint = CreateEndpointWithRateLimitPolicy(policy); + context.SetEndpoint(endpoint); + + // Act + await middleware.Invoke(context).DefaultTimeout(); + await middleware.Invoke(context).DefaultTimeout(); + await middleware.Invoke(context).DefaultTimeout(); + + // Assert + var rlContext = context.Features.Get(); + var statistics = rlContext.Context.EndpointLimiter.GetStatistics(context); + + Assert.NotNull(statistics); + Assert.True(rlContext.Context.Lease.IsAcquired); + Assert.True(statistics.TotalSuccessfulLeases == 3); + } + private Endpoint CreateEndpointWithRateLimitPolicy(IRateLimiterPolicy policy) { var endpointBuilder = new TestEndpointBuilder();