Skip to content

Commit

Permalink
removes over-allocation of tree nodes in WeightedShuffle (#4126)
Browse files Browse the repository at this point in the history
Nodes without children are never accessed and don't need to be
allocated.
  • Loading branch information
behzadnouri authored Dec 16, 2024
1 parent 9e59baa commit 47ad9a1
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 20 deletions.
3 changes: 3 additions & 0 deletions gossip/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,9 @@ name = "crds_gossip_pull"
[[bench]]
name = "crds_shards"

[[bench]]
name = "weighted_shuffle"

[[bin]]
name = "solana-gossip"
path = "src/main.rs"
Expand Down
18 changes: 18 additions & 0 deletions gossip/benches/weighted_shuffle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,24 @@ fn bench_weighted_shuffle_new(bencher: &mut Bencher) {

#[bench]
fn bench_weighted_shuffle_shuffle(bencher: &mut Bencher) {
let mut seed = [0u8; 32];
let mut rng = rand::thread_rng();
let weights = make_weights(&mut rng);
let weighted_shuffle = WeightedShuffle::new("", &weights);
bencher.iter(|| {
rng.fill(&mut seed[..]);
let mut rng = ChaChaRng::from_seed(seed);
weighted_shuffle
.clone()
.shuffle(&mut rng)
.for_each(|index| {
std::hint::black_box(index);
});
});
}

#[bench]
fn bench_weighted_shuffle_collect(bencher: &mut Bencher) {
let mut seed = [0u8; 32];
let mut rng = rand::thread_rng();
let weights = make_weights(&mut rng);
Expand Down
55 changes: 35 additions & 20 deletions gossip/src/weighted_shuffle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,11 @@ const BIT_MASK: usize = FANOUT - 1;
/// non-zero weighted indices.
#[derive(Clone)]
pub struct WeightedShuffle<T> {
// Number of "internal" nodes of the tree.
num_nodes: usize,
// Underlying array implementing the tree.
// Nodes without children are never accessed and don't need to be
// allocated, so tree.len() < num_nodes.
// tree[i][j] is the sum of all weights in the j'th sub-tree of node i.
tree: Vec<[T; FANOUT - 1]>,
// Current sum of all weights, excluding already sampled ones.
Expand All @@ -43,11 +47,13 @@ where
/// they are treated as zero.
pub fn new(name: &'static str, weights: &[T]) -> Self {
let zero = <T as Default>::default();
let mut tree = vec![[zero; FANOUT - 1]; get_tree_size(weights.len())];
let (num_nodes, size) = get_num_nodes_and_tree_size(weights.len());
debug_assert!(size <= num_nodes);
let mut tree = vec![[zero; FANOUT - 1]; size];
let mut sum = zero;
let mut zeros = Vec::default();
let mut num_negative = 0;
let mut num_overflow = 0;
let mut num_negative: usize = 0;
let mut num_overflow: usize = 0;
for (k, &weight) in weights.iter().enumerate() {
#[allow(clippy::neg_cmp_op_on_partial_ord)]
// weight < zero does not work for NaNs.
Expand All @@ -70,7 +76,7 @@ where
};
// Traverse the tree from the leaf node upwards to the root,
// updating the sub-tree sums along the way.
let mut index = tree.len() + k; // leaf node
let mut index = num_nodes + k; // leaf node
while index != 0 {
let offset = index & BIT_MASK;
index = (index - 1) >> BIT_SHIFT; // parent node
Expand All @@ -86,6 +92,7 @@ where
datapoint_error!("weighted-shuffle-overflow", (name, num_overflow, i64));
}
Self {
num_nodes,
tree,
weight: sum,
zeros,
Expand All @@ -103,7 +110,7 @@ where
self.weight -= weight;
// Traverse the tree from the leaf node upwards to the root,
// updating the sub-tree sums along the way.
let mut index = self.tree.len() + k; // leaf node
let mut index = self.num_nodes + k; // leaf node
while index != 0 {
let offset = index & BIT_MASK;
index = (index - 1) >> BIT_SHIFT; // parent node
Expand All @@ -127,7 +134,7 @@ where
'outer: while index < self.tree.len() {
for (j, &node) in self.tree[index].iter().enumerate() {
if val < node {
// Traverse to the j+1 subtree of self.tree[index].
// Traverse to the j'th subtree of self.tree[index].
weight = node;
index = (index << BIT_SHIFT) + j + 1;
continue 'outer;
Expand All @@ -140,14 +147,14 @@ where
// Traverse to the right-most subtree of self.tree[index].
index = (index << BIT_SHIFT) + FANOUT;
}
(index - self.tree.len(), weight)
(index - self.num_nodes, weight)
}

pub fn remove_index(&mut self, k: usize) {
// Traverse the tree from the leaf node upwards to the root, while
// maintaining the sum of weights of subtrees *not* containing the leaf
// node.
let mut index = self.tree.len() + k; // leaf node
let mut index = self.num_nodes + k; // leaf node
let mut weight = <T as Default>::default(); // zero
while index != 0 {
let offset = index & BIT_MASK;
Expand Down Expand Up @@ -223,16 +230,18 @@ where
}
}

// Maps number of items to the "internal" size of the tree
// Maps number of items to the number of "internal" nodes of the tree
// which "implicitly" holds those items on the leaves.
fn get_tree_size(count: usize) -> usize {
let mut size = if count == 1 { 1 } else { 0 };
let mut nodes = 1;
while nodes < count {
// Nodes without children are never accessed and don't need to be
// allocated, so the tree size is the second smaller number.
fn get_num_nodes_and_tree_size(count: usize) -> (/*num_nodes:*/ usize, /*tree_size:*/ usize) {
let mut size: usize = 0;
let mut nodes: usize = 1;
while nodes * FANOUT < count {
size += nodes;
nodes *= FANOUT;
}
size
(size + nodes, size + (count + FANOUT - 1) / FANOUT)
}

#[cfg(test)]
Expand Down Expand Up @@ -278,19 +287,25 @@ mod tests {
}

#[test]
fn test_get_tree_size() {
assert_eq!(get_tree_size(0), 0);
fn test_get_num_nodes_and_tree_size() {
assert_eq!(get_num_nodes_and_tree_size(0), (1, 0));
for count in 1..=16 {
assert_eq!(get_tree_size(count), 1);
assert_eq!(get_num_nodes_and_tree_size(count), (1, 1));
}
let num_nodes = 1 + 16;
for count in 17..=256 {
assert_eq!(get_tree_size(count), 1 + 16);
let tree_size = 1 + (count + 15) / 16;
assert_eq!(get_num_nodes_and_tree_size(count), (num_nodes, tree_size));
}
let num_nodes = 1 + 16 + 16 * 16;
for count in 257..=4096 {
assert_eq!(get_tree_size(count), 1 + 16 + 16 * 16);
let tree_size = 1 + 16 + (count + 15) / 16;
assert_eq!(get_num_nodes_and_tree_size(count), (num_nodes, tree_size));
}
let num_nodes = 1 + 16 + 16 * 16 + 16 * 16 * 16;
for count in 4097..=65536 {
assert_eq!(get_tree_size(count), 1 + 16 + 16 * 16 + 16 * 16 * 16);
let tree_size = 1 + 16 + 16 * 16 + (count + 15) / 16;
assert_eq!(get_num_nodes_and_tree_size(count), (num_nodes, tree_size));
}
}

Expand Down

0 comments on commit 47ad9a1

Please sign in to comment.