Skip to content

Commit

Permalink
Rollup merge of rust-lang#58577 - ssomers:btreeset_intersection, r=Ko…
Browse files Browse the repository at this point in the history
…drAus

improve worst-case performance of BTreeSet intersection

Major performance boost when comparing tiny and huge sets. Probably with controversial changes and I sure have questions:

- Names and places of functions and types
- How many comments to write where
- Why does rustc tell me to `ref mut` and `ref` matches on the iterator, while the book says ref is old school.
- (Why) do I have to write out the clone like that (`#[derive(Clone)]` doesn't work)
- Am I allowed to use `#[derive(Debug)]` there at all?
- I'd like to test function `are_proportionate_for_intersection` in test_intersection (or another test case next to it) itself, but I think the private function is inaccessible there. liballoc has other test cases not in the tests directory.

PS I don't list these questions to start a discussion here, just to inspire reviewers and to remember myself.
  • Loading branch information
Centril authored Feb 22, 2019
2 parents 52fd337 + 4fd0cc1 commit ef1faf4
Show file tree
Hide file tree
Showing 4 changed files with 189 additions and 15 deletions.
1 change: 1 addition & 0 deletions src/liballoc/benches/btree/mod.rs
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
mod map;
mod set;
88 changes: 88 additions & 0 deletions src/liballoc/benches/btree/set.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
use std::collections::BTreeSet;

use rand::{thread_rng, Rng};
use test::{black_box, Bencher};

fn random(n1: u32, n2: u32) -> [BTreeSet<usize>; 2] {
let mut rng = thread_rng();
let mut set1 = BTreeSet::new();
let mut set2 = BTreeSet::new();
for _ in 0..n1 {
let i = rng.gen::<usize>();
set1.insert(i);
}
for _ in 0..n2 {
let i = rng.gen::<usize>();
set2.insert(i);
}
[set1, set2]
}

fn staggered(n1: u32, n2: u32) -> [BTreeSet<u32>; 2] {
let mut even = BTreeSet::new();
let mut odd = BTreeSet::new();
for i in 0..n1 {
even.insert(i * 2);
}
for i in 0..n2 {
odd.insert(i * 2 + 1);
}
[even, odd]
}

fn neg_vs_pos(n1: u32, n2: u32) -> [BTreeSet<i32>; 2] {
let mut neg = BTreeSet::new();
let mut pos = BTreeSet::new();
for i in -(n1 as i32)..=-1 {
neg.insert(i);
}
for i in 1..=(n2 as i32) {
pos.insert(i);
}
[neg, pos]
}

fn pos_vs_neg(n1: u32, n2: u32) -> [BTreeSet<i32>; 2] {
let mut neg = BTreeSet::new();
let mut pos = BTreeSet::new();
for i in -(n1 as i32)..=-1 {
neg.insert(i);
}
for i in 1..=(n2 as i32) {
pos.insert(i);
}
[pos, neg]
}

macro_rules! set_intersection_bench {
($name: ident, $sets: expr) => {
#[bench]
pub fn $name(b: &mut Bencher) {
// setup
let sets = $sets;

// measure
b.iter(|| {
let x = sets[0].intersection(&sets[1]).count();
black_box(x);
})
}
};
}

set_intersection_bench! {intersect_random_100, random(100, 100)}
set_intersection_bench! {intersect_random_10k, random(10_000, 10_000)}
set_intersection_bench! {intersect_random_10_vs_10k, random(10, 10_000)}
set_intersection_bench! {intersect_random_10k_vs_10, random(10_000, 10)}
set_intersection_bench! {intersect_staggered_100, staggered(100, 100)}
set_intersection_bench! {intersect_staggered_10k, staggered(10_000, 10_000)}
set_intersection_bench! {intersect_staggered_10_vs_10k, staggered(10, 10_000)}
set_intersection_bench! {intersect_staggered_10k_vs_10, staggered(10_000, 10)}
set_intersection_bench! {intersect_neg_vs_pos_100, neg_vs_pos(100, 100)}
set_intersection_bench! {intersect_neg_vs_pos_10k, neg_vs_pos(10_000, 10_000)}
set_intersection_bench! {intersect_neg_vs_pos_10_vs_10k,neg_vs_pos(10, 10_000)}
set_intersection_bench! {intersect_neg_vs_pos_10k_vs_10,neg_vs_pos(10_000, 10)}
set_intersection_bench! {intersect_pos_vs_neg_100, pos_vs_neg(100, 100)}
set_intersection_bench! {intersect_pos_vs_neg_10k, pos_vs_neg(10_000, 10_000)}
set_intersection_bench! {intersect_pos_vs_neg_10_vs_10k,pos_vs_neg(10, 10_000)}
set_intersection_bench! {intersect_pos_vs_neg_10k_vs_10,pos_vs_neg(10_000, 10)}
102 changes: 87 additions & 15 deletions src/liballoc/collections/btree/set.rs
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,22 @@ impl<T: fmt::Debug> fmt::Debug for SymmetricDifference<'_, T> {
}
}

/// Whether the sizes of two sets are roughly the same order of magnitude.
///
/// If they are, or if either set is empty, then their intersection
/// is efficiently calculated by iterating both sets jointly.
/// If they aren't, then it is more scalable to iterate over the small set
/// and find matches in the large set (except if the largest element in
/// the small set hardly surpasses the smallest element in the large set).
fn are_proportionate_for_intersection(len1: usize, len2: usize) -> bool {
let (small, large) = if len1 <= len2 {
(len1, len2)
} else {
(len2, len1)
};
(large >> 7) <= small
}

/// A lazy iterator producing elements in the intersection of `BTreeSet`s.
///
/// This `struct` is created by the [`intersection`] method on [`BTreeSet`].
Expand All @@ -165,7 +181,13 @@ impl<T: fmt::Debug> fmt::Debug for SymmetricDifference<'_, T> {
#[stable(feature = "rust1", since = "1.0.0")]
pub struct Intersection<'a, T: 'a> {
a: Peekable<Iter<'a, T>>,
b: Peekable<Iter<'a, T>>,
b: IntersectionOther<'a, T>,
}

#[derive(Debug)]
enum IntersectionOther<'a, T> {
Stitch(Peekable<Iter<'a, T>>),
Search(&'a BTreeSet<T>),
}

#[stable(feature = "collection_debug", since = "1.17.0")]
Expand Down Expand Up @@ -326,9 +348,21 @@ impl<T: Ord> BTreeSet<T> {
/// ```
#[stable(feature = "rust1", since = "1.0.0")]
pub fn intersection<'a>(&'a self, other: &'a BTreeSet<T>) -> Intersection<'a, T> {
Intersection {
a: self.iter().peekable(),
b: other.iter().peekable(),
if are_proportionate_for_intersection(self.len(), other.len()) {
Intersection {
a: self.iter().peekable(),
b: IntersectionOther::Stitch(other.iter().peekable()),
}
} else if self.len() <= other.len() {
Intersection {
a: self.iter().peekable(),
b: IntersectionOther::Search(&other),
}
} else {
Intersection {
a: other.iter().peekable(),
b: IntersectionOther::Search(&self),
}
}
}

Expand Down Expand Up @@ -1069,6 +1103,14 @@ impl<'a, T: Ord> Iterator for SymmetricDifference<'a, T> {
#[stable(feature = "fused", since = "1.26.0")]
impl<T: Ord> FusedIterator for SymmetricDifference<'_, T> {}

impl<'a, T> Clone for IntersectionOther<'a, T> {
fn clone(&self) -> IntersectionOther<'a, T> {
match self {
IntersectionOther::Stitch(ref iter) => IntersectionOther::Stitch(iter.clone()),
IntersectionOther::Search(set) => IntersectionOther::Search(set),
}
}
}
#[stable(feature = "rust1", since = "1.0.0")]
impl<T> Clone for Intersection<'_, T> {
fn clone(&self) -> Self {
Expand All @@ -1083,24 +1125,36 @@ impl<'a, T: Ord> Iterator for Intersection<'a, T> {
type Item = &'a T;

fn next(&mut self) -> Option<&'a T> {
loop {
match Ord::cmp(self.a.peek()?, self.b.peek()?) {
Less => {
self.a.next();
}
Equal => {
self.b.next();
return self.a.next();
match self.b {
IntersectionOther::Stitch(ref mut self_b) => loop {
match Ord::cmp(self.a.peek()?, self_b.peek()?) {
Less => {
self.a.next();
}
Equal => {
self_b.next();
return self.a.next();
}
Greater => {
self_b.next();
}
}
Greater => {
self.b.next();
}
IntersectionOther::Search(set) => loop {
let e = self.a.next()?;
if set.contains(&e) {
return Some(e);
}
}
}
}

fn size_hint(&self) -> (usize, Option<usize>) {
(0, Some(min(self.a.len(), self.b.len())))
let b_len = match self.b {
IntersectionOther::Stitch(ref iter) => iter.len(),
IntersectionOther::Search(set) => set.len(),
};
(0, Some(min(self.a.len(), b_len)))
}
}

Expand Down Expand Up @@ -1140,3 +1194,21 @@ impl<'a, T: Ord> Iterator for Union<'a, T> {

#[stable(feature = "fused", since = "1.26.0")]
impl<T: Ord> FusedIterator for Union<'_, T> {}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_are_proportionate_for_intersection() {
assert!(are_proportionate_for_intersection(0, 0));
assert!(are_proportionate_for_intersection(0, 127));
assert!(!are_proportionate_for_intersection(0, 128));
assert!(are_proportionate_for_intersection(1, 255));
assert!(!are_proportionate_for_intersection(1, 256));
assert!(are_proportionate_for_intersection(127, 0));
assert!(!are_proportionate_for_intersection(128, 0));
assert!(are_proportionate_for_intersection(255, 1));
assert!(!are_proportionate_for_intersection(256, 1));
}
}
13 changes: 13 additions & 0 deletions src/liballoc/tests/btree/set.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,19 @@ fn test_intersection() {
check_intersection(&[11, 1, 3, 77, 103, 5, -5],
&[2, 11, 77, -9, -42, 5, 3],
&[3, 5, 11, 77]);

let mut large = [0i32; 512];
for i in 0..512 {
large[i] = i as i32
}
check_intersection(&large[..], &[], &[]);
check_intersection(&large[..], &[-1], &[]);
check_intersection(&large[..], &[42], &[42]);
check_intersection(&large[..], &[4, 2], &[2, 4]);
check_intersection(&[], &large[..], &[]);
check_intersection(&[-1], &large[..], &[]);
check_intersection(&[42], &large[..], &[42]);
check_intersection(&[4, 2], &large[..], &[2, 4]);
}

#[test]
Expand Down

0 comments on commit ef1faf4

Please sign in to comment.