From 57df0e02206861e32b4ae1fa03594a8ec722afaa Mon Sep 17 00:00:00 2001 From: atlas dostal Date: Wed, 24 Apr 2024 20:20:20 -0700 Subject: [PATCH] Extract a naga pub create_validator function for use in Bevy --- wgpu-core/src/device/resource.rs | 140 ++++++++++++++++--------------- 1 file changed, 74 insertions(+), 66 deletions(-) diff --git a/wgpu-core/src/device/resource.rs b/wgpu-core/src/device/resource.rs index 01def56ce7..3fb7aca3cb 100644 --- a/wgpu-core/src/device/resource.rs +++ b/wgpu-core/src/device/resource.rs @@ -1474,9 +1474,78 @@ impl Device { }; } - use naga::valid::Capabilities as Caps; profiling::scope!("naga::validate"); + let debug_source = + if self.instance_flags.contains(wgt::InstanceFlags::DEBUG) && !source.is_empty() { + Some(hal::DebugSource { + file_name: Cow::Owned( + desc.label + .as_ref() + .map_or("shader".to_string(), |l| l.to_string()), + ), + source_code: Cow::Owned(source.clone()), + }) + } else { + None + }; + + let info = self + .create_validator(naga::valid::ValidationFlags::all()) + .validate(&module) + .map_err(|inner| { + pipeline::CreateShaderModuleError::Validation(pipeline::ShaderError { + source, + label: desc.label.as_ref().map(|l| l.to_string()), + inner: Box::new(inner), + }) + })?; + + let interface = + validation::Interface::new(&module, &info, self.limits.clone(), self.features); + let hal_shader = hal::ShaderInput::Naga(hal::NagaShader { + module, + info, + debug_source, + }); + let hal_desc = hal::ShaderModuleDescriptor { + label: desc.label.to_hal(self.instance_flags), + runtime_checks: desc.shader_bound_checks.runtime_checks(), + }; + let raw = match unsafe { + self.raw + .as_ref() + .unwrap() + .create_shader_module(&hal_desc, hal_shader) + } { + Ok(raw) => raw, + Err(error) => { + return Err(match error { + hal::ShaderError::Device(error) => { + pipeline::CreateShaderModuleError::Device(error.into()) + } + hal::ShaderError::Compilation(ref msg) => { + log::error!("Shader error: {}", msg); + pipeline::CreateShaderModuleError::Generation + } + }) + } + }; + + Ok(pipeline::ShaderModule { + raw: Some(raw), + device: self.clone(), + interface: Some(interface), + info: ResourceInfo::new(desc.label.borrow_or_default(), None), + label: desc.label.borrow_or_default().to_string(), + }) + } + /// Create a validator with the given validation flags. + pub fn create_validator( + self: &Arc, + flags: naga::valid::ValidationFlags, + ) -> naga::valid::Validator { + use naga::valid::Capabilities as Caps; let mut caps = Caps::empty(); caps.set( Caps::PUSH_CONSTANT, @@ -1554,20 +1623,6 @@ impl Device { self.features.intersects(wgt::Features::SUBGROUP_BARRIER), ); - let debug_source = - if self.instance_flags.contains(wgt::InstanceFlags::DEBUG) && !source.is_empty() { - Some(hal::DebugSource { - file_name: Cow::Owned( - desc.label - .as_ref() - .map_or("shader".to_string(), |l| l.to_string()), - ), - source_code: Cow::Owned(source.clone()), - }) - } else { - None - }; - let mut subgroup_stages = naga::valid::ShaderStages::empty(); subgroup_stages.set( naga::valid::ShaderStages::COMPUTE | naga::valid::ShaderStages::FRAGMENT, @@ -1584,57 +1639,10 @@ impl Device { } else { naga::valid::SubgroupOperationSet::empty() }; - - let info = naga::valid::Validator::new(naga::valid::ValidationFlags::all(), caps) - .subgroup_stages(subgroup_stages) - .subgroup_operations(subgroup_operations) - .validate(&module) - .map_err(|inner| { - pipeline::CreateShaderModuleError::Validation(pipeline::ShaderError { - source, - label: desc.label.as_ref().map(|l| l.to_string()), - inner: Box::new(inner), - }) - })?; - - let interface = - validation::Interface::new(&module, &info, self.limits.clone(), self.features); - let hal_shader = hal::ShaderInput::Naga(hal::NagaShader { - module, - info, - debug_source, - }); - let hal_desc = hal::ShaderModuleDescriptor { - label: desc.label.to_hal(self.instance_flags), - runtime_checks: desc.shader_bound_checks.runtime_checks(), - }; - let raw = match unsafe { - self.raw - .as_ref() - .unwrap() - .create_shader_module(&hal_desc, hal_shader) - } { - Ok(raw) => raw, - Err(error) => { - return Err(match error { - hal::ShaderError::Device(error) => { - pipeline::CreateShaderModuleError::Device(error.into()) - } - hal::ShaderError::Compilation(ref msg) => { - log::error!("Shader error: {}", msg); - pipeline::CreateShaderModuleError::Generation - } - }) - } - }; - - Ok(pipeline::ShaderModule { - raw: Some(raw), - device: self.clone(), - interface: Some(interface), - info: ResourceInfo::new(desc.label.borrow_or_default(), None), - label: desc.label.borrow_or_default().to_string(), - }) + let mut validator = naga::valid::Validator::new(flags, caps); + validator.subgroup_stages(subgroup_stages); + validator.subgroup_operations(subgroup_operations); + validator } #[allow(unused_unsafe)]