Skip to content

Commit

Permalink
ConstantEvaluator::swizzle: Handle vector concatenation.
Browse files Browse the repository at this point in the history
  • Loading branch information
jimblandy committed Sep 18, 2023
1 parent 9a4db9b commit 3eec55a
Show file tree
Hide file tree
Showing 8 changed files with 138 additions and 9 deletions.
78 changes: 69 additions & 9 deletions src/proc/constant_evaluator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ pub enum ConstantEvaluatorError {
SplatScalarOnly,
#[error("Can only swizzle vector constants")]
SwizzleVectorOnly,
#[error("swizzle component not present in source expression")]
SwizzleOutOfBounds,
#[error("Type is not constructible")]
TypeNotConstructible,
#[error("Subexpression(s) are not constant")]
Expand Down Expand Up @@ -305,20 +307,31 @@ impl ConstantEvaluator<'_> {
let expr = Expression::Splat { size, value };
Ok(self.register_constant(expr, span))
}
Expression::Compose {
ty,
components: ref src_components,
} => {
Expression::Compose { ty, ref components } => {
let dst_ty = get_dst_ty(ty)?;

let components = pattern
let mut flattened = [src_constant; 4]; // dummy value
let len = self
.flatten_compose(ty, components)
.zip(flattened.iter_mut())
.map(|(component, elt)| *elt = component)
.count();
let flattened = &flattened[..len];

let swizzled_components = pattern[..size as usize]
.iter()
.take(size as usize)
.map(|&sc| src_components[sc as usize])
.collect();
.map(|&sc| {
let sc = sc as usize;
if let Some(elt) = flattened.get(sc) {
Ok(*elt)
} else {
Err(ConstantEvaluatorError::SwizzleOutOfBounds)
}
})
.collect::<Result<Vec<Handle<Expression>>, _>>()?;
let expr = Expression::Compose {
ty: dst_ty,
components,
components: swizzled_components,
};
Ok(self.register_constant(expr, span))
}
Expand Down Expand Up @@ -827,6 +840,53 @@ impl ConstantEvaluator<'_> {

self.expressions.append(expr, span)
}

/// Return an iterator over the individual components assembled by a
/// `Compose` expression.
///
/// Given `ty` and `components` from an `Expression::Compose`, return an
/// iterator over the components of the resulting value.
///
/// Normally, this would just be an iterator over `components`. However,
/// `Compose` expressions can concatenate vectors, in which case the i'th
/// value being composed is not generally the i'th element of `components`.
/// This function consults `ty` to decide if this concatenation is occuring,
/// and returns an iterator that produces the components of the result of
/// the `Compose` expression in either case.
fn flatten_compose<'c>(
&'c self,
ty: Handle<Type>,
components: &'c [Handle<Expression>],
) -> impl Iterator<Item = Handle<Expression>> + 'c {
// Returning `impl Iterator` is a bit tricky. We may or may not want to
// flatten the components, but we have to settle on a single concrete
// type to return. The below is a single iterator chain that handles
// both the flattening and non-flattening cases.
let (size, is_vector) = if let TypeInner::Vector { size, .. } = self.types[ty].inner {
(size as usize, true)
} else {
(components.len(), false)
};

components
.iter()
.flat_map(move |component| {
if let (
true,
&Expression::Compose {
ty: _,
components: ref subcomponents,
},
) = (is_vector, &self.expressions[*component])
{
subcomponents
} else {
std::slice::from_ref(component)
}
})
.take(size)
.cloned()
}
}

/// Helper function to implement the GLSL `max` function for floats.
Expand Down
5 changes: 5 additions & 0 deletions tests/in/const-exprs.wgsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
fn f() -> vec4<i32> {
let a = vec2(1, 2);
let b = vec2(3, 4);
return vec4(a, b).wzyx;
}
7 changes: 7 additions & 0 deletions tests/out/hlsl/const-exprs.hlsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
int4 f()
{
int2 a = int2(1, 2);
int2 b = int2(3, 4);
return int4(4, 3, 2, 1);
}

8 changes: 8 additions & 0 deletions tests/out/hlsl/const-exprs.ron
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
(
vertex:[
],
fragment:[
],
compute:[
],
)
13 changes: 13 additions & 0 deletions tests/out/msl/const-exprs.msl
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
// language: metal2.0
#include <metal_stdlib>
#include <simd/simd.h>

using metal::uint;


metal::int4 f(
) {
metal::int2 a = metal::int2(1, 2);
metal::int2 b = metal::int2(3, 4);
return metal::int4(4, 3, 2, 1);
}
26 changes: 26 additions & 0 deletions tests/out/spv/const-exprs.spvasm
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
; SPIR-V
; Version: 1.1
; Generator: rspirv
; Bound: 17
OpCapability Shader
OpCapability Linkage
%1 = OpExtInstImport "GLSL.std.450"
OpMemoryModel Logical GLSL450
%2 = OpTypeVoid
%4 = OpTypeInt 32 1
%3 = OpTypeVector %4 4
%5 = OpTypeVector %4 2
%8 = OpTypeFunction %3
%9 = OpConstant %4 1
%10 = OpConstant %4 2
%11 = OpConstantComposite %5 %9 %10
%12 = OpConstant %4 3
%13 = OpConstant %4 4
%14 = OpConstantComposite %5 %12 %13
%15 = OpConstantComposite %3 %13 %12 %10 %9
%7 = OpFunction %3 None %8
%6 = OpLabel
OpBranch %16
%16 = OpLabel
OpReturnValue %15
OpFunctionEnd
6 changes: 6 additions & 0 deletions tests/out/wgsl/const-exprs.wgsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
fn f() -> vec4<i32> {
let a = vec2<i32>(1, 2);
let b = vec2<i32>(3, 4);
return vec4<i32>(4, 3, 2, 1);
}

4 changes: 4 additions & 0 deletions tests/snapshots.rs
Original file line number Diff line number Diff line change
Expand Up @@ -616,6 +616,10 @@ fn convert_wgsl() {
"constructors",
Targets::SPIRV | Targets::METAL | Targets::GLSL | Targets::HLSL | Targets::WGSL,
),
(
"const-exprs",
Targets::SPIRV | Targets::METAL | Targets::GLSL | Targets::HLSL | Targets::WGSL,
),
];

for &(name, targets) in inputs.iter() {
Expand Down

0 comments on commit 3eec55a

Please sign in to comment.