Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using Microsoft.AspNetCore.Connections.Features;

namespace Microsoft.AspNetCore.RateLimiting.Features;

public interface IRateLimiterContextFeature
Copy link
Copy Markdown
Contributor Author

@MadL1me MadL1me Dec 17, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: there is no xml docs, cause i'm not sure about the API

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

New APIs need to be approved. This needs an API proposal.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new APIs also need to be added to https://github.com/dotnet/aspnetcore/blob/main/src/Middleware/RateLimiting/src/PublicAPI.Unshipped.txt. There's a suggestion hint in VS to do this for you.

{
RateLimiterContext Context { get; }
}

18 changes: 18 additions & 0 deletions src/Middleware/RateLimiting/src/Features/RateLimiterContext.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.Threading.RateLimiting;
using Microsoft.AspNetCore.Http;

namespace Microsoft.AspNetCore.RateLimiting.Features;

public class RateLimiterContext
{
public required HttpContext HttpContext { get; set; }

public required RateLimitLease Lease { get; set; }

public required PartitionedRateLimiter<HttpContext>? GlobalLimiter { get; set; }

public required PartitionedRateLimiter<HttpContext> EndpointLimiter { get; set; }
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;

namespace Microsoft.AspNetCore.RateLimiting.Features;

public class RateLimiterContextFeature : IRateLimiterContextFeature
{
private readonly RateLimiterContext _context;

public RateLimiterContextFeature(RateLimiterContext context)
{
_context = context;
}

public RateLimiterContext Context => _context;
}
11 changes: 11 additions & 0 deletions src/Middleware/RateLimiting/src/PublicAPI.Unshipped.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,17 @@ Microsoft.AspNetCore.RateLimiting.DisableRateLimitingAttribute.DisableRateLimiti
Microsoft.AspNetCore.RateLimiting.EnableRateLimitingAttribute
Microsoft.AspNetCore.RateLimiting.EnableRateLimitingAttribute.EnableRateLimitingAttribute(string! policyName) -> void
Microsoft.AspNetCore.RateLimiting.EnableRateLimitingAttribute.PolicyName.get -> string?
Microsoft.AspNetCore.RateLimiting.Features.IRateLimiterContextFeature
Microsoft.AspNetCore.RateLimiting.Features.IRateLimiterContextFeature.Context.get -> Microsoft.AspNetCore.RateLimiting.Features.RateLimiterContext!
Microsoft.AspNetCore.RateLimiting.Features.RateLimiterContext
Microsoft.AspNetCore.RateLimiting.Features.RateLimiterContext.EndpointLimiter.get -> System.Threading.RateLimiting.PartitionedRateLimiter<Microsoft.AspNetCore.Http.HttpContext!>!
Microsoft.AspNetCore.RateLimiting.Features.RateLimiterContext.GlobalLimiter.get -> System.Threading.RateLimiting.PartitionedRateLimiter<Microsoft.AspNetCore.Http.HttpContext!>?
Microsoft.AspNetCore.RateLimiting.Features.RateLimiterContext.HttpContext.get -> Microsoft.AspNetCore.Http.HttpContext!
Microsoft.AspNetCore.RateLimiting.Features.RateLimiterContext.HttpContext.set -> void
Microsoft.AspNetCore.RateLimiting.Features.RateLimiterContext.Lease.get -> System.Threading.RateLimiting.RateLimitLease!
Microsoft.AspNetCore.RateLimiting.Features.RateLimiterContextFeature
Microsoft.AspNetCore.RateLimiting.Features.RateLimiterContextFeature.Context.get -> Microsoft.AspNetCore.RateLimiting.Features.RateLimiterContext!
Microsoft.AspNetCore.RateLimiting.Features.RateLimiterContextFeature.RateLimiterContextFeature(Microsoft.AspNetCore.RateLimiting.Features.RateLimiterContext! context) -> void
Microsoft.AspNetCore.RateLimiting.IRateLimiterPolicy<TPartitionKey>
Microsoft.AspNetCore.RateLimiting.IRateLimiterPolicy<TPartitionKey>.GetPartition(Microsoft.AspNetCore.Http.HttpContext! httpContext) -> System.Threading.RateLimiting.RateLimitPartition<TPartitionKey>
Microsoft.AspNetCore.RateLimiting.IRateLimiterPolicy<TPartitionKey>.OnRejected.get -> System.Func<Microsoft.AspNetCore.RateLimiting.OnRejectedContext!, System.Threading.CancellationToken, System.Threading.Tasks.ValueTask>?
Expand Down
19 changes: 18 additions & 1 deletion src/Middleware/RateLimiting/src/RateLimitingMiddleware.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

using System.Threading.RateLimiting;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.RateLimiting.Features;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options;

Expand Down Expand Up @@ -75,9 +76,25 @@ public Task Invoke(HttpContext context)
return InvokeInternal(context, enableRateLimitingAttribute);
}

private void AddRateLimiterHttpContextFeature(HttpContext context, LeaseContext lease)
{
var rlContext = new RateLimiterContext()
{
EndpointLimiter = _endpointLimiter,
GlobalLimiter = _globalLimiter,
HttpContext = context,
Lease = lease.Lease
};

context.Features.Set<IRateLimiterContextFeature>(new RateLimiterContextFeature(rlContext));
}

private async Task InvokeInternal(HttpContext context, EnableRateLimitingAttribute? enableRateLimitingAttribute)
{
using var leaseContext = await TryAcquireAsync(context);

AddRateLimiterHttpContextFeature(context, leaseContext);

if (leaseContext.Lease?.IsAcquired == true)
{
await _next(context);
Expand Down Expand Up @@ -253,4 +270,4 @@ private static partial class RateLimiterLog
[LoggerMessage(3, LogLevel.Debug, "The request was canceled.", EventName = "RequestCanceled")]
internal static partial void RequestCanceled(ILogger logger);
}
}
}
107 changes: 107 additions & 0 deletions src/Middleware/RateLimiting/test/RateLimitingMiddlewareTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using System.Threading.RateLimiting;
using Microsoft.AspNetCore.Builder;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.RateLimiting.Features;
using Microsoft.AspNetCore.Testing;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Logging.Abstractions;
Expand Down Expand Up @@ -606,6 +607,112 @@ public async Task MultipleEndpointPolicies_LastOneWins()
Assert.Equal(StatusCodes.Status403Forbidden, context.Response.StatusCode);
}

[Fact]
public async Task RateLimiterContextFeature_PropertiesInitialized()
{
// Arrange
const string policy = "test_policy";
var options = CreateOptionsAccessor();

options.Value.GlobalLimiter = PartitionedRateLimiter.Create<HttpContext, string>(context =>
{
return RateLimitPartition.GetConcurrencyLimiter<string>("globalLimiter", key => new ConcurrencyLimiterOptions
{
PermitLimit = 10,
QueueProcessingOrder = QueueProcessingOrder.NewestFirst,
QueueLimit = 5
});
});
options.Value.AddConcurrencyLimiter(policy, configure =>
{
configure.QueueLimit = 10;
configure.PermitLimit = 10;
});

var middleware = CreateTestRateLimitingMiddleware(options);
var context = new DefaultHttpContext();
var endpoint = CreateEndpointWithRateLimitPolicy(policy);
context.SetEndpoint(endpoint);

//Act
await middleware.Invoke(context).DefaultTimeout();

// Assert
var rlContext = context.Features.Get<IRateLimiterContextFeature>();

Assert.NotNull(rlContext);
Assert.NotNull(rlContext.Context);
Assert.NotNull(rlContext.Context.GlobalLimiter);
Assert.NotNull(rlContext.Context.EndpointLimiter);
Assert.NotNull(rlContext.Context.HttpContext);
}

[Fact]
public async Task RateLimiterContextFeature_GetsCorrectGlobalRateLimiterInfo_WithSingleRequest()
{
// Arrange
var options = CreateOptionsAccessor();

options.Value.GlobalLimiter = PartitionedRateLimiter.Create<HttpContext, string>(context =>
{
return RateLimitPartition.GetConcurrencyLimiter<string>("globalLimiter", key => new ConcurrencyLimiterOptions
{
PermitLimit = 10,
QueueProcessingOrder = QueueProcessingOrder.NewestFirst,
QueueLimit = 5
});
});

var middleware = CreateTestRateLimitingMiddleware(options);
var context = new DefaultHttpContext();

//Act
await middleware.Invoke(context).DefaultTimeout();

// Assert
var rlContext = context.Features.Get<IRateLimiterContextFeature>();
var statistics = rlContext.Context.GlobalLimiter.GetStatistics(context);

Assert.NotNull(statistics);
Assert.True(rlContext.Context.Lease.IsAcquired);
Assert.Equal(1, statistics.TotalSuccessfulLeases);
Assert.Equal(0, statistics.TotalFailedLeases);
Assert.Equal(10, statistics.CurrentAvailablePermits);
}

[Fact]
public async Task GetsCorrectEndpointRateLimiterInfo_With3Requests()
{
// Arrange
const string policy = "test_policy";
var options = CreateOptionsAccessor();

options.Value.AddConcurrencyLimiter(policy, configure =>
{
configure.QueueLimit = 10;
configure.PermitLimit = 10;
});

var middleware = CreateTestRateLimitingMiddleware(options);

var context = new DefaultHttpContext();
var endpoint = CreateEndpointWithRateLimitPolicy(policy);
context.SetEndpoint(endpoint);

// Act
await middleware.Invoke(context).DefaultTimeout();
await middleware.Invoke(context).DefaultTimeout();
await middleware.Invoke(context).DefaultTimeout();

// Assert
var rlContext = context.Features.Get<IRateLimiterContextFeature>();
var statistics = rlContext.Context.EndpointLimiter.GetStatistics(context);

Assert.NotNull(statistics);
Assert.True(rlContext.Context.Lease.IsAcquired);
Assert.True(statistics.TotalSuccessfulLeases == 3);
}

private Endpoint CreateEndpointWithRateLimitPolicy<TPartitionKey>(IRateLimiterPolicy<TPartitionKey> policy)
{
var endpointBuilder = new TestEndpointBuilder();
Expand Down