Skip to content

Commit

Permalink
Add FindLsb / FindMsb (#1473)
Browse files Browse the repository at this point in the history
* Add FindLsb / FindMsb

* Fixes and tests for FindLsb/FindMsb

* Add findLsb / findMsb as WGSL builtins

* Fix tests

* Fix incompatible type issue with MSL output

* Requested changes

* Test fewer cases of findLsb/findMsb
  • Loading branch information
fintelia authored Dec 20, 2021
1 parent c2328fe commit f9b3485
Show file tree
Hide file tree
Showing 19 changed files with 157 additions and 3 deletions.
2 changes: 2 additions & 0 deletions src/back/glsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2485,6 +2485,8 @@ impl<'a, W: Write> Writer<'a, W> {
Mf::ReverseBits => "bitfieldReverse",
Mf::ExtractBits => "bitfieldExtract",
Mf::InsertBits => "bitfieldInsert",
Mf::FindLsb => "findLSB",
Mf::FindMsb => "findMSB",
// data packing
Mf::Pack4x8snorm => "packSnorm4x8",
Mf::Pack4x8unorm => "packUnorm4x8",
Expand Down
2 changes: 2 additions & 0 deletions src/back/hlsl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1874,6 +1874,8 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
// bits
Mf::CountOneBits => Function::Regular("countbits"),
Mf::ReverseBits => Function::Regular("reversebits"),
Mf::FindLsb => Function::Regular("firstbitlow"),
Mf::FindMsb => Function::Regular("firstbithigh"),
_ => return Err(Error::Unimplemented(format!("write_expr_math {:?}", fun))),
};

Expand Down
33 changes: 33 additions & 0 deletions src/back/msl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1099,6 +1099,21 @@ impl<W: Write> Writer<W> {
crate::TypeInner::Scalar { .. } => true,
_ => false,
};
let argument_size_suffix = match *context.resolve_type(arg) {
crate::TypeInner::Vector {
size: crate::VectorSize::Bi,
..
} => "2",
crate::TypeInner::Vector {
size: crate::VectorSize::Tri,
..
} => "3",
crate::TypeInner::Vector {
size: crate::VectorSize::Quad,
..
} => "4",
_ => "",
};

let fun_name = match fun {
// comparison
Expand Down Expand Up @@ -1162,6 +1177,8 @@ impl<W: Write> Writer<W> {
Mf::ReverseBits => "reverse_bits",
Mf::ExtractBits => "extract_bits",
Mf::InsertBits => "insert_bits",
Mf::FindLsb => "",
Mf::FindMsb => "",
// data packing
Mf::Pack4x8snorm => "pack_float_to_unorm4x8",
Mf::Pack4x8unorm => "pack_float_to_snorm4x8",
Expand All @@ -1182,6 +1199,22 @@ impl<W: Write> Writer<W> {
write!(self.out, " - ")?;
self.put_expression(arg1.unwrap(), context, false)?;
write!(self.out, ")")?;
} else if fun == Mf::FindLsb {
write!(
self.out,
"(((1 + int{}({}::ctz(",
argument_size_suffix, NAMESPACE
)?;
self.put_expression(arg, context, true)?;
write!(self.out, "))) % 33) - 1)")?;
} else if fun == Mf::FindMsb {
write!(
self.out,
"(((1 + int{}({}::clz(",
argument_size_suffix, NAMESPACE
)?;
self.put_expression(arg, context, true)?;
write!(self.out, "))) % 33) - 1)")?;
} else if fun == Mf::Unpack2x16float {
write!(self.out, "float2(as_type<half2>(")?;
self.put_expression(arg, context, false)?;
Expand Down
6 changes: 6 additions & 0 deletions src/back/spv/block.rs
Original file line number Diff line number Diff line change
Expand Up @@ -662,6 +662,12 @@ impl<'w> BlockContext<'w> {
arg2_id,
arg3_id,
)),
Mf::FindLsb => MathOp::Ext(spirv::GLOp::FindILsb),
Mf::FindMsb => MathOp::Ext(match arg_scalar_kind {
Some(crate::ScalarKind::Uint) => spirv::GLOp::FindUMsb,
Some(crate::ScalarKind::Sint) => spirv::GLOp::FindSMsb,
other => unimplemented!("Unexpected findMSB({:?})", other),
}),
Mf::Pack4x8unorm => MathOp::Ext(spirv::GLOp::PackUnorm4x8),
Mf::Pack4x8snorm => MathOp::Ext(spirv::GLOp::PackSnorm4x8),
Mf::Pack2x16float => MathOp::Ext(spirv::GLOp::PackHalf2x16),
Expand Down
2 changes: 2 additions & 0 deletions src/back/wgsl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1566,6 +1566,8 @@ impl<W: Write> Writer<W> {
Mf::ReverseBits => Function::Regular("reverseBits"),
Mf::ExtractBits => Function::Regular("extractBits"),
Mf::InsertBits => Function::Regular("insertBits"),
Mf::FindLsb => Function::Regular("findLsb"),
Mf::FindMsb => Function::Regular("findMsb"),
// data packing
Mf::Pack4x8snorm => Function::Regular("pack4x8snorm"),
Mf::Pack4x8unorm => Function::Regular("pack4x8unorm"),
Expand Down
5 changes: 4 additions & 1 deletion src/front/glsl/builtins.rs
Original file line number Diff line number Diff line change
Expand Up @@ -630,12 +630,15 @@ pub fn inject_builtin(declaration: &mut FunctionDeclaration, module: &mut Module
))
}
}
"bitCount" | "bitfieldReverse" | "bitfieldExtract" | "bitfieldInsert" => {
"bitCount" | "bitfieldReverse" | "bitfieldExtract" | "bitfieldInsert" | "findLSB"
| "findMSB" => {
let fun = match name {
"bitCount" => MathFunction::CountOneBits,
"bitfieldReverse" => MathFunction::ReverseBits,
"bitfieldExtract" => MathFunction::ExtractBits,
"bitfieldInsert" => MathFunction::InsertBits,
"findLSB" => MathFunction::FindLsb,
"findMSB" => MathFunction::FindMsb,
_ => unreachable!(),
};

Expand Down
2 changes: 2 additions & 0 deletions src/front/spv/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2637,6 +2637,8 @@ impl<I: Iterator<Item = u32>> Parser<I> {
Glo::UnpackHalf2x16 => Mf::Unpack2x16float,
Glo::UnpackUnorm2x16 => Mf::Unpack2x16unorm,
Glo::UnpackSnorm2x16 => Mf::Unpack2x16snorm,
Glo::FindILsb => Mf::FindLsb,
Glo::FindUMsb | Glo::FindSMsb => Mf::FindMsb,
_ => return Err(Error::UnsupportedExtInst(inst_id)),
};

Expand Down
2 changes: 2 additions & 0 deletions src/front/wgsl/conv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,8 @@ pub fn map_standard_fun(word: &str) -> Option<crate::MathFunction> {
"reverseBits" => Mf::ReverseBits,
"extractBits" => Mf::ExtractBits,
"insertBits" => Mf::InsertBits,
"findLsb" => Mf::FindLsb,
"findMsb" => Mf::FindMsb,
// data packing
"pack4x8snorm" => Mf::Pack4x8snorm,
"pack4x8unorm" => Mf::Pack4x8unorm,
Expand Down
2 changes: 2 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -926,6 +926,8 @@ pub enum MathFunction {
ReverseBits,
ExtractBits,
InsertBits,
FindLsb,
FindMsb,
// data packing
Pack4x8snorm,
Pack4x8unorm,
Expand Down
2 changes: 2 additions & 0 deletions src/proc/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,8 @@ impl super::MathFunction {
Self::ReverseBits => 1,
Self::ExtractBits => 3,
Self::InsertBits => 4,
Self::FindLsb => 1,
Self::FindMsb => 1,
// data packing
Self::Pack4x8snorm => 1,
Self::Pack4x8unorm => 1,
Expand Down
10 changes: 10 additions & 0 deletions src/proc/typifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -803,6 +803,16 @@ impl<'a> ResolveContext<'a> {
Mf::ReverseBits |
Mf::ExtractBits |
Mf::InsertBits => res_arg.clone(),
Mf::FindLsb |
Mf::FindMsb => match *res_arg.inner_with(types) {
Ti::Scalar { kind: _, width } =>
TypeResolution::Value(Ti::Scalar { kind: crate::ScalarKind::Sint, width }),
Ti::Vector { size, kind: _, width } =>
TypeResolution::Value(Ti::Vector { size, kind: crate::ScalarKind::Sint, width }),
ref other => return Err(ResolveError::IncompatibleOperands(
format!("{:?}({:?})", fun, other)
)),
},
// data packing
Mf::Pack4x8snorm |
Mf::Pack4x8unorm |
Expand Down
2 changes: 1 addition & 1 deletion src/valid/expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1231,7 +1231,7 @@ impl super::Validator {
_ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
}
}
Mf::CountOneBits | Mf::ReverseBits => {
Mf::CountOneBits | Mf::ReverseBits | Mf::FindLsb | Mf::FindMsb => {
if arg1_ty.is_some() | arg2_ty.is_some() | arg3_ty.is_some() {
return Err(ExpressionError::WrongArgumentCount(fun));
}
Expand Down
4 changes: 4 additions & 0 deletions tests/in/bits.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -36,4 +36,8 @@ fn main() {
u2 = extractBits(u2, 5u, 10u);
u3 = extractBits(u3, 5u, 10u);
u4 = extractBits(u4, 5u, 10u);
i = findLsb(i);
i2 = findLsb(u2);
i3 = findMsb(i3);
i = findMsb(u);
}
16 changes: 16 additions & 0 deletions tests/in/glsl/bits_glsl.frag
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,20 @@ void main() {
u2 = bitfieldExtract(u2, 5, 10);
u3 = bitfieldExtract(u3, 5, 10);
u4 = bitfieldExtract(u4, 5, 10);
i = findLSB(i);
i2 = findLSB(i2);
i3 = findLSB(i3);
i4 = findLSB(i4);
i = findLSB(u);
i2 = findLSB(u2);
i3 = findLSB(u3);
i4 = findLSB(u4);
i = findMSB(i);
i2 = findMSB(i2);
i3 = findMSB(i3);
i4 = findMSB(i4);
i = findMSB(u);
i2 = findMSB(u2);
i3 = findMSB(u3);
i4 = findMSB(u4);
}
8 changes: 8 additions & 0 deletions tests/out/glsl/bits.main.Compute.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,14 @@ void main() {
u3_ = bitfieldExtract(_e112, int(5u), int(10u));
uvec4 _e116 = u4_;
u4_ = bitfieldExtract(_e116, int(5u), int(10u));
int _e120 = i;
i = findLSB(_e120);
uvec2 _e122 = u2_;
i2_ = findLSB(_e122);
ivec3 _e124 = i3_;
i3_ = findMSB(_e124);
uint _e126 = u;
i = findMSB(_e126);
return;
}

8 changes: 8 additions & 0 deletions tests/out/msl/bits.msl
Original file line number Diff line number Diff line change
Expand Up @@ -83,5 +83,13 @@ kernel void main_(
u3_ = metal::extract_bits(_e112, 5u, 10u);
metal::uint4 _e116 = u4_;
u4_ = metal::extract_bits(_e116, 5u, 10u);
int _e120 = i;
i = (((1 + int(metal::ctz(_e120))) % 33) - 1);
metal::uint2 _e122 = u2_;
i2_ = (((1 + int2(metal::ctz(_e122))) % 33) - 1);
metal::int3 _e124 = i3_;
i3_ = (((1 + int3(metal::clz(_e124))) % 33) - 1);
metal::uint _e126 = u;
i = (((1 + int(metal::clz(_e126))) % 33) - 1);
return;
}
14 changes: 13 additions & 1 deletion tests/out/spv/bits.spvasm
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
; SPIR-V
; Version: 1.1
; Generator: rspirv
; Bound: 111
; Bound: 119
OpCapability Shader
%1 = OpExtInstImport "GLSL.std.450"
OpMemoryModel Logical GLSL450
Expand Down Expand Up @@ -151,5 +151,17 @@ OpStore %31 %108
%109 = OpLoad %16 %33
%110 = OpBitFieldUExtract %16 %109 %9 %10
OpStore %33 %110
%111 = OpLoad %4 %19
%112 = OpExtInst %4 %1 FindILsb %111
OpStore %19 %112
%113 = OpLoad %14 %29
%114 = OpExtInst %11 %1 FindILsb %113
OpStore %21 %114
%115 = OpLoad %12 %23
%116 = OpExtInst %12 %1 FindSMsb %115
OpStore %23 %116
%117 = OpLoad %6 %27
%118 = OpExtInst %4 %1 FindUMsb %117
OpStore %19 %118
OpReturn
OpFunctionEnd
8 changes: 8 additions & 0 deletions tests/out/wgsl/bits.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -79,5 +79,13 @@ fn main() {
u3_ = extractBits(_e112, 5u, 10u);
let _e116 = u4_;
u4_ = extractBits(_e116, 5u, 10u);
let _e120 = i;
i = findLsb(_e120);
let _e122 = u2_;
i2_ = findLsb(_e122);
let _e124 = i3_;
i3_ = findMsb(_e124);
let _e126 = u;
i = findMsb(_e126);
return;
}
32 changes: 32 additions & 0 deletions tests/out/wgsl/bits_glsl-frag.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,38 @@ fn main_1() {
u3_ = extractBits(_e207, u32(5), u32(10));
let _e216 = u4_;
u4_ = extractBits(_e216, u32(5), u32(10));
let _e223 = i;
i = findLsb(_e223);
let _e226 = i2_;
i2_ = findLsb(_e226);
let _e229 = i3_;
i3_ = findLsb(_e229);
let _e232 = i4_;
i4_ = findLsb(_e232);
let _e235 = u;
i = findLsb(_e235);
let _e238 = u2_;
i2_ = findLsb(_e238);
let _e241 = u3_;
i3_ = findLsb(_e241);
let _e244 = u4_;
i4_ = findLsb(_e244);
let _e247 = i;
i = findMsb(_e247);
let _e250 = i2_;
i2_ = findMsb(_e250);
let _e253 = i3_;
i3_ = findMsb(_e253);
let _e256 = i4_;
i4_ = findMsb(_e256);
let _e259 = u;
i = findMsb(_e259);
let _e262 = u2_;
i2_ = findMsb(_e262);
let _e265 = u3_;
i3_ = findMsb(_e265);
let _e268 = u4_;
i4_ = findMsb(_e268);
return;
}

Expand Down

0 comments on commit f9b3485

Please sign in to comment.