From f49812d90011a9ab27513b632dc83343b709ca87 Mon Sep 17 00:00:00 2001 From: Jacob Hughes Date: Fri, 29 Sep 2023 23:31:39 -0400 Subject: [PATCH 01/17] subgroup: Implement subgroupBallot for wgsl-in, wgsl-out, spv-out, hlsl-out TODO: metal out, figure out what needs to be done in validation --- src/back/dot/mod.rs | 5 +++++ src/back/glsl/mod.rs | 13 ++++++++++++- src/back/hlsl/writer.rs | 12 +++++++++++- src/back/msl/writer.rs | 2 ++ src/back/spv/block.rs | 21 ++++++++++++++++++++- src/back/spv/instructions.rs | 17 +++++++++++++++++ src/back/wgsl/writer.rs | 9 +++++++++ src/compact/expressions.rs | 2 ++ src/compact/statements.rs | 4 ++++ src/front/spv/mod.rs | 1 + src/front/wgsl/lower/mod.rs | 10 ++++++++++ src/lib.rs | 16 ++++++++++++++-- src/proc/terminator.rs | 1 + src/proc/typifier.rs | 5 +++++ src/valid/analyzer.rs | 5 +++++ src/valid/expression.rs | 1 + src/valid/function.rs | 3 +++ src/valid/handles.rs | 5 +++++ 18 files changed, 127 insertions(+), 5 deletions(-) diff --git a/src/back/dot/mod.rs b/src/back/dot/mod.rs index 1556371df1..b24eeae5ad 100644 --- a/src/back/dot/mod.rs +++ b/src/back/dot/mod.rs @@ -279,6 +279,10 @@ impl StatementGraph { crate::RayQueryFunction::Terminate => "RayQueryTerminate", } } + S::SubgroupBallot { result } => { + self.emits.push((id, result)); + "SubgroupBallot" + } }; // Set the last node to the merge node last_node = merge_id; @@ -586,6 +590,7 @@ fn write_function_expressions( let ty = if committed { "Committed" } else { "Candidate" }; (format!("rayQueryGet{}Intersection", ty).into(), 4) } + E::SubgroupBallotResult => ("SubgroupBallotResult".into(), 4), }; // give uniform expressions an outline diff --git a/src/back/glsl/mod.rs b/src/back/glsl/mod.rs index 60431e986e..43360f16a7 100644 --- a/src/back/glsl/mod.rs +++ b/src/back/glsl/mod.rs @@ -2250,6 +2250,16 @@ impl<'a, W: Write> Writer<'a, W> { writeln!(self.out, ");")?; } Statement::RayQuery { .. } => unreachable!(), + Statement::SubgroupBallot { result } => { + write!(self.out, "{level}")?; + let res_name = format!("{}{}", back::BAKE_PREFIX, result.index()); + let res_ty = ctx.info[result].ty.inner_with(&self.module.types); + self.write_value_type(res_ty)?; + write!(self.out, " {res_name} = ")?; + self.named_expressions.insert(result, res_name); + + writeln!(self.out, "subgroupBallot(true);")?; + } } Ok(()) @@ -3423,7 +3433,8 @@ impl<'a, W: Write> Writer<'a, W> { Expression::CallResult(_) | Expression::AtomicResult { .. } | Expression::RayQueryProceedResult - | Expression::WorkGroupUniformLoadResult { .. } => unreachable!(), + | Expression::WorkGroupUniformLoadResult { .. } + | Expression::SubgroupBallotResult => unreachable!(), // `ArrayLength` is written as `expr.length()` and we convert it to a uint Expression::ArrayLength(expr) => { write!(self.out, "uint(")?; diff --git a/src/back/hlsl/writer.rs b/src/back/hlsl/writer.rs index f26604476a..fc92cbd800 100644 --- a/src/back/hlsl/writer.rs +++ b/src/back/hlsl/writer.rs @@ -2004,6 +2004,15 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { writeln!(self.out, "{level}}}")? } Statement::RayQuery { .. } => unreachable!(), + Statement::SubgroupBallot { result } => { + write!(self.out, "{level}")?; + + let name = format!("{}{}", back::BAKE_PREFIX, result.index()); + write!(self.out, "const uint4 {name} = ")?; + self.named_expressions.insert(result, name); + + writeln!(self.out, "WaveActiveBallot(true);")?; + } } Ok(()) @@ -3152,7 +3161,8 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { Expression::CallResult(_) | Expression::AtomicResult { .. } | Expression::WorkGroupUniformLoadResult { .. } - | Expression::RayQueryProceedResult => {} + | Expression::RayQueryProceedResult + | Expression::SubgroupBallotResult => {} } if !closing_bracket.is_empty() { diff --git a/src/back/msl/writer.rs b/src/back/msl/writer.rs index 09f7b1c73f..5874b09493 100644 --- a/src/back/msl/writer.rs +++ b/src/back/msl/writer.rs @@ -1997,6 +1997,7 @@ impl Writer { } write!(self.out, "}}")?; } + crate::Expression::SubgroupBallotResult => todo!(), } Ok(()) } @@ -3010,6 +3011,7 @@ impl Writer { } } } + crate::Statement::SubgroupBallot { result } => todo!(), } } diff --git a/src/back/spv/block.rs b/src/back/spv/block.rs index 0471d957f0..357b7c3459 100644 --- a/src/back/spv/block.rs +++ b/src/back/spv/block.rs @@ -1130,7 +1130,8 @@ impl<'w> BlockContext<'w> { crate::Expression::CallResult(_) | crate::Expression::AtomicResult { .. } | crate::Expression::WorkGroupUniformLoadResult { .. } - | crate::Expression::RayQueryProceedResult => self.cached[expr_handle], + | crate::Expression::RayQueryProceedResult + | crate::Expression::SubgroupBallotResult => self.cached[expr_handle], crate::Expression::As { expr, kind, @@ -2338,6 +2339,24 @@ impl<'w> BlockContext<'w> { crate::Statement::RayQuery { query, ref fun } => { self.write_ray_query_function(query, fun, &mut block); } + crate::Statement::SubgroupBallot { result } => { + let vec4_u32_type_id = self.get_type_id(LookupType::Local(LocalType::Value { + vector_size: Some(crate::VectorSize::Quad), + kind: crate::ScalarKind::Uint, + width: 4, + pointer_space: None, + })); + let exec_scope_id = self.get_index_constant(spirv::Scope::Subgroup as u32); + let predicate = self.writer.get_constant_scalar(crate::Literal::Bool(true)); + let id = self.gen_id(); + block.body.push(Instruction::group_non_uniform_ballot( + vec4_u32_type_id, + id, + exec_scope_id, + predicate, + )); + self.cached[result] = id; + } } } diff --git a/src/back/spv/instructions.rs b/src/back/spv/instructions.rs index b963793ad3..1ca58431d5 100644 --- a/src/back/spv/instructions.rs +++ b/src/back/spv/instructions.rs @@ -1037,6 +1037,23 @@ impl super::Instruction { instruction.add_operand(semantics_id); instruction } + + // Group Instructions + + pub(super) fn group_non_uniform_ballot( + result_type_id: Word, + id: Word, + exec_scope_id: Word, + predicate: Word, + ) -> Self { + let mut instruction = Self::new(Op::GroupNonUniformBallot); + instruction.set_type(result_type_id); + instruction.set_result(id); + instruction.add_operand(exec_scope_id); + instruction.add_operand(predicate); + + instruction + } } impl From for spirv::ImageFormat { diff --git a/src/back/wgsl/writer.rs b/src/back/wgsl/writer.rs index 075d85558c..18887825ea 100644 --- a/src/back/wgsl/writer.rs +++ b/src/back/wgsl/writer.rs @@ -921,6 +921,14 @@ impl Writer { } } Statement::RayQuery { .. } => unreachable!(), + Statement::SubgroupBallot { result } => { + write!(self.out, "{level}")?; + let res_name = format!("{}{}", back::BAKE_PREFIX, result.index()); + self.start_named_expr(module, result, func_ctx, &res_name)?; + self.named_expressions.insert(result, res_name); + + writeln!(self.out, "subgroupBallot();")?; + } } Ok(()) @@ -1659,6 +1667,7 @@ impl Writer { Expression::CallResult(_) | Expression::AtomicResult { .. } | Expression::RayQueryProceedResult + | Expression::SubgroupBallotResult | Expression::WorkGroupUniformLoadResult { .. } => {} } diff --git a/src/compact/expressions.rs b/src/compact/expressions.rs index c1326e92be..d62d00a85f 100644 --- a/src/compact/expressions.rs +++ b/src/compact/expressions.rs @@ -55,6 +55,7 @@ impl<'tracer> ExpressionTracer<'tracer> { | Ex::GlobalVariable(_) | Ex::LocalVariable(_) | Ex::CallResult(_) + | Ex::SubgroupBallotResult // FIXME: ??? | Ex::RayQueryProceedResult => {} Ex::Constant(handle) => { @@ -222,6 +223,7 @@ impl ModuleMap { | Ex::GlobalVariable(_) | Ex::LocalVariable(_) | Ex::CallResult(_) + | Ex::SubgroupBallotResult // FIXME: ??? | Ex::RayQueryProceedResult => {} // Expressions that contain handles that need to be adjusted. diff --git a/src/compact/statements.rs b/src/compact/statements.rs index 4c62771023..3b27c8b71a 100644 --- a/src/compact/statements.rs +++ b/src/compact/statements.rs @@ -95,6 +95,9 @@ impl FunctionTracer<'_> { self.trace_expression(query); self.trace_ray_query_function(fun); } + St::SubgroupBallot { result } => { + self.trace_expression(result); + } // Trivial statements. St::Break @@ -244,6 +247,7 @@ impl FunctionMap { adjust(query); self.adjust_ray_query_function(fun); } + St::SubgroupBallot { ref mut result } => adjust(result), // Trivial statements. St::Break diff --git a/src/front/spv/mod.rs b/src/front/spv/mod.rs index 083205a45b..2658a95d30 100644 --- a/src/front/spv/mod.rs +++ b/src/front/spv/mod.rs @@ -3842,6 +3842,7 @@ impl> Frontend { } } S::WorkGroupUniformLoad { .. } => unreachable!(), + S::SubgroupBallot { .. } => unreachable!(), } i += 1; } diff --git a/src/front/wgsl/lower/mod.rs b/src/front/wgsl/lower/mod.rs index ae178cb702..0b727739be 100644 --- a/src/front/wgsl/lower/mod.rs +++ b/src/front/wgsl/lower/mod.rs @@ -2290,6 +2290,16 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { )?; return Ok(Some(handle)); } + "subgroupBallot" => { + ctx.prepare_args(arguments, 0, span).finish()?; + + let result = ctx + .interrupt_emitter(crate::Expression::SubgroupBallotResult, span); + let rctx = ctx.runtime_expression_ctx(span)?; + rctx.block + .push(crate::Statement::SubgroupBallot { result }, span); + return Ok(Some(result)); + } _ => return Err(Error::UnknownIdent(function.span, function.name)), } }; diff --git a/src/lib.rs b/src/lib.rs index 300c6e4820..31cae75d8d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1399,7 +1399,9 @@ pub enum Expression { /// /// For [`TypeInner::Atomic`] the result is a corresponding scalar. /// For other types behind the `pointer`, the result is `T`. - Load { pointer: Handle }, + Load { + pointer: Handle, + }, /// Sample a point from a sampled or a depth image. ImageSample { image: Handle, @@ -1539,7 +1541,10 @@ pub enum Expression { /// Result of calling another function. CallResult(Handle), /// Result of an atomic operation. - AtomicResult { ty: Handle, comparison: bool }, + AtomicResult { + ty: Handle, + comparison: bool, + }, /// Result of a [`WorkGroupUniformLoad`] statement. /// /// [`WorkGroupUniformLoad`]: Statement::WorkGroupUniformLoad @@ -1567,6 +1572,7 @@ pub enum Expression { query: Handle, committed: bool, }, + SubgroupBallotResult, } pub use block::Block; @@ -1839,6 +1845,12 @@ pub enum Statement { /// The specific operation we're performing on `query`. fun: RayQueryFunction, }, + SubgroupBallot { + /// The [`SubgroupBallotResult`] expression representing this load's result. + /// + /// [`SubgroupBallotResult`]: Expression::SubgroupBallotResult + result: Handle, + }, } /// A function argument. diff --git a/src/proc/terminator.rs b/src/proc/terminator.rs index a5239d4eca..d2dde729f1 100644 --- a/src/proc/terminator.rs +++ b/src/proc/terminator.rs @@ -37,6 +37,7 @@ pub fn ensure_block_returns(block: &mut crate::Block) { | S::RayQuery { .. } | S::Atomic { .. } | S::WorkGroupUniformLoad { .. } + | S::SubgroupBallot { .. } | S::Barrier(_)), ) | None => block.push(S::Return { value: None }, Default::default()), diff --git a/src/proc/typifier.rs b/src/proc/typifier.rs index ad9eec94d2..6241c5bad8 100644 --- a/src/proc/typifier.rs +++ b/src/proc/typifier.rs @@ -905,6 +905,11 @@ impl<'a> ResolveContext<'a> { .ok_or(ResolveError::MissingSpecialType)?; TypeResolution::Handle(result) } + crate::Expression::SubgroupBallotResult => TypeResolution::Value(Ti::Vector { + kind: crate::ScalarKind::Uint, + size: crate::VectorSize::Quad, + width: 4, + }), }) } } diff --git a/src/valid/analyzer.rs b/src/valid/analyzer.rs index ff1db071c8..d23caaf473 100644 --- a/src/valid/analyzer.rs +++ b/src/valid/analyzer.rs @@ -740,6 +740,10 @@ impl FunctionInfo { non_uniform_result: self.add_ref(query), requirements: UniformityRequirements::empty(), }, + E::SubgroupBallotResult => Uniformity { + non_uniform_result: None, + requirements: UniformityRequirements::empty(), + }, }; let ty = resolve_context.resolve(expression, |h| Ok(&self[h].ty))?; @@ -983,6 +987,7 @@ impl FunctionInfo { } FunctionUniformity::new() } + S::SubgroupBallot { result: _ } => FunctionUniformity::new(), }; disruptor = disruptor.or(uniformity.exit_disruptor()); diff --git a/src/valid/expression.rs b/src/valid/expression.rs index f77844b4b1..890a4b9973 100644 --- a/src/valid/expression.rs +++ b/src/valid/expression.rs @@ -1537,6 +1537,7 @@ impl super::Validator { return Err(ExpressionError::InvalidRayQueryType(query)); } }, + E::SubgroupBallotResult => ShaderStages::COMPUTE | ShaderStages::FRAGMENT, }; Ok(stages) } diff --git a/src/valid/function.rs b/src/valid/function.rs index d967f4b1f3..52f51a2810 100644 --- a/src/valid/function.rs +++ b/src/valid/function.rs @@ -919,6 +919,9 @@ impl super::Validator { crate::RayQueryFunction::Terminate => {} } } + S::SubgroupBallot { result } => { + self.emit_expression(result, context)?; + } } } Ok(BlockInfo { stages, finished }) diff --git a/src/valid/handles.rs b/src/valid/handles.rs index c68ded074b..547dfac551 100644 --- a/src/valid/handles.rs +++ b/src/valid/handles.rs @@ -394,6 +394,7 @@ impl super::Validator { } crate::Expression::AtomicResult { .. } | crate::Expression::RayQueryProceedResult + | crate::Expression::SubgroupBallotResult | crate::Expression::WorkGroupUniformLoadResult { .. } => (), crate::Expression::ArrayLength(array) => { handle.check_dep(array)?; @@ -539,6 +540,10 @@ impl super::Validator { } Ok(()) } + crate::Statement::SubgroupBallot { result } => { + validate_expr(result)?; + Ok(()) + } crate::Statement::Break | crate::Statement::Continue | crate::Statement::Kill From 827e7a140c373472a3f56ad3f9471380abd326a5 Mon Sep 17 00:00:00 2001 From: Jacob Hughes Date: Sat, 30 Sep 2023 15:54:00 -0400 Subject: [PATCH 02/17] subgroup: subgroupBallot metal out --- src/back/msl/writer.rs | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/back/msl/writer.rs b/src/back/msl/writer.rs index 5874b09493..145a4f1ff0 100644 --- a/src/back/msl/writer.rs +++ b/src/back/msl/writer.rs @@ -1935,6 +1935,7 @@ impl Writer { crate::Expression::CallResult(_) | crate::Expression::AtomicResult { .. } | crate::Expression::WorkGroupUniformLoadResult { .. } + | crate::Expression::SubgroupBallotResult | crate::Expression::RayQueryProceedResult => { unreachable!() } @@ -1997,7 +1998,6 @@ impl Writer { } write!(self.out, "}}")?; } - crate::Expression::SubgroupBallotResult => todo!(), } Ok(()) } @@ -3011,7 +3011,13 @@ impl Writer { } } } - crate::Statement::SubgroupBallot { result } => todo!(), + crate::Statement::SubgroupBallot { result } => { + write!(self.out, "{level}")?; + let name = self.namer.call(""); + self.start_baking_expression(result, &context.expression, &name)?; + self.named_expressions.insert(result, name); + write!(self.out, "{NAMESPACE}::simd_active_threads_mask();")?; + } } } From 60a803aff6f8f57cc9758a12c2b409c5ff966a4e Mon Sep 17 00:00:00 2001 From: Jacob Hughes Date: Sat, 30 Sep 2023 17:53:44 -0400 Subject: [PATCH 03/17] subgroup: require GroupNonUnifomBallot capability --- src/back/spv/block.rs | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/back/spv/block.rs b/src/back/spv/block.rs index 357b7c3459..5c36cd8bcc 100644 --- a/src/back/spv/block.rs +++ b/src/back/spv/block.rs @@ -2340,6 +2340,10 @@ impl<'w> BlockContext<'w> { self.write_ray_query_function(query, fun, &mut block); } crate::Statement::SubgroupBallot { result } => { + self.writer.require_any( + "GroupNonUniformBallot", + &[spirv::Capability::GroupNonUniformBallot], + )?; let vec4_u32_type_id = self.get_type_id(LookupType::Local(LocalType::Value { vector_size: Some(crate::VectorSize::Quad), kind: crate::ScalarKind::Uint, From f8e8e074fadd56fb3e3c8e147f4da8bf69474614 Mon Sep 17 00:00:00 2001 From: Jacob Hughes Date: Sat, 30 Sep 2023 22:59:39 -0400 Subject: [PATCH 04/17] subgroup: Add subgroup invocation id and subgroup size builtins --- src/back/glsl/mod.rs | 3 +++ src/back/hlsl/conv.rs | 9 ++++++--- src/back/msl/mod.rs | 3 +++ src/back/spv/writer.rs | 18 ++++++++++++++++++ src/back/wgsl/writer.rs | 2 ++ src/front/wgsl/parse/conv.rs | 3 +++ src/lib.rs | 3 +++ src/valid/interface.rs | 11 +++++++++++ 8 files changed, 49 insertions(+), 3 deletions(-) diff --git a/src/back/glsl/mod.rs b/src/back/glsl/mod.rs index 43360f16a7..baefb7248b 100644 --- a/src/back/glsl/mod.rs +++ b/src/back/glsl/mod.rs @@ -4167,6 +4167,9 @@ const fn glsl_built_in( Bi::WorkGroupId => "gl_WorkGroupID", Bi::WorkGroupSize => "gl_WorkGroupSize", Bi::NumWorkGroups => "gl_NumWorkGroups", + // subgroup + Bi::SubgroupInvocationId => "gl_SubgroupInvocationID", + Bi::SubgroupSize => "gl_SubgroupSize", } } diff --git a/src/back/hlsl/conv.rs b/src/back/hlsl/conv.rs index 19bde6926a..19c4da5e74 100644 --- a/src/back/hlsl/conv.rs +++ b/src/back/hlsl/conv.rs @@ -166,9 +166,12 @@ impl crate::BuiltIn { // to this field will get replaced with references to `SPECIAL_CBUF_VAR` // in `Writer::write_expr`. Self::NumWorkGroups => "SV_GroupID", - Self::BaseInstance | Self::BaseVertex | Self::WorkGroupSize => { - return Err(Error::Unimplemented(format!("builtin {self:?}"))) - } + + Self::SubgroupInvocationId + | Self::SubgroupSize + | Self::BaseInstance + | Self::BaseVertex + | Self::WorkGroupSize => return Err(Error::Unimplemented(format!("builtin {self:?}"))), Self::PointSize | Self::ViewIndex | Self::PointCoord => { return Err(Error::Custom(format!("Unsupported builtin {self:?}"))) } diff --git a/src/back/msl/mod.rs b/src/back/msl/mod.rs index 5ef18730c9..9e23d2a08d 100644 --- a/src/back/msl/mod.rs +++ b/src/back/msl/mod.rs @@ -437,6 +437,9 @@ impl ResolvedBinding { Bi::WorkGroupId => "threadgroup_position_in_grid", Bi::WorkGroupSize => "dispatch_threads_per_threadgroup", Bi::NumWorkGroups => "threadgroups_per_grid", + // subgroup + Bi::SubgroupInvocationId => "simdgroup_index_in_threadgroup", + Bi::SubgroupSize => "simdgroups_per_threadgroup", Bi::CullDistance | Bi::ViewIndex => { return Err(Error::UnsupportedBuiltIn(built_in)) } diff --git a/src/back/spv/writer.rs b/src/back/spv/writer.rs index 24cb14a161..a445fb2b10 100644 --- a/src/back/spv/writer.rs +++ b/src/back/spv/writer.rs @@ -1589,6 +1589,24 @@ impl Writer { Bi::WorkGroupId => BuiltIn::WorkgroupId, Bi::WorkGroupSize => BuiltIn::WorkgroupSize, Bi::NumWorkGroups => BuiltIn::NumWorkgroups, + // Subgroup + Bi::SubgroupInvocationId => { + self.require_any( + "`subgroup_invocation_id` built-in", + &[spirv::Capability::GroupNonUniform], + )?; + BuiltIn::SubgroupLocalInvocationId + } + Bi::SubgroupSize => { + self.require_any( + "`subgroup_invocation_id` built-in", + &[ + spirv::Capability::GroupNonUniform, + spirv::Capability::SubgroupBallotKHR, + ], + )?; + BuiltIn::SubgroupSize + } }; self.decorate(id, Decoration::BuiltIn, &[built_in as u32]); diff --git a/src/back/wgsl/writer.rs b/src/back/wgsl/writer.rs index 18887825ea..33d45635c4 100644 --- a/src/back/wgsl/writer.rs +++ b/src/back/wgsl/writer.rs @@ -1769,6 +1769,8 @@ fn builtin_str(built_in: crate::BuiltIn) -> Result<&'static str, Error> { Bi::SampleMask => "sample_mask", Bi::PrimitiveIndex => "primitive_index", Bi::ViewIndex => "view_index", + Bi::SubgroupInvocationId => "subgroup_invocation_id", + Bi::SubgroupSize => "subgroup_size", Bi::BaseInstance | Bi::BaseVertex | Bi::ClipDistance diff --git a/src/front/wgsl/parse/conv.rs b/src/front/wgsl/parse/conv.rs index 51977173d6..a27bdb1cbc 100644 --- a/src/front/wgsl/parse/conv.rs +++ b/src/front/wgsl/parse/conv.rs @@ -34,6 +34,9 @@ pub fn map_built_in(word: &str, span: Span) -> Result> "local_invocation_index" => crate::BuiltIn::LocalInvocationIndex, "workgroup_id" => crate::BuiltIn::WorkGroupId, "num_workgroups" => crate::BuiltIn::NumWorkGroups, + // subgroup + "subgroup_invocation_id" => crate::BuiltIn::SubgroupInvocationId, + "subgroup_size" => crate::BuiltIn::SubgroupSize, _ => return Err(Error::UnknownBuiltin(span)), }) } diff --git a/src/lib.rs b/src/lib.rs index 31cae75d8d..c4d9eef359 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -435,6 +435,9 @@ pub enum BuiltIn { WorkGroupId, WorkGroupSize, NumWorkGroups, + // subgroup + SubgroupInvocationId, + SubgroupSize, } /// Number of bytes per scalar. diff --git a/src/valid/interface.rs b/src/valid/interface.rs index 6c41ece81f..c1ffc4447a 100644 --- a/src/valid/interface.rs +++ b/src/valid/interface.rs @@ -299,6 +299,17 @@ impl VaryingContext<'_> { width, }, ), + Bi::SubgroupInvocationId | Bi::SubgroupSize => ( + match self.stage { + St::Compute | St::Fragment => !self.output, + St::Vertex => false, + }, + *ty_inner + == Ti::Scalar { + kind: Sk::Uint, + width, + }, + ), }; if !visible { From c260344fec28dfcf8be8f1975476c8543bd2e2ad Mon Sep 17 00:00:00 2001 From: Jacob Hughes Date: Sun, 1 Oct 2023 20:08:30 -0400 Subject: [PATCH 05/17] subgroup: SubgroupInvocationId is only valid in compute stages --- src/valid/interface.rs | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/src/valid/interface.rs b/src/valid/interface.rs index c1ffc4447a..bf4f397224 100644 --- a/src/valid/interface.rs +++ b/src/valid/interface.rs @@ -299,7 +299,15 @@ impl VaryingContext<'_> { width, }, ), - Bi::SubgroupInvocationId | Bi::SubgroupSize => ( + Bi::SubgroupInvocationId => ( + self.stage == St::Compute && !self.output, + *ty_inner + == Ti::Scalar { + kind: Sk::Uint, + width, + }, + ), + Bi::SubgroupSize => ( match self.stage { St::Compute | St::Fragment => !self.output, St::Vertex => false, From f0d4eb998900a5959e86ec50b66a742f474ed98a Mon Sep 17 00:00:00 2001 From: Jacob Hughes Date: Tue, 3 Oct 2023 16:20:25 -0400 Subject: [PATCH 06/17] subgroup: expierment with subgroupBarrier() based on OpControlbarrier SPIR-V OpControlBarrier with execution scope Subgroup has implementation defined behavior when executed nonuniformly. OpenCL SPIR-V execution spec say nonuniform execution is UB. Vulkan SPIR-V execution spec says nothing :). --- src/back/glsl/mod.rs | 3 +++ src/back/hlsl/writer.rs | 3 +++ src/back/msl/writer.rs | 3 +++ src/back/spv/writer.rs | 6 +++++- src/front/wgsl/lower/mod.rs | 8 ++++++++ src/lib.rs | 6 ++++-- 6 files changed, 26 insertions(+), 3 deletions(-) diff --git a/src/back/glsl/mod.rs b/src/back/glsl/mod.rs index baefb7248b..7b0fc970cf 100644 --- a/src/back/glsl/mod.rs +++ b/src/back/glsl/mod.rs @@ -3991,6 +3991,9 @@ impl<'a, W: Write> Writer<'a, W> { if flags.contains(crate::Barrier::WORK_GROUP) { writeln!(self.out, "{level}memoryBarrierShared();")?; } + if flags.contains(crate::Barrier::SUB_GROUP) { + unimplemented!() // FIXME + } writeln!(self.out, "{level}barrier();")?; Ok(()) } diff --git a/src/back/hlsl/writer.rs b/src/back/hlsl/writer.rs index fc92cbd800..821dcc743d 100644 --- a/src/back/hlsl/writer.rs +++ b/src/back/hlsl/writer.rs @@ -3229,6 +3229,9 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { if barrier.contains(crate::Barrier::WORK_GROUP) { writeln!(self.out, "{level}GroupMemoryBarrierWithGroupSync();")?; } + if barrier.contains(crate::Barrier::SUB_GROUP) { + unimplemented!() // FIXME + } Ok(()) } } diff --git a/src/back/msl/writer.rs b/src/back/msl/writer.rs index 145a4f1ff0..c2a1eaddd0 100644 --- a/src/back/msl/writer.rs +++ b/src/back/msl/writer.rs @@ -4354,6 +4354,9 @@ impl Writer { "{level}{NAMESPACE}::threadgroup_barrier({NAMESPACE}::mem_flags::mem_threadgroup);", )?; } + if flags.contains(crate::Barrier::SUB_GROUP) { + unimplemented!(); // FIXME + } Ok(()) } } diff --git a/src/back/spv/writer.rs b/src/back/spv/writer.rs index a445fb2b10..077751d90b 100644 --- a/src/back/spv/writer.rs +++ b/src/back/spv/writer.rs @@ -1314,7 +1314,11 @@ impl Writer { spirv::MemorySemantics::WORKGROUP_MEMORY, flags.contains(crate::Barrier::WORK_GROUP), ); - let exec_scope_id = self.get_index_constant(spirv::Scope::Workgroup as u32); + let exec_scope_id = if flags.contains(crate::Barrier::SUB_GROUP) { + self.get_index_constant(spirv::Scope::Subgroup as u32) + } else { + self.get_index_constant(spirv::Scope::Workgroup as u32) + }; let mem_scope_id = self.get_index_constant(memory_scope as u32); let semantics_id = self.get_index_constant(semantics.bits()); block.body.push(Instruction::control_barrier( diff --git a/src/front/wgsl/lower/mod.rs b/src/front/wgsl/lower/mod.rs index 0b727739be..49319cde63 100644 --- a/src/front/wgsl/lower/mod.rs +++ b/src/front/wgsl/lower/mod.rs @@ -2085,6 +2085,14 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { .push(crate::Statement::Barrier(crate::Barrier::WORK_GROUP), span); return Ok(None); } + "subgroupBarrier" => { + ctx.prepare_args(arguments, 0, span).finish()?; + + let rctx = ctx.runtime_expression_ctx(span)?; + rctx.block + .push(crate::Statement::Barrier(crate::Barrier::SUB_GROUP), span); + return Ok(None); + } "workgroupUniformLoad" => { let mut args = ctx.prepare_args(arguments, 1, span); let expr = args.next()?; diff --git a/src/lib.rs b/src/lib.rs index c4d9eef359..a6f755fae6 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1269,9 +1269,11 @@ bitflags::bitflags! { #[derive(Clone, Copy, Debug, Default, Eq, PartialEq)] pub struct Barrier: u32 { /// Barrier affects all `AddressSpace::Storage` accesses. - const STORAGE = 0x1; + const STORAGE = 1 << 0; /// Barrier affects all `AddressSpace::WorkGroup` accesses. - const WORK_GROUP = 0x2; + const WORK_GROUP = 1 << 1; + /// Barrier synchronizes execution across all invocations within a subgroup that exectue this instruction. + const SUB_GROUP = 1 << 2; } } From 3b60d7d0d98dd4c64a85a014d4848114de6133f0 Mon Sep 17 00:00:00 2001 From: Jacob Hughes Date: Sat, 14 Oct 2023 14:20:43 +0200 Subject: [PATCH 07/17] subgroup: add statement for rest of subgroup ops --- src/back/dot/mod.rs | 20 ++++++ src/back/glsl/mod.rs | 16 +++++ src/back/hlsl/writer.rs | 18 +++++- src/back/msl/writer.rs | 16 +++++ src/back/spv/block.rs | 18 +++++- src/back/spv/instructions.rs | 17 +++++ src/back/spv/mod.rs | 1 + src/back/spv/subgroup.rs | 83 ++++++++++++++++++++++++ src/back/wgsl/writer.rs | 16 +++++ src/compact/expressions.rs | 2 + src/compact/statements.rs | 40 ++++++++++++ src/front/spv/mod.rs | 4 +- src/front/wgsl/lower/mod.rs | 39 +++++++++++ src/lib.rs | 82 +++++++++++++++++++++++ src/proc/terminator.rs | 2 + src/proc/typifier.rs | 1 + src/valid/analyzer.rs | 26 +++++++- src/valid/expression.rs | 1 + src/valid/function.rs | 122 +++++++++++++++++++++++++++++++++++ src/valid/handles.rs | 22 +++++++ 20 files changed, 542 insertions(+), 4 deletions(-) create mode 100644 src/back/spv/subgroup.rs diff --git a/src/back/dot/mod.rs b/src/back/dot/mod.rs index b24eeae5ad..5ba3ffe49b 100644 --- a/src/back/dot/mod.rs +++ b/src/back/dot/mod.rs @@ -283,6 +283,25 @@ impl StatementGraph { self.emits.push((id, result)); "SubgroupBallot" } + S::SubgroupCollectiveOperation { + ref op, + ref collective_op, + argument, + result, + } => { + self.dependencies.push((id, argument, "arg")); + self.emits.push((id, result)); + "SubgroupCollectiveOperation" // FIXME + } + S::SubgroupBroadcast { + ref mode, + argument, + result, + } => { + self.dependencies.push((id, argument, "arg")); + self.emits.push((id, result)); + "SubgroupBroadcast" // FIXME + } }; // Set the last node to the merge node last_node = merge_id; @@ -591,6 +610,7 @@ fn write_function_expressions( (format!("rayQueryGet{}Intersection", ty).into(), 4) } E::SubgroupBallotResult => ("SubgroupBallotResult".into(), 4), + E::SubgroupOperationResult { .. } => ("SubgroupOperationResult".into(), 4), }; // give uniform expressions an outline diff --git a/src/back/glsl/mod.rs b/src/back/glsl/mod.rs index 7b0fc970cf..76bf2c6597 100644 --- a/src/back/glsl/mod.rs +++ b/src/back/glsl/mod.rs @@ -2260,6 +2260,21 @@ impl<'a, W: Write> Writer<'a, W> { writeln!(self.out, "subgroupBallot(true);")?; } + Statement::SubgroupCollectiveOperation { + ref op, + ref collective_op, + argument, + result, + } => { + unimplemented!(); // FIXME: + } + Statement::SubgroupBroadcast { + ref mode, + argument, + result, + } => { + unimplemented!(); // FIXME + } } Ok(()) @@ -3434,6 +3449,7 @@ impl<'a, W: Write> Writer<'a, W> { | Expression::AtomicResult { .. } | Expression::RayQueryProceedResult | Expression::WorkGroupUniformLoadResult { .. } + | Expression::SubgroupOperationResult { .. } | Expression::SubgroupBallotResult => unreachable!(), // `ArrayLength` is written as `expr.length()` and we convert it to a uint Expression::ArrayLength(expr) => { diff --git a/src/back/hlsl/writer.rs b/src/back/hlsl/writer.rs index 821dcc743d..b14d0c4bde 100644 --- a/src/back/hlsl/writer.rs +++ b/src/back/hlsl/writer.rs @@ -2013,6 +2013,21 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { writeln!(self.out, "WaveActiveBallot(true);")?; } + Statement::SubgroupCollectiveOperation { + ref op, + ref collective_op, + argument, + result, + } => { + unimplemented!(); // FIXME + } + Statement::SubgroupBroadcast { + ref mode, + argument, + result, + } => { + unimplemented!(); // FIXME + } } Ok(()) @@ -3162,7 +3177,8 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { | Expression::AtomicResult { .. } | Expression::WorkGroupUniformLoadResult { .. } | Expression::RayQueryProceedResult - | Expression::SubgroupBallotResult => {} + | Expression::SubgroupBallotResult + | Expression::SubgroupOperationResult { .. } => {} } if !closing_bracket.is_empty() { diff --git a/src/back/msl/writer.rs b/src/back/msl/writer.rs index c2a1eaddd0..cc4baf70ff 100644 --- a/src/back/msl/writer.rs +++ b/src/back/msl/writer.rs @@ -1936,6 +1936,7 @@ impl Writer { | crate::Expression::AtomicResult { .. } | crate::Expression::WorkGroupUniformLoadResult { .. } | crate::Expression::SubgroupBallotResult + | crate::Expression::SubgroupOperationResult { .. } | crate::Expression::RayQueryProceedResult => { unreachable!() } @@ -3018,6 +3019,21 @@ impl Writer { self.named_expressions.insert(result, name); write!(self.out, "{NAMESPACE}::simd_active_threads_mask();")?; } + crate::Statement::SubgroupCollectiveOperation { + ref op, + ref collective_op, + argument, + result, + } => { + unimplemented!(); // FIXME + } + crate::Statement::SubgroupBroadcast { + ref mode, + argument, + result, + } => { + unimplemented!(); // FIXME + } } } diff --git a/src/back/spv/block.rs b/src/back/spv/block.rs index 5c36cd8bcc..fa799b12c4 100644 --- a/src/back/spv/block.rs +++ b/src/back/spv/block.rs @@ -1131,7 +1131,8 @@ impl<'w> BlockContext<'w> { | crate::Expression::AtomicResult { .. } | crate::Expression::WorkGroupUniformLoadResult { .. } | crate::Expression::RayQueryProceedResult - | crate::Expression::SubgroupBallotResult => self.cached[expr_handle], + | crate::Expression::SubgroupBallotResult + | crate::Expression::SubgroupOperationResult { .. } => self.cached[expr_handle], crate::Expression::As { expr, kind, @@ -2361,6 +2362,21 @@ impl<'w> BlockContext<'w> { )); self.cached[result] = id; } + crate::Statement::SubgroupCollectiveOperation { + ref op, + ref collective_op, + argument, + result, + } => { + self.write_subgroup_operation(op, collective_op, argument, result, &mut block); + } + crate::Statement::SubgroupBroadcast { + ref mode, + argument, + result, + } => { + unimplemented!() // FIXME + } } } diff --git a/src/back/spv/instructions.rs b/src/back/spv/instructions.rs index 1ca58431d5..725014dee4 100644 --- a/src/back/spv/instructions.rs +++ b/src/back/spv/instructions.rs @@ -1052,6 +1052,23 @@ impl super::Instruction { instruction.add_operand(exec_scope_id); instruction.add_operand(predicate); + instruction + } + pub(super) fn group_non_uniform_arithmetic( + op: Op, + result_type_id: Word, + id: Word, + exec_scope_id: Word, + group_op: spirv::GroupOperation, + value: Word, + ) -> Self { + let mut instruction = Self::new(op); + instruction.set_type(result_type_id); + instruction.set_result(id); + instruction.add_operand(exec_scope_id); + instruction.add_operand(group_op as u32); + instruction.add_operand(value); + instruction } } diff --git a/src/back/spv/mod.rs b/src/back/spv/mod.rs index ac7281fc6b..a7a4bd302c 100644 --- a/src/back/spv/mod.rs +++ b/src/back/spv/mod.rs @@ -13,6 +13,7 @@ mod layout; mod ray; mod recyclable; mod selection; +mod subgroup; mod writer; pub use spirv::Capability; diff --git a/src/back/spv/subgroup.rs b/src/back/spv/subgroup.rs new file mode 100644 index 0000000000..a4f7b018f4 --- /dev/null +++ b/src/back/spv/subgroup.rs @@ -0,0 +1,83 @@ +use super::{Block, BlockContext, Error, Instruction}; +use crate::{arena::Handle, TypeInner}; + +impl<'w> BlockContext<'w> { + pub(super) fn write_subgroup_operation( + &mut self, + op: &crate::SubgroupOperation, + collective_op: &crate::CollectiveOperation, + argument: Handle, + result: Handle, + block: &mut Block, + ) -> Result<(), Error> { + self.writer.require_any( + "GroupNonUniformArithmetic", + &[ + spirv::Capability::GroupNonUniformArithmetic, + spirv::Capability::GroupNonUniformClustered, + spirv::Capability::GroupNonUniformPartitionedNV, + ], + )?; + + let id = self.gen_id(); + let result_ty = &self.fun_info[result].ty; + let result_type_id = self.get_expression_type_id(result_ty); + let result_ty_inner = result_ty.inner_with(&self.ir_module.types); + let kind = result_ty_inner.scalar_kind().unwrap(); + + let (is_scalar, kind) = match result_ty_inner { + TypeInner::Scalar { kind, .. } => (true, kind), + TypeInner::Vector { kind, .. } => (false, kind), + _ => unimplemented!(), + }; + + use crate::ScalarKind as sk; + use crate::SubgroupOperation as sg; + let spirv_op = match (kind, op) { + (sk::Bool, sg::All) if is_scalar => spirv::Op::GroupNonUniformAll, + (sk::Bool, sg::Any) if is_scalar => spirv::Op::GroupNonUniformAny, + (_, sg::All | sg::Any) => unimplemented!(), + + (sk::Sint | sk::Uint, sg::Add) => spirv::Op::GroupNonUniformIAdd, + (sk::Float, sg::Add) => spirv::Op::GroupNonUniformFAdd, + (sk::Sint | sk::Uint, sg::Mul) => spirv::Op::GroupNonUniformIMul, + (sk::Float, sg::Mul) => spirv::Op::GroupNonUniformFMul, + (sk::Sint, sg::Max) => spirv::Op::GroupNonUniformSMax, + (sk::Uint, sg::Max) => spirv::Op::GroupNonUniformUMax, + (sk::Float, sg::Max) => spirv::Op::GroupNonUniformFMax, + (sk::Sint, sg::Min) => spirv::Op::GroupNonUniformSMin, + (sk::Uint, sg::Min) => spirv::Op::GroupNonUniformUMin, + (sk::Float, sg::Min) => spirv::Op::GroupNonUniformFMin, + (sk::Bool, sg::Add | sg::Mul | sg::Min | sg::Max) => unimplemented!(), + + (sk::Sint | sk::Uint, sg::And) => spirv::Op::GroupNonUniformBitwiseAnd, + (sk::Sint | sk::Uint, sg::Or) => spirv::Op::GroupNonUniformBitwiseOr, + (sk::Sint | sk::Uint, sg::Xor) => spirv::Op::GroupNonUniformBitwiseXor, + (sk::Float, sg::And | sg::Or | sg::Xor) => unimplemented!(), + (sk::Bool, sg::And) => spirv::Op::GroupNonUniformLogicalAnd, + (sk::Bool, sg::Or) => spirv::Op::GroupNonUniformLogicalOr, + (sk::Bool, sg::Xor) => spirv::Op::GroupNonUniformLogicalXor, + }; + + let exec_scope_id = self.get_index_constant(spirv::Scope::Subgroup as u32); + + use crate::CollectiveOperation as c; + let group_op = match collective_op { + c::Reduce => spirv::GroupOperation::Reduce, + c::InclusiveScan => spirv::GroupOperation::InclusiveScan, + c::ExclusiveScan => spirv::GroupOperation::ExclusiveScan, + }; + + let arg_id = self.cached[argument]; + block.body.push(Instruction::group_non_uniform_arithmetic( + spirv_op, + result_type_id, + id, + exec_scope_id, + group_op, + arg_id, + )); + self.cached[result] = id; + Ok(()) + } +} diff --git a/src/back/wgsl/writer.rs b/src/back/wgsl/writer.rs index 33d45635c4..b66863f1df 100644 --- a/src/back/wgsl/writer.rs +++ b/src/back/wgsl/writer.rs @@ -929,6 +929,21 @@ impl Writer { writeln!(self.out, "subgroupBallot();")?; } + Statement::SubgroupCollectiveOperation { + ref op, + ref collective_op, + argument, + result, + } => { + unimplemented!() // FIXME + } + Statement::SubgroupBroadcast { + ref mode, + argument, + result, + } => { + unimplemented!() // FIXME + } } Ok(()) @@ -1668,6 +1683,7 @@ impl Writer { | Expression::AtomicResult { .. } | Expression::RayQueryProceedResult | Expression::SubgroupBallotResult + | Expression::SubgroupOperationResult { .. } | Expression::WorkGroupUniformLoadResult { .. } => {} } diff --git a/src/compact/expressions.rs b/src/compact/expressions.rs index d62d00a85f..1533362407 100644 --- a/src/compact/expressions.rs +++ b/src/compact/expressions.rs @@ -157,6 +157,7 @@ impl<'tracer> ExpressionTracer<'tracer> { Ex::AtomicResult { ty, comparison: _ } => self.trace_type(ty), Ex::WorkGroupUniformLoadResult { ty } => self.trace_type(ty), Ex::ArrayLength(expr) => work_list.push(expr), + Ex::SubgroupOperationResult { ty } => self.trace_type(ty), Ex::RayQueryGetIntersection { query, committed: _, @@ -351,6 +352,7 @@ impl ModuleMap { comparison: _, } => self.types.adjust(ty), Ex::WorkGroupUniformLoadResult { ref mut ty } => self.types.adjust(ty), + Ex::SubgroupOperationResult { ref mut ty } => self.types.adjust(ty), Ex::ArrayLength(ref mut expr) => adjust(expr), Ex::RayQueryGetIntersection { ref mut query, diff --git a/src/compact/statements.rs b/src/compact/statements.rs index 3b27c8b71a..462553b9d6 100644 --- a/src/compact/statements.rs +++ b/src/compact/statements.rs @@ -98,6 +98,26 @@ impl FunctionTracer<'_> { St::SubgroupBallot { result } => { self.trace_expression(result); } + St::SubgroupCollectiveOperation { + ref op, + ref collective_op, + argument, + result, + } => { + self.trace_expression(argument); + self.trace_expression(result); + } + St::SubgroupBroadcast { + ref mode, + argument, + result, + } => { + if let crate::BroadcastMode::Index(expr) = *mode { + self.trace_expression(expr); + } + self.trace_expression(argument); + self.trace_expression(result); + } // Trivial statements. St::Break @@ -248,6 +268,26 @@ impl FunctionMap { self.adjust_ray_query_function(fun); } St::SubgroupBallot { ref mut result } => adjust(result), + St::SubgroupCollectiveOperation { + ref mut op, + ref mut collective_op, + ref mut argument, + ref mut result, + } => { + adjust(argument); + adjust(result); + } + St::SubgroupBroadcast { + ref mut mode, + ref mut argument, + ref mut result, + } => { + if let crate::BroadcastMode::Index(expr) = mode { + adjust(expr); + } + adjust(argument); + adjust(result); + } // Trivial statements. St::Break diff --git a/src/front/spv/mod.rs b/src/front/spv/mod.rs index 2658a95d30..561e32e22d 100644 --- a/src/front/spv/mod.rs +++ b/src/front/spv/mod.rs @@ -3842,7 +3842,9 @@ impl> Frontend { } } S::WorkGroupUniformLoad { .. } => unreachable!(), - S::SubgroupBallot { .. } => unreachable!(), + S::SubgroupBallot { .. } => unreachable!(), // FIXME?? + S::SubgroupCollectiveOperation { .. } => unreachable!(), + S::SubgroupBroadcast { .. } => unreachable!(), } i += 1; } diff --git a/src/front/wgsl/lower/mod.rs b/src/front/wgsl/lower/mod.rs index 49319cde63..0443e46cec 100644 --- a/src/front/wgsl/lower/mod.rs +++ b/src/front/wgsl/lower/mod.rs @@ -2308,6 +2308,45 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { .push(crate::Statement::SubgroupBallot { result }, span); return Ok(Some(result)); } + "subgroupBroadcast" => { + unimplemented!(); // FIXME + } + "subgroupBroadcastFirst" => { + unimplemented!(); // FIXME + } + "subgroupAll" => { + unimplemented!(); // FIXME + } + "subgroupAny" => { + unimplemented!(); // FIXME + } + "subgroupAdd" => { + unimplemented!(); // FIXME + } + "subgroupMul" => { + unimplemented!(); // FIXME + } + "subgroupMin" => { + unimplemented!(); // FIXME + } + "subgroupMax" => { + unimplemented!(); // FIXME + } + "subgroupAnd" => { + unimplemented!(); // FIXME + } + "subgroupOr" => { + unimplemented!(); // FIXME + } + "subgroupXor" => { + unimplemented!(); // FIXME + } + "subgroupPrefixAdd" => { + unimplemented!(); // FIXME + } + "subgroupPrefixMul" => { + unimplemented!(); // FIXME + } _ => return Err(Error::UnknownIdent(function.span, function.name)), } }; diff --git a/src/lib.rs b/src/lib.rs index a6f755fae6..571c926e1d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1261,6 +1261,42 @@ pub enum SwizzleComponent { W = 3, } +#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] +#[cfg_attr(feature = "serialize", derive(Serialize))] +#[cfg_attr(feature = "deserialize", derive(Deserialize))] +#[cfg_attr(feature = "arbitrary", derive(Arbitrary))] +pub enum BroadcastMode { + First, + Index(Handle), +} + +#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] +#[cfg_attr(feature = "serialize", derive(Serialize))] +#[cfg_attr(feature = "deserialize", derive(Deserialize))] +#[cfg_attr(feature = "arbitrary", derive(Arbitrary))] +pub enum SubgroupOperation { + All, + Any, + Add, + Mul, + Min, + Max, + And, + Or, + Xor, +} + +#[repr(u8)] +#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] +#[cfg_attr(feature = "serialize", derive(Serialize))] +#[cfg_attr(feature = "deserialize", derive(Deserialize))] +#[cfg_attr(feature = "arbitrary", derive(Arbitrary))] +pub enum CollectiveOperation { + Reduce = 0, + InclusiveScan = 1, + ExclusiveScan = 2, +} + bitflags::bitflags! { /// Memory barrier flags. #[cfg_attr(feature = "serialize", derive(Serialize))] @@ -1578,6 +1614,9 @@ pub enum Expression { committed: bool, }, SubgroupBallotResult, + SubgroupOperationResult { + ty: Handle, + }, } pub use block::Block; @@ -1850,12 +1889,55 @@ pub enum Statement { /// The specific operation we're performing on `query`. fun: RayQueryFunction, }, + // subgroupBallot(bool) -> vec4 SubgroupBallot { /// The [`SubgroupBallotResult`] expression representing this load's result. /// /// [`SubgroupBallotResult`]: Expression::SubgroupBallotResult result: Handle, }, + + // subgroupBroadcast(value, lane) -> value + // subgroupBroadcastFirst(value) -> value + SubgroupBroadcast { + /// Specifies which thread to broadcast from + mode: BroadcastMode, + /// The value to broadcast over + argument: Handle, + /// The [`SubgroupBroadcastResult`] expression representing this load's result. + /// + /// [`SubgroupBroadcastResult`]: Expression::SubgroupBroadcastResult + result: Handle, + }, + + // Reduction on bool + // subgroupAll(bool) -> bool + // subgroupAny(bool) -> bool + // Reduction on float, int + // subgroupMin(value) -> value + // subgroupMax(value) -> value + // subgroupAdd(value) -> value + // subgroupMul(value) -> value + // Reduction on int + // subgroupAnd(value) -> value + // subgroupOr(value) -> value + // subgroupXor(value) -> value + // Scan on float, int + // subgroupPrefixAdd(value) -> value + // subgroupPrefixMul(value) -> value + /// Compute a collective operation across all active threads in th subgroup + SubgroupCollectiveOperation { + /// What operation to compute + op: SubgroupOperation, + /// How to combine the results + collective_op: CollectiveOperation, + /// The value to compute over + argument: Handle, + /// The [`SubgroupOperationResult`] expression representing this load's result. + /// + /// [`SubgroupOperationResult`]: Expression::SubgroupOperationResult + result: Handle, + }, } /// A function argument. diff --git a/src/proc/terminator.rs b/src/proc/terminator.rs index d2dde729f1..35111a11de 100644 --- a/src/proc/terminator.rs +++ b/src/proc/terminator.rs @@ -38,6 +38,8 @@ pub fn ensure_block_returns(block: &mut crate::Block) { | S::Atomic { .. } | S::WorkGroupUniformLoad { .. } | S::SubgroupBallot { .. } + | S::SubgroupCollectiveOperation { .. } + | S::SubgroupBroadcast { .. } | S::Barrier(_)), ) | None => block.push(S::Return { value: None }, Default::default()), diff --git a/src/proc/typifier.rs b/src/proc/typifier.rs index 6241c5bad8..c2b38ac73b 100644 --- a/src/proc/typifier.rs +++ b/src/proc/typifier.rs @@ -638,6 +638,7 @@ impl<'a> ResolveContext<'a> { | crate::BinaryOperator::ShiftRight => past(left)?.clone(), }, crate::Expression::AtomicResult { ty, .. } => TypeResolution::Handle(ty), + crate::Expression::SubgroupOperationResult { ty } => TypeResolution::Handle(ty), crate::Expression::WorkGroupUniformLoadResult { ty } => TypeResolution::Handle(ty), crate::Expression::Select { accept, .. } => past(accept)?.clone(), crate::Expression::Derivative { expr, .. } => past(expr)?.clone(), diff --git a/src/valid/analyzer.rs b/src/valid/analyzer.rs index d23caaf473..e9ca4ee7c7 100644 --- a/src/valid/analyzer.rs +++ b/src/valid/analyzer.rs @@ -741,7 +741,11 @@ impl FunctionInfo { requirements: UniformityRequirements::empty(), }, E::SubgroupBallotResult => Uniformity { - non_uniform_result: None, + non_uniform_result: None, // FIXME + requirements: UniformityRequirements::empty(), + }, + E::SubgroupOperationResult { ty } => Uniformity { + non_uniform_result: None, // FIXME requirements: UniformityRequirements::empty(), }, }; @@ -988,6 +992,26 @@ impl FunctionInfo { FunctionUniformity::new() } S::SubgroupBallot { result: _ } => FunctionUniformity::new(), + S::SubgroupCollectiveOperation { + ref op, + ref collective_op, + argument, + result: _, + } => { + let _ = self.add_ref(argument); + FunctionUniformity::new() + } + S::SubgroupBroadcast { + ref mode, + argument, + result, + } => { + let _ = self.add_ref(argument); + if let crate::BroadcastMode::Index(expr) = *mode { + let _ = self.add_ref(expr); + } + FunctionUniformity::new() + } }; disruptor = disruptor.or(uniformity.exit_disruptor()); diff --git a/src/valid/expression.rs b/src/valid/expression.rs index 890a4b9973..f840bba42d 100644 --- a/src/valid/expression.rs +++ b/src/valid/expression.rs @@ -1538,6 +1538,7 @@ impl super::Validator { } }, E::SubgroupBallotResult => ShaderStages::COMPUTE | ShaderStages::FRAGMENT, + E::SubgroupOperationResult { ty } => ShaderStages::COMPUTE, // FIXME }; Ok(stages) } diff --git a/src/valid/function.rs b/src/valid/function.rs index 52f51a2810..090fffd4c6 100644 --- a/src/valid/function.rs +++ b/src/valid/function.rs @@ -51,6 +51,15 @@ pub enum AtomicError { ResultTypeMismatch(Handle), } +#[derive(Clone, Debug, thiserror::Error)] +#[cfg_attr(test, derive(PartialEq))] +pub enum SubgroupError { + #[error("Operand {0:?} has invalid type.")] + InvalidOperand(Handle), + #[error("Result type for {0:?} doesn't match the statement")] + ResultTypeMismatch(Handle), +} + #[derive(Clone, Debug, thiserror::Error)] #[cfg_attr(test, derive(PartialEq))] pub enum LocalVariableError { @@ -159,6 +168,8 @@ pub enum FunctionError { WorkgroupUniformLoadExpressionMismatch(Handle), #[error("The expression {0:?} is not valid as a WorkGroupUniformLoad argument. It should be a Pointer in Workgroup address space")] WorkgroupUniformLoadInvalidPointer(Handle), + #[error("Subgroup operation is invalid")] + InvalidSubgroup(#[from] SubgroupError), } bitflags::bitflags! { @@ -413,6 +424,102 @@ impl super::Validator { } Ok(()) } + #[cfg(feature = "validate")] + fn validate_subgroup_operation( + &mut self, + op: &crate::SubgroupOperation, + _collective_op: &crate::CollectiveOperation, + argument: Handle, + result: Handle, + context: &BlockContext, + ) -> Result<(), WithSpan> { + self.emit_expression(argument, context)?; + let argument_inner = context.resolve_type(argument, &self.valid_expression_set)?; + + let (is_scalar, kind) = match argument_inner { + crate::TypeInner::Scalar { kind, .. } => (true, kind), + crate::TypeInner::Vector { kind, .. } => (false, kind), + _ => unimplemented!(), + }; + + use crate::ScalarKind as sk; + use crate::SubgroupOperation as sg; + match (kind, op) { + (sk::Bool, sg::All | sg::Any) if is_scalar => {} + (sk::Sint | sk::Uint | sk::Float, sg::Add | sg::Mul | sg::Min | sg::Max) => {} + (sk::Sint | sk::Uint | sk::Bool, sg::And | sg::Or | sg::Xor) => {} + + (_, sg::All | sg::Any) + | (sk::Bool, sg::Add | sg::Mul | sg::Min | sg::Max) + | (sk::Float, sg::And | sg::Or | sg::Xor) => { + log::error!("Subgroup operand type {:?}", argument_inner); + return Err(SubgroupError::InvalidOperand(argument) + .with_span_handle(argument, context.expressions) + .into_other()); + } + }; + + self.emit_expression(result, context)?; + match context.expressions[result] { + crate::Expression::SubgroupOperationResult { ty } + if { &context.types[ty].inner == argument_inner } => {} + _ => { + return Err(SubgroupError::ResultTypeMismatch(result) + .with_span_handle(result, context.expressions) + .into_other()) + } + } + Ok(()) + } + #[cfg(feature = "validate")] + fn validate_subgroup_broadcast( + &mut self, + mode: &crate::BroadcastMode, + argument: Handle, + result: Handle, + context: &BlockContext, + ) -> Result<(), WithSpan> { + if let crate::BroadcastMode::Index(expr) = *mode { + self.emit_expression(expr, context)?; + let index_ty = context.resolve_type(expr, &self.valid_expression_set)?; + match index_ty { + crate::TypeInner::Scalar { + kind: crate::ScalarKind::Uint, + .. + } => {} + _ => { + log::error!("Subgroup broadcast index type {:?}", index_ty); + return Err(SubgroupError::InvalidOperand(argument) + .with_span_handle(argument, context.expressions) + .into_other()); + } + } + } + self.emit_expression(argument, context)?; + let argument_inner = context.resolve_type(argument, &self.valid_expression_set)?; + + match argument_inner { + crate::TypeInner::Scalar { .. } | crate::TypeInner::Vector { .. } => {} + _ => { + log::error!("Subgroup operand type {:?}", argument_inner); + return Err(SubgroupError::InvalidOperand(argument) + .with_span_handle(argument, context.expressions) + .into_other()); + } + } + + self.emit_expression(result, context)?; + match context.expressions[result] { + crate::Expression::SubgroupOperationResult { ty } + if { &context.types[ty].inner == argument_inner } => {} + _ => { + return Err(SubgroupError::ResultTypeMismatch(result) + .with_span_handle(result, context.expressions) + .into_other()) + } + } + Ok(()) + } #[cfg(feature = "validate")] fn validate_block_impl( @@ -922,6 +1029,21 @@ impl super::Validator { S::SubgroupBallot { result } => { self.emit_expression(result, context)?; } + S::SubgroupCollectiveOperation { + ref op, + ref collective_op, + argument, + result, + } => { + self.validate_subgroup_operation(op, collective_op, argument, result, context)?; + } + S::SubgroupBroadcast { + ref mode, + argument, + result, + } => { + self.validate_subgroup_broadcast(mode, argument, result, context)?; + } } } Ok(BlockInfo { stages, finished }) diff --git a/src/valid/handles.rs b/src/valid/handles.rs index 547dfac551..399665342f 100644 --- a/src/valid/handles.rs +++ b/src/valid/handles.rs @@ -395,6 +395,7 @@ impl super::Validator { crate::Expression::AtomicResult { .. } | crate::Expression::RayQueryProceedResult | crate::Expression::SubgroupBallotResult + | crate::Expression::SubgroupOperationResult { .. } | crate::Expression::WorkGroupUniformLoadResult { .. } => (), crate::Expression::ArrayLength(array) => { handle.check_dep(array)?; @@ -544,6 +545,27 @@ impl super::Validator { validate_expr(result)?; Ok(()) } + crate::Statement::SubgroupCollectiveOperation { + op, + collective_op, + argument, + result, + } => { + validate_expr(argument)?; + validate_expr(result)?; + Ok(()) + } + crate::Statement::SubgroupBroadcast { + mode, + argument, + result, + } => { + if let crate::BroadcastMode::Index(expr) = mode { + validate_expr(expr)?; + } + validate_expr(result)?; + Ok(()) + } crate::Statement::Break | crate::Statement::Continue | crate::Statement::Kill From 6c6c4beab18b596ec550a743524ba0c89aff206d Mon Sep 17 00:00:00 2001 From: Jacob Hughes Date: Thu, 5 Oct 2023 09:43:29 -0400 Subject: [PATCH 08/17] subgroup: fix doc error on SubgroupBroadcast --- src/lib.rs | 21 ++------------------- 1 file changed, 2 insertions(+), 19 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 571c926e1d..ca7cca8fc0 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1897,34 +1897,17 @@ pub enum Statement { result: Handle, }, - // subgroupBroadcast(value, lane) -> value - // subgroupBroadcastFirst(value) -> value SubgroupBroadcast { /// Specifies which thread to broadcast from mode: BroadcastMode, /// The value to broadcast over argument: Handle, - /// The [`SubgroupBroadcastResult`] expression representing this load's result. + /// The [`SubgroupOperationResult`] expression representing this load's result. /// - /// [`SubgroupBroadcastResult`]: Expression::SubgroupBroadcastResult + /// [`SubgroupOperationResult`]: Expression::SubgroupOperationResult result: Handle, }, - // Reduction on bool - // subgroupAll(bool) -> bool - // subgroupAny(bool) -> bool - // Reduction on float, int - // subgroupMin(value) -> value - // subgroupMax(value) -> value - // subgroupAdd(value) -> value - // subgroupMul(value) -> value - // Reduction on int - // subgroupAnd(value) -> value - // subgroupOr(value) -> value - // subgroupXor(value) -> value - // Scan on float, int - // subgroupPrefixAdd(value) -> value - // subgroupPrefixMul(value) -> value /// Compute a collective operation across all active threads in th subgroup SubgroupCollectiveOperation { /// What operation to compute From 77c4e12bb10673abd34d8e776ae5595e345aa48b Mon Sep 17 00:00:00 2001 From: Jacob Hughes Date: Thu, 5 Oct 2023 11:46:55 -0400 Subject: [PATCH 09/17] subgroup: wgsl-in and spv-out for subgroup operations --- src/back/spv/block.rs | 4 +- src/back/spv/instructions.rs | 40 ++++++++++++- src/back/spv/subgroup.rs | 84 ++++++++++++++++++++++----- src/front/wgsl/lower/mod.rs | 108 +++++++++++++++++++++++------------ src/front/wgsl/parse/conv.rs | 21 +++++++ src/valid/function.rs | 12 ++-- 6 files changed, 210 insertions(+), 59 deletions(-) diff --git a/src/back/spv/block.rs b/src/back/spv/block.rs index fa799b12c4..b2366f3b16 100644 --- a/src/back/spv/block.rs +++ b/src/back/spv/block.rs @@ -2368,14 +2368,14 @@ impl<'w> BlockContext<'w> { argument, result, } => { - self.write_subgroup_operation(op, collective_op, argument, result, &mut block); + self.write_subgroup_operation(op, collective_op, argument, result, &mut block)?; } crate::Statement::SubgroupBroadcast { ref mode, argument, result, } => { - unimplemented!() // FIXME + self.write_subgroup_broadcast(mode, argument, result, &mut block)?; } } } diff --git a/src/back/spv/instructions.rs b/src/back/spv/instructions.rs index 725014dee4..8a528065ef 100644 --- a/src/back/spv/instructions.rs +++ b/src/back/spv/instructions.rs @@ -1054,19 +1054,55 @@ impl super::Instruction { instruction } + pub(super) fn group_non_uniform_broadcast( + result_type_id: Word, + id: Word, + exec_scope_id: Word, + value: Word, + index: Word, + ) -> Self { + let mut instruction = Self::new(Op::GroupNonUniformBroadcast); + instruction.set_type(result_type_id); + instruction.set_result(id); + instruction.add_operand(exec_scope_id); + instruction.add_operand(value); + instruction.add_operand(index); + + instruction + } + pub(super) fn group_non_uniform_broadcast_first( + result_type_id: Word, + id: Word, + exec_scope_id: Word, + value: Word, + ) -> Self { + let mut instruction = Self::new(Op::GroupNonUniformBroadcastFirst); + instruction.set_type(result_type_id); + instruction.set_result(id); + instruction.add_operand(exec_scope_id); + instruction.add_operand(value); + + instruction + } pub(super) fn group_non_uniform_arithmetic( op: Op, result_type_id: Word, id: Word, exec_scope_id: Word, - group_op: spirv::GroupOperation, + group_op: Option, value: Word, ) -> Self { + println!( + "{:?}", + (op, result_type_id, id, exec_scope_id, group_op, value) + ); let mut instruction = Self::new(op); instruction.set_type(result_type_id); instruction.set_result(id); instruction.add_operand(exec_scope_id); - instruction.add_operand(group_op as u32); + if let Some(group_op) = group_op { + instruction.add_operand(group_op as u32); + } instruction.add_operand(value); instruction diff --git a/src/back/spv/subgroup.rs b/src/back/spv/subgroup.rs index a4f7b018f4..3c8b7d827a 100644 --- a/src/back/spv/subgroup.rs +++ b/src/back/spv/subgroup.rs @@ -10,20 +10,30 @@ impl<'w> BlockContext<'w> { result: Handle, block: &mut Block, ) -> Result<(), Error> { - self.writer.require_any( - "GroupNonUniformArithmetic", - &[ - spirv::Capability::GroupNonUniformArithmetic, - spirv::Capability::GroupNonUniformClustered, - spirv::Capability::GroupNonUniformPartitionedNV, - ], - )?; + use crate::SubgroupOperation as sg; + match op { + sg::All | sg::Any => { + self.writer.require_any( + "GroupNonUniformVote", + &[spirv::Capability::GroupNonUniformVote], + )?; + } + _ => { + self.writer.require_any( + "GroupNonUniformArithmetic", + &[ + spirv::Capability::GroupNonUniformArithmetic, + spirv::Capability::GroupNonUniformClustered, + spirv::Capability::GroupNonUniformPartitionedNV, + ], + )?; + } + } let id = self.gen_id(); let result_ty = &self.fun_info[result].ty; let result_type_id = self.get_expression_type_id(result_ty); let result_ty_inner = result_ty.inner_with(&self.ir_module.types); - let kind = result_ty_inner.scalar_kind().unwrap(); let (is_scalar, kind) = match result_ty_inner { TypeInner::Scalar { kind, .. } => (true, kind), @@ -32,7 +42,6 @@ impl<'w> BlockContext<'w> { }; use crate::ScalarKind as sk; - use crate::SubgroupOperation as sg; let spirv_op = match (kind, op) { (sk::Bool, sg::All) if is_scalar => spirv::Op::GroupNonUniformAll, (sk::Bool, sg::Any) if is_scalar => spirv::Op::GroupNonUniformAny, @@ -62,10 +71,13 @@ impl<'w> BlockContext<'w> { let exec_scope_id = self.get_index_constant(spirv::Scope::Subgroup as u32); use crate::CollectiveOperation as c; - let group_op = match collective_op { - c::Reduce => spirv::GroupOperation::Reduce, - c::InclusiveScan => spirv::GroupOperation::InclusiveScan, - c::ExclusiveScan => spirv::GroupOperation::ExclusiveScan, + let group_op = match op { + sg::All | sg::Any => None, + _ => Some(match collective_op { + c::Reduce => spirv::GroupOperation::Reduce, + c::InclusiveScan => spirv::GroupOperation::InclusiveScan, + c::ExclusiveScan => spirv::GroupOperation::ExclusiveScan, + }), }; let arg_id = self.cached[argument]; @@ -80,4 +92,48 @@ impl<'w> BlockContext<'w> { self.cached[result] = id; Ok(()) } + pub(super) fn write_subgroup_broadcast( + &mut self, + mode: &crate::BroadcastMode, + argument: Handle, + result: Handle, + block: &mut Block, + ) -> Result<(), Error> { + self.writer.require_any( + "GroupNonUniformBallot", + &[spirv::Capability::GroupNonUniformBallot], + )?; + + let id = self.gen_id(); + let result_ty = &self.fun_info[result].ty; + let result_type_id = self.get_expression_type_id(result_ty); + + let exec_scope_id = self.get_index_constant(spirv::Scope::Subgroup as u32); + + let arg_id = self.cached[argument]; + match mode { + crate::BroadcastMode::Index(index) => { + let index_id = self.cached[*index]; + block.body.push(Instruction::group_non_uniform_broadcast( + result_type_id, + id, + exec_scope_id, + arg_id, + index_id, + )); + } + crate::BroadcastMode::First => { + block + .body + .push(Instruction::group_non_uniform_broadcast_first( + result_type_id, + id, + exec_scope_id, + arg_id, + )); + } + } + self.cached[result] = id; + Ok(()) + } } diff --git a/src/front/wgsl/lower/mod.rs b/src/front/wgsl/lower/mod.rs index 0443e46cec..0d1d7f31cd 100644 --- a/src/front/wgsl/lower/mod.rs +++ b/src/front/wgsl/lower/mod.rs @@ -1917,6 +1917,8 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { } } else if let Some(fun) = Texture::map(function.name) { self.texture_sample_helper(fun, arguments, span, ctx.reborrow())? + } else if let Some((op, cop)) = conv::map_subgroup_operation(function.name) { + return Ok(Some(self.subgroup_helper(span, op, cop, arguments, ctx)?)); } else { match function.name { "select" => { @@ -2309,43 +2311,51 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { return Ok(Some(result)); } "subgroupBroadcast" => { - unimplemented!(); // FIXME + let mut args = ctx.prepare_args(arguments, 2, span); + + let index = self.expression(args.next()?, ctx.reborrow())?; + let argument = self.expression(args.next()?, ctx.reborrow())?; + args.finish()?; + + let ty = ctx.register_type(argument)?; + + let result = ctx.interrupt_emitter( + crate::Expression::SubgroupOperationResult { ty }, + span, + ); + let rctx = ctx.runtime_expression_ctx(span)?; + rctx.block.push( + crate::Statement::SubgroupBroadcast { + mode: crate::BroadcastMode::Index(index), + argument, + result, + }, + span, + ); + return Ok(Some(result)); } "subgroupBroadcastFirst" => { - unimplemented!(); // FIXME - } - "subgroupAll" => { - unimplemented!(); // FIXME - } - "subgroupAny" => { - unimplemented!(); // FIXME - } - "subgroupAdd" => { - unimplemented!(); // FIXME - } - "subgroupMul" => { - unimplemented!(); // FIXME - } - "subgroupMin" => { - unimplemented!(); // FIXME - } - "subgroupMax" => { - unimplemented!(); // FIXME - } - "subgroupAnd" => { - unimplemented!(); // FIXME - } - "subgroupOr" => { - unimplemented!(); // FIXME - } - "subgroupXor" => { - unimplemented!(); // FIXME - } - "subgroupPrefixAdd" => { - unimplemented!(); // FIXME - } - "subgroupPrefixMul" => { - unimplemented!(); // FIXME + let mut args = ctx.prepare_args(arguments, 1, span); + + let argument = self.expression(args.next()?, ctx.reborrow())?; + args.finish()?; + + let ty = ctx.register_type(argument)?; + + let result = ctx.interrupt_emitter( + crate::Expression::SubgroupOperationResult { ty }, + span, + ); + let rctx = ctx.runtime_expression_ctx(span)?; + rctx.block.push( + crate::Statement::SubgroupBroadcast { + mode: crate::BroadcastMode::First, + argument, + result, + }, + span, + ); + return Ok(Some(result)); } _ => return Err(Error::UnknownIdent(function.span, function.name)), } @@ -2538,6 +2548,34 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { depth_ref, }) } + fn subgroup_helper( + &mut self, + span: Span, + op: crate::SubgroupOperation, + collective_op: crate::CollectiveOperation, + arguments: &[Handle>], + mut ctx: ExpressionContext<'source, '_, '_>, + ) -> Result, Error<'source>> { + let mut args = ctx.prepare_args(arguments, 1, span); + + let argument = self.expression(args.next()?, ctx.reborrow())?; + args.finish()?; + + let ty = ctx.register_type(argument)?; + + let result = ctx.interrupt_emitter(crate::Expression::SubgroupOperationResult { ty }, span); + let rctx = ctx.runtime_expression_ctx(span)?; + rctx.block.push( + crate::Statement::SubgroupCollectiveOperation { + op, + collective_op, + argument, + result, + }, + span, + ); + Ok(result) + } fn r#struct( &mut self, diff --git a/src/front/wgsl/parse/conv.rs b/src/front/wgsl/parse/conv.rs index a27bdb1cbc..160213e4a3 100644 --- a/src/front/wgsl/parse/conv.rs +++ b/src/front/wgsl/parse/conv.rs @@ -238,3 +238,24 @@ pub fn map_conservative_depth( _ => Err(Error::UnknownConservativeDepth(span)), } } + +pub fn map_subgroup_operation( + word: &str, +) -> Option<(crate::SubgroupOperation, crate::CollectiveOperation)> { + use crate::CollectiveOperation as co; + use crate::SubgroupOperation as sg; + Some(match word { + "subgroupAll" => (sg::All, co::Reduce), + "subgroupAny" => (sg::Any, co::Reduce), + "subgroupAdd" => (sg::Add, co::Reduce), + "subgroupMul" => (sg::Mul, co::Reduce), + "subgroupMin" => (sg::Min, co::Reduce), + "subgroupMax" => (sg::Max, co::Reduce), + "subgroupAnd" => (sg::And, co::Reduce), + "subgroupOr" => (sg::Or, co::Reduce), + "subgroupXor" => (sg::Xor, co::Reduce), + "subgroupPrefixAdd" => (sg::Add, co::InclusiveScan), + "subgroupPrefixMul" => (sg::Mul, co::InclusiveScan), + _ => return None, + }) +} diff --git a/src/valid/function.rs b/src/valid/function.rs index 090fffd4c6..171e6a747b 100644 --- a/src/valid/function.rs +++ b/src/valid/function.rs @@ -433,7 +433,6 @@ impl super::Validator { result: Handle, context: &BlockContext, ) -> Result<(), WithSpan> { - self.emit_expression(argument, context)?; let argument_inner = context.resolve_type(argument, &self.valid_expression_set)?; let (is_scalar, kind) = match argument_inner { @@ -480,7 +479,6 @@ impl super::Validator { context: &BlockContext, ) -> Result<(), WithSpan> { if let crate::BroadcastMode::Index(expr) = *mode { - self.emit_expression(expr, context)?; let index_ty = context.resolve_type(expr, &self.valid_expression_set)?; match index_ty { crate::TypeInner::Scalar { @@ -488,20 +486,22 @@ impl super::Validator { .. } => {} _ => { - log::error!("Subgroup broadcast index type {:?}", index_ty); + log::error!( + "Subgroup broadcast index type {:?}, expected unsigned int", + index_ty + ); return Err(SubgroupError::InvalidOperand(argument) - .with_span_handle(argument, context.expressions) + .with_span_handle(expr, context.expressions) .into_other()); } } } - self.emit_expression(argument, context)?; let argument_inner = context.resolve_type(argument, &self.valid_expression_set)?; match argument_inner { crate::TypeInner::Scalar { .. } | crate::TypeInner::Vector { .. } => {} _ => { - log::error!("Subgroup operand type {:?}", argument_inner); + log::error!("Subgroup broadcast operand type {:?}", argument_inner); return Err(SubgroupError::InvalidOperand(argument) .with_span_handle(argument, context.expressions) .into_other()); From 4cc1475a70f1abbf33a6d2951f58d1045284d1e8 Mon Sep 17 00:00:00 2001 From: Jacob Hughes Date: Thu, 5 Oct 2023 13:06:47 -0400 Subject: [PATCH 10/17] subgroup: add optional predicate for subgroupBallot --- src/back/dot/mod.rs | 5 ++++- src/back/glsl/mod.rs | 9 +++++++-- src/back/hlsl/writer.rs | 9 +++++++-- src/back/msl/writer.rs | 11 +++++++++-- src/back/spv/block.rs | 7 +++++-- src/back/wgsl/writer.rs | 8 ++++++-- src/compact/statements.rs | 15 +++++++++++++-- src/front/wgsl/lower/mod.rs | 10 ++++++++-- src/lib.rs | 2 ++ src/valid/analyzer.rs | 12 ++++++++++-- src/valid/function.rs | 21 ++++++++++++++++++++- src/valid/handles.rs | 3 ++- 12 files changed, 93 insertions(+), 19 deletions(-) diff --git a/src/back/dot/mod.rs b/src/back/dot/mod.rs index 5ba3ffe49b..fec2c60d32 100644 --- a/src/back/dot/mod.rs +++ b/src/back/dot/mod.rs @@ -279,7 +279,10 @@ impl StatementGraph { crate::RayQueryFunction::Terminate => "RayQueryTerminate", } } - S::SubgroupBallot { result } => { + S::SubgroupBallot { result, predicate } => { + if let Some(predicate) = predicate { + self.dependencies.push((id, predicate, "predicate")); + } self.emits.push((id, result)); "SubgroupBallot" } diff --git a/src/back/glsl/mod.rs b/src/back/glsl/mod.rs index 76bf2c6597..c3a80c652d 100644 --- a/src/back/glsl/mod.rs +++ b/src/back/glsl/mod.rs @@ -2250,7 +2250,7 @@ impl<'a, W: Write> Writer<'a, W> { writeln!(self.out, ");")?; } Statement::RayQuery { .. } => unreachable!(), - Statement::SubgroupBallot { result } => { + Statement::SubgroupBallot { result, predicate } => { write!(self.out, "{level}")?; let res_name = format!("{}{}", back::BAKE_PREFIX, result.index()); let res_ty = ctx.info[result].ty.inner_with(&self.module.types); @@ -2258,7 +2258,12 @@ impl<'a, W: Write> Writer<'a, W> { write!(self.out, " {res_name} = ")?; self.named_expressions.insert(result, res_name); - writeln!(self.out, "subgroupBallot(true);")?; + write!(self.out, "subgroupBallot(")?; + match predicate { + Some(predicate) => self.write_expr(predicate, ctx)?, + None => write!(self.out, "true")?, + } + write!(self.out, ");")?; } Statement::SubgroupCollectiveOperation { ref op, diff --git a/src/back/hlsl/writer.rs b/src/back/hlsl/writer.rs index b14d0c4bde..98b60f4954 100644 --- a/src/back/hlsl/writer.rs +++ b/src/back/hlsl/writer.rs @@ -2004,14 +2004,19 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { writeln!(self.out, "{level}}}")? } Statement::RayQuery { .. } => unreachable!(), - Statement::SubgroupBallot { result } => { + Statement::SubgroupBallot { result, predicate } => { write!(self.out, "{level}")?; let name = format!("{}{}", back::BAKE_PREFIX, result.index()); write!(self.out, "const uint4 {name} = ")?; self.named_expressions.insert(result, name); - writeln!(self.out, "WaveActiveBallot(true);")?; + write!(self.out, "WaveActiveBallot(")?; + match predicate { + Some(predicate) => self.write_expr(module, predicate, func_ctx)?, + None => write!(self.out, "true")?, + } + writeln!(self.out, ");")?; } Statement::SubgroupCollectiveOperation { ref op, diff --git a/src/back/msl/writer.rs b/src/back/msl/writer.rs index cc4baf70ff..30cf02ddc4 100644 --- a/src/back/msl/writer.rs +++ b/src/back/msl/writer.rs @@ -3012,12 +3012,19 @@ impl Writer { } } } - crate::Statement::SubgroupBallot { result } => { + crate::Statement::SubgroupBallot { result, predicate } => { write!(self.out, "{level}")?; let name = self.namer.call(""); self.start_baking_expression(result, &context.expression, &name)?; self.named_expressions.insert(result, name); - write!(self.out, "{NAMESPACE}::simd_active_threads_mask();")?; + write!(self.out, "{NAMESPACE}::simd_ballot(;")?; + match predicate { + Some(predicate) => { + self.put_expression(predicate, &context.expression, true)? + } + None => write!(self.out, "true")?, + } + writeln!(self.out, ");")?; } crate::Statement::SubgroupCollectiveOperation { ref op, diff --git a/src/back/spv/block.rs b/src/back/spv/block.rs index b2366f3b16..222d0cde39 100644 --- a/src/back/spv/block.rs +++ b/src/back/spv/block.rs @@ -2340,7 +2340,7 @@ impl<'w> BlockContext<'w> { crate::Statement::RayQuery { query, ref fun } => { self.write_ray_query_function(query, fun, &mut block); } - crate::Statement::SubgroupBallot { result } => { + crate::Statement::SubgroupBallot { result, predicate } => { self.writer.require_any( "GroupNonUniformBallot", &[spirv::Capability::GroupNonUniformBallot], @@ -2352,7 +2352,10 @@ impl<'w> BlockContext<'w> { pointer_space: None, })); let exec_scope_id = self.get_index_constant(spirv::Scope::Subgroup as u32); - let predicate = self.writer.get_constant_scalar(crate::Literal::Bool(true)); + let predicate = match predicate { + Some(predicate) => self.cached[predicate], + None => self.writer.get_constant_scalar(crate::Literal::Bool(true)), + }; let id = self.gen_id(); block.body.push(Instruction::group_non_uniform_ballot( vec4_u32_type_id, diff --git a/src/back/wgsl/writer.rs b/src/back/wgsl/writer.rs index b66863f1df..91718d4f4d 100644 --- a/src/back/wgsl/writer.rs +++ b/src/back/wgsl/writer.rs @@ -921,13 +921,17 @@ impl Writer { } } Statement::RayQuery { .. } => unreachable!(), - Statement::SubgroupBallot { result } => { + Statement::SubgroupBallot { result, predicate } => { write!(self.out, "{level}")?; let res_name = format!("{}{}", back::BAKE_PREFIX, result.index()); self.start_named_expr(module, result, func_ctx, &res_name)?; self.named_expressions.insert(result, res_name); - writeln!(self.out, "subgroupBallot();")?; + writeln!(self.out, "subgroupBallot(")?; + if let Some(predicate) = predicate { + self.write_expr(module, predicate, func_ctx)?; + } + writeln!(self.out, ");")?; } Statement::SubgroupCollectiveOperation { ref op, diff --git a/src/compact/statements.rs b/src/compact/statements.rs index 462553b9d6..0e8c0e4e81 100644 --- a/src/compact/statements.rs +++ b/src/compact/statements.rs @@ -95,7 +95,10 @@ impl FunctionTracer<'_> { self.trace_expression(query); self.trace_ray_query_function(fun); } - St::SubgroupBallot { result } => { + St::SubgroupBallot { result, predicate } => { + if let Some(predicate) = predicate { + self.trace_expression(predicate); + } self.trace_expression(result); } St::SubgroupCollectiveOperation { @@ -267,7 +270,15 @@ impl FunctionMap { adjust(query); self.adjust_ray_query_function(fun); } - St::SubgroupBallot { ref mut result } => adjust(result), + St::SubgroupBallot { + ref mut result, + ref mut predicate, + } => { + if let Some(ref mut predicate) = predicate { + adjust(predicate); + } + adjust(result); + } St::SubgroupCollectiveOperation { ref mut op, ref mut collective_op, diff --git a/src/front/wgsl/lower/mod.rs b/src/front/wgsl/lower/mod.rs index 0d1d7f31cd..d0e6ee3b40 100644 --- a/src/front/wgsl/lower/mod.rs +++ b/src/front/wgsl/lower/mod.rs @@ -2301,13 +2301,19 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { return Ok(Some(handle)); } "subgroupBallot" => { - ctx.prepare_args(arguments, 0, span).finish()?; + let mut args = ctx.prepare_args(arguments, 0, span); + let predicate = if arguments.len() == 1 { + Some(self.expression(args.next()?, ctx.reborrow())?) + } else { + None + }; + args.finish()?; let result = ctx .interrupt_emitter(crate::Expression::SubgroupBallotResult, span); let rctx = ctx.runtime_expression_ctx(span)?; rctx.block - .push(crate::Statement::SubgroupBallot { result }, span); + .push(crate::Statement::SubgroupBallot { result, predicate }, span); return Ok(Some(result)); } "subgroupBroadcast" => { diff --git a/src/lib.rs b/src/lib.rs index ca7cca8fc0..a74e86b33f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1895,6 +1895,8 @@ pub enum Statement { /// /// [`SubgroupBallotResult`]: Expression::SubgroupBallotResult result: Handle, + /// The value from this thread to store in the ballot + predicate: Option>, }, SubgroupBroadcast { diff --git a/src/valid/analyzer.rs b/src/valid/analyzer.rs index e9ca4ee7c7..262e3ffec7 100644 --- a/src/valid/analyzer.rs +++ b/src/valid/analyzer.rs @@ -991,14 +991,22 @@ impl FunctionInfo { } FunctionUniformity::new() } - S::SubgroupBallot { result: _ } => FunctionUniformity::new(), + S::SubgroupBallot { + result: _, + predicate, + } => { + if let Some(predicate) = predicate { + let _ = self.add_ref(predicate); + } + FunctionUniformity::new() + } S::SubgroupCollectiveOperation { ref op, ref collective_op, argument, result: _, } => { - let _ = self.add_ref(argument); + let _ = self.add_ref(argument); // FIXME FunctionUniformity::new() } S::SubgroupBroadcast { diff --git a/src/valid/function.rs b/src/valid/function.rs index 171e6a747b..91127a40ad 100644 --- a/src/valid/function.rs +++ b/src/valid/function.rs @@ -1026,7 +1026,26 @@ impl super::Validator { crate::RayQueryFunction::Terminate => {} } } - S::SubgroupBallot { result } => { + S::SubgroupBallot { result, predicate } => { + if let Some(predicate) = predicate { + let predicate_inner = + context.resolve_type(predicate, &self.valid_expression_set)?; + match predicate_inner { + crate::TypeInner::Scalar { + kind: crate::ScalarKind::Bool, + .. + } => {} + _ => { + log::error!( + "Subgroup ballot predicate type {:?} expected bool", + predicate_inner + ); + return Err(SubgroupError::InvalidOperand(predicate) + .with_span_handle(predicate, context.expressions) + .into_other()); + } + } + } self.emit_expression(result, context)?; } S::SubgroupCollectiveOperation { diff --git a/src/valid/handles.rs b/src/valid/handles.rs index 399665342f..91209b460b 100644 --- a/src/valid/handles.rs +++ b/src/valid/handles.rs @@ -541,7 +541,8 @@ impl super::Validator { } Ok(()) } - crate::Statement::SubgroupBallot { result } => { + crate::Statement::SubgroupBallot { result, predicate } => { + validate_expr_opt(predicate)?; validate_expr(result)?; Ok(()) } From 7f44dab87be57539bc91b5dd0986c91de4a58f0c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alexander=20Mei=C3=9Fner?= Date: Thu, 12 Oct 2023 01:07:12 +0200 Subject: [PATCH 11/17] Renames SubgroupBroadcast => SubgroupGather and BroadcastMode => GatherMode. --- src/back/dot/mod.rs | 4 ++-- src/back/glsl/mod.rs | 2 +- src/back/hlsl/writer.rs | 2 +- src/back/msl/writer.rs | 2 +- src/back/spv/block.rs | 2 +- src/back/spv/subgroup.rs | 6 +++--- src/back/wgsl/writer.rs | 2 +- src/compact/statements.rs | 8 ++++---- src/front/spv/mod.rs | 2 +- src/front/wgsl/lower/mod.rs | 8 ++++---- src/lib.rs | 14 +++++++------- src/proc/terminator.rs | 2 +- src/valid/analyzer.rs | 4 ++-- src/valid/function.rs | 6 +++--- src/valid/handles.rs | 12 +++++++----- 15 files changed, 39 insertions(+), 37 deletions(-) diff --git a/src/back/dot/mod.rs b/src/back/dot/mod.rs index fec2c60d32..d08eee631a 100644 --- a/src/back/dot/mod.rs +++ b/src/back/dot/mod.rs @@ -296,14 +296,14 @@ impl StatementGraph { self.emits.push((id, result)); "SubgroupCollectiveOperation" // FIXME } - S::SubgroupBroadcast { + S::SubgroupGather { ref mode, argument, result, } => { self.dependencies.push((id, argument, "arg")); self.emits.push((id, result)); - "SubgroupBroadcast" // FIXME + "SubgroupGather" // FIXME } }; // Set the last node to the merge node diff --git a/src/back/glsl/mod.rs b/src/back/glsl/mod.rs index c3a80c652d..77c7e414c4 100644 --- a/src/back/glsl/mod.rs +++ b/src/back/glsl/mod.rs @@ -2273,7 +2273,7 @@ impl<'a, W: Write> Writer<'a, W> { } => { unimplemented!(); // FIXME: } - Statement::SubgroupBroadcast { + Statement::SubgroupGather { ref mode, argument, result, diff --git a/src/back/hlsl/writer.rs b/src/back/hlsl/writer.rs index 98b60f4954..222a0bf02c 100644 --- a/src/back/hlsl/writer.rs +++ b/src/back/hlsl/writer.rs @@ -2026,7 +2026,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { } => { unimplemented!(); // FIXME } - Statement::SubgroupBroadcast { + Statement::SubgroupGather { ref mode, argument, result, diff --git a/src/back/msl/writer.rs b/src/back/msl/writer.rs index 30cf02ddc4..9d1bc84955 100644 --- a/src/back/msl/writer.rs +++ b/src/back/msl/writer.rs @@ -3034,7 +3034,7 @@ impl Writer { } => { unimplemented!(); // FIXME } - crate::Statement::SubgroupBroadcast { + crate::Statement::SubgroupGather { ref mode, argument, result, diff --git a/src/back/spv/block.rs b/src/back/spv/block.rs index 222d0cde39..9fba489f79 100644 --- a/src/back/spv/block.rs +++ b/src/back/spv/block.rs @@ -2373,7 +2373,7 @@ impl<'w> BlockContext<'w> { } => { self.write_subgroup_operation(op, collective_op, argument, result, &mut block)?; } - crate::Statement::SubgroupBroadcast { + crate::Statement::SubgroupGather { ref mode, argument, result, diff --git a/src/back/spv/subgroup.rs b/src/back/spv/subgroup.rs index 3c8b7d827a..ca193562b7 100644 --- a/src/back/spv/subgroup.rs +++ b/src/back/spv/subgroup.rs @@ -94,7 +94,7 @@ impl<'w> BlockContext<'w> { } pub(super) fn write_subgroup_broadcast( &mut self, - mode: &crate::BroadcastMode, + mode: &crate::GatherMode, argument: Handle, result: Handle, block: &mut Block, @@ -112,7 +112,7 @@ impl<'w> BlockContext<'w> { let arg_id = self.cached[argument]; match mode { - crate::BroadcastMode::Index(index) => { + crate::GatherMode::Broadcast(index) => { let index_id = self.cached[*index]; block.body.push(Instruction::group_non_uniform_broadcast( result_type_id, @@ -122,7 +122,7 @@ impl<'w> BlockContext<'w> { index_id, )); } - crate::BroadcastMode::First => { + crate::GatherMode::BroadcastFirst => { block .body .push(Instruction::group_non_uniform_broadcast_first( diff --git a/src/back/wgsl/writer.rs b/src/back/wgsl/writer.rs index 91718d4f4d..c6217ea2a8 100644 --- a/src/back/wgsl/writer.rs +++ b/src/back/wgsl/writer.rs @@ -941,7 +941,7 @@ impl Writer { } => { unimplemented!() // FIXME } - Statement::SubgroupBroadcast { + Statement::SubgroupGather { ref mode, argument, result, diff --git a/src/compact/statements.rs b/src/compact/statements.rs index 0e8c0e4e81..14526bef61 100644 --- a/src/compact/statements.rs +++ b/src/compact/statements.rs @@ -110,12 +110,12 @@ impl FunctionTracer<'_> { self.trace_expression(argument); self.trace_expression(result); } - St::SubgroupBroadcast { + St::SubgroupGather { ref mode, argument, result, } => { - if let crate::BroadcastMode::Index(expr) = *mode { + if let crate::GatherMode::Broadcast(expr) = *mode { self.trace_expression(expr); } self.trace_expression(argument); @@ -288,12 +288,12 @@ impl FunctionMap { adjust(argument); adjust(result); } - St::SubgroupBroadcast { + St::SubgroupGather { ref mut mode, ref mut argument, ref mut result, } => { - if let crate::BroadcastMode::Index(expr) = mode { + if let crate::GatherMode::Broadcast(expr) = mode { adjust(expr); } adjust(argument); diff --git a/src/front/spv/mod.rs b/src/front/spv/mod.rs index 561e32e22d..c78407ef99 100644 --- a/src/front/spv/mod.rs +++ b/src/front/spv/mod.rs @@ -3844,7 +3844,7 @@ impl> Frontend { S::WorkGroupUniformLoad { .. } => unreachable!(), S::SubgroupBallot { .. } => unreachable!(), // FIXME?? S::SubgroupCollectiveOperation { .. } => unreachable!(), - S::SubgroupBroadcast { .. } => unreachable!(), + S::SubgroupGather { .. } => unreachable!(), } i += 1; } diff --git a/src/front/wgsl/lower/mod.rs b/src/front/wgsl/lower/mod.rs index d0e6ee3b40..01fb5005ba 100644 --- a/src/front/wgsl/lower/mod.rs +++ b/src/front/wgsl/lower/mod.rs @@ -2331,8 +2331,8 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { ); let rctx = ctx.runtime_expression_ctx(span)?; rctx.block.push( - crate::Statement::SubgroupBroadcast { - mode: crate::BroadcastMode::Index(index), + crate::Statement::SubgroupGather { + mode: crate::GatherMode::Broadcast(index), argument, result, }, @@ -2354,8 +2354,8 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { ); let rctx = ctx.runtime_expression_ctx(span)?; rctx.block.push( - crate::Statement::SubgroupBroadcast { - mode: crate::BroadcastMode::First, + crate::Statement::SubgroupGather { + mode: crate::GatherMode::BroadcastFirst, argument, result, }, diff --git a/src/lib.rs b/src/lib.rs index a74e86b33f..dee02e49bf 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -436,8 +436,8 @@ pub enum BuiltIn { WorkGroupSize, NumWorkGroups, // subgroup - SubgroupInvocationId, SubgroupSize, + SubgroupInvocationId, } /// Number of bytes per scalar. @@ -1265,9 +1265,9 @@ pub enum SwizzleComponent { #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] #[cfg_attr(feature = "arbitrary", derive(Arbitrary))] -pub enum BroadcastMode { - First, - Index(Handle), +pub enum GatherMode { + BroadcastFirst, + Broadcast(Handle), } #[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] @@ -1899,9 +1899,9 @@ pub enum Statement { predicate: Option>, }, - SubgroupBroadcast { - /// Specifies which thread to broadcast from - mode: BroadcastMode, + SubgroupGather { + /// Specifies which thread to gather from + mode: GatherMode, /// The value to broadcast over argument: Handle, /// The [`SubgroupOperationResult`] expression representing this load's result. diff --git a/src/proc/terminator.rs b/src/proc/terminator.rs index 35111a11de..5edf55cb73 100644 --- a/src/proc/terminator.rs +++ b/src/proc/terminator.rs @@ -39,7 +39,7 @@ pub fn ensure_block_returns(block: &mut crate::Block) { | S::WorkGroupUniformLoad { .. } | S::SubgroupBallot { .. } | S::SubgroupCollectiveOperation { .. } - | S::SubgroupBroadcast { .. } + | S::SubgroupGather { .. } | S::Barrier(_)), ) | None => block.push(S::Return { value: None }, Default::default()), diff --git a/src/valid/analyzer.rs b/src/valid/analyzer.rs index 262e3ffec7..deb775aebf 100644 --- a/src/valid/analyzer.rs +++ b/src/valid/analyzer.rs @@ -1009,13 +1009,13 @@ impl FunctionInfo { let _ = self.add_ref(argument); // FIXME FunctionUniformity::new() } - S::SubgroupBroadcast { + S::SubgroupGather { ref mode, argument, result, } => { let _ = self.add_ref(argument); - if let crate::BroadcastMode::Index(expr) = *mode { + if let crate::GatherMode::Broadcast(expr) = *mode { let _ = self.add_ref(expr); } FunctionUniformity::new() diff --git a/src/valid/function.rs b/src/valid/function.rs index 91127a40ad..03e67f5f18 100644 --- a/src/valid/function.rs +++ b/src/valid/function.rs @@ -473,12 +473,12 @@ impl super::Validator { #[cfg(feature = "validate")] fn validate_subgroup_broadcast( &mut self, - mode: &crate::BroadcastMode, + mode: &crate::GatherMode, argument: Handle, result: Handle, context: &BlockContext, ) -> Result<(), WithSpan> { - if let crate::BroadcastMode::Index(expr) = *mode { + if let crate::GatherMode::Broadcast(expr) = *mode { let index_ty = context.resolve_type(expr, &self.valid_expression_set)?; match index_ty { crate::TypeInner::Scalar { @@ -1056,7 +1056,7 @@ impl super::Validator { } => { self.validate_subgroup_operation(op, collective_op, argument, result, context)?; } - S::SubgroupBroadcast { + S::SubgroupGather { ref mode, argument, result, diff --git a/src/valid/handles.rs b/src/valid/handles.rs index 91209b460b..6d9923a23b 100644 --- a/src/valid/handles.rs +++ b/src/valid/handles.rs @@ -547,8 +547,8 @@ impl super::Validator { Ok(()) } crate::Statement::SubgroupCollectiveOperation { - op, - collective_op, + op: _, + collective_op: _, argument, result, } => { @@ -556,13 +556,15 @@ impl super::Validator { validate_expr(result)?; Ok(()) } - crate::Statement::SubgroupBroadcast { + crate::Statement::SubgroupGather { mode, argument, result, } => { - if let crate::BroadcastMode::Index(expr) = mode { - validate_expr(expr)?; + validate_expr(argument)?; + match mode { + crate::GatherMode::BroadcastFirst => {} + crate::GatherMode::Broadcast(index) => validate_expr(index)?, } validate_expr(result)?; Ok(()) From 754b7279aabe5757d4afd6289752561e943e71b6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alexander=20Mei=C3=9Fner?= Date: Sat, 14 Oct 2023 18:35:27 +0200 Subject: [PATCH 12/17] General fixes. --- src/back/dot/mod.rs | 6 +-- src/back/glsl/mod.rs | 6 +-- src/back/hlsl/writer.rs | 6 +-- src/back/msl/writer.rs | 6 +-- src/back/wgsl/writer.rs | 6 +-- src/compact/statements.rs | 22 ++++----- src/front/wgsl/lower/mod.rs | 9 ++-- src/proc/constant_evaluator.rs | 8 ++++ src/valid/analyzer.rs | 14 +++--- src/valid/expression.rs | 2 +- src/valid/function.rs | 81 ++++++++++++++++++---------------- 11 files changed, 92 insertions(+), 74 deletions(-) diff --git a/src/back/dot/mod.rs b/src/back/dot/mod.rs index d08eee631a..3174b7c6b6 100644 --- a/src/back/dot/mod.rs +++ b/src/back/dot/mod.rs @@ -287,8 +287,8 @@ impl StatementGraph { "SubgroupBallot" } S::SubgroupCollectiveOperation { - ref op, - ref collective_op, + op, + collective_op, argument, result, } => { @@ -297,7 +297,7 @@ impl StatementGraph { "SubgroupCollectiveOperation" // FIXME } S::SubgroupGather { - ref mode, + mode, argument, result, } => { diff --git a/src/back/glsl/mod.rs b/src/back/glsl/mod.rs index 77c7e414c4..1b51c7cf8b 100644 --- a/src/back/glsl/mod.rs +++ b/src/back/glsl/mod.rs @@ -2266,15 +2266,15 @@ impl<'a, W: Write> Writer<'a, W> { write!(self.out, ");")?; } Statement::SubgroupCollectiveOperation { - ref op, - ref collective_op, + op, + collective_op, argument, result, } => { unimplemented!(); // FIXME: } Statement::SubgroupGather { - ref mode, + mode, argument, result, } => { diff --git a/src/back/hlsl/writer.rs b/src/back/hlsl/writer.rs index 222a0bf02c..a04fef7a1b 100644 --- a/src/back/hlsl/writer.rs +++ b/src/back/hlsl/writer.rs @@ -2019,15 +2019,15 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { writeln!(self.out, ");")?; } Statement::SubgroupCollectiveOperation { - ref op, - ref collective_op, + op, + collective_op, argument, result, } => { unimplemented!(); // FIXME } Statement::SubgroupGather { - ref mode, + mode, argument, result, } => { diff --git a/src/back/msl/writer.rs b/src/back/msl/writer.rs index 9d1bc84955..629b98c96d 100644 --- a/src/back/msl/writer.rs +++ b/src/back/msl/writer.rs @@ -3027,15 +3027,15 @@ impl Writer { writeln!(self.out, ");")?; } crate::Statement::SubgroupCollectiveOperation { - ref op, - ref collective_op, + op, + collective_op, argument, result, } => { unimplemented!(); // FIXME } crate::Statement::SubgroupGather { - ref mode, + mode, argument, result, } => { diff --git a/src/back/wgsl/writer.rs b/src/back/wgsl/writer.rs index c6217ea2a8..a2b2497c0c 100644 --- a/src/back/wgsl/writer.rs +++ b/src/back/wgsl/writer.rs @@ -934,15 +934,15 @@ impl Writer { writeln!(self.out, ");")?; } Statement::SubgroupCollectiveOperation { - ref op, - ref collective_op, + op, + collective_op, argument, result, } => { unimplemented!() // FIXME } Statement::SubgroupGather { - ref mode, + mode, argument, result, } => { diff --git a/src/compact/statements.rs b/src/compact/statements.rs index 14526bef61..04a184daf8 100644 --- a/src/compact/statements.rs +++ b/src/compact/statements.rs @@ -102,8 +102,8 @@ impl FunctionTracer<'_> { self.trace_expression(result); } St::SubgroupCollectiveOperation { - ref op, - ref collective_op, + op: _, + collective_op: _, argument, result, } => { @@ -111,12 +111,13 @@ impl FunctionTracer<'_> { self.trace_expression(result); } St::SubgroupGather { - ref mode, + mode, argument, result, } => { - if let crate::GatherMode::Broadcast(expr) = *mode { - self.trace_expression(expr); + match mode { + crate::GatherMode::BroadcastFirst => {} + crate::GatherMode::Broadcast(index) => self.trace_expression(index), } self.trace_expression(argument); self.trace_expression(result); @@ -274,14 +275,14 @@ impl FunctionMap { ref mut result, ref mut predicate, } => { - if let Some(ref mut predicate) = predicate { + if let Some(ref mut predicate) = *predicate { adjust(predicate); } adjust(result); } St::SubgroupCollectiveOperation { - ref mut op, - ref mut collective_op, + op: _, + collective_op: _, ref mut argument, ref mut result, } => { @@ -293,8 +294,9 @@ impl FunctionMap { ref mut argument, ref mut result, } => { - if let crate::GatherMode::Broadcast(expr) = mode { - adjust(expr); + match *mode { + crate::GatherMode::BroadcastFirst => {} + crate::GatherMode::Broadcast(ref mut index) => adjust(index), } adjust(argument); adjust(result); diff --git a/src/front/wgsl/lower/mod.rs b/src/front/wgsl/lower/mod.rs index 01fb5005ba..127008d437 100644 --- a/src/front/wgsl/lower/mod.rs +++ b/src/front/wgsl/lower/mod.rs @@ -2310,7 +2310,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { args.finish()?; let result = ctx - .interrupt_emitter(crate::Expression::SubgroupBallotResult, span); + .interrupt_emitter(crate::Expression::SubgroupBallotResult, span)?; let rctx = ctx.runtime_expression_ctx(span)?; rctx.block .push(crate::Statement::SubgroupBallot { result, predicate }, span); @@ -2328,7 +2328,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { let result = ctx.interrupt_emitter( crate::Expression::SubgroupOperationResult { ty }, span, - ); + )?; let rctx = ctx.runtime_expression_ctx(span)?; rctx.block.push( crate::Statement::SubgroupGather { @@ -2351,7 +2351,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { let result = ctx.interrupt_emitter( crate::Expression::SubgroupOperationResult { ty }, span, - ); + )?; let rctx = ctx.runtime_expression_ctx(span)?; rctx.block.push( crate::Statement::SubgroupGather { @@ -2569,7 +2569,8 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { let ty = ctx.register_type(argument)?; - let result = ctx.interrupt_emitter(crate::Expression::SubgroupOperationResult { ty }, span); + let result = + ctx.interrupt_emitter(crate::Expression::SubgroupOperationResult { ty }, span)?; let rctx = ctx.runtime_expression_ctx(span)?; rctx.block.push( crate::Statement::SubgroupCollectiveOperation { diff --git a/src/proc/constant_evaluator.rs b/src/proc/constant_evaluator.rs index 2082743975..c945805813 100644 --- a/src/proc/constant_evaluator.rs +++ b/src/proc/constant_evaluator.rs @@ -133,6 +133,8 @@ pub enum ConstantEvaluatorError { ImageExpression, #[error("Constants don't support ray query expressions")] RayQueryExpression, + #[error("Constants don't support subgroup expressions")] + SubgroupExpression, #[error("Cannot access the type")] InvalidAccessBase, #[error("Cannot access at the index")] @@ -439,6 +441,12 @@ impl<'a> ConstantEvaluator<'a> { Expression::RayQueryProceedResult | Expression::RayQueryGetIntersection { .. } => { Err(ConstantEvaluatorError::RayQueryExpression) } + Expression::SubgroupBallotResult { .. } => { + Err(ConstantEvaluatorError::SubgroupExpression) + } + Expression::SubgroupOperationResult { .. } => { + Err(ConstantEvaluatorError::SubgroupExpression) + } } } diff --git a/src/valid/analyzer.rs b/src/valid/analyzer.rs index deb775aebf..4095f03e41 100644 --- a/src/valid/analyzer.rs +++ b/src/valid/analyzer.rs @@ -744,7 +744,7 @@ impl FunctionInfo { non_uniform_result: None, // FIXME requirements: UniformityRequirements::empty(), }, - E::SubgroupOperationResult { ty } => Uniformity { + E::SubgroupOperationResult { .. } => Uniformity { non_uniform_result: None, // FIXME requirements: UniformityRequirements::empty(), }, @@ -1001,21 +1001,21 @@ impl FunctionInfo { FunctionUniformity::new() } S::SubgroupCollectiveOperation { - ref op, - ref collective_op, + op: _, + collective_op: _, argument, result: _, } => { - let _ = self.add_ref(argument); // FIXME + let _ = self.add_ref(argument); FunctionUniformity::new() } S::SubgroupGather { - ref mode, + mode, argument, - result, + result: _, } => { let _ = self.add_ref(argument); - if let crate::GatherMode::Broadcast(expr) = *mode { + if let crate::GatherMode::Broadcast(expr) = mode { let _ = self.add_ref(expr); } FunctionUniformity::new() diff --git a/src/valid/expression.rs b/src/valid/expression.rs index f840bba42d..03ad851dbf 100644 --- a/src/valid/expression.rs +++ b/src/valid/expression.rs @@ -1538,7 +1538,7 @@ impl super::Validator { } }, E::SubgroupBallotResult => ShaderStages::COMPUTE | ShaderStages::FRAGMENT, - E::SubgroupOperationResult { ty } => ShaderStages::COMPUTE, // FIXME + E::SubgroupOperationResult { .. } => ShaderStages::COMPUTE, // FIXME }; Ok(stages) } diff --git a/src/valid/function.rs b/src/valid/function.rs index 03e67f5f18..e88b763258 100644 --- a/src/valid/function.rs +++ b/src/valid/function.rs @@ -435,15 +435,20 @@ impl super::Validator { ) -> Result<(), WithSpan> { let argument_inner = context.resolve_type(argument, &self.valid_expression_set)?; - let (is_scalar, kind) = match argument_inner { + let (is_scalar, kind) = match *argument_inner { crate::TypeInner::Scalar { kind, .. } => (true, kind), crate::TypeInner::Vector { kind, .. } => (false, kind), - _ => unimplemented!(), + _ => { + log::error!("Subgroup operand type {:?}", argument_inner); + return Err(SubgroupError::InvalidOperand(argument) + .with_span_handle(argument, context.expressions) + .into_other()); + } }; use crate::ScalarKind as sk; use crate::SubgroupOperation as sg; - match (kind, op) { + match (kind, *op) { (sk::Bool, sg::All | sg::Any) if is_scalar => {} (sk::Sint | sk::Uint | sk::Float, sg::Add | sg::Mul | sg::Min | sg::Max) => {} (sk::Sint | sk::Uint | sk::Bool, sg::And | sg::Or | sg::Xor) => {} @@ -478,34 +483,36 @@ impl super::Validator { result: Handle, context: &BlockContext, ) -> Result<(), WithSpan> { - if let crate::GatherMode::Broadcast(expr) = *mode { - let index_ty = context.resolve_type(expr, &self.valid_expression_set)?; - match index_ty { - crate::TypeInner::Scalar { - kind: crate::ScalarKind::Uint, - .. - } => {} - _ => { - log::error!( - "Subgroup broadcast index type {:?}, expected unsigned int", - index_ty - ); - return Err(SubgroupError::InvalidOperand(argument) - .with_span_handle(expr, context.expressions) - .into_other()); + match *mode { + crate::GatherMode::BroadcastFirst => {} + crate::GatherMode::Broadcast(index) => { + let index_ty = context.resolve_type(index, &self.valid_expression_set)?; + match *index_ty { + crate::TypeInner::Scalar { + kind: crate::ScalarKind::Uint, + .. + } => {} + _ => { + log::error!( + "Subgroup gather index type {:?}, expected unsigned int", + index_ty + ); + return Err(SubgroupError::InvalidOperand(argument) + .with_span_handle(index, context.expressions) + .into_other()); + } } } } let argument_inner = context.resolve_type(argument, &self.valid_expression_set)?; - - match argument_inner { - crate::TypeInner::Scalar { .. } | crate::TypeInner::Vector { .. } => {} - _ => { - log::error!("Subgroup broadcast operand type {:?}", argument_inner); - return Err(SubgroupError::InvalidOperand(argument) - .with_span_handle(argument, context.expressions) - .into_other()); - } + if !matches!(*argument_inner, + crate::TypeInner::Scalar { kind, .. } | crate::TypeInner::Vector { kind, .. } + if matches!(kind, crate::ScalarKind::Uint | crate::ScalarKind::Sint | crate::ScalarKind::Float) + ) { + log::error!("Subgroup gather operand type {:?}", argument_inner); + return Err(SubgroupError::InvalidOperand(argument) + .with_span_handle(argument, context.expressions) + .into_other()); } self.emit_expression(result, context)?; @@ -1030,20 +1037,20 @@ impl super::Validator { if let Some(predicate) = predicate { let predicate_inner = context.resolve_type(predicate, &self.valid_expression_set)?; - match predicate_inner { + if !matches!( + *predicate_inner, crate::TypeInner::Scalar { kind: crate::ScalarKind::Bool, .. - } => {} - _ => { - log::error!( - "Subgroup ballot predicate type {:?} expected bool", - predicate_inner - ); - return Err(SubgroupError::InvalidOperand(predicate) - .with_span_handle(predicate, context.expressions) - .into_other()); } + ) { + log::error!( + "Subgroup ballot predicate type {:?} expected bool", + predicate_inner + ); + return Err(SubgroupError::InvalidOperand(predicate) + .with_span_handle(predicate, context.expressions) + .into_other()); } } self.emit_expression(result, context)?; From 359875367c17e4c3d40092e58460ef3b5bdf7026 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alexander=20Mei=C3=9Fner?= Date: Sat, 14 Oct 2023 18:55:14 +0200 Subject: [PATCH 13/17] Adds BuiltIn::NumSubgroups, BuiltIn::SubgroupId. --- src/back/glsl/mod.rs | 1 + src/back/hlsl/conv.rs | 1 + src/back/msl/mod.rs | 1 + src/back/spv/writer.rs | 1 + src/back/wgsl/writer.rs | 1 + src/lib.rs | 2 ++ src/valid/interface.rs | 4 ++-- 7 files changed, 9 insertions(+), 2 deletions(-) diff --git a/src/back/glsl/mod.rs b/src/back/glsl/mod.rs index 1b51c7cf8b..e6d3648f1d 100644 --- a/src/back/glsl/mod.rs +++ b/src/back/glsl/mod.rs @@ -4192,6 +4192,7 @@ const fn glsl_built_in( Bi::WorkGroupSize => "gl_WorkGroupSize", Bi::NumWorkGroups => "gl_NumWorkGroups", // subgroup + Bi::NumSubgroups | Bi::SubgroupId => todo!(), Bi::SubgroupInvocationId => "gl_SubgroupInvocationID", Bi::SubgroupSize => "gl_SubgroupSize", } diff --git a/src/back/hlsl/conv.rs b/src/back/hlsl/conv.rs index 19c4da5e74..3f51278cc1 100644 --- a/src/back/hlsl/conv.rs +++ b/src/back/hlsl/conv.rs @@ -167,6 +167,7 @@ impl crate::BuiltIn { // in `Writer::write_expr`. Self::NumWorkGroups => "SV_GroupID", + Self::NumSubgroups | Self::SubgroupId => todo!(), Self::SubgroupInvocationId | Self::SubgroupSize | Self::BaseInstance diff --git a/src/back/msl/mod.rs b/src/back/msl/mod.rs index 9e23d2a08d..4e0d8489e4 100644 --- a/src/back/msl/mod.rs +++ b/src/back/msl/mod.rs @@ -438,6 +438,7 @@ impl ResolvedBinding { Bi::WorkGroupSize => "dispatch_threads_per_threadgroup", Bi::NumWorkGroups => "threadgroups_per_grid", // subgroup + Bi::NumSubgroups | Bi::SubgroupId => todo!(), Bi::SubgroupInvocationId => "simdgroup_index_in_threadgroup", Bi::SubgroupSize => "simdgroups_per_threadgroup", Bi::CullDistance | Bi::ViewIndex => { diff --git a/src/back/spv/writer.rs b/src/back/spv/writer.rs index 077751d90b..6a55288308 100644 --- a/src/back/spv/writer.rs +++ b/src/back/spv/writer.rs @@ -1594,6 +1594,7 @@ impl Writer { Bi::WorkGroupSize => BuiltIn::WorkgroupSize, Bi::NumWorkGroups => BuiltIn::NumWorkgroups, // Subgroup + Bi::NumSubgroups | Bi::SubgroupId => todo!(), Bi::SubgroupInvocationId => { self.require_any( "`subgroup_invocation_id` built-in", diff --git a/src/back/wgsl/writer.rs b/src/back/wgsl/writer.rs index a2b2497c0c..8fc9881f10 100644 --- a/src/back/wgsl/writer.rs +++ b/src/back/wgsl/writer.rs @@ -1789,6 +1789,7 @@ fn builtin_str(built_in: crate::BuiltIn) -> Result<&'static str, Error> { Bi::SampleMask => "sample_mask", Bi::PrimitiveIndex => "primitive_index", Bi::ViewIndex => "view_index", + Bi::NumSubgroups | Bi::SubgroupId => todo!(), Bi::SubgroupInvocationId => "subgroup_invocation_id", Bi::SubgroupSize => "subgroup_size", Bi::BaseInstance diff --git a/src/lib.rs b/src/lib.rs index dee02e49bf..93a5d9b62a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -436,6 +436,8 @@ pub enum BuiltIn { WorkGroupSize, NumWorkGroups, // subgroup + NumSubgroups, + SubgroupId, SubgroupSize, SubgroupInvocationId, } diff --git a/src/valid/interface.rs b/src/valid/interface.rs index bf4f397224..4b5f66492c 100644 --- a/src/valid/interface.rs +++ b/src/valid/interface.rs @@ -299,7 +299,7 @@ impl VaryingContext<'_> { width, }, ), - Bi::SubgroupInvocationId => ( + Bi::NumSubgroups | Bi::SubgroupId => ( self.stage == St::Compute && !self.output, *ty_inner == Ti::Scalar { @@ -307,7 +307,7 @@ impl VaryingContext<'_> { width, }, ), - Bi::SubgroupSize => ( + Bi::SubgroupSize | Bi::SubgroupInvocationId => ( match self.stage { St::Compute | St::Fragment => !self.output, St::Vertex => false, From 53120910dac34679f061740721308148e6a9b069 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alexander=20Mei=C3=9Fner?= Date: Sat, 14 Oct 2023 18:55:14 +0200 Subject: [PATCH 14/17] Adds GatherMode::Shuffle, GatherMode::ShuffleDown, GatherMode::ShuffleUp, GatherMode::ShuffleXor. --- src/back/spv/subgroup.rs | 4 ++++ src/compact/statements.rs | 12 ++++++++++-- src/lib.rs | 4 ++++ src/valid/analyzer.rs | 11 +++++++++-- src/valid/function.rs | 6 +++++- src/valid/handles.rs | 6 +++++- 6 files changed, 37 insertions(+), 6 deletions(-) diff --git a/src/back/spv/subgroup.rs b/src/back/spv/subgroup.rs index ca193562b7..7206ec9312 100644 --- a/src/back/spv/subgroup.rs +++ b/src/back/spv/subgroup.rs @@ -132,6 +132,10 @@ impl<'w> BlockContext<'w> { arg_id, )); } + crate::GatherMode::Shuffle(index) + | crate::GatherMode::ShuffleDown(index) + | crate::GatherMode::ShuffleUp(index) + | crate::GatherMode::ShuffleXor(index) => todo!(), } self.cached[result] = id; Ok(()) diff --git a/src/compact/statements.rs b/src/compact/statements.rs index 04a184daf8..37074c4299 100644 --- a/src/compact/statements.rs +++ b/src/compact/statements.rs @@ -117,7 +117,11 @@ impl FunctionTracer<'_> { } => { match mode { crate::GatherMode::BroadcastFirst => {} - crate::GatherMode::Broadcast(index) => self.trace_expression(index), + crate::GatherMode::Broadcast(index) + | crate::GatherMode::Shuffle(index) + | crate::GatherMode::ShuffleDown(index) + | crate::GatherMode::ShuffleUp(index) + | crate::GatherMode::ShuffleXor(index) => self.trace_expression(index), } self.trace_expression(argument); self.trace_expression(result); @@ -296,7 +300,11 @@ impl FunctionMap { } => { match *mode { crate::GatherMode::BroadcastFirst => {} - crate::GatherMode::Broadcast(ref mut index) => adjust(index), + crate::GatherMode::Broadcast(ref mut index) + | crate::GatherMode::Shuffle(ref mut index) + | crate::GatherMode::ShuffleDown(ref mut index) + | crate::GatherMode::ShuffleUp(ref mut index) + | crate::GatherMode::ShuffleXor(ref mut index) => adjust(index), } adjust(argument); adjust(result); diff --git a/src/lib.rs b/src/lib.rs index 93a5d9b62a..c754405ed7 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1270,6 +1270,10 @@ pub enum SwizzleComponent { pub enum GatherMode { BroadcastFirst, Broadcast(Handle), + Shuffle(Handle), + ShuffleDown(Handle), + ShuffleUp(Handle), + ShuffleXor(Handle), } #[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] diff --git a/src/valid/analyzer.rs b/src/valid/analyzer.rs index 4095f03e41..f4347df1dd 100644 --- a/src/valid/analyzer.rs +++ b/src/valid/analyzer.rs @@ -1015,8 +1015,15 @@ impl FunctionInfo { result: _, } => { let _ = self.add_ref(argument); - if let crate::GatherMode::Broadcast(expr) = mode { - let _ = self.add_ref(expr); + match mode { + crate::GatherMode::BroadcastFirst => {} + crate::GatherMode::Broadcast(index) + | crate::GatherMode::Shuffle(index) + | crate::GatherMode::ShuffleDown(index) + | crate::GatherMode::ShuffleUp(index) + | crate::GatherMode::ShuffleXor(index) => { + let _ = self.add_ref(index); + } } FunctionUniformity::new() } diff --git a/src/valid/function.rs b/src/valid/function.rs index e88b763258..729a6405c2 100644 --- a/src/valid/function.rs +++ b/src/valid/function.rs @@ -485,7 +485,11 @@ impl super::Validator { ) -> Result<(), WithSpan> { match *mode { crate::GatherMode::BroadcastFirst => {} - crate::GatherMode::Broadcast(index) => { + crate::GatherMode::Broadcast(index) + | crate::GatherMode::Shuffle(index) + | crate::GatherMode::ShuffleDown(index) + | crate::GatherMode::ShuffleUp(index) + | crate::GatherMode::ShuffleXor(index) => { let index_ty = context.resolve_type(index, &self.valid_expression_set)?; match *index_ty { crate::TypeInner::Scalar { diff --git a/src/valid/handles.rs b/src/valid/handles.rs index 6d9923a23b..e1674f0804 100644 --- a/src/valid/handles.rs +++ b/src/valid/handles.rs @@ -564,7 +564,11 @@ impl super::Validator { validate_expr(argument)?; match mode { crate::GatherMode::BroadcastFirst => {} - crate::GatherMode::Broadcast(index) => validate_expr(index)?, + crate::GatherMode::Broadcast(index) + | crate::GatherMode::Shuffle(index) + | crate::GatherMode::ShuffleDown(index) + | crate::GatherMode::ShuffleUp(index) + | crate::GatherMode::ShuffleXor(index) => validate_expr(index)?, } validate_expr(result)?; Ok(()) From ab577c6586c5b40ed84b5a6e72af0dc330a4ab4f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alexander=20Mei=C3=9Fner?= Date: Sat, 14 Oct 2023 19:35:39 +0200 Subject: [PATCH 15/17] Implements all frontends and backends. --- src/back/dot/mod.rs | 66 ++++++++++- src/back/glsl/features.rs | 23 ++++ src/back/glsl/mod.rs | 102 ++++++++++++++++- src/back/hlsl/conv.rs | 14 +-- src/back/hlsl/writer.rs | 177 +++++++++++++++++++++++++++-- src/back/msl/mod.rs | 7 +- src/back/msl/writer.rs | 109 ++++++++++++++++-- src/back/spv/block.rs | 31 +---- src/back/spv/instructions.rs | 17 ++- src/back/spv/subgroup.rs | 111 ++++++++++++++---- src/back/spv/writer.rs | 26 ++++- src/back/wgsl/writer.rs | 98 +++++++++++++++- src/front/spv/error.rs | 6 +- src/front/spv/mod.rs | 213 ++++++++++++++++++++++++++++++++++- src/front/wgsl/lower/mod.rs | 98 ++++++++-------- src/front/wgsl/parse/conv.rs | 25 +++- 16 files changed, 959 insertions(+), 164 deletions(-) diff --git a/src/back/dot/mod.rs b/src/back/dot/mod.rs index 3174b7c6b6..86f4797b56 100644 --- a/src/back/dot/mod.rs +++ b/src/back/dot/mod.rs @@ -294,16 +294,78 @@ impl StatementGraph { } => { self.dependencies.push((id, argument, "arg")); self.emits.push((id, result)); - "SubgroupCollectiveOperation" // FIXME + match (collective_op, op) { + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::All) => { + "SubgroupAll" + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Any) => { + "SubgroupAny" + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Add) => { + "SubgroupAdd" + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Mul) => { + "SubgroupMul" + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Max) => { + "SubgroupMax" + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Min) => { + "SubgroupMin" + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::And) => { + "SubgroupAnd" + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Or) => { + "SubgroupOr" + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Xor) => { + "SubgroupXor" + } + ( + crate::CollectiveOperation::ExclusiveScan, + crate::SubgroupOperation::Add, + ) => "SubgroupPrefixExclusiveAdd", + ( + crate::CollectiveOperation::ExclusiveScan, + crate::SubgroupOperation::Mul, + ) => "SubgroupPrefixExclusiveMul", + ( + crate::CollectiveOperation::InclusiveScan, + crate::SubgroupOperation::Add, + ) => "SubgroupPrefixInclusiveAdd", + ( + crate::CollectiveOperation::InclusiveScan, + crate::SubgroupOperation::Mul, + ) => "SubgroupPrefixInclusiveMul", + _ => unimplemented!(), + } } S::SubgroupGather { mode, argument, result, } => { + match mode { + crate::GatherMode::BroadcastFirst => {} + crate::GatherMode::Broadcast(index) + | crate::GatherMode::Shuffle(index) + | crate::GatherMode::ShuffleDown(index) + | crate::GatherMode::ShuffleUp(index) + | crate::GatherMode::ShuffleXor(index) => { + self.dependencies.push((id, index, "index")) + } + } self.dependencies.push((id, argument, "arg")); self.emits.push((id, result)); - "SubgroupGather" // FIXME + match mode { + crate::GatherMode::BroadcastFirst => "SubgroupBroadcastFirst", + crate::GatherMode::Broadcast(_) => "SubgroupBroadcast", + crate::GatherMode::Shuffle(_) => "SubgroupShuffle", + crate::GatherMode::ShuffleDown(_) => "SubgroupShuffleDown", + crate::GatherMode::ShuffleUp(_) => "SubgroupShuffleUp", + crate::GatherMode::ShuffleXor(_) => "SubgroupShuffleXor", + } } }; // Set the last node to the merge node diff --git a/src/back/glsl/features.rs b/src/back/glsl/features.rs index b1fff4d4bc..483a3a9348 100644 --- a/src/back/glsl/features.rs +++ b/src/back/glsl/features.rs @@ -43,6 +43,8 @@ bitflags::bitflags! { const IMAGE_SIZE = 1 << 20; /// Dual source blending const DUAL_SOURCE_BLENDING = 1 << 21; + /// Subgroup operations + const SUBGROUP_OPERATIONS = 1 << 22; } } @@ -106,6 +108,7 @@ impl FeaturesManager { check_feature!(SAMPLE_VARIABLES, 400, 300); check_feature!(DYNAMIC_ARRAY_SIZE, 430, 310); check_feature!(DUAL_SOURCE_BLENDING, 330, 300 /* with extension */); + check_feature!(SUBGROUP_OPERATIONS, 430, 310); match version { Version::Embedded { is_webgl: true, .. } => check_feature!(MULTI_VIEW, 140, 300), _ => check_feature!(MULTI_VIEW, 140, 310), @@ -235,6 +238,22 @@ impl FeaturesManager { writeln!(out, "#extension GL_EXT_blend_func_extended : require")?; } + if self.0.contains(Features::SUBGROUP_OPERATIONS) { + // https://registry.khronos.org/OpenGL/extensions/KHR/KHR_shader_subgroup.txt + writeln!(out, "#extension GL_KHR_shader_subgroup_basic : require")?; + writeln!(out, "#extension GL_KHR_shader_subgroup_vote : require")?; + writeln!( + out, + "#extension GL_KHR_shader_subgroup_arithmetic : require" + )?; + writeln!(out, "#extension GL_KHR_shader_subgroup_ballot : require")?; + writeln!(out, "#extension GL_KHR_shader_subgroup_shuffle : require")?; + writeln!( + out, + "#extension GL_KHR_shader_subgroup_shuffle_relative : require" + )?; + } + Ok(()) } } @@ -455,6 +474,10 @@ impl<'a, W> Writer<'a, W> { } } } + Expression::SubgroupBallotResult | + Expression::SubgroupOperationResult { .. } => { + features.request(Features::SUBGROUP_OPERATIONS) + } _ => {} } } diff --git a/src/back/glsl/mod.rs b/src/back/glsl/mod.rs index e6d3648f1d..03a06890f8 100644 --- a/src/back/glsl/mod.rs +++ b/src/back/glsl/mod.rs @@ -2263,7 +2263,7 @@ impl<'a, W: Write> Writer<'a, W> { Some(predicate) => self.write_expr(predicate, ctx)?, None => write!(self.out, "true")?, } - write!(self.out, ");")?; + writeln!(self.out, ");")?; } Statement::SubgroupCollectiveOperation { op, @@ -2271,14 +2271,103 @@ impl<'a, W: Write> Writer<'a, W> { argument, result, } => { - unimplemented!(); // FIXME: + write!(self.out, "{level}")?; + let res_name = format!("{}{}", back::BAKE_PREFIX, result.index()); + let res_ty = ctx.info[result].ty.inner_with(&self.module.types); + self.write_value_type(res_ty)?; + write!(self.out, " {res_name} = ")?; + self.named_expressions.insert(result, res_name); + + match (collective_op, op) { + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::All) => { + write!(self.out, "subgroupAll(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Any) => { + write!(self.out, "subgroupAny(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Add) => { + write!(self.out, "subgroupAdd(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Mul) => { + write!(self.out, "subgroupMul(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Max) => { + write!(self.out, "subgroupMax(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Min) => { + write!(self.out, "subgroupMin(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::And) => { + write!(self.out, "subgroupAnd(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Or) => { + write!(self.out, "subgroupOr(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Xor) => { + write!(self.out, "subgroupXor(")? + } + (crate::CollectiveOperation::ExclusiveScan, crate::SubgroupOperation::Add) => { + write!(self.out, "subgroupExclusiveAdd(")? + } + (crate::CollectiveOperation::ExclusiveScan, crate::SubgroupOperation::Mul) => { + write!(self.out, "subgroupExclusiveMul(")? + } + (crate::CollectiveOperation::InclusiveScan, crate::SubgroupOperation::Add) => { + write!(self.out, "subgroupInclusiveAdd(")? + } + (crate::CollectiveOperation::InclusiveScan, crate::SubgroupOperation::Mul) => { + write!(self.out, "subgroupInclusiveMul(")? + } + _ => unimplemented!(), + } + self.write_expr(argument, ctx)?; + writeln!(self.out, ");")?; } Statement::SubgroupGather { mode, argument, result, } => { - unimplemented!(); // FIXME + write!(self.out, "{level}")?; + let res_name = format!("{}{}", back::BAKE_PREFIX, result.index()); + let res_ty = ctx.info[result].ty.inner_with(&self.module.types); + self.write_value_type(res_ty)?; + write!(self.out, " {res_name} = ")?; + self.named_expressions.insert(result, res_name); + + match mode { + crate::GatherMode::BroadcastFirst => { + write!(self.out, "subgroupBroadcastFirst(")?; + } + crate::GatherMode::Broadcast(_) => { + write!(self.out, "subgroupBroadcast(")?; + } + crate::GatherMode::Shuffle(_) => { + write!(self.out, "subgroupShuffle(")?; + } + crate::GatherMode::ShuffleDown(_) => { + write!(self.out, "subgroupShuffleDown(")?; + } + crate::GatherMode::ShuffleUp(_) => { + write!(self.out, "subgroupShuffleUp(")?; + } + crate::GatherMode::ShuffleXor(_) => { + write!(self.out, "subgroupShuffleXor(")?; + } + } + self.write_expr(argument, ctx)?; + match mode { + crate::GatherMode::BroadcastFirst => {} + crate::GatherMode::Broadcast(index) + | crate::GatherMode::Shuffle(index) + | crate::GatherMode::ShuffleDown(index) + | crate::GatherMode::ShuffleUp(index) + | crate::GatherMode::ShuffleXor(index) => { + write!(self.out, ", ")?; + self.write_expr(index, ctx)?; + } + } + writeln!(self.out, ");")?; } } @@ -4013,7 +4102,7 @@ impl<'a, W: Write> Writer<'a, W> { writeln!(self.out, "{level}memoryBarrierShared();")?; } if flags.contains(crate::Barrier::SUB_GROUP) { - unimplemented!() // FIXME + writeln!(self.out, "{level}subgroupMemoryBarrier();")?; } writeln!(self.out, "{level}barrier();")?; Ok(()) @@ -4192,9 +4281,10 @@ const fn glsl_built_in( Bi::WorkGroupSize => "gl_WorkGroupSize", Bi::NumWorkGroups => "gl_NumWorkGroups", // subgroup - Bi::NumSubgroups | Bi::SubgroupId => todo!(), - Bi::SubgroupInvocationId => "gl_SubgroupInvocationID", + Bi::NumSubgroups => "gl_NumSubgroups", + Bi::SubgroupId => "gl_SubgroupID", Bi::SubgroupSize => "gl_SubgroupSize", + Bi::SubgroupInvocationId => "gl_SubgroupInvocationID", } } diff --git a/src/back/hlsl/conv.rs b/src/back/hlsl/conv.rs index 3f51278cc1..d3fb76e401 100644 --- a/src/back/hlsl/conv.rs +++ b/src/back/hlsl/conv.rs @@ -166,13 +166,13 @@ impl crate::BuiltIn { // to this field will get replaced with references to `SPECIAL_CBUF_VAR` // in `Writer::write_expr`. Self::NumWorkGroups => "SV_GroupID", - - Self::NumSubgroups | Self::SubgroupId => todo!(), - Self::SubgroupInvocationId - | Self::SubgroupSize - | Self::BaseInstance - | Self::BaseVertex - | Self::WorkGroupSize => return Err(Error::Unimplemented(format!("builtin {self:?}"))), + Self::SubgroupSize + | Self::SubgroupInvocationId + | Self::NumSubgroups + | Self::SubgroupId => return Err(Error::Unimplemented(format!("builtin {self:?}"))), + Self::BaseInstance | Self::BaseVertex | Self::WorkGroupSize => { + return Err(Error::Unimplemented(format!("builtin {self:?}"))) + } Self::PointSize | Self::ViewIndex | Self::PointCoord => { return Err(Error::Custom(format!("Unsupported builtin {self:?}"))) } diff --git a/src/back/hlsl/writer.rs b/src/back/hlsl/writer.rs index a04fef7a1b..1eab43a4c3 100644 --- a/src/back/hlsl/writer.rs +++ b/src/back/hlsl/writer.rs @@ -1130,7 +1130,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { write!(self.out, " {name}(")?; let need_workgroup_variables_initialization = - self.need_workgroup_variables_initialization(func_ctx, module); + self.need_workgroup_variables_initialization(func, func_ctx, module); // Write function arguments for non entry point functions match func_ctx.ty { @@ -1166,7 +1166,21 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { write!(self.out, "{} {}", ep_input.ty_name, ep_input.arg_name,)?; } else { let stage = module.entry_points[ep_index as usize].stage; + let mut arg_num = 0; for (index, arg) in func.arguments.iter().enumerate() { + if matches!( + arg.binding, + Some(crate::Binding::BuiltIn(crate::BuiltIn::SubgroupSize)) + | Some(crate::Binding::BuiltIn( + crate::BuiltIn::SubgroupInvocationId + )) + | Some(crate::Binding::BuiltIn(crate::BuiltIn::NumSubgroups)) + | Some(crate::Binding::BuiltIn(crate::BuiltIn::SubgroupId)) + ) { + continue; + } + arg_num += 1; + if index != 0 { write!(self.out, ", ")?; } @@ -1186,7 +1200,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { } if need_workgroup_variables_initialization { - if !func.arguments.is_empty() { + if arg_num > 0 { write!(self.out, ", ")?; } write!(self.out, "uint3 __local_invocation_id : SV_GroupThreadID")?; @@ -1217,6 +1231,53 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { self.write_workgroup_variables_initialization(func_ctx, module)?; } + if let back::FunctionType::EntryPoint(ep_index) = func_ctx.ty { + let ep = &module.entry_points[ep_index as usize]; + for (index, arg) in func.arguments.iter().enumerate() { + if let Some(crate::Binding::BuiltIn(builtin)) = arg.binding { + if matches!( + builtin, + crate::BuiltIn::SubgroupSize + | crate::BuiltIn::SubgroupInvocationId + | crate::BuiltIn::NumSubgroups + | crate::BuiltIn::SubgroupId + ) { + let level = back::Level(1); + write!(self.out, "{level}const ")?; + + self.write_type(module, arg.ty)?; + + let argument_name = + &self.names[&NameKey::EntryPointArgument(ep_index, index as u32)]; + write!(self.out, " {argument_name} = ")?; + + match builtin { + crate::BuiltIn::SubgroupSize => { + writeln!(self.out, "WaveGetLaneCount();")? + } + crate::BuiltIn::SubgroupInvocationId => { + writeln!(self.out, "WaveGetLaneIndex();")? + } + crate::BuiltIn::NumSubgroups => writeln!( + self.out, + "({}u + WaveGetLaneCount() - 1u) / WaveGetLaneCount();", + ep.workgroup_size[0] * ep.workgroup_size[1] * ep.workgroup_size[2] + )?, + crate::BuiltIn::SubgroupId => { + writeln!( + self.out, + "(__local_invocation_id.x * {}u + __local_invocation_id.y * {}u + __local_invocation_id.z) / WaveGetLaneCount();", + ep.workgroup_size[0] * ep.workgroup_size[1], + ep.workgroup_size[1], + )?; + } + _ => unreachable!(), + } + } + } + } + } + if let back::FunctionType::EntryPoint(index) = func_ctx.ty { self.write_ep_arguments_initialization(module, func, index)?; } @@ -1267,14 +1328,20 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { fn need_workgroup_variables_initialization( &mut self, + func: &crate::Function, func_ctx: &back::FunctionCtx, module: &Module, ) -> bool { - self.options.zero_initialize_workgroup_memory + func.arguments.iter().any(|arg| { + matches!( + arg.binding, + Some(crate::Binding::BuiltIn(crate::BuiltIn::NumSubgroups)) + ) + }) || (self.options.zero_initialize_workgroup_memory && func_ctx.ty.is_compute_entry_point(module) && module.global_variables.iter().any(|(handle, var)| { !func_ctx.info[handle].is_empty() && var.space == crate::AddressSpace::WorkGroup - }) + })) } fn write_workgroup_variables_initialization( @@ -2006,7 +2073,6 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { Statement::RayQuery { .. } => unreachable!(), Statement::SubgroupBallot { result, predicate } => { write!(self.out, "{level}")?; - let name = format!("{}{}", back::BAKE_PREFIX, result.index()); write!(self.out, "const uint4 {name} = ")?; self.named_expressions.insert(result, name); @@ -2024,14 +2090,109 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { argument, result, } => { - unimplemented!(); // FIXME + write!(self.out, "{level}")?; + write!(self.out, "const ")?; + let name = format!("{}{}", back::BAKE_PREFIX, result.index()); + match func_ctx.info[result].ty { + proc::TypeResolution::Handle(handle) => self.write_type(module, handle)?, + proc::TypeResolution::Value(ref value) => { + self.write_value_type(module, value)? + } + }; + write!(self.out, " {name} = ")?; + self.named_expressions.insert(result, name); + + match (collective_op, op) { + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::All) => { + write!(self.out, "WaveActiveAllTrue(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Any) => { + write!(self.out, "WaveActiveAnyTrue(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Add) => { + write!(self.out, "WaveActiveSum(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Mul) => { + write!(self.out, "WaveActiveProduct(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Max) => { + write!(self.out, "WaveActiveMax(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Min) => { + write!(self.out, "WaveActiveMin(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::And) => { + write!(self.out, "WaveActiveBitAnd(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Or) => { + write!(self.out, "WaveActiveBitOr(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Xor) => { + write!(self.out, "WaveActiveBitXor(")? + } + (crate::CollectiveOperation::ExclusiveScan, crate::SubgroupOperation::Add) => { + write!(self.out, "WavePrefixSum(")? + } + (crate::CollectiveOperation::ExclusiveScan, crate::SubgroupOperation::Mul) => { + write!(self.out, "WavePrefixProduct(")? + } + (crate::CollectiveOperation::InclusiveScan, crate::SubgroupOperation::Add) => { + self.write_expr(module, argument, func_ctx)?; + write!(self.out, " + WavePrefixSum(")?; + } + (crate::CollectiveOperation::InclusiveScan, crate::SubgroupOperation::Mul) => { + self.write_expr(module, argument, func_ctx)?; + write!(self.out, " * WavePrefixProduct(")?; + } + _ => unimplemented!(), + } + self.write_expr(module, argument, func_ctx)?; + writeln!(self.out, ");")?; } Statement::SubgroupGather { mode, argument, result, } => { - unimplemented!(); // FIXME + write!(self.out, "{level}")?; + write!(self.out, "const ")?; + let name = format!("{}{}", back::BAKE_PREFIX, result.index()); + match func_ctx.info[result].ty { + proc::TypeResolution::Handle(handle) => self.write_type(module, handle)?, + proc::TypeResolution::Value(ref value) => { + self.write_value_type(module, value)? + } + }; + write!(self.out, " {name} = ")?; + self.named_expressions.insert(result, name); + + if matches!(mode, crate::GatherMode::BroadcastFirst) { + write!(self.out, "WaveReadLaneFirst(")?; + self.write_expr(module, argument, func_ctx)?; + } else { + write!(self.out, "WaveReadLaneAt(")?; + self.write_expr(module, argument, func_ctx)?; + write!(self.out, ", ")?; + match mode { + crate::GatherMode::BroadcastFirst => unreachable!(), + crate::GatherMode::Broadcast(index) | crate::GatherMode::Shuffle(index) => { + self.write_expr(module, index, func_ctx)?; + } + crate::GatherMode::ShuffleDown(index) => { + write!(self.out, "WaveGetLaneIndex() + ")?; + self.write_expr(module, index, func_ctx)?; + } + crate::GatherMode::ShuffleUp(index) => { + write!(self.out, "WaveGetLaneIndex() - ")?; + self.write_expr(module, index, func_ctx)?; + } + crate::GatherMode::ShuffleXor(index) => { + write!(self.out, "WaveGetLaneIndex() ^ ")?; + self.write_expr(module, index, func_ctx)?; + } + } + } + writeln!(self.out, ");")?; } } @@ -3251,7 +3412,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { writeln!(self.out, "{level}GroupMemoryBarrierWithGroupSync();")?; } if barrier.contains(crate::Barrier::SUB_GROUP) { - unimplemented!() // FIXME + // Does not exist in DirectX } Ok(()) } diff --git a/src/back/msl/mod.rs b/src/back/msl/mod.rs index 4e0d8489e4..eee825a83b 100644 --- a/src/back/msl/mod.rs +++ b/src/back/msl/mod.rs @@ -438,9 +438,10 @@ impl ResolvedBinding { Bi::WorkGroupSize => "dispatch_threads_per_threadgroup", Bi::NumWorkGroups => "threadgroups_per_grid", // subgroup - Bi::NumSubgroups | Bi::SubgroupId => todo!(), - Bi::SubgroupInvocationId => "simdgroup_index_in_threadgroup", - Bi::SubgroupSize => "simdgroups_per_threadgroup", + Bi::NumSubgroups => "simdgroups_per_threadgroup", + Bi::SubgroupId => "simdgroup_index_in_threadgroup", + Bi::SubgroupSize => "threads_per_simdgroup", + Bi::SubgroupInvocationId => "thread_index_in_simdgroup", Bi::CullDistance | Bi::ViewIndex => { return Err(Error::UnsupportedBuiltIn(built_in)) } diff --git a/src/back/msl/writer.rs b/src/back/msl/writer.rs index 629b98c96d..ce57588240 100644 --- a/src/back/msl/writer.rs +++ b/src/back/msl/writer.rs @@ -3017,14 +3017,13 @@ impl Writer { let name = self.namer.call(""); self.start_baking_expression(result, &context.expression, &name)?; self.named_expressions.insert(result, name); - write!(self.out, "{NAMESPACE}::simd_ballot(;")?; - match predicate { - Some(predicate) => { - self.put_expression(predicate, &context.expression, true)? - } - None => write!(self.out, "true")?, + write!(self.out, "uint4((uint64_t){NAMESPACE}::simd_ballot(")?; + if let Some(predicate) = predicate { + self.put_expression(predicate, &context.expression, true)?; + } else { + write!(self.out, "true")?; } - writeln!(self.out, ");")?; + writeln!(self.out, "), 0, 0, 0);")?; } crate::Statement::SubgroupCollectiveOperation { op, @@ -3032,14 +3031,101 @@ impl Writer { argument, result, } => { - unimplemented!(); // FIXME + write!(self.out, "{level}")?; + let name = self.namer.call(""); + self.start_baking_expression(result, &context.expression, &name)?; + self.named_expressions.insert(result, name); + match (collective_op, op) { + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::All) => { + write!(self.out, "{NAMESPACE}::simd_all(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Any) => { + write!(self.out, "{NAMESPACE}::simd_any(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Add) => { + write!(self.out, "{NAMESPACE}::simd_sum(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Mul) => { + write!(self.out, "{NAMESPACE}::simd_product(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Max) => { + write!(self.out, "{NAMESPACE}::simd_max(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Min) => { + write!(self.out, "{NAMESPACE}::simd_min(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::And) => { + write!(self.out, "{NAMESPACE}::simd_and(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Or) => { + write!(self.out, "{NAMESPACE}::simd_or(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Xor) => { + write!(self.out, "{NAMESPACE}::simd_xor(")? + } + ( + crate::CollectiveOperation::ExclusiveScan, + crate::SubgroupOperation::Add, + ) => write!(self.out, "{NAMESPACE}::simd_prefix_exclusive_sum(")?, + ( + crate::CollectiveOperation::ExclusiveScan, + crate::SubgroupOperation::Mul, + ) => write!(self.out, "{NAMESPACE}::simd_prefix_exclusive_product(")?, + ( + crate::CollectiveOperation::InclusiveScan, + crate::SubgroupOperation::Add, + ) => write!(self.out, "{NAMESPACE}::simd_prefix_inclusive_sum(")?, + ( + crate::CollectiveOperation::InclusiveScan, + crate::SubgroupOperation::Mul, + ) => write!(self.out, "{NAMESPACE}::simd_prefix_inclusive_product(")?, + _ => unimplemented!(), + } + self.put_expression(argument, &context.expression, true)?; + writeln!(self.out, ");")?; } crate::Statement::SubgroupGather { mode, argument, result, } => { - unimplemented!(); // FIXME + write!(self.out, "{level}")?; + let name = self.namer.call(""); + self.start_baking_expression(result, &context.expression, &name)?; + self.named_expressions.insert(result, name); + match mode { + crate::GatherMode::BroadcastFirst => { + write!(self.out, "{NAMESPACE}::simd_broadcast_first(")?; + } + crate::GatherMode::Broadcast(_) => { + write!(self.out, "{NAMESPACE}::simd_broadcast(")?; + } + crate::GatherMode::Shuffle(_) => { + write!(self.out, "{NAMESPACE}::simd_shuffle(")?; + } + crate::GatherMode::ShuffleDown(_) => { + write!(self.out, "{NAMESPACE}::simd_shuffle_down(")?; + } + crate::GatherMode::ShuffleUp(_) => { + write!(self.out, "{NAMESPACE}::simd_shuffle_up(")?; + } + crate::GatherMode::ShuffleXor(_) => { + write!(self.out, "{NAMESPACE}::simd_shuffle_xor(")?; + } + } + self.put_expression(argument, &context.expression, true)?; + match mode { + crate::GatherMode::BroadcastFirst => {} + crate::GatherMode::Broadcast(index) + | crate::GatherMode::Shuffle(index) + | crate::GatherMode::ShuffleDown(index) + | crate::GatherMode::ShuffleUp(index) + | crate::GatherMode::ShuffleXor(index) => { + write!(self.out, ", ")?; + self.put_expression(index, &context.expression, true)?; + } + } + writeln!(self.out, ");")?; } } } @@ -4378,7 +4464,10 @@ impl Writer { )?; } if flags.contains(crate::Barrier::SUB_GROUP) { - unimplemented!(); // FIXME + writeln!( + self.out, + "{level}{NAMESPACE}::simdgroup_barrier({NAMESPACE}::mem_flags::mem_threadgroup);", + )?; } Ok(()) } diff --git a/src/back/spv/block.rs b/src/back/spv/block.rs index 9fba489f79..50883ce071 100644 --- a/src/back/spv/block.rs +++ b/src/back/spv/block.rs @@ -2340,30 +2340,11 @@ impl<'w> BlockContext<'w> { crate::Statement::RayQuery { query, ref fun } => { self.write_ray_query_function(query, fun, &mut block); } - crate::Statement::SubgroupBallot { result, predicate } => { - self.writer.require_any( - "GroupNonUniformBallot", - &[spirv::Capability::GroupNonUniformBallot], - )?; - let vec4_u32_type_id = self.get_type_id(LookupType::Local(LocalType::Value { - vector_size: Some(crate::VectorSize::Quad), - kind: crate::ScalarKind::Uint, - width: 4, - pointer_space: None, - })); - let exec_scope_id = self.get_index_constant(spirv::Scope::Subgroup as u32); - let predicate = match predicate { - Some(predicate) => self.cached[predicate], - None => self.writer.get_constant_scalar(crate::Literal::Bool(true)), - }; - let id = self.gen_id(); - block.body.push(Instruction::group_non_uniform_ballot( - vec4_u32_type_id, - id, - exec_scope_id, - predicate, - )); - self.cached[result] = id; + crate::Statement::SubgroupBallot { + result, + ref predicate, + } => { + self.write_subgroup_ballot(predicate, result, &mut block)?; } crate::Statement::SubgroupCollectiveOperation { ref op, @@ -2378,7 +2359,7 @@ impl<'w> BlockContext<'w> { argument, result, } => { - self.write_subgroup_broadcast(mode, argument, result, &mut block)?; + self.write_subgroup_gather(mode, argument, result, &mut block)?; } } } diff --git a/src/back/spv/instructions.rs b/src/back/spv/instructions.rs index 8a528065ef..5f7c6b34fd 100644 --- a/src/back/spv/instructions.rs +++ b/src/back/spv/instructions.rs @@ -1054,33 +1054,34 @@ impl super::Instruction { instruction } - pub(super) fn group_non_uniform_broadcast( + pub(super) fn group_non_uniform_broadcast_first( result_type_id: Word, id: Word, exec_scope_id: Word, value: Word, - index: Word, ) -> Self { - let mut instruction = Self::new(Op::GroupNonUniformBroadcast); + let mut instruction = Self::new(Op::GroupNonUniformBroadcastFirst); instruction.set_type(result_type_id); instruction.set_result(id); instruction.add_operand(exec_scope_id); instruction.add_operand(value); - instruction.add_operand(index); instruction } - pub(super) fn group_non_uniform_broadcast_first( + pub(super) fn group_non_uniform_gather( + op: Op, result_type_id: Word, id: Word, exec_scope_id: Word, value: Word, + index: Word, ) -> Self { - let mut instruction = Self::new(Op::GroupNonUniformBroadcastFirst); + let mut instruction = Self::new(op); instruction.set_type(result_type_id); instruction.set_result(id); instruction.add_operand(exec_scope_id); instruction.add_operand(value); + instruction.add_operand(index); instruction } @@ -1092,10 +1093,6 @@ impl super::Instruction { group_op: Option, value: Word, ) -> Self { - println!( - "{:?}", - (op, result_type_id, id, exec_scope_id, group_op, value) - ); let mut instruction = Self::new(op); instruction.set_type(result_type_id); instruction.set_result(id); diff --git a/src/back/spv/subgroup.rs b/src/back/spv/subgroup.rs index 7206ec9312..79db752a6c 100644 --- a/src/back/spv/subgroup.rs +++ b/src/back/spv/subgroup.rs @@ -1,7 +1,43 @@ use super::{Block, BlockContext, Error, Instruction}; -use crate::{arena::Handle, TypeInner}; +use crate::{ + arena::Handle, + back::spv::{LocalType, LookupType}, + TypeInner, +}; impl<'w> BlockContext<'w> { + pub(super) fn write_subgroup_ballot( + &mut self, + predicate: &Option>, + result: Handle, + block: &mut Block, + ) -> Result<(), Error> { + self.writer.require_any( + "GroupNonUniformBallot", + &[spirv::Capability::GroupNonUniformBallot], + )?; + let vec4_u32_type_id = self.get_type_id(LookupType::Local(LocalType::Value { + vector_size: Some(crate::VectorSize::Quad), + kind: crate::ScalarKind::Uint, + width: 4, + pointer_space: None, + })); + let exec_scope_id = self.get_index_constant(spirv::Scope::Subgroup as u32); + let predicate = if let Some(predicate) = *predicate { + self.cached[predicate] + } else { + self.writer.get_constant_scalar(crate::Literal::Bool(true)) + }; + let id = self.gen_id(); + block.body.push(Instruction::group_non_uniform_ballot( + vec4_u32_type_id, + id, + exec_scope_id, + predicate, + )); + self.cached[result] = id; + Ok(()) + } pub(super) fn write_subgroup_operation( &mut self, op: &crate::SubgroupOperation, @@ -11,7 +47,7 @@ impl<'w> BlockContext<'w> { block: &mut Block, ) -> Result<(), Error> { use crate::SubgroupOperation as sg; - match op { + match *op { sg::All | sg::Any => { self.writer.require_any( "GroupNonUniformVote", @@ -21,11 +57,7 @@ impl<'w> BlockContext<'w> { _ => { self.writer.require_any( "GroupNonUniformArithmetic", - &[ - spirv::Capability::GroupNonUniformArithmetic, - spirv::Capability::GroupNonUniformClustered, - spirv::Capability::GroupNonUniformPartitionedNV, - ], + &[spirv::Capability::GroupNonUniformArithmetic], )?; } } @@ -35,14 +67,14 @@ impl<'w> BlockContext<'w> { let result_type_id = self.get_expression_type_id(result_ty); let result_ty_inner = result_ty.inner_with(&self.ir_module.types); - let (is_scalar, kind) = match result_ty_inner { + let (is_scalar, kind) = match *result_ty_inner { TypeInner::Scalar { kind, .. } => (true, kind), TypeInner::Vector { kind, .. } => (false, kind), _ => unimplemented!(), }; use crate::ScalarKind as sk; - let spirv_op = match (kind, op) { + let spirv_op = match (kind, *op) { (sk::Bool, sg::All) if is_scalar => spirv::Op::GroupNonUniformAll, (sk::Bool, sg::Any) if is_scalar => spirv::Op::GroupNonUniformAny, (_, sg::All | sg::Any) => unimplemented!(), @@ -71,9 +103,9 @@ impl<'w> BlockContext<'w> { let exec_scope_id = self.get_index_constant(spirv::Scope::Subgroup as u32); use crate::CollectiveOperation as c; - let group_op = match op { + let group_op = match *op { sg::All | sg::Any => None, - _ => Some(match collective_op { + _ => Some(match *collective_op { c::Reduce => spirv::GroupOperation::Reduce, c::InclusiveScan => spirv::GroupOperation::InclusiveScan, c::ExclusiveScan => spirv::GroupOperation::ExclusiveScan, @@ -92,7 +124,7 @@ impl<'w> BlockContext<'w> { self.cached[result] = id; Ok(()) } - pub(super) fn write_subgroup_broadcast( + pub(super) fn write_subgroup_gather( &mut self, mode: &crate::GatherMode, argument: Handle, @@ -103,6 +135,26 @@ impl<'w> BlockContext<'w> { "GroupNonUniformBallot", &[spirv::Capability::GroupNonUniformBallot], )?; + match *mode { + crate::GatherMode::BroadcastFirst | crate::GatherMode::Broadcast(_) => { + self.writer.require_any( + "GroupNonUniformBallot", + &[spirv::Capability::GroupNonUniformBallot], + )?; + } + crate::GatherMode::Shuffle(_) | crate::GatherMode::ShuffleXor(_) => { + self.writer.require_any( + "GroupNonUniformShuffle", + &[spirv::Capability::GroupNonUniformShuffle], + )?; + } + crate::GatherMode::ShuffleDown(_) | crate::GatherMode::ShuffleUp(_) => { + self.writer.require_any( + "GroupNonUniformShuffleRelative", + &[spirv::Capability::GroupNonUniformShuffleRelative], + )?; + } + } let id = self.gen_id(); let result_ty = &self.fun_info[result].ty; @@ -111,17 +163,7 @@ impl<'w> BlockContext<'w> { let exec_scope_id = self.get_index_constant(spirv::Scope::Subgroup as u32); let arg_id = self.cached[argument]; - match mode { - crate::GatherMode::Broadcast(index) => { - let index_id = self.cached[*index]; - block.body.push(Instruction::group_non_uniform_broadcast( - result_type_id, - id, - exec_scope_id, - arg_id, - index_id, - )); - } + match *mode { crate::GatherMode::BroadcastFirst => { block .body @@ -132,10 +174,29 @@ impl<'w> BlockContext<'w> { arg_id, )); } - crate::GatherMode::Shuffle(index) + crate::GatherMode::Broadcast(index) + | crate::GatherMode::Shuffle(index) | crate::GatherMode::ShuffleDown(index) | crate::GatherMode::ShuffleUp(index) - | crate::GatherMode::ShuffleXor(index) => todo!(), + | crate::GatherMode::ShuffleXor(index) => { + let index_id = self.cached[index]; + let op = match *mode { + crate::GatherMode::BroadcastFirst => unreachable!(), + crate::GatherMode::Broadcast(_) => spirv::Op::GroupNonUniformBroadcast, + crate::GatherMode::Shuffle(_) => spirv::Op::GroupNonUniformShuffle, + crate::GatherMode::ShuffleDown(_) => spirv::Op::GroupNonUniformShuffleDown, + crate::GatherMode::ShuffleUp(_) => spirv::Op::GroupNonUniformShuffleUp, + crate::GatherMode::ShuffleXor(_) => spirv::Op::GroupNonUniformShuffleXor, + }; + block.body.push(Instruction::group_non_uniform_gather( + op, + result_type_id, + id, + exec_scope_id, + arg_id, + index_id, + )); + } } self.cached[result] = id; Ok(()) diff --git a/src/back/spv/writer.rs b/src/back/spv/writer.rs index 6a55288308..da0fdf766f 100644 --- a/src/back/spv/writer.rs +++ b/src/back/spv/writer.rs @@ -1594,17 +1594,23 @@ impl Writer { Bi::WorkGroupSize => BuiltIn::WorkgroupSize, Bi::NumWorkGroups => BuiltIn::NumWorkgroups, // Subgroup - Bi::NumSubgroups | Bi::SubgroupId => todo!(), - Bi::SubgroupInvocationId => { + Bi::NumSubgroups => { self.require_any( - "`subgroup_invocation_id` built-in", + "`num_subgroups` built-in", &[spirv::Capability::GroupNonUniform], )?; - BuiltIn::SubgroupLocalInvocationId + BuiltIn::NumSubgroups + } + Bi::SubgroupId => { + self.require_any( + "`subgroup_id` built-in", + &[spirv::Capability::GroupNonUniform], + )?; + BuiltIn::SubgroupId } Bi::SubgroupSize => { self.require_any( - "`subgroup_invocation_id` built-in", + "`subgroup_size` built-in", &[ spirv::Capability::GroupNonUniform, spirv::Capability::SubgroupBallotKHR, @@ -1612,6 +1618,16 @@ impl Writer { )?; BuiltIn::SubgroupSize } + Bi::SubgroupInvocationId => { + self.require_any( + "`subgroup_invocation_id` built-in", + &[ + spirv::Capability::GroupNonUniform, + spirv::Capability::SubgroupBallotKHR, + ], + )?; + BuiltIn::SubgroupLocalInvocationId + } }; self.decorate(id, Decoration::BuiltIn, &[built_in as u32]); diff --git a/src/back/wgsl/writer.rs b/src/back/wgsl/writer.rs index 8fc9881f10..0bc2dfceb0 100644 --- a/src/back/wgsl/writer.rs +++ b/src/back/wgsl/writer.rs @@ -919,6 +919,10 @@ impl Writer { if barrier.contains(crate::Barrier::WORK_GROUP) { writeln!(self.out, "{level}workgroupBarrier();")?; } + + if barrier.contains(crate::Barrier::SUB_GROUP) { + writeln!(self.out, "{level}subgroupBarrier();")?; + } } Statement::RayQuery { .. } => unreachable!(), Statement::SubgroupBallot { result, predicate } => { @@ -939,14 +943,99 @@ impl Writer { argument, result, } => { - unimplemented!() // FIXME + write!(self.out, "{level}")?; + let res_name = format!("{}{}", back::BAKE_PREFIX, result.index()); + self.start_named_expr(module, result, func_ctx, &res_name)?; + self.named_expressions.insert(result, res_name); + + match (collective_op, op) { + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::All) => { + write!(self.out, "subgroupAll(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Any) => { + write!(self.out, "subgroupAny(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Add) => { + write!(self.out, "subgroupAdd(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Mul) => { + write!(self.out, "subgroupMul(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Max) => { + write!(self.out, "subgroupMax(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Min) => { + write!(self.out, "subgroupMin(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::And) => { + write!(self.out, "subgroupAnd(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Or) => { + write!(self.out, "subgroupOr(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Xor) => { + write!(self.out, "subgroupXor(")? + } + (crate::CollectiveOperation::ExclusiveScan, crate::SubgroupOperation::Add) => { + write!(self.out, "subgroupPrefixExclusiveAdd(")? + } + (crate::CollectiveOperation::ExclusiveScan, crate::SubgroupOperation::Mul) => { + write!(self.out, "subgroupPrefixExclusiveMul(")? + } + (crate::CollectiveOperation::InclusiveScan, crate::SubgroupOperation::Add) => { + write!(self.out, "subgroupPrefixInclusiveAdd(")? + } + (crate::CollectiveOperation::InclusiveScan, crate::SubgroupOperation::Mul) => { + write!(self.out, "subgroupPrefixInclusiveMul(")? + } + _ => unimplemented!(), + } + self.write_expr(module, argument, func_ctx)?; + writeln!(self.out, ");")?; } Statement::SubgroupGather { mode, argument, result, } => { - unimplemented!() // FIXME + write!(self.out, "{level}")?; + let res_name = format!("{}{}", back::BAKE_PREFIX, result.index()); + self.start_named_expr(module, result, func_ctx, &res_name)?; + self.named_expressions.insert(result, res_name); + + match mode { + crate::GatherMode::BroadcastFirst => { + write!(self.out, "subgroupBroadcastFirst(")?; + } + crate::GatherMode::Broadcast(_) => { + write!(self.out, "subgroupBroadcast(")?; + } + crate::GatherMode::Shuffle(_) => { + write!(self.out, "subgroupShuffle(")?; + } + crate::GatherMode::ShuffleDown(_) => { + write!(self.out, "subgroupShuffleDown(")?; + } + crate::GatherMode::ShuffleUp(_) => { + write!(self.out, "subgroupShuffleUp(")?; + } + crate::GatherMode::ShuffleXor(_) => { + write!(self.out, "subgroupShuffleXor(")?; + } + } + self.write_expr(module, argument, func_ctx)?; + match mode { + crate::GatherMode::BroadcastFirst => {} + crate::GatherMode::Broadcast(index) + | crate::GatherMode::Shuffle(index) + | crate::GatherMode::ShuffleDown(index) + | crate::GatherMode::ShuffleUp(index) + | crate::GatherMode::ShuffleXor(index) => { + write!(self.out, ", ")?; + self.write_expr(module, index, func_ctx)?; + } + } + writeln!(self.out, ");")?; } } @@ -1789,9 +1878,10 @@ fn builtin_str(built_in: crate::BuiltIn) -> Result<&'static str, Error> { Bi::SampleMask => "sample_mask", Bi::PrimitiveIndex => "primitive_index", Bi::ViewIndex => "view_index", - Bi::NumSubgroups | Bi::SubgroupId => todo!(), - Bi::SubgroupInvocationId => "subgroup_invocation_id", + Bi::NumSubgroups => "num_subgroups", + Bi::SubgroupId => "subgroup_id", Bi::SubgroupSize => "subgroup_size", + Bi::SubgroupInvocationId => "subgroup_invocation_id", Bi::BaseInstance | Bi::BaseVertex | Bi::ClipDistance diff --git a/src/front/spv/error.rs b/src/front/spv/error.rs index 2f9bf2d1bc..8508ede042 100644 --- a/src/front/spv/error.rs +++ b/src/front/spv/error.rs @@ -54,6 +54,8 @@ pub enum Error { UnknownBinaryOperator(spirv::Op), #[error("unknown relational function {0:?}")] UnknownRelationalFunction(spirv::Op), + #[error("unsupported group opeation %{0}")] + UnsupportedGroupOperation(spirv::Word), #[error("invalid parameter {0:?}")] InvalidParameter(spirv::Op), #[error("invalid operand count {1} for {0:?}")] @@ -116,8 +118,8 @@ pub enum Error { FunctionCallCycle(spirv::Word), #[error("invalid array size {0:?}")] InvalidArraySize(Handle), - #[error("invalid barrier scope %{0}")] - InvalidBarrierScope(spirv::Word), + #[error("invalid execution scope %{0}")] + InvalidExecutionScope(spirv::Word), #[error("invalid barrier memory semantics %{0}")] InvalidBarrierMemorySemantics(spirv::Word), #[error( diff --git a/src/front/spv/mod.rs b/src/front/spv/mod.rs index c78407ef99..e7c082c6a5 100644 --- a/src/front/spv/mod.rs +++ b/src/front/spv/mod.rs @@ -3650,7 +3650,7 @@ impl> Frontend { let semantics_const = self.lookup_constant.lookup(semantics_id)?; let exec_scope = resolve_constant(ctx.gctx(), exec_scope_const.handle) - .ok_or(Error::InvalidBarrierScope(exec_scope_id))?; + .ok_or(Error::InvalidExecutionScope(exec_scope_id))?; let semantics = resolve_constant(ctx.gctx(), semantics_const.handle) .ok_or(Error::InvalidBarrierMemorySemantics(semantics_id))?; @@ -3691,6 +3691,209 @@ impl> Frontend { }, ); } + Op::GroupNonUniformBallot => { + inst.expect(4)?; + let _result_type_id = self.next()?; + let result_id = self.next()?; + let exec_scope_id = self.next()?; + let predicate_id = self.next()?; + + let result_lookup = self.lookup_expression.lookup(result_id)?; + let result_handle = get_expr_handle!(result_id, result_lookup); + + let exec_scope_const = self.lookup_constant.lookup(exec_scope_id)?; + let _exec_scope = resolve_constant(ctx.gctx(), exec_scope_const.handle) + .filter(|exec_scope| *exec_scope == spirv::Scope::Subgroup as u32) + .ok_or(Error::InvalidExecutionScope(exec_scope_id))?; + + let predicate = if self + .lookup_constant + .lookup(predicate_id) + .ok() + .filter(|predicate_const| { + matches!( + ctx.gctx().const_expressions + [ctx.gctx().constants[predicate_const.handle].init], + crate::Expression::Literal(crate::Literal::Bool(true)) + ) + }) + .is_some() + { + None + } else { + let predicate_lookup = self.lookup_expression.lookup(predicate_id)?; + let predicate_handle = get_expr_handle!(predicate_id, predicate_lookup); + Some(predicate_handle) + }; + + block.push( + crate::Statement::SubgroupBallot { + result: result_handle, + predicate, + }, + span, + ); + } + spirv::Op::GroupNonUniformAll + | spirv::Op::GroupNonUniformAny + | spirv::Op::GroupNonUniformIAdd + | spirv::Op::GroupNonUniformFAdd + | spirv::Op::GroupNonUniformIMul + | spirv::Op::GroupNonUniformFMul + | spirv::Op::GroupNonUniformSMax + | spirv::Op::GroupNonUniformUMax + | spirv::Op::GroupNonUniformFMax + | spirv::Op::GroupNonUniformSMin + | spirv::Op::GroupNonUniformUMin + | spirv::Op::GroupNonUniformFMin + | spirv::Op::GroupNonUniformBitwiseAnd + | spirv::Op::GroupNonUniformBitwiseOr + | spirv::Op::GroupNonUniformBitwiseXor + | spirv::Op::GroupNonUniformLogicalAnd + | spirv::Op::GroupNonUniformLogicalOr + | spirv::Op::GroupNonUniformLogicalXor => { + inst.expect( + if matches!( + inst.op, + spirv::Op::GroupNonUniformAll | spirv::Op::GroupNonUniformAny + ) { + 4 + } else { + 5 + }, + )?; + let _result_type_id = self.next()?; + let result_id = self.next()?; + let exec_scope_id = self.next()?; + let collective_op_id = match inst.op { + spirv::Op::GroupNonUniformAll | spirv::Op::GroupNonUniformAny => { + crate::CollectiveOperation::Reduce + } + _ => { + let group_op_id = self.next()?; + match spirv::GroupOperation::from_u32(group_op_id) { + Some(spirv::GroupOperation::Reduce) => { + crate::CollectiveOperation::Reduce + } + Some(spirv::GroupOperation::InclusiveScan) => { + crate::CollectiveOperation::InclusiveScan + } + Some(spirv::GroupOperation::ExclusiveScan) => { + crate::CollectiveOperation::ExclusiveScan + } + _ => return Err(Error::UnsupportedGroupOperation(group_op_id)), + } + } + }; + let argument_id = self.next()?; + + let result_lookup = self.lookup_expression.lookup(result_id)?; + let result_handle = get_expr_handle!(result_id, result_lookup); + let argument_lookup = self.lookup_expression.lookup(argument_id)?; + let argument_handle = get_expr_handle!(argument_id, argument_lookup); + + let exec_scope_const = self.lookup_constant.lookup(exec_scope_id)?; + let _exec_scope = resolve_constant(ctx.gctx(), exec_scope_const.handle) + .filter(|exec_scope| *exec_scope == spirv::Scope::Subgroup as u32) + .ok_or(Error::InvalidExecutionScope(exec_scope_id))?; + + let op_id = match inst.op { + spirv::Op::GroupNonUniformAll => crate::SubgroupOperation::All, + spirv::Op::GroupNonUniformAny => crate::SubgroupOperation::Any, + spirv::Op::GroupNonUniformIAdd | spirv::Op::GroupNonUniformFAdd => { + crate::SubgroupOperation::Add + } + spirv::Op::GroupNonUniformIMul | spirv::Op::GroupNonUniformFMul => { + crate::SubgroupOperation::Mul + } + spirv::Op::GroupNonUniformSMax + | spirv::Op::GroupNonUniformUMax + | spirv::Op::GroupNonUniformFMax => crate::SubgroupOperation::Max, + spirv::Op::GroupNonUniformSMin + | spirv::Op::GroupNonUniformUMin + | spirv::Op::GroupNonUniformFMin => crate::SubgroupOperation::Min, + spirv::Op::GroupNonUniformBitwiseAnd + | spirv::Op::GroupNonUniformLogicalAnd => crate::SubgroupOperation::And, + spirv::Op::GroupNonUniformBitwiseOr + | spirv::Op::GroupNonUniformLogicalOr => crate::SubgroupOperation::Or, + spirv::Op::GroupNonUniformBitwiseXor + | spirv::Op::GroupNonUniformLogicalXor => crate::SubgroupOperation::Xor, + _ => unreachable!(), + }; + + block.push( + crate::Statement::SubgroupCollectiveOperation { + result: result_handle, + op: op_id, + collective_op: collective_op_id, + argument: argument_handle, + }, + span, + ); + } + Op::GroupNonUniformBroadcastFirst + | Op::GroupNonUniformBroadcast + | Op::GroupNonUniformShuffle + | Op::GroupNonUniformShuffleDown + | Op::GroupNonUniformShuffleUp + | Op::GroupNonUniformShuffleXor => { + inst.expect( + if matches!(inst.op, spirv::Op::GroupNonUniformBroadcastFirst) { + 4 + } else { + 5 + }, + )?; + let _result_type_id = self.next()?; + let result_id = self.next()?; + let exec_scope_id = self.next()?; + let argument_id = self.next()?; + + let result_lookup = self.lookup_expression.lookup(result_id)?; + let result_handle = get_expr_handle!(result_id, result_lookup); + let argument_lookup = self.lookup_expression.lookup(argument_id)?; + let argument_handle = get_expr_handle!(argument_id, argument_lookup); + + let exec_scope_const = self.lookup_constant.lookup(exec_scope_id)?; + let _exec_scope = resolve_constant(ctx.gctx(), exec_scope_const.handle) + .filter(|exec_scope| *exec_scope == spirv::Scope::Subgroup as u32) + .ok_or(Error::InvalidExecutionScope(exec_scope_id))?; + + let mode = if matches!(inst.op, spirv::Op::GroupNonUniformBroadcastFirst) { + crate::GatherMode::BroadcastFirst + } else { + let index_id = self.next()?; + let index_lookup = self.lookup_expression.lookup(index_id)?; + let index_handle = get_expr_handle!(index_id, index_lookup); + match inst.op { + spirv::Op::GroupNonUniformBroadcast => { + crate::GatherMode::Broadcast(index_handle) + } + spirv::Op::GroupNonUniformShuffle => { + crate::GatherMode::Shuffle(index_handle) + } + spirv::Op::GroupNonUniformShuffleDown => { + crate::GatherMode::ShuffleDown(index_handle) + } + spirv::Op::GroupNonUniformShuffleUp => { + crate::GatherMode::ShuffleUp(index_handle) + } + spirv::Op::GroupNonUniformShuffleXor => { + crate::GatherMode::ShuffleXor(index_handle) + } + _ => unreachable!(), + } + }; + + block.push( + crate::Statement::SubgroupGather { + result: result_handle, + mode, + argument: argument_handle, + }, + span, + ); + } _ => return Err(Error::UnsupportedInstruction(self.state, inst.op)), } }; @@ -3811,7 +4014,10 @@ impl> Frontend { | S::Store { .. } | S::ImageStore { .. } | S::Atomic { .. } - | S::RayQuery { .. } => {} + | S::RayQuery { .. } + | S::SubgroupBallot { .. } + | S::SubgroupCollectiveOperation { .. } + | S::SubgroupGather { .. } => {} S::Call { function: ref mut callee, ref arguments, @@ -3842,9 +4048,6 @@ impl> Frontend { } } S::WorkGroupUniformLoad { .. } => unreachable!(), - S::SubgroupBallot { .. } => unreachable!(), // FIXME?? - S::SubgroupCollectiveOperation { .. } => unreachable!(), - S::SubgroupGather { .. } => unreachable!(), } i += 1; } diff --git a/src/front/wgsl/lower/mod.rs b/src/front/wgsl/lower/mod.rs index 127008d437..d3e47f6959 100644 --- a/src/front/wgsl/lower/mod.rs +++ b/src/front/wgsl/lower/mod.rs @@ -1918,7 +1918,13 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { } else if let Some(fun) = Texture::map(function.name) { self.texture_sample_helper(fun, arguments, span, ctx.reborrow())? } else if let Some((op, cop)) = conv::map_subgroup_operation(function.name) { - return Ok(Some(self.subgroup_helper(span, op, cop, arguments, ctx)?)); + return Ok(Some( + self.subgroup_operation_helper(span, op, cop, arguments, ctx)?, + )); + } else if let Some(mode) = conv::map_subgroup_gather(function.name) { + return Ok(Some( + self.subgroup_gather_helper(span, mode, arguments, ctx)?, + )); } else { match function.name { "select" => { @@ -2316,53 +2322,6 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { .push(crate::Statement::SubgroupBallot { result, predicate }, span); return Ok(Some(result)); } - "subgroupBroadcast" => { - let mut args = ctx.prepare_args(arguments, 2, span); - - let index = self.expression(args.next()?, ctx.reborrow())?; - let argument = self.expression(args.next()?, ctx.reborrow())?; - args.finish()?; - - let ty = ctx.register_type(argument)?; - - let result = ctx.interrupt_emitter( - crate::Expression::SubgroupOperationResult { ty }, - span, - )?; - let rctx = ctx.runtime_expression_ctx(span)?; - rctx.block.push( - crate::Statement::SubgroupGather { - mode: crate::GatherMode::Broadcast(index), - argument, - result, - }, - span, - ); - return Ok(Some(result)); - } - "subgroupBroadcastFirst" => { - let mut args = ctx.prepare_args(arguments, 1, span); - - let argument = self.expression(args.next()?, ctx.reborrow())?; - args.finish()?; - - let ty = ctx.register_type(argument)?; - - let result = ctx.interrupt_emitter( - crate::Expression::SubgroupOperationResult { ty }, - span, - )?; - let rctx = ctx.runtime_expression_ctx(span)?; - rctx.block.push( - crate::Statement::SubgroupGather { - mode: crate::GatherMode::BroadcastFirst, - argument, - result, - }, - span, - ); - return Ok(Some(result)); - } _ => return Err(Error::UnknownIdent(function.span, function.name)), } }; @@ -2554,7 +2513,8 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { depth_ref, }) } - fn subgroup_helper( + + fn subgroup_operation_helper( &mut self, span: Span, op: crate::SubgroupOperation, @@ -2584,6 +2544,46 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { Ok(result) } + fn subgroup_gather_helper( + &mut self, + span: Span, + mode: crate::GatherMode, + arguments: &[Handle>], + mut ctx: ExpressionContext<'source, '_, '_>, + ) -> Result, Error<'source>> { + let mut args = ctx.prepare_args(arguments, 2, span); + + let argument = self.expression(args.next()?, ctx.reborrow())?; + let index = if let crate::GatherMode::BroadcastFirst = mode { + Handle::new(NonZeroU32::new(u32::MAX).unwrap()) + } else { + self.expression(args.next()?, ctx.reborrow())? + }; + args.finish()?; + + let ty = ctx.register_type(argument)?; + + let result = + ctx.interrupt_emitter(crate::Expression::SubgroupOperationResult { ty }, span)?; + let rctx = ctx.runtime_expression_ctx(span)?; + rctx.block.push( + crate::Statement::SubgroupGather { + mode: match mode { + crate::GatherMode::BroadcastFirst => crate::GatherMode::BroadcastFirst, + crate::GatherMode::Broadcast(_) => crate::GatherMode::Broadcast(index), + crate::GatherMode::Shuffle(_) => crate::GatherMode::Shuffle(index), + crate::GatherMode::ShuffleDown(_) => crate::GatherMode::ShuffleDown(index), + crate::GatherMode::ShuffleUp(_) => crate::GatherMode::ShuffleUp(index), + crate::GatherMode::ShuffleXor(_) => crate::GatherMode::ShuffleXor(index), + }, + argument, + result, + }, + span, + ); + Ok(result) + } + fn r#struct( &mut self, s: &ast::Struct<'source>, diff --git a/src/front/wgsl/parse/conv.rs b/src/front/wgsl/parse/conv.rs index 160213e4a3..c53f4df753 100644 --- a/src/front/wgsl/parse/conv.rs +++ b/src/front/wgsl/parse/conv.rs @@ -35,8 +35,10 @@ pub fn map_built_in(word: &str, span: Span) -> Result> "workgroup_id" => crate::BuiltIn::WorkGroupId, "num_workgroups" => crate::BuiltIn::NumWorkGroups, // subgroup - "subgroup_invocation_id" => crate::BuiltIn::SubgroupInvocationId, + "num_subgroups" => crate::BuiltIn::NumSubgroups, + "subgroup_id" => crate::BuiltIn::SubgroupId, "subgroup_size" => crate::BuiltIn::SubgroupSize, + "subgroup_invocation_id" => crate::BuiltIn::SubgroupInvocationId, _ => return Err(Error::UnknownBuiltin(span)), }) } @@ -254,8 +256,25 @@ pub fn map_subgroup_operation( "subgroupAnd" => (sg::And, co::Reduce), "subgroupOr" => (sg::Or, co::Reduce), "subgroupXor" => (sg::Xor, co::Reduce), - "subgroupPrefixAdd" => (sg::Add, co::InclusiveScan), - "subgroupPrefixMul" => (sg::Mul, co::InclusiveScan), + "subgroupPrefixExclusiveAdd" => (sg::Add, co::ExclusiveScan), + "subgroupPrefixExclusiveMul" => (sg::Mul, co::ExclusiveScan), + "subgroupPrefixInclusiveAdd" => (sg::Add, co::InclusiveScan), + "subgroupPrefixInclusiveMul" => (sg::Mul, co::InclusiveScan), + _ => return None, + }) +} + +pub fn map_subgroup_gather(word: &str) -> Option { + use crate::GatherMode as gm; + use crate::Handle; + use std::num::NonZeroU32; + Some(match word { + "subgroupBroadcastFirst" => gm::BroadcastFirst, + "subgroupBroadcast" => gm::Broadcast(Handle::new(NonZeroU32::new(u32::MAX).unwrap())), + "subgroupShuffle" => gm::Shuffle(Handle::new(NonZeroU32::new(u32::MAX).unwrap())), + "subgroupShuffleDown" => gm::ShuffleDown(Handle::new(NonZeroU32::new(u32::MAX).unwrap())), + "subgroupShuffleUp" => gm::ShuffleUp(Handle::new(NonZeroU32::new(u32::MAX).unwrap())), + "subgroupShuffleXor" => gm::ShuffleXor(Handle::new(NonZeroU32::new(u32::MAX).unwrap())), _ => return None, }) } From 0db0b37a12aaf6968186c357db3ce4a2c42ad999 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alexander=20Mei=C3=9Fner?= Date: Sat, 14 Oct 2023 20:23:53 +0200 Subject: [PATCH 16/17] Adjusts metal backend test_stack_size(). --- src/back/msl/writer.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/back/msl/writer.rs b/src/back/msl/writer.rs index ce57588240..1b0791b573 100644 --- a/src/back/msl/writer.rs +++ b/src/back/msl/writer.rs @@ -4739,8 +4739,8 @@ fn test_stack_size() { } let stack_size = addresses_end - addresses_start; // check the size (in debug only) - // last observed macOS value: 19152 (CI) - if !(9000..=20000).contains(&stack_size) { + // last observed macOS value: 22256 (CI) + if !(15000..=25000).contains(&stack_size) { panic!("`put_block` stack size {stack_size} has changed!"); } } From 1fe737502cba971267a04b3010ae995e06038c76 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alexander=20Mei=C3=9Fner?= Date: Fri, 20 Oct 2023 23:28:20 +0200 Subject: [PATCH 17/17] Adds test and snapshots. --- tests/in/subgroup-operations.param.ron | 26 +++++++ tests/in/subgroup-operations.wgsl | 32 ++++++++ .../subgroup-operations.main.Compute.glsl | 41 +++++++++++ tests/out/hlsl/subgroup-operations.hlsl | 32 ++++++++ tests/out/hlsl/subgroup-operations.ron | 12 +++ tests/out/msl/subgroup-operations.msl | 38 ++++++++++ tests/out/spv/subgroup-operations.spvasm | 73 +++++++++++++++++++ tests/out/wgsl/subgroup-operations.wgsl | 26 +++++++ tests/snapshots.rs | 4 + 9 files changed, 284 insertions(+) create mode 100644 tests/in/subgroup-operations.param.ron create mode 100644 tests/in/subgroup-operations.wgsl create mode 100644 tests/out/glsl/subgroup-operations.main.Compute.glsl create mode 100644 tests/out/hlsl/subgroup-operations.hlsl create mode 100644 tests/out/hlsl/subgroup-operations.ron create mode 100644 tests/out/msl/subgroup-operations.msl create mode 100644 tests/out/spv/subgroup-operations.spvasm create mode 100644 tests/out/wgsl/subgroup-operations.wgsl diff --git a/tests/in/subgroup-operations.param.ron b/tests/in/subgroup-operations.param.ron new file mode 100644 index 0000000000..fc444a3efe --- /dev/null +++ b/tests/in/subgroup-operations.param.ron @@ -0,0 +1,26 @@ +( + spv: ( + version: (1, 3), + ), + msl: ( + lang_version: (2, 4), + per_entry_point_map: {}, + inline_samplers: [], + spirv_cross_compatibility: false, + fake_missing_bindings: false, + zero_initialize_workgroup_memory: true, + ), + glsl: ( + version: Desktop(430), + writer_flags: (""), + binding_map: { }, + zero_initialize_workgroup_memory: true, + ), + hlsl: ( + shader_model: V6_0, + binding_map: {}, + fake_missing_bindings: true, + special_constants_binding: None, + zero_initialize_workgroup_memory: true, + ), +) diff --git a/tests/in/subgroup-operations.wgsl b/tests/in/subgroup-operations.wgsl new file mode 100644 index 0000000000..f30b60be47 --- /dev/null +++ b/tests/in/subgroup-operations.wgsl @@ -0,0 +1,32 @@ +@compute @workgroup_size(1) +fn main( + @builtin(num_subgroups) num_subgroups: u32, + @builtin(subgroup_id) subgroup_id: u32, + @builtin(subgroup_size) subgroup_size: u32, + @builtin(subgroup_invocation_id) subgroup_invocation_id: u32, +) { + subgroupBarrier(); + + subgroupBallot((subgroup_invocation_id & 1u) == 1u); + + subgroupAll(subgroup_invocation_id != 0u); + subgroupAny(subgroup_invocation_id == 0u); + subgroupAdd(subgroup_invocation_id); + subgroupMul(subgroup_invocation_id); + subgroupMin(subgroup_invocation_id); + subgroupMax(subgroup_invocation_id); + subgroupAnd(subgroup_invocation_id); + subgroupOr(subgroup_invocation_id); + subgroupXor(subgroup_invocation_id); + subgroupPrefixExclusiveAdd(subgroup_invocation_id); + subgroupPrefixExclusiveMul(subgroup_invocation_id); + subgroupPrefixInclusiveAdd(subgroup_invocation_id); + subgroupPrefixInclusiveMul(subgroup_invocation_id); + + subgroupBroadcastFirst(subgroup_invocation_id); + subgroupBroadcast(subgroup_invocation_id, 4u); + subgroupShuffle(subgroup_invocation_id, subgroup_size - 1u - subgroup_invocation_id); + subgroupShuffleDown(subgroup_invocation_id, 1u); + subgroupShuffleUp(subgroup_invocation_id, 1u); + subgroupShuffleXor(subgroup_invocation_id, subgroup_size - 1u); +} diff --git a/tests/out/glsl/subgroup-operations.main.Compute.glsl b/tests/out/glsl/subgroup-operations.main.Compute.glsl new file mode 100644 index 0000000000..a37cf8e247 --- /dev/null +++ b/tests/out/glsl/subgroup-operations.main.Compute.glsl @@ -0,0 +1,41 @@ +#version 430 core +#extension GL_ARB_compute_shader : require +#extension GL_KHR_shader_subgroup_basic : require +#extension GL_KHR_shader_subgroup_vote : require +#extension GL_KHR_shader_subgroup_arithmetic : require +#extension GL_KHR_shader_subgroup_ballot : require +#extension GL_KHR_shader_subgroup_shuffle : require +#extension GL_KHR_shader_subgroup_shuffle_relative : require +layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in; + + +void main() { + uint num_subgroups = gl_NumSubgroups; + uint subgroup_id = gl_SubgroupID; + uint subgroup_size = gl_SubgroupSize; + uint subgroup_invocation_id = gl_SubgroupInvocationID; + subgroupMemoryBarrier(); + barrier(); + uvec4 _e8 = subgroupBallot(((subgroup_invocation_id & 1u) == 1u)); + bool _e11 = subgroupAll((subgroup_invocation_id != 0u)); + bool _e14 = subgroupAny((subgroup_invocation_id == 0u)); + uint _e15 = subgroupAdd(subgroup_invocation_id); + uint _e16 = subgroupMul(subgroup_invocation_id); + uint _e17 = subgroupMin(subgroup_invocation_id); + uint _e18 = subgroupMax(subgroup_invocation_id); + uint _e19 = subgroupAnd(subgroup_invocation_id); + uint _e20 = subgroupOr(subgroup_invocation_id); + uint _e21 = subgroupXor(subgroup_invocation_id); + uint _e22 = subgroupExclusiveAdd(subgroup_invocation_id); + uint _e23 = subgroupExclusiveMul(subgroup_invocation_id); + uint _e24 = subgroupInclusiveAdd(subgroup_invocation_id); + uint _e25 = subgroupInclusiveMul(subgroup_invocation_id); + uint _e26 = subgroupBroadcastFirst(subgroup_invocation_id); + uint _e28 = subgroupBroadcast(subgroup_invocation_id, 4u); + uint _e32 = subgroupShuffle(subgroup_invocation_id, ((subgroup_size - 1u) - subgroup_invocation_id)); + uint _e34 = subgroupShuffleDown(subgroup_invocation_id, 1u); + uint _e36 = subgroupShuffleUp(subgroup_invocation_id, 1u); + uint _e39 = subgroupShuffleXor(subgroup_invocation_id, (subgroup_size - 1u)); + return; +} + diff --git a/tests/out/hlsl/subgroup-operations.hlsl b/tests/out/hlsl/subgroup-operations.hlsl new file mode 100644 index 0000000000..baa37826e0 --- /dev/null +++ b/tests/out/hlsl/subgroup-operations.hlsl @@ -0,0 +1,32 @@ +[numthreads(1, 1, 1)] +void main(uint3 __local_invocation_id : SV_GroupThreadID) +{ + if (all(__local_invocation_id == uint3(0u, 0u, 0u))) { + } + GroupMemoryBarrierWithGroupSync(); + const uint num_subgroups = (1u + WaveGetLaneCount() - 1u) / WaveGetLaneCount(); + const uint subgroup_id = (__local_invocation_id.x * 1u + __local_invocation_id.y * 1u + __local_invocation_id.z) / WaveGetLaneCount(); + const uint subgroup_size = WaveGetLaneCount(); + const uint subgroup_invocation_id = WaveGetLaneIndex(); + const uint4 _e8 = WaveActiveBallot(((subgroup_invocation_id & 1u) == 1u)); + const bool _e11 = WaveActiveAllTrue((subgroup_invocation_id != 0u)); + const bool _e14 = WaveActiveAnyTrue((subgroup_invocation_id == 0u)); + const uint _e15 = WaveActiveSum(subgroup_invocation_id); + const uint _e16 = WaveActiveProduct(subgroup_invocation_id); + const uint _e17 = WaveActiveMin(subgroup_invocation_id); + const uint _e18 = WaveActiveMax(subgroup_invocation_id); + const uint _e19 = WaveActiveBitAnd(subgroup_invocation_id); + const uint _e20 = WaveActiveBitOr(subgroup_invocation_id); + const uint _e21 = WaveActiveBitXor(subgroup_invocation_id); + const uint _e22 = WavePrefixSum(subgroup_invocation_id); + const uint _e23 = WavePrefixProduct(subgroup_invocation_id); + const uint _e24 = subgroup_invocation_id + WavePrefixSum(subgroup_invocation_id); + const uint _e25 = subgroup_invocation_id * WavePrefixProduct(subgroup_invocation_id); + const uint _e26 = WaveReadLaneFirst(subgroup_invocation_id); + const uint _e28 = WaveReadLaneAt(subgroup_invocation_id, 4u); + const uint _e32 = WaveReadLaneAt(subgroup_invocation_id, ((subgroup_size - 1u) - subgroup_invocation_id)); + const uint _e34 = WaveReadLaneAt(subgroup_invocation_id, WaveGetLaneIndex() + 1u); + const uint _e36 = WaveReadLaneAt(subgroup_invocation_id, WaveGetLaneIndex() - 1u); + const uint _e39 = WaveReadLaneAt(subgroup_invocation_id, WaveGetLaneIndex() ^ (subgroup_size - 1u)); + return; +} diff --git a/tests/out/hlsl/subgroup-operations.ron b/tests/out/hlsl/subgroup-operations.ron new file mode 100644 index 0000000000..b973fe3da1 --- /dev/null +++ b/tests/out/hlsl/subgroup-operations.ron @@ -0,0 +1,12 @@ +( + vertex:[ + ], + fragment:[ + ], + compute:[ + ( + entry_point:"main", + target_profile:"cs_6_0", + ), + ], +) diff --git a/tests/out/msl/subgroup-operations.msl b/tests/out/msl/subgroup-operations.msl new file mode 100644 index 0000000000..576fa3b84e --- /dev/null +++ b/tests/out/msl/subgroup-operations.msl @@ -0,0 +1,38 @@ +// language: metal2.4 +#include +#include + +using metal::uint; + + +struct main_Input { +}; +kernel void main_( + uint num_subgroups [[simdgroups_per_threadgroup]] +, uint subgroup_id [[simdgroup_index_in_threadgroup]] +, uint subgroup_size [[threads_per_simdgroup]] +, uint subgroup_invocation_id [[thread_index_in_simdgroup]] +) { + metal::simdgroup_barrier(metal::mem_flags::mem_threadgroup); + metal::uint4 unnamed = uint4((uint64_t)metal::simd_ballot((subgroup_invocation_id & 1u) == 1u), 0, 0, 0); + bool unnamed_1 = metal::simd_all(subgroup_invocation_id != 0u); + bool unnamed_2 = metal::simd_any(subgroup_invocation_id == 0u); + uint unnamed_3 = metal::simd_sum(subgroup_invocation_id); + uint unnamed_4 = metal::simd_product(subgroup_invocation_id); + uint unnamed_5 = metal::simd_min(subgroup_invocation_id); + uint unnamed_6 = metal::simd_max(subgroup_invocation_id); + uint unnamed_7 = metal::simd_and(subgroup_invocation_id); + uint unnamed_8 = metal::simd_or(subgroup_invocation_id); + uint unnamed_9 = metal::simd_xor(subgroup_invocation_id); + uint unnamed_10 = metal::simd_prefix_exclusive_sum(subgroup_invocation_id); + uint unnamed_11 = metal::simd_prefix_exclusive_product(subgroup_invocation_id); + uint unnamed_12 = metal::simd_prefix_inclusive_sum(subgroup_invocation_id); + uint unnamed_13 = metal::simd_prefix_inclusive_product(subgroup_invocation_id); + uint unnamed_14 = metal::simd_broadcast_first(subgroup_invocation_id); + uint unnamed_15 = metal::simd_broadcast(subgroup_invocation_id, 4u); + uint unnamed_16 = metal::simd_shuffle(subgroup_invocation_id, (subgroup_size - 1u) - subgroup_invocation_id); + uint unnamed_17 = metal::simd_shuffle_down(subgroup_invocation_id, 1u); + uint unnamed_18 = metal::simd_shuffle_up(subgroup_invocation_id, 1u); + uint unnamed_19 = metal::simd_shuffle_xor(subgroup_invocation_id, subgroup_size - 1u); + return; +} diff --git a/tests/out/spv/subgroup-operations.spvasm b/tests/out/spv/subgroup-operations.spvasm new file mode 100644 index 0000000000..c2023c5473 --- /dev/null +++ b/tests/out/spv/subgroup-operations.spvasm @@ -0,0 +1,73 @@ +; SPIR-V +; Version: 1.3 +; Generator: rspirv +; Bound: 52 +OpCapability Shader +OpCapability GroupNonUniform +OpCapability GroupNonUniformBallot +OpCapability GroupNonUniformVote +OpCapability GroupNonUniformArithmetic +OpCapability GroupNonUniformShuffle +OpCapability GroupNonUniformShuffleRelative +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint GLCompute %15 "main" %6 %9 %11 %13 +OpExecutionMode %15 LocalSize 1 1 1 +OpDecorate %6 BuiltIn NumSubgroups +OpDecorate %9 BuiltIn SubgroupId +OpDecorate %11 BuiltIn SubgroupSize +OpDecorate %13 BuiltIn SubgroupLocalInvocationId +%2 = OpTypeVoid +%3 = OpTypeInt 32 0 +%4 = OpTypeBool +%7 = OpTypePointer Input %3 +%6 = OpVariable %7 Input +%9 = OpVariable %7 Input +%11 = OpVariable %7 Input +%13 = OpVariable %7 Input +%16 = OpTypeFunction %2 +%17 = OpConstant %3 1 +%18 = OpConstant %3 0 +%19 = OpConstant %3 4 +%21 = OpConstant %3 3 +%22 = OpConstant %3 2 +%23 = OpConstant %3 8 +%26 = OpTypeVector %3 4 +%15 = OpFunction %2 None %16 +%5 = OpLabel +%8 = OpLoad %3 %6 +%10 = OpLoad %3 %9 +%12 = OpLoad %3 %11 +%14 = OpLoad %3 %13 +OpBranch %20 +%20 = OpLabel +OpControlBarrier %21 %22 %23 +%24 = OpBitwiseAnd %3 %14 %17 +%25 = OpIEqual %4 %24 %17 +%27 = OpGroupNonUniformBallot %26 %21 %25 +%28 = OpINotEqual %4 %14 %18 +%29 = OpGroupNonUniformAll %4 %21 %28 +%30 = OpIEqual %4 %14 %18 +%31 = OpGroupNonUniformAny %4 %21 %30 +%32 = OpGroupNonUniformIAdd %3 %21 Reduce %14 +%33 = OpGroupNonUniformIMul %3 %21 Reduce %14 +%34 = OpGroupNonUniformUMin %3 %21 Reduce %14 +%35 = OpGroupNonUniformUMax %3 %21 Reduce %14 +%36 = OpGroupNonUniformBitwiseAnd %3 %21 Reduce %14 +%37 = OpGroupNonUniformBitwiseOr %3 %21 Reduce %14 +%38 = OpGroupNonUniformBitwiseXor %3 %21 Reduce %14 +%39 = OpGroupNonUniformIAdd %3 %21 ExclusiveScan %14 +%40 = OpGroupNonUniformIMul %3 %21 ExclusiveScan %14 +%41 = OpGroupNonUniformIAdd %3 %21 InclusiveScan %14 +%42 = OpGroupNonUniformIMul %3 %21 InclusiveScan %14 +%43 = OpGroupNonUniformBroadcastFirst %3 %21 %14 +%44 = OpGroupNonUniformBroadcast %3 %21 %14 %19 +%45 = OpISub %3 %12 %17 +%46 = OpISub %3 %45 %14 +%47 = OpGroupNonUniformShuffle %3 %21 %14 %46 +%48 = OpGroupNonUniformShuffleDown %3 %21 %14 %17 +%49 = OpGroupNonUniformShuffleUp %3 %21 %14 %17 +%50 = OpISub %3 %12 %17 +%51 = OpGroupNonUniformShuffleXor %3 %21 %14 %50 +OpReturn +OpFunctionEnd \ No newline at end of file diff --git a/tests/out/wgsl/subgroup-operations.wgsl b/tests/out/wgsl/subgroup-operations.wgsl new file mode 100644 index 0000000000..f12f226387 --- /dev/null +++ b/tests/out/wgsl/subgroup-operations.wgsl @@ -0,0 +1,26 @@ +@compute @workgroup_size(1, 1, 1) +fn main(@builtin(num_subgroups) num_subgroups: u32, @builtin(subgroup_id) subgroup_id: u32, @builtin(subgroup_size) subgroup_size: u32, @builtin(subgroup_invocation_id) subgroup_invocation_id: u32) { + subgroupBarrier(); + let _e8 = subgroupBallot( +((subgroup_invocation_id & 1u) == 1u)); + let _e11 = subgroupAll((subgroup_invocation_id != 0u)); + let _e14 = subgroupAny((subgroup_invocation_id == 0u)); + let _e15 = subgroupAdd(subgroup_invocation_id); + let _e16 = subgroupMul(subgroup_invocation_id); + let _e17 = subgroupMin(subgroup_invocation_id); + let _e18 = subgroupMax(subgroup_invocation_id); + let _e19 = subgroupAnd(subgroup_invocation_id); + let _e20 = subgroupOr(subgroup_invocation_id); + let _e21 = subgroupXor(subgroup_invocation_id); + let _e22 = subgroupPrefixExclusiveAdd(subgroup_invocation_id); + let _e23 = subgroupPrefixExclusiveMul(subgroup_invocation_id); + let _e24 = subgroupPrefixInclusiveAdd(subgroup_invocation_id); + let _e25 = subgroupPrefixInclusiveMul(subgroup_invocation_id); + let _e26 = subgroupBroadcastFirst(subgroup_invocation_id); + let _e28 = subgroupBroadcast(subgroup_invocation_id, 4u); + let _e32 = subgroupShuffle(subgroup_invocation_id, ((subgroup_size - 1u) - subgroup_invocation_id)); + let _e34 = subgroupShuffleDown(subgroup_invocation_id, 1u); + let _e36 = subgroupShuffleUp(subgroup_invocation_id, 1u); + let _e39 = subgroupShuffleXor(subgroup_invocation_id, (subgroup_size - 1u)); + return; +} diff --git a/tests/snapshots.rs b/tests/snapshots.rs index c3455dd864..c720e2efd1 100644 --- a/tests/snapshots.rs +++ b/tests/snapshots.rs @@ -782,6 +782,10 @@ fn convert_wgsl() { Targets::SPIRV | Targets::METAL | Targets::GLSL | Targets::HLSL | Targets::WGSL, ), ("separate-entry-points", Targets::SPIRV | Targets::GLSL), + ( + "subgroup-operations", + Targets::SPIRV | Targets::METAL | Targets::GLSL | Targets::HLSL | Targets::WGSL, + ), ]; for &(name, targets) in inputs.iter() {