Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow non-structure buffer types #1682

Merged
merged 1 commit into from
Feb 2, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 19 additions & 11 deletions src/back/hlsl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -585,9 +585,6 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {

let name = &self.names[&NameKey::GlobalVariable(handle)];
write!(self.out, " {}", name)?;
if let TypeInner::Array { size, .. } = module.types[global.ty].inner {
self.write_array_size(module, size)?;
}

if let Some(ref binding) = global.binding {
// this was already resolved earlier when we started evaluating an entry point.
Expand All @@ -597,20 +594,31 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
write!(self.out, ", space{}", bt.space)?;
}
write!(self.out, ")")?;
} else if global.space == crate::AddressSpace::Private {
write!(self.out, " = ")?;
if let Some(init) = global.init {
self.write_constant(module, init)?;
} else {
self.write_default_init(module, global.ty)?;
} else {
// need to write the array size if the type was emitted with `write_type`
if let TypeInner::Array { size, .. } = module.types[global.ty].inner {
self.write_array_size(module, size)?;
}
if global.space == crate::AddressSpace::Private {
write!(self.out, " = ")?;
if let Some(init) = global.init {
self.write_constant(module, init)?;
} else {
self.write_default_init(module, global.ty)?;
}
}
}

if global.space == crate::AddressSpace::Uniform {
write!(self.out, " {{ ")?;
self.write_type(module, global.ty)?;
let name = &self.names[&NameKey::GlobalVariable(handle)];
writeln!(self.out, " {}; }}", name)?;
let sub_name = &self.names[&NameKey::GlobalVariable(handle)];
write!(self.out, " {}", sub_name)?;
// need to write the array size if the type was emitted with `write_type`
if let TypeInner::Array { size, .. } = module.types[global.ty].inner {
self.write_array_size(module, size)?;
}
writeln!(self.out, "; }}")?;
} else {
writeln!(self.out, ";")?;
}
Expand Down
41 changes: 25 additions & 16 deletions src/back/msl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -368,18 +368,25 @@ fn should_pack_struct_member(
}

fn needs_array_length(ty: Handle<crate::Type>, arena: &crate::UniqueArena<crate::Type>) -> bool {
if let crate::TypeInner::Struct { ref members, .. } = arena[ty].inner {
if let Some(member) = members.last() {
if let crate::TypeInner::Array {
size: crate::ArraySize::Dynamic,
..
} = arena[member.ty].inner
{
return true;
match arena[ty].inner {
crate::TypeInner::Struct { ref members, .. } => {
if let Some(member) = members.last() {
if let crate::TypeInner::Array {
size: crate::ArraySize::Dynamic,
..
} = arena[member.ty].inner
{
return true;
}
}
false
}
crate::TypeInner::Array {
size: crate::ArraySize::Dynamic,
..
} => true,
_ => false,
}
false
}

impl crate::AddressSpace {
Expand Down Expand Up @@ -741,16 +748,18 @@ impl<W: Write> Writer<W> {
context: &ExpressionContext,
) -> BackendResult {
let global = &context.module.global_variables[handle];
let members = match context.module.types[global.ty].inner {
crate::TypeInner::Struct { ref members, .. } => members,
let (offset, array_ty) = match context.module.types[global.ty].inner {
crate::TypeInner::Struct { ref members, .. } => match members.last() {
Some(&crate::StructMember { offset, ty, .. }) => (offset, ty),
None => return Err(Error::Validation),
},
crate::TypeInner::Array {
size: crate::ArraySize::Dynamic,
..
} => (0, global.ty),
_ => return Err(Error::Validation),
};

let (offset, array_ty) = match members.last() {
Some(&crate::StructMember { offset, ty, .. }) => (offset, ty),
None => return Err(Error::Validation),
};

let (size, stride) = match context.module.types[array_ty].inner {
crate::TypeInner::Array { base, stride, .. } => (
context.module.types[base]
Expand Down
3 changes: 2 additions & 1 deletion src/back/spv/helpers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ pub fn global_needs_wrapper(ir_module: &crate::Module, var: &crate::GlobalVariab
},
None => false,
},
_ => false,
// if it's not a structure, let's wrap it to be able to put "Block"
_ => true,
}
}
20 changes: 9 additions & 11 deletions src/valid/interface.rs
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,15 @@ impl super::Validator {
true,
)
}
crate::AddressSpace::Handle => (TypeFlags::empty(), true),
crate::AddressSpace::Handle => {
match types[var.ty].inner {
crate::TypeInner::Image { .. } | crate::TypeInner::Sampler { .. } => {}
_ => {
return Err(GlobalVariableError::InvalidType);
}
};
(TypeFlags::empty(), true)
}
crate::AddressSpace::Private | crate::AddressSpace::WorkGroup => {
(TypeFlags::DATA | TypeFlags::SIZED, false)
}
Expand All @@ -375,16 +383,6 @@ impl super::Validator {
}
};

let is_handle = var.space == crate::AddressSpace::Handle;
let good_type = match types[var.ty].inner {
crate::TypeInner::Struct { .. } => !is_handle,
crate::TypeInner::Image { .. } | crate::TypeInner::Sampler { .. } => is_handle,
_ => false,
};
if is_resource && !good_type {
return Err(GlobalVariableError::InvalidType);
}

if !type_info.flags.contains(required_type_flags) {
return Err(GlobalVariableError::MissingTypeFlags {
seen: type_info.flags,
Expand Down
9 changes: 4 additions & 5 deletions tests/in/globals.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,11 @@ struct Foo {
@group(0) @binding(1)
var<storage> alignment: Foo;

struct Dummy {
arr: array<vec2<f32>>;
};

@group(0) @binding(2)
var<storage> dummy: Dummy;
var<storage> dummy: array<vec2<f32>>;

@group(0) @binding(3)
var<uniform> float_vecs: array<vec4<f32>, 20>;

@stage(compute) @workgroup_size(1)
fn main() {
Expand Down
8 changes: 4 additions & 4 deletions tests/out/glsl/globals.main.Compute.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@ layout(std430) readonly buffer Foo_block_0Compute { Foo _group_0_binding_1_cs; }
void main() {
float Foo_1 = 1.0;
bool at = true;
float _e8 = _group_0_binding_1_cs.v1_;
wg[3] = _e8;
float _e13 = _group_0_binding_1_cs.v3_.x;
wg[2] = _e13;
float _e9 = _group_0_binding_1_cs.v1_;
wg[3] = _e9;
float _e14 = _group_0_binding_1_cs.v3_.x;
wg[2] = _e14;
at_1 = 2u;
return;
}
Expand Down
9 changes: 5 additions & 4 deletions tests/out/hlsl/globals.hlsl
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,18 @@ groupshared float wg[10];
groupshared uint at_1;
ByteAddressBuffer alignment : register(t1);
ByteAddressBuffer dummy : register(t2);
cbuffer float_vecs : register(b3) { float4 float_vecs[20]; }

[numthreads(1, 1, 1)]
void main()
{
float Foo_1 = 1.0;
bool at = true;

float _expr8 = asfloat(alignment.Load(12));
wg[3] = _expr8;
float _expr13 = asfloat(alignment.Load(0+0));
wg[2] = _expr13;
float _expr9 = asfloat(alignment.Load(12));
wg[3] = _expr9;
float _expr14 = asfloat(alignment.Load(0+0));
wg[2] = _expr14;
at_1 = 2u;
return;
}
12 changes: 6 additions & 6 deletions tests/out/msl/globals.msl
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ struct Foo {
float v1_;
};
typedef metal::float2 type_6[1];
struct Dummy {
type_6 arr;
struct type_8 {
metal::float4 inner[20];
};

kernel void main_(
Expand All @@ -26,10 +26,10 @@ kernel void main_(
) {
float Foo_1 = 1.0;
bool at = true;
float _e8 = alignment.v1_;
wg.inner[3] = _e8;
float _e13 = metal::float3(alignment.v3_).x;
wg.inner[2] = _e13;
float _e9 = alignment.v1_;
wg.inner[3] = _e9;
float _e14 = metal::float3(alignment.v3_).x;
wg.inner[2] = _e14;
metal::atomic_store_explicit(&at_1, 2u, metal::memory_order_relaxed);
return;
}
141 changes: 77 additions & 64 deletions tests/out/spv/globals.spvasm
Original file line number Diff line number Diff line change
@@ -1,81 +1,94 @@
; SPIR-V
; Version: 1.1
; Generator: rspirv
; Bound: 53
; Bound: 61
OpCapability Shader
OpExtension "SPV_KHR_storage_buffer_storage_class"
%1 = OpExtInstImport "GLSL.std.450"
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %34 "main"
OpExecutionMode %34 LocalSize 1 1 1
OpDecorate %14 ArrayStride 4
OpMemberDecorate %16 0 Offset 0
OpMemberDecorate %16 1 Offset 12
OpDecorate %18 ArrayStride 8
OpMemberDecorate %19 0 Offset 0
OpDecorate %24 NonWritable
OpDecorate %24 DescriptorSet 0
OpDecorate %24 Binding 1
OpDecorate %25 Block
OpMemberDecorate %25 0 Offset 0
OpDecorate %27 NonWritable
OpDecorate %27 DescriptorSet 0
OpDecorate %27 Binding 2
OpDecorate %19 Block
OpEntryPoint GLCompute %40 "main"
OpExecutionMode %40 LocalSize 1 1 1
OpDecorate %15 ArrayStride 4
OpMemberDecorate %17 0 Offset 0
OpMemberDecorate %17 1 Offset 12
OpDecorate %19 ArrayStride 8
OpDecorate %21 ArrayStride 16
OpDecorate %26 NonWritable
OpDecorate %26 DescriptorSet 0
OpDecorate %26 Binding 1
OpDecorate %27 Block
OpMemberDecorate %27 0 Offset 0
OpDecorate %29 NonWritable
OpDecorate %29 DescriptorSet 0
OpDecorate %29 Binding 2
OpDecorate %30 Block
OpMemberDecorate %30 0 Offset 0
OpDecorate %32 DescriptorSet 0
OpDecorate %32 Binding 3
OpDecorate %33 Block
OpMemberDecorate %33 0 Offset 0
%2 = OpTypeVoid
%4 = OpTypeBool
%3 = OpConstantTrue %4
%6 = OpTypeInt 32 0
%5 = OpConstant %6 10
%8 = OpTypeInt 32 1
%7 = OpConstant %8 3
%9 = OpConstant %8 2
%10 = OpConstant %6 2
%12 = OpTypeFloat 32
%11 = OpConstant %12 1.0
%13 = OpConstantTrue %4
%14 = OpTypeArray %12 %5
%15 = OpTypeVector %12 3
%16 = OpTypeStruct %15 %12
%17 = OpTypeVector %12 2
%18 = OpTypeRuntimeArray %17
%19 = OpTypeStruct %18
%21 = OpTypePointer Workgroup %14
%20 = OpVariable %21 Workgroup
%23 = OpTypePointer Workgroup %6
%7 = OpConstant %8 20
%9 = OpConstant %8 3
%10 = OpConstant %8 2
%11 = OpConstant %6 2
%13 = OpTypeFloat 32
%12 = OpConstant %13 1.0
%14 = OpConstantTrue %4
%15 = OpTypeArray %13 %5
%16 = OpTypeVector %13 3
%17 = OpTypeStruct %16 %13
%18 = OpTypeVector %13 2
%19 = OpTypeRuntimeArray %18
%20 = OpTypeVector %13 4
%21 = OpTypeArray %20 %7
%23 = OpTypePointer Workgroup %15
%22 = OpVariable %23 Workgroup
%25 = OpTypeStruct %16
%26 = OpTypePointer StorageBuffer %25
%24 = OpVariable %26 StorageBuffer
%28 = OpTypePointer StorageBuffer %19
%27 = OpVariable %28 StorageBuffer
%30 = OpTypePointer Function %12
%32 = OpTypePointer Function %4
%35 = OpTypeFunction %2
%36 = OpTypePointer StorageBuffer %16
%37 = OpConstant %6 0
%40 = OpTypePointer Workgroup %12
%41 = OpTypePointer StorageBuffer %12
%42 = OpConstant %6 1
%45 = OpConstant %6 3
%47 = OpTypePointer StorageBuffer %15
%48 = OpTypePointer StorageBuffer %12
%52 = OpConstant %6 256
%34 = OpFunction %2 None %35
%33 = OpLabel
%29 = OpVariable %30 Function %11
%31 = OpVariable %32 Function %13
%38 = OpAccessChain %36 %24 %37
OpBranch %39
%25 = OpTypePointer Workgroup %6
%24 = OpVariable %25 Workgroup
%27 = OpTypeStruct %17
%28 = OpTypePointer StorageBuffer %27
%26 = OpVariable %28 StorageBuffer
%30 = OpTypeStruct %19
%31 = OpTypePointer StorageBuffer %30
%29 = OpVariable %31 StorageBuffer
%33 = OpTypeStruct %21
%34 = OpTypePointer Uniform %33
%32 = OpVariable %34 Uniform
%36 = OpTypePointer Function %13
%38 = OpTypePointer Function %4
%41 = OpTypeFunction %2
%42 = OpTypePointer StorageBuffer %17
%43 = OpConstant %6 0
%45 = OpTypePointer StorageBuffer %19
%46 = OpTypePointer Uniform %21
%48 = OpTypePointer Workgroup %13
%49 = OpTypePointer StorageBuffer %13
%50 = OpConstant %6 1
%53 = OpConstant %6 3
%55 = OpTypePointer StorageBuffer %16
%56 = OpTypePointer StorageBuffer %13
%60 = OpConstant %6 256
%40 = OpFunction %2 None %41
%39 = OpLabel
%43 = OpAccessChain %41 %38 %42
%44 = OpLoad %12 %43
%46 = OpAccessChain %40 %20 %45
OpStore %46 %44
%49 = OpAccessChain %48 %38 %37 %37
%50 = OpLoad %12 %49
%51 = OpAccessChain %40 %20 %10
OpStore %51 %50
OpAtomicStore %22 %9 %52 %10
%35 = OpVariable %36 Function %12
%37 = OpVariable %38 Function %14
%44 = OpAccessChain %42 %26 %43
OpBranch %47
%47 = OpLabel
%51 = OpAccessChain %49 %44 %50
%52 = OpLoad %13 %51
%54 = OpAccessChain %48 %22 %53
OpStore %54 %52
%57 = OpAccessChain %56 %44 %43 %43
%58 = OpLoad %13 %57
%59 = OpAccessChain %48 %22 %11
OpStore %59 %58
OpAtomicStore %24 %10 %60 %11
OpReturn
OpFunctionEnd
Loading