diff --git a/CHANGELOG.md b/CHANGELOG.md index ad6c7913e1..e5d4147a1e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -150,6 +150,7 @@ By @atlv24 in [#5383](https://github.com/gfx-rs/wgpu/pull/5383) - Ensure render pipelines have at least 1 target. By @ErichDonGubler in [#5715](https://github.com/gfx-rs/wgpu/pull/5715) - `wgpu::ComputePass` now internally takes ownership of `QuerySet` for both `wgpu::ComputePassTimestampWrites` as well as timestamp writes and statistics query, fixing crashes when destroying `QuerySet` before ending the pass. By @wumpf in [#5671](https://github.com/gfx-rs/wgpu/pull/5671) +- Validate resources passed during compute pass recording for mismatching device. By @wumpf in [#5779](https://github.com/gfx-rs/wgpu/pull/5779) #### Metal diff --git a/wgpu-core/src/command/compute.rs b/wgpu-core/src/command/compute.rs index 30284002e3..acbff0a030 100644 --- a/wgpu-core/src/command/compute.rs +++ b/wgpu-core/src/command/compute.rs @@ -53,6 +53,11 @@ pub struct ComputePass { // Resource binding dedupe state. current_bind_groups: BindGroupStateChange, current_pipeline: StateChange, + + /// The device that this pass is associated with. + /// + /// Used for quick validation during recording. + device_id: id::DeviceId, } impl ComputePass { @@ -63,6 +68,10 @@ impl ComputePass { timestamp_writes, } = desc; + let device_id = parent + .as_ref() + .map_or(id::DeviceId::dummy(0), |p| p.device.as_info().id()); + Self { base: Some(BasePass::new(label)), parent, @@ -70,6 +79,8 @@ impl ComputePass { current_bind_groups: BindGroupStateChange::new(), current_pipeline: StateChange::new(), + + device_id, } } @@ -350,6 +361,13 @@ impl Global { ); }; + if query_set.device.as_info().id() != cmd_buf.device.as_info().id() { + return ( + ComputePass::new(None, arc_desc), + Some(CommandEncoderError::WrongDeviceForTimestampWritesQuerySet), + ); + } + Some(ArcComputePassTimestampWrites { query_set, beginning_of_pass_write_index: tw.beginning_of_pass_write_index, @@ -976,6 +994,10 @@ impl Global { .map_err(|_| ComputePassErrorInner::InvalidBindGroup(index)) .map_pass_err(scope)?; + if bind_group.device.as_info().id() != pass.device_id { + return Err(DeviceError::WrongDevice).map_pass_err(scope); + } + base.commands.push(ArcComputeCommand::SetBindGroup { index, num_dynamic_offsets: offsets.len(), @@ -993,8 +1015,9 @@ impl Global { let redundant = pass.current_pipeline.set_and_check_redundant(pipeline_id); let scope = PassErrorScope::SetPipelineCompute(pipeline_id); - let base = pass.base_mut(scope)?; + let device_id = pass.device_id; + let base = pass.base_mut(scope)?; if redundant { // Do redundant early-out **after** checking whether the pass is ended or not. return Ok(()); @@ -1008,6 +1031,10 @@ impl Global { .map_err(|_| ComputePassErrorInner::InvalidPipeline(pipeline_id)) .map_pass_err(scope)?; + if pipeline.device.as_info().id() != device_id { + return Err(DeviceError::WrongDevice).map_pass_err(scope); + } + base.commands.push(ArcComputeCommand::SetPipeline(pipeline)); Ok(()) @@ -1081,6 +1108,7 @@ impl Global { indirect: true, pipeline: pass.current_pipeline.last_state, }; + let device_id = pass.device_id; let base = pass.base_mut(scope)?; let buffer = hub @@ -1090,6 +1118,10 @@ impl Global { .map_err(|_| ComputePassErrorInner::InvalidBuffer(buffer_id)) .map_pass_err(scope)?; + if buffer.device.as_info().id() != device_id { + return Err(DeviceError::WrongDevice).map_pass_err(scope); + } + base.commands .push(ArcComputeCommand::::DispatchIndirect { buffer, offset }); @@ -1153,6 +1185,7 @@ impl Global { query_index: u32, ) -> Result<(), ComputePassError> { let scope = PassErrorScope::WriteTimestamp; + let device_id = pass.device_id; let base = pass.base_mut(scope)?; let hub = A::hub(self); @@ -1163,6 +1196,10 @@ impl Global { .map_err(|_| ComputePassErrorInner::InvalidQuerySet(query_set_id)) .map_pass_err(scope)?; + if query_set.device.as_info().id() != device_id { + return Err(DeviceError::WrongDevice).map_pass_err(scope); + } + base.commands.push(ArcComputeCommand::WriteTimestamp { query_set, query_index, @@ -1178,6 +1215,7 @@ impl Global { query_index: u32, ) -> Result<(), ComputePassError> { let scope = PassErrorScope::BeginPipelineStatisticsQuery; + let device_id = pass.device_id; let base = pass.base_mut(scope)?; let hub = A::hub(self); @@ -1188,6 +1226,10 @@ impl Global { .map_err(|_| ComputePassErrorInner::InvalidQuerySet(query_set_id)) .map_pass_err(scope)?; + if query_set.device.as_info().id() != device_id { + return Err(DeviceError::WrongDevice).map_pass_err(scope); + } + base.commands .push(ArcComputeCommand::BeginPipelineStatisticsQuery { query_set, diff --git a/wgpu-core/src/command/mod.rs b/wgpu-core/src/command/mod.rs index 13731f22c6..874e207a27 100644 --- a/wgpu-core/src/command/mod.rs +++ b/wgpu-core/src/command/mod.rs @@ -633,8 +633,11 @@ pub enum CommandEncoderError { Device(#[from] DeviceError), #[error("Command encoder is locked by a previously created render/compute pass. Before recording any new commands, the pass must be ended.")] Locked, + #[error("QuerySet provided for pass timestamp writes is invalid.")] InvalidTimestampWritesQuerySetId, + #[error("QuerySet provided for pass timestamp writes that was created by a different device.")] + WrongDeviceForTimestampWritesQuerySet, } impl Global {