Skip to content

Commit

Permalink
Add ExactSizeIterator implementation for QueryCombinatonIter (#5148)
Browse files Browse the repository at this point in the history
Following #5124 I decided to add the `ExactSizeIterator` impl for `QueryCombinationIter`.

Also:
- Clean up the tests for `size_hint` and `len` for both the normal `QueryIter` and `QueryCombinationIter`.
- Add tests to `QueryCombinationIter` when it shouldn't be `ExactSizeIterator`

---

## Changelog

- Added `ExactSizeIterator` implementation for `QueryCombinatonIter`
  • Loading branch information
nicopap committed Jul 13, 2022
1 parent 56d69c1 commit d0b98ca
Show file tree
Hide file tree
Showing 4 changed files with 184 additions and 174 deletions.
40 changes: 32 additions & 8 deletions crates/bevy_ecs/src/query/iter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -343,17 +343,26 @@ where
if max_size < K {
return (0, Some(0));
}
if max_size == K {
return (1, Some(1));
}

// n! / k!(n-k)! = (n*n-1*...*n-k+1) / k!
let max_combinations = (0..K)
.try_fold(1usize, |n, i| n.checked_mul(max_size - i))
.map(|n| {
let k_factorial: usize = (1..=K).product();
n / k_factorial
});
// binomial coefficient: (n ; k) = n! / k!(n-k)! = (n*n-1*...*n-k+1) / k!
// See https://en.wikipedia.org/wiki/Binomial_coefficient
// See https://blog.plover.com/math/choose.html for implementation
// It was chosen to reduce overflow potential.
fn choose(n: usize, k: usize) -> Option<usize> {
let ks = 1..=k;
let ns = (n - k + 1..=n).rev();
ks.zip(ns)
.try_fold(1_usize, |acc, (k, n)| Some(acc.checked_mul(n)? / k))
}
let smallest = K.min(max_size - K);
let max_combinations = choose(max_size, smallest);

let archetype_query = F::Fetch::IS_ARCHETYPAL && Q::Fetch::IS_ARCHETYPAL;
let min_combinations = if archetype_query { max_size } else { 0 };
let known_max = max_combinations.unwrap_or(usize::MAX);
let min_combinations = if archetype_query { known_max } else { 0 };
(min_combinations, max_combinations)
}
}
Expand All @@ -372,6 +381,21 @@ where
}
}

impl<'w, 's, Q: ReadOnlyWorldQuery, F: ReadOnlyWorldQuery + ArchetypeFilter, const K: usize>
ExactSizeIterator for QueryCombinationIter<'w, 's, Q, F, K>
where
QueryFetch<'w, Q>: Clone,
QueryFetch<'w, F>: Clone,
{
/// Returns the exact length of the iterator.
///
/// **NOTE**: When the iterator length overflows `usize`, this will
/// return `usize::MAX`.
fn len(&self) -> usize {
self.size_hint().0
}
}

// This is correct as [`QueryCombinationIter`] always returns `None` once exhausted.
impl<'w, 's, Q: ReadOnlyWorldQuery, F: ReadOnlyWorldQuery, const K: usize> FusedIterator
for QueryCombinationIter<'w, 's, Q, F, K>
Expand Down
267 changes: 101 additions & 166 deletions crates/bevy_ecs/src/query/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,11 @@ pub(crate) unsafe fn debug_checked_unreachable() -> ! {
#[cfg(test)]
mod tests {
use super::WorldQuery;
use crate::prelude::{AnyOf, Entity, Or, With, Without};
use crate::prelude::{AnyOf, Entity, Or, QueryState, With, Without};
use crate::query::{ArchetypeFilter, QueryCombinationIter, QueryFetch, ReadOnlyWorldQuery};
use crate::system::{IntoSystem, Query, System};
use crate::{self as bevy_ecs, component::Component, world::World};
use std::any::type_name;
use std::collections::HashSet;

#[derive(Component, Debug, Hash, Eq, PartialEq, Clone, Copy)]
Expand Down Expand Up @@ -54,24 +56,81 @@ mod tests {
}

#[test]
fn query_filtered_len() {
fn query_filtered_exactsizeiterator_len() {
fn choose(n: usize, k: usize) -> usize {
if n == 0 || k == 0 || n < k {
return 0;
}
let ks = 1..=k;
let ns = (n - k + 1..=n).rev();
ks.zip(ns).fold(1, |acc, (k, n)| acc * n / k)
}
fn assert_combination<Q, F, const K: usize>(world: &mut World, expected_size: usize)
where
Q: ReadOnlyWorldQuery,
F: ReadOnlyWorldQuery + ArchetypeFilter,
for<'w> QueryFetch<'w, Q>: Clone,
for<'w> QueryFetch<'w, F>: Clone,
{
let mut query = world.query_filtered::<Q, F>();
let iter = query.iter_combinations::<K>(world);
let query_type = type_name::<QueryCombinationIter<Q, F, K>>();
assert_all_sizes_iterator_equal(iter, expected_size, query_type);
}
fn assert_all_sizes_equal<Q, F>(world: &mut World, expected_size: usize)
where
Q: ReadOnlyWorldQuery,
F: ReadOnlyWorldQuery + ArchetypeFilter,
for<'w> QueryFetch<'w, Q>: Clone,
for<'w> QueryFetch<'w, F>: Clone,
{
let mut query = world.query_filtered::<Q, F>();
let iter = query.iter(world);
let query_type = type_name::<QueryState<Q, F>>();
assert_all_sizes_iterator_equal(iter, expected_size, query_type);

let expected = expected_size;
assert_combination::<Q, F, 0>(world, choose(expected, 0));
assert_combination::<Q, F, 1>(world, choose(expected, 1));
assert_combination::<Q, F, 2>(world, choose(expected, 2));
assert_combination::<Q, F, 5>(world, choose(expected, 5));
assert_combination::<Q, F, 43>(world, choose(expected, 43));
assert_combination::<Q, F, 128>(world, choose(expected, 128));
}
fn assert_all_sizes_iterator_equal(
iterator: impl ExactSizeIterator,
expected_size: usize,
query_type: &'static str,
) {
let size_hint_0 = iterator.size_hint().0;
let size_hint_1 = iterator.size_hint().1;
let len = iterator.len();
// `count` tests that not only it is the expected value, but also
// the value is accurate to what the query returns.
let count = iterator.count();
// This will show up when one of the asserts in this function fails
println!(
r#"query declared sizes:
for query: {query_type}
expected: {expected_size}
len(): {len}
size_hint().0: {size_hint_0}
size_hint().1: {size_hint_1:?}
count(): {count}"#
);
assert_eq!(len, expected_size);
assert_eq!(size_hint_0, expected_size);
assert_eq!(size_hint_1, Some(expected_size));
assert_eq!(count, expected_size);
}

let mut world = World::new();
world.spawn().insert_bundle((A(1), B(1)));
world.spawn().insert_bundle((A(2),));
world.spawn().insert_bundle((A(3),));

let mut values = world.query_filtered::<&A, With<B>>();
let n = 1;
assert_eq!(values.iter(&world).size_hint().0, n);
assert_eq!(values.iter(&world).size_hint().1.unwrap(), n);
assert_eq!(values.iter(&world).len(), n);
assert_eq!(values.iter(&world).count(), n);
let mut values = world.query_filtered::<&A, Without<B>>();
let n = 2;
assert_eq!(values.iter(&world).size_hint().0, n);
assert_eq!(values.iter(&world).size_hint().1.unwrap(), n);
assert_eq!(values.iter(&world).len(), n);
assert_eq!(values.iter(&world).count(), n);
assert_all_sizes_equal::<&A, With<B>>(&mut world, 1);
assert_all_sizes_equal::<&A, Without<B>>(&mut world, 2);

let mut world = World::new();
world.spawn().insert_bundle((A(1), B(1), C(1)));
Expand All @@ -86,110 +145,37 @@ mod tests {
world.spawn().insert_bundle((A(10),));

// With/Without for B and C
let mut values = world.query_filtered::<&A, With<B>>();
let n = 3;
assert_eq!(values.iter(&world).size_hint().0, n);
assert_eq!(values.iter(&world).size_hint().1.unwrap(), n);
assert_eq!(values.iter(&world).len(), n);
assert_eq!(values.iter(&world).count(), n);
let mut values = world.query_filtered::<&A, With<C>>();
let n = 4;
assert_eq!(values.iter(&world).size_hint().0, n);
assert_eq!(values.iter(&world).size_hint().1.unwrap(), n);
assert_eq!(values.iter(&world).len(), n);
assert_eq!(values.iter(&world).count(), n);
let mut values = world.query_filtered::<&A, Without<B>>();
let n = 7;
assert_eq!(values.iter(&world).size_hint().0, n);
assert_eq!(values.iter(&world).size_hint().1.unwrap(), n);
assert_eq!(values.iter(&world).len(), n);
assert_eq!(values.iter(&world).count(), n);
let mut values = world.query_filtered::<&A, Without<C>>();
let n = 6;
assert_eq!(values.iter(&world).size_hint().0, n);
assert_eq!(values.iter(&world).size_hint().1.unwrap(), n);
assert_eq!(values.iter(&world).len(), n);
assert_eq!(values.iter(&world).count(), n);
assert_all_sizes_equal::<&A, With<B>>(&mut world, 3);
assert_all_sizes_equal::<&A, With<C>>(&mut world, 4);
assert_all_sizes_equal::<&A, Without<B>>(&mut world, 7);
assert_all_sizes_equal::<&A, Without<C>>(&mut world, 6);

// With/Without (And) combinations
let mut values = world.query_filtered::<&A, (With<B>, With<C>)>();
let n = 1;
assert_eq!(values.iter(&world).size_hint().0, n);
assert_eq!(values.iter(&world).size_hint().1.unwrap(), n);
assert_eq!(values.iter(&world).len(), n);
assert_eq!(values.iter(&world).count(), n);
let mut values = world.query_filtered::<&A, (With<B>, Without<C>)>();
let n = 2;
assert_eq!(values.iter(&world).size_hint().0, n);
assert_eq!(values.iter(&world).size_hint().1.unwrap(), n);
assert_eq!(values.iter(&world).len(), n);
assert_eq!(values.iter(&world).count(), n);
let mut values = world.query_filtered::<&A, (Without<B>, With<C>)>();
let n = 3;
assert_eq!(values.iter(&world).size_hint().0, n);
assert_eq!(values.iter(&world).size_hint().1.unwrap(), n);
assert_eq!(values.iter(&world).len(), n);
assert_eq!(values.iter(&world).count(), n);
let mut values = world.query_filtered::<&A, (Without<B>, Without<C>)>();
let n = 4;
assert_eq!(values.iter(&world).size_hint().0, n);
assert_eq!(values.iter(&world).size_hint().1.unwrap(), n);
assert_eq!(values.iter(&world).len(), n);
assert_eq!(values.iter(&world).count(), n);
assert_all_sizes_equal::<&A, (With<B>, With<C>)>(&mut world, 1);
assert_all_sizes_equal::<&A, (With<B>, Without<C>)>(&mut world, 2);
assert_all_sizes_equal::<&A, (Without<B>, With<C>)>(&mut world, 3);
assert_all_sizes_equal::<&A, (Without<B>, Without<C>)>(&mut world, 4);

// With/Without Or<()> combinations
let mut values = world.query_filtered::<&A, Or<(With<B>, With<C>)>>();
let n = 6;
assert_eq!(values.iter(&world).size_hint().0, n);
assert_eq!(values.iter(&world).size_hint().1.unwrap(), n);
assert_eq!(values.iter(&world).len(), n);
assert_eq!(values.iter(&world).count(), n);
let mut values = world.query_filtered::<&A, Or<(With<B>, Without<C>)>>();
let n = 7;
assert_eq!(values.iter(&world).size_hint().0, n);
assert_eq!(values.iter(&world).size_hint().1.unwrap(), n);
assert_eq!(values.iter(&world).len(), n);
assert_eq!(values.iter(&world).count(), n);
let mut values = world.query_filtered::<&A, Or<(Without<B>, With<C>)>>();
let n = 8;
assert_eq!(values.iter(&world).size_hint().0, n);
assert_eq!(values.iter(&world).size_hint().1.unwrap(), n);
assert_eq!(values.iter(&world).len(), n);
assert_eq!(values.iter(&world).count(), n);
let mut values = world.query_filtered::<&A, Or<(Without<B>, Without<C>)>>();
let n = 9;
assert_eq!(values.iter(&world).size_hint().0, n);
assert_eq!(values.iter(&world).size_hint().1.unwrap(), n);
assert_eq!(values.iter(&world).len(), n);
assert_eq!(values.iter(&world).count(), n);

let mut values = world.query_filtered::<&A, (Or<(With<B>,)>, Or<(With<C>,)>)>();
let n = 1;
assert_eq!(values.iter(&world).size_hint().0, n);
assert_eq!(values.iter(&world).size_hint().1.unwrap(), n);
assert_eq!(values.iter(&world).len(), n);
assert_eq!(values.iter(&world).count(), n);
let mut values = world.query_filtered::<&A, Or<(Or<(With<B>, With<C>)>, With<D>)>>();
let n = 6;
assert_eq!(values.iter(&world).size_hint().0, n);
assert_eq!(values.iter(&world).size_hint().1.unwrap(), n);
assert_eq!(values.iter(&world).len(), n);
assert_eq!(values.iter(&world).count(), n);

world.spawn().insert_bundle((A(11), D(11)));

let mut values = world.query_filtered::<&A, Or<(Or<(With<B>, With<C>)>, With<D>)>>();
let n = 7;
assert_eq!(values.iter(&world).size_hint().0, n);
assert_eq!(values.iter(&world).size_hint().1.unwrap(), n);
assert_eq!(values.iter(&world).len(), n);
assert_eq!(values.iter(&world).count(), n);
let mut values = world.query_filtered::<&A, Or<(Or<(With<B>, With<C>)>, Without<D>)>>();
let n = 10;
assert_eq!(values.iter(&world).size_hint().0, n);
assert_eq!(values.iter(&world).size_hint().1.unwrap(), n);
assert_eq!(values.iter(&world).len(), n);
assert_eq!(values.iter(&world).count(), n);
assert_all_sizes_equal::<&A, Or<(With<B>, With<C>)>>(&mut world, 6);
assert_all_sizes_equal::<&A, Or<(With<B>, Without<C>)>>(&mut world, 7);
assert_all_sizes_equal::<&A, Or<(Without<B>, With<C>)>>(&mut world, 8);
assert_all_sizes_equal::<&A, Or<(Without<B>, Without<C>)>>(&mut world, 9);
assert_all_sizes_equal::<&A, (Or<(With<B>,)>, Or<(With<C>,)>)>(&mut world, 1);
assert_all_sizes_equal::<&A, Or<(Or<(With<B>, With<C>)>, With<D>)>>(&mut world, 6);

for i in 11..14 {
world.spawn().insert_bundle((A(i), D(i)));
}

assert_all_sizes_equal::<&A, Or<(Or<(With<B>, With<C>)>, With<D>)>>(&mut world, 9);
assert_all_sizes_equal::<&A, Or<(Or<(With<B>, With<C>)>, Without<D>)>>(&mut world, 10);

// a fair amount of entities
for i in 14..20 {
world.spawn().insert_bundle((C(i), D(i)));
}
assert_all_sizes_equal::<Entity, (With<C>, With<D>)>(&mut world, 6);
}

#[test]
Expand All @@ -201,23 +187,6 @@ mod tests {
world.spawn().insert_bundle((A(3),));
world.spawn().insert_bundle((A(4),));

let mut a_query = world.query::<&A>();
let w = &world;
assert_eq!(a_query.iter_combinations::<0>(w).count(), 0);
assert_eq!(a_query.iter_combinations::<0>(w).size_hint().1, Some(0));
assert_eq!(a_query.iter_combinations::<1>(w).count(), 4);
assert_eq!(a_query.iter_combinations::<1>(w).size_hint().1, Some(4));
assert_eq!(a_query.iter_combinations::<2>(w).count(), 6);
assert_eq!(a_query.iter_combinations::<2>(w).size_hint().1, Some(6));
assert_eq!(a_query.iter_combinations::<3>(w).count(), 4);
assert_eq!(a_query.iter_combinations::<3>(w).size_hint().1, Some(4));
assert_eq!(a_query.iter_combinations::<4>(w).count(), 1);
assert_eq!(a_query.iter_combinations::<4>(w).size_hint().1, Some(1));
assert_eq!(a_query.iter_combinations::<5>(w).count(), 0);
assert_eq!(a_query.iter_combinations::<5>(w).size_hint().1, Some(0));
assert_eq!(a_query.iter_combinations::<128>(w).count(), 0);
assert_eq!(a_query.iter_combinations::<128>(w).size_hint().1, Some(0));

let values: Vec<[&A; 2]> = world.query::<&A>().iter_combinations(&world).collect();
assert_eq!(
values,
Expand All @@ -230,8 +199,7 @@ mod tests {
[&A(3), &A(4)],
]
);
let size = a_query.iter_combinations::<3>(&world).size_hint();
assert_eq!(size.1, Some(4));
let mut a_query = world.query::<&A>();
let values: Vec<[&A; 3]> = a_query.iter_combinations(&world).collect();
assert_eq!(
values,
Expand Down Expand Up @@ -282,40 +250,7 @@ mod tests {
world.spawn().insert_bundle((A(3),));
world.spawn().insert_bundle((A(4),));

let mut a_with_b = world.query_filtered::<&A, With<B>>();
let w = &world;
assert_eq!(a_with_b.iter_combinations::<0>(w).count(), 0);
assert_eq!(a_with_b.iter_combinations::<0>(w).size_hint().1, Some(0));
assert_eq!(a_with_b.iter_combinations::<1>(w).count(), 1);
assert_eq!(a_with_b.iter_combinations::<1>(w).size_hint().1, Some(1));
assert_eq!(a_with_b.iter_combinations::<2>(w).count(), 0);
assert_eq!(a_with_b.iter_combinations::<2>(w).size_hint().1, Some(0));
assert_eq!(a_with_b.iter_combinations::<3>(w).count(), 0);
assert_eq!(a_with_b.iter_combinations::<3>(w).size_hint().1, Some(0));
assert_eq!(a_with_b.iter_combinations::<4>(w).count(), 0);
assert_eq!(a_with_b.iter_combinations::<4>(w).size_hint().1, Some(0));
assert_eq!(a_with_b.iter_combinations::<5>(w).count(), 0);
assert_eq!(a_with_b.iter_combinations::<5>(w).size_hint().1, Some(0));
assert_eq!(a_with_b.iter_combinations::<128>(w).count(), 0);
assert_eq!(a_with_b.iter_combinations::<128>(w).size_hint().1, Some(0));

let mut a_wout_b = world.query_filtered::<&A, Without<B>>();
let w = &world;
assert_eq!(a_wout_b.iter_combinations::<0>(w).count(), 0);
assert_eq!(a_wout_b.iter_combinations::<0>(w).size_hint().1, Some(0));
assert_eq!(a_wout_b.iter_combinations::<1>(w).count(), 3);
assert_eq!(a_wout_b.iter_combinations::<1>(w).size_hint().1, Some(3));
assert_eq!(a_wout_b.iter_combinations::<2>(w).count(), 3);
assert_eq!(a_wout_b.iter_combinations::<2>(w).size_hint().1, Some(3));
assert_eq!(a_wout_b.iter_combinations::<3>(w).count(), 1);
assert_eq!(a_wout_b.iter_combinations::<3>(w).size_hint().1, Some(1));
assert_eq!(a_wout_b.iter_combinations::<4>(w).count(), 0);
assert_eq!(a_wout_b.iter_combinations::<4>(w).size_hint().1, Some(0));
assert_eq!(a_wout_b.iter_combinations::<5>(w).count(), 0);
assert_eq!(a_wout_b.iter_combinations::<5>(w).size_hint().1, Some(0));
assert_eq!(a_wout_b.iter_combinations::<128>(w).count(), 0);
assert_eq!(a_wout_b.iter_combinations::<128>(w).size_hint().1, Some(0));

let values: HashSet<[&A; 2]> = a_wout_b.iter_combinations(&world).collect();
assert_eq!(
values,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
use bevy_ecs::prelude::*;

#[derive(Component)]
struct Foo;
#[derive(Component)]
struct Bar;

fn on_changed(query: Query<&Foo, Or<(Changed<Foo>, With<Bar>)>>) {
// this should fail to compile
is_exact_size_iterator(query.iter_combinations::<2>());
}

fn on_added(query: Query<&Foo, (Added<Foo>, Without<Bar>)>) {
// this should fail to compile
is_exact_size_iterator(query.iter_combinations::<2>());
}

fn is_exact_size_iterator<T: ExactSizeIterator>(_iter: T) {}

fn main() {}
Loading

0 comments on commit d0b98ca

Please sign in to comment.