Skip to content

Commit

Permalink
[wgsl-in] Only splat 'scalar op vec', not '_ op vec' (#2444)
Browse files Browse the repository at this point in the history
Change binary_op_splat() from splatting:

        vec op scalar
        _ op vec

to only splat:

        vec op scalar
        scalar op vec
  • Loading branch information
fornwall authored Aug 18, 2023
1 parent 3da9355 commit b001313
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 8 deletions.
11 changes: 3 additions & 8 deletions src/front/wgsl/lower/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -525,13 +525,8 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> {
) {
self.grow_types(*left)?.grow_types(*right)?;

let left_size = match *self.resolved_inner(*left) {
crate::TypeInner::Vector { size, .. } => Some(size),
_ => None,
};

match (left_size, self.resolved_inner(*right)) {
(Some(size), &crate::TypeInner::Scalar { .. }) => {
match (self.resolved_inner(*left), self.resolved_inner(*right)) {
(&crate::TypeInner::Vector { size, .. }, &crate::TypeInner::Scalar { .. }) => {
*right = self.append_expression(
crate::Expression::Splat {
size,
Expand All @@ -540,7 +535,7 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> {
self.get_expression_span(*right),
);
}
(None, &crate::TypeInner::Vector { size, .. }) => {
(&crate::TypeInner::Scalar { .. }, &crate::TypeInner::Vector { size, .. }) => {
*left = self.append_expression(
crate::Expression::Splat { size, value: *left },
self.get_expression_span(*left),
Expand Down
18 changes: 18 additions & 0 deletions src/front/wgsl/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,24 @@ fn binary_expression_mixed_scalar_and_vector_operands() {

assert_eq!(found_expressions, 1);
}

let module = parse_str(
"@fragment
fn main(mat: mat3x3<f32>) {
let vec = vec3<f32>(1.0, 1.0, 1.0);
let result = mat / vec;
}",
)
.unwrap();
let expressions = &&module.entry_points[0].function.expressions;
let found_splat = expressions.iter().any(|(_, e)| {
if let crate::Expression::Binary { left, .. } = *e {
matches!(&expressions[left], &crate::Expression::Splat { .. })
} else {
false
}
});
assert!(!found_splat, "'mat / vec' should not be splatted");
}

#[test]
Expand Down

0 comments on commit b001313

Please sign in to comment.