From 10d5323a80488dea43c83fac4d93c0f783bff039 Mon Sep 17 00:00:00 2001 From: Jim Blandy Date: Sun, 17 Sep 2023 17:51:14 -0700 Subject: [PATCH 1/4] ConstantEvaluator::swizzle: Handle vector concatenation, indexing. --- src/proc/constant_evaluator.rs | 83 +++++++++++++++++--- tests/in/const-exprs.wgsl | 14 ++++ tests/out/glsl/const-exprs.main.Compute.glsl | 20 +++++ tests/out/hlsl/const-exprs.hlsl | 12 +++ tests/out/hlsl/const-exprs.ron | 12 +++ tests/out/msl/const-exprs.msl | 17 ++++ tests/out/spv/const-exprs.spvasm | 51 ++++++++++++ tests/out/wgsl/const-exprs.wgsl | 13 +++ tests/snapshots.rs | 4 + 9 files changed, 214 insertions(+), 12 deletions(-) create mode 100644 tests/in/const-exprs.wgsl create mode 100644 tests/out/glsl/const-exprs.main.Compute.glsl create mode 100644 tests/out/hlsl/const-exprs.hlsl create mode 100644 tests/out/hlsl/const-exprs.ron create mode 100644 tests/out/msl/const-exprs.msl create mode 100644 tests/out/spv/const-exprs.spvasm create mode 100644 tests/out/wgsl/const-exprs.wgsl diff --git a/src/proc/constant_evaluator.rs b/src/proc/constant_evaluator.rs index 19b635072e..b2e0c8346d 100644 --- a/src/proc/constant_evaluator.rs +++ b/src/proc/constant_evaluator.rs @@ -80,6 +80,8 @@ pub enum ConstantEvaluatorError { SplatScalarOnly, #[error("Can only swizzle vector constants")] SwizzleVectorOnly, + #[error("swizzle component not present in source expression")] + SwizzleOutOfBounds, #[error("Type is not constructible")] TypeNotConstructible, #[error("Subexpression(s) are not constant")] @@ -305,20 +307,31 @@ impl ConstantEvaluator<'_> { let expr = Expression::Splat { size, value }; Ok(self.register_constant(expr, span)) } - Expression::Compose { - ty, - components: ref src_components, - } => { + Expression::Compose { ty, ref components } => { let dst_ty = get_dst_ty(ty)?; - let components = pattern + let mut flattened = [src_constant; 4]; // dummy value + let len = self + .flatten_compose(ty, components) + .zip(flattened.iter_mut()) + .map(|(component, elt)| *elt = component) + .count(); + let flattened = &flattened[..len]; + + let swizzled_components = pattern[..size as usize] .iter() - .take(size as usize) - .map(|&sc| src_components[sc as usize]) - .collect(); + .map(|&sc| { + let sc = sc as usize; + if let Some(elt) = flattened.get(sc) { + Ok(*elt) + } else { + Err(ConstantEvaluatorError::SwizzleOutOfBounds) + } + }) + .collect::>, _>>()?; let expr = Expression::Compose { ty: dst_ty, - components, + components: swizzled_components, }; Ok(self.register_constant(expr, span)) } @@ -454,9 +467,8 @@ impl ConstantEvaluator<'_> { .components() .ok_or(ConstantEvaluatorError::InvalidAccessBase)?; - components - .get(index) - .copied() + self.flatten_compose(ty, components) + .nth(index) .ok_or(ConstantEvaluatorError::InvalidAccessIndex) } _ => Err(ConstantEvaluatorError::InvalidAccessBase), @@ -827,6 +839,53 @@ impl ConstantEvaluator<'_> { self.expressions.append(expr, span) } + + /// Return an iterator over the individual components assembled by a + /// `Compose` expression. + /// + /// Given `ty` and `components` from an `Expression::Compose`, return an + /// iterator over the components of the resulting value. + /// + /// Normally, this would just be an iterator over `components`. However, + /// `Compose` expressions can concatenate vectors, in which case the i'th + /// value being composed is not generally the i'th element of `components`. + /// This function consults `ty` to decide if this concatenation is occuring, + /// and returns an iterator that produces the components of the result of + /// the `Compose` expression in either case. + fn flatten_compose<'c>( + &'c self, + ty: Handle, + components: &'c [Handle], + ) -> impl Iterator> + 'c { + // Returning `impl Iterator` is a bit tricky. We may or may not want to + // flatten the components, but we have to settle on a single concrete + // type to return. The below is a single iterator chain that handles + // both the flattening and non-flattening cases. + let (size, is_vector) = if let TypeInner::Vector { size, .. } = self.types[ty].inner { + (size as usize, true) + } else { + (components.len(), false) + }; + + components + .iter() + .flat_map(move |component| { + if let ( + true, + &Expression::Compose { + ty: _, + components: ref subcomponents, + }, + ) = (is_vector, &self.expressions[*component]) + { + subcomponents + } else { + std::slice::from_ref(component) + } + }) + .take(size) + .cloned() + } } /// Helper function to implement the GLSL `max` function for floats. diff --git a/tests/in/const-exprs.wgsl b/tests/in/const-exprs.wgsl new file mode 100644 index 0000000000..d51deee9ef --- /dev/null +++ b/tests/in/const-exprs.wgsl @@ -0,0 +1,14 @@ +@group(0) @binding(0) +var out: vec4; + +@group(0) @binding(1) +var out2: i32; + +@compute @workgroup_size(1) +fn main() { + let a = vec2(1, 2); + let b = vec2(3, 4); + out = vec4(a, b).wzyx; + + out2 = vec4(a, b)[1]; +} diff --git a/tests/out/glsl/const-exprs.main.Compute.glsl b/tests/out/glsl/const-exprs.main.Compute.glsl new file mode 100644 index 0000000000..9918cd68ab --- /dev/null +++ b/tests/out/glsl/const-exprs.main.Compute.glsl @@ -0,0 +1,20 @@ +#version 310 es + +precision highp float; +precision highp int; + +layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in; + +layout(std430) buffer type_block_0Compute { ivec4 _group_0_binding_0_cs; }; + +layout(std430) buffer type_1_block_1Compute { int _group_0_binding_1_cs; }; + + +void main() { + ivec2 a = ivec2(1, 2); + ivec2 b = ivec2(3, 4); + _group_0_binding_0_cs = ivec4(4, 3, 2, 1); + _group_0_binding_1_cs = 2; + return; +} + diff --git a/tests/out/hlsl/const-exprs.hlsl b/tests/out/hlsl/const-exprs.hlsl new file mode 100644 index 0000000000..f483a9b56d --- /dev/null +++ b/tests/out/hlsl/const-exprs.hlsl @@ -0,0 +1,12 @@ +RWByteAddressBuffer out_ : register(u0); +RWByteAddressBuffer out2_ : register(u1); + +[numthreads(1, 1, 1)] +void main() +{ + int2 a = int2(1, 2); + int2 b = int2(3, 4); + out_.Store4(0, asuint(int4(4, 3, 2, 1))); + out2_.Store(0, asuint(2)); + return; +} diff --git a/tests/out/hlsl/const-exprs.ron b/tests/out/hlsl/const-exprs.ron new file mode 100644 index 0000000000..a07b03300b --- /dev/null +++ b/tests/out/hlsl/const-exprs.ron @@ -0,0 +1,12 @@ +( + vertex:[ + ], + fragment:[ + ], + compute:[ + ( + entry_point:"main", + target_profile:"cs_5_1", + ), + ], +) diff --git a/tests/out/msl/const-exprs.msl b/tests/out/msl/const-exprs.msl new file mode 100644 index 0000000000..d8f38b4fc2 --- /dev/null +++ b/tests/out/msl/const-exprs.msl @@ -0,0 +1,17 @@ +// language: metal2.0 +#include +#include + +using metal::uint; + + +kernel void main_( + device metal::int4& out [[user(fake0)]] +, device int& out2_ [[user(fake0)]] +) { + metal::int2 a = metal::int2(1, 2); + metal::int2 b = metal::int2(3, 4); + out = metal::int4(4, 3, 2, 1); + out2_ = 2; + return; +} diff --git a/tests/out/spv/const-exprs.spvasm b/tests/out/spv/const-exprs.spvasm new file mode 100644 index 0000000000..fa9fb4fd51 --- /dev/null +++ b/tests/out/spv/const-exprs.spvasm @@ -0,0 +1,51 @@ +; SPIR-V +; Version: 1.1 +; Generator: rspirv +; Bound: 30 +OpCapability Shader +OpExtension "SPV_KHR_storage_buffer_storage_class" +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint GLCompute %15 "main" +OpExecutionMode %15 LocalSize 1 1 1 +OpDecorate %8 DescriptorSet 0 +OpDecorate %8 Binding 0 +OpDecorate %9 Block +OpMemberDecorate %9 0 Offset 0 +OpDecorate %11 DescriptorSet 0 +OpDecorate %11 Binding 1 +OpDecorate %12 Block +OpMemberDecorate %12 0 Offset 0 +%2 = OpTypeVoid +%4 = OpTypeInt 32 1 +%3 = OpTypeVector %4 4 +%5 = OpTypeVector %4 2 +%6 = OpConstant %4 0 +%7 = OpConstant %4 1 +%9 = OpTypeStruct %3 +%10 = OpTypePointer StorageBuffer %9 +%8 = OpVariable %10 StorageBuffer +%12 = OpTypeStruct %4 +%13 = OpTypePointer StorageBuffer %12 +%11 = OpVariable %13 StorageBuffer +%16 = OpTypeFunction %2 +%17 = OpTypePointer StorageBuffer %3 +%19 = OpTypeInt 32 0 +%18 = OpConstant %19 0 +%21 = OpTypePointer StorageBuffer %4 +%23 = OpConstant %4 2 +%24 = OpConstantComposite %5 %7 %23 +%25 = OpConstant %4 3 +%26 = OpConstant %4 4 +%27 = OpConstantComposite %5 %25 %26 +%28 = OpConstantComposite %3 %26 %25 %23 %7 +%15 = OpFunction %2 None %16 +%14 = OpLabel +%20 = OpAccessChain %17 %8 %18 +%22 = OpAccessChain %21 %11 %18 +OpBranch %29 +%29 = OpLabel +OpStore %20 %28 +OpStore %22 %23 +OpReturn +OpFunctionEnd \ No newline at end of file diff --git a/tests/out/wgsl/const-exprs.wgsl b/tests/out/wgsl/const-exprs.wgsl new file mode 100644 index 0000000000..b58339b8af --- /dev/null +++ b/tests/out/wgsl/const-exprs.wgsl @@ -0,0 +1,13 @@ +@group(0) @binding(0) +var out: vec4; +@group(0) @binding(1) +var out2_: i32; + +@compute @workgroup_size(1, 1, 1) +fn main() { + let a = vec2(1, 2); + let b = vec2(3, 4); + out = vec4(4, 3, 2, 1); + out2_ = 2; + return; +} diff --git a/tests/snapshots.rs b/tests/snapshots.rs index dce0a7edf9..95a4137a8a 100644 --- a/tests/snapshots.rs +++ b/tests/snapshots.rs @@ -616,6 +616,10 @@ fn convert_wgsl() { "constructors", Targets::SPIRV | Targets::METAL | Targets::GLSL | Targets::HLSL | Targets::WGSL, ), + ( + "const-exprs", + Targets::SPIRV | Targets::METAL | Targets::GLSL | Targets::HLSL | Targets::WGSL, + ), ]; for &(name, targets) in inputs.iter() { From d2468a17e2d79efce19fc72b12e5b47a9639cd43 Mon Sep 17 00:00:00 2001 From: Jim Blandy Date: Tue, 19 Sep 2023 10:50:12 -0700 Subject: [PATCH 2/4] Handle vector Compose expressions nested two deep. --- src/proc/constant_evaluator.rs | 35 +++++--- tests/in/const-exprs.wgsl | 10 +-- tests/out/glsl/const-exprs.main.Compute.glsl | 3 + tests/out/hlsl/const-exprs.hlsl | 2 + tests/out/msl/const-exprs.msl | 2 + tests/out/spv/const-exprs.spvasm | 92 ++++++++++++-------- tests/out/wgsl/const-exprs.wgsl | 3 + 7 files changed, 90 insertions(+), 57 deletions(-) diff --git a/src/proc/constant_evaluator.rs b/src/proc/constant_evaluator.rs index b2e0c8346d..0f4b039b18 100644 --- a/src/proc/constant_evaluator.rs +++ b/src/proc/constant_evaluator.rs @@ -867,22 +867,29 @@ impl ConstantEvaluator<'_> { (components.len(), false) }; - components - .iter() - .flat_map(move |component| { - if let ( - true, - &Expression::Compose { - ty: _, - components: ref subcomponents, - }, - ) = (is_vector, &self.expressions[*component]) + fn flattener<'c>( + component: &'c Handle, + is_vector: bool, + expressions: &'c Arena, + ) -> &'c [Handle] { + if is_vector { + if let Expression::Compose { + ty: _, + components: ref subcomponents, + } = expressions[*component] { - subcomponents - } else { - std::slice::from_ref(component) + return subcomponents; } - }) + } + std::slice::from_ref(component) + } + + // Expressions like `vec4(vec3(vec2(6, 7), 8), 9)` require us to flatten + // two levels. + components + .iter() + .flat_map(move |component| flattener(component, is_vector, self.expressions)) + .flat_map(move |component| flattener(component, is_vector, self.expressions)) .take(size) .cloned() } diff --git a/tests/in/const-exprs.wgsl b/tests/in/const-exprs.wgsl index d51deee9ef..c89e61d499 100644 --- a/tests/in/const-exprs.wgsl +++ b/tests/in/const-exprs.wgsl @@ -1,8 +1,6 @@ -@group(0) @binding(0) -var out: vec4; - -@group(0) @binding(1) -var out2: i32; +@group(0) @binding(0) var out: vec4; +@group(0) @binding(1) var out2: i32; +@group(0) @binding(2) var out3: i32; @compute @workgroup_size(1) fn main() { @@ -11,4 +9,6 @@ fn main() { out = vec4(a, b).wzyx; out2 = vec4(a, b)[1]; + + out3 = vec4(vec3(vec2(6, 7), 8), 9)[0]; } diff --git a/tests/out/glsl/const-exprs.main.Compute.glsl b/tests/out/glsl/const-exprs.main.Compute.glsl index 9918cd68ab..ff634004ca 100644 --- a/tests/out/glsl/const-exprs.main.Compute.glsl +++ b/tests/out/glsl/const-exprs.main.Compute.glsl @@ -9,12 +9,15 @@ layout(std430) buffer type_block_0Compute { ivec4 _group_0_binding_0_cs; }; layout(std430) buffer type_1_block_1Compute { int _group_0_binding_1_cs; }; +layout(std430) buffer type_1_block_2Compute { int _group_0_binding_2_cs; }; + void main() { ivec2 a = ivec2(1, 2); ivec2 b = ivec2(3, 4); _group_0_binding_0_cs = ivec4(4, 3, 2, 1); _group_0_binding_1_cs = 2; + _group_0_binding_2_cs = 6; return; } diff --git a/tests/out/hlsl/const-exprs.hlsl b/tests/out/hlsl/const-exprs.hlsl index f483a9b56d..f6faee1d40 100644 --- a/tests/out/hlsl/const-exprs.hlsl +++ b/tests/out/hlsl/const-exprs.hlsl @@ -1,5 +1,6 @@ RWByteAddressBuffer out_ : register(u0); RWByteAddressBuffer out2_ : register(u1); +RWByteAddressBuffer out3_ : register(u2); [numthreads(1, 1, 1)] void main() @@ -8,5 +9,6 @@ void main() int2 b = int2(3, 4); out_.Store4(0, asuint(int4(4, 3, 2, 1))); out2_.Store(0, asuint(2)); + out3_.Store(0, asuint(6)); return; } diff --git a/tests/out/msl/const-exprs.msl b/tests/out/msl/const-exprs.msl index d8f38b4fc2..19b9b727fb 100644 --- a/tests/out/msl/const-exprs.msl +++ b/tests/out/msl/const-exprs.msl @@ -8,10 +8,12 @@ using metal::uint; kernel void main_( device metal::int4& out [[user(fake0)]] , device int& out2_ [[user(fake0)]] +, device int& out3_ [[user(fake0)]] ) { metal::int2 a = metal::int2(1, 2); metal::int2 b = metal::int2(3, 4); out = metal::int4(4, 3, 2, 1); out2_ = 2; + out3_ = 6; return; } diff --git a/tests/out/spv/const-exprs.spvasm b/tests/out/spv/const-exprs.spvasm index fa9fb4fd51..ae26ded97e 100644 --- a/tests/out/spv/const-exprs.spvasm +++ b/tests/out/spv/const-exprs.spvasm @@ -1,51 +1,67 @@ ; SPIR-V ; Version: 1.1 ; Generator: rspirv -; Bound: 30 +; Bound: 41 OpCapability Shader OpExtension "SPV_KHR_storage_buffer_storage_class" %1 = OpExtInstImport "GLSL.std.450" OpMemoryModel Logical GLSL450 -OpEntryPoint GLCompute %15 "main" -OpExecutionMode %15 LocalSize 1 1 1 -OpDecorate %8 DescriptorSet 0 -OpDecorate %8 Binding 0 -OpDecorate %9 Block -OpMemberDecorate %9 0 Offset 0 -OpDecorate %11 DescriptorSet 0 -OpDecorate %11 Binding 1 -OpDecorate %12 Block -OpMemberDecorate %12 0 Offset 0 +OpEntryPoint GLCompute %20 "main" +OpExecutionMode %20 LocalSize 1 1 1 +OpDecorate %10 DescriptorSet 0 +OpDecorate %10 Binding 0 +OpDecorate %11 Block +OpMemberDecorate %11 0 Offset 0 +OpDecorate %13 DescriptorSet 0 +OpDecorate %13 Binding 1 +OpDecorate %14 Block +OpMemberDecorate %14 0 Offset 0 +OpDecorate %16 DescriptorSet 0 +OpDecorate %16 Binding 2 +OpDecorate %17 Block +OpMemberDecorate %17 0 Offset 0 %2 = OpTypeVoid %4 = OpTypeInt 32 1 %3 = OpTypeVector %4 4 %5 = OpTypeVector %4 2 -%6 = OpConstant %4 0 -%7 = OpConstant %4 1 -%9 = OpTypeStruct %3 -%10 = OpTypePointer StorageBuffer %9 -%8 = OpVariable %10 StorageBuffer -%12 = OpTypeStruct %4 -%13 = OpTypePointer StorageBuffer %12 -%11 = OpVariable %13 StorageBuffer -%16 = OpTypeFunction %2 -%17 = OpTypePointer StorageBuffer %3 -%19 = OpTypeInt 32 0 -%18 = OpConstant %19 0 -%21 = OpTypePointer StorageBuffer %4 -%23 = OpConstant %4 2 -%24 = OpConstantComposite %5 %7 %23 -%25 = OpConstant %4 3 -%26 = OpConstant %4 4 -%27 = OpConstantComposite %5 %25 %26 -%28 = OpConstantComposite %3 %26 %25 %23 %7 -%15 = OpFunction %2 None %16 -%14 = OpLabel -%20 = OpAccessChain %17 %8 %18 -%22 = OpAccessChain %21 %11 %18 -OpBranch %29 -%29 = OpLabel -OpStore %20 %28 -OpStore %22 %23 +%6 = OpTypeVector %4 3 +%7 = OpConstant %4 0 +%8 = OpConstant %4 1 +%9 = OpConstant %4 2 +%11 = OpTypeStruct %3 +%12 = OpTypePointer StorageBuffer %11 +%10 = OpVariable %12 StorageBuffer +%14 = OpTypeStruct %4 +%15 = OpTypePointer StorageBuffer %14 +%13 = OpVariable %15 StorageBuffer +%17 = OpTypeStruct %4 +%18 = OpTypePointer StorageBuffer %17 +%16 = OpVariable %18 StorageBuffer +%21 = OpTypeFunction %2 +%22 = OpTypePointer StorageBuffer %3 +%24 = OpTypeInt 32 0 +%23 = OpConstant %24 0 +%26 = OpTypePointer StorageBuffer %4 +%29 = OpConstantComposite %5 %8 %9 +%30 = OpConstant %4 3 +%31 = OpConstant %4 4 +%32 = OpConstantComposite %5 %30 %31 +%33 = OpConstantComposite %3 %31 %30 %9 %8 +%34 = OpConstant %4 6 +%35 = OpConstant %4 7 +%36 = OpConstantComposite %5 %34 %35 +%37 = OpConstant %4 8 +%38 = OpConstantComposite %6 %36 %37 +%39 = OpConstant %4 9 +%20 = OpFunction %2 None %21 +%19 = OpLabel +%25 = OpAccessChain %22 %10 %23 +%27 = OpAccessChain %26 %13 %23 +%28 = OpAccessChain %26 %16 %23 +OpBranch %40 +%40 = OpLabel +OpStore %25 %33 +OpStore %27 %9 +OpStore %28 %34 OpReturn OpFunctionEnd \ No newline at end of file diff --git a/tests/out/wgsl/const-exprs.wgsl b/tests/out/wgsl/const-exprs.wgsl index b58339b8af..201535836e 100644 --- a/tests/out/wgsl/const-exprs.wgsl +++ b/tests/out/wgsl/const-exprs.wgsl @@ -2,6 +2,8 @@ var out: vec4; @group(0) @binding(1) var out2_: i32; +@group(0) @binding(2) +var out3_: i32; @compute @workgroup_size(1, 1, 1) fn main() { @@ -9,5 +11,6 @@ fn main() { let b = vec2(3, 4); out = vec4(4, 3, 2, 1); out2_ = 2; + out3_ = 6; return; } From f562f3a0af4745dd859bc96fe942a5400746b0e4 Mon Sep 17 00:00:00 2001 From: Jim Blandy Date: Tue, 19 Sep 2023 12:13:56 -0700 Subject: [PATCH 3/4] Move `flatten_compose` to `proc`, and make it a free function. --- src/proc/constant_evaluator.rs | 66 ++++------------------------------ src/proc/mod.rs | 55 ++++++++++++++++++++++++++++ 2 files changed, 61 insertions(+), 60 deletions(-) diff --git a/src/proc/constant_evaluator.rs b/src/proc/constant_evaluator.rs index 0f4b039b18..a2ebd53ea4 100644 --- a/src/proc/constant_evaluator.rs +++ b/src/proc/constant_evaluator.rs @@ -311,11 +311,11 @@ impl ConstantEvaluator<'_> { let dst_ty = get_dst_ty(ty)?; let mut flattened = [src_constant; 4]; // dummy value - let len = self - .flatten_compose(ty, components) - .zip(flattened.iter_mut()) - .map(|(component, elt)| *elt = component) - .count(); + let len = + crate::proc::flatten_compose(ty, components, self.expressions, self.types) + .zip(flattened.iter_mut()) + .map(|(component, elt)| *elt = component) + .count(); let flattened = &flattened[..len]; let swizzled_components = pattern[..size as usize] @@ -467,7 +467,7 @@ impl ConstantEvaluator<'_> { .components() .ok_or(ConstantEvaluatorError::InvalidAccessBase)?; - self.flatten_compose(ty, components) + crate::proc::flatten_compose(ty, components, self.expressions, self.types) .nth(index) .ok_or(ConstantEvaluatorError::InvalidAccessIndex) } @@ -839,60 +839,6 @@ impl ConstantEvaluator<'_> { self.expressions.append(expr, span) } - - /// Return an iterator over the individual components assembled by a - /// `Compose` expression. - /// - /// Given `ty` and `components` from an `Expression::Compose`, return an - /// iterator over the components of the resulting value. - /// - /// Normally, this would just be an iterator over `components`. However, - /// `Compose` expressions can concatenate vectors, in which case the i'th - /// value being composed is not generally the i'th element of `components`. - /// This function consults `ty` to decide if this concatenation is occuring, - /// and returns an iterator that produces the components of the result of - /// the `Compose` expression in either case. - fn flatten_compose<'c>( - &'c self, - ty: Handle, - components: &'c [Handle], - ) -> impl Iterator> + 'c { - // Returning `impl Iterator` is a bit tricky. We may or may not want to - // flatten the components, but we have to settle on a single concrete - // type to return. The below is a single iterator chain that handles - // both the flattening and non-flattening cases. - let (size, is_vector) = if let TypeInner::Vector { size, .. } = self.types[ty].inner { - (size as usize, true) - } else { - (components.len(), false) - }; - - fn flattener<'c>( - component: &'c Handle, - is_vector: bool, - expressions: &'c Arena, - ) -> &'c [Handle] { - if is_vector { - if let Expression::Compose { - ty: _, - components: ref subcomponents, - } = expressions[*component] - { - return subcomponents; - } - } - std::slice::from_ref(component) - } - - // Expressions like `vec4(vec3(vec2(6, 7), 8), 9)` require us to flatten - // two levels. - components - .iter() - .flat_map(move |component| flattener(component, is_vector, self.expressions)) - .flat_map(move |component| flattener(component, is_vector, self.expressions)) - .take(size) - .cloned() - } } /// Helper function to implement the GLSL `max` function for floats. diff --git a/src/proc/mod.rs b/src/proc/mod.rs index 35b88537c1..405fb75394 100644 --- a/src/proc/mod.rs +++ b/src/proc/mod.rs @@ -655,6 +655,61 @@ impl GlobalCtx<'_> { } } +/// Return an iterator over the individual components assembled by a +/// `Compose` expression. +/// +/// Given `ty` and `components` from an `Expression::Compose`, return an +/// iterator over the components of the resulting value. +/// +/// Normally, this would just be an iterator over `components`. However, +/// `Compose` expressions can concatenate vectors, in which case the i'th +/// value being composed is not generally the i'th element of `components`. +/// This function consults `ty` to decide if this concatenation is occuring, +/// and returns an iterator that produces the components of the result of +/// the `Compose` expression in either case. +pub fn flatten_compose<'arenas>( + ty: crate::Handle, + components: &'arenas [crate::Handle], + expressions: &'arenas crate::Arena, + types: &'arenas crate::UniqueArena, +) -> impl Iterator> + 'arenas { + // Returning `impl Iterator` is a bit tricky. We may or may not want to + // flatten the components, but we have to settle on a single concrete + // type to return. The below is a single iterator chain that handles + // both the flattening and non-flattening cases. + let (size, is_vector) = if let crate::TypeInner::Vector { size, .. } = types[ty].inner { + (size as usize, true) + } else { + (components.len(), false) + }; + + fn flattener<'c>( + component: &'c crate::Handle, + is_vector: bool, + expressions: &'c crate::Arena, + ) -> &'c [crate::Handle] { + if is_vector { + if let crate::Expression::Compose { + ty: _, + components: ref subcomponents, + } = expressions[*component] + { + return subcomponents; + } + } + std::slice::from_ref(component) + } + + // Expressions like `vec4(vec3(vec2(6, 7), 8), 9)` require us to flatten + // two levels. + components + .iter() + .flat_map(move |component| flattener(component, is_vector, expressions)) + .flat_map(move |component| flattener(component, is_vector, expressions)) + .take(size) + .cloned() +} + #[test] fn test_matrix_size() { let module = crate::Module::default(); From 01704e193ba2b09a9be723a4cdbedb9f01e22589 Mon Sep 17 00:00:00 2001 From: Jim Blandy Date: Tue, 19 Sep 2023 12:14:47 -0700 Subject: [PATCH 4/4] [spv-out] Ensure that we flatten Compose for OpConstantCompose. --- src/back/spv/block.rs | 27 +++++++++++++++------------ src/back/spv/writer.rs | 12 ++++++++---- tests/out/spv/const-exprs.spvasm | 2 +- 3 files changed, 24 insertions(+), 17 deletions(-) diff --git a/src/back/spv/block.rs b/src/back/spv/block.rs index d698931f6d..342a3805f5 100644 --- a/src/back/spv/block.rs +++ b/src/back/spv/block.rs @@ -243,21 +243,24 @@ impl<'w> BlockContext<'w> { self.writer.constant_ids[init.index()] } crate::Expression::ZeroValue(_) => self.writer.get_constant_null(result_type_id), - crate::Expression::Compose { - ty: _, - ref components, - } => { + crate::Expression::Compose { ty, ref components } => { self.temp_list.clear(); - for &component in components { - self.temp_list.push(self.cached[component]); - } - if self.ir_function.expressions.is_const(expr_handle) { - let ty = self - .writer - .get_expression_lookup_type(&self.fun_info[expr_handle].ty); - self.writer.get_constant_composite(ty, &self.temp_list) + self.temp_list.extend( + crate::proc::flatten_compose( + ty, + components, + &self.ir_function.expressions, + &self.ir_module.types, + ) + .map(|component| self.cached[component]), + ); + self.writer + .get_constant_composite(LookupType::Handle(ty), &self.temp_list) } else { + self.temp_list + .extend(components.iter().map(|&component| self.cached[component])); + let id = self.gen_id(); block.body.push(Instruction::composite_construct( result_type_id, diff --git a/src/back/spv/writer.rs b/src/back/spv/writer.rs index be0cc02e59..a0b068f0e6 100644 --- a/src/back/spv/writer.rs +++ b/src/back/spv/writer.rs @@ -1269,10 +1269,14 @@ impl Writer { self.get_constant_null(type_id) } crate::Expression::Compose { ty, ref components } => { - let component_ids: Vec<_> = components - .iter() - .map(|component| self.constant_ids[component.index()]) - .collect(); + let component_ids: Vec<_> = crate::proc::flatten_compose( + ty, + components, + &ir_module.const_expressions, + &ir_module.types, + ) + .map(|component| self.constant_ids[component.index()]) + .collect(); self.get_constant_composite(LookupType::Handle(ty), component_ids.as_slice()) } crate::Expression::Splat { size, value } => { diff --git a/tests/out/spv/const-exprs.spvasm b/tests/out/spv/const-exprs.spvasm index ae26ded97e..673ab1fb00 100644 --- a/tests/out/spv/const-exprs.spvasm +++ b/tests/out/spv/const-exprs.spvasm @@ -51,7 +51,7 @@ OpMemberDecorate %17 0 Offset 0 %35 = OpConstant %4 7 %36 = OpConstantComposite %5 %34 %35 %37 = OpConstant %4 8 -%38 = OpConstantComposite %6 %36 %37 +%38 = OpConstantComposite %6 %34 %35 %37 %39 = OpConstant %4 9 %20 = OpFunction %2 None %21 %19 = OpLabel