From 26adeab6869662b6bd2b7fce7c32e90274df6baa Mon Sep 17 00:00:00 2001 From: Nicola Papale Date: Tue, 5 Jul 2022 16:35:51 +0200 Subject: [PATCH] Fix size_hint for partially consumed QueryIter Instead of returning the total count of elements in the `QueryIter` in `size_hint`, we return the count of remaining elements in it. This Fixes #5149. This is also true of `QueryCombinationIter`. - https://github.com/bevyengine/bevy/issues/5149 - https://github.com/bevyengine/bevy/pull/5148 --- crates/bevy_ecs/src/query/iter.rs | 70 +++++++++++++++---------------- crates/bevy_ecs/src/query/mod.rs | 22 +++++++--- 2 files changed, 51 insertions(+), 41 deletions(-) diff --git a/crates/bevy_ecs/src/query/iter.rs b/crates/bevy_ecs/src/query/iter.rs index 498e04fe4245df..233e55938187e2 100644 --- a/crates/bevy_ecs/src/query/iter.rs +++ b/crates/bevy_ecs/src/query/iter.rs @@ -62,13 +62,7 @@ where } fn size_hint(&self) -> (usize, Option) { - let max_size = self - .query_state - .matched_archetype_ids - .iter() - .map(|id| self.archetypes[*id].len()) - .sum(); - + let max_size = self.cursor.remaining(self.tables, self.archetypes); let archetype_query = F::Fetch::IS_ARCHETYPAL && QF::IS_ARCHETYPAL; let min_size = if archetype_query { max_size } else { 0 }; (min_size, Some(max_size)) @@ -264,11 +258,16 @@ impl<'w, 's, Q: WorldQuery, F: WorldQuery, const K: usize> QueryCombinationIter< return None; } - // first, iterate from last to first until next item is found + // TODO: can speed up the following code using `cursor.remaining()` instead of `next_item.is_none()` + // when Q::Fetch::IS_ARCHETYPAL && F::Fetch::IS_ARCHETYPAL + // + // let `i` be the index of `c`, the last cursor in `self.cursors` that + // returns `K-i` or more elements. + // Make cursor in index `j` for all `j` in `[i, K)` a copy of `c` advanced `j-i+1` times. + // If no such `c` exists, return `None` 'outer: for i in (0..K).rev() { match self.cursors[i].next(self.tables, self.archetypes, self.query_state) { Some(_) => { - // walk forward up to last element, propagating cursor state forward for j in (i + 1)..K { self.cursors[j] = self.cursors[j - 1].clone(); match self.cursors[j].next(self.tables, self.archetypes, self.query_state) { @@ -329,36 +328,29 @@ where } fn size_hint(&self) -> (usize, Option) { - if K == 0 { - return (0, Some(0)); - } - - let max_size: usize = self - .query_state - .matched_archetype_ids - .iter() - .map(|id| self.archetypes[*id].len()) - .sum(); - - if max_size < K { - return (0, Some(0)); - } - if max_size == K { - return (1, Some(1)); - } - // binomial coefficient: (n ; k) = n! / k!(n-k)! = (n*n-1*...*n-k+1) / k! // See https://en.wikipedia.org/wiki/Binomial_coefficient // See https://blog.plover.com/math/choose.html for implementation // It was chosen to reduce overflow potential. fn choose(n: usize, k: usize) -> Option { + if k > n || n == 0 { + return Some(0); + } + let k = k.min(n - k); let ks = 1..=k; - let ns = (n - k + 1..=n).rev(); + let ns = (n + 1 - k..=n).rev(); ks.zip(ns) .try_fold(1_usize, |acc, (k, n)| Some(acc.checked_mul(n)? / k)) } - let smallest = K.min(max_size - K); - let max_combinations = choose(max_size, smallest); + // sum_i=0..k choose(cursors[i].remaining, k-i) + let max_combinations = self + .cursors + .iter() + .enumerate() + .try_fold(0, |acc, (i, cursor)| { + let n = cursor.remaining(self.tables, self.archetypes); + Some(acc + choose(n, K - i)?) + }); let archetype_query = F::Fetch::IS_ARCHETYPAL && Q::Fetch::IS_ARCHETYPAL; let known_max = max_combinations.unwrap_or(usize::MAX); @@ -373,11 +365,7 @@ where F: WorldQuery + ArchetypeFilter, { fn len(&self) -> usize { - self.query_state - .matched_archetype_ids - .iter() - .map(|id| self.archetypes[*id].len()) - .sum() + self.size_hint().0 } } @@ -497,6 +485,18 @@ where } } + /// How many values will this cursor return? + fn remaining(&self, tables: &'w Tables, archetypes: &'w Archetypes) -> usize { + let remaining_matched: usize = if Self::IS_DENSE { + let ids = self.table_id_iter.clone(); + ids.map(|id| tables[*id].len()).sum() + } else { + let ids = self.archetype_id_iter.clone(); + ids.map(|id| archetypes[*id].len()).sum() + }; + remaining_matched + self.current_len - self.current_index + } + // NOTE: If you are changing query iteration code, remember to update the following places, where relevant: // QueryIterationCursor, QueryState::for_each_unchecked_manual, QueryState::par_for_each_unchecked_manual /// # Safety diff --git a/crates/bevy_ecs/src/query/mod.rs b/crates/bevy_ecs/src/query/mod.rs index 2a6da06c314f1c..7f5cafce9068ca 100644 --- a/crates/bevy_ecs/src/query/mod.rs +++ b/crates/bevy_ecs/src/query/mod.rs @@ -73,9 +73,13 @@ mod tests { for<'w> QueryFetch<'w, F>: Clone, { let mut query = world.query_filtered::(); - let iter = query.iter_combinations::(world); let query_type = type_name::>(); - assert_all_sizes_iterator_equal(iter, expected_size, query_type); + let iter = query.iter_combinations::(world); + assert_all_sizes_iterator_equal(iter, expected_size, 0, query_type); + let iter = query.iter_combinations::(world); + assert_all_sizes_iterator_equal(iter, expected_size, 1, query_type); + let iter = query.iter_combinations::(world); + assert_all_sizes_iterator_equal(iter, expected_size, 5, query_type); } fn assert_all_sizes_equal(world: &mut World, expected_size: usize) where @@ -85,9 +89,10 @@ mod tests { for<'w> QueryFetch<'w, F>: Clone, { let mut query = world.query_filtered::(); - let iter = query.iter(world); let query_type = type_name::>(); - assert_all_sizes_iterator_equal(iter, expected_size, query_type); + assert_all_sizes_iterator_equal(query.iter(world), expected_size, 0, query_type); + assert_all_sizes_iterator_equal(query.iter(world), expected_size, 1, query_type); + assert_all_sizes_iterator_equal(query.iter(world), expected_size, 5, query_type); let expected = expected_size; assert_combination::(world, choose(expected, 0)); @@ -95,13 +100,18 @@ mod tests { assert_combination::(world, choose(expected, 2)); assert_combination::(world, choose(expected, 5)); assert_combination::(world, choose(expected, 43)); - assert_combination::(world, choose(expected, 128)); + assert_combination::(world, choose(expected, 64)); } fn assert_all_sizes_iterator_equal( - iterator: impl ExactSizeIterator, + mut iterator: impl ExactSizeIterator, expected_size: usize, + skip: usize, query_type: &'static str, ) { + let expected_size = expected_size.saturating_sub(skip); + for _ in 0..skip { + iterator.next(); + } let size_hint_0 = iterator.size_hint().0; let size_hint_1 = iterator.size_hint().1; let len = iterator.len();