From 0082137061ec9ff14b25e258673c08e9714343e1 Mon Sep 17 00:00:00 2001
From: Martin Tomka <martintomka@microsoft.com>
Date: Tue, 4 Jul 2023 06:55:09 +0200
Subject: [PATCH] Add support for partitioned rate limiter

---
 .../RateLimiterResilienceStrategy.cs          |  8 ++--
 ...iterResilienceStrategyBuilderExtensions.cs |  4 +-
 .../RateLimiterStrategyOptions.cs             |  2 +-
 .../ResilienceRateLimiter.cs                  | 45 +++++++++++++++++++
 ...esilienceStrategyBuilderExtensionsTests.cs | 22 +++++----
 .../RateLimiterResilienceStrategyTests.cs     |  2 +-
 .../ResilienceRateLimiterTests.cs             | 38 ++++++++++++++++
 7 files changed, 104 insertions(+), 17 deletions(-)
 create mode 100644 src/Polly.RateLimiting/ResilienceRateLimiter.cs
 create mode 100644 test/Polly.RateLimiting.Tests/ResilienceRateLimiterTests.cs

diff --git a/src/Polly.RateLimiting/RateLimiterResilienceStrategy.cs b/src/Polly.RateLimiting/RateLimiterResilienceStrategy.cs
index 3a284d12864..2cbc0ba2a78 100644
--- a/src/Polly.RateLimiting/RateLimiterResilienceStrategy.cs
+++ b/src/Polly.RateLimiting/RateLimiterResilienceStrategy.cs
@@ -8,7 +8,7 @@ internal sealed class RateLimiterResilienceStrategy : ResilienceStrategy
     private readonly ResilienceStrategyTelemetry _telemetry;
 
     public RateLimiterResilienceStrategy(
-        RateLimiter limiter,
+        ResilienceRateLimiter limiter,
         Func<OnRateLimiterRejectedArguments, ValueTask>? onRejected,
         ResilienceStrategyTelemetry telemetry)
     {
@@ -18,7 +18,7 @@ public RateLimiterResilienceStrategy(
         _telemetry = telemetry;
     }
 
-    public RateLimiter Limiter { get; }
+    public ResilienceRateLimiter Limiter { get; }
 
     public Func<OnRateLimiterRejectedArguments, ValueTask>? OnLeaseRejected { get; }
 
@@ -27,9 +27,7 @@ protected override async ValueTask<Outcome<TResult>> ExecuteCoreAsync<TResult, T
         ResilienceContext context,
         TState state)
     {
-        using var lease = await Limiter.AcquireAsync(
-            permitCount: 1,
-            context.CancellationToken).ConfigureAwait(context.ContinueOnCapturedContext);
+        using var lease = await Limiter.AcquireAsync(context).ConfigureAwait(context.ContinueOnCapturedContext);
 
         if (lease.IsAcquired)
         {
diff --git a/src/Polly.RateLimiting/RateLimiterResilienceStrategyBuilderExtensions.cs b/src/Polly.RateLimiting/RateLimiterResilienceStrategyBuilderExtensions.cs
index 6fb9bd756eb..a51f12a8c17 100644
--- a/src/Polly.RateLimiting/RateLimiterResilienceStrategyBuilderExtensions.cs
+++ b/src/Polly.RateLimiting/RateLimiterResilienceStrategyBuilderExtensions.cs
@@ -78,7 +78,7 @@ public static TBuilder AddRateLimiter<TBuilder>(
 
         return builder.AddRateLimiter(new RateLimiterStrategyOptions
         {
-            RateLimiter = limiter,
+            RateLimiter = ResilienceRateLimiter.Create(limiter),
         });
     }
 
@@ -104,7 +104,7 @@ public static TBuilder AddRateLimiter<TBuilder>(
             context =>
             {
                 return new RateLimiterResilienceStrategy(
-                    options.RateLimiter ?? new ConcurrencyLimiter(options.DefaultRateLimiterOptions),
+                    options.RateLimiter ?? ResilienceRateLimiter.Create(new ConcurrencyLimiter(options.DefaultRateLimiterOptions)),
                     options.OnRejected,
                     context.Telemetry);
             },
diff --git a/src/Polly.RateLimiting/RateLimiterStrategyOptions.cs b/src/Polly.RateLimiting/RateLimiterStrategyOptions.cs
index 75492c25ba6..b235287afbb 100644
--- a/src/Polly.RateLimiting/RateLimiterStrategyOptions.cs
+++ b/src/Polly.RateLimiting/RateLimiterStrategyOptions.cs
@@ -46,5 +46,5 @@ public class RateLimiterStrategyOptions : ResilienceStrategyOptions
     /// Defaults to <see langword="null"/>. If this property is <see langword="null"/>,
     /// then the strategy will use a <see cref="ConcurrencyLimiter"/> created using <see cref="DefaultRateLimiterOptions"/>.
     /// </remarks>
-    public RateLimiter? RateLimiter { get; set; }
+    public ResilienceRateLimiter? RateLimiter { get; set; }
 }
diff --git a/src/Polly.RateLimiting/ResilienceRateLimiter.cs b/src/Polly.RateLimiting/ResilienceRateLimiter.cs
new file mode 100644
index 00000000000..629694a58c2
--- /dev/null
+++ b/src/Polly.RateLimiting/ResilienceRateLimiter.cs
@@ -0,0 +1,45 @@
+using System.Threading.RateLimiting;
+
+namespace Polly.RateLimiting;
+
+/// <summary>
+/// This class is just a simple adapter for the built-in limiters in the <c>System.Threading.RateLimiting</c> namespace.
+/// </summary>
+public sealed class ResilienceRateLimiter
+{
+    private ResilienceRateLimiter(RateLimiter? limiter, PartitionedRateLimiter<ResilienceContext>? partitionedLimiter)
+    {
+        Limiter = limiter;
+        PartitionedLimiter = partitionedLimiter;
+    }
+
+    /// <summary>
+    /// Creates an instance of <see cref="ResilienceRateLimiter"/> from <paramref name="rateLimiter"/>.
+    /// </summary>
+    /// <param name="rateLimiter">The rate limiter instance.</param>
+    /// <returns>An instance of <see cref="ResilienceRateLimiter"/>.</returns>
+    public static ResilienceRateLimiter Create(RateLimiter rateLimiter) => new(Guard.NotNull(rateLimiter), null);
+
+    /// <summary>
+    /// Creates an instance of <see cref="ResilienceRateLimiter"/> from partitioned <paramref name="rateLimiter"/>.
+    /// </summary>
+    /// <param name="rateLimiter">The rate limiter instance.</param>
+    /// <returns>An instance of <see cref="ResilienceRateLimiter"/>.</returns>
+    public static ResilienceRateLimiter Create(PartitionedRateLimiter<ResilienceContext> rateLimiter) => new(null, Guard.NotNull(rateLimiter));
+
+    internal RateLimiter? Limiter { get; }
+
+    internal PartitionedRateLimiter<ResilienceContext>? PartitionedLimiter { get; }
+
+    internal ValueTask<RateLimitLease> AcquireAsync(ResilienceContext context)
+    {
+        if (PartitionedLimiter is not null)
+        {
+            return PartitionedLimiter.AcquireAsync(context, permitCount: 1, context.CancellationToken);
+        }
+        else
+        {
+            return Limiter!.AcquireAsync(permitCount: 1, context.CancellationToken);
+        }
+    }
+}
diff --git a/test/Polly.RateLimiting.Tests/RateLimiterResilienceStrategyBuilderExtensionsTests.cs b/test/Polly.RateLimiting.Tests/RateLimiterResilienceStrategyBuilderExtensionsTests.cs
index 3ee05e12505..2321cb4b1dc 100644
--- a/test/Polly.RateLimiting.Tests/RateLimiterResilienceStrategyBuilderExtensionsTests.cs
+++ b/test/Polly.RateLimiting.Tests/RateLimiterResilienceStrategyBuilderExtensionsTests.cs
@@ -73,14 +73,20 @@ public void AddRateLimiter_AllExtensions_Ok()
     [Fact]
     public void AddRateLimiter_Ok()
     {
-        new ResilienceStrategyBuilder().AddRateLimiter(new RateLimiterStrategyOptions
+        using var limiter = new ConcurrencyLimiter(new ConcurrencyLimiterOptions
         {
-            RateLimiter = new ConcurrencyLimiter(new ConcurrencyLimiterOptions
+            QueueLimit = 10,
+            PermitLimit = 10
+        });
+
+        new ResilienceStrategyBuilder()
+            .AddRateLimiter(new RateLimiterStrategyOptions
             {
-                QueueLimit = 10,
-                PermitLimit = 10
+                RateLimiter = ResilienceRateLimiter.Create(limiter)
             })
-        }).Build().Should().BeOfType<RateLimiterResilienceStrategy>();
+            .Build()
+            .Should()
+            .BeOfType<RateLimiterResilienceStrategy>();
     }
 
     [Fact]
@@ -117,7 +123,7 @@ public void AddRateLimiter_Options_Ok()
         var strategy = new ResilienceStrategyBuilder()
             .AddRateLimiter(new RateLimiterStrategyOptions
             {
-                RateLimiter = Mock.Of<RateLimiter>()
+                RateLimiter = ResilienceRateLimiter.Create(Mock.Of<RateLimiter>())
             })
             .Build();
 
@@ -141,13 +147,13 @@ private static void AssertRateLimiter(ResilienceStrategyBuilder<int> builder, bo
             strategy.OnLeaseRejected.Should().BeNull();
         }
 
-        assertLimiter?.Invoke(strategy.Limiter);
+        assertLimiter?.Invoke(strategy.Limiter.Limiter!);
     }
 
     private static void AssertConcurrencyLimiter(ResilienceStrategyBuilder<int> builder, bool hasEvents)
     {
         var strategy = GetResilienceStrategy(builder.Build());
-        strategy.Limiter.Should().BeOfType<ConcurrencyLimiter>();
+        strategy.Limiter.Limiter.Should().BeOfType<ConcurrencyLimiter>();
 
         if (hasEvents)
         {
diff --git a/test/Polly.RateLimiting.Tests/RateLimiterResilienceStrategyTests.cs b/test/Polly.RateLimiting.Tests/RateLimiterResilienceStrategyTests.cs
index e2c4cfe3bc5..583cf49d2f8 100644
--- a/test/Polly.RateLimiting.Tests/RateLimiterResilienceStrategyTests.cs
+++ b/test/Polly.RateLimiting.Tests/RateLimiterResilienceStrategyTests.cs
@@ -104,7 +104,7 @@ private RateLimiterResilienceStrategy Create()
         return (RateLimiterResilienceStrategy)builder
             .AddRateLimiter(new RateLimiterStrategyOptions
             {
-                RateLimiter = _limiter.Object,
+                RateLimiter = ResilienceRateLimiter.Create(_limiter.Object),
                 OnRejected = _event
             })
             .Build();
diff --git a/test/Polly.RateLimiting.Tests/ResilienceRateLimiterTests.cs b/test/Polly.RateLimiting.Tests/ResilienceRateLimiterTests.cs
new file mode 100644
index 00000000000..33ee03b619d
--- /dev/null
+++ b/test/Polly.RateLimiting.Tests/ResilienceRateLimiterTests.cs
@@ -0,0 +1,38 @@
+using System.Threading.RateLimiting;
+using System.Threading.Tasks;
+using Moq;
+using Moq.Protected;
+
+namespace Polly.RateLimiting.Tests;
+
+public class ResilienceRateLimiterTests
+{
+    [Fact]
+    public async Task Create_RateLimiter_Ok()
+    {
+        var lease = Mock.Of<RateLimitLease>();
+        var limiterMock = new Mock<RateLimiter>(MockBehavior.Strict);
+        limiterMock.Protected().Setup<ValueTask<RateLimitLease>>("AcquireAsyncCore", 1, default(CancellationToken)).ReturnsAsync(lease);
+
+        var limiter = ResilienceRateLimiter.Create(limiterMock.Object);
+
+        (await limiter.AcquireAsync(ResilienceContext.Get())).Should().Be(lease);
+        limiter.Limiter.Should().NotBeNull();
+        limiterMock.VerifyAll();
+    }
+
+    [Fact]
+    public async Task Create_PartitionedRateLimiter_Ok()
+    {
+        var context = ResilienceContext.Get();
+        var lease = Mock.Of<RateLimitLease>();
+        var limiterMock = new Mock<PartitionedRateLimiter<ResilienceContext>>(MockBehavior.Strict);
+        limiterMock.Protected().Setup<ValueTask<RateLimitLease>>("AcquireAsyncCore", context, 1, default(CancellationToken)).ReturnsAsync(lease);
+
+        var limiter = ResilienceRateLimiter.Create(limiterMock.Object);
+
+        (await limiter.AcquireAsync(context)).Should().Be(lease);
+        limiter.PartitionedLimiter.Should().NotBeNull();
+        limiterMock.VerifyAll();
+    }
+}