diff --git a/src/Middleware/RateLimiting/src/DefaultRateLimiterStatisticsFeature.cs b/src/Middleware/RateLimiting/src/DefaultRateLimiterStatisticsFeature.cs new file mode 100644 index 000000000000..9902ac8d73ba --- /dev/null +++ b/src/Middleware/RateLimiting/src/DefaultRateLimiterStatisticsFeature.cs @@ -0,0 +1,28 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Threading.RateLimiting; +using Microsoft.AspNetCore.Http; + +namespace Microsoft.AspNetCore.RateLimiting; + +internal sealed class DefaultRateLimiterStatisticsFeature : IRateLimiterStatisticsFeature +{ + private readonly PartitionedRateLimiter? _globalLimiter; + private readonly PartitionedRateLimiter _endpointLimiter; + private readonly HttpContext _httpContext; + + public DefaultRateLimiterStatisticsFeature( + PartitionedRateLimiter? globalLimiter, + PartitionedRateLimiter endpointLimiter, + HttpContext context) + { + _globalLimiter = globalLimiter; + _endpointLimiter = endpointLimiter; + _httpContext = context; + } + + public RateLimiterStatistics? GetEndpointStatistics() => _endpointLimiter.GetStatistics(_httpContext); + + public RateLimiterStatistics? GetGlobalStatistics() => _globalLimiter?.GetStatistics(_httpContext); +} diff --git a/src/Middleware/RateLimiting/src/IRateLimiterStatisticsFeature.cs b/src/Middleware/RateLimiting/src/IRateLimiterStatisticsFeature.cs new file mode 100644 index 000000000000..c79753554ecd --- /dev/null +++ b/src/Middleware/RateLimiting/src/IRateLimiterStatisticsFeature.cs @@ -0,0 +1,29 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Threading.RateLimiting; +using Microsoft.AspNetCore.Http; + +namespace Microsoft.AspNetCore.RateLimiting; + +/// +/// An Interface which is used to represent statistics methods for global and endpoint limiters. +/// Obtained via . +/// +/// +/// Requires to be true. +/// +public interface IRateLimiterStatisticsFeature +{ + /// + /// Method to fetch for the global + /// + /// for the global . + RateLimiterStatistics? GetGlobalStatistics(); + /// + /// Method to fetch for the endpoints + /// + /// for the endpoints . + RateLimiterStatistics? GetEndpointStatistics(); + +} diff --git a/src/Middleware/RateLimiting/src/PublicAPI.Unshipped.txt b/src/Middleware/RateLimiting/src/PublicAPI.Unshipped.txt index 7dc5c58110bf..63f0e01cea3f 100644 --- a/src/Middleware/RateLimiting/src/PublicAPI.Unshipped.txt +++ b/src/Middleware/RateLimiting/src/PublicAPI.Unshipped.txt @@ -1 +1,6 @@ #nullable enable +Microsoft.AspNetCore.RateLimiting.IRateLimiterStatisticsFeature +Microsoft.AspNetCore.RateLimiting.IRateLimiterStatisticsFeature.GetEndpointStatistics() -> System.Threading.RateLimiting.RateLimiterStatistics? +Microsoft.AspNetCore.RateLimiting.IRateLimiterStatisticsFeature.GetGlobalStatistics() -> System.Threading.RateLimiting.RateLimiterStatistics? +Microsoft.AspNetCore.RateLimiting.RateLimiterOptions.TrackStatistics.get -> bool +Microsoft.AspNetCore.RateLimiting.RateLimiterOptions.TrackStatistics.set -> void diff --git a/src/Middleware/RateLimiting/src/RateLimiterOptions.cs b/src/Middleware/RateLimiting/src/RateLimiterOptions.cs index 9872e1ca788a..8aa1b75ca84c 100644 --- a/src/Middleware/RateLimiting/src/RateLimiterOptions.cs +++ b/src/Middleware/RateLimiting/src/RateLimiterOptions.cs @@ -40,6 +40,14 @@ public sealed class RateLimiterOptions /// public int RejectionStatusCode { get; set; } = StatusCodes.Status503ServiceUnavailable; + /// + /// Gets or sets whether to track global and endpoint . + /// + /// + /// If enabled, adds to . + /// + public bool TrackStatistics { get; set; } + /// /// Adds a new rate limiting policy with the given /// diff --git a/src/Middleware/RateLimiting/src/RateLimitingMiddleware.cs b/src/Middleware/RateLimiting/src/RateLimitingMiddleware.cs index d743f77feea6..b32eee04c289 100644 --- a/src/Middleware/RateLimiting/src/RateLimitingMiddleware.cs +++ b/src/Middleware/RateLimiting/src/RateLimitingMiddleware.cs @@ -19,6 +19,7 @@ internal sealed partial class RateLimitingMiddleware private readonly PartitionedRateLimiter? _globalLimiter; private readonly PartitionedRateLimiter _endpointLimiter; private readonly int _rejectionStatusCode; + private readonly bool _trackStatistics; private readonly Dictionary _policyMap; private readonly DefaultKeyType _defaultPolicyKey = new DefaultKeyType("__defaultPolicy", new PolicyNameKey { PolicyName = "__defaultPolicyKey" }); @@ -39,6 +40,7 @@ public RateLimitingMiddleware(RequestDelegate next, ILogger(options.Value.PolicyMap); // Activate policies passed to AddPolicy @@ -49,7 +51,6 @@ public RateLimitingMiddleware(RequestDelegate next, ILogger CreateEndpointLimiter() }, new DefaultKeyTypeEqualityComparer()); } + private void AddRateLimiterStatisticsFeature(HttpContext context) + { + context.Features.Set(new DefaultRateLimiterStatisticsFeature(_globalLimiter, _endpointLimiter, context)); + } + private static partial class RateLimiterLog { [LoggerMessage(1, LogLevel.Debug, "Rate limits exceeded, rejecting this request.", EventName = "RequestRejectedLimitsExceeded")] @@ -253,4 +265,4 @@ private static partial class RateLimiterLog [LoggerMessage(3, LogLevel.Debug, "The request was canceled.", EventName = "RequestCanceled")] internal static partial void RequestCanceled(ILogger logger); } -} \ No newline at end of file +} diff --git a/src/Middleware/RateLimiting/test/RateLimitingMiddlewareTests.cs b/src/Middleware/RateLimiting/test/RateLimitingMiddlewareTests.cs index 24c6ccd93990..0cf985fdbd4f 100644 --- a/src/Middleware/RateLimiting/test/RateLimitingMiddlewareTests.cs +++ b/src/Middleware/RateLimiting/test/RateLimitingMiddlewareTests.cs @@ -606,6 +606,75 @@ public async Task MultipleEndpointPolicies_LastOneWins() Assert.Equal(StatusCodes.Status403Forbidden, context.Response.StatusCode); } + [Fact] + public async Task StatisticsFeature_NotTracked() + { + // Arrange + var options = CreateOptionsAccessor(); + var policy = new TestRateLimiterPolicy("myKey1", 404, false); + + var middleware = CreateTestRateLimitingMiddleware(options); + + var context = new DefaultHttpContext(); + var endpoint = CreateEndpointWithRateLimitPolicy(policy); + context.SetEndpoint(endpoint); + + // Act + await middleware.Invoke(context).DefaultTimeout(); + + // Assert + Assert.Null(context.Features.Get()); + } + + [Fact] + public async Task StatisticsFeature_SuccessfullyTracked() + { + // Arrange + var options = CreateOptionsAccessor(trackStatistics: true); + + var policy = new TestRateLimiterPolicy("myKey1", 404, false); + + var middleware = CreateTestRateLimitingMiddleware(options); + + var context = new DefaultHttpContext(); + var endpoint = CreateEndpointWithRateLimitPolicy(policy); + context.SetEndpoint(endpoint); + + // Act + await middleware.Invoke(context).DefaultTimeout(); + + // Assert + Assert.NotNull(context.Features.Get()); + } + + [Fact] + public async Task StatisticsFeature_GetsStatistics_ForGlobalAndEndpointLimiter() + { + // Arrange + var options = CreateOptionsAccessor(trackStatistics: true); + + var globalStatistics = new RateLimiterStatistics(); + var endpointStatistics = new RateLimiterStatistics(); + + var policy = new TestRateLimiterPolicy("myKey1", 404, true, endpointStatistics); + options.Value.GlobalLimiter = new TestPartitionedRateLimiter(new TestRateLimiter(false), globalStatistics); + + var middleware = CreateTestRateLimitingMiddleware(options); + + var context = new DefaultHttpContext(); + var endpoint = CreateEndpointWithRateLimitPolicy(policy); + context.SetEndpoint(endpoint); + + // Act + await middleware.Invoke(context).DefaultTimeout(); + + // Assert + var statisticsFeature = context.Features.Get(); + + Assert.Equal(endpointStatistics, statisticsFeature.GetEndpointStatistics()); + Assert.Equal(globalStatistics, statisticsFeature.GetGlobalStatistics()); + } + private Endpoint CreateEndpointWithRateLimitPolicy(IRateLimiterPolicy policy) { var endpointBuilder = new TestEndpointBuilder(); @@ -639,5 +708,8 @@ private RateLimitingMiddleware CreateTestRateLimitingMiddleware(IOptions()); - private IOptions CreateOptionsAccessor() => Options.Create(new RateLimiterOptions()); + private IOptions CreateOptionsAccessor(bool trackStatistics = false) => Options.Create(new RateLimiterOptions() + { + TrackStatistics = trackStatistics + }); } diff --git a/src/Middleware/RateLimiting/test/TestPartitionedRateLimiter.cs b/src/Middleware/RateLimiting/test/TestPartitionedRateLimiter.cs index fa19f79779b7..791cd82870a3 100644 --- a/src/Middleware/RateLimiting/test/TestPartitionedRateLimiter.cs +++ b/src/Middleware/RateLimiting/test/TestPartitionedRateLimiter.cs @@ -13,12 +13,14 @@ namespace Microsoft.AspNetCore.RateLimiting; internal class TestPartitionedRateLimiter : PartitionedRateLimiter { private List limiters = new List(); + private RateLimiterStatistics _statistics; public TestPartitionedRateLimiter() { } - public TestPartitionedRateLimiter(RateLimiter limiter) + public TestPartitionedRateLimiter(RateLimiter limiter, RateLimiterStatistics statistics = null) { limiters.Add(limiter); + _statistics = statistics; } public void AddLimiter(RateLimiter limiter) @@ -28,7 +30,7 @@ public void AddLimiter(RateLimiter limiter) public override RateLimiterStatistics GetStatistics(TResource resourceID) { - throw new NotImplementedException(); + return _statistics; } protected override RateLimitLease AttemptAcquireCore(TResource resourceID, int permitCount) diff --git a/src/Middleware/RateLimiting/test/TestRateLimiter.cs b/src/Middleware/RateLimiting/test/TestRateLimiter.cs index 2ec0ace3ae66..0643a0bb0699 100644 --- a/src/Middleware/RateLimiting/test/TestRateLimiter.cs +++ b/src/Middleware/RateLimiting/test/TestRateLimiter.cs @@ -8,17 +8,19 @@ namespace Microsoft.AspNetCore.RateLimiting; internal class TestRateLimiter : RateLimiter { private readonly bool _alwaysAccept; + private RateLimiterStatistics _statistics; - public TestRateLimiter(bool alwaysAccept) + public TestRateLimiter(bool alwaysAccept, RateLimiterStatistics statistics = null) { _alwaysAccept = alwaysAccept; + _statistics = statistics; } public override TimeSpan? IdleDuration => throw new NotImplementedException(); public override RateLimiterStatistics GetStatistics() { - throw new NotImplementedException(); + return _statistics; } protected override RateLimitLease AttemptAcquireCore(int permitCount) diff --git a/src/Middleware/RateLimiting/test/TestRateLimiterPolicy.cs b/src/Middleware/RateLimiting/test/TestRateLimiterPolicy.cs index 49a7cd7845a7..ce80b044b25f 100644 --- a/src/Middleware/RateLimiting/test/TestRateLimiterPolicy.cs +++ b/src/Middleware/RateLimiting/test/TestRateLimiterPolicy.cs @@ -10,11 +10,13 @@ internal class TestRateLimiterPolicy : IRateLimiterPolicy private readonly string _key; private readonly bool _alwaysAccept; private readonly Func _onRejected; + private readonly RateLimiterStatistics _statistics; - public TestRateLimiterPolicy(string key, int statusCode, bool alwaysAccept) + public TestRateLimiterPolicy(string key, int statusCode, bool alwaysAccept, RateLimiterStatistics statistics = null) { _key = key; _alwaysAccept = alwaysAccept; + _statistics = statistics; _onRejected = (context, token) => { @@ -29,7 +31,7 @@ public RateLimitPartition GetPartition(HttpContext httpContext) { return RateLimitPartition.Get(_key, (key => { - return new TestRateLimiter(_alwaysAccept); + return new TestRateLimiter(_alwaysAccept, _statistics); })); } }