diff --git a/src/back/spv/block.rs b/src/back/spv/block.rs index d698931f6d..342a3805f5 100644 --- a/src/back/spv/block.rs +++ b/src/back/spv/block.rs @@ -243,21 +243,24 @@ impl<'w> BlockContext<'w> { self.writer.constant_ids[init.index()] } crate::Expression::ZeroValue(_) => self.writer.get_constant_null(result_type_id), - crate::Expression::Compose { - ty: _, - ref components, - } => { + crate::Expression::Compose { ty, ref components } => { self.temp_list.clear(); - for &component in components { - self.temp_list.push(self.cached[component]); - } - if self.ir_function.expressions.is_const(expr_handle) { - let ty = self - .writer - .get_expression_lookup_type(&self.fun_info[expr_handle].ty); - self.writer.get_constant_composite(ty, &self.temp_list) + self.temp_list.extend( + crate::proc::flatten_compose( + ty, + components, + &self.ir_function.expressions, + &self.ir_module.types, + ) + .map(|component| self.cached[component]), + ); + self.writer + .get_constant_composite(LookupType::Handle(ty), &self.temp_list) } else { + self.temp_list + .extend(components.iter().map(|&component| self.cached[component])); + let id = self.gen_id(); block.body.push(Instruction::composite_construct( result_type_id, diff --git a/src/back/spv/writer.rs b/src/back/spv/writer.rs index 7d79377786..0f4809565c 100644 --- a/src/back/spv/writer.rs +++ b/src/back/spv/writer.rs @@ -1269,10 +1269,14 @@ impl Writer { self.get_constant_null(type_id) } crate::Expression::Compose { ty, ref components } => { - let component_ids: Vec<_> = components - .iter() - .map(|component| self.constant_ids[component.index()]) - .collect(); + let component_ids: Vec<_> = crate::proc::flatten_compose( + ty, + components, + &ir_module.const_expressions, + &ir_module.types, + ) + .map(|component| self.constant_ids[component.index()]) + .collect(); self.get_constant_composite(LookupType::Handle(ty), component_ids.as_slice()) } crate::Expression::Splat { size, value } => { diff --git a/src/proc/constant_evaluator.rs b/src/proc/constant_evaluator.rs index 6b917c4d67..9d2055ff6a 100644 --- a/src/proc/constant_evaluator.rs +++ b/src/proc/constant_evaluator.rs @@ -73,6 +73,8 @@ pub enum ConstantEvaluatorError { SplatScalarOnly, #[error("Can only swizzle vector constants")] SwizzleVectorOnly, + #[error("swizzle component not present in source expression")] + SwizzleOutOfBounds, #[error("Type is not constructible")] TypeNotConstructible, #[error("Subexpression(s) are not constant")] @@ -306,20 +308,31 @@ impl ConstantEvaluator<'_> { let expr = Expression::Splat { size, value }; Ok(self.register_evaluated_expr(expr, span)) } - Expression::Compose { - ty, - components: ref src_components, - } => { + Expression::Compose { ty, ref components } => { let dst_ty = get_dst_ty(ty)?; - let components = pattern + let mut flattened = [src_constant; 4]; // dummy value + let len = + crate::proc::flatten_compose(ty, components, self.expressions, self.types) + .zip(flattened.iter_mut()) + .map(|(component, elt)| *elt = component) + .count(); + let flattened = &flattened[..len]; + + let swizzled_components = pattern[..size as usize] .iter() - .take(size as usize) - .map(|&sc| src_components[sc as usize]) - .collect(); + .map(|&sc| { + let sc = sc as usize; + if let Some(elt) = flattened.get(sc) { + Ok(*elt) + } else { + Err(ConstantEvaluatorError::SwizzleOutOfBounds) + } + }) + .collect::>, _>>()?; let expr = Expression::Compose { ty: dst_ty, - components, + components: swizzled_components, }; Ok(self.register_evaluated_expr(expr, span)) } @@ -455,9 +468,8 @@ impl ConstantEvaluator<'_> { .components() .ok_or(ConstantEvaluatorError::InvalidAccessBase)?; - components - .get(index) - .copied() + crate::proc::flatten_compose(ty, components, self.expressions, self.types) + .nth(index) .ok_or(ConstantEvaluatorError::InvalidAccessIndex) } _ => Err(ConstantEvaluatorError::InvalidAccessBase), diff --git a/src/proc/mod.rs b/src/proc/mod.rs index b654f5c4b2..cb08ce49e6 100644 --- a/src/proc/mod.rs +++ b/src/proc/mod.rs @@ -656,6 +656,61 @@ impl GlobalCtx<'_> { } } +/// Return an iterator over the individual components assembled by a +/// `Compose` expression. +/// +/// Given `ty` and `components` from an `Expression::Compose`, return an +/// iterator over the components of the resulting value. +/// +/// Normally, this would just be an iterator over `components`. However, +/// `Compose` expressions can concatenate vectors, in which case the i'th +/// value being composed is not generally the i'th element of `components`. +/// This function consults `ty` to decide if this concatenation is occuring, +/// and returns an iterator that produces the components of the result of +/// the `Compose` expression in either case. +pub fn flatten_compose<'arenas>( + ty: crate::Handle, + components: &'arenas [crate::Handle], + expressions: &'arenas crate::Arena, + types: &'arenas crate::UniqueArena, +) -> impl Iterator> + 'arenas { + // Returning `impl Iterator` is a bit tricky. We may or may not want to + // flatten the components, but we have to settle on a single concrete + // type to return. The below is a single iterator chain that handles + // both the flattening and non-flattening cases. + let (size, is_vector) = if let crate::TypeInner::Vector { size, .. } = types[ty].inner { + (size as usize, true) + } else { + (components.len(), false) + }; + + fn flattener<'c>( + component: &'c crate::Handle, + is_vector: bool, + expressions: &'c crate::Arena, + ) -> &'c [crate::Handle] { + if is_vector { + if let crate::Expression::Compose { + ty: _, + components: ref subcomponents, + } = expressions[*component] + { + return subcomponents; + } + } + std::slice::from_ref(component) + } + + // Expressions like `vec4(vec3(vec2(6, 7), 8), 9)` require us to flatten + // two levels. + components + .iter() + .flat_map(move |component| flattener(component, is_vector, expressions)) + .flat_map(move |component| flattener(component, is_vector, expressions)) + .take(size) + .cloned() +} + #[test] fn test_matrix_size() { let module = crate::Module::default(); diff --git a/tests/in/const-exprs.wgsl b/tests/in/const-exprs.wgsl new file mode 100644 index 0000000000..c89e61d499 --- /dev/null +++ b/tests/in/const-exprs.wgsl @@ -0,0 +1,14 @@ +@group(0) @binding(0) var out: vec4; +@group(0) @binding(1) var out2: i32; +@group(0) @binding(2) var out3: i32; + +@compute @workgroup_size(1) +fn main() { + let a = vec2(1, 2); + let b = vec2(3, 4); + out = vec4(a, b).wzyx; + + out2 = vec4(a, b)[1]; + + out3 = vec4(vec3(vec2(6, 7), 8), 9)[0]; +} diff --git a/tests/out/glsl/const-exprs.main.Compute.glsl b/tests/out/glsl/const-exprs.main.Compute.glsl new file mode 100644 index 0000000000..ff634004ca --- /dev/null +++ b/tests/out/glsl/const-exprs.main.Compute.glsl @@ -0,0 +1,23 @@ +#version 310 es + +precision highp float; +precision highp int; + +layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in; + +layout(std430) buffer type_block_0Compute { ivec4 _group_0_binding_0_cs; }; + +layout(std430) buffer type_1_block_1Compute { int _group_0_binding_1_cs; }; + +layout(std430) buffer type_1_block_2Compute { int _group_0_binding_2_cs; }; + + +void main() { + ivec2 a = ivec2(1, 2); + ivec2 b = ivec2(3, 4); + _group_0_binding_0_cs = ivec4(4, 3, 2, 1); + _group_0_binding_1_cs = 2; + _group_0_binding_2_cs = 6; + return; +} + diff --git a/tests/out/hlsl/const-exprs.hlsl b/tests/out/hlsl/const-exprs.hlsl new file mode 100644 index 0000000000..f6faee1d40 --- /dev/null +++ b/tests/out/hlsl/const-exprs.hlsl @@ -0,0 +1,14 @@ +RWByteAddressBuffer out_ : register(u0); +RWByteAddressBuffer out2_ : register(u1); +RWByteAddressBuffer out3_ : register(u2); + +[numthreads(1, 1, 1)] +void main() +{ + int2 a = int2(1, 2); + int2 b = int2(3, 4); + out_.Store4(0, asuint(int4(4, 3, 2, 1))); + out2_.Store(0, asuint(2)); + out3_.Store(0, asuint(6)); + return; +} diff --git a/tests/out/hlsl/const-exprs.ron b/tests/out/hlsl/const-exprs.ron new file mode 100644 index 0000000000..a07b03300b --- /dev/null +++ b/tests/out/hlsl/const-exprs.ron @@ -0,0 +1,12 @@ +( + vertex:[ + ], + fragment:[ + ], + compute:[ + ( + entry_point:"main", + target_profile:"cs_5_1", + ), + ], +) diff --git a/tests/out/msl/const-exprs.msl b/tests/out/msl/const-exprs.msl new file mode 100644 index 0000000000..19b9b727fb --- /dev/null +++ b/tests/out/msl/const-exprs.msl @@ -0,0 +1,19 @@ +// language: metal2.0 +#include +#include + +using metal::uint; + + +kernel void main_( + device metal::int4& out [[user(fake0)]] +, device int& out2_ [[user(fake0)]] +, device int& out3_ [[user(fake0)]] +) { + metal::int2 a = metal::int2(1, 2); + metal::int2 b = metal::int2(3, 4); + out = metal::int4(4, 3, 2, 1); + out2_ = 2; + out3_ = 6; + return; +} diff --git a/tests/out/spv/const-exprs.spvasm b/tests/out/spv/const-exprs.spvasm new file mode 100644 index 0000000000..ef00c16b59 --- /dev/null +++ b/tests/out/spv/const-exprs.spvasm @@ -0,0 +1,66 @@ +; SPIR-V +; Version: 1.1 +; Generator: rspirv +; Bound: 40 +OpCapability Shader +OpExtension "SPV_KHR_storage_buffer_storage_class" +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint GLCompute %17 "main" +OpExecutionMode %17 LocalSize 1 1 1 +OpDecorate %7 DescriptorSet 0 +OpDecorate %7 Binding 0 +OpDecorate %8 Block +OpMemberDecorate %8 0 Offset 0 +OpDecorate %10 DescriptorSet 0 +OpDecorate %10 Binding 1 +OpDecorate %11 Block +OpMemberDecorate %11 0 Offset 0 +OpDecorate %13 DescriptorSet 0 +OpDecorate %13 Binding 2 +OpDecorate %14 Block +OpMemberDecorate %14 0 Offset 0 +%2 = OpTypeVoid +%4 = OpTypeInt 32 1 +%3 = OpTypeVector %4 4 +%5 = OpTypeVector %4 2 +%6 = OpTypeVector %4 3 +%8 = OpTypeStruct %3 +%9 = OpTypePointer StorageBuffer %8 +%7 = OpVariable %9 StorageBuffer +%11 = OpTypeStruct %4 +%12 = OpTypePointer StorageBuffer %11 +%10 = OpVariable %12 StorageBuffer +%14 = OpTypeStruct %4 +%15 = OpTypePointer StorageBuffer %14 +%13 = OpVariable %15 StorageBuffer +%18 = OpTypeFunction %2 +%19 = OpTypePointer StorageBuffer %3 +%21 = OpTypeInt 32 0 +%20 = OpConstant %21 0 +%23 = OpTypePointer StorageBuffer %4 +%26 = OpConstant %4 1 +%27 = OpConstant %4 2 +%28 = OpConstantComposite %5 %26 %27 +%29 = OpConstant %4 3 +%30 = OpConstant %4 4 +%31 = OpConstantComposite %5 %29 %30 +%32 = OpConstantComposite %3 %30 %29 %27 %26 +%33 = OpConstant %4 6 +%34 = OpConstant %4 7 +%35 = OpConstantComposite %5 %33 %34 +%36 = OpConstant %4 8 +%37 = OpConstantComposite %6 %33 %34 %36 +%38 = OpConstant %4 9 +%17 = OpFunction %2 None %18 +%16 = OpLabel +%22 = OpAccessChain %19 %7 %20 +%24 = OpAccessChain %23 %10 %20 +%25 = OpAccessChain %23 %13 %20 +OpBranch %39 +%39 = OpLabel +OpStore %22 %32 +OpStore %24 %27 +OpStore %25 %33 +OpReturn +OpFunctionEnd \ No newline at end of file diff --git a/tests/out/wgsl/const-exprs.wgsl b/tests/out/wgsl/const-exprs.wgsl new file mode 100644 index 0000000000..201535836e --- /dev/null +++ b/tests/out/wgsl/const-exprs.wgsl @@ -0,0 +1,16 @@ +@group(0) @binding(0) +var out: vec4; +@group(0) @binding(1) +var out2_: i32; +@group(0) @binding(2) +var out3_: i32; + +@compute @workgroup_size(1, 1, 1) +fn main() { + let a = vec2(1, 2); + let b = vec2(3, 4); + out = vec4(4, 3, 2, 1); + out2_ = 2; + out3_ = 6; + return; +} diff --git a/tests/snapshots.rs b/tests/snapshots.rs index 495bea9598..2bc7f45444 100644 --- a/tests/snapshots.rs +++ b/tests/snapshots.rs @@ -777,6 +777,10 @@ fn convert_wgsl() { Targets::SPIRV | Targets::METAL | Targets::GLSL | Targets::HLSL | Targets::WGSL, ), ("msl-varyings", Targets::METAL), + ( + "const-exprs", + Targets::SPIRV | Targets::METAL | Targets::GLSL | Targets::HLSL | Targets::WGSL, + ), ]; for &(name, targets) in inputs.iter() {