From 5ec78ab2df2fec25d408b83078693a0568e0621c Mon Sep 17 00:00:00 2001 From: teoxoy <28601907+teoxoy@users.noreply.github.com> Date: Mon, 13 Nov 2023 17:48:24 +0100 Subject: [PATCH 01/30] add pipeline constants plumbing --- deno_webgpu/pipeline.rs | 12 ++++++--- examples/src/boids/mod.rs | 3 +++ examples/src/bunnymark/mod.rs | 2 ++ examples/src/conservative_raster/mod.rs | 8 ++++++ examples/src/cube/mod.rs | 4 +++ examples/src/hello_compute/mod.rs | 1 + examples/src/hello_synchronization/mod.rs | 2 ++ examples/src/hello_triangle/mod.rs | 2 ++ examples/src/hello_workgroups/mod.rs | 1 + examples/src/mipmap/mod.rs | 4 +++ examples/src/msaa_line/mod.rs | 2 ++ examples/src/render_to_texture/mod.rs | 2 ++ examples/src/repeated_compute/mod.rs | 1 + examples/src/shadow/mod.rs | 3 +++ examples/src/skybox/mod.rs | 4 +++ examples/src/srgb_blend/mod.rs | 2 ++ examples/src/stencil_triangles/mod.rs | 4 +++ examples/src/storage_texture/mod.rs | 1 + examples/src/texture_arrays/mod.rs | 2 ++ examples/src/timestamp_queries/mod.rs | 3 +++ examples/src/uniform_values/mod.rs | 2 ++ examples/src/water/mod.rs | 4 +++ naga-cli/src/bin/naga.rs | 3 +++ naga/benches/criterion.rs | 5 +++- naga/src/back/glsl/mod.rs | 4 ++- naga/src/back/hlsl/mod.rs | 8 ++++++ naga/src/back/hlsl/writer.rs | 3 ++- naga/src/back/mod.rs | 9 +++++++ naga/src/back/msl/mod.rs | 4 ++- naga/src/back/spv/mod.rs | 4 ++- naga/tests/in/interface.param.ron | 1 + naga/tests/snapshots.rs | 6 ++++- player/tests/data/bind-group.ron | 1 + .../tests/data/pipeline-statistics-query.ron | 1 + player/tests/data/quad.ron | 2 ++ player/tests/data/zero-init-buffer.ron | 1 + .../tests/data/zero-init-texture-binding.ron | 1 + tests/src/image.rs | 1 + tests/tests/bgra8unorm_storage.rs | 1 + tests/tests/bind_group_layout_dedup.rs | 5 ++++ tests/tests/buffer.rs | 2 ++ tests/tests/device.rs | 4 +++ tests/tests/mem_leaks.rs | 6 +++-- tests/tests/nv12_texture/mod.rs | 2 ++ tests/tests/occlusion_query/mod.rs | 1 + tests/tests/partially_bounded_arrays/mod.rs | 1 + tests/tests/pipeline.rs | 1 + tests/tests/push_constants.rs | 1 + tests/tests/regression/issue_3349.rs | 2 ++ tests/tests/regression/issue_3457.rs | 4 +++ tests/tests/scissor_tests/mod.rs | 6 +++-- tests/tests/shader/mod.rs | 1 + tests/tests/shader/zero_init_workgroup_mem.rs | 2 ++ tests/tests/shader_primitive_index/mod.rs | 8 +++--- tests/tests/shader_view_format/mod.rs | 2 ++ tests/tests/vertex_indices/mod.rs | 7 ++++-- wgpu-core/src/device/resource.rs | 5 +++- wgpu-core/src/pipeline.rs | 8 ++++++ wgpu-hal/examples/halmark/main.rs | 3 +++ wgpu-hal/examples/ray-traced-triangle/main.rs | 1 + wgpu-hal/src/dx12/device.rs | 5 +++- wgpu-hal/src/gles/device.rs | 1 + wgpu-hal/src/lib.rs | 3 +++ wgpu-hal/src/metal/device.rs | 1 + wgpu-hal/src/vulkan/device.rs | 1 + wgpu/src/backend/wgpu_core.rs | 3 +++ wgpu/src/lib.rs | 25 +++++++++++++++++++ 67 files changed, 214 insertions(+), 21 deletions(-) diff --git a/deno_webgpu/pipeline.rs b/deno_webgpu/pipeline.rs index ab7cf42e7b..3031287607 100644 --- a/deno_webgpu/pipeline.rs +++ b/deno_webgpu/pipeline.rs @@ -8,6 +8,7 @@ use deno_core::ResourceId; use serde::Deserialize; use serde::Serialize; use std::borrow::Cow; +use std::collections::HashMap; use std::rc::Rc; use super::error::WebGpuError; @@ -75,7 +76,7 @@ pub enum GPUPipelineLayoutOrGPUAutoLayoutMode { pub struct GpuProgrammableStage { module: ResourceId, entry_point: Option, - // constants: HashMap + constants: HashMap, } #[op2] @@ -111,7 +112,7 @@ pub fn op_webgpu_create_compute_pipeline( stage: wgpu_core::pipeline::ProgrammableStageDescriptor { module: compute_shader_module_resource.1, entry_point: compute.entry_point.map(Cow::from), - // TODO(lucacasonato): support args.compute.constants + constants: Cow::Owned(compute.constants), }, }; let implicit_pipelines = match layout { @@ -279,6 +280,7 @@ impl<'a> From for wgpu_core::pipeline::VertexBufferLayout struct GpuVertexState { module: ResourceId, entry_point: String, + constants: HashMap, buffers: Vec>, } @@ -306,7 +308,7 @@ struct GpuFragmentState { targets: Vec>, module: u32, entry_point: String, - // TODO(lucacasonato): constants + constants: HashMap, } #[derive(Deserialize)] @@ -356,8 +358,9 @@ pub fn op_webgpu_create_render_pipeline( stage: wgpu_core::pipeline::ProgrammableStageDescriptor { module: fragment_shader_module_resource.1, entry_point: Some(Cow::from(fragment.entry_point)), + constants: Cow::Owned(fragment.constants), }, - targets: Cow::from(fragment.targets), + targets: Cow::Owned(fragment.targets), }) } else { None @@ -378,6 +381,7 @@ pub fn op_webgpu_create_render_pipeline( stage: wgpu_core::pipeline::ProgrammableStageDescriptor { module: vertex_shader_module_resource.1, entry_point: Some(Cow::Owned(args.vertex.entry_point)), + constants: Cow::Owned(args.vertex.constants), }, buffers: Cow::Owned(vertex_buffers), }, diff --git a/examples/src/boids/mod.rs b/examples/src/boids/mod.rs index b608394134..02846beeae 100644 --- a/examples/src/boids/mod.rs +++ b/examples/src/boids/mod.rs @@ -132,6 +132,7 @@ impl crate::framework::Example for Example { vertex: wgpu::VertexState { module: &draw_shader, entry_point: "main_vs", + constants: &Default::default(), buffers: &[ wgpu::VertexBufferLayout { array_stride: 4 * 4, @@ -148,6 +149,7 @@ impl crate::framework::Example for Example { fragment: Some(wgpu::FragmentState { module: &draw_shader, entry_point: "main_fs", + constants: &Default::default(), targets: &[Some(config.view_formats[0].into())], }), primitive: wgpu::PrimitiveState::default(), @@ -163,6 +165,7 @@ impl crate::framework::Example for Example { layout: Some(&compute_pipeline_layout), module: &compute_shader, entry_point: "main", + constants: &Default::default(), }); // buffer for the three 2d triangle vertices of each instance diff --git a/examples/src/bunnymark/mod.rs b/examples/src/bunnymark/mod.rs index c29da351ee..be09478071 100644 --- a/examples/src/bunnymark/mod.rs +++ b/examples/src/bunnymark/mod.rs @@ -203,11 +203,13 @@ impl crate::framework::Example for Example { vertex: wgpu::VertexState { module: &shader, entry_point: "vs_main", + constants: &Default::default(), buffers: &[], }, fragment: Some(wgpu::FragmentState { module: &shader, entry_point: "fs_main", + constants: &Default::default(), targets: &[Some(wgpu::ColorTargetState { format: config.view_formats[0], blend: Some(wgpu::BlendState::ALPHA_BLENDING), diff --git a/examples/src/conservative_raster/mod.rs b/examples/src/conservative_raster/mod.rs index ce2054caa0..12cdaa399d 100644 --- a/examples/src/conservative_raster/mod.rs +++ b/examples/src/conservative_raster/mod.rs @@ -97,11 +97,13 @@ impl crate::framework::Example for Example { vertex: wgpu::VertexState { module: &shader_triangle_and_lines, entry_point: "vs_main", + constants: &Default::default(), buffers: &[], }, fragment: Some(wgpu::FragmentState { module: &shader_triangle_and_lines, entry_point: "fs_main_red", + constants: &Default::default(), targets: &[Some(RENDER_TARGET_FORMAT.into())], }), primitive: wgpu::PrimitiveState { @@ -120,11 +122,13 @@ impl crate::framework::Example for Example { vertex: wgpu::VertexState { module: &shader_triangle_and_lines, entry_point: "vs_main", + constants: &Default::default(), buffers: &[], }, fragment: Some(wgpu::FragmentState { module: &shader_triangle_and_lines, entry_point: "fs_main_blue", + constants: &Default::default(), targets: &[Some(RENDER_TARGET_FORMAT.into())], }), primitive: wgpu::PrimitiveState::default(), @@ -144,11 +148,13 @@ impl crate::framework::Example for Example { vertex: wgpu::VertexState { module: &shader_triangle_and_lines, entry_point: "vs_main", + constants: &Default::default(), buffers: &[], }, fragment: Some(wgpu::FragmentState { module: &shader_triangle_and_lines, entry_point: "fs_main_white", + constants: &Default::default(), targets: &[Some(config.view_formats[0].into())], }), primitive: wgpu::PrimitiveState { @@ -205,11 +211,13 @@ impl crate::framework::Example for Example { vertex: wgpu::VertexState { module: &shader, entry_point: "vs_main", + constants: &Default::default(), buffers: &[], }, fragment: Some(wgpu::FragmentState { module: &shader, entry_point: "fs_main", + constants: &Default::default(), targets: &[Some(config.view_formats[0].into())], }), primitive: wgpu::PrimitiveState::default(), diff --git a/examples/src/cube/mod.rs b/examples/src/cube/mod.rs index d21aafe5de..d87193fcfe 100644 --- a/examples/src/cube/mod.rs +++ b/examples/src/cube/mod.rs @@ -244,11 +244,13 @@ impl crate::framework::Example for Example { vertex: wgpu::VertexState { module: &shader, entry_point: "vs_main", + constants: &Default::default(), buffers: &vertex_buffers, }, fragment: Some(wgpu::FragmentState { module: &shader, entry_point: "fs_main", + constants: &Default::default(), targets: &[Some(config.view_formats[0].into())], }), primitive: wgpu::PrimitiveState { @@ -270,11 +272,13 @@ impl crate::framework::Example for Example { vertex: wgpu::VertexState { module: &shader, entry_point: "vs_main", + constants: &Default::default(), buffers: &vertex_buffers, }, fragment: Some(wgpu::FragmentState { module: &shader, entry_point: "fs_wire", + constants: &Default::default(), targets: &[Some(wgpu::ColorTargetState { format: config.view_formats[0], blend: Some(wgpu::BlendState { diff --git a/examples/src/hello_compute/mod.rs b/examples/src/hello_compute/mod.rs index ef452bf023..63169662e0 100644 --- a/examples/src/hello_compute/mod.rs +++ b/examples/src/hello_compute/mod.rs @@ -109,6 +109,7 @@ async fn execute_gpu_inner( layout: None, module: &cs_module, entry_point: "main", + constants: &Default::default(), }); // Instantiates the bind group, once again specifying the binding of buffers. diff --git a/examples/src/hello_synchronization/mod.rs b/examples/src/hello_synchronization/mod.rs index c2a6fe8b26..7dc2e6c9c0 100644 --- a/examples/src/hello_synchronization/mod.rs +++ b/examples/src/hello_synchronization/mod.rs @@ -103,12 +103,14 @@ async fn execute( layout: Some(&pipeline_layout), module: &shaders_module, entry_point: "patient_main", + constants: &Default::default(), }); let hasty_pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor { label: None, layout: Some(&pipeline_layout), module: &shaders_module, entry_point: "hasty_main", + constants: &Default::default(), }); //---------------------------------------------------------- diff --git a/examples/src/hello_triangle/mod.rs b/examples/src/hello_triangle/mod.rs index faa1db8f8b..76b7a5a73d 100644 --- a/examples/src/hello_triangle/mod.rs +++ b/examples/src/hello_triangle/mod.rs @@ -60,10 +60,12 @@ async fn run(event_loop: EventLoop<()>, window: Window) { module: &shader, entry_point: "vs_main", buffers: &[], + constants: &Default::default(), }, fragment: Some(wgpu::FragmentState { module: &shader, entry_point: "fs_main", + constants: &Default::default(), targets: &[Some(swapchain_format.into())], }), primitive: wgpu::PrimitiveState::default(), diff --git a/examples/src/hello_workgroups/mod.rs b/examples/src/hello_workgroups/mod.rs index 3e5795048f..5fb0eff6b1 100644 --- a/examples/src/hello_workgroups/mod.rs +++ b/examples/src/hello_workgroups/mod.rs @@ -110,6 +110,7 @@ async fn run() { layout: Some(&pipeline_layout), module: &shader, entry_point: "main", + constants: &Default::default(), }); //---------------------------------------------------------- diff --git a/examples/src/mipmap/mod.rs b/examples/src/mipmap/mod.rs index 7551021024..fc40d5d884 100644 --- a/examples/src/mipmap/mod.rs +++ b/examples/src/mipmap/mod.rs @@ -93,11 +93,13 @@ impl Example { vertex: wgpu::VertexState { module: &shader, entry_point: "vs_main", + constants: &Default::default(), buffers: &[], }, fragment: Some(wgpu::FragmentState { module: &shader, entry_point: "fs_main", + constants: &Default::default(), targets: &[Some(TEXTURE_FORMAT.into())], }), primitive: wgpu::PrimitiveState { @@ -290,11 +292,13 @@ impl crate::framework::Example for Example { vertex: wgpu::VertexState { module: &shader, entry_point: "vs_main", + constants: &Default::default(), buffers: &[], }, fragment: Some(wgpu::FragmentState { module: &shader, entry_point: "fs_main", + constants: &Default::default(), targets: &[Some(config.view_formats[0].into())], }), primitive: wgpu::PrimitiveState { diff --git a/examples/src/msaa_line/mod.rs b/examples/src/msaa_line/mod.rs index 595bcbf17a..178968f47b 100644 --- a/examples/src/msaa_line/mod.rs +++ b/examples/src/msaa_line/mod.rs @@ -54,6 +54,7 @@ impl Example { vertex: wgpu::VertexState { module: shader, entry_point: "vs_main", + constants: &Default::default(), buffers: &[wgpu::VertexBufferLayout { array_stride: std::mem::size_of::() as wgpu::BufferAddress, step_mode: wgpu::VertexStepMode::Vertex, @@ -63,6 +64,7 @@ impl Example { fragment: Some(wgpu::FragmentState { module: shader, entry_point: "fs_main", + constants: &Default::default(), targets: &[Some(config.view_formats[0].into())], }), primitive: wgpu::PrimitiveState { diff --git a/examples/src/render_to_texture/mod.rs b/examples/src/render_to_texture/mod.rs index 96be26b0f9..0cb2cdea74 100644 --- a/examples/src/render_to_texture/mod.rs +++ b/examples/src/render_to_texture/mod.rs @@ -59,11 +59,13 @@ async fn run(_path: Option) { vertex: wgpu::VertexState { module: &shader, entry_point: "vs_main", + constants: &Default::default(), buffers: &[], }, fragment: Some(wgpu::FragmentState { module: &shader, entry_point: "fs_main", + constants: &Default::default(), targets: &[Some(wgpu::TextureFormat::Rgba8UnormSrgb.into())], }), primitive: wgpu::PrimitiveState::default(), diff --git a/examples/src/repeated_compute/mod.rs b/examples/src/repeated_compute/mod.rs index faed2467bd..0c47055191 100644 --- a/examples/src/repeated_compute/mod.rs +++ b/examples/src/repeated_compute/mod.rs @@ -245,6 +245,7 @@ impl WgpuContext { layout: Some(&pipeline_layout), module: &shader, entry_point: "main", + constants: &Default::default(), }); WgpuContext { diff --git a/examples/src/shadow/mod.rs b/examples/src/shadow/mod.rs index 485d0d78d6..d0a29cc8b0 100644 --- a/examples/src/shadow/mod.rs +++ b/examples/src/shadow/mod.rs @@ -500,6 +500,7 @@ impl crate::framework::Example for Example { vertex: wgpu::VertexState { module: &shader, entry_point: "vs_bake", + constants: &Default::default(), buffers: &[vb_desc.clone()], }, fragment: None, @@ -632,6 +633,7 @@ impl crate::framework::Example for Example { vertex: wgpu::VertexState { module: &shader, entry_point: "vs_main", + constants: &Default::default(), buffers: &[vb_desc], }, fragment: Some(wgpu::FragmentState { @@ -641,6 +643,7 @@ impl crate::framework::Example for Example { } else { "fs_main_without_storage" }, + constants: &Default::default(), targets: &[Some(config.view_formats[0].into())], }), primitive: wgpu::PrimitiveState { diff --git a/examples/src/skybox/mod.rs b/examples/src/skybox/mod.rs index bdb5e66142..443c9d41e0 100644 --- a/examples/src/skybox/mod.rs +++ b/examples/src/skybox/mod.rs @@ -199,11 +199,13 @@ impl crate::framework::Example for Example { vertex: wgpu::VertexState { module: &shader, entry_point: "vs_sky", + constants: &Default::default(), buffers: &[], }, fragment: Some(wgpu::FragmentState { module: &shader, entry_point: "fs_sky", + constants: &Default::default(), targets: &[Some(config.view_formats[0].into())], }), primitive: wgpu::PrimitiveState { @@ -226,6 +228,7 @@ impl crate::framework::Example for Example { vertex: wgpu::VertexState { module: &shader, entry_point: "vs_entity", + constants: &Default::default(), buffers: &[wgpu::VertexBufferLayout { array_stride: std::mem::size_of::() as wgpu::BufferAddress, step_mode: wgpu::VertexStepMode::Vertex, @@ -235,6 +238,7 @@ impl crate::framework::Example for Example { fragment: Some(wgpu::FragmentState { module: &shader, entry_point: "fs_entity", + constants: &Default::default(), targets: &[Some(config.view_formats[0].into())], }), primitive: wgpu::PrimitiveState { diff --git a/examples/src/srgb_blend/mod.rs b/examples/src/srgb_blend/mod.rs index d4021e6c5f..fdff310c31 100644 --- a/examples/src/srgb_blend/mod.rs +++ b/examples/src/srgb_blend/mod.rs @@ -131,11 +131,13 @@ impl crate::framework::Example for Example { vertex: wgpu::VertexState { module: &shader, entry_point: "vs_main", + constants: &Default::default(), buffers: &vertex_buffers, }, fragment: Some(wgpu::FragmentState { module: &shader, entry_point: "fs_main", + constants: &Default::default(), targets: &[Some(wgpu::ColorTargetState { format: config.view_formats[0], blend: Some(wgpu::BlendState::ALPHA_BLENDING), diff --git a/examples/src/stencil_triangles/mod.rs b/examples/src/stencil_triangles/mod.rs index bf645d3a34..07b8e3ec51 100644 --- a/examples/src/stencil_triangles/mod.rs +++ b/examples/src/stencil_triangles/mod.rs @@ -74,11 +74,13 @@ impl crate::framework::Example for Example { vertex: wgpu::VertexState { module: &shader, entry_point: "vs_main", + constants: &Default::default(), buffers: &vertex_buffers, }, fragment: Some(wgpu::FragmentState { module: &shader, entry_point: "fs_main", + constants: &Default::default(), targets: &[Some(wgpu::ColorTargetState { format: config.view_formats[0], blend: None, @@ -112,11 +114,13 @@ impl crate::framework::Example for Example { vertex: wgpu::VertexState { module: &shader, entry_point: "vs_main", + constants: &Default::default(), buffers: &vertex_buffers, }, fragment: Some(wgpu::FragmentState { module: &shader, entry_point: "fs_main", + constants: &Default::default(), targets: &[Some(config.view_formats[0].into())], }), primitive: Default::default(), diff --git a/examples/src/storage_texture/mod.rs b/examples/src/storage_texture/mod.rs index d4e207f3bc..f83f61967d 100644 --- a/examples/src/storage_texture/mod.rs +++ b/examples/src/storage_texture/mod.rs @@ -100,6 +100,7 @@ async fn run(_path: Option) { layout: Some(&pipeline_layout), module: &shader, entry_point: "main", + constants: &Default::default(), }); log::info!("Wgpu context set up."); diff --git a/examples/src/texture_arrays/mod.rs b/examples/src/texture_arrays/mod.rs index ccad759993..c786b0efee 100644 --- a/examples/src/texture_arrays/mod.rs +++ b/examples/src/texture_arrays/mod.rs @@ -321,6 +321,7 @@ impl crate::framework::Example for Example { vertex: wgpu::VertexState { module: &base_shader_module, entry_point: "vert_main", + constants: &Default::default(), buffers: &[wgpu::VertexBufferLayout { array_stride: vertex_size as wgpu::BufferAddress, step_mode: wgpu::VertexStepMode::Vertex, @@ -330,6 +331,7 @@ impl crate::framework::Example for Example { fragment: Some(wgpu::FragmentState { module: fragment_shader_module, entry_point: fragment_entry_point, + constants: &Default::default(), targets: &[Some(config.view_formats[0].into())], }), primitive: wgpu::PrimitiveState { diff --git a/examples/src/timestamp_queries/mod.rs b/examples/src/timestamp_queries/mod.rs index 4911af4136..58952c76c0 100644 --- a/examples/src/timestamp_queries/mod.rs +++ b/examples/src/timestamp_queries/mod.rs @@ -298,6 +298,7 @@ fn compute_pass( layout: None, module, entry_point: "main_cs", + constants: &Default::default(), }); let bind_group_layout = compute_pipeline.get_bind_group_layout(0); let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor { @@ -352,11 +353,13 @@ fn render_pass( vertex: wgpu::VertexState { module, entry_point: "vs_main", + constants: &Default::default(), buffers: &[], }, fragment: Some(wgpu::FragmentState { module, entry_point: "fs_main", + constants: &Default::default(), targets: &[Some(format.into())], }), primitive: wgpu::PrimitiveState::default(), diff --git a/examples/src/uniform_values/mod.rs b/examples/src/uniform_values/mod.rs index 4a31ddc069..1ddee03e9f 100644 --- a/examples/src/uniform_values/mod.rs +++ b/examples/src/uniform_values/mod.rs @@ -179,11 +179,13 @@ impl WgpuContext { vertex: wgpu::VertexState { module: &shader, entry_point: "vs_main", + constants: &Default::default(), buffers: &[], }, fragment: Some(wgpu::FragmentState { module: &shader, entry_point: "fs_main", + constants: &Default::default(), targets: &[Some(swapchain_format.into())], }), primitive: wgpu::PrimitiveState::default(), diff --git a/examples/src/water/mod.rs b/examples/src/water/mod.rs index 7371e96155..0cd00aac54 100644 --- a/examples/src/water/mod.rs +++ b/examples/src/water/mod.rs @@ -512,6 +512,7 @@ impl crate::framework::Example for Example { vertex: wgpu::VertexState { module: &water_module, entry_point: "vs_main", + constants: &Default::default(), // Layout of our vertices. This should match the structs // which are uploaded to the GPU. This should also be // ensured by tagging on either a `#[repr(C)]` onto a @@ -527,6 +528,7 @@ impl crate::framework::Example for Example { fragment: Some(wgpu::FragmentState { module: &water_module, entry_point: "fs_main", + constants: &Default::default(), // Describes how the colour will be interpolated // and assigned to the output attachment. targets: &[Some(wgpu::ColorTargetState { @@ -581,6 +583,7 @@ impl crate::framework::Example for Example { vertex: wgpu::VertexState { module: &terrain_module, entry_point: "vs_main", + constants: &Default::default(), buffers: &[wgpu::VertexBufferLayout { array_stride: terrain_vertex_size as wgpu::BufferAddress, step_mode: wgpu::VertexStepMode::Vertex, @@ -590,6 +593,7 @@ impl crate::framework::Example for Example { fragment: Some(wgpu::FragmentState { module: &terrain_module, entry_point: "fs_main", + constants: &Default::default(), targets: &[Some(config.view_formats[0].into())], }), primitive: wgpu::PrimitiveState { diff --git a/naga-cli/src/bin/naga.rs b/naga-cli/src/bin/naga.rs index 29e5a5044b..a20611114b 100644 --- a/naga-cli/src/bin/naga.rs +++ b/naga-cli/src/bin/naga.rs @@ -593,6 +593,7 @@ fn write_output( pipeline_options_owned = spv::PipelineOptions { entry_point: name.clone(), shader_stage: module.entry_points[ep_index].stage, + constants: naga::back::PipelineConstants::default(), }; Some(&pipeline_options_owned) } @@ -633,6 +634,7 @@ fn write_output( _ => unreachable!(), }, multiview: None, + constants: naga::back::PipelineConstants::default(), }; let mut buffer = String::new(); @@ -668,6 +670,7 @@ fn write_output( "Generating hlsl output requires validation to \ succeed, and it failed in a previous step", ))?, + &hlsl::PipelineOptions::default(), ) .unwrap_pretty(); fs::write(output_path, buffer)?; diff --git a/naga/benches/criterion.rs b/naga/benches/criterion.rs index e57c58a847..420c9ee335 100644 --- a/naga/benches/criterion.rs +++ b/naga/benches/criterion.rs @@ -193,6 +193,7 @@ fn backends(c: &mut Criterion) { let pipeline_options = naga::back::spv::PipelineOptions { shader_stage: ep.stage, entry_point: ep.name.clone(), + constants: naga::back::PipelineConstants::default(), }; writer .write(module, info, Some(&pipeline_options), &None, &mut data) @@ -223,10 +224,11 @@ fn backends(c: &mut Criterion) { group.bench_function("hlsl", |b| { b.iter(|| { let options = naga::back::hlsl::Options::default(); + let pipeline_options = naga::back::hlsl::PipelineOptions::default(); let mut string = String::new(); for &(ref module, ref info) in inputs.iter() { let mut writer = naga::back::hlsl::Writer::new(&mut string, &options); - let _ = writer.write(module, info); // may fail on unimplemented things + let _ = writer.write(module, info, &pipeline_options); // may fail on unimplemented things string.clear(); } }); @@ -248,6 +250,7 @@ fn backends(c: &mut Criterion) { shader_stage: ep.stage, entry_point: ep.name.clone(), multiview: None, + constants: naga::back::PipelineConstants::default(), }; // might be `Err` if missing features diff --git a/naga/src/back/glsl/mod.rs b/naga/src/back/glsl/mod.rs index 9bda594610..410b69eaf9 100644 --- a/naga/src/back/glsl/mod.rs +++ b/naga/src/back/glsl/mod.rs @@ -282,7 +282,7 @@ impl Default for Options { } /// A subset of options meant to be changed per pipeline. -#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[derive(Debug, Clone)] #[cfg_attr(feature = "serialize", derive(serde::Serialize))] #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] pub struct PipelineOptions { @@ -294,6 +294,8 @@ pub struct PipelineOptions { pub entry_point: String, /// How many views to render to, if doing multiview rendering. pub multiview: Option, + /// Pipeline constants. + pub constants: back::PipelineConstants, } #[derive(Debug)] diff --git a/naga/src/back/hlsl/mod.rs b/naga/src/back/hlsl/mod.rs index f37a223f47..37d26bf3b2 100644 --- a/naga/src/back/hlsl/mod.rs +++ b/naga/src/back/hlsl/mod.rs @@ -195,6 +195,14 @@ pub struct Options { pub zero_initialize_workgroup_memory: bool, } +#[derive(Clone, Debug, Default)] +#[cfg_attr(feature = "serialize", derive(serde::Serialize))] +#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] +pub struct PipelineOptions { + /// Pipeline constants. + pub constants: back::PipelineConstants, +} + impl Default for Options { fn default() -> Self { Options { diff --git a/naga/src/back/hlsl/writer.rs b/naga/src/back/hlsl/writer.rs index 4ba856946b..b442637c43 100644 --- a/naga/src/back/hlsl/writer.rs +++ b/naga/src/back/hlsl/writer.rs @@ -1,7 +1,7 @@ use super::{ help::{WrappedArrayLength, WrappedConstructor, WrappedImageQuery, WrappedStructMatrixAccess}, storage::StoreValue, - BackendResult, Error, Options, + BackendResult, Error, Options, PipelineOptions, }; use crate::{ back, @@ -167,6 +167,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { &mut self, module: &Module, module_info: &valid::ModuleInfo, + _pipeline_options: &PipelineOptions, ) -> Result { self.reset(module); diff --git a/naga/src/back/mod.rs b/naga/src/back/mod.rs index c8f091decb..61dc4a0601 100644 --- a/naga/src/back/mod.rs +++ b/naga/src/back/mod.rs @@ -26,6 +26,15 @@ pub const BAKE_PREFIX: &str = "_e"; /// Expressions that need baking. pub type NeedBakeExpressions = crate::FastHashSet>; +/// Specifies the values of pipeline-overridable constants in the shader module. +/// +/// If an `@id` attribute was specified on the declaration, +/// the key must be the pipeline constant ID as a decimal ASCII number; if not, +/// the key must be the constant's identifier name. +/// +/// The value may represent any of WGSL's concrete scalar types. +pub type PipelineConstants = std::collections::HashMap; + /// Indentation level. #[derive(Clone, Copy)] pub struct Level(pub usize); diff --git a/naga/src/back/msl/mod.rs b/naga/src/back/msl/mod.rs index 68e5b79906..7e05be29bd 100644 --- a/naga/src/back/msl/mod.rs +++ b/naga/src/back/msl/mod.rs @@ -221,7 +221,7 @@ impl Default for Options { } /// A subset of options that are meant to be changed per pipeline. -#[derive(Debug, Default, Clone, PartialEq, Eq, Hash)] +#[derive(Debug, Default, Clone)] #[cfg_attr(feature = "serialize", derive(serde::Serialize))] #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] pub struct PipelineOptions { @@ -232,6 +232,8 @@ pub struct PipelineOptions { /// /// Enable this for vertex shaders with point primitive topologies. pub allow_and_force_point_size: bool, + /// Pipeline constants. + pub constants: crate::back::PipelineConstants, } impl Options { diff --git a/naga/src/back/spv/mod.rs b/naga/src/back/spv/mod.rs index eb29e3cd8b..087c49bccf 100644 --- a/naga/src/back/spv/mod.rs +++ b/naga/src/back/spv/mod.rs @@ -725,7 +725,7 @@ impl<'a> Default for Options<'a> { } // A subset of options meant to be changed per pipeline. -#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[derive(Debug, Clone)] #[cfg_attr(feature = "serialize", derive(serde::Serialize))] #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] pub struct PipelineOptions { @@ -735,6 +735,8 @@ pub struct PipelineOptions { /// /// If no entry point that matches is found while creating a [`Writer`], a error will be thrown. pub entry_point: String, + /// Pipeline constants. + pub constants: crate::back::PipelineConstants, } pub fn write_vec( diff --git a/naga/tests/in/interface.param.ron b/naga/tests/in/interface.param.ron index 4d85661767..19ed5e464c 100644 --- a/naga/tests/in/interface.param.ron +++ b/naga/tests/in/interface.param.ron @@ -27,5 +27,6 @@ ), msl_pipeline: ( allow_and_force_point_size: true, + constants: {}, ), ) diff --git a/naga/tests/snapshots.rs b/naga/tests/snapshots.rs index 198a4aa2db..f19077869d 100644 --- a/naga/tests/snapshots.rs +++ b/naga/tests/snapshots.rs @@ -428,6 +428,7 @@ fn write_output_spv( let pipeline_options = spv::PipelineOptions { entry_point: ep.name.clone(), shader_stage: ep.stage, + constants: naga::back::PipelineConstants::default(), }; write_output_spv_inner( input, @@ -516,6 +517,7 @@ fn write_output_glsl( shader_stage: stage, entry_point: ep_name.to_string(), multiview, + constants: naga::back::PipelineConstants::default(), }; let mut buffer = String::new(); @@ -548,7 +550,9 @@ fn write_output_hlsl( let mut buffer = String::new(); let mut writer = hlsl::Writer::new(&mut buffer, options); - let reflection_info = writer.write(module, info).expect("HLSL write failed"); + let reflection_info = writer + .write(module, info, &hlsl::PipelineOptions::default()) + .expect("HLSL write failed"); input.write_output_file("hlsl", "hlsl", buffer); diff --git a/player/tests/data/bind-group.ron b/player/tests/data/bind-group.ron index 471f921fe9..92415e4ff3 100644 --- a/player/tests/data/bind-group.ron +++ b/player/tests/data/bind-group.ron @@ -57,6 +57,7 @@ stage: ( module: Id(0, 1, Empty), entry_point: None, + constants: {}, ), ), ), diff --git a/player/tests/data/pipeline-statistics-query.ron b/player/tests/data/pipeline-statistics-query.ron index 8274e341f2..3c672f4e56 100644 --- a/player/tests/data/pipeline-statistics-query.ron +++ b/player/tests/data/pipeline-statistics-query.ron @@ -30,6 +30,7 @@ stage: ( module: Id(0, 1, Empty), entry_point: None, + constants: {}, ), ), ), diff --git a/player/tests/data/quad.ron b/player/tests/data/quad.ron index b7db1f8c24..9d6b4a25f6 100644 --- a/player/tests/data/quad.ron +++ b/player/tests/data/quad.ron @@ -58,6 +58,7 @@ stage: ( module: Id(0, 1, Empty), entry_point: None, + constants: {}, ), buffers: [], ), @@ -65,6 +66,7 @@ stage: ( module: Id(0, 1, Empty), entry_point: None, + constants: {}, ), targets: [ Some(( diff --git a/player/tests/data/zero-init-buffer.ron b/player/tests/data/zero-init-buffer.ron index be9a20d898..5697a2555e 100644 --- a/player/tests/data/zero-init-buffer.ron +++ b/player/tests/data/zero-init-buffer.ron @@ -134,6 +134,7 @@ stage: ( module: Id(0, 1, Empty), entry_point: None, + constants: {}, ), ), ), diff --git a/player/tests/data/zero-init-texture-binding.ron b/player/tests/data/zero-init-texture-binding.ron index 41a513f60f..340cb0cfa2 100644 --- a/player/tests/data/zero-init-texture-binding.ron +++ b/player/tests/data/zero-init-texture-binding.ron @@ -135,6 +135,7 @@ stage: ( module: Id(0, 1, Empty), entry_point: None, + constants: {}, ), ), ), diff --git a/tests/src/image.rs b/tests/src/image.rs index e1b9b07201..98310233c9 100644 --- a/tests/src/image.rs +++ b/tests/src/image.rs @@ -369,6 +369,7 @@ fn copy_via_compute( layout: Some(&pll), module: &sm, entry_point: "copy_texture_to_buffer", + constants: &Default::default(), }); { diff --git a/tests/tests/bgra8unorm_storage.rs b/tests/tests/bgra8unorm_storage.rs index b1ca3b8362..c3913e5df8 100644 --- a/tests/tests/bgra8unorm_storage.rs +++ b/tests/tests/bgra8unorm_storage.rs @@ -96,6 +96,7 @@ static BGRA8_UNORM_STORAGE: GpuTestConfiguration = GpuTestConfiguration::new() label: None, layout: Some(&pl), entry_point: "main", + constants: &Default::default(), module: &module, }); diff --git a/tests/tests/bind_group_layout_dedup.rs b/tests/tests/bind_group_layout_dedup.rs index 7ac30fb8fe..519cfbda29 100644 --- a/tests/tests/bind_group_layout_dedup.rs +++ b/tests/tests/bind_group_layout_dedup.rs @@ -90,6 +90,7 @@ async fn bgl_dedupe(ctx: TestingContext) { layout: Some(&pipeline_layout), module: &module, entry_point: "no_resources", + constants: &Default::default(), }; let pipeline = ctx.device.create_compute_pipeline(&desc); @@ -218,6 +219,7 @@ fn bgl_dedupe_with_dropped_user_handle(ctx: TestingContext) { layout: Some(&pipeline_layout), module: &module, entry_point: "no_resources", + constants: &Default::default(), }); let mut encoder = ctx.device.create_command_encoder(&Default::default()); @@ -263,6 +265,7 @@ fn bgl_dedupe_derived(ctx: TestingContext) { layout: None, module: &module, entry_point: "resources", + constants: &Default::default(), }); // We create two bind groups, pulling the bind_group_layout from the pipeline each time. @@ -333,6 +336,7 @@ fn separate_programs_have_incompatible_derived_bgls(ctx: TestingContext) { layout: None, module: &module, entry_point: "resources", + constants: &Default::default(), }; // Create two pipelines, creating a BG from the second. let pipeline1 = ctx.device.create_compute_pipeline(&desc); @@ -394,6 +398,7 @@ fn derived_bgls_incompatible_with_regular_bgls(ctx: TestingContext) { layout: None, module: &module, entry_point: "resources", + constants: &Default::default(), }); // Create a matching BGL diff --git a/tests/tests/buffer.rs b/tests/tests/buffer.rs index a5fcf3e595..1622995c35 100644 --- a/tests/tests/buffer.rs +++ b/tests/tests/buffer.rs @@ -224,6 +224,7 @@ static MINIMUM_BUFFER_BINDING_SIZE_LAYOUT: GpuTestConfiguration = GpuTestConfigu layout: Some(&pipeline_layout), module: &shader_module, entry_point: "main", + constants: &Default::default(), }); }); }); @@ -292,6 +293,7 @@ static MINIMUM_BUFFER_BINDING_SIZE_DISPATCH: GpuTestConfiguration = GpuTestConfi layout: Some(&pipeline_layout), module: &shader_module, entry_point: "main", + constants: &Default::default(), }); let buffer = ctx.device.create_buffer(&wgpu::BufferDescriptor { diff --git a/tests/tests/device.rs b/tests/tests/device.rs index ff596d0918..82e3f71a1c 100644 --- a/tests/tests/device.rs +++ b/tests/tests/device.rs @@ -480,6 +480,7 @@ static DEVICE_DESTROY_THEN_MORE: GpuTestConfiguration = GpuTestConfiguration::ne vertex: wgpu::VertexState { module: &shader_module, entry_point: "", + constants: &Default::default(), buffers: &[], }, primitive: wgpu::PrimitiveState::default(), @@ -498,6 +499,7 @@ static DEVICE_DESTROY_THEN_MORE: GpuTestConfiguration = GpuTestConfiguration::ne layout: None, module: &shader_module, entry_point: "", + constants: &Default::default(), }); }); @@ -734,6 +736,7 @@ fn vs_main() -> @builtin(position) vec4 { fragment: Some(wgpu::FragmentState { module: &trivial_shaders_with_some_reversed_bindings, entry_point: "fs_main", + constants: &Default::default(), targets: &[Some(wgt::ColorTargetState { format: wgt::TextureFormat::Bgra8Unorm, blend: None, @@ -747,6 +750,7 @@ fn vs_main() -> @builtin(position) vec4 { vertex: wgpu::VertexState { module: &trivial_shaders_with_some_reversed_bindings, entry_point: "vs_main", + constants: &Default::default(), buffers: &[], }, primitive: wgt::PrimitiveState::default(), diff --git a/tests/tests/mem_leaks.rs b/tests/tests/mem_leaks.rs index 83fa2bbc11..949b4d96ce 100644 --- a/tests/tests/mem_leaks.rs +++ b/tests/tests/mem_leaks.rs @@ -95,15 +95,17 @@ async fn draw_test_with_reports( layout: Some(&ppl), vertex: wgpu::VertexState { buffers: &[], - entry_point: "vs_main_builtin", module: &shader, + entry_point: "vs_main_builtin", + constants: &Default::default(), }, primitive: wgpu::PrimitiveState::default(), depth_stencil: None, multisample: wgpu::MultisampleState::default(), fragment: Some(wgpu::FragmentState { - entry_point: "fs_main", module: &shader, + entry_point: "fs_main", + constants: &Default::default(), targets: &[Some(wgpu::ColorTargetState { format: wgpu::TextureFormat::Rgba8Unorm, blend: None, diff --git a/tests/tests/nv12_texture/mod.rs b/tests/tests/nv12_texture/mod.rs index aba12e02b6..0f4ba16f25 100644 --- a/tests/tests/nv12_texture/mod.rs +++ b/tests/tests/nv12_texture/mod.rs @@ -24,11 +24,13 @@ static NV12_TEXTURE_CREATION_SAMPLING: GpuTestConfiguration = GpuTestConfigurati vertex: wgpu::VertexState { module: &shader, entry_point: "vs_main", + constants: &Default::default(), buffers: &[], }, fragment: Some(wgpu::FragmentState { module: &shader, entry_point: "fs_main", + constants: &Default::default(), targets: &[Some(target_format.into())], }), primitive: wgpu::PrimitiveState { diff --git a/tests/tests/occlusion_query/mod.rs b/tests/tests/occlusion_query/mod.rs index 0c3f3072a5..2db035bfb2 100644 --- a/tests/tests/occlusion_query/mod.rs +++ b/tests/tests/occlusion_query/mod.rs @@ -37,6 +37,7 @@ static OCCLUSION_QUERY: GpuTestConfiguration = GpuTestConfiguration::new() vertex: wgpu::VertexState { module: &shader, entry_point: "vs_main", + constants: &Default::default(), buffers: &[], }, fragment: None, diff --git a/tests/tests/partially_bounded_arrays/mod.rs b/tests/tests/partially_bounded_arrays/mod.rs index a1350718f5..b93e900a9c 100644 --- a/tests/tests/partially_bounded_arrays/mod.rs +++ b/tests/tests/partially_bounded_arrays/mod.rs @@ -69,6 +69,7 @@ static PARTIALLY_BOUNDED_ARRAY: GpuTestConfiguration = GpuTestConfiguration::new layout: Some(&pipeline_layout), module: &cs_module, entry_point: "main", + constants: &Default::default(), }); let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor { diff --git a/tests/tests/pipeline.rs b/tests/tests/pipeline.rs index 2350f10663..c8814e25f7 100644 --- a/tests/tests/pipeline.rs +++ b/tests/tests/pipeline.rs @@ -28,6 +28,7 @@ static PIPELINE_DEFAULT_LAYOUT_BAD_MODULE: GpuTestConfiguration = GpuTestConfigu layout: None, module: &module, entry_point: "doesn't exist", + constants: &Default::default(), }); pipeline.get_bind_group_layout(0); diff --git a/tests/tests/push_constants.rs b/tests/tests/push_constants.rs index 0e16b3df65..d1119476c3 100644 --- a/tests/tests/push_constants.rs +++ b/tests/tests/push_constants.rs @@ -103,6 +103,7 @@ async fn partial_update_test(ctx: TestingContext) { layout: Some(&pipeline_layout), module: &sm, entry_point: "main", + constants: &Default::default(), }); let mut encoder = ctx diff --git a/tests/tests/regression/issue_3349.rs b/tests/tests/regression/issue_3349.rs index 2d94d56920..93b91b9d7b 100644 --- a/tests/tests/regression/issue_3349.rs +++ b/tests/tests/regression/issue_3349.rs @@ -102,11 +102,13 @@ async fn multi_stage_data_binding_test(ctx: TestingContext) { vertex: wgpu::VertexState { module: &vs_sm, entry_point: "vs_main", + constants: &Default::default(), buffers: &[], }, fragment: Some(wgpu::FragmentState { module: &fs_sm, entry_point: "fs_main", + constants: &Default::default(), targets: &[Some(wgpu::ColorTargetState { format: wgpu::TextureFormat::Rgba8Unorm, blend: None, diff --git a/tests/tests/regression/issue_3457.rs b/tests/tests/regression/issue_3457.rs index 12ace62e88..0fca44b0c9 100644 --- a/tests/tests/regression/issue_3457.rs +++ b/tests/tests/regression/issue_3457.rs @@ -52,6 +52,7 @@ static PASS_RESET_VERTEX_BUFFER: GpuTestConfiguration = vertex: VertexState { module: &module, entry_point: "double_buffer_vert", + constants: &Default::default(), buffers: &[ VertexBufferLayout { array_stride: 16, @@ -71,6 +72,7 @@ static PASS_RESET_VERTEX_BUFFER: GpuTestConfiguration = fragment: Some(FragmentState { module: &module, entry_point: "double_buffer_frag", + constants: &Default::default(), targets: &[Some(ColorTargetState { format: TextureFormat::Rgba8Unorm, blend: None, @@ -88,6 +90,7 @@ static PASS_RESET_VERTEX_BUFFER: GpuTestConfiguration = vertex: VertexState { module: &module, entry_point: "single_buffer_vert", + constants: &Default::default(), buffers: &[VertexBufferLayout { array_stride: 16, step_mode: VertexStepMode::Vertex, @@ -100,6 +103,7 @@ static PASS_RESET_VERTEX_BUFFER: GpuTestConfiguration = fragment: Some(FragmentState { module: &module, entry_point: "single_buffer_frag", + constants: &Default::default(), targets: &[Some(ColorTargetState { format: TextureFormat::Rgba8Unorm, blend: None, diff --git a/tests/tests/scissor_tests/mod.rs b/tests/tests/scissor_tests/mod.rs index 11b72ba7a4..efc658501d 100644 --- a/tests/tests/scissor_tests/mod.rs +++ b/tests/tests/scissor_tests/mod.rs @@ -42,16 +42,18 @@ async fn scissor_test_impl( label: Some("Pipeline"), layout: None, vertex: wgpu::VertexState { - entry_point: "vs_main", module: &shader, + entry_point: "vs_main", + constants: &Default::default(), buffers: &[], }, primitive: wgpu::PrimitiveState::default(), depth_stencil: None, multisample: wgpu::MultisampleState::default(), fragment: Some(wgpu::FragmentState { - entry_point: "fs_main", module: &shader, + entry_point: "fs_main", + constants: &Default::default(), targets: &[Some(wgpu::ColorTargetState { format: wgpu::TextureFormat::Rgba8Unorm, blend: None, diff --git a/tests/tests/shader/mod.rs b/tests/tests/shader/mod.rs index 1a981971f7..bb93c690e8 100644 --- a/tests/tests/shader/mod.rs +++ b/tests/tests/shader/mod.rs @@ -307,6 +307,7 @@ async fn shader_input_output_test( layout: Some(&pll), module: &sm, entry_point: "cs_main", + constants: &Default::default(), }); // -- Initializing data -- diff --git a/tests/tests/shader/zero_init_workgroup_mem.rs b/tests/tests/shader/zero_init_workgroup_mem.rs index 6774f1aac1..2bbcd87d90 100644 --- a/tests/tests/shader/zero_init_workgroup_mem.rs +++ b/tests/tests/shader/zero_init_workgroup_mem.rs @@ -87,6 +87,7 @@ static ZERO_INIT_WORKGROUP_MEMORY: GpuTestConfiguration = GpuTestConfiguration:: layout: Some(&pll), module: &sm, entry_point: "read", + constants: &Default::default(), }); let pipeline_write = ctx @@ -96,6 +97,7 @@ static ZERO_INIT_WORKGROUP_MEMORY: GpuTestConfiguration = GpuTestConfiguration:: layout: None, module: &sm, entry_point: "write", + constants: &Default::default(), }); // -- Initializing data -- diff --git a/tests/tests/shader_primitive_index/mod.rs b/tests/tests/shader_primitive_index/mod.rs index 096df9c0f7..fa6bbcfb53 100644 --- a/tests/tests/shader_primitive_index/mod.rs +++ b/tests/tests/shader_primitive_index/mod.rs @@ -120,6 +120,9 @@ async fn pulling_common( label: None, layout: None, vertex: wgpu::VertexState { + module: &shader, + entry_point: "vs_main", + constants: &Default::default(), buffers: &[wgpu::VertexBufferLayout { array_stride: 8, step_mode: wgpu::VertexStepMode::Vertex, @@ -129,15 +132,14 @@ async fn pulling_common( shader_location: 0, }], }], - entry_point: "vs_main", - module: &shader, }, primitive: wgpu::PrimitiveState::default(), depth_stencil: None, multisample: wgpu::MultisampleState::default(), fragment: Some(wgpu::FragmentState { - entry_point: "fs_main", module: &shader, + entry_point: "fs_main", + constants: &Default::default(), targets: &[Some(wgpu::ColorTargetState { format: wgpu::TextureFormat::Rgba8Unorm, blend: None, diff --git a/tests/tests/shader_view_format/mod.rs b/tests/tests/shader_view_format/mod.rs index 842388763b..60efa0130f 100644 --- a/tests/tests/shader_view_format/mod.rs +++ b/tests/tests/shader_view_format/mod.rs @@ -93,11 +93,13 @@ async fn reinterpret( vertex: wgpu::VertexState { module: shader, entry_point: "vs_main", + constants: &Default::default(), buffers: &[], }, fragment: Some(wgpu::FragmentState { module: shader, entry_point: "fs_main", + constants: &Default::default(), targets: &[Some(src_format.into())], }), primitive: wgpu::PrimitiveState { diff --git a/tests/tests/vertex_indices/mod.rs b/tests/tests/vertex_indices/mod.rs index e0a2cbae06..77e08489bf 100644 --- a/tests/tests/vertex_indices/mod.rs +++ b/tests/tests/vertex_indices/mod.rs @@ -272,20 +272,23 @@ async fn vertex_index_common(ctx: TestingContext) { push_constant_ranges: &[], }); + let constants = &Default::default(); let mut pipeline_desc = wgpu::RenderPipelineDescriptor { label: None, layout: Some(&ppl), vertex: wgpu::VertexState { buffers: &[], - entry_point: "vs_main_builtin", module: &shader, + entry_point: "vs_main_builtin", + constants, }, primitive: wgpu::PrimitiveState::default(), depth_stencil: None, multisample: wgpu::MultisampleState::default(), fragment: Some(wgpu::FragmentState { - entry_point: "fs_main", module: &shader, + entry_point: "fs_main", + constants, targets: &[Some(wgpu::ColorTargetState { format: wgpu::TextureFormat::Rgba8Unorm, blend: None, diff --git a/wgpu-core/src/device/resource.rs b/wgpu-core/src/device/resource.rs index 4892aecb75..c626d81937 100644 --- a/wgpu-core/src/device/resource.rs +++ b/wgpu-core/src/device/resource.rs @@ -2762,8 +2762,9 @@ impl Device { label: desc.label.to_hal(self.instance_flags), layout: pipeline_layout.raw(), stage: hal::ProgrammableStage { - entry_point: final_entry_point_name.as_ref(), module: shader_module.raw(), + entry_point: final_entry_point_name.as_ref(), + constants: desc.stage.constants.as_ref(), }, }; @@ -3178,6 +3179,7 @@ impl Device { hal::ProgrammableStage { module: vertex_shader_module.raw(), entry_point: &vertex_entry_point_name, + constants: stage_desc.constants.as_ref(), } }; @@ -3237,6 +3239,7 @@ impl Device { Some(hal::ProgrammableStage { module: shader_module.raw(), entry_point: &fragment_entry_point_name, + constants: fragment_state.stage.constants.as_ref(), }) } None => None, diff --git a/wgpu-core/src/pipeline.rs b/wgpu-core/src/pipeline.rs index 4a7651b327..b1689bd691 100644 --- a/wgpu-core/src/pipeline.rs +++ b/wgpu-core/src/pipeline.rs @@ -233,6 +233,14 @@ pub struct ProgrammableStageDescriptor<'a> { /// * If a single entry point associated with this stage must be in the shader, then proceed as /// if `Some(…)` was specified with that entry point's name. pub entry_point: Option>, + /// Specifies the values of pipeline-overridable constants in the shader module. + /// + /// If an `@id` attribute was specified on the declaration, + /// the key must be the pipeline constant ID as a decimal ASCII number; if not, + /// the key must be the constant's identifier name. + /// + /// The value may represent any of WGSL's concrete scalar types. + pub constants: Cow<'a, naga::back::PipelineConstants>, } /// Number of implicit bind groups derived at pipeline creation. diff --git a/wgpu-hal/examples/halmark/main.rs b/wgpu-hal/examples/halmark/main.rs index c238f299e7..29dfd49d28 100644 --- a/wgpu-hal/examples/halmark/main.rs +++ b/wgpu-hal/examples/halmark/main.rs @@ -245,17 +245,20 @@ impl Example { .unwrap() }; + let constants = naga::back::PipelineConstants::default(); let pipeline_desc = hal::RenderPipelineDescriptor { label: None, layout: &pipeline_layout, vertex_stage: hal::ProgrammableStage { module: &shader, entry_point: "vs_main", + constants: &constants, }, vertex_buffers: &[], fragment_stage: Some(hal::ProgrammableStage { module: &shader, entry_point: "fs_main", + constants: &constants, }), primitive: wgt::PrimitiveState { topology: wgt::PrimitiveTopology::TriangleStrip, diff --git a/wgpu-hal/examples/ray-traced-triangle/main.rs b/wgpu-hal/examples/ray-traced-triangle/main.rs index c05feae820..2ed2d64627 100644 --- a/wgpu-hal/examples/ray-traced-triangle/main.rs +++ b/wgpu-hal/examples/ray-traced-triangle/main.rs @@ -371,6 +371,7 @@ impl Example { stage: hal::ProgrammableStage { module: &shader_module, entry_point: "main", + constants: &Default::default(), }, }) } diff --git a/wgpu-hal/src/dx12/device.rs b/wgpu-hal/src/dx12/device.rs index 23bd409dc4..69a846d131 100644 --- a/wgpu-hal/src/dx12/device.rs +++ b/wgpu-hal/src/dx12/device.rs @@ -222,10 +222,13 @@ impl super::Device { //TODO: reuse the writer let mut source = String::new(); let mut writer = hlsl::Writer::new(&mut source, &layout.naga_options); + let pipeline_options = hlsl::PipelineOptions { + constants: stage.constants.to_owned(), + }; let reflection_info = { profiling::scope!("naga::back::hlsl::write"); writer - .write(module, &stage.module.naga.info) + .write(module, &stage.module.naga.info, &pipeline_options) .map_err(|e| crate::PipelineError::Linkage(stage_bit, format!("HLSL: {e:?}")))? }; diff --git a/wgpu-hal/src/gles/device.rs b/wgpu-hal/src/gles/device.rs index 50c07f3ff0..171c53a93e 100644 --- a/wgpu-hal/src/gles/device.rs +++ b/wgpu-hal/src/gles/device.rs @@ -218,6 +218,7 @@ impl super::Device { shader_stage: naga_stage, entry_point: stage.entry_point.to_string(), multiview: context.multiview, + constants: stage.constants.to_owned(), }; let shader = &stage.module.naga; diff --git a/wgpu-hal/src/lib.rs b/wgpu-hal/src/lib.rs index 79bd54e66e..a3f9ac5722 100644 --- a/wgpu-hal/src/lib.rs +++ b/wgpu-hal/src/lib.rs @@ -1318,6 +1318,8 @@ pub struct ProgrammableStage<'a, A: Api> { /// The name of the entry point in the compiled shader. There must be a function with this name /// in the shader. pub entry_point: &'a str, + /// Pipeline constants + pub constants: &'a naga::back::PipelineConstants, } // Rust gets confused about the impl requirements for `A` @@ -1326,6 +1328,7 @@ impl Clone for ProgrammableStage<'_, A> { Self { module: self.module, entry_point: self.entry_point, + constants: self.constants, } } } diff --git a/wgpu-hal/src/metal/device.rs b/wgpu-hal/src/metal/device.rs index 179429f5d7..3826909387 100644 --- a/wgpu-hal/src/metal/device.rs +++ b/wgpu-hal/src/metal/device.rs @@ -112,6 +112,7 @@ impl super::Device { metal::MTLPrimitiveTopologyClass::Point => true, _ => false, }, + constants: stage.constants.to_owned(), }; let (source, info) = naga::back::msl::write_string( diff --git a/wgpu-hal/src/vulkan/device.rs b/wgpu-hal/src/vulkan/device.rs index 70028cc700..2dcded2200 100644 --- a/wgpu-hal/src/vulkan/device.rs +++ b/wgpu-hal/src/vulkan/device.rs @@ -734,6 +734,7 @@ impl super::Device { let pipeline_options = naga::back::spv::PipelineOptions { entry_point: stage.entry_point.to_string(), shader_stage: naga_stage, + constants: stage.constants.to_owned(), }; let needs_temp_options = !runtime_checks || !binding_map.is_empty() diff --git a/wgpu/src/backend/wgpu_core.rs b/wgpu/src/backend/wgpu_core.rs index 98f1ca1de6..c73b0fbd1d 100644 --- a/wgpu/src/backend/wgpu_core.rs +++ b/wgpu/src/backend/wgpu_core.rs @@ -1143,6 +1143,7 @@ impl crate::Context for ContextWgpuCore { stage: pipe::ProgrammableStageDescriptor { module: desc.vertex.module.id.into(), entry_point: Some(Borrowed(desc.vertex.entry_point)), + constants: Borrowed(desc.vertex.constants), }, buffers: Borrowed(&vertex_buffers), }, @@ -1153,6 +1154,7 @@ impl crate::Context for ContextWgpuCore { stage: pipe::ProgrammableStageDescriptor { module: frag.module.id.into(), entry_point: Some(Borrowed(frag.entry_point)), + constants: Borrowed(frag.constants), }, targets: Borrowed(frag.targets), }), @@ -1201,6 +1203,7 @@ impl crate::Context for ContextWgpuCore { stage: pipe::ProgrammableStageDescriptor { module: desc.module.id.into(), entry_point: Some(Borrowed(desc.entry_point)), + constants: Borrowed(desc.constants), }, }; diff --git a/wgpu/src/lib.rs b/wgpu/src/lib.rs index a49c72a1ed..3bf77fd10b 100644 --- a/wgpu/src/lib.rs +++ b/wgpu/src/lib.rs @@ -28,6 +28,7 @@ use std::{ any::Any, borrow::Cow, cmp::Ordering, + collections::HashMap, error, fmt, future::Future, marker::PhantomData, @@ -1467,6 +1468,14 @@ pub struct VertexState<'a> { /// The name of the entry point in the compiled shader. There must be a function with this name /// in the shader. pub entry_point: &'a str, + /// Specifies the values of pipeline-overridable constants in the shader module. + /// + /// If an `@id` attribute was specified on the declaration, + /// the key must be the pipeline constant ID as a decimal ASCII number; if not, + /// the key must be the constant's identifier name. + /// + /// The value may represent any of WGSL's concrete scalar types. + pub constants: &'a HashMap, /// The format of any vertex buffers used with this pipeline. pub buffers: &'a [VertexBufferLayout<'a>], } @@ -1486,6 +1495,14 @@ pub struct FragmentState<'a> { /// The name of the entry point in the compiled shader. There must be a function with this name /// in the shader. pub entry_point: &'a str, + /// Specifies the values of pipeline-overridable constants in the shader module. + /// + /// If an `@id` attribute was specified on the declaration, + /// the key must be the pipeline constant ID as a decimal ASCII number; if not, + /// the key must be the constant's identifier name. + /// + /// The value may represent any of WGSL's concrete scalar types. + pub constants: &'a HashMap, /// The color state of the render targets. pub targets: &'a [Option], } @@ -1575,6 +1592,14 @@ pub struct ComputePipelineDescriptor<'a> { /// The name of the entry point in the compiled shader. There must be a function with this name /// and no return value in the shader. pub entry_point: &'a str, + /// Specifies the values of pipeline-overridable constants in the shader module. + /// + /// If an `@id` attribute was specified on the declaration, + /// the key must be the pipeline constant ID as a decimal ASCII number; if not, + /// the key must be the constant's identifier name. + /// + /// The value may represent any of WGSL's concrete scalar types. + pub constants: &'a HashMap, } #[cfg(send_sync)] static_assertions::assert_impl_all!(ComputePipelineDescriptor<'_>: Send, Sync); From 747fb30e86023f75b1f7ad13b0397c6c685c7a61 Mon Sep 17 00:00:00 2001 From: Jim Blandy Date: Wed, 6 Dec 2023 16:07:15 -0800 Subject: [PATCH 02/30] [naga] Delete `Constant::override` and `Override`. --- naga/src/front/glsl/parser_tests.rs | 2 -- naga/src/front/glsl/variables.rs | 1 - naga/src/front/spv/mod.rs | 13 ----------- naga/src/front/wgsl/lower/mod.rs | 1 - naga/src/lib.rs | 22 +++-------------- naga/src/proc/constant_evaluator.rs | 9 ------- naga/src/proc/mod.rs | 8 ++----- naga/src/valid/expression.rs | 4 +--- naga/src/valid/handles.rs | 8 +------ naga/tests/out/ir/shadow.compact.ron | 19 --------------- naga/tests/out/ir/shadow.ron | 35 ---------------------------- 11 files changed, 7 insertions(+), 115 deletions(-) diff --git a/naga/src/front/glsl/parser_tests.rs b/naga/src/front/glsl/parser_tests.rs index 259052cd27..e6e2b2c853 100644 --- a/naga/src/front/glsl/parser_tests.rs +++ b/naga/src/front/glsl/parser_tests.rs @@ -557,7 +557,6 @@ fn constants() { constants.next().unwrap().1, &Constant { name: Some("a".to_owned()), - r#override: crate::Override::None, ty: ty_handle, init: init_handle } @@ -567,7 +566,6 @@ fn constants() { constants.next().unwrap().1, &Constant { name: Some("b".to_owned()), - r#override: crate::Override::None, ty: ty_handle, init: init_handle } diff --git a/naga/src/front/glsl/variables.rs b/naga/src/front/glsl/variables.rs index 9d2e7a0e7b..0725fbd94f 100644 --- a/naga/src/front/glsl/variables.rs +++ b/naga/src/front/glsl/variables.rs @@ -472,7 +472,6 @@ impl Frontend { let constant = Constant { name: name.clone(), - r#override: crate::Override::None, ty, init, }; diff --git a/naga/src/front/spv/mod.rs b/naga/src/front/spv/mod.rs index b793448597..df8ec9363b 100644 --- a/naga/src/front/spv/mod.rs +++ b/naga/src/front/spv/mod.rs @@ -196,7 +196,6 @@ struct Decoration { location: Option, desc_set: Option, desc_index: Option, - specialization: Option, storage_buffer: bool, offset: Option, array_stride: Option, @@ -216,11 +215,6 @@ impl Decoration { } } - fn specialization(&self) -> crate::Override { - self.specialization - .map_or(crate::Override::None, crate::Override::ByNameOrId) - } - const fn resource_binding(&self) -> Option { match *self { Decoration { @@ -756,9 +750,6 @@ impl> Frontend { spirv::Decoration::RowMajor => { dec.matrix_major = Some(Majority::Row); } - spirv::Decoration::SpecId => { - dec.specialization = Some(self.next()?); - } other => { log::warn!("Unknown decoration {:?}", other); for _ in base_words + 1..inst.wc { @@ -4931,7 +4922,6 @@ impl> Frontend { LookupConstant { handle: module.constants.append( crate::Constant { - r#override: decor.specialization(), name: decor.name, ty, init, @@ -4982,7 +4972,6 @@ impl> Frontend { LookupConstant { handle: module.constants.append( crate::Constant { - r#override: decor.specialization(), name: decor.name, ty, init, @@ -5017,7 +5006,6 @@ impl> Frontend { .append(crate::Expression::ZeroValue(ty), span); let handle = module.constants.append( crate::Constant { - r#override: decor.specialization(), name: decor.name, ty, init, @@ -5056,7 +5044,6 @@ impl> Frontend { LookupConstant { handle: module.constants.append( crate::Constant { - r#override: decor.specialization(), name: decor.name, ty, init, diff --git a/naga/src/front/wgsl/lower/mod.rs b/naga/src/front/wgsl/lower/mod.rs index 2ca6c182b7..093c41e757 100644 --- a/naga/src/front/wgsl/lower/mod.rs +++ b/naga/src/front/wgsl/lower/mod.rs @@ -956,7 +956,6 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { let handle = ctx.module.constants.append( crate::Constant { name: Some(c.name.name.to_string()), - r#override: crate::Override::None, ty, init, }, diff --git a/naga/src/lib.rs b/naga/src/lib.rs index 4b45174300..7623e2d704 100644 --- a/naga/src/lib.rs +++ b/naga/src/lib.rs @@ -175,7 +175,7 @@ tree. A Naga *constant expression* is one of the following [`Expression`] variants, whose operands (if any) are also constant expressions: - [`Literal`] -- [`Constant`], for [`Constant`s][const_type] whose [`override`] is [`None`] +- [`Constant`], for [`Constant`s][const_type] whose `override` is `None` - [`ZeroValue`], for fixed-size types - [`Compose`] - [`Access`] @@ -195,7 +195,7 @@ A constant expression can be evaluated at module translation time. A Naga *override expression* is the same as a [constant expression], except that it is also allowed to refer to [`Constant`s][const_type] -whose [`override`] is something other than [`None`]. +whose `override` is something other than `None`. An override expression can be evaluated at pipeline creation time. @@ -239,8 +239,6 @@ An override expression can be evaluated at pipeline creation time. [`As`]: Expression::As [const_type]: Constant -[`override`]: Constant::override -[`None`]: Override::None [constant expression]: index.html#constant-expressions */ @@ -892,17 +890,6 @@ pub enum Literal { AbstractFloat(f64), } -#[derive(Debug, PartialEq)] -#[cfg_attr(feature = "clone", derive(Clone))] -#[cfg_attr(feature = "serialize", derive(Serialize))] -#[cfg_attr(feature = "deserialize", derive(Deserialize))] -#[cfg_attr(feature = "arbitrary", derive(Arbitrary))] -pub enum Override { - None, - ByName, - ByNameOrId(u32), -} - /// Constant value. #[derive(Debug, PartialEq)] #[cfg_attr(feature = "clone", derive(Clone))] @@ -911,7 +898,6 @@ pub enum Override { #[cfg_attr(feature = "arbitrary", derive(Arbitrary))] pub struct Constant { pub name: Option, - pub r#override: Override, pub ty: Handle, /// The value of the constant. @@ -919,12 +905,10 @@ pub struct Constant { /// This [`Handle`] refers to [`Module::const_expressions`], not /// any [`Function::expressions`] arena. /// - /// If [`override`] is [`None`], then this must be a Naga + /// If `override` is `None`, then this must be a Naga /// [constant expression]. Otherwise, this may be a Naga /// [override expression] or [constant expression]. /// - /// [`override`]: Constant::override - /// [`None`]: Override::None /// [constant expression]: index.html#constant-expressions /// [override expression]: index.html#override-expressions pub init: Handle, diff --git a/naga/src/proc/constant_evaluator.rs b/naga/src/proc/constant_evaluator.rs index 983af3718c..3a8d47325f 100644 --- a/naga/src/proc/constant_evaluator.rs +++ b/naga/src/proc/constant_evaluator.rs @@ -2059,7 +2059,6 @@ mod tests { let h = constants.append( Constant { name: None, - r#override: crate::Override::None, ty: scalar_ty, init: const_expressions .append(Expression::Literal(Literal::I32(4)), Default::default()), @@ -2070,7 +2069,6 @@ mod tests { let h1 = constants.append( Constant { name: None, - r#override: crate::Override::None, ty: scalar_ty, init: const_expressions .append(Expression::Literal(Literal::I32(8)), Default::default()), @@ -2081,7 +2079,6 @@ mod tests { let vec_h = constants.append( Constant { name: None, - r#override: crate::Override::None, ty: vec_ty, init: const_expressions.append( Expression::Compose { @@ -2180,7 +2177,6 @@ mod tests { let h = constants.append( Constant { name: None, - r#override: crate::Override::None, ty: scalar_ty, init: const_expressions .append(Expression::Literal(Literal::I32(4)), Default::default()), @@ -2267,7 +2263,6 @@ mod tests { let vec1 = constants.append( Constant { name: None, - r#override: crate::Override::None, ty: vec_ty, init: const_expressions.append( Expression::Compose { @@ -2283,7 +2278,6 @@ mod tests { let vec2 = constants.append( Constant { name: None, - r#override: crate::Override::None, ty: vec_ty, init: const_expressions.append( Expression::Compose { @@ -2299,7 +2293,6 @@ mod tests { let h = constants.append( Constant { name: None, - r#override: crate::Override::None, ty: matrix_ty, init: const_expressions.append( Expression::Compose { @@ -2395,7 +2388,6 @@ mod tests { let h = constants.append( Constant { name: None, - r#override: crate::Override::None, ty: i32_ty, init: const_expressions .append(Expression::Literal(Literal::I32(4)), Default::default()), @@ -2475,7 +2467,6 @@ mod tests { let h = constants.append( Constant { name: None, - r#override: crate::Override::None, ty: i32_ty, init: const_expressions .append(Expression::Literal(Literal::I32(4)), Default::default()), diff --git a/naga/src/proc/mod.rs b/naga/src/proc/mod.rs index 46cbb6c3b3..3823dbd1e1 100644 --- a/naga/src/proc/mod.rs +++ b/naga/src/proc/mod.rs @@ -553,13 +553,9 @@ impl crate::Expression { /// /// [`Access`]: crate::Expression::Access /// [`ResolveContext`]: crate::proc::ResolveContext - pub fn is_dynamic_index(&self, module: &crate::Module) -> bool { + pub const fn is_dynamic_index(&self) -> bool { match *self { - Self::Literal(_) | Self::ZeroValue(_) => false, - Self::Constant(handle) => { - let constant = &module.constants[handle]; - !matches!(constant.r#override, crate::Override::None) - } + Self::Literal(_) | Self::ZeroValue(_) | Self::Constant(_) => false, _ => true, } } diff --git a/naga/src/valid/expression.rs b/naga/src/valid/expression.rs index 838ecc4e27..ec4e12993c 100644 --- a/naga/src/valid/expression.rs +++ b/naga/src/valid/expression.rs @@ -252,9 +252,7 @@ impl super::Validator { return Err(ExpressionError::InvalidIndexType(index)); } } - if dynamic_indexing_restricted - && function.expressions[index].is_dynamic_index(module) - { + if dynamic_indexing_restricted && function.expressions[index].is_dynamic_index() { return Err(ExpressionError::IndexMustBeConstant(base)); } diff --git a/naga/src/valid/handles.rs b/naga/src/valid/handles.rs index e482f293bb..1884c01303 100644 --- a/naga/src/valid/handles.rs +++ b/naga/src/valid/handles.rs @@ -76,12 +76,7 @@ impl super::Validator { |handle| Self::validate_expression_handle(handle, const_expressions); for (_handle, constant) in constants.iter() { - let &crate::Constant { - name: _, - r#override: _, - ty, - init, - } = constant; + let &crate::Constant { name: _, ty, init } = constant; validate_type(ty)?; validate_const_expr(init)?; } @@ -679,7 +674,6 @@ fn constant_deps() { let self_referential_const = constants.append( Constant { name: None, - r#override: crate::Override::None, ty: i32_handle, init: fun_expr, }, diff --git a/naga/tests/out/ir/shadow.compact.ron b/naga/tests/out/ir/shadow.compact.ron index dc7b2eae78..4e65180691 100644 --- a/naga/tests/out/ir/shadow.compact.ron +++ b/naga/tests/out/ir/shadow.compact.ron @@ -159,115 +159,96 @@ constants: [ ( name: None, - override: None, ty: 1, init: 1, ), ( name: None, - override: None, ty: 1, init: 2, ), ( name: None, - override: None, ty: 1, init: 3, ), ( name: None, - override: None, ty: 1, init: 4, ), ( name: None, - override: None, ty: 1, init: 5, ), ( name: None, - override: None, ty: 2, init: 9, ), ( name: None, - override: None, ty: 3, init: 10, ), ( name: None, - override: None, ty: 3, init: 11, ), ( name: None, - override: None, ty: 3, init: 12, ), ( name: None, - override: None, ty: 7, init: 13, ), ( name: None, - override: None, ty: 7, init: 14, ), ( name: None, - override: None, ty: 7, init: 15, ), ( name: None, - override: None, ty: 7, init: 16, ), ( name: None, - override: None, ty: 7, init: 17, ), ( name: None, - override: None, ty: 7, init: 18, ), ( name: None, - override: None, ty: 7, init: 19, ), ( name: None, - override: None, ty: 7, init: 20, ), ( name: None, - override: None, ty: 7, init: 21, ), ( name: None, - override: None, ty: 7, init: 22, ), diff --git a/naga/tests/out/ir/shadow.ron b/naga/tests/out/ir/shadow.ron index 51bd3b264e..0b2310284a 100644 --- a/naga/tests/out/ir/shadow.ron +++ b/naga/tests/out/ir/shadow.ron @@ -282,211 +282,176 @@ constants: [ ( name: None, - override: None, ty: 1, init: 1, ), ( name: None, - override: None, ty: 1, init: 2, ), ( name: None, - override: None, ty: 1, init: 3, ), ( name: None, - override: None, ty: 1, init: 4, ), ( name: None, - override: None, ty: 1, init: 5, ), ( name: None, - override: None, ty: 2, init: 9, ), ( name: None, - override: None, ty: 3, init: 10, ), ( name: None, - override: None, ty: 3, init: 11, ), ( name: None, - override: None, ty: 3, init: 12, ), ( name: None, - override: None, ty: 1, init: 13, ), ( name: None, - override: None, ty: 9, init: 14, ), ( name: None, - override: None, ty: 9, init: 15, ), ( name: None, - override: None, ty: 9, init: 16, ), ( name: None, - override: None, ty: 9, init: 17, ), ( name: None, - override: None, ty: 9, init: 18, ), ( name: None, - override: None, ty: 9, init: 19, ), ( name: None, - override: None, ty: 9, init: 20, ), ( name: None, - override: None, ty: 9, init: 21, ), ( name: None, - override: None, ty: 9, init: 22, ), ( name: None, - override: None, ty: 9, init: 23, ), ( name: None, - override: None, ty: 9, init: 24, ), ( name: None, - override: None, ty: 9, init: 25, ), ( name: None, - override: None, ty: 9, init: 26, ), ( name: None, - override: None, ty: 9, init: 27, ), ( name: None, - override: None, ty: 9, init: 28, ), ( name: None, - override: None, ty: 9, init: 29, ), ( name: None, - override: None, ty: 9, init: 30, ), ( name: None, - override: None, ty: 9, init: 31, ), ( name: None, - override: None, ty: 9, init: 32, ), ( name: None, - override: None, ty: 9, init: 33, ), ( name: None, - override: None, ty: 9, init: 34, ), ( name: None, - override: None, ty: 9, init: 35, ), ( name: None, - override: None, ty: 9, init: 36, ), ( name: None, - override: None, ty: 9, init: 37, ), ( name: None, - override: None, ty: 9, init: 38, ), From e5b7df3d54020f5f8bfde8d4110e6fd558a72e93 Mon Sep 17 00:00:00 2001 From: Teodor Tanasoaia <28601907+teoxoy@users.noreply.github.com> Date: Thu, 7 Dec 2023 20:19:43 +0100 Subject: [PATCH 03/30] [wgsl-in] add support for override declarations (#4793) Co-authored-by: Jim Blandy --- naga/src/back/dot/mod.rs | 1 + naga/src/back/glsl/mod.rs | 1 + naga/src/back/hlsl/writer.rs | 3 + naga/src/back/msl/writer.rs | 3 + naga/src/back/spv/block.rs | 3 + naga/src/back/wgsl/writer.rs | 3 + naga/src/compact/expressions.rs | 9 +++ naga/src/compact/functions.rs | 2 + naga/src/compact/mod.rs | 19 ++++++ naga/src/front/spv/function.rs | 2 + naga/src/front/spv/mod.rs | 3 +- naga/src/front/wgsl/error.rs | 17 ++++-- naga/src/front/wgsl/index.rs | 1 + naga/src/front/wgsl/lower/mod.rs | 70 +++++++++++++++++++-- naga/src/front/wgsl/parse/ast.rs | 9 +++ naga/src/front/wgsl/parse/mod.rs | 30 +++++++++ naga/src/front/wgsl/to_wgsl.rs | 1 + naga/src/lib.rs | 37 +++++++---- naga/src/proc/constant_evaluator.rs | 23 ++++++- naga/src/proc/mod.rs | 2 + naga/src/proc/typifier.rs | 3 + naga/src/valid/analyzer.rs | 3 +- naga/src/valid/expression.rs | 2 +- naga/src/valid/handles.rs | 39 +++++++++++- naga/src/valid/mod.rs | 57 +++++++++++++++++ naga/tests/in/overrides.wgsl | 14 +++++ naga/tests/out/analysis/overrides.info.ron | 26 ++++++++ naga/tests/out/ir/access.compact.ron | 1 + naga/tests/out/ir/access.ron | 1 + naga/tests/out/ir/collatz.compact.ron | 1 + naga/tests/out/ir/collatz.ron | 1 + naga/tests/out/ir/overrides.compact.ron | 71 ++++++++++++++++++++++ naga/tests/out/ir/overrides.ron | 71 ++++++++++++++++++++++ naga/tests/out/ir/shadow.compact.ron | 1 + naga/tests/out/ir/shadow.ron | 1 + naga/tests/snapshots.rs | 8 +++ naga/tests/wgsl_errors.rs | 4 +- 37 files changed, 515 insertions(+), 28 deletions(-) create mode 100644 naga/tests/in/overrides.wgsl create mode 100644 naga/tests/out/analysis/overrides.info.ron create mode 100644 naga/tests/out/ir/overrides.compact.ron create mode 100644 naga/tests/out/ir/overrides.ron diff --git a/naga/src/back/dot/mod.rs b/naga/src/back/dot/mod.rs index 1556371df1..d128c855ca 100644 --- a/naga/src/back/dot/mod.rs +++ b/naga/src/back/dot/mod.rs @@ -404,6 +404,7 @@ fn write_function_expressions( let (label, color_id) = match *expression { E::Literal(_) => ("Literal".into(), 2), E::Constant(_) => ("Constant".into(), 2), + E::Override(_) => ("Override".into(), 2), E::ZeroValue(_) => ("ZeroValue".into(), 2), E::Compose { ref components, .. } => { payload = Some(Payload::Arguments(components)); diff --git a/naga/src/back/glsl/mod.rs b/naga/src/back/glsl/mod.rs index 410b69eaf9..8241b04128 100644 --- a/naga/src/back/glsl/mod.rs +++ b/naga/src/back/glsl/mod.rs @@ -2538,6 +2538,7 @@ impl<'a, W: Write> Writer<'a, W> { |writer, expr| writer.write_expr(expr, ctx), )?; } + Expression::Override(_) => return Err(Error::Custom("overrides are WIP".into())), // `Access` is applied to arrays, vectors and matrices and is written as indexing Expression::Access { base, index } => { self.write_expr(base, ctx)?; diff --git a/naga/src/back/hlsl/writer.rs b/naga/src/back/hlsl/writer.rs index b442637c43..657774d070 100644 --- a/naga/src/back/hlsl/writer.rs +++ b/naga/src/back/hlsl/writer.rs @@ -2141,6 +2141,9 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { |writer, expr| writer.write_expr(module, expr, func_ctx), )?; } + Expression::Override(_) => { + return Err(Error::Unimplemented("overrides are WIP".into())) + } // All of the multiplication can be expressed as `mul`, // except vector * vector, which needs to use the "*" operator. Expression::Binary { diff --git a/naga/src/back/msl/writer.rs b/naga/src/back/msl/writer.rs index 5227d8e7db..f8fa2c4da5 100644 --- a/naga/src/back/msl/writer.rs +++ b/naga/src/back/msl/writer.rs @@ -1431,6 +1431,9 @@ impl Writer { |writer, context, expr| writer.put_expression(expr, context, true), )?; } + crate::Expression::Override(_) => { + return Err(Error::FeatureNotImplemented("overrides are WIP".into())) + } crate::Expression::Access { base, .. } | crate::Expression::AccessIndex { base, .. } => { // This is an acceptable place to generate a `ReadZeroSkipWrite` check. diff --git a/naga/src/back/spv/block.rs b/naga/src/back/spv/block.rs index 81f2fc10e0..dcec24d7d6 100644 --- a/naga/src/back/spv/block.rs +++ b/naga/src/back/spv/block.rs @@ -239,6 +239,9 @@ impl<'w> BlockContext<'w> { let init = self.ir_module.constants[handle].init; self.writer.constant_ids[init.index()] } + crate::Expression::Override(_) => { + return Err(Error::FeatureNotImplemented("overrides are WIP")) + } crate::Expression::ZeroValue(_) => self.writer.get_constant_null(result_type_id), crate::Expression::Compose { ty, ref components } => { self.temp_list.clear(); diff --git a/naga/src/back/wgsl/writer.rs b/naga/src/back/wgsl/writer.rs index 3039cbbbe4..607954706f 100644 --- a/naga/src/back/wgsl/writer.rs +++ b/naga/src/back/wgsl/writer.rs @@ -1199,6 +1199,9 @@ impl Writer { |writer, expr| writer.write_expr(module, expr, func_ctx), )?; } + Expression::Override(_) => { + return Err(Error::Unimplemented("overrides are WIP".into())) + } Expression::FunctionArgument(pos) => { let name_key = func_ctx.argument_key(pos); let name = &self.names[&name_key]; diff --git a/naga/src/compact/expressions.rs b/naga/src/compact/expressions.rs index 301bbe3240..21c4c9cdc2 100644 --- a/naga/src/compact/expressions.rs +++ b/naga/src/compact/expressions.rs @@ -3,6 +3,7 @@ use crate::arena::{Arena, Handle}; pub struct ExpressionTracer<'tracer> { pub constants: &'tracer Arena, + pub overrides: &'tracer Arena, /// The arena in which we are currently tracing expressions. pub expressions: &'tracer Arena, @@ -88,6 +89,11 @@ impl<'tracer> ExpressionTracer<'tracer> { None => self.expressions_used.insert(init), } } + Ex::Override(_) => { + // All overrides are considered used by definition. We mark + // their types and initialization expressions as used in + // `compact::compact`, so we have no more work to do here. + } Ex::ZeroValue(ty) => self.types_used.insert(ty), Ex::Compose { ty, ref components } => { self.types_used.insert(ty); @@ -219,6 +225,9 @@ impl ModuleMap { | Ex::CallResult(_) | Ex::RayQueryProceedResult => {} + // All overrides are retained, so their handles never change. + Ex::Override(_) => {} + // Expressions that contain handles that need to be adjusted. Ex::Constant(ref mut constant) => self.constants.adjust(constant), Ex::ZeroValue(ref mut ty) => self.types.adjust(ty), diff --git a/naga/src/compact/functions.rs b/naga/src/compact/functions.rs index b0d08c7e96..98a23acee0 100644 --- a/naga/src/compact/functions.rs +++ b/naga/src/compact/functions.rs @@ -4,6 +4,7 @@ use super::{FunctionMap, ModuleMap}; pub struct FunctionTracer<'a> { pub function: &'a crate::Function, pub constants: &'a crate::Arena, + pub overrides: &'a crate::Arena, pub types_used: &'a mut HandleSet, pub constants_used: &'a mut HandleSet, @@ -47,6 +48,7 @@ impl<'a> FunctionTracer<'a> { fn as_expression(&mut self) -> super::expressions::ExpressionTracer { super::expressions::ExpressionTracer { constants: self.constants, + overrides: self.overrides, expressions: &self.function.expressions, types_used: self.types_used, diff --git a/naga/src/compact/mod.rs b/naga/src/compact/mod.rs index b4e57ed5c9..2b49d34995 100644 --- a/naga/src/compact/mod.rs +++ b/naga/src/compact/mod.rs @@ -54,6 +54,14 @@ pub fn compact(module: &mut crate::Module) { } } + // We treat all overrides as used by definition. + for (_, override_) in module.overrides.iter() { + module_tracer.types_used.insert(override_.ty); + if let Some(init) = override_.init { + module_tracer.const_expressions_used.insert(init); + } + } + // We assume that all functions are used. // // Observe which types, constant expressions, constants, and @@ -158,6 +166,15 @@ pub fn compact(module: &mut crate::Module) { } }); + // Adjust override types and initializers. + log::trace!("adjusting overrides"); + for (_, override_) in module.overrides.iter_mut() { + module_map.types.adjust(&mut override_.ty); + if let Some(init) = override_.init.as_mut() { + module_map.const_expressions.adjust(init); + } + } + // Adjust global variables' types and initializers. log::trace!("adjusting global variables"); for (_, global) in module.global_variables.iter_mut() { @@ -235,6 +252,7 @@ impl<'module> ModuleTracer<'module> { expressions::ExpressionTracer { expressions: &self.module.const_expressions, constants: &self.module.constants, + overrides: &self.module.overrides, types_used: &mut self.types_used, constants_used: &mut self.constants_used, expressions_used: &mut self.const_expressions_used, @@ -249,6 +267,7 @@ impl<'module> ModuleTracer<'module> { FunctionTracer { function, constants: &self.module.constants, + overrides: &self.module.overrides, types_used: &mut self.types_used, constants_used: &mut self.constants_used, const_expressions_used: &mut self.const_expressions_used, diff --git a/naga/src/front/spv/function.rs b/naga/src/front/spv/function.rs index e81ecf5c9b..7fefef02a2 100644 --- a/naga/src/front/spv/function.rs +++ b/naga/src/front/spv/function.rs @@ -128,6 +128,7 @@ impl> super::Frontend { expressions: &mut fun.expressions, local_arena: &mut fun.local_variables, const_arena: &mut module.constants, + overrides: &mut module.overrides, const_expressions: &mut module.const_expressions, type_arena: &module.types, global_arena: &module.global_variables, @@ -581,6 +582,7 @@ impl<'function> BlockContext<'function> { crate::proc::GlobalCtx { types: self.type_arena, constants: self.const_arena, + overrides: self.overrides, const_expressions: self.const_expressions, } } diff --git a/naga/src/front/spv/mod.rs b/naga/src/front/spv/mod.rs index df8ec9363b..697cbb7b4e 100644 --- a/naga/src/front/spv/mod.rs +++ b/naga/src/front/spv/mod.rs @@ -531,6 +531,7 @@ struct BlockContext<'function> { local_arena: &'function mut Arena, /// Constants arena of the module being processed const_arena: &'function mut Arena, + overrides: &'function mut Arena, const_expressions: &'function mut Arena, /// Type arena of the module being processed type_arena: &'function UniqueArena, @@ -3934,7 +3935,7 @@ impl> Frontend { Op::TypeImage => self.parse_type_image(inst, &mut module), Op::TypeSampledImage => self.parse_type_sampled_image(inst), Op::TypeSampler => self.parse_type_sampler(inst, &mut module), - Op::Constant | Op::SpecConstant => self.parse_constant(inst, &mut module), + Op::Constant => self.parse_constant(inst, &mut module), Op::ConstantComposite => self.parse_composite_constant(inst, &mut module), Op::ConstantNull | Op::Undef => self.parse_null_constant(inst, &mut module), Op::ConstantTrue => self.parse_bool_constant(inst, true, &mut module), diff --git a/naga/src/front/wgsl/error.rs b/naga/src/front/wgsl/error.rs index f0b55c70fd..24e6c9f8c5 100644 --- a/naga/src/front/wgsl/error.rs +++ b/naga/src/front/wgsl/error.rs @@ -190,7 +190,7 @@ pub enum Error<'a> { expected: String, got: String, }, - MissingType(Span), + DeclMissingTypeAndInit(Span), MissingAttribute(&'static str, Span), InvalidAtomicPointer(Span), InvalidAtomicOperandType(Span), @@ -273,6 +273,7 @@ pub enum Error<'a> { span: Span, limit: u8, }, + PipelineConstantIDValue(Span), } impl<'a> Error<'a> { @@ -522,11 +523,11 @@ impl<'a> Error<'a> { notes: vec![], } } - Error::MissingType(name_span) => ParseError { - message: format!("variable `{}` needs a type", &source[name_span]), + Error::DeclMissingTypeAndInit(name_span) => ParseError { + message: format!("declaration of `{}` needs a type specifier or initializer", &source[name_span]), labels: vec![( name_span, - format!("definition of `{}`", &source[name_span]).into(), + "needs a type specifier or initializer".into(), )], notes: vec![], }, @@ -781,6 +782,14 @@ impl<'a> Error<'a> { format!("nesting limit is currently set to {limit}"), ], }, + Error::PipelineConstantIDValue(span) => ParseError { + message: "pipeline constant ID must be between 0 and 65535 inclusive".to_string(), + labels: vec![( + span, + "must be between 0 and 65535 inclusive".into(), + )], + notes: vec![], + }, } } } diff --git a/naga/src/front/wgsl/index.rs b/naga/src/front/wgsl/index.rs index a5524fe8f1..593405508f 100644 --- a/naga/src/front/wgsl/index.rs +++ b/naga/src/front/wgsl/index.rs @@ -187,6 +187,7 @@ const fn decl_ident<'a>(decl: &ast::GlobalDecl<'a>) -> ast::Ident<'a> { ast::GlobalDeclKind::Fn(ref f) => f.name, ast::GlobalDeclKind::Var(ref v) => v.name, ast::GlobalDeclKind::Const(ref c) => c.name, + ast::GlobalDeclKind::Override(ref o) => o.name, ast::GlobalDeclKind::Struct(ref s) => s.name, ast::GlobalDeclKind::Type(ref t) => t.name, } diff --git a/naga/src/front/wgsl/lower/mod.rs b/naga/src/front/wgsl/lower/mod.rs index 093c41e757..553633ff3f 100644 --- a/naga/src/front/wgsl/lower/mod.rs +++ b/naga/src/front/wgsl/lower/mod.rs @@ -786,6 +786,7 @@ enum LoweredGlobalDecl { Function(Handle), Var(Handle), Const(Handle), + Override(Handle), Type(Handle), EntryPoint, } @@ -965,6 +966,65 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { ctx.globals .insert(c.name.name, LoweredGlobalDecl::Const(handle)); } + ast::GlobalDeclKind::Override(ref o) => { + let init = o + .init + .map(|init| self.expression(init, &mut ctx.as_const())) + .transpose()?; + let inferred_type = init + .map(|init| ctx.as_const().register_type(init)) + .transpose()?; + + let explicit_ty = + o.ty.map(|ty| self.resolve_ast_type(ty, &mut ctx)) + .transpose()?; + + let id = + o.id.map(|id| self.const_u32(id, &mut ctx.as_const())) + .transpose()?; + + let id = if let Some((id, id_span)) = id { + Some( + u16::try_from(id) + .map_err(|_| Error::PipelineConstantIDValue(id_span))?, + ) + } else { + None + }; + + let ty = match (explicit_ty, inferred_type) { + (Some(explicit_ty), Some(inferred_type)) => { + if explicit_ty == inferred_type { + explicit_ty + } else { + let gctx = ctx.module.to_ctx(); + return Err(Error::InitializationTypeMismatch { + name: o.name.span, + expected: explicit_ty.to_wgsl(&gctx), + got: inferred_type.to_wgsl(&gctx), + }); + } + } + (Some(explicit_ty), None) => explicit_ty, + (None, Some(inferred_type)) => inferred_type, + (None, None) => { + return Err(Error::DeclMissingTypeAndInit(o.name.span)); + } + }; + + let handle = ctx.module.overrides.append( + crate::Override { + name: Some(o.name.name.to_string()), + id, + ty, + init, + }, + span, + ); + + ctx.globals + .insert(o.name.name, LoweredGlobalDecl::Override(handle)); + } ast::GlobalDeclKind::Struct(ref s) => { let handle = self.r#struct(s, span, &mut ctx)?; ctx.globals @@ -1202,7 +1262,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { ty = explicit_ty; initializer = None; } - (None, None) => return Err(Error::MissingType(v.name.span)), + (None, None) => return Err(Error::DeclMissingTypeAndInit(v.name.span)), } let (const_initializer, initializer) = { @@ -1818,9 +1878,11 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { )?; Ok(Some(handle)) } - Some(&LoweredGlobalDecl::Const(_) | &LoweredGlobalDecl::Var(_)) => { - Err(Error::Unexpected(function.span, ExpectedToken::Function)) - } + Some( + &LoweredGlobalDecl::Const(_) + | &LoweredGlobalDecl::Override(_) + | &LoweredGlobalDecl::Var(_), + ) => Err(Error::Unexpected(function.span, ExpectedToken::Function)), Some(&LoweredGlobalDecl::EntryPoint) => Err(Error::CalledEntryPoint(function.span)), Some(&LoweredGlobalDecl::Function(function)) => { let arguments = arguments diff --git a/naga/src/front/wgsl/parse/ast.rs b/naga/src/front/wgsl/parse/ast.rs index dbaac523cb..ea8013ee7c 100644 --- a/naga/src/front/wgsl/parse/ast.rs +++ b/naga/src/front/wgsl/parse/ast.rs @@ -82,6 +82,7 @@ pub enum GlobalDeclKind<'a> { Fn(Function<'a>), Var(GlobalVariable<'a>), Const(Const<'a>), + Override(Override<'a>), Struct(Struct<'a>), Type(TypeAlias<'a>), } @@ -200,6 +201,14 @@ pub struct Const<'a> { pub init: Handle>, } +#[derive(Debug)] +pub struct Override<'a> { + pub name: Ident<'a>, + pub id: Option>>, + pub ty: Option>>, + pub init: Option>>, +} + /// The size of an [`Array`] or [`BindingArray`]. /// /// [`Array`]: Type::Array diff --git a/naga/src/front/wgsl/parse/mod.rs b/naga/src/front/wgsl/parse/mod.rs index 6724eb95f9..79ea1ae609 100644 --- a/naga/src/front/wgsl/parse/mod.rs +++ b/naga/src/front/wgsl/parse/mod.rs @@ -2180,6 +2180,7 @@ impl Parser { let mut early_depth_test = ParsedAttribute::default(); let (mut bind_index, mut bind_group) = (ParsedAttribute::default(), ParsedAttribute::default()); + let mut id = ParsedAttribute::default(); let mut dependencies = FastIndexSet::default(); let mut ctx = ExpressionContext { @@ -2203,6 +2204,11 @@ impl Parser { bind_group.set(self.general_expression(lexer, &mut ctx)?, name_span)?; lexer.expect(Token::Paren(')'))?; } + ("id", name_span) => { + lexer.expect(Token::Paren('('))?; + id.set(self.general_expression(lexer, &mut ctx)?, name_span)?; + lexer.expect(Token::Paren(')'))?; + } ("vertex", name_span) => { stage.set(crate::ShaderStage::Vertex, name_span)?; } @@ -2293,6 +2299,30 @@ impl Parser { Some(ast::GlobalDeclKind::Const(ast::Const { name, ty, init })) } + (Token::Word("override"), _) => { + let name = lexer.next_ident()?; + + let ty = if lexer.skip(Token::Separator(':')) { + Some(self.type_decl(lexer, &mut ctx)?) + } else { + None + }; + + let init = if lexer.skip(Token::Operation('=')) { + Some(self.general_expression(lexer, &mut ctx)?) + } else { + None + }; + + lexer.expect(Token::Separator(';'))?; + + Some(ast::GlobalDeclKind::Override(ast::Override { + name, + id: id.value, + ty, + init, + })) + } (Token::Word("var"), _) => { let mut var = self.variable_decl(lexer, &mut ctx)?; var.binding = binding.take(); diff --git a/naga/src/front/wgsl/to_wgsl.rs b/naga/src/front/wgsl/to_wgsl.rs index c8331ace09..ba6063ab46 100644 --- a/naga/src/front/wgsl/to_wgsl.rs +++ b/naga/src/front/wgsl/to_wgsl.rs @@ -226,6 +226,7 @@ mod tests { let gctx = crate::proc::GlobalCtx { types: &types, constants: &crate::Arena::new(), + overrides: &crate::Arena::new(), const_expressions: &crate::Arena::new(), }; let array = crate::TypeInner::Array { diff --git a/naga/src/lib.rs b/naga/src/lib.rs index 7623e2d704..e55d4bb280 100644 --- a/naga/src/lib.rs +++ b/naga/src/lib.rs @@ -175,7 +175,7 @@ tree. A Naga *constant expression* is one of the following [`Expression`] variants, whose operands (if any) are also constant expressions: - [`Literal`] -- [`Constant`], for [`Constant`s][const_type] whose `override` is `None` +- [`Constant`], for [`Constant`]s - [`ZeroValue`], for fixed-size types - [`Compose`] - [`Access`] @@ -194,8 +194,7 @@ A constant expression can be evaluated at module translation time. ## Override expressions A Naga *override expression* is the same as a [constant expression], -except that it is also allowed to refer to [`Constant`s][const_type] -whose `override` is something other than `None`. +except that it is also allowed to reference other [`Override`]s. An override expression can be evaluated at pipeline creation time. @@ -238,8 +237,6 @@ An override expression can be evaluated at pipeline creation time. [`Math`]: Expression::Math [`As`]: Expression::As -[const_type]: Constant - [constant expression]: index.html#constant-expressions */ @@ -890,6 +887,25 @@ pub enum Literal { AbstractFloat(f64), } +/// Pipeline-overridable constant. +#[derive(Debug, PartialEq)] +#[cfg_attr(feature = "clone", derive(Clone))] +#[cfg_attr(feature = "serialize", derive(Serialize))] +#[cfg_attr(feature = "deserialize", derive(Deserialize))] +#[cfg_attr(feature = "arbitrary", derive(Arbitrary))] +pub struct Override { + pub name: Option, + /// Pipeline Constant ID. + pub id: Option, + pub ty: Handle, + + /// The default value of the pipeline-overridable constant. + /// + /// This [`Handle`] refers to [`Module::const_expressions`], not + /// any [`Function::expressions`] arena. + pub init: Option>, +} + /// Constant value. #[derive(Debug, PartialEq)] #[cfg_attr(feature = "clone", derive(Clone))] @@ -904,13 +920,6 @@ pub struct Constant { /// /// This [`Handle`] refers to [`Module::const_expressions`], not /// any [`Function::expressions`] arena. - /// - /// If `override` is `None`, then this must be a Naga - /// [constant expression]. Otherwise, this may be a Naga - /// [override expression] or [constant expression]. - /// - /// [constant expression]: index.html#constant-expressions - /// [override expression]: index.html#override-expressions pub init: Handle, } @@ -1299,6 +1308,8 @@ pub enum Expression { Literal(Literal), /// Constant value. Constant(Handle), + /// Pipeline-overridable constant. + Override(Handle), /// Zero value of a type. ZeroValue(Handle), /// Composite expression. @@ -2053,6 +2064,8 @@ pub struct Module { pub special_types: SpecialTypes, /// Arena for the constants defined in this module. pub constants: Arena, + /// Arena for the pipeline-overridable constants defined in this module. + pub overrides: Arena, /// Arena for the global variables defined in this module. pub global_variables: Arena, /// [Constant expressions] and [override expressions] used by this module. diff --git a/naga/src/proc/constant_evaluator.rs b/naga/src/proc/constant_evaluator.rs index 3a8d47325f..8a9da04d33 100644 --- a/naga/src/proc/constant_evaluator.rs +++ b/naga/src/proc/constant_evaluator.rs @@ -4,8 +4,8 @@ use arrayvec::ArrayVec; use crate::{ arena::{Arena, Handle, UniqueArena}, - ArraySize, BinaryOperator, Constant, Expression, Literal, ScalarKind, Span, Type, TypeInner, - UnaryOperator, + ArraySize, BinaryOperator, Constant, Expression, Literal, Override, ScalarKind, Span, Type, + TypeInner, UnaryOperator, }; /// A macro that allows dollar signs (`$`) to be emitted by other macros. Useful for generating @@ -291,6 +291,9 @@ pub struct ConstantEvaluator<'a> { /// The module's constant arena. constants: &'a Arena, + /// The module's override arena. + overrides: &'a Arena, + /// The arena to which we are contributing expressions. expressions: &'a mut Arena, @@ -456,6 +459,7 @@ impl<'a> ConstantEvaluator<'a> { behavior, types: &mut module.types, constants: &module.constants, + overrides: &module.overrides, expressions: &mut module.const_expressions, function_local_data: None, } @@ -515,6 +519,7 @@ impl<'a> ConstantEvaluator<'a> { behavior, types: &mut module.types, constants: &module.constants, + overrides: &module.overrides, expressions, function_local_data: Some(FunctionLocalData { const_expressions: &module.const_expressions, @@ -529,6 +534,7 @@ impl<'a> ConstantEvaluator<'a> { crate::proc::GlobalCtx { types: self.types, constants: self.constants, + overrides: self.overrides, const_expressions: match self.function_local_data { Some(ref data) => data.const_expressions, None => self.expressions, @@ -605,6 +611,9 @@ impl<'a> ConstantEvaluator<'a> { // This is mainly done to avoid having constants pointing to other constants. Ok(self.constants[c].init) } + Expression::Override(_) => Err(ConstantEvaluatorError::NotImplemented( + "overrides are WIP".into(), + )), Expression::Literal(_) | Expression::ZeroValue(_) | Expression::Constant(_) => { self.register_evaluated_expr(expr.clone(), span) } @@ -2035,6 +2044,7 @@ mod tests { fn unary_op() { let mut types = UniqueArena::new(); let mut constants = Arena::new(); + let overrides = Arena::new(); let mut const_expressions = Arena::new(); let scalar_ty = types.insert( @@ -2113,6 +2123,7 @@ mod tests { behavior: Behavior::Wgsl, types: &mut types, constants: &constants, + overrides: &overrides, expressions: &mut const_expressions, function_local_data: None, }; @@ -2164,6 +2175,7 @@ mod tests { fn cast() { let mut types = UniqueArena::new(); let mut constants = Arena::new(); + let overrides = Arena::new(); let mut const_expressions = Arena::new(); let scalar_ty = types.insert( @@ -2196,6 +2208,7 @@ mod tests { behavior: Behavior::Wgsl, types: &mut types, constants: &constants, + overrides: &overrides, expressions: &mut const_expressions, function_local_data: None, }; @@ -2214,6 +2227,7 @@ mod tests { fn access() { let mut types = UniqueArena::new(); let mut constants = Arena::new(); + let overrides = Arena::new(); let mut const_expressions = Arena::new(); let matrix_ty = types.insert( @@ -2311,6 +2325,7 @@ mod tests { behavior: Behavior::Wgsl, types: &mut types, constants: &constants, + overrides: &overrides, expressions: &mut const_expressions, function_local_data: None, }; @@ -2364,6 +2379,7 @@ mod tests { fn compose_of_constants() { let mut types = UniqueArena::new(); let mut constants = Arena::new(); + let overrides = Arena::new(); let mut const_expressions = Arena::new(); let i32_ty = types.insert( @@ -2401,6 +2417,7 @@ mod tests { behavior: Behavior::Wgsl, types: &mut types, constants: &constants, + overrides: &overrides, expressions: &mut const_expressions, function_local_data: None, }; @@ -2443,6 +2460,7 @@ mod tests { fn splat_of_constant() { let mut types = UniqueArena::new(); let mut constants = Arena::new(); + let overrides = Arena::new(); let mut const_expressions = Arena::new(); let i32_ty = types.insert( @@ -2480,6 +2498,7 @@ mod tests { behavior: Behavior::Wgsl, types: &mut types, constants: &constants, + overrides: &overrides, expressions: &mut const_expressions, function_local_data: None, }; diff --git a/naga/src/proc/mod.rs b/naga/src/proc/mod.rs index 3823dbd1e1..ddb42a2c52 100644 --- a/naga/src/proc/mod.rs +++ b/naga/src/proc/mod.rs @@ -648,6 +648,7 @@ impl crate::Module { GlobalCtx { types: &self.types, constants: &self.constants, + overrides: &self.overrides, const_expressions: &self.const_expressions, } } @@ -663,6 +664,7 @@ pub(super) enum U32EvalError { pub struct GlobalCtx<'a> { pub types: &'a crate::UniqueArena, pub constants: &'a crate::Arena, + pub overrides: &'a crate::Arena, pub const_expressions: &'a crate::Arena, } diff --git a/naga/src/proc/typifier.rs b/naga/src/proc/typifier.rs index 9c4403445c..845b35cb4d 100644 --- a/naga/src/proc/typifier.rs +++ b/naga/src/proc/typifier.rs @@ -185,6 +185,7 @@ pub enum ResolveError { pub struct ResolveContext<'a> { pub constants: &'a Arena, + pub overrides: &'a Arena, pub types: &'a UniqueArena, pub special_types: &'a crate::SpecialTypes, pub global_vars: &'a Arena, @@ -202,6 +203,7 @@ impl<'a> ResolveContext<'a> { ) -> Self { Self { constants: &module.constants, + overrides: &module.overrides, types: &module.types, special_types: &module.special_types, global_vars: &module.global_variables, @@ -407,6 +409,7 @@ impl<'a> ResolveContext<'a> { }, crate::Expression::Literal(lit) => TypeResolution::Value(lit.ty_inner()), crate::Expression::Constant(h) => TypeResolution::Handle(self.constants[h].ty), + crate::Expression::Override(h) => TypeResolution::Handle(self.overrides[h].ty), crate::Expression::ZeroValue(ty) => TypeResolution::Handle(ty), crate::Expression::Compose { ty, .. } => TypeResolution::Handle(ty), crate::Expression::FunctionArgument(index) => { diff --git a/naga/src/valid/analyzer.rs b/naga/src/valid/analyzer.rs index 03fbc4089b..84f57f6c8a 100644 --- a/naga/src/valid/analyzer.rs +++ b/naga/src/valid/analyzer.rs @@ -574,7 +574,7 @@ impl FunctionInfo { non_uniform_result: self.add_ref(vector), requirements: UniformityRequirements::empty(), }, - E::Literal(_) | E::Constant(_) | E::ZeroValue(_) => Uniformity::new(), + E::Literal(_) | E::Constant(_) | E::Override(_) | E::ZeroValue(_) => Uniformity::new(), E::Compose { ref components, .. } => { let non_uniform_result = components .iter() @@ -1186,6 +1186,7 @@ fn uniform_control_flow() { }; let resolve_context = ResolveContext { constants: &Arena::new(), + overrides: &Arena::new(), types: &type_arena, special_types: &crate::SpecialTypes::default(), global_vars: &global_var_arena, diff --git a/naga/src/valid/expression.rs b/naga/src/valid/expression.rs index ec4e12993c..7b259d69f9 100644 --- a/naga/src/valid/expression.rs +++ b/naga/src/valid/expression.rs @@ -345,7 +345,7 @@ impl super::Validator { self.validate_literal(literal)?; ShaderStages::all() } - E::Constant(_) | E::ZeroValue(_) => ShaderStages::all(), + E::Constant(_) | E::Override(_) | E::ZeroValue(_) => ShaderStages::all(), E::Compose { ref components, ty } => { validate_compose( ty, diff --git a/naga/src/valid/handles.rs b/naga/src/valid/handles.rs index 1884c01303..0643b1c9f5 100644 --- a/naga/src/valid/handles.rs +++ b/naga/src/valid/handles.rs @@ -31,6 +31,7 @@ impl super::Validator { pub(super) fn validate_module_handles(module: &crate::Module) -> Result<(), ValidationError> { let &crate::Module { ref constants, + ref overrides, ref entry_points, ref functions, ref global_variables, @@ -68,7 +69,7 @@ impl super::Validator { } for handle_and_expr in const_expressions.iter() { - Self::validate_const_expression_handles(handle_and_expr, constants, types)?; + Self::validate_const_expression_handles(handle_and_expr, constants, overrides, types)?; } let validate_type = |handle| Self::validate_type_handle(handle, types); @@ -81,6 +82,19 @@ impl super::Validator { validate_const_expr(init)?; } + for (_handle, override_) in overrides.iter() { + let &crate::Override { + name: _, + id: _, + ty, + init, + } = override_; + validate_type(ty)?; + if let Some(init_expr) = init { + validate_const_expr(init_expr)?; + } + } + for (_handle, global_variable) in global_variables.iter() { let &crate::GlobalVariable { name: _, @@ -135,6 +149,7 @@ impl super::Validator { Self::validate_expression_handles( handle_and_expr, constants, + overrides, const_expressions, types, local_variables, @@ -181,6 +196,13 @@ impl super::Validator { handle.check_valid_for(constants).map(|_| ()) } + fn validate_override_handle( + handle: Handle, + overrides: &Arena, + ) -> Result<(), InvalidHandleError> { + handle.check_valid_for(overrides).map(|_| ()) + } + fn validate_expression_handle( handle: Handle, expressions: &Arena, @@ -198,9 +220,11 @@ impl super::Validator { fn validate_const_expression_handles( (handle, expression): (Handle, &crate::Expression), constants: &Arena, + overrides: &Arena, types: &UniqueArena, ) -> Result<(), InvalidHandleError> { let validate_constant = |handle| Self::validate_constant_handle(handle, constants); + let validate_override = |handle| Self::validate_override_handle(handle, overrides); let validate_type = |handle| Self::validate_type_handle(handle, types); match *expression { @@ -209,6 +233,12 @@ impl super::Validator { validate_constant(constant)?; handle.check_dep(constants[constant].init)?; } + crate::Expression::Override(override_) => { + validate_override(override_)?; + if let Some(init) = overrides[override_].init { + handle.check_dep(init)?; + } + } crate::Expression::ZeroValue(ty) => { validate_type(ty)?; } @@ -225,6 +255,7 @@ impl super::Validator { fn validate_expression_handles( (handle, expression): (Handle, &crate::Expression), constants: &Arena, + overrides: &Arena, const_expressions: &Arena, types: &UniqueArena, local_variables: &Arena, @@ -234,6 +265,7 @@ impl super::Validator { current_function: Option>, ) -> Result<(), InvalidHandleError> { let validate_constant = |handle| Self::validate_constant_handle(handle, constants); + let validate_override = |handle| Self::validate_override_handle(handle, overrides); let validate_const_expr = |handle| Self::validate_expression_handle(handle, const_expressions); let validate_type = |handle| Self::validate_type_handle(handle, types); @@ -255,6 +287,9 @@ impl super::Validator { crate::Expression::Constant(constant) => { validate_constant(constant)?; } + crate::Expression::Override(override_) => { + validate_override(override_)?; + } crate::Expression::ZeroValue(ty) => { validate_type(ty)?; } @@ -659,6 +694,7 @@ fn constant_deps() { let mut const_exprs = Arena::new(); let mut fun_exprs = Arena::new(); let mut constants = Arena::new(); + let overrides = Arena::new(); let i32_handle = types.insert( Type { @@ -686,6 +722,7 @@ fn constant_deps() { assert!(super::Validator::validate_const_expression_handles( handle_and_expr, &constants, + &overrides, &types, ) .is_err()); diff --git a/naga/src/valid/mod.rs b/naga/src/valid/mod.rs index 5459434f33..d54079ac13 100644 --- a/naga/src/valid/mod.rs +++ b/naga/src/valid/mod.rs @@ -184,6 +184,16 @@ pub enum ConstantError { NonConstructibleType, } +#[derive(Clone, Debug, thiserror::Error)] +pub enum OverrideError { + #[error("The type doesn't match the override")] + InvalidType, + #[error("The type is not constructible")] + NonConstructibleType, + #[error("The type is not a scalar")] + TypeNotScalar, +} + #[derive(Clone, Debug, thiserror::Error)] pub enum ValidationError { #[error(transparent)] @@ -207,6 +217,12 @@ pub enum ValidationError { name: String, source: ConstantError, }, + #[error("Override {handle:?} '{name}' is invalid")] + Override { + handle: Handle, + name: String, + source: OverrideError, + }, #[error("Global variable {handle:?} '{name}' is invalid")] GlobalVariable { handle: Handle, @@ -329,6 +345,35 @@ impl Validator { Ok(()) } + fn validate_override( + &self, + handle: Handle, + gctx: crate::proc::GlobalCtx, + mod_info: &ModuleInfo, + ) -> Result<(), OverrideError> { + let o = &gctx.overrides[handle]; + + let type_info = &self.types[o.ty.index()]; + if !type_info.flags.contains(TypeFlags::CONSTRUCTIBLE) { + return Err(OverrideError::NonConstructibleType); + } + + let decl_ty = &gctx.types[o.ty].inner; + match decl_ty { + &crate::TypeInner::Scalar(_) => {} + _ => return Err(OverrideError::TypeNotScalar), + } + + if let Some(init) = o.init { + let init_ty = mod_info[init].inner_with(gctx.types); + if !decl_ty.equivalent(init_ty, gctx.types) { + return Err(OverrideError::InvalidType); + } + } + + Ok(()) + } + /// Check the given module to be valid. pub fn validate( &mut self, @@ -406,6 +451,18 @@ impl Validator { .with_span_handle(handle, &module.constants) })? } + + for (handle, override_) in module.overrides.iter() { + self.validate_override(handle, module.to_ctx(), &mod_info) + .map_err(|source| { + ValidationError::Override { + handle, + name: override_.name.clone().unwrap_or_default(), + source, + } + .with_span_handle(handle, &module.overrides) + })? + } } for (var_handle, var) in module.global_variables.iter() { diff --git a/naga/tests/in/overrides.wgsl b/naga/tests/in/overrides.wgsl new file mode 100644 index 0000000000..803269a656 --- /dev/null +++ b/naga/tests/in/overrides.wgsl @@ -0,0 +1,14 @@ +@id(0) override has_point_light: bool = true; // Algorithmic control +@id(1200) override specular_param: f32 = 2.3; // Numeric control +@id(1300) override gain: f32; // Must be overridden + override width: f32 = 0.0; // Specified at the API level using + // the name "width". + override depth: f32; // Specified at the API level using + // the name "depth". + // Must be overridden. + // override height = 2 * depth; // The default value + // (if not set at the API level), + // depends on another + // overridable constant. + +override inferred_f32 = 2.718; diff --git a/naga/tests/out/analysis/overrides.info.ron b/naga/tests/out/analysis/overrides.info.ron new file mode 100644 index 0000000000..9ad1b3914e --- /dev/null +++ b/naga/tests/out/analysis/overrides.info.ron @@ -0,0 +1,26 @@ +( + type_flags: [ + ("DATA | SIZED | COPY | ARGUMENT | CONSTRUCTIBLE"), + ("DATA | SIZED | COPY | IO_SHAREABLE | HOST_SHAREABLE | ARGUMENT | CONSTRUCTIBLE"), + ], + functions: [], + entry_points: [], + const_expression_types: [ + Value(Scalar(( + kind: Bool, + width: 1, + ))), + Value(Scalar(( + kind: Float, + width: 4, + ))), + Value(Scalar(( + kind: Float, + width: 4, + ))), + Value(Scalar(( + kind: Float, + width: 4, + ))), + ], +) \ No newline at end of file diff --git a/naga/tests/out/ir/access.compact.ron b/naga/tests/out/ir/access.compact.ron index 0670534e90..37ace5283f 100644 --- a/naga/tests/out/ir/access.compact.ron +++ b/naga/tests/out/ir/access.compact.ron @@ -324,6 +324,7 @@ predeclared_types: {}, ), constants: [], + overrides: [], global_variables: [ ( name: Some("global_const"), diff --git a/naga/tests/out/ir/access.ron b/naga/tests/out/ir/access.ron index 0670534e90..37ace5283f 100644 --- a/naga/tests/out/ir/access.ron +++ b/naga/tests/out/ir/access.ron @@ -324,6 +324,7 @@ predeclared_types: {}, ), constants: [], + overrides: [], global_variables: [ ( name: Some("global_const"), diff --git a/naga/tests/out/ir/collatz.compact.ron b/naga/tests/out/ir/collatz.compact.ron index cfc3bfa0ee..fe4af55c1b 100644 --- a/naga/tests/out/ir/collatz.compact.ron +++ b/naga/tests/out/ir/collatz.compact.ron @@ -46,6 +46,7 @@ predeclared_types: {}, ), constants: [], + overrides: [], global_variables: [ ( name: Some("v_indices"), diff --git a/naga/tests/out/ir/collatz.ron b/naga/tests/out/ir/collatz.ron index cfc3bfa0ee..fe4af55c1b 100644 --- a/naga/tests/out/ir/collatz.ron +++ b/naga/tests/out/ir/collatz.ron @@ -46,6 +46,7 @@ predeclared_types: {}, ), constants: [], + overrides: [], global_variables: [ ( name: Some("v_indices"), diff --git a/naga/tests/out/ir/overrides.compact.ron b/naga/tests/out/ir/overrides.compact.ron new file mode 100644 index 0000000000..5ac9ade6f6 --- /dev/null +++ b/naga/tests/out/ir/overrides.compact.ron @@ -0,0 +1,71 @@ +( + types: [ + ( + name: None, + inner: Scalar(( + kind: Bool, + width: 1, + )), + ), + ( + name: None, + inner: Scalar(( + kind: Float, + width: 4, + )), + ), + ], + special_types: ( + ray_desc: None, + ray_intersection: None, + predeclared_types: {}, + ), + constants: [], + overrides: [ + ( + name: Some("has_point_light"), + id: Some(0), + ty: 1, + init: Some(1), + ), + ( + name: Some("specular_param"), + id: Some(1200), + ty: 2, + init: Some(2), + ), + ( + name: Some("gain"), + id: Some(1300), + ty: 2, + init: None, + ), + ( + name: Some("width"), + id: None, + ty: 2, + init: Some(3), + ), + ( + name: Some("depth"), + id: None, + ty: 2, + init: None, + ), + ( + name: Some("inferred_f32"), + id: None, + ty: 2, + init: Some(4), + ), + ], + global_variables: [], + const_expressions: [ + Literal(Bool(true)), + Literal(F32(2.3)), + Literal(F32(0.0)), + Literal(F32(2.718)), + ], + functions: [], + entry_points: [], +) \ No newline at end of file diff --git a/naga/tests/out/ir/overrides.ron b/naga/tests/out/ir/overrides.ron new file mode 100644 index 0000000000..5ac9ade6f6 --- /dev/null +++ b/naga/tests/out/ir/overrides.ron @@ -0,0 +1,71 @@ +( + types: [ + ( + name: None, + inner: Scalar(( + kind: Bool, + width: 1, + )), + ), + ( + name: None, + inner: Scalar(( + kind: Float, + width: 4, + )), + ), + ], + special_types: ( + ray_desc: None, + ray_intersection: None, + predeclared_types: {}, + ), + constants: [], + overrides: [ + ( + name: Some("has_point_light"), + id: Some(0), + ty: 1, + init: Some(1), + ), + ( + name: Some("specular_param"), + id: Some(1200), + ty: 2, + init: Some(2), + ), + ( + name: Some("gain"), + id: Some(1300), + ty: 2, + init: None, + ), + ( + name: Some("width"), + id: None, + ty: 2, + init: Some(3), + ), + ( + name: Some("depth"), + id: None, + ty: 2, + init: None, + ), + ( + name: Some("inferred_f32"), + id: None, + ty: 2, + init: Some(4), + ), + ], + global_variables: [], + const_expressions: [ + Literal(Bool(true)), + Literal(F32(2.3)), + Literal(F32(0.0)), + Literal(F32(2.718)), + ], + functions: [], + entry_points: [], +) \ No newline at end of file diff --git a/naga/tests/out/ir/shadow.compact.ron b/naga/tests/out/ir/shadow.compact.ron index 4e65180691..fab0f1e2f6 100644 --- a/naga/tests/out/ir/shadow.compact.ron +++ b/naga/tests/out/ir/shadow.compact.ron @@ -253,6 +253,7 @@ init: 22, ), ], + overrides: [], global_variables: [ ( name: Some("t_shadow"), diff --git a/naga/tests/out/ir/shadow.ron b/naga/tests/out/ir/shadow.ron index 0b2310284a..9acbbdaadd 100644 --- a/naga/tests/out/ir/shadow.ron +++ b/naga/tests/out/ir/shadow.ron @@ -456,6 +456,7 @@ init: 38, ), ], + overrides: [], global_variables: [ ( name: Some("t_shadow"), diff --git a/naga/tests/snapshots.rs b/naga/tests/snapshots.rs index f19077869d..1d3734500d 100644 --- a/naga/tests/snapshots.rs +++ b/naga/tests/snapshots.rs @@ -815,6 +815,14 @@ fn convert_wgsl() { "int64", Targets::SPIRV | Targets::HLSL | Targets::WGSL | Targets::METAL, ), + ( + "overrides", + Targets::IR | Targets::ANALYSIS, // | Targets::SPIRV + // | Targets::METAL + // | Targets::GLSL + // | Targets::HLSL + // | Targets::WGSL, + ), ]; for &(name, targets) in inputs.iter() { diff --git a/naga/tests/wgsl_errors.rs b/naga/tests/wgsl_errors.rs index 74c273e33a..d6d1710f77 100644 --- a/naga/tests/wgsl_errors.rs +++ b/naga/tests/wgsl_errors.rs @@ -570,11 +570,11 @@ fn local_var_missing_type() { var x; } "#, - r#"error: variable `x` needs a type + r#"error: declaration of `x` needs a type specifier or initializer ┌─ wgsl:3:21 │ 3 │ var x; - │ ^ definition of `x` + │ ^ needs a type specifier or initializer "#, ); From da0a6c96e5042fe68f683809524c25e2d94cddd5 Mon Sep 17 00:00:00 2001 From: teoxoy <28601907+teoxoy@users.noreply.github.com> Date: Fri, 5 Jan 2024 15:25:26 +0100 Subject: [PATCH 04/30] remove naga's clone feature --- naga/Cargo.toml | 1 - naga/src/arena.rs | 4 ++-- naga/src/lib.rs | 24 +++++++----------------- wgpu-core/Cargo.toml | 1 - wgpu-hal/Cargo.toml | 1 - wgpu/Cargo.toml | 1 - 6 files changed, 9 insertions(+), 23 deletions(-) diff --git a/naga/Cargo.toml b/naga/Cargo.toml index a880b63126..df71664ea4 100644 --- a/naga/Cargo.toml +++ b/naga/Cargo.toml @@ -21,7 +21,6 @@ all-features = true [features] default = [] -clone = [] dot-out = [] glsl-in = ["pp-rs"] glsl-out = [] diff --git a/naga/src/arena.rs b/naga/src/arena.rs index c37538667f..4e5f5af6ec 100644 --- a/naga/src/arena.rs +++ b/naga/src/arena.rs @@ -239,7 +239,7 @@ impl Range { /// Adding new items to the arena produces a strongly-typed [`Handle`]. /// The arena can be indexed using the given handle to obtain /// a reference to the stored item. -#[cfg_attr(feature = "clone", derive(Clone))] +#[derive(Clone)] #[cfg_attr(feature = "serialize", derive(serde::Serialize))] #[cfg_attr(feature = "serialize", serde(transparent))] #[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))] @@ -531,7 +531,7 @@ mod tests { /// /// `UniqueArena` is similar to [`Arena`]: If `Arena` is vector-like, /// `UniqueArena` is `HashSet`-like. -#[cfg_attr(feature = "clone", derive(Clone))] +#[derive(Clone)] pub struct UniqueArena { set: FastIndexSet, diff --git a/naga/src/lib.rs b/naga/src/lib.rs index e55d4bb280..671fcc97c6 100644 --- a/naga/src/lib.rs +++ b/naga/src/lib.rs @@ -34,9 +34,6 @@ with optional span info, representing a series of statements executed in order. `EntryPoint`s or `Function` is a `Block`, and `Statement` has a [`Block`][Statement::Block] variant. -If the `clone` feature is enabled, [`Arena`], [`UniqueArena`], [`Type`], [`TypeInner`], -[`Constant`], [`Function`], [`EntryPoint`] and [`Module`] can be cloned. - ## Arenas To improve translator performance and reduce memory usage, most structures are @@ -888,8 +885,7 @@ pub enum Literal { } /// Pipeline-overridable constant. -#[derive(Debug, PartialEq)] -#[cfg_attr(feature = "clone", derive(Clone))] +#[derive(Debug, PartialEq, Clone)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] #[cfg_attr(feature = "arbitrary", derive(Arbitrary))] @@ -907,8 +903,7 @@ pub struct Override { } /// Constant value. -#[derive(Debug, PartialEq)] -#[cfg_attr(feature = "clone", derive(Clone))] +#[derive(Debug, PartialEq, Clone)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] #[cfg_attr(feature = "arbitrary", derive(Arbitrary))] @@ -1908,8 +1903,7 @@ pub struct FunctionResult { } /// A function defined in the module. -#[derive(Debug, Default)] -#[cfg_attr(feature = "clone", derive(Clone))] +#[derive(Debug, Default, Clone)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] #[cfg_attr(feature = "arbitrary", derive(Arbitrary))] @@ -1973,8 +1967,7 @@ pub struct Function { /// [`Location`]: Binding::Location /// [`function`]: EntryPoint::function /// [`stage`]: EntryPoint::stage -#[derive(Debug)] -#[cfg_attr(feature = "clone", derive(Clone))] +#[derive(Debug, Clone)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] #[cfg_attr(feature = "arbitrary", derive(Arbitrary))] @@ -1998,8 +1991,7 @@ pub struct EntryPoint { /// These cannot be spelled in WGSL source. /// /// Stored in [`SpecialTypes::predeclared_types`] and created by [`Module::generate_predeclared_type`]. -#[derive(Debug, PartialEq, Eq, Hash)] -#[cfg_attr(feature = "clone", derive(Clone))] +#[derive(Debug, PartialEq, Eq, Hash, Clone)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] #[cfg_attr(feature = "arbitrary", derive(Arbitrary))] @@ -2016,8 +2008,7 @@ pub enum PredeclaredType { } /// Set of special types that can be optionally generated by the frontends. -#[derive(Debug, Default)] -#[cfg_attr(feature = "clone", derive(Clone))] +#[derive(Debug, Default, Clone)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] #[cfg_attr(feature = "arbitrary", derive(Arbitrary))] @@ -2052,8 +2043,7 @@ pub struct SpecialTypes { /// Alternatively, you can load an existing shader using one of the [available front ends][front]. /// /// When finished, you can export modules using one of the [available backends][back]. -#[derive(Debug, Default)] -#[cfg_attr(feature = "clone", derive(Clone))] +#[derive(Debug, Default, Clone)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] #[cfg_attr(feature = "arbitrary", derive(Arbitrary))] diff --git a/wgpu-core/Cargo.toml b/wgpu-core/Cargo.toml index a0d3a5b612..c5149053c5 100644 --- a/wgpu-core/Cargo.toml +++ b/wgpu-core/Cargo.toml @@ -118,7 +118,6 @@ thiserror = "1" [dependencies.naga] path = "../naga" version = "0.19.0" -features = ["clone"] [dependencies.wgt] package = "wgpu-types" diff --git a/wgpu-hal/Cargo.toml b/wgpu-hal/Cargo.toml index ad1d0a974a..5851fdd76e 100644 --- a/wgpu-hal/Cargo.toml +++ b/wgpu-hal/Cargo.toml @@ -179,7 +179,6 @@ ndk-sys = { version = "0.5.0", optional = true } [dependencies.naga] path = "../naga" version = "0.19.0" -features = ["clone"] [build-dependencies] cfg_aliases.workspace = true diff --git a/wgpu/Cargo.toml b/wgpu/Cargo.toml index 43605f1f41..41da38ed3c 100644 --- a/wgpu/Cargo.toml +++ b/wgpu/Cargo.toml @@ -177,7 +177,6 @@ static_assertions.workspace = true [dependencies.naga] workspace = true -features = ["clone"] optional = true [build-dependencies] From 1a3c47b32dbc491f4f868784a6566adf688fab17 Mon Sep 17 00:00:00 2001 From: teoxoy <28601907+teoxoy@users.noreply.github.com> Date: Fri, 5 Jan 2024 14:42:07 +0100 Subject: [PATCH 05/30] [spv/msl/hlsl-out] support pipeline constant value replacements --- naga/src/arena.rs | 11 ++ naga/src/back/glsl/mod.rs | 6 + naga/src/back/hlsl/mod.rs | 2 + naga/src/back/hlsl/writer.rs | 6 +- naga/src/back/mod.rs | 8 + naga/src/back/msl/mod.rs | 2 + naga/src/back/msl/writer.rs | 4 + naga/src/back/pipeline_constants.rs | 213 +++++++++++++++++++++ naga/src/back/spv/mod.rs | 2 + naga/src/back/spv/writer.rs | 10 + naga/src/back/wgsl/writer.rs | 6 + naga/src/proc/constant_evaluator.rs | 1 + naga/src/valid/mod.rs | 15 +- naga/tests/in/overrides.param.ron | 11 ++ naga/tests/in/overrides.wgsl | 3 + naga/tests/out/analysis/overrides.info.ron | 17 +- naga/tests/out/hlsl/overrides.hlsl | 12 ++ naga/tests/out/hlsl/overrides.ron | 12 ++ naga/tests/out/ir/overrides.compact.ron | 22 ++- naga/tests/out/ir/overrides.ron | 22 ++- naga/tests/out/msl/overrides.msl | 17 ++ naga/tests/out/spv/overrides.main.spvasm | 25 +++ naga/tests/snapshots.rs | 50 ++++- wgpu-hal/src/vulkan/device.rs | 1 + 24 files changed, 463 insertions(+), 15 deletions(-) create mode 100644 naga/src/back/pipeline_constants.rs create mode 100644 naga/tests/in/overrides.param.ron create mode 100644 naga/tests/out/hlsl/overrides.hlsl create mode 100644 naga/tests/out/hlsl/overrides.ron create mode 100644 naga/tests/out/msl/overrides.msl create mode 100644 naga/tests/out/spv/overrides.main.spvasm diff --git a/naga/src/arena.rs b/naga/src/arena.rs index 4e5f5af6ec..184102757e 100644 --- a/naga/src/arena.rs +++ b/naga/src/arena.rs @@ -297,6 +297,17 @@ impl Arena { .map(|(i, v)| unsafe { (Handle::from_usize_unchecked(i), v) }) } + /// Drains the arena, returning an iterator over the items stored. + pub fn drain(&mut self) -> impl DoubleEndedIterator, T, Span)> { + let arena = std::mem::take(self); + arena + .data + .into_iter() + .zip(arena.span_info) + .enumerate() + .map(|(i, (v, span))| unsafe { (Handle::from_usize_unchecked(i), v, span) }) + } + /// Returns a iterator over the items stored in this arena, /// returning both the item's handle and a mutable reference to it. pub fn iter_mut(&mut self) -> impl DoubleEndedIterator, &mut T)> { diff --git a/naga/src/back/glsl/mod.rs b/naga/src/back/glsl/mod.rs index 8241b04128..736a3b57b7 100644 --- a/naga/src/back/glsl/mod.rs +++ b/naga/src/back/glsl/mod.rs @@ -567,6 +567,12 @@ impl<'a, W: Write> Writer<'a, W> { pipeline_options: &'a PipelineOptions, policies: proc::BoundsCheckPolicies, ) -> Result { + if !module.overrides.is_empty() { + return Err(Error::Custom( + "Pipeline constants are not yet supported for this back-end".to_string(), + )); + } + // Check if the requested version is supported if !options.version.is_supported() { log::error!("Version {}", options.version); diff --git a/naga/src/back/hlsl/mod.rs b/naga/src/back/hlsl/mod.rs index 37d26bf3b2..588c91d69d 100644 --- a/naga/src/back/hlsl/mod.rs +++ b/naga/src/back/hlsl/mod.rs @@ -255,6 +255,8 @@ pub enum Error { Unimplemented(String), // TODO: Error used only during development #[error("{0}")] Custom(String), + #[error(transparent)] + PipelineConstant(#[from] back::pipeline_constants::PipelineConstantError), } #[derive(Default)] diff --git a/naga/src/back/hlsl/writer.rs b/naga/src/back/hlsl/writer.rs index 657774d070..0db6489840 100644 --- a/naga/src/back/hlsl/writer.rs +++ b/naga/src/back/hlsl/writer.rs @@ -167,8 +167,12 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { &mut self, module: &Module, module_info: &valid::ModuleInfo, - _pipeline_options: &PipelineOptions, + pipeline_options: &PipelineOptions, ) -> Result { + let module = + back::pipeline_constants::process_overrides(module, &pipeline_options.constants)?; + let module = module.as_ref(); + self.reset(module); // Write special constants, if needed diff --git a/naga/src/back/mod.rs b/naga/src/back/mod.rs index 61dc4a0601..a95328d4fa 100644 --- a/naga/src/back/mod.rs +++ b/naga/src/back/mod.rs @@ -16,6 +16,14 @@ pub mod spv; #[cfg(feature = "wgsl-out")] pub mod wgsl; +#[cfg(any( + feature = "hlsl-out", + feature = "msl-out", + feature = "spv-out", + feature = "glsl-out" +))] +mod pipeline_constants; + /// Names of vector components. pub const COMPONENTS: &[char] = &['x', 'y', 'z', 'w']; /// Indent for backends. diff --git a/naga/src/back/msl/mod.rs b/naga/src/back/msl/mod.rs index 7e05be29bd..702b373cfc 100644 --- a/naga/src/back/msl/mod.rs +++ b/naga/src/back/msl/mod.rs @@ -143,6 +143,8 @@ pub enum Error { UnsupportedArrayOfType(Handle), #[error("ray tracing is not supported prior to MSL 2.3")] UnsupportedRayTracing, + #[error(transparent)] + PipelineConstant(#[from] crate::back::pipeline_constants::PipelineConstantError), } #[derive(Clone, Debug, PartialEq, thiserror::Error)] diff --git a/naga/src/back/msl/writer.rs b/naga/src/back/msl/writer.rs index f8fa2c4da5..36d8bc820b 100644 --- a/naga/src/back/msl/writer.rs +++ b/naga/src/back/msl/writer.rs @@ -3223,6 +3223,10 @@ impl Writer { options: &Options, pipeline_options: &PipelineOptions, ) -> Result { + let module = + back::pipeline_constants::process_overrides(module, &pipeline_options.constants)?; + let module = module.as_ref(); + self.names.clear(); self.namer.reset( module, diff --git a/naga/src/back/pipeline_constants.rs b/naga/src/back/pipeline_constants.rs new file mode 100644 index 0000000000..5a3cad2a6d --- /dev/null +++ b/naga/src/back/pipeline_constants.rs @@ -0,0 +1,213 @@ +use super::PipelineConstants; +use crate::{Constant, Expression, Literal, Module, Scalar, Span, TypeInner}; +use std::borrow::Cow; +use thiserror::Error; + +#[derive(Error, Debug, Clone)] +#[cfg_attr(test, derive(PartialEq))] +pub enum PipelineConstantError { + #[error("Missing value for pipeline-overridable constant with identifier string: '{0}'")] + MissingValue(String), + #[error("Source f64 value needs to be finite (NaNs and Inifinites are not allowed) for number destinations")] + SrcNeedsToBeFinite, + #[error("Source f64 value doesn't fit in destination")] + DstRangeTooSmall, +} + +pub(super) fn process_overrides<'a>( + module: &'a Module, + pipeline_constants: &PipelineConstants, +) -> Result, PipelineConstantError> { + if module.overrides.is_empty() { + return Ok(Cow::Borrowed(module)); + } + + let mut module = module.clone(); + + for (_handle, override_, span) in module.overrides.drain() { + let key = if let Some(id) = override_.id { + Cow::Owned(id.to_string()) + } else if let Some(ref name) = override_.name { + Cow::Borrowed(name) + } else { + unreachable!(); + }; + let init = if let Some(value) = pipeline_constants.get::(&key) { + let literal = match module.types[override_.ty].inner { + TypeInner::Scalar(scalar) => map_value_to_literal(*value, scalar)?, + _ => unreachable!(), + }; + module + .const_expressions + .append(Expression::Literal(literal), Span::UNDEFINED) + } else if let Some(init) = override_.init { + init + } else { + return Err(PipelineConstantError::MissingValue(key.to_string())); + }; + let constant = Constant { + name: override_.name, + ty: override_.ty, + init, + }; + module.constants.append(constant, span); + } + + Ok(Cow::Owned(module)) +} + +fn map_value_to_literal(value: f64, scalar: Scalar) -> Result { + // note that in rust 0.0 == -0.0 + match scalar { + Scalar::BOOL => { + // https://webidl.spec.whatwg.org/#js-boolean + let value = value != 0.0 && !value.is_nan(); + Ok(Literal::Bool(value)) + } + Scalar::I32 => { + // https://webidl.spec.whatwg.org/#js-long + if !value.is_finite() { + return Err(PipelineConstantError::SrcNeedsToBeFinite); + } + + let value = value.trunc(); + if value < f64::from(i32::MIN) || value > f64::from(i32::MAX) { + return Err(PipelineConstantError::DstRangeTooSmall); + } + + let value = value as i32; + Ok(Literal::I32(value)) + } + Scalar::U32 => { + // https://webidl.spec.whatwg.org/#js-unsigned-long + if !value.is_finite() { + return Err(PipelineConstantError::SrcNeedsToBeFinite); + } + + let value = value.trunc(); + if value < f64::from(u32::MIN) || value > f64::from(u32::MAX) { + return Err(PipelineConstantError::DstRangeTooSmall); + } + + let value = value as u32; + Ok(Literal::U32(value)) + } + Scalar::F32 => { + // https://webidl.spec.whatwg.org/#js-float + if !value.is_finite() { + return Err(PipelineConstantError::SrcNeedsToBeFinite); + } + + let value = value as f32; + if !value.is_finite() { + return Err(PipelineConstantError::DstRangeTooSmall); + } + + Ok(Literal::F32(value)) + } + Scalar::F64 => { + // https://webidl.spec.whatwg.org/#js-double + if !value.is_finite() { + return Err(PipelineConstantError::SrcNeedsToBeFinite); + } + + Ok(Literal::F64(value)) + } + _ => unreachable!(), + } +} + +#[test] +fn test_map_value_to_literal() { + let bool_test_cases = [ + (0.0, false), + (-0.0, false), + (f64::NAN, false), + (1.0, true), + (f64::INFINITY, true), + (f64::NEG_INFINITY, true), + ]; + for (value, out) in bool_test_cases { + let res = Ok(Literal::Bool(out)); + assert_eq!(map_value_to_literal(value, Scalar::BOOL), res); + } + + for scalar in [Scalar::I32, Scalar::U32, Scalar::F32, Scalar::F64] { + for value in [f64::NAN, f64::INFINITY, f64::NEG_INFINITY] { + let res = Err(PipelineConstantError::SrcNeedsToBeFinite); + assert_eq!(map_value_to_literal(value, scalar), res); + } + } + + // i32 + assert_eq!( + map_value_to_literal(f64::from(i32::MIN), Scalar::I32), + Ok(Literal::I32(i32::MIN)) + ); + assert_eq!( + map_value_to_literal(f64::from(i32::MAX), Scalar::I32), + Ok(Literal::I32(i32::MAX)) + ); + assert_eq!( + map_value_to_literal(f64::from(i32::MIN) - 1.0, Scalar::I32), + Err(PipelineConstantError::DstRangeTooSmall) + ); + assert_eq!( + map_value_to_literal(f64::from(i32::MAX) + 1.0, Scalar::I32), + Err(PipelineConstantError::DstRangeTooSmall) + ); + + // u32 + assert_eq!( + map_value_to_literal(f64::from(u32::MIN), Scalar::U32), + Ok(Literal::U32(u32::MIN)) + ); + assert_eq!( + map_value_to_literal(f64::from(u32::MAX), Scalar::U32), + Ok(Literal::U32(u32::MAX)) + ); + assert_eq!( + map_value_to_literal(f64::from(u32::MIN) - 1.0, Scalar::U32), + Err(PipelineConstantError::DstRangeTooSmall) + ); + assert_eq!( + map_value_to_literal(f64::from(u32::MAX) + 1.0, Scalar::U32), + Err(PipelineConstantError::DstRangeTooSmall) + ); + + // f32 + assert_eq!( + map_value_to_literal(f64::from(f32::MIN), Scalar::F32), + Ok(Literal::F32(f32::MIN)) + ); + assert_eq!( + map_value_to_literal(f64::from(f32::MAX), Scalar::F32), + Ok(Literal::F32(f32::MAX)) + ); + assert_eq!( + map_value_to_literal(-f64::from_bits(0x47efffffefffffff), Scalar::F32), + Ok(Literal::F32(f32::MIN)) + ); + assert_eq!( + map_value_to_literal(f64::from_bits(0x47efffffefffffff), Scalar::F32), + Ok(Literal::F32(f32::MAX)) + ); + assert_eq!( + map_value_to_literal(-f64::from_bits(0x47effffff0000000), Scalar::F32), + Err(PipelineConstantError::DstRangeTooSmall) + ); + assert_eq!( + map_value_to_literal(f64::from_bits(0x47effffff0000000), Scalar::F32), + Err(PipelineConstantError::DstRangeTooSmall) + ); + + // f64 + assert_eq!( + map_value_to_literal(f64::MIN, Scalar::F64), + Ok(Literal::F64(f64::MIN)) + ); + assert_eq!( + map_value_to_literal(f64::MAX, Scalar::F64), + Ok(Literal::F64(f64::MAX)) + ); +} diff --git a/naga/src/back/spv/mod.rs b/naga/src/back/spv/mod.rs index 087c49bccf..3c0332d59d 100644 --- a/naga/src/back/spv/mod.rs +++ b/naga/src/back/spv/mod.rs @@ -70,6 +70,8 @@ pub enum Error { FeatureNotImplemented(&'static str), #[error("module is not validated properly: {0}")] Validation(&'static str), + #[error(transparent)] + PipelineConstant(#[from] crate::back::pipeline_constants::PipelineConstantError), } #[derive(Default)] diff --git a/naga/src/back/spv/writer.rs b/naga/src/back/spv/writer.rs index a5065e0623..975aa625d0 100644 --- a/naga/src/back/spv/writer.rs +++ b/naga/src/back/spv/writer.rs @@ -2029,6 +2029,16 @@ impl Writer { debug_info: &Option, words: &mut Vec, ) -> Result<(), Error> { + let ir_module = if let Some(pipeline_options) = pipeline_options { + crate::back::pipeline_constants::process_overrides( + ir_module, + &pipeline_options.constants, + )? + } else { + std::borrow::Cow::Borrowed(ir_module) + }; + let ir_module = ir_module.as_ref(); + self.reset(); // Try to find the entry point and corresponding index diff --git a/naga/src/back/wgsl/writer.rs b/naga/src/back/wgsl/writer.rs index 607954706f..7ca689f482 100644 --- a/naga/src/back/wgsl/writer.rs +++ b/naga/src/back/wgsl/writer.rs @@ -106,6 +106,12 @@ impl Writer { } pub fn write(&mut self, module: &Module, info: &valid::ModuleInfo) -> BackendResult { + if !module.overrides.is_empty() { + return Err(Error::Unimplemented( + "Pipeline constants are not yet supported for this back-end".to_string(), + )); + } + self.reset(module); // Save all ep result types diff --git a/naga/src/proc/constant_evaluator.rs b/naga/src/proc/constant_evaluator.rs index 8a9da04d33..5617cc7709 100644 --- a/naga/src/proc/constant_evaluator.rs +++ b/naga/src/proc/constant_evaluator.rs @@ -359,6 +359,7 @@ impl ExpressionConstnessTracker { } #[derive(Clone, Debug, thiserror::Error)] +#[cfg_attr(test, derive(PartialEq))] pub enum ConstantEvaluatorError { #[error("Constants cannot access function arguments")] FunctionArg, diff --git a/naga/src/valid/mod.rs b/naga/src/valid/mod.rs index d54079ac13..be11e8e390 100644 --- a/naga/src/valid/mod.rs +++ b/naga/src/valid/mod.rs @@ -186,6 +186,8 @@ pub enum ConstantError { #[derive(Clone, Debug, thiserror::Error)] pub enum OverrideError { + #[error("Override name and ID are missing")] + MissingNameAndID, #[error("The type doesn't match the override")] InvalidType, #[error("The type is not constructible")] @@ -353,6 +355,10 @@ impl Validator { ) -> Result<(), OverrideError> { let o = &gctx.overrides[handle]; + if o.name.is_none() && o.id.is_none() { + return Err(OverrideError::MissingNameAndID); + } + let type_info = &self.types[o.ty.index()]; if !type_info.flags.contains(TypeFlags::CONSTRUCTIBLE) { return Err(OverrideError::NonConstructibleType); @@ -360,7 +366,14 @@ impl Validator { let decl_ty = &gctx.types[o.ty].inner; match decl_ty { - &crate::TypeInner::Scalar(_) => {} + &crate::TypeInner::Scalar(scalar) => match scalar { + crate::Scalar::BOOL + | crate::Scalar::I32 + | crate::Scalar::U32 + | crate::Scalar::F32 + | crate::Scalar::F64 => {} + _ => return Err(OverrideError::TypeNotScalar), + }, _ => return Err(OverrideError::TypeNotScalar), } diff --git a/naga/tests/in/overrides.param.ron b/naga/tests/in/overrides.param.ron new file mode 100644 index 0000000000..5c9b72d310 --- /dev/null +++ b/naga/tests/in/overrides.param.ron @@ -0,0 +1,11 @@ +( + spv: ( + version: (1, 0), + separate_entry_points: true, + ), + pipeline_constants: { + "0": NaN, + "1300": 1.1, + "depth": 2.3, + } +) diff --git a/naga/tests/in/overrides.wgsl b/naga/tests/in/overrides.wgsl index 803269a656..b498a8b527 100644 --- a/naga/tests/in/overrides.wgsl +++ b/naga/tests/in/overrides.wgsl @@ -12,3 +12,6 @@ // overridable constant. override inferred_f32 = 2.718; + +@compute @workgroup_size(1) +fn main() {} \ No newline at end of file diff --git a/naga/tests/out/analysis/overrides.info.ron b/naga/tests/out/analysis/overrides.info.ron index 9ad1b3914e..481c3eac99 100644 --- a/naga/tests/out/analysis/overrides.info.ron +++ b/naga/tests/out/analysis/overrides.info.ron @@ -4,7 +4,22 @@ ("DATA | SIZED | COPY | IO_SHAREABLE | HOST_SHAREABLE | ARGUMENT | CONSTRUCTIBLE"), ], functions: [], - entry_points: [], + entry_points: [ + ( + flags: ("EXPRESSIONS | BLOCKS | CONTROL_FLOW_UNIFORMITY | STRUCT_LAYOUTS | CONSTANTS | BINDINGS"), + available_stages: ("VERTEX | FRAGMENT | COMPUTE"), + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + may_kill: false, + sampling_set: [], + global_uses: [], + expressions: [], + sampling: [], + dual_source_blending: false, + ), + ], const_expression_types: [ Value(Scalar(( kind: Bool, diff --git a/naga/tests/out/hlsl/overrides.hlsl b/naga/tests/out/hlsl/overrides.hlsl new file mode 100644 index 0000000000..63b13a5d2b --- /dev/null +++ b/naga/tests/out/hlsl/overrides.hlsl @@ -0,0 +1,12 @@ +static const bool has_point_light = false; +static const float specular_param = 2.3; +static const float gain = 1.1; +static const float width = 0.0; +static const float depth = 2.3; +static const float inferred_f32_ = 2.718; + +[numthreads(1, 1, 1)] +void main() +{ + return; +} diff --git a/naga/tests/out/hlsl/overrides.ron b/naga/tests/out/hlsl/overrides.ron new file mode 100644 index 0000000000..a07b03300b --- /dev/null +++ b/naga/tests/out/hlsl/overrides.ron @@ -0,0 +1,12 @@ +( + vertex:[ + ], + fragment:[ + ], + compute:[ + ( + entry_point:"main", + target_profile:"cs_5_1", + ), + ], +) diff --git a/naga/tests/out/ir/overrides.compact.ron b/naga/tests/out/ir/overrides.compact.ron index 5ac9ade6f6..af4b31eba9 100644 --- a/naga/tests/out/ir/overrides.compact.ron +++ b/naga/tests/out/ir/overrides.compact.ron @@ -67,5 +67,25 @@ Literal(F32(2.718)), ], functions: [], - entry_points: [], + entry_points: [ + ( + name: "main", + stage: Compute, + early_depth_test: None, + workgroup_size: (1, 1, 1), + function: ( + name: Some("main"), + arguments: [], + result: None, + local_variables: [], + expressions: [], + named_expressions: {}, + body: [ + Return( + value: None, + ), + ], + ), + ), + ], ) \ No newline at end of file diff --git a/naga/tests/out/ir/overrides.ron b/naga/tests/out/ir/overrides.ron index 5ac9ade6f6..af4b31eba9 100644 --- a/naga/tests/out/ir/overrides.ron +++ b/naga/tests/out/ir/overrides.ron @@ -67,5 +67,25 @@ Literal(F32(2.718)), ], functions: [], - entry_points: [], + entry_points: [ + ( + name: "main", + stage: Compute, + early_depth_test: None, + workgroup_size: (1, 1, 1), + function: ( + name: Some("main"), + arguments: [], + result: None, + local_variables: [], + expressions: [], + named_expressions: {}, + body: [ + Return( + value: None, + ), + ], + ), + ), + ], ) \ No newline at end of file diff --git a/naga/tests/out/msl/overrides.msl b/naga/tests/out/msl/overrides.msl new file mode 100644 index 0000000000..419edd8904 --- /dev/null +++ b/naga/tests/out/msl/overrides.msl @@ -0,0 +1,17 @@ +// language: metal1.0 +#include +#include + +using metal::uint; + +constant bool has_point_light = false; +constant float specular_param = 2.3; +constant float gain = 1.1; +constant float width = 0.0; +constant float depth = 2.3; +constant float inferred_f32_ = 2.718; + +kernel void main_( +) { + return; +} diff --git a/naga/tests/out/spv/overrides.main.spvasm b/naga/tests/out/spv/overrides.main.spvasm new file mode 100644 index 0000000000..7dfa6df3e5 --- /dev/null +++ b/naga/tests/out/spv/overrides.main.spvasm @@ -0,0 +1,25 @@ +; SPIR-V +; Version: 1.0 +; Generator: rspirv +; Bound: 15 +OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint GLCompute %12 "main" +OpExecutionMode %12 LocalSize 1 1 1 +%2 = OpTypeVoid +%3 = OpTypeBool +%4 = OpTypeFloat 32 +%5 = OpConstantTrue %3 +%6 = OpConstant %4 2.3 +%7 = OpConstant %4 0.0 +%8 = OpConstant %4 2.718 +%9 = OpConstantFalse %3 +%10 = OpConstant %4 1.1 +%13 = OpTypeFunction %2 +%12 = OpFunction %2 None %13 +%11 = OpLabel +OpBranch %14 +%14 = OpLabel +OpReturn +OpFunctionEnd \ No newline at end of file diff --git a/naga/tests/snapshots.rs b/naga/tests/snapshots.rs index 1d3734500d..e2f6dff25f 100644 --- a/naga/tests/snapshots.rs +++ b/naga/tests/snapshots.rs @@ -87,6 +87,17 @@ struct Parameters { #[cfg(all(feature = "deserialize", feature = "glsl-out"))] #[serde(default)] glsl_multiview: Option, + #[cfg(all( + feature = "deserialize", + any( + feature = "hlsl-out", + feature = "msl-out", + feature = "spv-out", + feature = "glsl-out" + ) + ))] + #[serde(default)] + pipeline_constants: naga::back::PipelineConstants, } /// Information about a shader input file. @@ -331,18 +342,25 @@ fn check_targets( debug_info, ¶ms.spv, params.bounds_check_policies, + ¶ms.pipeline_constants, ); } } #[cfg(all(feature = "deserialize", feature = "msl-out"))] { if targets.contains(Targets::METAL) { + if !params.msl_pipeline.constants.is_empty() { + panic!("Supply pipeline constants via pipeline_constants instead of msl_pipeline.constants!"); + } + let mut pipeline_options = params.msl_pipeline.clone(); + pipeline_options.constants = params.pipeline_constants.clone(); + write_output_msl( input, module, &info, ¶ms.msl, - ¶ms.msl_pipeline, + &pipeline_options, params.bounds_check_policies, ); } @@ -363,6 +381,7 @@ fn check_targets( ¶ms.glsl, params.bounds_check_policies, params.glsl_multiview, + ¶ms.pipeline_constants, ); } } @@ -377,7 +396,13 @@ fn check_targets( #[cfg(all(feature = "deserialize", feature = "hlsl-out"))] { if targets.contains(Targets::HLSL) { - write_output_hlsl(input, module, &info, ¶ms.hlsl); + write_output_hlsl( + input, + module, + &info, + ¶ms.hlsl, + ¶ms.pipeline_constants, + ); } } #[cfg(all(feature = "deserialize", feature = "wgsl-out"))] @@ -396,6 +421,7 @@ fn write_output_spv( debug_info: Option, params: &SpirvOutParameters, bounds_check_policies: naga::proc::BoundsCheckPolicies, + pipeline_constants: &naga::back::PipelineConstants, ) { use naga::back::spv; use rspirv::binary::Disassemble; @@ -428,7 +454,7 @@ fn write_output_spv( let pipeline_options = spv::PipelineOptions { entry_point: ep.name.clone(), shader_stage: ep.stage, - constants: naga::back::PipelineConstants::default(), + constants: pipeline_constants.clone(), }; write_output_spv_inner( input, @@ -508,6 +534,7 @@ fn write_output_glsl( options: &naga::back::glsl::Options, bounds_check_policies: naga::proc::BoundsCheckPolicies, multiview: Option, + pipeline_constants: &naga::back::PipelineConstants, ) { use naga::back::glsl; @@ -517,7 +544,7 @@ fn write_output_glsl( shader_stage: stage, entry_point: ep_name.to_string(), multiview, - constants: naga::back::PipelineConstants::default(), + constants: pipeline_constants.clone(), }; let mut buffer = String::new(); @@ -542,6 +569,7 @@ fn write_output_hlsl( module: &naga::Module, info: &naga::valid::ModuleInfo, options: &naga::back::hlsl::Options, + pipeline_constants: &naga::back::PipelineConstants, ) { use naga::back::hlsl; use std::fmt::Write as _; @@ -551,7 +579,13 @@ fn write_output_hlsl( let mut buffer = String::new(); let mut writer = hlsl::Writer::new(&mut buffer, options); let reflection_info = writer - .write(module, info, &hlsl::PipelineOptions::default()) + .write( + module, + info, + &hlsl::PipelineOptions { + constants: pipeline_constants.clone(), + }, + ) .expect("HLSL write failed"); input.write_output_file("hlsl", "hlsl", buffer); @@ -817,11 +851,7 @@ fn convert_wgsl() { ), ( "overrides", - Targets::IR | Targets::ANALYSIS, // | Targets::SPIRV - // | Targets::METAL - // | Targets::GLSL - // | Targets::HLSL - // | Targets::WGSL, + Targets::IR | Targets::ANALYSIS | Targets::SPIRV | Targets::METAL | Targets::HLSL, ), ]; diff --git a/wgpu-hal/src/vulkan/device.rs b/wgpu-hal/src/vulkan/device.rs index 2dcded2200..989ad60c72 100644 --- a/wgpu-hal/src/vulkan/device.rs +++ b/wgpu-hal/src/vulkan/device.rs @@ -1588,6 +1588,7 @@ impl crate::Device for super::Device { .shared .workarounds .contains(super::Workarounds::SEPARATE_ENTRY_POINTS) + || !naga_shader.module.overrides.is_empty() { return Ok(super::ShaderModule::Intermediate { naga_shader, From d340f9fe9ecf9b4a698a183a7f38b5f0508fadb9 Mon Sep 17 00:00:00 2001 From: teoxoy <28601907+teoxoy@users.noreply.github.com> Date: Tue, 9 Jan 2024 14:30:10 +0100 Subject: [PATCH 06/30] validate that override ids are unique --- naga/src/valid/mod.rs | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/naga/src/valid/mod.rs b/naga/src/valid/mod.rs index be11e8e390..311279478c 100644 --- a/naga/src/valid/mod.rs +++ b/naga/src/valid/mod.rs @@ -174,6 +174,7 @@ pub struct Validator { switch_values: FastHashSet, valid_expression_list: Vec>, valid_expression_set: BitSet, + override_ids: FastHashSet, } #[derive(Clone, Debug, thiserror::Error)] @@ -188,6 +189,8 @@ pub enum ConstantError { pub enum OverrideError { #[error("Override name and ID are missing")] MissingNameAndID, + #[error("Override ID must be unique")] + DuplicateID, #[error("The type doesn't match the override")] InvalidType, #[error("The type is not constructible")] @@ -311,6 +314,7 @@ impl Validator { switch_values: FastHashSet::default(), valid_expression_list: Vec::new(), valid_expression_set: BitSet::new(), + override_ids: FastHashSet::default(), } } @@ -323,6 +327,7 @@ impl Validator { self.switch_values.clear(); self.valid_expression_list.clear(); self.valid_expression_set.clear(); + self.override_ids.clear(); } fn validate_constant( @@ -348,7 +353,7 @@ impl Validator { } fn validate_override( - &self, + &mut self, handle: Handle, gctx: crate::proc::GlobalCtx, mod_info: &ModuleInfo, @@ -359,6 +364,12 @@ impl Validator { return Err(OverrideError::MissingNameAndID); } + if let Some(id) = o.id { + if !self.override_ids.insert(id) { + return Err(OverrideError::DuplicateID); + } + } + let type_info = &self.types[o.ty.index()]; if !type_info.flags.contains(TypeFlags::CONSTRUCTIBLE) { return Err(OverrideError::NonConstructibleType); From 120d0a0bcc42d807ef8db379a357d0ab91a47003 Mon Sep 17 00:00:00 2001 From: teoxoy <28601907+teoxoy@users.noreply.github.com> Date: Tue, 16 Jan 2024 14:47:14 +0100 Subject: [PATCH 07/30] [const-eval] refactor logic around `try_eval_and_append` --- naga/src/front/glsl/context.rs | 23 +---------- naga/src/front/wgsl/lower/mod.rs | 14 +------ naga/src/proc/constant_evaluator.rs | 61 ++++++++++++++++++++--------- 3 files changed, 46 insertions(+), 52 deletions(-) diff --git a/naga/src/front/glsl/context.rs b/naga/src/front/glsl/context.rs index f26c57965d..a3b4e0edde 100644 --- a/naga/src/front/glsl/context.rs +++ b/naga/src/front/glsl/context.rs @@ -260,29 +260,10 @@ impl<'a> Context<'a> { ) }; - let res = eval.try_eval_and_append(&expr, meta).map_err(|e| Error { + eval.try_eval_and_append(expr, meta).map_err(|e| Error { kind: e.into(), meta, - }); - - match res { - Ok(expr) => Ok(expr), - Err(e) => { - if self.is_const { - Err(e) - } else { - let needs_pre_emit = expr.needs_pre_emit(); - if needs_pre_emit { - self.body.extend(self.emitter.finish(&self.expressions)); - } - let h = self.expressions.append(expr, meta); - if needs_pre_emit { - self.emitter.start(&self.expressions); - } - Ok(h) - } - } - } + }) } /// Add variable to current scope diff --git a/naga/src/front/wgsl/lower/mod.rs b/naga/src/front/wgsl/lower/mod.rs index 553633ff3f..29a87751ca 100644 --- a/naga/src/front/wgsl/lower/mod.rs +++ b/naga/src/front/wgsl/lower/mod.rs @@ -358,18 +358,8 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> { span: Span, ) -> Result, Error<'source>> { let mut eval = self.as_const_evaluator(); - match eval.try_eval_and_append(&expr, span) { - Ok(expr) => Ok(expr), - - // `expr` is not a constant expression. This is fine as - // long as we're not building `Module::const_expressions`. - Err(err) => match self.expr_type { - ExpressionContextType::Runtime(ref mut rctx) => { - Ok(rctx.function.expressions.append(expr, span)) - } - ExpressionContextType::Constant => Err(Error::ConstantEvaluatorError(err, span)), - }, - } + eval.try_eval_and_append(expr, span) + .map_err(|e| Error::ConstantEvaluatorError(e, span)) } fn const_access(&self, handle: Handle) -> Option { diff --git a/naga/src/proc/constant_evaluator.rs b/naga/src/proc/constant_evaluator.rs index 5617cc7709..a9c873afbc 100644 --- a/naga/src/proc/constant_evaluator.rs +++ b/naga/src/proc/constant_evaluator.rs @@ -587,9 +587,11 @@ impl<'a> ConstantEvaluator<'a> { /// [`ZeroValue`], and [`Swizzle`] expressions - to the expression arena /// `self` contributes to. /// - /// If `expr`'s value cannot be determined at compile time, return a an - /// error. If it's acceptable to evaluate `expr` at runtime, this error can - /// be ignored, and the caller can append `expr` to the arena itself. + /// If `expr`'s value cannot be determined at compile time, and `self` is + /// contributing to some function's expression arena, then append `expr` to + /// that arena unchanged (and thus unevaluated). Otherwise, `self` must be + /// contributing to the module's constant expression arena; since `expr`'s + /// value is not a constant, return an error. /// /// We only consider `expr` itself, without recursing into its operands. Its /// operands must all have been produced by prior calls to @@ -601,6 +603,22 @@ impl<'a> ConstantEvaluator<'a> { /// [`ZeroValue`]: Expression::ZeroValue /// [`Swizzle`]: Expression::Swizzle pub fn try_eval_and_append( + &mut self, + expr: Expression, + span: Span, + ) -> Result, ConstantEvaluatorError> { + let res = self.try_eval_and_append_impl(&expr, span); + if self.function_local_data.is_some() { + match res { + Ok(h) => Ok(h), + Err(_) => Ok(self.append_expr(expr, span, false)), + } + } else { + res + } + } + + fn try_eval_and_append_impl( &mut self, expr: &Expression, span: Span, @@ -1863,6 +1881,10 @@ impl<'a> ConstantEvaluator<'a> { crate::valid::check_literal_value(literal)?; } + Ok(self.append_expr(expr, span, true)) + } + + fn append_expr(&mut self, expr: Expression, span: Span, is_const: bool) -> Handle { if let Some(FunctionLocalData { ref mut emitter, ref mut block, @@ -1872,19 +1894,20 @@ impl<'a> ConstantEvaluator<'a> { { let is_running = emitter.is_running(); let needs_pre_emit = expr.needs_pre_emit(); - if is_running && needs_pre_emit { + let h = if is_running && needs_pre_emit { block.extend(emitter.finish(self.expressions)); let h = self.expressions.append(expr, span); emitter.start(self.expressions); - expression_constness.insert(h); - Ok(h) + h } else { - let h = self.expressions.append(expr, span); + self.expressions.append(expr, span) + }; + if is_const { expression_constness.insert(h); - Ok(h) } + h } else { - Ok(self.expressions.append(expr, span)) + self.expressions.append(expr, span) } } @@ -2130,13 +2153,13 @@ mod tests { }; let res1 = solver - .try_eval_and_append(&expr2, Default::default()) + .try_eval_and_append(expr2, Default::default()) .unwrap(); let res2 = solver - .try_eval_and_append(&expr3, Default::default()) + .try_eval_and_append(expr3, Default::default()) .unwrap(); let res3 = solver - .try_eval_and_append(&expr4, Default::default()) + .try_eval_and_append(expr4, Default::default()) .unwrap(); assert_eq!( @@ -2215,7 +2238,7 @@ mod tests { }; let res = solver - .try_eval_and_append(&root, Default::default()) + .try_eval_and_append(root, Default::default()) .unwrap(); assert_eq!( @@ -2334,7 +2357,7 @@ mod tests { let root1 = Expression::AccessIndex { base, index: 1 }; let res1 = solver - .try_eval_and_append(&root1, Default::default()) + .try_eval_and_append(root1, Default::default()) .unwrap(); let root2 = Expression::AccessIndex { @@ -2343,7 +2366,7 @@ mod tests { }; let res2 = solver - .try_eval_and_append(&root2, Default::default()) + .try_eval_and_append(root2, Default::default()) .unwrap(); match const_expressions[res1] { @@ -2425,7 +2448,7 @@ mod tests { let solved_compose = solver .try_eval_and_append( - &Expression::Compose { + Expression::Compose { ty: vec2_i32_ty, components: vec![h_expr, h_expr], }, @@ -2434,7 +2457,7 @@ mod tests { .unwrap(); let solved_negate = solver .try_eval_and_append( - &Expression::Unary { + Expression::Unary { op: UnaryOperator::Negate, expr: solved_compose, }, @@ -2506,7 +2529,7 @@ mod tests { let solved_compose = solver .try_eval_and_append( - &Expression::Splat { + Expression::Splat { size: VectorSize::Bi, value: h_expr, }, @@ -2515,7 +2538,7 @@ mod tests { .unwrap(); let solved_negate = solver .try_eval_and_append( - &Expression::Unary { + Expression::Unary { op: UnaryOperator::Negate, expr: solved_compose, }, From e9775c8567304fb786e684ed88d094ab6cbe3923 Mon Sep 17 00:00:00 2001 From: teoxoy <28601907+teoxoy@users.noreply.github.com> Date: Tue, 13 Feb 2024 14:51:03 +0100 Subject: [PATCH 08/30] [const-eval] fix evaluation of bool constuctors --- naga/src/proc/mod.rs | 4 ++-- .../out/glsl/constructors.main.Compute.glsl | 1 - naga/tests/out/hlsl/constructors.hlsl | 1 - naga/tests/out/msl/constructors.msl | 1 - naga/tests/out/spv/constructors.spvasm | 21 ++++++++++--------- naga/tests/out/wgsl/constructors.wgsl | 1 - 6 files changed, 13 insertions(+), 16 deletions(-) diff --git a/naga/src/proc/mod.rs b/naga/src/proc/mod.rs index ddb42a2c52..6dc677ff23 100644 --- a/naga/src/proc/mod.rs +++ b/naga/src/proc/mod.rs @@ -216,8 +216,8 @@ impl crate::Literal { (value, crate::ScalarKind::Sint, 4) => Some(Self::I32(value as _)), (value, crate::ScalarKind::Uint, 8) => Some(Self::U64(value as _)), (value, crate::ScalarKind::Sint, 8) => Some(Self::I64(value as _)), - (1, crate::ScalarKind::Bool, 4) => Some(Self::Bool(true)), - (0, crate::ScalarKind::Bool, 4) => Some(Self::Bool(false)), + (1, crate::ScalarKind::Bool, crate::BOOL_WIDTH) => Some(Self::Bool(true)), + (0, crate::ScalarKind::Bool, crate::BOOL_WIDTH) => Some(Self::Bool(false)), _ => None, } } diff --git a/naga/tests/out/glsl/constructors.main.Compute.glsl b/naga/tests/out/glsl/constructors.main.Compute.glsl index 4b4b0e71a4..c28401d0b4 100644 --- a/naga/tests/out/glsl/constructors.main.Compute.glsl +++ b/naga/tests/out/glsl/constructors.main.Compute.glsl @@ -31,7 +31,6 @@ void main() { uvec2 cit0_ = uvec2(0u); mat2x2 cit1_ = mat2x2(vec2(0.0), vec2(0.0)); int cit2_[4] = int[4](0, 1, 2, 3); - bool ic0_ = bool(false); uvec2 ic4_ = uvec2(0u, 0u); mat2x3 ic5_ = mat2x3(vec3(0.0, 0.0, 0.0), vec3(0.0, 0.0, 0.0)); } diff --git a/naga/tests/out/hlsl/constructors.hlsl b/naga/tests/out/hlsl/constructors.hlsl index 232494fa21..39f3137605 100644 --- a/naga/tests/out/hlsl/constructors.hlsl +++ b/naga/tests/out/hlsl/constructors.hlsl @@ -49,7 +49,6 @@ void main() uint2 cit0_ = (0u).xx; float2x2 cit1_ = float2x2((0.0).xx, (0.0).xx); int cit2_[4] = Constructarray4_int_(0, 1, 2, 3); - bool ic0_ = bool((bool)0); uint2 ic4_ = uint2(0u, 0u); float2x3 ic5_ = float2x3(float3(0.0, 0.0, 0.0), float3(0.0, 0.0, 0.0)); } diff --git a/naga/tests/out/msl/constructors.msl b/naga/tests/out/msl/constructors.msl index b29e2468b0..6733568a92 100644 --- a/naga/tests/out/msl/constructors.msl +++ b/naga/tests/out/msl/constructors.msl @@ -39,7 +39,6 @@ kernel void main_( metal::uint2 cit0_ = metal::uint2(0u); metal::float2x2 cit1_ = metal::float2x2(metal::float2(0.0), metal::float2(0.0)); type_11 cit2_ = type_11 {0, 1, 2, 3}; - bool ic0_ = static_cast(bool {}); metal::uint2 ic4_ = metal::uint2(0u, 0u); metal::float2x3 ic5_ = metal::float2x3(metal::float3(0.0, 0.0, 0.0), metal::float3(0.0, 0.0, 0.0)); } diff --git a/naga/tests/out/spv/constructors.spvasm b/naga/tests/out/spv/constructors.spvasm index 1a481aa95e..615a31dc1b 100644 --- a/naga/tests/out/spv/constructors.spvasm +++ b/naga/tests/out/spv/constructors.spvasm @@ -67,17 +67,18 @@ OpDecorate %17 ArrayStride 4 %56 = OpConstantComposite %14 %55 %55 %57 = OpConstantComposite %9 %21 %21 %58 = OpConstantComposite %8 %57 %57 -%59 = OpConstantComposite %14 %55 %55 -%60 = OpConstantComposite %7 %21 %21 %21 -%61 = OpConstantComposite %20 %60 %60 -%62 = OpConstantNull %20 -%64 = OpTypePointer Function %6 -%65 = OpConstantNull %6 +%59 = OpConstantFalse %13 +%60 = OpConstantComposite %14 %55 %55 +%61 = OpConstantComposite %7 %21 %21 %21 +%62 = OpConstantComposite %20 %61 %61 +%63 = OpConstantNull %20 +%65 = OpTypePointer Function %6 +%66 = OpConstantNull %6 %44 = OpFunction %2 None %45 %43 = OpLabel -%63 = OpVariable %64 Function %65 -OpBranch %66 -%66 = OpLabel -OpStore %63 %47 +%64 = OpVariable %65 Function %66 +OpBranch %67 +%67 = OpLabel +OpStore %64 %47 OpReturn OpFunctionEnd \ No newline at end of file diff --git a/naga/tests/out/wgsl/constructors.wgsl b/naga/tests/out/wgsl/constructors.wgsl index a8f62dfecd..0e5eec734a 100644 --- a/naga/tests/out/wgsl/constructors.wgsl +++ b/naga/tests/out/wgsl/constructors.wgsl @@ -26,7 +26,6 @@ fn main() { let cit0_ = vec2(0u); let cit1_ = mat2x2(vec2(0f), vec2(0f)); let cit2_ = array(0i, 1i, 2i, 3i); - let ic0_ = bool(bool()); let ic4_ = vec2(0u, 0u); let ic5_ = mat2x3(vec3(0f, 0f, 0f), vec3(0f, 0f, 0f)); } From c48f01b29bcc69e36a27ac4029ab40e084d05be1 Mon Sep 17 00:00:00 2001 From: teoxoy <28601907+teoxoy@users.noreply.github.com> Date: Wed, 14 Feb 2024 15:17:07 +0100 Subject: [PATCH 09/30] implement override-expression evaluation for initializers of override declarations --- naga/src/arena.rs | 2 + naga/src/back/hlsl/mod.rs | 2 +- naga/src/back/hlsl/writer.rs | 9 +- naga/src/back/msl/mod.rs | 2 +- naga/src/back/msl/writer.rs | 6 +- naga/src/back/pipeline_constants.rs | 335 +++++++++++++++-- naga/src/back/spv/mod.rs | 40 +- naga/src/back/spv/writer.rs | 14 +- naga/src/front/glsl/context.rs | 23 +- naga/src/front/glsl/functions.rs | 10 + naga/src/front/glsl/parser.rs | 17 +- naga/src/front/glsl/parser/declarations.rs | 9 +- naga/src/front/glsl/parser/functions.rs | 7 +- naga/src/front/glsl/types.rs | 7 +- naga/src/front/wgsl/lower/mod.rs | 92 ++++- naga/src/proc/constant_evaluator.rs | 407 ++++++++++++++------- naga/src/proc/mod.rs | 2 +- naga/src/valid/analyzer.rs | 2 +- naga/src/valid/expression.rs | 21 +- naga/src/valid/function.rs | 20 +- naga/src/valid/handles.rs | 2 + naga/src/valid/interface.rs | 13 +- naga/src/valid/mod.rs | 41 ++- naga/src/valid/type.rs | 2 + naga/tests/in/overrides.wgsl | 2 +- naga/tests/out/analysis/overrides.info.ron | 6 + naga/tests/out/hlsl/overrides.hlsl | 1 + naga/tests/out/ir/overrides.compact.ron | 15 +- naga/tests/out/ir/overrides.ron | 15 +- naga/tests/out/msl/overrides.msl | 1 + naga/tests/out/spv/overrides.main.spvasm | 24 +- naga/tests/out/wgsl/quad_glsl.vert.wgsl | 4 +- 32 files changed, 910 insertions(+), 243 deletions(-) diff --git a/naga/src/arena.rs b/naga/src/arena.rs index 184102757e..740df85b86 100644 --- a/naga/src/arena.rs +++ b/naga/src/arena.rs @@ -122,6 +122,7 @@ impl Handle { serde(transparent) )] #[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))] +#[cfg_attr(test, derive(PartialEq))] pub struct Range { inner: ops::Range, #[cfg_attr(any(feature = "serialize", feature = "deserialize"), serde(skip))] @@ -140,6 +141,7 @@ impl Range { // NOTE: Keep this diagnostic in sync with that of [`BadHandle`]. #[derive(Clone, Debug, thiserror::Error)] +#[cfg_attr(test, derive(PartialEq))] #[error("Handle range {range:?} of {kind} is either not present, or inaccessible yet")] pub struct BadRangeError { // This error is used for many `Handle` types, but there's no point in making this generic, so diff --git a/naga/src/back/hlsl/mod.rs b/naga/src/back/hlsl/mod.rs index 588c91d69d..d423b003ff 100644 --- a/naga/src/back/hlsl/mod.rs +++ b/naga/src/back/hlsl/mod.rs @@ -256,7 +256,7 @@ pub enum Error { #[error("{0}")] Custom(String), #[error(transparent)] - PipelineConstant(#[from] back::pipeline_constants::PipelineConstantError), + PipelineConstant(#[from] Box), } #[derive(Default)] diff --git a/naga/src/back/hlsl/writer.rs b/naga/src/back/hlsl/writer.rs index 0db6489840..1abc6ceca0 100644 --- a/naga/src/back/hlsl/writer.rs +++ b/naga/src/back/hlsl/writer.rs @@ -169,9 +169,14 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { module_info: &valid::ModuleInfo, pipeline_options: &PipelineOptions, ) -> Result { - let module = - back::pipeline_constants::process_overrides(module, &pipeline_options.constants)?; + let (module, module_info) = back::pipeline_constants::process_overrides( + module, + module_info, + &pipeline_options.constants, + ) + .map_err(Box::new)?; let module = module.as_ref(); + let module_info = module_info.as_ref(); self.reset(module); diff --git a/naga/src/back/msl/mod.rs b/naga/src/back/msl/mod.rs index 702b373cfc..6ba8227a20 100644 --- a/naga/src/back/msl/mod.rs +++ b/naga/src/back/msl/mod.rs @@ -144,7 +144,7 @@ pub enum Error { #[error("ray tracing is not supported prior to MSL 2.3")] UnsupportedRayTracing, #[error(transparent)] - PipelineConstant(#[from] crate::back::pipeline_constants::PipelineConstantError), + PipelineConstant(#[from] Box), } #[derive(Clone, Debug, PartialEq, thiserror::Error)] diff --git a/naga/src/back/msl/writer.rs b/naga/src/back/msl/writer.rs index 36d8bc820b..3c2a741cd4 100644 --- a/naga/src/back/msl/writer.rs +++ b/naga/src/back/msl/writer.rs @@ -3223,9 +3223,11 @@ impl Writer { options: &Options, pipeline_options: &PipelineOptions, ) -> Result { - let module = - back::pipeline_constants::process_overrides(module, &pipeline_options.constants)?; + let (module, info) = + back::pipeline_constants::process_overrides(module, info, &pipeline_options.constants) + .map_err(Box::new)?; let module = module.as_ref(); + let info = info.as_ref(); self.names.clear(); self.namer.reset( diff --git a/naga/src/back/pipeline_constants.rs b/naga/src/back/pipeline_constants.rs index 5a3cad2a6d..6b2792dd28 100644 --- a/naga/src/back/pipeline_constants.rs +++ b/naga/src/back/pipeline_constants.rs @@ -1,6 +1,10 @@ use super::PipelineConstants; -use crate::{Constant, Expression, Literal, Module, Scalar, Span, TypeInner}; -use std::borrow::Cow; +use crate::{ + proc::{ConstantEvaluator, ConstantEvaluatorError}, + valid::{Capabilities, ModuleInfo, ValidationError, ValidationFlags, Validator}, + Constant, Expression, Handle, Literal, Module, Override, Scalar, Span, TypeInner, WithSpan, +}; +use std::{borrow::Cow, collections::HashSet}; use thiserror::Error; #[derive(Error, Debug, Clone)] @@ -12,48 +16,317 @@ pub enum PipelineConstantError { SrcNeedsToBeFinite, #[error("Source f64 value doesn't fit in destination")] DstRangeTooSmall, + #[error(transparent)] + ConstantEvaluatorError(#[from] ConstantEvaluatorError), + #[error(transparent)] + ValidationError(#[from] WithSpan), } pub(super) fn process_overrides<'a>( module: &'a Module, + module_info: &'a ModuleInfo, pipeline_constants: &PipelineConstants, -) -> Result, PipelineConstantError> { +) -> Result<(Cow<'a, Module>, Cow<'a, ModuleInfo>), PipelineConstantError> { if module.overrides.is_empty() { - return Ok(Cow::Borrowed(module)); + return Ok((Cow::Borrowed(module), Cow::Borrowed(module_info))); } let mut module = module.clone(); + let mut override_map = Vec::with_capacity(module.overrides.len()); + let mut adjusted_const_expressions = Vec::with_capacity(module.const_expressions.len()); + let mut adjusted_constant_initializers = HashSet::with_capacity(module.constants.len()); - for (_handle, override_, span) in module.overrides.drain() { - let key = if let Some(id) = override_.id { - Cow::Owned(id.to_string()) - } else if let Some(ref name) = override_.name { - Cow::Borrowed(name) - } else { - unreachable!(); - }; - let init = if let Some(value) = pipeline_constants.get::(&key) { - let literal = match module.types[override_.ty].inner { - TypeInner::Scalar(scalar) => map_value_to_literal(*value, scalar)?, - _ => unreachable!(), - }; - module - .const_expressions - .append(Expression::Literal(literal), Span::UNDEFINED) - } else if let Some(init) = override_.init { - init - } else { - return Err(PipelineConstantError::MissingValue(key.to_string())); - }; - let constant = Constant { - name: override_.name, - ty: override_.ty, - init, + let mut global_expression_kind_tracker = crate::proc::ExpressionConstnessTracker::new(); + + let mut override_iter = module.overrides.drain(); + + for (old_h, expr, span) in module.const_expressions.drain() { + let mut expr = match expr { + Expression::Override(h) => { + let c_h = if let Some(new_h) = override_map.get(h.index()) { + *new_h + } else { + let mut new_h = None; + for entry in override_iter.by_ref() { + let stop = entry.0 == h; + new_h = Some(process_override( + entry, + pipeline_constants, + &mut module, + &mut override_map, + &adjusted_const_expressions, + &mut adjusted_constant_initializers, + &mut global_expression_kind_tracker, + )?); + if stop { + break; + } + } + new_h.unwrap() + }; + Expression::Constant(c_h) + } + Expression::Constant(c_h) => { + adjusted_constant_initializers.insert(c_h); + module.constants[c_h].init = adjusted_const_expressions[c_h.index()]; + expr + } + expr => expr, }; - module.constants.append(constant, span); + let mut evaluator = ConstantEvaluator::for_wgsl_module( + &mut module, + &mut global_expression_kind_tracker, + false, + ); + adjust_expr(&adjusted_const_expressions, &mut expr); + let h = evaluator.try_eval_and_append(expr, span)?; + debug_assert_eq!(old_h.index(), adjusted_const_expressions.len()); + adjusted_const_expressions.push(h); + } + + for entry in override_iter { + process_override( + entry, + pipeline_constants, + &mut module, + &mut override_map, + &adjusted_const_expressions, + &mut adjusted_constant_initializers, + &mut global_expression_kind_tracker, + )?; + } + + for (_, c) in module + .constants + .iter_mut() + .filter(|&(c_h, _)| !adjusted_constant_initializers.contains(&c_h)) + { + c.init = adjusted_const_expressions[c.init.index()]; + } + + for (_, v) in module.global_variables.iter_mut() { + if let Some(ref mut init) = v.init { + *init = adjusted_const_expressions[init.index()]; + } } - Ok(Cow::Owned(module)) + let mut validator = Validator::new(ValidationFlags::all(), Capabilities::all()); + let module_info = validator.validate(&module)?; + + Ok((Cow::Owned(module), Cow::Owned(module_info))) +} + +fn process_override( + (old_h, override_, span): (Handle, Override, Span), + pipeline_constants: &PipelineConstants, + module: &mut Module, + override_map: &mut Vec>, + adjusted_const_expressions: &[Handle], + adjusted_constant_initializers: &mut HashSet>, + global_expression_kind_tracker: &mut crate::proc::ExpressionConstnessTracker, +) -> Result, PipelineConstantError> { + let key = if let Some(id) = override_.id { + Cow::Owned(id.to_string()) + } else if let Some(ref name) = override_.name { + Cow::Borrowed(name) + } else { + unreachable!(); + }; + let init = if let Some(value) = pipeline_constants.get::(&key) { + let literal = match module.types[override_.ty].inner { + TypeInner::Scalar(scalar) => map_value_to_literal(*value, scalar)?, + _ => unreachable!(), + }; + let expr = module + .const_expressions + .append(Expression::Literal(literal), Span::UNDEFINED); + global_expression_kind_tracker.insert(expr, crate::proc::ExpressionKind::Const); + expr + } else if let Some(init) = override_.init { + adjusted_const_expressions[init.index()] + } else { + return Err(PipelineConstantError::MissingValue(key.to_string())); + }; + let constant = Constant { + name: override_.name, + ty: override_.ty, + init, + }; + let h = module.constants.append(constant, span); + debug_assert_eq!(old_h.index(), override_map.len()); + override_map.push(h); + adjusted_constant_initializers.insert(h); + Ok(h) +} + +fn adjust_expr(new_pos: &[Handle], expr: &mut Expression) { + let adjust = |expr: &mut Handle| { + *expr = new_pos[expr.index()]; + }; + match *expr { + Expression::Compose { + ref mut components, .. + } => { + for c in components.iter_mut() { + adjust(c); + } + } + Expression::Access { + ref mut base, + ref mut index, + } => { + adjust(base); + adjust(index); + } + Expression::AccessIndex { ref mut base, .. } => { + adjust(base); + } + Expression::Splat { ref mut value, .. } => { + adjust(value); + } + Expression::Swizzle { ref mut vector, .. } => { + adjust(vector); + } + Expression::Load { ref mut pointer } => { + adjust(pointer); + } + Expression::ImageSample { + ref mut image, + ref mut sampler, + ref mut coordinate, + ref mut array_index, + ref mut offset, + ref mut level, + ref mut depth_ref, + .. + } => { + adjust(image); + adjust(sampler); + adjust(coordinate); + if let Some(e) = array_index.as_mut() { + adjust(e); + } + if let Some(e) = offset.as_mut() { + adjust(e); + } + match *level { + crate::SampleLevel::Exact(ref mut expr) + | crate::SampleLevel::Bias(ref mut expr) => { + adjust(expr); + } + crate::SampleLevel::Gradient { + ref mut x, + ref mut y, + } => { + adjust(x); + adjust(y); + } + _ => {} + } + if let Some(e) = depth_ref.as_mut() { + adjust(e); + } + } + Expression::ImageLoad { + ref mut image, + ref mut coordinate, + ref mut array_index, + ref mut sample, + ref mut level, + } => { + adjust(image); + adjust(coordinate); + if let Some(e) = array_index.as_mut() { + adjust(e); + } + if let Some(e) = sample.as_mut() { + adjust(e); + } + if let Some(e) = level.as_mut() { + adjust(e); + } + } + Expression::ImageQuery { + ref mut image, + ref mut query, + } => { + adjust(image); + match *query { + crate::ImageQuery::Size { ref mut level } => { + if let Some(e) = level.as_mut() { + adjust(e); + } + } + _ => {} + } + } + Expression::Unary { ref mut expr, .. } => { + adjust(expr); + } + Expression::Binary { + ref mut left, + ref mut right, + .. + } => { + adjust(left); + adjust(right); + } + Expression::Select { + ref mut condition, + ref mut accept, + ref mut reject, + } => { + adjust(condition); + adjust(accept); + adjust(reject); + } + Expression::Derivative { ref mut expr, .. } => { + adjust(expr); + } + Expression::Relational { + ref mut argument, .. + } => { + adjust(argument); + } + Expression::Math { + ref mut arg, + ref mut arg1, + ref mut arg2, + ref mut arg3, + .. + } => { + adjust(arg); + if let Some(e) = arg1.as_mut() { + adjust(e); + } + if let Some(e) = arg2.as_mut() { + adjust(e); + } + if let Some(e) = arg3.as_mut() { + adjust(e); + } + } + Expression::As { ref mut expr, .. } => { + adjust(expr); + } + Expression::ArrayLength(ref mut expr) => { + adjust(expr); + } + Expression::RayQueryGetIntersection { ref mut query, .. } => { + adjust(query); + } + Expression::Literal(_) + | Expression::FunctionArgument(_) + | Expression::GlobalVariable(_) + | Expression::LocalVariable(_) + | Expression::CallResult(_) + | Expression::RayQueryProceedResult + | Expression::Constant(_) + | Expression::Override(_) + | Expression::ZeroValue(_) + | Expression::AtomicResult { .. } + | Expression::WorkGroupUniformLoadResult { .. } => {} + } } fn map_value_to_literal(value: f64, scalar: Scalar) -> Result { diff --git a/naga/src/back/spv/mod.rs b/naga/src/back/spv/mod.rs index 3c0332d59d..f1bbaecce1 100644 --- a/naga/src/back/spv/mod.rs +++ b/naga/src/back/spv/mod.rs @@ -71,7 +71,7 @@ pub enum Error { #[error("module is not validated properly: {0}")] Validation(&'static str), #[error(transparent)] - PipelineConstant(#[from] crate::back::pipeline_constants::PipelineConstantError), + PipelineConstant(#[from] Box), } #[derive(Default)] @@ -529,6 +529,42 @@ struct FunctionArgument { handle_id: Word, } +/// Tracks the expressions for which the backend emits the following instructions: +/// - OpConstantTrue +/// - OpConstantFalse +/// - OpConstant +/// - OpConstantComposite +/// - OpConstantNull +struct ExpressionConstnessTracker { + inner: bit_set::BitSet, +} + +impl ExpressionConstnessTracker { + fn from_arena(arena: &crate::Arena) -> Self { + let mut inner = bit_set::BitSet::new(); + for (handle, expr) in arena.iter() { + let insert = match *expr { + crate::Expression::Literal(_) + | crate::Expression::ZeroValue(_) + | crate::Expression::Constant(_) => true, + crate::Expression::Compose { ref components, .. } => { + components.iter().all(|h| inner.contains(h.index())) + } + crate::Expression::Splat { value, .. } => inner.contains(value.index()), + _ => false, + }; + if insert { + inner.insert(handle.index()); + } + } + Self { inner } + } + + fn is_const(&self, value: Handle) -> bool { + self.inner.contains(value.index()) + } +} + /// General information needed to emit SPIR-V for Naga statements. struct BlockContext<'w> { /// The writer handling the module to which this code belongs. @@ -554,7 +590,7 @@ struct BlockContext<'w> { temp_list: Vec, /// Tracks the constness of `Expression`s residing in `self.ir_function.expressions` - expression_constness: crate::proc::ExpressionConstnessTracker, + expression_constness: ExpressionConstnessTracker, } impl BlockContext<'_> { diff --git a/naga/src/back/spv/writer.rs b/naga/src/back/spv/writer.rs index 975aa625d0..868fad7fa2 100644 --- a/naga/src/back/spv/writer.rs +++ b/naga/src/back/spv/writer.rs @@ -615,7 +615,7 @@ impl Writer { // Steal the Writer's temp list for a bit. temp_list: std::mem::take(&mut self.temp_list), writer: self, - expression_constness: crate::proc::ExpressionConstnessTracker::from_arena( + expression_constness: super::ExpressionConstnessTracker::from_arena( &ir_function.expressions, ), }; @@ -2029,15 +2029,21 @@ impl Writer { debug_info: &Option, words: &mut Vec, ) -> Result<(), Error> { - let ir_module = if let Some(pipeline_options) = pipeline_options { + let (ir_module, info) = if let Some(pipeline_options) = pipeline_options { crate::back::pipeline_constants::process_overrides( ir_module, + info, &pipeline_options.constants, - )? + ) + .map_err(Box::new)? } else { - std::borrow::Cow::Borrowed(ir_module) + ( + std::borrow::Cow::Borrowed(ir_module), + std::borrow::Cow::Borrowed(info), + ) }; let ir_module = ir_module.as_ref(); + let info = info.as_ref(); self.reset(); diff --git a/naga/src/front/glsl/context.rs b/naga/src/front/glsl/context.rs index a3b4e0edde..0c370cd5e5 100644 --- a/naga/src/front/glsl/context.rs +++ b/naga/src/front/glsl/context.rs @@ -77,12 +77,19 @@ pub struct Context<'a> { pub body: Block, pub module: &'a mut crate::Module, pub is_const: bool, - /// Tracks the constness of `Expression`s residing in `self.expressions` - pub expression_constness: crate::proc::ExpressionConstnessTracker, + /// Tracks the expression kind of `Expression`s residing in `self.expressions` + pub local_expression_kind_tracker: crate::proc::ExpressionConstnessTracker, + /// Tracks the expression kind of `Expression`s residing in `self.module.const_expressions` + pub global_expression_kind_tracker: &'a mut crate::proc::ExpressionConstnessTracker, } impl<'a> Context<'a> { - pub fn new(frontend: &Frontend, module: &'a mut crate::Module, is_const: bool) -> Result { + pub fn new( + frontend: &Frontend, + module: &'a mut crate::Module, + is_const: bool, + global_expression_kind_tracker: &'a mut crate::proc::ExpressionConstnessTracker, + ) -> Result { let mut this = Context { expressions: Arena::new(), locals: Arena::new(), @@ -101,7 +108,8 @@ impl<'a> Context<'a> { body: Block::new(), module, is_const: false, - expression_constness: crate::proc::ExpressionConstnessTracker::new(), + local_expression_kind_tracker: crate::proc::ExpressionConstnessTracker::new(), + global_expression_kind_tracker, }; this.emit_start(); @@ -249,12 +257,15 @@ impl<'a> Context<'a> { pub fn add_expression(&mut self, expr: Expression, meta: Span) -> Result> { let mut eval = if self.is_const { - crate::proc::ConstantEvaluator::for_glsl_module(self.module) + crate::proc::ConstantEvaluator::for_glsl_module( + self.module, + self.global_expression_kind_tracker, + ) } else { crate::proc::ConstantEvaluator::for_glsl_function( self.module, &mut self.expressions, - &mut self.expression_constness, + &mut self.local_expression_kind_tracker, &mut self.emitter, &mut self.body, ) diff --git a/naga/src/front/glsl/functions.rs b/naga/src/front/glsl/functions.rs index 01846eb814..fa1bbef56b 100644 --- a/naga/src/front/glsl/functions.rs +++ b/naga/src/front/glsl/functions.rs @@ -1236,6 +1236,8 @@ impl Frontend { let pointer = ctx .expressions .append(Expression::GlobalVariable(arg.handle), Default::default()); + ctx.local_expression_kind_tracker + .insert(pointer, crate::proc::ExpressionKind::Runtime); let ty = ctx.module.global_variables[arg.handle].ty; @@ -1256,6 +1258,8 @@ impl Frontend { let value = ctx .expressions .append(Expression::FunctionArgument(idx), Default::default()); + ctx.local_expression_kind_tracker + .insert(value, crate::proc::ExpressionKind::Runtime); ctx.body .push(Statement::Store { pointer, value }, Default::default()); }, @@ -1285,6 +1289,8 @@ impl Frontend { let pointer = ctx .expressions .append(Expression::GlobalVariable(arg.handle), Default::default()); + ctx.local_expression_kind_tracker + .insert(pointer, crate::proc::ExpressionKind::Runtime); let ty = ctx.module.global_variables[arg.handle].ty; @@ -1307,6 +1313,8 @@ impl Frontend { let load = ctx .expressions .append(Expression::Load { pointer }, Default::default()); + ctx.local_expression_kind_tracker + .insert(load, crate::proc::ExpressionKind::Runtime); ctx.body.push( Statement::Emit(ctx.expressions.range_from(len)), Default::default(), @@ -1329,6 +1337,8 @@ impl Frontend { let res = ctx .expressions .append(Expression::Compose { ty, components }, Default::default()); + ctx.local_expression_kind_tracker + .insert(res, crate::proc::ExpressionKind::Runtime); ctx.body.push( Statement::Emit(ctx.expressions.range_from(len)), Default::default(), diff --git a/naga/src/front/glsl/parser.rs b/naga/src/front/glsl/parser.rs index 851d2e1d79..d4eb39b39b 100644 --- a/naga/src/front/glsl/parser.rs +++ b/naga/src/front/glsl/parser.rs @@ -164,9 +164,15 @@ impl<'source> ParsingContext<'source> { pub fn parse(&mut self, frontend: &mut Frontend) -> Result { let mut module = Module::default(); + let mut global_expression_kind_tracker = crate::proc::ExpressionConstnessTracker::new(); // Body and expression arena for global initialization - let mut ctx = Context::new(frontend, &mut module, false)?; + let mut ctx = Context::new( + frontend, + &mut module, + false, + &mut global_expression_kind_tracker, + )?; while self.peek(frontend).is_some() { self.parse_external_declaration(frontend, &mut ctx)?; @@ -196,7 +202,11 @@ impl<'source> ParsingContext<'source> { frontend: &mut Frontend, ctx: &mut Context, ) -> Result<(u32, Span)> { - let (const_expr, meta) = self.parse_constant_expression(frontend, ctx.module)?; + let (const_expr, meta) = self.parse_constant_expression( + frontend, + ctx.module, + ctx.global_expression_kind_tracker, + )?; let res = ctx.module.to_ctx().eval_expr_to_u32(const_expr); @@ -219,8 +229,9 @@ impl<'source> ParsingContext<'source> { &mut self, frontend: &mut Frontend, module: &mut Module, + global_expression_kind_tracker: &mut crate::proc::ExpressionConstnessTracker, ) -> Result<(Handle, Span)> { - let mut ctx = Context::new(frontend, module, true)?; + let mut ctx = Context::new(frontend, module, true, global_expression_kind_tracker)?; let mut stmt_ctx = ctx.stmt_ctx(); let expr = self.parse_conditional(frontend, &mut ctx, &mut stmt_ctx, None)?; diff --git a/naga/src/front/glsl/parser/declarations.rs b/naga/src/front/glsl/parser/declarations.rs index f5e38fb016..2d253a378d 100644 --- a/naga/src/front/glsl/parser/declarations.rs +++ b/naga/src/front/glsl/parser/declarations.rs @@ -251,7 +251,7 @@ impl<'source> ParsingContext<'source> { init.and_then(|expr| ctx.ctx.lift_up_const_expression(expr).ok()); late_initializer = None; } else if let Some(init) = init { - if ctx.is_inside_loop || !ctx.ctx.expression_constness.is_const(init) { + if ctx.is_inside_loop || !ctx.ctx.local_expression_kind_tracker.is_const(init) { decl_initializer = None; late_initializer = Some(init); } else { @@ -326,7 +326,12 @@ impl<'source> ParsingContext<'source> { let result = ty.map(|ty| FunctionResult { ty, binding: None }); - let mut context = Context::new(frontend, ctx.module, false)?; + let mut context = Context::new( + frontend, + ctx.module, + false, + ctx.global_expression_kind_tracker, + )?; self.parse_function_args(frontend, &mut context)?; diff --git a/naga/src/front/glsl/parser/functions.rs b/naga/src/front/glsl/parser/functions.rs index d428d74761..6d3b9d7ba4 100644 --- a/naga/src/front/glsl/parser/functions.rs +++ b/naga/src/front/glsl/parser/functions.rs @@ -192,8 +192,11 @@ impl<'source> ParsingContext<'source> { TokenValue::Case => { self.bump(frontend)?; - let (const_expr, meta) = - self.parse_constant_expression(frontend, ctx.module)?; + let (const_expr, meta) = self.parse_constant_expression( + frontend, + ctx.module, + ctx.global_expression_kind_tracker, + )?; match ctx.module.const_expressions[const_expr] { Expression::Literal(Literal::I32(value)) => match uint { diff --git a/naga/src/front/glsl/types.rs b/naga/src/front/glsl/types.rs index e87d76fffc..8a04b23839 100644 --- a/naga/src/front/glsl/types.rs +++ b/naga/src/front/glsl/types.rs @@ -330,7 +330,7 @@ impl Context<'_> { expr: Handle, ) -> Result> { let meta = self.expressions.get_span(expr); - Ok(match self.expressions[expr] { + let h = match self.expressions[expr] { ref expr @ (Expression::Literal(_) | Expression::Constant(_) | Expression::ZeroValue(_)) => self.module.const_expressions.append(expr.clone(), meta), @@ -355,6 +355,9 @@ impl Context<'_> { meta, }) } - }) + }; + self.global_expression_kind_tracker + .insert(h, crate::proc::ExpressionKind::Const); + Ok(h) } } diff --git a/naga/src/front/wgsl/lower/mod.rs b/naga/src/front/wgsl/lower/mod.rs index 29a87751ca..662e318f8b 100644 --- a/naga/src/front/wgsl/lower/mod.rs +++ b/naga/src/front/wgsl/lower/mod.rs @@ -86,6 +86,8 @@ pub struct GlobalContext<'source, 'temp, 'out> { module: &'out mut crate::Module, const_typifier: &'temp mut Typifier, + + global_expression_kind_tracker: &'temp mut crate::proc::ExpressionConstnessTracker, } impl<'source> GlobalContext<'source, '_, '_> { @@ -97,6 +99,19 @@ impl<'source> GlobalContext<'source, '_, '_> { module: self.module, const_typifier: self.const_typifier, expr_type: ExpressionContextType::Constant, + global_expression_kind_tracker: self.global_expression_kind_tracker, + } + } + + fn as_override(&mut self) -> ExpressionContext<'source, '_, '_> { + ExpressionContext { + ast_expressions: self.ast_expressions, + globals: self.globals, + types: self.types, + module: self.module, + const_typifier: self.const_typifier, + expr_type: ExpressionContextType::Override, + global_expression_kind_tracker: self.global_expression_kind_tracker, } } @@ -165,6 +180,7 @@ pub struct StatementContext<'source, 'temp, 'out> { /// we should consider them to be const. See the use of `force_non_const` in /// the code for lowering `let` bindings. expression_constness: &'temp mut crate::proc::ExpressionConstnessTracker, + global_expression_kind_tracker: &'temp mut crate::proc::ExpressionConstnessTracker, } impl<'a, 'temp> StatementContext<'a, 'temp, '_> { @@ -181,6 +197,7 @@ impl<'a, 'temp> StatementContext<'a, 'temp, '_> { types: self.types, ast_expressions: self.ast_expressions, const_typifier: self.const_typifier, + global_expression_kind_tracker: self.global_expression_kind_tracker, module: self.module, expr_type: ExpressionContextType::Runtime(RuntimeExpressionContext { local_table: self.local_table, @@ -200,6 +217,7 @@ impl<'a, 'temp> StatementContext<'a, 'temp, '_> { types: self.types, module: self.module, const_typifier: self.const_typifier, + global_expression_kind_tracker: self.global_expression_kind_tracker, } } @@ -253,6 +271,14 @@ pub enum ExpressionContextType<'temp, 'out> { /// available in the [`ExpressionContext`], so this variant /// carries no further information. Constant, + + /// We are lowering to an override expression, to be included in the module's + /// constant expression arena. + /// + /// Everything override expressions are allowed to refer to is + /// available in the [`ExpressionContext`], so this variant + /// carries no further information. + Override, } /// State for lowering an [`ast::Expression`] to Naga IR. @@ -311,6 +337,7 @@ pub struct ExpressionContext<'source, 'temp, 'out> { /// /// [`module::const_expressions`]: crate::Module::const_expressions const_typifier: &'temp mut Typifier, + global_expression_kind_tracker: &'temp mut crate::proc::ExpressionConstnessTracker, /// Whether we are lowering a constant expression or a general /// runtime expression, and the data needed in each case. @@ -326,6 +353,7 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> { const_typifier: self.const_typifier, module: self.module, expr_type: ExpressionContextType::Constant, + global_expression_kind_tracker: self.global_expression_kind_tracker, } } @@ -336,6 +364,7 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> { types: self.types, module: self.module, const_typifier: self.const_typifier, + global_expression_kind_tracker: self.global_expression_kind_tracker, } } @@ -348,7 +377,16 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> { rctx.emitter, rctx.block, ), - ExpressionContextType::Constant => ConstantEvaluator::for_wgsl_module(self.module), + ExpressionContextType::Constant => ConstantEvaluator::for_wgsl_module( + self.module, + self.global_expression_kind_tracker, + false, + ), + ExpressionContextType::Override => ConstantEvaluator::for_wgsl_module( + self.module, + self.global_expression_kind_tracker, + true, + ), } } @@ -375,20 +413,25 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> { .ok() } ExpressionContextType::Constant => self.module.to_ctx().eval_expr_to_u32(handle).ok(), + ExpressionContextType::Override => None, } } fn get_expression_span(&self, handle: Handle) -> Span { match self.expr_type { ExpressionContextType::Runtime(ref ctx) => ctx.function.expressions.get_span(handle), - ExpressionContextType::Constant => self.module.const_expressions.get_span(handle), + ExpressionContextType::Constant | ExpressionContextType::Override => { + self.module.const_expressions.get_span(handle) + } } } fn typifier(&self) -> &Typifier { match self.expr_type { ExpressionContextType::Runtime(ref ctx) => ctx.typifier, - ExpressionContextType::Constant => self.const_typifier, + ExpressionContextType::Constant | ExpressionContextType::Override => { + self.const_typifier + } } } @@ -398,7 +441,9 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> { ) -> Result<&mut RuntimeExpressionContext<'temp, 'out>, Error<'source>> { match self.expr_type { ExpressionContextType::Runtime(ref mut ctx) => Ok(ctx), - ExpressionContextType::Constant => Err(Error::UnexpectedOperationInConstContext(span)), + ExpressionContextType::Constant | ExpressionContextType::Override => { + Err(Error::UnexpectedOperationInConstContext(span)) + } } } @@ -435,7 +480,7 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> { } // This means a `gather` operation appeared in a constant expression. // This error refers to the `gather` itself, not its "component" argument. - ExpressionContextType::Constant => { + ExpressionContextType::Constant | ExpressionContextType::Override => { Err(Error::UnexpectedOperationInConstContext(gather_span)) } } @@ -461,7 +506,9 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> { // to also borrow self.module.types mutably below. let typifier = match self.expr_type { ExpressionContextType::Runtime(ref ctx) => ctx.typifier, - ExpressionContextType::Constant => &*self.const_typifier, + ExpressionContextType::Constant | ExpressionContextType::Override => { + &*self.const_typifier + } }; Ok(typifier.register_type(handle, &mut self.module.types)) } @@ -504,7 +551,7 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> { typifier = &mut *ctx.typifier; expressions = &ctx.function.expressions; } - ExpressionContextType::Constant => { + ExpressionContextType::Constant | ExpressionContextType::Override => { resolve_ctx = ResolveContext::with_locals(self.module, &empty_arena, &[]); typifier = self.const_typifier; expressions = &self.module.const_expressions; @@ -600,14 +647,14 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> { rctx.block .extend(rctx.emitter.finish(&rctx.function.expressions)); } - ExpressionContextType::Constant => {} + ExpressionContextType::Constant | ExpressionContextType::Override => {} } let result = self.append_expression(expression, span); match self.expr_type { ExpressionContextType::Runtime(ref mut rctx) => { rctx.emitter.start(&rctx.function.expressions); } - ExpressionContextType::Constant => {} + ExpressionContextType::Constant | ExpressionContextType::Override => {} } result } @@ -852,6 +899,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { types: &tu.types, module: &mut module, const_typifier: &mut Typifier::new(), + global_expression_kind_tracker: &mut crate::proc::ExpressionConstnessTracker::new(), }; for decl_handle in self.index.visit_ordered() { @@ -959,7 +1007,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { ast::GlobalDeclKind::Override(ref o) => { let init = o .init - .map(|init| self.expression(init, &mut ctx.as_const())) + .map(|init| self.expression(init, &mut ctx.as_override())) .transpose()?; let inferred_type = init .map(|init| ctx.as_const().register_type(init)) @@ -1049,6 +1097,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { let mut local_table = FastHashMap::default(); let mut expressions = Arena::new(); let mut named_expressions = FastIndexMap::default(); + let mut local_expression_kind_tracker = crate::proc::ExpressionConstnessTracker::new(); let arguments = f .arguments @@ -1060,6 +1109,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { .append(crate::Expression::FunctionArgument(i as u32), arg.name.span); local_table.insert(arg.handle, Typed::Plain(expr)); named_expressions.insert(expr, (arg.name.name.to_string(), arg.name.span)); + local_expression_kind_tracker.insert(expr, crate::proc::ExpressionKind::Runtime); Ok(crate::FunctionArgument { name: Some(arg.name.name.to_string()), @@ -1102,7 +1152,8 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { named_expressions: &mut named_expressions, types: ctx.types, module: ctx.module, - expression_constness: &mut crate::proc::ExpressionConstnessTracker::new(), + expression_constness: &mut local_expression_kind_tracker, + global_expression_kind_tracker: ctx.global_expression_kind_tracker, }; let mut body = self.block(&f.body, false, &mut stmt_ctx)?; ensure_block_returns(&mut body); @@ -1518,6 +1569,10 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { .function .expressions .append(crate::Expression::Binary { op, left, right }, stmt.span); + rctx.expression_constness + .insert(left, crate::proc::ExpressionKind::Runtime); + rctx.expression_constness + .insert(value, crate::proc::ExpressionKind::Runtime); block.extend(emitter.finish(&ctx.function.expressions)); crate::Statement::Store { @@ -1611,7 +1666,12 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { LoweredGlobalDecl::Const(handle) => { Typed::Plain(crate::Expression::Constant(handle)) } - _ => { + LoweredGlobalDecl::Override(handle) => { + Typed::Plain(crate::Expression::Override(handle)) + } + LoweredGlobalDecl::Function(_) + | LoweredGlobalDecl::Type(_) + | LoweredGlobalDecl::EntryPoint => { return Err(Error::Unexpected(span, ExpectedToken::Variable)); } }; @@ -1886,9 +1946,13 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { rctx.block .extend(rctx.emitter.finish(&rctx.function.expressions)); let result = has_result.then(|| { - rctx.function + let result = rctx + .function .expressions - .append(crate::Expression::CallResult(function), span) + .append(crate::Expression::CallResult(function), span); + rctx.expression_constness + .insert(result, crate::proc::ExpressionKind::Runtime); + result }); rctx.emitter.start(&rctx.function.expressions); rctx.block.push( diff --git a/naga/src/proc/constant_evaluator.rs b/naga/src/proc/constant_evaluator.rs index a9c873afbc..6318a57c00 100644 --- a/naga/src/proc/constant_evaluator.rs +++ b/naga/src/proc/constant_evaluator.rs @@ -253,9 +253,9 @@ gen_component_wise_extractor! { } #[derive(Debug)] -enum Behavior { - Wgsl, - Glsl, +enum Behavior<'a> { + Wgsl(WgslRestrictions<'a>), + Glsl(GlslRestrictions<'a>), } /// A context for evaluating constant expressions. @@ -278,7 +278,7 @@ enum Behavior { #[derive(Debug)] pub struct ConstantEvaluator<'a> { /// Which language's evaluation rules we should follow. - behavior: Behavior, + behavior: Behavior<'a>, /// The module's type arena. /// @@ -297,65 +297,145 @@ pub struct ConstantEvaluator<'a> { /// The arena to which we are contributing expressions. expressions: &'a mut Arena, - /// When `self.expressions` refers to a function's local expression - /// arena, this needs to be populated - function_local_data: Option>, + /// Tracks the constness of expressions residing in [`Self::expressions`] + expression_kind_tracker: &'a mut ExpressionConstnessTracker, +} + +#[derive(Debug)] +enum WgslRestrictions<'a> { + /// - const-expressions will be evaluated and inserted in the arena + Const, + /// - const-expressions will be evaluated and inserted in the arena + /// - override-expressions will be inserted in the arena + Override, + /// - const-expressions will be evaluated and inserted in the arena + /// - override-expressions will be inserted in the arena + /// - runtime-expressions will be inserted in the arena + Runtime(FunctionLocalData<'a>), +} + +#[derive(Debug)] +enum GlslRestrictions<'a> { + /// - const-expressions will be evaluated and inserted in the arena + Const, + /// - const-expressions will be evaluated and inserted in the arena + /// - override-expressions will be inserted in the arena + /// - runtime-expressions will be inserted in the arena + Runtime(FunctionLocalData<'a>), } #[derive(Debug)] struct FunctionLocalData<'a> { /// Global constant expressions const_expressions: &'a Arena, - /// Tracks the constness of expressions residing in `ConstantEvaluator.expressions` - expression_constness: &'a mut ExpressionConstnessTracker, emitter: &'a mut super::Emitter, block: &'a mut crate::Block, } +#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy)] +pub enum ExpressionKind { + Const, + Override, + Runtime, +} + #[derive(Debug)] pub struct ExpressionConstnessTracker { - inner: bit_set::BitSet, + inner: Vec, } impl ExpressionConstnessTracker { - pub fn new() -> Self { - Self { - inner: bit_set::BitSet::new(), - } + pub const fn new() -> Self { + Self { inner: Vec::new() } } /// Forces the the expression to not be const pub fn force_non_const(&mut self, value: Handle) { - self.inner.remove(value.index()); + self.inner[value.index()] = ExpressionKind::Runtime; } - fn insert(&mut self, value: Handle) { - self.inner.insert(value.index()); + pub fn insert(&mut self, value: Handle, expr_type: ExpressionKind) { + assert_eq!(self.inner.len(), value.index()); + self.inner.push(expr_type); + } + pub fn is_const(&self, h: Handle) -> bool { + matches!(self.type_of(h), ExpressionKind::Const) } - pub fn is_const(&self, value: Handle) -> bool { - self.inner.contains(value.index()) + pub fn is_const_or_override(&self, h: Handle) -> bool { + matches!( + self.type_of(h), + ExpressionKind::Const | ExpressionKind::Override + ) + } + + fn type_of(&self, value: Handle) -> ExpressionKind { + self.inner[value.index()] } pub fn from_arena(arena: &Arena) -> Self { - let mut tracker = Self::new(); - for (handle, expr) in arena.iter() { - let insert = match *expr { - crate::Expression::Literal(_) - | crate::Expression::ZeroValue(_) - | crate::Expression::Constant(_) => true, - crate::Expression::Compose { ref components, .. } => { - components.iter().all(|h| tracker.is_const(*h)) - } - crate::Expression::Splat { value, .. } => tracker.is_const(value), - _ => false, - }; - if insert { - tracker.insert(handle); - } + let mut tracker = Self { + inner: Vec::with_capacity(arena.len()), + }; + for (_, expr) in arena.iter() { + tracker.inner.push(tracker.type_of_with_expr(expr)); } tracker } + + fn type_of_with_expr(&self, expr: &Expression) -> ExpressionKind { + match *expr { + Expression::Literal(_) | Expression::ZeroValue(_) | Expression::Constant(_) => { + ExpressionKind::Const + } + Expression::Override(_) => ExpressionKind::Override, + Expression::Compose { ref components, .. } => { + let mut expr_type = ExpressionKind::Const; + for component in components { + expr_type = expr_type.max(self.type_of(*component)) + } + expr_type + } + Expression::Splat { value, .. } => self.type_of(value), + Expression::AccessIndex { base, .. } => self.type_of(base), + Expression::Access { base, index } => self.type_of(base).max(self.type_of(index)), + Expression::Swizzle { vector, .. } => self.type_of(vector), + Expression::Unary { expr, .. } => self.type_of(expr), + Expression::Binary { left, right, .. } => self.type_of(left).max(self.type_of(right)), + Expression::Math { + arg, + arg1, + arg2, + arg3, + .. + } => self + .type_of(arg) + .max( + arg1.map(|arg| self.type_of(arg)) + .unwrap_or(ExpressionKind::Const), + ) + .max( + arg2.map(|arg| self.type_of(arg)) + .unwrap_or(ExpressionKind::Const), + ) + .max( + arg3.map(|arg| self.type_of(arg)) + .unwrap_or(ExpressionKind::Const), + ), + Expression::As { expr, .. } => self.type_of(expr), + Expression::Select { + condition, + accept, + reject, + } => self + .type_of(condition) + .max(self.type_of(accept)) + .max(self.type_of(reject)), + Expression::Relational { argument, .. } => self.type_of(argument), + Expression::ArrayLength(expr) => self.type_of(expr), + _ => ExpressionKind::Runtime, + } + } } #[derive(Clone, Debug, thiserror::Error)] @@ -436,6 +516,12 @@ pub enum ConstantEvaluatorError { ShiftedMoreThan32Bits, #[error(transparent)] Literal(#[from] crate::valid::LiteralError), + #[error("Can't use pipeline-overridable constants in const-expressions")] + Override, + #[error("Unexpected runtime-expression")] + RuntimeExpr, + #[error("Unexpected override-expression")] + OverrideExpr, } impl<'a> ConstantEvaluator<'a> { @@ -443,26 +529,49 @@ impl<'a> ConstantEvaluator<'a> { /// constant expression arena. /// /// Report errors according to WGSL's rules for constant evaluation. - pub fn for_wgsl_module(module: &'a mut crate::Module) -> Self { - Self::for_module(Behavior::Wgsl, module) + pub fn for_wgsl_module( + module: &'a mut crate::Module, + global_expression_kind_tracker: &'a mut ExpressionConstnessTracker, + in_override_ctx: bool, + ) -> Self { + Self::for_module( + Behavior::Wgsl(if in_override_ctx { + WgslRestrictions::Override + } else { + WgslRestrictions::Const + }), + module, + global_expression_kind_tracker, + ) } /// Return a [`ConstantEvaluator`] that will add expressions to `module`'s /// constant expression arena. /// /// Report errors according to GLSL's rules for constant evaluation. - pub fn for_glsl_module(module: &'a mut crate::Module) -> Self { - Self::for_module(Behavior::Glsl, module) + pub fn for_glsl_module( + module: &'a mut crate::Module, + global_expression_kind_tracker: &'a mut ExpressionConstnessTracker, + ) -> Self { + Self::for_module( + Behavior::Glsl(GlslRestrictions::Const), + module, + global_expression_kind_tracker, + ) } - fn for_module(behavior: Behavior, module: &'a mut crate::Module) -> Self { + fn for_module( + behavior: Behavior<'a>, + module: &'a mut crate::Module, + global_expression_kind_tracker: &'a mut ExpressionConstnessTracker, + ) -> Self { Self { behavior, types: &mut module.types, constants: &module.constants, overrides: &module.overrides, expressions: &mut module.const_expressions, - function_local_data: None, + expression_kind_tracker: global_expression_kind_tracker, } } @@ -473,18 +582,22 @@ impl<'a> ConstantEvaluator<'a> { pub fn for_wgsl_function( module: &'a mut crate::Module, expressions: &'a mut Arena, - expression_constness: &'a mut ExpressionConstnessTracker, + local_expression_kind_tracker: &'a mut ExpressionConstnessTracker, emitter: &'a mut super::Emitter, block: &'a mut crate::Block, ) -> Self { - Self::for_function( - Behavior::Wgsl, - module, + Self { + behavior: Behavior::Wgsl(WgslRestrictions::Runtime(FunctionLocalData { + const_expressions: &module.const_expressions, + emitter, + block, + })), + types: &mut module.types, + constants: &module.constants, + overrides: &module.overrides, expressions, - expression_constness, - emitter, - block, - ) + expression_kind_tracker: local_expression_kind_tracker, + } } /// Return a [`ConstantEvaluator`] that will add expressions to `function`'s @@ -494,40 +607,21 @@ impl<'a> ConstantEvaluator<'a> { pub fn for_glsl_function( module: &'a mut crate::Module, expressions: &'a mut Arena, - expression_constness: &'a mut ExpressionConstnessTracker, - emitter: &'a mut super::Emitter, - block: &'a mut crate::Block, - ) -> Self { - Self::for_function( - Behavior::Glsl, - module, - expressions, - expression_constness, - emitter, - block, - ) - } - - fn for_function( - behavior: Behavior, - module: &'a mut crate::Module, - expressions: &'a mut Arena, - expression_constness: &'a mut ExpressionConstnessTracker, + local_expression_kind_tracker: &'a mut ExpressionConstnessTracker, emitter: &'a mut super::Emitter, block: &'a mut crate::Block, ) -> Self { Self { - behavior, + behavior: Behavior::Glsl(GlslRestrictions::Runtime(FunctionLocalData { + const_expressions: &module.const_expressions, + emitter, + block, + })), types: &mut module.types, constants: &module.constants, overrides: &module.overrides, expressions, - function_local_data: Some(FunctionLocalData { - const_expressions: &module.const_expressions, - expression_constness, - emitter, - block, - }), + expression_kind_tracker: local_expression_kind_tracker, } } @@ -536,19 +630,17 @@ impl<'a> ConstantEvaluator<'a> { types: self.types, constants: self.constants, overrides: self.overrides, - const_expressions: match self.function_local_data { - Some(ref data) => data.const_expressions, + const_expressions: match self.function_local_data() { + Some(data) => data.const_expressions, None => self.expressions, }, } } fn check(&self, expr: Handle) -> Result<(), ConstantEvaluatorError> { - if let Some(ref function_local_data) = self.function_local_data { - if !function_local_data.expression_constness.is_const(expr) { - log::debug!("check: SubexpressionsAreNotConstant"); - return Err(ConstantEvaluatorError::SubexpressionsAreNotConstant); - } + if !self.expression_kind_tracker.is_const(expr) { + log::debug!("check: SubexpressionsAreNotConstant"); + return Err(ConstantEvaluatorError::SubexpressionsAreNotConstant); } Ok(()) } @@ -561,7 +653,7 @@ impl<'a> ConstantEvaluator<'a> { Expression::Constant(c) => { // Are we working in a function's expression arena, or the // module's constant expression arena? - if let Some(ref function_local_data) = self.function_local_data { + if let Some(function_local_data) = self.function_local_data() { // Deep-copy the constant's value into our arena. self.copy_from( self.constants[c].init, @@ -607,14 +699,56 @@ impl<'a> ConstantEvaluator<'a> { expr: Expression, span: Span, ) -> Result, ConstantEvaluatorError> { - let res = self.try_eval_and_append_impl(&expr, span); - if self.function_local_data.is_some() { - match res { - Ok(h) => Ok(h), - Err(_) => Ok(self.append_expr(expr, span, false)), - } - } else { - res + match ( + &self.behavior, + self.expression_kind_tracker.type_of_with_expr(&expr), + ) { + // avoid errors on unimplemented functionality if possible + ( + &Behavior::Wgsl(WgslRestrictions::Runtime(_)) + | &Behavior::Glsl(GlslRestrictions::Runtime(_)), + ExpressionKind::Const, + ) => match self.try_eval_and_append_impl(&expr, span) { + Err( + ConstantEvaluatorError::NotImplemented(_) + | ConstantEvaluatorError::InvalidBinaryOpArgs, + ) => Ok(self.append_expr(expr, span, ExpressionKind::Runtime)), + res => res, + }, + (_, ExpressionKind::Const) => self.try_eval_and_append_impl(&expr, span), + (&Behavior::Wgsl(WgslRestrictions::Const), ExpressionKind::Override) => { + Err(ConstantEvaluatorError::OverrideExpr) + } + ( + &Behavior::Wgsl(WgslRestrictions::Override | WgslRestrictions::Runtime(_)), + ExpressionKind::Override, + ) => Ok(self.append_expr(expr, span, ExpressionKind::Override)), + (&Behavior::Glsl(_), ExpressionKind::Override) => unreachable!(), + ( + &Behavior::Wgsl(WgslRestrictions::Runtime(_)) + | &Behavior::Glsl(GlslRestrictions::Runtime(_)), + ExpressionKind::Runtime, + ) => Ok(self.append_expr(expr, span, ExpressionKind::Runtime)), + (_, ExpressionKind::Runtime) => Err(ConstantEvaluatorError::RuntimeExpr), + } + } + + /// Is the [`Self::expressions`] arena the global module expression arena? + const fn is_global_arena(&self) -> bool { + matches!( + self.behavior, + Behavior::Wgsl(WgslRestrictions::Const | WgslRestrictions::Override) + | Behavior::Glsl(GlslRestrictions::Const) + ) + } + + const fn function_local_data(&self) -> Option<&FunctionLocalData<'a>> { + match self.behavior { + Behavior::Wgsl(WgslRestrictions::Runtime(ref function_local_data)) + | Behavior::Glsl(GlslRestrictions::Runtime(ref function_local_data)) => { + Some(function_local_data) + } + _ => None, } } @@ -625,14 +759,12 @@ impl<'a> ConstantEvaluator<'a> { ) -> Result, ConstantEvaluatorError> { log::trace!("try_eval_and_append: {:?}", expr); match *expr { - Expression::Constant(c) if self.function_local_data.is_none() => { + Expression::Constant(c) if self.is_global_arena() => { // "See through" the constant and use its initializer. // This is mainly done to avoid having constants pointing to other constants. Ok(self.constants[c].init) } - Expression::Override(_) => Err(ConstantEvaluatorError::NotImplemented( - "overrides are WIP".into(), - )), + Expression::Override(_) => Err(ConstantEvaluatorError::Override), Expression::Literal(_) | Expression::ZeroValue(_) | Expression::Constant(_) => { self.register_evaluated_expr(expr.clone(), span) } @@ -713,8 +845,8 @@ impl<'a> ConstantEvaluator<'a> { format!("{fun:?} built-in function"), )), Expression::ArrayLength(expr) => match self.behavior { - Behavior::Wgsl => Err(ConstantEvaluatorError::ArrayLength), - Behavior::Glsl => { + Behavior::Wgsl(_) => Err(ConstantEvaluatorError::ArrayLength), + Behavior::Glsl(_) => { let expr = self.check_and_get(expr)?; self.array_length(expr, span) } @@ -1881,34 +2013,35 @@ impl<'a> ConstantEvaluator<'a> { crate::valid::check_literal_value(literal)?; } - Ok(self.append_expr(expr, span, true)) + Ok(self.append_expr(expr, span, ExpressionKind::Const)) } - fn append_expr(&mut self, expr: Expression, span: Span, is_const: bool) -> Handle { - if let Some(FunctionLocalData { - ref mut emitter, - ref mut block, - ref mut expression_constness, - .. - }) = self.function_local_data - { - let is_running = emitter.is_running(); - let needs_pre_emit = expr.needs_pre_emit(); - let h = if is_running && needs_pre_emit { - block.extend(emitter.finish(self.expressions)); - let h = self.expressions.append(expr, span); - emitter.start(self.expressions); - h - } else { - self.expressions.append(expr, span) - }; - if is_const { - expression_constness.insert(h); + fn append_expr( + &mut self, + expr: Expression, + span: Span, + expr_type: ExpressionKind, + ) -> Handle { + let h = match self.behavior { + Behavior::Wgsl(WgslRestrictions::Runtime(ref mut function_local_data)) + | Behavior::Glsl(GlslRestrictions::Runtime(ref mut function_local_data)) => { + let is_running = function_local_data.emitter.is_running(); + let needs_pre_emit = expr.needs_pre_emit(); + if is_running && needs_pre_emit { + function_local_data + .block + .extend(function_local_data.emitter.finish(self.expressions)); + let h = self.expressions.append(expr, span); + function_local_data.emitter.start(self.expressions); + h + } else { + self.expressions.append(expr, span) + } } - h - } else { - self.expressions.append(expr, span) - } + _ => self.expressions.append(expr, span), + }; + self.expression_kind_tracker.insert(h, expr_type); + h } fn resolve_type( @@ -2062,7 +2195,7 @@ mod tests { UniqueArena, VectorSize, }; - use super::{Behavior, ConstantEvaluator}; + use super::{Behavior, ConstantEvaluator, ExpressionConstnessTracker, WgslRestrictions}; #[test] fn unary_op() { @@ -2143,13 +2276,15 @@ mod tests { expr: expr1, }; + let expression_kind_tracker = + &mut ExpressionConstnessTracker::from_arena(&const_expressions); let mut solver = ConstantEvaluator { - behavior: Behavior::Wgsl, + behavior: Behavior::Wgsl(WgslRestrictions::Const), types: &mut types, constants: &constants, overrides: &overrides, expressions: &mut const_expressions, - function_local_data: None, + expression_kind_tracker, }; let res1 = solver @@ -2228,13 +2363,15 @@ mod tests { convert: Some(crate::BOOL_WIDTH), }; + let expression_kind_tracker = + &mut ExpressionConstnessTracker::from_arena(&const_expressions); let mut solver = ConstantEvaluator { - behavior: Behavior::Wgsl, + behavior: Behavior::Wgsl(WgslRestrictions::Const), types: &mut types, constants: &constants, overrides: &overrides, expressions: &mut const_expressions, - function_local_data: None, + expression_kind_tracker, }; let res = solver @@ -2345,13 +2482,15 @@ mod tests { let base = const_expressions.append(Expression::Constant(h), Default::default()); + let expression_kind_tracker = + &mut ExpressionConstnessTracker::from_arena(&const_expressions); let mut solver = ConstantEvaluator { - behavior: Behavior::Wgsl, + behavior: Behavior::Wgsl(WgslRestrictions::Const), types: &mut types, constants: &constants, overrides: &overrides, expressions: &mut const_expressions, - function_local_data: None, + expression_kind_tracker, }; let root1 = Expression::AccessIndex { base, index: 1 }; @@ -2437,13 +2576,15 @@ mod tests { let h_expr = const_expressions.append(Expression::Constant(h), Default::default()); + let expression_kind_tracker = + &mut ExpressionConstnessTracker::from_arena(&const_expressions); let mut solver = ConstantEvaluator { - behavior: Behavior::Wgsl, + behavior: Behavior::Wgsl(WgslRestrictions::Const), types: &mut types, constants: &constants, overrides: &overrides, expressions: &mut const_expressions, - function_local_data: None, + expression_kind_tracker, }; let solved_compose = solver @@ -2518,13 +2659,15 @@ mod tests { let h_expr = const_expressions.append(Expression::Constant(h), Default::default()); + let expression_kind_tracker = + &mut ExpressionConstnessTracker::from_arena(&const_expressions); let mut solver = ConstantEvaluator { - behavior: Behavior::Wgsl, + behavior: Behavior::Wgsl(WgslRestrictions::Const), types: &mut types, constants: &constants, overrides: &overrides, expressions: &mut const_expressions, - function_local_data: None, + expression_kind_tracker, }; let solved_compose = solver diff --git a/naga/src/proc/mod.rs b/naga/src/proc/mod.rs index 6dc677ff23..2db956ee0e 100644 --- a/naga/src/proc/mod.rs +++ b/naga/src/proc/mod.rs @@ -11,7 +11,7 @@ mod terminator; mod typifier; pub use constant_evaluator::{ - ConstantEvaluator, ConstantEvaluatorError, ExpressionConstnessTracker, + ConstantEvaluator, ConstantEvaluatorError, ExpressionConstnessTracker, ExpressionKind, }; pub use emitter::Emitter; pub use index::{BoundsCheckPolicies, BoundsCheckPolicy, IndexableLength, IndexableLengthError}; diff --git a/naga/src/valid/analyzer.rs b/naga/src/valid/analyzer.rs index 84f57f6c8a..fbb4461e38 100644 --- a/naga/src/valid/analyzer.rs +++ b/naga/src/valid/analyzer.rs @@ -226,7 +226,7 @@ struct Sampling { sampler: GlobalOrArgument, } -#[derive(Debug)] +#[derive(Debug, Clone)] #[cfg_attr(feature = "serialize", derive(serde::Serialize))] #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] pub struct FunctionInfo { diff --git a/naga/src/valid/expression.rs b/naga/src/valid/expression.rs index 7b259d69f9..79180a0711 100644 --- a/naga/src/valid/expression.rs +++ b/naga/src/valid/expression.rs @@ -90,6 +90,8 @@ pub enum ExpressionError { sampler: bool, has_ref: bool, }, + #[error("Sample offset must be a const-expression")] + InvalidSampleOffsetExprType, #[error("Sample offset constant {1:?} doesn't match the image dimension {0:?}")] InvalidSampleOffset(crate::ImageDimension, Handle), #[error("Depth reference {0:?} is not a scalar float")] @@ -129,9 +131,10 @@ pub enum ExpressionError { } #[derive(Clone, Debug, thiserror::Error)] +#[cfg_attr(test, derive(PartialEq))] pub enum ConstExpressionError { - #[error("The expression is not a constant expression")] - NonConst, + #[error("The expression is not a constant or override expression")] + NonConstOrOverride, #[error(transparent)] Compose(#[from] super::ComposeError), #[error("Splatting {0:?} can't be done")] @@ -184,9 +187,14 @@ impl super::Validator { handle: Handle, gctx: crate::proc::GlobalCtx, mod_info: &ModuleInfo, + global_expr_kind: &crate::proc::ExpressionConstnessTracker, ) -> Result<(), ConstExpressionError> { use crate::Expression as E; + if !global_expr_kind.is_const_or_override(handle) { + return Err(super::ConstExpressionError::NonConstOrOverride); + } + match gctx.const_expressions[handle] { E::Literal(literal) => { self.validate_literal(literal)?; @@ -203,12 +211,14 @@ impl super::Validator { crate::TypeInner::Scalar { .. } => {} _ => return Err(super::ConstExpressionError::InvalidSplatType(value)), }, - _ => return Err(super::ConstExpressionError::NonConst), + // the constant evaluator will report errors about override-expressions + _ => {} } Ok(()) } + #[allow(clippy::too_many_arguments)] pub(super) fn validate_expression( &self, root: Handle, @@ -217,6 +227,7 @@ impl super::Validator { module: &crate::Module, info: &FunctionInfo, mod_info: &ModuleInfo, + global_expr_kind: &crate::proc::ExpressionConstnessTracker, ) -> Result { use crate::{Expression as E, Scalar as Sc, ScalarKind as Sk, TypeInner as Ti}; @@ -462,6 +473,10 @@ impl super::Validator { // check constant offset if let Some(const_expr) = offset { + if !global_expr_kind.is_const(const_expr) { + return Err(ExpressionError::InvalidSampleOffsetExprType); + } + match *mod_info[const_expr].inner_with(&module.types) { Ti::Scalar(Sc { kind: Sk::Sint, .. }) if num_components == 1 => {} Ti::Vector { diff --git a/naga/src/valid/function.rs b/naga/src/valid/function.rs index f0ca22cbda..dfb7fbc6ee 100644 --- a/naga/src/valid/function.rs +++ b/naga/src/valid/function.rs @@ -927,7 +927,7 @@ impl super::Validator { var: &crate::LocalVariable, gctx: crate::proc::GlobalCtx, fun_info: &FunctionInfo, - expression_constness: &crate::proc::ExpressionConstnessTracker, + local_expr_kind: &crate::proc::ExpressionConstnessTracker, ) -> Result<(), LocalVariableError> { log::debug!("var {:?}", var); let type_info = self @@ -945,7 +945,7 @@ impl super::Validator { return Err(LocalVariableError::InitializerType); } - if !expression_constness.is_const(init) { + if !local_expr_kind.is_const(init) { return Err(LocalVariableError::NonConstInitializer); } } @@ -959,14 +959,14 @@ impl super::Validator { module: &crate::Module, mod_info: &ModuleInfo, entry_point: bool, + global_expr_kind: &crate::proc::ExpressionConstnessTracker, ) -> Result> { let mut info = mod_info.process_function(fun, module, self.flags, self.capabilities)?; - let expression_constness = - crate::proc::ExpressionConstnessTracker::from_arena(&fun.expressions); + let local_expr_kind = crate::proc::ExpressionConstnessTracker::from_arena(&fun.expressions); for (var_handle, var) in fun.local_variables.iter() { - self.validate_local_var(var, module.to_ctx(), &info, &expression_constness) + self.validate_local_var(var, module.to_ctx(), &info, &local_expr_kind) .map_err(|source| { FunctionError::LocalVariable { handle: var_handle, @@ -1032,7 +1032,15 @@ impl super::Validator { self.valid_expression_set.insert(handle.index()); } if self.flags.contains(super::ValidationFlags::EXPRESSIONS) { - match self.validate_expression(handle, expr, fun, module, &info, mod_info) { + match self.validate_expression( + handle, + expr, + fun, + module, + &info, + mod_info, + global_expr_kind, + ) { Ok(stages) => info.available_stages &= stages, Err(source) => { return Err(FunctionError::Expression { handle, source } diff --git a/naga/src/valid/handles.rs b/naga/src/valid/handles.rs index 0643b1c9f5..bcda98b294 100644 --- a/naga/src/valid/handles.rs +++ b/naga/src/valid/handles.rs @@ -592,6 +592,7 @@ impl From for ValidationError { } #[derive(Clone, Debug, thiserror::Error)] +#[cfg_attr(test, derive(PartialEq))] pub enum InvalidHandleError { #[error(transparent)] BadHandle(#[from] BadHandle), @@ -602,6 +603,7 @@ pub enum InvalidHandleError { } #[derive(Clone, Debug, thiserror::Error)] +#[cfg_attr(test, derive(PartialEq))] #[error( "{subject:?} of kind {subject_kind:?} depends on {depends_on:?} of kind {depends_on_kind}, \ which has not been processed yet" diff --git a/naga/src/valid/interface.rs b/naga/src/valid/interface.rs index 84c8b09ddb..945af946bb 100644 --- a/naga/src/valid/interface.rs +++ b/naga/src/valid/interface.rs @@ -10,6 +10,7 @@ use bit_set::BitSet; const MAX_WORKGROUP_SIZE: u32 = 0x4000; #[derive(Clone, Debug, thiserror::Error)] +#[cfg_attr(test, derive(PartialEq))] pub enum GlobalVariableError { #[error("Usage isn't compatible with address space {0:?}")] InvalidUsage(crate::AddressSpace), @@ -30,6 +31,8 @@ pub enum GlobalVariableError { Handle, #[source] Disalignment, ), + #[error("Initializer must be a const-expression")] + InitializerExprType, #[error("Initializer doesn't match the variable type")] InitializerType, #[error("Initializer can't be used with address space {0:?}")] @@ -39,6 +42,7 @@ pub enum GlobalVariableError { } #[derive(Clone, Debug, thiserror::Error)] +#[cfg_attr(test, derive(PartialEq))] pub enum VaryingError { #[error("The type {0:?} does not match the varying")] InvalidType(Handle), @@ -76,6 +80,7 @@ pub enum VaryingError { } #[derive(Clone, Debug, thiserror::Error)] +#[cfg_attr(test, derive(PartialEq))] pub enum EntryPointError { #[error("Multiple conflicting entry points")] Conflict, @@ -395,6 +400,7 @@ impl super::Validator { var: &crate::GlobalVariable, gctx: crate::proc::GlobalCtx, mod_info: &ModuleInfo, + global_expr_kind: &crate::proc::ExpressionConstnessTracker, ) -> Result<(), GlobalVariableError> { use super::TypeFlags; @@ -523,6 +529,10 @@ impl super::Validator { } } + if !global_expr_kind.is_const(init) { + return Err(GlobalVariableError::InitializerExprType); + } + let decl_ty = &gctx.types[var.ty].inner; let init_ty = mod_info[init].inner_with(gctx.types); if !decl_ty.equivalent(init_ty, gctx.types) { @@ -538,6 +548,7 @@ impl super::Validator { ep: &crate::EntryPoint, module: &crate::Module, mod_info: &ModuleInfo, + global_expr_kind: &crate::proc::ExpressionConstnessTracker, ) -> Result> { if ep.early_depth_test.is_some() { let required = Capabilities::EARLY_DEPTH_TEST; @@ -566,7 +577,7 @@ impl super::Validator { } let mut info = self - .validate_function(&ep.function, module, mod_info, true) + .validate_function(&ep.function, module, mod_info, true, global_expr_kind) .map_err(WithSpan::into_other)?; { diff --git a/naga/src/valid/mod.rs b/naga/src/valid/mod.rs index 311279478c..b4b2063775 100644 --- a/naga/src/valid/mod.rs +++ b/naga/src/valid/mod.rs @@ -12,7 +12,7 @@ mod r#type; use crate::{ arena::Handle, - proc::{LayoutError, Layouter, TypeResolution}, + proc::{ExpressionConstnessTracker, LayoutError, Layouter, TypeResolution}, FastHashSet, }; use bit_set::BitSet; @@ -131,7 +131,7 @@ bitflags::bitflags! { } } -#[derive(Debug)] +#[derive(Debug, Clone)] #[cfg_attr(feature = "serialize", derive(serde::Serialize))] #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] pub struct ModuleInfo { @@ -178,7 +178,10 @@ pub struct Validator { } #[derive(Clone, Debug, thiserror::Error)] +#[cfg_attr(test, derive(PartialEq))] pub enum ConstantError { + #[error("Initializer must be a const-expression")] + InitializerExprType, #[error("The type doesn't match the constant")] InvalidType, #[error("The type is not constructible")] @@ -186,11 +189,14 @@ pub enum ConstantError { } #[derive(Clone, Debug, thiserror::Error)] +#[cfg_attr(test, derive(PartialEq))] pub enum OverrideError { #[error("Override name and ID are missing")] MissingNameAndID, #[error("Override ID must be unique")] DuplicateID, + #[error("Initializer must be a const-expression or override-expression")] + InitializerExprType, #[error("The type doesn't match the override")] InvalidType, #[error("The type is not constructible")] @@ -200,6 +206,7 @@ pub enum OverrideError { } #[derive(Clone, Debug, thiserror::Error)] +#[cfg_attr(test, derive(PartialEq))] pub enum ValidationError { #[error(transparent)] InvalidHandle(#[from] InvalidHandleError), @@ -335,6 +342,7 @@ impl Validator { handle: Handle, gctx: crate::proc::GlobalCtx, mod_info: &ModuleInfo, + global_expr_kind: &ExpressionConstnessTracker, ) -> Result<(), ConstantError> { let con = &gctx.constants[handle]; @@ -343,6 +351,10 @@ impl Validator { return Err(ConstantError::NonConstructibleType); } + if !global_expr_kind.is_const(con.init) { + return Err(ConstantError::InitializerExprType); + } + let decl_ty = &gctx.types[con.ty].inner; let init_ty = mod_info[con.init].inner_with(gctx.types); if !decl_ty.equivalent(init_ty, gctx.types) { @@ -455,17 +467,24 @@ impl Validator { } } + let global_expr_kind = ExpressionConstnessTracker::from_arena(&module.const_expressions); + if self.flags.contains(ValidationFlags::CONSTANTS) { for (handle, _) in module.const_expressions.iter() { - self.validate_const_expression(handle, module.to_ctx(), &mod_info) - .map_err(|source| { - ValidationError::ConstExpression { handle, source } - .with_span_handle(handle, &module.const_expressions) - })? + self.validate_const_expression( + handle, + module.to_ctx(), + &mod_info, + &global_expr_kind, + ) + .map_err(|source| { + ValidationError::ConstExpression { handle, source } + .with_span_handle(handle, &module.const_expressions) + })? } for (handle, constant) in module.constants.iter() { - self.validate_constant(handle, module.to_ctx(), &mod_info) + self.validate_constant(handle, module.to_ctx(), &mod_info, &global_expr_kind) .map_err(|source| { ValidationError::Constant { handle, @@ -490,7 +509,7 @@ impl Validator { } for (var_handle, var) in module.global_variables.iter() { - self.validate_global_var(var, module.to_ctx(), &mod_info) + self.validate_global_var(var, module.to_ctx(), &mod_info, &global_expr_kind) .map_err(|source| { ValidationError::GlobalVariable { handle: var_handle, @@ -502,7 +521,7 @@ impl Validator { } for (handle, fun) in module.functions.iter() { - match self.validate_function(fun, module, &mod_info, false) { + match self.validate_function(fun, module, &mod_info, false, &global_expr_kind) { Ok(info) => mod_info.functions.push(info), Err(error) => { return Err(error.and_then(|source| { @@ -528,7 +547,7 @@ impl Validator { .with_span()); // TODO: keep some EP span information? } - match self.validate_entry_point(ep, module, &mod_info) { + match self.validate_entry_point(ep, module, &mod_info, &global_expr_kind) { Ok(info) => mod_info.entry_points.push(info), Err(error) => { return Err(error.and_then(|source| { diff --git a/naga/src/valid/type.rs b/naga/src/valid/type.rs index b8eb618ed4..03e87fd99b 100644 --- a/naga/src/valid/type.rs +++ b/naga/src/valid/type.rs @@ -63,6 +63,7 @@ bitflags::bitflags! { } #[derive(Clone, Copy, Debug, thiserror::Error)] +#[cfg_attr(test, derive(PartialEq))] pub enum Disalignment { #[error("The array stride {stride} is not a multiple of the required alignment {alignment}")] ArrayStride { stride: u32, alignment: Alignment }, @@ -87,6 +88,7 @@ pub enum Disalignment { } #[derive(Clone, Debug, thiserror::Error)] +#[cfg_attr(test, derive(PartialEq))] pub enum TypeError { #[error("Capability {0:?} is required")] MissingCapability(Capabilities), diff --git a/naga/tests/in/overrides.wgsl b/naga/tests/in/overrides.wgsl index b498a8b527..41e99f9426 100644 --- a/naga/tests/in/overrides.wgsl +++ b/naga/tests/in/overrides.wgsl @@ -6,7 +6,7 @@ override depth: f32; // Specified at the API level using // the name "depth". // Must be overridden. - // override height = 2 * depth; // The default value + override height = 2 * depth; // The default value // (if not set at the API level), // depends on another // overridable constant. diff --git a/naga/tests/out/analysis/overrides.info.ron b/naga/tests/out/analysis/overrides.info.ron index 481c3eac99..7a2447f3c0 100644 --- a/naga/tests/out/analysis/overrides.info.ron +++ b/naga/tests/out/analysis/overrides.info.ron @@ -33,6 +33,12 @@ kind: Float, width: 4, ))), + Handle(2), + Value(Scalar(( + kind: Float, + width: 4, + ))), + Handle(2), Value(Scalar(( kind: Float, width: 4, diff --git a/naga/tests/out/hlsl/overrides.hlsl b/naga/tests/out/hlsl/overrides.hlsl index 63b13a5d2b..0a849fd4db 100644 --- a/naga/tests/out/hlsl/overrides.hlsl +++ b/naga/tests/out/hlsl/overrides.hlsl @@ -3,6 +3,7 @@ static const float specular_param = 2.3; static const float gain = 1.1; static const float width = 0.0; static const float depth = 2.3; +static const float height = 4.6; static const float inferred_f32_ = 2.718; [numthreads(1, 1, 1)] diff --git a/naga/tests/out/ir/overrides.compact.ron b/naga/tests/out/ir/overrides.compact.ron index af4b31eba9..d15abbd033 100644 --- a/naga/tests/out/ir/overrides.compact.ron +++ b/naga/tests/out/ir/overrides.compact.ron @@ -52,11 +52,17 @@ ty: 2, init: None, ), + ( + name: Some("height"), + id: None, + ty: 2, + init: Some(6), + ), ( name: Some("inferred_f32"), id: None, ty: 2, - init: Some(4), + init: Some(7), ), ], global_variables: [], @@ -64,6 +70,13 @@ Literal(Bool(true)), Literal(F32(2.3)), Literal(F32(0.0)), + Override(5), + Literal(F32(2.0)), + Binary( + op: Multiply, + left: 5, + right: 4, + ), Literal(F32(2.718)), ], functions: [], diff --git a/naga/tests/out/ir/overrides.ron b/naga/tests/out/ir/overrides.ron index af4b31eba9..d15abbd033 100644 --- a/naga/tests/out/ir/overrides.ron +++ b/naga/tests/out/ir/overrides.ron @@ -52,11 +52,17 @@ ty: 2, init: None, ), + ( + name: Some("height"), + id: None, + ty: 2, + init: Some(6), + ), ( name: Some("inferred_f32"), id: None, ty: 2, - init: Some(4), + init: Some(7), ), ], global_variables: [], @@ -64,6 +70,13 @@ Literal(Bool(true)), Literal(F32(2.3)), Literal(F32(0.0)), + Override(5), + Literal(F32(2.0)), + Binary( + op: Multiply, + left: 5, + right: 4, + ), Literal(F32(2.718)), ], functions: [], diff --git a/naga/tests/out/msl/overrides.msl b/naga/tests/out/msl/overrides.msl index 419edd8904..13a3b623a0 100644 --- a/naga/tests/out/msl/overrides.msl +++ b/naga/tests/out/msl/overrides.msl @@ -9,6 +9,7 @@ constant float specular_param = 2.3; constant float gain = 1.1; constant float width = 0.0; constant float depth = 2.3; +constant float height = 4.6; constant float inferred_f32_ = 2.718; kernel void main_( diff --git a/naga/tests/out/spv/overrides.main.spvasm b/naga/tests/out/spv/overrides.main.spvasm index 7dfa6df3e5..7731edfb93 100644 --- a/naga/tests/out/spv/overrides.main.spvasm +++ b/naga/tests/out/spv/overrides.main.spvasm @@ -1,25 +1,27 @@ ; SPIR-V ; Version: 1.0 ; Generator: rspirv -; Bound: 15 +; Bound: 17 OpCapability Shader %1 = OpExtInstImport "GLSL.std.450" OpMemoryModel Logical GLSL450 -OpEntryPoint GLCompute %12 "main" -OpExecutionMode %12 LocalSize 1 1 1 +OpEntryPoint GLCompute %14 "main" +OpExecutionMode %14 LocalSize 1 1 1 %2 = OpTypeVoid %3 = OpTypeBool %4 = OpTypeFloat 32 %5 = OpConstantTrue %3 %6 = OpConstant %4 2.3 %7 = OpConstant %4 0.0 -%8 = OpConstant %4 2.718 -%9 = OpConstantFalse %3 -%10 = OpConstant %4 1.1 -%13 = OpTypeFunction %2 -%12 = OpFunction %2 None %13 -%11 = OpLabel -OpBranch %14 -%14 = OpLabel +%8 = OpConstantFalse %3 +%9 = OpConstant %4 1.1 +%10 = OpConstant %4 2.0 +%11 = OpConstant %4 4.6 +%12 = OpConstant %4 2.718 +%15 = OpTypeFunction %2 +%14 = OpFunction %2 None %15 +%13 = OpLabel +OpBranch %16 +%16 = OpLabel OpReturn OpFunctionEnd \ No newline at end of file diff --git a/naga/tests/out/wgsl/quad_glsl.vert.wgsl b/naga/tests/out/wgsl/quad_glsl.vert.wgsl index 8942e4c72f..0a3d7cecac 100644 --- a/naga/tests/out/wgsl/quad_glsl.vert.wgsl +++ b/naga/tests/out/wgsl/quad_glsl.vert.wgsl @@ -14,8 +14,8 @@ fn main_1() { let _e4 = a_uv_1; v_uv = _e4; let _e6 = a_pos_1; - let _e8 = (c_scale * _e6); - gl_Position = vec4(_e8.x, _e8.y, 0f, 1f); + let _e7 = (c_scale * _e6); + gl_Position = vec4(_e7.x, _e7.y, 0f, 1f); return; } From 98f4ca0742eae8e618b3d843f671dc88e8d9aced Mon Sep 17 00:00:00 2001 From: teoxoy <28601907+teoxoy@users.noreply.github.com> Date: Wed, 14 Feb 2024 15:20:29 +0100 Subject: [PATCH 10/30] rename `ExpressionConstnessTracker` to `ExpressionKindTracker` --- naga/src/back/pipeline_constants.rs | 4 +-- naga/src/front/glsl/context.rs | 8 +++--- naga/src/front/glsl/parser.rs | 4 +-- naga/src/front/wgsl/lower/mod.rs | 38 +++++++++++++++-------------- naga/src/proc/constant_evaluator.rs | 33 +++++++++++-------------- naga/src/proc/mod.rs | 2 +- naga/src/valid/expression.rs | 4 +-- naga/src/valid/function.rs | 6 ++--- naga/src/valid/interface.rs | 4 +-- naga/src/valid/mod.rs | 6 ++--- 10 files changed, 53 insertions(+), 56 deletions(-) diff --git a/naga/src/back/pipeline_constants.rs b/naga/src/back/pipeline_constants.rs index 6b2792dd28..79c44f5e9f 100644 --- a/naga/src/back/pipeline_constants.rs +++ b/naga/src/back/pipeline_constants.rs @@ -36,7 +36,7 @@ pub(super) fn process_overrides<'a>( let mut adjusted_const_expressions = Vec::with_capacity(module.const_expressions.len()); let mut adjusted_constant_initializers = HashSet::with_capacity(module.constants.len()); - let mut global_expression_kind_tracker = crate::proc::ExpressionConstnessTracker::new(); + let mut global_expression_kind_tracker = crate::proc::ExpressionKindTracker::new(); let mut override_iter = module.overrides.drain(); @@ -123,7 +123,7 @@ fn process_override( override_map: &mut Vec>, adjusted_const_expressions: &[Handle], adjusted_constant_initializers: &mut HashSet>, - global_expression_kind_tracker: &mut crate::proc::ExpressionConstnessTracker, + global_expression_kind_tracker: &mut crate::proc::ExpressionKindTracker, ) -> Result, PipelineConstantError> { let key = if let Some(id) = override_.id { Cow::Owned(id.to_string()) diff --git a/naga/src/front/glsl/context.rs b/naga/src/front/glsl/context.rs index 0c370cd5e5..ec844597d6 100644 --- a/naga/src/front/glsl/context.rs +++ b/naga/src/front/glsl/context.rs @@ -78,9 +78,9 @@ pub struct Context<'a> { pub module: &'a mut crate::Module, pub is_const: bool, /// Tracks the expression kind of `Expression`s residing in `self.expressions` - pub local_expression_kind_tracker: crate::proc::ExpressionConstnessTracker, + pub local_expression_kind_tracker: crate::proc::ExpressionKindTracker, /// Tracks the expression kind of `Expression`s residing in `self.module.const_expressions` - pub global_expression_kind_tracker: &'a mut crate::proc::ExpressionConstnessTracker, + pub global_expression_kind_tracker: &'a mut crate::proc::ExpressionKindTracker, } impl<'a> Context<'a> { @@ -88,7 +88,7 @@ impl<'a> Context<'a> { frontend: &Frontend, module: &'a mut crate::Module, is_const: bool, - global_expression_kind_tracker: &'a mut crate::proc::ExpressionConstnessTracker, + global_expression_kind_tracker: &'a mut crate::proc::ExpressionKindTracker, ) -> Result { let mut this = Context { expressions: Arena::new(), @@ -108,7 +108,7 @@ impl<'a> Context<'a> { body: Block::new(), module, is_const: false, - local_expression_kind_tracker: crate::proc::ExpressionConstnessTracker::new(), + local_expression_kind_tracker: crate::proc::ExpressionKindTracker::new(), global_expression_kind_tracker, }; diff --git a/naga/src/front/glsl/parser.rs b/naga/src/front/glsl/parser.rs index d4eb39b39b..28e0808063 100644 --- a/naga/src/front/glsl/parser.rs +++ b/naga/src/front/glsl/parser.rs @@ -164,7 +164,7 @@ impl<'source> ParsingContext<'source> { pub fn parse(&mut self, frontend: &mut Frontend) -> Result { let mut module = Module::default(); - let mut global_expression_kind_tracker = crate::proc::ExpressionConstnessTracker::new(); + let mut global_expression_kind_tracker = crate::proc::ExpressionKindTracker::new(); // Body and expression arena for global initialization let mut ctx = Context::new( @@ -229,7 +229,7 @@ impl<'source> ParsingContext<'source> { &mut self, frontend: &mut Frontend, module: &mut Module, - global_expression_kind_tracker: &mut crate::proc::ExpressionConstnessTracker, + global_expression_kind_tracker: &mut crate::proc::ExpressionKindTracker, ) -> Result<(Handle, Span)> { let mut ctx = Context::new(frontend, module, true, global_expression_kind_tracker)?; diff --git a/naga/src/front/wgsl/lower/mod.rs b/naga/src/front/wgsl/lower/mod.rs index 662e318f8b..e689dda53a 100644 --- a/naga/src/front/wgsl/lower/mod.rs +++ b/naga/src/front/wgsl/lower/mod.rs @@ -87,7 +87,7 @@ pub struct GlobalContext<'source, 'temp, 'out> { const_typifier: &'temp mut Typifier, - global_expression_kind_tracker: &'temp mut crate::proc::ExpressionConstnessTracker, + global_expression_kind_tracker: &'temp mut crate::proc::ExpressionKindTracker, } impl<'source> GlobalContext<'source, '_, '_> { @@ -179,8 +179,8 @@ pub struct StatementContext<'source, 'temp, 'out> { /// with the form of the expressions; it is also tracking whether WGSL says /// we should consider them to be const. See the use of `force_non_const` in /// the code for lowering `let` bindings. - expression_constness: &'temp mut crate::proc::ExpressionConstnessTracker, - global_expression_kind_tracker: &'temp mut crate::proc::ExpressionConstnessTracker, + local_expression_kind_tracker: &'temp mut crate::proc::ExpressionKindTracker, + global_expression_kind_tracker: &'temp mut crate::proc::ExpressionKindTracker, } impl<'a, 'temp> StatementContext<'a, 'temp, '_> { @@ -205,7 +205,7 @@ impl<'a, 'temp> StatementContext<'a, 'temp, '_> { block, emitter, typifier: self.typifier, - expression_constness: self.expression_constness, + local_expression_kind_tracker: self.local_expression_kind_tracker, }), } } @@ -250,8 +250,8 @@ pub struct RuntimeExpressionContext<'temp, 'out> { /// Which `Expression`s in `self.naga_expressions` are const expressions, in /// the WGSL sense. /// - /// See [`StatementContext::expression_constness`] for details. - expression_constness: &'temp mut crate::proc::ExpressionConstnessTracker, + /// See [`StatementContext::local_expression_kind_tracker`] for details. + local_expression_kind_tracker: &'temp mut crate::proc::ExpressionKindTracker, } /// The type of Naga IR expression we are lowering an [`ast::Expression`] to. @@ -337,7 +337,7 @@ pub struct ExpressionContext<'source, 'temp, 'out> { /// /// [`module::const_expressions`]: crate::Module::const_expressions const_typifier: &'temp mut Typifier, - global_expression_kind_tracker: &'temp mut crate::proc::ExpressionConstnessTracker, + global_expression_kind_tracker: &'temp mut crate::proc::ExpressionKindTracker, /// Whether we are lowering a constant expression or a general /// runtime expression, and the data needed in each case. @@ -373,7 +373,7 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> { ExpressionContextType::Runtime(ref mut rctx) => ConstantEvaluator::for_wgsl_function( self.module, &mut rctx.function.expressions, - rctx.expression_constness, + rctx.local_expression_kind_tracker, rctx.emitter, rctx.block, ), @@ -403,7 +403,7 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> { fn const_access(&self, handle: Handle) -> Option { match self.expr_type { ExpressionContextType::Runtime(ref ctx) => { - if !ctx.expression_constness.is_const(handle) { + if !ctx.local_expression_kind_tracker.is_const(handle) { return None; } @@ -455,7 +455,7 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> { ) -> Result> { match self.expr_type { ExpressionContextType::Runtime(ref rctx) => { - if !rctx.expression_constness.is_const(expr) { + if !rctx.local_expression_kind_tracker.is_const(expr) { return Err(Error::ExpectedConstExprConcreteIntegerScalar( component_span, )); @@ -899,7 +899,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { types: &tu.types, module: &mut module, const_typifier: &mut Typifier::new(), - global_expression_kind_tracker: &mut crate::proc::ExpressionConstnessTracker::new(), + global_expression_kind_tracker: &mut crate::proc::ExpressionKindTracker::new(), }; for decl_handle in self.index.visit_ordered() { @@ -1097,7 +1097,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { let mut local_table = FastHashMap::default(); let mut expressions = Arena::new(); let mut named_expressions = FastIndexMap::default(); - let mut local_expression_kind_tracker = crate::proc::ExpressionConstnessTracker::new(); + let mut local_expression_kind_tracker = crate::proc::ExpressionKindTracker::new(); let arguments = f .arguments @@ -1152,7 +1152,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { named_expressions: &mut named_expressions, types: ctx.types, module: ctx.module, - expression_constness: &mut local_expression_kind_tracker, + local_expression_kind_tracker: &mut local_expression_kind_tracker, global_expression_kind_tracker: ctx.global_expression_kind_tracker, }; let mut body = self.block(&f.body, false, &mut stmt_ctx)?; @@ -1232,7 +1232,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { // affects when errors must be reported, so we can't even // treat suitable `let` bindings as constant as an // optimization. - ctx.expression_constness.force_non_const(value); + ctx.local_expression_kind_tracker.force_non_const(value); let explicit_ty = l.ty.map(|ty| self.resolve_ast_type(ty, &mut ctx.as_global())) @@ -1316,7 +1316,9 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { // - the initialization is not a constant // expression, so its value depends on the // state at the point of initialization. - if is_inside_loop || !ctx.expression_constness.is_const(init) { + if is_inside_loop + || !ctx.local_expression_kind_tracker.is_const(init) + { (None, Some(init)) } else { (Some(init), None) @@ -1569,9 +1571,9 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { .function .expressions .append(crate::Expression::Binary { op, left, right }, stmt.span); - rctx.expression_constness + rctx.local_expression_kind_tracker .insert(left, crate::proc::ExpressionKind::Runtime); - rctx.expression_constness + rctx.local_expression_kind_tracker .insert(value, crate::proc::ExpressionKind::Runtime); block.extend(emitter.finish(&ctx.function.expressions)); @@ -1950,7 +1952,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { .function .expressions .append(crate::Expression::CallResult(function), span); - rctx.expression_constness + rctx.local_expression_kind_tracker .insert(result, crate::proc::ExpressionKind::Runtime); result }); diff --git a/naga/src/proc/constant_evaluator.rs b/naga/src/proc/constant_evaluator.rs index 6318a57c00..6f09ec5444 100644 --- a/naga/src/proc/constant_evaluator.rs +++ b/naga/src/proc/constant_evaluator.rs @@ -298,7 +298,7 @@ pub struct ConstantEvaluator<'a> { expressions: &'a mut Arena, /// Tracks the constness of expressions residing in [`Self::expressions`] - expression_kind_tracker: &'a mut ExpressionConstnessTracker, + expression_kind_tracker: &'a mut ExpressionKindTracker, } #[derive(Debug)] @@ -340,11 +340,11 @@ pub enum ExpressionKind { } #[derive(Debug)] -pub struct ExpressionConstnessTracker { +pub struct ExpressionKindTracker { inner: Vec, } -impl ExpressionConstnessTracker { +impl ExpressionKindTracker { pub const fn new() -> Self { Self { inner: Vec::new() } } @@ -531,7 +531,7 @@ impl<'a> ConstantEvaluator<'a> { /// Report errors according to WGSL's rules for constant evaluation. pub fn for_wgsl_module( module: &'a mut crate::Module, - global_expression_kind_tracker: &'a mut ExpressionConstnessTracker, + global_expression_kind_tracker: &'a mut ExpressionKindTracker, in_override_ctx: bool, ) -> Self { Self::for_module( @@ -551,7 +551,7 @@ impl<'a> ConstantEvaluator<'a> { /// Report errors according to GLSL's rules for constant evaluation. pub fn for_glsl_module( module: &'a mut crate::Module, - global_expression_kind_tracker: &'a mut ExpressionConstnessTracker, + global_expression_kind_tracker: &'a mut ExpressionKindTracker, ) -> Self { Self::for_module( Behavior::Glsl(GlslRestrictions::Const), @@ -563,7 +563,7 @@ impl<'a> ConstantEvaluator<'a> { fn for_module( behavior: Behavior<'a>, module: &'a mut crate::Module, - global_expression_kind_tracker: &'a mut ExpressionConstnessTracker, + global_expression_kind_tracker: &'a mut ExpressionKindTracker, ) -> Self { Self { behavior, @@ -582,7 +582,7 @@ impl<'a> ConstantEvaluator<'a> { pub fn for_wgsl_function( module: &'a mut crate::Module, expressions: &'a mut Arena, - local_expression_kind_tracker: &'a mut ExpressionConstnessTracker, + local_expression_kind_tracker: &'a mut ExpressionKindTracker, emitter: &'a mut super::Emitter, block: &'a mut crate::Block, ) -> Self { @@ -607,7 +607,7 @@ impl<'a> ConstantEvaluator<'a> { pub fn for_glsl_function( module: &'a mut crate::Module, expressions: &'a mut Arena, - local_expression_kind_tracker: &'a mut ExpressionConstnessTracker, + local_expression_kind_tracker: &'a mut ExpressionKindTracker, emitter: &'a mut super::Emitter, block: &'a mut crate::Block, ) -> Self { @@ -2195,7 +2195,7 @@ mod tests { UniqueArena, VectorSize, }; - use super::{Behavior, ConstantEvaluator, ExpressionConstnessTracker, WgslRestrictions}; + use super::{Behavior, ConstantEvaluator, ExpressionKindTracker, WgslRestrictions}; #[test] fn unary_op() { @@ -2276,8 +2276,7 @@ mod tests { expr: expr1, }; - let expression_kind_tracker = - &mut ExpressionConstnessTracker::from_arena(&const_expressions); + let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&const_expressions); let mut solver = ConstantEvaluator { behavior: Behavior::Wgsl(WgslRestrictions::Const), types: &mut types, @@ -2363,8 +2362,7 @@ mod tests { convert: Some(crate::BOOL_WIDTH), }; - let expression_kind_tracker = - &mut ExpressionConstnessTracker::from_arena(&const_expressions); + let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&const_expressions); let mut solver = ConstantEvaluator { behavior: Behavior::Wgsl(WgslRestrictions::Const), types: &mut types, @@ -2482,8 +2480,7 @@ mod tests { let base = const_expressions.append(Expression::Constant(h), Default::default()); - let expression_kind_tracker = - &mut ExpressionConstnessTracker::from_arena(&const_expressions); + let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&const_expressions); let mut solver = ConstantEvaluator { behavior: Behavior::Wgsl(WgslRestrictions::Const), types: &mut types, @@ -2576,8 +2573,7 @@ mod tests { let h_expr = const_expressions.append(Expression::Constant(h), Default::default()); - let expression_kind_tracker = - &mut ExpressionConstnessTracker::from_arena(&const_expressions); + let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&const_expressions); let mut solver = ConstantEvaluator { behavior: Behavior::Wgsl(WgslRestrictions::Const), types: &mut types, @@ -2659,8 +2655,7 @@ mod tests { let h_expr = const_expressions.append(Expression::Constant(h), Default::default()); - let expression_kind_tracker = - &mut ExpressionConstnessTracker::from_arena(&const_expressions); + let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&const_expressions); let mut solver = ConstantEvaluator { behavior: Behavior::Wgsl(WgslRestrictions::Const), types: &mut types, diff --git a/naga/src/proc/mod.rs b/naga/src/proc/mod.rs index 2db956ee0e..eda732978a 100644 --- a/naga/src/proc/mod.rs +++ b/naga/src/proc/mod.rs @@ -11,7 +11,7 @@ mod terminator; mod typifier; pub use constant_evaluator::{ - ConstantEvaluator, ConstantEvaluatorError, ExpressionConstnessTracker, ExpressionKind, + ConstantEvaluator, ConstantEvaluatorError, ExpressionKind, ExpressionKindTracker, }; pub use emitter::Emitter; pub use index::{BoundsCheckPolicies, BoundsCheckPolicy, IndexableLength, IndexableLengthError}; diff --git a/naga/src/valid/expression.rs b/naga/src/valid/expression.rs index 79180a0711..4a1020cb78 100644 --- a/naga/src/valid/expression.rs +++ b/naga/src/valid/expression.rs @@ -187,7 +187,7 @@ impl super::Validator { handle: Handle, gctx: crate::proc::GlobalCtx, mod_info: &ModuleInfo, - global_expr_kind: &crate::proc::ExpressionConstnessTracker, + global_expr_kind: &crate::proc::ExpressionKindTracker, ) -> Result<(), ConstExpressionError> { use crate::Expression as E; @@ -227,7 +227,7 @@ impl super::Validator { module: &crate::Module, info: &FunctionInfo, mod_info: &ModuleInfo, - global_expr_kind: &crate::proc::ExpressionConstnessTracker, + global_expr_kind: &crate::proc::ExpressionKindTracker, ) -> Result { use crate::{Expression as E, Scalar as Sc, ScalarKind as Sk, TypeInner as Ti}; diff --git a/naga/src/valid/function.rs b/naga/src/valid/function.rs index dfb7fbc6ee..b8ad63cc6d 100644 --- a/naga/src/valid/function.rs +++ b/naga/src/valid/function.rs @@ -927,7 +927,7 @@ impl super::Validator { var: &crate::LocalVariable, gctx: crate::proc::GlobalCtx, fun_info: &FunctionInfo, - local_expr_kind: &crate::proc::ExpressionConstnessTracker, + local_expr_kind: &crate::proc::ExpressionKindTracker, ) -> Result<(), LocalVariableError> { log::debug!("var {:?}", var); let type_info = self @@ -959,11 +959,11 @@ impl super::Validator { module: &crate::Module, mod_info: &ModuleInfo, entry_point: bool, - global_expr_kind: &crate::proc::ExpressionConstnessTracker, + global_expr_kind: &crate::proc::ExpressionKindTracker, ) -> Result> { let mut info = mod_info.process_function(fun, module, self.flags, self.capabilities)?; - let local_expr_kind = crate::proc::ExpressionConstnessTracker::from_arena(&fun.expressions); + let local_expr_kind = crate::proc::ExpressionKindTracker::from_arena(&fun.expressions); for (var_handle, var) in fun.local_variables.iter() { self.validate_local_var(var, module.to_ctx(), &info, &local_expr_kind) diff --git a/naga/src/valid/interface.rs b/naga/src/valid/interface.rs index 945af946bb..0e42075de1 100644 --- a/naga/src/valid/interface.rs +++ b/naga/src/valid/interface.rs @@ -400,7 +400,7 @@ impl super::Validator { var: &crate::GlobalVariable, gctx: crate::proc::GlobalCtx, mod_info: &ModuleInfo, - global_expr_kind: &crate::proc::ExpressionConstnessTracker, + global_expr_kind: &crate::proc::ExpressionKindTracker, ) -> Result<(), GlobalVariableError> { use super::TypeFlags; @@ -548,7 +548,7 @@ impl super::Validator { ep: &crate::EntryPoint, module: &crate::Module, mod_info: &ModuleInfo, - global_expr_kind: &crate::proc::ExpressionConstnessTracker, + global_expr_kind: &crate::proc::ExpressionKindTracker, ) -> Result> { if ep.early_depth_test.is_some() { let required = Capabilities::EARLY_DEPTH_TEST; diff --git a/naga/src/valid/mod.rs b/naga/src/valid/mod.rs index b4b2063775..72da6377d9 100644 --- a/naga/src/valid/mod.rs +++ b/naga/src/valid/mod.rs @@ -12,7 +12,7 @@ mod r#type; use crate::{ arena::Handle, - proc::{ExpressionConstnessTracker, LayoutError, Layouter, TypeResolution}, + proc::{ExpressionKindTracker, LayoutError, Layouter, TypeResolution}, FastHashSet, }; use bit_set::BitSet; @@ -342,7 +342,7 @@ impl Validator { handle: Handle, gctx: crate::proc::GlobalCtx, mod_info: &ModuleInfo, - global_expr_kind: &ExpressionConstnessTracker, + global_expr_kind: &ExpressionKindTracker, ) -> Result<(), ConstantError> { let con = &gctx.constants[handle]; @@ -467,7 +467,7 @@ impl Validator { } } - let global_expr_kind = ExpressionConstnessTracker::from_arena(&module.const_expressions); + let global_expr_kind = ExpressionKindTracker::from_arena(&module.const_expressions); if self.flags.contains(ValidationFlags::CONSTANTS) { for (handle, _) in module.const_expressions.iter() { From 29abd93762b07abe9dd7dafad96a47bd5466d7b1 Mon Sep 17 00:00:00 2001 From: teoxoy <28601907+teoxoy@users.noreply.github.com> Date: Wed, 14 Feb 2024 15:25:23 +0100 Subject: [PATCH 11/30] rename `const_expressions` to `global_expressions` --- naga/src/back/glsl/mod.rs | 2 +- naga/src/back/hlsl/writer.rs | 4 +- naga/src/back/msl/writer.rs | 2 +- naga/src/back/pipeline_constants.rs | 26 +++--- naga/src/back/spv/writer.rs | 8 +- naga/src/back/wgsl/writer.rs | 2 +- naga/src/compact/expressions.rs | 14 +-- naga/src/compact/functions.rs | 4 +- naga/src/compact/mod.rs | 34 ++++---- naga/src/front/glsl/context.rs | 4 +- naga/src/front/glsl/parser/functions.rs | 2 +- naga/src/front/glsl/parser_tests.rs | 4 +- naga/src/front/glsl/types.rs | 10 ++- naga/src/front/spv/function.rs | 4 +- naga/src/front/spv/image.rs | 2 +- naga/src/front/spv/mod.rs | 22 ++--- naga/src/front/spv/null.rs | 8 +- naga/src/front/wgsl/lower/mod.rs | 8 +- naga/src/front/wgsl/to_wgsl.rs | 2 +- naga/src/lib.rs | 10 +-- naga/src/proc/constant_evaluator.rs | 108 ++++++++++++------------ naga/src/proc/mod.rs | 12 +-- naga/src/valid/analyzer.rs | 2 +- naga/src/valid/expression.rs | 4 +- naga/src/valid/handles.rs | 12 +-- naga/src/valid/mod.rs | 12 +-- naga/tests/out/ir/access.compact.ron | 2 +- naga/tests/out/ir/access.ron | 2 +- naga/tests/out/ir/collatz.compact.ron | 2 +- naga/tests/out/ir/collatz.ron | 2 +- naga/tests/out/ir/overrides.compact.ron | 2 +- naga/tests/out/ir/overrides.ron | 2 +- naga/tests/out/ir/shadow.compact.ron | 2 +- naga/tests/out/ir/shadow.ron | 2 +- 34 files changed, 170 insertions(+), 168 deletions(-) diff --git a/naga/src/back/glsl/mod.rs b/naga/src/back/glsl/mod.rs index 736a3b57b7..13811a2df0 100644 --- a/naga/src/back/glsl/mod.rs +++ b/naga/src/back/glsl/mod.rs @@ -2410,7 +2410,7 @@ impl<'a, W: Write> Writer<'a, W> { fn write_const_expr(&mut self, expr: Handle) -> BackendResult { self.write_possibly_const_expr( expr, - &self.module.const_expressions, + &self.module.global_expressions, |expr| &self.info[expr], |writer, expr| writer.write_const_expr(expr), ) diff --git a/naga/src/back/hlsl/writer.rs b/naga/src/back/hlsl/writer.rs index 1abc6ceca0..4bde1f6486 100644 --- a/naga/src/back/hlsl/writer.rs +++ b/naga/src/back/hlsl/writer.rs @@ -243,7 +243,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { self.write_special_functions(module)?; - self.write_wrapped_compose_functions(module, &module.const_expressions)?; + self.write_wrapped_compose_functions(module, &module.global_expressions)?; // Write all named constants let mut constants = module @@ -2007,7 +2007,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { self.write_possibly_const_expression( module, expr, - &module.const_expressions, + &module.global_expressions, |writer, expr| writer.write_const_expression(module, expr), ) } diff --git a/naga/src/back/msl/writer.rs b/naga/src/back/msl/writer.rs index 3c2a741cd4..7797bc658f 100644 --- a/naga/src/back/msl/writer.rs +++ b/naga/src/back/msl/writer.rs @@ -1248,7 +1248,7 @@ impl Writer { ) -> BackendResult { self.put_possibly_const_expression( expr_handle, - &module.const_expressions, + &module.global_expressions, module, mod_info, &(module, mod_info), diff --git a/naga/src/back/pipeline_constants.rs b/naga/src/back/pipeline_constants.rs index 79c44f5e9f..a301b4ff3d 100644 --- a/naga/src/back/pipeline_constants.rs +++ b/naga/src/back/pipeline_constants.rs @@ -33,14 +33,14 @@ pub(super) fn process_overrides<'a>( let mut module = module.clone(); let mut override_map = Vec::with_capacity(module.overrides.len()); - let mut adjusted_const_expressions = Vec::with_capacity(module.const_expressions.len()); + let mut adjusted_global_expressions = Vec::with_capacity(module.global_expressions.len()); let mut adjusted_constant_initializers = HashSet::with_capacity(module.constants.len()); let mut global_expression_kind_tracker = crate::proc::ExpressionKindTracker::new(); let mut override_iter = module.overrides.drain(); - for (old_h, expr, span) in module.const_expressions.drain() { + for (old_h, expr, span) in module.global_expressions.drain() { let mut expr = match expr { Expression::Override(h) => { let c_h = if let Some(new_h) = override_map.get(h.index()) { @@ -54,7 +54,7 @@ pub(super) fn process_overrides<'a>( pipeline_constants, &mut module, &mut override_map, - &adjusted_const_expressions, + &adjusted_global_expressions, &mut adjusted_constant_initializers, &mut global_expression_kind_tracker, )?); @@ -68,7 +68,7 @@ pub(super) fn process_overrides<'a>( } Expression::Constant(c_h) => { adjusted_constant_initializers.insert(c_h); - module.constants[c_h].init = adjusted_const_expressions[c_h.index()]; + module.constants[c_h].init = adjusted_global_expressions[c_h.index()]; expr } expr => expr, @@ -78,10 +78,10 @@ pub(super) fn process_overrides<'a>( &mut global_expression_kind_tracker, false, ); - adjust_expr(&adjusted_const_expressions, &mut expr); + adjust_expr(&adjusted_global_expressions, &mut expr); let h = evaluator.try_eval_and_append(expr, span)?; - debug_assert_eq!(old_h.index(), adjusted_const_expressions.len()); - adjusted_const_expressions.push(h); + debug_assert_eq!(old_h.index(), adjusted_global_expressions.len()); + adjusted_global_expressions.push(h); } for entry in override_iter { @@ -90,7 +90,7 @@ pub(super) fn process_overrides<'a>( pipeline_constants, &mut module, &mut override_map, - &adjusted_const_expressions, + &adjusted_global_expressions, &mut adjusted_constant_initializers, &mut global_expression_kind_tracker, )?; @@ -101,12 +101,12 @@ pub(super) fn process_overrides<'a>( .iter_mut() .filter(|&(c_h, _)| !adjusted_constant_initializers.contains(&c_h)) { - c.init = adjusted_const_expressions[c.init.index()]; + c.init = adjusted_global_expressions[c.init.index()]; } for (_, v) in module.global_variables.iter_mut() { if let Some(ref mut init) = v.init { - *init = adjusted_const_expressions[init.index()]; + *init = adjusted_global_expressions[init.index()]; } } @@ -121,7 +121,7 @@ fn process_override( pipeline_constants: &PipelineConstants, module: &mut Module, override_map: &mut Vec>, - adjusted_const_expressions: &[Handle], + adjusted_global_expressions: &[Handle], adjusted_constant_initializers: &mut HashSet>, global_expression_kind_tracker: &mut crate::proc::ExpressionKindTracker, ) -> Result, PipelineConstantError> { @@ -138,12 +138,12 @@ fn process_override( _ => unreachable!(), }; let expr = module - .const_expressions + .global_expressions .append(Expression::Literal(literal), Span::UNDEFINED); global_expression_kind_tracker.insert(expr, crate::proc::ExpressionKind::Const); expr } else if let Some(init) = override_.init { - adjusted_const_expressions[init.index()] + adjusted_global_expressions[init.index()] } else { return Err(PipelineConstantError::MissingValue(key.to_string())); }; diff --git a/naga/src/back/spv/writer.rs b/naga/src/back/spv/writer.rs index 868fad7fa2..0fc0227fb7 100644 --- a/naga/src/back/spv/writer.rs +++ b/naga/src/back/spv/writer.rs @@ -1258,7 +1258,7 @@ impl Writer { ir_module: &crate::Module, mod_info: &ModuleInfo, ) -> Result { - let id = match ir_module.const_expressions[handle] { + let id = match ir_module.global_expressions[handle] { crate::Expression::Literal(literal) => self.get_constant_scalar(literal), crate::Expression::Constant(constant) => { let constant = &ir_module.constants[constant]; @@ -1272,7 +1272,7 @@ impl Writer { let component_ids: Vec<_> = crate::proc::flatten_compose( ty, components, - &ir_module.const_expressions, + &ir_module.global_expressions, &ir_module.types, ) .map(|component| self.constant_ids[component.index()]) @@ -1914,8 +1914,8 @@ impl Writer { // write all const-expressions as constants self.constant_ids - .resize(ir_module.const_expressions.len(), 0); - for (handle, _) in ir_module.const_expressions.iter() { + .resize(ir_module.global_expressions.len(), 0); + for (handle, _) in ir_module.global_expressions.iter() { self.write_constant_expr(handle, ir_module, mod_info)?; } debug_assert!(self.constant_ids.iter().all(|&id| id != 0)); diff --git a/naga/src/back/wgsl/writer.rs b/naga/src/back/wgsl/writer.rs index 7ca689f482..8005a27617 100644 --- a/naga/src/back/wgsl/writer.rs +++ b/naga/src/back/wgsl/writer.rs @@ -1076,7 +1076,7 @@ impl Writer { self.write_possibly_const_expression( module, expr, - &module.const_expressions, + &module.global_expressions, |writer, expr| writer.write_const_expression(module, expr), ) } diff --git a/naga/src/compact/expressions.rs b/naga/src/compact/expressions.rs index 21c4c9cdc2..0f2d8b1a02 100644 --- a/naga/src/compact/expressions.rs +++ b/naga/src/compact/expressions.rs @@ -21,11 +21,11 @@ pub struct ExpressionTracer<'tracer> { /// the module's constant expression arena. pub expressions_used: &'tracer mut HandleSet, - /// The used set for the module's `const_expressions` arena. + /// The used set for the module's `global_expressions` arena. /// /// If `None`, we are already tracing the constant expressions, /// and `expressions_used` already refers to their handle set. - pub const_expressions_used: Option<&'tracer mut HandleSet>, + pub global_expressions_used: Option<&'tracer mut HandleSet>, } impl<'tracer> ExpressionTracer<'tracer> { @@ -40,11 +40,11 @@ impl<'tracer> ExpressionTracer<'tracer> { /// marked. /// /// [fe]: crate::Function::expressions - /// [ce]: crate::Module::const_expressions + /// [ce]: crate::Module::global_expressions pub fn trace_expressions(&mut self) { log::trace!( "entering trace_expression of {}", - if self.const_expressions_used.is_some() { + if self.global_expressions_used.is_some() { "function expressions" } else { "const expressions" @@ -84,7 +84,7 @@ impl<'tracer> ExpressionTracer<'tracer> { // and the constant refers to the initializer, it must // precede `expr` in the arena. let init = self.constants[handle].init; - match self.const_expressions_used { + match self.global_expressions_used { Some(ref mut used) => used.insert(init), None => self.expressions_used.insert(init), } @@ -122,7 +122,7 @@ impl<'tracer> ExpressionTracer<'tracer> { self.expressions_used .insert_iter([image, sampler, coordinate]); self.expressions_used.insert_iter(array_index); - match self.const_expressions_used { + match self.global_expressions_used { Some(ref mut used) => used.insert_iter(offset), None => self.expressions_used.insert_iter(offset), } @@ -276,7 +276,7 @@ impl ModuleMap { adjust(coordinate); operand_map.adjust_option(array_index); if let Some(ref mut offset) = *offset { - self.const_expressions.adjust(offset); + self.global_expressions.adjust(offset); } self.adjust_sample_level(level, operand_map); operand_map.adjust_option(depth_ref); diff --git a/naga/src/compact/functions.rs b/naga/src/compact/functions.rs index 98a23acee0..4ac2223eb7 100644 --- a/naga/src/compact/functions.rs +++ b/naga/src/compact/functions.rs @@ -8,7 +8,7 @@ pub struct FunctionTracer<'a> { pub types_used: &'a mut HandleSet, pub constants_used: &'a mut HandleSet, - pub const_expressions_used: &'a mut HandleSet, + pub global_expressions_used: &'a mut HandleSet, /// Function-local expressions used. pub expressions_used: HandleSet, @@ -54,7 +54,7 @@ impl<'a> FunctionTracer<'a> { types_used: self.types_used, constants_used: self.constants_used, expressions_used: &mut self.expressions_used, - const_expressions_used: Some(&mut self.const_expressions_used), + global_expressions_used: Some(&mut self.global_expressions_used), } } } diff --git a/naga/src/compact/mod.rs b/naga/src/compact/mod.rs index 2b49d34995..0d7a37b579 100644 --- a/naga/src/compact/mod.rs +++ b/naga/src/compact/mod.rs @@ -38,7 +38,7 @@ pub fn compact(module: &mut crate::Module) { log::trace!("tracing global {:?}", global.name); module_tracer.types_used.insert(global.ty); if let Some(init) = global.init { - module_tracer.const_expressions_used.insert(init); + module_tracer.global_expressions_used.insert(init); } } } @@ -50,7 +50,7 @@ pub fn compact(module: &mut crate::Module) { for (handle, constant) in module.constants.iter() { if constant.name.is_some() { module_tracer.constants_used.insert(handle); - module_tracer.const_expressions_used.insert(constant.init); + module_tracer.global_expressions_used.insert(constant.init); } } @@ -58,7 +58,7 @@ pub fn compact(module: &mut crate::Module) { for (_, override_) in module.overrides.iter() { module_tracer.types_used.insert(override_.ty); if let Some(init) = override_.init { - module_tracer.const_expressions_used.insert(init); + module_tracer.global_expressions_used.insert(init); } } @@ -145,9 +145,9 @@ pub fn compact(module: &mut crate::Module) { // Drop unused constant expressions, reusing existing storage. log::trace!("adjusting constant expressions"); - module.const_expressions.retain_mut(|handle, expr| { - if module_map.const_expressions.used(handle) { - module_map.adjust_expression(expr, &module_map.const_expressions); + module.global_expressions.retain_mut(|handle, expr| { + if module_map.global_expressions.used(handle) { + module_map.adjust_expression(expr, &module_map.global_expressions); true } else { false @@ -159,7 +159,7 @@ pub fn compact(module: &mut crate::Module) { module.constants.retain_mut(|handle, constant| { if module_map.constants.used(handle) { module_map.types.adjust(&mut constant.ty); - module_map.const_expressions.adjust(&mut constant.init); + module_map.global_expressions.adjust(&mut constant.init); true } else { false @@ -171,7 +171,7 @@ pub fn compact(module: &mut crate::Module) { for (_, override_) in module.overrides.iter_mut() { module_map.types.adjust(&mut override_.ty); if let Some(init) = override_.init.as_mut() { - module_map.const_expressions.adjust(init); + module_map.global_expressions.adjust(init); } } @@ -181,7 +181,7 @@ pub fn compact(module: &mut crate::Module) { log::trace!("adjusting global {:?}", global.name); module_map.types.adjust(&mut global.ty); if let Some(ref mut init) = global.init { - module_map.const_expressions.adjust(init); + module_map.global_expressions.adjust(init); } } @@ -210,7 +210,7 @@ struct ModuleTracer<'module> { module: &'module crate::Module, types_used: HandleSet, constants_used: HandleSet, - const_expressions_used: HandleSet, + global_expressions_used: HandleSet, } impl<'module> ModuleTracer<'module> { @@ -219,7 +219,7 @@ impl<'module> ModuleTracer<'module> { module, types_used: HandleSet::for_arena(&module.types), constants_used: HandleSet::for_arena(&module.constants), - const_expressions_used: HandleSet::for_arena(&module.const_expressions), + global_expressions_used: HandleSet::for_arena(&module.global_expressions), } } @@ -250,13 +250,13 @@ impl<'module> ModuleTracer<'module> { fn as_const_expression(&mut self) -> expressions::ExpressionTracer { expressions::ExpressionTracer { - expressions: &self.module.const_expressions, + expressions: &self.module.global_expressions, constants: &self.module.constants, overrides: &self.module.overrides, types_used: &mut self.types_used, constants_used: &mut self.constants_used, - expressions_used: &mut self.const_expressions_used, - const_expressions_used: None, + expressions_used: &mut self.global_expressions_used, + global_expressions_used: None, } } @@ -270,7 +270,7 @@ impl<'module> ModuleTracer<'module> { overrides: &self.module.overrides, types_used: &mut self.types_used, constants_used: &mut self.constants_used, - const_expressions_used: &mut self.const_expressions_used, + global_expressions_used: &mut self.global_expressions_used, expressions_used: HandleSet::for_arena(&function.expressions), } } @@ -279,7 +279,7 @@ impl<'module> ModuleTracer<'module> { struct ModuleMap { types: HandleMap, constants: HandleMap, - const_expressions: HandleMap, + global_expressions: HandleMap, } impl From> for ModuleMap { @@ -287,7 +287,7 @@ impl From> for ModuleMap { ModuleMap { types: HandleMap::from_set(used.types_used), constants: HandleMap::from_set(used.constants_used), - const_expressions: HandleMap::from_set(used.const_expressions_used), + global_expressions: HandleMap::from_set(used.global_expressions_used), } } } diff --git a/naga/src/front/glsl/context.rs b/naga/src/front/glsl/context.rs index ec844597d6..6ba7df593a 100644 --- a/naga/src/front/glsl/context.rs +++ b/naga/src/front/glsl/context.rs @@ -79,7 +79,7 @@ pub struct Context<'a> { pub is_const: bool, /// Tracks the expression kind of `Expression`s residing in `self.expressions` pub local_expression_kind_tracker: crate::proc::ExpressionKindTracker, - /// Tracks the expression kind of `Expression`s residing in `self.module.const_expressions` + /// Tracks the expression kind of `Expression`s residing in `self.module.global_expressions` pub global_expression_kind_tracker: &'a mut crate::proc::ExpressionKindTracker, } @@ -1471,7 +1471,7 @@ impl Index> for Context<'_> { fn index(&self, index: Handle) -> &Self::Output { if self.is_const { - &self.module.const_expressions[index] + &self.module.global_expressions[index] } else { &self.expressions[index] } diff --git a/naga/src/front/glsl/parser/functions.rs b/naga/src/front/glsl/parser/functions.rs index 6d3b9d7ba4..d0c889e4d3 100644 --- a/naga/src/front/glsl/parser/functions.rs +++ b/naga/src/front/glsl/parser/functions.rs @@ -198,7 +198,7 @@ impl<'source> ParsingContext<'source> { ctx.global_expression_kind_tracker, )?; - match ctx.module.const_expressions[const_expr] { + match ctx.module.global_expressions[const_expr] { Expression::Literal(Literal::I32(value)) => match uint { // This unchecked cast isn't good, but since // we only reach this code when the selector diff --git a/naga/src/front/glsl/parser_tests.rs b/naga/src/front/glsl/parser_tests.rs index e6e2b2c853..c065dc15d6 100644 --- a/naga/src/front/glsl/parser_tests.rs +++ b/naga/src/front/glsl/parser_tests.rs @@ -539,7 +539,7 @@ fn constants() { let mut types = module.types.iter(); let mut constants = module.constants.iter(); - let mut const_expressions = module.const_expressions.iter(); + let mut global_expressions = module.global_expressions.iter(); let (ty_handle, ty) = types.next().unwrap(); assert_eq!( @@ -550,7 +550,7 @@ fn constants() { } ); - let (init_handle, init) = const_expressions.next().unwrap(); + let (init_handle, init) = global_expressions.next().unwrap(); assert_eq!(init, &Expression::Literal(crate::Literal::F32(1.0))); assert_eq!( diff --git a/naga/src/front/glsl/types.rs b/naga/src/front/glsl/types.rs index 8a04b23839..f6836169c0 100644 --- a/naga/src/front/glsl/types.rs +++ b/naga/src/front/glsl/types.rs @@ -233,7 +233,7 @@ impl Context<'_> { }; let expressions = if self.is_const { - &self.module.const_expressions + &self.module.global_expressions } else { &self.expressions }; @@ -333,20 +333,22 @@ impl Context<'_> { let h = match self.expressions[expr] { ref expr @ (Expression::Literal(_) | Expression::Constant(_) - | Expression::ZeroValue(_)) => self.module.const_expressions.append(expr.clone(), meta), + | Expression::ZeroValue(_)) => { + self.module.global_expressions.append(expr.clone(), meta) + } Expression::Compose { ty, ref components } => { let mut components = components.clone(); for component in &mut components { *component = self.lift_up_const_expression(*component)?; } self.module - .const_expressions + .global_expressions .append(Expression::Compose { ty, components }, meta) } Expression::Splat { size, value } => { let value = self.lift_up_const_expression(value)?; self.module - .const_expressions + .global_expressions .append(Expression::Splat { size, value }, meta) } _ => { diff --git a/naga/src/front/spv/function.rs b/naga/src/front/spv/function.rs index 7fefef02a2..5f8dd09608 100644 --- a/naga/src/front/spv/function.rs +++ b/naga/src/front/spv/function.rs @@ -129,7 +129,7 @@ impl> super::Frontend { local_arena: &mut fun.local_variables, const_arena: &mut module.constants, overrides: &mut module.overrides, - const_expressions: &mut module.const_expressions, + global_expressions: &mut module.global_expressions, type_arena: &module.types, global_arena: &module.global_variables, arguments: &fun.arguments, @@ -583,7 +583,7 @@ impl<'function> BlockContext<'function> { types: self.type_arena, constants: self.const_arena, overrides: self.overrides, - const_expressions: self.const_expressions, + global_expressions: self.global_expressions, } } diff --git a/naga/src/front/spv/image.rs b/naga/src/front/spv/image.rs index 0f25dd626b..21fff3f4af 100644 --- a/naga/src/front/spv/image.rs +++ b/naga/src/front/spv/image.rs @@ -508,7 +508,7 @@ impl> super::Frontend { spirv::ImageOperands::CONST_OFFSET => { let offset_constant = self.next()?; let offset_handle = self.lookup_constant.lookup(offset_constant)?.handle; - let offset_handle = ctx.const_expressions.append( + let offset_handle = ctx.global_expressions.append( crate::Expression::Constant(offset_handle), Default::default(), ); diff --git a/naga/src/front/spv/mod.rs b/naga/src/front/spv/mod.rs index 697cbb7b4e..24053cf26b 100644 --- a/naga/src/front/spv/mod.rs +++ b/naga/src/front/spv/mod.rs @@ -532,7 +532,7 @@ struct BlockContext<'function> { /// Constants arena of the module being processed const_arena: &'function mut Arena, overrides: &'function mut Arena, - const_expressions: &'function mut Arena, + global_expressions: &'function mut Arena, /// Type arena of the module being processed type_arena: &'function UniqueArena, /// Global arena of the module being processed @@ -4916,7 +4916,7 @@ impl> Frontend { let span = self.span_from_with_op(start); let init = module - .const_expressions + .global_expressions .append(crate::Expression::Literal(literal), span); self.lookup_constant.insert( id, @@ -4956,7 +4956,7 @@ impl> Frontend { let span = self.span_from_with_op(start); let constant = self.lookup_constant.lookup(component_id)?; let expr = module - .const_expressions + .global_expressions .append(crate::Expression::Constant(constant.handle), span); components.push(expr); } @@ -4966,7 +4966,7 @@ impl> Frontend { let span = self.span_from_with_op(start); let init = module - .const_expressions + .global_expressions .append(crate::Expression::Compose { ty, components }, span); self.lookup_constant.insert( id, @@ -5003,7 +5003,7 @@ impl> Frontend { let decor = self.future_decor.remove(&id).unwrap_or_default(); let init = module - .const_expressions + .global_expressions .append(crate::Expression::ZeroValue(ty), span); let handle = module.constants.append( crate::Constant { @@ -5036,7 +5036,7 @@ impl> Frontend { let decor = self.future_decor.remove(&id).unwrap_or_default(); - let init = module.const_expressions.append( + let init = module.global_expressions.append( crate::Expression::Literal(crate::Literal::Bool(value)), span, ); @@ -5075,7 +5075,7 @@ impl> Frontend { let span = self.span_from_with_op(start); let lconst = self.lookup_constant.lookup(init_id)?; let expr = module - .const_expressions + .global_expressions .append(crate::Expression::Constant(lconst.handle), span); Some(expr) } else { @@ -5197,7 +5197,7 @@ impl> Frontend { match null::generate_default_built_in( Some(built_in), ty, - &mut module.const_expressions, + &mut module.global_expressions, span, ) { Ok(handle) => Some(handle), @@ -5219,14 +5219,14 @@ impl> Frontend { let handle = null::generate_default_built_in( built_in, member.ty, - &mut module.const_expressions, + &mut module.global_expressions, span, )?; components.push(handle); } Some( module - .const_expressions + .global_expressions .append(crate::Expression::Compose { ty, components }, span), ) } @@ -5295,7 +5295,7 @@ fn resolve_constant( gctx: crate::proc::GlobalCtx, constant: Handle, ) -> Option { - match gctx.const_expressions[gctx.constants[constant].init] { + match gctx.global_expressions[gctx.constants[constant].init] { crate::Expression::Literal(crate::Literal::U32(id)) => Some(id), crate::Expression::Literal(crate::Literal::I32(id)) => Some(id as u32), _ => None, diff --git a/naga/src/front/spv/null.rs b/naga/src/front/spv/null.rs index 42cccca80a..c7d3776841 100644 --- a/naga/src/front/spv/null.rs +++ b/naga/src/front/spv/null.rs @@ -5,14 +5,14 @@ use crate::arena::{Arena, Handle}; pub fn generate_default_built_in( built_in: Option, ty: Handle, - const_expressions: &mut Arena, + global_expressions: &mut Arena, span: crate::Span, ) -> Result, Error> { let expr = match built_in { Some(crate::BuiltIn::Position { .. }) => { - let zero = const_expressions + let zero = global_expressions .append(crate::Expression::Literal(crate::Literal::F32(0.0)), span); - let one = const_expressions + let one = global_expressions .append(crate::Expression::Literal(crate::Literal::F32(1.0)), span); crate::Expression::Compose { ty, @@ -27,5 +27,5 @@ pub fn generate_default_built_in( // Note: `crate::BuiltIn::ClipDistance` is intentionally left for the default path _ => crate::Expression::ZeroValue(ty), }; - Ok(const_expressions.append(expr, span)) + Ok(global_expressions.append(expr, span)) } diff --git a/naga/src/front/wgsl/lower/mod.rs b/naga/src/front/wgsl/lower/mod.rs index e689dda53a..1a8b75811b 100644 --- a/naga/src/front/wgsl/lower/mod.rs +++ b/naga/src/front/wgsl/lower/mod.rs @@ -333,9 +333,9 @@ pub struct ExpressionContext<'source, 'temp, 'out> { /// [`Module`]: crate::Module module: &'out mut crate::Module, - /// Type judgments for [`module::const_expressions`]. + /// Type judgments for [`module::global_expressions`]. /// - /// [`module::const_expressions`]: crate::Module::const_expressions + /// [`module::global_expressions`]: crate::Module::global_expressions const_typifier: &'temp mut Typifier, global_expression_kind_tracker: &'temp mut crate::proc::ExpressionKindTracker, @@ -421,7 +421,7 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> { match self.expr_type { ExpressionContextType::Runtime(ref ctx) => ctx.function.expressions.get_span(handle), ExpressionContextType::Constant | ExpressionContextType::Override => { - self.module.const_expressions.get_span(handle) + self.module.global_expressions.get_span(handle) } } } @@ -554,7 +554,7 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> { ExpressionContextType::Constant | ExpressionContextType::Override => { resolve_ctx = ResolveContext::with_locals(self.module, &empty_arena, &[]); typifier = self.const_typifier; - expressions = &self.module.const_expressions; + expressions = &self.module.global_expressions; } }; typifier diff --git a/naga/src/front/wgsl/to_wgsl.rs b/naga/src/front/wgsl/to_wgsl.rs index ba6063ab46..63bc9f7317 100644 --- a/naga/src/front/wgsl/to_wgsl.rs +++ b/naga/src/front/wgsl/to_wgsl.rs @@ -227,7 +227,7 @@ mod tests { types: &types, constants: &crate::Arena::new(), overrides: &crate::Arena::new(), - const_expressions: &crate::Arena::new(), + global_expressions: &crate::Arena::new(), }; let array = crate::TypeInner::Array { base: mytype1, diff --git a/naga/src/lib.rs b/naga/src/lib.rs index 671fcc97c6..4b421b08fd 100644 --- a/naga/src/lib.rs +++ b/naga/src/lib.rs @@ -897,7 +897,7 @@ pub struct Override { /// The default value of the pipeline-overridable constant. /// - /// This [`Handle`] refers to [`Module::const_expressions`], not + /// This [`Handle`] refers to [`Module::global_expressions`], not /// any [`Function::expressions`] arena. pub init: Option>, } @@ -913,7 +913,7 @@ pub struct Constant { /// The value of the constant. /// - /// This [`Handle`] refers to [`Module::const_expressions`], not + /// This [`Handle`] refers to [`Module::global_expressions`], not /// any [`Function::expressions`] arena. pub init: Handle, } @@ -980,7 +980,7 @@ pub struct GlobalVariable { pub ty: Handle, /// Initial value for this variable. /// - /// Expression handle lives in const_expressions + /// Expression handle lives in global_expressions pub init: Option>, } @@ -1430,7 +1430,7 @@ pub enum Expression { gather: Option, coordinate: Handle, array_index: Option>, - /// Expression handle lives in const_expressions + /// Expression handle lives in global_expressions offset: Option>, level: SampleLevel, depth_ref: Option>, @@ -2065,7 +2065,7 @@ pub struct Module { /// /// [Constant expressions]: index.html#constant-expressions /// [override expressions]: index.html#override-expressions - pub const_expressions: Arena, + pub global_expressions: Arena, /// Arena for the functions defined in this module. /// /// Each function must appear in this arena strictly before all its callers. diff --git a/naga/src/proc/constant_evaluator.rs b/naga/src/proc/constant_evaluator.rs index 6f09ec5444..532f364532 100644 --- a/naga/src/proc/constant_evaluator.rs +++ b/naga/src/proc/constant_evaluator.rs @@ -327,7 +327,7 @@ enum GlslRestrictions<'a> { #[derive(Debug)] struct FunctionLocalData<'a> { /// Global constant expressions - const_expressions: &'a Arena, + global_expressions: &'a Arena, emitter: &'a mut super::Emitter, block: &'a mut crate::Block, } @@ -570,7 +570,7 @@ impl<'a> ConstantEvaluator<'a> { types: &mut module.types, constants: &module.constants, overrides: &module.overrides, - expressions: &mut module.const_expressions, + expressions: &mut module.global_expressions, expression_kind_tracker: global_expression_kind_tracker, } } @@ -588,7 +588,7 @@ impl<'a> ConstantEvaluator<'a> { ) -> Self { Self { behavior: Behavior::Wgsl(WgslRestrictions::Runtime(FunctionLocalData { - const_expressions: &module.const_expressions, + global_expressions: &module.global_expressions, emitter, block, })), @@ -613,7 +613,7 @@ impl<'a> ConstantEvaluator<'a> { ) -> Self { Self { behavior: Behavior::Glsl(GlslRestrictions::Runtime(FunctionLocalData { - const_expressions: &module.const_expressions, + global_expressions: &module.global_expressions, emitter, block, })), @@ -630,8 +630,8 @@ impl<'a> ConstantEvaluator<'a> { types: self.types, constants: self.constants, overrides: self.overrides, - const_expressions: match self.function_local_data() { - Some(data) => data.const_expressions, + global_expressions: match self.function_local_data() { + Some(data) => data.global_expressions, None => self.expressions, }, } @@ -657,7 +657,7 @@ impl<'a> ConstantEvaluator<'a> { // Deep-copy the constant's value into our arena. self.copy_from( self.constants[c].init, - function_local_data.const_expressions, + function_local_data.global_expressions, ) } else { // "See through" the constant and use its initializer. @@ -2202,7 +2202,7 @@ mod tests { let mut types = UniqueArena::new(); let mut constants = Arena::new(); let overrides = Arena::new(); - let mut const_expressions = Arena::new(); + let mut global_expressions = Arena::new(); let scalar_ty = types.insert( Type { @@ -2227,7 +2227,7 @@ mod tests { Constant { name: None, ty: scalar_ty, - init: const_expressions + init: global_expressions .append(Expression::Literal(Literal::I32(4)), Default::default()), }, Default::default(), @@ -2237,7 +2237,7 @@ mod tests { Constant { name: None, ty: scalar_ty, - init: const_expressions + init: global_expressions .append(Expression::Literal(Literal::I32(8)), Default::default()), }, Default::default(), @@ -2247,7 +2247,7 @@ mod tests { Constant { name: None, ty: vec_ty, - init: const_expressions.append( + init: global_expressions.append( Expression::Compose { ty: vec_ty, components: vec![constants[h].init, constants[h1].init], @@ -2258,8 +2258,8 @@ mod tests { Default::default(), ); - let expr = const_expressions.append(Expression::Constant(h), Default::default()); - let expr1 = const_expressions.append(Expression::Constant(vec_h), Default::default()); + let expr = global_expressions.append(Expression::Constant(h), Default::default()); + let expr1 = global_expressions.append(Expression::Constant(vec_h), Default::default()); let expr2 = Expression::Unary { op: UnaryOperator::Negate, @@ -2276,13 +2276,13 @@ mod tests { expr: expr1, }; - let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&const_expressions); + let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions); let mut solver = ConstantEvaluator { behavior: Behavior::Wgsl(WgslRestrictions::Const), types: &mut types, constants: &constants, overrides: &overrides, - expressions: &mut const_expressions, + expressions: &mut global_expressions, expression_kind_tracker, }; @@ -2297,16 +2297,16 @@ mod tests { .unwrap(); assert_eq!( - const_expressions[res1], + global_expressions[res1], Expression::Literal(Literal::I32(-4)) ); assert_eq!( - const_expressions[res2], + global_expressions[res2], Expression::Literal(Literal::I32(!4)) ); - let res3_inner = &const_expressions[res3]; + let res3_inner = &global_expressions[res3]; match *res3_inner { Expression::Compose { @@ -2316,11 +2316,11 @@ mod tests { assert_eq!(*ty, vec_ty); let mut components_iter = components.iter().copied(); assert_eq!( - const_expressions[components_iter.next().unwrap()], + global_expressions[components_iter.next().unwrap()], Expression::Literal(Literal::I32(!4)) ); assert_eq!( - const_expressions[components_iter.next().unwrap()], + global_expressions[components_iter.next().unwrap()], Expression::Literal(Literal::I32(!8)) ); assert!(components_iter.next().is_none()); @@ -2334,7 +2334,7 @@ mod tests { let mut types = UniqueArena::new(); let mut constants = Arena::new(); let overrides = Arena::new(); - let mut const_expressions = Arena::new(); + let mut global_expressions = Arena::new(); let scalar_ty = types.insert( Type { @@ -2348,13 +2348,13 @@ mod tests { Constant { name: None, ty: scalar_ty, - init: const_expressions + init: global_expressions .append(Expression::Literal(Literal::I32(4)), Default::default()), }, Default::default(), ); - let expr = const_expressions.append(Expression::Constant(h), Default::default()); + let expr = global_expressions.append(Expression::Constant(h), Default::default()); let root = Expression::As { expr, @@ -2362,13 +2362,13 @@ mod tests { convert: Some(crate::BOOL_WIDTH), }; - let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&const_expressions); + let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions); let mut solver = ConstantEvaluator { behavior: Behavior::Wgsl(WgslRestrictions::Const), types: &mut types, constants: &constants, overrides: &overrides, - expressions: &mut const_expressions, + expressions: &mut global_expressions, expression_kind_tracker, }; @@ -2377,7 +2377,7 @@ mod tests { .unwrap(); assert_eq!( - const_expressions[res], + global_expressions[res], Expression::Literal(Literal::Bool(true)) ); } @@ -2387,7 +2387,7 @@ mod tests { let mut types = UniqueArena::new(); let mut constants = Arena::new(); let overrides = Arena::new(); - let mut const_expressions = Arena::new(); + let mut global_expressions = Arena::new(); let matrix_ty = types.insert( Type { @@ -2416,7 +2416,7 @@ mod tests { let mut vec2_components = Vec::with_capacity(3); for i in 0..3 { - let h = const_expressions.append( + let h = global_expressions.append( Expression::Literal(Literal::F32(i as f32)), Default::default(), ); @@ -2425,7 +2425,7 @@ mod tests { } for i in 3..6 { - let h = const_expressions.append( + let h = global_expressions.append( Expression::Literal(Literal::F32(i as f32)), Default::default(), ); @@ -2437,7 +2437,7 @@ mod tests { Constant { name: None, ty: vec_ty, - init: const_expressions.append( + init: global_expressions.append( Expression::Compose { ty: vec_ty, components: vec1_components, @@ -2452,7 +2452,7 @@ mod tests { Constant { name: None, ty: vec_ty, - init: const_expressions.append( + init: global_expressions.append( Expression::Compose { ty: vec_ty, components: vec2_components, @@ -2467,7 +2467,7 @@ mod tests { Constant { name: None, ty: matrix_ty, - init: const_expressions.append( + init: global_expressions.append( Expression::Compose { ty: matrix_ty, components: vec![constants[vec1].init, constants[vec2].init], @@ -2478,15 +2478,15 @@ mod tests { Default::default(), ); - let base = const_expressions.append(Expression::Constant(h), Default::default()); + let base = global_expressions.append(Expression::Constant(h), Default::default()); - let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&const_expressions); + let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions); let mut solver = ConstantEvaluator { behavior: Behavior::Wgsl(WgslRestrictions::Const), types: &mut types, constants: &constants, overrides: &overrides, - expressions: &mut const_expressions, + expressions: &mut global_expressions, expression_kind_tracker, }; @@ -2505,7 +2505,7 @@ mod tests { .try_eval_and_append(root2, Default::default()) .unwrap(); - match const_expressions[res1] { + match global_expressions[res1] { Expression::Compose { ref ty, ref components, @@ -2513,15 +2513,15 @@ mod tests { assert_eq!(*ty, vec_ty); let mut components_iter = components.iter().copied(); assert_eq!( - const_expressions[components_iter.next().unwrap()], + global_expressions[components_iter.next().unwrap()], Expression::Literal(Literal::F32(3.)) ); assert_eq!( - const_expressions[components_iter.next().unwrap()], + global_expressions[components_iter.next().unwrap()], Expression::Literal(Literal::F32(4.)) ); assert_eq!( - const_expressions[components_iter.next().unwrap()], + global_expressions[components_iter.next().unwrap()], Expression::Literal(Literal::F32(5.)) ); assert!(components_iter.next().is_none()); @@ -2530,7 +2530,7 @@ mod tests { } assert_eq!( - const_expressions[res2], + global_expressions[res2], Expression::Literal(Literal::F32(5.)) ); } @@ -2540,7 +2540,7 @@ mod tests { let mut types = UniqueArena::new(); let mut constants = Arena::new(); let overrides = Arena::new(); - let mut const_expressions = Arena::new(); + let mut global_expressions = Arena::new(); let i32_ty = types.insert( Type { @@ -2565,21 +2565,21 @@ mod tests { Constant { name: None, ty: i32_ty, - init: const_expressions + init: global_expressions .append(Expression::Literal(Literal::I32(4)), Default::default()), }, Default::default(), ); - let h_expr = const_expressions.append(Expression::Constant(h), Default::default()); + let h_expr = global_expressions.append(Expression::Constant(h), Default::default()); - let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&const_expressions); + let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions); let mut solver = ConstantEvaluator { behavior: Behavior::Wgsl(WgslRestrictions::Const), types: &mut types, constants: &constants, overrides: &overrides, - expressions: &mut const_expressions, + expressions: &mut global_expressions, expression_kind_tracker, }; @@ -2602,11 +2602,11 @@ mod tests { ) .unwrap(); - let pass = match const_expressions[solved_negate] { + let pass = match global_expressions[solved_negate] { Expression::Compose { ty, ref components } => { ty == vec2_i32_ty && components.iter().all(|&component| { - let component = &const_expressions[component]; + let component = &global_expressions[component]; matches!(*component, Expression::Literal(Literal::I32(-4))) }) } @@ -2622,7 +2622,7 @@ mod tests { let mut types = UniqueArena::new(); let mut constants = Arena::new(); let overrides = Arena::new(); - let mut const_expressions = Arena::new(); + let mut global_expressions = Arena::new(); let i32_ty = types.insert( Type { @@ -2647,21 +2647,21 @@ mod tests { Constant { name: None, ty: i32_ty, - init: const_expressions + init: global_expressions .append(Expression::Literal(Literal::I32(4)), Default::default()), }, Default::default(), ); - let h_expr = const_expressions.append(Expression::Constant(h), Default::default()); + let h_expr = global_expressions.append(Expression::Constant(h), Default::default()); - let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&const_expressions); + let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions); let mut solver = ConstantEvaluator { behavior: Behavior::Wgsl(WgslRestrictions::Const), types: &mut types, constants: &constants, overrides: &overrides, - expressions: &mut const_expressions, + expressions: &mut global_expressions, expression_kind_tracker, }; @@ -2684,11 +2684,11 @@ mod tests { ) .unwrap(); - let pass = match const_expressions[solved_negate] { + let pass = match global_expressions[solved_negate] { Expression::Compose { ty, ref components } => { ty == vec2_i32_ty && components.iter().all(|&component| { - let component = &const_expressions[component]; + let component = &global_expressions[component]; matches!(*component, Expression::Literal(Literal::I32(-4))) }) } diff --git a/naga/src/proc/mod.rs b/naga/src/proc/mod.rs index eda732978a..0e89f29032 100644 --- a/naga/src/proc/mod.rs +++ b/naga/src/proc/mod.rs @@ -649,7 +649,7 @@ impl crate::Module { types: &self.types, constants: &self.constants, overrides: &self.overrides, - const_expressions: &self.const_expressions, + global_expressions: &self.global_expressions, } } } @@ -665,17 +665,17 @@ pub struct GlobalCtx<'a> { pub types: &'a crate::UniqueArena, pub constants: &'a crate::Arena, pub overrides: &'a crate::Arena, - pub const_expressions: &'a crate::Arena, + pub global_expressions: &'a crate::Arena, } impl GlobalCtx<'_> { - /// Try to evaluate the expression in `self.const_expressions` using its `handle` and return it as a `u32`. + /// Try to evaluate the expression in `self.global_expressions` using its `handle` and return it as a `u32`. #[allow(dead_code)] pub(super) fn eval_expr_to_u32( &self, handle: crate::Handle, ) -> Result { - self.eval_expr_to_u32_from(handle, self.const_expressions) + self.eval_expr_to_u32_from(handle, self.global_expressions) } /// Try to evaluate the expression in the `arena` using its `handle` and return it as a `u32`. @@ -698,7 +698,7 @@ impl GlobalCtx<'_> { &self, handle: crate::Handle, ) -> Option { - self.eval_expr_to_literal_from(handle, self.const_expressions) + self.eval_expr_to_literal_from(handle, self.global_expressions) } fn eval_expr_to_literal_from( @@ -722,7 +722,7 @@ impl GlobalCtx<'_> { } match arena[handle] { crate::Expression::Constant(c) => { - get(*self, self.constants[c].init, self.const_expressions) + get(*self, self.constants[c].init, self.global_expressions) } _ => get(*self, handle, arena), } diff --git a/naga/src/valid/analyzer.rs b/naga/src/valid/analyzer.rs index fbb4461e38..d45c25c62e 100644 --- a/naga/src/valid/analyzer.rs +++ b/naga/src/valid/analyzer.rs @@ -1047,7 +1047,7 @@ impl ModuleInfo { gctx: crate::proc::GlobalCtx, ) -> Result<(), super::ConstExpressionError> { self.const_expression_types[handle.index()] = - resolve_context.resolve(&gctx.const_expressions[handle], |h| Ok(&self[h]))?; + resolve_context.resolve(&gctx.global_expressions[handle], |h| Ok(&self[h]))?; Ok(()) } diff --git a/naga/src/valid/expression.rs b/naga/src/valid/expression.rs index 4a1020cb78..289eb02011 100644 --- a/naga/src/valid/expression.rs +++ b/naga/src/valid/expression.rs @@ -195,7 +195,7 @@ impl super::Validator { return Err(super::ConstExpressionError::NonConstOrOverride); } - match gctx.const_expressions[handle] { + match gctx.global_expressions[handle] { E::Literal(literal) => { self.validate_literal(literal)?; } @@ -1729,7 +1729,7 @@ fn validate_with_const_expression( use crate::span::Span; let mut module = crate::Module::default(); - module.const_expressions.append(expr, Span::default()); + module.global_expressions.append(expr, Span::default()); let mut validator = super::Validator::new(super::ValidationFlags::CONSTANTS, caps); diff --git a/naga/src/valid/handles.rs b/naga/src/valid/handles.rs index bcda98b294..5d3087a28f 100644 --- a/naga/src/valid/handles.rs +++ b/naga/src/valid/handles.rs @@ -37,7 +37,7 @@ impl super::Validator { ref global_variables, ref types, ref special_types, - ref const_expressions, + ref global_expressions, } = module; // NOTE: Types being first is important. All other forms of validation depend on this. @@ -68,13 +68,13 @@ impl super::Validator { } } - for handle_and_expr in const_expressions.iter() { + for handle_and_expr in global_expressions.iter() { Self::validate_const_expression_handles(handle_and_expr, constants, overrides, types)?; } let validate_type = |handle| Self::validate_type_handle(handle, types); let validate_const_expr = - |handle| Self::validate_expression_handle(handle, const_expressions); + |handle| Self::validate_expression_handle(handle, global_expressions); for (_handle, constant) in constants.iter() { let &crate::Constant { name: _, ty, init } = constant; @@ -150,7 +150,7 @@ impl super::Validator { handle_and_expr, constants, overrides, - const_expressions, + global_expressions, types, local_variables, global_variables, @@ -256,7 +256,7 @@ impl super::Validator { (handle, expression): (Handle, &crate::Expression), constants: &Arena, overrides: &Arena, - const_expressions: &Arena, + global_expressions: &Arena, types: &UniqueArena, local_variables: &Arena, global_variables: &Arena, @@ -267,7 +267,7 @@ impl super::Validator { let validate_constant = |handle| Self::validate_constant_handle(handle, constants); let validate_override = |handle| Self::validate_override_handle(handle, overrides); let validate_const_expr = - |handle| Self::validate_expression_handle(handle, const_expressions); + |handle| Self::validate_expression_handle(handle, global_expressions); let validate_type = |handle| Self::validate_type_handle(handle, types); match *expression { diff --git a/naga/src/valid/mod.rs b/naga/src/valid/mod.rs index 72da6377d9..b9730c1f3d 100644 --- a/naga/src/valid/mod.rs +++ b/naga/src/valid/mod.rs @@ -435,7 +435,7 @@ impl Validator { type_flags: Vec::with_capacity(module.types.len()), functions: Vec::with_capacity(module.functions.len()), entry_points: Vec::with_capacity(module.entry_points.len()), - const_expression_types: vec![placeholder; module.const_expressions.len()] + const_expression_types: vec![placeholder; module.global_expressions.len()] .into_boxed_slice(), }; @@ -457,20 +457,20 @@ impl Validator { { let t = crate::Arena::new(); let resolve_context = crate::proc::ResolveContext::with_locals(module, &t, &[]); - for (handle, _) in module.const_expressions.iter() { + for (handle, _) in module.global_expressions.iter() { mod_info .process_const_expression(handle, &resolve_context, module.to_ctx()) .map_err(|source| { ValidationError::ConstExpression { handle, source } - .with_span_handle(handle, &module.const_expressions) + .with_span_handle(handle, &module.global_expressions) })? } } - let global_expr_kind = ExpressionKindTracker::from_arena(&module.const_expressions); + let global_expr_kind = ExpressionKindTracker::from_arena(&module.global_expressions); if self.flags.contains(ValidationFlags::CONSTANTS) { - for (handle, _) in module.const_expressions.iter() { + for (handle, _) in module.global_expressions.iter() { self.validate_const_expression( handle, module.to_ctx(), @@ -479,7 +479,7 @@ impl Validator { ) .map_err(|source| { ValidationError::ConstExpression { handle, source } - .with_span_handle(handle, &module.const_expressions) + .with_span_handle(handle, &module.global_expressions) })? } diff --git a/naga/tests/out/ir/access.compact.ron b/naga/tests/out/ir/access.compact.ron index 37ace5283f..4bae535e64 100644 --- a/naga/tests/out/ir/access.compact.ron +++ b/naga/tests/out/ir/access.compact.ron @@ -378,7 +378,7 @@ init: None, ), ], - const_expressions: [ + global_expressions: [ Literal(U32(0)), Literal(U32(0)), Literal(U32(0)), diff --git a/naga/tests/out/ir/access.ron b/naga/tests/out/ir/access.ron index 37ace5283f..4bae535e64 100644 --- a/naga/tests/out/ir/access.ron +++ b/naga/tests/out/ir/access.ron @@ -378,7 +378,7 @@ init: None, ), ], - const_expressions: [ + global_expressions: [ Literal(U32(0)), Literal(U32(0)), Literal(U32(0)), diff --git a/naga/tests/out/ir/collatz.compact.ron b/naga/tests/out/ir/collatz.compact.ron index fe4af55c1b..3312ddbf77 100644 --- a/naga/tests/out/ir/collatz.compact.ron +++ b/naga/tests/out/ir/collatz.compact.ron @@ -61,7 +61,7 @@ init: None, ), ], - const_expressions: [], + global_expressions: [], functions: [ ( name: Some("collatz_iterations"), diff --git a/naga/tests/out/ir/collatz.ron b/naga/tests/out/ir/collatz.ron index fe4af55c1b..3312ddbf77 100644 --- a/naga/tests/out/ir/collatz.ron +++ b/naga/tests/out/ir/collatz.ron @@ -61,7 +61,7 @@ init: None, ), ], - const_expressions: [], + global_expressions: [], functions: [ ( name: Some("collatz_iterations"), diff --git a/naga/tests/out/ir/overrides.compact.ron b/naga/tests/out/ir/overrides.compact.ron index d15abbd033..7a60f14239 100644 --- a/naga/tests/out/ir/overrides.compact.ron +++ b/naga/tests/out/ir/overrides.compact.ron @@ -66,7 +66,7 @@ ), ], global_variables: [], - const_expressions: [ + global_expressions: [ Literal(Bool(true)), Literal(F32(2.3)), Literal(F32(0.0)), diff --git a/naga/tests/out/ir/overrides.ron b/naga/tests/out/ir/overrides.ron index d15abbd033..7a60f14239 100644 --- a/naga/tests/out/ir/overrides.ron +++ b/naga/tests/out/ir/overrides.ron @@ -66,7 +66,7 @@ ), ], global_variables: [], - const_expressions: [ + global_expressions: [ Literal(Bool(true)), Literal(F32(2.3)), Literal(F32(0.0)), diff --git a/naga/tests/out/ir/shadow.compact.ron b/naga/tests/out/ir/shadow.compact.ron index fab0f1e2f6..45d819b9e0 100644 --- a/naga/tests/out/ir/shadow.compact.ron +++ b/naga/tests/out/ir/shadow.compact.ron @@ -319,7 +319,7 @@ init: None, ), ], - const_expressions: [ + global_expressions: [ Literal(F32(0.0)), Literal(F32(1.0)), Literal(F32(0.5)), diff --git a/naga/tests/out/ir/shadow.ron b/naga/tests/out/ir/shadow.ron index 9acbbdaadd..523c6d4192 100644 --- a/naga/tests/out/ir/shadow.ron +++ b/naga/tests/out/ir/shadow.ron @@ -522,7 +522,7 @@ init: None, ), ], - const_expressions: [ + global_expressions: [ Literal(F32(0.0)), Literal(F32(1.0)), Literal(F32(0.5)), From bb999fcbfe38d1a3c7c8a494fac2b851d31014d9 Mon Sep 17 00:00:00 2001 From: teoxoy <28601907+teoxoy@users.noreply.github.com> Date: Thu, 7 Mar 2024 16:11:33 +0100 Subject: [PATCH 12/30] [valid] error on non fully evaluated const-expressions --- naga/src/valid/expression.rs | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/naga/src/valid/expression.rs b/naga/src/valid/expression.rs index 289eb02011..055ca6dfdb 100644 --- a/naga/src/valid/expression.rs +++ b/naga/src/valid/expression.rs @@ -135,6 +135,8 @@ pub enum ExpressionError { pub enum ConstExpressionError { #[error("The expression is not a constant or override expression")] NonConstOrOverride, + #[error("The expression is not a fully evaluated constant expression")] + NonFullyEvaluatedConst, #[error(transparent)] Compose(#[from] super::ComposeError), #[error("Splatting {0:?} can't be done")] @@ -211,6 +213,9 @@ impl super::Validator { crate::TypeInner::Scalar { .. } => {} _ => return Err(super::ConstExpressionError::InvalidSplatType(value)), }, + _ if global_expr_kind.is_const(handle) => { + return Err(super::ConstExpressionError::NonFullyEvaluatedConst) + } // the constant evaluator will report errors about override-expressions _ => {} } From 7d79d71742c20bd4593cc595c9366fe7b4e7e986 Mon Sep 17 00:00:00 2001 From: teoxoy <28601907+teoxoy@users.noreply.github.com> Date: Thu, 7 Mar 2024 16:35:18 +0100 Subject: [PATCH 13/30] [valid] make sure overrides are not present after evaluation --- naga/src/back/pipeline_constants.rs | 2 +- naga/src/valid/expression.rs | 2 +- naga/src/valid/mod.rs | 27 +++++++++++++++++++++++++++ 3 files changed, 29 insertions(+), 2 deletions(-) diff --git a/naga/src/back/pipeline_constants.rs b/naga/src/back/pipeline_constants.rs index a301b4ff3d..9679aaecb9 100644 --- a/naga/src/back/pipeline_constants.rs +++ b/naga/src/back/pipeline_constants.rs @@ -111,7 +111,7 @@ pub(super) fn process_overrides<'a>( } let mut validator = Validator::new(ValidationFlags::all(), Capabilities::all()); - let module_info = validator.validate(&module)?; + let module_info = validator.validate_no_overrides(&module)?; Ok((Cow::Owned(module), Cow::Owned(module_info))) } diff --git a/naga/src/valid/expression.rs b/naga/src/valid/expression.rs index 055ca6dfdb..bf46fd3262 100644 --- a/naga/src/valid/expression.rs +++ b/naga/src/valid/expression.rs @@ -213,7 +213,7 @@ impl super::Validator { crate::TypeInner::Scalar { .. } => {} _ => return Err(super::ConstExpressionError::InvalidSplatType(value)), }, - _ if global_expr_kind.is_const(handle) => { + _ if global_expr_kind.is_const(handle) || !self.allow_overrides => { return Err(super::ConstExpressionError::NonFullyEvaluatedConst) } // the constant evaluator will report errors about override-expressions diff --git a/naga/src/valid/mod.rs b/naga/src/valid/mod.rs index b9730c1f3d..f34c0f6f1a 100644 --- a/naga/src/valid/mod.rs +++ b/naga/src/valid/mod.rs @@ -175,6 +175,7 @@ pub struct Validator { valid_expression_list: Vec>, valid_expression_set: BitSet, override_ids: FastHashSet, + allow_overrides: bool, } #[derive(Clone, Debug, thiserror::Error)] @@ -203,6 +204,8 @@ pub enum OverrideError { NonConstructibleType, #[error("The type is not a scalar")] TypeNotScalar, + #[error("Override declarations are not allowed")] + NotAllowed, } #[derive(Clone, Debug, thiserror::Error)] @@ -322,6 +325,7 @@ impl Validator { valid_expression_list: Vec::new(), valid_expression_set: BitSet::new(), override_ids: FastHashSet::default(), + allow_overrides: true, } } @@ -370,6 +374,10 @@ impl Validator { gctx: crate::proc::GlobalCtx, mod_info: &ModuleInfo, ) -> Result<(), OverrideError> { + if !self.allow_overrides { + return Err(OverrideError::NotAllowed); + } + let o = &gctx.overrides[handle]; if o.name.is_none() && o.id.is_none() { @@ -414,6 +422,25 @@ impl Validator { pub fn validate( &mut self, module: &crate::Module, + ) -> Result> { + self.allow_overrides = true; + self.validate_impl(module) + } + + /// Check the given module to be valid. + /// + /// With the additional restriction that overrides are not present. + pub fn validate_no_overrides( + &mut self, + module: &crate::Module, + ) -> Result> { + self.allow_overrides = false; + self.validate_impl(module) + } + + fn validate_impl( + &mut self, + module: &crate::Module, ) -> Result> { self.reset(); self.reset_types(module.types.len()); From f3ff6b3c76e5e8f10ce99339c5c52e5707d782ca Mon Sep 17 00:00:00 2001 From: Jim Blandy Date: Tue, 12 Mar 2024 18:10:59 -0700 Subject: [PATCH 14/30] [naga] Add some documentation to process_overrides and subroutines. --- naga/src/back/pipeline_constants.rs | 80 +++++++++++++++++++++++++++++ 1 file changed, 80 insertions(+) diff --git a/naga/src/back/pipeline_constants.rs b/naga/src/back/pipeline_constants.rs index 9679aaecb9..298ccbc0d3 100644 --- a/naga/src/back/pipeline_constants.rs +++ b/naga/src/back/pipeline_constants.rs @@ -22,6 +22,19 @@ pub enum PipelineConstantError { ValidationError(#[from] WithSpan), } +/// Replace all overrides in `module` with constants. +/// +/// If no changes are needed, this just returns `Cow::Borrowed` +/// references to `module` and `module_info`. Otherwise, it clones +/// `module`, edits its [`global_expressions`] arena to contain only +/// fully-evaluated expressions, and returns `Cow::Owned` values +/// holding the simplified module and its validation results. +/// +/// In either case, the module returned has an empty `overrides` +/// arena, and the `global_expressions` arena contains only +/// fully-evaluated expressions. +/// +/// [`global_expressions`]: Module::global_expressions pub(super) fn process_overrides<'a>( module: &'a Module, module_info: &'a ModuleInfo, @@ -32,14 +45,62 @@ pub(super) fn process_overrides<'a>( } let mut module = module.clone(); + + // A map from override handles to the handles of the constants + // we've replaced them with. let mut override_map = Vec::with_capacity(module.overrides.len()); + + // A map from `module`'s original global expression handles to + // handles in the new, simplified global expression arena. let mut adjusted_global_expressions = Vec::with_capacity(module.global_expressions.len()); + + // The set of constants whose initializer handles we've already + // updated to refer to the newly built global expression arena. + // + // All constants in `module` must have their `init` handles + // updated to point into the new, simplified global expression + // arena. Some of these we can most easily handle as a side effect + // during the simplification process, but we must handle the rest + // in a final fixup pass, guided by `adjusted_global_expressions`. We + // add their handles to this set, so that the final fixup step can + // leave them alone. let mut adjusted_constant_initializers = HashSet::with_capacity(module.constants.len()); let mut global_expression_kind_tracker = crate::proc::ExpressionKindTracker::new(); + // An iterator through the original overrides table, consumed in + // approximate tandem with the global expressions. let mut override_iter = module.overrides.drain(); + // Do two things in tandem: + // + // - Rebuild the global expression arena from scratch, fully + // evaluating all expressions, and replacing each `Override` + // expression in `module.global_expressions` with a `Constant` + // expression. + // + // - Build a new `Constant` in `module.constants` to take the + // place of each `Override`. + // + // Build a map from old global expression handles to their + // fully-evaluated counterparts in `adjusted_global_expressions` as we + // go. + // + // Why in tandem? Overrides refer to expressions, and expressions + // refer to overrides, so we can't disentangle the two into + // separate phases. However, we can take advantage of the fact + // that the overrides and expressions must form a DAG, and work + // our way from the leaves to the roots, replacing and evaluating + // as we go. + // + // Although the two loops are nested, this is really two + // alternating phases: we adjust and evaluate constant expressions + // until we hit an `Override` expression, at which point we switch + // to building `Constant`s for `Overrides` until we've handled the + // one used by the expression. Then we switch back to processing + // expressions. Because we know they form a DAG, we know the + // `Override` expressions we encounter can only have initializers + // referring to global expressions we've already simplified. for (old_h, expr, span) in module.global_expressions.drain() { let mut expr = match expr { Expression::Override(h) => { @@ -84,6 +145,7 @@ pub(super) fn process_overrides<'a>( adjusted_global_expressions.push(h); } + // Finish processing any overrides we didn't visit in the loop above. for entry in override_iter { process_override( entry, @@ -96,6 +158,9 @@ pub(super) fn process_overrides<'a>( )?; } + // Update the initialization expression handles of all `Constant`s + // and `GlobalVariable`s. Skip `Constant`s we'd already updated en + // passant. for (_, c) in module .constants .iter_mut() @@ -110,12 +175,18 @@ pub(super) fn process_overrides<'a>( } } + // Now that the global expression arena has changed, we need to + // recompute those expressions' types. For the time being, do a + // full re-validation. let mut validator = Validator::new(ValidationFlags::all(), Capabilities::all()); let module_info = validator.validate_no_overrides(&module)?; Ok((Cow::Owned(module), Cow::Owned(module_info))) } +/// Add a [`Constant`] to `module` for the override `old_h`. +/// +/// Add the new `Constant` to `override_map` and `adjusted_constant_initializers`. fn process_override( (old_h, override_, span): (Handle, Override, Span), pipeline_constants: &PipelineConstants, @@ -125,6 +196,7 @@ fn process_override( adjusted_constant_initializers: &mut HashSet>, global_expression_kind_tracker: &mut crate::proc::ExpressionKindTracker, ) -> Result, PipelineConstantError> { + // Determine which key to use for `override_` in `pipeline_constants`. let key = if let Some(id) = override_.id { Cow::Owned(id.to_string()) } else if let Some(ref name) = override_.name { @@ -132,6 +204,10 @@ fn process_override( } else { unreachable!(); }; + + // Generate a global expression for `override_`'s value, either + // from the provided `pipeline_constants` table or its initializer + // in the module. let init = if let Some(value) = pipeline_constants.get::(&key) { let literal = match module.types[override_.ty].inner { TypeInner::Scalar(scalar) => map_value_to_literal(*value, scalar)?, @@ -147,6 +223,8 @@ fn process_override( } else { return Err(PipelineConstantError::MissingValue(key.to_string())); }; + + // Generate a new `Constant` to represent the override's value. let constant = Constant { name: override_.name, ty: override_.ty, @@ -159,6 +237,8 @@ fn process_override( Ok(h) } +/// Replace every expression handle in `expr` with its counterpart +/// given by `new_pos`. fn adjust_expr(new_pos: &[Handle], expr: &mut Expression) { let adjust = |expr: &mut Handle| { *expr = new_pos[expr.index()]; From 0c4b2d7993464c2737dc8fba52baa7f60b93fed6 Mon Sep 17 00:00:00 2001 From: teoxoy <28601907+teoxoy@users.noreply.github.com> Date: Thu, 14 Mar 2024 18:01:00 +0100 Subject: [PATCH 15/30] refactor `try_eval_and_append` body --- naga/src/proc/constant_evaluator.rs | 72 +++++++++++++++++------------ 1 file changed, 43 insertions(+), 29 deletions(-) diff --git a/naga/src/proc/constant_evaluator.rs b/naga/src/proc/constant_evaluator.rs index 532f364532..f1f01e5855 100644 --- a/naga/src/proc/constant_evaluator.rs +++ b/naga/src/proc/constant_evaluator.rs @@ -258,6 +258,17 @@ enum Behavior<'a> { Glsl(GlslRestrictions<'a>), } +impl Behavior<'_> { + /// Returns `true` if the inner WGSL/GLSL restrictions are runtime restrictions. + const fn has_runtime_restrictions(&self) -> bool { + matches!( + self, + &Behavior::Wgsl(WgslRestrictions::Runtime(_)) + | &Behavior::Glsl(GlslRestrictions::Runtime(_)) + ) + } +} + /// A context for evaluating constant expressions. /// /// A `ConstantEvaluator` points at an expression arena to which it can append @@ -699,37 +710,40 @@ impl<'a> ConstantEvaluator<'a> { expr: Expression, span: Span, ) -> Result, ConstantEvaluatorError> { - match ( - &self.behavior, - self.expression_kind_tracker.type_of_with_expr(&expr), - ) { - // avoid errors on unimplemented functionality if possible - ( - &Behavior::Wgsl(WgslRestrictions::Runtime(_)) - | &Behavior::Glsl(GlslRestrictions::Runtime(_)), - ExpressionKind::Const, - ) => match self.try_eval_and_append_impl(&expr, span) { - Err( - ConstantEvaluatorError::NotImplemented(_) - | ConstantEvaluatorError::InvalidBinaryOpArgs, - ) => Ok(self.append_expr(expr, span, ExpressionKind::Runtime)), - res => res, + match self.expression_kind_tracker.type_of_with_expr(&expr) { + ExpressionKind::Const => { + let eval_result = self.try_eval_and_append_impl(&expr, span); + // avoid errors on unimplemented functionality if possible + if self.behavior.has_runtime_restrictions() + && matches!( + eval_result, + Err(ConstantEvaluatorError::NotImplemented(_) + | ConstantEvaluatorError::InvalidBinaryOpArgs,) + ) + { + Ok(self.append_expr(expr, span, ExpressionKind::Runtime)) + } else { + eval_result + } + } + ExpressionKind::Override => match self.behavior { + Behavior::Wgsl(WgslRestrictions::Override | WgslRestrictions::Runtime(_)) => { + Ok(self.append_expr(expr, span, ExpressionKind::Override)) + } + Behavior::Wgsl(WgslRestrictions::Const) => { + Err(ConstantEvaluatorError::OverrideExpr) + } + Behavior::Glsl(_) => { + unreachable!() + } }, - (_, ExpressionKind::Const) => self.try_eval_and_append_impl(&expr, span), - (&Behavior::Wgsl(WgslRestrictions::Const), ExpressionKind::Override) => { - Err(ConstantEvaluatorError::OverrideExpr) + ExpressionKind::Runtime => { + if self.behavior.has_runtime_restrictions() { + Ok(self.append_expr(expr, span, ExpressionKind::Runtime)) + } else { + Err(ConstantEvaluatorError::RuntimeExpr) + } } - ( - &Behavior::Wgsl(WgslRestrictions::Override | WgslRestrictions::Runtime(_)), - ExpressionKind::Override, - ) => Ok(self.append_expr(expr, span, ExpressionKind::Override)), - (&Behavior::Glsl(_), ExpressionKind::Override) => unreachable!(), - ( - &Behavior::Wgsl(WgslRestrictions::Runtime(_)) - | &Behavior::Glsl(GlslRestrictions::Runtime(_)), - ExpressionKind::Runtime, - ) => Ok(self.append_expr(expr, span, ExpressionKind::Runtime)), - (_, ExpressionKind::Runtime) => Err(ConstantEvaluatorError::RuntimeExpr), } } From e2ff98be125f97eaf4b4dd9124b5af3b9b9caafe Mon Sep 17 00:00:00 2001 From: teoxoy <28601907+teoxoy@users.noreply.github.com> Date: Wed, 6 Mar 2024 12:22:33 +0100 Subject: [PATCH 16/30] evaluate override-expressions in functions --- naga/src/back/pipeline_constants.rs | 278 ++++++++++++++++++++- naga/tests/in/overrides.wgsl | 6 +- naga/tests/out/analysis/overrides.info.ron | 78 +++++- naga/tests/out/hlsl/overrides.hlsl | 5 + naga/tests/out/ir/overrides.compact.ron | 50 +++- naga/tests/out/ir/overrides.ron | 50 +++- naga/tests/out/msl/overrides.msl | 4 + naga/tests/out/spv/overrides.main.spvasm | 15 +- 8 files changed, 472 insertions(+), 14 deletions(-) diff --git a/naga/src/back/pipeline_constants.rs b/naga/src/back/pipeline_constants.rs index 298ccbc0d3..bd9eec76ee 100644 --- a/naga/src/back/pipeline_constants.rs +++ b/naga/src/back/pipeline_constants.rs @@ -1,10 +1,11 @@ use super::PipelineConstants; use crate::{ - proc::{ConstantEvaluator, ConstantEvaluatorError}, + proc::{ConstantEvaluator, ConstantEvaluatorError, Emitter}, valid::{Capabilities, ModuleInfo, ValidationError, ValidationFlags, Validator}, - Constant, Expression, Handle, Literal, Module, Override, Scalar, Span, TypeInner, WithSpan, + Arena, Block, Constant, Expression, Function, Handle, Literal, Module, Override, Range, Scalar, + Span, Statement, SwitchCase, TypeInner, WithSpan, }; -use std::{borrow::Cow, collections::HashSet}; +use std::{borrow::Cow, collections::HashSet, mem}; use thiserror::Error; #[derive(Error, Debug, Clone)] @@ -175,6 +176,18 @@ pub(super) fn process_overrides<'a>( } } + let mut functions = mem::take(&mut module.functions); + for (_, function) in functions.iter_mut() { + process_function(&mut module, &override_map, function)?; + } + let _ = mem::replace(&mut module.functions, functions); + + let mut entry_points = mem::take(&mut module.entry_points); + for ep in entry_points.iter_mut() { + process_function(&mut module, &override_map, &mut ep.function)?; + } + let _ = mem::replace(&mut module.entry_points, entry_points); + // Now that the global expression arena has changed, we need to // recompute those expressions' types. For the time being, do a // full re-validation. @@ -237,6 +250,64 @@ fn process_override( Ok(h) } +/// Replaces all `Expression::Override`s in this function's expression arena +/// with `Expression::Constant` and evaluates all expressions in its arena. +fn process_function( + module: &mut Module, + override_map: &[Handle], + function: &mut Function, +) -> Result<(), ConstantEvaluatorError> { + // A map from original local expression handles to + // handles in the new, local expression arena. + let mut adjusted_local_expressions = Vec::with_capacity(function.expressions.len()); + + let mut local_expression_kind_tracker = crate::proc::ExpressionKindTracker::new(); + + let mut expressions = mem::take(&mut function.expressions); + + // Dummy `emitter` and `block` for the constant evaluator. + // We can ignore the concept of emitting expressions here since + // expressions have already been covered by a `Statement::Emit` + // in the frontend. + // The only thing we might have to do is remove some expressions + // that have been covered by a `Statement::Emit`. See the docs of + // `filter_emits_in_block` for the reasoning. + let mut emitter = Emitter::default(); + let mut block = Block::new(); + + for (old_h, expr, span) in expressions.drain() { + let mut expr = match expr { + Expression::Override(h) => Expression::Constant(override_map[h.index()]), + expr => expr, + }; + let mut evaluator = ConstantEvaluator::for_wgsl_function( + module, + &mut function.expressions, + &mut local_expression_kind_tracker, + &mut emitter, + &mut block, + ); + adjust_expr(&adjusted_local_expressions, &mut expr); + let h = evaluator.try_eval_and_append(expr, span)?; + debug_assert_eq!(old_h.index(), adjusted_local_expressions.len()); + adjusted_local_expressions.push(h); + } + + adjust_block(&adjusted_local_expressions, &mut function.body); + + let new_body = filter_emits_in_block(&function.body, &function.expressions); + let _ = mem::replace(&mut function.body, new_body); + + let named_expressions = mem::take(&mut function.named_expressions); + for (expr_h, name) in named_expressions { + function + .named_expressions + .insert(adjusted_local_expressions[expr_h.index()], name); + } + + Ok(()) +} + /// Replace every expression handle in `expr` with its counterpart /// given by `new_pos`. fn adjust_expr(new_pos: &[Handle], expr: &mut Expression) { @@ -409,6 +480,207 @@ fn adjust_expr(new_pos: &[Handle], expr: &mut Expression) { } } +/// Replace every expression handle in `block` with its counterpart +/// given by `new_pos`. +fn adjust_block(new_pos: &[Handle], block: &mut Block) { + for stmt in block.iter_mut() { + adjust_stmt(new_pos, stmt); + } +} + +/// Replace every expression handle in `stmt` with its counterpart +/// given by `new_pos`. +fn adjust_stmt(new_pos: &[Handle], stmt: &mut Statement) { + let adjust = |expr: &mut Handle| { + *expr = new_pos[expr.index()]; + }; + match *stmt { + Statement::Emit(ref mut range) => { + if let Some((mut first, mut last)) = range.first_and_last() { + adjust(&mut first); + adjust(&mut last); + *range = Range::new_from_bounds(first, last); + } + } + Statement::Block(ref mut block) => { + adjust_block(new_pos, block); + } + Statement::If { + ref mut condition, + ref mut accept, + ref mut reject, + } => { + adjust(condition); + adjust_block(new_pos, accept); + adjust_block(new_pos, reject); + } + Statement::Switch { + ref mut selector, + ref mut cases, + } => { + adjust(selector); + for case in cases.iter_mut() { + adjust_block(new_pos, &mut case.body); + } + } + Statement::Loop { + ref mut body, + ref mut continuing, + ref mut break_if, + } => { + adjust_block(new_pos, body); + adjust_block(new_pos, continuing); + if let Some(e) = break_if.as_mut() { + adjust(e); + } + } + Statement::Return { ref mut value } => { + if let Some(e) = value.as_mut() { + adjust(e); + } + } + Statement::Store { + ref mut pointer, + ref mut value, + } => { + adjust(pointer); + adjust(value); + } + Statement::ImageStore { + ref mut image, + ref mut coordinate, + ref mut array_index, + ref mut value, + } => { + adjust(image); + adjust(coordinate); + if let Some(e) = array_index.as_mut() { + adjust(e); + } + adjust(value); + } + crate::Statement::Atomic { + ref mut pointer, + ref mut value, + ref mut result, + .. + } => { + adjust(pointer); + adjust(value); + adjust(result); + } + Statement::WorkGroupUniformLoad { + ref mut pointer, + ref mut result, + } => { + adjust(pointer); + adjust(result); + } + Statement::Call { + ref mut arguments, + ref mut result, + .. + } => { + for argument in arguments.iter_mut() { + adjust(argument); + } + if let Some(e) = result.as_mut() { + adjust(e); + } + } + Statement::RayQuery { ref mut query, .. } => { + adjust(query); + } + Statement::Break | Statement::Continue | Statement::Kill | Statement::Barrier(_) => {} + } +} + +/// Filters out expressions that `needs_pre_emit`. This step is necessary after +/// const evaluation since unevaluated expressions could have been included in +/// `Statement::Emit`; but since they have been evaluated we need to filter those +/// out. +fn filter_emits_in_block(block: &Block, expressions: &Arena) -> Block { + let mut out = Block::with_capacity(block.len()); + for (stmt, span) in block.span_iter() { + match stmt { + &Statement::Emit(ref range) => { + let mut current = None; + for expr_h in range.clone() { + if expressions[expr_h].needs_pre_emit() { + if let Some((first, last)) = current { + out.push(Statement::Emit(Range::new_from_bounds(first, last)), *span); + } + + current = None; + } else if let Some((_, ref mut last)) = current { + *last = expr_h; + } else { + current = Some((expr_h, expr_h)); + } + } + if let Some((first, last)) = current { + out.push(Statement::Emit(Range::new_from_bounds(first, last)), *span); + } + } + &Statement::Block(ref block) => { + let block = filter_emits_in_block(block, expressions); + out.push(Statement::Block(block), *span); + } + &Statement::If { + condition, + ref accept, + ref reject, + } => { + let accept = filter_emits_in_block(accept, expressions); + let reject = filter_emits_in_block(reject, expressions); + out.push( + Statement::If { + condition, + accept, + reject, + }, + *span, + ); + } + &Statement::Switch { + selector, + ref cases, + } => { + let cases = cases + .iter() + .map(|case| { + let body = filter_emits_in_block(&case.body, expressions); + SwitchCase { + value: case.value, + body, + fall_through: case.fall_through, + } + }) + .collect(); + out.push(Statement::Switch { selector, cases }, *span); + } + &Statement::Loop { + ref body, + ref continuing, + break_if, + } => { + let body = filter_emits_in_block(body, expressions); + let continuing = filter_emits_in_block(continuing, expressions); + out.push( + Statement::Loop { + body, + continuing, + break_if, + }, + *span, + ); + } + stmt => out.push(stmt.clone(), *span), + } + } + out +} + fn map_value_to_literal(value: f64, scalar: Scalar) -> Result { // note that in rust 0.0 == -0.0 match scalar { diff --git a/naga/tests/in/overrides.wgsl b/naga/tests/in/overrides.wgsl index 41e99f9426..b06edecdb9 100644 --- a/naga/tests/in/overrides.wgsl +++ b/naga/tests/in/overrides.wgsl @@ -14,4 +14,8 @@ override inferred_f32 = 2.718; @compute @workgroup_size(1) -fn main() {} \ No newline at end of file +fn main() { + var t = height * 5; + let a = !has_point_light; + var x = a; +} \ No newline at end of file diff --git a/naga/tests/out/analysis/overrides.info.ron b/naga/tests/out/analysis/overrides.info.ron index 7a2447f3c0..389e7fba7f 100644 --- a/naga/tests/out/analysis/overrides.info.ron +++ b/naga/tests/out/analysis/overrides.info.ron @@ -15,7 +15,83 @@ may_kill: false, sampling_set: [], global_uses: [], - expressions: [], + expressions: [ + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Handle(2), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Float, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Float, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: Some(4), + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Pointer( + base: 2, + space: Function, + )), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Handle(1), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Handle(1), + ), + ( + uniformity: ( + non_uniform_result: Some(7), + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Pointer( + base: 1, + space: Function, + )), + ), + ], sampling: [], dual_source_blending: false, ), diff --git a/naga/tests/out/hlsl/overrides.hlsl b/naga/tests/out/hlsl/overrides.hlsl index 0a849fd4db..1541ae7281 100644 --- a/naga/tests/out/hlsl/overrides.hlsl +++ b/naga/tests/out/hlsl/overrides.hlsl @@ -9,5 +9,10 @@ static const float inferred_f32_ = 2.718; [numthreads(1, 1, 1)] void main() { + float t = (float)0; + bool x = (bool)0; + + t = 23.0; + x = true; return; } diff --git a/naga/tests/out/ir/overrides.compact.ron b/naga/tests/out/ir/overrides.compact.ron index 7a60f14239..b0a230a716 100644 --- a/naga/tests/out/ir/overrides.compact.ron +++ b/naga/tests/out/ir/overrides.compact.ron @@ -90,10 +90,54 @@ name: Some("main"), arguments: [], result: None, - local_variables: [], - expressions: [], - named_expressions: {}, + local_variables: [ + ( + name: Some("t"), + ty: 2, + init: None, + ), + ( + name: Some("x"), + ty: 1, + init: None, + ), + ], + expressions: [ + Override(6), + Literal(F32(5.0)), + Binary( + op: Multiply, + left: 1, + right: 2, + ), + LocalVariable(1), + Override(1), + Unary( + op: LogicalNot, + expr: 5, + ), + LocalVariable(2), + ], + named_expressions: { + 6: "a", + }, body: [ + Emit(( + start: 2, + end: 3, + )), + Store( + pointer: 4, + value: 3, + ), + Emit(( + start: 5, + end: 6, + )), + Store( + pointer: 7, + value: 6, + ), Return( value: None, ), diff --git a/naga/tests/out/ir/overrides.ron b/naga/tests/out/ir/overrides.ron index 7a60f14239..b0a230a716 100644 --- a/naga/tests/out/ir/overrides.ron +++ b/naga/tests/out/ir/overrides.ron @@ -90,10 +90,54 @@ name: Some("main"), arguments: [], result: None, - local_variables: [], - expressions: [], - named_expressions: {}, + local_variables: [ + ( + name: Some("t"), + ty: 2, + init: None, + ), + ( + name: Some("x"), + ty: 1, + init: None, + ), + ], + expressions: [ + Override(6), + Literal(F32(5.0)), + Binary( + op: Multiply, + left: 1, + right: 2, + ), + LocalVariable(1), + Override(1), + Unary( + op: LogicalNot, + expr: 5, + ), + LocalVariable(2), + ], + named_expressions: { + 6: "a", + }, body: [ + Emit(( + start: 2, + end: 3, + )), + Store( + pointer: 4, + value: 3, + ), + Emit(( + start: 5, + end: 6, + )), + Store( + pointer: 7, + value: 6, + ), Return( value: None, ), diff --git a/naga/tests/out/msl/overrides.msl b/naga/tests/out/msl/overrides.msl index 13a3b623a0..0bc9e6b12c 100644 --- a/naga/tests/out/msl/overrides.msl +++ b/naga/tests/out/msl/overrides.msl @@ -14,5 +14,9 @@ constant float inferred_f32_ = 2.718; kernel void main_( ) { + float t = {}; + bool x = {}; + t = 23.0; + x = true; return; } diff --git a/naga/tests/out/spv/overrides.main.spvasm b/naga/tests/out/spv/overrides.main.spvasm index 7731edfb93..d421606ca9 100644 --- a/naga/tests/out/spv/overrides.main.spvasm +++ b/naga/tests/out/spv/overrides.main.spvasm @@ -1,7 +1,7 @@ ; SPIR-V ; Version: 1.0 ; Generator: rspirv -; Bound: 17 +; Bound: 24 OpCapability Shader %1 = OpExtInstImport "GLSL.std.450" OpMemoryModel Logical GLSL450 @@ -19,9 +19,18 @@ OpExecutionMode %14 LocalSize 1 1 1 %11 = OpConstant %4 4.6 %12 = OpConstant %4 2.718 %15 = OpTypeFunction %2 +%16 = OpConstant %4 23.0 +%18 = OpTypePointer Function %4 +%19 = OpConstantNull %4 +%21 = OpTypePointer Function %3 +%22 = OpConstantNull %3 %14 = OpFunction %2 None %15 %13 = OpLabel -OpBranch %16 -%16 = OpLabel +%17 = OpVariable %18 Function %19 +%20 = OpVariable %21 Function %22 +OpBranch %23 +%23 = OpLabel +OpStore %17 %16 +OpStore %20 %5 OpReturn OpFunctionEnd \ No newline at end of file From e2fcc2603c07ec0196411161abe810148efcb2a8 Mon Sep 17 00:00:00 2001 From: teoxoy <28601907+teoxoy@users.noreply.github.com> Date: Wed, 6 Mar 2024 12:23:16 +0100 Subject: [PATCH 17/30] allow private variables to have an override-expression initializer --- naga/src/front/wgsl/lower/mod.rs | 2 +- naga/src/valid/interface.rs | 4 +- naga/tests/in/overrides.wgsl | 4 ++ naga/tests/out/analysis/overrides.info.ron | 70 +++++++++++++++++++++- naga/tests/out/hlsl/overrides.hlsl | 5 ++ naga/tests/out/ir/overrides.compact.ron | 45 +++++++++++++- naga/tests/out/ir/overrides.ron | 45 +++++++++++++- naga/tests/out/msl/overrides.msl | 4 ++ naga/tests/out/spv/overrides.main.spvasm | 43 +++++++------ 9 files changed, 199 insertions(+), 23 deletions(-) diff --git a/naga/src/front/wgsl/lower/mod.rs b/naga/src/front/wgsl/lower/mod.rs index 1a8b75811b..7abd95114d 100644 --- a/naga/src/front/wgsl/lower/mod.rs +++ b/naga/src/front/wgsl/lower/mod.rs @@ -916,7 +916,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { let init; if let Some(init_ast) = v.init { - let mut ectx = ctx.as_const(); + let mut ectx = ctx.as_override(); let lowered = self.expression_for_abstract(init_ast, &mut ectx)?; let ty_res = crate::proc::TypeResolution::Handle(ty); let converted = ectx diff --git a/naga/src/valid/interface.rs b/naga/src/valid/interface.rs index 0e42075de1..2435b34c29 100644 --- a/naga/src/valid/interface.rs +++ b/naga/src/valid/interface.rs @@ -31,7 +31,7 @@ pub enum GlobalVariableError { Handle, #[source] Disalignment, ), - #[error("Initializer must be a const-expression")] + #[error("Initializer must be an override-expression")] InitializerExprType, #[error("Initializer doesn't match the variable type")] InitializerType, @@ -529,7 +529,7 @@ impl super::Validator { } } - if !global_expr_kind.is_const(init) { + if !global_expr_kind.is_const_or_override(init) { return Err(GlobalVariableError::InitializerExprType); } diff --git a/naga/tests/in/overrides.wgsl b/naga/tests/in/overrides.wgsl index b06edecdb9..ab1d637a11 100644 --- a/naga/tests/in/overrides.wgsl +++ b/naga/tests/in/overrides.wgsl @@ -13,9 +13,13 @@ override inferred_f32 = 2.718; +var gain_x_10: f32 = gain * 10.; + @compute @workgroup_size(1) fn main() { var t = height * 5; let a = !has_point_light; var x = a; + + var gain_x_100 = gain_x_10 * 10.; } \ No newline at end of file diff --git a/naga/tests/out/analysis/overrides.info.ron b/naga/tests/out/analysis/overrides.info.ron index 389e7fba7f..6ea54bb296 100644 --- a/naga/tests/out/analysis/overrides.info.ron +++ b/naga/tests/out/analysis/overrides.info.ron @@ -14,7 +14,9 @@ ), may_kill: false, sampling_set: [], - global_uses: [], + global_uses: [ + ("READ"), + ], expressions: [ ( uniformity: ( @@ -91,6 +93,63 @@ space: Function, )), ), + ( + uniformity: ( + non_uniform_result: Some(8), + requirements: (""), + ), + ref_count: 1, + assignable_global: Some(1), + ty: Value(Pointer( + base: 2, + space: Private, + )), + ), + ( + uniformity: ( + non_uniform_result: Some(8), + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Handle(2), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Float, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: Some(8), + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Float, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: Some(12), + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Pointer( + base: 2, + space: Function, + )), + ), ], sampling: [], dual_source_blending: false, @@ -119,5 +178,14 @@ kind: Float, width: 4, ))), + Handle(2), + Value(Scalar(( + kind: Float, + width: 4, + ))), + Value(Scalar(( + kind: Float, + width: 4, + ))), ], ) \ No newline at end of file diff --git a/naga/tests/out/hlsl/overrides.hlsl b/naga/tests/out/hlsl/overrides.hlsl index 1541ae7281..072cd9ffcc 100644 --- a/naga/tests/out/hlsl/overrides.hlsl +++ b/naga/tests/out/hlsl/overrides.hlsl @@ -6,13 +6,18 @@ static const float depth = 2.3; static const float height = 4.6; static const float inferred_f32_ = 2.718; +static float gain_x_10_ = 11.0; + [numthreads(1, 1, 1)] void main() { float t = (float)0; bool x = (bool)0; + float gain_x_100_ = (float)0; t = 23.0; x = true; + float _expr10 = gain_x_10_; + gain_x_100_ = (_expr10 * 10.0); return; } diff --git a/naga/tests/out/ir/overrides.compact.ron b/naga/tests/out/ir/overrides.compact.ron index b0a230a716..4188354224 100644 --- a/naga/tests/out/ir/overrides.compact.ron +++ b/naga/tests/out/ir/overrides.compact.ron @@ -65,7 +65,15 @@ init: Some(7), ), ], - global_variables: [], + global_variables: [ + ( + name: Some("gain_x_10"), + space: Private, + binding: None, + ty: 2, + init: Some(10), + ), + ], global_expressions: [ Literal(Bool(true)), Literal(F32(2.3)), @@ -78,6 +86,13 @@ right: 4, ), Literal(F32(2.718)), + Override(3), + Literal(F32(10.0)), + Binary( + op: Multiply, + left: 8, + right: 9, + ), ], functions: [], entry_points: [ @@ -101,6 +116,11 @@ ty: 1, init: None, ), + ( + name: Some("gain_x_100"), + ty: 2, + init: None, + ), ], expressions: [ Override(6), @@ -117,6 +137,17 @@ expr: 5, ), LocalVariable(2), + GlobalVariable(1), + Load( + pointer: 8, + ), + Literal(F32(10.0)), + Binary( + op: Multiply, + left: 9, + right: 10, + ), + LocalVariable(3), ], named_expressions: { 6: "a", @@ -138,6 +169,18 @@ pointer: 7, value: 6, ), + Emit(( + start: 8, + end: 9, + )), + Emit(( + start: 10, + end: 11, + )), + Store( + pointer: 12, + value: 11, + ), Return( value: None, ), diff --git a/naga/tests/out/ir/overrides.ron b/naga/tests/out/ir/overrides.ron index b0a230a716..4188354224 100644 --- a/naga/tests/out/ir/overrides.ron +++ b/naga/tests/out/ir/overrides.ron @@ -65,7 +65,15 @@ init: Some(7), ), ], - global_variables: [], + global_variables: [ + ( + name: Some("gain_x_10"), + space: Private, + binding: None, + ty: 2, + init: Some(10), + ), + ], global_expressions: [ Literal(Bool(true)), Literal(F32(2.3)), @@ -78,6 +86,13 @@ right: 4, ), Literal(F32(2.718)), + Override(3), + Literal(F32(10.0)), + Binary( + op: Multiply, + left: 8, + right: 9, + ), ], functions: [], entry_points: [ @@ -101,6 +116,11 @@ ty: 1, init: None, ), + ( + name: Some("gain_x_100"), + ty: 2, + init: None, + ), ], expressions: [ Override(6), @@ -117,6 +137,17 @@ expr: 5, ), LocalVariable(2), + GlobalVariable(1), + Load( + pointer: 8, + ), + Literal(F32(10.0)), + Binary( + op: Multiply, + left: 9, + right: 10, + ), + LocalVariable(3), ], named_expressions: { 6: "a", @@ -138,6 +169,18 @@ pointer: 7, value: 6, ), + Emit(( + start: 8, + end: 9, + )), + Emit(( + start: 10, + end: 11, + )), + Store( + pointer: 12, + value: 11, + ), Return( value: None, ), diff --git a/naga/tests/out/msl/overrides.msl b/naga/tests/out/msl/overrides.msl index 0bc9e6b12c..f884d1b527 100644 --- a/naga/tests/out/msl/overrides.msl +++ b/naga/tests/out/msl/overrides.msl @@ -14,9 +14,13 @@ constant float inferred_f32_ = 2.718; kernel void main_( ) { + float gain_x_10_ = 11.0; float t = {}; bool x = {}; + float gain_x_100_ = {}; t = 23.0; x = true; + float _e10 = gain_x_10_; + gain_x_100_ = _e10 * 10.0; return; } diff --git a/naga/tests/out/spv/overrides.main.spvasm b/naga/tests/out/spv/overrides.main.spvasm index d421606ca9..d4ce4752ed 100644 --- a/naga/tests/out/spv/overrides.main.spvasm +++ b/naga/tests/out/spv/overrides.main.spvasm @@ -1,12 +1,12 @@ ; SPIR-V ; Version: 1.0 ; Generator: rspirv -; Bound: 24 +; Bound: 32 OpCapability Shader %1 = OpExtInstImport "GLSL.std.450" OpMemoryModel Logical GLSL450 -OpEntryPoint GLCompute %14 "main" -OpExecutionMode %14 LocalSize 1 1 1 +OpEntryPoint GLCompute %18 "main" +OpExecutionMode %18 LocalSize 1 1 1 %2 = OpTypeVoid %3 = OpTypeBool %4 = OpTypeFloat 32 @@ -18,19 +18,28 @@ OpExecutionMode %14 LocalSize 1 1 1 %10 = OpConstant %4 2.0 %11 = OpConstant %4 4.6 %12 = OpConstant %4 2.718 -%15 = OpTypeFunction %2 -%16 = OpConstant %4 23.0 -%18 = OpTypePointer Function %4 -%19 = OpConstantNull %4 -%21 = OpTypePointer Function %3 -%22 = OpConstantNull %3 -%14 = OpFunction %2 None %15 -%13 = OpLabel -%17 = OpVariable %18 Function %19 -%20 = OpVariable %21 Function %22 -OpBranch %23 -%23 = OpLabel -OpStore %17 %16 -OpStore %20 %5 +%13 = OpConstant %4 10.0 +%14 = OpConstant %4 11.0 +%16 = OpTypePointer Private %4 +%15 = OpVariable %16 Private %14 +%19 = OpTypeFunction %2 +%20 = OpConstant %4 23.0 +%22 = OpTypePointer Function %4 +%23 = OpConstantNull %4 +%25 = OpTypePointer Function %3 +%26 = OpConstantNull %3 +%28 = OpConstantNull %4 +%18 = OpFunction %2 None %19 +%17 = OpLabel +%21 = OpVariable %22 Function %23 +%24 = OpVariable %25 Function %26 +%27 = OpVariable %22 Function %28 +OpBranch %29 +%29 = OpLabel +OpStore %21 %20 +OpStore %24 %5 +%30 = OpLoad %4 %15 +%31 = OpFMul %4 %30 %13 +OpStore %27 %31 OpReturn OpFunctionEnd \ No newline at end of file From dcef6c6f36a128b33d3a587970fb8ac553483a4f Mon Sep 17 00:00:00 2001 From: Jim Blandy Date: Sat, 23 Mar 2024 07:52:14 -0700 Subject: [PATCH 18/30] [naga] Doc tweaks for `back::pipeline_constants`. --- naga/src/back/pipeline_constants.rs | 39 ++++++++++++++++++++++------- 1 file changed, 30 insertions(+), 9 deletions(-) diff --git a/naga/src/back/pipeline_constants.rs b/naga/src/back/pipeline_constants.rs index bd9eec76ee..143afd8a57 100644 --- a/naga/src/back/pipeline_constants.rs +++ b/naga/src/back/pipeline_constants.rs @@ -188,9 +188,9 @@ pub(super) fn process_overrides<'a>( } let _ = mem::replace(&mut module.entry_points, entry_points); - // Now that the global expression arena has changed, we need to - // recompute those expressions' types. For the time being, do a - // full re-validation. + // Now that we've rewritten all the expressions, we need to + // recompute their types and other metadata. For the time being, + // do a full re-validation. let mut validator = Validator::new(ValidationFlags::all(), Capabilities::all()); let module_info = validator.validate_no_overrides(&module)?; @@ -250,8 +250,15 @@ fn process_override( Ok(h) } -/// Replaces all `Expression::Override`s in this function's expression arena -/// with `Expression::Constant` and evaluates all expressions in its arena. +/// Replace all override expressions in `function` with fully-evaluated constants. +/// +/// Replace all `Expression::Override`s in `function`'s expression arena with +/// the corresponding `Expression::Constant`s, as given in `override_map`. +/// Replace any expressions whose values are now known with their fully +/// evaluated form. +/// +/// If `h` is a `Handle`, then `override_map[h.index()]` is the +/// `Handle` for the override's final value. fn process_function( module: &mut Module, override_map: &[Handle], @@ -298,6 +305,8 @@ fn process_function( let new_body = filter_emits_in_block(&function.body, &function.expressions); let _ = mem::replace(&mut function.body, new_body); + // We've changed the keys of `function.named_expression`, so we have to + // rebuild it from scratch. let named_expressions = mem::take(&mut function.named_expressions); for (expr_h, name) in named_expressions { function @@ -595,10 +604,22 @@ fn adjust_stmt(new_pos: &[Handle], stmt: &mut Statement) { } } -/// Filters out expressions that `needs_pre_emit`. This step is necessary after -/// const evaluation since unevaluated expressions could have been included in -/// `Statement::Emit`; but since they have been evaluated we need to filter those -/// out. +/// Adjust [`Emit`] statements in `block` to skip [`needs_pre_emit`] expressions we have introduced. +/// +/// According to validation, [`Emit`] statements must not cover any expressions +/// for which [`Expression::needs_pre_emit`] returns true. All expressions built +/// by successful constant evaluation fall into that category, meaning that +/// `process_function` will usually rewrite [`Override`] expressions and those +/// that use their values into pre-emitted expressions, leaving any [`Emit`] +/// statements that cover them invalid. +/// +/// This function rewrites all [`Emit`] statements into zero or more new +/// [`Emit`] statements covering only those expressions in the original range +/// that are not pre-emitted. +/// +/// [`Emit`]: Statement::Emit +/// [`needs_pre_emit`]: Expression::needs_pre_emit +/// [`Override`]: Expression::Override fn filter_emits_in_block(block: &Block, expressions: &Arena) -> Block { let mut out = Block::with_capacity(block.len()); for (stmt, span) in block.span_iter() { From df8c8cb01919849a77e328fde4077751dc25f249 Mon Sep 17 00:00:00 2001 From: Jim Blandy Date: Sat, 23 Mar 2024 07:53:05 -0700 Subject: [PATCH 19/30] [naga] Simplify uses of `replace` in `back::pipeline_constants`. --- naga/src/back/pipeline_constants.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/naga/src/back/pipeline_constants.rs b/naga/src/back/pipeline_constants.rs index 143afd8a57..0cc5df5732 100644 --- a/naga/src/back/pipeline_constants.rs +++ b/naga/src/back/pipeline_constants.rs @@ -180,13 +180,13 @@ pub(super) fn process_overrides<'a>( for (_, function) in functions.iter_mut() { process_function(&mut module, &override_map, function)?; } - let _ = mem::replace(&mut module.functions, functions); + module.functions = functions; let mut entry_points = mem::take(&mut module.entry_points); for ep in entry_points.iter_mut() { process_function(&mut module, &override_map, &mut ep.function)?; } - let _ = mem::replace(&mut module.entry_points, entry_points); + module.entry_points = entry_points; // Now that we've rewritten all the expressions, we need to // recompute their types and other metadata. For the time being, @@ -303,7 +303,7 @@ fn process_function( adjust_block(&adjusted_local_expressions, &mut function.body); let new_body = filter_emits_in_block(&function.body, &function.expressions); - let _ = mem::replace(&mut function.body, new_body); + function.body = new_body; // We've changed the keys of `function.named_expression`, so we have to // rebuild it from scratch. From 106f56a136aea7e450e70cf32161643ad39688df Mon Sep 17 00:00:00 2001 From: Jim Blandy Date: Sat, 23 Mar 2024 07:53:39 -0700 Subject: [PATCH 20/30] [naga] Hoist `ConstantEvaluator` construction in `process_function`. There's no need to build a fresh `ConstantEvaluator` for every expression; just build it once and reuse it. --- naga/src/back/pipeline_constants.rs | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/naga/src/back/pipeline_constants.rs b/naga/src/back/pipeline_constants.rs index 0cc5df5732..b8789d3b93 100644 --- a/naga/src/back/pipeline_constants.rs +++ b/naga/src/back/pipeline_constants.rs @@ -282,18 +282,18 @@ fn process_function( let mut emitter = Emitter::default(); let mut block = Block::new(); - for (old_h, expr, span) in expressions.drain() { - let mut expr = match expr { - Expression::Override(h) => Expression::Constant(override_map[h.index()]), - expr => expr, - }; - let mut evaluator = ConstantEvaluator::for_wgsl_function( - module, - &mut function.expressions, - &mut local_expression_kind_tracker, - &mut emitter, - &mut block, - ); + let mut evaluator = ConstantEvaluator::for_wgsl_function( + module, + &mut function.expressions, + &mut local_expression_kind_tracker, + &mut emitter, + &mut block, + ); + + for (old_h, mut expr, span) in expressions.drain() { + if let Expression::Override(h) = expr { + expr = Expression::Constant(override_map[h.index()]); + } adjust_expr(&adjusted_local_expressions, &mut expr); let h = evaluator.try_eval_and_append(expr, span)?; debug_assert_eq!(old_h.index(), adjusted_local_expressions.len()); From ce772879dce27ee30c8b68a28b1d734fd059ebc1 Mon Sep 17 00:00:00 2001 From: Jim Blandy Date: Sat, 23 Mar 2024 07:47:49 -0700 Subject: [PATCH 21/30] [naga] Let `filter_emits_with_block` operate on a `&mut Block`. This removes some clones and collects, simplifies call sites, and isn't any more complicated to implement. --- naga/src/back/pipeline_constants.rs | 76 +++++++++++++---------------- naga/src/block.rs | 6 +++ 2 files changed, 39 insertions(+), 43 deletions(-) diff --git a/naga/src/back/pipeline_constants.rs b/naga/src/back/pipeline_constants.rs index b8789d3b93..62e3cd0e42 100644 --- a/naga/src/back/pipeline_constants.rs +++ b/naga/src/back/pipeline_constants.rs @@ -3,7 +3,7 @@ use crate::{ proc::{ConstantEvaluator, ConstantEvaluatorError, Emitter}, valid::{Capabilities, ModuleInfo, ValidationError, ValidationFlags, Validator}, Arena, Block, Constant, Expression, Function, Handle, Literal, Module, Override, Range, Scalar, - Span, Statement, SwitchCase, TypeInner, WithSpan, + Span, Statement, TypeInner, WithSpan, }; use std::{borrow::Cow, collections::HashSet, mem}; use thiserror::Error; @@ -302,8 +302,7 @@ fn process_function( adjust_block(&adjusted_local_expressions, &mut function.body); - let new_body = filter_emits_in_block(&function.body, &function.expressions); - function.body = new_body; + filter_emits_in_block(&mut function.body, &function.expressions); // We've changed the keys of `function.named_expression`, so we have to // rebuild it from scratch. @@ -620,16 +619,16 @@ fn adjust_stmt(new_pos: &[Handle], stmt: &mut Statement) { /// [`Emit`]: Statement::Emit /// [`needs_pre_emit`]: Expression::needs_pre_emit /// [`Override`]: Expression::Override -fn filter_emits_in_block(block: &Block, expressions: &Arena) -> Block { - let mut out = Block::with_capacity(block.len()); - for (stmt, span) in block.span_iter() { +fn filter_emits_in_block(block: &mut Block, expressions: &Arena) { + let original = std::mem::replace(block, Block::with_capacity(block.len())); + for (stmt, span) in original.span_into_iter() { match stmt { - &Statement::Emit(ref range) => { + Statement::Emit(range) => { let mut current = None; - for expr_h in range.clone() { + for expr_h in range { if expressions[expr_h].needs_pre_emit() { if let Some((first, last)) = current { - out.push(Statement::Emit(Range::new_from_bounds(first, last)), *span); + block.push(Statement::Emit(Range::new_from_bounds(first, last)), span); } current = None; @@ -640,66 +639,57 @@ fn filter_emits_in_block(block: &Block, expressions: &Arena) -> Bloc } } if let Some((first, last)) = current { - out.push(Statement::Emit(Range::new_from_bounds(first, last)), *span); + block.push(Statement::Emit(Range::new_from_bounds(first, last)), span); } } - &Statement::Block(ref block) => { - let block = filter_emits_in_block(block, expressions); - out.push(Statement::Block(block), *span); + Statement::Block(mut child) => { + filter_emits_in_block(&mut child, expressions); + block.push(Statement::Block(child), span); } - &Statement::If { + Statement::If { condition, - ref accept, - ref reject, + mut accept, + mut reject, } => { - let accept = filter_emits_in_block(accept, expressions); - let reject = filter_emits_in_block(reject, expressions); - out.push( + filter_emits_in_block(&mut accept, expressions); + filter_emits_in_block(&mut reject, expressions); + block.push( Statement::If { condition, accept, reject, }, - *span, + span, ); } - &Statement::Switch { + Statement::Switch { selector, - ref cases, + mut cases, } => { - let cases = cases - .iter() - .map(|case| { - let body = filter_emits_in_block(&case.body, expressions); - SwitchCase { - value: case.value, - body, - fall_through: case.fall_through, - } - }) - .collect(); - out.push(Statement::Switch { selector, cases }, *span); + for case in &mut cases { + filter_emits_in_block(&mut case.body, expressions); + } + block.push(Statement::Switch { selector, cases }, span); } - &Statement::Loop { - ref body, - ref continuing, + Statement::Loop { + mut body, + mut continuing, break_if, } => { - let body = filter_emits_in_block(body, expressions); - let continuing = filter_emits_in_block(continuing, expressions); - out.push( + filter_emits_in_block(&mut body, expressions); + filter_emits_in_block(&mut continuing, expressions); + block.push( Statement::Loop { body, continuing, break_if, }, - *span, + span, ); } - stmt => out.push(stmt.clone(), *span), + stmt => block.push(stmt.clone(), span), } } - out } fn map_value_to_literal(value: f64, scalar: Scalar) -> Result { diff --git a/naga/src/block.rs b/naga/src/block.rs index 0abda9da7c..2e86a928f1 100644 --- a/naga/src/block.rs +++ b/naga/src/block.rs @@ -65,6 +65,12 @@ impl Block { self.span_info.splice(range.clone(), other.span_info); self.body.splice(range, other.body); } + + pub fn span_into_iter(self) -> impl Iterator { + let Block { body, span_info } = self; + body.into_iter().zip(span_info) + } + pub fn span_iter(&self) -> impl Iterator { let span_iter = self.span_info.iter(); self.body.iter().zip(span_iter) From 2c29ecb7ece5eeb5f9f369d1887ac77b30735212 Mon Sep 17 00:00:00 2001 From: Jim Blandy Date: Sun, 24 Mar 2024 13:47:39 -0700 Subject: [PATCH 22/30] [naga] Tweak comments in `ConstantEvaluator::try_eval_and_append`. I found I needed a little bit more detail here. --- naga/src/proc/constant_evaluator.rs | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/naga/src/proc/constant_evaluator.rs b/naga/src/proc/constant_evaluator.rs index f1f01e5855..547fbbc652 100644 --- a/naga/src/proc/constant_evaluator.rs +++ b/naga/src/proc/constant_evaluator.rs @@ -713,7 +713,10 @@ impl<'a> ConstantEvaluator<'a> { match self.expression_kind_tracker.type_of_with_expr(&expr) { ExpressionKind::Const => { let eval_result = self.try_eval_and_append_impl(&expr, span); - // avoid errors on unimplemented functionality if possible + // We should be able to evaluate `Const` expressions at this + // point. If we failed to, then that probably means we just + // haven't implemented that part of constant evaluation. Work + // around this by simply emitting it as a run-time expression. if self.behavior.has_runtime_restrictions() && matches!( eval_result, From e62621b584ff272f3ff781b42cbb26665e6effb7 Mon Sep 17 00:00:00 2001 From: Jim Blandy Date: Mon, 25 Mar 2024 13:31:00 -0700 Subject: [PATCH 23/30] [naga] Add missing newline to test input file. --- naga/tests/in/overrides.wgsl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/naga/tests/in/overrides.wgsl b/naga/tests/in/overrides.wgsl index ab1d637a11..6173c3463f 100644 --- a/naga/tests/in/overrides.wgsl +++ b/naga/tests/in/overrides.wgsl @@ -22,4 +22,4 @@ fn main() { var x = a; var gain_x_100 = gain_x_10 * 10.; -} \ No newline at end of file +} From ddd1222e161e8bef434dccac053569090fa9f83e Mon Sep 17 00:00:00 2001 From: Jim Blandy Date: Mon, 25 Mar 2024 18:19:17 -0700 Subject: [PATCH 24/30] [naga] Handle comparison operands in pipeline constant evaluation. Properly adjust `AtomicFunction::Exchange::compare` after pipeline constant evaluation. --- naga/src/back/pipeline_constants.rs | 17 ++- ...rrides-atomicCompareExchangeWeak.param.ron | 9 ++ .../overrides-atomicCompareExchangeWeak.wgsl | 7 + ...ides-atomicCompareExchangeWeak.compact.ron | 128 ++++++++++++++++++ .../overrides-atomicCompareExchangeWeak.ron | 128 ++++++++++++++++++ ...errides-atomicCompareExchangeWeak.f.spvasm | 52 +++++++ naga/tests/snapshots.rs | 4 + 7 files changed, 344 insertions(+), 1 deletion(-) create mode 100644 naga/tests/in/overrides-atomicCompareExchangeWeak.param.ron create mode 100644 naga/tests/in/overrides-atomicCompareExchangeWeak.wgsl create mode 100644 naga/tests/out/ir/overrides-atomicCompareExchangeWeak.compact.ron create mode 100644 naga/tests/out/ir/overrides-atomicCompareExchangeWeak.ron create mode 100644 naga/tests/out/spv/overrides-atomicCompareExchangeWeak.f.spvasm diff --git a/naga/src/back/pipeline_constants.rs b/naga/src/back/pipeline_constants.rs index 62e3cd0e42..a7606f5bb7 100644 --- a/naga/src/back/pipeline_constants.rs +++ b/naga/src/back/pipeline_constants.rs @@ -571,11 +571,26 @@ fn adjust_stmt(new_pos: &[Handle], stmt: &mut Statement) { ref mut pointer, ref mut value, ref mut result, - .. + ref mut fun, } => { adjust(pointer); adjust(value); adjust(result); + match *fun { + crate::AtomicFunction::Exchange { + compare: Some(ref mut compare), + } => { + adjust(compare); + } + crate::AtomicFunction::Add + | crate::AtomicFunction::Subtract + | crate::AtomicFunction::And + | crate::AtomicFunction::ExclusiveOr + | crate::AtomicFunction::InclusiveOr + | crate::AtomicFunction::Min + | crate::AtomicFunction::Max + | crate::AtomicFunction::Exchange { compare: None } => {} + } } Statement::WorkGroupUniformLoad { ref mut pointer, diff --git a/naga/tests/in/overrides-atomicCompareExchangeWeak.param.ron b/naga/tests/in/overrides-atomicCompareExchangeWeak.param.ron new file mode 100644 index 0000000000..ff9c84ac61 --- /dev/null +++ b/naga/tests/in/overrides-atomicCompareExchangeWeak.param.ron @@ -0,0 +1,9 @@ +( + spv: ( + version: (1, 0), + separate_entry_points: true, + ), + pipeline_constants: { + "o": 2.0 + } +) diff --git a/naga/tests/in/overrides-atomicCompareExchangeWeak.wgsl b/naga/tests/in/overrides-atomicCompareExchangeWeak.wgsl new file mode 100644 index 0000000000..03376b5931 --- /dev/null +++ b/naga/tests/in/overrides-atomicCompareExchangeWeak.wgsl @@ -0,0 +1,7 @@ +override o: i32; +var a: atomic; + +@compute @workgroup_size(1) +fn f() { + atomicCompareExchangeWeak(&a, u32(o), 1u); +} diff --git a/naga/tests/out/ir/overrides-atomicCompareExchangeWeak.compact.ron b/naga/tests/out/ir/overrides-atomicCompareExchangeWeak.compact.ron new file mode 100644 index 0000000000..8c889382dd --- /dev/null +++ b/naga/tests/out/ir/overrides-atomicCompareExchangeWeak.compact.ron @@ -0,0 +1,128 @@ +( + types: [ + ( + name: None, + inner: Scalar(( + kind: Sint, + width: 4, + )), + ), + ( + name: None, + inner: Atomic(( + kind: Uint, + width: 4, + )), + ), + ( + name: None, + inner: Scalar(( + kind: Uint, + width: 4, + )), + ), + ( + name: None, + inner: Scalar(( + kind: Bool, + width: 1, + )), + ), + ( + name: Some("__atomic_compare_exchange_result"), + inner: Struct( + members: [ + ( + name: Some("old_value"), + ty: 3, + binding: None, + offset: 0, + ), + ( + name: Some("exchanged"), + ty: 4, + binding: None, + offset: 4, + ), + ], + span: 8, + ), + ), + ], + special_types: ( + ray_desc: None, + ray_intersection: None, + predeclared_types: { + AtomicCompareExchangeWeakResult(( + kind: Uint, + width: 4, + )): 5, + }, + ), + constants: [], + overrides: [ + ( + name: Some("o"), + id: None, + ty: 1, + init: None, + ), + ], + global_variables: [ + ( + name: Some("a"), + space: WorkGroup, + binding: None, + ty: 2, + init: None, + ), + ], + global_expressions: [], + functions: [], + entry_points: [ + ( + name: "f", + stage: Compute, + early_depth_test: None, + workgroup_size: (1, 1, 1), + function: ( + name: Some("f"), + arguments: [], + result: None, + local_variables: [], + expressions: [ + GlobalVariable(1), + Override(1), + As( + expr: 2, + kind: Uint, + convert: Some(4), + ), + Literal(U32(1)), + AtomicResult( + ty: 5, + comparison: true, + ), + ], + named_expressions: {}, + body: [ + Emit(( + start: 2, + end: 3, + )), + Atomic( + pointer: 1, + fun: Exchange( + compare: Some(3), + ), + value: 4, + result: 5, + ), + Return( + value: None, + ), + ], + ), + ), + ], +) \ No newline at end of file diff --git a/naga/tests/out/ir/overrides-atomicCompareExchangeWeak.ron b/naga/tests/out/ir/overrides-atomicCompareExchangeWeak.ron new file mode 100644 index 0000000000..8c889382dd --- /dev/null +++ b/naga/tests/out/ir/overrides-atomicCompareExchangeWeak.ron @@ -0,0 +1,128 @@ +( + types: [ + ( + name: None, + inner: Scalar(( + kind: Sint, + width: 4, + )), + ), + ( + name: None, + inner: Atomic(( + kind: Uint, + width: 4, + )), + ), + ( + name: None, + inner: Scalar(( + kind: Uint, + width: 4, + )), + ), + ( + name: None, + inner: Scalar(( + kind: Bool, + width: 1, + )), + ), + ( + name: Some("__atomic_compare_exchange_result"), + inner: Struct( + members: [ + ( + name: Some("old_value"), + ty: 3, + binding: None, + offset: 0, + ), + ( + name: Some("exchanged"), + ty: 4, + binding: None, + offset: 4, + ), + ], + span: 8, + ), + ), + ], + special_types: ( + ray_desc: None, + ray_intersection: None, + predeclared_types: { + AtomicCompareExchangeWeakResult(( + kind: Uint, + width: 4, + )): 5, + }, + ), + constants: [], + overrides: [ + ( + name: Some("o"), + id: None, + ty: 1, + init: None, + ), + ], + global_variables: [ + ( + name: Some("a"), + space: WorkGroup, + binding: None, + ty: 2, + init: None, + ), + ], + global_expressions: [], + functions: [], + entry_points: [ + ( + name: "f", + stage: Compute, + early_depth_test: None, + workgroup_size: (1, 1, 1), + function: ( + name: Some("f"), + arguments: [], + result: None, + local_variables: [], + expressions: [ + GlobalVariable(1), + Override(1), + As( + expr: 2, + kind: Uint, + convert: Some(4), + ), + Literal(U32(1)), + AtomicResult( + ty: 5, + comparison: true, + ), + ], + named_expressions: {}, + body: [ + Emit(( + start: 2, + end: 3, + )), + Atomic( + pointer: 1, + fun: Exchange( + compare: Some(3), + ), + value: 4, + result: 5, + ), + Return( + value: None, + ), + ], + ), + ), + ], +) \ No newline at end of file diff --git a/naga/tests/out/spv/overrides-atomicCompareExchangeWeak.f.spvasm b/naga/tests/out/spv/overrides-atomicCompareExchangeWeak.f.spvasm new file mode 100644 index 0000000000..59c69ae1fc --- /dev/null +++ b/naga/tests/out/spv/overrides-atomicCompareExchangeWeak.f.spvasm @@ -0,0 +1,52 @@ +; SPIR-V +; Version: 1.0 +; Generator: rspirv +; Bound: 33 +OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint GLCompute %11 "f" %18 +OpExecutionMode %11 LocalSize 1 1 1 +OpMemberDecorate %6 0 Offset 0 +OpMemberDecorate %6 1 Offset 4 +OpDecorate %18 BuiltIn LocalInvocationId +%2 = OpTypeVoid +%3 = OpTypeInt 32 1 +%4 = OpTypeInt 32 0 +%5 = OpTypeBool +%6 = OpTypeStruct %4 %5 +%7 = OpConstant %3 2 +%9 = OpTypePointer Workgroup %4 +%8 = OpVariable %9 Workgroup +%12 = OpTypeFunction %2 +%13 = OpConstant %4 2 +%14 = OpConstant %4 1 +%16 = OpConstantNull %4 +%17 = OpTypeVector %4 3 +%19 = OpTypePointer Input %17 +%18 = OpVariable %19 Input +%21 = OpConstantNull %17 +%22 = OpTypeVector %5 3 +%27 = OpConstant %4 264 +%30 = OpConstant %4 256 +%11 = OpFunction %2 None %12 +%10 = OpLabel +OpBranch %15 +%15 = OpLabel +%20 = OpLoad %17 %18 +%23 = OpIEqual %22 %20 %21 +%24 = OpAll %5 %23 +OpSelectionMerge %25 None +OpBranchConditional %24 %26 %25 +%26 = OpLabel +OpStore %8 %16 +OpBranch %25 +%25 = OpLabel +OpControlBarrier %13 %13 %27 +OpBranch %28 +%28 = OpLabel +%31 = OpAtomicCompareExchange %4 %8 %7 %30 %30 %14 %13 +%32 = OpIEqual %5 %31 %13 +%29 = OpCompositeConstruct %6 %31 %32 +OpReturn +OpFunctionEnd \ No newline at end of file diff --git a/naga/tests/snapshots.rs b/naga/tests/snapshots.rs index e2f6dff25f..151e8b3da3 100644 --- a/naga/tests/snapshots.rs +++ b/naga/tests/snapshots.rs @@ -853,6 +853,10 @@ fn convert_wgsl() { "overrides", Targets::IR | Targets::ANALYSIS | Targets::SPIRV | Targets::METAL | Targets::HLSL, ), + ( + "overrides-atomicCompareExchangeWeak", + Targets::IR | Targets::SPIRV, + ), ]; for &(name, targets) in inputs.iter() { From f015fb16bd40d9beeabd343199798079ec9c4b78 Mon Sep 17 00:00:00 2001 From: Jim Blandy Date: Mon, 25 Mar 2024 18:29:11 -0700 Subject: [PATCH 25/30] [naga] Spell out members in adjust_expr. --- naga/src/back/pipeline_constants.rs | 61 ++++++++++++++++++++++------- 1 file changed, 46 insertions(+), 15 deletions(-) diff --git a/naga/src/back/pipeline_constants.rs b/naga/src/back/pipeline_constants.rs index a7606f5bb7..d41eeedef2 100644 --- a/naga/src/back/pipeline_constants.rs +++ b/naga/src/back/pipeline_constants.rs @@ -324,7 +324,8 @@ fn adjust_expr(new_pos: &[Handle], expr: &mut Expression) { }; match *expr { Expression::Compose { - ref mut components, .. + ref mut components, + ty: _, } => { for c in components.iter_mut() { adjust(c); @@ -337,13 +338,23 @@ fn adjust_expr(new_pos: &[Handle], expr: &mut Expression) { adjust(base); adjust(index); } - Expression::AccessIndex { ref mut base, .. } => { + Expression::AccessIndex { + ref mut base, + index: _, + } => { adjust(base); } - Expression::Splat { ref mut value, .. } => { + Expression::Splat { + ref mut value, + size: _, + } => { adjust(value); } - Expression::Swizzle { ref mut vector, .. } => { + Expression::Swizzle { + ref mut vector, + size: _, + pattern: _, + } => { adjust(vector); } Expression::Load { ref mut pointer } => { @@ -357,7 +368,7 @@ fn adjust_expr(new_pos: &[Handle], expr: &mut Expression) { ref mut offset, ref mut level, ref mut depth_ref, - .. + gather: _, } => { adjust(image); adjust(sampler); @@ -416,16 +427,21 @@ fn adjust_expr(new_pos: &[Handle], expr: &mut Expression) { adjust(e); } } - _ => {} + crate::ImageQuery::NumLevels + | crate::ImageQuery::NumLayers + | crate::ImageQuery::NumSamples => {} } } - Expression::Unary { ref mut expr, .. } => { + Expression::Unary { + ref mut expr, + op: _, + } => { adjust(expr); } Expression::Binary { ref mut left, ref mut right, - .. + op: _, } => { adjust(left); adjust(right); @@ -439,11 +455,16 @@ fn adjust_expr(new_pos: &[Handle], expr: &mut Expression) { adjust(accept); adjust(reject); } - Expression::Derivative { ref mut expr, .. } => { + Expression::Derivative { + ref mut expr, + axis: _, + ctrl: _, + } => { adjust(expr); } Expression::Relational { - ref mut argument, .. + ref mut argument, + fun: _, } => { adjust(argument); } @@ -452,7 +473,7 @@ fn adjust_expr(new_pos: &[Handle], expr: &mut Expression) { ref mut arg1, ref mut arg2, ref mut arg3, - .. + fun: _, } => { adjust(arg); if let Some(e) = arg1.as_mut() { @@ -465,13 +486,20 @@ fn adjust_expr(new_pos: &[Handle], expr: &mut Expression) { adjust(e); } } - Expression::As { ref mut expr, .. } => { + Expression::As { + ref mut expr, + kind: _, + convert: _, + } => { adjust(expr); } Expression::ArrayLength(ref mut expr) => { adjust(expr); } - Expression::RayQueryGetIntersection { ref mut query, .. } => { + Expression::RayQueryGetIntersection { + ref mut query, + committed: _, + } => { adjust(query); } Expression::Literal(_) @@ -483,8 +511,11 @@ fn adjust_expr(new_pos: &[Handle], expr: &mut Expression) { | Expression::Constant(_) | Expression::Override(_) | Expression::ZeroValue(_) - | Expression::AtomicResult { .. } - | Expression::WorkGroupUniformLoadResult { .. } => {} + | Expression::AtomicResult { + ty: _, + comparison: _, + } + | Expression::WorkGroupUniformLoadResult { ty: _ } => {} } } From de50eb08b95a27d999eafe4c4b3ae8a1f384e6e6 Mon Sep 17 00:00:00 2001 From: Jim Blandy Date: Mon, 25 Mar 2024 19:11:23 -0700 Subject: [PATCH 26/30] [naga] Adjust RayQuery statements in override processing. --- naga/src/back/pipeline_constants.rs | 20 +- naga/tests/in/overrides-ray-query.param.ron | 18 ++ naga/tests/in/overrides-ray-query.wgsl | 21 ++ .../out/ir/overrides-ray-query.compact.ron | 259 ++++++++++++++++++ naga/tests/out/ir/overrides-ray-query.ron | 259 ++++++++++++++++++ naga/tests/out/msl/overrides-ray-query.msl | 45 +++ .../out/spv/overrides-ray-query.main.spvasm | 77 ++++++ naga/tests/snapshots.rs | 5 + 8 files changed, 702 insertions(+), 2 deletions(-) create mode 100644 naga/tests/in/overrides-ray-query.param.ron create mode 100644 naga/tests/in/overrides-ray-query.wgsl create mode 100644 naga/tests/out/ir/overrides-ray-query.compact.ron create mode 100644 naga/tests/out/ir/overrides-ray-query.ron create mode 100644 naga/tests/out/msl/overrides-ray-query.msl create mode 100644 naga/tests/out/spv/overrides-ray-query.main.spvasm diff --git a/naga/src/back/pipeline_constants.rs b/naga/src/back/pipeline_constants.rs index d41eeedef2..c1fd2d02cc 100644 --- a/naga/src/back/pipeline_constants.rs +++ b/naga/src/back/pipeline_constants.rs @@ -633,7 +633,7 @@ fn adjust_stmt(new_pos: &[Handle], stmt: &mut Statement) { Statement::Call { ref mut arguments, ref mut result, - .. + function: _, } => { for argument in arguments.iter_mut() { adjust(argument); @@ -642,8 +642,24 @@ fn adjust_stmt(new_pos: &[Handle], stmt: &mut Statement) { adjust(e); } } - Statement::RayQuery { ref mut query, .. } => { + Statement::RayQuery { + ref mut query, + ref mut fun, + } => { adjust(query); + match *fun { + crate::RayQueryFunction::Initialize { + ref mut acceleration_structure, + ref mut descriptor, + } => { + adjust(acceleration_structure); + adjust(descriptor); + } + crate::RayQueryFunction::Proceed { ref mut result } => { + adjust(result); + } + crate::RayQueryFunction::Terminate => {} + } } Statement::Break | Statement::Continue | Statement::Kill | Statement::Barrier(_) => {} } diff --git a/naga/tests/in/overrides-ray-query.param.ron b/naga/tests/in/overrides-ray-query.param.ron new file mode 100644 index 0000000000..588656aaac --- /dev/null +++ b/naga/tests/in/overrides-ray-query.param.ron @@ -0,0 +1,18 @@ +( + god_mode: true, + spv: ( + version: (1, 4), + separate_entry_points: true, + ), + msl: ( + lang_version: (2, 4), + spirv_cross_compatibility: false, + fake_missing_bindings: true, + zero_initialize_workgroup_memory: false, + per_entry_point_map: {}, + inline_samplers: [], + ), + pipeline_constants: { + "o": 2.0 + } +) diff --git a/naga/tests/in/overrides-ray-query.wgsl b/naga/tests/in/overrides-ray-query.wgsl new file mode 100644 index 0000000000..dca7447ed0 --- /dev/null +++ b/naga/tests/in/overrides-ray-query.wgsl @@ -0,0 +1,21 @@ +override o: f32; + +@group(0) @binding(0) +var acc_struct: acceleration_structure; + +@compute @workgroup_size(1) +fn main() { + var rq: ray_query; + + let desc = RayDesc( + RAY_FLAG_TERMINATE_ON_FIRST_HIT, + 0xFFu, + o * 17.0, + o * 19.0, + vec3(o * 23.0), + vec3(o * 29.0, o * 31.0, o * 37.0), + ); + rayQueryInitialize(&rq, acc_struct, desc); + + while (rayQueryProceed(&rq)) {} +} diff --git a/naga/tests/out/ir/overrides-ray-query.compact.ron b/naga/tests/out/ir/overrides-ray-query.compact.ron new file mode 100644 index 0000000000..b127259bbb --- /dev/null +++ b/naga/tests/out/ir/overrides-ray-query.compact.ron @@ -0,0 +1,259 @@ +( + types: [ + ( + name: None, + inner: Scalar(( + kind: Float, + width: 4, + )), + ), + ( + name: None, + inner: AccelerationStructure, + ), + ( + name: None, + inner: RayQuery, + ), + ( + name: None, + inner: Scalar(( + kind: Uint, + width: 4, + )), + ), + ( + name: None, + inner: Vector( + size: Tri, + scalar: ( + kind: Float, + width: 4, + ), + ), + ), + ( + name: Some("RayDesc"), + inner: Struct( + members: [ + ( + name: Some("flags"), + ty: 4, + binding: None, + offset: 0, + ), + ( + name: Some("cull_mask"), + ty: 4, + binding: None, + offset: 4, + ), + ( + name: Some("tmin"), + ty: 1, + binding: None, + offset: 8, + ), + ( + name: Some("tmax"), + ty: 1, + binding: None, + offset: 12, + ), + ( + name: Some("origin"), + ty: 5, + binding: None, + offset: 16, + ), + ( + name: Some("dir"), + ty: 5, + binding: None, + offset: 32, + ), + ], + span: 48, + ), + ), + ], + special_types: ( + ray_desc: Some(6), + ray_intersection: None, + predeclared_types: {}, + ), + constants: [], + overrides: [ + ( + name: Some("o"), + id: None, + ty: 1, + init: None, + ), + ], + global_variables: [ + ( + name: Some("acc_struct"), + space: Handle, + binding: Some(( + group: 0, + binding: 0, + )), + ty: 2, + init: None, + ), + ], + global_expressions: [], + functions: [], + entry_points: [ + ( + name: "main", + stage: Compute, + early_depth_test: None, + workgroup_size: (1, 1, 1), + function: ( + name: Some("main"), + arguments: [], + result: None, + local_variables: [ + ( + name: Some("rq"), + ty: 3, + init: None, + ), + ], + expressions: [ + LocalVariable(1), + Literal(U32(4)), + Literal(U32(255)), + Override(1), + Literal(F32(17.0)), + Binary( + op: Multiply, + left: 4, + right: 5, + ), + Override(1), + Literal(F32(19.0)), + Binary( + op: Multiply, + left: 7, + right: 8, + ), + Override(1), + Literal(F32(23.0)), + Binary( + op: Multiply, + left: 10, + right: 11, + ), + Splat( + size: Tri, + value: 12, + ), + Override(1), + Literal(F32(29.0)), + Binary( + op: Multiply, + left: 14, + right: 15, + ), + Override(1), + Literal(F32(31.0)), + Binary( + op: Multiply, + left: 17, + right: 18, + ), + Override(1), + Literal(F32(37.0)), + Binary( + op: Multiply, + left: 20, + right: 21, + ), + Compose( + ty: 5, + components: [ + 16, + 19, + 22, + ], + ), + Compose( + ty: 6, + components: [ + 2, + 3, + 6, + 9, + 13, + 23, + ], + ), + GlobalVariable(1), + RayQueryProceedResult, + ], + named_expressions: { + 24: "desc", + }, + body: [ + Emit(( + start: 5, + end: 6, + )), + Emit(( + start: 8, + end: 9, + )), + Emit(( + start: 11, + end: 13, + )), + Emit(( + start: 15, + end: 16, + )), + Emit(( + start: 18, + end: 19, + )), + Emit(( + start: 21, + end: 24, + )), + RayQuery( + query: 1, + fun: Initialize( + acceleration_structure: 25, + descriptor: 24, + ), + ), + Loop( + body: [ + RayQuery( + query: 1, + fun: Proceed( + result: 26, + ), + ), + If( + condition: 26, + accept: [], + reject: [ + Break, + ], + ), + Block([]), + ], + continuing: [], + break_if: None, + ), + Return( + value: None, + ), + ], + ), + ), + ], +) \ No newline at end of file diff --git a/naga/tests/out/ir/overrides-ray-query.ron b/naga/tests/out/ir/overrides-ray-query.ron new file mode 100644 index 0000000000..b127259bbb --- /dev/null +++ b/naga/tests/out/ir/overrides-ray-query.ron @@ -0,0 +1,259 @@ +( + types: [ + ( + name: None, + inner: Scalar(( + kind: Float, + width: 4, + )), + ), + ( + name: None, + inner: AccelerationStructure, + ), + ( + name: None, + inner: RayQuery, + ), + ( + name: None, + inner: Scalar(( + kind: Uint, + width: 4, + )), + ), + ( + name: None, + inner: Vector( + size: Tri, + scalar: ( + kind: Float, + width: 4, + ), + ), + ), + ( + name: Some("RayDesc"), + inner: Struct( + members: [ + ( + name: Some("flags"), + ty: 4, + binding: None, + offset: 0, + ), + ( + name: Some("cull_mask"), + ty: 4, + binding: None, + offset: 4, + ), + ( + name: Some("tmin"), + ty: 1, + binding: None, + offset: 8, + ), + ( + name: Some("tmax"), + ty: 1, + binding: None, + offset: 12, + ), + ( + name: Some("origin"), + ty: 5, + binding: None, + offset: 16, + ), + ( + name: Some("dir"), + ty: 5, + binding: None, + offset: 32, + ), + ], + span: 48, + ), + ), + ], + special_types: ( + ray_desc: Some(6), + ray_intersection: None, + predeclared_types: {}, + ), + constants: [], + overrides: [ + ( + name: Some("o"), + id: None, + ty: 1, + init: None, + ), + ], + global_variables: [ + ( + name: Some("acc_struct"), + space: Handle, + binding: Some(( + group: 0, + binding: 0, + )), + ty: 2, + init: None, + ), + ], + global_expressions: [], + functions: [], + entry_points: [ + ( + name: "main", + stage: Compute, + early_depth_test: None, + workgroup_size: (1, 1, 1), + function: ( + name: Some("main"), + arguments: [], + result: None, + local_variables: [ + ( + name: Some("rq"), + ty: 3, + init: None, + ), + ], + expressions: [ + LocalVariable(1), + Literal(U32(4)), + Literal(U32(255)), + Override(1), + Literal(F32(17.0)), + Binary( + op: Multiply, + left: 4, + right: 5, + ), + Override(1), + Literal(F32(19.0)), + Binary( + op: Multiply, + left: 7, + right: 8, + ), + Override(1), + Literal(F32(23.0)), + Binary( + op: Multiply, + left: 10, + right: 11, + ), + Splat( + size: Tri, + value: 12, + ), + Override(1), + Literal(F32(29.0)), + Binary( + op: Multiply, + left: 14, + right: 15, + ), + Override(1), + Literal(F32(31.0)), + Binary( + op: Multiply, + left: 17, + right: 18, + ), + Override(1), + Literal(F32(37.0)), + Binary( + op: Multiply, + left: 20, + right: 21, + ), + Compose( + ty: 5, + components: [ + 16, + 19, + 22, + ], + ), + Compose( + ty: 6, + components: [ + 2, + 3, + 6, + 9, + 13, + 23, + ], + ), + GlobalVariable(1), + RayQueryProceedResult, + ], + named_expressions: { + 24: "desc", + }, + body: [ + Emit(( + start: 5, + end: 6, + )), + Emit(( + start: 8, + end: 9, + )), + Emit(( + start: 11, + end: 13, + )), + Emit(( + start: 15, + end: 16, + )), + Emit(( + start: 18, + end: 19, + )), + Emit(( + start: 21, + end: 24, + )), + RayQuery( + query: 1, + fun: Initialize( + acceleration_structure: 25, + descriptor: 24, + ), + ), + Loop( + body: [ + RayQuery( + query: 1, + fun: Proceed( + result: 26, + ), + ), + If( + condition: 26, + accept: [], + reject: [ + Break, + ], + ), + Block([]), + ], + continuing: [], + break_if: None, + ), + Return( + value: None, + ), + ], + ), + ), + ], +) \ No newline at end of file diff --git a/naga/tests/out/msl/overrides-ray-query.msl b/naga/tests/out/msl/overrides-ray-query.msl new file mode 100644 index 0000000000..3a508b6f61 --- /dev/null +++ b/naga/tests/out/msl/overrides-ray-query.msl @@ -0,0 +1,45 @@ +// language: metal2.4 +#include +#include + +using metal::uint; +struct _RayQuery { + metal::raytracing::intersector intersector; + metal::raytracing::intersector::result_type intersection; + bool ready = false; +}; +constexpr metal::uint _map_intersection_type(const metal::raytracing::intersection_type ty) { + return ty==metal::raytracing::intersection_type::triangle ? 1 : + ty==metal::raytracing::intersection_type::bounding_box ? 4 : 0; +} + +struct RayDesc { + uint flags; + uint cull_mask; + float tmin; + float tmax; + metal::float3 origin; + metal::float3 dir; +}; +constant float o = 2.0; + +kernel void main_( + metal::raytracing::instance_acceleration_structure acc_struct [[user(fake0)]] +) { + _RayQuery rq = {}; + RayDesc desc = RayDesc {4u, 255u, 34.0, 38.0, metal::float3(46.0), metal::float3(58.0, 62.0, 74.0)}; + rq.intersector.assume_geometry_type(metal::raytracing::geometry_type::triangle); + rq.intersector.set_opacity_cull_mode((desc.flags & 64) != 0 ? metal::raytracing::opacity_cull_mode::opaque : (desc.flags & 128) != 0 ? metal::raytracing::opacity_cull_mode::non_opaque : metal::raytracing::opacity_cull_mode::none); + rq.intersector.force_opacity((desc.flags & 1) != 0 ? metal::raytracing::forced_opacity::opaque : (desc.flags & 2) != 0 ? metal::raytracing::forced_opacity::non_opaque : metal::raytracing::forced_opacity::none); + rq.intersector.accept_any_intersection((desc.flags & 4) != 0); + rq.intersection = rq.intersector.intersect(metal::raytracing::ray(desc.origin, desc.dir, desc.tmin, desc.tmax), acc_struct, desc.cull_mask); rq.ready = true; + while(true) { + bool _e31 = rq.ready; + rq.ready = false; + if (_e31) { + } else { + break; + } + } + return; +} diff --git a/naga/tests/out/spv/overrides-ray-query.main.spvasm b/naga/tests/out/spv/overrides-ray-query.main.spvasm new file mode 100644 index 0000000000..a341393468 --- /dev/null +++ b/naga/tests/out/spv/overrides-ray-query.main.spvasm @@ -0,0 +1,77 @@ +; SPIR-V +; Version: 1.4 +; Generator: rspirv +; Bound: 46 +OpCapability Shader +OpCapability RayQueryKHR +OpExtension "SPV_KHR_ray_query" +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint GLCompute %13 "main" %10 +OpExecutionMode %13 LocalSize 1 1 1 +OpMemberDecorate %8 0 Offset 0 +OpMemberDecorate %8 1 Offset 4 +OpMemberDecorate %8 2 Offset 8 +OpMemberDecorate %8 3 Offset 12 +OpMemberDecorate %8 4 Offset 16 +OpMemberDecorate %8 5 Offset 32 +OpDecorate %10 DescriptorSet 0 +OpDecorate %10 Binding 0 +%2 = OpTypeVoid +%3 = OpTypeFloat 32 +%4 = OpTypeAccelerationStructureNV +%5 = OpTypeRayQueryKHR +%6 = OpTypeInt 32 0 +%7 = OpTypeVector %3 3 +%8 = OpTypeStruct %6 %6 %3 %3 %7 %7 +%9 = OpConstant %3 2.0 +%11 = OpTypePointer UniformConstant %4 +%10 = OpVariable %11 UniformConstant +%14 = OpTypeFunction %2 +%16 = OpConstant %6 4 +%17 = OpConstant %6 255 +%18 = OpConstant %3 34.0 +%19 = OpConstant %3 38.0 +%20 = OpConstant %3 46.0 +%21 = OpConstantComposite %7 %20 %20 %20 +%22 = OpConstant %3 58.0 +%23 = OpConstant %3 62.0 +%24 = OpConstant %3 74.0 +%25 = OpConstantComposite %7 %22 %23 %24 +%26 = OpConstantComposite %8 %16 %17 %18 %19 %21 %25 +%28 = OpTypePointer Function %5 +%41 = OpTypeBool +%13 = OpFunction %2 None %14 +%12 = OpLabel +%27 = OpVariable %28 Function +%15 = OpLoad %4 %10 +OpBranch %29 +%29 = OpLabel +%30 = OpCompositeExtract %6 %26 0 +%31 = OpCompositeExtract %6 %26 1 +%32 = OpCompositeExtract %3 %26 2 +%33 = OpCompositeExtract %3 %26 3 +%34 = OpCompositeExtract %7 %26 4 +%35 = OpCompositeExtract %7 %26 5 +OpRayQueryInitializeKHR %27 %15 %30 %31 %34 %32 %35 %33 +OpBranch %36 +%36 = OpLabel +OpLoopMerge %37 %39 None +OpBranch %38 +%38 = OpLabel +%40 = OpRayQueryProceedKHR %41 %27 +OpSelectionMerge %42 None +OpBranchConditional %40 %42 %43 +%43 = OpLabel +OpBranch %37 +%42 = OpLabel +OpBranch %44 +%44 = OpLabel +OpBranch %45 +%45 = OpLabel +OpBranch %39 +%39 = OpLabel +OpBranch %36 +%37 = OpLabel +OpReturn +OpFunctionEnd \ No newline at end of file diff --git a/naga/tests/snapshots.rs b/naga/tests/snapshots.rs index 151e8b3da3..94c50c7975 100644 --- a/naga/tests/snapshots.rs +++ b/naga/tests/snapshots.rs @@ -466,6 +466,7 @@ fn write_output_spv( ); } } else { + assert!(pipeline_constants.is_empty()); write_output_spv_inner(input, module, info, &options, None, "spvasm"); } } @@ -857,6 +858,10 @@ fn convert_wgsl() { "overrides-atomicCompareExchangeWeak", Targets::IR | Targets::SPIRV, ), + ( + "overrides-ray-query", + Targets::IR | Targets::SPIRV | Targets::METAL, + ), ]; for &(name, targets) in inputs.iter() { From b9403c1ecd401d3302aabb415caf6b2b73a24c84 Mon Sep 17 00:00:00 2001 From: Jim Blandy Date: Wed, 20 Mar 2024 17:11:38 -0400 Subject: [PATCH 27/30] [naga-cli] Add `--override` option. --- naga-cli/src/bin/naga.rs | 37 +++++++++++++++++++++++++++++++++++-- 1 file changed, 35 insertions(+), 2 deletions(-) diff --git a/naga-cli/src/bin/naga.rs b/naga-cli/src/bin/naga.rs index a20611114b..58293d69fa 100644 --- a/naga-cli/src/bin/naga.rs +++ b/naga-cli/src/bin/naga.rs @@ -105,6 +105,10 @@ struct Args { #[argh(switch)] version: bool, + /// override value, of the form "foo=N,bar=M", repeatable + #[argh(option, long = "override")] + overrides: Vec, + /// the input and output files. /// /// First positional argument is the input file. If not specified, the @@ -202,12 +206,34 @@ impl FromStr for MslVersionArg { } } +#[derive(Clone, Debug)] +struct Overrides { + pairs: Vec<(String, f64)>, +} + +impl FromStr for Overrides { + type Err = String; + + fn from_str(s: &str) -> Result { + let mut pairs = vec![]; + for pair in s.split(',') { + let Some((name, value)) = pair.split_once('=') else { + return Err(format!("value needs a `=`: {pair:?}")); + }; + let value = f64::from_str(value.trim()).map_err(|err| format!("{err}: {value:?}"))?; + pairs.push((name.trim().to_string(), value)); + } + Ok(Overrides { pairs }) + } +} + #[derive(Default)] struct Parameters<'a> { validation_flags: naga::valid::ValidationFlags, bounds_check_policies: naga::proc::BoundsCheckPolicies, entry_point: Option, keep_coordinate_space: bool, + overrides: naga::back::PipelineConstants, spv_in: naga::front::spv::Options, spv_out: naga::back::spv::Options<'a>, dot: naga::back::dot::Options, @@ -301,7 +327,12 @@ fn run() -> Result<(), Box> { Some(arg) => arg.0, None => params.bounds_check_policies.index, }; - + params.overrides = args + .overrides + .iter() + .flat_map(|o| &o.pairs) + .cloned() + .collect(); params.spv_in = naga::front::spv::Options { adjust_coordinate_space: !args.keep_coordinate_space, strict_capabilities: false, @@ -670,7 +701,9 @@ fn write_output( "Generating hlsl output requires validation to \ succeed, and it failed in a previous step", ))?, - &hlsl::PipelineOptions::default(), + &hlsl::PipelineOptions { + constants: params.overrides.clone(), + }, ) .unwrap_pretty(); fs::write(output_path, buffer)?; From 9c3fc6cf867d025bf454cd5c6543b2c46a045533 Mon Sep 17 00:00:00 2001 From: teoxoy <28601907+teoxoy@users.noreply.github.com> Date: Thu, 4 Apr 2024 11:58:28 +0200 Subject: [PATCH 28/30] move the burden of evaluating override-expressions to users of naga's API --- naga-cli/src/bin/naga.rs | 83 ++++++++++--------- naga/benches/criterion.rs | 5 +- naga/src/back/glsl/mod.rs | 10 +-- naga/src/back/hlsl/mod.rs | 12 +-- naga/src/back/hlsl/writer.rs | 18 ++-- naga/src/back/mod.rs | 2 +- naga/src/back/msl/mod.rs | 6 +- naga/src/back/msl/writer.rs | 12 +-- naga/src/back/pipeline_constants.rs | 2 +- naga/src/back/spv/block.rs | 4 +- naga/src/back/spv/mod.rs | 6 +- naga/src/back/spv/writer.rs | 18 +--- naga/src/back/wgsl/writer.rs | 4 +- naga/tests/in/interface.param.ron | 1 - .../out/glsl/overrides.main.Compute.glsl | 29 +++++++ naga/tests/snapshots.rs | 57 +++++++------ wgpu-hal/src/dx12/device.rs | 14 ++-- wgpu-hal/src/gles/device.rs | 23 +++-- wgpu-hal/src/metal/device.rs | 21 ++--- wgpu-hal/src/vulkan/device.rs | 16 ++-- 20 files changed, 175 insertions(+), 168 deletions(-) create mode 100644 naga/tests/out/glsl/overrides.main.Compute.glsl diff --git a/naga-cli/src/bin/naga.rs b/naga-cli/src/bin/naga.rs index 58293d69fa..36ca1e99ae 100644 --- a/naga-cli/src/bin/naga.rs +++ b/naga-cli/src/bin/naga.rs @@ -597,17 +597,18 @@ fn write_output( let mut options = params.msl.clone(); options.bounds_check_policies = params.bounds_check_policies; + let info = info.as_ref().ok_or(CliError( + "Generating metal output requires validation to \ + succeed, and it failed in a previous step", + ))?; + + let (module, info) = + naga::back::pipeline_constants::process_overrides(module, info, ¶ms.overrides) + .unwrap_pretty(); + let pipeline_options = msl::PipelineOptions::default(); - let (msl, _) = msl::write_string( - module, - info.as_ref().ok_or(CliError( - "Generating metal output requires validation to \ - succeed, and it failed in a previous step", - ))?, - &options, - &pipeline_options, - ) - .unwrap_pretty(); + let (msl, _) = + msl::write_string(&module, &info, &options, &pipeline_options).unwrap_pretty(); fs::write(output_path, msl)?; } "spv" => { @@ -624,23 +625,23 @@ fn write_output( pipeline_options_owned = spv::PipelineOptions { entry_point: name.clone(), shader_stage: module.entry_points[ep_index].stage, - constants: naga::back::PipelineConstants::default(), }; Some(&pipeline_options_owned) } None => None, }; - let spv = spv::write_vec( - module, - info.as_ref().ok_or(CliError( - "Generating SPIR-V output requires validation to \ - succeed, and it failed in a previous step", - ))?, - ¶ms.spv_out, - pipeline_options, - ) - .unwrap_pretty(); + let info = info.as_ref().ok_or(CliError( + "Generating SPIR-V output requires validation to \ + succeed, and it failed in a previous step", + ))?; + + let (module, info) = + naga::back::pipeline_constants::process_overrides(module, info, ¶ms.overrides) + .unwrap_pretty(); + + let spv = + spv::write_vec(&module, &info, ¶ms.spv_out, pipeline_options).unwrap_pretty(); let bytes = spv .iter() .fold(Vec::with_capacity(spv.len() * 4), |mut v, w| { @@ -665,17 +666,22 @@ fn write_output( _ => unreachable!(), }, multiview: None, - constants: naga::back::PipelineConstants::default(), }; + let info = info.as_ref().ok_or(CliError( + "Generating glsl output requires validation to \ + succeed, and it failed in a previous step", + ))?; + + let (module, info) = + naga::back::pipeline_constants::process_overrides(module, info, ¶ms.overrides) + .unwrap_pretty(); + let mut buffer = String::new(); let mut writer = glsl::Writer::new( &mut buffer, - module, - info.as_ref().ok_or(CliError( - "Generating glsl output requires validation to \ - succeed, and it failed in a previous step", - ))?, + &module, + &info, ¶ms.glsl, &pipeline_options, params.bounds_check_policies, @@ -692,20 +698,19 @@ fn write_output( } "hlsl" => { use naga::back::hlsl; + + let info = info.as_ref().ok_or(CliError( + "Generating hlsl output requires validation to \ + succeed, and it failed in a previous step", + ))?; + + let (module, info) = + naga::back::pipeline_constants::process_overrides(module, info, ¶ms.overrides) + .unwrap_pretty(); + let mut buffer = String::new(); let mut writer = hlsl::Writer::new(&mut buffer, ¶ms.hlsl); - writer - .write( - module, - info.as_ref().ok_or(CliError( - "Generating hlsl output requires validation to \ - succeed, and it failed in a previous step", - ))?, - &hlsl::PipelineOptions { - constants: params.overrides.clone(), - }, - ) - .unwrap_pretty(); + writer.write(&module, &info).unwrap_pretty(); fs::write(output_path, buffer)?; } "wgsl" => { diff --git a/naga/benches/criterion.rs b/naga/benches/criterion.rs index 420c9ee335..e57c58a847 100644 --- a/naga/benches/criterion.rs +++ b/naga/benches/criterion.rs @@ -193,7 +193,6 @@ fn backends(c: &mut Criterion) { let pipeline_options = naga::back::spv::PipelineOptions { shader_stage: ep.stage, entry_point: ep.name.clone(), - constants: naga::back::PipelineConstants::default(), }; writer .write(module, info, Some(&pipeline_options), &None, &mut data) @@ -224,11 +223,10 @@ fn backends(c: &mut Criterion) { group.bench_function("hlsl", |b| { b.iter(|| { let options = naga::back::hlsl::Options::default(); - let pipeline_options = naga::back::hlsl::PipelineOptions::default(); let mut string = String::new(); for &(ref module, ref info) in inputs.iter() { let mut writer = naga::back::hlsl::Writer::new(&mut string, &options); - let _ = writer.write(module, info, &pipeline_options); // may fail on unimplemented things + let _ = writer.write(module, info); // may fail on unimplemented things string.clear(); } }); @@ -250,7 +248,6 @@ fn backends(c: &mut Criterion) { shader_stage: ep.stage, entry_point: ep.name.clone(), multiview: None, - constants: naga::back::PipelineConstants::default(), }; // might be `Err` if missing features diff --git a/naga/src/back/glsl/mod.rs b/naga/src/back/glsl/mod.rs index 13811a2df0..bede79610a 100644 --- a/naga/src/back/glsl/mod.rs +++ b/naga/src/back/glsl/mod.rs @@ -294,8 +294,6 @@ pub struct PipelineOptions { pub entry_point: String, /// How many views to render to, if doing multiview rendering. pub multiview: Option, - /// Pipeline constants. - pub constants: back::PipelineConstants, } #[derive(Debug)] @@ -499,6 +497,8 @@ pub enum Error { ImageMultipleSamplers, #[error("{0}")] Custom(String), + #[error("overrides should not be present at this stage")] + Override, } /// Binary operation with a different logic on the GLSL side. @@ -568,9 +568,7 @@ impl<'a, W: Write> Writer<'a, W> { policies: proc::BoundsCheckPolicies, ) -> Result { if !module.overrides.is_empty() { - return Err(Error::Custom( - "Pipeline constants are not yet supported for this back-end".to_string(), - )); + return Err(Error::Override); } // Check if the requested version is supported @@ -2544,7 +2542,7 @@ impl<'a, W: Write> Writer<'a, W> { |writer, expr| writer.write_expr(expr, ctx), )?; } - Expression::Override(_) => return Err(Error::Custom("overrides are WIP".into())), + Expression::Override(_) => return Err(Error::Override), // `Access` is applied to arrays, vectors and matrices and is written as indexing Expression::Access { base, index } => { self.write_expr(base, ctx)?; diff --git a/naga/src/back/hlsl/mod.rs b/naga/src/back/hlsl/mod.rs index d423b003ff..392dc2c34a 100644 --- a/naga/src/back/hlsl/mod.rs +++ b/naga/src/back/hlsl/mod.rs @@ -195,14 +195,6 @@ pub struct Options { pub zero_initialize_workgroup_memory: bool, } -#[derive(Clone, Debug, Default)] -#[cfg_attr(feature = "serialize", derive(serde::Serialize))] -#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] -pub struct PipelineOptions { - /// Pipeline constants. - pub constants: back::PipelineConstants, -} - impl Default for Options { fn default() -> Self { Options { @@ -255,8 +247,8 @@ pub enum Error { Unimplemented(String), // TODO: Error used only during development #[error("{0}")] Custom(String), - #[error(transparent)] - PipelineConstant(#[from] Box), + #[error("overrides should not be present at this stage")] + Override, } #[derive(Default)] diff --git a/naga/src/back/hlsl/writer.rs b/naga/src/back/hlsl/writer.rs index 4bde1f6486..d4c6097eb3 100644 --- a/naga/src/back/hlsl/writer.rs +++ b/naga/src/back/hlsl/writer.rs @@ -1,7 +1,7 @@ use super::{ help::{WrappedArrayLength, WrappedConstructor, WrappedImageQuery, WrappedStructMatrixAccess}, storage::StoreValue, - BackendResult, Error, Options, PipelineOptions, + BackendResult, Error, Options, }; use crate::{ back, @@ -167,16 +167,10 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { &mut self, module: &Module, module_info: &valid::ModuleInfo, - pipeline_options: &PipelineOptions, ) -> Result { - let (module, module_info) = back::pipeline_constants::process_overrides( - module, - module_info, - &pipeline_options.constants, - ) - .map_err(Box::new)?; - let module = module.as_ref(); - let module_info = module_info.as_ref(); + if !module.overrides.is_empty() { + return Err(Error::Override); + } self.reset(module); @@ -2150,9 +2144,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { |writer, expr| writer.write_expr(module, expr, func_ctx), )?; } - Expression::Override(_) => { - return Err(Error::Unimplemented("overrides are WIP".into())) - } + Expression::Override(_) => return Err(Error::Override), // All of the multiplication can be expressed as `mul`, // except vector * vector, which needs to use the "*" operator. Expression::Binary { diff --git a/naga/src/back/mod.rs b/naga/src/back/mod.rs index a95328d4fa..0c9c5e4761 100644 --- a/naga/src/back/mod.rs +++ b/naga/src/back/mod.rs @@ -22,7 +22,7 @@ pub mod wgsl; feature = "spv-out", feature = "glsl-out" ))] -mod pipeline_constants; +pub mod pipeline_constants; /// Names of vector components. pub const COMPONENTS: &[char] = &['x', 'y', 'z', 'w']; diff --git a/naga/src/back/msl/mod.rs b/naga/src/back/msl/mod.rs index 6ba8227a20..2c7cdea6af 100644 --- a/naga/src/back/msl/mod.rs +++ b/naga/src/back/msl/mod.rs @@ -143,8 +143,8 @@ pub enum Error { UnsupportedArrayOfType(Handle), #[error("ray tracing is not supported prior to MSL 2.3")] UnsupportedRayTracing, - #[error(transparent)] - PipelineConstant(#[from] Box), + #[error("overrides should not be present at this stage")] + Override, } #[derive(Clone, Debug, PartialEq, thiserror::Error)] @@ -234,8 +234,6 @@ pub struct PipelineOptions { /// /// Enable this for vertex shaders with point primitive topologies. pub allow_and_force_point_size: bool, - /// Pipeline constants. - pub constants: crate::back::PipelineConstants, } impl Options { diff --git a/naga/src/back/msl/writer.rs b/naga/src/back/msl/writer.rs index 7797bc658f..0d0f651665 100644 --- a/naga/src/back/msl/writer.rs +++ b/naga/src/back/msl/writer.rs @@ -1431,9 +1431,7 @@ impl Writer { |writer, context, expr| writer.put_expression(expr, context, true), )?; } - crate::Expression::Override(_) => { - return Err(Error::FeatureNotImplemented("overrides are WIP".into())) - } + crate::Expression::Override(_) => return Err(Error::Override), crate::Expression::Access { base, .. } | crate::Expression::AccessIndex { base, .. } => { // This is an acceptable place to generate a `ReadZeroSkipWrite` check. @@ -3223,11 +3221,9 @@ impl Writer { options: &Options, pipeline_options: &PipelineOptions, ) -> Result { - let (module, info) = - back::pipeline_constants::process_overrides(module, info, &pipeline_options.constants) - .map_err(Box::new)?; - let module = module.as_ref(); - let info = info.as_ref(); + if !module.overrides.is_empty() { + return Err(Error::Override); + } self.names.clear(); self.namer.reset( diff --git a/naga/src/back/pipeline_constants.rs b/naga/src/back/pipeline_constants.rs index c1fd2d02cc..be5760a733 100644 --- a/naga/src/back/pipeline_constants.rs +++ b/naga/src/back/pipeline_constants.rs @@ -36,7 +36,7 @@ pub enum PipelineConstantError { /// fully-evaluated expressions. /// /// [`global_expressions`]: Module::global_expressions -pub(super) fn process_overrides<'a>( +pub fn process_overrides<'a>( module: &'a Module, module_info: &'a ModuleInfo, pipeline_constants: &PipelineConstants, diff --git a/naga/src/back/spv/block.rs b/naga/src/back/spv/block.rs index dcec24d7d6..9b8430e861 100644 --- a/naga/src/back/spv/block.rs +++ b/naga/src/back/spv/block.rs @@ -239,9 +239,7 @@ impl<'w> BlockContext<'w> { let init = self.ir_module.constants[handle].init; self.writer.constant_ids[init.index()] } - crate::Expression::Override(_) => { - return Err(Error::FeatureNotImplemented("overrides are WIP")) - } + crate::Expression::Override(_) => return Err(Error::Override), crate::Expression::ZeroValue(_) => self.writer.get_constant_null(result_type_id), crate::Expression::Compose { ty, ref components } => { self.temp_list.clear(); diff --git a/naga/src/back/spv/mod.rs b/naga/src/back/spv/mod.rs index f1bbaecce1..8626bb104d 100644 --- a/naga/src/back/spv/mod.rs +++ b/naga/src/back/spv/mod.rs @@ -70,8 +70,8 @@ pub enum Error { FeatureNotImplemented(&'static str), #[error("module is not validated properly: {0}")] Validation(&'static str), - #[error(transparent)] - PipelineConstant(#[from] Box), + #[error("overrides should not be present at this stage")] + Override, } #[derive(Default)] @@ -773,8 +773,6 @@ pub struct PipelineOptions { /// /// If no entry point that matches is found while creating a [`Writer`], a error will be thrown. pub entry_point: String, - /// Pipeline constants. - pub constants: crate::back::PipelineConstants, } pub fn write_vec( diff --git a/naga/src/back/spv/writer.rs b/naga/src/back/spv/writer.rs index 0fc0227fb7..ef65ac7dad 100644 --- a/naga/src/back/spv/writer.rs +++ b/naga/src/back/spv/writer.rs @@ -2029,21 +2029,9 @@ impl Writer { debug_info: &Option, words: &mut Vec, ) -> Result<(), Error> { - let (ir_module, info) = if let Some(pipeline_options) = pipeline_options { - crate::back::pipeline_constants::process_overrides( - ir_module, - info, - &pipeline_options.constants, - ) - .map_err(Box::new)? - } else { - ( - std::borrow::Cow::Borrowed(ir_module), - std::borrow::Cow::Borrowed(info), - ) - }; - let ir_module = ir_module.as_ref(); - let info = info.as_ref(); + if !ir_module.overrides.is_empty() { + return Err(Error::Override); + } self.reset(); diff --git a/naga/src/back/wgsl/writer.rs b/naga/src/back/wgsl/writer.rs index 8005a27617..b63e16da3b 100644 --- a/naga/src/back/wgsl/writer.rs +++ b/naga/src/back/wgsl/writer.rs @@ -1205,9 +1205,7 @@ impl Writer { |writer, expr| writer.write_expr(module, expr, func_ctx), )?; } - Expression::Override(_) => { - return Err(Error::Unimplemented("overrides are WIP".into())) - } + Expression::Override(_) => unreachable!(), Expression::FunctionArgument(pos) => { let name_key = func_ctx.argument_key(pos); let name = &self.names[&name_key]; diff --git a/naga/tests/in/interface.param.ron b/naga/tests/in/interface.param.ron index 19ed5e464c..4d85661767 100644 --- a/naga/tests/in/interface.param.ron +++ b/naga/tests/in/interface.param.ron @@ -27,6 +27,5 @@ ), msl_pipeline: ( allow_and_force_point_size: true, - constants: {}, ), ) diff --git a/naga/tests/out/glsl/overrides.main.Compute.glsl b/naga/tests/out/glsl/overrides.main.Compute.glsl new file mode 100644 index 0000000000..a4e4b004bb --- /dev/null +++ b/naga/tests/out/glsl/overrides.main.Compute.glsl @@ -0,0 +1,29 @@ +#version 310 es + +precision highp float; +precision highp int; + +layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in; + +const bool has_point_light = false; +const float specular_param = 2.3; +const float gain = 1.1; +const float width = 0.0; +const float depth = 2.3; +const float height = 4.6; +const float inferred_f32_ = 2.718; + +float gain_x_10_ = 11.0; + + +void main() { + float t = 0.0; + bool x = false; + float gain_x_100_ = 0.0; + t = 23.0; + x = true; + float _e10 = gain_x_10_; + gain_x_100_ = (_e10 * 10.0); + return; +} + diff --git a/naga/tests/snapshots.rs b/naga/tests/snapshots.rs index 94c50c7975..3b12f192ab 100644 --- a/naga/tests/snapshots.rs +++ b/naga/tests/snapshots.rs @@ -349,19 +349,14 @@ fn check_targets( #[cfg(all(feature = "deserialize", feature = "msl-out"))] { if targets.contains(Targets::METAL) { - if !params.msl_pipeline.constants.is_empty() { - panic!("Supply pipeline constants via pipeline_constants instead of msl_pipeline.constants!"); - } - let mut pipeline_options = params.msl_pipeline.clone(); - pipeline_options.constants = params.pipeline_constants.clone(); - write_output_msl( input, module, &info, ¶ms.msl, - &pipeline_options, + ¶ms.msl_pipeline, params.bounds_check_policies, + ¶ms.pipeline_constants, ); } } @@ -449,25 +444,27 @@ fn write_output_spv( debug_info, }; + let (module, info) = + naga::back::pipeline_constants::process_overrides(module, info, pipeline_constants) + .expect("override evaluation failed"); + if params.separate_entry_points { for ep in module.entry_points.iter() { let pipeline_options = spv::PipelineOptions { entry_point: ep.name.clone(), shader_stage: ep.stage, - constants: pipeline_constants.clone(), }; write_output_spv_inner( input, - module, - info, + &module, + &info, &options, Some(&pipeline_options), &format!("{}.spvasm", ep.name), ); } } else { - assert!(pipeline_constants.is_empty()); - write_output_spv_inner(input, module, info, &options, None, "spvasm"); + write_output_spv_inner(input, &module, &info, &options, None, "spvasm"); } } @@ -505,14 +502,19 @@ fn write_output_msl( options: &naga::back::msl::Options, pipeline_options: &naga::back::msl::PipelineOptions, bounds_check_policies: naga::proc::BoundsCheckPolicies, + pipeline_constants: &naga::back::PipelineConstants, ) { use naga::back::msl; println!("generating MSL"); + let (module, info) = + naga::back::pipeline_constants::process_overrides(module, info, pipeline_constants) + .expect("override evaluation failed"); + let mut options = options.clone(); options.bounds_check_policies = bounds_check_policies; - let (string, tr_info) = msl::write_string(module, info, &options, pipeline_options) + let (string, tr_info) = msl::write_string(&module, &info, &options, pipeline_options) .unwrap_or_else(|err| panic!("Metal write failed: {err}")); for (ep, result) in module.entry_points.iter().zip(tr_info.entry_point_names) { @@ -545,14 +547,16 @@ fn write_output_glsl( shader_stage: stage, entry_point: ep_name.to_string(), multiview, - constants: pipeline_constants.clone(), }; let mut buffer = String::new(); + let (module, info) = + naga::back::pipeline_constants::process_overrides(module, info, pipeline_constants) + .expect("override evaluation failed"); let mut writer = glsl::Writer::new( &mut buffer, - module, - info, + &module, + &info, options, &pipeline_options, bounds_check_policies, @@ -577,17 +581,13 @@ fn write_output_hlsl( println!("generating HLSL"); + let (module, info) = + naga::back::pipeline_constants::process_overrides(module, info, pipeline_constants) + .expect("override evaluation failed"); + let mut buffer = String::new(); let mut writer = hlsl::Writer::new(&mut buffer, options); - let reflection_info = writer - .write( - module, - info, - &hlsl::PipelineOptions { - constants: pipeline_constants.clone(), - }, - ) - .expect("HLSL write failed"); + let reflection_info = writer.write(&module, &info).expect("HLSL write failed"); input.write_output_file("hlsl", "hlsl", buffer); @@ -852,7 +852,12 @@ fn convert_wgsl() { ), ( "overrides", - Targets::IR | Targets::ANALYSIS | Targets::SPIRV | Targets::METAL | Targets::HLSL, + Targets::IR + | Targets::ANALYSIS + | Targets::SPIRV + | Targets::METAL + | Targets::HLSL + | Targets::GLSL, ), ( "overrides-atomicCompareExchangeWeak", diff --git a/wgpu-hal/src/dx12/device.rs b/wgpu-hal/src/dx12/device.rs index 69a846d131..153dd6b90d 100644 --- a/wgpu-hal/src/dx12/device.rs +++ b/wgpu-hal/src/dx12/device.rs @@ -218,17 +218,21 @@ impl super::Device { use naga::back::hlsl; let stage_bit = crate::auxil::map_naga_stage(naga_stage); - let module = &stage.module.naga.module; + + let (module, info) = naga::back::pipeline_constants::process_overrides( + &stage.module.naga.module, + &stage.module.naga.info, + stage.constants, + ) + .map_err(|e| crate::PipelineError::Linkage(stage_bit, format!("HLSL: {e:?}")))?; + //TODO: reuse the writer let mut source = String::new(); let mut writer = hlsl::Writer::new(&mut source, &layout.naga_options); - let pipeline_options = hlsl::PipelineOptions { - constants: stage.constants.to_owned(), - }; let reflection_info = { profiling::scope!("naga::back::hlsl::write"); writer - .write(module, &stage.module.naga.info, &pipeline_options) + .write(&module, &info) .map_err(|e| crate::PipelineError::Linkage(stage_bit, format!("HLSL: {e:?}")))? }; diff --git a/wgpu-hal/src/gles/device.rs b/wgpu-hal/src/gles/device.rs index 171c53a93e..921941735c 100644 --- a/wgpu-hal/src/gles/device.rs +++ b/wgpu-hal/src/gles/device.rs @@ -218,12 +218,19 @@ impl super::Device { shader_stage: naga_stage, entry_point: stage.entry_point.to_string(), multiview: context.multiview, - constants: stage.constants.to_owned(), }; - let shader = &stage.module.naga; - let entry_point_index = shader - .module + let (module, info) = naga::back::pipeline_constants::process_overrides( + &stage.module.naga.module, + &stage.module.naga.info, + stage.constants, + ) + .map_err(|e| { + let msg = format!("{e}"); + crate::PipelineError::Linkage(map_naga_stage(naga_stage), msg) + })?; + + let entry_point_index = module .entry_points .iter() .position(|ep| ep.name.as_str() == stage.entry_point) @@ -250,8 +257,8 @@ impl super::Device { let mut output = String::new(); let mut writer = glsl::Writer::new( &mut output, - &shader.module, - &shader.info, + &module, + &info, &context.layout.naga_options, &pipeline_options, policies, @@ -270,8 +277,8 @@ impl super::Device { context.consume_reflection( gl, - &shader.module, - shader.info.get_entry_point(entry_point_index), + &module, + info.get_entry_point(entry_point_index), reflection_info, naga_stage, program, diff --git a/wgpu-hal/src/metal/device.rs b/wgpu-hal/src/metal/device.rs index 3826909387..377c5a483f 100644 --- a/wgpu-hal/src/metal/device.rs +++ b/wgpu-hal/src/metal/device.rs @@ -69,7 +69,13 @@ impl super::Device { ) -> Result { let stage_bit = map_naga_stage(naga_stage); - let module = &stage.module.naga.module; + let (module, module_info) = naga::back::pipeline_constants::process_overrides( + &stage.module.naga.module, + &stage.module.naga.info, + stage.constants, + ) + .map_err(|e| crate::PipelineError::Linkage(stage_bit, format!("MSL: {:?}", e)))?; + let ep_resources = &layout.per_stage_map[naga_stage]; let bounds_check_policy = if stage.module.runtime_checks { @@ -112,16 +118,11 @@ impl super::Device { metal::MTLPrimitiveTopologyClass::Point => true, _ => false, }, - constants: stage.constants.to_owned(), }; - let (source, info) = naga::back::msl::write_string( - module, - &stage.module.naga.info, - &options, - &pipeline_options, - ) - .map_err(|e| crate::PipelineError::Linkage(stage_bit, format!("MSL: {:?}", e)))?; + let (source, info) = + naga::back::msl::write_string(&module, &module_info, &options, &pipeline_options) + .map_err(|e| crate::PipelineError::Linkage(stage_bit, format!("MSL: {:?}", e)))?; log::debug!( "Naga generated shader for entry point '{}' and stage {:?}\n{}", @@ -169,7 +170,7 @@ impl super::Device { })?; // collect sizes indices, immutable buffers, and work group memory sizes - let ep_info = &stage.module.naga.info.get_entry_point(ep_index); + let ep_info = &module_info.get_entry_point(ep_index); let mut wg_memory_sizes = Vec::new(); let mut sized_bindings = Vec::new(); let mut immutable_buffer_mask = 0; diff --git a/wgpu-hal/src/vulkan/device.rs b/wgpu-hal/src/vulkan/device.rs index 989ad60c72..52b899900f 100644 --- a/wgpu-hal/src/vulkan/device.rs +++ b/wgpu-hal/src/vulkan/device.rs @@ -734,7 +734,6 @@ impl super::Device { let pipeline_options = naga::back::spv::PipelineOptions { entry_point: stage.entry_point.to_string(), shader_stage: naga_stage, - constants: stage.constants.to_owned(), }; let needs_temp_options = !runtime_checks || !binding_map.is_empty() @@ -766,14 +765,17 @@ impl super::Device { } else { &self.naga_options }; + + let (module, info) = naga::back::pipeline_constants::process_overrides( + &naga_shader.module, + &naga_shader.info, + stage.constants, + ) + .map_err(|e| crate::PipelineError::Linkage(stage_flags, format!("{e}")))?; + let spv = { profiling::scope!("naga::spv::write_vec"); - naga::back::spv::write_vec( - &naga_shader.module, - &naga_shader.info, - options, - Some(&pipeline_options), - ) + naga::back::spv::write_vec(&module, &info, options, Some(&pipeline_options)) } .map_err(|e| crate::PipelineError::Linkage(stage_flags, format!("{e}")))?; self.create_shader_module_impl(&spv)? From c104f088cc7edd0c10b8bf7f68ddd8208be1fab8 Mon Sep 17 00:00:00 2001 From: Jim Blandy Date: Tue, 26 Mar 2024 15:25:57 -0700 Subject: [PATCH 29/30] [naga wgsl-in] Allow override expressions as local var initializers. Allow `LocalVariable::init` to be an override expression. Note that this is unrelated to WGSL compliance. The WGSL front end already accepts any sort of expression as an initializer for `LocalVariable`s, but initialization by an override expression was handled in the same way as initialization by a runtime expression, via an explicit `Store` statement. This commit merely lets us skip the `Store` when the initializer is an override expression, producing slightly cleaner output in some cases. --- naga/src/back/pipeline_constants.rs | 7 ++++ naga/src/front/wgsl/lower/mod.rs | 2 +- naga/src/lib.rs | 2 +- naga/src/valid/function.rs | 8 ++-- naga/tests/out/analysis/overrides.info.ron | 22 +++-------- .../out/glsl/overrides.main.Compute.glsl | 7 ++-- naga/tests/out/hlsl/overrides.hlsl | 7 ++-- naga/tests/out/ir/overrides.compact.ron | 37 ++++++++----------- naga/tests/out/ir/overrides.ron | 37 ++++++++----------- naga/tests/out/msl/overrides.msl | 7 ++-- naga/tests/out/spv/overrides.main.spvasm | 28 +++++++------- 11 files changed, 72 insertions(+), 92 deletions(-) diff --git a/naga/src/back/pipeline_constants.rs b/naga/src/back/pipeline_constants.rs index be5760a733..50a6a3d57a 100644 --- a/naga/src/back/pipeline_constants.rs +++ b/naga/src/back/pipeline_constants.rs @@ -304,6 +304,13 @@ fn process_function( filter_emits_in_block(&mut function.body, &function.expressions); + // Update local expression initializers. + for (_, local) in function.local_variables.iter_mut() { + if let &mut Some(ref mut init) = &mut local.init { + *init = adjusted_local_expressions[init.index()]; + } + } + // We've changed the keys of `function.named_expression`, so we have to // rebuild it from scratch. let named_expressions = mem::take(&mut function.named_expressions); diff --git a/naga/src/front/wgsl/lower/mod.rs b/naga/src/front/wgsl/lower/mod.rs index 7abd95114d..77212f2086 100644 --- a/naga/src/front/wgsl/lower/mod.rs +++ b/naga/src/front/wgsl/lower/mod.rs @@ -1317,7 +1317,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { // expression, so its value depends on the // state at the point of initialization. if is_inside_loop - || !ctx.local_expression_kind_tracker.is_const(init) + || !ctx.local_expression_kind_tracker.is_const_or_override(init) { (None, Some(init)) } else { diff --git a/naga/src/lib.rs b/naga/src/lib.rs index 4b421b08fd..ceb7e55b7b 100644 --- a/naga/src/lib.rs +++ b/naga/src/lib.rs @@ -998,7 +998,7 @@ pub struct LocalVariable { /// /// This handle refers to this `LocalVariable`'s function's /// [`expressions`] arena, but it is required to be an evaluated - /// constant expression. + /// override expression. /// /// [`expressions`]: Function::expressions pub init: Option>, diff --git a/naga/src/valid/function.rs b/naga/src/valid/function.rs index b8ad63cc6d..fe5681449e 100644 --- a/naga/src/valid/function.rs +++ b/naga/src/valid/function.rs @@ -54,8 +54,8 @@ pub enum LocalVariableError { InvalidType(Handle), #[error("Initializer doesn't match the variable type")] InitializerType, - #[error("Initializer is not const")] - NonConstInitializer, + #[error("Initializer is not a const or override expression")] + NonConstOrOverrideInitializer, } #[derive(Clone, Debug, thiserror::Error)] @@ -945,8 +945,8 @@ impl super::Validator { return Err(LocalVariableError::InitializerType); } - if !local_expr_kind.is_const(init) { - return Err(LocalVariableError::NonConstInitializer); + if !local_expr_kind.is_const_or_override(init) { + return Err(LocalVariableError::NonConstOrOverrideInitializer); } } diff --git a/naga/tests/out/analysis/overrides.info.ron b/naga/tests/out/analysis/overrides.info.ron index 6ea54bb296..00d8ce1ea8 100644 --- a/naga/tests/out/analysis/overrides.info.ron +++ b/naga/tests/out/analysis/overrides.info.ron @@ -51,18 +51,6 @@ width: 4, ))), ), - ( - uniformity: ( - non_uniform_result: Some(4), - requirements: (""), - ), - ref_count: 1, - assignable_global: None, - ty: Value(Pointer( - base: 2, - space: Function, - )), - ), ( uniformity: ( non_uniform_result: None, @@ -83,7 +71,7 @@ ), ( uniformity: ( - non_uniform_result: Some(7), + non_uniform_result: Some(6), requirements: (""), ), ref_count: 1, @@ -95,7 +83,7 @@ ), ( uniformity: ( - non_uniform_result: Some(8), + non_uniform_result: Some(7), requirements: (""), ), ref_count: 1, @@ -107,7 +95,7 @@ ), ( uniformity: ( - non_uniform_result: Some(8), + non_uniform_result: Some(7), requirements: (""), ), ref_count: 1, @@ -128,7 +116,7 @@ ), ( uniformity: ( - non_uniform_result: Some(8), + non_uniform_result: Some(7), requirements: (""), ), ref_count: 1, @@ -140,7 +128,7 @@ ), ( uniformity: ( - non_uniform_result: Some(12), + non_uniform_result: Some(11), requirements: (""), ), ref_count: 1, diff --git a/naga/tests/out/glsl/overrides.main.Compute.glsl b/naga/tests/out/glsl/overrides.main.Compute.glsl index a4e4b004bb..b6d86d50ba 100644 --- a/naga/tests/out/glsl/overrides.main.Compute.glsl +++ b/naga/tests/out/glsl/overrides.main.Compute.glsl @@ -17,13 +17,12 @@ float gain_x_10_ = 11.0; void main() { - float t = 0.0; + float t = 23.0; bool x = false; float gain_x_100_ = 0.0; - t = 23.0; x = true; - float _e10 = gain_x_10_; - gain_x_100_ = (_e10 * 10.0); + float _e9 = gain_x_10_; + gain_x_100_ = (_e9 * 10.0); return; } diff --git a/naga/tests/out/hlsl/overrides.hlsl b/naga/tests/out/hlsl/overrides.hlsl index 072cd9ffcc..b0582d544e 100644 --- a/naga/tests/out/hlsl/overrides.hlsl +++ b/naga/tests/out/hlsl/overrides.hlsl @@ -11,13 +11,12 @@ static float gain_x_10_ = 11.0; [numthreads(1, 1, 1)] void main() { - float t = (float)0; + float t = 23.0; bool x = (bool)0; float gain_x_100_ = (float)0; - t = 23.0; x = true; - float _expr10 = gain_x_10_; - gain_x_100_ = (_expr10 * 10.0); + float _expr9 = gain_x_10_; + gain_x_100_ = (_expr9 * 10.0); return; } diff --git a/naga/tests/out/ir/overrides.compact.ron b/naga/tests/out/ir/overrides.compact.ron index 4188354224..bc25af3bce 100644 --- a/naga/tests/out/ir/overrides.compact.ron +++ b/naga/tests/out/ir/overrides.compact.ron @@ -109,7 +109,7 @@ ( name: Some("t"), ty: 2, - init: None, + init: Some(3), ), ( name: Some("x"), @@ -130,56 +130,51 @@ left: 1, right: 2, ), - LocalVariable(1), Override(1), Unary( op: LogicalNot, - expr: 5, + expr: 4, ), LocalVariable(2), GlobalVariable(1), Load( - pointer: 8, + pointer: 7, ), Literal(F32(10.0)), Binary( op: Multiply, - left: 9, - right: 10, + left: 8, + right: 9, ), LocalVariable(3), ], named_expressions: { - 6: "a", + 5: "a", }, body: [ Emit(( start: 2, end: 3, )), - Store( - pointer: 4, - value: 3, - ), Emit(( - start: 5, - end: 6, + start: 4, + end: 5, )), Store( - pointer: 7, - value: 6, + pointer: 6, + value: 5, ), Emit(( - start: 8, - end: 9, + start: 7, + end: 8, )), Emit(( - start: 10, - end: 11, + start: 9, + end: 10, )), Store( - pointer: 12, - value: 11, + pointer: 11, + value: 10, ), Return( value: None, diff --git a/naga/tests/out/ir/overrides.ron b/naga/tests/out/ir/overrides.ron index 4188354224..bc25af3bce 100644 --- a/naga/tests/out/ir/overrides.ron +++ b/naga/tests/out/ir/overrides.ron @@ -109,7 +109,7 @@ ( name: Some("t"), ty: 2, - init: None, + init: Some(3), ), ( name: Some("x"), @@ -130,56 +130,51 @@ left: 1, right: 2, ), - LocalVariable(1), Override(1), Unary( op: LogicalNot, - expr: 5, + expr: 4, ), LocalVariable(2), GlobalVariable(1), Load( - pointer: 8, + pointer: 7, ), Literal(F32(10.0)), Binary( op: Multiply, - left: 9, - right: 10, + left: 8, + right: 9, ), LocalVariable(3), ], named_expressions: { - 6: "a", + 5: "a", }, body: [ Emit(( start: 2, end: 3, )), - Store( - pointer: 4, - value: 3, - ), Emit(( - start: 5, - end: 6, + start: 4, + end: 5, )), Store( - pointer: 7, - value: 6, + pointer: 6, + value: 5, ), Emit(( - start: 8, - end: 9, + start: 7, + end: 8, )), Emit(( - start: 10, - end: 11, + start: 9, + end: 10, )), Store( - pointer: 12, - value: 11, + pointer: 11, + value: 10, ), Return( value: None, diff --git a/naga/tests/out/msl/overrides.msl b/naga/tests/out/msl/overrides.msl index f884d1b527..d9e95d0704 100644 --- a/naga/tests/out/msl/overrides.msl +++ b/naga/tests/out/msl/overrides.msl @@ -15,12 +15,11 @@ constant float inferred_f32_ = 2.718; kernel void main_( ) { float gain_x_10_ = 11.0; - float t = {}; + float t = 23.0; bool x = {}; float gain_x_100_ = {}; - t = 23.0; x = true; - float _e10 = gain_x_10_; - gain_x_100_ = _e10 * 10.0; + float _e9 = gain_x_10_; + gain_x_100_ = _e9 * 10.0; return; } diff --git a/naga/tests/out/spv/overrides.main.spvasm b/naga/tests/out/spv/overrides.main.spvasm index d4ce4752ed..d21eb7c674 100644 --- a/naga/tests/out/spv/overrides.main.spvasm +++ b/naga/tests/out/spv/overrides.main.spvasm @@ -1,7 +1,7 @@ ; SPIR-V ; Version: 1.0 ; Generator: rspirv -; Bound: 32 +; Bound: 31 OpCapability Shader %1 = OpExtInstImport "GLSL.std.450" OpMemoryModel Logical GLSL450 @@ -25,21 +25,19 @@ OpExecutionMode %18 LocalSize 1 1 1 %19 = OpTypeFunction %2 %20 = OpConstant %4 23.0 %22 = OpTypePointer Function %4 -%23 = OpConstantNull %4 -%25 = OpTypePointer Function %3 -%26 = OpConstantNull %3 -%28 = OpConstantNull %4 +%24 = OpTypePointer Function %3 +%25 = OpConstantNull %3 +%27 = OpConstantNull %4 %18 = OpFunction %2 None %19 %17 = OpLabel -%21 = OpVariable %22 Function %23 -%24 = OpVariable %25 Function %26 -%27 = OpVariable %22 Function %28 -OpBranch %29 -%29 = OpLabel -OpStore %21 %20 -OpStore %24 %5 -%30 = OpLoad %4 %15 -%31 = OpFMul %4 %30 %13 -OpStore %27 %31 +%21 = OpVariable %22 Function %20 +%23 = OpVariable %24 Function %25 +%26 = OpVariable %22 Function %27 +OpBranch %28 +%28 = OpLabel +OpStore %23 %5 +%29 = OpLoad %4 %15 +%30 = OpFMul %4 %29 %13 +OpStore %26 %30 OpReturn OpFunctionEnd \ No newline at end of file From 6d895482c5aa8d017c95d0ab35e332c6d965a61f Mon Sep 17 00:00:00 2001 From: teoxoy <28601907+teoxoy@users.noreply.github.com> Date: Fri, 5 Apr 2024 17:13:48 +0200 Subject: [PATCH 30/30] add changelog entry --- CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index c80fcc8c71..f7631d6fb6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -117,6 +117,8 @@ Bottom level categories: - Added `wgpu::TextureView::as_hal` - `wgpu::Texture::as_hal` now returns a user-defined type to match the other as_hal functions +- Added support for pipeline-overridable constants. By @teoxoy & @jimblandy in [#5500](https://github.com/gfx-rs/wgpu/pull/5500) + #### GLES - Log an error when GLES texture format heuristics fail. By @PolyMeilex in [#5266](https://github.com/gfx-rs/wgpu/issues/5266)