Skip to content

Commit

Permalink
[wgsl-in] use AccessIndex for accesses using constant index
Browse files Browse the repository at this point in the history
  • Loading branch information
Frizi authored and kvark committed Jun 17, 2021
1 parent b29e853 commit 3370147
Show file tree
Hide file tree
Showing 6 changed files with 58 additions and 15 deletions.
2 changes: 1 addition & 1 deletion src/back/msl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -765,7 +765,7 @@ impl<W: Write> Writer<W> {
write!(self.out, "[{}]", index)?;
}
crate::TypeInner::Array { .. } => {
write!(self.out, "[{}]", index)?;
write!(self.out, ".{}[{}]", WRAPPED_ARRAY_FIELD, index)?;
}
_ => {
// unexpected indexing, should fail validation
Expand Down
12 changes: 10 additions & 2 deletions src/front/wgsl/lexer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -233,15 +233,23 @@ impl<'a> Lexer<'a> {
token
}

pub(super) fn expect(&mut self, expected: Token<'a>) -> Result<(), Error<'a>> {
pub(super) fn expect_span(
&mut self,
expected: Token<'a>,
) -> Result<std::ops::Range<usize>, Error<'a>> {
let next = self.next();
if next.0 == expected {
Ok(())
Ok(next.1)
} else {
Err(Error::Unexpected(next, ExpectedToken::Token(expected)))
}
}

pub(super) fn expect(&mut self, expected: Token<'a>) -> Result<(), Error<'a>> {
self.expect_span(expected)?;
Ok(())
}

pub(super) fn expect_generic_paren(&mut self, expected: char) -> Result<(), Error<'a>> {
let next = self.next_generic();
if next.0 == Token::Paren(expected) {
Expand Down
48 changes: 41 additions & 7 deletions src/front/wgsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use crate::{
proc::{
ensure_block_returns, Alignment, Layouter, ResolveContext, ResolveError, TypeResolution,
},
FastHashMap,
ConstantInner, FastHashMap, ScalarValue,
};

use self::lexer::Lexer;
Expand All @@ -26,6 +26,7 @@ use codespan_reporting::{
};
use std::{
borrow::Cow,
convert::TryFrom,
io::{self, Write},
iter,
num::{NonZeroU32, ParseFloatError, ParseIntError},
Expand Down Expand Up @@ -98,6 +99,8 @@ pub enum Error<'a> {
#[error("")]
BadFloat(Span, ParseFloatError),
#[error("")]
BadU32Constant(Span),
#[error("")]
BadScalarWidth(Span, &'a str),
#[error("")]
BadAccessor(Span),
Expand Down Expand Up @@ -250,6 +253,15 @@ impl<'a> Error<'a> {
labels: vec![(bad_span.clone(), "expected floating-point literal".into())],
notes: vec![err.to_string()],
},
Error::BadU32Constant(ref bad_span) => ParseError {
message: format!(
"expected non-negative integer constant expression, found `{}`",
&source[bad_span.clone()],
),
labels: vec![(bad_span.clone(), "expected non-negative integer".into())],
notes: vec![],
},

Error::BadScalarWidth(ref bad_span, width) => ParseError {
message: format!("invalid width of `{}` for literal", width,),
labels: vec![(bad_span.clone(), "invalid width".into())],
Expand Down Expand Up @@ -1023,7 +1035,7 @@ impl Parser {
ty: char,
width: &'a str,
token: TokenSpan<'a>,
) -> Result<crate::ConstantInner, Error<'a>> {
) -> Result<ConstantInner, Error<'a>> {
let span = token.1;
let value = match ty {
'i' => word
Expand Down Expand Up @@ -1737,12 +1749,34 @@ impl Parser {
}
}
Token::Paren('[') => {
let _ = lexer.next();
let (_, open_brace_span) = lexer.next();
let index = self.parse_general_expression(lexer, ctx.reborrow())?;
lexer.expect(Token::Paren(']'))?;
crate::Expression::Access {
base: handle,
index,
let close_brace_span = lexer.expect_span(Token::Paren(']'))?;

if let crate::Expression::Constant(constant) = ctx.expressions[index] {
let expr_span = open_brace_span.end..close_brace_span.start;

let index = match ctx.constants[constant].inner {
ConstantInner::Scalar {
value: ScalarValue::Uint(int),
..
} => u32::try_from(int).map_err(|_| Error::BadU32Constant(expr_span)),
ConstantInner::Scalar {
value: ScalarValue::Sint(int),
..
} => u32::try_from(int).map_err(|_| Error::BadU32Constant(expr_span)),
_ => Err(Error::BadU32Constant(expr_span)),
}?;

crate::Expression::AccessIndex {
base: handle,
index,
}
} else {
crate::Expression::Access {
base: handle,
index,
}
}
}
_ => {
Expand Down
2 changes: 1 addition & 1 deletion tests/out/access.msl
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ vertex fooOutput foo(
type6 c;
float baz = foo1;
foo1 = 1.0;
metal::float4 _e9 = bar.matrix[3u];
metal::float4 _e9 = bar.matrix[3];
float b = _e9.x;
int a = bar.data[(1 + (_buffer_sizes.size0 - 64 - 4) / 4) - 1u];
for(int _i=0; _i<5; ++_i) c.inner[_i] = type6 {a, static_cast<int>(b), 3, 4, 5}.inner[_i];
Expand Down
2 changes: 1 addition & 1 deletion tests/out/access.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ fn foo([[builtin(vertex_index)]] vi: u32) -> [[builtin(position)]] vec4<f32> {

let baz: f32 = foo1;
foo1 = 1.0;
let _e9: vec4<f32> = bar.matrix[3u];
let _e9: vec4<f32> = bar.matrix[3];
let b: f32 = _e9.x;
let a: i32 = bar.data[(arrayLength(&bar.data) - 1u)];
c = array<i32,5>(a, i32(b), 3, 4, 5);
Expand Down
7 changes: 4 additions & 3 deletions tests/out/globals.spvasm
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
; SPIR-V
; Version: 1.0
; Generator: rspirv
; Bound: 20
; Bound: 21
OpCapability Shader
%1 = OpExtInstImport "GLSL.std.450"
OpMemoryModel Logical GLSL450
Expand All @@ -22,11 +22,12 @@ OpDecorate %11 ArrayStride 4
%12 = OpVariable %13 Workgroup
%16 = OpTypeFunction %2
%18 = OpTypePointer Workgroup %10
%19 = OpConstant %6 3
%15 = OpFunction %2 None %16
%14 = OpLabel
OpBranch %17
%17 = OpLabel
%19 = OpAccessChain %18 %12 %7
OpStore %19 %9
%20 = OpAccessChain %18 %12 %19
OpStore %20 %9
OpReturn
OpFunctionEnd

0 comments on commit 3370147

Please sign in to comment.