diff --git a/crates/bevy_render/src/pipeline/pipeline_layout.rs b/crates/bevy_render/src/pipeline/pipeline_layout.rs index d372787fe991a..52600348ba750 100644 --- a/crates/bevy_render/src/pipeline/pipeline_layout.rs +++ b/crates/bevy_render/src/pipeline/pipeline_layout.rs @@ -25,16 +25,21 @@ impl PipelineLayout { for shader_binding in shader_bind_group.bindings.iter() { if let Some(binding) = bind_group .bindings - .iter() + .iter_mut() .find(|binding| binding.index == shader_binding.index) { - if binding != shader_binding { + binding.shader_stage |= shader_binding.shader_stage; + if binding.bind_type != shader_binding.bind_type + || binding.name != shader_binding.name + || binding.index != shader_binding.index + { panic!("Binding {} in BindGroup {} does not match across all shader types: {:?} {:?}", binding.index, bind_group.index, binding, shader_binding); } } else { bind_group.bindings.push(shader_binding.clone()); } } + bind_group.update_id(); } None => { bind_groups.insert(shader_bind_group.index, shader_bind_group.clone()); diff --git a/crates/bevy_render/src/shader/shader_reflect.rs b/crates/bevy_render/src/shader/shader_reflect.rs index 86bc07c611832..6ed6b7c31a9d9 100644 --- a/crates/bevy_render/src/shader/shader_reflect.rs +++ b/crates/bevy_render/src/shader/shader_reflect.rs @@ -9,7 +9,8 @@ use bevy_core::AsBytes; use spirv_reflect::{ types::{ ReflectDescriptorBinding, ReflectDescriptorSet, ReflectDescriptorType, ReflectDimension, - ReflectInterfaceVariable, ReflectTypeDescription, ReflectTypeFlags, + ReflectInterfaceVariable, ReflectShaderStageFlags, ReflectTypeDescription, + ReflectTypeFlags, }, ShaderModule, }; @@ -30,9 +31,10 @@ impl ShaderLayout { match ShaderModule::load_u8_data(spirv_data.as_bytes()) { Ok(ref mut module) => { let entry_point_name = module.get_entry_point_name(); + let shader_stage = module.get_shader_stage(); let mut bind_groups = Vec::new(); for descriptor_set in module.enumerate_descriptor_sets(None).unwrap() { - let bind_group = reflect_bind_group(&descriptor_set); + let bind_group = reflect_bind_group(&descriptor_set, shader_stage); bind_groups.push(bind_group); } @@ -150,10 +152,13 @@ fn reflect_vertex_attribute_descriptor( } } -fn reflect_bind_group(descriptor_set: &ReflectDescriptorSet) -> BindGroupDescriptor { +fn reflect_bind_group( + descriptor_set: &ReflectDescriptorSet, + shader_stage: ReflectShaderStageFlags, +) -> BindGroupDescriptor { let mut bindings = Vec::new(); for descriptor_binding in descriptor_set.bindings.iter() { - let binding = reflect_binding(descriptor_binding); + let binding = reflect_binding(descriptor_binding, shader_stage); bindings.push(binding); } @@ -170,7 +175,10 @@ fn reflect_dimension(type_description: &ReflectTypeDescription) -> TextureViewDi } } -fn reflect_binding(binding: &ReflectDescriptorBinding) -> BindingDescriptor { +fn reflect_binding( + binding: &ReflectDescriptorBinding, + shader_stage: ReflectShaderStageFlags, +) -> BindingDescriptor { let type_description = binding.type_description.as_ref().unwrap(); let (name, bind_type) = match binding.descriptor_type { ReflectDescriptorType::UniformBuffer => ( @@ -200,12 +208,24 @@ fn reflect_binding(binding: &ReflectDescriptorBinding) -> BindingDescriptor { _ => panic!("unsupported bind type {:?}", binding.descriptor_type), }; + let mut shader_stage = match shader_stage { + ReflectShaderStageFlags::COMPUTE => BindingShaderStage::COMPUTE, + ReflectShaderStageFlags::VERTEX => BindingShaderStage::VERTEX, + ReflectShaderStageFlags::FRAGMENT => BindingShaderStage::FRAGMENT, + _ => panic!("Only one specified shader stage is supported."), + }; + + let name = name.to_string(); + + if name == "Camera" { + shader_stage = BindingShaderStage::VERTEX | BindingShaderStage::FRAGMENT; + } + BindingDescriptor { index: binding.binding, bind_type, - name: name.to_string(), - // TODO: We should be able to detect which shader program the binding is being used in.. - shader_stage: BindingShaderStage::VERTEX | BindingShaderStage::FRAGMENT, + name, + shader_stage, } } @@ -429,7 +449,7 @@ mod tests { dimension: TextureViewDimension::D2, component_type: TextureComponentType::Float, }, - shader_stage: BindingShaderStage::VERTEX | BindingShaderStage::FRAGMENT, + shader_stage: BindingShaderStage::VERTEX, }] ), ]