diff --git a/include/envoy/common/token_bucket.h b/include/envoy/common/token_bucket.h index 59a21642213d..bf84168a27f5 100644 --- a/include/envoy/common/token_bucket.h +++ b/include/envoy/common/token_bucket.h @@ -18,16 +18,19 @@ class TokenBucket { virtual ~TokenBucket() {} /** - * @param tokens supplies the number of tokens to be consumed. Default is 1. - * @return true if bucket is not empty, otherwise it returns false. + * @param tokens supplies the number of tokens to be consumed. + * @param allow_partial supplies whether the token bucket will allow consumption of less tokens + * than asked for. If allow_partial is true, the bucket contains 5 tokens, + * and the caller asks for 3, the bucket will return 3 tokens. + * @return the number of tokens actually consumed. */ - virtual bool consume(uint64_t tokens = 1) PURE; + virtual uint64_t consume(uint64_t tokens, bool allow_partial) PURE; /** * @return returns the approximate time until a next token is available. Currently it * returns the upper bound on the amount of time until a next token is available. */ - virtual uint64_t nextTokenAvailableMs() PURE; + virtual std::chrono::milliseconds nextTokenAvailable() PURE; }; typedef std::unique_ptr TokenBucketPtr; diff --git a/source/common/common/token_bucket_impl.cc b/source/common/common/token_bucket_impl.cc index 539044d8a190..c92f1c7db01b 100644 --- a/source/common/common/token_bucket_impl.cc +++ b/source/common/common/token_bucket_impl.cc @@ -8,7 +8,7 @@ TokenBucketImpl::TokenBucketImpl(uint64_t max_tokens, TimeSource& time_source, d : max_tokens_(max_tokens), fill_rate_(std::abs(fill_rate)), tokens_(max_tokens), last_fill_(time_source.monotonicTime()), time_source_(time_source) {} -bool TokenBucketImpl::consume(uint64_t tokens) { +uint64_t TokenBucketImpl::consume(uint64_t tokens, bool allow_partial) { if (tokens_ < max_tokens_) { const auto time_now = time_source_.monotonicTime(); tokens_ = std::min((std::chrono::duration(time_now - last_fill_).count() * fill_rate_) + @@ -17,21 +17,25 @@ bool TokenBucketImpl::consume(uint64_t tokens) { last_fill_ = time_now; } + if (allow_partial) { + tokens = std::min(tokens, static_cast(std::floor(tokens_))); + } + if (tokens_ < tokens) { return false; } tokens_ -= tokens; - return true; + return tokens; } -uint64_t TokenBucketImpl::nextTokenAvailableMs() { +std::chrono::milliseconds TokenBucketImpl::nextTokenAvailable() { // If there are tokens available, return immediately. if (tokens_ >= 1) { - return 0; + return std::chrono::milliseconds(0); } // TODO(ramaraochavali): implement a more precise way that works for very low rate limits. - return (1 / fill_rate_) * 1000; + return std::chrono::milliseconds(static_cast(std::ceil((1 / fill_rate_) * 1000))); } } // namespace Envoy diff --git a/source/common/common/token_bucket_impl.h b/source/common/common/token_bucket_impl.h index 4176370a6068..7daa3fb8e79b 100644 --- a/source/common/common/token_bucket_impl.h +++ b/source/common/common/token_bucket_impl.h @@ -20,9 +20,9 @@ class TokenBucketImpl : public TokenBucket { */ explicit TokenBucketImpl(uint64_t max_tokens, TimeSource& time_source, double fill_rate = 1); - bool consume(uint64_t tokens = 1) override; - - uint64_t nextTokenAvailableMs() override; + // TokenBucket + uint64_t consume(uint64_t tokens, bool allow_partial) override; + std::chrono::milliseconds nextTokenAvailable() override; private: const double max_tokens_; diff --git a/source/common/config/grpc_stream.h b/source/common/config/grpc_stream.h index 2d03e4e88c62..7329f47f841e 100644 --- a/source/common/config/grpc_stream.h +++ b/source/common/config/grpc_stream.h @@ -126,14 +126,13 @@ class GrpcStream : public Grpc::TypedAsyncStreamCallbacks, } bool checkRateLimitAllowsDrain() { - if (!rate_limiting_enabled_ || limit_request_->consume()) { + if (!rate_limiting_enabled_ || limit_request_->consume(1, false)) { return true; } ASSERT(drain_request_timer_ != nullptr); control_plane_stats_.rate_limit_enforced_.inc(); // Enable the drain request timer. - drain_request_timer_->enableTimer( - std::chrono::milliseconds(limit_request_->nextTokenAvailableMs())); + drain_request_timer_->enableTimer(limit_request_->nextTokenAvailable()); return false; } diff --git a/test/common/common/token_bucket_impl_test.cc b/test/common/common/token_bucket_impl_test.cc index 5d4e811e473e..4a44acd84701 100644 --- a/test/common/common/token_bucket_impl_test.cc +++ b/test/common/common/token_bucket_impl_test.cc @@ -17,60 +17,72 @@ class TokenBucketImplTest : public testing::Test { TEST_F(TokenBucketImplTest, Initialization) { TokenBucketImpl token_bucket{1, time_system_, -1.0}; - EXPECT_TRUE(token_bucket.consume()); - EXPECT_FALSE(token_bucket.consume()); + EXPECT_EQ(1, token_bucket.consume(1, false)); + EXPECT_EQ(0, token_bucket.consume(1, false)); } // Verifies TokenBucket's maximum capacity. TEST_F(TokenBucketImplTest, MaxBucketSize) { TokenBucketImpl token_bucket{3, time_system_, 1}; - EXPECT_TRUE(token_bucket.consume(3)); + EXPECT_EQ(3, token_bucket.consume(3, false)); time_system_.setMonotonicTime(std::chrono::seconds(10)); - EXPECT_FALSE(token_bucket.consume(4)); - EXPECT_TRUE(token_bucket.consume(3)); + EXPECT_EQ(0, token_bucket.consume(4, false)); + EXPECT_EQ(3, token_bucket.consume(3, false)); } // Verifies that TokenBucket can consume tokens. TEST_F(TokenBucketImplTest, Consume) { TokenBucketImpl token_bucket{10, time_system_, 1}; - EXPECT_FALSE(token_bucket.consume(20)); - EXPECT_TRUE(token_bucket.consume(9)); + EXPECT_EQ(0, token_bucket.consume(20, false)); + EXPECT_EQ(9, token_bucket.consume(9, false)); - EXPECT_TRUE(token_bucket.consume()); + EXPECT_EQ(1, token_bucket.consume(1, false)); time_system_.setMonotonicTime(std::chrono::milliseconds(999)); - EXPECT_FALSE(token_bucket.consume()); + EXPECT_EQ(0, token_bucket.consume(1, false)); time_system_.setMonotonicTime(std::chrono::milliseconds(5999)); - EXPECT_FALSE(token_bucket.consume(6)); + EXPECT_EQ(0, token_bucket.consume(6, false)); time_system_.setMonotonicTime(std::chrono::milliseconds(6000)); - EXPECT_TRUE(token_bucket.consume(6)); - EXPECT_FALSE(token_bucket.consume()); + EXPECT_EQ(6, token_bucket.consume(6, false)); + EXPECT_EQ(0, token_bucket.consume(1, false)); } // Verifies that TokenBucket can refill tokens. TEST_F(TokenBucketImplTest, Refill) { TokenBucketImpl token_bucket{1, time_system_, 0.5}; - EXPECT_TRUE(token_bucket.consume()); + EXPECT_EQ(1, token_bucket.consume(1, false)); time_system_.setMonotonicTime(std::chrono::milliseconds(500)); - EXPECT_FALSE(token_bucket.consume()); + EXPECT_EQ(0, token_bucket.consume(1, false)); time_system_.setMonotonicTime(std::chrono::milliseconds(1500)); - EXPECT_FALSE(token_bucket.consume()); + EXPECT_EQ(0, token_bucket.consume(1, false)); time_system_.setMonotonicTime(std::chrono::milliseconds(2000)); - EXPECT_TRUE(token_bucket.consume()); + EXPECT_EQ(1, token_bucket.consume(1, false)); } TEST_F(TokenBucketImplTest, NextTokenAvailable) { TokenBucketImpl token_bucket{10, time_system_, 5}; - EXPECT_TRUE(token_bucket.consume(9)); - EXPECT_EQ(0, token_bucket.nextTokenAvailableMs()); - EXPECT_TRUE(token_bucket.consume()); - EXPECT_FALSE(token_bucket.consume()); - EXPECT_EQ(200, token_bucket.nextTokenAvailableMs()); + EXPECT_EQ(9, token_bucket.consume(9, false)); + EXPECT_EQ(std::chrono::milliseconds(0), token_bucket.nextTokenAvailable()); + EXPECT_EQ(1, token_bucket.consume(1, false)); + EXPECT_EQ(0, token_bucket.consume(1, false)); + EXPECT_EQ(std::chrono::milliseconds(200), token_bucket.nextTokenAvailable()); +} + +// Test partial consumption of tokens. +TEST_F(TokenBucketImplTest, PartialConsumption) { + TokenBucketImpl token_bucket{16, time_system_, 16}; + EXPECT_EQ(16, token_bucket.consume(18, true)); + EXPECT_EQ(std::chrono::milliseconds(63), token_bucket.nextTokenAvailable()); + time_system_.sleep(std::chrono::milliseconds(62)); + EXPECT_EQ(0, token_bucket.consume(1, true)); + time_system_.sleep(std::chrono::milliseconds(1)); + EXPECT_EQ(1, token_bucket.consume(2, true)); + EXPECT_EQ(std::chrono::milliseconds(63), token_bucket.nextTokenAvailable()); } } // namespace Envoy