Skip to content

Commit

Permalink
Fix size_hint for partially consumed QueryIter
Browse files Browse the repository at this point in the history
Instead of returning the total count of elements in the `QueryIter` in
`size_hint`, we return the count of remaining elements in it. This
Fixes bevyengine#5149. This is also true of `QueryCombinationIter`.

- bevyengine#5149
- bevyengine#5148
  • Loading branch information
nicopap committed Sep 1, 2022
1 parent 9e34c74 commit f8787f8
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 41 deletions.
70 changes: 35 additions & 35 deletions crates/bevy_ecs/src/query/iter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,13 +56,7 @@ impl<'w, 's, Q: WorldQuery, F: WorldQuery> Iterator for QueryIter<'w, 's, Q, F>
}

fn size_hint(&self) -> (usize, Option<usize>) {
let max_size = self
.query_state
.matched_archetype_ids
.iter()
.map(|id| self.archetypes[*id].len())
.sum();

let max_size = self.cursor.remaining(self.tables, self.archetypes);
let archetype_query = Q::IS_ARCHETYPAL && F::IS_ARCHETYPAL;
let min_size = if archetype_query { max_size } else { 0 };
(min_size, Some(max_size))
Expand Down Expand Up @@ -333,11 +327,16 @@ impl<'w, 's, Q: WorldQuery, F: WorldQuery, const K: usize> QueryCombinationIter<
return None;
}

// first, iterate from last to first until next item is found
// TODO: can speed up the following code using `cursor.remaining()` instead of `next_item.is_none()`
// when Q::IS_ARCHETYPAL && F::IS_ARCHETYPAL
//
// let `i` be the index of `c`, the last cursor in `self.cursors` that
// returns `K-i` or more elements.
// Make cursor in index `j` for all `j` in `[i, K)` a copy of `c` advanced `j-i+1` times.
// If no such `c` exists, return `None`
'outer: for i in (0..K).rev() {
match self.cursors[i].next(self.tables, self.archetypes, self.query_state) {
Some(_) => {
// walk forward up to last element, propagating cursor state forward
for j in (i + 1)..K {
self.cursors[j] = self.cursors[j - 1].clone();
match self.cursors[j].next(self.tables, self.archetypes, self.query_state) {
Expand Down Expand Up @@ -398,36 +397,29 @@ where
}

fn size_hint(&self) -> (usize, Option<usize>) {
if K == 0 {
return (0, Some(0));
}

let max_size: usize = self
.query_state
.matched_archetype_ids
.iter()
.map(|id| self.archetypes[*id].len())
.sum();

if max_size < K {
return (0, Some(0));
}
if max_size == K {
return (1, Some(1));
}

// binomial coefficient: (n ; k) = n! / k!(n-k)! = (n*n-1*...*n-k+1) / k!
// See https://en.wikipedia.org/wiki/Binomial_coefficient
// See https://blog.plover.com/math/choose.html for implementation
// It was chosen to reduce overflow potential.
fn choose(n: usize, k: usize) -> Option<usize> {
if k > n || n == 0 {
return Some(0);
}
let k = k.min(n - k);
let ks = 1..=k;
let ns = (n - k + 1..=n).rev();
let ns = (n + 1 - k..=n).rev();
ks.zip(ns)
.try_fold(1_usize, |acc, (k, n)| Some(acc.checked_mul(n)? / k))
}
let smallest = K.min(max_size - K);
let max_combinations = choose(max_size, smallest);
// sum_i=0..k choose(cursors[i].remaining, k-i)
let max_combinations = self
.cursors
.iter()
.enumerate()
.try_fold(0, |acc, (i, cursor)| {
let n = cursor.remaining(self.tables, self.archetypes);
Some(acc + choose(n, K - i)?)
});

let archetype_query = F::IS_ARCHETYPAL && Q::IS_ARCHETYPAL;
let known_max = max_combinations.unwrap_or(usize::MAX);
Expand All @@ -441,11 +433,7 @@ where
F: ArchetypeFilter,
{
fn len(&self) -> usize {
self.query_state
.matched_archetype_ids
.iter()
.map(|id| self.archetypes[*id].len())
.sum()
self.size_hint().0
}
}

Expand Down Expand Up @@ -562,6 +550,18 @@ impl<'w, 's, Q: WorldQuery, F: WorldQuery> QueryIterationCursor<'w, 's, Q, F> {
}
}

/// How many values will this cursor return?
fn remaining(&self, tables: &'w Tables, archetypes: &'w Archetypes) -> usize {
let remaining_matched: usize = if Self::IS_DENSE {
let ids = self.table_id_iter.clone();
ids.map(|id| tables[*id].len()).sum()
} else {
let ids = self.archetype_id_iter.clone();
ids.map(|id| archetypes[*id].len()).sum()
};
remaining_matched + self.current_len - self.current_index
}

// NOTE: If you are changing query iteration code, remember to update the following places, where relevant:
// QueryIter, QueryIterationCursor, QueryManyIter, QueryCombinationIter, QueryState::for_each_unchecked_manual, QueryState::par_for_each_unchecked_manual
/// # Safety
Expand Down
22 changes: 16 additions & 6 deletions crates/bevy_ecs/src/query/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,13 @@ mod tests {
for<'w> QueryFetch<'w, F::ReadOnly>: Clone,
{
let mut query = world.query_filtered::<Q, F>();
let iter = query.iter_combinations::<K>(world);
let query_type = type_name::<QueryCombinationIter<Q, F, K>>();
assert_all_sizes_iterator_equal(iter, expected_size, query_type);
let iter = query.iter_combinations::<K>(world);
assert_all_sizes_iterator_equal(iter, expected_size, 0, query_type);
let iter = query.iter_combinations::<K>(world);
assert_all_sizes_iterator_equal(iter, expected_size, 1, query_type);
let iter = query.iter_combinations::<K>(world);
assert_all_sizes_iterator_equal(iter, expected_size, 5, query_type);
}
fn assert_all_sizes_equal<Q, F>(world: &mut World, expected_size: usize)
where
Expand All @@ -87,23 +91,29 @@ mod tests {
for<'w> QueryFetch<'w, F::ReadOnly>: Clone,
{
let mut query = world.query_filtered::<Q, F>();
let iter = query.iter(world);
let query_type = type_name::<QueryState<Q, F>>();
assert_all_sizes_iterator_equal(iter, expected_size, query_type);
assert_all_sizes_iterator_equal(query.iter(world), expected_size, 0, query_type);
assert_all_sizes_iterator_equal(query.iter(world), expected_size, 1, query_type);
assert_all_sizes_iterator_equal(query.iter(world), expected_size, 5, query_type);

let expected = expected_size;
assert_combination::<Q, F, 0>(world, choose(expected, 0));
assert_combination::<Q, F, 1>(world, choose(expected, 1));
assert_combination::<Q, F, 2>(world, choose(expected, 2));
assert_combination::<Q, F, 5>(world, choose(expected, 5));
assert_combination::<Q, F, 43>(world, choose(expected, 43));
assert_combination::<Q, F, 128>(world, choose(expected, 128));
assert_combination::<Q, F, 64>(world, choose(expected, 64));
}
fn assert_all_sizes_iterator_equal(
iterator: impl ExactSizeIterator,
mut iterator: impl ExactSizeIterator,
expected_size: usize,
skip: usize,
query_type: &'static str,
) {
let expected_size = expected_size.saturating_sub(skip);
for _ in 0..skip {
iterator.next();
}
let size_hint_0 = iterator.size_hint().0;
let size_hint_1 = iterator.size_hint().1;
let len = iterator.len();
Expand Down

0 comments on commit f8787f8

Please sign in to comment.