Skip to content

Commit

Permalink
subgroup: wgsl-in and spv-out for subgroup operations
Browse files Browse the repository at this point in the history
  • Loading branch information
exrook committed Oct 5, 2023
1 parent 2bbc09e commit d86df59
Show file tree
Hide file tree
Showing 8 changed files with 216 additions and 59 deletions.
4 changes: 2 additions & 2 deletions src/back/spv/block.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2357,14 +2357,14 @@ impl<'w> BlockContext<'w> {
argument,
result,
} => {
self.write_subgroup_operation(op, collective_op, argument, result, &mut block);
self.write_subgroup_operation(op, collective_op, argument, result, &mut block)?;
}
crate::Statement::SubgroupBroadcast {
ref mode,
argument,
result,
} => {
unimplemented!() // FIXME
self.write_subgroup_broadcast(mode, argument, result, &mut block)?;
}
}
}
Expand Down
40 changes: 38 additions & 2 deletions src/back/spv/instructions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1054,19 +1054,55 @@ impl super::Instruction {

instruction
}
pub(super) fn group_non_uniform_broadcast(
result_type_id: Word,
id: Word,
exec_scope_id: Word,
value: Word,
index: Word,
) -> Self {
let mut instruction = Self::new(Op::GroupNonUniformBroadcast);
instruction.set_type(result_type_id);
instruction.set_result(id);
instruction.add_operand(exec_scope_id);
instruction.add_operand(value);
instruction.add_operand(index);

instruction
}
pub(super) fn group_non_uniform_broadcast_first(
result_type_id: Word,
id: Word,
exec_scope_id: Word,
value: Word,
) -> Self {
let mut instruction = Self::new(Op::GroupNonUniformBroadcastFirst);
instruction.set_type(result_type_id);
instruction.set_result(id);
instruction.add_operand(exec_scope_id);
instruction.add_operand(value);

instruction
}
pub(super) fn group_non_uniform_arithmetic(
op: Op,
result_type_id: Word,
id: Word,
exec_scope_id: Word,
group_op: spirv::GroupOperation,
group_op: Option<spirv::GroupOperation>,
value: Word,
) -> Self {
println!(
"{:?}",
(op, result_type_id, id, exec_scope_id, group_op, value)
);
let mut instruction = Self::new(op);
instruction.set_type(result_type_id);
instruction.set_result(id);
instruction.add_operand(exec_scope_id);
instruction.add_operand(group_op as u32);
if let Some(group_op) = group_op {
instruction.add_operand(group_op as u32);
}
instruction.add_operand(value);

instruction
Expand Down
84 changes: 70 additions & 14 deletions src/back/spv/subgroup.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,20 +10,30 @@ impl<'w> BlockContext<'w> {
result: Handle<crate::Expression>,
block: &mut Block,
) -> Result<(), Error> {
self.writer.require_any(
"GroupNonUniformArithmetic",
&[
spirv::Capability::GroupNonUniformArithmetic,
spirv::Capability::GroupNonUniformClustered,
spirv::Capability::GroupNonUniformPartitionedNV,
],
)?;
use crate::SubgroupOperation as sg;
match op {
sg::All | sg::Any => {
self.writer.require_any(
"GroupNonUniformVote",
&[spirv::Capability::GroupNonUniformVote],
)?;
}
_ => {
self.writer.require_any(
"GroupNonUniformArithmetic",
&[
spirv::Capability::GroupNonUniformArithmetic,
spirv::Capability::GroupNonUniformClustered,
spirv::Capability::GroupNonUniformPartitionedNV,
],
)?;
}
}

let id = self.gen_id();
let result_ty = &self.fun_info[result].ty;
let result_type_id = self.get_expression_type_id(result_ty);
let result_ty_inner = result_ty.inner_with(&self.ir_module.types);
let kind = result_ty_inner.scalar_kind().unwrap();

let (is_scalar, kind) = match result_ty_inner {
TypeInner::Scalar { kind, .. } => (true, kind),
Expand All @@ -32,7 +42,6 @@ impl<'w> BlockContext<'w> {
};

use crate::ScalarKind as sk;
use crate::SubgroupOperation as sg;
let spirv_op = match (kind, op) {
(sk::Bool, sg::All) if is_scalar => spirv::Op::GroupNonUniformAll,
(sk::Bool, sg::Any) if is_scalar => spirv::Op::GroupNonUniformAny,
Expand Down Expand Up @@ -62,10 +71,13 @@ impl<'w> BlockContext<'w> {
let exec_scope_id = self.get_index_constant(spirv::Scope::Subgroup as u32);

use crate::CollectiveOperation as c;
let group_op = match collective_op {
c::Reduce => spirv::GroupOperation::Reduce,
c::InclusiveScan => spirv::GroupOperation::InclusiveScan,
c::ExclusiveScan => spirv::GroupOperation::ExclusiveScan,
let group_op = match op {
sg::All | sg::Any => None,
_ => Some(match collective_op {
c::Reduce => spirv::GroupOperation::Reduce,
c::InclusiveScan => spirv::GroupOperation::InclusiveScan,
c::ExclusiveScan => spirv::GroupOperation::ExclusiveScan,
}),
};

let arg_id = self.cached[argument];
Expand All @@ -80,4 +92,48 @@ impl<'w> BlockContext<'w> {
self.cached[result] = id;
Ok(())
}
pub(super) fn write_subgroup_broadcast(
&mut self,
mode: &crate::BroadcastMode,
argument: Handle<crate::Expression>,
result: Handle<crate::Expression>,
block: &mut Block,
) -> Result<(), Error> {
self.writer.require_any(
"GroupNonUniformBallot",
&[spirv::Capability::GroupNonUniformBallot],
)?;

let id = self.gen_id();
let result_ty = &self.fun_info[result].ty;
let result_type_id = self.get_expression_type_id(result_ty);

let exec_scope_id = self.get_index_constant(spirv::Scope::Subgroup as u32);

let arg_id = self.cached[argument];
match mode {
crate::BroadcastMode::Index(index) => {
let index_id = self.cached[*index];
block.body.push(Instruction::group_non_uniform_broadcast(
result_type_id,
id,
exec_scope_id,
arg_id,
index_id,
));
}
crate::BroadcastMode::First => {
block
.body
.push(Instruction::group_non_uniform_broadcast_first(
result_type_id,
id,
exec_scope_id,
arg_id,
));
}
}
self.cached[result] = id;
Ok(())
}
}
108 changes: 73 additions & 35 deletions src/front/wgsl/lower/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1783,6 +1783,8 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
}
} else if let Some(fun) = Texture::map(function.name) {
self.texture_sample_helper(fun, arguments, span, ctx.reborrow())?
} else if let Some((op, cop)) = conv::map_subgroup_operation(function.name) {
return Ok(Some(self.subgroup_helper(span, op, cop, arguments, ctx)?));
} else {
match function.name {
"select" => {
Expand Down Expand Up @@ -2176,43 +2178,51 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
return Ok(Some(result));
}
"subgroupBroadcast" => {
unimplemented!(); // FIXME
let mut args = ctx.prepare_args(arguments, 2, span);

let index = self.expression(args.next()?, ctx.reborrow())?;
let argument = self.expression(args.next()?, ctx.reborrow())?;
args.finish()?;

let ty = ctx.register_type(argument)?;

let result = ctx.interrupt_emitter(
crate::Expression::SubgroupOperationResult { ty },
span,
);
let rctx = ctx.runtime_expression_ctx(span)?;
rctx.block.push(
crate::Statement::SubgroupBroadcast {
mode: crate::BroadcastMode::Index(index),
argument,
result,
},
span,
);
return Ok(Some(result));
}
"subgroupBroadcastFirst" => {
unimplemented!(); // FIXME
}
"subgroupAll" => {
unimplemented!(); // FIXME
}
"subgroupAny" => {
unimplemented!(); // FIXME
}
"subgroupAdd" => {
unimplemented!(); // FIXME
}
"subgroupMul" => {
unimplemented!(); // FIXME
}
"subgroupMin" => {
unimplemented!(); // FIXME
}
"subgroupMax" => {
unimplemented!(); // FIXME
}
"subgroupAnd" => {
unimplemented!(); // FIXME
}
"subgroupOr" => {
unimplemented!(); // FIXME
}
"subgroupXor" => {
unimplemented!(); // FIXME
}
"subgroupPrefixAdd" => {
unimplemented!(); // FIXME
}
"subgroupPrefixMul" => {
unimplemented!(); // FIXME
let mut args = ctx.prepare_args(arguments, 1, span);

let argument = self.expression(args.next()?, ctx.reborrow())?;
args.finish()?;

let ty = ctx.register_type(argument)?;

let result = ctx.interrupt_emitter(
crate::Expression::SubgroupOperationResult { ty },
span,
);
let rctx = ctx.runtime_expression_ctx(span)?;
rctx.block.push(
crate::Statement::SubgroupBroadcast {
mode: crate::BroadcastMode::First,
argument,
result,
},
span,
);
return Ok(Some(result));
}
_ => return Err(Error::UnknownIdent(function.span, function.name)),
}
Expand Down Expand Up @@ -2405,6 +2415,34 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
depth_ref,
})
}
fn subgroup_helper(
&mut self,
span: Span,
op: crate::SubgroupOperation,
collective_op: crate::CollectiveOperation,
arguments: &[Handle<ast::Expression<'source>>],
mut ctx: ExpressionContext<'source, '_, '_>,
) -> Result<Handle<crate::Expression>, Error<'source>> {
let mut args = ctx.prepare_args(arguments, 1, span);

let argument = self.expression(args.next()?, ctx.reborrow())?;
args.finish()?;

let ty = ctx.register_type(argument)?;

let result = ctx.interrupt_emitter(crate::Expression::SubgroupOperationResult { ty }, span);
let rctx = ctx.runtime_expression_ctx(span)?;
rctx.block.push(
crate::Statement::SubgroupCollectiveOperation {
op,
collective_op,
argument,
result,
},
span,
);
Ok(result)
}

fn r#struct(
&mut self,
Expand Down
21 changes: 21 additions & 0 deletions src/front/wgsl/parse/conv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -237,3 +237,24 @@ pub fn map_conservative_depth(
_ => Err(Error::UnknownConservativeDepth(span)),
}
}

pub fn map_subgroup_operation(
word: &str,
) -> Option<(crate::SubgroupOperation, crate::CollectiveOperation)> {
use crate::CollectiveOperation as co;
use crate::SubgroupOperation as sg;
Some(match word {
"subgroupAll" => (sg::All, co::Reduce),
"subgroupAny" => (sg::Any, co::Reduce),
"subgroupAdd" => (sg::Add, co::Reduce),
"subgroupMul" => (sg::Mul, co::Reduce),
"subgroupMin" => (sg::Min, co::Reduce),
"subgroupMax" => (sg::Max, co::Reduce),
"subgroupAnd" => (sg::And, co::Reduce),
"subgroupOr" => (sg::Or, co::Reduce),
"subgroupXor" => (sg::Xor, co::Reduce),
"subgroupPrefixAdd" => (sg::Add, co::InclusiveScan),
"subgroupPrefixMul" => (sg::Mul, co::InclusiveScan),
_ => return None,
})
}
1 change: 1 addition & 0 deletions src/valid/analyzer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -989,6 +989,7 @@ impl FunctionInfo {
argument,
result,
} => {
println!("analyzer");
let _ = self.add_ref(argument);
if let crate::BroadcastMode::Index(expr) = *mode {
let _ = self.add_ref(expr);
Expand Down
Loading

0 comments on commit d86df59

Please sign in to comment.