From 813185c62839b1c168138a3a599bd23404ce8b59 Mon Sep 17 00:00:00 2001 From: Alex Snaps Date: Sat, 11 May 2024 10:41:15 -0400 Subject: [PATCH 01/10] Base `Cr(dt)CounterValue` --- .../src/storage/atomic_expiring_value.rs | 34 +++ .../storage/distributed/cr_counter_value.rs | 209 ++++++++++++++++++ limitador/src/storage/distributed/mod.rs | 1 + limitador/src/storage/mod.rs | 1 + 4 files changed, 245 insertions(+) create mode 100644 limitador/src/storage/distributed/cr_counter_value.rs create mode 100644 limitador/src/storage/distributed/mod.rs diff --git a/limitador/src/storage/atomic_expiring_value.rs b/limitador/src/storage/atomic_expiring_value.rs index f80eaa22..b16ec8f0 100644 --- a/limitador/src/storage/atomic_expiring_value.rs +++ b/limitador/src/storage/atomic_expiring_value.rs @@ -102,6 +102,40 @@ impl AtomicExpiryTime { } false } + + pub fn merge(&self, other: Self) { + let mut other = other; + loop { + let now = SystemTime::now(); + other = match self.merge_at(other, now) { + Ok(_) => return, + Err(other) => other, + }; + } + } + + pub fn merge_at(&self, other: Self, when: SystemTime) -> Result<(), Self> { + let other_exp = other.expiry.load(Ordering::SeqCst); + let expiry = self.expiry.load(Ordering::SeqCst); + if other_exp < expiry && other_exp > Self::since_epoch(when) { + // if our expiry changed, some thread observed the time window as elapsed... + // `other` can't be in the future anymore! Safely ignoring the failure scenario + return match self.expiry.compare_exchange( + expiry, + other_exp, + Ordering::SeqCst, + Ordering::SeqCst, + ) { + Ok(_) => Ok(()), + Err(_) => Err(other), + }; + } + Ok(()) + } + + pub fn into_inner(self) -> SystemTime { + SystemTime::UNIX_EPOCH + Duration::from_micros(self.expiry.load(Ordering::SeqCst)) + } } impl Clone for AtomicExpiryTime { diff --git a/limitador/src/storage/distributed/cr_counter_value.rs b/limitador/src/storage/distributed/cr_counter_value.rs new file mode 100644 index 00000000..33fa589d --- /dev/null +++ b/limitador/src/storage/distributed/cr_counter_value.rs @@ -0,0 +1,209 @@ +use crate::storage::atomic_expiring_value::AtomicExpiryTime; +use std::collections::btree_map::Entry; +use std::collections::BTreeMap; +use std::ops::Not; +use std::sync::atomic::{AtomicU64, Ordering}; +use std::sync::RwLock; +use std::time::{Duration, SystemTime}; + +struct CrCounterValue { + ourselves: A, + value: AtomicU64, + others: RwLock>, + expiry: AtomicExpiryTime, +} + +impl CrCounterValue { + pub fn new(actor: A, time_window: Duration) -> Self { + Self { + ourselves: actor, + value: Default::default(), + others: RwLock::default(), + expiry: AtomicExpiryTime::from_now(time_window), + } + } + + pub fn read(&self) -> u64 { + self.read_at(SystemTime::now()) + } + + pub fn read_at(&self, when: SystemTime) -> u64 { + if self.expiry.expired_at(when) { + 0 + } else { + let guard = self.others.read().unwrap(); + let others: u64 = guard.values().sum(); + others + self.value.load(Ordering::Relaxed) + } + } + + pub fn inc(&self, increment: u64, time_window: Duration) { + self.inc_at(increment, time_window, SystemTime::now()) + } + + pub fn inc_at(&self, increment: u64, time_window: Duration, when: SystemTime) { + if self + .expiry + .update_if_expired(time_window.as_micros() as u64, when) + { + self.value.store(increment, Ordering::SeqCst); + } else { + self.value.fetch_add(increment, Ordering::SeqCst); + } + } + + pub fn inc_actor(&self, actor: A, increment: u64, time_window: Duration) { + self.inc_actor_at(actor, increment, time_window, SystemTime::now()); + } + + pub fn inc_actor_at(&self, actor: A, increment: u64, time_window: Duration, when: SystemTime) { + if actor == self.ourselves { + self.inc_at(increment, time_window, when); + } else { + let mut guard = self.others.write().unwrap(); + if self + .expiry + .update_if_expired(time_window.as_micros() as u64, when) + { + guard.insert(actor, increment); + } else { + *guard.entry(actor).or_insert(0) += increment; + } + } + } + + pub fn merge(&self, other: Self) { + self.merge_at(other, SystemTime::now()); + } + + pub fn merge_at(&self, other: Self, when: SystemTime) { + if self.expiry.expired_at(when).not() && other.expiry.expired_at(when).not() { + let (expiry, other_values) = other.into_inner(); + let _ = self.expiry.merge_at(AtomicExpiryTime::new(expiry), when); + let ourselves = self.value.load(Ordering::SeqCst); + let mut others = self.others.write().unwrap(); + for (actor, other_value) in other_values { + if actor == self.ourselves { + if other_value > ourselves { + self.value + .fetch_add(other_value - ourselves, Ordering::SeqCst); + } + } else { + match others.entry(actor) { + Entry::Vacant(entry) => { + entry.insert(other_value); + } + Entry::Occupied(mut known) => { + let local = known.get_mut(); + if other_value > *local { + *local = other_value; + } + } + } + } + } + } + } + + fn into_inner(self) -> (SystemTime, BTreeMap) { + let Self { + ourselves, + value, + others, + expiry, + } = self; + let mut map = others.into_inner().unwrap(); + map.insert(ourselves, value.into_inner()); + (expiry.into_inner(), map) + } +} + +#[cfg(test)] +mod tests { + use crate::storage::distributed::cr_counter_value::CrCounterValue; + use std::time::{Duration, SystemTime}; + + #[test] + fn local_increments_are_readable() { + let window = Duration::from_secs(1); + let a = CrCounterValue::new('A', window); + a.inc(3, window); + assert_eq!(3, a.read()); + a.inc(2, window); + assert_eq!(5, a.read()); + } + + #[test] + fn local_increments_expire() { + let window = Duration::from_secs(1); + let a = CrCounterValue::new('A', window); + let now = SystemTime::now(); + a.inc_at(3, window, now); + assert_eq!(3, a.read()); + a.inc_at(2, window, now + window); + assert_eq!(2, a.read()); + } + + #[test] + fn other_increments_are_readable() { + let window = Duration::from_secs(1); + let a = CrCounterValue::new('A', window); + a.inc_actor('B', 3, window); + assert_eq!(3, a.read()); + a.inc_actor('B', 2, window); + assert_eq!(5, a.read()); + } + + #[test] + fn other_increments_expire() { + let window = Duration::from_secs(1); + let a = CrCounterValue::new('A', window); + let now = SystemTime::now(); + a.inc_actor_at('B', 3, window, now); + assert_eq!(3, a.read()); + a.inc_actor_at('B', 2, window, now + window); + assert_eq!(2, a.read()); + } + + #[test] + fn merges() { + let window = Duration::from_secs(1); + { + let a = CrCounterValue::new('A', window); + let b = CrCounterValue::new('B', window); + a.inc(3, window); + b.inc(2, window); + a.merge(b); + assert_eq!(a.read(), 5); + } + + { + let a = CrCounterValue::new('A', window); + let b = CrCounterValue::new('B', window); + a.inc(3, window); + b.inc(2, window); + b.merge(a); + assert_eq!(b.read(), 5); + } + + { + let a = CrCounterValue::new('A', window); + let b = CrCounterValue::new('B', window); + a.inc(3, window); + b.inc(2, window); + b.inc_actor('A', 2, window); // older value! + b.merge(a); // merges the 3 + assert_eq!(b.read(), 5); + } + + { + let a = CrCounterValue::new('A', window); + let b = CrCounterValue::new('B', window); + a.inc(3, window); + b.inc(2, window); + b.inc_actor('A', 5, window); // newer value! + b.merge(a); // ignores the 3 and keeps its own 5 for a + assert_eq!(b.read(), 7); + } + } +} diff --git a/limitador/src/storage/distributed/mod.rs b/limitador/src/storage/distributed/mod.rs new file mode 100644 index 00000000..6d0e009e --- /dev/null +++ b/limitador/src/storage/distributed/mod.rs @@ -0,0 +1 @@ +mod cr_counter_value; diff --git a/limitador/src/storage/mod.rs b/limitador/src/storage/mod.rs index 4db70278..12e2d55e 100644 --- a/limitador/src/storage/mod.rs +++ b/limitador/src/storage/mod.rs @@ -8,6 +8,7 @@ use thiserror::Error; #[cfg(feature = "disk_storage")] pub mod disk; +pub mod distributed; pub mod in_memory; #[cfg(feature = "redis_storage")] From 91708687e4681b198f3041fb6a74baa6cf0aaec9 Mon Sep 17 00:00:00 2001 From: Alex Snaps Date: Sun, 12 May 2024 09:48:14 -0400 Subject: [PATCH 02/10] Tests --- .../src/storage/atomic_expiring_value.rs | 11 ++ .../storage/distributed/cr_counter_value.rs | 124 ++++++++++++------ 2 files changed, 97 insertions(+), 38 deletions(-) diff --git a/limitador/src/storage/atomic_expiring_value.rs b/limitador/src/storage/atomic_expiring_value.rs index b16ec8f0..c193bedc 100644 --- a/limitador/src/storage/atomic_expiring_value.rs +++ b/limitador/src/storage/atomic_expiring_value.rs @@ -103,6 +103,7 @@ impl AtomicExpiryTime { false } + #[allow(dead_code)] pub fn merge(&self, other: Self) { let mut other = other; loop { @@ -134,6 +135,10 @@ impl AtomicExpiryTime { } pub fn into_inner(self) -> SystemTime { + self.expires_at() + } + + pub fn expires_at(&self) -> SystemTime { SystemTime::UNIX_EPOCH + Duration::from_micros(self.expiry.load(Ordering::SeqCst)) } } @@ -164,6 +169,12 @@ impl Clone for AtomicExpiringValue { } } +impl From for AtomicExpiryTime { + fn from(value: SystemTime) -> Self { + AtomicExpiryTime::new(value) + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/limitador/src/storage/distributed/cr_counter_value.rs b/limitador/src/storage/distributed/cr_counter_value.rs index 33fa589d..704d8c85 100644 --- a/limitador/src/storage/distributed/cr_counter_value.rs +++ b/limitador/src/storage/distributed/cr_counter_value.rs @@ -1,7 +1,6 @@ use crate::storage::atomic_expiring_value::AtomicExpiryTime; use std::collections::btree_map::Entry; use std::collections::BTreeMap; -use std::ops::Not; use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::RwLock; use std::time::{Duration, SystemTime}; @@ -13,6 +12,7 @@ struct CrCounterValue { expiry: AtomicExpiryTime, } +#[allow(dead_code)] impl CrCounterValue { pub fn new(actor: A, time_window: Duration) -> Self { Self { @@ -77,9 +77,12 @@ impl CrCounterValue { } pub fn merge_at(&self, other: Self, when: SystemTime) { - if self.expiry.expired_at(when).not() && other.expiry.expired_at(when).not() { - let (expiry, other_values) = other.into_inner(); - let _ = self.expiry.merge_at(AtomicExpiryTime::new(expiry), when); + let (expiry, other_values) = other.into_inner(); + if expiry > when { + let _ = self.expiry.merge_at(expiry.into(), when); + if self.expiry.expired_at(when) { + self.reset(expiry); + } let ourselves = self.value.load(Ordering::SeqCst); let mut others = self.others.write().unwrap(); for (actor, other_value) in other_values { @@ -116,6 +119,13 @@ impl CrCounterValue { map.insert(ourselves, value.into_inner()); (expiry.into_inner(), map) } + + fn reset(&self, expiry: SystemTime) { + let mut guard = self.others.write().unwrap(); + self.expiry.update(expiry); + self.value.store(0, Ordering::SeqCst); + guard.clear() + } } #[cfg(test)] @@ -168,42 +178,80 @@ mod tests { #[test] fn merges() { let window = Duration::from_secs(1); - { - let a = CrCounterValue::new('A', window); - let b = CrCounterValue::new('B', window); - a.inc(3, window); - b.inc(2, window); - a.merge(b); - assert_eq!(a.read(), 5); - } + let a = CrCounterValue::new('A', window); + let b = CrCounterValue::new('B', window); + a.inc(3, window); + b.inc(2, window); + a.merge(b); + assert_eq!(a.read(), 5); + } - { - let a = CrCounterValue::new('A', window); - let b = CrCounterValue::new('B', window); - a.inc(3, window); - b.inc(2, window); - b.merge(a); - assert_eq!(b.read(), 5); - } + #[test] + fn merges_symetric() { + let window = Duration::from_secs(1); + let a = CrCounterValue::new('A', window); + let b = CrCounterValue::new('B', window); + a.inc(3, window); + b.inc(2, window); + b.merge(a); + assert_eq!(b.read(), 5); + } - { - let a = CrCounterValue::new('A', window); - let b = CrCounterValue::new('B', window); - a.inc(3, window); - b.inc(2, window); - b.inc_actor('A', 2, window); // older value! - b.merge(a); // merges the 3 - assert_eq!(b.read(), 5); - } + #[test] + fn merges_overrides_with_larger_value() { + let window = Duration::from_secs(1); + let a = CrCounterValue::new('A', window); + let b = CrCounterValue::new('B', window); + a.inc(3, window); + b.inc(2, window); + b.inc_actor('A', 2, window); // older value! + b.merge(a); // merges the 3 + assert_eq!(b.read(), 5); + } - { - let a = CrCounterValue::new('A', window); - let b = CrCounterValue::new('B', window); - a.inc(3, window); - b.inc(2, window); - b.inc_actor('A', 5, window); // newer value! - b.merge(a); // ignores the 3 and keeps its own 5 for a - assert_eq!(b.read(), 7); - } + #[test] + fn merges_ignore_lesser_values() { + let window = Duration::from_secs(1); + let a = CrCounterValue::new('A', window); + let b = CrCounterValue::new('B', window); + a.inc(3, window); + b.inc(2, window); + b.inc_actor('A', 5, window); // newer value! + b.merge(a); // ignores the 3 and keeps its own 5 for a + assert_eq!(b.read(), 7); + } + + #[test] + fn merge_ignores_expired_sets() { + let window = Duration::from_secs(1); + let a = CrCounterValue::new('A', Duration::ZERO); + a.inc(3, Duration::ZERO); + let b = CrCounterValue::new('B', window); + b.inc(2, window); + b.merge(a); + assert_eq!(b.read(), 2); + } + + #[test] + fn merge_ignores_expired_sets_symmetric() { + let window = Duration::from_secs(1); + let a = CrCounterValue::new('A', Duration::ZERO); + a.inc(3, Duration::ZERO); + let b = CrCounterValue::new('B', window); + b.inc(2, window); + a.merge(b); + assert_eq!(a.read(), 2); + } + + #[test] + fn merge_uses_earliest_expiry() { + let later = Duration::from_secs(1); + let a = CrCounterValue::new('A', later); + let sooner = Duration::from_millis(200); + let b = CrCounterValue::new('B', sooner); + a.inc(3, later); + b.inc(2, later); + a.merge(b); + assert!(a.expiry.duration() < sooner); } } From 946dd38b70423cf90f43c42a2c4193da64146801 Mon Sep 17 00:00:00 2001 From: Alex Snaps Date: Mon, 13 May 2024 18:18:08 -0400 Subject: [PATCH 03/10] Created CrInMemoryStorage that uses the CrCounterValue --- .../storage/distributed/cr_counter_value.rs | 17 +- limitador/src/storage/distributed/mod.rs | 293 ++++++++++++++++++ limitador/tests/integration_tests.rs | 8 + 3 files changed, 317 insertions(+), 1 deletion(-) diff --git a/limitador/src/storage/distributed/cr_counter_value.rs b/limitador/src/storage/distributed/cr_counter_value.rs index 704d8c85..365c6bd0 100644 --- a/limitador/src/storage/distributed/cr_counter_value.rs +++ b/limitador/src/storage/distributed/cr_counter_value.rs @@ -5,7 +5,7 @@ use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::RwLock; use std::time::{Duration, SystemTime}; -struct CrCounterValue { +pub struct CrCounterValue { ourselves: A, value: AtomicU64, others: RwLock>, @@ -108,6 +108,10 @@ impl CrCounterValue { } } + pub fn ttl(&self) -> Duration { + self.expiry.duration() + } + fn into_inner(self) -> (SystemTime, BTreeMap) { let Self { ourselves, @@ -128,6 +132,17 @@ impl CrCounterValue { } } +impl Clone for CrCounterValue { + fn clone(&self) -> Self { + Self { + ourselves: self.ourselves.clone(), + value: AtomicU64::new(self.value.load(Ordering::SeqCst)), + others: RwLock::new(self.others.read().unwrap().clone()), + expiry: self.expiry.clone(), + } + } +} + #[cfg(test)] mod tests { use crate::storage::distributed::cr_counter_value::CrCounterValue; diff --git a/limitador/src/storage/distributed/mod.rs b/limitador/src/storage/distributed/mod.rs index 6d0e009e..9cca4922 100644 --- a/limitador/src/storage/distributed/mod.rs +++ b/limitador/src/storage/distributed/mod.rs @@ -1 +1,294 @@ mod cr_counter_value; + +use crate::counter::Counter; +use crate::limit::{Limit, Namespace}; +use crate::storage::distributed::cr_counter_value::CrCounterValue; +use crate::storage::{Authorization, CounterStorage, StorageErr}; +use moka::sync::Cache; +use std::collections::hash_map::Entry; +use std::collections::{HashMap, HashSet}; +use std::ops::Deref; +use std::sync::{Arc, RwLock}; +use std::time::{Duration, SystemTime}; + +type NamespacedLimitCounters = HashMap>; + +pub struct CrInMemoryStorage { + identifier: String, + limits_for_namespace: RwLock>>, + qualified_counters: Cache>>, +} + +impl CounterStorage for CrInMemoryStorage { + #[tracing::instrument(skip_all)] + fn is_within_limits(&self, counter: &Counter, delta: i64) -> Result { + let limits_by_namespace = self.limits_for_namespace.read().unwrap(); + + let mut value = 0; + + if counter.is_qualified() { + if let Some(counter) = self.qualified_counters.get(counter) { + value = counter.read(); + } + } else if let Some(limits) = limits_by_namespace.get(counter.limit().namespace()) { + if let Some(counter) = limits.get(counter.limit()) { + value = counter.read(); + } + } + + Ok(counter.max_value() as u64 >= value + (delta as u64)) + } + + #[tracing::instrument(skip_all)] + fn add_counter(&self, limit: &Limit) -> Result<(), StorageErr> { + if limit.variables().is_empty() { + let mut limits_by_namespace = self.limits_for_namespace.write().unwrap(); + limits_by_namespace + .entry(limit.namespace().clone()) + .or_default() + .entry(limit.clone()) + .or_insert(CrCounterValue::new( + self.identifier.clone(), + Duration::from_secs(limit.seconds()), + )); + } + Ok(()) + } + + #[tracing::instrument(skip_all)] + fn update_counter(&self, counter: &Counter, delta: i64) -> Result<(), StorageErr> { + let mut limits_by_namespace = self.limits_for_namespace.write().unwrap(); + let now = SystemTime::now(); + if counter.is_qualified() { + let value = match self.qualified_counters.get(counter) { + None => self.qualified_counters.get_with(counter.clone(), || { + Arc::new(CrCounterValue::new( + self.identifier.clone(), + Duration::from_secs(counter.seconds()), + )) + }), + Some(counter) => counter, + }; + value.inc_at(delta as u64, Duration::from_secs(counter.seconds()), now); + } else { + match limits_by_namespace.entry(counter.limit().namespace().clone()) { + Entry::Vacant(v) => { + let mut limits = HashMap::new(); + let duration = Duration::from_secs(counter.seconds()); + let counter_val = CrCounterValue::new(self.identifier.clone(), duration); + counter_val.inc_at(delta as u64, duration, now); + limits.insert(counter.limit().clone(), counter_val); + v.insert(limits); + } + Entry::Occupied(mut o) => match o.get_mut().entry(counter.limit().clone()) { + Entry::Vacant(v) => { + let duration = Duration::from_secs(counter.seconds()); + let counter_value = CrCounterValue::new(self.identifier.clone(), duration); + counter_value.inc_at(delta as u64, duration, now); + v.insert(counter_value); + } + Entry::Occupied(o) => { + o.get() + .inc_at(delta as u64, Duration::from_secs(counter.seconds()), now); + } + }, + } + } + Ok(()) + } + + #[tracing::instrument(skip_all)] + fn check_and_update( + &self, + counters: &mut Vec, + delta: i64, + load_counters: bool, + ) -> Result { + let limits_by_namespace = self.limits_for_namespace.write().unwrap(); + let mut first_limited = None; + let mut counter_values_to_update: Vec<(&CrCounterValue, u64)> = Vec::new(); + let mut qualified_counter_values_to_updated: Vec<(Arc>, u64)> = + Vec::new(); + let now = SystemTime::now(); + + let mut process_counter = + |counter: &mut Counter, value: i64, delta: i64| -> Option { + if load_counters { + let remaining = counter.max_value() - (value + delta); + counter.set_remaining(remaining); + if first_limited.is_none() && remaining < 0 { + first_limited = Some(Authorization::Limited( + counter.limit().name().map(|n| n.to_owned()), + )); + } + } + if !Self::counter_is_within_limits(counter, Some(&value), delta) { + return Some(Authorization::Limited( + counter.limit().name().map(|n| n.to_owned()), + )); + } + None + }; + + // Process simple counters + for counter in counters.iter_mut().filter(|c| !c.is_qualified()) { + let atomic_expiring_value: &CrCounterValue = limits_by_namespace + .get(counter.limit().namespace()) + .and_then(|limits| limits.get(counter.limit())) + .unwrap(); + + if let Some(limited) = + process_counter(counter, atomic_expiring_value.read() as i64, delta) + { + if !load_counters { + return Ok(limited); + } + } + counter_values_to_update.push((atomic_expiring_value, counter.seconds())); + } + + // Process qualified counters + for counter in counters.iter_mut().filter(|c| c.is_qualified()) { + let value = match self.qualified_counters.get(counter) { + None => self.qualified_counters.get_with(counter.clone(), || { + Arc::new(CrCounterValue::new( + self.identifier.clone(), + Duration::from_secs(counter.seconds()), + )) + }), + Some(counter) => counter, + }; + + if let Some(limited) = process_counter(counter, value.read() as i64, delta) { + if !load_counters { + return Ok(limited); + } + } + + qualified_counter_values_to_updated.push((value, counter.seconds())); + } + + if let Some(limited) = first_limited { + return Ok(limited); + } + + // Update counters + counter_values_to_update.iter().for_each(|(v, ttl)| { + v.inc_at(delta as u64, Duration::from_secs(*ttl), now); + }); + qualified_counter_values_to_updated + .iter() + .for_each(|(v, ttl)| { + v.inc_at(delta as u64, Duration::from_secs(*ttl), now); + }); + + Ok(Authorization::Ok) + } + + #[tracing::instrument(skip_all)] + fn get_counters(&self, limits: &HashSet) -> Result, StorageErr> { + let mut res = HashSet::new(); + + let namespaces: HashSet<&Namespace> = limits.iter().map(Limit::namespace).collect(); + let limits_by_namespace = self.limits_for_namespace.read().unwrap(); + + for namespace in namespaces { + if let Some(limits) = limits_by_namespace.get(namespace) { + for limit in limits.keys() { + if limits.contains_key(limit) { + for (counter, expiring_value) in self.counters_in_namespace(namespace) { + let mut counter_with_val = counter.clone(); + counter_with_val.set_remaining( + counter_with_val.max_value() - expiring_value.read() as i64, + ); + counter_with_val.set_expires_in(expiring_value.ttl()); + if counter_with_val.expires_in().unwrap() > Duration::ZERO { + res.insert(counter_with_val); + } + } + } + } + } + } + + for (counter, expiring_value) in self.qualified_counters.iter() { + if limits.contains(counter.limit()) { + let mut counter_with_val = counter.deref().clone(); + counter_with_val + .set_remaining(counter_with_val.max_value() - expiring_value.read() as i64); + counter_with_val.set_expires_in(expiring_value.ttl()); + if counter_with_val.expires_in().unwrap() > Duration::ZERO { + res.insert(counter_with_val); + } + } + } + + Ok(res) + } + + #[tracing::instrument(skip_all)] + fn delete_counters(&self, limits: HashSet) -> Result<(), StorageErr> { + for limit in limits { + self.delete_counters_of_limit(&limit); + } + Ok(()) + } + + #[tracing::instrument(skip_all)] + fn clear(&self) -> Result<(), StorageErr> { + self.limits_for_namespace.write().unwrap().clear(); + Ok(()) + } +} + +impl CrInMemoryStorage { + pub fn new(identifier: String, cache_size: u64) -> Self { + Self { + identifier, + limits_for_namespace: RwLock::new(HashMap::new()), + qualified_counters: Cache::new(cache_size), + } + } + + fn counters_in_namespace( + &self, + namespace: &Namespace, + ) -> HashMap> { + let mut res: HashMap> = HashMap::new(); + + if let Some(counters_by_limit) = self.limits_for_namespace.read().unwrap().get(namespace) { + for (limit, value) in counters_by_limit { + res.insert( + Counter::new(limit.clone(), HashMap::default()), + value.clone(), + ); + } + } + + for (counter, value) in self.qualified_counters.iter() { + if counter.namespace() == namespace { + res.insert(counter.deref().clone(), value.deref().clone()); + } + } + + res + } + + fn delete_counters_of_limit(&self, limit: &Limit) { + if let Some(counters_by_limit) = self + .limits_for_namespace + .write() + .unwrap() + .get_mut(limit.namespace()) + { + counters_by_limit.remove(limit); + } + } + + fn counter_is_within_limits(counter: &Counter, current_val: Option<&i64>, delta: i64) -> bool { + match current_val { + Some(current_val) => current_val + delta <= counter.max_value(), + None => counter.max_value() >= delta, + } + } +} diff --git a/limitador/tests/integration_tests.rs b/limitador/tests/integration_tests.rs index f14d8f95..623326fb 100644 --- a/limitador/tests/integration_tests.rs +++ b/limitador/tests/integration_tests.rs @@ -13,6 +13,13 @@ macro_rules! test_with_all_storage_impls { $function(&mut TestsLimiter::new_from_blocking_impl(rate_limiter)).await; } + #[tokio::test] + async fn [<$function _distributed_storage>]() { + let rate_limiter = + RateLimiter::new_with_storage(Box::new(CrInMemoryStorage::new("test_node".to_owned(), 10_000))); + $function(&mut TestsLimiter::new_from_blocking_impl(rate_limiter)).await; + } + #[tokio::test] async fn [<$function _disk_storage>]() { let dir = TempDir::new().expect("We should have a dir!"); @@ -90,6 +97,7 @@ mod test { use limitador::limit::Limit; use limitador::storage::disk::{DiskStorage, OptimizeFor}; use limitador::storage::in_memory::InMemoryStorage; + use limitador::storage::distributed::CrInMemoryStorage; use std::collections::{HashMap, HashSet}; use std::thread::sleep; use std::time::Duration; From 3b29edc9076e7d92a9a86c79d86389171f130e5f Mon Sep 17 00:00:00 2001 From: Alex Snaps Date: Tue, 14 May 2024 14:02:04 -0400 Subject: [PATCH 04/10] Basic UDP replication --- limitador-server/examples/limits.yaml | 4 +- limitador/src/lib.rs | 2 +- .../storage/distributed/cr_counter_value.rs | 22 +- limitador/src/storage/distributed/mod.rs | 189 +++++++++++++++--- limitador/src/storage/mod.rs | 12 +- limitador/tests/integration_tests.rs | 2 +- 6 files changed, 196 insertions(+), 35 deletions(-) diff --git a/limitador-server/examples/limits.yaml b/limitador-server/examples/limits.yaml index f0ea815b..afcb2b50 100644 --- a/limitador-server/examples/limits.yaml +++ b/limitador-server/examples/limits.yaml @@ -1,12 +1,10 @@ --- - namespace: test_namespace - max_value: 10 + max_value: 1000000 seconds: 60 conditions: - - "req.method == 'GET'" variables: - - user_id - namespace: test_namespace max_value: 5 diff --git a/limitador/src/lib.rs b/limitador/src/lib.rs index 59f07a67..77c2c4f5 100644 --- a/limitador/src/lib.rs +++ b/limitador/src/lib.rs @@ -197,7 +197,7 @@ use std::collections::{HashMap, HashSet}; use crate::counter::Counter; use crate::errors::LimitadorError; use crate::limit::{Limit, Namespace}; -use crate::storage::in_memory::InMemoryStorage; +use crate::storage::distributed::CrInMemoryStorage as InMemoryStorage; use crate::storage::{AsyncCounterStorage, AsyncStorage, Authorization, CounterStorage, Storage}; #[macro_use] diff --git a/limitador/src/storage/distributed/cr_counter_value.rs b/limitador/src/storage/distributed/cr_counter_value.rs index 365c6bd0..75fbaf61 100644 --- a/limitador/src/storage/distributed/cr_counter_value.rs +++ b/limitador/src/storage/distributed/cr_counter_value.rs @@ -5,6 +5,7 @@ use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::RwLock; use std::time::{Duration, SystemTime}; +#[derive(Debug)] pub struct CrCounterValue { ourselves: A, value: AtomicU64, @@ -94,7 +95,9 @@ impl CrCounterValue { } else { match others.entry(actor) { Entry::Vacant(entry) => { - entry.insert(other_value); + if other_value > 0 { + entry.insert(other_value); + } } Entry::Occupied(mut known) => { let local = known.get_mut(); @@ -112,7 +115,11 @@ impl CrCounterValue { self.expiry.duration() } - fn into_inner(self) -> (SystemTime, BTreeMap) { + pub fn expiry(&self) -> SystemTime { + self.expiry.expires_at() + } + + pub fn into_inner(self) -> (SystemTime, BTreeMap) { let Self { ourselves, value, @@ -143,6 +150,17 @@ impl Clone for CrCounterValue { } } +impl From<(SystemTime, BTreeMap)> for CrCounterValue { + fn from(value: (SystemTime, BTreeMap)) -> Self { + Self { + ourselves: A::default(), + value: Default::default(), + others: RwLock::new(value.1), + expiry: value.0.into(), + } + } +} + #[cfg(test)] mod tests { use crate::storage::distributed::cr_counter_value::CrCounterValue; diff --git a/limitador/src/storage/distributed/mod.rs b/limitador/src/storage/distributed/mod.rs index 9cca4922..70e20bf1 100644 --- a/limitador/src/storage/distributed/mod.rs +++ b/limitador/src/storage/distributed/mod.rs @@ -1,22 +1,30 @@ -mod cr_counter_value; +use std::collections::hash_map::Entry; +use std::collections::{BTreeMap, HashMap, HashSet}; +use std::net::ToSocketAddrs; +use std::ops::Deref; +use std::sync::{Arc, RwLock}; +use std::time::{Duration, SystemTime, UNIX_EPOCH}; + +use moka::sync::Cache; +use serde::{Deserialize, Serialize}; +use tokio::net::UdpSocket; +use tokio::sync::mpsc; +use tokio::sync::mpsc::Sender; use crate::counter::Counter; use crate::limit::{Limit, Namespace}; use crate::storage::distributed::cr_counter_value::CrCounterValue; use crate::storage::{Authorization, CounterStorage, StorageErr}; -use moka::sync::Cache; -use std::collections::hash_map::Entry; -use std::collections::{HashMap, HashSet}; -use std::ops::Deref; -use std::sync::{Arc, RwLock}; -use std::time::{Duration, SystemTime}; + +mod cr_counter_value; type NamespacedLimitCounters = HashMap>; pub struct CrInMemoryStorage { identifier: String, - limits_for_namespace: RwLock>>, - qualified_counters: Cache>>, + sender: Sender, + limits_for_namespace: Arc>>>, + qualified_counters: Arc>>>, } impl CounterStorage for CrInMemoryStorage { @@ -69,14 +77,14 @@ impl CounterStorage for CrInMemoryStorage { }), Some(counter) => counter, }; - value.inc_at(delta as u64, Duration::from_secs(counter.seconds()), now); + self.increment_counter(counter.clone(), &value, delta as u64, now); } else { match limits_by_namespace.entry(counter.limit().namespace().clone()) { Entry::Vacant(v) => { let mut limits = HashMap::new(); let duration = Duration::from_secs(counter.seconds()); let counter_val = CrCounterValue::new(self.identifier.clone(), duration); - counter_val.inc_at(delta as u64, duration, now); + self.increment_counter(counter.clone(), &counter_val, delta as u64, now); limits.insert(counter.limit().clone(), counter_val); v.insert(limits); } @@ -84,12 +92,11 @@ impl CounterStorage for CrInMemoryStorage { Entry::Vacant(v) => { let duration = Duration::from_secs(counter.seconds()); let counter_value = CrCounterValue::new(self.identifier.clone(), duration); - counter_value.inc_at(delta as u64, duration, now); + self.increment_counter(counter.clone(), &counter_value, delta as u64, now); v.insert(counter_value); } Entry::Occupied(o) => { - o.get() - .inc_at(delta as u64, Duration::from_secs(counter.seconds()), now); + self.increment_counter(counter.clone(), o.get(), delta as u64, now); } }, } @@ -106,8 +113,8 @@ impl CounterStorage for CrInMemoryStorage { ) -> Result { let limits_by_namespace = self.limits_for_namespace.write().unwrap(); let mut first_limited = None; - let mut counter_values_to_update: Vec<(&CrCounterValue, u64)> = Vec::new(); - let mut qualified_counter_values_to_updated: Vec<(Arc>, u64)> = + let mut counter_values_to_update: Vec<(&CrCounterValue, Counter)> = Vec::new(); + let mut qualified_counter_values_to_updated: Vec<(Arc>, Counter)> = Vec::new(); let now = SystemTime::now(); @@ -144,7 +151,7 @@ impl CounterStorage for CrInMemoryStorage { return Ok(limited); } } - counter_values_to_update.push((atomic_expiring_value, counter.seconds())); + counter_values_to_update.push((atomic_expiring_value, counter.clone())); } // Process qualified counters @@ -165,7 +172,7 @@ impl CounterStorage for CrInMemoryStorage { } } - qualified_counter_values_to_updated.push((value, counter.seconds())); + qualified_counter_values_to_updated.push((value, counter.clone())); } if let Some(limited) = first_limited { @@ -173,13 +180,15 @@ impl CounterStorage for CrInMemoryStorage { } // Update counters - counter_values_to_update.iter().for_each(|(v, ttl)| { - v.inc_at(delta as u64, Duration::from_secs(*ttl), now); - }); + counter_values_to_update + .into_iter() + .for_each(|(v, counter)| { + self.increment_counter(counter, v, delta as u64, now); + }); qualified_counter_values_to_updated - .iter() - .for_each(|(v, ttl)| { - v.inc_at(delta as u64, Duration::from_secs(*ttl), now); + .into_iter() + .for_each(|(v, counter)| { + self.increment_counter(counter, v.deref(), delta as u64, now); }); Ok(Authorization::Ok) @@ -242,11 +251,73 @@ impl CounterStorage for CrInMemoryStorage { } impl CrInMemoryStorage { - pub fn new(identifier: String, cache_size: u64) -> Self { + pub fn new(identifier: String, cache_size: u64, local: String, broadcast: String) -> Self { + let (sender, mut rx) = mpsc::channel(1000); + + let local = local.to_socket_addrs().unwrap().next().unwrap(); + let remote = broadcast.clone(); + tokio::spawn(async move { + let sock = UdpSocket::bind(local).await.unwrap(); + sock.set_broadcast(true).unwrap(); + sock.connect(remote).await.unwrap(); + loop { + let message: CounterValueMessage = rx.recv().await.unwrap(); + let buf = postcard::to_stdvec(&message).unwrap(); + match sock.send(&buf).await { + Ok(len) => { + if len != buf.len() { + println!("Couldn't send complete message!"); + } + } + Err(err) => println!("Couldn't send update: {:?}", err), + }; + } + }); + + let limits_for_namespace = Arc::new(RwLock::new(HashMap::< + Namespace, + HashMap>, + >::new())); + let qualified_counters = Arc::new(Cache::new(cache_size)); + + { + let limits_for_namespace = limits_for_namespace.clone(); + tokio::spawn(async move { + let sock = UdpSocket::bind(broadcast).await.unwrap(); + sock.set_broadcast(true).unwrap(); + let mut buf = [0; 1024]; + loop { + let (len, addr) = sock.recv_from(&mut buf).await.unwrap(); + if addr != local { + match postcard::from_bytes::(&buf[..len]) { + Ok(message) => { + let CounterValueMessage { + counter_key, + expiry, + values, + } = message; + let counter = >::into(counter_key); + let counters = limits_for_namespace.read().unwrap(); + let limits = counters.get(counter.namespace()).unwrap(); + let value = limits.get(counter.limit()).unwrap(); + value.merge( + (UNIX_EPOCH + Duration::from_secs(expiry), values).into(), + ); + } + Err(err) => { + println!("Error from {} bytes: {:?} \n{:?}", len, err, &buf[..len]) + } + } + } + } + }); + } + Self { identifier, - limits_for_namespace: RwLock::new(HashMap::new()), - qualified_counters: Cache::new(cache_size), + sender, + limits_for_namespace, + qualified_counters, } } @@ -291,4 +362,68 @@ impl CrInMemoryStorage { None => counter.max_value() >= delta, } } + + fn increment_counter( + &self, + key: Counter, + counter: &CrCounterValue, + delta: u64, + when: SystemTime, + ) { + counter.inc_at(delta, Duration::from_secs(key.seconds()), when); + let sender = self.sender.clone(); + let counter = counter.clone(); + tokio::spawn(async move { + let (expiry, values) = counter.into_inner(); + let message = CounterValueMessage { + counter_key: key.into(), + expiry: expiry.duration_since(UNIX_EPOCH).unwrap().as_secs(), + values, + }; + sender.send(message).await + }); + } +} + +#[derive(Debug, Serialize, Deserialize)] +struct CounterValueMessage { + counter_key: CounterKey, + expiry: u64, + values: BTreeMap, +} + +#[derive(Debug, Serialize, Deserialize)] +struct CounterKey { + namespace: Namespace, + seconds: u64, + conditions: HashSet, + variables: HashSet, + vars: HashMap, +} + +impl From for CounterKey { + fn from(value: Counter) -> Self { + Self { + namespace: value.namespace().clone(), + seconds: value.seconds(), + variables: value.limit().variables(), + conditions: value.limit().conditions(), + vars: value.set_variables().clone(), + } + } +} + +impl From for Counter { + fn from(value: CounterKey) -> Self { + Self::new( + Limit::new( + value.namespace, + 0, + value.seconds, + value.conditions, + value.vars.keys(), + ), + value.vars, + ) + } } diff --git a/limitador/src/storage/mod.rs b/limitador/src/storage/mod.rs index 12e2d55e..e55eee02 100644 --- a/limitador/src/storage/mod.rs +++ b/limitador/src/storage/mod.rs @@ -3,6 +3,7 @@ use crate::limit::{Limit, Namespace}; use crate::InMemoryStorage; use async_trait::async_trait; use std::collections::{HashMap, HashSet}; +use std::env; use std::sync::RwLock; use thiserror::Error; @@ -35,9 +36,18 @@ pub struct AsyncStorage { impl Storage { pub fn new(cache_size: u64) -> Self { + let local = + env::var("LOCAL").expect("We need the env var LOCAL to be set to your local :port"); + let broadcast = env::var("BROADCAST") + .expect("We need the env var BROADCAST to be set to your broadcast :port"); Self { limits: RwLock::new(HashMap::new()), - counters: Box::new(InMemoryStorage::new(cache_size)), + counters: Box::new(InMemoryStorage::new( + local.to_owned(), + cache_size, + local, + broadcast, + )), } } diff --git a/limitador/tests/integration_tests.rs b/limitador/tests/integration_tests.rs index 623326fb..4b95a4e3 100644 --- a/limitador/tests/integration_tests.rs +++ b/limitador/tests/integration_tests.rs @@ -96,8 +96,8 @@ mod test { use crate::helpers::tests_limiter::*; use limitador::limit::Limit; use limitador::storage::disk::{DiskStorage, OptimizeFor}; - use limitador::storage::in_memory::InMemoryStorage; use limitador::storage::distributed::CrInMemoryStorage; + use limitador::storage::in_memory::InMemoryStorage; use std::collections::{HashMap, HashSet}; use std::thread::sleep; use std::time::Duration; From 4909a94a45b7d74deaac72dd0b99bf543a58f7f0 Mon Sep 17 00:00:00 2001 From: Alex Snaps Date: Tue, 14 May 2024 14:24:30 -0400 Subject: [PATCH 05/10] Fix off by 1_000_000 --- limitador/src/storage/distributed/cr_counter_value.rs | 2 +- limitador/tests/integration_tests.rs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/limitador/src/storage/distributed/cr_counter_value.rs b/limitador/src/storage/distributed/cr_counter_value.rs index 75fbaf61..2c38e9a7 100644 --- a/limitador/src/storage/distributed/cr_counter_value.rs +++ b/limitador/src/storage/distributed/cr_counter_value.rs @@ -45,7 +45,7 @@ impl CrCounterValue { pub fn inc_at(&self, increment: u64, time_window: Duration, when: SystemTime) { if self .expiry - .update_if_expired(time_window.as_micros() as u64, when) + .update_if_expired(time_window.as_secs(), when) { self.value.store(increment, Ordering::SeqCst); } else { diff --git a/limitador/tests/integration_tests.rs b/limitador/tests/integration_tests.rs index 4b95a4e3..90623d9f 100644 --- a/limitador/tests/integration_tests.rs +++ b/limitador/tests/integration_tests.rs @@ -16,7 +16,7 @@ macro_rules! test_with_all_storage_impls { #[tokio::test] async fn [<$function _distributed_storage>]() { let rate_limiter = - RateLimiter::new_with_storage(Box::new(CrInMemoryStorage::new("test_node".to_owned(), 10_000))); + RateLimiter::new_with_storage(Box::new(CrInMemoryStorage::new("test_node".to_owned(), 10_000, "127.0.0.1:19876".to_owned(), "127.0.0.255:19876".to_owned()))); $function(&mut TestsLimiter::new_from_blocking_impl(rate_limiter)).await; } From a70be8076914013fb4df1d3aee204b2b49a05d8d Mon Sep 17 00:00:00 2001 From: Alex Snaps Date: Tue, 14 May 2024 14:46:17 -0400 Subject: [PATCH 06/10] Wired distributed storage in --- limitador-server/src/config.rs | 9 +++ limitador-server/src/main.rs | 63 ++++++++++++++++++- limitador/src/lib.rs | 2 +- .../storage/distributed/cr_counter_value.rs | 5 +- limitador/src/storage/distributed/mod.rs | 40 ++++++------ limitador/src/storage/mod.rs | 14 +---- 6 files changed, 94 insertions(+), 39 deletions(-) diff --git a/limitador-server/src/config.rs b/limitador-server/src/config.rs index 949bf446..e71bc98b 100644 --- a/limitador-server/src/config.rs +++ b/limitador-server/src/config.rs @@ -140,6 +140,7 @@ pub enum StorageConfiguration { InMemory(InMemoryStorageConfiguration), Disk(DiskStorageConfiguration), Redis(RedisStorageConfiguration), + Distributed(DistributedStorageConfiguration), } #[derive(PartialEq, Eq, Debug)] @@ -147,6 +148,14 @@ pub struct InMemoryStorageConfiguration { pub cache_size: Option, } +#[derive(PartialEq, Eq, Debug)] +pub struct DistributedStorageConfiguration { + pub name: String, + pub cache_size: Option, + pub local: String, + pub broadcast: String, +} + #[derive(PartialEq, Eq, Debug)] pub struct DiskStorageConfiguration { pub path: String, diff --git a/limitador-server/src/main.rs b/limitador-server/src/main.rs index 8962d5ef..5aabb526 100644 --- a/limitador-server/src/main.rs +++ b/limitador-server/src/main.rs @@ -6,8 +6,9 @@ extern crate log; extern crate clap; use crate::config::{ - Configuration, DiskStorageConfiguration, InMemoryStorageConfiguration, - RedisStorageCacheConfiguration, RedisStorageConfiguration, StorageConfiguration, + Configuration, DiskStorageConfiguration, DistributedStorageConfiguration, + InMemoryStorageConfiguration, RedisStorageCacheConfiguration, RedisStorageConfiguration, + StorageConfiguration, }; use crate::envoy_rls::server::{run_envoy_rls_server, RateLimitHeaders}; use crate::http_api::server::run_http_server; @@ -23,6 +24,7 @@ use limitador::storage::redis::{ AsyncRedisStorage, CachedRedisStorage, CachedRedisStorageBuilder, DEFAULT_BATCH_SIZE, DEFAULT_FLUSHING_PERIOD_SEC, DEFAULT_MAX_CACHED_COUNTERS, DEFAULT_RESPONSE_TIMEOUT_MS, }; +use limitador::storage::DistributedInMemoryStorage; use limitador::storage::{AsyncCounterStorage, AsyncStorage, Storage}; use limitador::{ storage, AsyncRateLimiter, AsyncRateLimiterBuilder, RateLimiter, RateLimiterBuilder, @@ -83,6 +85,7 @@ impl Limiter { let rate_limiter = match config.storage { StorageConfiguration::Redis(cfg) => Self::redis_limiter(cfg).await, StorageConfiguration::InMemory(cfg) => Self::in_memory_limiter(cfg), + StorageConfiguration::Distributed(cfg) => Self::distributed_limiter(cfg), StorageConfiguration::Disk(cfg) => Self::disk_limiter(cfg), }; @@ -154,6 +157,19 @@ impl Limiter { Self::Blocking(rate_limiter_builder.build()) } + fn distributed_limiter(cfg: DistributedStorageConfiguration) -> Self { + let storage = DistributedInMemoryStorage::new( + cfg.name, + cfg.cache_size.or_else(guess_cache_size).unwrap(), + cfg.local, + cfg.broadcast, + ); + let rate_limiter_builder = + RateLimiterBuilder::with_storage(Storage::with_counter_storage(Box::new(storage))); + + Self::Blocking(rate_limiter_builder.build()) + } + pub async fn load_limits_from_file>( &self, path: &P, @@ -563,6 +579,41 @@ fn create_config() -> (Configuration, &'static str) { .display_order(6) .help("Timeout for Redis commands in milliseconds"), ), + ) + .subcommand( + Command::new("distributed") + .about("Replicates CRDT-based counters across multiple Limitador servers") + .display_order(5) + .arg( + Arg::new("NAME") + .action(ArgAction::Set) + .required(true) + .display_order(2) + .help("Unique name to identify this Limitador instance"), + ) + .arg( + Arg::new("LOCAL") + .action(ArgAction::Set) + .required(true) + .display_order(2) + .help("Local IP:PORT to send datagrams from"), + ) + .arg( + Arg::new("BROADCAST") + .action(ArgAction::Set) + .required(true) + .display_order(3) + .help("Broadcast IP:PORT to send datagrams to"), + ) + .arg( + Arg::new("CACHE_SIZE") + .long("cache") + .short('c') + .action(ArgAction::Set) + .value_parser(value_parser!(u64)) + .display_order(4) + .help("Sets the size of the cache for 'qualified counters'"), + ), ); let matches = cmdline.get_matches(); @@ -630,6 +681,14 @@ fn create_config() -> (Configuration, &'static str) { Some(("memory", sub)) => StorageConfiguration::InMemory(InMemoryStorageConfiguration { cache_size: sub.get_one::("CACHE_SIZE").copied(), }), + Some(("distributed", sub)) => { + StorageConfiguration::Distributed(DistributedStorageConfiguration { + name: sub.get_one::("NAME").unwrap().to_owned(), + local: sub.get_one::("LOCAL").unwrap().to_owned(), + broadcast: sub.get_one::("BROADCAST").unwrap().to_owned(), + cache_size: sub.get_one::("CACHE_SIZE").copied(), + }) + } None => storage_config_from_env(), _ => unreachable!("Some storage wasn't configured!"), }; diff --git a/limitador/src/lib.rs b/limitador/src/lib.rs index 77c2c4f5..59f07a67 100644 --- a/limitador/src/lib.rs +++ b/limitador/src/lib.rs @@ -197,7 +197,7 @@ use std::collections::{HashMap, HashSet}; use crate::counter::Counter; use crate::errors::LimitadorError; use crate::limit::{Limit, Namespace}; -use crate::storage::distributed::CrInMemoryStorage as InMemoryStorage; +use crate::storage::in_memory::InMemoryStorage; use crate::storage::{AsyncCounterStorage, AsyncStorage, Authorization, CounterStorage, Storage}; #[macro_use] diff --git a/limitador/src/storage/distributed/cr_counter_value.rs b/limitador/src/storage/distributed/cr_counter_value.rs index 2c38e9a7..eb6fc1fb 100644 --- a/limitador/src/storage/distributed/cr_counter_value.rs +++ b/limitador/src/storage/distributed/cr_counter_value.rs @@ -43,10 +43,7 @@ impl CrCounterValue { } pub fn inc_at(&self, increment: u64, time_window: Duration, when: SystemTime) { - if self - .expiry - .update_if_expired(time_window.as_secs(), when) - { + if self.expiry.update_if_expired(time_window.as_secs(), when) { self.value.store(increment, Ordering::SeqCst); } else { self.value.fetch_add(increment, Ordering::SeqCst); diff --git a/limitador/src/storage/distributed/mod.rs b/limitador/src/storage/distributed/mod.rs index 70e20bf1..12289db6 100644 --- a/limitador/src/storage/distributed/mod.rs +++ b/limitador/src/storage/distributed/mod.rs @@ -29,7 +29,7 @@ pub struct CrInMemoryStorage { impl CounterStorage for CrInMemoryStorage { #[tracing::instrument(skip_all)] - fn is_within_limits(&self, counter: &Counter, delta: i64) -> Result { + fn is_within_limits(&self, counter: &Counter, delta: u64) -> Result { let limits_by_namespace = self.limits_for_namespace.read().unwrap(); let mut value = 0; @@ -44,7 +44,7 @@ impl CounterStorage for CrInMemoryStorage { } } - Ok(counter.max_value() as u64 >= value + (delta as u64)) + Ok(counter.max_value() >= value + delta) } #[tracing::instrument(skip_all)] @@ -64,7 +64,7 @@ impl CounterStorage for CrInMemoryStorage { } #[tracing::instrument(skip_all)] - fn update_counter(&self, counter: &Counter, delta: i64) -> Result<(), StorageErr> { + fn update_counter(&self, counter: &Counter, delta: u64) -> Result<(), StorageErr> { let mut limits_by_namespace = self.limits_for_namespace.write().unwrap(); let now = SystemTime::now(); if counter.is_qualified() { @@ -77,14 +77,14 @@ impl CounterStorage for CrInMemoryStorage { }), Some(counter) => counter, }; - self.increment_counter(counter.clone(), &value, delta as u64, now); + self.increment_counter(counter.clone(), &value, delta, now); } else { match limits_by_namespace.entry(counter.limit().namespace().clone()) { Entry::Vacant(v) => { let mut limits = HashMap::new(); let duration = Duration::from_secs(counter.seconds()); let counter_val = CrCounterValue::new(self.identifier.clone(), duration); - self.increment_counter(counter.clone(), &counter_val, delta as u64, now); + self.increment_counter(counter.clone(), &counter_val, delta, now); limits.insert(counter.limit().clone(), counter_val); v.insert(limits); } @@ -92,11 +92,11 @@ impl CounterStorage for CrInMemoryStorage { Entry::Vacant(v) => { let duration = Duration::from_secs(counter.seconds()); let counter_value = CrCounterValue::new(self.identifier.clone(), duration); - self.increment_counter(counter.clone(), &counter_value, delta as u64, now); + self.increment_counter(counter.clone(), &counter_value, delta, now); v.insert(counter_value); } Entry::Occupied(o) => { - self.increment_counter(counter.clone(), o.get(), delta as u64, now); + self.increment_counter(counter.clone(), o.get(), delta, now); } }, } @@ -108,7 +108,7 @@ impl CounterStorage for CrInMemoryStorage { fn check_and_update( &self, counters: &mut Vec, - delta: i64, + delta: u64, load_counters: bool, ) -> Result { let limits_by_namespace = self.limits_for_namespace.write().unwrap(); @@ -119,11 +119,11 @@ impl CounterStorage for CrInMemoryStorage { let now = SystemTime::now(); let mut process_counter = - |counter: &mut Counter, value: i64, delta: i64| -> Option { + |counter: &mut Counter, value: u64, delta: u64| -> Option { if load_counters { - let remaining = counter.max_value() - (value + delta); - counter.set_remaining(remaining); - if first_limited.is_none() && remaining < 0 { + let remaining = counter.max_value().checked_sub(value + delta); + counter.set_remaining(remaining.unwrap_or(0)); + if first_limited.is_none() && remaining.is_none() { first_limited = Some(Authorization::Limited( counter.limit().name().map(|n| n.to_owned()), )); @@ -144,9 +144,7 @@ impl CounterStorage for CrInMemoryStorage { .and_then(|limits| limits.get(counter.limit())) .unwrap(); - if let Some(limited) = - process_counter(counter, atomic_expiring_value.read() as i64, delta) - { + if let Some(limited) = process_counter(counter, atomic_expiring_value.read(), delta) { if !load_counters { return Ok(limited); } @@ -166,7 +164,7 @@ impl CounterStorage for CrInMemoryStorage { Some(counter) => counter, }; - if let Some(limited) = process_counter(counter, value.read() as i64, delta) { + if let Some(limited) = process_counter(counter, value.read(), delta) { if !load_counters { return Ok(limited); } @@ -183,12 +181,12 @@ impl CounterStorage for CrInMemoryStorage { counter_values_to_update .into_iter() .for_each(|(v, counter)| { - self.increment_counter(counter, v, delta as u64, now); + self.increment_counter(counter, v, delta, now); }); qualified_counter_values_to_updated .into_iter() .for_each(|(v, counter)| { - self.increment_counter(counter, v.deref(), delta as u64, now); + self.increment_counter(counter, v.deref(), delta, now); }); Ok(Authorization::Ok) @@ -208,7 +206,7 @@ impl CounterStorage for CrInMemoryStorage { for (counter, expiring_value) in self.counters_in_namespace(namespace) { let mut counter_with_val = counter.clone(); counter_with_val.set_remaining( - counter_with_val.max_value() - expiring_value.read() as i64, + counter_with_val.max_value() - expiring_value.read(), ); counter_with_val.set_expires_in(expiring_value.ttl()); if counter_with_val.expires_in().unwrap() > Duration::ZERO { @@ -224,7 +222,7 @@ impl CounterStorage for CrInMemoryStorage { if limits.contains(counter.limit()) { let mut counter_with_val = counter.deref().clone(); counter_with_val - .set_remaining(counter_with_val.max_value() - expiring_value.read() as i64); + .set_remaining(counter_with_val.max_value() - expiring_value.read()); counter_with_val.set_expires_in(expiring_value.ttl()); if counter_with_val.expires_in().unwrap() > Duration::ZERO { res.insert(counter_with_val); @@ -356,7 +354,7 @@ impl CrInMemoryStorage { } } - fn counter_is_within_limits(counter: &Counter, current_val: Option<&i64>, delta: i64) -> bool { + fn counter_is_within_limits(counter: &Counter, current_val: Option<&u64>, delta: u64) -> bool { match current_val { Some(current_val) => current_val + delta <= counter.max_value(), None => counter.max_value() >= delta, diff --git a/limitador/src/storage/mod.rs b/limitador/src/storage/mod.rs index e55eee02..a1d607c8 100644 --- a/limitador/src/storage/mod.rs +++ b/limitador/src/storage/mod.rs @@ -3,7 +3,6 @@ use crate::limit::{Limit, Namespace}; use crate::InMemoryStorage; use async_trait::async_trait; use std::collections::{HashMap, HashSet}; -use std::env; use std::sync::RwLock; use thiserror::Error; @@ -12,6 +11,8 @@ pub mod disk; pub mod distributed; pub mod in_memory; +pub use crate::storage::distributed::CrInMemoryStorage as DistributedInMemoryStorage; + #[cfg(feature = "redis_storage")] pub mod redis; @@ -36,18 +37,9 @@ pub struct AsyncStorage { impl Storage { pub fn new(cache_size: u64) -> Self { - let local = - env::var("LOCAL").expect("We need the env var LOCAL to be set to your local :port"); - let broadcast = env::var("BROADCAST") - .expect("We need the env var BROADCAST to be set to your broadcast :port"); Self { limits: RwLock::new(HashMap::new()), - counters: Box::new(InMemoryStorage::new( - local.to_owned(), - cache_size, - local, - broadcast, - )), + counters: Box::new(InMemoryStorage::new(cache_size)), } } From fb6f0164aa92f8216581e058d24232275ba8a290 Mon Sep 17 00:00:00 2001 From: Alex Snaps Date: Tue, 14 May 2024 14:53:35 -0400 Subject: [PATCH 07/10] Ignore distributed storage when targetting wasm32 --- limitador/src/storage/atomic_expiring_value.rs | 2 ++ limitador/src/storage/mod.rs | 2 ++ 2 files changed, 4 insertions(+) diff --git a/limitador/src/storage/atomic_expiring_value.rs b/limitador/src/storage/atomic_expiring_value.rs index c193bedc..8b00c7bd 100644 --- a/limitador/src/storage/atomic_expiring_value.rs +++ b/limitador/src/storage/atomic_expiring_value.rs @@ -134,10 +134,12 @@ impl AtomicExpiryTime { Ok(()) } + #[allow(dead_code)] pub fn into_inner(self) -> SystemTime { self.expires_at() } + #[allow(dead_code)] pub fn expires_at(&self) -> SystemTime { SystemTime::UNIX_EPOCH + Duration::from_micros(self.expiry.load(Ordering::SeqCst)) } diff --git a/limitador/src/storage/mod.rs b/limitador/src/storage/mod.rs index a1d607c8..f51979b9 100644 --- a/limitador/src/storage/mod.rs +++ b/limitador/src/storage/mod.rs @@ -8,9 +8,11 @@ use thiserror::Error; #[cfg(feature = "disk_storage")] pub mod disk; +#[cfg(not(target_arch = "wasm32"))] pub mod distributed; pub mod in_memory; +#[cfg(not(target_arch = "wasm32"))] pub use crate::storage::distributed::CrInMemoryStorage as DistributedInMemoryStorage; #[cfg(feature = "redis_storage")] From 99b9e872b9d35f8e71241f019b39356311785e0d Mon Sep 17 00:00:00 2001 From: Alex Snaps Date: Wed, 15 May 2024 11:01:50 -0400 Subject: [PATCH 08/10] Accept replicated qualified counters --- limitador/src/storage/distributed/mod.rs | 25 +++++++++++++++++------- 1 file changed, 18 insertions(+), 7 deletions(-) diff --git a/limitador/src/storage/distributed/mod.rs b/limitador/src/storage/distributed/mod.rs index 12289db6..5732a322 100644 --- a/limitador/src/storage/distributed/mod.rs +++ b/limitador/src/storage/distributed/mod.rs @@ -276,10 +276,12 @@ impl CrInMemoryStorage { Namespace, HashMap>, >::new())); - let qualified_counters = Arc::new(Cache::new(cache_size)); + let qualified_counters: Arc>>> = + Arc::new(Cache::new(cache_size)); { let limits_for_namespace = limits_for_namespace.clone(); + let qualified_counters = qualified_counters.clone(); tokio::spawn(async move { let sock = UdpSocket::bind(broadcast).await.unwrap(); sock.set_broadcast(true).unwrap(); @@ -295,12 +297,21 @@ impl CrInMemoryStorage { values, } = message; let counter = >::into(counter_key); - let counters = limits_for_namespace.read().unwrap(); - let limits = counters.get(counter.namespace()).unwrap(); - let value = limits.get(counter.limit()).unwrap(); - value.merge( - (UNIX_EPOCH + Duration::from_secs(expiry), values).into(), - ); + if counter.is_qualified() { + if let Some(counter) = qualified_counters.get(&counter) { + counter.merge( + (UNIX_EPOCH + Duration::from_secs(expiry), values) + .into(), + ); + } + } else { + let counters = limits_for_namespace.read().unwrap(); + let limits = counters.get(counter.namespace()).unwrap(); + let value = limits.get(counter.limit()).unwrap(); + value.merge( + (UNIX_EPOCH + Duration::from_secs(expiry), values).into(), + ); + }; } Err(err) => { println!("Error from {} bytes: {:?} \n{:?}", len, err, &buf[..len]) From 020e46835896aa0e2160e95bcae7b66476d7f180 Mon Sep 17 00:00:00 2001 From: Alex Snaps Date: Wed, 15 May 2024 11:10:46 -0400 Subject: [PATCH 09/10] Add distributed_storage feature --- limitador-server/Cargo.toml | 3 + limitador-server/src/config.rs | 2 + limitador-server/src/main.rs | 83 +++++++++++++++------------- limitador/Cargo.toml | 1 + limitador/src/storage/mod.rs | 4 +- limitador/tests/integration_tests.rs | 2 + 6 files changed, 55 insertions(+), 40 deletions(-) diff --git a/limitador-server/Cargo.toml b/limitador-server/Cargo.toml index f7b2c146..4acdb31d 100644 --- a/limitador-server/Cargo.toml +++ b/limitador-server/Cargo.toml @@ -12,6 +12,9 @@ documentation = "https://kuadrant.io/docs/limitador" readme = "README.md" edition = "2021" +[features] +distributed_storage = ["limitador/distributed_storage"] + # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] diff --git a/limitador-server/src/config.rs b/limitador-server/src/config.rs index e71bc98b..dc4ef59c 100644 --- a/limitador-server/src/config.rs +++ b/limitador-server/src/config.rs @@ -140,6 +140,7 @@ pub enum StorageConfiguration { InMemory(InMemoryStorageConfiguration), Disk(DiskStorageConfiguration), Redis(RedisStorageConfiguration), + #[cfg(feature = "distributed_storage")] Distributed(DistributedStorageConfiguration), } @@ -149,6 +150,7 @@ pub struct InMemoryStorageConfiguration { } #[derive(PartialEq, Eq, Debug)] +#[cfg(feature = "distributed_storage")] pub struct DistributedStorageConfiguration { pub name: String, pub cache_size: Option, diff --git a/limitador-server/src/main.rs b/limitador-server/src/main.rs index 5aabb526..f04e4fa9 100644 --- a/limitador-server/src/main.rs +++ b/limitador-server/src/main.rs @@ -5,10 +5,11 @@ extern crate log; extern crate clap; +#[cfg(feature = "distributed_storage")] +use crate::config::DistributedStorageConfiguration; use crate::config::{ - Configuration, DiskStorageConfiguration, DistributedStorageConfiguration, - InMemoryStorageConfiguration, RedisStorageCacheConfiguration, RedisStorageConfiguration, - StorageConfiguration, + Configuration, DiskStorageConfiguration, InMemoryStorageConfiguration, + RedisStorageCacheConfiguration, RedisStorageConfiguration, StorageConfiguration, }; use crate::envoy_rls::server::{run_envoy_rls_server, RateLimitHeaders}; use crate::http_api::server::run_http_server; @@ -24,6 +25,7 @@ use limitador::storage::redis::{ AsyncRedisStorage, CachedRedisStorage, CachedRedisStorageBuilder, DEFAULT_BATCH_SIZE, DEFAULT_FLUSHING_PERIOD_SEC, DEFAULT_MAX_CACHED_COUNTERS, DEFAULT_RESPONSE_TIMEOUT_MS, }; +#[cfg(feature = "distributed_storage")] use limitador::storage::DistributedInMemoryStorage; use limitador::storage::{AsyncCounterStorage, AsyncStorage, Storage}; use limitador::{ @@ -85,6 +87,7 @@ impl Limiter { let rate_limiter = match config.storage { StorageConfiguration::Redis(cfg) => Self::redis_limiter(cfg).await, StorageConfiguration::InMemory(cfg) => Self::in_memory_limiter(cfg), + #[cfg(feature = "distributed_storage")] StorageConfiguration::Distributed(cfg) => Self::distributed_limiter(cfg), StorageConfiguration::Disk(cfg) => Self::disk_limiter(cfg), }; @@ -157,6 +160,7 @@ impl Limiter { Self::Blocking(rate_limiter_builder.build()) } + #[cfg(feature = "distributed_storage")] fn distributed_limiter(cfg: DistributedStorageConfiguration) -> Self { let storage = DistributedInMemoryStorage::new( cfg.name, @@ -579,43 +583,45 @@ fn create_config() -> (Configuration, &'static str) { .display_order(6) .help("Timeout for Redis commands in milliseconds"), ), - ) - .subcommand( - Command::new("distributed") - .about("Replicates CRDT-based counters across multiple Limitador servers") - .display_order(5) - .arg( - Arg::new("NAME") - .action(ArgAction::Set) - .required(true) - .display_order(2) - .help("Unique name to identify this Limitador instance"), - ) - .arg( - Arg::new("LOCAL") - .action(ArgAction::Set) - .required(true) - .display_order(2) - .help("Local IP:PORT to send datagrams from"), - ) - .arg( - Arg::new("BROADCAST") - .action(ArgAction::Set) - .required(true) - .display_order(3) - .help("Broadcast IP:PORT to send datagrams to"), - ) - .arg( - Arg::new("CACHE_SIZE") - .long("cache") - .short('c') - .action(ArgAction::Set) - .value_parser(value_parser!(u64)) - .display_order(4) - .help("Sets the size of the cache for 'qualified counters'"), - ), ); + #[cfg(feature = "distributed_storage")] + let cmdline = cmdline.subcommand( + Command::new("distributed") + .about("Replicates CRDT-based counters across multiple Limitador servers") + .display_order(5) + .arg( + Arg::new("NAME") + .action(ArgAction::Set) + .required(true) + .display_order(2) + .help("Unique name to identify this Limitador instance"), + ) + .arg( + Arg::new("LOCAL") + .action(ArgAction::Set) + .required(true) + .display_order(2) + .help("Local IP:PORT to send datagrams from"), + ) + .arg( + Arg::new("BROADCAST") + .action(ArgAction::Set) + .required(true) + .display_order(3) + .help("Broadcast IP:PORT to send datagrams to"), + ) + .arg( + Arg::new("CACHE_SIZE") + .long("cache") + .short('c') + .action(ArgAction::Set) + .value_parser(value_parser!(u64)) + .display_order(4) + .help("Sets the size of the cache for 'qualified counters'"), + ), + ); + let matches = cmdline.get_matches(); let limits_file = matches.get_one::("LIMITS_FILE").unwrap(); @@ -681,6 +687,7 @@ fn create_config() -> (Configuration, &'static str) { Some(("memory", sub)) => StorageConfiguration::InMemory(InMemoryStorageConfiguration { cache_size: sub.get_one::("CACHE_SIZE").copied(), }), + #[cfg(feature = "distributed_storage")] Some(("distributed", sub)) => { StorageConfiguration::Distributed(DistributedStorageConfiguration { name: sub.get_one::("NAME").unwrap().to_owned(), diff --git a/limitador/Cargo.toml b/limitador/Cargo.toml index 456b3309..8f0a681b 100644 --- a/limitador/Cargo.toml +++ b/limitador/Cargo.toml @@ -15,6 +15,7 @@ edition = "2021" [features] default = ["disk_storage", "redis_storage"] disk_storage = ["rocksdb"] +distributed_storage = [] redis_storage = ["redis", "r2d2", "tokio"] lenient_conditions = [] diff --git a/limitador/src/storage/mod.rs b/limitador/src/storage/mod.rs index f51979b9..22abd33a 100644 --- a/limitador/src/storage/mod.rs +++ b/limitador/src/storage/mod.rs @@ -8,11 +8,11 @@ use thiserror::Error; #[cfg(feature = "disk_storage")] pub mod disk; -#[cfg(not(target_arch = "wasm32"))] +#[cfg(feature = "distributed_storage")] pub mod distributed; pub mod in_memory; -#[cfg(not(target_arch = "wasm32"))] +#[cfg(feature = "distributed_storage")] pub use crate::storage::distributed::CrInMemoryStorage as DistributedInMemoryStorage; #[cfg(feature = "redis_storage")] diff --git a/limitador/tests/integration_tests.rs b/limitador/tests/integration_tests.rs index 90623d9f..2b1e9afe 100644 --- a/limitador/tests/integration_tests.rs +++ b/limitador/tests/integration_tests.rs @@ -13,6 +13,7 @@ macro_rules! test_with_all_storage_impls { $function(&mut TestsLimiter::new_from_blocking_impl(rate_limiter)).await; } + #[cfg(feature = "distributed_storage")] #[tokio::test] async fn [<$function _distributed_storage>]() { let rate_limiter = @@ -96,6 +97,7 @@ mod test { use crate::helpers::tests_limiter::*; use limitador::limit::Limit; use limitador::storage::disk::{DiskStorage, OptimizeFor}; + #[cfg(feature = "distributed_storage")] use limitador::storage::distributed::CrInMemoryStorage; use limitador::storage::in_memory::InMemoryStorage; use std::collections::{HashMap, HashSet}; From 907dc582a2fb67fbad6932ced26cacef4bed6e83 Mon Sep 17 00:00:00 2001 From: Alex Snaps Date: Wed, 15 May 2024 13:54:33 -0400 Subject: [PATCH 10/10] Readded features in full_version --- limitador-server/build.rs | 9 +++++++++ limitador-server/src/main.rs | 5 +++-- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/limitador-server/build.rs b/limitador-server/build.rs index f7a0d811..9700c93b 100644 --- a/limitador-server/build.rs +++ b/limitador-server/build.rs @@ -6,6 +6,7 @@ use std::process::Command; fn main() -> Result<(), Box> { set_git_hash("LIMITADOR_GIT_HASH"); set_profile("LIMITADOR_PROFILE"); + set_features("LIMITADOR_FEATURES"); generate_protobuf() } @@ -31,6 +32,14 @@ fn set_profile(env: &str) { } } +fn set_features(env: &str) { + let mut features = vec![]; + if cfg!(feature = "distributed_storage") { + features.push("+distributed"); + } + println!("cargo:rustc-env={env}={features:?}"); +} + fn set_git_hash(env: &str) { let git_sha = Command::new("/usr/bin/git") .args(["rev-parse", "HEAD"]) diff --git a/limitador-server/src/main.rs b/limitador-server/src/main.rs index f04e4fa9..a4ced111 100644 --- a/limitador-server/src/main.rs +++ b/limitador-server/src/main.rs @@ -61,6 +61,7 @@ pub mod prometheus_metrics; const LIMITADOR_VERSION: &str = env!("CARGO_PKG_VERSION"); const LIMITADOR_PROFILE: &str = env!("LIMITADOR_PROFILE"); +const LIMITADOR_FEATURES: &str = env!("LIMITADOR_FEATURES"); const LIMITADOR_HEADER: &str = "Limitador Server"; #[derive(Error, Debug)] @@ -370,12 +371,12 @@ async fn main() -> Result<(), Box> { fn create_config() -> (Configuration, &'static str) { let full_version: &'static str = formatcp!( - "v{} ({}) {}", + "v{} ({}) {} {}", LIMITADOR_VERSION, env!("LIMITADOR_GIT_HASH"), + LIMITADOR_FEATURES, LIMITADOR_PROFILE, ); - // wire args based of defaults let limit_arg = Arg::new("LIMITS_FILE") .action(ArgAction::Set)