Skip to content

Commit

Permalink
Implement bind group layout deduplication for all configurations
Browse files Browse the repository at this point in the history
Currently wgpu-core implement bind group layout deduplication only when it creates its own resource IDs. In other words it works for wgpu but not in Firefox.
This PR bridges the gap by allowing an optional indirection in bind group layouts: each BGL may store an ID referring to its "deduplicated" BGL.
When referring to a BGL the rest of the code must make sure to follow the indirection. The exception is command buffer processing which is considered hot code and where we first validate against the provided BGL ID and only follow the indirection if the initial check failed.

The main pain point with this approach is the various places where wgpu-core manually updates reference counts: we have to be careful about following the indirection to track the right BGL.
  • Loading branch information
nical committed Jul 12, 2023
1 parent b3de786 commit 1af673a
Show file tree
Hide file tree
Showing 8 changed files with 165 additions and 48 deletions.
32 changes: 32 additions & 0 deletions wgpu-core/src/binding_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,8 @@ pub struct BindGroupLayoutDescriptor<'a> {

pub(crate) type BindEntryMap = FastHashMap<u32, wgt::BindGroupLayoutEntry>;

pub type BindGroupLayouts<A> = crate::storage::Storage<BindGroupLayout<A>, BindGroupLayoutId>;

/// Bind group layout.
///
/// The lifetime of BGLs is a bit special. They are only referenced on CPU
Expand All @@ -450,6 +452,12 @@ pub struct BindGroupLayout<A: hal::Api> {
pub(crate) device_id: Stored<DeviceId>,
pub(crate) multi_ref_count: MultiRefCount,
pub(crate) entries: BindEntryMap,
// When a layout created and there already exists a compatible layout the new layout
// keeps a reference to the older compatible one. In some places we the substitute
// the bind group layout id with its compatible sibbling.
// Since this substitution can come at a cost, it is skipped when wgpu-core generates
// its own resource IDs.
pub(crate) compatible_layout: Option<Valid<BindGroupLayoutId>>,
#[allow(unused)]
pub(crate) dynamic_count: usize,
pub(crate) count_validator: BindingTypeMaxCountValidator,
Expand All @@ -472,6 +480,30 @@ impl<A: hal::Api> Resource for BindGroupLayout<A> {
}
}

// If a bindgroup needs to be substitued with its compatible equivalent, return the latter.
pub(crate) fn try_get_bind_group_layout<A: HalApi>(
layouts: &BindGroupLayouts<A>,
id: BindGroupLayoutId,
) -> Option<&BindGroupLayout<A>> {
let layout = layouts.get(id).ok()?;
if let Some(compat) = layout.compatible_layout {
return Some(&layouts[compat]);
}

Some(layout)
}

pub(crate) fn get_bind_group_layout<A: HalApi>(
layouts: &BindGroupLayouts<A>,
id: Valid<BindGroupLayoutId>,
) -> (Valid<BindGroupLayoutId>, &BindGroupLayout<A>) {
let layout = &layouts[id];
layout
.compatible_layout
.map(|compat| (compat, &layouts[compat]))
.unwrap_or((id, layout))
}

#[derive(Clone, Debug, Error)]
#[non_exhaustive]
pub enum CreatePipelineLayoutError {
Expand Down
35 changes: 27 additions & 8 deletions wgpu-core/src/command/bind.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use crate::{
binding_model::{BindGroup, LateMinBufferBindingSizeMismatch, PipelineLayout},
binding_model::{
BindGroup, BindGroupLayouts, LateMinBufferBindingSizeMismatch, PipelineLayout,
},
device::SHADER_STAGE_COUNT,
hal_api::HalApi,
id::{BindGroupId, PipelineLayoutId, Valid},
Expand All @@ -13,7 +15,10 @@ use arrayvec::ArrayVec;
type BindGroupMask = u8;

mod compat {
use crate::id::{BindGroupLayoutId, Valid};
use crate::{
binding_model::BindGroupLayouts,
id::{BindGroupLayoutId, Valid},
};
use std::ops::Range;

#[derive(Debug, Default)]
Expand All @@ -27,8 +32,16 @@ mod compat {
self.assigned.is_some() && self.expected.is_some()
}

fn is_valid(&self) -> bool {
self.expected.is_none() || self.expected == self.assigned
fn is_valid<A: hal::Api>(&self, bind_group_layouts: &BindGroupLayouts<A>) -> bool {
if self.expected.is_none() || self.expected == self.assigned {
return true;
}

if let Some(id) = self.assigned {
return bind_group_layouts[id].compatible_layout == self.expected;
}

false
}
}

Expand Down Expand Up @@ -88,9 +101,12 @@ mod compat {
.filter_map(|(i, e)| if e.is_active() { Some(i) } else { None })
}

pub fn invalid_mask(&self) -> super::BindGroupMask {
pub fn invalid_mask<A: hal::Api>(
&self,
bind_group_layouts: &BindGroupLayouts<A>,
) -> super::BindGroupMask {
self.entries.iter().enumerate().fold(0, |mask, (i, entry)| {
if entry.is_valid() {
if entry.is_valid(bind_group_layouts) {
mask
} else {
mask | 1u8 << i
Expand Down Expand Up @@ -276,8 +292,11 @@ impl Binder {
.map(move |index| payloads[index].group_id.as_ref().unwrap().value)
}

pub(super) fn invalid_mask(&self) -> BindGroupMask {
self.manager.invalid_mask()
pub(super) fn invalid_mask<A: hal::Api>(
&self,
bind_group_layouts: &BindGroupLayouts<A>,
) -> BindGroupMask {
self.manager.invalid_mask(bind_group_layouts)
}

/// Scan active buffer bindings corresponding to layouts without `min_binding_size` specified.
Expand Down
16 changes: 11 additions & 5 deletions wgpu-core/src/command/compute.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use crate::{
binding_model::{
BindError, BindGroup, LateMinBufferBindingSizeMismatch, PushConstantUploadError,
BindError, BindGroup, BindGroupLayouts, LateMinBufferBindingSizeMismatch,
PushConstantUploadError,
},
command::{
bind::Binder,
Expand Down Expand Up @@ -257,8 +258,8 @@ struct State<A: HalApi> {
}

impl<A: HalApi> State<A> {
fn is_ready(&self) -> Result<(), DispatchError> {
let bind_mask = self.binder.invalid_mask();
fn is_ready(&self, bind_group_layouts: &BindGroupLayouts<A>) -> Result<(), DispatchError> {
let bind_mask = self.binder.invalid_mask(bind_group_layouts);
if bind_mask != 0 {
//let (expected, provided) = self.binder.entries[index as usize].info();
return Err(DispatchError::IncompatibleBindGroup {
Expand Down Expand Up @@ -371,6 +372,7 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
let (bind_group_guard, mut token) = hub.bind_groups.read(&mut token);
let (pipeline_guard, mut token) = hub.compute_pipelines.read(&mut token);
let (query_set_guard, mut token) = hub.query_sets.read(&mut token);
let (bind_group_layout_guard, mut token) = hub.bind_group_layouts.read(&mut token);
let (buffer_guard, mut token) = hub.buffers.read(&mut token);
let (texture_guard, _) = hub.textures.read(&mut token);

Expand Down Expand Up @@ -591,7 +593,9 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
pipeline: state.pipeline,
};

state.is_ready().map_pass_err(scope)?;
state
.is_ready(&*bind_group_layout_guard)
.map_pass_err(scope)?;
state
.flush_states(
raw,
Expand Down Expand Up @@ -628,7 +632,9 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
pipeline: state.pipeline,
};

state.is_ready().map_pass_err(scope)?;
state
.is_ready(&*bind_group_layout_guard)
.map_pass_err(scope)?;

device
.require_downlevel_flags(wgt::DownlevelFlags::INDIRECT_EXECUTION)
Expand Down
27 changes: 20 additions & 7 deletions wgpu-core/src/command/render.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::{
binding_model::BindError,
binding_model::{BindError, BindGroupLayouts},
command::{
self,
bind::Binder,
Expand Down Expand Up @@ -386,7 +386,11 @@ struct State {
}

impl State {
fn is_ready(&self, indexed: bool) -> Result<(), DrawError> {
fn is_ready<A: hal::Api>(
&self,
indexed: bool,
bind_group_layouts: &BindGroupLayouts<A>,
) -> Result<(), DrawError> {
// Determine how many vertex buffers have already been bound
let vertex_buffer_count = self.vertex.inputs.iter().take_while(|v| v.bound).count() as u32;
// Compare with the needed quantity
Expand All @@ -396,7 +400,7 @@ impl State {
});
}

let bind_mask = self.binder.invalid_mask();
let bind_mask = self.binder.invalid_mask(bind_group_layouts);
if bind_mask != 0 {
//let (expected, provided) = self.binder.entries[index as usize].info();
return Err(DrawError::IncompatibleBindGroup {
Expand Down Expand Up @@ -1252,6 +1256,7 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
let (bind_group_guard, mut token) = hub.bind_groups.read(&mut token);
let (render_pipeline_guard, mut token) = hub.render_pipelines.read(&mut token);
let (query_set_guard, mut token) = hub.query_sets.read(&mut token);
let (bind_group_layout_guard, mut token) = hub.bind_group_layouts.read(&mut token);
let (buffer_guard, mut token) = hub.buffers.read(&mut token);
let (texture_guard, mut token) = hub.textures.read(&mut token);
let (view_guard, _) = hub.texture_views.read(&mut token);
Expand Down Expand Up @@ -1713,7 +1718,9 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
indirect: false,
pipeline: state.pipeline,
};
state.is_ready(indexed).map_pass_err(scope)?;
state
.is_ready::<A>(indexed, &bind_group_layout_guard)
.map_pass_err(scope)?;

let last_vertex = first_vertex + vertex_count;
let vertex_limit = state.vertex.vertex_limit;
Expand Down Expand Up @@ -1753,7 +1760,9 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
indirect: false,
pipeline: state.pipeline,
};
state.is_ready(indexed).map_pass_err(scope)?;
state
.is_ready::<A>(indexed, &*bind_group_layout_guard)
.map_pass_err(scope)?;

//TODO: validate that base_vertex + max_index() is
// within the provided range
Expand Down Expand Up @@ -1798,7 +1807,9 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
indirect: true,
pipeline: state.pipeline,
};
state.is_ready(indexed).map_pass_err(scope)?;
state
.is_ready::<A>(indexed, &*bind_group_layout_guard)
.map_pass_err(scope)?;

let stride = match indexed {
false => mem::size_of::<wgt::DrawIndirectArgs>(),
Expand Down Expand Up @@ -1870,7 +1881,9 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
indirect: true,
pipeline: state.pipeline,
};
state.is_ready(indexed).map_pass_err(scope)?;
state
.is_ready::<A>(indexed, &*bind_group_layout_guard)
.map_pass_err(scope)?;

let stride = match indexed {
false => mem::size_of::<wgt::DrawIndirectArgs>(),
Expand Down
57 changes: 42 additions & 15 deletions wgpu-core/src/device/global.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
#[cfg(feature = "trace")]
use crate::device::trace;
use crate::{
binding_model, command, conv,
binding_model,
command, conv,
device::{life::WaitIdleError, map_buffer, queue, Device, DeviceError, HostMap},
global::Global,
hal_api::HalApi,
Expand Down Expand Up @@ -1069,19 +1070,30 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
}
}

// If there is an equivalent BGL, just bump the refcount and return it.
// This is only applicable for identity filters that are generating new IDs,
// so their inputs are `PhantomData` of size 0.
if mem::size_of::<Input<G, id::BindGroupLayoutId>>() == 0 {
let mut compatible_layout = None;
{
let (bgl_guard, _) = hub.bind_group_layouts.read(&mut token);
if let Some(id) =
Device::deduplicate_bind_group_layout(device_id, &entry_map, &*bgl_guard)
{
return (id, None);
// If there is an equivalent BGL, just bump the refcount and return it.
// This is only applicable to identity filters that generate their IDs,
// which use PhantomData as their identity.
// In practice this means:
// - wgpu users take this branch and return the existing
// id without using the indirection layer in BindGrouoLayout.
// - Other users like gecko or the replay tool use don't take
// the branch and instead rely on the indirection to use the
// proper bind group layout id.
if std::mem::size_of::<Input<G, id::BindGroupLayoutId>>() == 0 {
return (id, None);
}

compatible_layout = Some(id::Valid(id));
}
}

let layout = match device.create_bind_group_layout(
let mut layout = match device.create_bind_group_layout(
device_id,
desc.label.borrow_option(),
entry_map,
Expand All @@ -1090,7 +1102,10 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
Err(e) => break e,
};

layout.compatible_layout = compatible_layout;

let id = fid.assign(layout, &mut token);

return (id.0, None);
};

Expand Down Expand Up @@ -1235,16 +1250,28 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
.add(trace::Action::CreateBindGroup(fid.id(), desc.clone()));
}

let bind_group_layout = match bind_group_layout_guard.get(desc.layout) {
let mut bind_group_layout = match bind_group_layout_guard.get(desc.layout) {
Ok(layout) => layout,
Err(_) => break binding_model::CreateBindGroupError::InvalidLayout,
Err(..) => break binding_model::CreateBindGroupError::InvalidLayout,
};

let mut layout_id = id::Valid(desc.layout);
if let Some(id) = bind_group_layout.compatible_layout {
layout_id = id;
bind_group_layout = &bind_group_layout_guard[id];
}

let bind_group = match device.create_bind_group(
device_id,
bind_group_layout,
layout_id,
desc,
hub,
&mut token,
) {
Ok(bind_group) => bind_group,
Err(e) => break e,
};
let bind_group =
match device.create_bind_group(device_id, bind_group_layout, desc, hub, &mut token)
{
Ok(bind_group) => bind_group,
Err(e) => break e,
};
let ref_count = bind_group.life_guard.add_ref();

let id = fid.assign(bind_group, &mut token);
Expand Down
26 changes: 18 additions & 8 deletions wgpu-core/src/device/life.rs
Original file line number Diff line number Diff line change
Expand Up @@ -762,14 +762,24 @@ impl<A: HalApi> LifetimeTracker<A> {
//Note: nothing else can bump the refcount since the guard is locked exclusively
//Note: same BGL can appear multiple times in the list, but only the last
// encounter could drop the refcount to 0.
if guard[id].multi_ref_count.dec_and_check_empty() {
log::debug!("Bind group layout {:?} will be destroyed", id);
#[cfg(feature = "trace")]
if let Some(t) = trace {
t.lock().add(trace::Action::DestroyBindGroupLayout(id.0));
}
if let Some(lay) = hub.bind_group_layouts.unregister_locked(id.0, &mut *guard) {
self.free_resources.bind_group_layouts.push(lay.raw);
let mut bgl_to_check = Some(id);
while let Some(id) = bgl_to_check.take() {
let bgl = &guard[id];
if bgl.multi_ref_count.dec_and_check_empty() {
// If This layout points to a compatible one, go over the latter
// to decrement the ref count and potentially destroy it.
bgl_to_check = bgl.compatible_layout;

log::debug!("Bind group layout {:?} will be destroyed", id);
#[cfg(feature = "trace")]
if let Some(t) = trace {
t.lock().add(trace::Action::DestroyBindGroupLayout(id.0));
}
if let Some(lay) =
hub.bind_group_layouts.unregister_locked(id.0, &mut *guard)
{
self.free_resources.bind_group_layouts.push(lay.raw);
}
}
}
}
Expand Down
Loading

0 comments on commit 1af673a

Please sign in to comment.