diff --git a/src/Polly.RateLimiting/RateLimiterResilienceStrategy.cs b/src/Polly.RateLimiting/RateLimiterResilienceStrategy.cs index 3a284d12864..2cbc0ba2a78 100644 --- a/src/Polly.RateLimiting/RateLimiterResilienceStrategy.cs +++ b/src/Polly.RateLimiting/RateLimiterResilienceStrategy.cs @@ -8,7 +8,7 @@ internal sealed class RateLimiterResilienceStrategy : ResilienceStrategy private readonly ResilienceStrategyTelemetry _telemetry; public RateLimiterResilienceStrategy( - RateLimiter limiter, + ResilienceRateLimiter limiter, Func? onRejected, ResilienceStrategyTelemetry telemetry) { @@ -18,7 +18,7 @@ public RateLimiterResilienceStrategy( _telemetry = telemetry; } - public RateLimiter Limiter { get; } + public ResilienceRateLimiter Limiter { get; } public Func? OnLeaseRejected { get; } @@ -27,9 +27,7 @@ protected override async ValueTask> ExecuteCoreAsync( return builder.AddRateLimiter(new RateLimiterStrategyOptions { - RateLimiter = limiter, + RateLimiter = ResilienceRateLimiter.Create(limiter), }); } @@ -104,7 +104,7 @@ public static TBuilder AddRateLimiter( context => { return new RateLimiterResilienceStrategy( - options.RateLimiter ?? new ConcurrencyLimiter(options.DefaultRateLimiterOptions), + options.RateLimiter ?? ResilienceRateLimiter.Create(new ConcurrencyLimiter(options.DefaultRateLimiterOptions)), options.OnRejected, context.Telemetry); }, diff --git a/src/Polly.RateLimiting/RateLimiterStrategyOptions.cs b/src/Polly.RateLimiting/RateLimiterStrategyOptions.cs index 75492c25ba6..b235287afbb 100644 --- a/src/Polly.RateLimiting/RateLimiterStrategyOptions.cs +++ b/src/Polly.RateLimiting/RateLimiterStrategyOptions.cs @@ -46,5 +46,5 @@ public class RateLimiterStrategyOptions : ResilienceStrategyOptions /// Defaults to . If this property is , /// then the strategy will use a created using . /// - public RateLimiter? RateLimiter { get; set; } + public ResilienceRateLimiter? RateLimiter { get; set; } } diff --git a/src/Polly.RateLimiting/ResilienceRateLimiter.cs b/src/Polly.RateLimiting/ResilienceRateLimiter.cs new file mode 100644 index 00000000000..629694a58c2 --- /dev/null +++ b/src/Polly.RateLimiting/ResilienceRateLimiter.cs @@ -0,0 +1,45 @@ +using System.Threading.RateLimiting; + +namespace Polly.RateLimiting; + +/// +/// This class is just a simple adapter for the built-in limiters in the System.Threading.RateLimiting namespace. +/// +public sealed class ResilienceRateLimiter +{ + private ResilienceRateLimiter(RateLimiter? limiter, PartitionedRateLimiter? partitionedLimiter) + { + Limiter = limiter; + PartitionedLimiter = partitionedLimiter; + } + + /// + /// Creates an instance of from . + /// + /// The rate limiter instance. + /// An instance of . + public static ResilienceRateLimiter Create(RateLimiter rateLimiter) => new(Guard.NotNull(rateLimiter), null); + + /// + /// Creates an instance of from partitioned . + /// + /// The rate limiter instance. + /// An instance of . + public static ResilienceRateLimiter Create(PartitionedRateLimiter rateLimiter) => new(null, Guard.NotNull(rateLimiter)); + + internal RateLimiter? Limiter { get; } + + internal PartitionedRateLimiter? PartitionedLimiter { get; } + + internal ValueTask AcquireAsync(ResilienceContext context) + { + if (PartitionedLimiter is not null) + { + return PartitionedLimiter.AcquireAsync(context, permitCount: 1, context.CancellationToken); + } + else + { + return Limiter!.AcquireAsync(permitCount: 1, context.CancellationToken); + } + } +} diff --git a/test/Polly.RateLimiting.Tests/RateLimiterResilienceStrategyBuilderExtensionsTests.cs b/test/Polly.RateLimiting.Tests/RateLimiterResilienceStrategyBuilderExtensionsTests.cs index 3ee05e12505..2321cb4b1dc 100644 --- a/test/Polly.RateLimiting.Tests/RateLimiterResilienceStrategyBuilderExtensionsTests.cs +++ b/test/Polly.RateLimiting.Tests/RateLimiterResilienceStrategyBuilderExtensionsTests.cs @@ -73,14 +73,20 @@ public void AddRateLimiter_AllExtensions_Ok() [Fact] public void AddRateLimiter_Ok() { - new ResilienceStrategyBuilder().AddRateLimiter(new RateLimiterStrategyOptions + using var limiter = new ConcurrencyLimiter(new ConcurrencyLimiterOptions { - RateLimiter = new ConcurrencyLimiter(new ConcurrencyLimiterOptions + QueueLimit = 10, + PermitLimit = 10 + }); + + new ResilienceStrategyBuilder() + .AddRateLimiter(new RateLimiterStrategyOptions { - QueueLimit = 10, - PermitLimit = 10 + RateLimiter = ResilienceRateLimiter.Create(limiter) }) - }).Build().Should().BeOfType(); + .Build() + .Should() + .BeOfType(); } [Fact] @@ -117,7 +123,7 @@ public void AddRateLimiter_Options_Ok() var strategy = new ResilienceStrategyBuilder() .AddRateLimiter(new RateLimiterStrategyOptions { - RateLimiter = Mock.Of() + RateLimiter = ResilienceRateLimiter.Create(Mock.Of()) }) .Build(); @@ -141,13 +147,13 @@ private static void AssertRateLimiter(ResilienceStrategyBuilder builder, bo strategy.OnLeaseRejected.Should().BeNull(); } - assertLimiter?.Invoke(strategy.Limiter); + assertLimiter?.Invoke(strategy.Limiter.Limiter!); } private static void AssertConcurrencyLimiter(ResilienceStrategyBuilder builder, bool hasEvents) { var strategy = GetResilienceStrategy(builder.Build()); - strategy.Limiter.Should().BeOfType(); + strategy.Limiter.Limiter.Should().BeOfType(); if (hasEvents) { diff --git a/test/Polly.RateLimiting.Tests/RateLimiterResilienceStrategyTests.cs b/test/Polly.RateLimiting.Tests/RateLimiterResilienceStrategyTests.cs index e2c4cfe3bc5..583cf49d2f8 100644 --- a/test/Polly.RateLimiting.Tests/RateLimiterResilienceStrategyTests.cs +++ b/test/Polly.RateLimiting.Tests/RateLimiterResilienceStrategyTests.cs @@ -104,7 +104,7 @@ private RateLimiterResilienceStrategy Create() return (RateLimiterResilienceStrategy)builder .AddRateLimiter(new RateLimiterStrategyOptions { - RateLimiter = _limiter.Object, + RateLimiter = ResilienceRateLimiter.Create(_limiter.Object), OnRejected = _event }) .Build(); diff --git a/test/Polly.RateLimiting.Tests/ResilienceRateLimiterTests.cs b/test/Polly.RateLimiting.Tests/ResilienceRateLimiterTests.cs new file mode 100644 index 00000000000..33ee03b619d --- /dev/null +++ b/test/Polly.RateLimiting.Tests/ResilienceRateLimiterTests.cs @@ -0,0 +1,38 @@ +using System.Threading.RateLimiting; +using System.Threading.Tasks; +using Moq; +using Moq.Protected; + +namespace Polly.RateLimiting.Tests; + +public class ResilienceRateLimiterTests +{ + [Fact] + public async Task Create_RateLimiter_Ok() + { + var lease = Mock.Of(); + var limiterMock = new Mock(MockBehavior.Strict); + limiterMock.Protected().Setup>("AcquireAsyncCore", 1, default(CancellationToken)).ReturnsAsync(lease); + + var limiter = ResilienceRateLimiter.Create(limiterMock.Object); + + (await limiter.AcquireAsync(ResilienceContext.Get())).Should().Be(lease); + limiter.Limiter.Should().NotBeNull(); + limiterMock.VerifyAll(); + } + + [Fact] + public async Task Create_PartitionedRateLimiter_Ok() + { + var context = ResilienceContext.Get(); + var lease = Mock.Of(); + var limiterMock = new Mock>(MockBehavior.Strict); + limiterMock.Protected().Setup>("AcquireAsyncCore", context, 1, default(CancellationToken)).ReturnsAsync(lease); + + var limiter = ResilienceRateLimiter.Create(limiterMock.Object); + + (await limiter.AcquireAsync(context)).Should().Be(lease); + limiter.PartitionedLimiter.Should().NotBeNull(); + limiterMock.VerifyAll(); + } +}