From 064afd8944ff0909b6004f9d6fa6565cdb954130 Mon Sep 17 00:00:00 2001 From: Jacob Hughes Date: Sat, 21 Oct 2023 17:23:34 -0400 Subject: [PATCH] subgroup: refactor wgsl subgroup gather parsing --- src/front/wgsl/lower/mod.rs | 53 +++++++++++++++++++++++++++--------- src/front/wgsl/parse/conv.rs | 15 ---------- 2 files changed, 40 insertions(+), 28 deletions(-) diff --git a/src/front/wgsl/lower/mod.rs b/src/front/wgsl/lower/mod.rs index d3e47f6959..d2784070cc 100644 --- a/src/front/wgsl/lower/mod.rs +++ b/src/front/wgsl/lower/mod.rs @@ -884,6 +884,29 @@ impl Texture { } } +enum SubgroupGather { + BroadcastFirst, + Broadcast, + Shuffle, + ShuffleDown, + ShuffleUp, + ShuffleXor, +} + +impl SubgroupGather { + pub fn map(word: &str) -> Option { + Some(match word { + "subgroupBroadcastFirst" => Self::BroadcastFirst, + "subgroupBroadcast" => Self::Broadcast, + "subgroupShuffle" => Self::Shuffle, + "subgroupShuffleDown" => Self::ShuffleDown, + "subgroupShuffleUp" => Self::ShuffleUp, + "subgroupShuffleXor" => Self::ShuffleXor, + _ => return None, + }) + } +} + pub struct Lowerer<'source, 'temp> { index: &'temp Index<'source>, layouter: Layouter, @@ -1921,7 +1944,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { return Ok(Some( self.subgroup_operation_helper(span, op, cop, arguments, ctx)?, )); - } else if let Some(mode) = conv::map_subgroup_gather(function.name) { + } else if let Some(mode) = SubgroupGather::map(function.name) { return Ok(Some( self.subgroup_gather_helper(span, mode, arguments, ctx)?, )); @@ -2547,18 +2570,29 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { fn subgroup_gather_helper( &mut self, span: Span, - mode: crate::GatherMode, + mode: SubgroupGather, arguments: &[Handle>], mut ctx: ExpressionContext<'source, '_, '_>, ) -> Result, Error<'source>> { let mut args = ctx.prepare_args(arguments, 2, span); let argument = self.expression(args.next()?, ctx.reborrow())?; - let index = if let crate::GatherMode::BroadcastFirst = mode { - Handle::new(NonZeroU32::new(u32::MAX).unwrap()) + + use SubgroupGather as Sg; + let mode = if let Sg::BroadcastFirst = mode { + crate::GatherMode::BroadcastFirst } else { - self.expression(args.next()?, ctx.reborrow())? + let index = self.expression(args.next()?, ctx.reborrow())?; + match mode { + Sg::Broadcast => crate::GatherMode::Broadcast(index), + Sg::Shuffle => crate::GatherMode::Shuffle(index), + Sg::ShuffleDown => crate::GatherMode::ShuffleDown(index), + Sg::ShuffleUp => crate::GatherMode::ShuffleUp(index), + Sg::ShuffleXor => crate::GatherMode::ShuffleXor(index), + Sg::BroadcastFirst => unreachable!(), + } }; + args.finish()?; let ty = ctx.register_type(argument)?; @@ -2568,14 +2602,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { let rctx = ctx.runtime_expression_ctx(span)?; rctx.block.push( crate::Statement::SubgroupGather { - mode: match mode { - crate::GatherMode::BroadcastFirst => crate::GatherMode::BroadcastFirst, - crate::GatherMode::Broadcast(_) => crate::GatherMode::Broadcast(index), - crate::GatherMode::Shuffle(_) => crate::GatherMode::Shuffle(index), - crate::GatherMode::ShuffleDown(_) => crate::GatherMode::ShuffleDown(index), - crate::GatherMode::ShuffleUp(_) => crate::GatherMode::ShuffleUp(index), - crate::GatherMode::ShuffleXor(_) => crate::GatherMode::ShuffleXor(index), - }, + mode, argument, result, }, diff --git a/src/front/wgsl/parse/conv.rs b/src/front/wgsl/parse/conv.rs index c53f4df753..61fd1bb37e 100644 --- a/src/front/wgsl/parse/conv.rs +++ b/src/front/wgsl/parse/conv.rs @@ -263,18 +263,3 @@ pub fn map_subgroup_operation( _ => return None, }) } - -pub fn map_subgroup_gather(word: &str) -> Option { - use crate::GatherMode as gm; - use crate::Handle; - use std::num::NonZeroU32; - Some(match word { - "subgroupBroadcastFirst" => gm::BroadcastFirst, - "subgroupBroadcast" => gm::Broadcast(Handle::new(NonZeroU32::new(u32::MAX).unwrap())), - "subgroupShuffle" => gm::Shuffle(Handle::new(NonZeroU32::new(u32::MAX).unwrap())), - "subgroupShuffleDown" => gm::ShuffleDown(Handle::new(NonZeroU32::new(u32::MAX).unwrap())), - "subgroupShuffleUp" => gm::ShuffleUp(Handle::new(NonZeroU32::new(u32::MAX).unwrap())), - "subgroupShuffleXor" => gm::ShuffleXor(Handle::new(NonZeroU32::new(u32::MAX).unwrap())), - _ => return None, - }) -}