diff --git a/src/back/dot/mod.rs b/src/back/dot/mod.rs index 1556371df1..86f4797b56 100644 --- a/src/back/dot/mod.rs +++ b/src/back/dot/mod.rs @@ -279,6 +279,94 @@ impl StatementGraph { crate::RayQueryFunction::Terminate => "RayQueryTerminate", } } + S::SubgroupBallot { result, predicate } => { + if let Some(predicate) = predicate { + self.dependencies.push((id, predicate, "predicate")); + } + self.emits.push((id, result)); + "SubgroupBallot" + } + S::SubgroupCollectiveOperation { + op, + collective_op, + argument, + result, + } => { + self.dependencies.push((id, argument, "arg")); + self.emits.push((id, result)); + match (collective_op, op) { + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::All) => { + "SubgroupAll" + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Any) => { + "SubgroupAny" + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Add) => { + "SubgroupAdd" + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Mul) => { + "SubgroupMul" + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Max) => { + "SubgroupMax" + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Min) => { + "SubgroupMin" + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::And) => { + "SubgroupAnd" + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Or) => { + "SubgroupOr" + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Xor) => { + "SubgroupXor" + } + ( + crate::CollectiveOperation::ExclusiveScan, + crate::SubgroupOperation::Add, + ) => "SubgroupPrefixExclusiveAdd", + ( + crate::CollectiveOperation::ExclusiveScan, + crate::SubgroupOperation::Mul, + ) => "SubgroupPrefixExclusiveMul", + ( + crate::CollectiveOperation::InclusiveScan, + crate::SubgroupOperation::Add, + ) => "SubgroupPrefixInclusiveAdd", + ( + crate::CollectiveOperation::InclusiveScan, + crate::SubgroupOperation::Mul, + ) => "SubgroupPrefixInclusiveMul", + _ => unimplemented!(), + } + } + S::SubgroupGather { + mode, + argument, + result, + } => { + match mode { + crate::GatherMode::BroadcastFirst => {} + crate::GatherMode::Broadcast(index) + | crate::GatherMode::Shuffle(index) + | crate::GatherMode::ShuffleDown(index) + | crate::GatherMode::ShuffleUp(index) + | crate::GatherMode::ShuffleXor(index) => { + self.dependencies.push((id, index, "index")) + } + } + self.dependencies.push((id, argument, "arg")); + self.emits.push((id, result)); + match mode { + crate::GatherMode::BroadcastFirst => "SubgroupBroadcastFirst", + crate::GatherMode::Broadcast(_) => "SubgroupBroadcast", + crate::GatherMode::Shuffle(_) => "SubgroupShuffle", + crate::GatherMode::ShuffleDown(_) => "SubgroupShuffleDown", + crate::GatherMode::ShuffleUp(_) => "SubgroupShuffleUp", + crate::GatherMode::ShuffleXor(_) => "SubgroupShuffleXor", + } + } }; // Set the last node to the merge node last_node = merge_id; @@ -586,6 +674,8 @@ fn write_function_expressions( let ty = if committed { "Committed" } else { "Candidate" }; (format!("rayQueryGet{}Intersection", ty).into(), 4) } + E::SubgroupBallotResult => ("SubgroupBallotResult".into(), 4), + E::SubgroupOperationResult { .. } => ("SubgroupOperationResult".into(), 4), }; // give uniform expressions an outline diff --git a/src/back/glsl/features.rs b/src/back/glsl/features.rs index b1fff4d4bc..483a3a9348 100644 --- a/src/back/glsl/features.rs +++ b/src/back/glsl/features.rs @@ -43,6 +43,8 @@ bitflags::bitflags! { const IMAGE_SIZE = 1 << 20; /// Dual source blending const DUAL_SOURCE_BLENDING = 1 << 21; + /// Subgroup operations + const SUBGROUP_OPERATIONS = 1 << 22; } } @@ -106,6 +108,7 @@ impl FeaturesManager { check_feature!(SAMPLE_VARIABLES, 400, 300); check_feature!(DYNAMIC_ARRAY_SIZE, 430, 310); check_feature!(DUAL_SOURCE_BLENDING, 330, 300 /* with extension */); + check_feature!(SUBGROUP_OPERATIONS, 430, 310); match version { Version::Embedded { is_webgl: true, .. } => check_feature!(MULTI_VIEW, 140, 300), _ => check_feature!(MULTI_VIEW, 140, 310), @@ -235,6 +238,22 @@ impl FeaturesManager { writeln!(out, "#extension GL_EXT_blend_func_extended : require")?; } + if self.0.contains(Features::SUBGROUP_OPERATIONS) { + // https://registry.khronos.org/OpenGL/extensions/KHR/KHR_shader_subgroup.txt + writeln!(out, "#extension GL_KHR_shader_subgroup_basic : require")?; + writeln!(out, "#extension GL_KHR_shader_subgroup_vote : require")?; + writeln!( + out, + "#extension GL_KHR_shader_subgroup_arithmetic : require" + )?; + writeln!(out, "#extension GL_KHR_shader_subgroup_ballot : require")?; + writeln!(out, "#extension GL_KHR_shader_subgroup_shuffle : require")?; + writeln!( + out, + "#extension GL_KHR_shader_subgroup_shuffle_relative : require" + )?; + } + Ok(()) } } @@ -455,6 +474,10 @@ impl<'a, W> Writer<'a, W> { } } } + Expression::SubgroupBallotResult | + Expression::SubgroupOperationResult { .. } => { + features.request(Features::SUBGROUP_OPERATIONS) + } _ => {} } } diff --git a/src/back/glsl/mod.rs b/src/back/glsl/mod.rs index 60431e986e..03a06890f8 100644 --- a/src/back/glsl/mod.rs +++ b/src/back/glsl/mod.rs @@ -2250,6 +2250,125 @@ impl<'a, W: Write> Writer<'a, W> { writeln!(self.out, ");")?; } Statement::RayQuery { .. } => unreachable!(), + Statement::SubgroupBallot { result, predicate } => { + write!(self.out, "{level}")?; + let res_name = format!("{}{}", back::BAKE_PREFIX, result.index()); + let res_ty = ctx.info[result].ty.inner_with(&self.module.types); + self.write_value_type(res_ty)?; + write!(self.out, " {res_name} = ")?; + self.named_expressions.insert(result, res_name); + + write!(self.out, "subgroupBallot(")?; + match predicate { + Some(predicate) => self.write_expr(predicate, ctx)?, + None => write!(self.out, "true")?, + } + writeln!(self.out, ");")?; + } + Statement::SubgroupCollectiveOperation { + op, + collective_op, + argument, + result, + } => { + write!(self.out, "{level}")?; + let res_name = format!("{}{}", back::BAKE_PREFIX, result.index()); + let res_ty = ctx.info[result].ty.inner_with(&self.module.types); + self.write_value_type(res_ty)?; + write!(self.out, " {res_name} = ")?; + self.named_expressions.insert(result, res_name); + + match (collective_op, op) { + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::All) => { + write!(self.out, "subgroupAll(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Any) => { + write!(self.out, "subgroupAny(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Add) => { + write!(self.out, "subgroupAdd(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Mul) => { + write!(self.out, "subgroupMul(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Max) => { + write!(self.out, "subgroupMax(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Min) => { + write!(self.out, "subgroupMin(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::And) => { + write!(self.out, "subgroupAnd(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Or) => { + write!(self.out, "subgroupOr(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Xor) => { + write!(self.out, "subgroupXor(")? + } + (crate::CollectiveOperation::ExclusiveScan, crate::SubgroupOperation::Add) => { + write!(self.out, "subgroupExclusiveAdd(")? + } + (crate::CollectiveOperation::ExclusiveScan, crate::SubgroupOperation::Mul) => { + write!(self.out, "subgroupExclusiveMul(")? + } + (crate::CollectiveOperation::InclusiveScan, crate::SubgroupOperation::Add) => { + write!(self.out, "subgroupInclusiveAdd(")? + } + (crate::CollectiveOperation::InclusiveScan, crate::SubgroupOperation::Mul) => { + write!(self.out, "subgroupInclusiveMul(")? + } + _ => unimplemented!(), + } + self.write_expr(argument, ctx)?; + writeln!(self.out, ");")?; + } + Statement::SubgroupGather { + mode, + argument, + result, + } => { + write!(self.out, "{level}")?; + let res_name = format!("{}{}", back::BAKE_PREFIX, result.index()); + let res_ty = ctx.info[result].ty.inner_with(&self.module.types); + self.write_value_type(res_ty)?; + write!(self.out, " {res_name} = ")?; + self.named_expressions.insert(result, res_name); + + match mode { + crate::GatherMode::BroadcastFirst => { + write!(self.out, "subgroupBroadcastFirst(")?; + } + crate::GatherMode::Broadcast(_) => { + write!(self.out, "subgroupBroadcast(")?; + } + crate::GatherMode::Shuffle(_) => { + write!(self.out, "subgroupShuffle(")?; + } + crate::GatherMode::ShuffleDown(_) => { + write!(self.out, "subgroupShuffleDown(")?; + } + crate::GatherMode::ShuffleUp(_) => { + write!(self.out, "subgroupShuffleUp(")?; + } + crate::GatherMode::ShuffleXor(_) => { + write!(self.out, "subgroupShuffleXor(")?; + } + } + self.write_expr(argument, ctx)?; + match mode { + crate::GatherMode::BroadcastFirst => {} + crate::GatherMode::Broadcast(index) + | crate::GatherMode::Shuffle(index) + | crate::GatherMode::ShuffleDown(index) + | crate::GatherMode::ShuffleUp(index) + | crate::GatherMode::ShuffleXor(index) => { + write!(self.out, ", ")?; + self.write_expr(index, ctx)?; + } + } + writeln!(self.out, ");")?; + } } Ok(()) @@ -3423,7 +3542,9 @@ impl<'a, W: Write> Writer<'a, W> { Expression::CallResult(_) | Expression::AtomicResult { .. } | Expression::RayQueryProceedResult - | Expression::WorkGroupUniformLoadResult { .. } => unreachable!(), + | Expression::WorkGroupUniformLoadResult { .. } + | Expression::SubgroupOperationResult { .. } + | Expression::SubgroupBallotResult => unreachable!(), // `ArrayLength` is written as `expr.length()` and we convert it to a uint Expression::ArrayLength(expr) => { write!(self.out, "uint(")?; @@ -3980,6 +4101,9 @@ impl<'a, W: Write> Writer<'a, W> { if flags.contains(crate::Barrier::WORK_GROUP) { writeln!(self.out, "{level}memoryBarrierShared();")?; } + if flags.contains(crate::Barrier::SUB_GROUP) { + writeln!(self.out, "{level}subgroupMemoryBarrier();")?; + } writeln!(self.out, "{level}barrier();")?; Ok(()) } @@ -4156,6 +4280,11 @@ const fn glsl_built_in( Bi::WorkGroupId => "gl_WorkGroupID", Bi::WorkGroupSize => "gl_WorkGroupSize", Bi::NumWorkGroups => "gl_NumWorkGroups", + // subgroup + Bi::NumSubgroups => "gl_NumSubgroups", + Bi::SubgroupId => "gl_SubgroupID", + Bi::SubgroupSize => "gl_SubgroupSize", + Bi::SubgroupInvocationId => "gl_SubgroupInvocationID", } } diff --git a/src/back/hlsl/conv.rs b/src/back/hlsl/conv.rs index 19bde6926a..d3fb76e401 100644 --- a/src/back/hlsl/conv.rs +++ b/src/back/hlsl/conv.rs @@ -166,6 +166,10 @@ impl crate::BuiltIn { // to this field will get replaced with references to `SPECIAL_CBUF_VAR` // in `Writer::write_expr`. Self::NumWorkGroups => "SV_GroupID", + Self::SubgroupSize + | Self::SubgroupInvocationId + | Self::NumSubgroups + | Self::SubgroupId => return Err(Error::Unimplemented(format!("builtin {self:?}"))), Self::BaseInstance | Self::BaseVertex | Self::WorkGroupSize => { return Err(Error::Unimplemented(format!("builtin {self:?}"))) } diff --git a/src/back/hlsl/writer.rs b/src/back/hlsl/writer.rs index f26604476a..1eab43a4c3 100644 --- a/src/back/hlsl/writer.rs +++ b/src/back/hlsl/writer.rs @@ -1130,7 +1130,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { write!(self.out, " {name}(")?; let need_workgroup_variables_initialization = - self.need_workgroup_variables_initialization(func_ctx, module); + self.need_workgroup_variables_initialization(func, func_ctx, module); // Write function arguments for non entry point functions match func_ctx.ty { @@ -1166,7 +1166,21 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { write!(self.out, "{} {}", ep_input.ty_name, ep_input.arg_name,)?; } else { let stage = module.entry_points[ep_index as usize].stage; + let mut arg_num = 0; for (index, arg) in func.arguments.iter().enumerate() { + if matches!( + arg.binding, + Some(crate::Binding::BuiltIn(crate::BuiltIn::SubgroupSize)) + | Some(crate::Binding::BuiltIn( + crate::BuiltIn::SubgroupInvocationId + )) + | Some(crate::Binding::BuiltIn(crate::BuiltIn::NumSubgroups)) + | Some(crate::Binding::BuiltIn(crate::BuiltIn::SubgroupId)) + ) { + continue; + } + arg_num += 1; + if index != 0 { write!(self.out, ", ")?; } @@ -1186,7 +1200,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { } if need_workgroup_variables_initialization { - if !func.arguments.is_empty() { + if arg_num > 0 { write!(self.out, ", ")?; } write!(self.out, "uint3 __local_invocation_id : SV_GroupThreadID")?; @@ -1217,6 +1231,53 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { self.write_workgroup_variables_initialization(func_ctx, module)?; } + if let back::FunctionType::EntryPoint(ep_index) = func_ctx.ty { + let ep = &module.entry_points[ep_index as usize]; + for (index, arg) in func.arguments.iter().enumerate() { + if let Some(crate::Binding::BuiltIn(builtin)) = arg.binding { + if matches!( + builtin, + crate::BuiltIn::SubgroupSize + | crate::BuiltIn::SubgroupInvocationId + | crate::BuiltIn::NumSubgroups + | crate::BuiltIn::SubgroupId + ) { + let level = back::Level(1); + write!(self.out, "{level}const ")?; + + self.write_type(module, arg.ty)?; + + let argument_name = + &self.names[&NameKey::EntryPointArgument(ep_index, index as u32)]; + write!(self.out, " {argument_name} = ")?; + + match builtin { + crate::BuiltIn::SubgroupSize => { + writeln!(self.out, "WaveGetLaneCount();")? + } + crate::BuiltIn::SubgroupInvocationId => { + writeln!(self.out, "WaveGetLaneIndex();")? + } + crate::BuiltIn::NumSubgroups => writeln!( + self.out, + "({}u + WaveGetLaneCount() - 1u) / WaveGetLaneCount();", + ep.workgroup_size[0] * ep.workgroup_size[1] * ep.workgroup_size[2] + )?, + crate::BuiltIn::SubgroupId => { + writeln!( + self.out, + "(__local_invocation_id.x * {}u + __local_invocation_id.y * {}u + __local_invocation_id.z) / WaveGetLaneCount();", + ep.workgroup_size[0] * ep.workgroup_size[1], + ep.workgroup_size[1], + )?; + } + _ => unreachable!(), + } + } + } + } + } + if let back::FunctionType::EntryPoint(index) = func_ctx.ty { self.write_ep_arguments_initialization(module, func, index)?; } @@ -1267,14 +1328,20 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { fn need_workgroup_variables_initialization( &mut self, + func: &crate::Function, func_ctx: &back::FunctionCtx, module: &Module, ) -> bool { - self.options.zero_initialize_workgroup_memory + func.arguments.iter().any(|arg| { + matches!( + arg.binding, + Some(crate::Binding::BuiltIn(crate::BuiltIn::NumSubgroups)) + ) + }) || (self.options.zero_initialize_workgroup_memory && func_ctx.ty.is_compute_entry_point(module) && module.global_variables.iter().any(|(handle, var)| { !func_ctx.info[handle].is_empty() && var.space == crate::AddressSpace::WorkGroup - }) + })) } fn write_workgroup_variables_initialization( @@ -2004,6 +2071,129 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { writeln!(self.out, "{level}}}")? } Statement::RayQuery { .. } => unreachable!(), + Statement::SubgroupBallot { result, predicate } => { + write!(self.out, "{level}")?; + let name = format!("{}{}", back::BAKE_PREFIX, result.index()); + write!(self.out, "const uint4 {name} = ")?; + self.named_expressions.insert(result, name); + + write!(self.out, "WaveActiveBallot(")?; + match predicate { + Some(predicate) => self.write_expr(module, predicate, func_ctx)?, + None => write!(self.out, "true")?, + } + writeln!(self.out, ");")?; + } + Statement::SubgroupCollectiveOperation { + op, + collective_op, + argument, + result, + } => { + write!(self.out, "{level}")?; + write!(self.out, "const ")?; + let name = format!("{}{}", back::BAKE_PREFIX, result.index()); + match func_ctx.info[result].ty { + proc::TypeResolution::Handle(handle) => self.write_type(module, handle)?, + proc::TypeResolution::Value(ref value) => { + self.write_value_type(module, value)? + } + }; + write!(self.out, " {name} = ")?; + self.named_expressions.insert(result, name); + + match (collective_op, op) { + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::All) => { + write!(self.out, "WaveActiveAllTrue(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Any) => { + write!(self.out, "WaveActiveAnyTrue(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Add) => { + write!(self.out, "WaveActiveSum(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Mul) => { + write!(self.out, "WaveActiveProduct(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Max) => { + write!(self.out, "WaveActiveMax(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Min) => { + write!(self.out, "WaveActiveMin(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::And) => { + write!(self.out, "WaveActiveBitAnd(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Or) => { + write!(self.out, "WaveActiveBitOr(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Xor) => { + write!(self.out, "WaveActiveBitXor(")? + } + (crate::CollectiveOperation::ExclusiveScan, crate::SubgroupOperation::Add) => { + write!(self.out, "WavePrefixSum(")? + } + (crate::CollectiveOperation::ExclusiveScan, crate::SubgroupOperation::Mul) => { + write!(self.out, "WavePrefixProduct(")? + } + (crate::CollectiveOperation::InclusiveScan, crate::SubgroupOperation::Add) => { + self.write_expr(module, argument, func_ctx)?; + write!(self.out, " + WavePrefixSum(")?; + } + (crate::CollectiveOperation::InclusiveScan, crate::SubgroupOperation::Mul) => { + self.write_expr(module, argument, func_ctx)?; + write!(self.out, " * WavePrefixProduct(")?; + } + _ => unimplemented!(), + } + self.write_expr(module, argument, func_ctx)?; + writeln!(self.out, ");")?; + } + Statement::SubgroupGather { + mode, + argument, + result, + } => { + write!(self.out, "{level}")?; + write!(self.out, "const ")?; + let name = format!("{}{}", back::BAKE_PREFIX, result.index()); + match func_ctx.info[result].ty { + proc::TypeResolution::Handle(handle) => self.write_type(module, handle)?, + proc::TypeResolution::Value(ref value) => { + self.write_value_type(module, value)? + } + }; + write!(self.out, " {name} = ")?; + self.named_expressions.insert(result, name); + + if matches!(mode, crate::GatherMode::BroadcastFirst) { + write!(self.out, "WaveReadLaneFirst(")?; + self.write_expr(module, argument, func_ctx)?; + } else { + write!(self.out, "WaveReadLaneAt(")?; + self.write_expr(module, argument, func_ctx)?; + write!(self.out, ", ")?; + match mode { + crate::GatherMode::BroadcastFirst => unreachable!(), + crate::GatherMode::Broadcast(index) | crate::GatherMode::Shuffle(index) => { + self.write_expr(module, index, func_ctx)?; + } + crate::GatherMode::ShuffleDown(index) => { + write!(self.out, "WaveGetLaneIndex() + ")?; + self.write_expr(module, index, func_ctx)?; + } + crate::GatherMode::ShuffleUp(index) => { + write!(self.out, "WaveGetLaneIndex() - ")?; + self.write_expr(module, index, func_ctx)?; + } + crate::GatherMode::ShuffleXor(index) => { + write!(self.out, "WaveGetLaneIndex() ^ ")?; + self.write_expr(module, index, func_ctx)?; + } + } + } + writeln!(self.out, ");")?; + } } Ok(()) @@ -3152,7 +3342,9 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { Expression::CallResult(_) | Expression::AtomicResult { .. } | Expression::WorkGroupUniformLoadResult { .. } - | Expression::RayQueryProceedResult => {} + | Expression::RayQueryProceedResult + | Expression::SubgroupBallotResult + | Expression::SubgroupOperationResult { .. } => {} } if !closing_bracket.is_empty() { @@ -3219,6 +3411,9 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { if barrier.contains(crate::Barrier::WORK_GROUP) { writeln!(self.out, "{level}GroupMemoryBarrierWithGroupSync();")?; } + if barrier.contains(crate::Barrier::SUB_GROUP) { + // Does not exist in DirectX + } Ok(()) } } diff --git a/src/back/msl/mod.rs b/src/back/msl/mod.rs index 5ef18730c9..eee825a83b 100644 --- a/src/back/msl/mod.rs +++ b/src/back/msl/mod.rs @@ -437,6 +437,11 @@ impl ResolvedBinding { Bi::WorkGroupId => "threadgroup_position_in_grid", Bi::WorkGroupSize => "dispatch_threads_per_threadgroup", Bi::NumWorkGroups => "threadgroups_per_grid", + // subgroup + Bi::NumSubgroups => "simdgroups_per_threadgroup", + Bi::SubgroupId => "simdgroup_index_in_threadgroup", + Bi::SubgroupSize => "threads_per_simdgroup", + Bi::SubgroupInvocationId => "thread_index_in_simdgroup", Bi::CullDistance | Bi::ViewIndex => { return Err(Error::UnsupportedBuiltIn(built_in)) } diff --git a/src/back/msl/writer.rs b/src/back/msl/writer.rs index 09f7b1c73f..1b0791b573 100644 --- a/src/back/msl/writer.rs +++ b/src/back/msl/writer.rs @@ -1935,6 +1935,8 @@ impl Writer { crate::Expression::CallResult(_) | crate::Expression::AtomicResult { .. } | crate::Expression::WorkGroupUniformLoadResult { .. } + | crate::Expression::SubgroupBallotResult + | crate::Expression::SubgroupOperationResult { .. } | crate::Expression::RayQueryProceedResult => { unreachable!() } @@ -3010,6 +3012,121 @@ impl Writer { } } } + crate::Statement::SubgroupBallot { result, predicate } => { + write!(self.out, "{level}")?; + let name = self.namer.call(""); + self.start_baking_expression(result, &context.expression, &name)?; + self.named_expressions.insert(result, name); + write!(self.out, "uint4((uint64_t){NAMESPACE}::simd_ballot(")?; + if let Some(predicate) = predicate { + self.put_expression(predicate, &context.expression, true)?; + } else { + write!(self.out, "true")?; + } + writeln!(self.out, "), 0, 0, 0);")?; + } + crate::Statement::SubgroupCollectiveOperation { + op, + collective_op, + argument, + result, + } => { + write!(self.out, "{level}")?; + let name = self.namer.call(""); + self.start_baking_expression(result, &context.expression, &name)?; + self.named_expressions.insert(result, name); + match (collective_op, op) { + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::All) => { + write!(self.out, "{NAMESPACE}::simd_all(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Any) => { + write!(self.out, "{NAMESPACE}::simd_any(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Add) => { + write!(self.out, "{NAMESPACE}::simd_sum(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Mul) => { + write!(self.out, "{NAMESPACE}::simd_product(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Max) => { + write!(self.out, "{NAMESPACE}::simd_max(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Min) => { + write!(self.out, "{NAMESPACE}::simd_min(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::And) => { + write!(self.out, "{NAMESPACE}::simd_and(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Or) => { + write!(self.out, "{NAMESPACE}::simd_or(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Xor) => { + write!(self.out, "{NAMESPACE}::simd_xor(")? + } + ( + crate::CollectiveOperation::ExclusiveScan, + crate::SubgroupOperation::Add, + ) => write!(self.out, "{NAMESPACE}::simd_prefix_exclusive_sum(")?, + ( + crate::CollectiveOperation::ExclusiveScan, + crate::SubgroupOperation::Mul, + ) => write!(self.out, "{NAMESPACE}::simd_prefix_exclusive_product(")?, + ( + crate::CollectiveOperation::InclusiveScan, + crate::SubgroupOperation::Add, + ) => write!(self.out, "{NAMESPACE}::simd_prefix_inclusive_sum(")?, + ( + crate::CollectiveOperation::InclusiveScan, + crate::SubgroupOperation::Mul, + ) => write!(self.out, "{NAMESPACE}::simd_prefix_inclusive_product(")?, + _ => unimplemented!(), + } + self.put_expression(argument, &context.expression, true)?; + writeln!(self.out, ");")?; + } + crate::Statement::SubgroupGather { + mode, + argument, + result, + } => { + write!(self.out, "{level}")?; + let name = self.namer.call(""); + self.start_baking_expression(result, &context.expression, &name)?; + self.named_expressions.insert(result, name); + match mode { + crate::GatherMode::BroadcastFirst => { + write!(self.out, "{NAMESPACE}::simd_broadcast_first(")?; + } + crate::GatherMode::Broadcast(_) => { + write!(self.out, "{NAMESPACE}::simd_broadcast(")?; + } + crate::GatherMode::Shuffle(_) => { + write!(self.out, "{NAMESPACE}::simd_shuffle(")?; + } + crate::GatherMode::ShuffleDown(_) => { + write!(self.out, "{NAMESPACE}::simd_shuffle_down(")?; + } + crate::GatherMode::ShuffleUp(_) => { + write!(self.out, "{NAMESPACE}::simd_shuffle_up(")?; + } + crate::GatherMode::ShuffleXor(_) => { + write!(self.out, "{NAMESPACE}::simd_shuffle_xor(")?; + } + } + self.put_expression(argument, &context.expression, true)?; + match mode { + crate::GatherMode::BroadcastFirst => {} + crate::GatherMode::Broadcast(index) + | crate::GatherMode::Shuffle(index) + | crate::GatherMode::ShuffleDown(index) + | crate::GatherMode::ShuffleUp(index) + | crate::GatherMode::ShuffleXor(index) => { + write!(self.out, ", ")?; + self.put_expression(index, &context.expression, true)?; + } + } + writeln!(self.out, ");")?; + } } } @@ -4346,6 +4463,12 @@ impl Writer { "{level}{NAMESPACE}::threadgroup_barrier({NAMESPACE}::mem_flags::mem_threadgroup);", )?; } + if flags.contains(crate::Barrier::SUB_GROUP) { + writeln!( + self.out, + "{level}{NAMESPACE}::simdgroup_barrier({NAMESPACE}::mem_flags::mem_threadgroup);", + )?; + } Ok(()) } } @@ -4616,8 +4739,8 @@ fn test_stack_size() { } let stack_size = addresses_end - addresses_start; // check the size (in debug only) - // last observed macOS value: 19152 (CI) - if !(9000..=20000).contains(&stack_size) { + // last observed macOS value: 22256 (CI) + if !(15000..=25000).contains(&stack_size) { panic!("`put_block` stack size {stack_size} has changed!"); } } diff --git a/src/back/spv/block.rs b/src/back/spv/block.rs index 0471d957f0..50883ce071 100644 --- a/src/back/spv/block.rs +++ b/src/back/spv/block.rs @@ -1130,7 +1130,9 @@ impl<'w> BlockContext<'w> { crate::Expression::CallResult(_) | crate::Expression::AtomicResult { .. } | crate::Expression::WorkGroupUniformLoadResult { .. } - | crate::Expression::RayQueryProceedResult => self.cached[expr_handle], + | crate::Expression::RayQueryProceedResult + | crate::Expression::SubgroupBallotResult + | crate::Expression::SubgroupOperationResult { .. } => self.cached[expr_handle], crate::Expression::As { expr, kind, @@ -2338,6 +2340,27 @@ impl<'w> BlockContext<'w> { crate::Statement::RayQuery { query, ref fun } => { self.write_ray_query_function(query, fun, &mut block); } + crate::Statement::SubgroupBallot { + result, + ref predicate, + } => { + self.write_subgroup_ballot(predicate, result, &mut block)?; + } + crate::Statement::SubgroupCollectiveOperation { + ref op, + ref collective_op, + argument, + result, + } => { + self.write_subgroup_operation(op, collective_op, argument, result, &mut block)?; + } + crate::Statement::SubgroupGather { + ref mode, + argument, + result, + } => { + self.write_subgroup_gather(mode, argument, result, &mut block)?; + } } } diff --git a/src/back/spv/instructions.rs b/src/back/spv/instructions.rs index b963793ad3..5f7c6b34fd 100644 --- a/src/back/spv/instructions.rs +++ b/src/back/spv/instructions.rs @@ -1037,6 +1037,73 @@ impl super::Instruction { instruction.add_operand(semantics_id); instruction } + + // Group Instructions + + pub(super) fn group_non_uniform_ballot( + result_type_id: Word, + id: Word, + exec_scope_id: Word, + predicate: Word, + ) -> Self { + let mut instruction = Self::new(Op::GroupNonUniformBallot); + instruction.set_type(result_type_id); + instruction.set_result(id); + instruction.add_operand(exec_scope_id); + instruction.add_operand(predicate); + + instruction + } + pub(super) fn group_non_uniform_broadcast_first( + result_type_id: Word, + id: Word, + exec_scope_id: Word, + value: Word, + ) -> Self { + let mut instruction = Self::new(Op::GroupNonUniformBroadcastFirst); + instruction.set_type(result_type_id); + instruction.set_result(id); + instruction.add_operand(exec_scope_id); + instruction.add_operand(value); + + instruction + } + pub(super) fn group_non_uniform_gather( + op: Op, + result_type_id: Word, + id: Word, + exec_scope_id: Word, + value: Word, + index: Word, + ) -> Self { + let mut instruction = Self::new(op); + instruction.set_type(result_type_id); + instruction.set_result(id); + instruction.add_operand(exec_scope_id); + instruction.add_operand(value); + instruction.add_operand(index); + + instruction + } + pub(super) fn group_non_uniform_arithmetic( + op: Op, + result_type_id: Word, + id: Word, + exec_scope_id: Word, + group_op: Option, + value: Word, + ) -> Self { + let mut instruction = Self::new(op); + instruction.set_type(result_type_id); + instruction.set_result(id); + instruction.add_operand(exec_scope_id); + if let Some(group_op) = group_op { + instruction.add_operand(group_op as u32); + } + instruction.add_operand(value); + + instruction + } } impl From for spirv::ImageFormat { diff --git a/src/back/spv/mod.rs b/src/back/spv/mod.rs index ac7281fc6b..a7a4bd302c 100644 --- a/src/back/spv/mod.rs +++ b/src/back/spv/mod.rs @@ -13,6 +13,7 @@ mod layout; mod ray; mod recyclable; mod selection; +mod subgroup; mod writer; pub use spirv::Capability; diff --git a/src/back/spv/subgroup.rs b/src/back/spv/subgroup.rs new file mode 100644 index 0000000000..79db752a6c --- /dev/null +++ b/src/back/spv/subgroup.rs @@ -0,0 +1,204 @@ +use super::{Block, BlockContext, Error, Instruction}; +use crate::{ + arena::Handle, + back::spv::{LocalType, LookupType}, + TypeInner, +}; + +impl<'w> BlockContext<'w> { + pub(super) fn write_subgroup_ballot( + &mut self, + predicate: &Option>, + result: Handle, + block: &mut Block, + ) -> Result<(), Error> { + self.writer.require_any( + "GroupNonUniformBallot", + &[spirv::Capability::GroupNonUniformBallot], + )?; + let vec4_u32_type_id = self.get_type_id(LookupType::Local(LocalType::Value { + vector_size: Some(crate::VectorSize::Quad), + kind: crate::ScalarKind::Uint, + width: 4, + pointer_space: None, + })); + let exec_scope_id = self.get_index_constant(spirv::Scope::Subgroup as u32); + let predicate = if let Some(predicate) = *predicate { + self.cached[predicate] + } else { + self.writer.get_constant_scalar(crate::Literal::Bool(true)) + }; + let id = self.gen_id(); + block.body.push(Instruction::group_non_uniform_ballot( + vec4_u32_type_id, + id, + exec_scope_id, + predicate, + )); + self.cached[result] = id; + Ok(()) + } + pub(super) fn write_subgroup_operation( + &mut self, + op: &crate::SubgroupOperation, + collective_op: &crate::CollectiveOperation, + argument: Handle, + result: Handle, + block: &mut Block, + ) -> Result<(), Error> { + use crate::SubgroupOperation as sg; + match *op { + sg::All | sg::Any => { + self.writer.require_any( + "GroupNonUniformVote", + &[spirv::Capability::GroupNonUniformVote], + )?; + } + _ => { + self.writer.require_any( + "GroupNonUniformArithmetic", + &[spirv::Capability::GroupNonUniformArithmetic], + )?; + } + } + + let id = self.gen_id(); + let result_ty = &self.fun_info[result].ty; + let result_type_id = self.get_expression_type_id(result_ty); + let result_ty_inner = result_ty.inner_with(&self.ir_module.types); + + let (is_scalar, kind) = match *result_ty_inner { + TypeInner::Scalar { kind, .. } => (true, kind), + TypeInner::Vector { kind, .. } => (false, kind), + _ => unimplemented!(), + }; + + use crate::ScalarKind as sk; + let spirv_op = match (kind, *op) { + (sk::Bool, sg::All) if is_scalar => spirv::Op::GroupNonUniformAll, + (sk::Bool, sg::Any) if is_scalar => spirv::Op::GroupNonUniformAny, + (_, sg::All | sg::Any) => unimplemented!(), + + (sk::Sint | sk::Uint, sg::Add) => spirv::Op::GroupNonUniformIAdd, + (sk::Float, sg::Add) => spirv::Op::GroupNonUniformFAdd, + (sk::Sint | sk::Uint, sg::Mul) => spirv::Op::GroupNonUniformIMul, + (sk::Float, sg::Mul) => spirv::Op::GroupNonUniformFMul, + (sk::Sint, sg::Max) => spirv::Op::GroupNonUniformSMax, + (sk::Uint, sg::Max) => spirv::Op::GroupNonUniformUMax, + (sk::Float, sg::Max) => spirv::Op::GroupNonUniformFMax, + (sk::Sint, sg::Min) => spirv::Op::GroupNonUniformSMin, + (sk::Uint, sg::Min) => spirv::Op::GroupNonUniformUMin, + (sk::Float, sg::Min) => spirv::Op::GroupNonUniformFMin, + (sk::Bool, sg::Add | sg::Mul | sg::Min | sg::Max) => unimplemented!(), + + (sk::Sint | sk::Uint, sg::And) => spirv::Op::GroupNonUniformBitwiseAnd, + (sk::Sint | sk::Uint, sg::Or) => spirv::Op::GroupNonUniformBitwiseOr, + (sk::Sint | sk::Uint, sg::Xor) => spirv::Op::GroupNonUniformBitwiseXor, + (sk::Float, sg::And | sg::Or | sg::Xor) => unimplemented!(), + (sk::Bool, sg::And) => spirv::Op::GroupNonUniformLogicalAnd, + (sk::Bool, sg::Or) => spirv::Op::GroupNonUniformLogicalOr, + (sk::Bool, sg::Xor) => spirv::Op::GroupNonUniformLogicalXor, + }; + + let exec_scope_id = self.get_index_constant(spirv::Scope::Subgroup as u32); + + use crate::CollectiveOperation as c; + let group_op = match *op { + sg::All | sg::Any => None, + _ => Some(match *collective_op { + c::Reduce => spirv::GroupOperation::Reduce, + c::InclusiveScan => spirv::GroupOperation::InclusiveScan, + c::ExclusiveScan => spirv::GroupOperation::ExclusiveScan, + }), + }; + + let arg_id = self.cached[argument]; + block.body.push(Instruction::group_non_uniform_arithmetic( + spirv_op, + result_type_id, + id, + exec_scope_id, + group_op, + arg_id, + )); + self.cached[result] = id; + Ok(()) + } + pub(super) fn write_subgroup_gather( + &mut self, + mode: &crate::GatherMode, + argument: Handle, + result: Handle, + block: &mut Block, + ) -> Result<(), Error> { + self.writer.require_any( + "GroupNonUniformBallot", + &[spirv::Capability::GroupNonUniformBallot], + )?; + match *mode { + crate::GatherMode::BroadcastFirst | crate::GatherMode::Broadcast(_) => { + self.writer.require_any( + "GroupNonUniformBallot", + &[spirv::Capability::GroupNonUniformBallot], + )?; + } + crate::GatherMode::Shuffle(_) | crate::GatherMode::ShuffleXor(_) => { + self.writer.require_any( + "GroupNonUniformShuffle", + &[spirv::Capability::GroupNonUniformShuffle], + )?; + } + crate::GatherMode::ShuffleDown(_) | crate::GatherMode::ShuffleUp(_) => { + self.writer.require_any( + "GroupNonUniformShuffleRelative", + &[spirv::Capability::GroupNonUniformShuffleRelative], + )?; + } + } + + let id = self.gen_id(); + let result_ty = &self.fun_info[result].ty; + let result_type_id = self.get_expression_type_id(result_ty); + + let exec_scope_id = self.get_index_constant(spirv::Scope::Subgroup as u32); + + let arg_id = self.cached[argument]; + match *mode { + crate::GatherMode::BroadcastFirst => { + block + .body + .push(Instruction::group_non_uniform_broadcast_first( + result_type_id, + id, + exec_scope_id, + arg_id, + )); + } + crate::GatherMode::Broadcast(index) + | crate::GatherMode::Shuffle(index) + | crate::GatherMode::ShuffleDown(index) + | crate::GatherMode::ShuffleUp(index) + | crate::GatherMode::ShuffleXor(index) => { + let index_id = self.cached[index]; + let op = match *mode { + crate::GatherMode::BroadcastFirst => unreachable!(), + crate::GatherMode::Broadcast(_) => spirv::Op::GroupNonUniformBroadcast, + crate::GatherMode::Shuffle(_) => spirv::Op::GroupNonUniformShuffle, + crate::GatherMode::ShuffleDown(_) => spirv::Op::GroupNonUniformShuffleDown, + crate::GatherMode::ShuffleUp(_) => spirv::Op::GroupNonUniformShuffleUp, + crate::GatherMode::ShuffleXor(_) => spirv::Op::GroupNonUniformShuffleXor, + }; + block.body.push(Instruction::group_non_uniform_gather( + op, + result_type_id, + id, + exec_scope_id, + arg_id, + index_id, + )); + } + } + self.cached[result] = id; + Ok(()) + } +} diff --git a/src/back/spv/writer.rs b/src/back/spv/writer.rs index 24cb14a161..da0fdf766f 100644 --- a/src/back/spv/writer.rs +++ b/src/back/spv/writer.rs @@ -1314,7 +1314,11 @@ impl Writer { spirv::MemorySemantics::WORKGROUP_MEMORY, flags.contains(crate::Barrier::WORK_GROUP), ); - let exec_scope_id = self.get_index_constant(spirv::Scope::Workgroup as u32); + let exec_scope_id = if flags.contains(crate::Barrier::SUB_GROUP) { + self.get_index_constant(spirv::Scope::Subgroup as u32) + } else { + self.get_index_constant(spirv::Scope::Workgroup as u32) + }; let mem_scope_id = self.get_index_constant(memory_scope as u32); let semantics_id = self.get_index_constant(semantics.bits()); block.body.push(Instruction::control_barrier( @@ -1589,6 +1593,41 @@ impl Writer { Bi::WorkGroupId => BuiltIn::WorkgroupId, Bi::WorkGroupSize => BuiltIn::WorkgroupSize, Bi::NumWorkGroups => BuiltIn::NumWorkgroups, + // Subgroup + Bi::NumSubgroups => { + self.require_any( + "`num_subgroups` built-in", + &[spirv::Capability::GroupNonUniform], + )?; + BuiltIn::NumSubgroups + } + Bi::SubgroupId => { + self.require_any( + "`subgroup_id` built-in", + &[spirv::Capability::GroupNonUniform], + )?; + BuiltIn::SubgroupId + } + Bi::SubgroupSize => { + self.require_any( + "`subgroup_size` built-in", + &[ + spirv::Capability::GroupNonUniform, + spirv::Capability::SubgroupBallotKHR, + ], + )?; + BuiltIn::SubgroupSize + } + Bi::SubgroupInvocationId => { + self.require_any( + "`subgroup_invocation_id` built-in", + &[ + spirv::Capability::GroupNonUniform, + spirv::Capability::SubgroupBallotKHR, + ], + )?; + BuiltIn::SubgroupLocalInvocationId + } }; self.decorate(id, Decoration::BuiltIn, &[built_in as u32]); diff --git a/src/back/wgsl/writer.rs b/src/back/wgsl/writer.rs index 075d85558c..0bc2dfceb0 100644 --- a/src/back/wgsl/writer.rs +++ b/src/back/wgsl/writer.rs @@ -919,8 +919,124 @@ impl Writer { if barrier.contains(crate::Barrier::WORK_GROUP) { writeln!(self.out, "{level}workgroupBarrier();")?; } + + if barrier.contains(crate::Barrier::SUB_GROUP) { + writeln!(self.out, "{level}subgroupBarrier();")?; + } } Statement::RayQuery { .. } => unreachable!(), + Statement::SubgroupBallot { result, predicate } => { + write!(self.out, "{level}")?; + let res_name = format!("{}{}", back::BAKE_PREFIX, result.index()); + self.start_named_expr(module, result, func_ctx, &res_name)?; + self.named_expressions.insert(result, res_name); + + writeln!(self.out, "subgroupBallot(")?; + if let Some(predicate) = predicate { + self.write_expr(module, predicate, func_ctx)?; + } + writeln!(self.out, ");")?; + } + Statement::SubgroupCollectiveOperation { + op, + collective_op, + argument, + result, + } => { + write!(self.out, "{level}")?; + let res_name = format!("{}{}", back::BAKE_PREFIX, result.index()); + self.start_named_expr(module, result, func_ctx, &res_name)?; + self.named_expressions.insert(result, res_name); + + match (collective_op, op) { + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::All) => { + write!(self.out, "subgroupAll(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Any) => { + write!(self.out, "subgroupAny(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Add) => { + write!(self.out, "subgroupAdd(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Mul) => { + write!(self.out, "subgroupMul(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Max) => { + write!(self.out, "subgroupMax(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Min) => { + write!(self.out, "subgroupMin(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::And) => { + write!(self.out, "subgroupAnd(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Or) => { + write!(self.out, "subgroupOr(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Xor) => { + write!(self.out, "subgroupXor(")? + } + (crate::CollectiveOperation::ExclusiveScan, crate::SubgroupOperation::Add) => { + write!(self.out, "subgroupPrefixExclusiveAdd(")? + } + (crate::CollectiveOperation::ExclusiveScan, crate::SubgroupOperation::Mul) => { + write!(self.out, "subgroupPrefixExclusiveMul(")? + } + (crate::CollectiveOperation::InclusiveScan, crate::SubgroupOperation::Add) => { + write!(self.out, "subgroupPrefixInclusiveAdd(")? + } + (crate::CollectiveOperation::InclusiveScan, crate::SubgroupOperation::Mul) => { + write!(self.out, "subgroupPrefixInclusiveMul(")? + } + _ => unimplemented!(), + } + self.write_expr(module, argument, func_ctx)?; + writeln!(self.out, ");")?; + } + Statement::SubgroupGather { + mode, + argument, + result, + } => { + write!(self.out, "{level}")?; + let res_name = format!("{}{}", back::BAKE_PREFIX, result.index()); + self.start_named_expr(module, result, func_ctx, &res_name)?; + self.named_expressions.insert(result, res_name); + + match mode { + crate::GatherMode::BroadcastFirst => { + write!(self.out, "subgroupBroadcastFirst(")?; + } + crate::GatherMode::Broadcast(_) => { + write!(self.out, "subgroupBroadcast(")?; + } + crate::GatherMode::Shuffle(_) => { + write!(self.out, "subgroupShuffle(")?; + } + crate::GatherMode::ShuffleDown(_) => { + write!(self.out, "subgroupShuffleDown(")?; + } + crate::GatherMode::ShuffleUp(_) => { + write!(self.out, "subgroupShuffleUp(")?; + } + crate::GatherMode::ShuffleXor(_) => { + write!(self.out, "subgroupShuffleXor(")?; + } + } + self.write_expr(module, argument, func_ctx)?; + match mode { + crate::GatherMode::BroadcastFirst => {} + crate::GatherMode::Broadcast(index) + | crate::GatherMode::Shuffle(index) + | crate::GatherMode::ShuffleDown(index) + | crate::GatherMode::ShuffleUp(index) + | crate::GatherMode::ShuffleXor(index) => { + write!(self.out, ", ")?; + self.write_expr(module, index, func_ctx)?; + } + } + writeln!(self.out, ");")?; + } } Ok(()) @@ -1659,6 +1775,8 @@ impl Writer { Expression::CallResult(_) | Expression::AtomicResult { .. } | Expression::RayQueryProceedResult + | Expression::SubgroupBallotResult + | Expression::SubgroupOperationResult { .. } | Expression::WorkGroupUniformLoadResult { .. } => {} } @@ -1760,6 +1878,10 @@ fn builtin_str(built_in: crate::BuiltIn) -> Result<&'static str, Error> { Bi::SampleMask => "sample_mask", Bi::PrimitiveIndex => "primitive_index", Bi::ViewIndex => "view_index", + Bi::NumSubgroups => "num_subgroups", + Bi::SubgroupId => "subgroup_id", + Bi::SubgroupSize => "subgroup_size", + Bi::SubgroupInvocationId => "subgroup_invocation_id", Bi::BaseInstance | Bi::BaseVertex | Bi::ClipDistance diff --git a/src/compact/expressions.rs b/src/compact/expressions.rs index c1326e92be..1533362407 100644 --- a/src/compact/expressions.rs +++ b/src/compact/expressions.rs @@ -55,6 +55,7 @@ impl<'tracer> ExpressionTracer<'tracer> { | Ex::GlobalVariable(_) | Ex::LocalVariable(_) | Ex::CallResult(_) + | Ex::SubgroupBallotResult // FIXME: ??? | Ex::RayQueryProceedResult => {} Ex::Constant(handle) => { @@ -156,6 +157,7 @@ impl<'tracer> ExpressionTracer<'tracer> { Ex::AtomicResult { ty, comparison: _ } => self.trace_type(ty), Ex::WorkGroupUniformLoadResult { ty } => self.trace_type(ty), Ex::ArrayLength(expr) => work_list.push(expr), + Ex::SubgroupOperationResult { ty } => self.trace_type(ty), Ex::RayQueryGetIntersection { query, committed: _, @@ -222,6 +224,7 @@ impl ModuleMap { | Ex::GlobalVariable(_) | Ex::LocalVariable(_) | Ex::CallResult(_) + | Ex::SubgroupBallotResult // FIXME: ??? | Ex::RayQueryProceedResult => {} // Expressions that contain handles that need to be adjusted. @@ -349,6 +352,7 @@ impl ModuleMap { comparison: _, } => self.types.adjust(ty), Ex::WorkGroupUniformLoadResult { ref mut ty } => self.types.adjust(ty), + Ex::SubgroupOperationResult { ref mut ty } => self.types.adjust(ty), Ex::ArrayLength(ref mut expr) => adjust(expr), Ex::RayQueryGetIntersection { ref mut query, diff --git a/src/compact/statements.rs b/src/compact/statements.rs index 4c62771023..37074c4299 100644 --- a/src/compact/statements.rs +++ b/src/compact/statements.rs @@ -95,6 +95,37 @@ impl FunctionTracer<'_> { self.trace_expression(query); self.trace_ray_query_function(fun); } + St::SubgroupBallot { result, predicate } => { + if let Some(predicate) = predicate { + self.trace_expression(predicate); + } + self.trace_expression(result); + } + St::SubgroupCollectiveOperation { + op: _, + collective_op: _, + argument, + result, + } => { + self.trace_expression(argument); + self.trace_expression(result); + } + St::SubgroupGather { + mode, + argument, + result, + } => { + match mode { + crate::GatherMode::BroadcastFirst => {} + crate::GatherMode::Broadcast(index) + | crate::GatherMode::Shuffle(index) + | crate::GatherMode::ShuffleDown(index) + | crate::GatherMode::ShuffleUp(index) + | crate::GatherMode::ShuffleXor(index) => self.trace_expression(index), + } + self.trace_expression(argument); + self.trace_expression(result); + } // Trivial statements. St::Break @@ -244,6 +275,40 @@ impl FunctionMap { adjust(query); self.adjust_ray_query_function(fun); } + St::SubgroupBallot { + ref mut result, + ref mut predicate, + } => { + if let Some(ref mut predicate) = *predicate { + adjust(predicate); + } + adjust(result); + } + St::SubgroupCollectiveOperation { + op: _, + collective_op: _, + ref mut argument, + ref mut result, + } => { + adjust(argument); + adjust(result); + } + St::SubgroupGather { + ref mut mode, + ref mut argument, + ref mut result, + } => { + match *mode { + crate::GatherMode::BroadcastFirst => {} + crate::GatherMode::Broadcast(ref mut index) + | crate::GatherMode::Shuffle(ref mut index) + | crate::GatherMode::ShuffleDown(ref mut index) + | crate::GatherMode::ShuffleUp(ref mut index) + | crate::GatherMode::ShuffleXor(ref mut index) => adjust(index), + } + adjust(argument); + adjust(result); + } // Trivial statements. St::Break diff --git a/src/front/spv/error.rs b/src/front/spv/error.rs index 2f9bf2d1bc..8508ede042 100644 --- a/src/front/spv/error.rs +++ b/src/front/spv/error.rs @@ -54,6 +54,8 @@ pub enum Error { UnknownBinaryOperator(spirv::Op), #[error("unknown relational function {0:?}")] UnknownRelationalFunction(spirv::Op), + #[error("unsupported group opeation %{0}")] + UnsupportedGroupOperation(spirv::Word), #[error("invalid parameter {0:?}")] InvalidParameter(spirv::Op), #[error("invalid operand count {1} for {0:?}")] @@ -116,8 +118,8 @@ pub enum Error { FunctionCallCycle(spirv::Word), #[error("invalid array size {0:?}")] InvalidArraySize(Handle), - #[error("invalid barrier scope %{0}")] - InvalidBarrierScope(spirv::Word), + #[error("invalid execution scope %{0}")] + InvalidExecutionScope(spirv::Word), #[error("invalid barrier memory semantics %{0}")] InvalidBarrierMemorySemantics(spirv::Word), #[error( diff --git a/src/front/spv/mod.rs b/src/front/spv/mod.rs index 083205a45b..e7c082c6a5 100644 --- a/src/front/spv/mod.rs +++ b/src/front/spv/mod.rs @@ -3650,7 +3650,7 @@ impl> Frontend { let semantics_const = self.lookup_constant.lookup(semantics_id)?; let exec_scope = resolve_constant(ctx.gctx(), exec_scope_const.handle) - .ok_or(Error::InvalidBarrierScope(exec_scope_id))?; + .ok_or(Error::InvalidExecutionScope(exec_scope_id))?; let semantics = resolve_constant(ctx.gctx(), semantics_const.handle) .ok_or(Error::InvalidBarrierMemorySemantics(semantics_id))?; @@ -3691,6 +3691,209 @@ impl> Frontend { }, ); } + Op::GroupNonUniformBallot => { + inst.expect(4)?; + let _result_type_id = self.next()?; + let result_id = self.next()?; + let exec_scope_id = self.next()?; + let predicate_id = self.next()?; + + let result_lookup = self.lookup_expression.lookup(result_id)?; + let result_handle = get_expr_handle!(result_id, result_lookup); + + let exec_scope_const = self.lookup_constant.lookup(exec_scope_id)?; + let _exec_scope = resolve_constant(ctx.gctx(), exec_scope_const.handle) + .filter(|exec_scope| *exec_scope == spirv::Scope::Subgroup as u32) + .ok_or(Error::InvalidExecutionScope(exec_scope_id))?; + + let predicate = if self + .lookup_constant + .lookup(predicate_id) + .ok() + .filter(|predicate_const| { + matches!( + ctx.gctx().const_expressions + [ctx.gctx().constants[predicate_const.handle].init], + crate::Expression::Literal(crate::Literal::Bool(true)) + ) + }) + .is_some() + { + None + } else { + let predicate_lookup = self.lookup_expression.lookup(predicate_id)?; + let predicate_handle = get_expr_handle!(predicate_id, predicate_lookup); + Some(predicate_handle) + }; + + block.push( + crate::Statement::SubgroupBallot { + result: result_handle, + predicate, + }, + span, + ); + } + spirv::Op::GroupNonUniformAll + | spirv::Op::GroupNonUniformAny + | spirv::Op::GroupNonUniformIAdd + | spirv::Op::GroupNonUniformFAdd + | spirv::Op::GroupNonUniformIMul + | spirv::Op::GroupNonUniformFMul + | spirv::Op::GroupNonUniformSMax + | spirv::Op::GroupNonUniformUMax + | spirv::Op::GroupNonUniformFMax + | spirv::Op::GroupNonUniformSMin + | spirv::Op::GroupNonUniformUMin + | spirv::Op::GroupNonUniformFMin + | spirv::Op::GroupNonUniformBitwiseAnd + | spirv::Op::GroupNonUniformBitwiseOr + | spirv::Op::GroupNonUniformBitwiseXor + | spirv::Op::GroupNonUniformLogicalAnd + | spirv::Op::GroupNonUniformLogicalOr + | spirv::Op::GroupNonUniformLogicalXor => { + inst.expect( + if matches!( + inst.op, + spirv::Op::GroupNonUniformAll | spirv::Op::GroupNonUniformAny + ) { + 4 + } else { + 5 + }, + )?; + let _result_type_id = self.next()?; + let result_id = self.next()?; + let exec_scope_id = self.next()?; + let collective_op_id = match inst.op { + spirv::Op::GroupNonUniformAll | spirv::Op::GroupNonUniformAny => { + crate::CollectiveOperation::Reduce + } + _ => { + let group_op_id = self.next()?; + match spirv::GroupOperation::from_u32(group_op_id) { + Some(spirv::GroupOperation::Reduce) => { + crate::CollectiveOperation::Reduce + } + Some(spirv::GroupOperation::InclusiveScan) => { + crate::CollectiveOperation::InclusiveScan + } + Some(spirv::GroupOperation::ExclusiveScan) => { + crate::CollectiveOperation::ExclusiveScan + } + _ => return Err(Error::UnsupportedGroupOperation(group_op_id)), + } + } + }; + let argument_id = self.next()?; + + let result_lookup = self.lookup_expression.lookup(result_id)?; + let result_handle = get_expr_handle!(result_id, result_lookup); + let argument_lookup = self.lookup_expression.lookup(argument_id)?; + let argument_handle = get_expr_handle!(argument_id, argument_lookup); + + let exec_scope_const = self.lookup_constant.lookup(exec_scope_id)?; + let _exec_scope = resolve_constant(ctx.gctx(), exec_scope_const.handle) + .filter(|exec_scope| *exec_scope == spirv::Scope::Subgroup as u32) + .ok_or(Error::InvalidExecutionScope(exec_scope_id))?; + + let op_id = match inst.op { + spirv::Op::GroupNonUniformAll => crate::SubgroupOperation::All, + spirv::Op::GroupNonUniformAny => crate::SubgroupOperation::Any, + spirv::Op::GroupNonUniformIAdd | spirv::Op::GroupNonUniformFAdd => { + crate::SubgroupOperation::Add + } + spirv::Op::GroupNonUniformIMul | spirv::Op::GroupNonUniformFMul => { + crate::SubgroupOperation::Mul + } + spirv::Op::GroupNonUniformSMax + | spirv::Op::GroupNonUniformUMax + | spirv::Op::GroupNonUniformFMax => crate::SubgroupOperation::Max, + spirv::Op::GroupNonUniformSMin + | spirv::Op::GroupNonUniformUMin + | spirv::Op::GroupNonUniformFMin => crate::SubgroupOperation::Min, + spirv::Op::GroupNonUniformBitwiseAnd + | spirv::Op::GroupNonUniformLogicalAnd => crate::SubgroupOperation::And, + spirv::Op::GroupNonUniformBitwiseOr + | spirv::Op::GroupNonUniformLogicalOr => crate::SubgroupOperation::Or, + spirv::Op::GroupNonUniformBitwiseXor + | spirv::Op::GroupNonUniformLogicalXor => crate::SubgroupOperation::Xor, + _ => unreachable!(), + }; + + block.push( + crate::Statement::SubgroupCollectiveOperation { + result: result_handle, + op: op_id, + collective_op: collective_op_id, + argument: argument_handle, + }, + span, + ); + } + Op::GroupNonUniformBroadcastFirst + | Op::GroupNonUniformBroadcast + | Op::GroupNonUniformShuffle + | Op::GroupNonUniformShuffleDown + | Op::GroupNonUniformShuffleUp + | Op::GroupNonUniformShuffleXor => { + inst.expect( + if matches!(inst.op, spirv::Op::GroupNonUniformBroadcastFirst) { + 4 + } else { + 5 + }, + )?; + let _result_type_id = self.next()?; + let result_id = self.next()?; + let exec_scope_id = self.next()?; + let argument_id = self.next()?; + + let result_lookup = self.lookup_expression.lookup(result_id)?; + let result_handle = get_expr_handle!(result_id, result_lookup); + let argument_lookup = self.lookup_expression.lookup(argument_id)?; + let argument_handle = get_expr_handle!(argument_id, argument_lookup); + + let exec_scope_const = self.lookup_constant.lookup(exec_scope_id)?; + let _exec_scope = resolve_constant(ctx.gctx(), exec_scope_const.handle) + .filter(|exec_scope| *exec_scope == spirv::Scope::Subgroup as u32) + .ok_or(Error::InvalidExecutionScope(exec_scope_id))?; + + let mode = if matches!(inst.op, spirv::Op::GroupNonUniformBroadcastFirst) { + crate::GatherMode::BroadcastFirst + } else { + let index_id = self.next()?; + let index_lookup = self.lookup_expression.lookup(index_id)?; + let index_handle = get_expr_handle!(index_id, index_lookup); + match inst.op { + spirv::Op::GroupNonUniformBroadcast => { + crate::GatherMode::Broadcast(index_handle) + } + spirv::Op::GroupNonUniformShuffle => { + crate::GatherMode::Shuffle(index_handle) + } + spirv::Op::GroupNonUniformShuffleDown => { + crate::GatherMode::ShuffleDown(index_handle) + } + spirv::Op::GroupNonUniformShuffleUp => { + crate::GatherMode::ShuffleUp(index_handle) + } + spirv::Op::GroupNonUniformShuffleXor => { + crate::GatherMode::ShuffleXor(index_handle) + } + _ => unreachable!(), + } + }; + + block.push( + crate::Statement::SubgroupGather { + result: result_handle, + mode, + argument: argument_handle, + }, + span, + ); + } _ => return Err(Error::UnsupportedInstruction(self.state, inst.op)), } }; @@ -3811,7 +4014,10 @@ impl> Frontend { | S::Store { .. } | S::ImageStore { .. } | S::Atomic { .. } - | S::RayQuery { .. } => {} + | S::RayQuery { .. } + | S::SubgroupBallot { .. } + | S::SubgroupCollectiveOperation { .. } + | S::SubgroupGather { .. } => {} S::Call { function: ref mut callee, ref arguments, diff --git a/src/front/wgsl/lower/mod.rs b/src/front/wgsl/lower/mod.rs index ae178cb702..d3e47f6959 100644 --- a/src/front/wgsl/lower/mod.rs +++ b/src/front/wgsl/lower/mod.rs @@ -1917,6 +1917,14 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { } } else if let Some(fun) = Texture::map(function.name) { self.texture_sample_helper(fun, arguments, span, ctx.reborrow())? + } else if let Some((op, cop)) = conv::map_subgroup_operation(function.name) { + return Ok(Some( + self.subgroup_operation_helper(span, op, cop, arguments, ctx)?, + )); + } else if let Some(mode) = conv::map_subgroup_gather(function.name) { + return Ok(Some( + self.subgroup_gather_helper(span, mode, arguments, ctx)?, + )); } else { match function.name { "select" => { @@ -2085,6 +2093,14 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { .push(crate::Statement::Barrier(crate::Barrier::WORK_GROUP), span); return Ok(None); } + "subgroupBarrier" => { + ctx.prepare_args(arguments, 0, span).finish()?; + + let rctx = ctx.runtime_expression_ctx(span)?; + rctx.block + .push(crate::Statement::Barrier(crate::Barrier::SUB_GROUP), span); + return Ok(None); + } "workgroupUniformLoad" => { let mut args = ctx.prepare_args(arguments, 1, span); let expr = args.next()?; @@ -2290,6 +2306,22 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { )?; return Ok(Some(handle)); } + "subgroupBallot" => { + let mut args = ctx.prepare_args(arguments, 0, span); + let predicate = if arguments.len() == 1 { + Some(self.expression(args.next()?, ctx.reborrow())?) + } else { + None + }; + args.finish()?; + + let result = ctx + .interrupt_emitter(crate::Expression::SubgroupBallotResult, span)?; + let rctx = ctx.runtime_expression_ctx(span)?; + rctx.block + .push(crate::Statement::SubgroupBallot { result, predicate }, span); + return Ok(Some(result)); + } _ => return Err(Error::UnknownIdent(function.span, function.name)), } }; @@ -2482,6 +2514,76 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { }) } + fn subgroup_operation_helper( + &mut self, + span: Span, + op: crate::SubgroupOperation, + collective_op: crate::CollectiveOperation, + arguments: &[Handle>], + mut ctx: ExpressionContext<'source, '_, '_>, + ) -> Result, Error<'source>> { + let mut args = ctx.prepare_args(arguments, 1, span); + + let argument = self.expression(args.next()?, ctx.reborrow())?; + args.finish()?; + + let ty = ctx.register_type(argument)?; + + let result = + ctx.interrupt_emitter(crate::Expression::SubgroupOperationResult { ty }, span)?; + let rctx = ctx.runtime_expression_ctx(span)?; + rctx.block.push( + crate::Statement::SubgroupCollectiveOperation { + op, + collective_op, + argument, + result, + }, + span, + ); + Ok(result) + } + + fn subgroup_gather_helper( + &mut self, + span: Span, + mode: crate::GatherMode, + arguments: &[Handle>], + mut ctx: ExpressionContext<'source, '_, '_>, + ) -> Result, Error<'source>> { + let mut args = ctx.prepare_args(arguments, 2, span); + + let argument = self.expression(args.next()?, ctx.reborrow())?; + let index = if let crate::GatherMode::BroadcastFirst = mode { + Handle::new(NonZeroU32::new(u32::MAX).unwrap()) + } else { + self.expression(args.next()?, ctx.reborrow())? + }; + args.finish()?; + + let ty = ctx.register_type(argument)?; + + let result = + ctx.interrupt_emitter(crate::Expression::SubgroupOperationResult { ty }, span)?; + let rctx = ctx.runtime_expression_ctx(span)?; + rctx.block.push( + crate::Statement::SubgroupGather { + mode: match mode { + crate::GatherMode::BroadcastFirst => crate::GatherMode::BroadcastFirst, + crate::GatherMode::Broadcast(_) => crate::GatherMode::Broadcast(index), + crate::GatherMode::Shuffle(_) => crate::GatherMode::Shuffle(index), + crate::GatherMode::ShuffleDown(_) => crate::GatherMode::ShuffleDown(index), + crate::GatherMode::ShuffleUp(_) => crate::GatherMode::ShuffleUp(index), + crate::GatherMode::ShuffleXor(_) => crate::GatherMode::ShuffleXor(index), + }, + argument, + result, + }, + span, + ); + Ok(result) + } + fn r#struct( &mut self, s: &ast::Struct<'source>, diff --git a/src/front/wgsl/parse/conv.rs b/src/front/wgsl/parse/conv.rs index 51977173d6..c53f4df753 100644 --- a/src/front/wgsl/parse/conv.rs +++ b/src/front/wgsl/parse/conv.rs @@ -34,6 +34,11 @@ pub fn map_built_in(word: &str, span: Span) -> Result> "local_invocation_index" => crate::BuiltIn::LocalInvocationIndex, "workgroup_id" => crate::BuiltIn::WorkGroupId, "num_workgroups" => crate::BuiltIn::NumWorkGroups, + // subgroup + "num_subgroups" => crate::BuiltIn::NumSubgroups, + "subgroup_id" => crate::BuiltIn::SubgroupId, + "subgroup_size" => crate::BuiltIn::SubgroupSize, + "subgroup_invocation_id" => crate::BuiltIn::SubgroupInvocationId, _ => return Err(Error::UnknownBuiltin(span)), }) } @@ -235,3 +240,41 @@ pub fn map_conservative_depth( _ => Err(Error::UnknownConservativeDepth(span)), } } + +pub fn map_subgroup_operation( + word: &str, +) -> Option<(crate::SubgroupOperation, crate::CollectiveOperation)> { + use crate::CollectiveOperation as co; + use crate::SubgroupOperation as sg; + Some(match word { + "subgroupAll" => (sg::All, co::Reduce), + "subgroupAny" => (sg::Any, co::Reduce), + "subgroupAdd" => (sg::Add, co::Reduce), + "subgroupMul" => (sg::Mul, co::Reduce), + "subgroupMin" => (sg::Min, co::Reduce), + "subgroupMax" => (sg::Max, co::Reduce), + "subgroupAnd" => (sg::And, co::Reduce), + "subgroupOr" => (sg::Or, co::Reduce), + "subgroupXor" => (sg::Xor, co::Reduce), + "subgroupPrefixExclusiveAdd" => (sg::Add, co::ExclusiveScan), + "subgroupPrefixExclusiveMul" => (sg::Mul, co::ExclusiveScan), + "subgroupPrefixInclusiveAdd" => (sg::Add, co::InclusiveScan), + "subgroupPrefixInclusiveMul" => (sg::Mul, co::InclusiveScan), + _ => return None, + }) +} + +pub fn map_subgroup_gather(word: &str) -> Option { + use crate::GatherMode as gm; + use crate::Handle; + use std::num::NonZeroU32; + Some(match word { + "subgroupBroadcastFirst" => gm::BroadcastFirst, + "subgroupBroadcast" => gm::Broadcast(Handle::new(NonZeroU32::new(u32::MAX).unwrap())), + "subgroupShuffle" => gm::Shuffle(Handle::new(NonZeroU32::new(u32::MAX).unwrap())), + "subgroupShuffleDown" => gm::ShuffleDown(Handle::new(NonZeroU32::new(u32::MAX).unwrap())), + "subgroupShuffleUp" => gm::ShuffleUp(Handle::new(NonZeroU32::new(u32::MAX).unwrap())), + "subgroupShuffleXor" => gm::ShuffleXor(Handle::new(NonZeroU32::new(u32::MAX).unwrap())), + _ => return None, + }) +} diff --git a/src/lib.rs b/src/lib.rs index 300c6e4820..c754405ed7 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -435,6 +435,11 @@ pub enum BuiltIn { WorkGroupId, WorkGroupSize, NumWorkGroups, + // subgroup + NumSubgroups, + SubgroupId, + SubgroupSize, + SubgroupInvocationId, } /// Number of bytes per scalar. @@ -1258,6 +1263,46 @@ pub enum SwizzleComponent { W = 3, } +#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] +#[cfg_attr(feature = "serialize", derive(Serialize))] +#[cfg_attr(feature = "deserialize", derive(Deserialize))] +#[cfg_attr(feature = "arbitrary", derive(Arbitrary))] +pub enum GatherMode { + BroadcastFirst, + Broadcast(Handle), + Shuffle(Handle), + ShuffleDown(Handle), + ShuffleUp(Handle), + ShuffleXor(Handle), +} + +#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] +#[cfg_attr(feature = "serialize", derive(Serialize))] +#[cfg_attr(feature = "deserialize", derive(Deserialize))] +#[cfg_attr(feature = "arbitrary", derive(Arbitrary))] +pub enum SubgroupOperation { + All, + Any, + Add, + Mul, + Min, + Max, + And, + Or, + Xor, +} + +#[repr(u8)] +#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] +#[cfg_attr(feature = "serialize", derive(Serialize))] +#[cfg_attr(feature = "deserialize", derive(Deserialize))] +#[cfg_attr(feature = "arbitrary", derive(Arbitrary))] +pub enum CollectiveOperation { + Reduce = 0, + InclusiveScan = 1, + ExclusiveScan = 2, +} + bitflags::bitflags! { /// Memory barrier flags. #[cfg_attr(feature = "serialize", derive(Serialize))] @@ -1266,9 +1311,11 @@ bitflags::bitflags! { #[derive(Clone, Copy, Debug, Default, Eq, PartialEq)] pub struct Barrier: u32 { /// Barrier affects all `AddressSpace::Storage` accesses. - const STORAGE = 0x1; + const STORAGE = 1 << 0; /// Barrier affects all `AddressSpace::WorkGroup` accesses. - const WORK_GROUP = 0x2; + const WORK_GROUP = 1 << 1; + /// Barrier synchronizes execution across all invocations within a subgroup that exectue this instruction. + const SUB_GROUP = 1 << 2; } } @@ -1399,7 +1446,9 @@ pub enum Expression { /// /// For [`TypeInner::Atomic`] the result is a corresponding scalar. /// For other types behind the `pointer`, the result is `T`. - Load { pointer: Handle }, + Load { + pointer: Handle, + }, /// Sample a point from a sampled or a depth image. ImageSample { image: Handle, @@ -1539,7 +1588,10 @@ pub enum Expression { /// Result of calling another function. CallResult(Handle), /// Result of an atomic operation. - AtomicResult { ty: Handle, comparison: bool }, + AtomicResult { + ty: Handle, + comparison: bool, + }, /// Result of a [`WorkGroupUniformLoad`] statement. /// /// [`WorkGroupUniformLoad`]: Statement::WorkGroupUniformLoad @@ -1567,6 +1619,10 @@ pub enum Expression { query: Handle, committed: bool, }, + SubgroupBallotResult, + SubgroupOperationResult { + ty: Handle, + }, } pub use block::Block; @@ -1839,6 +1895,40 @@ pub enum Statement { /// The specific operation we're performing on `query`. fun: RayQueryFunction, }, + // subgroupBallot(bool) -> vec4 + SubgroupBallot { + /// The [`SubgroupBallotResult`] expression representing this load's result. + /// + /// [`SubgroupBallotResult`]: Expression::SubgroupBallotResult + result: Handle, + /// The value from this thread to store in the ballot + predicate: Option>, + }, + + SubgroupGather { + /// Specifies which thread to gather from + mode: GatherMode, + /// The value to broadcast over + argument: Handle, + /// The [`SubgroupOperationResult`] expression representing this load's result. + /// + /// [`SubgroupOperationResult`]: Expression::SubgroupOperationResult + result: Handle, + }, + + /// Compute a collective operation across all active threads in th subgroup + SubgroupCollectiveOperation { + /// What operation to compute + op: SubgroupOperation, + /// How to combine the results + collective_op: CollectiveOperation, + /// The value to compute over + argument: Handle, + /// The [`SubgroupOperationResult`] expression representing this load's result. + /// + /// [`SubgroupOperationResult`]: Expression::SubgroupOperationResult + result: Handle, + }, } /// A function argument. diff --git a/src/proc/constant_evaluator.rs b/src/proc/constant_evaluator.rs index 2082743975..c945805813 100644 --- a/src/proc/constant_evaluator.rs +++ b/src/proc/constant_evaluator.rs @@ -133,6 +133,8 @@ pub enum ConstantEvaluatorError { ImageExpression, #[error("Constants don't support ray query expressions")] RayQueryExpression, + #[error("Constants don't support subgroup expressions")] + SubgroupExpression, #[error("Cannot access the type")] InvalidAccessBase, #[error("Cannot access at the index")] @@ -439,6 +441,12 @@ impl<'a> ConstantEvaluator<'a> { Expression::RayQueryProceedResult | Expression::RayQueryGetIntersection { .. } => { Err(ConstantEvaluatorError::RayQueryExpression) } + Expression::SubgroupBallotResult { .. } => { + Err(ConstantEvaluatorError::SubgroupExpression) + } + Expression::SubgroupOperationResult { .. } => { + Err(ConstantEvaluatorError::SubgroupExpression) + } } } diff --git a/src/proc/terminator.rs b/src/proc/terminator.rs index a5239d4eca..5edf55cb73 100644 --- a/src/proc/terminator.rs +++ b/src/proc/terminator.rs @@ -37,6 +37,9 @@ pub fn ensure_block_returns(block: &mut crate::Block) { | S::RayQuery { .. } | S::Atomic { .. } | S::WorkGroupUniformLoad { .. } + | S::SubgroupBallot { .. } + | S::SubgroupCollectiveOperation { .. } + | S::SubgroupGather { .. } | S::Barrier(_)), ) | None => block.push(S::Return { value: None }, Default::default()), diff --git a/src/proc/typifier.rs b/src/proc/typifier.rs index ad9eec94d2..c2b38ac73b 100644 --- a/src/proc/typifier.rs +++ b/src/proc/typifier.rs @@ -638,6 +638,7 @@ impl<'a> ResolveContext<'a> { | crate::BinaryOperator::ShiftRight => past(left)?.clone(), }, crate::Expression::AtomicResult { ty, .. } => TypeResolution::Handle(ty), + crate::Expression::SubgroupOperationResult { ty } => TypeResolution::Handle(ty), crate::Expression::WorkGroupUniformLoadResult { ty } => TypeResolution::Handle(ty), crate::Expression::Select { accept, .. } => past(accept)?.clone(), crate::Expression::Derivative { expr, .. } => past(expr)?.clone(), @@ -905,6 +906,11 @@ impl<'a> ResolveContext<'a> { .ok_or(ResolveError::MissingSpecialType)?; TypeResolution::Handle(result) } + crate::Expression::SubgroupBallotResult => TypeResolution::Value(Ti::Vector { + kind: crate::ScalarKind::Uint, + size: crate::VectorSize::Quad, + width: 4, + }), }) } } diff --git a/src/valid/analyzer.rs b/src/valid/analyzer.rs index ff1db071c8..f4347df1dd 100644 --- a/src/valid/analyzer.rs +++ b/src/valid/analyzer.rs @@ -740,6 +740,14 @@ impl FunctionInfo { non_uniform_result: self.add_ref(query), requirements: UniformityRequirements::empty(), }, + E::SubgroupBallotResult => Uniformity { + non_uniform_result: None, // FIXME + requirements: UniformityRequirements::empty(), + }, + E::SubgroupOperationResult { .. } => Uniformity { + non_uniform_result: None, // FIXME + requirements: UniformityRequirements::empty(), + }, }; let ty = resolve_context.resolve(expression, |h| Ok(&self[h].ty))?; @@ -983,6 +991,42 @@ impl FunctionInfo { } FunctionUniformity::new() } + S::SubgroupBallot { + result: _, + predicate, + } => { + if let Some(predicate) = predicate { + let _ = self.add_ref(predicate); + } + FunctionUniformity::new() + } + S::SubgroupCollectiveOperation { + op: _, + collective_op: _, + argument, + result: _, + } => { + let _ = self.add_ref(argument); + FunctionUniformity::new() + } + S::SubgroupGather { + mode, + argument, + result: _, + } => { + let _ = self.add_ref(argument); + match mode { + crate::GatherMode::BroadcastFirst => {} + crate::GatherMode::Broadcast(index) + | crate::GatherMode::Shuffle(index) + | crate::GatherMode::ShuffleDown(index) + | crate::GatherMode::ShuffleUp(index) + | crate::GatherMode::ShuffleXor(index) => { + let _ = self.add_ref(index); + } + } + FunctionUniformity::new() + } }; disruptor = disruptor.or(uniformity.exit_disruptor()); diff --git a/src/valid/expression.rs b/src/valid/expression.rs index f77844b4b1..03ad851dbf 100644 --- a/src/valid/expression.rs +++ b/src/valid/expression.rs @@ -1537,6 +1537,8 @@ impl super::Validator { return Err(ExpressionError::InvalidRayQueryType(query)); } }, + E::SubgroupBallotResult => ShaderStages::COMPUTE | ShaderStages::FRAGMENT, + E::SubgroupOperationResult { .. } => ShaderStages::COMPUTE, // FIXME }; Ok(stages) } diff --git a/src/valid/function.rs b/src/valid/function.rs index d967f4b1f3..729a6405c2 100644 --- a/src/valid/function.rs +++ b/src/valid/function.rs @@ -51,6 +51,15 @@ pub enum AtomicError { ResultTypeMismatch(Handle), } +#[derive(Clone, Debug, thiserror::Error)] +#[cfg_attr(test, derive(PartialEq))] +pub enum SubgroupError { + #[error("Operand {0:?} has invalid type.")] + InvalidOperand(Handle), + #[error("Result type for {0:?} doesn't match the statement")] + ResultTypeMismatch(Handle), +} + #[derive(Clone, Debug, thiserror::Error)] #[cfg_attr(test, derive(PartialEq))] pub enum LocalVariableError { @@ -159,6 +168,8 @@ pub enum FunctionError { WorkgroupUniformLoadExpressionMismatch(Handle), #[error("The expression {0:?} is not valid as a WorkGroupUniformLoad argument. It should be a Pointer in Workgroup address space")] WorkgroupUniformLoadInvalidPointer(Handle), + #[error("Subgroup operation is invalid")] + InvalidSubgroup(#[from] SubgroupError), } bitflags::bitflags! { @@ -413,6 +424,113 @@ impl super::Validator { } Ok(()) } + #[cfg(feature = "validate")] + fn validate_subgroup_operation( + &mut self, + op: &crate::SubgroupOperation, + _collective_op: &crate::CollectiveOperation, + argument: Handle, + result: Handle, + context: &BlockContext, + ) -> Result<(), WithSpan> { + let argument_inner = context.resolve_type(argument, &self.valid_expression_set)?; + + let (is_scalar, kind) = match *argument_inner { + crate::TypeInner::Scalar { kind, .. } => (true, kind), + crate::TypeInner::Vector { kind, .. } => (false, kind), + _ => { + log::error!("Subgroup operand type {:?}", argument_inner); + return Err(SubgroupError::InvalidOperand(argument) + .with_span_handle(argument, context.expressions) + .into_other()); + } + }; + + use crate::ScalarKind as sk; + use crate::SubgroupOperation as sg; + match (kind, *op) { + (sk::Bool, sg::All | sg::Any) if is_scalar => {} + (sk::Sint | sk::Uint | sk::Float, sg::Add | sg::Mul | sg::Min | sg::Max) => {} + (sk::Sint | sk::Uint | sk::Bool, sg::And | sg::Or | sg::Xor) => {} + + (_, sg::All | sg::Any) + | (sk::Bool, sg::Add | sg::Mul | sg::Min | sg::Max) + | (sk::Float, sg::And | sg::Or | sg::Xor) => { + log::error!("Subgroup operand type {:?}", argument_inner); + return Err(SubgroupError::InvalidOperand(argument) + .with_span_handle(argument, context.expressions) + .into_other()); + } + }; + + self.emit_expression(result, context)?; + match context.expressions[result] { + crate::Expression::SubgroupOperationResult { ty } + if { &context.types[ty].inner == argument_inner } => {} + _ => { + return Err(SubgroupError::ResultTypeMismatch(result) + .with_span_handle(result, context.expressions) + .into_other()) + } + } + Ok(()) + } + #[cfg(feature = "validate")] + fn validate_subgroup_broadcast( + &mut self, + mode: &crate::GatherMode, + argument: Handle, + result: Handle, + context: &BlockContext, + ) -> Result<(), WithSpan> { + match *mode { + crate::GatherMode::BroadcastFirst => {} + crate::GatherMode::Broadcast(index) + | crate::GatherMode::Shuffle(index) + | crate::GatherMode::ShuffleDown(index) + | crate::GatherMode::ShuffleUp(index) + | crate::GatherMode::ShuffleXor(index) => { + let index_ty = context.resolve_type(index, &self.valid_expression_set)?; + match *index_ty { + crate::TypeInner::Scalar { + kind: crate::ScalarKind::Uint, + .. + } => {} + _ => { + log::error!( + "Subgroup gather index type {:?}, expected unsigned int", + index_ty + ); + return Err(SubgroupError::InvalidOperand(argument) + .with_span_handle(index, context.expressions) + .into_other()); + } + } + } + } + let argument_inner = context.resolve_type(argument, &self.valid_expression_set)?; + if !matches!(*argument_inner, + crate::TypeInner::Scalar { kind, .. } | crate::TypeInner::Vector { kind, .. } + if matches!(kind, crate::ScalarKind::Uint | crate::ScalarKind::Sint | crate::ScalarKind::Float) + ) { + log::error!("Subgroup gather operand type {:?}", argument_inner); + return Err(SubgroupError::InvalidOperand(argument) + .with_span_handle(argument, context.expressions) + .into_other()); + } + + self.emit_expression(result, context)?; + match context.expressions[result] { + crate::Expression::SubgroupOperationResult { ty } + if { &context.types[ty].inner == argument_inner } => {} + _ => { + return Err(SubgroupError::ResultTypeMismatch(result) + .with_span_handle(result, context.expressions) + .into_other()) + } + } + Ok(()) + } #[cfg(feature = "validate")] fn validate_block_impl( @@ -919,6 +1037,43 @@ impl super::Validator { crate::RayQueryFunction::Terminate => {} } } + S::SubgroupBallot { result, predicate } => { + if let Some(predicate) = predicate { + let predicate_inner = + context.resolve_type(predicate, &self.valid_expression_set)?; + if !matches!( + *predicate_inner, + crate::TypeInner::Scalar { + kind: crate::ScalarKind::Bool, + .. + } + ) { + log::error!( + "Subgroup ballot predicate type {:?} expected bool", + predicate_inner + ); + return Err(SubgroupError::InvalidOperand(predicate) + .with_span_handle(predicate, context.expressions) + .into_other()); + } + } + self.emit_expression(result, context)?; + } + S::SubgroupCollectiveOperation { + ref op, + ref collective_op, + argument, + result, + } => { + self.validate_subgroup_operation(op, collective_op, argument, result, context)?; + } + S::SubgroupGather { + ref mode, + argument, + result, + } => { + self.validate_subgroup_broadcast(mode, argument, result, context)?; + } } } Ok(BlockInfo { stages, finished }) diff --git a/src/valid/handles.rs b/src/valid/handles.rs index c68ded074b..e1674f0804 100644 --- a/src/valid/handles.rs +++ b/src/valid/handles.rs @@ -394,6 +394,8 @@ impl super::Validator { } crate::Expression::AtomicResult { .. } | crate::Expression::RayQueryProceedResult + | crate::Expression::SubgroupBallotResult + | crate::Expression::SubgroupOperationResult { .. } | crate::Expression::WorkGroupUniformLoadResult { .. } => (), crate::Expression::ArrayLength(array) => { handle.check_dep(array)?; @@ -539,6 +541,38 @@ impl super::Validator { } Ok(()) } + crate::Statement::SubgroupBallot { result, predicate } => { + validate_expr_opt(predicate)?; + validate_expr(result)?; + Ok(()) + } + crate::Statement::SubgroupCollectiveOperation { + op: _, + collective_op: _, + argument, + result, + } => { + validate_expr(argument)?; + validate_expr(result)?; + Ok(()) + } + crate::Statement::SubgroupGather { + mode, + argument, + result, + } => { + validate_expr(argument)?; + match mode { + crate::GatherMode::BroadcastFirst => {} + crate::GatherMode::Broadcast(index) + | crate::GatherMode::Shuffle(index) + | crate::GatherMode::ShuffleDown(index) + | crate::GatherMode::ShuffleUp(index) + | crate::GatherMode::ShuffleXor(index) => validate_expr(index)?, + } + validate_expr(result)?; + Ok(()) + } crate::Statement::Break | crate::Statement::Continue | crate::Statement::Kill diff --git a/src/valid/interface.rs b/src/valid/interface.rs index 6c41ece81f..4b5f66492c 100644 --- a/src/valid/interface.rs +++ b/src/valid/interface.rs @@ -299,6 +299,25 @@ impl VaryingContext<'_> { width, }, ), + Bi::NumSubgroups | Bi::SubgroupId => ( + self.stage == St::Compute && !self.output, + *ty_inner + == Ti::Scalar { + kind: Sk::Uint, + width, + }, + ), + Bi::SubgroupSize | Bi::SubgroupInvocationId => ( + match self.stage { + St::Compute | St::Fragment => !self.output, + St::Vertex => false, + }, + *ty_inner + == Ti::Scalar { + kind: Sk::Uint, + width, + }, + ), }; if !visible { diff --git a/tests/in/subgroup-operations.param.ron b/tests/in/subgroup-operations.param.ron new file mode 100644 index 0000000000..fc444a3efe --- /dev/null +++ b/tests/in/subgroup-operations.param.ron @@ -0,0 +1,26 @@ +( + spv: ( + version: (1, 3), + ), + msl: ( + lang_version: (2, 4), + per_entry_point_map: {}, + inline_samplers: [], + spirv_cross_compatibility: false, + fake_missing_bindings: false, + zero_initialize_workgroup_memory: true, + ), + glsl: ( + version: Desktop(430), + writer_flags: (""), + binding_map: { }, + zero_initialize_workgroup_memory: true, + ), + hlsl: ( + shader_model: V6_0, + binding_map: {}, + fake_missing_bindings: true, + special_constants_binding: None, + zero_initialize_workgroup_memory: true, + ), +) diff --git a/tests/in/subgroup-operations.wgsl b/tests/in/subgroup-operations.wgsl new file mode 100644 index 0000000000..f30b60be47 --- /dev/null +++ b/tests/in/subgroup-operations.wgsl @@ -0,0 +1,32 @@ +@compute @workgroup_size(1) +fn main( + @builtin(num_subgroups) num_subgroups: u32, + @builtin(subgroup_id) subgroup_id: u32, + @builtin(subgroup_size) subgroup_size: u32, + @builtin(subgroup_invocation_id) subgroup_invocation_id: u32, +) { + subgroupBarrier(); + + subgroupBallot((subgroup_invocation_id & 1u) == 1u); + + subgroupAll(subgroup_invocation_id != 0u); + subgroupAny(subgroup_invocation_id == 0u); + subgroupAdd(subgroup_invocation_id); + subgroupMul(subgroup_invocation_id); + subgroupMin(subgroup_invocation_id); + subgroupMax(subgroup_invocation_id); + subgroupAnd(subgroup_invocation_id); + subgroupOr(subgroup_invocation_id); + subgroupXor(subgroup_invocation_id); + subgroupPrefixExclusiveAdd(subgroup_invocation_id); + subgroupPrefixExclusiveMul(subgroup_invocation_id); + subgroupPrefixInclusiveAdd(subgroup_invocation_id); + subgroupPrefixInclusiveMul(subgroup_invocation_id); + + subgroupBroadcastFirst(subgroup_invocation_id); + subgroupBroadcast(subgroup_invocation_id, 4u); + subgroupShuffle(subgroup_invocation_id, subgroup_size - 1u - subgroup_invocation_id); + subgroupShuffleDown(subgroup_invocation_id, 1u); + subgroupShuffleUp(subgroup_invocation_id, 1u); + subgroupShuffleXor(subgroup_invocation_id, subgroup_size - 1u); +} diff --git a/tests/out/glsl/subgroup-operations.main.Compute.glsl b/tests/out/glsl/subgroup-operations.main.Compute.glsl new file mode 100644 index 0000000000..a37cf8e247 --- /dev/null +++ b/tests/out/glsl/subgroup-operations.main.Compute.glsl @@ -0,0 +1,41 @@ +#version 430 core +#extension GL_ARB_compute_shader : require +#extension GL_KHR_shader_subgroup_basic : require +#extension GL_KHR_shader_subgroup_vote : require +#extension GL_KHR_shader_subgroup_arithmetic : require +#extension GL_KHR_shader_subgroup_ballot : require +#extension GL_KHR_shader_subgroup_shuffle : require +#extension GL_KHR_shader_subgroup_shuffle_relative : require +layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in; + + +void main() { + uint num_subgroups = gl_NumSubgroups; + uint subgroup_id = gl_SubgroupID; + uint subgroup_size = gl_SubgroupSize; + uint subgroup_invocation_id = gl_SubgroupInvocationID; + subgroupMemoryBarrier(); + barrier(); + uvec4 _e8 = subgroupBallot(((subgroup_invocation_id & 1u) == 1u)); + bool _e11 = subgroupAll((subgroup_invocation_id != 0u)); + bool _e14 = subgroupAny((subgroup_invocation_id == 0u)); + uint _e15 = subgroupAdd(subgroup_invocation_id); + uint _e16 = subgroupMul(subgroup_invocation_id); + uint _e17 = subgroupMin(subgroup_invocation_id); + uint _e18 = subgroupMax(subgroup_invocation_id); + uint _e19 = subgroupAnd(subgroup_invocation_id); + uint _e20 = subgroupOr(subgroup_invocation_id); + uint _e21 = subgroupXor(subgroup_invocation_id); + uint _e22 = subgroupExclusiveAdd(subgroup_invocation_id); + uint _e23 = subgroupExclusiveMul(subgroup_invocation_id); + uint _e24 = subgroupInclusiveAdd(subgroup_invocation_id); + uint _e25 = subgroupInclusiveMul(subgroup_invocation_id); + uint _e26 = subgroupBroadcastFirst(subgroup_invocation_id); + uint _e28 = subgroupBroadcast(subgroup_invocation_id, 4u); + uint _e32 = subgroupShuffle(subgroup_invocation_id, ((subgroup_size - 1u) - subgroup_invocation_id)); + uint _e34 = subgroupShuffleDown(subgroup_invocation_id, 1u); + uint _e36 = subgroupShuffleUp(subgroup_invocation_id, 1u); + uint _e39 = subgroupShuffleXor(subgroup_invocation_id, (subgroup_size - 1u)); + return; +} + diff --git a/tests/out/hlsl/subgroup-operations.hlsl b/tests/out/hlsl/subgroup-operations.hlsl new file mode 100644 index 0000000000..baa37826e0 --- /dev/null +++ b/tests/out/hlsl/subgroup-operations.hlsl @@ -0,0 +1,32 @@ +[numthreads(1, 1, 1)] +void main(uint3 __local_invocation_id : SV_GroupThreadID) +{ + if (all(__local_invocation_id == uint3(0u, 0u, 0u))) { + } + GroupMemoryBarrierWithGroupSync(); + const uint num_subgroups = (1u + WaveGetLaneCount() - 1u) / WaveGetLaneCount(); + const uint subgroup_id = (__local_invocation_id.x * 1u + __local_invocation_id.y * 1u + __local_invocation_id.z) / WaveGetLaneCount(); + const uint subgroup_size = WaveGetLaneCount(); + const uint subgroup_invocation_id = WaveGetLaneIndex(); + const uint4 _e8 = WaveActiveBallot(((subgroup_invocation_id & 1u) == 1u)); + const bool _e11 = WaveActiveAllTrue((subgroup_invocation_id != 0u)); + const bool _e14 = WaveActiveAnyTrue((subgroup_invocation_id == 0u)); + const uint _e15 = WaveActiveSum(subgroup_invocation_id); + const uint _e16 = WaveActiveProduct(subgroup_invocation_id); + const uint _e17 = WaveActiveMin(subgroup_invocation_id); + const uint _e18 = WaveActiveMax(subgroup_invocation_id); + const uint _e19 = WaveActiveBitAnd(subgroup_invocation_id); + const uint _e20 = WaveActiveBitOr(subgroup_invocation_id); + const uint _e21 = WaveActiveBitXor(subgroup_invocation_id); + const uint _e22 = WavePrefixSum(subgroup_invocation_id); + const uint _e23 = WavePrefixProduct(subgroup_invocation_id); + const uint _e24 = subgroup_invocation_id + WavePrefixSum(subgroup_invocation_id); + const uint _e25 = subgroup_invocation_id * WavePrefixProduct(subgroup_invocation_id); + const uint _e26 = WaveReadLaneFirst(subgroup_invocation_id); + const uint _e28 = WaveReadLaneAt(subgroup_invocation_id, 4u); + const uint _e32 = WaveReadLaneAt(subgroup_invocation_id, ((subgroup_size - 1u) - subgroup_invocation_id)); + const uint _e34 = WaveReadLaneAt(subgroup_invocation_id, WaveGetLaneIndex() + 1u); + const uint _e36 = WaveReadLaneAt(subgroup_invocation_id, WaveGetLaneIndex() - 1u); + const uint _e39 = WaveReadLaneAt(subgroup_invocation_id, WaveGetLaneIndex() ^ (subgroup_size - 1u)); + return; +} diff --git a/tests/out/hlsl/subgroup-operations.ron b/tests/out/hlsl/subgroup-operations.ron new file mode 100644 index 0000000000..b973fe3da1 --- /dev/null +++ b/tests/out/hlsl/subgroup-operations.ron @@ -0,0 +1,12 @@ +( + vertex:[ + ], + fragment:[ + ], + compute:[ + ( + entry_point:"main", + target_profile:"cs_6_0", + ), + ], +) diff --git a/tests/out/msl/subgroup-operations.msl b/tests/out/msl/subgroup-operations.msl new file mode 100644 index 0000000000..576fa3b84e --- /dev/null +++ b/tests/out/msl/subgroup-operations.msl @@ -0,0 +1,38 @@ +// language: metal2.4 +#include +#include + +using metal::uint; + + +struct main_Input { +}; +kernel void main_( + uint num_subgroups [[simdgroups_per_threadgroup]] +, uint subgroup_id [[simdgroup_index_in_threadgroup]] +, uint subgroup_size [[threads_per_simdgroup]] +, uint subgroup_invocation_id [[thread_index_in_simdgroup]] +) { + metal::simdgroup_barrier(metal::mem_flags::mem_threadgroup); + metal::uint4 unnamed = uint4((uint64_t)metal::simd_ballot((subgroup_invocation_id & 1u) == 1u), 0, 0, 0); + bool unnamed_1 = metal::simd_all(subgroup_invocation_id != 0u); + bool unnamed_2 = metal::simd_any(subgroup_invocation_id == 0u); + uint unnamed_3 = metal::simd_sum(subgroup_invocation_id); + uint unnamed_4 = metal::simd_product(subgroup_invocation_id); + uint unnamed_5 = metal::simd_min(subgroup_invocation_id); + uint unnamed_6 = metal::simd_max(subgroup_invocation_id); + uint unnamed_7 = metal::simd_and(subgroup_invocation_id); + uint unnamed_8 = metal::simd_or(subgroup_invocation_id); + uint unnamed_9 = metal::simd_xor(subgroup_invocation_id); + uint unnamed_10 = metal::simd_prefix_exclusive_sum(subgroup_invocation_id); + uint unnamed_11 = metal::simd_prefix_exclusive_product(subgroup_invocation_id); + uint unnamed_12 = metal::simd_prefix_inclusive_sum(subgroup_invocation_id); + uint unnamed_13 = metal::simd_prefix_inclusive_product(subgroup_invocation_id); + uint unnamed_14 = metal::simd_broadcast_first(subgroup_invocation_id); + uint unnamed_15 = metal::simd_broadcast(subgroup_invocation_id, 4u); + uint unnamed_16 = metal::simd_shuffle(subgroup_invocation_id, (subgroup_size - 1u) - subgroup_invocation_id); + uint unnamed_17 = metal::simd_shuffle_down(subgroup_invocation_id, 1u); + uint unnamed_18 = metal::simd_shuffle_up(subgroup_invocation_id, 1u); + uint unnamed_19 = metal::simd_shuffle_xor(subgroup_invocation_id, subgroup_size - 1u); + return; +} diff --git a/tests/out/spv/subgroup-operations.spvasm b/tests/out/spv/subgroup-operations.spvasm new file mode 100644 index 0000000000..c2023c5473 --- /dev/null +++ b/tests/out/spv/subgroup-operations.spvasm @@ -0,0 +1,73 @@ +; SPIR-V +; Version: 1.3 +; Generator: rspirv +; Bound: 52 +OpCapability Shader +OpCapability GroupNonUniform +OpCapability GroupNonUniformBallot +OpCapability GroupNonUniformVote +OpCapability GroupNonUniformArithmetic +OpCapability GroupNonUniformShuffle +OpCapability GroupNonUniformShuffleRelative +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint GLCompute %15 "main" %6 %9 %11 %13 +OpExecutionMode %15 LocalSize 1 1 1 +OpDecorate %6 BuiltIn NumSubgroups +OpDecorate %9 BuiltIn SubgroupId +OpDecorate %11 BuiltIn SubgroupSize +OpDecorate %13 BuiltIn SubgroupLocalInvocationId +%2 = OpTypeVoid +%3 = OpTypeInt 32 0 +%4 = OpTypeBool +%7 = OpTypePointer Input %3 +%6 = OpVariable %7 Input +%9 = OpVariable %7 Input +%11 = OpVariable %7 Input +%13 = OpVariable %7 Input +%16 = OpTypeFunction %2 +%17 = OpConstant %3 1 +%18 = OpConstant %3 0 +%19 = OpConstant %3 4 +%21 = OpConstant %3 3 +%22 = OpConstant %3 2 +%23 = OpConstant %3 8 +%26 = OpTypeVector %3 4 +%15 = OpFunction %2 None %16 +%5 = OpLabel +%8 = OpLoad %3 %6 +%10 = OpLoad %3 %9 +%12 = OpLoad %3 %11 +%14 = OpLoad %3 %13 +OpBranch %20 +%20 = OpLabel +OpControlBarrier %21 %22 %23 +%24 = OpBitwiseAnd %3 %14 %17 +%25 = OpIEqual %4 %24 %17 +%27 = OpGroupNonUniformBallot %26 %21 %25 +%28 = OpINotEqual %4 %14 %18 +%29 = OpGroupNonUniformAll %4 %21 %28 +%30 = OpIEqual %4 %14 %18 +%31 = OpGroupNonUniformAny %4 %21 %30 +%32 = OpGroupNonUniformIAdd %3 %21 Reduce %14 +%33 = OpGroupNonUniformIMul %3 %21 Reduce %14 +%34 = OpGroupNonUniformUMin %3 %21 Reduce %14 +%35 = OpGroupNonUniformUMax %3 %21 Reduce %14 +%36 = OpGroupNonUniformBitwiseAnd %3 %21 Reduce %14 +%37 = OpGroupNonUniformBitwiseOr %3 %21 Reduce %14 +%38 = OpGroupNonUniformBitwiseXor %3 %21 Reduce %14 +%39 = OpGroupNonUniformIAdd %3 %21 ExclusiveScan %14 +%40 = OpGroupNonUniformIMul %3 %21 ExclusiveScan %14 +%41 = OpGroupNonUniformIAdd %3 %21 InclusiveScan %14 +%42 = OpGroupNonUniformIMul %3 %21 InclusiveScan %14 +%43 = OpGroupNonUniformBroadcastFirst %3 %21 %14 +%44 = OpGroupNonUniformBroadcast %3 %21 %14 %19 +%45 = OpISub %3 %12 %17 +%46 = OpISub %3 %45 %14 +%47 = OpGroupNonUniformShuffle %3 %21 %14 %46 +%48 = OpGroupNonUniformShuffleDown %3 %21 %14 %17 +%49 = OpGroupNonUniformShuffleUp %3 %21 %14 %17 +%50 = OpISub %3 %12 %17 +%51 = OpGroupNonUniformShuffleXor %3 %21 %14 %50 +OpReturn +OpFunctionEnd \ No newline at end of file diff --git a/tests/out/wgsl/subgroup-operations.wgsl b/tests/out/wgsl/subgroup-operations.wgsl new file mode 100644 index 0000000000..f12f226387 --- /dev/null +++ b/tests/out/wgsl/subgroup-operations.wgsl @@ -0,0 +1,26 @@ +@compute @workgroup_size(1, 1, 1) +fn main(@builtin(num_subgroups) num_subgroups: u32, @builtin(subgroup_id) subgroup_id: u32, @builtin(subgroup_size) subgroup_size: u32, @builtin(subgroup_invocation_id) subgroup_invocation_id: u32) { + subgroupBarrier(); + let _e8 = subgroupBallot( +((subgroup_invocation_id & 1u) == 1u)); + let _e11 = subgroupAll((subgroup_invocation_id != 0u)); + let _e14 = subgroupAny((subgroup_invocation_id == 0u)); + let _e15 = subgroupAdd(subgroup_invocation_id); + let _e16 = subgroupMul(subgroup_invocation_id); + let _e17 = subgroupMin(subgroup_invocation_id); + let _e18 = subgroupMax(subgroup_invocation_id); + let _e19 = subgroupAnd(subgroup_invocation_id); + let _e20 = subgroupOr(subgroup_invocation_id); + let _e21 = subgroupXor(subgroup_invocation_id); + let _e22 = subgroupPrefixExclusiveAdd(subgroup_invocation_id); + let _e23 = subgroupPrefixExclusiveMul(subgroup_invocation_id); + let _e24 = subgroupPrefixInclusiveAdd(subgroup_invocation_id); + let _e25 = subgroupPrefixInclusiveMul(subgroup_invocation_id); + let _e26 = subgroupBroadcastFirst(subgroup_invocation_id); + let _e28 = subgroupBroadcast(subgroup_invocation_id, 4u); + let _e32 = subgroupShuffle(subgroup_invocation_id, ((subgroup_size - 1u) - subgroup_invocation_id)); + let _e34 = subgroupShuffleDown(subgroup_invocation_id, 1u); + let _e36 = subgroupShuffleUp(subgroup_invocation_id, 1u); + let _e39 = subgroupShuffleXor(subgroup_invocation_id, (subgroup_size - 1u)); + return; +} diff --git a/tests/snapshots.rs b/tests/snapshots.rs index c3455dd864..c720e2efd1 100644 --- a/tests/snapshots.rs +++ b/tests/snapshots.rs @@ -782,6 +782,10 @@ fn convert_wgsl() { Targets::SPIRV | Targets::METAL | Targets::GLSL | Targets::HLSL | Targets::WGSL, ), ("separate-entry-points", Targets::SPIRV | Targets::GLSL), + ( + "subgroup-operations", + Targets::SPIRV | Targets::METAL | Targets::GLSL | Targets::HLSL | Targets::WGSL, + ), ]; for &(name, targets) in inputs.iter() {