From 085324e6d0da8cffb22b084baedc794393be37b1 Mon Sep 17 00:00:00 2001 From: Fredrik Fornwall Date: Fri, 18 Aug 2023 14:59:32 +0200 Subject: [PATCH] [wgsl-in] Only splat 'scalar op vec', not '_ op vec' Change binary_op_splat() from splatting: vec op scalar _ op vec to only splat: vec op scalar scalar op vec --- src/front/wgsl/lower/mod.rs | 11 +++-------- src/front/wgsl/tests.rs | 18 ++++++++++++++++++ 2 files changed, 21 insertions(+), 8 deletions(-) diff --git a/src/front/wgsl/lower/mod.rs b/src/front/wgsl/lower/mod.rs index f3a7b22dd3..4c5acf0ad9 100644 --- a/src/front/wgsl/lower/mod.rs +++ b/src/front/wgsl/lower/mod.rs @@ -525,13 +525,8 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> { ) { self.grow_types(*left)?.grow_types(*right)?; - let left_size = match *self.resolved_inner(*left) { - crate::TypeInner::Vector { size, .. } => Some(size), - _ => None, - }; - - match (left_size, self.resolved_inner(*right)) { - (Some(size), &crate::TypeInner::Scalar { .. }) => { + match (self.resolved_inner(*left), self.resolved_inner(*right)) { + (&crate::TypeInner::Vector { size, .. }, &crate::TypeInner::Scalar { .. }) => { *right = self.append_expression( crate::Expression::Splat { size, @@ -540,7 +535,7 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> { self.get_expression_span(*right), ); } - (None, &crate::TypeInner::Vector { size, .. }) => { + (&crate::TypeInner::Scalar { .. }, &crate::TypeInner::Vector { size, .. }) => { *left = self.append_expression( crate::Expression::Splat { size, value: *left }, self.get_expression_span(*left), diff --git a/src/front/wgsl/tests.rs b/src/front/wgsl/tests.rs index 02fc110cae..09a4356b07 100644 --- a/src/front/wgsl/tests.rs +++ b/src/front/wgsl/tests.rs @@ -433,6 +433,24 @@ fn binary_expression_mixed_scalar_and_vector_operands() { assert_eq!(found_expressions, 1); } + + let module = parse_str( + "@fragment + fn main(mat: mat3x3) { + let vec = vec3(1.0, 1.0, 1.0); + let result = mat / vec; + }", + ) + .unwrap(); + let expressions = &&module.entry_points[0].function.expressions; + let found_splat = expressions.iter().any(|(_, e)| { + if let crate::Expression::Binary { left, .. } = *e { + matches!(&expressions[left], &crate::Expression::Splat { .. }) + } else { + false + } + }); + assert!(!found_splat, "'mat / vec' should not be splatted"); } #[test]