From 47ad9a1b0fb0ffbe19f0a05f6c128b9b39c84ebd Mon Sep 17 00:00:00 2001 From: behzad nouri Date: Mon, 16 Dec 2024 19:05:15 +0000 Subject: [PATCH] removes over-allocation of tree nodes in WeightedShuffle (#4126) Nodes without children are never accessed and don't need to be allocated. --- gossip/Cargo.toml | 3 ++ gossip/benches/weighted_shuffle.rs | 18 ++++++++++ gossip/src/weighted_shuffle.rs | 55 +++++++++++++++++++----------- 3 files changed, 56 insertions(+), 20 deletions(-) diff --git a/gossip/Cargo.toml b/gossip/Cargo.toml index 9d2a0adbe5b39e..fc2ee43522dd5c 100644 --- a/gossip/Cargo.toml +++ b/gossip/Cargo.toml @@ -96,6 +96,9 @@ name = "crds_gossip_pull" [[bench]] name = "crds_shards" +[[bench]] +name = "weighted_shuffle" + [[bin]] name = "solana-gossip" path = "src/main.rs" diff --git a/gossip/benches/weighted_shuffle.rs b/gossip/benches/weighted_shuffle.rs index 09615c57bbca15..7744c2f938b1eb 100644 --- a/gossip/benches/weighted_shuffle.rs +++ b/gossip/benches/weighted_shuffle.rs @@ -25,6 +25,24 @@ fn bench_weighted_shuffle_new(bencher: &mut Bencher) { #[bench] fn bench_weighted_shuffle_shuffle(bencher: &mut Bencher) { + let mut seed = [0u8; 32]; + let mut rng = rand::thread_rng(); + let weights = make_weights(&mut rng); + let weighted_shuffle = WeightedShuffle::new("", &weights); + bencher.iter(|| { + rng.fill(&mut seed[..]); + let mut rng = ChaChaRng::from_seed(seed); + weighted_shuffle + .clone() + .shuffle(&mut rng) + .for_each(|index| { + std::hint::black_box(index); + }); + }); +} + +#[bench] +fn bench_weighted_shuffle_collect(bencher: &mut Bencher) { let mut seed = [0u8; 32]; let mut rng = rand::thread_rng(); let weights = make_weights(&mut rng); diff --git a/gossip/src/weighted_shuffle.rs b/gossip/src/weighted_shuffle.rs index 656615449b2a79..9d607001d95a39 100644 --- a/gossip/src/weighted_shuffle.rs +++ b/gossip/src/weighted_shuffle.rs @@ -26,7 +26,11 @@ const BIT_MASK: usize = FANOUT - 1; /// non-zero weighted indices. #[derive(Clone)] pub struct WeightedShuffle { + // Number of "internal" nodes of the tree. + num_nodes: usize, // Underlying array implementing the tree. + // Nodes without children are never accessed and don't need to be + // allocated, so tree.len() < num_nodes. // tree[i][j] is the sum of all weights in the j'th sub-tree of node i. tree: Vec<[T; FANOUT - 1]>, // Current sum of all weights, excluding already sampled ones. @@ -43,11 +47,13 @@ where /// they are treated as zero. pub fn new(name: &'static str, weights: &[T]) -> Self { let zero = ::default(); - let mut tree = vec![[zero; FANOUT - 1]; get_tree_size(weights.len())]; + let (num_nodes, size) = get_num_nodes_and_tree_size(weights.len()); + debug_assert!(size <= num_nodes); + let mut tree = vec![[zero; FANOUT - 1]; size]; let mut sum = zero; let mut zeros = Vec::default(); - let mut num_negative = 0; - let mut num_overflow = 0; + let mut num_negative: usize = 0; + let mut num_overflow: usize = 0; for (k, &weight) in weights.iter().enumerate() { #[allow(clippy::neg_cmp_op_on_partial_ord)] // weight < zero does not work for NaNs. @@ -70,7 +76,7 @@ where }; // Traverse the tree from the leaf node upwards to the root, // updating the sub-tree sums along the way. - let mut index = tree.len() + k; // leaf node + let mut index = num_nodes + k; // leaf node while index != 0 { let offset = index & BIT_MASK; index = (index - 1) >> BIT_SHIFT; // parent node @@ -86,6 +92,7 @@ where datapoint_error!("weighted-shuffle-overflow", (name, num_overflow, i64)); } Self { + num_nodes, tree, weight: sum, zeros, @@ -103,7 +110,7 @@ where self.weight -= weight; // Traverse the tree from the leaf node upwards to the root, // updating the sub-tree sums along the way. - let mut index = self.tree.len() + k; // leaf node + let mut index = self.num_nodes + k; // leaf node while index != 0 { let offset = index & BIT_MASK; index = (index - 1) >> BIT_SHIFT; // parent node @@ -127,7 +134,7 @@ where 'outer: while index < self.tree.len() { for (j, &node) in self.tree[index].iter().enumerate() { if val < node { - // Traverse to the j+1 subtree of self.tree[index]. + // Traverse to the j'th subtree of self.tree[index]. weight = node; index = (index << BIT_SHIFT) + j + 1; continue 'outer; @@ -140,14 +147,14 @@ where // Traverse to the right-most subtree of self.tree[index]. index = (index << BIT_SHIFT) + FANOUT; } - (index - self.tree.len(), weight) + (index - self.num_nodes, weight) } pub fn remove_index(&mut self, k: usize) { // Traverse the tree from the leaf node upwards to the root, while // maintaining the sum of weights of subtrees *not* containing the leaf // node. - let mut index = self.tree.len() + k; // leaf node + let mut index = self.num_nodes + k; // leaf node let mut weight = ::default(); // zero while index != 0 { let offset = index & BIT_MASK; @@ -223,16 +230,18 @@ where } } -// Maps number of items to the "internal" size of the tree +// Maps number of items to the number of "internal" nodes of the tree // which "implicitly" holds those items on the leaves. -fn get_tree_size(count: usize) -> usize { - let mut size = if count == 1 { 1 } else { 0 }; - let mut nodes = 1; - while nodes < count { +// Nodes without children are never accessed and don't need to be +// allocated, so the tree size is the second smaller number. +fn get_num_nodes_and_tree_size(count: usize) -> (/*num_nodes:*/ usize, /*tree_size:*/ usize) { + let mut size: usize = 0; + let mut nodes: usize = 1; + while nodes * FANOUT < count { size += nodes; nodes *= FANOUT; } - size + (size + nodes, size + (count + FANOUT - 1) / FANOUT) } #[cfg(test)] @@ -278,19 +287,25 @@ mod tests { } #[test] - fn test_get_tree_size() { - assert_eq!(get_tree_size(0), 0); + fn test_get_num_nodes_and_tree_size() { + assert_eq!(get_num_nodes_and_tree_size(0), (1, 0)); for count in 1..=16 { - assert_eq!(get_tree_size(count), 1); + assert_eq!(get_num_nodes_and_tree_size(count), (1, 1)); } + let num_nodes = 1 + 16; for count in 17..=256 { - assert_eq!(get_tree_size(count), 1 + 16); + let tree_size = 1 + (count + 15) / 16; + assert_eq!(get_num_nodes_and_tree_size(count), (num_nodes, tree_size)); } + let num_nodes = 1 + 16 + 16 * 16; for count in 257..=4096 { - assert_eq!(get_tree_size(count), 1 + 16 + 16 * 16); + let tree_size = 1 + 16 + (count + 15) / 16; + assert_eq!(get_num_nodes_and_tree_size(count), (num_nodes, tree_size)); } + let num_nodes = 1 + 16 + 16 * 16 + 16 * 16 * 16; for count in 4097..=65536 { - assert_eq!(get_tree_size(count), 1 + 16 + 16 * 16 + 16 * 16 * 16); + let tree_size = 1 + 16 + 16 * 16 + (count + 15) / 16; + assert_eq!(get_num_nodes_and_tree_size(count), (num_nodes, tree_size)); } }