Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for PartitionedRateLimiter #1383

Merged
merged 1 commit into from
Jul 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions src/Polly.RateLimiting/RateLimiterResilienceStrategy.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ internal sealed class RateLimiterResilienceStrategy : ResilienceStrategy
private readonly ResilienceStrategyTelemetry _telemetry;

public RateLimiterResilienceStrategy(
RateLimiter limiter,
ResilienceRateLimiter limiter,
Func<OnRateLimiterRejectedArguments, ValueTask>? onRejected,
ResilienceStrategyTelemetry telemetry)
{
Expand All @@ -18,7 +18,7 @@ public RateLimiterResilienceStrategy(
_telemetry = telemetry;
}

public RateLimiter Limiter { get; }
public ResilienceRateLimiter Limiter { get; }

public Func<OnRateLimiterRejectedArguments, ValueTask>? OnLeaseRejected { get; }

Expand All @@ -27,9 +27,7 @@ protected override async ValueTask<Outcome<TResult>> ExecuteCoreAsync<TResult, T
ResilienceContext context,
TState state)
{
using var lease = await Limiter.AcquireAsync(
permitCount: 1,
context.CancellationToken).ConfigureAwait(context.ContinueOnCapturedContext);
using var lease = await Limiter.AcquireAsync(context).ConfigureAwait(context.ContinueOnCapturedContext);

if (lease.IsAcquired)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ public static TBuilder AddRateLimiter<TBuilder>(

return builder.AddRateLimiter(new RateLimiterStrategyOptions
{
RateLimiter = limiter,
RateLimiter = ResilienceRateLimiter.Create(limiter),
});
}

Expand All @@ -104,7 +104,7 @@ public static TBuilder AddRateLimiter<TBuilder>(
context =>
{
return new RateLimiterResilienceStrategy(
options.RateLimiter ?? new ConcurrencyLimiter(options.DefaultRateLimiterOptions),
options.RateLimiter ?? ResilienceRateLimiter.Create(new ConcurrencyLimiter(options.DefaultRateLimiterOptions)),
options.OnRejected,
context.Telemetry);
},
Expand Down
2 changes: 1 addition & 1 deletion src/Polly.RateLimiting/RateLimiterStrategyOptions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -46,5 +46,5 @@ public class RateLimiterStrategyOptions : ResilienceStrategyOptions
/// Defaults to <see langword="null"/>. If this property is <see langword="null"/>,
/// then the strategy will use a <see cref="ConcurrencyLimiter"/> created using <see cref="DefaultRateLimiterOptions"/>.
/// </remarks>
public RateLimiter? RateLimiter { get; set; }
public ResilienceRateLimiter? RateLimiter { get; set; }
}
45 changes: 45 additions & 0 deletions src/Polly.RateLimiting/ResilienceRateLimiter.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
using System.Threading.RateLimiting;

namespace Polly.RateLimiting;

/// <summary>
/// This class is just a simple adapter for the built-in limiters in the <c>System.Threading.RateLimiting</c> namespace.
/// </summary>
public sealed class ResilienceRateLimiter
{
private ResilienceRateLimiter(RateLimiter? limiter, PartitionedRateLimiter<ResilienceContext>? partitionedLimiter)
{
Limiter = limiter;
PartitionedLimiter = partitionedLimiter;
}

/// <summary>
/// Creates an instance of <see cref="ResilienceRateLimiter"/> from <paramref name="rateLimiter"/>.
/// </summary>
/// <param name="rateLimiter">The rate limiter instance.</param>
/// <returns>An instance of <see cref="ResilienceRateLimiter"/>.</returns>
public static ResilienceRateLimiter Create(RateLimiter rateLimiter) => new(Guard.NotNull(rateLimiter), null);

/// <summary>
/// Creates an instance of <see cref="ResilienceRateLimiter"/> from partitioned <paramref name="rateLimiter"/>.
/// </summary>
/// <param name="rateLimiter">The rate limiter instance.</param>
/// <returns>An instance of <see cref="ResilienceRateLimiter"/>.</returns>
public static ResilienceRateLimiter Create(PartitionedRateLimiter<ResilienceContext> rateLimiter) => new(null, Guard.NotNull(rateLimiter));

internal RateLimiter? Limiter { get; }

internal PartitionedRateLimiter<ResilienceContext>? PartitionedLimiter { get; }

internal ValueTask<RateLimitLease> AcquireAsync(ResilienceContext context)
{
if (PartitionedLimiter is not null)
{
return PartitionedLimiter.AcquireAsync(context, permitCount: 1, context.CancellationToken);
}
else
{
return Limiter!.AcquireAsync(permitCount: 1, context.CancellationToken);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -73,14 +73,20 @@ public void AddRateLimiter_AllExtensions_Ok()
[Fact]
public void AddRateLimiter_Ok()
{
new ResilienceStrategyBuilder().AddRateLimiter(new RateLimiterStrategyOptions
using var limiter = new ConcurrencyLimiter(new ConcurrencyLimiterOptions
{
RateLimiter = new ConcurrencyLimiter(new ConcurrencyLimiterOptions
QueueLimit = 10,
PermitLimit = 10
});

new ResilienceStrategyBuilder()
.AddRateLimiter(new RateLimiterStrategyOptions
{
QueueLimit = 10,
PermitLimit = 10
RateLimiter = ResilienceRateLimiter.Create(limiter)
})
}).Build().Should().BeOfType<RateLimiterResilienceStrategy>();
.Build()
.Should()
.BeOfType<RateLimiterResilienceStrategy>();
}

[Fact]
Expand Down Expand Up @@ -117,7 +123,7 @@ public void AddRateLimiter_Options_Ok()
var strategy = new ResilienceStrategyBuilder()
.AddRateLimiter(new RateLimiterStrategyOptions
{
RateLimiter = Mock.Of<RateLimiter>()
RateLimiter = ResilienceRateLimiter.Create(Mock.Of<RateLimiter>())
})
.Build();

Expand All @@ -141,13 +147,13 @@ private static void AssertRateLimiter(ResilienceStrategyBuilder<int> builder, bo
strategy.OnLeaseRejected.Should().BeNull();
}

assertLimiter?.Invoke(strategy.Limiter);
assertLimiter?.Invoke(strategy.Limiter.Limiter!);
}

private static void AssertConcurrencyLimiter(ResilienceStrategyBuilder<int> builder, bool hasEvents)
{
var strategy = GetResilienceStrategy(builder.Build());
strategy.Limiter.Should().BeOfType<ConcurrencyLimiter>();
strategy.Limiter.Limiter.Should().BeOfType<ConcurrencyLimiter>();

if (hasEvents)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ private RateLimiterResilienceStrategy Create()
return (RateLimiterResilienceStrategy)builder
.AddRateLimiter(new RateLimiterStrategyOptions
{
RateLimiter = _limiter.Object,
RateLimiter = ResilienceRateLimiter.Create(_limiter.Object),
OnRejected = _event
})
.Build();
Expand Down
38 changes: 38 additions & 0 deletions test/Polly.RateLimiting.Tests/ResilienceRateLimiterTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
using System.Threading.RateLimiting;
using System.Threading.Tasks;
using Moq;
using Moq.Protected;

namespace Polly.RateLimiting.Tests;

public class ResilienceRateLimiterTests
{
[Fact]
public async Task Create_RateLimiter_Ok()
{
var lease = Mock.Of<RateLimitLease>();
var limiterMock = new Mock<RateLimiter>(MockBehavior.Strict);
limiterMock.Protected().Setup<ValueTask<RateLimitLease>>("AcquireAsyncCore", 1, default(CancellationToken)).ReturnsAsync(lease);

var limiter = ResilienceRateLimiter.Create(limiterMock.Object);

(await limiter.AcquireAsync(ResilienceContext.Get())).Should().Be(lease);
limiter.Limiter.Should().NotBeNull();
limiterMock.VerifyAll();
}

[Fact]
public async Task Create_PartitionedRateLimiter_Ok()
{
var context = ResilienceContext.Get();
var lease = Mock.Of<RateLimitLease>();
var limiterMock = new Mock<PartitionedRateLimiter<ResilienceContext>>(MockBehavior.Strict);
limiterMock.Protected().Setup<ValueTask<RateLimitLease>>("AcquireAsyncCore", context, 1, default(CancellationToken)).ReturnsAsync(lease);

var limiter = ResilienceRateLimiter.Create(limiterMock.Object);

(await limiter.AcquireAsync(context)).Should().Be(lease);
limiter.PartitionedLimiter.Should().NotBeNull();
limiterMock.VerifyAll();
}
}