Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[naga] Teach the constant evaluator vector/vector operators. #4861

Merged
merged 3 commits into from
Dec 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,8 @@ This feature allowed you to call `global_id` on any wgpu opaque handle to get a

#### Naga

- Naga's WGSL front and back ends now have experimental support for 64-bit floating-point literals: `1.0lf` denotes an `f64` value. There has been experimental support for an `f64` type for a while, but until now there was no syntax for writing literals with that type. As before, Naga module validation rejects `f64` values unless `naga::valid::Capabilities::FLOAT64` is requested. By @jimblandy in [#4747](https://github.com/gfx-rs/wgpu/pull/4747).
- Naga'sn WGSL front and back ends now have experimental support for 64-bit floating-point literals: `1.0lf` denotes an `f64` value. There has been experimental support for an `f64` type for a while, but until now there was no syntax for writing literals with that type. As before, Naga module validation rejects `f64` values unless `naga::valid::Capabilities::FLOAT64` is requested. By @jimblandy in [#4747](https://github.com/gfx-rs/wgpu/pull/4747).
- Naga constant evaluation can now process binary operators whose operands are both vectors. By @jimblandy in [#4861](https://github.com/gfx-rs/wgpu/pull/4861).

### Changes

Expand Down
20 changes: 19 additions & 1 deletion naga/src/front/wgsl/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,12 @@ pub enum Error<'a> {
source_span: Span,
source_type: String,
},
ConcretizationFailed {
expr_span: Span,
expr_type: String,
scalar: String,
inner: ConstantEvaluatorError,
},
}

impl<'a> Error<'a> {
Expand Down Expand Up @@ -731,7 +737,19 @@ impl<'a> Error<'a> {
)
],
notes: vec![],
}
},
Error::ConcretizationFailed { expr_span, ref expr_type, ref scalar, ref inner } => ParseError {
message: format!("failed to convert expression to a concrete type: {}", inner),
labels: vec![
(
expr_span,
format!("this expression has type {}", expr_type).into(),
)
],
notes: vec![
format!("the expression should have been converted to have {} scalar type", scalar),
]
},
}
}
}
18 changes: 15 additions & 3 deletions naga/src/front/wgsl/lower/conversion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -174,11 +174,23 @@ impl<'source, 'temp, 'out> super::ExpressionContext<'source, 'temp, 'out> {
if let Some(scalar) = inner.automatically_convertible_scalar(&self.module.types) {
let concretized = scalar.concretize();
if concretized != scalar {
let span = self.get_expression_span(expr);
assert!(scalar.is_abstract());
let expr_span = self.get_expression_span(expr);
expr = self
.as_const_evaluator()
.cast_array(expr, concretized, span)
.map_err(|err| super::Error::ConstantEvaluatorError(err, span))?;
.cast_array(expr, concretized, expr_span)
.map_err(|err| {
// A `TypeResolution` includes the type's full name, if
// it has one. Also, avoid holding the borrow of `inner`
// across the call to `cast_array`.
let expr_type = &self.typifier()[expr];
super::Error::ConcretizationFailed {
expr_span,
expr_type: expr_type.to_wgsl(&self.module.to_ctx()),
scalar: concretized.to_wgsl(),
inner: err,
}
})?;
}
}

Expand Down
95 changes: 95 additions & 0 deletions naga/src/proc/constant_evaluator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1361,12 +1361,107 @@ impl<'a> ConstantEvaluator<'a> {
}
Expression::Compose { ty, components }
}
(
&Expression::Compose {
components: ref left_components,
ty: left_ty,
},
&Expression::Compose {
components: ref right_components,
ty: right_ty,
},
) => {
// We have to make a copy of the component lists, because the
// call to `binary_op_vector` needs `&mut self`, but `self` owns
// the component lists.
let left_flattened = crate::proc::flatten_compose(
left_ty,
left_components,
self.expressions,
self.types,
);
let right_flattened = crate::proc::flatten_compose(
right_ty,
right_components,
self.expressions,
self.types,
);

// `flatten_compose` doesn't return an `ExactSizeIterator`, so
// make a reasonable guess of the capacity we'll need.
let mut flattened = Vec::with_capacity(left_components.len());
flattened.extend(left_flattened.zip(right_flattened));

match (&self.types[left_ty].inner, &self.types[right_ty].inner) {
(
&TypeInner::Vector {
size: left_size, ..
},
&TypeInner::Vector {
size: right_size, ..
},
) if left_size == right_size => {
self.binary_op_vector(op, left_size, &flattened, left_ty, span)?
}
_ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
}
}
_ => return Err(ConstantEvaluatorError::InvalidBinaryOpArgs),
};

self.register_evaluated_expr(expr, span)
}

fn binary_op_vector(
&mut self,
op: BinaryOperator,
size: crate::VectorSize,
components: &[(Handle<Expression>, Handle<Expression>)],
left_ty: Handle<Type>,
span: Span,
) -> Result<Expression, ConstantEvaluatorError> {
let ty = match op {
// Relational operators produce vectors of booleans.
BinaryOperator::Equal
| BinaryOperator::NotEqual
| BinaryOperator::Less
| BinaryOperator::LessEqual
| BinaryOperator::Greater
| BinaryOperator::GreaterEqual => self.types.insert(
Type {
name: None,
inner: TypeInner::Vector {
size,
scalar: crate::Scalar::BOOL,
},
},
span,
),

// Other operators produce the same type as their left
// operand.
BinaryOperator::Add
| BinaryOperator::Subtract
| BinaryOperator::Multiply
| BinaryOperator::Divide
| BinaryOperator::Modulo
| BinaryOperator::And
| BinaryOperator::ExclusiveOr
| BinaryOperator::InclusiveOr
| BinaryOperator::LogicalAnd
| BinaryOperator::LogicalOr
| BinaryOperator::ShiftLeft
| BinaryOperator::ShiftRight => left_ty,
};

let components = components
.iter()
.map(|&(left, right)| self.binary_op(op, left, right, span))
.collect::<Result<Vec<_>, _>>()?;

Ok(Expression::Compose { ty, components })
}

/// Deep copy `expr` from `expressions` into `self.expressions`.
///
/// Return the root of the new copy.
Expand Down
10 changes: 10 additions & 0 deletions naga/src/proc/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,16 @@ impl super::Scalar {
width: crate::ABSTRACT_WIDTH,
};

pub const fn is_abstract(self) -> bool {
match self.kind {
crate::ScalarKind::AbstractInt | crate::ScalarKind::AbstractFloat => true,
crate::ScalarKind::Sint
| crate::ScalarKind::Uint
| crate::ScalarKind::Float
| crate::ScalarKind::Bool => false,
}
}

/// Construct a float `Scalar` with the given width.
///
/// This is especially common when dealing with
Expand Down
7 changes: 6 additions & 1 deletion naga/src/valid/compose.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,12 @@ pub fn validate_compose(
scalar: comp_scalar,
} if comp_scalar == scalar => comp_size as u32,
ref other => {
log::error!("Vector component[{}] type {:?}", index, other);
log::error!(
"Vector component[{}] type {:?}, building {:?}",
index,
other,
scalar
);
return Err(ComposeError::ComponentType {
index: index as u32,
});
Expand Down
3 changes: 3 additions & 0 deletions naga/tests/in/const-exprs.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -84,3 +84,6 @@ fn map_texture_kind(texture_kind: i32) -> u32 {
fn compose_of_splat() {
var x = vec4f(vec3f(1.0), 2.0).wzyx;
}

const add_vec = vec2(1.0f) + vec2(3.0f, 4.0f);
const compare_vec = vec2(3.0f) == vec2(3.0f, 4.0f);
10 changes: 5 additions & 5 deletions naga/tests/in/operators.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ fn builtins() -> vec4<f32> {
return vec4<f32>(vec4<i32>(s1) + v_i32_zero) + s2 + m1 + m2 + b1 + b2;
}

fn splat() -> vec4<f32> {
let a = (1.0 + vec2<f32>(2.0) - 3.0) / 4.0;
let b = vec4<i32>(5) % 2;
fn splat(m: f32, n: i32) -> vec4<f32> {
let a = (2.0 + vec2<f32>(m) - 4.0) / 8.0;
let b = vec4<i32>(n) % 2;
return a.xyxy + vec4<f32>(b);
}

Expand Down Expand Up @@ -280,9 +280,9 @@ fn assignment() {
}

@compute @workgroup_size(1)
fn main() {
fn main(@builtin(workgroup_id) id: vec3<u32>) {
builtins();
splat();
splat(f32(id.x), i32(id.y));
bool_cast(v_f32_one.xyz);

logical();
Expand Down
2 changes: 2 additions & 0 deletions naga/tests/out/glsl/const-exprs.main.Compute.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ const vec4 DIV = vec4(0.44444445, 0.0, 0.0, 0.0);
const int TEXTURE_KIND_REGULAR = 0;
const int TEXTURE_KIND_WARP = 1;
const int TEXTURE_KIND_SKY = 2;
const vec2 add_vec = vec2(4.0, 5.0);
const bvec2 compare_vec = bvec2(true, false);


void swizzle_of_compose() {
Expand Down
13 changes: 7 additions & 6 deletions naga/tests/out/glsl/operators.main.Compute.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ vec4 builtins() {
return (((((vec4((ivec4(s1_) + v_i32_zero)) + s2_) + m1_) + m2_) + vec4(b1_)) + b2_);
}

vec4 splat() {
vec2 a_2 = (((vec2(1.0) + vec2(2.0)) - vec2(3.0)) / vec2(4.0));
ivec4 b = (ivec4(5) % ivec4(2));
vec4 splat(float m, int n) {
vec2 a_2 = (((vec2(2.0) + vec2(m)) - vec2(4.0)) / vec2(8.0));
ivec4 b = (ivec4(n) % ivec4(2));
return (a_2.xyxy + vec4(b));
}

Expand Down Expand Up @@ -247,9 +247,10 @@ void negation_avoids_prefix_decrement() {
}

void main() {
vec4 _e0 = builtins();
vec4 _e1 = splat();
vec3 _e6 = bool_cast(vec3(1.0, 1.0, 1.0));
uvec3 id = gl_WorkGroupID;
vec4 _e1 = builtins();
vec4 _e6 = splat(float(id.x), int(id.y));
vec3 _e11 = bool_cast(vec3(1.0, 1.0, 1.0));
logical();
arithmetic();
bit();
Expand Down
2 changes: 2 additions & 0 deletions naga/tests/out/hlsl/const-exprs.hlsl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ static const float4 DIV = float4(0.44444445, 0.0, 0.0, 0.0);
static const int TEXTURE_KIND_REGULAR = 0;
static const int TEXTURE_KIND_WARP = 1;
static const int TEXTURE_KIND_SKY = 2;
static const float2 add_vec = float2(4.0, 5.0);
static const bool2 compare_vec = bool2(true, false);

void swizzle_of_compose()
{
Expand Down
14 changes: 7 additions & 7 deletions naga/tests/out/hlsl/operators.hlsl
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@ float4 builtins()
return (((((float4(((s1_).xxxx + v_i32_zero)) + s2_) + m1_) + m2_) + (b1_).xxxx) + b2_);
}

float4 splat()
float4 splat(float m, int n)
{
float2 a_2 = ((((1.0).xx + (2.0).xx) - (3.0).xx) / (4.0).xx);
int4 b = ((5).xxxx % (2).xxxx);
float2 a_2 = ((((2.0).xx + (m).xx) - (4.0).xx) / (8.0).xx);
int4 b = ((n).xxxx % (2).xxxx);
return (a_2.xyxy + float4(b));
}

Expand Down Expand Up @@ -251,11 +251,11 @@ void negation_avoids_prefix_decrement()
}

[numthreads(1, 1, 1)]
void main()
void main(uint3 id : SV_GroupID)
{
const float4 _e0 = builtins();
const float4 _e1 = splat();
const float3 _e6 = bool_cast(float3(1.0, 1.0, 1.0));
const float4 _e1 = builtins();
const float4 _e6 = splat(float(id.x), int(id.y));
const float3 _e11 = bool_cast(float3(1.0, 1.0, 1.0));
logical();
arithmetic();
bit();
Expand Down
2 changes: 2 additions & 0 deletions naga/tests/out/msl/const-exprs.msl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ constant metal::float4 DIV = metal::float4(0.44444445, 0.0, 0.0, 0.0);
constant int TEXTURE_KIND_REGULAR = 0;
constant int TEXTURE_KIND_WARP = 1;
constant int TEXTURE_KIND_SKY = 2;
constant metal::float2 add_vec = metal::float2(4.0, 5.0);
constant metal::bool2 compare_vec = metal::bool2(true, false);

void swizzle_of_compose(
) {
Expand Down
15 changes: 10 additions & 5 deletions naga/tests/out/msl/operators.msl
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,11 @@ metal::float4 builtins(
}

metal::float4 splat(
float m,
int n
) {
metal::float2 a_2 = ((metal::float2(1.0) + metal::float2(2.0)) - metal::float2(3.0)) / metal::float2(4.0);
metal::int4 b = metal::int4(5) % metal::int4(2);
metal::float2 a_2 = ((metal::float2(2.0) + metal::float2(m)) - metal::float2(4.0)) / metal::float2(8.0);
metal::int4 b = metal::int4(n) % metal::int4(2);
return a_2.xyxy + static_cast<metal::float4>(b);
}

Expand Down Expand Up @@ -255,11 +257,14 @@ void negation_avoids_prefix_decrement(
int p7_ = -(-(-(-(-(1)))));
}

struct main_Input {
};
kernel void main_(
metal::uint3 id [[threadgroup_position_in_grid]]
) {
metal::float4 _e0 = builtins();
metal::float4 _e1 = splat();
metal::float3 _e6 = bool_cast(metal::float3(1.0, 1.0, 1.0));
metal::float4 _e1 = builtins();
metal::float4 _e6 = splat(static_cast<float>(id.x), static_cast<int>(id.y));
metal::float3 _e11 = bool_cast(metal::float3(1.0, 1.0, 1.0));
logical();
arithmetic();
bit();
Expand Down
Loading
Loading