diff --git a/misk-rate-limiting-bucket4j-redis/src/test/kotlin/misk/ratelimiting/bucket4j/redis/RedisRateLimiterTest.kt b/misk-rate-limiting-bucket4j-redis/src/test/kotlin/misk/ratelimiting/bucket4j/redis/RedisRateLimiterTest.kt index b7a5438d2bb..38b8d5a08c9 100644 --- a/misk-rate-limiting-bucket4j-redis/src/test/kotlin/misk/ratelimiting/bucket4j/redis/RedisRateLimiterTest.kt +++ b/misk-rate-limiting-bucket4j-redis/src/test/kotlin/misk/ratelimiting/bucket4j/redis/RedisRateLimiterTest.kt @@ -19,6 +19,7 @@ import redis.clients.jedis.ConnectionPoolConfig import wisp.deployment.TESTING import wisp.ratelimiting.RateLimiter import wisp.ratelimiting.testing.TestRateLimitConfig +import wisp.ratelimiting.testing.TestRateLimitConfigRefillGreedily @MiskTest(startService = true) class RedisRateLimiterTest { @@ -100,6 +101,70 @@ class RedisRateLimiterTest { assertThat(counter).isEqualTo(10) } + @Test + fun `test bucket refilled at the end of the interval after consuming all tokens`() { + val increment = TestRateLimitConfig.refillPeriod.dividedBy(5) + repeat(5) { + val result = rateLimiter.consumeToken(KEY, TestRateLimitConfig) + assertThat(result.didConsume).isTrue() + assertThat(result.remaining).isEqualTo(TestRateLimitConfig.capacity - 1 - it) + fakeClock.add(increment) + } + + assertThat(rateLimiter.availableTokens(KEY, TestRateLimitConfig)).isEqualTo(5L) + } + + @Test + fun `test bucket refilled at the end of the interval after consuming some tokens`() { + val increment = TestRateLimitConfig.refillPeriod.dividedBy(5) + repeat(3) { + val result = rateLimiter.consumeToken(KEY, TestRateLimitConfig) + assertThat(result.didConsume).isTrue() + assertThat(result.remaining).isEqualTo(TestRateLimitConfig.capacity - 1 - it) + fakeClock.add(increment) + } + + assertThat(rateLimiter.availableTokens(KEY, TestRateLimitConfig)).isEqualTo(2L) + fakeClock.add(increment) + assertThat(rateLimiter.availableTokens(KEY, TestRateLimitConfig)).isEqualTo(2L) + fakeClock.add(increment) // the clock now has past the end of the interval + assertThat(rateLimiter.availableTokens(KEY, TestRateLimitConfig)).isEqualTo(5L) + } + + @Test + fun `test bucket refilled continuously after each increment`() { + repeat(5) { + val result = rateLimiter.consumeToken(KEY, TestRateLimitConfigRefillGreedily) + assertThat(result.didConsume).isTrue() + assertThat(result.remaining).isEqualTo(TestRateLimitConfigRefillGreedily.capacity - 1 - it) + } + assertThat(rateLimiter.availableTokens(KEY, TestRateLimitConfigRefillGreedily)).isEqualTo(0L) + assertThat(rateLimiter.consumeToken(KEY, TestRateLimitConfigRefillGreedily).didConsume).isFalse() + + val increment = TestRateLimitConfigRefillGreedily.refillPeriod.dividedBy(5) + repeat(5) { + // One token is added back after each increment + fakeClock.add(increment) + assertThat(rateLimiter.availableTokens(KEY, TestRateLimitConfigRefillGreedily)).isEqualTo(it + 1L) + } + } + + @Test + fun `test bucket refilled continuously`() { + val increment = TestRateLimitConfigRefillGreedily.refillPeriod.dividedBy(5) + repeat(5) { + val result = rateLimiter.consumeToken(KEY, TestRateLimitConfigRefillGreedily) + assertThat(result.didConsume).isTrue() + assertThat(result.remaining).isEqualTo(TestRateLimitConfigRefillGreedily.capacity - 1) + + assertThat(rateLimiter.availableTokens(KEY, TestRateLimitConfigRefillGreedily)).isEqualTo(4L) + fakeClock.add(increment) + assertThat(rateLimiter.availableTokens(KEY, TestRateLimitConfigRefillGreedily)).isEqualTo(5L) + } + + assertThat(rateLimiter.availableTokens(KEY, TestRateLimitConfigRefillGreedily)).isEqualTo(5L) + } + companion object { private const val KEY = "test_key" } diff --git a/wisp/wisp-rate-limiting/api/wisp-rate-limiting.api b/wisp/wisp-rate-limiting/api/wisp-rate-limiting.api index 08c5913d157..45c7eaf8128 100644 --- a/wisp/wisp-rate-limiting/api/wisp-rate-limiting.api +++ b/wisp/wisp-rate-limiting/api/wisp-rate-limiting.api @@ -1,12 +1,22 @@ +public final class wisp/ratelimiting/RateLimitBucketRefillStrategy : java/lang/Enum { + public static final field GREEDY Lwisp/ratelimiting/RateLimitBucketRefillStrategy; + public static final field INTERVAL Lwisp/ratelimiting/RateLimitBucketRefillStrategy; + public static fun getEntries ()Lkotlin/enums/EnumEntries; + public static fun valueOf (Ljava/lang/String;)Lwisp/ratelimiting/RateLimitBucketRefillStrategy; + public static fun values ()[Lwisp/ratelimiting/RateLimitBucketRefillStrategy; +} + public abstract interface class wisp/ratelimiting/RateLimitConfiguration { public abstract fun getCapacity ()J public abstract fun getName ()Ljava/lang/String; public abstract fun getRefillAmount ()J public abstract fun getRefillPeriod ()Ljava/time/Duration; + public abstract fun getRefillStrategy ()Lwisp/ratelimiting/RateLimitBucketRefillStrategy; public abstract fun getVersion ()Ljava/lang/Long; } public final class wisp/ratelimiting/RateLimitConfiguration$DefaultImpls { + public static fun getRefillStrategy (Lwisp/ratelimiting/RateLimitConfiguration;)Lwisp/ratelimiting/RateLimitBucketRefillStrategy; public static fun getVersion (Lwisp/ratelimiting/RateLimitConfiguration;)Ljava/lang/Long; } @@ -124,6 +134,17 @@ public final class wisp/ratelimiting/testing/TestRateLimitConfig : wisp/ratelimi public fun getName ()Ljava/lang/String; public fun getRefillAmount ()J public fun getRefillPeriod ()Ljava/time/Duration; + public fun getRefillStrategy ()Lwisp/ratelimiting/RateLimitBucketRefillStrategy; + public fun getVersion ()Ljava/lang/Long; +} + +public final class wisp/ratelimiting/testing/TestRateLimitConfigRefillGreedily : wisp/ratelimiting/RateLimitConfiguration { + public static final field INSTANCE Lwisp/ratelimiting/testing/TestRateLimitConfigRefillGreedily; + public fun getCapacity ()J + public fun getName ()Ljava/lang/String; + public fun getRefillAmount ()J + public fun getRefillPeriod ()Ljava/time/Duration; + public fun getRefillStrategy ()Lwisp/ratelimiting/RateLimitBucketRefillStrategy; public fun getVersion ()Ljava/lang/Long; } diff --git a/wisp/wisp-rate-limiting/bucket4j/src/main/kotlin/wisp/ratelimiting/bucket4j/Bucket4jRateLimiter.kt b/wisp/wisp-rate-limiting/bucket4j/src/main/kotlin/wisp/ratelimiting/bucket4j/Bucket4jRateLimiter.kt index 90ab4a10ff6..aa5700a5912 100644 --- a/wisp/wisp-rate-limiting/bucket4j/src/main/kotlin/wisp/ratelimiting/bucket4j/Bucket4jRateLimiter.kt +++ b/wisp/wisp-rate-limiting/bucket4j/src/main/kotlin/wisp/ratelimiting/bucket4j/Bucket4jRateLimiter.kt @@ -9,6 +9,7 @@ import io.github.bucket4j.distributed.BucketProxy import io.github.bucket4j.distributed.proxy.ProxyManager import io.micrometer.core.instrument.MeterRegistry import io.micrometer.core.instrument.Metrics +import wisp.ratelimiting.RateLimitBucketRefillStrategy import wisp.ratelimiting.RateLimitConfiguration import wisp.ratelimiting.RateLimiter import wisp.ratelimiting.RateLimiterMetrics @@ -115,10 +116,19 @@ class Bucket4jRateLimiter @JvmOverloads constructor( } private fun RateLimitConfiguration.toBandwidth(): Bandwidth { - return Bandwidth.builder() - .capacity(capacity) - .refillIntervally(refillAmount, refillPeriod) - .initialTokens(capacity) - .build() + return if (refillStrategy == RateLimitBucketRefillStrategy.GREEDY) { + Bandwidth.builder() + .capacity(capacity) + .refillGreedy(refillAmount, refillPeriod) + .initialTokens(capacity) + .build() + } + else { + Bandwidth.builder() + .capacity(capacity) + .refillIntervally(refillAmount, refillPeriod) + .initialTokens(capacity) + .build() + } } } diff --git a/wisp/wisp-rate-limiting/src/main/kotlin/wisp/ratelimiting/RateLimitBucketRefillStrategy.kt b/wisp/wisp-rate-limiting/src/main/kotlin/wisp/ratelimiting/RateLimitBucketRefillStrategy.kt new file mode 100644 index 00000000000..fcae22ba51b --- /dev/null +++ b/wisp/wisp-rate-limiting/src/main/kotlin/wisp/ratelimiting/RateLimitBucketRefillStrategy.kt @@ -0,0 +1,13 @@ +package wisp.ratelimiting + +enum class RateLimitBucketRefillStrategy { + /* + * The bucket will be filled continuously at the specified rate + */ + GREEDY, + /* + * The bucket will be topped off at the end of the interval, + * no matter when the last token was consumed. + */ + INTERVAL +} diff --git a/wisp/wisp-rate-limiting/src/main/kotlin/wisp/ratelimiting/RateLimitConfiguration.kt b/wisp/wisp-rate-limiting/src/main/kotlin/wisp/ratelimiting/RateLimitConfiguration.kt index d4c0aba5d45..2b3ee27fbe1 100644 --- a/wisp/wisp-rate-limiting/src/main/kotlin/wisp/ratelimiting/RateLimitConfiguration.kt +++ b/wisp/wisp-rate-limiting/src/main/kotlin/wisp/ratelimiting/RateLimitConfiguration.kt @@ -34,4 +34,7 @@ interface RateLimitConfiguration { */ val version: Long? get() = null // returns null to be backward compatible + + val refillStrategy: RateLimitBucketRefillStrategy + get() = RateLimitBucketRefillStrategy.INTERVAL } diff --git a/wisp/wisp-rate-limiting/src/testFixtures/kotlin/wisp/ratelimiting/testing/TestRateLimitConfig.kt b/wisp/wisp-rate-limiting/src/testFixtures/kotlin/wisp/ratelimiting/testing/TestRateLimitConfig.kt index 3e61b3953be..a20eff783d7 100644 --- a/wisp/wisp-rate-limiting/src/testFixtures/kotlin/wisp/ratelimiting/testing/TestRateLimitConfig.kt +++ b/wisp/wisp-rate-limiting/src/testFixtures/kotlin/wisp/ratelimiting/testing/TestRateLimitConfig.kt @@ -1,5 +1,6 @@ package wisp.ratelimiting.testing +import wisp.ratelimiting.RateLimitBucketRefillStrategy import wisp.ratelimiting.RateLimitConfiguration import java.time.Duration @@ -12,3 +13,14 @@ object TestRateLimitConfig : RateLimitConfiguration { override val refillAmount = BUCKET_CAPACITY override val refillPeriod: Duration = REFILL_DURATION } + +object TestRateLimitConfigRefillGreedily : RateLimitConfiguration { + private const val BUCKET_CAPACITY = 5L + private val REFILL_DURATION: Duration = Duration.ofSeconds(30L) + + override val capacity = BUCKET_CAPACITY + override val name = "test_configuration_refill_greedily" + override val refillAmount = BUCKET_CAPACITY + override val refillPeriod: Duration = REFILL_DURATION + override val refillStrategy: RateLimitBucketRefillStrategy = RateLimitBucketRefillStrategy.GREEDY +}