Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.Threading.RateLimiting;
using Microsoft.AspNetCore.Http;

namespace Microsoft.AspNetCore.RateLimiting;

internal sealed class DefaultRateLimiterStatisticsFeature : IRateLimiterStatisticsFeature
{
private readonly PartitionedRateLimiter<HttpContext>? _globalLimiter;
private readonly PartitionedRateLimiter<HttpContext> _endpointLimiter;
private readonly HttpContext _httpContext;

public DefaultRateLimiterStatisticsFeature(
PartitionedRateLimiter<HttpContext>? globalLimiter,
PartitionedRateLimiter<HttpContext> endpointLimiter,
HttpContext context)
{
_globalLimiter = globalLimiter;
_endpointLimiter = endpointLimiter;
_httpContext = context;
}

public RateLimiterStatistics? GetEndpointStatistics() => _endpointLimiter.GetStatistics(_httpContext);

public RateLimiterStatistics? GetGlobalStatistics() => _globalLimiter?.GetStatistics(_httpContext);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.Threading.RateLimiting;
using Microsoft.AspNetCore.Http;

namespace Microsoft.AspNetCore.RateLimiting;

/// <summary>
/// An Interface which is used to represent <see cref="RateLimiter"/> statistics methods for global and endpoint limiters.
/// Obtained via <see cref="HttpContext.Features"/>.
/// </summary>
/// <remarks>
/// Requires <see cref="RateLimiterOptions.TrackStatistics"/> to be true.
/// </remarks>
public interface IRateLimiterStatisticsFeature
{
/// <summary>
/// Method to fetch <see cref="RateLimiterStatistics"/> for the global <see cref="PartitionedRateLimiter"/>
/// </summary>
/// <returns><see cref="RateLimiterStatistics"/> for the global <see cref="PartitionedRateLimiter"/>.</returns>
RateLimiterStatistics? GetGlobalStatistics();
/// <summary>
/// Method to fetch <see cref="RateLimiterStatistics"/> for the endpoints <see cref="PartitionedRateLimiter"/>
/// </summary>
/// <returns><see cref="RateLimiterStatistics"/> for the endpoints <see cref="PartitionedRateLimiter"/>.</returns>
RateLimiterStatistics? GetEndpointStatistics();

}
5 changes: 5 additions & 0 deletions src/Middleware/RateLimiting/src/PublicAPI.Unshipped.txt
Original file line number Diff line number Diff line change
@@ -1 +1,6 @@
#nullable enable
Microsoft.AspNetCore.RateLimiting.IRateLimiterStatisticsFeature
Microsoft.AspNetCore.RateLimiting.IRateLimiterStatisticsFeature.GetEndpointStatistics() -> System.Threading.RateLimiting.RateLimiterStatistics?
Microsoft.AspNetCore.RateLimiting.IRateLimiterStatisticsFeature.GetGlobalStatistics() -> System.Threading.RateLimiting.RateLimiterStatistics?
Microsoft.AspNetCore.RateLimiting.RateLimiterOptions.TrackStatistics.get -> bool
Microsoft.AspNetCore.RateLimiting.RateLimiterOptions.TrackStatistics.set -> void
8 changes: 8 additions & 0 deletions src/Middleware/RateLimiting/src/RateLimiterOptions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,14 @@ public sealed class RateLimiterOptions
/// </remarks>
public int RejectionStatusCode { get; set; } = StatusCodes.Status503ServiceUnavailable;

/// <summary>
/// Gets or sets whether to track global and endpoint <see cref="RateLimiterStatistics"/>.
/// </summary>
/// <remarks>
/// If enabled, adds <see cref="IRateLimiterStatisticsFeature"/> to <see cref="HttpContext.Features"/>.
/// </remarks>
public bool TrackStatistics { get; set; }

/// <summary>
/// Adds a new rate limiting policy with the given <paramref name="policyName"/>
/// </summary>
Expand Down
16 changes: 14 additions & 2 deletions src/Middleware/RateLimiting/src/RateLimitingMiddleware.cs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ internal sealed partial class RateLimitingMiddleware
private readonly PartitionedRateLimiter<HttpContext>? _globalLimiter;
private readonly PartitionedRateLimiter<HttpContext> _endpointLimiter;
private readonly int _rejectionStatusCode;
private readonly bool _trackStatistics;
private readonly Dictionary<string, DefaultRateLimiterPolicy> _policyMap;
private readonly DefaultKeyType _defaultPolicyKey = new DefaultKeyType("__defaultPolicy", new PolicyNameKey { PolicyName = "__defaultPolicyKey" });

Expand All @@ -39,6 +40,7 @@ public RateLimitingMiddleware(RequestDelegate next, ILogger<RateLimitingMiddlewa
_logger = logger;
_defaultOnRejected = options.Value.OnRejected;
_rejectionStatusCode = options.Value.RejectionStatusCode;
_trackStatistics = options.Value.TrackStatistics;
_policyMap = new Dictionary<string, DefaultRateLimiterPolicy>(options.Value.PolicyMap);

// Activate policies passed to AddPolicy<TPartitionKey, TPolicy>
Expand All @@ -49,7 +51,6 @@ public RateLimitingMiddleware(RequestDelegate next, ILogger<RateLimitingMiddlewa

_globalLimiter = options.Value.GlobalLimiter;
_endpointLimiter = CreateEndpointLimiter();

}

// TODO - EventSource?
Expand Down Expand Up @@ -78,6 +79,12 @@ public Task Invoke(HttpContext context)
private async Task InvokeInternal(HttpContext context, EnableRateLimitingAttribute? enableRateLimitingAttribute)
{
using var leaseContext = await TryAcquireAsync(context);

if (_trackStatistics)
{
AddRateLimiterStatisticsFeature(context);
}

if (leaseContext.Lease?.IsAcquired == true)
{
await _next(context);
Expand Down Expand Up @@ -242,6 +249,11 @@ private PartitionedRateLimiter<HttpContext> CreateEndpointLimiter()
}, new DefaultKeyTypeEqualityComparer());
}

private void AddRateLimiterStatisticsFeature(HttpContext context)
{
context.Features.Set<IRateLimiterStatisticsFeature>(new DefaultRateLimiterStatisticsFeature(_globalLimiter, _endpointLimiter, context));
Copy link
Copy Markdown
Member

@BrennanConroy BrennanConroy Jan 13, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like we can allocate this once and reuse it

}

private static partial class RateLimiterLog
{
[LoggerMessage(1, LogLevel.Debug, "Rate limits exceeded, rejecting this request.", EventName = "RequestRejectedLimitsExceeded")]
Expand All @@ -253,4 +265,4 @@ private static partial class RateLimiterLog
[LoggerMessage(3, LogLevel.Debug, "The request was canceled.", EventName = "RequestCanceled")]
internal static partial void RequestCanceled(ILogger logger);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -606,6 +606,75 @@ public async Task MultipleEndpointPolicies_LastOneWins()
Assert.Equal(StatusCodes.Status403Forbidden, context.Response.StatusCode);
}

[Fact]
public async Task StatisticsFeature_NotTracked()
{
// Arrange
var options = CreateOptionsAccessor();
var policy = new TestRateLimiterPolicy("myKey1", 404, false);

var middleware = CreateTestRateLimitingMiddleware(options);

var context = new DefaultHttpContext();
var endpoint = CreateEndpointWithRateLimitPolicy(policy);
context.SetEndpoint(endpoint);

// Act
await middleware.Invoke(context).DefaultTimeout();

// Assert
Assert.Null(context.Features.Get<IRateLimiterStatisticsFeature>());
}

[Fact]
public async Task StatisticsFeature_SuccessfullyTracked()
{
// Arrange
var options = CreateOptionsAccessor(trackStatistics: true);

var policy = new TestRateLimiterPolicy("myKey1", 404, false);

var middleware = CreateTestRateLimitingMiddleware(options);

var context = new DefaultHttpContext();
var endpoint = CreateEndpointWithRateLimitPolicy(policy);
context.SetEndpoint(endpoint);

// Act
await middleware.Invoke(context).DefaultTimeout();

// Assert
Assert.NotNull(context.Features.Get<IRateLimiterStatisticsFeature>());
}

[Fact]
public async Task StatisticsFeature_GetsStatistics_ForGlobalAndEndpointLimiter()
{
// Arrange
var options = CreateOptionsAccessor(trackStatistics: true);

var globalStatistics = new RateLimiterStatistics();
var endpointStatistics = new RateLimiterStatistics();

var policy = new TestRateLimiterPolicy("myKey1", 404, true, endpointStatistics);
options.Value.GlobalLimiter = new TestPartitionedRateLimiter<HttpContext>(new TestRateLimiter(false), globalStatistics);

var middleware = CreateTestRateLimitingMiddleware(options);

var context = new DefaultHttpContext();
var endpoint = CreateEndpointWithRateLimitPolicy(policy);
context.SetEndpoint(endpoint);

// Act
await middleware.Invoke(context).DefaultTimeout();

// Assert
var statisticsFeature = context.Features.Get<IRateLimiterStatisticsFeature>();

Assert.Equal(endpointStatistics, statisticsFeature.GetEndpointStatistics());
Assert.Equal(globalStatistics, statisticsFeature.GetGlobalStatistics());
}

private Endpoint CreateEndpointWithRateLimitPolicy<TPartitionKey>(IRateLimiterPolicy<TPartitionKey> policy)
{
var endpointBuilder = new TestEndpointBuilder();
Expand Down Expand Up @@ -639,5 +708,8 @@ private RateLimitingMiddleware CreateTestRateLimitingMiddleware(IOptions<RateLim
options,
serviceProvider ?? Mock.Of<IServiceProvider>());

private IOptions<RateLimiterOptions> CreateOptionsAccessor() => Options.Create(new RateLimiterOptions());
private IOptions<RateLimiterOptions> CreateOptionsAccessor(bool trackStatistics = false) => Options.Create(new RateLimiterOptions()
{
TrackStatistics = trackStatistics
});
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,14 @@ namespace Microsoft.AspNetCore.RateLimiting;
internal class TestPartitionedRateLimiter<TResource> : PartitionedRateLimiter<TResource>
{
private List<RateLimiter> limiters = new List<RateLimiter>();
private RateLimiterStatistics _statistics;

public TestPartitionedRateLimiter() { }

public TestPartitionedRateLimiter(RateLimiter limiter)
public TestPartitionedRateLimiter(RateLimiter limiter, RateLimiterStatistics statistics = null)
{
limiters.Add(limiter);
_statistics = statistics;
}

public void AddLimiter(RateLimiter limiter)
Expand All @@ -28,7 +30,7 @@ public void AddLimiter(RateLimiter limiter)

public override RateLimiterStatistics GetStatistics(TResource resourceID)
{
throw new NotImplementedException();
return _statistics;
}

protected override RateLimitLease AttemptAcquireCore(TResource resourceID, int permitCount)
Expand Down
6 changes: 4 additions & 2 deletions src/Middleware/RateLimiting/test/TestRateLimiter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,19 @@ namespace Microsoft.AspNetCore.RateLimiting;
internal class TestRateLimiter : RateLimiter
{
private readonly bool _alwaysAccept;
private RateLimiterStatistics _statistics;

public TestRateLimiter(bool alwaysAccept)
public TestRateLimiter(bool alwaysAccept, RateLimiterStatistics statistics = null)
{
_alwaysAccept = alwaysAccept;
_statistics = statistics;
}

public override TimeSpan? IdleDuration => throw new NotImplementedException();

public override RateLimiterStatistics GetStatistics()
{
throw new NotImplementedException();
return _statistics;
}

protected override RateLimitLease AttemptAcquireCore(int permitCount)
Expand Down
6 changes: 4 additions & 2 deletions src/Middleware/RateLimiting/test/TestRateLimiterPolicy.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,13 @@ internal class TestRateLimiterPolicy : IRateLimiterPolicy<string>
private readonly string _key;
private readonly bool _alwaysAccept;
private readonly Func<OnRejectedContext, CancellationToken, ValueTask> _onRejected;
private readonly RateLimiterStatistics _statistics;

public TestRateLimiterPolicy(string key, int statusCode, bool alwaysAccept)
public TestRateLimiterPolicy(string key, int statusCode, bool alwaysAccept, RateLimiterStatistics statistics = null)
{
_key = key;
_alwaysAccept = alwaysAccept;
_statistics = statistics;

_onRejected = (context, token) =>
{
Expand All @@ -29,7 +31,7 @@ public RateLimitPartition<string> GetPartition(HttpContext httpContext)
{
return RateLimitPartition.Get<string>(_key, (key =>
{
return new TestRateLimiter(_alwaysAccept);
return new TestRateLimiter(_alwaysAccept, _statistics);
}));
}
}