Skip to content

Commit

Permalink
Remove lifetime constraints from wgpu::ComputePass methods (#5570)
Browse files Browse the repository at this point in the history
* basic test setup

* remove lifetime and drop resources on test - test fails now just as expected

* compute pass recording is now hub dependent (needs gfx_select)

* compute pass recording now bumps reference count of uses resources directly on recording

TODO:
* bind groups don't work because the Binder gets an id only
* wgpu level error handling is missing

* simplify compute pass state flush, compute pass execution no longer needs to lock bind_group storage

* wgpu sided error handling

* make ComputePass hal dependent, removing command cast hack. Introduce DynComputePass on wgpu side

* remove stray repr(C)

* changelog entry

* fix deno issues -> move DynComputePass into wgc

* split out resources setup from test
  • Loading branch information
Wumpf authored May 14, 2024
1 parent 00456cf commit 77a83fb
Show file tree
Hide file tree
Showing 11 changed files with 613 additions and 190 deletions.
8 changes: 8 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,14 @@ Bottom level categories:

### Major Changes

#### Remove lifetime bounds on `wgpu::ComputePass`

TODO(wumpf): This is still work in progress. Should write a bit more about it. Also will very likely extend to `wgpu::RenderPass` before release.

`wgpu::ComputePass` recording methods (e.g. `wgpu::ComputePass:set_render_pipeline`) no longer impose a lifetime constraint passed in resources.

By @wumpf in [#5569](https://github.com/gfx-rs/wgpu/pull/5569).

#### Querying shader compilation errors

Wgpu now supports querying [shader compilation info](https://www.w3.org/TR/webgpu/#dom-gpushadermodule-getcompilationinfo).
Expand Down
5 changes: 3 additions & 2 deletions deno_webgpu/command_encoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -254,13 +254,14 @@ pub fn op_webgpu_command_encoder_begin_compute_pass(
None
};

let instance = state.borrow::<super::Instance>();
let command_encoder = &command_encoder_resource.1;
let descriptor = wgpu_core::command::ComputePassDescriptor {
label: Some(label),
timestamp_writes: timestamp_writes.as_ref(),
};

let compute_pass =
wgpu_core::command::ComputePass::new(command_encoder_resource.1, &descriptor);
let compute_pass = gfx_select!(command_encoder => instance.command_encoder_create_compute_pass_dyn(*command_encoder, &descriptor));

let rid = state
.resource_table
Expand Down
67 changes: 29 additions & 38 deletions deno_webgpu/compute_pass.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@ use std::cell::RefCell;

use super::error::WebGpuResult;

pub(crate) struct WebGpuComputePass(pub(crate) RefCell<wgpu_core::command::ComputePass>);
pub(crate) struct WebGpuComputePass(
pub(crate) RefCell<Box<dyn wgpu_core::command::DynComputePass>>,
);
impl Resource for WebGpuComputePass {
fn name(&self) -> Cow<str> {
"webGPUComputePass".into()
Expand All @@ -31,10 +33,10 @@ pub fn op_webgpu_compute_pass_set_pipeline(
.resource_table
.get::<WebGpuComputePass>(compute_pass_rid)?;

wgpu_core::command::compute_commands::wgpu_compute_pass_set_pipeline(
&mut compute_pass_resource.0.borrow_mut(),
compute_pipeline_resource.1,
);
compute_pass_resource
.0
.borrow_mut()
.set_pipeline(state.borrow(), compute_pipeline_resource.1)?;

Ok(WebGpuResult::empty())
}
Expand All @@ -52,12 +54,10 @@ pub fn op_webgpu_compute_pass_dispatch_workgroups(
.resource_table
.get::<WebGpuComputePass>(compute_pass_rid)?;

wgpu_core::command::compute_commands::wgpu_compute_pass_dispatch_workgroups(
&mut compute_pass_resource.0.borrow_mut(),
x,
y,
z,
);
compute_pass_resource
.0
.borrow_mut()
.dispatch_workgroups(state.borrow(), x, y, z);

Ok(WebGpuResult::empty())
}
Expand All @@ -77,11 +77,10 @@ pub fn op_webgpu_compute_pass_dispatch_workgroups_indirect(
.resource_table
.get::<WebGpuComputePass>(compute_pass_rid)?;

wgpu_core::command::compute_commands::wgpu_compute_pass_dispatch_workgroups_indirect(
&mut compute_pass_resource.0.borrow_mut(),
buffer_resource.1,
indirect_offset,
);
compute_pass_resource
.0
.borrow_mut()
.dispatch_workgroups_indirect(state.borrow(), buffer_resource.1, indirect_offset)?;

Ok(WebGpuResult::empty())
}
Expand All @@ -90,24 +89,15 @@ pub fn op_webgpu_compute_pass_dispatch_workgroups_indirect(
#[serde]
pub fn op_webgpu_compute_pass_end(
state: &mut OpState,
#[smi] command_encoder_rid: ResourceId,
#[smi] compute_pass_rid: ResourceId,
) -> Result<WebGpuResult, AnyError> {
let command_encoder_resource =
state
.resource_table
.get::<super::command_encoder::WebGpuCommandEncoder>(command_encoder_rid)?;
let command_encoder = command_encoder_resource.1;
let compute_pass_resource = state
.resource_table
.take::<WebGpuComputePass>(compute_pass_rid)?;
let compute_pass = &compute_pass_resource.0.borrow();
let instance = state.borrow::<super::Instance>();

gfx_ok!(command_encoder => instance.command_encoder_run_compute_pass(
command_encoder,
compute_pass
))
compute_pass_resource.0.borrow_mut().run(state.borrow())?;

Ok(WebGpuResult::empty())
}

#[op2]
Expand Down Expand Up @@ -137,12 +127,12 @@ pub fn op_webgpu_compute_pass_set_bind_group(

let dynamic_offsets_data: &[u32] = &dynamic_offsets_data[start..start + len];

wgpu_core::command::compute_commands::wgpu_compute_pass_set_bind_group(
&mut compute_pass_resource.0.borrow_mut(),
compute_pass_resource.0.borrow_mut().set_bind_group(
state.borrow(),
index,
bind_group_resource.1,
dynamic_offsets_data,
);
)?;

Ok(WebGpuResult::empty())
}
Expand All @@ -158,8 +148,8 @@ pub fn op_webgpu_compute_pass_push_debug_group(
.resource_table
.get::<WebGpuComputePass>(compute_pass_rid)?;

wgpu_core::command::compute_commands::wgpu_compute_pass_push_debug_group(
&mut compute_pass_resource.0.borrow_mut(),
compute_pass_resource.0.borrow_mut().push_debug_group(
state.borrow(),
group_label,
0, // wgpu#975
);
Expand All @@ -177,9 +167,10 @@ pub fn op_webgpu_compute_pass_pop_debug_group(
.resource_table
.get::<WebGpuComputePass>(compute_pass_rid)?;

wgpu_core::command::compute_commands::wgpu_compute_pass_pop_debug_group(
&mut compute_pass_resource.0.borrow_mut(),
);
compute_pass_resource
.0
.borrow_mut()
.pop_debug_group(state.borrow());

Ok(WebGpuResult::empty())
}
Expand All @@ -195,8 +186,8 @@ pub fn op_webgpu_compute_pass_insert_debug_marker(
.resource_table
.get::<WebGpuComputePass>(compute_pass_rid)?;

wgpu_core::command::compute_commands::wgpu_compute_pass_insert_debug_marker(
&mut compute_pass_resource.0.borrow_mut(),
compute_pass_resource.0.borrow_mut().insert_debug_marker(
state.borrow(),
marker_label,
0, // wgpu#975
);
Expand Down
174 changes: 174 additions & 0 deletions tests/tests/compute_pass_resource_ownership.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
//! Tests that compute passes take ownership of resources that are associated with.
//! I.e. once a resource is passed in to a compute pass, it can be dropped.
//!
//! TODO: Test doesn't check on timestamp writes & pipeline statistics queries yet.
//! (Not important as long as they are lifetime constrained to the command encoder,
//! but once we lift this constraint, we should add tests for this as well!)
//! TODO: Also should test resource ownership for:
//! * write_timestamp
//! * begin_pipeline_statistics_query

use std::num::NonZeroU64;

use wgpu::util::DeviceExt as _;
use wgpu_test::{gpu_test, GpuTestConfiguration, TestParameters, TestingContext};

const SHADER_SRC: &str = "
@group(0) @binding(0)
var<storage, read_write> buffer: array<vec4f>;
@compute @workgroup_size(1, 1, 1) fn main() {
buffer[0] *= 2.0;
}
";

#[gpu_test]
static COMPUTE_PASS_RESOURCE_OWNERSHIP: GpuTestConfiguration = GpuTestConfiguration::new()
.parameters(TestParameters::default().test_features_limits())
.run_async(compute_pass_resource_ownership);

async fn compute_pass_resource_ownership(ctx: TestingContext) {
let ResourceSetup {
gpu_buffer,
cpu_buffer,
buffer_size,
indirect_buffer,
bind_group,
pipeline,
} = resource_setup(&ctx);

let mut encoder = ctx
.device
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("encoder"),
});

{
let mut cpass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("compute_pass"),
timestamp_writes: None, // TODO: See description above, we should test this as well once we lift the lifetime bound.
});
cpass.set_pipeline(&pipeline);
cpass.set_bind_group(0, &bind_group, &[]);
cpass.dispatch_workgroups_indirect(&indirect_buffer, 0);

// Now drop all resources we set. Then do a device poll to make sure the resources are really not dropped too early, no matter what.
drop(pipeline);
drop(bind_group);
drop(indirect_buffer);
ctx.async_poll(wgpu::Maintain::wait())
.await
.panic_on_timeout();
}

// Ensure that the compute pass still executed normally.
encoder.copy_buffer_to_buffer(&gpu_buffer, 0, &cpu_buffer, 0, buffer_size);
ctx.queue.submit([encoder.finish()]);
cpu_buffer.slice(..).map_async(wgpu::MapMode::Read, |_| ());
ctx.async_poll(wgpu::Maintain::wait())
.await
.panic_on_timeout();

let data = cpu_buffer.slice(..).get_mapped_range();

let floats: &[f32] = bytemuck::cast_slice(&data);
assert_eq!(floats, [2.0, 4.0, 6.0, 8.0]);
}

// Setup ------------------------------------------------------------

struct ResourceSetup {
gpu_buffer: wgpu::Buffer,
cpu_buffer: wgpu::Buffer,
buffer_size: u64,

indirect_buffer: wgpu::Buffer,
bind_group: wgpu::BindGroup,
pipeline: wgpu::ComputePipeline,
}

fn resource_setup(ctx: &TestingContext) -> ResourceSetup {
let sm = ctx
.device
.create_shader_module(wgpu::ShaderModuleDescriptor {
label: Some("shader"),
source: wgpu::ShaderSource::Wgsl(SHADER_SRC.into()),
});

let buffer_size = 4 * std::mem::size_of::<f32>() as u64;

let bgl = ctx
.device
.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
label: Some("bind_group_layout"),
entries: &[wgpu::BindGroupLayoutEntry {
binding: 0,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: false },
has_dynamic_offset: false,
min_binding_size: NonZeroU64::new(buffer_size),
},
count: None,
}],
});

let gpu_buffer = ctx
.device
.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: Some("gpu_buffer"),
usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
contents: bytemuck::bytes_of(&[1.0_f32, 2.0, 3.0, 4.0]),
});

let cpu_buffer = ctx.device.create_buffer(&wgpu::BufferDescriptor {
label: Some("cpu_buffer"),
size: buffer_size,
usage: wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ,
mapped_at_creation: false,
});

let indirect_buffer = ctx
.device
.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: Some("gpu_buffer"),
usage: wgpu::BufferUsages::INDIRECT,
contents: wgpu::util::DispatchIndirectArgs { x: 1, y: 1, z: 1 }.as_bytes(),
});

let bind_group = ctx.device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("bind_group"),
layout: &bgl,
entries: &[wgpu::BindGroupEntry {
binding: 0,
resource: gpu_buffer.as_entire_binding(),
}],
});

let pipeline_layout = ctx
.device
.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
label: Some("pipeline_layout"),
bind_group_layouts: &[&bgl],
push_constant_ranges: &[],
});

let pipeline = ctx
.device
.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
label: Some("pipeline"),
layout: Some(&pipeline_layout),
module: &sm,
entry_point: "main",
compilation_options: Default::default(),
});

ResourceSetup {
gpu_buffer,
cpu_buffer,
buffer_size,
indirect_buffer,
bind_group,
pipeline,
}
}
1 change: 1 addition & 0 deletions tests/tests/root.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ mod buffer;
mod buffer_copy;
mod buffer_usages;
mod clear_texture;
mod compute_pass_resource_ownership;
mod create_surface_error;
mod device;
mod encoder;
Expand Down
5 changes: 2 additions & 3 deletions wgpu-core/src/command/bind.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ use crate::{
binding_model::{BindGroup, LateMinBufferBindingSizeMismatch, PipelineLayout},
device::SHADER_STAGE_COUNT,
hal_api::HalApi,
id::BindGroupId,
pipeline::LateSizedBufferGroup,
resource::Resource,
};
Expand Down Expand Up @@ -359,11 +358,11 @@ impl<A: HalApi> Binder<A> {
&self.payloads[bind_range]
}

pub(super) fn list_active(&self) -> impl Iterator<Item = BindGroupId> + '_ {
pub(super) fn list_active<'a>(&'a self) -> impl Iterator<Item = &'a Arc<BindGroup<A>>> + '_ {
let payloads = &self.payloads;
self.manager
.list_active()
.map(move |index| payloads[index].group.as_ref().unwrap().as_info().id())
.map(move |index| payloads[index].group.as_ref().unwrap())
}

pub(super) fn invalid_mask(&self) -> BindGroupMask {
Expand Down
Loading

0 comments on commit 77a83fb

Please sign in to comment.