diff --git a/src/back/glsl/mod.rs b/src/back/glsl/mod.rs index e32c803f05..34a22b5247 100644 --- a/src/back/glsl/mod.rs +++ b/src/back/glsl/mod.rs @@ -2485,6 +2485,8 @@ impl<'a, W: Write> Writer<'a, W> { Mf::ReverseBits => "bitfieldReverse", Mf::ExtractBits => "bitfieldExtract", Mf::InsertBits => "bitfieldInsert", + Mf::FindLsb => "findLSB", + Mf::FindMsb => "findMSB", // data packing Mf::Pack4x8snorm => "packSnorm4x8", Mf::Pack4x8unorm => "packUnorm4x8", diff --git a/src/back/hlsl/writer.rs b/src/back/hlsl/writer.rs index ca00d889c0..2f7b4e7ad9 100644 --- a/src/back/hlsl/writer.rs +++ b/src/back/hlsl/writer.rs @@ -1874,6 +1874,8 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { // bits Mf::CountOneBits => Function::Regular("countbits"), Mf::ReverseBits => Function::Regular("reversebits"), + Mf::FindLsb => Function::Regular("firstbitlow"), + Mf::FindMsb => Function::Regular("firstbithigh"), _ => return Err(Error::Unimplemented(format!("write_expr_math {:?}", fun))), }; diff --git a/src/back/msl/writer.rs b/src/back/msl/writer.rs index 1050f1434c..664b7ea2ff 100644 --- a/src/back/msl/writer.rs +++ b/src/back/msl/writer.rs @@ -1099,6 +1099,21 @@ impl Writer { crate::TypeInner::Scalar { .. } => true, _ => false, }; + let argument_size_suffix = match *context.resolve_type(arg) { + crate::TypeInner::Vector { + size: crate::VectorSize::Bi, + .. + } => "2", + crate::TypeInner::Vector { + size: crate::VectorSize::Tri, + .. + } => "3", + crate::TypeInner::Vector { + size: crate::VectorSize::Quad, + .. + } => "4", + _ => "", + }; let fun_name = match fun { // comparison @@ -1162,6 +1177,8 @@ impl Writer { Mf::ReverseBits => "reverse_bits", Mf::ExtractBits => "extract_bits", Mf::InsertBits => "insert_bits", + Mf::FindLsb => "", + Mf::FindMsb => "", // data packing Mf::Pack4x8snorm => "pack_float_to_unorm4x8", Mf::Pack4x8unorm => "pack_float_to_snorm4x8", @@ -1182,6 +1199,22 @@ impl Writer { write!(self.out, " - ")?; self.put_expression(arg1.unwrap(), context, false)?; write!(self.out, ")")?; + } else if fun == Mf::FindLsb { + write!( + self.out, + "(((1 + int{}({}::ctz(", + argument_size_suffix, NAMESPACE + )?; + self.put_expression(arg, context, true)?; + write!(self.out, "))) % 33) - 1)")?; + } else if fun == Mf::FindMsb { + write!( + self.out, + "(((1 + int{}({}::clz(", + argument_size_suffix, NAMESPACE + )?; + self.put_expression(arg, context, true)?; + write!(self.out, "))) % 33) - 1)")?; } else if fun == Mf::Unpack2x16float { write!(self.out, "float2(as_type(")?; self.put_expression(arg, context, false)?; diff --git a/src/back/spv/block.rs b/src/back/spv/block.rs index 5f01ca028f..aba92b9b87 100644 --- a/src/back/spv/block.rs +++ b/src/back/spv/block.rs @@ -662,6 +662,12 @@ impl<'w> BlockContext<'w> { arg2_id, arg3_id, )), + Mf::FindLsb => MathOp::Ext(spirv::GLOp::FindILsb), + Mf::FindMsb => MathOp::Ext(match arg_scalar_kind { + Some(crate::ScalarKind::Uint) => spirv::GLOp::FindUMsb, + Some(crate::ScalarKind::Sint) => spirv::GLOp::FindSMsb, + other => unimplemented!("Unexpected findMSB({:?})", other), + }), Mf::Pack4x8unorm => MathOp::Ext(spirv::GLOp::PackUnorm4x8), Mf::Pack4x8snorm => MathOp::Ext(spirv::GLOp::PackSnorm4x8), Mf::Pack2x16float => MathOp::Ext(spirv::GLOp::PackHalf2x16), diff --git a/src/back/wgsl/writer.rs b/src/back/wgsl/writer.rs index 6c226fe443..dbff2db3c3 100644 --- a/src/back/wgsl/writer.rs +++ b/src/back/wgsl/writer.rs @@ -1566,6 +1566,8 @@ impl Writer { Mf::ReverseBits => Function::Regular("reverseBits"), Mf::ExtractBits => Function::Regular("extractBits"), Mf::InsertBits => Function::Regular("insertBits"), + Mf::FindLsb => Function::Regular("findLsb"), + Mf::FindMsb => Function::Regular("findMsb"), // data packing Mf::Pack4x8snorm => Function::Regular("pack4x8snorm"), Mf::Pack4x8unorm => Function::Regular("pack4x8unorm"), diff --git a/src/front/glsl/builtins.rs b/src/front/glsl/builtins.rs index a7416ded2d..1a114a57b2 100644 --- a/src/front/glsl/builtins.rs +++ b/src/front/glsl/builtins.rs @@ -630,12 +630,15 @@ pub fn inject_builtin(declaration: &mut FunctionDeclaration, module: &mut Module )) } } - "bitCount" | "bitfieldReverse" | "bitfieldExtract" | "bitfieldInsert" => { + "bitCount" | "bitfieldReverse" | "bitfieldExtract" | "bitfieldInsert" | "findLSB" + | "findMSB" => { let fun = match name { "bitCount" => MathFunction::CountOneBits, "bitfieldReverse" => MathFunction::ReverseBits, "bitfieldExtract" => MathFunction::ExtractBits, "bitfieldInsert" => MathFunction::InsertBits, + "findLSB" => MathFunction::FindLsb, + "findMSB" => MathFunction::FindMsb, _ => unreachable!(), }; diff --git a/src/front/spv/mod.rs b/src/front/spv/mod.rs index cfdf1fe88f..dc578374b5 100644 --- a/src/front/spv/mod.rs +++ b/src/front/spv/mod.rs @@ -2637,6 +2637,8 @@ impl> Parser { Glo::UnpackHalf2x16 => Mf::Unpack2x16float, Glo::UnpackUnorm2x16 => Mf::Unpack2x16unorm, Glo::UnpackSnorm2x16 => Mf::Unpack2x16snorm, + Glo::FindILsb => Mf::FindLsb, + Glo::FindUMsb | Glo::FindSMsb => Mf::FindMsb, _ => return Err(Error::UnsupportedExtInst(inst_id)), }; diff --git a/src/front/wgsl/conv.rs b/src/front/wgsl/conv.rs index 12d195ba6a..b93869071e 100644 --- a/src/front/wgsl/conv.rs +++ b/src/front/wgsl/conv.rs @@ -199,6 +199,8 @@ pub fn map_standard_fun(word: &str) -> Option { "reverseBits" => Mf::ReverseBits, "extractBits" => Mf::ExtractBits, "insertBits" => Mf::InsertBits, + "findLsb" => Mf::FindLsb, + "findMsb" => Mf::FindMsb, // data packing "pack4x8snorm" => Mf::Pack4x8snorm, "pack4x8unorm" => Mf::Pack4x8unorm, diff --git a/src/lib.rs b/src/lib.rs index 7be2b53773..349ae21f96 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -926,6 +926,8 @@ pub enum MathFunction { ReverseBits, ExtractBits, InsertBits, + FindLsb, + FindMsb, // data packing Pack4x8snorm, Pack4x8unorm, diff --git a/src/proc/mod.rs b/src/proc/mod.rs index 8344f934c3..04ae4a594c 100644 --- a/src/proc/mod.rs +++ b/src/proc/mod.rs @@ -262,6 +262,8 @@ impl super::MathFunction { Self::ReverseBits => 1, Self::ExtractBits => 3, Self::InsertBits => 4, + Self::FindLsb => 1, + Self::FindMsb => 1, // data packing Self::Pack4x8snorm => 1, Self::Pack4x8unorm => 1, diff --git a/src/proc/typifier.rs b/src/proc/typifier.rs index cf6035802f..270276243b 100644 --- a/src/proc/typifier.rs +++ b/src/proc/typifier.rs @@ -803,6 +803,16 @@ impl<'a> ResolveContext<'a> { Mf::ReverseBits | Mf::ExtractBits | Mf::InsertBits => res_arg.clone(), + Mf::FindLsb | + Mf::FindMsb => match *res_arg.inner_with(types) { + Ti::Scalar { kind: _, width } => + TypeResolution::Value(Ti::Scalar { kind: crate::ScalarKind::Sint, width }), + Ti::Vector { size, kind: _, width } => + TypeResolution::Value(Ti::Vector { size, kind: crate::ScalarKind::Sint, width }), + ref other => return Err(ResolveError::IncompatibleOperands( + format!("{:?}({:?})", fun, other) + )), + }, // data packing Mf::Pack4x8snorm | Mf::Pack4x8unorm | diff --git a/src/valid/expression.rs b/src/valid/expression.rs index 26ec967b19..29409a1bcd 100644 --- a/src/valid/expression.rs +++ b/src/valid/expression.rs @@ -1231,7 +1231,7 @@ impl super::Validator { _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)), } } - Mf::CountOneBits | Mf::ReverseBits => { + Mf::CountOneBits | Mf::ReverseBits | Mf::FindLsb | Mf::FindMsb => { if arg1_ty.is_some() | arg2_ty.is_some() | arg3_ty.is_some() { return Err(ExpressionError::WrongArgumentCount(fun)); } diff --git a/tests/in/bits.wgsl b/tests/in/bits.wgsl index 1834411f35..4bec623aa4 100644 --- a/tests/in/bits.wgsl +++ b/tests/in/bits.wgsl @@ -36,4 +36,8 @@ fn main() { u2 = extractBits(u2, 5u, 10u); u3 = extractBits(u3, 5u, 10u); u4 = extractBits(u4, 5u, 10u); + i = findLsb(i); + i2 = findLsb(u2); + i3 = findMsb(i3); + i = findMsb(u); } diff --git a/tests/in/glsl/bits_glsl.frag b/tests/in/glsl/bits_glsl.frag index cd960ed861..807ca32160 100644 --- a/tests/in/glsl/bits_glsl.frag +++ b/tests/in/glsl/bits_glsl.frag @@ -37,4 +37,20 @@ void main() { u2 = bitfieldExtract(u2, 5, 10); u3 = bitfieldExtract(u3, 5, 10); u4 = bitfieldExtract(u4, 5, 10); + i = findLSB(i); + i2 = findLSB(i2); + i3 = findLSB(i3); + i4 = findLSB(i4); + i = findLSB(u); + i2 = findLSB(u2); + i3 = findLSB(u3); + i4 = findLSB(u4); + i = findMSB(i); + i2 = findMSB(i2); + i3 = findMSB(i3); + i4 = findMSB(i4); + i = findMSB(u); + i2 = findMSB(u2); + i3 = findMSB(u3); + i4 = findMSB(u4); } \ No newline at end of file diff --git a/tests/out/glsl/bits.main.Compute.glsl b/tests/out/glsl/bits.main.Compute.glsl index 303cdbc440..dd44101be1 100644 --- a/tests/out/glsl/bits.main.Compute.glsl +++ b/tests/out/glsl/bits.main.Compute.glsl @@ -85,6 +85,14 @@ void main() { u3_ = bitfieldExtract(_e112, int(5u), int(10u)); uvec4 _e116 = u4_; u4_ = bitfieldExtract(_e116, int(5u), int(10u)); + int _e120 = i; + i = findLSB(_e120); + uvec2 _e122 = u2_; + i2_ = findLSB(_e122); + ivec3 _e124 = i3_; + i3_ = findMSB(_e124); + uint _e126 = u; + i = findMSB(_e126); return; } diff --git a/tests/out/msl/bits.msl b/tests/out/msl/bits.msl index 3067ab4165..b0c76ff82c 100644 --- a/tests/out/msl/bits.msl +++ b/tests/out/msl/bits.msl @@ -83,5 +83,13 @@ kernel void main_( u3_ = metal::extract_bits(_e112, 5u, 10u); metal::uint4 _e116 = u4_; u4_ = metal::extract_bits(_e116, 5u, 10u); + int _e120 = i; + i = (((1 + int(metal::ctz(_e120))) % 33) - 1); + metal::uint2 _e122 = u2_; + i2_ = (((1 + int2(metal::ctz(_e122))) % 33) - 1); + metal::int3 _e124 = i3_; + i3_ = (((1 + int3(metal::clz(_e124))) % 33) - 1); + metal::uint _e126 = u; + i = (((1 + int(metal::clz(_e126))) % 33) - 1); return; } diff --git a/tests/out/spv/bits.spvasm b/tests/out/spv/bits.spvasm index 3ef507eb4c..b224d04ab3 100644 --- a/tests/out/spv/bits.spvasm +++ b/tests/out/spv/bits.spvasm @@ -1,7 +1,7 @@ ; SPIR-V ; Version: 1.1 ; Generator: rspirv -; Bound: 111 +; Bound: 119 OpCapability Shader %1 = OpExtInstImport "GLSL.std.450" OpMemoryModel Logical GLSL450 @@ -151,5 +151,17 @@ OpStore %31 %108 %109 = OpLoad %16 %33 %110 = OpBitFieldUExtract %16 %109 %9 %10 OpStore %33 %110 +%111 = OpLoad %4 %19 +%112 = OpExtInst %4 %1 FindILsb %111 +OpStore %19 %112 +%113 = OpLoad %14 %29 +%114 = OpExtInst %11 %1 FindILsb %113 +OpStore %21 %114 +%115 = OpLoad %12 %23 +%116 = OpExtInst %12 %1 FindSMsb %115 +OpStore %23 %116 +%117 = OpLoad %6 %27 +%118 = OpExtInst %4 %1 FindUMsb %117 +OpStore %19 %118 OpReturn OpFunctionEnd \ No newline at end of file diff --git a/tests/out/wgsl/bits.wgsl b/tests/out/wgsl/bits.wgsl index 75eb3a5f3c..45037118bf 100644 --- a/tests/out/wgsl/bits.wgsl +++ b/tests/out/wgsl/bits.wgsl @@ -79,5 +79,13 @@ fn main() { u3_ = extractBits(_e112, 5u, 10u); let _e116 = u4_; u4_ = extractBits(_e116, 5u, 10u); + let _e120 = i; + i = findLsb(_e120); + let _e122 = u2_; + i2_ = findLsb(_e122); + let _e124 = i3_; + i3_ = findMsb(_e124); + let _e126 = u; + i = findMsb(_e126); return; } diff --git a/tests/out/wgsl/bits_glsl-frag.wgsl b/tests/out/wgsl/bits_glsl-frag.wgsl index 504adf8d7d..29f77ca764 100644 --- a/tests/out/wgsl/bits_glsl-frag.wgsl +++ b/tests/out/wgsl/bits_glsl-frag.wgsl @@ -70,6 +70,38 @@ fn main_1() { u3_ = extractBits(_e207, u32(5), u32(10)); let _e216 = u4_; u4_ = extractBits(_e216, u32(5), u32(10)); + let _e223 = i; + i = findLsb(_e223); + let _e226 = i2_; + i2_ = findLsb(_e226); + let _e229 = i3_; + i3_ = findLsb(_e229); + let _e232 = i4_; + i4_ = findLsb(_e232); + let _e235 = u; + i = findLsb(_e235); + let _e238 = u2_; + i2_ = findLsb(_e238); + let _e241 = u3_; + i3_ = findLsb(_e241); + let _e244 = u4_; + i4_ = findLsb(_e244); + let _e247 = i; + i = findMsb(_e247); + let _e250 = i2_; + i2_ = findMsb(_e250); + let _e253 = i3_; + i3_ = findMsb(_e253); + let _e256 = i4_; + i4_ = findMsb(_e256); + let _e259 = u; + i = findMsb(_e259); + let _e262 = u2_; + i2_ = findMsb(_e262); + let _e265 = u3_; + i3_ = findMsb(_e265); + let _e268 = u4_; + i4_ = findMsb(_e268); return; }