Skip to content

Commit

Permalink
hlsl: respect array stride in storage buffers (#1507)
Browse files Browse the repository at this point in the history
  • Loading branch information
kvark authored Nov 4, 2021
1 parent d9b1668 commit 28c4532
Show file tree
Hide file tree
Showing 7 changed files with 110 additions and 88 deletions.
73 changes: 44 additions & 29 deletions src/back/hlsl/storage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -367,39 +367,22 @@ impl<W: fmt::Write> super::Writer<'_, W> {
mut cur_expr: Handle<crate::Expression>,
func_ctx: &FunctionCtx,
) -> Result<Handle<crate::GlobalVariable>, Error> {
enum AccessIndex {
Expression(Handle<crate::Expression>),
Constant(u32),
}
enum Parent<'a> {
Array { stride: u32 },
Struct(&'a [crate::StructMember]),
}
self.temp_access_chain.clear();
loop {
// determine the size of the pointee
let stride = match *func_ctx.info[cur_expr].ty.inner_with(&module.types) {
crate::TypeInner::Pointer { base, class: _ } => {
module.types[base].inner.span(&module.constants)
}
crate::TypeInner::ValuePointer { size, width, .. } => {
size.map_or(1, |s| s as u32) * width as u32
}
_ => 0,
};

let (next_expr, sub) = match func_ctx.expressions[cur_expr] {
loop {
let (next_expr, access_index) = match func_ctx.expressions[cur_expr] {
crate::Expression::GlobalVariable(handle) => return Ok(handle),
crate::Expression::Access { base, index } => (
base,
SubAccess::Index {
value: index,
stride,
},
),
crate::Expression::Access { base, index } => (base, AccessIndex::Expression(index)),
crate::Expression::AccessIndex { base, index } => {
let sub = match *func_ctx.info[base].ty.inner_with(&module.types) {
crate::TypeInner::Pointer { base, .. } => match module.types[base].inner {
crate::TypeInner::Struct { ref members, .. } => {
SubAccess::Offset(members[index as usize].offset)
}
_ => SubAccess::Offset(index * stride),
},
_ => SubAccess::Offset(index * stride),
};
(base, sub)
(base, AccessIndex::Constant(index))
}
ref other => {
return Err(Error::Unimplemented(format!(
Expand All @@ -408,6 +391,38 @@ impl<W: fmt::Write> super::Writer<'_, W> {
)))
}
};

let parent = match *func_ctx.info[next_expr].ty.inner_with(&module.types) {
crate::TypeInner::Pointer { base, .. } => match module.types[base].inner {
crate::TypeInner::Struct { ref members, .. } => Parent::Struct(members),
crate::TypeInner::Array { stride, .. } => Parent::Array { stride },
crate::TypeInner::Vector { width, .. } => Parent::Array {
stride: width as u32,
},
crate::TypeInner::Matrix { rows, width, .. } => Parent::Array {
stride: width as u32 * if rows > crate::VectorSize::Bi { 4 } else { 2 },
},
_ => unreachable!(),
},
crate::TypeInner::ValuePointer { width, .. } => Parent::Array {
stride: width as u32,
},
_ => unreachable!(),
};

let sub = match (parent, access_index) {
(Parent::Array { stride }, AccessIndex::Expression(value)) => {
SubAccess::Index { value, stride }
}
(Parent::Array { stride }, AccessIndex::Constant(index)) => {
SubAccess::Offset(stride * index)
}
(Parent::Struct(members), AccessIndex::Constant(index)) => {
SubAccess::Offset(members[index as usize].offset)
}
(Parent::Struct(_), AccessIndex::Expression(_)) => unreachable!(),
};

self.temp_access_chain.push(sub);
cur_expr = next_expr;
}
Expand Down
3 changes: 2 additions & 1 deletion tests/in/access.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ struct Bar {
matrix: mat4x4<f32>;
atom: atomic<i32>;
arr: [[stride(8)]] array<vec2<u32>, 2>;
data: [[stride(4)]] array<i32>;
data: [[stride(8)]] array<i32>;
};

[[group(0), binding(0)]]
Expand Down Expand Up @@ -37,6 +37,7 @@ fn foo([[builtin(vertex_index)]] vi: u32) -> [[builtin(position)]] vec4<f32> {
bar.matrix[1].z = 1.0;
bar.matrix = mat4x4<f32>(vec4<f32>(0.0), vec4<f32>(1.0), vec4<f32>(2.0), vec4<f32>(3.0));
bar.arr = array<vec2<u32>, 2>(vec2<u32>(0u), vec2<u32>(1u));
bar.data[1] = 1;

// test array indexing
var c = array<i32, 5>(a, i32(b), 3, 4, 5);
Expand Down
1 change: 1 addition & 0 deletions tests/out/glsl/access.foo.Vertex.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ void main() {
_group_0_binding_0.matrix[1][2] = 1.0;
_group_0_binding_0.matrix = mat4x4(vec4(0.0), vec4(1.0), vec4(2.0), vec4(3.0));
_group_0_binding_0.arr = uvec2[2](uvec2(0u), uvec2(1u));
_group_0_binding_0.data[1] = 1;
c = int[5](a, int(b), 3, 4, 5);
c[(vi + 1u)] = 42;
int value = c[vi];
Expand Down
3 changes: 2 additions & 1 deletion tests/out/hlsl/access.hlsl
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ float4 foo(uint vi : SV_VertexID) : SV_Position
float4x4 matrix1 = float4x4(asfloat(bar.Load4(0+0)), asfloat(bar.Load4(0+16)), asfloat(bar.Load4(0+32)), asfloat(bar.Load4(0+48)));
uint2 arr[2] = {asuint(bar.Load2(72+0)), asuint(bar.Load2(72+8))};
float b = asfloat(bar.Load(0+48+0));
int a = asint(bar.Load((((NagaBufferLengthRW(bar) - 88) / 4) - 2u)*4+88));
int a = asint(bar.Load((((NagaBufferLengthRW(bar) - 88) / 8) - 2u)*8+88));
const float _e25 = read_from_private(foo1);
bar.Store(8+16+0, asuint(1.0));
{
Expand All @@ -39,6 +39,7 @@ float4 foo(uint vi : SV_VertexID) : SV_Position
bar.Store2(72+0, asuint(_value2[0]));
bar.Store2(72+8, asuint(_value2[1]));
}
bar.Store(8+88, asuint(1));
{
int _result[5]={ a, int(b), 3, 4, 5 };
for(int _i=0; _i<5; ++_i) c[_i] = _result[_i];
Expand Down
3 changes: 2 additions & 1 deletion tests/out/msl/access.msl
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,12 @@ vertex fooOutput foo(
metal::float4x4 matrix = bar.matrix;
type3 arr = bar.arr;
float b = bar.matrix[3].x;
int a = bar.data[(1 + (_buffer_sizes.size0 - 88 - 4) / 4) - 2u];
int a = bar.data[(1 + (_buffer_sizes.size0 - 88 - 4) / 8) - 2u];
float _e25 = read_from_private(foo1);
bar.matrix[1].z = 1.0;
bar.matrix = metal::float4x4(metal::float4(0.0), metal::float4(1.0), metal::float4(2.0), metal::float4(3.0));
for(int _i=0; _i<2; ++_i) bar.arr.inner[_i] = type3 {metal::uint2(0u), metal::uint2(1u)}.inner[_i];
bar.data[1] = 1;
for(int _i=0; _i<5; ++_i) c.inner[_i] = type11 {a, static_cast<int>(b), 3, 4, 5}.inner[_i];
c.inner[vi + 1u] = 42;
int value = c.inner[vi];
Expand Down
112 changes: 57 additions & 55 deletions tests/out/spv/access.spvasm
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
; SPIR-V
; Version: 1.1
; Generator: rspirv
; Bound: 114
; Bound: 115
OpCapability Shader
OpExtension "SPV_KHR_storage_buffer_storage_class"
%1 = OpExtInstImport "GLSL.std.450"
OpMemoryModel Logical GLSL450
OpEntryPoint Vertex %47 "foo" %42 %45
OpEntryPoint GLCompute %91 "atomics"
OpExecutionMode %91 LocalSize 1 1 1
OpEntryPoint GLCompute %92 "atomics"
OpExecutionMode %92 LocalSize 1 1 1
OpSource GLSL 450
OpMemberName %26 0 "matrix"
OpMemberName %26 1 "atom"
Expand All @@ -22,10 +22,10 @@ OpName %38 "foo"
OpName %39 "c"
OpName %42 "vi"
OpName %47 "foo"
OpName %89 "tmp"
OpName %91 "atomics"
OpName %90 "tmp"
OpName %92 "atomics"
OpDecorate %24 ArrayStride 8
OpDecorate %25 ArrayStride 4
OpDecorate %25 ArrayStride 8
OpDecorate %26 Block
OpMemberDecorate %26 0 Offset 0
OpMemberDecorate %26 0 ColMajor
Expand Down Expand Up @@ -80,10 +80,10 @@ OpDecorate %45 BuiltIn Position
%57 = OpTypePointer StorageBuffer %22
%58 = OpTypePointer StorageBuffer %6
%61 = OpTypePointer StorageBuffer %25
%81 = OpTypePointer Function %4
%85 = OpTypeVector %4 4
%93 = OpTypePointer StorageBuffer %4
%96 = OpConstant %9 64
%82 = OpTypePointer Function %4
%86 = OpTypeVector %4 4
%94 = OpTypePointer StorageBuffer %4
%97 = OpConstant %9 64
%34 = OpFunction %6 None %35
%33 = OpFunctionParameter %27
%32 = OpLabel
Expand Down Expand Up @@ -126,52 +126,54 @@ OpStore %73 %72
%76 = OpCompositeConstruct %24 %74 %75
%77 = OpAccessChain %54 %30 %10
OpStore %77 %76
%78 = OpConvertFToS %4 %60
%79 = OpCompositeConstruct %29 %65 %78 %18 %19 %17
OpStore %39 %79
%80 = OpIAdd %9 %44 %16
%82 = OpAccessChain %81 %39 %80
OpStore %82 %20
%83 = OpAccessChain %81 %39 %44
%84 = OpLoad %4 %83
%86 = OpCompositeConstruct %85 %84 %84 %84 %84
%87 = OpConvertSToF %22 %86
%88 = OpMatrixTimesVector %22 %53 %87
OpStore %45 %88
%78 = OpAccessChain %28 %30 %8 %16
OpStore %78 %12
%79 = OpConvertFToS %4 %60
%80 = OpCompositeConstruct %29 %65 %79 %18 %19 %17
OpStore %39 %80
%81 = OpIAdd %9 %44 %16
%83 = OpAccessChain %82 %39 %81
OpStore %83 %20
%84 = OpAccessChain %82 %39 %44
%85 = OpLoad %4 %84
%87 = OpCompositeConstruct %86 %85 %85 %85 %85
%88 = OpConvertSToF %22 %87
%89 = OpMatrixTimesVector %22 %53 %88
OpStore %45 %89
OpReturn
OpFunctionEnd
%91 = OpFunction %2 None %48
%90 = OpLabel
%89 = OpVariable %81 Function
OpBranch %92
%92 = OpLabel
%94 = OpAccessChain %93 %30 %16
%95 = OpAtomicLoad %4 %94 %12 %96
%98 = OpAccessChain %93 %30 %16
%97 = OpAtomicIAdd %4 %98 %12 %96 %17
OpStore %89 %97
%100 = OpAccessChain %93 %30 %16
%99 = OpAtomicISub %4 %100 %12 %96 %17
OpStore %89 %99
%102 = OpAccessChain %93 %30 %16
%101 = OpAtomicAnd %4 %102 %12 %96 %17
OpStore %89 %101
%104 = OpAccessChain %93 %30 %16
%103 = OpAtomicOr %4 %104 %12 %96 %17
OpStore %89 %103
%106 = OpAccessChain %93 %30 %16
%105 = OpAtomicXor %4 %106 %12 %96 %17
OpStore %89 %105
%108 = OpAccessChain %93 %30 %16
%107 = OpAtomicSMin %4 %108 %12 %96 %17
OpStore %89 %107
%110 = OpAccessChain %93 %30 %16
%109 = OpAtomicSMax %4 %110 %12 %96 %17
OpStore %89 %109
%112 = OpAccessChain %93 %30 %16
%111 = OpAtomicExchange %4 %112 %12 %96 %17
OpStore %89 %111
%113 = OpAccessChain %93 %30 %16
OpAtomicStore %113 %12 %96 %95
%92 = OpFunction %2 None %48
%91 = OpLabel
%90 = OpVariable %82 Function
OpBranch %93
%93 = OpLabel
%95 = OpAccessChain %94 %30 %16
%96 = OpAtomicLoad %4 %95 %12 %97
%99 = OpAccessChain %94 %30 %16
%98 = OpAtomicIAdd %4 %99 %12 %97 %17
OpStore %90 %98
%101 = OpAccessChain %94 %30 %16
%100 = OpAtomicISub %4 %101 %12 %97 %17
OpStore %90 %100
%103 = OpAccessChain %94 %30 %16
%102 = OpAtomicAnd %4 %103 %12 %97 %17
OpStore %90 %102
%105 = OpAccessChain %94 %30 %16
%104 = OpAtomicOr %4 %105 %12 %97 %17
OpStore %90 %104
%107 = OpAccessChain %94 %30 %16
%106 = OpAtomicXor %4 %107 %12 %97 %17
OpStore %90 %106
%109 = OpAccessChain %94 %30 %16
%108 = OpAtomicSMin %4 %109 %12 %97 %17
OpStore %90 %108
%111 = OpAccessChain %94 %30 %16
%110 = OpAtomicSMax %4 %111 %12 %97 %17
OpStore %90 %110
%113 = OpAccessChain %94 %30 %16
%112 = OpAtomicExchange %4 %113 %12 %97 %17
OpStore %90 %112
%114 = OpAccessChain %94 %30 %16
OpAtomicStore %114 %12 %97 %96
OpReturn
OpFunctionEnd
3 changes: 2 additions & 1 deletion tests/out/wgsl/access.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ struct Bar {
matrix: mat4x4<f32>;
atom: atomic<i32>;
arr: [[stride(8)]] array<vec2<u32>,2>;
data: [[stride(4)]] array<i32>;
data: [[stride(8)]] array<i32>;
};

[[group(0), binding(0)]]
Expand All @@ -30,6 +30,7 @@ fn foo([[builtin(vertex_index)]] vi: u32) -> [[builtin(position)]] vec4<f32> {
bar.matrix[1][2] = 1.0;
bar.matrix = mat4x4<f32>(vec4<f32>(0.0), vec4<f32>(1.0), vec4<f32>(2.0), vec4<f32>(3.0));
bar.arr = array<vec2<u32>,2>(vec2<u32>(0u), vec2<u32>(1u));
bar.data[1] = 1;
c = array<i32,5>(a, i32(b), 3, 4, 5);
c[(vi + 1u)] = 42;
let value: i32 = c[vi];
Expand Down

0 comments on commit 28c4532

Please sign in to comment.