From 8aec37803a184f93e5eb0bcdeaf884d6dcdd5aee Mon Sep 17 00:00:00 2001 From: Erin Power <xampprocky@gmail.com> Date: Mon, 14 Oct 2024 18:06:32 +0200 Subject: [PATCH] perf: Use papaya over dashmap --- Cargo.lock | 118 ++++++++++ Cargo.toml | 2 + src/collections/ttl.rs | 382 +++++-------------------------- src/components/proxy/sessions.rs | 28 ++- src/config.rs | 44 ++-- src/filters/local_rate_limit.rs | 39 ++-- src/net/cluster.rs | 167 ++++++-------- src/net/endpoint/address.rs | 2 +- src/net/phoenix.rs | 7 +- 9 files changed, 305 insertions(+), 484 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index ee5554877a..4e23580dc8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -187,6 +187,16 @@ dependencies = [ "syn 2.0.72", ] +[[package]] +name = "atomic-wait" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a55b94919229f2c42292fd71ffa4b75e83193bffdd77b1e858cd55fd2d0b0ea8" +dependencies = [ + "libc", + "windows-sys 0.42.0", +] + [[package]] name = "atomic-waker" version = "1.1.2" @@ -2019,6 +2029,17 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39" +[[package]] +name = "papaya" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d17fbf29d99ed1d2a1fecdb37d08898790965c85fd2634ba4023ab9710089059" +dependencies = [ + "atomic-wait", + "seize", + "serde", +] + [[package]] name = "parking" version = "2.2.0" @@ -2370,6 +2391,7 @@ dependencies = [ "notify", "num_cpus", "once_cell", + "papaya", "parking_lot", "pprof2", "pretty_assertions", @@ -2383,6 +2405,7 @@ dependencies = [ "regex", "schemars", "seahash", + "seize", "serde", "serde_json", "serde_regex", @@ -2421,6 +2444,38 @@ dependencies = [ "syn 2.0.72", ] +[[package]] +name = "quilkin-profiling" +version = "0.10.0-dev" +dependencies = [ + "arc-swap", + "async-stream", + "cached", + "enum-map", + "eyre", + "fixedstr", + "futures", + "once_cell", + "parking_lot", + "prometheus", + "prost", + "prost-types", + "quilkin-proto", + "rand", + "schemars", + "serde", + "serde_json", + "thiserror", + "tokio", + "tokio-stream", + "tonic", + "tracing", + "tracing-futures", + "tryhard", + "url", + "uuid", +] + [[package]] name = "quilkin-proto" version = "0.10.0-dev" @@ -2816,6 +2871,12 @@ dependencies = [ "libc", ] +[[package]] +name = "seize" +version = "0.4.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d659fa6f19e82a52ab8d3fff3c380bd8cc16462eaea411395618a38760eb85bc" + [[package]] name = "serde" version = "1.0.205" @@ -3728,6 +3789,21 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" +[[package]] +name = "windows-sys" +version = "0.42.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a3e1820f08b8513f676f7ab6c1f99ff312fb97b553d30ff4dd86f9f15728aa7" +dependencies = [ + "windows_aarch64_gnullvm 0.42.2", + "windows_aarch64_msvc 0.42.2", + "windows_i686_gnu 0.42.2", + "windows_i686_msvc 0.42.2", + "windows_x86_64_gnu 0.42.2", + "windows_x86_64_gnullvm 0.42.2", + "windows_x86_64_msvc 0.42.2", +] + [[package]] name = "windows-sys" version = "0.48.0" @@ -3786,6 +3862,12 @@ dependencies = [ "windows_x86_64_msvc 0.52.6", ] +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "597a5118570b68bc08d8d59125332c54f1ba9d9adeedeef5b99b02ba2b0698f8" + [[package]] name = "windows_aarch64_gnullvm" version = "0.48.5" @@ -3798,6 +3880,12 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3" +[[package]] +name = "windows_aarch64_msvc" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e08e8864a60f06ef0d0ff4ba04124db8b0fb3be5776a5cd47641e942e58c4d43" + [[package]] name = "windows_aarch64_msvc" version = "0.48.5" @@ -3810,6 +3898,12 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469" +[[package]] +name = "windows_i686_gnu" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c61d927d8da41da96a81f029489353e68739737d3beca43145c8afec9a31a84f" + [[package]] name = "windows_i686_gnu" version = "0.48.5" @@ -3828,6 +3922,12 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66" +[[package]] +name = "windows_i686_msvc" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "44d840b6ec649f480a41c8d80f9c65108b92d89345dd94027bfe06ac444d1060" + [[package]] name = "windows_i686_msvc" version = "0.48.5" @@ -3840,6 +3940,12 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66" +[[package]] +name = "windows_x86_64_gnu" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8de912b8b8feb55c064867cf047dda097f92d51efad5b491dfb98f6bbb70cb36" + [[package]] name = "windows_x86_64_gnu" version = "0.48.5" @@ -3852,6 +3958,12 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78" +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "26d41b46a36d453748aedef1486d5c7a85db22e56aff34643984ea85514e94a3" + [[package]] name = "windows_x86_64_gnullvm" version = "0.48.5" @@ -3864,6 +3976,12 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d" +[[package]] +name = "windows_x86_64_msvc" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9aec5da331524158c6d1a4ac0ab1541149c0b9505fde06423b02f5ef0106b9f0" + [[package]] name = "windows_x86_64_msvc" version = "0.48.5" diff --git a/Cargo.toml b/Cargo.toml index 532dd400c4..9baf6596b9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -150,6 +150,8 @@ cfg-if = "1.0.0" libflate = "2.0.0" form_urlencoded = "1.2.1" gxhash = "3.4.1" +papaya = { version = "0.1.3", features = ["serde"] } +seize = "0.4.5" [dependencies.hyper-util] version = "0.1" diff --git a/src/collections/ttl.rs b/src/collections/ttl.rs index 526b147e91..7dadc5b315 100644 --- a/src/collections/ttl.rs +++ b/src/collections/ttl.rs @@ -14,9 +14,6 @@ * limitations under the License. */ -use dashmap::mapref::entry::Entry as DashMapEntry; -use dashmap::mapref::one::{Ref, RefMut}; -use dashmap::DashMap; use tracing::warn; use std::hash::Hash; @@ -25,7 +22,7 @@ use std::sync::Arc; use std::time::Duration; use tokio::sync::oneshot::{channel, Receiver, Sender}; -pub use dashmap::try_result::TryResult; +type HashMap<K, V, S = gxhash::GxBuildHasher> = papaya::HashMap<K, V, S>; // Clippy isn't recognizing that these imports are used conditionally. #[allow(unused_imports)] @@ -92,7 +89,7 @@ impl<V> std::ops::Deref for Value<V> { /// Map contains the hash map implementation. struct Map<K, V> { - inner: DashMap<K, Value<V>>, + inner: HashMap<K, Value<V>, gxhash::GxBuildHasher>, ttl: Duration, clock: Clock, shutdown_tx: Option<Sender<()>>, @@ -132,15 +129,19 @@ where V: Send + Sync + 'static, { pub fn new(ttl: Duration, poll_interval: Duration) -> Self { - Self::initialize(DashMap::new(), ttl, poll_interval) + Self::initialize(<_>::default(), ttl, poll_interval) } #[allow(dead_code)] pub fn with_capacity(ttl: Duration, poll_interval: Duration, capacity: usize) -> Self { - Self::initialize(DashMap::with_capacity(capacity), ttl, poll_interval) + Self::initialize( + HashMap::with_capacity_and_hasher(capacity, <_>::default()), + ttl, + poll_interval, + ) } - fn initialize(inner: DashMap<K, Value<V>>, ttl: Duration, poll_interval: Duration) -> Self { + fn initialize(inner: HashMap<K, Value<V>>, ttl: Duration, poll_interval: Duration) -> Self { let (shutdown_tx, shutdown_rx) = channel(); let map = TtlMap(Arc::new(Map { inner, @@ -165,41 +166,38 @@ where } } -#[allow(dead_code)] impl<K, V> TtlMap<K, V> where K: Hash + Eq + Send + Sync + 'static, - V: Send + Sync, + V: Send + Sync + Clone, { /// Returns a reference to value corresponding to key. - pub fn get(&self, key: &K) -> Option<Ref<K, Value<V>>> { - let value = self.0.inner.get(key); - if let Some(ref value) = value { - value.update_expiration(self.0.ttl) + pub fn get(&self, key: &K) -> Option<V> { + let pin = self.0.inner.pin(); + let value = pin.get(key); + if let Some(value) = value { + value.update_expiration(self.0.ttl); } - value + value.map(|value| value.value.clone()) } +} +impl<K, V> TtlMap<K, V> +where + K: Hash + Eq + Send + Sync + 'static, + V: Send + Sync, +{ /// Returns a reference to value corresponding to key. - pub fn try_get(&self, key: &K) -> TryResult<Ref<K, Value<V>>> { - let value = self.0.inner.try_get(key); - if let TryResult::Present(ref value) = value { - value.update_expiration(self.0.ttl) - } - - value - } - - /// Returns a mutable reference to value corresponding to key. - /// The value will be reset to expire at the configured TTL after the time of retrieval. - pub fn get_mut(&self, key: &K) -> Option<RefMut<K, Value<V>>> { - let value = self.0.inner.get_mut(key); - if let Some(ref value) = value { + pub fn get_by_ref<F>(&self, key: &K, and_then: impl FnOnce(&V) -> F) -> Option<F> { + let pin = self.0.inner.pin(); + let value = pin.get(key); + if let Some(value) = value { value.update_expiration(self.0.ttl); + Some((and_then)(value)) + } else { + None } - - value } /// Returns the number of entries currently in the map. @@ -219,41 +217,22 @@ where /// Returns true if the map contains a value for the specified key. pub fn contains_key(&self, key: &K) -> bool { - self.0.inner.contains_key(key) + self.0.inner.pin().contains_key(key) } /// Inserts a key-value pair into the map. /// The value will be set to expire at the configured TTL after the time of insertion. /// If a previous value existed for this key, that value is returned. - pub fn insert(&self, key: K, value: V) -> Option<V> { + pub fn insert(&self, key: K, value: V) { self.0 .inner - .insert(key, Value::new(value, self.0.ttl, self.0.clock.clone())) - .map(|value| value.value) + .pin() + .insert(key, Value::new(value, self.0.ttl, self.0.clock.clone())); } /// Removes a key-value pair from the map. pub fn remove(&self, key: K) -> bool { - self.0.inner.remove(&key).is_some() - } - - /// Returns an entry for in-place updates of the specified key-value pair. - /// Note: This acquires a write lock on the map's shard that corresponds - /// to the entry. - pub fn entry(&self, key: K) -> Entry<K, Value<V>> { - let ttl = self.0.ttl; - match self.0.inner.entry(key) { - inner @ DashMapEntry::Occupied(_) => Entry::Occupied(OccupiedEntry { - inner, - ttl, - clock: self.0.clock.clone(), - }), - inner @ DashMapEntry::Vacant(_) => Entry::Vacant(VacantEntry { - inner, - ttl, - clock: self.0.clock.clone(), - }), - } + self.0.inner.pin().remove(&key).is_some() } } @@ -283,87 +262,6 @@ where } } -/// A view into an occupied entry in the map. -pub struct OccupiedEntry<'a, K, V> { - inner: DashMapEntry<'a, K, V>, - ttl: Duration, - clock: Clock, -} - -/// A view into a vacant entry in the map. -pub struct VacantEntry<'a, K, V> { - inner: DashMapEntry<'a, K, V>, - ttl: Duration, - clock: Clock, -} - -/// A view into an entry in the map. -/// It may either be [`VacantEntry`] or [`OccupiedEntry`] -pub enum Entry<'a, K, V> { - Occupied(OccupiedEntry<'a, K, V>), - Vacant(VacantEntry<'a, K, V>), -} - -impl<'a, K, V> OccupiedEntry<'a, K, Value<V>> -where - K: Eq + Hash, -{ - /// Returns a reference to the entry's value. - /// The value will be reset to expire at the configured TTL after the time of retrieval. - pub fn get(&self) -> &Value<V> { - match &self.inner { - DashMapEntry::Occupied(entry) => { - let value = entry.get(); - value.update_expiration(self.ttl); - value - } - _ => unreachable!("BUG: entry type should be occupied"), - } - } - - #[allow(dead_code)] - /// Returns a mutable reference to the entry's value. - /// The value will be reset to expire at the configured TTL after the time of retrieval. - pub fn get_mut(&mut self) -> &mut Value<V> { - match &mut self.inner { - DashMapEntry::Occupied(entry) => { - let value = entry.get_mut(); - value.update_expiration(self.ttl); - value - } - _ => unreachable!("BUG: entry type should be occupied"), - } - } - - #[allow(dead_code)] - /// Replace the entry's value with a new value, returning the old value. - /// The value will be set to expire at the configured TTL after the time of insertion. - pub fn insert(&mut self, value: V) -> Value<V> { - match &mut self.inner { - DashMapEntry::Occupied(entry) => { - entry.insert(Value::new(value, self.ttl, self.clock.clone())) - } - _ => unreachable!("BUG: entry type should be occupied"), - } - } -} - -impl<'a, K, V> VacantEntry<'a, K, Value<V>> -where - K: Eq + Hash, -{ - /// Set an entry's value. - /// The value will be set to expire at the configured TTL after the time of insertion. - pub fn insert(self, value: V) -> RefMut<'a, K, Value<V>> { - match self.inner { - DashMapEntry::Vacant(entry) => { - entry.insert(Value::new(value, self.ttl, self.clock.clone())) - } - _ => unreachable!("BUG: entry type should be vacant"), - } - } -} - fn spawn_cleanup_task<K, V>( map: Arc<Map<K, V>>, poll_interval: Duration, @@ -401,21 +299,13 @@ where return; }; - // Take a read lock first and check if there is at least 1 item to remove. - let has_expired_keys = map - .inner + let pin = map.inner.pin(); + let expired_keys = pin .iter() - .filter(|entry| entry.value().expiration_secs() <= now_secs) - .take(1) - .next() - .is_some(); - - // If we have work to do then, take a write lock. - if has_expired_keys { - // Go over the whole map in case anything expired - // since acquiring the write lock. - map.inner - .retain(|_, value| value.expiration_secs() > now_secs); + .filter(|(_, value)| value.expiration_secs() <= now_secs); + + for (key, _) in expired_keys { + map.inner.pin().remove(key); } } @@ -512,8 +402,8 @@ mod tests { map.insert(one.clone(), 1); map.insert(two.clone(), 2); - assert_eq!(map.get(&one).unwrap().value, 1); - assert_eq!(map.get(&two).unwrap().value, 2); + assert_eq!(map.get(&one).unwrap(), 1); + assert_eq!(map.get(&two).unwrap(), 2); } #[tokio::test] @@ -527,15 +417,17 @@ mod tests { Duration::from_secs(10), Duration::from_millis(10), ); - map.insert(one.clone(), 1); - let exp1 = map.get(&one).unwrap().expiration_secs(); + map.insert(one.clone(), 1); + let exp1 = map.0.inner.pin().get(&one).unwrap().expiration_secs(); time::advance(Duration::from_secs(2)).await; - let exp2 = map.get(&one).unwrap().expiration_secs(); + let _ = map.get(&one).unwrap(); + let exp2 = map.0.inner.pin().get(&one).unwrap().expiration_secs(); time::advance(Duration::from_secs(3)).await; - let exp3 = map.get(&one).unwrap().expiration_secs(); + let _ = map.get(&one).unwrap(); + let exp3 = map.0.inner.pin().get(&one).unwrap().expiration_secs(); assert!(exp1 < exp2); assert_eq!(2, exp2 - exp1); @@ -560,177 +452,6 @@ mod tests { assert!(map.contains_key(&two)); } - #[tokio::test] - async fn entry_occupied_insert_and_get() { - let (one, _) = address_pair(); - - let map = TtlMap::<EndpointAddress, usize>::new( - Duration::from_secs(10), - Duration::from_millis(10), - ); - map.insert(one.clone(), 1); - - match map.entry(one.clone()) { - Entry::Occupied(mut entry) => { - assert_eq!(entry.get().value, 1); - entry.insert(5); - } - _ => unreachable!("expected occupied entry"), - } - - assert_eq!(map.get(&one).unwrap().value, 5); - } - - #[tokio::test] - async fn entry_occupied_get_mut() { - let (one, _) = address_pair(); - - let map = TtlMap::<EndpointAddress, usize>::new( - Duration::from_secs(10), - Duration::from_millis(10), - ); - map.insert(one.clone(), 1); - - match map.entry(one.clone()) { - Entry::Occupied(mut entry) => { - entry.get_mut().value = 5; - } - _ => unreachable!("expected occupied entry"), - } - - assert_eq!(map.get(&one).unwrap().value, 5); - } - - #[tokio::test] - async fn entry_vacant_insert() { - let (one, _) = address_pair(); - - let map = TtlMap::<EndpointAddress, usize>::new( - Duration::from_secs(10), - Duration::from_millis(10), - ); - - match map.entry(one.clone()) { - Entry::Vacant(entry) => { - let mut e = entry.insert(1); - assert_eq!(e.value, 1); - e.value = 5; - } - _ => unreachable!("expected occupied entry"), - } - - assert_eq!(map.get(&one).unwrap().value, 5); - } - - #[tokio::test] - async fn entry_occupied_get_expiration() { - // Test that when we get a value via OccupiedEntry, we update its expiration. - time::pause(); - - let (one, _) = address_pair(); - - let map = TtlMap::<EndpointAddress, usize>::new( - Duration::from_secs(10), - Duration::from_millis(10), - ); - map.insert(one.clone(), 1); - - let exp1 = map.get(&one).unwrap().expiration_secs(); - - time::advance(Duration::from_secs(2)).await; - - let exp2 = match map.entry(one.clone()) { - Entry::Occupied(entry) => entry.get().expiration_secs(), - _ => unreachable!("expected occupied entry"), - }; - - assert!(exp1 < exp2); - assert_eq!(2, exp2 - exp1); - } - - #[tokio::test] - async fn entry_occupied_get_mut_expiration() { - // Test that when we get_mut a value via OccupiedEntry, we update its expiration. - time::pause(); - - let (one, _) = address_pair(); - - let map = TtlMap::<EndpointAddress, usize>::new( - Duration::from_secs(10), - Duration::from_millis(10), - ); - map.insert(one.clone(), 1); - - let exp1 = map.get(&one).unwrap().expiration_secs(); - - time::advance(Duration::from_secs(2)).await; - - let exp2 = match map.entry(one) { - Entry::Occupied(mut entry) => entry.get_mut().expiration_secs(), - _ => unreachable!("expected occupied entry"), - }; - - assert!(exp1 < exp2); - assert_eq!(2, exp2 - exp1); - } - - #[tokio::test] - async fn entry_occupied_insert_expiration() { - // Test that when we replace a value via OccupiedEntry, we update its expiration. - time::pause(); - - let (one, _) = address_pair(); - - let map = TtlMap::<EndpointAddress, usize>::new( - Duration::from_secs(10), - Duration::from_millis(10), - ); - map.insert(one.clone(), 1); - - let exp1 = map.get(&one).unwrap().expiration_secs(); - - time::advance(Duration::from_secs(2)).await; - - let old_exp1 = match map.entry(one.clone()) { - Entry::Occupied(mut entry) => entry.insert(9).expiration_secs(), - _ => unreachable!("expected occupied entry"), - }; - - let exp2 = map.get(&one).unwrap().expiration_secs(); - - assert_eq!(exp1, old_exp1); - assert!(exp1 < exp2); - assert_eq!(2, exp2 - exp1); - } - - #[tokio::test] - async fn entry_occupied_vacant_expiration() { - // Test that when we insert a value via VacantEntry, we update its expiration. - time::pause(); - - let (one, _) = address_pair(); - - let map = TtlMap::<EndpointAddress, usize>::new( - Duration::from_secs(10), - Duration::from_millis(10), - ); - - let exp1 = match map.entry(one.clone()) { - Entry::Vacant(entry) => entry.insert(9).expiration_secs(), - _ => unreachable!("expected vacant entry"), - }; - - time::advance(Duration::from_secs(2)).await; - - let exp2 = map.get(&one).unwrap().expiration_secs(); - - // Initial expiration should be set at our configured ttl. - assert_eq!(10, exp1); - - assert!(exp1 < exp2); - assert_eq!(2, exp2 - exp1); - } - #[tokio::test] async fn expiration_ttl() { // Test that when we expire entries at our configured ttl. @@ -741,10 +462,9 @@ mod tests { let ttl = Duration::from_secs(12); let map = TtlMap::<EndpointAddress, usize>::new(ttl, Duration::from_millis(10)); - let exp = match map.entry(one) { - Entry::Vacant(entry) => entry.insert(9).expiration_secs(), - _ => unreachable!("expected vacant entry"), - }; + assert!(map.0.inner.pin().get(&one).is_none()); + map.insert(one.clone(), 9); + let exp = map.0.inner.pin().get(&one).unwrap().expiration_secs(); // Check that it expires at our configured TTL. assert_eq!(12, exp); diff --git a/src/components/proxy/sessions.rs b/src/components/proxy/sessions.rs index 0fcd6a2563..7f30b59a2d 100644 --- a/src/components/proxy/sessions.rs +++ b/src/components/proxy/sessions.rs @@ -220,12 +220,14 @@ impl SessionPool { ) -> Result<(Option<MetricsIpNetEntry>, UpstreamSender), super::PipelineError> { tracing::trace!(source=%key.source, dest=%key.dest, "SessionPool::get"); // If we already have a session for the key pairing, return that session. - if let Some(entry) = self.session_map.get(&key) { + if let Some((asn_info, upstream_sender)) = self.session_map.get_by_ref(&key, |value| { + ( + value.asn_info.as_ref().map(MetricsIpNetEntry::from), + value.upstream_sender.clone(), + ) + }) { tracing::trace!("returning existing session"); - return Ok(( - entry.asn_info.as_ref().map(MetricsIpNetEntry::from), - entry.upstream_sender.clone(), - )); + return Ok((asn_info, upstream_sender)); } // If there's a socket_set available, it means there are sockets @@ -629,8 +631,12 @@ mod tests { let _socket1 = pool.get(key1).unwrap(); let _socket2 = pool.get(key2).unwrap(); assert_ne!( - pool.session_map.get(&key1).unwrap().socket_port, - pool.session_map.get(&key2).unwrap().socket_port + pool.session_map + .get_by_ref(&key1, |v| v.socket_port) + .unwrap(), + pool.session_map + .get_by_ref(&key2, |v| v.socket_port) + .unwrap() ); assert!(pool.drop_session(key1).await); @@ -655,8 +661,12 @@ mod tests { let _socket2 = pool.get(key2).unwrap(); assert_eq!( - pool.session_map.get(&key1).unwrap().socket_port, - pool.session_map.get(&key2).unwrap().socket_port + pool.session_map + .get_by_ref(&key1, |v| v.socket_port) + .unwrap(), + pool.session_map + .get_by_ref(&key2, |v| v.socket_port) + .unwrap() ); } diff --git a/src/config.rs b/src/config.rs index 7064792c8e..e99e3a3afc 100644 --- a/src/config.rs +++ b/src/config.rs @@ -297,10 +297,10 @@ impl Config { }); } DatacenterConfig::NonAgent { datacenters } => { - for entry in datacenters.read().iter() { - let host = entry.key().to_string(); - let qcmp_port = entry.qcmp_port; - let version = format!("{}-{qcmp_port}", entry.icao_code); + for (key, value) in datacenters.read().pin().iter() { + let host = key.to_string(); + let qcmp_port = value.qcmp_port; + let version = format!("{}-{qcmp_port}", value.icao_code); if client_state.version_matches(&host, &version) { continue; @@ -309,7 +309,7 @@ impl Config { let resource = crate::xds::Resource::Datacenter( crate::net::cluster::proto::Datacenter { qcmp_port: qcmp_port as _, - icao_code: entry.icao_code.to_string(), + icao_code: value.icao_code.to_string(), host: host.clone(), }, ); @@ -330,7 +330,7 @@ impl Config { let Ok(addr) = key.parse() else { continue; }; - if dc.get(&addr).is_none() { + if dc.pin().get(&addr).is_none() { removed.insert(key.clone()); } } @@ -366,8 +366,8 @@ impl Config { }; if client_state.subscribed.is_empty() { - for cluster in self.clusters.read().iter() { - push(cluster.key(), cluster.value())?; + for (key, value) in self.clusters.read().pin().iter() { + push(key, value)?; } } else { for locality in client_state.subscribed.iter().filter_map(|name| { @@ -377,8 +377,8 @@ impl Config { name.parse().ok().map(Some) } }) { - if let Some(cluster) = self.clusters.read().get(&locality) { - push(cluster.key(), cluster.value())?; + if let Some(value) = self.clusters.read().pin().get(&locality) { + push(&locality, value)?; } } }; @@ -387,7 +387,7 @@ impl Config { // is when ClusterMap::update_unlocated_endpoints is called to move the None // locality endpoints to another one, so we just detect that case manually if client_state.versions.contains_key("") - && self.clusters.read().get(&None).is_none() + && self.clusters.read().pin().get(&None).is_none() { removed.insert("".into()); } @@ -593,16 +593,15 @@ impl Config { #[derive(Default, Debug, Deserialize, Serialize)] pub struct DatacenterMap { - map: dashmap::DashMap<IpAddr, Datacenter>, + map: papaya::HashMap<IpAddr, Datacenter, gxhash::GxBuildHasher>, version: AtomicU64, } impl DatacenterMap { #[inline] - pub fn insert(&self, ip: IpAddr, datacenter: Datacenter) -> Option<Datacenter> { - let old = self.map.insert(ip, datacenter); + pub fn insert(&self, ip: IpAddr, datacenter: Datacenter) { + self.map.pin().insert(ip, datacenter); self.version.fetch_add(1, Relaxed); - old } #[inline] @@ -621,13 +620,10 @@ impl DatacenterMap { } #[inline] - pub fn get(&self, key: &IpAddr) -> Option<dashmap::mapref::one::Ref<IpAddr, Datacenter>> { - self.map.get(key) - } - - #[inline] - pub fn iter(&self) -> dashmap::iter::Iter<IpAddr, Datacenter> { - self.map.iter() + pub fn pin( + &self, + ) -> papaya::HashMapRef<IpAddr, Datacenter, gxhash::GxBuildHasher, seize::LocalGuard> { + self.map.pin() } } @@ -676,8 +672,8 @@ impl PartialEq for DatacenterMap { return false; } - for a in self.iter() { - match rhs.get(a.key()).filter(|b| *a.value() == **b) { + for (key, value) in self.pin().iter() { + match rhs.pin().get(key).filter(|b| *value == **b) { Some(_) => {} None => return false, } diff --git a/src/filters/local_rate_limit.rs b/src/filters/local_rate_limit.rs index 05087c70d8..cd4fa51984 100644 --- a/src/filters/local_rate_limit.rs +++ b/src/filters/local_rate_limit.rs @@ -20,11 +20,7 @@ use std::time::Duration; use serde::{Deserialize, Serialize}; -use crate::{ - collections::ttl::{Entry, TtlMap}, - filters::prelude::*, - net::endpoint::EndpointAddress, -}; +use crate::{collections::ttl::TtlMap, filters::prelude::*, net::endpoint::EndpointAddress}; use crate::generated::quilkin::filters::local_rate_limit::v1alpha1 as proto; @@ -50,8 +46,8 @@ const SESSION_EXPIRY_POLL_INTERVAL: Duration = Duration::from_secs(60); /// number of packet handling workers). #[derive(Debug)] struct Bucket { - counter: Arc<AtomicUsize>, - window_start_time_secs: Arc<AtomicU64>, + counter: AtomicUsize, + window_start_time_secs: AtomicU64, } /// A filter that implements rate limiting on packets based on the token-bucket @@ -61,7 +57,7 @@ struct Bucket { /// flow through the filter untouched. pub struct LocalRateLimit { /// Tracks rate limiting state per source address. - state: TtlMap<EndpointAddress, Bucket>, + state: TtlMap<EndpointAddress, Arc<Bucket>>, /// Filter configuration. config: Config, } @@ -95,10 +91,10 @@ impl LocalRateLimit { } if let Some(bucket) = self.state.get(address) { - let prev_count = bucket.value.counter.fetch_add(1, Ordering::Relaxed); + let prev_count = bucket.counter.fetch_add(1, Ordering::Relaxed); let now_secs = self.state.now_relative_secs(); - let window_start_secs = bucket.value.window_start_time_secs.load(Ordering::Relaxed); + let window_start_secs = bucket.window_start_time_secs.load(Ordering::Relaxed); let elapsed_secs = now_secs - window_start_secs; let start_new_window = elapsed_secs > self.config.period as u64; @@ -115,9 +111,8 @@ impl LocalRateLimit { if start_new_window { // Current time window has ended, so we can reset the counter and // start a new time window instead. - bucket.value.counter.store(1, Ordering::Relaxed); + bucket.counter.store(1, Ordering::Relaxed); bucket - .value .window_start_time_secs .store(now_secs, Ordering::Relaxed); } @@ -125,21 +120,23 @@ impl LocalRateLimit { return true; } - match self.state.entry(address.clone()) { - Entry::Occupied(entry) => { + match self.state.get(address) { + Some(value) => { // It is possible that some other task has added the item since we // checked for it. If so, only increment the counter - no need to // update the window start time since the window has just started. - let bucket = entry.get(); - bucket.value.counter.fetch_add(1, Ordering::Relaxed); + value.counter.fetch_add(1, Ordering::Relaxed); } - Entry::Vacant(entry) => { + None => { // New entry, set both the time stamp and let now_secs = self.state.now_relative_secs(); - entry.insert(Bucket { - counter: Arc::new(AtomicUsize::new(1)), - window_start_time_secs: Arc::new(AtomicU64::new(now_secs)), - }); + self.state.insert( + address.clone(), + Arc::new(Bucket { + counter: AtomicUsize::new(1), + window_start_time_secs: AtomicU64::new(now_secs), + }), + ); } }; diff --git a/src/net/cluster.rs b/src/net/cluster.rs index 9cc3bfbc35..487020d4d6 100644 --- a/src/net/cluster.rs +++ b/src/net/cluster.rs @@ -20,8 +20,8 @@ use std::{ sync::atomic::{AtomicU64, AtomicUsize, Ordering::Relaxed}, }; -use dashmap::DashMap; use once_cell::sync::Lazy; +use papaya::HashMap; use serde::{Deserialize, Serialize}; use crate::net::endpoint::{Endpoint, EndpointAddress, Locality}; @@ -259,16 +259,12 @@ impl EndpointSet { /// Represents a full snapshot of all clusters. pub struct ClusterMap<S = gxhash::GxBuildHasher> { - map: DashMap<Option<Locality>, EndpointSet, S>, - token_map: DashMap<u64, Vec<EndpointAddress>>, + map: papaya::HashMap<Option<Locality>, EndpointSet, S>, + token_map: papaya::HashMap<u64, Vec<EndpointAddress>, S>, num_endpoints: AtomicUsize, version: AtomicU64, } -type DashMapRef<'inner, S> = dashmap::mapref::one::Ref<'inner, Option<Locality>, EndpointSet, S>; -type DashMapRefMut<'inner, S> = - dashmap::mapref::one::RefMut<'inner, Option<Locality>, EndpointSet, S>; - impl ClusterMap { pub fn new() -> Self { Self::default() @@ -294,7 +290,7 @@ where { pub fn benchmarking(capacity: usize, hasher: S) -> Self { Self { - map: DashMap::with_capacity_and_hasher(capacity, hasher), + map: papaya::HashMap::with_capacity_and_hasher(capacity, hasher), ..Self::default() } } @@ -306,8 +302,8 @@ where pub fn apply(&self, locality: Option<Locality>, cluster: EndpointSet) { let new_len = cluster.len(); - if let Some(mut current) = self.map.get_mut(&locality) { - let current = current.value_mut(); + if let Some(current) = self.map.pin().get(&locality) { + let mut current = current.clone(); let (old_len, token_map_diff) = current.replace(cluster); @@ -317,22 +313,24 @@ where self.num_endpoints.fetch_sub(old_len - new_len, Relaxed); } + self.map.pin().insert(locality, current); self.version.fetch_add(1, Relaxed); for (token_hash, addrs) in token_map_diff { if let Some(addrs) = addrs { - self.token_map.insert(token_hash, addrs); + self.token_map.pin().insert(token_hash, addrs); } else { - self.token_map.remove(&token_hash); + self.token_map.pin().remove(&token_hash); } } } else { for (token_hash, addrs) in &cluster.token_map { self.token_map + .pin() .insert(*token_hash, addrs.iter().cloned().collect()); } - self.map.insert(locality, cluster); + self.map.pin().insert(locality, cluster); self.num_endpoints.fetch_add(new_len, Relaxed); self.version.fetch_add(1, Relaxed); } @@ -348,20 +346,9 @@ where self.map.is_empty() } - pub fn get(&self, key: &Option<Locality>) -> Option<DashMapRef<S>> { - self.map.get(key) - } - - pub fn get_mut(&self, key: &Option<Locality>) -> Option<DashMapRefMut<S>> { - self.map.get_mut(key) - } - - pub fn get_default(&self) -> Option<DashMapRef<S>> { - self.get(&None) - } - - pub fn get_default_mut(&self) -> Option<DashMapRefMut<S>> { - self.get_mut(&None) + #[inline] + pub fn pin(&self) -> papaya::HashMapRef<Option<Locality>, EndpointSet, S, seize::LocalGuard> { + self.map.pin() } #[inline] @@ -371,11 +358,12 @@ where #[inline] pub fn remove_endpoint(&self, needle: &Endpoint) -> bool { - for mut entry in self.map.iter_mut() { - let set = entry.value_mut(); - - if set.endpoints.remove(needle) { - set.update(); + for (key, value) in self.map.pin().iter() { + if value.endpoints.contains(needle) { + let mut value = value.clone(); + value.endpoints.remove(needle); + value.update(); + self.map.pin().insert(key.clone(), value); self.num_endpoints.fetch_sub(1, Relaxed); self.version.fetch_add(1, Relaxed); return true; @@ -387,45 +375,33 @@ where #[inline] pub fn remove_endpoint_if(&self, closure: impl Fn(&Endpoint) -> bool) -> bool { - for mut entry in self.map.iter_mut() { - let set = entry.value_mut(); - if let Some(endpoint) = set + for (key, value) in self.map.pin().iter() { + if let Some(endpoint) = value .endpoints .iter() .find(|endpoint| (closure)(endpoint)) .cloned() { - // This will always be true, but.... - let removed = set.endpoints.remove(&endpoint); - if removed { - set.update(); - self.num_endpoints.fetch_sub(1, Relaxed); - self.version.fetch_add(1, Relaxed); - } - return removed; + let mut value = value.clone(); + value.endpoints.remove(&endpoint); + value.update(); + self.map.pin().insert(key.clone(), value); + self.num_endpoints.fetch_sub(1, Relaxed); + self.version.fetch_add(1, Relaxed); + return true; } } false } - #[inline] - pub fn iter(&self) -> dashmap::iter::Iter<Option<Locality>, EndpointSet, S> { - self.map.iter() - } - - pub fn entry( - &self, - key: Option<Locality>, - ) -> dashmap::mapref::entry::Entry<Option<Locality>, EndpointSet, S> { - self.map.entry(key) - } - #[inline] pub fn replace(&self, locality: Option<Locality>, endpoint: Endpoint) -> Option<Endpoint> { - if let Some(mut set) = self.map.get_mut(&locality) { + if let Some(set) = self.map.pin().get(&locality) { + let mut set = set.clone(); let replaced = set.endpoints.replace(endpoint); set.update(); + self.map.pin().insert(locality, set); self.version.fetch_add(1, Relaxed); if replaced.is_none() { @@ -443,16 +419,16 @@ where pub fn endpoints(&self) -> Vec<Endpoint> { let mut endpoints = Vec::with_capacity(self.num_of_endpoints()); - for set in self.map.iter() { - endpoints.extend(set.value().endpoints.iter().cloned()); + for (_, value) in self.map.pin().iter() { + endpoints.extend(value.endpoints.iter().cloned()); } endpoints } pub fn nth_endpoint(&self, mut index: usize) -> Option<Endpoint> { - for set in self.iter() { - let set = &set.value().endpoints; + for (_, value) in self.map.pin().iter() { + let set = &value.endpoints; if index < set.len() { return set.iter().nth(index).cloned(); } else { @@ -466,8 +442,8 @@ where pub fn filter_endpoints(&self, f: impl Fn(&Endpoint) -> bool) -> Vec<Endpoint> { let mut endpoints = Vec::new(); - for set in self.iter() { - for endpoint in set.endpoints.iter().filter(|e| (f)(e)) { + for (_, value) in self.map.pin().iter() { + for endpoint in value.endpoints.iter().filter(|e| (f)(e)) { endpoints.push(endpoint.clone()); } } @@ -487,29 +463,28 @@ where #[inline] pub fn update_unlocated_endpoints(&self, locality: Locality) { - if let Some((_, set)) = self.map.remove(&None) { + if let Some(set) = self.map.pin().remove(&None).cloned() { self.version.fetch_add(1, Relaxed); - if let Some(replaced) = self.map.insert(Some(locality), set) { + if let Some(replaced) = self.map.pin().insert(Some(locality), set) { self.num_endpoints.fetch_sub(replaced.len(), Relaxed); } } } #[inline] - pub fn remove_locality(&self, locality: &Option<Locality>) -> Option<EndpointSet> { - let ret = self.map.remove(locality).map(|(_k, v)| v); - if let Some(ret) = &ret { + pub fn remove_locality(&self, locality: &Option<Locality>) { + if let Some(ret) = self.map.pin().remove(locality) { self.version.fetch_add(1, Relaxed); self.num_endpoints.fetch_sub(ret.len(), Relaxed); } - - ret } pub fn addresses_for_token(&self, token: Token) -> Vec<EndpointAddress> { self.token_map + .pin() .get(&token.0) - .map_or(Vec::new(), |addrs| addrs.value().to_vec()) + .cloned() + .unwrap_or_default() } } @@ -547,7 +522,7 @@ where { fn default() -> Self { Self { - map: <DashMap<Option<Locality>, EndpointSet, S>>::default(), + map: <HashMap<Option<Locality>, EndpointSet, S>>::default(), token_map: Default::default(), version: <_>::default(), num_endpoints: <_>::default(), @@ -568,10 +543,12 @@ where S: Default + std::hash::BuildHasher + Clone, { fn eq(&self, rhs: &Self) -> bool { - for a in self.iter() { + for (key, value) in self.map.pin().iter() { match rhs - .get(a.key()) - .filter(|b| a.value().endpoints == b.endpoints) + .map + .pin() + .get(key) + .filter(|b| value.endpoints == b.endpoints) { Some(_) => {} None => return false, @@ -651,10 +628,9 @@ impl Serialize for ClusterMap { S: serde::Serializer, { self.map + .pin() .iter() - .map(|entry| { - EndpointWithLocality::from((entry.key().clone(), entry.value().endpoints.clone())) - }) + .map(|(key, value)| EndpointWithLocality::from((key.clone(), value.endpoints.clone()))) .collect::<Vec<_>>() .serialize(ser) } @@ -665,7 +641,7 @@ where S: Default + std::hash::BuildHasher + Clone, { fn from(cmd: ClusterMapDeser) -> Self { - let map = DashMap::from_iter(cmd.endpoints.into_iter().map( + let map = HashMap::from_iter(cmd.endpoints.into_iter().map( |EndpointWithLocality { locality, endpoints, @@ -676,17 +652,19 @@ where } } -impl<S> From<DashMap<Option<Locality>, EndpointSet, S>> for ClusterMap<S> +impl<S> From<HashMap<Option<Locality>, EndpointSet, S>> for ClusterMap<S> where S: Default + std::hash::BuildHasher + Clone, { - fn from(map: DashMap<Option<Locality>, EndpointSet, S>) -> Self { - let num_endpoints = AtomicUsize::new(map.iter().map(|kv| kv.value().len()).sum()); + fn from(map: HashMap<Option<Locality>, EndpointSet, S>) -> Self { + let num_endpoints = AtomicUsize::new(map.pin().iter().map(|(_, value)| value.len()).sum()); - let token_map = DashMap::<u64, Vec<EndpointAddress>>::default(); - for es in &map { - for (token_hash, addrs) in &es.value().token_map { - token_map.insert(*token_hash, addrs.iter().cloned().collect()); + let token_map = HashMap::<u64, Vec<EndpointAddress>, S>::default(); + for value in map.pin().values() { + for (token_hash, addrs) in &value.token_map { + token_map + .pin() + .insert(*token_hash, addrs.iter().cloned().collect()); } } @@ -727,13 +705,15 @@ mod tests { cluster1.insert(Some(nl1.clone()), [endpoint.clone()].into()); cluster1.insert(Some(de1.clone()), [endpoint.clone()].into()); - assert_eq!(cluster1.get(&Some(nl1.clone())).unwrap().len(), 1); + assert_eq!(cluster1.pin().get(&Some(nl1.clone())).unwrap().len(), 1); assert!(cluster1 + .pin() .get(&Some(nl1.clone())) .unwrap() .contains(&endpoint)); - assert_eq!(cluster1.get(&Some(de1.clone())).unwrap().len(), 1); + assert_eq!(cluster1.pin().get(&Some(de1.clone())).unwrap().len(), 1); assert!(cluster1 + .pin() .get(&Some(de1.clone())) .unwrap() .contains(&endpoint)); @@ -742,16 +722,13 @@ mod tests { cluster1.insert(Some(de1.clone()), [endpoint.clone()].into()); - assert_eq!(cluster1.get(&Some(nl1.clone())).unwrap().len(), 1); - assert_eq!(cluster1.get(&Some(de1.clone())).unwrap().len(), 1); - assert!(cluster1 - .get(&Some(de1.clone())) - .unwrap() - .contains(&endpoint)); + assert_eq!(cluster1.pin().get(&Some(nl1.clone())).unwrap().len(), 1); + assert_eq!(cluster1.pin().get(&Some(de1.clone())).unwrap().len(), 1); + assert!(dbg!(cluster1.pin().get(&Some(de1.clone())).unwrap()).contains(&endpoint)); cluster1.insert(Some(de1.clone()), <_>::default()); - assert_eq!(cluster1.get(&Some(nl1.clone())).unwrap().len(), 1); - assert!(cluster1.get(&Some(de1.clone())).unwrap().is_empty()); + assert_eq!(cluster1.pin().get(&Some(nl1.clone())).unwrap().len(), 1); + assert!(cluster1.pin().get(&Some(de1.clone())).unwrap().is_empty()); } } diff --git a/src/net/endpoint/address.rs b/src/net/endpoint/address.rs index 2948b0a249..51a0864a61 100644 --- a/src/net/endpoint/address.rs +++ b/src/net/endpoint/address.rs @@ -71,7 +71,7 @@ impl EndpointAddress { Lazy::new(<_>::default); match CACHE.get(name) { - Some(ip) => **ip, + Some(ip) => ip, None => { let handle = tokio::runtime::Handle::current(); let set = handle diff --git a/src/net/phoenix.rs b/src/net/phoenix.rs index 4e8bc63b65..f3a8a79ea6 100644 --- a/src/net/phoenix.rs +++ b/src/net/phoenix.rs @@ -446,9 +446,10 @@ impl<M: Measurement + 'static> Phoenix<M> { let crate::config::DatacenterConfig::NonAgent { datacenters } = &config.datacenter else { unreachable!("this shouldn't be called by an agent") }; - for entry in datacenters.write().iter() { - let addr = (*entry.key(), entry.value().qcmp_port).into(); - self.add_node_if_not_exists(addr, entry.value().icao_code); + + for (key, value) in datacenters.read().pin().iter() { + let addr = (*key, value.qcmp_port).into(); + self.add_node_if_not_exists(addr, value.icao_code); } } }