Skip to content

Commit

Permalink
[naga-cli] Add input-kind and shader-stage args (gfx-rs#5411)
Browse files Browse the repository at this point in the history
  • Loading branch information
ratmice authored Apr 18, 2024
1 parent 4e77762 commit e0ac24a
Show file tree
Hide file tree
Showing 5 changed files with 107 additions and 45 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ Bottom level categories:

### New features

- Added `--shader-stage` and `--input-kind` options to naga-cli for specifying vertex/fragment/compute shaders, and frontend. by @ratmice in [#5411](https://github.com/gfx-rs/wgpu/pull/5411)

#### General

- Implemented the `Unorm10_10_10_2` VertexFormat.
Expand Down
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ path = "./naga"
version = "0.19.2"

[workspace.dependencies]
anyhow = "1.0"
anyhow = "1.0.23"
arrayvec = "0.7"
bit-vec = "0.6"
bitflags = "2"
Expand Down
1 change: 1 addition & 0 deletions naga-cli/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ log = "0.4"
codespan-reporting = "0.11"
env_logger = "0.11"
argh = "0.1.5"
anyhow.workspace = true

[dependencies.naga]
version = "0.19"
Expand Down
146 changes: 102 additions & 44 deletions naga-cli/src/bin/naga.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#![allow(clippy::manual_strip)]
use anyhow::{anyhow, Context as _};
#[allow(unused_imports)]
use std::fs;
use std::{error::Error, fmt, io::Read, path::Path, str::FromStr};
Expand Down Expand Up @@ -62,6 +63,16 @@ struct Args {
#[argh(option)]
shader_model: Option<ShaderModelArg>,

/// the shader stage, for example 'frag', 'vert', or 'compute'.
/// if the shader stage is unspecified it will be derived from
/// the file extension.
#[argh(option)]
shader_stage: Option<ShaderStage>,

/// the kind of input, e.g. 'glsl', 'wgsl', 'spv', or 'bin'.
#[argh(option)]
input_kind: Option<InputKind>,

/// the metal version to use, for example, 1.0, 1.1, 1.2, etc.
#[argh(option)]
metal_version: Option<MslVersionArg>,
Expand Down Expand Up @@ -170,6 +181,46 @@ impl FromStr for ShaderModelArg {
}
}

/// Newtype so we can implement [`FromStr`] for `ShaderSource`.
#[derive(Debug, Clone, Copy)]
struct ShaderStage(naga::ShaderStage);

impl FromStr for ShaderStage {
type Err = anyhow::Error;

fn from_str(s: &str) -> Result<Self, Self::Err> {
use naga::ShaderStage;
Ok(Self(match s.to_lowercase().as_str() {
"frag" | "fragment" => ShaderStage::Fragment,
"comp" | "compute" => ShaderStage::Compute,
"vert" | "vertex" => ShaderStage::Vertex,
_ => return Err(anyhow!("Invalid shader stage: {s}")),
}))
}
}

/// Input kind/file extension mapping
#[derive(Debug, Clone, Copy)]
enum InputKind {
Bincode,
Glsl,
SpirV,
Wgsl,
}
impl FromStr for InputKind {
type Err = anyhow::Error;

fn from_str(s: &str) -> Result<Self, Self::Err> {
Ok(match s.to_lowercase().as_str() {
"bin" => InputKind::Bincode,
"glsl" => InputKind::Glsl,
"spv" => InputKind::SpirV,
"wgsl" => InputKind::Wgsl,
_ => return Err(anyhow!("Invalid value for --input-kind: {s}")),
})
}
}

/// Newtype so we can implement [`FromStr`] for [`naga::back::glsl::Version`].
#[derive(Clone, Debug)]
struct GlslProfileArg(naga::back::glsl::Version);
Expand Down Expand Up @@ -247,6 +298,8 @@ struct Parameters<'a> {
msl: naga::back::msl::Options,
glsl: naga::back::glsl::Options,
hlsl: naga::back::hlsl::Options,
input_kind: Option<InputKind>,
shader_stage: Option<ShaderStage>,
}

trait PrettyResult {
Expand Down Expand Up @@ -300,7 +353,7 @@ impl fmt::Display for CliError {
}
impl std::error::Error for CliError {}

fn run() -> Result<(), Box<dyn std::error::Error>> {
fn run() -> anyhow::Result<()> {
env_logger::init();

// Parse commandline arguments
Expand Down Expand Up @@ -381,6 +434,9 @@ fn run() -> Result<(), Box<dyn std::error::Error>> {
return Err(CliError("Input file path is not specified").into());
};

params.input_kind = args.input_kind;
params.shader_stage = args.shader_stage;

let Parsed {
mut module,
input_text,
Expand Down Expand Up @@ -500,75 +556,77 @@ struct Parsed {
input_text: Option<String>,
}

fn parse_input(
input_path: &Path,
input: Vec<u8>,
params: &Parameters,
) -> Result<Parsed, Box<dyn std::error::Error>> {
let (module, input_text) = match Path::new(&input_path)
.extension()
.ok_or(CliError("Input filename has no extension"))?
.to_str()
.ok_or(CliError("Input filename not valid unicode"))?
{
"bin" => (bincode::deserialize(&input)?, None),
"spv" => naga::front::spv::parse_u8_slice(&input, &params.spv_in).map(|m| (m, None))?,
"wgsl" => {
fn parse_input(input_path: &Path, input: Vec<u8>, params: &Parameters) -> anyhow::Result<Parsed> {
let input_kind = match params.input_kind {
Some(kind) => kind,
None => input_path
.extension()
.context("Input filename has no extension")?
.to_str()
.context("Input filename not valid unicode")?
.parse()
.context("Unable to determine --input-kind from filename")?,
};

let (module, input_text) = match input_kind {
InputKind::Bincode => (bincode::deserialize(&input)?, None),
InputKind::SpirV => {
naga::front::spv::parse_u8_slice(&input, &params.spv_in).map(|m| (m, None))?
}
InputKind::Wgsl => {
let input = String::from_utf8(input)?;
let result = naga::front::wgsl::parse_str(&input);
match result {
Ok(v) => (v, Some(input)),
Err(ref e) => {
let message = format!(
let message = anyhow!(
"Could not parse WGSL:\n{}",
e.emit_to_string_with_path(&input, input_path)
);
return Err(message.into());
return Err(message);
}
}
}
ext @ ("vert" | "frag" | "comp" | "glsl") => {
InputKind::Glsl => {
let shader_stage = match params.shader_stage {
Some(shader_stage) => shader_stage,
None => {
// filename.shader_stage.glsl -> filename.shader_stage
let file_stem = input_path
.file_stem()
.context("Unable to determine file stem from input filename.")?;
// filename.shader_stage -> shader_stage
let inner_ext = Path::new(file_stem)
.extension()
.context("Unable to determine inner extension from input filename.")?
.to_str()
.context("Input filename not valid unicode")?;
inner_ext.parse().context("from input filename")?
}
};
let input = String::from_utf8(input)?;
let mut parser = naga::front::glsl::Frontend::default();

(
parser
.parse(
&naga::front::glsl::Options {
stage: match ext {
"vert" => naga::ShaderStage::Vertex,
"frag" => naga::ShaderStage::Fragment,
"comp" => naga::ShaderStage::Compute,
"glsl" => {
let internal_name = input_path.to_string_lossy();
match Path::new(&internal_name[..internal_name.len()-5])
.extension()
.ok_or(CliError("Input filename ending with .glsl has no internal extension"))?
.to_str()
.ok_or(CliError("Input filename not valid unicode"))?
{
"vert" => naga::ShaderStage::Vertex,
"frag" => naga::ShaderStage::Fragment,
"comp" => naga::ShaderStage::Compute,
_ => unreachable!(),
}
},
_ => unreachable!(),
},
stage: shader_stage.0,
defines: Default::default(),
},
&input,
)
.unwrap_or_else(|error| {
let filename = input_path.file_name().and_then(std::ffi::OsStr::to_str).unwrap_or("glsl");
let filename = input_path
.file_name()
.and_then(std::ffi::OsStr::to_str)
.unwrap_or("glsl");
let mut writer = StandardStream::stderr(ColorChoice::Auto);
error.emit_to_writer_with_path(&mut writer, &input, filename);
std::process::exit(1);
}),
Some(input),
)
}
_ => return Err(CliError("Unknown input file extension").into()),
};

Ok(Parsed { module, input_text })
Expand All @@ -579,7 +637,7 @@ fn write_output(
info: &Option<naga::valid::ModuleInfo>,
params: &Parameters,
output_path: &str,
) -> Result<(), Box<dyn std::error::Error>> {
) -> anyhow::Result<()> {
match Path::new(&output_path)
.extension()
.ok_or(CliError("Output filename has no extension"))?
Expand Down Expand Up @@ -744,7 +802,7 @@ fn write_output(
Ok(())
}

fn bulk_validate(args: Args, params: &Parameters) -> Result<(), Box<dyn std::error::Error>> {
fn bulk_validate(args: Args, params: &Parameters) -> anyhow::Result<()> {
let mut invalid = vec![];
for input_path in args.files {
let path = Path::new(&input_path);
Expand Down Expand Up @@ -787,7 +845,7 @@ fn bulk_validate(args: Args, params: &Parameters) -> Result<(), Box<dyn std::err
for path in invalid {
writeln!(&mut formatted, " {path}").unwrap();
}
return Err(formatted.into());
return Err(anyhow!(formatted));
}

Ok(())
Expand Down

0 comments on commit e0ac24a

Please sign in to comment.