Skip to content

Commit

Permalink
Fix size_hint for partially consumed QueryIter
Browse files Browse the repository at this point in the history
Instead of returning the total count of elements in the `QueryIter` in
`size_hint`, we return the count of remaining elements in it. This
Fixes bevyengine#5149. This is also true of `QueryCombinationIter`.

- bevyengine#5149
- bevyengine#5148
  • Loading branch information
nicopap committed Jul 5, 2022
1 parent 5b5013d commit c94a1b3
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 34 deletions.
76 changes: 42 additions & 34 deletions crates/bevy_ecs/src/query/iter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,13 +62,7 @@ where
}

fn size_hint(&self) -> (usize, Option<usize>) {
let max_size = self
.query_state
.matched_archetype_ids
.iter()
.map(|id| self.archetypes[*id].len())
.sum();

let max_size = self.cursor.remaining(self.tables, self.archetypes);
let archetype_query = F::Fetch::IS_ARCHETYPAL && QF::IS_ARCHETYPAL;
let min_size = if archetype_query { max_size } else { 0 };
(min_size, Some(max_size))
Expand Down Expand Up @@ -264,11 +258,16 @@ impl<'w, 's, Q: WorldQuery, F: WorldQuery, const K: usize> QueryCombinationIter<
return None;
}

// first, iterate from last to first until next item is found
// TODO: can speed up the following code using `cursor.remaining()` instead of `next_item.is_none()`
// when Q::Fetch::IS_ARCHETYPAL && F::Fetch::IS_ARCHETYPAL
//
// let `i` be the index of `c`, the last cursor in `self.cursors` that
// returns `K-i` or more elements.
// Make cursor in index `j` for all `j` in `[i, K)` a copy of `c` advanced `j-i+1` times.
// If no such `c` exists, return `None`
'outer: for i in (0..K).rev() {
match self.cursors[i].next(self.tables, self.archetypes, self.query_state) {
Some(_) => {
// walk forward up to last element, propagating cursor state forward
for j in (i + 1)..K {
self.cursors[j] = self.cursors[j - 1].clone();
match self.cursors[j].next(self.tables, self.archetypes, self.query_state) {
Expand Down Expand Up @@ -329,31 +328,32 @@ where
}

fn size_hint(&self) -> (usize, Option<usize>) {
if K == 0 {
return (0, Some(0));
// binomial coefficient: (n ; k) = n! / k!(n-k)! = (n*n-1*...*n-k+1) / k!
// See https://en.wikipedia.org/wiki/Binomial_coefficient
// See https://blog.plover.com/math/choose.html for implementation
// It was chosen to reduce overflow potential.
fn choose(n: usize, k: usize) -> Option<usize> {
if k > n {
return Some(0);
}
let ks = 1..=k;
let ns = (n + 1 - k..=n).rev();
ks.zip(ns)
.try_fold(1_usize, |acc, (k, n)| Some(acc.checked_mul(n)? / k))
}

let max_size: usize = self
.query_state
.matched_archetype_ids
// sum_i=0..k choose(cursors[i].remaining, k-i)
let max_combinations = self
.cursors
.iter()
.map(|id| self.archetypes[*id].len())
.sum();

if max_size < K {
return (0, Some(0));
}

// n! / k!(n-k)! = (n*n-1*...*n-k+1) / k!
let max_combinations = (0..K)
.try_fold(1usize, |n, i| n.checked_mul(max_size - i))
.map(|n| {
let k_factorial: usize = (1..=K).product();
n / k_factorial
.enumerate()
.try_fold(0, |acc, (i, cursor)| {
let n = cursor.remaining(self.tables, self.archetypes);
Some(acc + choose(n, K - i)?)
});

let archetype_query = F::Fetch::IS_ARCHETYPAL && Q::Fetch::IS_ARCHETYPAL;
let min_combinations = if archetype_query { max_size } else { 0 };
let known_max = max_combinations.unwrap_or(usize::MAX);
let min_combinations = if archetype_query { known_max } else { 0 };
(min_combinations, max_combinations)
}
}
Expand All @@ -364,11 +364,7 @@ where
F: WorldQuery + ArchetypeFilter,
{
fn len(&self) -> usize {
self.query_state
.matched_archetype_ids
.iter()
.map(|id| self.archetypes[*id].len())
.sum()
self.size_hint().0
}
}

Expand Down Expand Up @@ -473,6 +469,18 @@ where
}
}

/// How many values will this cursor return?
fn remaining(&self, tables: &'w Tables, archetypes: &'w Archetypes) -> usize {
let remaining_matched: usize = if Self::IS_DENSE {
let ids = self.table_id_iter.clone();
ids.map(|id| tables[*id].len()).sum()
} else {
let ids = self.archetype_id_iter.clone();
ids.map(|id| archetypes[*id].len()).sum()
};
remaining_matched + self.current_len - self.current_index
}

// NOTE: If you are changing query iteration code, remember to update the following places, where relevant:
// QueryIterationCursor, QueryState::for_each_unchecked_manual, QueryState::par_for_each_unchecked_manual
/// # Safety
Expand Down
38 changes: 38 additions & 0 deletions crates/bevy_ecs/src/query/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -144,18 +144,30 @@ mod tests {
assert_eq!(values.iter(&world).size_hint().1.unwrap(), n);
assert_eq!(values.iter(&world).len(), n);
assert_eq!(values.iter(&world).count(), n);
let mut iterator = values.iter(&world);
let _ = iterator.next();
assert_eq!(iterator.len(), n - 1);

let mut values = world.query_filtered::<&A, Or<(With<B>, Without<C>)>>();
let n = 7;
assert_eq!(values.iter(&world).size_hint().0, n);
assert_eq!(values.iter(&world).size_hint().1.unwrap(), n);
assert_eq!(values.iter(&world).len(), n);
assert_eq!(values.iter(&world).count(), n);
let mut iterator = values.iter(&world);
let _ = iterator.next();
assert_eq!(iterator.len(), n - 1);

let mut values = world.query_filtered::<&A, Or<(Without<B>, With<C>)>>();
let n = 8;
assert_eq!(values.iter(&world).size_hint().0, n);
assert_eq!(values.iter(&world).size_hint().1.unwrap(), n);
assert_eq!(values.iter(&world).len(), n);
assert_eq!(values.iter(&world).count(), n);
let mut iterator = values.iter(&world);
let _ = iterator.next();
assert_eq!(iterator.len(), n - 1);

let mut values = world.query_filtered::<&A, Or<(Without<B>, Without<C>)>>();
let n = 9;
assert_eq!(values.iter(&world).size_hint().0, n);
Expand All @@ -169,6 +181,12 @@ mod tests {
assert_eq!(values.iter(&world).size_hint().1.unwrap(), n);
assert_eq!(values.iter(&world).len(), n);
assert_eq!(values.iter(&world).count(), n);
let mut iterator = values.iter(&world);
let _ = iterator.next();
assert_eq!(iterator.len(), 0);
let _ = iterator.next();
assert_eq!(iterator.len(), 0);

let mut values = world.query_filtered::<&A, Or<(Or<(With<B>, With<C>)>, With<D>)>>();
let n = 6;
assert_eq!(values.iter(&world).size_hint().0, n);
Expand Down Expand Up @@ -218,6 +236,18 @@ mod tests {
assert_eq!(a_query.iter_combinations::<128>(w).count(), 0);
assert_eq!(a_query.iter_combinations::<128>(w).size_hint().1, Some(0));

let mut combination = a_query.iter_combinations::<2>(w);
let mut expected = 6;
for _ in 0..6 {
let _ = combination.next();
expected -= 1;
assert_eq!(combination.size_hint().1, Some(expected));
}

let mut combination = a_query.iter_combinations::<4>(w);
let _ = combination.next();
assert_eq!(combination.size_hint().1, Some(0));

let values: Vec<[&A; 2]> = world.query::<&A>().iter_combinations(&world).collect();
assert_eq!(
values,
Expand Down Expand Up @@ -299,6 +329,10 @@ mod tests {
assert_eq!(a_with_b.iter_combinations::<128>(w).count(), 0);
assert_eq!(a_with_b.iter_combinations::<128>(w).size_hint().1, Some(0));

let mut combination = a_with_b.iter_combinations::<1>(w);
_ = combination.next();
assert_eq!(combination.size_hint().1, Some(0));

let mut a_wout_b = world.query_filtered::<&A, Without<B>>();
let w = &world;
assert_eq!(a_wout_b.iter_combinations::<0>(w).count(), 0);
Expand All @@ -316,6 +350,10 @@ mod tests {
assert_eq!(a_wout_b.iter_combinations::<128>(w).count(), 0);
assert_eq!(a_wout_b.iter_combinations::<128>(w).size_hint().1, Some(0));

let mut combination = a_wout_b.iter_combinations::<2>(w);
_ = combination.next();
assert_eq!(combination.size_hint().1, Some(2));

let values: HashSet<[&A; 2]> = a_wout_b.iter_combinations(&world).collect();
assert_eq!(
values,
Expand Down

0 comments on commit c94a1b3

Please sign in to comment.