diff --git a/src/front/wgsl/error.rs b/src/front/wgsl/error.rs index 701eca992b..a3355f861a 100644 --- a/src/front/wgsl/error.rs +++ b/src/front/wgsl/error.rs @@ -168,6 +168,7 @@ pub enum Error<'a> { InvalidIdentifierUnderscore(Span), ReservedIdentifierPrefix(Span), UnknownAddressSpace(Span), + RepeatedAttribute(Span), UnknownAttribute(Span), UnknownBuiltin(Span), UnknownAccess(Span), @@ -430,6 +431,11 @@ impl<'a> Error<'a> { labels: vec![(bad_span, "unknown address space".into())], notes: vec![], }, + Error::RepeatedAttribute(bad_span) => ParseError { + message: format!("repeated attribute: '{}'", &source[bad_span]), + labels: vec![(bad_span, "repated attribute".into())], + notes: vec![], + }, Error::UnknownAttribute(bad_span) => ParseError { message: format!("unknown attribute: '{}'", &source[bad_span]), labels: vec![(bad_span, "unknown attribute".into())], diff --git a/src/front/wgsl/parse/mod.rs b/src/front/wgsl/parse/mod.rs index c13ee52775..55377cc884 100644 --- a/src/front/wgsl/parse/mod.rs +++ b/src/front/wgsl/parse/mod.rs @@ -120,6 +120,13 @@ enum Rule { GeneralExpr, } +const fn fail_if_repeated_attribute<'a>(repeated: bool, name_span: Span) -> Result<(), Error<'a>> { + if repeated { + return Err(Error::RepeatedAttribute(name_span)); + } + Ok(()) +} + #[derive(Default)] struct BindingParser { location: Option, @@ -139,18 +146,21 @@ impl BindingParser { match name { "location" => { lexer.expect(Token::Paren('('))?; + fail_if_repeated_attribute(self.location.is_some(), name_span)?; self.location = Some(Parser::non_negative_i32_literal(lexer)?); lexer.expect(Token::Paren(')'))?; } "builtin" => { lexer.expect(Token::Paren('('))?; let (raw, span) = lexer.next_ident_with_span()?; + fail_if_repeated_attribute(self.built_in.is_some(), name_span)?; self.built_in = Some(conv::map_built_in(raw, span)?); lexer.expect(Token::Paren(')'))?; } "interpolate" => { lexer.expect(Token::Paren('('))?; let (raw, span) = lexer.next_ident_with_span()?; + fail_if_repeated_attribute(self.interpolation.is_some(), name_span)?; self.interpolation = Some(conv::map_interpolation(raw, span)?); if lexer.skip(Token::Separator(',')) { let (raw, span) = lexer.next_ident_with_span()?; @@ -158,7 +168,10 @@ impl BindingParser { } lexer.expect(Token::Paren(')'))?; } - "invariant" => self.invariant = true, + "invariant" => { + fail_if_repeated_attribute(self.invariant, name_span)?; + self.invariant = true; + } _ => return Err(Error::UnknownAttribute(name_span)), } Ok(()) @@ -995,16 +1008,18 @@ impl Parser { let mut bind_parser = BindingParser::default(); while lexer.skip(Token::Attribute) { match lexer.next_ident_with_span()? { - ("size", _) => { + ("size", name_span) => { lexer.expect(Token::Paren('('))?; let (value, span) = lexer.capture_span(Self::non_negative_i32_literal)?; lexer.expect(Token::Paren(')'))?; + fail_if_repeated_attribute(size.is_some(), name_span)?; size = Some((value, span)); } - ("align", _) => { + ("align", name_span) => { lexer.expect(Token::Paren('('))?; let (value, span) = lexer.capture_span(Self::non_negative_i32_literal)?; lexer.expect(Token::Paren(')'))?; + fail_if_repeated_attribute(align.is_some(), name_span)?; align = Some((value, span)); } (word, word_span) => bind_parser.parse(lexer, word, word_span)?, @@ -2139,23 +2154,28 @@ impl Parser { self.push_rule_span(Rule::Attribute, lexer); while lexer.skip(Token::Attribute) { match lexer.next_ident_with_span()? { - ("binding", _) => { + ("binding", name_span) => { lexer.expect(Token::Paren('('))?; + fail_if_repeated_attribute(bind_index.is_some(), name_span)?; bind_index = Some(Self::non_negative_i32_literal(lexer)?); lexer.expect(Token::Paren(')'))?; } - ("group", _) => { + ("group", name_span) => { lexer.expect(Token::Paren('('))?; + fail_if_repeated_attribute(bind_group.is_some(), name_span)?; bind_group = Some(Self::non_negative_i32_literal(lexer)?); lexer.expect(Token::Paren(')'))?; } - ("vertex", _) => { + ("vertex", name_span) => { + fail_if_repeated_attribute(stage.is_some(), name_span)?; stage = Some(crate::ShaderStage::Vertex); } - ("fragment", _) => { + ("fragment", name_span) => { + fail_if_repeated_attribute(stage.is_some(), name_span)?; stage = Some(crate::ShaderStage::Fragment); } - ("compute", _) => { + ("compute", name_span) => { + fail_if_repeated_attribute(stage.is_some(), name_span)?; stage = Some(crate::ShaderStage::Compute); } ("workgroup_size", _) => { @@ -2175,7 +2195,7 @@ impl Parser { } } } - ("early_depth_test", _) => { + ("early_depth_test", name_span) => { let conservative = if lexer.skip(Token::Paren('(')) { let (ident, ident_span) = lexer.next_ident_with_span()?; let value = conv::map_conservative_depth(ident, ident_span)?; @@ -2184,6 +2204,7 @@ impl Parser { } else { None }; + fail_if_repeated_attribute(early_depth_test.is_some(), name_span)?; early_depth_test = Some(crate::EarlyDepthTest { conservative }); } (_, word_span) => return Err(Error::UnknownAttribute(word_span)), diff --git a/src/front/wgsl/tests.rs b/src/front/wgsl/tests.rs index 3aec0acf2d..af65127572 100644 --- a/src/front/wgsl/tests.rs +++ b/src/front/wgsl/tests.rs @@ -509,3 +509,43 @@ fn parse_texture_load_store_expecting_four_args() { ); } } + +#[test] +fn parse_repeated_attributes() { + use crate::{ + front::wgsl::{error::Error, Frontend}, + Span, + }; + + let template_vs = "@vertex fn vs() -> __REPLACE__ vec4 { return vec4(0.0); }"; + let template_struct = "struct A { __REPLACE__ data: vec3 }"; + let template_resource = "__REPLACE__ var tex_los_res: texture_2d_array;"; + let template_stage = "__REPLACE__ fn vs() -> vec4 { return vec4(0.0); }"; + for (attribute, template) in [ + ("align(16)", template_struct), + ("binding(0)", template_resource), + ("builtin(position)", template_vs), + ("compute", template_stage), + ("fragment", template_stage), + ("group(0)", template_resource), + ("interpolate(flat)", template_vs), + ("invariant", template_vs), + ("location(0)", template_vs), + ("size(16)", template_struct), + ("vertex", template_stage), + ("early_depth_test(less_equal)", template_resource), + ] { + let shader = template.replace("__REPLACE__", &format!("@{attribute} @{attribute}")); + let name_length = attribute.rfind('(').unwrap_or(attribute.len()) as u32; + let span_start = shader.rfind(attribute).unwrap() as u32; + let span_end = span_start + name_length; + let expected_span = Span::new(span_start, span_end); + + let result = Frontend::new().inner(&shader); + println!("WHAT? {} RESULT: {:?}", attribute, result); + assert!(matches!( + result.unwrap_err(), + Error::RepeatedAttribute(span) if span == expected_span + )); + } +}