From d86df5983bde837fe5cfeea28a9dfaf3dfb6f3d5 Mon Sep 17 00:00:00 2001 From: Jacob Hughes Date: Thu, 5 Oct 2023 11:46:55 -0400 Subject: [PATCH] subgroup: wgsl-in and spv-out for subgroup operations --- src/back/spv/block.rs | 4 +- src/back/spv/instructions.rs | 40 ++++++++++++- src/back/spv/subgroup.rs | 84 ++++++++++++++++++++++----- src/front/wgsl/lower/mod.rs | 108 +++++++++++++++++++++++------------ src/front/wgsl/parse/conv.rs | 21 +++++++ src/valid/analyzer.rs | 1 + src/valid/function.rs | 16 ++++-- src/valid/handles.rs | 1 + 8 files changed, 216 insertions(+), 59 deletions(-) diff --git a/src/back/spv/block.rs b/src/back/spv/block.rs index c1963450dc..555495ffb5 100644 --- a/src/back/spv/block.rs +++ b/src/back/spv/block.rs @@ -2357,14 +2357,14 @@ impl<'w> BlockContext<'w> { argument, result, } => { - self.write_subgroup_operation(op, collective_op, argument, result, &mut block); + self.write_subgroup_operation(op, collective_op, argument, result, &mut block)?; } crate::Statement::SubgroupBroadcast { ref mode, argument, result, } => { - unimplemented!() // FIXME + self.write_subgroup_broadcast(mode, argument, result, &mut block)?; } } } diff --git a/src/back/spv/instructions.rs b/src/back/spv/instructions.rs index 2eb2a0be4e..9c889a0f2d 100644 --- a/src/back/spv/instructions.rs +++ b/src/back/spv/instructions.rs @@ -1054,19 +1054,55 @@ impl super::Instruction { instruction } + pub(super) fn group_non_uniform_broadcast( + result_type_id: Word, + id: Word, + exec_scope_id: Word, + value: Word, + index: Word, + ) -> Self { + let mut instruction = Self::new(Op::GroupNonUniformBroadcast); + instruction.set_type(result_type_id); + instruction.set_result(id); + instruction.add_operand(exec_scope_id); + instruction.add_operand(value); + instruction.add_operand(index); + + instruction + } + pub(super) fn group_non_uniform_broadcast_first( + result_type_id: Word, + id: Word, + exec_scope_id: Word, + value: Word, + ) -> Self { + let mut instruction = Self::new(Op::GroupNonUniformBroadcastFirst); + instruction.set_type(result_type_id); + instruction.set_result(id); + instruction.add_operand(exec_scope_id); + instruction.add_operand(value); + + instruction + } pub(super) fn group_non_uniform_arithmetic( op: Op, result_type_id: Word, id: Word, exec_scope_id: Word, - group_op: spirv::GroupOperation, + group_op: Option, value: Word, ) -> Self { + println!( + "{:?}", + (op, result_type_id, id, exec_scope_id, group_op, value) + ); let mut instruction = Self::new(op); instruction.set_type(result_type_id); instruction.set_result(id); instruction.add_operand(exec_scope_id); - instruction.add_operand(group_op as u32); + if let Some(group_op) = group_op { + instruction.add_operand(group_op as u32); + } instruction.add_operand(value); instruction diff --git a/src/back/spv/subgroup.rs b/src/back/spv/subgroup.rs index a4f7b018f4..3c8b7d827a 100644 --- a/src/back/spv/subgroup.rs +++ b/src/back/spv/subgroup.rs @@ -10,20 +10,30 @@ impl<'w> BlockContext<'w> { result: Handle, block: &mut Block, ) -> Result<(), Error> { - self.writer.require_any( - "GroupNonUniformArithmetic", - &[ - spirv::Capability::GroupNonUniformArithmetic, - spirv::Capability::GroupNonUniformClustered, - spirv::Capability::GroupNonUniformPartitionedNV, - ], - )?; + use crate::SubgroupOperation as sg; + match op { + sg::All | sg::Any => { + self.writer.require_any( + "GroupNonUniformVote", + &[spirv::Capability::GroupNonUniformVote], + )?; + } + _ => { + self.writer.require_any( + "GroupNonUniformArithmetic", + &[ + spirv::Capability::GroupNonUniformArithmetic, + spirv::Capability::GroupNonUniformClustered, + spirv::Capability::GroupNonUniformPartitionedNV, + ], + )?; + } + } let id = self.gen_id(); let result_ty = &self.fun_info[result].ty; let result_type_id = self.get_expression_type_id(result_ty); let result_ty_inner = result_ty.inner_with(&self.ir_module.types); - let kind = result_ty_inner.scalar_kind().unwrap(); let (is_scalar, kind) = match result_ty_inner { TypeInner::Scalar { kind, .. } => (true, kind), @@ -32,7 +42,6 @@ impl<'w> BlockContext<'w> { }; use crate::ScalarKind as sk; - use crate::SubgroupOperation as sg; let spirv_op = match (kind, op) { (sk::Bool, sg::All) if is_scalar => spirv::Op::GroupNonUniformAll, (sk::Bool, sg::Any) if is_scalar => spirv::Op::GroupNonUniformAny, @@ -62,10 +71,13 @@ impl<'w> BlockContext<'w> { let exec_scope_id = self.get_index_constant(spirv::Scope::Subgroup as u32); use crate::CollectiveOperation as c; - let group_op = match collective_op { - c::Reduce => spirv::GroupOperation::Reduce, - c::InclusiveScan => spirv::GroupOperation::InclusiveScan, - c::ExclusiveScan => spirv::GroupOperation::ExclusiveScan, + let group_op = match op { + sg::All | sg::Any => None, + _ => Some(match collective_op { + c::Reduce => spirv::GroupOperation::Reduce, + c::InclusiveScan => spirv::GroupOperation::InclusiveScan, + c::ExclusiveScan => spirv::GroupOperation::ExclusiveScan, + }), }; let arg_id = self.cached[argument]; @@ -80,4 +92,48 @@ impl<'w> BlockContext<'w> { self.cached[result] = id; Ok(()) } + pub(super) fn write_subgroup_broadcast( + &mut self, + mode: &crate::BroadcastMode, + argument: Handle, + result: Handle, + block: &mut Block, + ) -> Result<(), Error> { + self.writer.require_any( + "GroupNonUniformBallot", + &[spirv::Capability::GroupNonUniformBallot], + )?; + + let id = self.gen_id(); + let result_ty = &self.fun_info[result].ty; + let result_type_id = self.get_expression_type_id(result_ty); + + let exec_scope_id = self.get_index_constant(spirv::Scope::Subgroup as u32); + + let arg_id = self.cached[argument]; + match mode { + crate::BroadcastMode::Index(index) => { + let index_id = self.cached[*index]; + block.body.push(Instruction::group_non_uniform_broadcast( + result_type_id, + id, + exec_scope_id, + arg_id, + index_id, + )); + } + crate::BroadcastMode::First => { + block + .body + .push(Instruction::group_non_uniform_broadcast_first( + result_type_id, + id, + exec_scope_id, + arg_id, + )); + } + } + self.cached[result] = id; + Ok(()) + } } diff --git a/src/front/wgsl/lower/mod.rs b/src/front/wgsl/lower/mod.rs index d74be46f22..1d9e8c40e7 100644 --- a/src/front/wgsl/lower/mod.rs +++ b/src/front/wgsl/lower/mod.rs @@ -1783,6 +1783,8 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { } } else if let Some(fun) = Texture::map(function.name) { self.texture_sample_helper(fun, arguments, span, ctx.reborrow())? + } else if let Some((op, cop)) = conv::map_subgroup_operation(function.name) { + return Ok(Some(self.subgroup_helper(span, op, cop, arguments, ctx)?)); } else { match function.name { "select" => { @@ -2176,43 +2178,51 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { return Ok(Some(result)); } "subgroupBroadcast" => { - unimplemented!(); // FIXME + let mut args = ctx.prepare_args(arguments, 2, span); + + let index = self.expression(args.next()?, ctx.reborrow())?; + let argument = self.expression(args.next()?, ctx.reborrow())?; + args.finish()?; + + let ty = ctx.register_type(argument)?; + + let result = ctx.interrupt_emitter( + crate::Expression::SubgroupOperationResult { ty }, + span, + ); + let rctx = ctx.runtime_expression_ctx(span)?; + rctx.block.push( + crate::Statement::SubgroupBroadcast { + mode: crate::BroadcastMode::Index(index), + argument, + result, + }, + span, + ); + return Ok(Some(result)); } "subgroupBroadcastFirst" => { - unimplemented!(); // FIXME - } - "subgroupAll" => { - unimplemented!(); // FIXME - } - "subgroupAny" => { - unimplemented!(); // FIXME - } - "subgroupAdd" => { - unimplemented!(); // FIXME - } - "subgroupMul" => { - unimplemented!(); // FIXME - } - "subgroupMin" => { - unimplemented!(); // FIXME - } - "subgroupMax" => { - unimplemented!(); // FIXME - } - "subgroupAnd" => { - unimplemented!(); // FIXME - } - "subgroupOr" => { - unimplemented!(); // FIXME - } - "subgroupXor" => { - unimplemented!(); // FIXME - } - "subgroupPrefixAdd" => { - unimplemented!(); // FIXME - } - "subgroupPrefixMul" => { - unimplemented!(); // FIXME + let mut args = ctx.prepare_args(arguments, 1, span); + + let argument = self.expression(args.next()?, ctx.reborrow())?; + args.finish()?; + + let ty = ctx.register_type(argument)?; + + let result = ctx.interrupt_emitter( + crate::Expression::SubgroupOperationResult { ty }, + span, + ); + let rctx = ctx.runtime_expression_ctx(span)?; + rctx.block.push( + crate::Statement::SubgroupBroadcast { + mode: crate::BroadcastMode::First, + argument, + result, + }, + span, + ); + return Ok(Some(result)); } _ => return Err(Error::UnknownIdent(function.span, function.name)), } @@ -2405,6 +2415,34 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { depth_ref, }) } + fn subgroup_helper( + &mut self, + span: Span, + op: crate::SubgroupOperation, + collective_op: crate::CollectiveOperation, + arguments: &[Handle>], + mut ctx: ExpressionContext<'source, '_, '_>, + ) -> Result, Error<'source>> { + let mut args = ctx.prepare_args(arguments, 1, span); + + let argument = self.expression(args.next()?, ctx.reborrow())?; + args.finish()?; + + let ty = ctx.register_type(argument)?; + + let result = ctx.interrupt_emitter(crate::Expression::SubgroupOperationResult { ty }, span); + let rctx = ctx.runtime_expression_ctx(span)?; + rctx.block.push( + crate::Statement::SubgroupCollectiveOperation { + op, + collective_op, + argument, + result, + }, + span, + ); + Ok(result) + } fn r#struct( &mut self, diff --git a/src/front/wgsl/parse/conv.rs b/src/front/wgsl/parse/conv.rs index cbfe37d51c..a67fb2e922 100644 --- a/src/front/wgsl/parse/conv.rs +++ b/src/front/wgsl/parse/conv.rs @@ -237,3 +237,24 @@ pub fn map_conservative_depth( _ => Err(Error::UnknownConservativeDepth(span)), } } + +pub fn map_subgroup_operation( + word: &str, +) -> Option<(crate::SubgroupOperation, crate::CollectiveOperation)> { + use crate::CollectiveOperation as co; + use crate::SubgroupOperation as sg; + Some(match word { + "subgroupAll" => (sg::All, co::Reduce), + "subgroupAny" => (sg::Any, co::Reduce), + "subgroupAdd" => (sg::Add, co::Reduce), + "subgroupMul" => (sg::Mul, co::Reduce), + "subgroupMin" => (sg::Min, co::Reduce), + "subgroupMax" => (sg::Max, co::Reduce), + "subgroupAnd" => (sg::And, co::Reduce), + "subgroupOr" => (sg::Or, co::Reduce), + "subgroupXor" => (sg::Xor, co::Reduce), + "subgroupPrefixAdd" => (sg::Add, co::InclusiveScan), + "subgroupPrefixMul" => (sg::Mul, co::InclusiveScan), + _ => return None, + }) +} diff --git a/src/valid/analyzer.rs b/src/valid/analyzer.rs index 7c266428c5..248a934a69 100644 --- a/src/valid/analyzer.rs +++ b/src/valid/analyzer.rs @@ -989,6 +989,7 @@ impl FunctionInfo { argument, result, } => { + println!("analyzer"); let _ = self.add_ref(argument); if let crate::BroadcastMode::Index(expr) = *mode { let _ = self.add_ref(expr); diff --git a/src/valid/function.rs b/src/valid/function.rs index 9f18da49ec..3659187958 100644 --- a/src/valid/function.rs +++ b/src/valid/function.rs @@ -431,7 +431,6 @@ impl super::Validator { result: Handle, context: &BlockContext, ) -> Result<(), WithSpan> { - self.emit_expression(argument, context)?; let argument_inner = context.resolve_type(argument, &self.valid_expression_set)?; let (is_scalar, kind) = match argument_inner { @@ -477,8 +476,8 @@ impl super::Validator { result: Handle, context: &BlockContext, ) -> Result<(), WithSpan> { + println!("function HENLO??"); if let crate::BroadcastMode::Index(expr) = *mode { - self.emit_expression(expr, context)?; let index_ty = context.resolve_type(expr, &self.valid_expression_set)?; match index_ty { crate::TypeInner::Scalar { @@ -486,25 +485,29 @@ impl super::Validator { .. } => {} _ => { - log::error!("Subgroup broadcast index type {:?}", index_ty); + log::error!( + "Subgroup broadcast index type {:?}, expected unsigned int", + index_ty + ); return Err(SubgroupError::InvalidOperand(argument) - .with_span_handle(argument, context.expressions) + .with_span_handle(expr, context.expressions) .into_other()); } } } - self.emit_expression(argument, context)?; + println!("function aaa"); let argument_inner = context.resolve_type(argument, &self.valid_expression_set)?; match argument_inner { crate::TypeInner::Scalar { .. } | crate::TypeInner::Vector { .. } => {} _ => { - log::error!("Subgroup operand type {:?}", argument_inner); + log::error!("Subgroup broadcast operand type {:?}", argument_inner); return Err(SubgroupError::InvalidOperand(argument) .with_span_handle(argument, context.expressions) .into_other()); } } + println!("function bbb"); self.emit_expression(result, context)?; match context.expressions[result] { @@ -516,6 +519,7 @@ impl super::Validator { .into_other()) } } + println!("function ccc"); Ok(()) } diff --git a/src/valid/handles.rs b/src/valid/handles.rs index 69b8079675..78e22d4ef1 100644 --- a/src/valid/handles.rs +++ b/src/valid/handles.rs @@ -560,6 +560,7 @@ impl super::Validator { argument, result, } => { + println!("handles"); if let crate::BroadcastMode::Index(expr) = mode { validate_expr(expr)?; }