From a00638f1be124c25a434758d5ea488eaf4285131 Mon Sep 17 00:00:00 2001 From: Dzmitry Malyshau Date: Fri, 21 Jan 2022 12:30:18 -0500 Subject: [PATCH] Allow non-structure buffer types --- src/back/hlsl/writer.rs | 30 +++-- src/back/msl/writer.rs | 41 ++++--- src/back/spv/helpers.rs | 3 +- src/valid/interface.rs | 20 ++-- tests/in/globals.wgsl | 9 +- tests/out/glsl/globals.main.Compute.glsl | 8 +- tests/out/hlsl/globals.hlsl | 9 +- tests/out/msl/globals.msl | 12 +- tests/out/spv/globals.spvasm | 141 +++++++++++++---------- tests/out/wgsl/globals.wgsl | 16 ++- 10 files changed, 158 insertions(+), 131 deletions(-) diff --git a/src/back/hlsl/writer.rs b/src/back/hlsl/writer.rs index 01a58cf3fc..10e928c62d 100644 --- a/src/back/hlsl/writer.rs +++ b/src/back/hlsl/writer.rs @@ -585,9 +585,6 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { let name = &self.names[&NameKey::GlobalVariable(handle)]; write!(self.out, " {}", name)?; - if let TypeInner::Array { size, .. } = module.types[global.ty].inner { - self.write_array_size(module, size)?; - } if let Some(ref binding) = global.binding { // this was already resolved earlier when we started evaluating an entry point. @@ -597,20 +594,31 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { write!(self.out, ", space{}", bt.space)?; } write!(self.out, ")")?; - } else if global.space == crate::AddressSpace::Private { - write!(self.out, " = ")?; - if let Some(init) = global.init { - self.write_constant(module, init)?; - } else { - self.write_default_init(module, global.ty)?; + } else { + // need to write the array size if the type was emitted with `write_type` + if let TypeInner::Array { size, .. } = module.types[global.ty].inner { + self.write_array_size(module, size)?; + } + if global.space == crate::AddressSpace::Private { + write!(self.out, " = ")?; + if let Some(init) = global.init { + self.write_constant(module, init)?; + } else { + self.write_default_init(module, global.ty)?; + } } } if global.space == crate::AddressSpace::Uniform { write!(self.out, " {{ ")?; self.write_type(module, global.ty)?; - let name = &self.names[&NameKey::GlobalVariable(handle)]; - writeln!(self.out, " {}; }}", name)?; + let sub_name = &self.names[&NameKey::GlobalVariable(handle)]; + write!(self.out, " {}", sub_name)?; + // need to write the array size if the type was emitted with `write_type` + if let TypeInner::Array { size, .. } = module.types[global.ty].inner { + self.write_array_size(module, size)?; + } + writeln!(self.out, "; }}")?; } else { writeln!(self.out, ";")?; } diff --git a/src/back/msl/writer.rs b/src/back/msl/writer.rs index f2fd1bd89b..13e570ba1d 100644 --- a/src/back/msl/writer.rs +++ b/src/back/msl/writer.rs @@ -368,18 +368,25 @@ fn should_pack_struct_member( } fn needs_array_length(ty: Handle, arena: &crate::UniqueArena) -> bool { - if let crate::TypeInner::Struct { ref members, .. } = arena[ty].inner { - if let Some(member) = members.last() { - if let crate::TypeInner::Array { - size: crate::ArraySize::Dynamic, - .. - } = arena[member.ty].inner - { - return true; + match arena[ty].inner { + crate::TypeInner::Struct { ref members, .. } => { + if let Some(member) = members.last() { + if let crate::TypeInner::Array { + size: crate::ArraySize::Dynamic, + .. + } = arena[member.ty].inner + { + return true; + } } + false } + crate::TypeInner::Array { + size: crate::ArraySize::Dynamic, + .. + } => true, + _ => false, } - false } impl crate::AddressSpace { @@ -741,16 +748,18 @@ impl Writer { context: &ExpressionContext, ) -> BackendResult { let global = &context.module.global_variables[handle]; - let members = match context.module.types[global.ty].inner { - crate::TypeInner::Struct { ref members, .. } => members, + let (offset, array_ty) = match context.module.types[global.ty].inner { + crate::TypeInner::Struct { ref members, .. } => match members.last() { + Some(&crate::StructMember { offset, ty, .. }) => (offset, ty), + None => return Err(Error::Validation), + }, + crate::TypeInner::Array { + size: crate::ArraySize::Dynamic, + .. + } => (0, global.ty), _ => return Err(Error::Validation), }; - let (offset, array_ty) = match members.last() { - Some(&crate::StructMember { offset, ty, .. }) => (offset, ty), - None => return Err(Error::Validation), - }; - let (size, stride) = match context.module.types[array_ty].inner { crate::TypeInner::Array { base, stride, .. } => ( context.module.types[base] diff --git a/src/back/spv/helpers.rs b/src/back/spv/helpers.rs index cac21add02..acde9504d9 100644 --- a/src/back/spv/helpers.rs +++ b/src/back/spv/helpers.rs @@ -86,6 +86,7 @@ pub fn global_needs_wrapper(ir_module: &crate::Module, var: &crate::GlobalVariab }, None => false, }, - _ => false, + // if it's not a structure, let's wrap it to be able to put "Block" + _ => true, } } diff --git a/src/valid/interface.rs b/src/valid/interface.rs index 2cfaf5f67d..c322aac2b3 100644 --- a/src/valid/interface.rs +++ b/src/valid/interface.rs @@ -358,7 +358,15 @@ impl super::Validator { true, ) } - crate::AddressSpace::Handle => (TypeFlags::empty(), true), + crate::AddressSpace::Handle => { + match types[var.ty].inner { + crate::TypeInner::Image { .. } | crate::TypeInner::Sampler { .. } => {} + _ => { + return Err(GlobalVariableError::InvalidType); + } + }; + (TypeFlags::empty(), true) + } crate::AddressSpace::Private | crate::AddressSpace::WorkGroup => { (TypeFlags::DATA | TypeFlags::SIZED, false) } @@ -375,16 +383,6 @@ impl super::Validator { } }; - let is_handle = var.space == crate::AddressSpace::Handle; - let good_type = match types[var.ty].inner { - crate::TypeInner::Struct { .. } => !is_handle, - crate::TypeInner::Image { .. } | crate::TypeInner::Sampler { .. } => is_handle, - _ => false, - }; - if is_resource && !good_type { - return Err(GlobalVariableError::InvalidType); - } - if !type_info.flags.contains(required_type_flags) { return Err(GlobalVariableError::MissingTypeFlags { seen: type_info.flags, diff --git a/tests/in/globals.wgsl b/tests/in/globals.wgsl index 02c9e069a2..23dabf1b1b 100644 --- a/tests/in/globals.wgsl +++ b/tests/in/globals.wgsl @@ -13,12 +13,11 @@ struct Foo { @group(0) @binding(1) var alignment: Foo; -struct Dummy { - arr: array>; -}; - @group(0) @binding(2) -var dummy: Dummy; +var dummy: array>; + +@group(0) @binding(3) +var float_vecs: array, 20>; @stage(compute) @workgroup_size(1) fn main() { diff --git a/tests/out/glsl/globals.main.Compute.glsl b/tests/out/glsl/globals.main.Compute.glsl index 4b53779517..a40e6bbac9 100644 --- a/tests/out/glsl/globals.main.Compute.glsl +++ b/tests/out/glsl/globals.main.Compute.glsl @@ -19,10 +19,10 @@ layout(std430) readonly buffer Foo_block_0Compute { Foo _group_0_binding_1_cs; } void main() { float Foo_1 = 1.0; bool at = true; - float _e8 = _group_0_binding_1_cs.v1_; - wg[3] = _e8; - float _e13 = _group_0_binding_1_cs.v3_.x; - wg[2] = _e13; + float _e9 = _group_0_binding_1_cs.v1_; + wg[3] = _e9; + float _e14 = _group_0_binding_1_cs.v3_.x; + wg[2] = _e14; at_1 = 2u; return; } diff --git a/tests/out/hlsl/globals.hlsl b/tests/out/hlsl/globals.hlsl index 41a114f5b8..a6b91c9523 100644 --- a/tests/out/hlsl/globals.hlsl +++ b/tests/out/hlsl/globals.hlsl @@ -9,6 +9,7 @@ groupshared float wg[10]; groupshared uint at_1; ByteAddressBuffer alignment : register(t1); ByteAddressBuffer dummy : register(t2); +cbuffer float_vecs : register(b3) { float4 float_vecs[20]; } [numthreads(1, 1, 1)] void main() @@ -16,10 +17,10 @@ void main() float Foo_1 = 1.0; bool at = true; - float _expr8 = asfloat(alignment.Load(12)); - wg[3] = _expr8; - float _expr13 = asfloat(alignment.Load(0+0)); - wg[2] = _expr13; + float _expr9 = asfloat(alignment.Load(12)); + wg[3] = _expr9; + float _expr14 = asfloat(alignment.Load(0+0)); + wg[2] = _expr14; at_1 = 2u; return; } diff --git a/tests/out/msl/globals.msl b/tests/out/msl/globals.msl index f2c37b5ab1..dde8160766 100644 --- a/tests/out/msl/globals.msl +++ b/tests/out/msl/globals.msl @@ -15,8 +15,8 @@ struct Foo { float v1_; }; typedef metal::float2 type_6[1]; -struct Dummy { - type_6 arr; +struct type_8 { + metal::float4 inner[20]; }; kernel void main_( @@ -26,10 +26,10 @@ kernel void main_( ) { float Foo_1 = 1.0; bool at = true; - float _e8 = alignment.v1_; - wg.inner[3] = _e8; - float _e13 = metal::float3(alignment.v3_).x; - wg.inner[2] = _e13; + float _e9 = alignment.v1_; + wg.inner[3] = _e9; + float _e14 = metal::float3(alignment.v3_).x; + wg.inner[2] = _e14; metal::atomic_store_explicit(&at_1, 2u, metal::memory_order_relaxed); return; } diff --git a/tests/out/spv/globals.spvasm b/tests/out/spv/globals.spvasm index a11785e8c4..ecc2805ef4 100644 --- a/tests/out/spv/globals.spvasm +++ b/tests/out/spv/globals.spvasm @@ -1,81 +1,94 @@ ; SPIR-V ; Version: 1.1 ; Generator: rspirv -; Bound: 53 +; Bound: 61 OpCapability Shader OpExtension "SPV_KHR_storage_buffer_storage_class" %1 = OpExtInstImport "GLSL.std.450" OpMemoryModel Logical GLSL450 -OpEntryPoint GLCompute %34 "main" -OpExecutionMode %34 LocalSize 1 1 1 -OpDecorate %14 ArrayStride 4 -OpMemberDecorate %16 0 Offset 0 -OpMemberDecorate %16 1 Offset 12 -OpDecorate %18 ArrayStride 8 -OpMemberDecorate %19 0 Offset 0 -OpDecorate %24 NonWritable -OpDecorate %24 DescriptorSet 0 -OpDecorate %24 Binding 1 -OpDecorate %25 Block -OpMemberDecorate %25 0 Offset 0 -OpDecorate %27 NonWritable -OpDecorate %27 DescriptorSet 0 -OpDecorate %27 Binding 2 -OpDecorate %19 Block +OpEntryPoint GLCompute %40 "main" +OpExecutionMode %40 LocalSize 1 1 1 +OpDecorate %15 ArrayStride 4 +OpMemberDecorate %17 0 Offset 0 +OpMemberDecorate %17 1 Offset 12 +OpDecorate %19 ArrayStride 8 +OpDecorate %21 ArrayStride 16 +OpDecorate %26 NonWritable +OpDecorate %26 DescriptorSet 0 +OpDecorate %26 Binding 1 +OpDecorate %27 Block +OpMemberDecorate %27 0 Offset 0 +OpDecorate %29 NonWritable +OpDecorate %29 DescriptorSet 0 +OpDecorate %29 Binding 2 +OpDecorate %30 Block +OpMemberDecorate %30 0 Offset 0 +OpDecorate %32 DescriptorSet 0 +OpDecorate %32 Binding 3 +OpDecorate %33 Block +OpMemberDecorate %33 0 Offset 0 %2 = OpTypeVoid %4 = OpTypeBool %3 = OpConstantTrue %4 %6 = OpTypeInt 32 0 %5 = OpConstant %6 10 %8 = OpTypeInt 32 1 -%7 = OpConstant %8 3 -%9 = OpConstant %8 2 -%10 = OpConstant %6 2 -%12 = OpTypeFloat 32 -%11 = OpConstant %12 1.0 -%13 = OpConstantTrue %4 -%14 = OpTypeArray %12 %5 -%15 = OpTypeVector %12 3 -%16 = OpTypeStruct %15 %12 -%17 = OpTypeVector %12 2 -%18 = OpTypeRuntimeArray %17 -%19 = OpTypeStruct %18 -%21 = OpTypePointer Workgroup %14 -%20 = OpVariable %21 Workgroup -%23 = OpTypePointer Workgroup %6 +%7 = OpConstant %8 20 +%9 = OpConstant %8 3 +%10 = OpConstant %8 2 +%11 = OpConstant %6 2 +%13 = OpTypeFloat 32 +%12 = OpConstant %13 1.0 +%14 = OpConstantTrue %4 +%15 = OpTypeArray %13 %5 +%16 = OpTypeVector %13 3 +%17 = OpTypeStruct %16 %13 +%18 = OpTypeVector %13 2 +%19 = OpTypeRuntimeArray %18 +%20 = OpTypeVector %13 4 +%21 = OpTypeArray %20 %7 +%23 = OpTypePointer Workgroup %15 %22 = OpVariable %23 Workgroup -%25 = OpTypeStruct %16 -%26 = OpTypePointer StorageBuffer %25 -%24 = OpVariable %26 StorageBuffer -%28 = OpTypePointer StorageBuffer %19 -%27 = OpVariable %28 StorageBuffer -%30 = OpTypePointer Function %12 -%32 = OpTypePointer Function %4 -%35 = OpTypeFunction %2 -%36 = OpTypePointer StorageBuffer %16 -%37 = OpConstant %6 0 -%40 = OpTypePointer Workgroup %12 -%41 = OpTypePointer StorageBuffer %12 -%42 = OpConstant %6 1 -%45 = OpConstant %6 3 -%47 = OpTypePointer StorageBuffer %15 -%48 = OpTypePointer StorageBuffer %12 -%52 = OpConstant %6 256 -%34 = OpFunction %2 None %35 -%33 = OpLabel -%29 = OpVariable %30 Function %11 -%31 = OpVariable %32 Function %13 -%38 = OpAccessChain %36 %24 %37 -OpBranch %39 +%25 = OpTypePointer Workgroup %6 +%24 = OpVariable %25 Workgroup +%27 = OpTypeStruct %17 +%28 = OpTypePointer StorageBuffer %27 +%26 = OpVariable %28 StorageBuffer +%30 = OpTypeStruct %19 +%31 = OpTypePointer StorageBuffer %30 +%29 = OpVariable %31 StorageBuffer +%33 = OpTypeStruct %21 +%34 = OpTypePointer Uniform %33 +%32 = OpVariable %34 Uniform +%36 = OpTypePointer Function %13 +%38 = OpTypePointer Function %4 +%41 = OpTypeFunction %2 +%42 = OpTypePointer StorageBuffer %17 +%43 = OpConstant %6 0 +%45 = OpTypePointer StorageBuffer %19 +%46 = OpTypePointer Uniform %21 +%48 = OpTypePointer Workgroup %13 +%49 = OpTypePointer StorageBuffer %13 +%50 = OpConstant %6 1 +%53 = OpConstant %6 3 +%55 = OpTypePointer StorageBuffer %16 +%56 = OpTypePointer StorageBuffer %13 +%60 = OpConstant %6 256 +%40 = OpFunction %2 None %41 %39 = OpLabel -%43 = OpAccessChain %41 %38 %42 -%44 = OpLoad %12 %43 -%46 = OpAccessChain %40 %20 %45 -OpStore %46 %44 -%49 = OpAccessChain %48 %38 %37 %37 -%50 = OpLoad %12 %49 -%51 = OpAccessChain %40 %20 %10 -OpStore %51 %50 -OpAtomicStore %22 %9 %52 %10 +%35 = OpVariable %36 Function %12 +%37 = OpVariable %38 Function %14 +%44 = OpAccessChain %42 %26 %43 +OpBranch %47 +%47 = OpLabel +%51 = OpAccessChain %49 %44 %50 +%52 = OpLoad %13 %51 +%54 = OpAccessChain %48 %22 %53 +OpStore %54 %52 +%57 = OpAccessChain %56 %44 %43 %43 +%58 = OpLoad %13 %57 +%59 = OpAccessChain %48 %22 %11 +OpStore %59 %58 +OpAtomicStore %24 %10 %60 %11 OpReturn OpFunctionEnd \ No newline at end of file diff --git a/tests/out/wgsl/globals.wgsl b/tests/out/wgsl/globals.wgsl index 6184f987c1..3196468c50 100644 --- a/tests/out/wgsl/globals.wgsl +++ b/tests/out/wgsl/globals.wgsl @@ -3,10 +3,6 @@ struct Foo { v1_: f32; }; -struct Dummy { - arr: array>; -}; - let Foo_2: bool = true; var wg: array; @@ -14,17 +10,19 @@ var at_1: atomic; @group(0) @binding(1) var alignment: Foo; @group(0) @binding(2) -var dummy: Dummy; +var dummy: array>; +@group(0) @binding(3) +var float_vecs: array,20>; @stage(compute) @workgroup_size(1, 1, 1) fn main() { var Foo_1: f32 = 1.0; var at: bool = true; - let _e8 = alignment.v1_; - wg[3] = _e8; - let _e13 = alignment.v3_.x; - wg[2] = _e13; + let _e9 = alignment.v1_; + wg[3] = _e9; + let _e14 = alignment.v3_.x; + wg[2] = _e14; atomicStore((&at_1), 2u); return; }