Skip to content

Commit

Permalink
subgroup: refactor wgsl subgroup gather parsing
Browse files Browse the repository at this point in the history
  • Loading branch information
exrook committed Oct 21, 2023
1 parent 3f67bc4 commit 064afd8
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 28 deletions.
53 changes: 40 additions & 13 deletions src/front/wgsl/lower/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -884,6 +884,29 @@ impl Texture {
}
}

enum SubgroupGather {
BroadcastFirst,
Broadcast,
Shuffle,
ShuffleDown,
ShuffleUp,
ShuffleXor,
}

impl SubgroupGather {
pub fn map(word: &str) -> Option<Self> {
Some(match word {
"subgroupBroadcastFirst" => Self::BroadcastFirst,
"subgroupBroadcast" => Self::Broadcast,
"subgroupShuffle" => Self::Shuffle,
"subgroupShuffleDown" => Self::ShuffleDown,
"subgroupShuffleUp" => Self::ShuffleUp,
"subgroupShuffleXor" => Self::ShuffleXor,
_ => return None,
})
}
}

pub struct Lowerer<'source, 'temp> {
index: &'temp Index<'source>,
layouter: Layouter,
Expand Down Expand Up @@ -1921,7 +1944,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
return Ok(Some(
self.subgroup_operation_helper(span, op, cop, arguments, ctx)?,
));
} else if let Some(mode) = conv::map_subgroup_gather(function.name) {
} else if let Some(mode) = SubgroupGather::map(function.name) {
return Ok(Some(
self.subgroup_gather_helper(span, mode, arguments, ctx)?,
));
Expand Down Expand Up @@ -2547,18 +2570,29 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
fn subgroup_gather_helper(
&mut self,
span: Span,
mode: crate::GatherMode,
mode: SubgroupGather,
arguments: &[Handle<ast::Expression<'source>>],
mut ctx: ExpressionContext<'source, '_, '_>,
) -> Result<Handle<crate::Expression>, Error<'source>> {
let mut args = ctx.prepare_args(arguments, 2, span);

let argument = self.expression(args.next()?, ctx.reborrow())?;
let index = if let crate::GatherMode::BroadcastFirst = mode {
Handle::new(NonZeroU32::new(u32::MAX).unwrap())

use SubgroupGather as Sg;
let mode = if let Sg::BroadcastFirst = mode {
crate::GatherMode::BroadcastFirst
} else {
self.expression(args.next()?, ctx.reborrow())?
let index = self.expression(args.next()?, ctx.reborrow())?;
match mode {
Sg::Broadcast => crate::GatherMode::Broadcast(index),
Sg::Shuffle => crate::GatherMode::Shuffle(index),
Sg::ShuffleDown => crate::GatherMode::ShuffleDown(index),
Sg::ShuffleUp => crate::GatherMode::ShuffleUp(index),
Sg::ShuffleXor => crate::GatherMode::ShuffleXor(index),
Sg::BroadcastFirst => unreachable!(),
}
};

args.finish()?;

let ty = ctx.register_type(argument)?;
Expand All @@ -2568,14 +2602,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
let rctx = ctx.runtime_expression_ctx(span)?;
rctx.block.push(
crate::Statement::SubgroupGather {
mode: match mode {
crate::GatherMode::BroadcastFirst => crate::GatherMode::BroadcastFirst,
crate::GatherMode::Broadcast(_) => crate::GatherMode::Broadcast(index),
crate::GatherMode::Shuffle(_) => crate::GatherMode::Shuffle(index),
crate::GatherMode::ShuffleDown(_) => crate::GatherMode::ShuffleDown(index),
crate::GatherMode::ShuffleUp(_) => crate::GatherMode::ShuffleUp(index),
crate::GatherMode::ShuffleXor(_) => crate::GatherMode::ShuffleXor(index),
},
mode,
argument,
result,
},
Expand Down
15 changes: 0 additions & 15 deletions src/front/wgsl/parse/conv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -263,18 +263,3 @@ pub fn map_subgroup_operation(
_ => return None,
})
}

pub fn map_subgroup_gather(word: &str) -> Option<crate::GatherMode> {
use crate::GatherMode as gm;
use crate::Handle;
use std::num::NonZeroU32;
Some(match word {
"subgroupBroadcastFirst" => gm::BroadcastFirst,
"subgroupBroadcast" => gm::Broadcast(Handle::new(NonZeroU32::new(u32::MAX).unwrap())),
"subgroupShuffle" => gm::Shuffle(Handle::new(NonZeroU32::new(u32::MAX).unwrap())),
"subgroupShuffleDown" => gm::ShuffleDown(Handle::new(NonZeroU32::new(u32::MAX).unwrap())),
"subgroupShuffleUp" => gm::ShuffleUp(Handle::new(NonZeroU32::new(u32::MAX).unwrap())),
"subgroupShuffleXor" => gm::ShuffleXor(Handle::new(NonZeroU32::new(u32::MAX).unwrap())),
_ => return None,
})
}

0 comments on commit 064afd8

Please sign in to comment.