From 89dc3d3dd11477c06db703fb24ffc5c6adfd399a Mon Sep 17 00:00:00 2001 From: hedon954 <171725713@qq.com> Date: Mon, 9 Sep 2024 18:01:09 +0800 Subject: [PATCH] feat(rl): add `sliding window count` rate limiter --- devkit-rl/Cargo.toml | 8 ++ .../benches/sliding_window_count_bench.rs | 16 +++ devkit-rl/benches/sliding_window_log_bench.rs | 16 +++ devkit-rl/src/lib.rs | 2 + devkit-rl/src/sliding_window_count.rs | 116 ++++++++++++++++++ 5 files changed, 158 insertions(+) create mode 100644 devkit-rl/benches/sliding_window_count_bench.rs create mode 100644 devkit-rl/benches/sliding_window_log_bench.rs create mode 100644 devkit-rl/src/sliding_window_count.rs diff --git a/devkit-rl/Cargo.toml b/devkit-rl/Cargo.toml index 6f6d09d..c4611a8 100644 --- a/devkit-rl/Cargo.toml +++ b/devkit-rl/Cargo.toml @@ -16,6 +16,14 @@ harness = false name = "fixed_window_bench" harness = false +[[bench]] +name = "sliding_window_log_bench" +harness = false + +[[bench]] +name = "sliding_window_count_bench" +harness = false + [dependencies] chrono = "0.4.38" criterion = { workspace = true } diff --git a/devkit-rl/benches/sliding_window_count_bench.rs b/devkit-rl/benches/sliding_window_count_bench.rs new file mode 100644 index 0000000..3da063c --- /dev/null +++ b/devkit-rl/benches/sliding_window_count_bench.rs @@ -0,0 +1,16 @@ +use std::time::Duration; + +use criterion::{criterion_group, criterion_main, Criterion}; +use devkit_rl::SlidingWindowCount; + +fn sliding_window_count_benchmark(c: &mut Criterion) { + let tb = SlidingWindowCount::new(10, Duration::from_millis(1), 10); + c.bench_function("sliding_window_count", |b| { + b.iter(|| { + tb.allow(); + }) + }); +} + +criterion_group!(benches, sliding_window_count_benchmark); +criterion_main!(benches); diff --git a/devkit-rl/benches/sliding_window_log_bench.rs b/devkit-rl/benches/sliding_window_log_bench.rs new file mode 100644 index 0000000..8a39df1 --- /dev/null +++ b/devkit-rl/benches/sliding_window_log_bench.rs @@ -0,0 +1,16 @@ +use std::time::Duration; + +use criterion::{criterion_group, criterion_main, Criterion}; +use devkit_rl::SlidingWindowLog; + +fn sliding_window_log_benchmark(c: &mut Criterion) { + let tb = SlidingWindowLog::new(10, Some(Duration::from_millis(1))); + c.bench_function("sliding_window_log", |b| { + b.iter(|| { + tb.allow(); + }) + }); +} + +criterion_group!(benches, sliding_window_log_benchmark); +criterion_main!(benches); diff --git a/devkit-rl/src/lib.rs b/devkit-rl/src/lib.rs index c12eed7..47b383d 100644 --- a/devkit-rl/src/lib.rs +++ b/devkit-rl/src/lib.rs @@ -1,9 +1,11 @@ mod fixed_window; mod leaky_bucket; +mod sliding_window_count; mod sliding_window_log; mod token_bucket; pub use fixed_window::FixedWindow; pub use leaky_bucket::LeakyBucket; +pub use sliding_window_count::SlidingWindowCount; pub use sliding_window_log::SlidingWindowLog; pub use token_bucket::TokenBucket; diff --git a/devkit-rl/src/sliding_window_count.rs b/devkit-rl/src/sliding_window_count.rs new file mode 100644 index 0000000..5da7183 --- /dev/null +++ b/devkit-rl/src/sliding_window_count.rs @@ -0,0 +1,116 @@ +use std::{ + sync::{Arc, Mutex}, + time::{Duration, Instant}, +}; + +#[derive(Debug, Clone)] +pub struct SlidingWindowCount { + inner: Arc>, +} + +#[derive(Debug)] +struct SlidingWindowCountInner { + buckets: Vec, + win_size: u64, + bucket_interval: Duration, + last_update: Instant, + last_index: usize, +} + +impl SlidingWindowCount { + pub fn new(win_size: u64, interval: Duration, bucket_count: u64) -> Self { + Self { + inner: Arc::new(Mutex::new(SlidingWindowCountInner { + buckets: vec![0; bucket_count as usize], + win_size, + bucket_interval: interval.div_f64(bucket_count as f64), + last_update: Instant::now(), + last_index: 0, + })), + } + } + + pub fn allow(&self) -> bool { + self.allow_n(1) + } + + pub fn allow_n(&self, n: u64) -> bool { + let mut inner = self.inner.lock().unwrap(); + + inner.update_buckets(); + + if inner.total_count() + n <= inner.win_size { + inner.add_requests(n); + true + } else { + false + } + } +} + +impl SlidingWindowCountInner { + fn update_buckets(&mut self) { + let now = Instant::now(); + + let bucket_passed = self.bucket_passed(now); + + for i in 0..bucket_passed { + let idx = (i + self.last_index) % self.buckets.len(); + self.buckets[idx] = 0; + } + + self.last_index = (self.last_index + bucket_passed) % self.buckets.len(); + self.last_update = now; + } + + fn bucket_passed(&self, now: Instant) -> usize { + let elapsed = now - self.last_update; + let count = elapsed.div_duration_f64(self.bucket_interval) as usize; + if count > self.buckets.len() { + self.buckets.len() + } else { + count + } + } + + fn total_count(&self) -> u64 { + self.buckets.iter().sum() + } + + fn add_requests(&mut self, n: u64) { + self.buckets[self.last_index] += n; + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn sliding_window_count_should_work() { + const SIZE: u64 = 20; + const BUCKET_COUNT: u64 = 10; + const WINDOW_INTERVAL: Duration = Duration::from_millis(BUCKET_COUNT); + + let swc = SlidingWindowCount::new(SIZE, WINDOW_INTERVAL, BUCKET_COUNT); + + // first 20 requests should be allowed + for _ in 0..SIZE { + assert!(swc.allow()); + } + + // in current window, no more token should be allowed + assert!(!swc.allow()); + assert_eq!(SIZE, swc.inner.lock().unwrap().total_count()); + + // sleep for 1/2 interval, some older tokens would be removed, + // new should be allowed. + std::thread::sleep(WINDOW_INTERVAL / 2); + assert!(swc.allow()); + + // sleep for a long time, all buckets should be cleared. + std::thread::sleep(WINDOW_INTERVAL * 2); + assert!(swc.allow()); + assert_eq!(1, swc.inner.lock().unwrap().total_count()); + } +}