From 886622c5e971c46eff2b55caf950ce385990366c Mon Sep 17 00:00:00 2001 From: James Liu Date: Mon, 30 May 2022 16:59:38 +0000 Subject: [PATCH] Remove task_pool parameter from par_for_each(_mut) (#4705) # Objective Fixes #3183. Requiring a `&TaskPool` parameter is sort of meaningless if the only correct one is to use the one provided by `Res` all the time. ## Solution Have `QueryState` save a clone of the `ComputeTaskPool` which is used for all `par_for_each` functions. ~~Adds a small overhead of the internal `Arc` clone as a part of the startup, but the ergonomics win should be well worth this hardly-noticable overhead.~~ Updated the docs to note that it will panic the task pool is not present as a resource. # Future Work If https://github.com/bevyengine/rfcs/pull/54 is approved, we can replace these resource lookups with a static function call instead to get the `ComputeTaskPool`. --- ## Changelog Removed: The `task_pool` parameter of `Query(State)::par_for_each(_mut)`. These calls will use the `World`'s `ComputeTaskPool` resource instead. ## Migration Guide The `task_pool` parameter for `Query(State)::par_for_each(_mut)` has been removed. Remove these parameters from all calls to these functions. Before: ```rust fn parallel_system( task_pool: Res, query: Query<&MyComponent>, ) { query.par_for_each(&task_pool, 32, |comp| { ... }); } ``` After: ```rust fn parallel_system(query: Query<&MyComponent>) { query.par_for_each(32, |comp| { ... }); } ``` If using `Query(State)` outside of a system run by the scheduler, you may need to manually configure and initialize a `ComputeTaskPool` as a resource in the `World`. --- .../bevy_ecs/ecs_bench_suite/heavy_compute.rs | 8 +- crates/bevy_ecs/src/lib.rs | 10 +- crates/bevy_ecs/src/query/state.rs | 216 ++++++++++-------- crates/bevy_ecs/src/system/query.rs | 24 +- examples/ecs/parallel_query.rs | 16 +- 5 files changed, 149 insertions(+), 125 deletions(-) diff --git a/benches/benches/bevy_ecs/ecs_bench_suite/heavy_compute.rs b/benches/benches/bevy_ecs/ecs_bench_suite/heavy_compute.rs index 4ddae1781ea1b6..71a8e6cc6a950d 100644 --- a/benches/benches/bevy_ecs/ecs_bench_suite/heavy_compute.rs +++ b/benches/benches/bevy_ecs/ecs_bench_suite/heavy_compute.rs @@ -1,5 +1,5 @@ use bevy_ecs::prelude::*; -use bevy_tasks::TaskPool; +use bevy_tasks::{ComputeTaskPool, TaskPool}; use glam::*; #[derive(Component, Copy, Clone)] @@ -29,8 +29,8 @@ impl Benchmark { ) })); - fn sys(task_pool: Res, mut query: Query<(&mut Position, &mut Transform)>) { - query.par_for_each_mut(&task_pool, 128, |(mut pos, mut mat)| { + fn sys(mut query: Query<(&mut Position, &mut Transform)>) { + query.par_for_each_mut(128, |(mut pos, mut mat)| { for _ in 0..100 { mat.0 = mat.0.inverse(); } @@ -39,7 +39,7 @@ impl Benchmark { }); } - world.insert_resource(TaskPool::default()); + world.insert_resource(ComputeTaskPool(TaskPool::default())); let mut system = IntoSystem::into_system(sys); system.initialize(&mut world); system.update_archetype_component_access(&world); diff --git a/crates/bevy_ecs/src/lib.rs b/crates/bevy_ecs/src/lib.rs index 3cc81f754015f7..30de6ec1a8170f 100644 --- a/crates/bevy_ecs/src/lib.rs +++ b/crates/bevy_ecs/src/lib.rs @@ -59,7 +59,7 @@ mod tests { query::{Added, ChangeTrackers, Changed, FilteredAccess, With, Without, WorldQuery}, world::{Mut, World}, }; - use bevy_tasks::TaskPool; + use bevy_tasks::{ComputeTaskPool, TaskPool}; use std::{ any::TypeId, sync::{ @@ -376,7 +376,7 @@ mod tests { #[test] fn par_for_each_dense() { let mut world = World::new(); - let task_pool = TaskPool::default(); + world.insert_resource(ComputeTaskPool(TaskPool::default())); let e1 = world.spawn().insert(A(1)).id(); let e2 = world.spawn().insert(A(2)).id(); let e3 = world.spawn().insert(A(3)).id(); @@ -385,7 +385,7 @@ mod tests { let results = Arc::new(Mutex::new(Vec::new())); world .query::<(Entity, &A)>() - .par_for_each(&world, &task_pool, 2, |(e, &A(i))| { + .par_for_each(&world, 2, |(e, &A(i))| { results.lock().unwrap().push((e, i)); }); results.lock().unwrap().sort(); @@ -398,8 +398,7 @@ mod tests { #[test] fn par_for_each_sparse() { let mut world = World::new(); - - let task_pool = TaskPool::default(); + world.insert_resource(ComputeTaskPool(TaskPool::default())); let e1 = world.spawn().insert(SparseStored(1)).id(); let e2 = world.spawn().insert(SparseStored(2)).id(); let e3 = world.spawn().insert(SparseStored(3)).id(); @@ -408,7 +407,6 @@ mod tests { let results = Arc::new(Mutex::new(Vec::new())); world.query::<(Entity, &SparseStored)>().par_for_each( &world, - &task_pool, 2, |(e, &SparseStored(i))| results.lock().unwrap().push((e, i)), ); diff --git a/crates/bevy_ecs/src/query/state.rs b/crates/bevy_ecs/src/query/state.rs index 4125f26584c3f2..b88407f69212fd 100644 --- a/crates/bevy_ecs/src/query/state.rs +++ b/crates/bevy_ecs/src/query/state.rs @@ -10,17 +10,18 @@ use crate::{ storage::TableId, world::{World, WorldId}, }; -use bevy_tasks::TaskPool; +use bevy_tasks::{ComputeTaskPool, TaskPool}; #[cfg(feature = "trace")] use bevy_utils::tracing::Instrument; use fixedbitset::FixedBitSet; -use std::fmt; +use std::{fmt, ops::Deref}; use super::{QueryFetch, QueryItem, ROQueryFetch, ROQueryItem}; /// Provides scoped access to a [`World`] state according to a given [`WorldQuery`] and query filter. pub struct QueryState { world_id: WorldId, + task_pool: Option, pub(crate) archetype_generation: ArchetypeGeneration, pub(crate) matched_tables: FixedBitSet, pub(crate) matched_archetypes: FixedBitSet, @@ -61,6 +62,9 @@ impl QueryState { let mut state = Self { world_id: world.id(), + task_pool: world + .get_resource::() + .map(|task_pool| task_pool.deref().clone()), archetype_generation: ArchetypeGeneration::initial(), matched_table_ids: Vec::new(), matched_archetype_ids: Vec::new(), @@ -689,15 +693,18 @@ impl QueryState { ); } - /// Runs `func` on each query result in parallel using the given `task_pool`. + /// Runs `func` on each query result in parallel. /// /// This can only be called for read-only queries, see [`Self::par_for_each_mut`] for /// write-queries. + /// + /// # Panics + /// The [`ComputeTaskPool`] resource must be added to the `World` before using this method. If using this from a query + /// that is being initialized and run from the ECS scheduler, this should never panic. #[inline] pub fn par_for_each<'w, FN: Fn(ROQueryItem<'w, Q>) + Send + Sync + Clone>( &mut self, world: &'w World, - task_pool: &TaskPool, batch_size: usize, func: FN, ) { @@ -706,7 +713,6 @@ impl QueryState { self.update_archetypes(world); self.par_for_each_unchecked_manual::, FN>( world, - task_pool, batch_size, func, world.last_change_tick(), @@ -715,12 +721,15 @@ impl QueryState { } } - /// Runs `func` on each query result in parallel using the given `task_pool`. + /// Runs `func` on each query result in parallel. + /// + /// # Panics + /// The [`ComputeTaskPool`] resource must be added to the `World` before using this method. If using this from a query + /// that is being initialized and run from the ECS scheduler, this should never panic. #[inline] pub fn par_for_each_mut<'w, FN: Fn(QueryItem<'w, Q>) + Send + Sync + Clone>( &mut self, world: &'w mut World, - task_pool: &TaskPool, batch_size: usize, func: FN, ) { @@ -729,7 +738,6 @@ impl QueryState { self.update_archetypes(world); self.par_for_each_unchecked_manual::, FN>( world, - task_pool, batch_size, func, world.last_change_tick(), @@ -738,10 +746,14 @@ impl QueryState { } } - /// Runs `func` on each query result in parallel using the given `task_pool`. + /// Runs `func` on each query result in parallel. /// /// This can only be called for read-only queries. /// + /// # Panics + /// [`ComputeTaskPool`] was not stored in the world at initialzation. If using this from a query + /// that is being initialized and run from the ECS scheduler, this should never panic. + /// /// # Safety /// /// This does not check for mutable query correctness. To be safe, make sure mutable queries @@ -750,14 +762,12 @@ impl QueryState { pub unsafe fn par_for_each_unchecked<'w, FN: Fn(QueryItem<'w, Q>) + Send + Sync + Clone>( &mut self, world: &'w World, - task_pool: &TaskPool, batch_size: usize, func: FN, ) { self.update_archetypes(world); self.par_for_each_unchecked_manual::, FN>( world, - task_pool, batch_size, func, world.last_change_tick(), @@ -833,6 +843,10 @@ impl QueryState { /// the current change tick are given. This is faster than the equivalent /// iter() method, but cannot be chained like a normal [`Iterator`]. /// + /// # Panics + /// [`ComputeTaskPool`] was not stored in the world at initialzation. If using this from a query + /// that is being initialized and run from the ECS scheduler, this should never panic. + /// /// # Safety /// /// This does not check for mutable query correctness. To be safe, make sure mutable queries @@ -846,7 +860,6 @@ impl QueryState { >( &self, world: &'w World, - task_pool: &TaskPool, batch_size: usize, func: FN, last_change_tick: u32, @@ -854,95 +867,106 @@ impl QueryState { ) { // NOTE: If you are changing query iteration code, remember to update the following places, where relevant: // QueryIter, QueryIterationCursor, QueryState::for_each_unchecked_manual, QueryState::par_for_each_unchecked_manual - task_pool.scope(|scope| { - if QF::IS_DENSE && >::IS_DENSE { - let tables = &world.storages().tables; - for table_id in &self.matched_table_ids { - let table = &tables[*table_id]; - let mut offset = 0; - while offset < table.len() { - let func = func.clone(); - let len = batch_size.min(table.len() - offset); - let task = async move { - let mut fetch = - QF::init(world, &self.fetch_state, last_change_tick, change_tick); - let mut filter = as Fetch>::init( - world, - &self.filter_state, - last_change_tick, - change_tick, - ); - let tables = &world.storages().tables; - let table = &tables[*table_id]; - fetch.set_table(&self.fetch_state, table); - filter.set_table(&self.filter_state, table); - for table_index in offset..offset + len { - if !filter.table_filter_fetch(table_index) { - continue; + self.task_pool + .as_ref() + .expect("Cannot iterate query in parallel. No ComputeTaskPool initialized.") + .scope(|scope| { + if QF::IS_DENSE && >::IS_DENSE { + let tables = &world.storages().tables; + for table_id in &self.matched_table_ids { + let table = &tables[*table_id]; + let mut offset = 0; + while offset < table.len() { + let func = func.clone(); + let len = batch_size.min(table.len() - offset); + let task = async move { + let mut fetch = QF::init( + world, + &self.fetch_state, + last_change_tick, + change_tick, + ); + let mut filter = as Fetch>::init( + world, + &self.filter_state, + last_change_tick, + change_tick, + ); + let tables = &world.storages().tables; + let table = &tables[*table_id]; + fetch.set_table(&self.fetch_state, table); + filter.set_table(&self.filter_state, table); + for table_index in offset..offset + len { + if !filter.table_filter_fetch(table_index) { + continue; + } + let item = fetch.table_fetch(table_index); + func(item); } - let item = fetch.table_fetch(table_index); - func(item); - } - }; - #[cfg(feature = "trace")] - let span = bevy_utils::tracing::info_span!( - "par_for_each", - query = std::any::type_name::(), - filter = std::any::type_name::(), - count = len, - ); - #[cfg(feature = "trace")] - let task = task.instrument(span); - scope.spawn(task); - offset += batch_size; - } - } - } else { - let archetypes = &world.archetypes; - for archetype_id in &self.matched_archetype_ids { - let mut offset = 0; - let archetype = &archetypes[*archetype_id]; - while offset < archetype.len() { - let func = func.clone(); - let len = batch_size.min(archetype.len() - offset); - let task = async move { - let mut fetch = - QF::init(world, &self.fetch_state, last_change_tick, change_tick); - let mut filter = as Fetch>::init( - world, - &self.filter_state, - last_change_tick, - change_tick, + }; + #[cfg(feature = "trace")] + let span = bevy_utils::tracing::info_span!( + "par_for_each", + query = std::any::type_name::(), + filter = std::any::type_name::(), + count = len, ); - let tables = &world.storages().tables; - let archetype = &world.archetypes[*archetype_id]; - fetch.set_archetype(&self.fetch_state, archetype, tables); - filter.set_archetype(&self.filter_state, archetype, tables); - - for archetype_index in offset..offset + len { - if !filter.archetype_filter_fetch(archetype_index) { - continue; + #[cfg(feature = "trace")] + let task = task.instrument(span); + scope.spawn(task); + offset += batch_size; + } + } + } else { + let archetypes = &world.archetypes; + for archetype_id in &self.matched_archetype_ids { + let mut offset = 0; + let archetype = &archetypes[*archetype_id]; + while offset < archetype.len() { + let func = func.clone(); + let len = batch_size.min(archetype.len() - offset); + let task = async move { + let mut fetch = QF::init( + world, + &self.fetch_state, + last_change_tick, + change_tick, + ); + let mut filter = as Fetch>::init( + world, + &self.filter_state, + last_change_tick, + change_tick, + ); + let tables = &world.storages().tables; + let archetype = &world.archetypes[*archetype_id]; + fetch.set_archetype(&self.fetch_state, archetype, tables); + filter.set_archetype(&self.filter_state, archetype, tables); + + for archetype_index in offset..offset + len { + if !filter.archetype_filter_fetch(archetype_index) { + continue; + } + func(fetch.archetype_fetch(archetype_index)); } - func(fetch.archetype_fetch(archetype_index)); - } - }; - - #[cfg(feature = "trace")] - let span = bevy_utils::tracing::info_span!( - "par_for_each", - query = std::any::type_name::(), - filter = std::any::type_name::(), - count = len, - ); - #[cfg(feature = "trace")] - let task = task.instrument(span); - - scope.spawn(task); - offset += batch_size; + }; + + #[cfg(feature = "trace")] + let span = bevy_utils::tracing::info_span!( + "par_for_each", + query = std::any::type_name::(), + filter = std::any::type_name::(), + count = len, + ); + #[cfg(feature = "trace")] + let task = task.instrument(span); + + scope.spawn(task); + offset += batch_size; + } } } - } - }); + }); } /// Returns a single immutable query result when there is exactly one entity matching diff --git a/crates/bevy_ecs/src/system/query.rs b/crates/bevy_ecs/src/system/query.rs index 2298f5636e1131..42e77e8461e0fb 100644 --- a/crates/bevy_ecs/src/system/query.rs +++ b/crates/bevy_ecs/src/system/query.rs @@ -7,7 +7,6 @@ use crate::{ }, world::{Mut, World}, }; -use bevy_tasks::TaskPool; use std::{any::TypeId, fmt::Debug}; /// Provides scoped access to components in a [`World`]. @@ -493,7 +492,7 @@ impl<'w, 's, Q: WorldQuery, F: WorldQuery> Query<'w, 's, Q, F> { }; } - /// Runs `f` on each query result in parallel using the given [`TaskPool`]. + /// Runs `f` on each query result in parallel using the [`World`]'s [`ComputeTaskPool`]. /// /// This can only be called for immutable data, see [`Self::par_for_each_mut`] for /// mutable access. @@ -502,7 +501,7 @@ impl<'w, 's, Q: WorldQuery, F: WorldQuery> Query<'w, 's, Q, F> { /// /// The items in the query get sorted into batches. /// Internally, this function spawns a group of futures that each take on a `batch_size` sized section of the items (or less if the division is not perfect). - /// Then, the tasks in the [`TaskPool`] work through these futures. + /// Then, the tasks in the [`ComputeTaskPool`] work through these futures. /// /// You can use this value to tune between maximum multithreading ability (many small batches) and minimum parallelization overhead (few big batches). /// Rule of thumb: If the function body is (mostly) computationally expensive but there are not many items, a small batch size (=more batches) may help to even out the load. @@ -510,13 +509,17 @@ impl<'w, 's, Q: WorldQuery, F: WorldQuery> Query<'w, 's, Q, F> { /// /// # Arguments /// - ///* `task_pool` - The [`TaskPool`] to use ///* `batch_size` - The number of batches to spawn ///* `f` - The function to run on each item in the query + /// + /// # Panics + /// The [`ComputeTaskPool`] resource must be added to the `World` before using this method. If using this from a query + /// that is being initialized and run from the ECS scheduler, this should never panic. + /// + /// [`ComputeTaskPool`]: bevy_tasks::prelude::ComputeTaskPool #[inline] pub fn par_for_each<'this>( &'this self, - task_pool: &TaskPool, batch_size: usize, f: impl Fn(ROQueryItem<'this, Q>) + Send + Sync + Clone, ) { @@ -526,7 +529,6 @@ impl<'w, 's, Q: WorldQuery, F: WorldQuery> Query<'w, 's, Q, F> { self.state .par_for_each_unchecked_manual::, _>( self.world, - task_pool, batch_size, f, self.last_change_tick, @@ -535,12 +537,17 @@ impl<'w, 's, Q: WorldQuery, F: WorldQuery> Query<'w, 's, Q, F> { }; } - /// Runs `f` on each query result in parallel using the given [`TaskPool`]. + /// Runs `f` on each query result in parallel using the [`World`]'s [`ComputeTaskPool`]. /// See [`Self::par_for_each`] for more details. + /// + /// # Panics + /// [`ComputeTaskPool`] was not stored in the world at initialzation. If using this from a query + /// that is being initialized and run from the ECS scheduler, this should never panic. + /// + /// [`ComputeTaskPool`]: bevy_tasks::prelude::ComputeTaskPool #[inline] pub fn par_for_each_mut<'a, FN: Fn(QueryItem<'a, Q>) + Send + Sync + Clone>( &'a mut self, - task_pool: &TaskPool, batch_size: usize, f: FN, ) { @@ -550,7 +557,6 @@ impl<'w, 's, Q: WorldQuery, F: WorldQuery> Query<'w, 's, Q, F> { self.state .par_for_each_unchecked_manual::, FN>( self.world, - task_pool, batch_size, f, self.last_change_tick, diff --git a/examples/ecs/parallel_query.rs b/examples/ecs/parallel_query.rs index f10cb1df148395..1df3e9bfc99c76 100644 --- a/examples/ecs/parallel_query.rs +++ b/examples/ecs/parallel_query.rs @@ -1,6 +1,6 @@ //! Illustrates parallel queries with `ParallelIterator`. -use bevy::{prelude::*, tasks::prelude::*}; +use bevy::prelude::*; use rand::random; #[derive(Component, Deref)] @@ -23,26 +23,22 @@ fn spawn_system(mut commands: Commands, asset_server: Res) { } // Move sprites according to their velocity -fn move_system(pool: Res, mut sprites: Query<(&mut Transform, &Velocity)>) { +fn move_system(mut sprites: Query<(&mut Transform, &Velocity)>) { // Compute the new location of each sprite in parallel on the // ComputeTaskPool using batches of 32 sprites // - // This example is only for demonstrative purposes. Using a + // This example is only for demonstrative purposes. Using a // ParallelIterator for an inexpensive operation like addition on only 128 // elements will not typically be faster than just using a normal Iterator. // See the ParallelIterator documentation for more information on when // to use or not use ParallelIterator over a normal Iterator. - sprites.par_for_each_mut(&pool, 32, |(mut transform, velocity)| { + sprites.par_for_each_mut(32, |(mut transform, velocity)| { transform.translation += velocity.extend(0.0); }); } // Bounce sprites outside the window -fn bounce_system( - pool: Res, - windows: Res, - mut sprites: Query<(&Transform, &mut Velocity)>, -) { +fn bounce_system(windows: Res, mut sprites: Query<(&Transform, &mut Velocity)>) { let window = windows.primary(); let width = window.width(); let height = window.height(); @@ -53,7 +49,7 @@ fn bounce_system( sprites // Batch size of 32 is chosen to limit the overhead of // ParallelIterator, since negating a vector is very inexpensive. - .par_for_each_mut(&pool, 32, |(transform, mut v)| { + .par_for_each_mut(32, |(transform, mut v)| { if !(left < transform.translation.x && transform.translation.x < right && bottom < transform.translation.y