Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

hlsl: respect array stride in storage buffers #1507

Merged
merged 1 commit into from
Nov 4, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 44 additions & 29 deletions src/back/hlsl/storage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -367,39 +367,22 @@ impl<W: fmt::Write> super::Writer<'_, W> {
mut cur_expr: Handle<crate::Expression>,
func_ctx: &FunctionCtx,
) -> Result<Handle<crate::GlobalVariable>, Error> {
enum AccessIndex {
Expression(Handle<crate::Expression>),
Constant(u32),
}
enum Parent<'a> {
Array { stride: u32 },
Struct(&'a [crate::StructMember]),
}
self.temp_access_chain.clear();
loop {
// determine the size of the pointee
let stride = match *func_ctx.info[cur_expr].ty.inner_with(&module.types) {
crate::TypeInner::Pointer { base, class: _ } => {
module.types[base].inner.span(&module.constants)
}
crate::TypeInner::ValuePointer { size, width, .. } => {
size.map_or(1, |s| s as u32) * width as u32
}
_ => 0,
};

let (next_expr, sub) = match func_ctx.expressions[cur_expr] {
loop {
let (next_expr, access_index) = match func_ctx.expressions[cur_expr] {
crate::Expression::GlobalVariable(handle) => return Ok(handle),
crate::Expression::Access { base, index } => (
base,
SubAccess::Index {
value: index,
stride,
},
),
crate::Expression::Access { base, index } => (base, AccessIndex::Expression(index)),
crate::Expression::AccessIndex { base, index } => {
let sub = match *func_ctx.info[base].ty.inner_with(&module.types) {
crate::TypeInner::Pointer { base, .. } => match module.types[base].inner {
crate::TypeInner::Struct { ref members, .. } => {
SubAccess::Offset(members[index as usize].offset)
}
_ => SubAccess::Offset(index * stride),
},
_ => SubAccess::Offset(index * stride),
};
(base, sub)
(base, AccessIndex::Constant(index))
}
ref other => {
return Err(Error::Unimplemented(format!(
Expand All @@ -408,6 +391,38 @@ impl<W: fmt::Write> super::Writer<'_, W> {
)))
}
};

let parent = match *func_ctx.info[next_expr].ty.inner_with(&module.types) {
crate::TypeInner::Pointer { base, .. } => match module.types[base].inner {
crate::TypeInner::Struct { ref members, .. } => Parent::Struct(members),
crate::TypeInner::Array { stride, .. } => Parent::Array { stride },
crate::TypeInner::Vector { width, .. } => Parent::Array {
stride: width as u32,
},
crate::TypeInner::Matrix { rows, width, .. } => Parent::Array {
stride: width as u32 * if rows > crate::VectorSize::Bi { 4 } else { 2 },
},
_ => unreachable!(),
},
crate::TypeInner::ValuePointer { width, .. } => Parent::Array {
stride: width as u32,
},
_ => unreachable!(),
};

let sub = match (parent, access_index) {
(Parent::Array { stride }, AccessIndex::Expression(value)) => {
SubAccess::Index { value, stride }
}
(Parent::Array { stride }, AccessIndex::Constant(index)) => {
SubAccess::Offset(stride * index)
}
(Parent::Struct(members), AccessIndex::Constant(index)) => {
SubAccess::Offset(members[index as usize].offset)
}
(Parent::Struct(_), AccessIndex::Expression(_)) => unreachable!(),
};

self.temp_access_chain.push(sub);
cur_expr = next_expr;
}
Expand Down
3 changes: 2 additions & 1 deletion tests/in/access.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ struct Bar {
matrix: mat4x4<f32>;
atom: atomic<i32>;
arr: [[stride(8)]] array<vec2<u32>, 2>;
data: [[stride(4)]] array<i32>;
data: [[stride(8)]] array<i32>;
};

[[group(0), binding(0)]]
Expand Down Expand Up @@ -37,6 +37,7 @@ fn foo([[builtin(vertex_index)]] vi: u32) -> [[builtin(position)]] vec4<f32> {
bar.matrix[1].z = 1.0;
bar.matrix = mat4x4<f32>(vec4<f32>(0.0), vec4<f32>(1.0), vec4<f32>(2.0), vec4<f32>(3.0));
bar.arr = array<vec2<u32>, 2>(vec2<u32>(0u), vec2<u32>(1u));
bar.data[1] = 1;

// test array indexing
var c = array<i32, 5>(a, i32(b), 3, 4, 5);
Expand Down
1 change: 1 addition & 0 deletions tests/out/glsl/access.foo.Vertex.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ void main() {
_group_0_binding_0.matrix[1][2] = 1.0;
_group_0_binding_0.matrix = mat4x4(vec4(0.0), vec4(1.0), vec4(2.0), vec4(3.0));
_group_0_binding_0.arr = uvec2[2](uvec2(0u), uvec2(1u));
_group_0_binding_0.data[1] = 1;
c = int[5](a, int(b), 3, 4, 5);
c[(vi + 1u)] = 42;
int value = c[vi];
Expand Down
3 changes: 2 additions & 1 deletion tests/out/hlsl/access.hlsl
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ float4 foo(uint vi : SV_VertexID) : SV_Position
float4x4 matrix1 = float4x4(asfloat(bar.Load4(0+0)), asfloat(bar.Load4(0+16)), asfloat(bar.Load4(0+32)), asfloat(bar.Load4(0+48)));
uint2 arr[2] = {asuint(bar.Load2(72+0)), asuint(bar.Load2(72+8))};
float b = asfloat(bar.Load(0+48+0));
int a = asint(bar.Load((((NagaBufferLengthRW(bar) - 88) / 4) - 2u)*4+88));
int a = asint(bar.Load((((NagaBufferLengthRW(bar) - 88) / 8) - 2u)*8+88));
const float _e25 = read_from_private(foo1);
bar.Store(8+16+0, asuint(1.0));
{
Expand All @@ -39,6 +39,7 @@ float4 foo(uint vi : SV_VertexID) : SV_Position
bar.Store2(72+0, asuint(_value2[0]));
bar.Store2(72+8, asuint(_value2[1]));
}
bar.Store(8+88, asuint(1));
{
int _result[5]={ a, int(b), 3, 4, 5 };
for(int _i=0; _i<5; ++_i) c[_i] = _result[_i];
Expand Down
3 changes: 2 additions & 1 deletion tests/out/msl/access.msl
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,12 @@ vertex fooOutput foo(
metal::float4x4 matrix = bar.matrix;
type3 arr = bar.arr;
float b = bar.matrix[3].x;
int a = bar.data[(1 + (_buffer_sizes.size0 - 88 - 4) / 4) - 2u];
int a = bar.data[(1 + (_buffer_sizes.size0 - 88 - 4) / 8) - 2u];
float _e25 = read_from_private(foo1);
bar.matrix[1].z = 1.0;
bar.matrix = metal::float4x4(metal::float4(0.0), metal::float4(1.0), metal::float4(2.0), metal::float4(3.0));
for(int _i=0; _i<2; ++_i) bar.arr.inner[_i] = type3 {metal::uint2(0u), metal::uint2(1u)}.inner[_i];
bar.data[1] = 1;
for(int _i=0; _i<5; ++_i) c.inner[_i] = type11 {a, static_cast<int>(b), 3, 4, 5}.inner[_i];
c.inner[vi + 1u] = 42;
int value = c.inner[vi];
Expand Down
112 changes: 57 additions & 55 deletions tests/out/spv/access.spvasm
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
; SPIR-V
; Version: 1.1
; Generator: rspirv
; Bound: 114
; Bound: 115
OpCapability Shader
OpExtension "SPV_KHR_storage_buffer_storage_class"
%1 = OpExtInstImport "GLSL.std.450"
OpMemoryModel Logical GLSL450
OpEntryPoint Vertex %47 "foo" %42 %45
OpEntryPoint GLCompute %91 "atomics"
OpExecutionMode %91 LocalSize 1 1 1
OpEntryPoint GLCompute %92 "atomics"
OpExecutionMode %92 LocalSize 1 1 1
OpSource GLSL 450
OpMemberName %26 0 "matrix"
OpMemberName %26 1 "atom"
Expand All @@ -22,10 +22,10 @@ OpName %38 "foo"
OpName %39 "c"
OpName %42 "vi"
OpName %47 "foo"
OpName %89 "tmp"
OpName %91 "atomics"
OpName %90 "tmp"
OpName %92 "atomics"
OpDecorate %24 ArrayStride 8
OpDecorate %25 ArrayStride 4
OpDecorate %25 ArrayStride 8
OpDecorate %26 Block
OpMemberDecorate %26 0 Offset 0
OpMemberDecorate %26 0 ColMajor
Expand Down Expand Up @@ -80,10 +80,10 @@ OpDecorate %45 BuiltIn Position
%57 = OpTypePointer StorageBuffer %22
%58 = OpTypePointer StorageBuffer %6
%61 = OpTypePointer StorageBuffer %25
%81 = OpTypePointer Function %4
%85 = OpTypeVector %4 4
%93 = OpTypePointer StorageBuffer %4
%96 = OpConstant %9 64
%82 = OpTypePointer Function %4
%86 = OpTypeVector %4 4
%94 = OpTypePointer StorageBuffer %4
%97 = OpConstant %9 64
%34 = OpFunction %6 None %35
%33 = OpFunctionParameter %27
%32 = OpLabel
Expand Down Expand Up @@ -126,52 +126,54 @@ OpStore %73 %72
%76 = OpCompositeConstruct %24 %74 %75
%77 = OpAccessChain %54 %30 %10
OpStore %77 %76
%78 = OpConvertFToS %4 %60
%79 = OpCompositeConstruct %29 %65 %78 %18 %19 %17
OpStore %39 %79
%80 = OpIAdd %9 %44 %16
%82 = OpAccessChain %81 %39 %80
OpStore %82 %20
%83 = OpAccessChain %81 %39 %44
%84 = OpLoad %4 %83
%86 = OpCompositeConstruct %85 %84 %84 %84 %84
%87 = OpConvertSToF %22 %86
%88 = OpMatrixTimesVector %22 %53 %87
OpStore %45 %88
%78 = OpAccessChain %28 %30 %8 %16
OpStore %78 %12
%79 = OpConvertFToS %4 %60
%80 = OpCompositeConstruct %29 %65 %79 %18 %19 %17
OpStore %39 %80
%81 = OpIAdd %9 %44 %16
%83 = OpAccessChain %82 %39 %81
OpStore %83 %20
%84 = OpAccessChain %82 %39 %44
%85 = OpLoad %4 %84
%87 = OpCompositeConstruct %86 %85 %85 %85 %85
%88 = OpConvertSToF %22 %87
%89 = OpMatrixTimesVector %22 %53 %88
OpStore %45 %89
OpReturn
OpFunctionEnd
%91 = OpFunction %2 None %48
%90 = OpLabel
%89 = OpVariable %81 Function
OpBranch %92
%92 = OpLabel
%94 = OpAccessChain %93 %30 %16
%95 = OpAtomicLoad %4 %94 %12 %96
%98 = OpAccessChain %93 %30 %16
%97 = OpAtomicIAdd %4 %98 %12 %96 %17
OpStore %89 %97
%100 = OpAccessChain %93 %30 %16
%99 = OpAtomicISub %4 %100 %12 %96 %17
OpStore %89 %99
%102 = OpAccessChain %93 %30 %16
%101 = OpAtomicAnd %4 %102 %12 %96 %17
OpStore %89 %101
%104 = OpAccessChain %93 %30 %16
%103 = OpAtomicOr %4 %104 %12 %96 %17
OpStore %89 %103
%106 = OpAccessChain %93 %30 %16
%105 = OpAtomicXor %4 %106 %12 %96 %17
OpStore %89 %105
%108 = OpAccessChain %93 %30 %16
%107 = OpAtomicSMin %4 %108 %12 %96 %17
OpStore %89 %107
%110 = OpAccessChain %93 %30 %16
%109 = OpAtomicSMax %4 %110 %12 %96 %17
OpStore %89 %109
%112 = OpAccessChain %93 %30 %16
%111 = OpAtomicExchange %4 %112 %12 %96 %17
OpStore %89 %111
%113 = OpAccessChain %93 %30 %16
OpAtomicStore %113 %12 %96 %95
%92 = OpFunction %2 None %48
%91 = OpLabel
%90 = OpVariable %82 Function
OpBranch %93
%93 = OpLabel
%95 = OpAccessChain %94 %30 %16
%96 = OpAtomicLoad %4 %95 %12 %97
%99 = OpAccessChain %94 %30 %16
%98 = OpAtomicIAdd %4 %99 %12 %97 %17
OpStore %90 %98
%101 = OpAccessChain %94 %30 %16
%100 = OpAtomicISub %4 %101 %12 %97 %17
OpStore %90 %100
%103 = OpAccessChain %94 %30 %16
%102 = OpAtomicAnd %4 %103 %12 %97 %17
OpStore %90 %102
%105 = OpAccessChain %94 %30 %16
%104 = OpAtomicOr %4 %105 %12 %97 %17
OpStore %90 %104
%107 = OpAccessChain %94 %30 %16
%106 = OpAtomicXor %4 %107 %12 %97 %17
OpStore %90 %106
%109 = OpAccessChain %94 %30 %16
%108 = OpAtomicSMin %4 %109 %12 %97 %17
OpStore %90 %108
%111 = OpAccessChain %94 %30 %16
%110 = OpAtomicSMax %4 %111 %12 %97 %17
OpStore %90 %110
%113 = OpAccessChain %94 %30 %16
%112 = OpAtomicExchange %4 %113 %12 %97 %17
OpStore %90 %112
%114 = OpAccessChain %94 %30 %16
OpAtomicStore %114 %12 %97 %96
OpReturn
OpFunctionEnd
3 changes: 2 additions & 1 deletion tests/out/wgsl/access.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ struct Bar {
matrix: mat4x4<f32>;
atom: atomic<i32>;
arr: [[stride(8)]] array<vec2<u32>,2>;
data: [[stride(4)]] array<i32>;
data: [[stride(8)]] array<i32>;
};

[[group(0), binding(0)]]
Expand All @@ -30,6 +30,7 @@ fn foo([[builtin(vertex_index)]] vi: u32) -> [[builtin(position)]] vec4<f32> {
bar.matrix[1][2] = 1.0;
bar.matrix = mat4x4<f32>(vec4<f32>(0.0), vec4<f32>(1.0), vec4<f32>(2.0), vec4<f32>(3.0));
bar.arr = array<vec2<u32>,2>(vec2<u32>(0u), vec2<u32>(1u));
bar.data[1] = 1;
c = array<i32,5>(a, i32(b), 3, 4, 5);
c[(vi + 1u)] = 42;
let value: i32 = c[vi];
Expand Down