From 0082137061ec9ff14b25e258673c08e9714343e1 Mon Sep 17 00:00:00 2001 From: Martin Tomka <martintomka@microsoft.com> Date: Tue, 4 Jul 2023 06:55:09 +0200 Subject: [PATCH] Add support for partitioned rate limiter --- .../RateLimiterResilienceStrategy.cs | 8 ++-- ...iterResilienceStrategyBuilderExtensions.cs | 4 +- .../RateLimiterStrategyOptions.cs | 2 +- .../ResilienceRateLimiter.cs | 45 +++++++++++++++++++ ...esilienceStrategyBuilderExtensionsTests.cs | 22 +++++---- .../RateLimiterResilienceStrategyTests.cs | 2 +- .../ResilienceRateLimiterTests.cs | 38 ++++++++++++++++ 7 files changed, 104 insertions(+), 17 deletions(-) create mode 100644 src/Polly.RateLimiting/ResilienceRateLimiter.cs create mode 100644 test/Polly.RateLimiting.Tests/ResilienceRateLimiterTests.cs diff --git a/src/Polly.RateLimiting/RateLimiterResilienceStrategy.cs b/src/Polly.RateLimiting/RateLimiterResilienceStrategy.cs index 3a284d12864..2cbc0ba2a78 100644 --- a/src/Polly.RateLimiting/RateLimiterResilienceStrategy.cs +++ b/src/Polly.RateLimiting/RateLimiterResilienceStrategy.cs @@ -8,7 +8,7 @@ internal sealed class RateLimiterResilienceStrategy : ResilienceStrategy private readonly ResilienceStrategyTelemetry _telemetry; public RateLimiterResilienceStrategy( - RateLimiter limiter, + ResilienceRateLimiter limiter, Func<OnRateLimiterRejectedArguments, ValueTask>? onRejected, ResilienceStrategyTelemetry telemetry) { @@ -18,7 +18,7 @@ public RateLimiterResilienceStrategy( _telemetry = telemetry; } - public RateLimiter Limiter { get; } + public ResilienceRateLimiter Limiter { get; } public Func<OnRateLimiterRejectedArguments, ValueTask>? OnLeaseRejected { get; } @@ -27,9 +27,7 @@ protected override async ValueTask<Outcome<TResult>> ExecuteCoreAsync<TResult, T ResilienceContext context, TState state) { - using var lease = await Limiter.AcquireAsync( - permitCount: 1, - context.CancellationToken).ConfigureAwait(context.ContinueOnCapturedContext); + using var lease = await Limiter.AcquireAsync(context).ConfigureAwait(context.ContinueOnCapturedContext); if (lease.IsAcquired) { diff --git a/src/Polly.RateLimiting/RateLimiterResilienceStrategyBuilderExtensions.cs b/src/Polly.RateLimiting/RateLimiterResilienceStrategyBuilderExtensions.cs index 6fb9bd756eb..a51f12a8c17 100644 --- a/src/Polly.RateLimiting/RateLimiterResilienceStrategyBuilderExtensions.cs +++ b/src/Polly.RateLimiting/RateLimiterResilienceStrategyBuilderExtensions.cs @@ -78,7 +78,7 @@ public static TBuilder AddRateLimiter<TBuilder>( return builder.AddRateLimiter(new RateLimiterStrategyOptions { - RateLimiter = limiter, + RateLimiter = ResilienceRateLimiter.Create(limiter), }); } @@ -104,7 +104,7 @@ public static TBuilder AddRateLimiter<TBuilder>( context => { return new RateLimiterResilienceStrategy( - options.RateLimiter ?? new ConcurrencyLimiter(options.DefaultRateLimiterOptions), + options.RateLimiter ?? ResilienceRateLimiter.Create(new ConcurrencyLimiter(options.DefaultRateLimiterOptions)), options.OnRejected, context.Telemetry); }, diff --git a/src/Polly.RateLimiting/RateLimiterStrategyOptions.cs b/src/Polly.RateLimiting/RateLimiterStrategyOptions.cs index 75492c25ba6..b235287afbb 100644 --- a/src/Polly.RateLimiting/RateLimiterStrategyOptions.cs +++ b/src/Polly.RateLimiting/RateLimiterStrategyOptions.cs @@ -46,5 +46,5 @@ public class RateLimiterStrategyOptions : ResilienceStrategyOptions /// Defaults to <see langword="null"/>. If this property is <see langword="null"/>, /// then the strategy will use a <see cref="ConcurrencyLimiter"/> created using <see cref="DefaultRateLimiterOptions"/>. /// </remarks> - public RateLimiter? RateLimiter { get; set; } + public ResilienceRateLimiter? RateLimiter { get; set; } } diff --git a/src/Polly.RateLimiting/ResilienceRateLimiter.cs b/src/Polly.RateLimiting/ResilienceRateLimiter.cs new file mode 100644 index 00000000000..629694a58c2 --- /dev/null +++ b/src/Polly.RateLimiting/ResilienceRateLimiter.cs @@ -0,0 +1,45 @@ +using System.Threading.RateLimiting; + +namespace Polly.RateLimiting; + +/// <summary> +/// This class is just a simple adapter for the built-in limiters in the <c>System.Threading.RateLimiting</c> namespace. +/// </summary> +public sealed class ResilienceRateLimiter +{ + private ResilienceRateLimiter(RateLimiter? limiter, PartitionedRateLimiter<ResilienceContext>? partitionedLimiter) + { + Limiter = limiter; + PartitionedLimiter = partitionedLimiter; + } + + /// <summary> + /// Creates an instance of <see cref="ResilienceRateLimiter"/> from <paramref name="rateLimiter"/>. + /// </summary> + /// <param name="rateLimiter">The rate limiter instance.</param> + /// <returns>An instance of <see cref="ResilienceRateLimiter"/>.</returns> + public static ResilienceRateLimiter Create(RateLimiter rateLimiter) => new(Guard.NotNull(rateLimiter), null); + + /// <summary> + /// Creates an instance of <see cref="ResilienceRateLimiter"/> from partitioned <paramref name="rateLimiter"/>. + /// </summary> + /// <param name="rateLimiter">The rate limiter instance.</param> + /// <returns>An instance of <see cref="ResilienceRateLimiter"/>.</returns> + public static ResilienceRateLimiter Create(PartitionedRateLimiter<ResilienceContext> rateLimiter) => new(null, Guard.NotNull(rateLimiter)); + + internal RateLimiter? Limiter { get; } + + internal PartitionedRateLimiter<ResilienceContext>? PartitionedLimiter { get; } + + internal ValueTask<RateLimitLease> AcquireAsync(ResilienceContext context) + { + if (PartitionedLimiter is not null) + { + return PartitionedLimiter.AcquireAsync(context, permitCount: 1, context.CancellationToken); + } + else + { + return Limiter!.AcquireAsync(permitCount: 1, context.CancellationToken); + } + } +} diff --git a/test/Polly.RateLimiting.Tests/RateLimiterResilienceStrategyBuilderExtensionsTests.cs b/test/Polly.RateLimiting.Tests/RateLimiterResilienceStrategyBuilderExtensionsTests.cs index 3ee05e12505..2321cb4b1dc 100644 --- a/test/Polly.RateLimiting.Tests/RateLimiterResilienceStrategyBuilderExtensionsTests.cs +++ b/test/Polly.RateLimiting.Tests/RateLimiterResilienceStrategyBuilderExtensionsTests.cs @@ -73,14 +73,20 @@ public void AddRateLimiter_AllExtensions_Ok() [Fact] public void AddRateLimiter_Ok() { - new ResilienceStrategyBuilder().AddRateLimiter(new RateLimiterStrategyOptions + using var limiter = new ConcurrencyLimiter(new ConcurrencyLimiterOptions { - RateLimiter = new ConcurrencyLimiter(new ConcurrencyLimiterOptions + QueueLimit = 10, + PermitLimit = 10 + }); + + new ResilienceStrategyBuilder() + .AddRateLimiter(new RateLimiterStrategyOptions { - QueueLimit = 10, - PermitLimit = 10 + RateLimiter = ResilienceRateLimiter.Create(limiter) }) - }).Build().Should().BeOfType<RateLimiterResilienceStrategy>(); + .Build() + .Should() + .BeOfType<RateLimiterResilienceStrategy>(); } [Fact] @@ -117,7 +123,7 @@ public void AddRateLimiter_Options_Ok() var strategy = new ResilienceStrategyBuilder() .AddRateLimiter(new RateLimiterStrategyOptions { - RateLimiter = Mock.Of<RateLimiter>() + RateLimiter = ResilienceRateLimiter.Create(Mock.Of<RateLimiter>()) }) .Build(); @@ -141,13 +147,13 @@ private static void AssertRateLimiter(ResilienceStrategyBuilder<int> builder, bo strategy.OnLeaseRejected.Should().BeNull(); } - assertLimiter?.Invoke(strategy.Limiter); + assertLimiter?.Invoke(strategy.Limiter.Limiter!); } private static void AssertConcurrencyLimiter(ResilienceStrategyBuilder<int> builder, bool hasEvents) { var strategy = GetResilienceStrategy(builder.Build()); - strategy.Limiter.Should().BeOfType<ConcurrencyLimiter>(); + strategy.Limiter.Limiter.Should().BeOfType<ConcurrencyLimiter>(); if (hasEvents) { diff --git a/test/Polly.RateLimiting.Tests/RateLimiterResilienceStrategyTests.cs b/test/Polly.RateLimiting.Tests/RateLimiterResilienceStrategyTests.cs index e2c4cfe3bc5..583cf49d2f8 100644 --- a/test/Polly.RateLimiting.Tests/RateLimiterResilienceStrategyTests.cs +++ b/test/Polly.RateLimiting.Tests/RateLimiterResilienceStrategyTests.cs @@ -104,7 +104,7 @@ private RateLimiterResilienceStrategy Create() return (RateLimiterResilienceStrategy)builder .AddRateLimiter(new RateLimiterStrategyOptions { - RateLimiter = _limiter.Object, + RateLimiter = ResilienceRateLimiter.Create(_limiter.Object), OnRejected = _event }) .Build(); diff --git a/test/Polly.RateLimiting.Tests/ResilienceRateLimiterTests.cs b/test/Polly.RateLimiting.Tests/ResilienceRateLimiterTests.cs new file mode 100644 index 00000000000..33ee03b619d --- /dev/null +++ b/test/Polly.RateLimiting.Tests/ResilienceRateLimiterTests.cs @@ -0,0 +1,38 @@ +using System.Threading.RateLimiting; +using System.Threading.Tasks; +using Moq; +using Moq.Protected; + +namespace Polly.RateLimiting.Tests; + +public class ResilienceRateLimiterTests +{ + [Fact] + public async Task Create_RateLimiter_Ok() + { + var lease = Mock.Of<RateLimitLease>(); + var limiterMock = new Mock<RateLimiter>(MockBehavior.Strict); + limiterMock.Protected().Setup<ValueTask<RateLimitLease>>("AcquireAsyncCore", 1, default(CancellationToken)).ReturnsAsync(lease); + + var limiter = ResilienceRateLimiter.Create(limiterMock.Object); + + (await limiter.AcquireAsync(ResilienceContext.Get())).Should().Be(lease); + limiter.Limiter.Should().NotBeNull(); + limiterMock.VerifyAll(); + } + + [Fact] + public async Task Create_PartitionedRateLimiter_Ok() + { + var context = ResilienceContext.Get(); + var lease = Mock.Of<RateLimitLease>(); + var limiterMock = new Mock<PartitionedRateLimiter<ResilienceContext>>(MockBehavior.Strict); + limiterMock.Protected().Setup<ValueTask<RateLimitLease>>("AcquireAsyncCore", context, 1, default(CancellationToken)).ReturnsAsync(lease); + + var limiter = ResilienceRateLimiter.Create(limiterMock.Object); + + (await limiter.AcquireAsync(context)).Should().Be(lease); + limiter.PartitionedLimiter.Should().NotBeNull(); + limiterMock.VerifyAll(); + } +}