Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[wgsl-in] Fail on repeated attributes #2428

Merged
merged 2 commits into from
Aug 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/front/wgsl/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ pub enum Error<'a> {
InvalidIdentifierUnderscore(Span),
ReservedIdentifierPrefix(Span),
UnknownAddressSpace(Span),
RepeatedAttribute(Span),
UnknownAttribute(Span),
UnknownBuiltin(Span),
UnknownAccess(Span),
Expand Down Expand Up @@ -430,6 +431,11 @@ impl<'a> Error<'a> {
labels: vec![(bad_span, "unknown address space".into())],
notes: vec![],
},
Error::RepeatedAttribute(bad_span) => ParseError {
message: format!("repeated attribute: '{}'", &source[bad_span]),
labels: vec![(bad_span, "repated attribute".into())],
notes: vec![],
},
Error::UnknownAttribute(bad_span) => ParseError {
message: format!("unknown attribute: '{}'", &source[bad_span]),
labels: vec![(bad_span, "unknown attribute".into())],
Expand Down
109 changes: 68 additions & 41 deletions src/front/wgsl/parse/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -120,13 +120,33 @@ enum Rule {
GeneralExpr,
}

struct ParsedAttribute<T> {
value: Option<T>,
}

impl<T> Default for ParsedAttribute<T> {
fn default() -> Self {
Self { value: None }
}
}

impl<T> ParsedAttribute<T> {
fn set(&mut self, value: T, name_span: Span) -> Result<(), Error<'static>> {
if self.value.is_some() {
return Err(Error::RepeatedAttribute(name_span));
}
self.value = Some(value);
Ok(())
}
}

#[derive(Default)]
struct BindingParser {
location: Option<u32>,
built_in: Option<crate::BuiltIn>,
interpolation: Option<crate::Interpolation>,
sampling: Option<crate::Sampling>,
invariant: bool,
location: ParsedAttribute<u32>,
built_in: ParsedAttribute<crate::BuiltIn>,
interpolation: ParsedAttribute<crate::Interpolation>,
sampling: ParsedAttribute<crate::Sampling>,
invariant: ParsedAttribute<bool>,
}

impl BindingParser {
Expand All @@ -139,38 +159,44 @@ impl BindingParser {
match name {
"location" => {
lexer.expect(Token::Paren('('))?;
self.location = Some(Parser::non_negative_i32_literal(lexer)?);
self.location
.set(Parser::non_negative_i32_literal(lexer)?, name_span)?;
lexer.expect(Token::Paren(')'))?;
}
"builtin" => {
lexer.expect(Token::Paren('('))?;
let (raw, span) = lexer.next_ident_with_span()?;
self.built_in = Some(conv::map_built_in(raw, span)?);
self.built_in
.set(conv::map_built_in(raw, span)?, name_span)?;
lexer.expect(Token::Paren(')'))?;
}
"interpolate" => {
lexer.expect(Token::Paren('('))?;
let (raw, span) = lexer.next_ident_with_span()?;
self.interpolation = Some(conv::map_interpolation(raw, span)?);
self.interpolation
.set(conv::map_interpolation(raw, span)?, name_span)?;
if lexer.skip(Token::Separator(',')) {
let (raw, span) = lexer.next_ident_with_span()?;
self.sampling = Some(conv::map_sampling(raw, span)?);
self.sampling
.set(conv::map_sampling(raw, span)?, name_span)?;
}
lexer.expect(Token::Paren(')'))?;
}
"invariant" => self.invariant = true,
"invariant" => {
self.invariant.set(true, name_span)?;
}
_ => return Err(Error::UnknownAttribute(name_span)),
}
Ok(())
}

const fn finish<'a>(self, span: Span) -> Result<Option<crate::Binding>, Error<'a>> {
fn finish<'a>(self, span: Span) -> Result<Option<crate::Binding>, Error<'a>> {
match (
self.location,
self.built_in,
self.interpolation,
self.sampling,
self.invariant,
self.location.value,
self.built_in.value,
self.interpolation.value,
self.sampling.value,
self.invariant.value.unwrap_or_default(),
) {
(None, None, None, None, false) => Ok(None),
(Some(location), None, interpolation, sampling, false) => {
Expand Down Expand Up @@ -990,22 +1016,22 @@ impl Parser {
ExpectedToken::Token(Token::Separator(',')),
));
}
let (mut size, mut align) = (None, None);
let (mut size, mut align) = (ParsedAttribute::default(), ParsedAttribute::default());
self.push_rule_span(Rule::Attribute, lexer);
let mut bind_parser = BindingParser::default();
while lexer.skip(Token::Attribute) {
match lexer.next_ident_with_span()? {
("size", _) => {
("size", name_span) => {
lexer.expect(Token::Paren('('))?;
let (value, span) = lexer.capture_span(Self::non_negative_i32_literal)?;
lexer.expect(Token::Paren(')'))?;
size = Some((value, span));
size.set((value, span), name_span)?;
}
("align", _) => {
("align", name_span) => {
lexer.expect(Token::Paren('('))?;
let (value, span) = lexer.capture_span(Self::non_negative_i32_literal)?;
lexer.expect(Token::Paren(')'))?;
align = Some((value, span));
align.set((value, span), name_span)?;
}
(word, word_span) => bind_parser.parse(lexer, word, word_span)?,
}
Expand All @@ -1023,8 +1049,8 @@ impl Parser {
name,
ty,
binding,
size,
align,
size: size.value,
align: align.value,
});
}

Expand Down Expand Up @@ -2131,32 +2157,33 @@ impl Parser {
) -> Result<(), Error<'a>> {
// read attributes
let mut binding = None;
let mut stage = None;
let mut stage = ParsedAttribute::default();
let mut workgroup_size = [0u32; 3];
let mut early_depth_test = None;
let (mut bind_index, mut bind_group) = (None, None);
let mut early_depth_test = ParsedAttribute::default();
let (mut bind_index, mut bind_group) =
(ParsedAttribute::default(), ParsedAttribute::default());

self.push_rule_span(Rule::Attribute, lexer);
while lexer.skip(Token::Attribute) {
match lexer.next_ident_with_span()? {
("binding", _) => {
("binding", name_span) => {
lexer.expect(Token::Paren('('))?;
bind_index = Some(Self::non_negative_i32_literal(lexer)?);
bind_index.set(Self::non_negative_i32_literal(lexer)?, name_span)?;
lexer.expect(Token::Paren(')'))?;
}
("group", _) => {
("group", name_span) => {
lexer.expect(Token::Paren('('))?;
bind_group = Some(Self::non_negative_i32_literal(lexer)?);
bind_group.set(Self::non_negative_i32_literal(lexer)?, name_span)?;
lexer.expect(Token::Paren(')'))?;
}
("vertex", _) => {
stage = Some(crate::ShaderStage::Vertex);
("vertex", name_span) => {
stage.set(crate::ShaderStage::Vertex, name_span)?;
}
("fragment", _) => {
stage = Some(crate::ShaderStage::Fragment);
("fragment", name_span) => {
stage.set(crate::ShaderStage::Fragment, name_span)?;
}
("compute", _) => {
stage = Some(crate::ShaderStage::Compute);
("compute", name_span) => {
stage.set(crate::ShaderStage::Compute, name_span)?;
}
("workgroup_size", _) => {
lexer.expect(Token::Paren('('))?;
Expand All @@ -2175,7 +2202,7 @@ impl Parser {
}
}
}
("early_depth_test", _) => {
("early_depth_test", name_span) => {
let conservative = if lexer.skip(Token::Paren('(')) {
let (ident, ident_span) = lexer.next_ident_with_span()?;
let value = conv::map_conservative_depth(ident, ident_span)?;
Expand All @@ -2184,14 +2211,14 @@ impl Parser {
} else {
None
};
early_depth_test = Some(crate::EarlyDepthTest { conservative });
early_depth_test.set(crate::EarlyDepthTest { conservative }, name_span)?;
}
(_, word_span) => return Err(Error::UnknownAttribute(word_span)),
}
}

let attrib_span = self.pop_rule_span(lexer);
match (bind_group, bind_index) {
match (bind_group.value, bind_index.value) {
(Some(group), Some(index)) => {
binding = Some(crate::ResourceBinding {
group,
Expand Down Expand Up @@ -2254,9 +2281,9 @@ impl Parser {
(Token::Word("fn"), _) => {
let function = self.function_decl(lexer, out, &mut dependencies)?;
Some(ast::GlobalDeclKind::Fn(ast::Function {
entry_point: stage.map(|stage| ast::EntryPoint {
entry_point: stage.value.map(|stage| ast::EntryPoint {
stage,
early_depth_test,
early_depth_test: early_depth_test.value,
workgroup_size,
}),
..function
Expand Down
39 changes: 39 additions & 0 deletions src/front/wgsl/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -509,3 +509,42 @@ fn parse_texture_load_store_expecting_four_args() {
);
}
}

#[test]
fn parse_repeated_attributes() {
use crate::{
front::wgsl::{error::Error, Frontend},
Span,
};

let template_vs = "@vertex fn vs() -> __REPLACE__ vec4<f32> { return vec4<f32>(0.0); }";
let template_struct = "struct A { __REPLACE__ data: vec3<f32> }";
let template_resource = "__REPLACE__ var tex_los_res: texture_2d_array<i32>;";
let template_stage = "__REPLACE__ fn vs() -> vec4<f32> { return vec4<f32>(0.0); }";
for (attribute, template) in [
("align(16)", template_struct),
("binding(0)", template_resource),
("builtin(position)", template_vs),
("compute", template_stage),
("fragment", template_stage),
("group(0)", template_resource),
("interpolate(flat)", template_vs),
("invariant", template_vs),
("location(0)", template_vs),
("size(16)", template_struct),
("vertex", template_stage),
("early_depth_test(less_equal)", template_resource),
] {
let shader = template.replace("__REPLACE__", &format!("@{attribute} @{attribute}"));
let name_length = attribute.rfind('(').unwrap_or(attribute.len()) as u32;
let span_start = shader.rfind(attribute).unwrap() as u32;
let span_end = span_start + name_length;
let expected_span = Span::new(span_start, span_end);

let result = Frontend::new().inner(&shader);
assert!(matches!(
result.unwrap_err(),
Error::RepeatedAttribute(span) if span == expected_span
));
}
}