diff --git a/src/front/wgsl/error.rs b/src/front/wgsl/error.rs index 701eca992b..a3355f861a 100644 --- a/src/front/wgsl/error.rs +++ b/src/front/wgsl/error.rs @@ -168,6 +168,7 @@ pub enum Error<'a> { InvalidIdentifierUnderscore(Span), ReservedIdentifierPrefix(Span), UnknownAddressSpace(Span), + RepeatedAttribute(Span), UnknownAttribute(Span), UnknownBuiltin(Span), UnknownAccess(Span), @@ -430,6 +431,11 @@ impl<'a> Error<'a> { labels: vec![(bad_span, "unknown address space".into())], notes: vec![], }, + Error::RepeatedAttribute(bad_span) => ParseError { + message: format!("repeated attribute: '{}'", &source[bad_span]), + labels: vec![(bad_span, "repated attribute".into())], + notes: vec![], + }, Error::UnknownAttribute(bad_span) => ParseError { message: format!("unknown attribute: '{}'", &source[bad_span]), labels: vec![(bad_span, "unknown attribute".into())], diff --git a/src/front/wgsl/parse/mod.rs b/src/front/wgsl/parse/mod.rs index c13ee52775..b6798027ea 100644 --- a/src/front/wgsl/parse/mod.rs +++ b/src/front/wgsl/parse/mod.rs @@ -120,13 +120,33 @@ enum Rule { GeneralExpr, } +struct ParsedAttribute { + value: Option, +} + +impl Default for ParsedAttribute { + fn default() -> Self { + Self { value: None } + } +} + +impl ParsedAttribute { + fn set(&mut self, value: T, name_span: Span) -> Result<(), Error<'static>> { + if self.value.is_some() { + return Err(Error::RepeatedAttribute(name_span)); + } + self.value = Some(value); + Ok(()) + } +} + #[derive(Default)] struct BindingParser { - location: Option, - built_in: Option, - interpolation: Option, - sampling: Option, - invariant: bool, + location: ParsedAttribute, + built_in: ParsedAttribute, + interpolation: ParsedAttribute, + sampling: ParsedAttribute, + invariant: ParsedAttribute, } impl BindingParser { @@ -139,38 +159,44 @@ impl BindingParser { match name { "location" => { lexer.expect(Token::Paren('('))?; - self.location = Some(Parser::non_negative_i32_literal(lexer)?); + self.location + .set(Parser::non_negative_i32_literal(lexer)?, name_span)?; lexer.expect(Token::Paren(')'))?; } "builtin" => { lexer.expect(Token::Paren('('))?; let (raw, span) = lexer.next_ident_with_span()?; - self.built_in = Some(conv::map_built_in(raw, span)?); + self.built_in + .set(conv::map_built_in(raw, span)?, name_span)?; lexer.expect(Token::Paren(')'))?; } "interpolate" => { lexer.expect(Token::Paren('('))?; let (raw, span) = lexer.next_ident_with_span()?; - self.interpolation = Some(conv::map_interpolation(raw, span)?); + self.interpolation + .set(conv::map_interpolation(raw, span)?, name_span)?; if lexer.skip(Token::Separator(',')) { let (raw, span) = lexer.next_ident_with_span()?; - self.sampling = Some(conv::map_sampling(raw, span)?); + self.sampling + .set(conv::map_sampling(raw, span)?, name_span)?; } lexer.expect(Token::Paren(')'))?; } - "invariant" => self.invariant = true, + "invariant" => { + self.invariant.set(true, name_span)?; + } _ => return Err(Error::UnknownAttribute(name_span)), } Ok(()) } - const fn finish<'a>(self, span: Span) -> Result, Error<'a>> { + fn finish<'a>(self, span: Span) -> Result, Error<'a>> { match ( - self.location, - self.built_in, - self.interpolation, - self.sampling, - self.invariant, + self.location.value, + self.built_in.value, + self.interpolation.value, + self.sampling.value, + self.invariant.value.unwrap_or_default(), ) { (None, None, None, None, false) => Ok(None), (Some(location), None, interpolation, sampling, false) => { @@ -990,22 +1016,22 @@ impl Parser { ExpectedToken::Token(Token::Separator(',')), )); } - let (mut size, mut align) = (None, None); + let (mut size, mut align) = (ParsedAttribute::default(), ParsedAttribute::default()); self.push_rule_span(Rule::Attribute, lexer); let mut bind_parser = BindingParser::default(); while lexer.skip(Token::Attribute) { match lexer.next_ident_with_span()? { - ("size", _) => { + ("size", name_span) => { lexer.expect(Token::Paren('('))?; let (value, span) = lexer.capture_span(Self::non_negative_i32_literal)?; lexer.expect(Token::Paren(')'))?; - size = Some((value, span)); + size.set((value, span), name_span)?; } - ("align", _) => { + ("align", name_span) => { lexer.expect(Token::Paren('('))?; let (value, span) = lexer.capture_span(Self::non_negative_i32_literal)?; lexer.expect(Token::Paren(')'))?; - align = Some((value, span)); + align.set((value, span), name_span)?; } (word, word_span) => bind_parser.parse(lexer, word, word_span)?, } @@ -1023,8 +1049,8 @@ impl Parser { name, ty, binding, - size, - align, + size: size.value, + align: align.value, }); } @@ -2131,32 +2157,33 @@ impl Parser { ) -> Result<(), Error<'a>> { // read attributes let mut binding = None; - let mut stage = None; + let mut stage = ParsedAttribute::default(); let mut workgroup_size = [0u32; 3]; - let mut early_depth_test = None; - let (mut bind_index, mut bind_group) = (None, None); + let mut early_depth_test = ParsedAttribute::default(); + let (mut bind_index, mut bind_group) = + (ParsedAttribute::default(), ParsedAttribute::default()); self.push_rule_span(Rule::Attribute, lexer); while lexer.skip(Token::Attribute) { match lexer.next_ident_with_span()? { - ("binding", _) => { + ("binding", name_span) => { lexer.expect(Token::Paren('('))?; - bind_index = Some(Self::non_negative_i32_literal(lexer)?); + bind_index.set(Self::non_negative_i32_literal(lexer)?, name_span)?; lexer.expect(Token::Paren(')'))?; } - ("group", _) => { + ("group", name_span) => { lexer.expect(Token::Paren('('))?; - bind_group = Some(Self::non_negative_i32_literal(lexer)?); + bind_group.set(Self::non_negative_i32_literal(lexer)?, name_span)?; lexer.expect(Token::Paren(')'))?; } - ("vertex", _) => { - stage = Some(crate::ShaderStage::Vertex); + ("vertex", name_span) => { + stage.set(crate::ShaderStage::Vertex, name_span)?; } - ("fragment", _) => { - stage = Some(crate::ShaderStage::Fragment); + ("fragment", name_span) => { + stage.set(crate::ShaderStage::Fragment, name_span)?; } - ("compute", _) => { - stage = Some(crate::ShaderStage::Compute); + ("compute", name_span) => { + stage.set(crate::ShaderStage::Compute, name_span)?; } ("workgroup_size", _) => { lexer.expect(Token::Paren('('))?; @@ -2175,7 +2202,7 @@ impl Parser { } } } - ("early_depth_test", _) => { + ("early_depth_test", name_span) => { let conservative = if lexer.skip(Token::Paren('(')) { let (ident, ident_span) = lexer.next_ident_with_span()?; let value = conv::map_conservative_depth(ident, ident_span)?; @@ -2184,14 +2211,14 @@ impl Parser { } else { None }; - early_depth_test = Some(crate::EarlyDepthTest { conservative }); + early_depth_test.set(crate::EarlyDepthTest { conservative }, name_span)?; } (_, word_span) => return Err(Error::UnknownAttribute(word_span)), } } let attrib_span = self.pop_rule_span(lexer); - match (bind_group, bind_index) { + match (bind_group.value, bind_index.value) { (Some(group), Some(index)) => { binding = Some(crate::ResourceBinding { group, @@ -2254,9 +2281,9 @@ impl Parser { (Token::Word("fn"), _) => { let function = self.function_decl(lexer, out, &mut dependencies)?; Some(ast::GlobalDeclKind::Fn(ast::Function { - entry_point: stage.map(|stage| ast::EntryPoint { + entry_point: stage.value.map(|stage| ast::EntryPoint { stage, - early_depth_test, + early_depth_test: early_depth_test.value, workgroup_size, }), ..function diff --git a/src/front/wgsl/tests.rs b/src/front/wgsl/tests.rs index 3aec0acf2d..80ea261434 100644 --- a/src/front/wgsl/tests.rs +++ b/src/front/wgsl/tests.rs @@ -509,3 +509,42 @@ fn parse_texture_load_store_expecting_four_args() { ); } } + +#[test] +fn parse_repeated_attributes() { + use crate::{ + front::wgsl::{error::Error, Frontend}, + Span, + }; + + let template_vs = "@vertex fn vs() -> __REPLACE__ vec4 { return vec4(0.0); }"; + let template_struct = "struct A { __REPLACE__ data: vec3 }"; + let template_resource = "__REPLACE__ var tex_los_res: texture_2d_array;"; + let template_stage = "__REPLACE__ fn vs() -> vec4 { return vec4(0.0); }"; + for (attribute, template) in [ + ("align(16)", template_struct), + ("binding(0)", template_resource), + ("builtin(position)", template_vs), + ("compute", template_stage), + ("fragment", template_stage), + ("group(0)", template_resource), + ("interpolate(flat)", template_vs), + ("invariant", template_vs), + ("location(0)", template_vs), + ("size(16)", template_struct), + ("vertex", template_stage), + ("early_depth_test(less_equal)", template_resource), + ] { + let shader = template.replace("__REPLACE__", &format!("@{attribute} @{attribute}")); + let name_length = attribute.rfind('(').unwrap_or(attribute.len()) as u32; + let span_start = shader.rfind(attribute).unwrap() as u32; + let span_end = span_start + name_length; + let expected_span = Span::new(span_start, span_end); + + let result = Frontend::new().inner(&shader); + assert!(matches!( + result.unwrap_err(), + Error::RepeatedAttribute(span) if span == expected_span + )); + } +}