Skip to content

Commit

Permalink
Implement collecting for par_for_each
Browse files Browse the repository at this point in the history
  • Loading branch information
TheRawMeatball committed May 3, 2021
1 parent afaf4ad commit 082ff31
Show file tree
Hide file tree
Showing 4 changed files with 114 additions and 106 deletions.
4 changes: 2 additions & 2 deletions crates/bevy_ecs/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ mod tests {
let e4 = world.spawn().insert_bundle((4, true)).id();
let e5 = world.spawn().insert_bundle((5, true)).id();
let results = Arc::new(Mutex::new(Vec::new()));
world
let _ = world
.query::<(Entity, &i32)>()
.par_for_each(&world, &task_pool, 2, |(e, &i)| results.lock().push((e, i)));
results.lock().sort();
Expand All @@ -285,7 +285,7 @@ mod tests {
let e4 = world.spawn().insert_bundle((4, true)).id();
let e5 = world.spawn().insert_bundle((5, true)).id();
let results = Arc::new(Mutex::new(Vec::new()));
world
let _ = world
.query::<(Entity, &i32)>()
.par_for_each(&world, &task_pool, 2, |(e, &i)| results.lock().push((e, i)));
results.lock().sort();
Expand Down
195 changes: 101 additions & 94 deletions crates/bevy_ecs/src/query/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -280,47 +280,44 @@ where
}

#[inline]
pub fn par_for_each<'w>(
pub fn par_for_each<'w, R: Send + Sync + 'static>(
&mut self,
world: &'w World,
task_pool: &TaskPool,
batch_size: usize,
func: impl Fn(<Q::Fetch as Fetch<'w>>::Item) + Send + Sync + Clone,
) where
func: impl Fn(<Q::Fetch as Fetch<'w>>::Item) -> R + Send + Sync + Clone,
) -> impl Iterator<Item = R>
where
Q::Fetch: ReadOnlyFetch,
{
// SAFETY: query is read only
unsafe {
self.par_for_each_unchecked(world, task_pool, batch_size, func);
}
unsafe { self.par_for_each_unchecked(world, task_pool, batch_size, func) }
}

#[inline]
pub fn par_for_each_mut<'w>(
pub fn par_for_each_mut<'w, R: Send + Sync + 'static>(
&mut self,
world: &'w mut World,
task_pool: &TaskPool,
batch_size: usize,
func: impl Fn(<Q::Fetch as Fetch<'w>>::Item) + Send + Sync + Clone,
) {
func: impl Fn(<Q::Fetch as Fetch<'w>>::Item) -> R + Send + Sync + Clone,
) -> impl Iterator<Item = R> {
// SAFETY: query has unique world access
unsafe {
self.par_for_each_unchecked(world, task_pool, batch_size, func);
}
unsafe { self.par_for_each_unchecked(world, task_pool, batch_size, func) }
}

/// # Safety
///
/// This does not check for mutable query correctness. To be safe, make sure mutable queries
/// have unique access to the components they query.
#[inline]
pub unsafe fn par_for_each_unchecked<'w>(
pub unsafe fn par_for_each_unchecked<'w, R: Send + Sync + 'static>(
&mut self,
world: &'w World,
task_pool: &TaskPool,
batch_size: usize,
func: impl Fn(<Q::Fetch as Fetch<'w>>::Item) + Send + Sync + Clone,
) {
func: impl Fn(<Q::Fetch as Fetch<'w>>::Item) -> R + Send + Sync + Clone,
) -> impl Iterator<Item = R> {
self.validate_world_and_update_archetypes(world);
self.par_for_each_unchecked_manual(
world,
Expand All @@ -329,7 +326,7 @@ where
func,
world.last_change_tick(),
world.read_change_tick(),
);
)
}

/// # Safety
Expand Down Expand Up @@ -388,95 +385,105 @@ where
/// have unique access to the components they query.
/// This does not validate that `world.id()` matches `self.world_id`. Calling this on a `world`
/// with a mismatched WorldId is unsound.
pub unsafe fn par_for_each_unchecked_manual<'w, 's>(
pub unsafe fn par_for_each_unchecked_manual<'w, 's, R: Send + Sync + 'static>(
&'s self,
world: &'w World,
task_pool: &TaskPool,
batch_size: usize,
func: impl Fn(<Q::Fetch as Fetch<'w>>::Item) + Send + Sync + Clone,
func: impl Fn(<Q::Fetch as Fetch<'w>>::Item) -> R + Send + Sync + Clone,
last_change_tick: u32,
change_tick: u32,
) {
task_pool.scope(|scope| {
let fetch =
<Q::Fetch as Fetch>::init(world, &self.fetch_state, last_change_tick, change_tick);
let filter =
<F::Fetch as Fetch>::init(world, &self.filter_state, last_change_tick, change_tick);
) -> impl Iterator<Item = R> {
task_pool
.scope(|scope| {
let fetch = <Q::Fetch as Fetch>::init(
world,
&self.fetch_state,
last_change_tick,
change_tick,
);
let filter = <F::Fetch as Fetch>::init(
world,
&self.filter_state,
last_change_tick,
change_tick,
);

if fetch.is_dense() && filter.is_dense() {
let tables = &world.storages().tables;
for table_id in self.matched_table_ids.iter() {
let table = &tables[*table_id];
let mut offset = 0;
while offset < table.len() {
let func = func.clone();
scope.spawn(async move {
let mut fetch = <Q::Fetch as Fetch>::init(
world,
&self.fetch_state,
last_change_tick,
change_tick,
);
let mut filter = <F::Fetch as Fetch>::init(
world,
&self.filter_state,
last_change_tick,
change_tick,
);
let tables = &world.storages().tables;
let table = &tables[*table_id];
fetch.set_table(&self.fetch_state, table);
filter.set_table(&self.filter_state, table);
let len = batch_size.min(table.len() - offset);
for table_index in offset..offset + len {
if !filter.table_filter_fetch(table_index) {
continue;
}
let item = fetch.table_fetch(table_index);
func(item);
}
});
offset += batch_size;
if fetch.is_dense() && filter.is_dense() {
let tables = &world.storages().tables;
for table_id in self.matched_table_ids.iter() {
let table = &tables[*table_id];
let mut offset = 0;
while offset < table.len() {
let func = func.clone();
scope.spawn(async move {
let mut fetch = <Q::Fetch as Fetch>::init(
world,
&self.fetch_state,
last_change_tick,
change_tick,
);
let mut filter = <F::Fetch as Fetch>::init(
world,
&self.filter_state,
last_change_tick,
change_tick,
);
let tables = &world.storages().tables;
let table = &tables[*table_id];
fetch.set_table(&self.fetch_state, table);
filter.set_table(&self.filter_state, table);
let len = batch_size.min(table.len() - offset);
(offset..offset + len)
.filter(|&table_index| filter.table_filter_fetch(table_index))
.map(|table_index| fetch.table_fetch(table_index))
.map(func)
.collect::<Vec<_>>()
});
offset += batch_size;
}
}
}
} else {
let archetypes = &world.archetypes;
for archetype_id in self.matched_archetype_ids.iter() {
let mut offset = 0;
let archetype = &archetypes[*archetype_id];
while offset < archetype.len() {
let func = func.clone();
scope.spawn(async move {
let mut fetch = <Q::Fetch as Fetch>::init(
world,
&self.fetch_state,
last_change_tick,
change_tick,
);
let mut filter = <F::Fetch as Fetch>::init(
world,
&self.filter_state,
last_change_tick,
change_tick,
);
let tables = &world.storages().tables;
let archetype = &world.archetypes[*archetype_id];
fetch.set_archetype(&self.fetch_state, archetype, tables);
filter.set_archetype(&self.filter_state, archetype, tables);
} else {
let archetypes = &world.archetypes;
for archetype_id in self.matched_archetype_ids.iter() {
let mut offset = 0;
let archetype = &archetypes[*archetype_id];
while offset < archetype.len() {
let func = func.clone();
scope.spawn(async move {
let mut fetch = <Q::Fetch as Fetch>::init(
world,
&self.fetch_state,
last_change_tick,
change_tick,
);
let mut filter = <F::Fetch as Fetch>::init(
world,
&self.filter_state,
last_change_tick,
change_tick,
);
let tables = &world.storages().tables;
let archetype = &world.archetypes[*archetype_id];
fetch.set_archetype(&self.fetch_state, archetype, tables);
filter.set_archetype(&self.filter_state, archetype, tables);

let len = batch_size.min(archetype.len() - offset);
for archetype_index in offset..offset + len {
if !filter.archetype_filter_fetch(archetype_index) {
continue;
}
func(fetch.archetype_fetch(archetype_index));
}
});
offset += batch_size;
let len = batch_size.min(archetype.len() - offset);
(offset..offset + len)
.filter(|&archetype_index| {
filter.archetype_filter_fetch(archetype_index)
})
.map(|archetype_index| fetch.archetype_fetch(archetype_index))
.map(func)
.collect::<Vec<_>>()
});
offset += batch_size;
}
}
}
}
});
})
.into_iter()
.flatten()
}
}

Expand Down
17 changes: 9 additions & 8 deletions crates/bevy_ecs/src/system/query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -224,12 +224,13 @@ where
/// This can only be called for read-only queries, see [`Self::par_for_each_mut`] for
/// write-queries.
#[inline]
pub fn par_for_each(
pub fn par_for_each<R: Send + Sync + 'static>(
&self,
task_pool: &TaskPool,
batch_size: usize,
f: impl Fn(<Q::Fetch as Fetch<'w>>::Item) + Send + Sync + Clone,
) where
f: impl Fn(<Q::Fetch as Fetch<'w>>::Item) -> R + Send + Sync + Clone,
) -> impl Iterator<Item = R>
where
Q::Fetch: ReadOnlyFetch,
{
// SAFE: system runs without conflicts with other systems. same-system queries have runtime
Expand All @@ -243,17 +244,17 @@ where
self.last_change_tick,
self.change_tick,
)
};
}
}

/// Runs `f` on each query result in parallel using the given task pool.
#[inline]
pub fn par_for_each_mut(
pub fn par_for_each_mut<R: Send + Sync + 'static>(
&mut self,
task_pool: &TaskPool,
batch_size: usize,
f: impl Fn(<Q::Fetch as Fetch<'w>>::Item) + Send + Sync + Clone,
) {
f: impl Fn(<Q::Fetch as Fetch<'w>>::Item) -> R + Send + Sync + Clone,
) -> impl Iterator<Item = R> {
// SAFE: system runs without conflicts with other systems. same-system queries have runtime
// borrow checks when they conflict
unsafe {
Expand All @@ -265,7 +266,7 @@ where
self.last_change_tick,
self.change_tick,
)
};
}
}

/// Gets the query result for the given [`Entity`].
Expand Down
4 changes: 2 additions & 2 deletions examples/ecs/parallel_query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ fn move_system(pool: Res<ComputeTaskPool>, mut sprites: Query<(&mut Transform, &
// elements will not typically be faster than just using a normal Iterator.
// See the ParallelIterator documentation for more information on when
// to use or not use ParallelIterator over a normal Iterator.
sprites.par_for_each_mut(&pool, 32, |(mut transform, velocity)| {
let _ = sprites.par_for_each_mut(&pool, 32, |(mut transform, velocity)| {
transform.translation += velocity.0.extend(0.0);
});
}
Expand All @@ -52,7 +52,7 @@ fn bounce_system(
let right = width / 2.0;
let bottom = height / -2.0;
let top = height / 2.0;
sprites
let _ = sprites
// Batch size of 32 is chosen to limit the overhead of
// ParallelIterator, since negating a vector is very inexpensive.
.par_for_each_mut(&pool, 32, |(transform, mut v)| {
Expand Down

0 comments on commit 082ff31

Please sign in to comment.