Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: Extract parallel queue abstraction #7348

Merged
merged 13 commits into from
Feb 19, 2024
1 change: 0 additions & 1 deletion crates/bevy_ecs/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ bevy_ecs_macros = { path = "macros", version = "0.9.0" }

async-channel = "1.4"
event-listener = "2.5"
thread_local = "1.1.4"
fixedbitset = "0.4.2"
fxhash = "0.2"
downcast-rs = "1.2"
Expand Down
22 changes: 7 additions & 15 deletions crates/bevy_ecs/src/system/commands/parallel_scope.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
use std::cell::Cell;

use thread_local::ThreadLocal;
use bevy_utils::Parallel;

use crate::{
entity::Entities,
Expand All @@ -14,7 +12,7 @@ use super::{CommandQueue, Commands};
#[doc(hidden)]
#[derive(Default)]
pub struct ParallelCommandsState {
thread_local_storage: ThreadLocal<Cell<CommandQueue>>,
thread_queues: Parallel<CommandQueue>,
}

/// An alternative to [`Commands`] that can be used in parallel contexts, such as those in [`Query::par_iter`](crate::system::Query::par_iter)
Expand Down Expand Up @@ -62,8 +60,8 @@ unsafe impl SystemParam for ParallelCommands<'_, '_> {
let _system_span =
bevy_utils::tracing::info_span!("system_commands", name = _system_meta.name())
.entered();
for cq in &mut state.thread_local_storage {
cq.get_mut().apply(world);
for cq in state.thread_queues.iter_mut() {
cq.apply(world);
}
}

Expand All @@ -82,16 +80,10 @@ unsafe impl SystemParam for ParallelCommands<'_, '_> {

impl<'w, 's> ParallelCommands<'w, 's> {
pub fn command_scope<R>(&self, f: impl FnOnce(Commands) -> R) -> R {
let store = &self.state.thread_local_storage;
let command_queue_cell = store.get_or_default();
let mut command_queue = command_queue_cell.take();

let r = f(Commands::new_from_entities(
let mut command_queue = self.state.thread_queues.get();
f(Commands::new_from_entities(
&mut command_queue,
self.entities,
));

command_queue_cell.set(command_queue);
r
))
}
}
19 changes: 5 additions & 14 deletions crates/bevy_render/src/view/visibility/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@ use bevy_reflect::Reflect;
use bevy_reflect::{std_traits::ReflectDefault, FromReflect};
use bevy_transform::components::GlobalTransform;
use bevy_transform::TransformSystem;
use std::cell::Cell;
use thread_local::ThreadLocal;
use bevy_utils::Parallel;

use crate::{
camera::{
Expand Down Expand Up @@ -356,7 +355,7 @@ fn propagate_recursive(
/// [`ComputedVisibility`] of all entities, and for each view also compute the [`VisibleEntities`]
/// for that view.
pub fn check_visibility(
mut thread_queues: Local<ThreadLocal<Cell<Vec<Entity>>>>,
mut thread_queues: Local<Parallel<Vec<Entity>>>,
mut view_query: Query<(&mut VisibleEntities, &Frustum, Option<&RenderLayers>), With<Camera>>,
mut visible_aabb_query: Query<(
Entity,
Expand Down Expand Up @@ -413,10 +412,7 @@ pub fn check_visibility(
}

computed_visibility.set_visible_in_view();
let cell = thread_queues.get_or_default();
let mut queue = cell.take();
queue.push(entity);
cell.set(queue);
thread_queues.get().push(entity);
},
);

Expand All @@ -434,16 +430,11 @@ pub fn check_visibility(
}

computed_visibility.set_visible_in_view();
let cell = thread_queues.get_or_default();
let mut queue = cell.take();
queue.push(entity);
cell.set(queue);
thread_queues.get().push(entity);
},
);

for cell in thread_queues.iter_mut() {
visible_entities.entities.append(cell.get_mut());
}
thread_queues.drain_into(&mut visible_entities.entities);
}
}

Expand Down
1 change: 1 addition & 0 deletions crates/bevy_utils/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ uuid = { version = "1.1", features = ["v4", "serde"] }
hashbrown = { version = "0.12", features = ["serde"] }
petgraph = "0.6"
thiserror = "1.0"
thread_local = "1.0"

[target.'cfg(target_arch = "wasm32")'.dependencies]
getrandom = {version = "0.2.0", features = ["js"]}
2 changes: 2 additions & 0 deletions crates/bevy_utils/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,14 @@ pub mod syncunsafecell;

mod default;
mod float_ord;
mod parallel_queue;

pub use ahash::AHasher;
pub use default::default;
pub use float_ord::*;
pub use hashbrown;
pub use instant::{Duration, Instant};
pub use parallel_queue::*;
pub use petgraph;
pub use thiserror;
pub use tracing;
Expand Down
92 changes: 92 additions & 0 deletions crates/bevy_utils/src/parallel_queue.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
use core::{
cell::Cell,
ops::{Deref, DerefMut, Drop},
};
use thread_local::ThreadLocal;

/// A cohesive set of thread-local values of a given type.
///
/// Mutable references can be fetched if `T: Default` via [`Parallel::get`].
#[derive(Default)]
pub struct Parallel<T: Send> {
locals: ThreadLocal<Cell<T>>,
}

impl<T: Send> Parallel<T> {
/// Gets a mutable iterator over all of the per-thread queues.
pub fn iter_mut(&mut self) -> impl Iterator<Item = &'_ mut T> {
self.locals.iter_mut().map(|cell| cell.get_mut())
}

/// Clears all of the stored thread local values.
pub fn clear(&mut self) {
self.locals.clear();
}
}

impl<T: Default + Send> Parallel<T> {
/// Takes the thread-local value and replaces it with the default.
#[inline]
pub fn get(&self) -> ParRef<'_, T> {
james7132 marked this conversation as resolved.
Show resolved Hide resolved
let cell = self.locals.get_or_default();
let value = cell.take();
ParRef { cell, value }
}
}

impl<T, I> Parallel<I>
where
I: IntoIterator<Item = T> + Default + Send + 'static,
{
/// Collect all enqueued items from all threads and them into one
james7132 marked this conversation as resolved.
Show resolved Hide resolved
james7132 marked this conversation as resolved.
Show resolved Hide resolved
pub fn drain<B>(&mut self) -> B
where
B: FromIterator<T>,
{
self.locals
.iter_mut()
.flat_map(|item| item.take().into_iter())
james7132 marked this conversation as resolved.
Show resolved Hide resolved
.collect()
}
}

impl<T: Send> Parallel<Vec<T>> {
/// Collect all enqueued items from all threads and them into one
pub fn drain_into(&mut self, out: &mut Vec<T>) {
out.clear();
let size = self
.locals
.iter_mut()
.map(|queue| queue.get_mut().len())
.sum();
out.reserve(size);
for queue in self.locals.iter_mut() {
out.append(queue.get_mut());
}
}
james7132 marked this conversation as resolved.
Show resolved Hide resolved
}

/// A retrieved thread-local reference to a value in [`Parallel`].
pub struct ParRef<'a, T: Default> {
cell: &'a Cell<T>,
value: T,
}

impl<'a, T: Default> Deref for ParRef<'a, T> {
type Target = T;
fn deref(&self) -> &T {
&self.value
}
}

impl<'a, T: Default> DerefMut for ParRef<'a, T> {
fn deref_mut(&mut self) -> &mut T {
&mut self.value
}
}

impl<'a, T: Default> Drop for ParRef<'a, T> {
fn drop(&mut self) {
self.cell.set(core::mem::take(&mut self.value));
}
}