Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[wgsl-in] Handle modf and frexp #2452

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/back/glsl/keywords.rs
Original file line number Diff line number Diff line change
Expand Up @@ -477,4 +477,9 @@ pub const RESERVED_KEYWORDS: &[&str] = &[
// entry point name (should not be shadowed)
//
"main",
// Naga utilities:
super::MODF_FUNCTION,
super::MODF_STRUCT,
super::FREXP_FUNCTION,
super::FREXP_STRUCT,
];
57 changes: 55 additions & 2 deletions src/back/glsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,11 @@ pub const SUPPORTED_ES_VERSIONS: &[u16] = &[300, 310, 320];
/// of detail for bounds checking in `ImageLoad`
const CLAMPED_LOD_SUFFIX: &str = "_clamped_lod";

pub(crate) const FREXP_FUNCTION: &str = "naga_frexp";
pub(crate) const FREXP_STRUCT: &str = "naga_frexp_result";
pub(crate) const MODF_FUNCTION: &str = "naga_modf";
pub(crate) const MODF_STRUCT: &str = "naga_modf_result";

/// Mapping between resources and bindings.
pub type BindingMap = std::collections::BTreeMap<crate::ResourceBinding, u8>;

Expand Down Expand Up @@ -604,6 +609,38 @@ impl<'a, W: Write> Writer<'a, W> {
}
}

if self.module.special_types.frexp_result.is_some() {
writeln!(
self.out,
"struct {FREXP_STRUCT} {{
float fract;
int exp;
}};

{FREXP_STRUCT} {FREXP_FUNCTION}(float arg) {{
int exp;
float fract = frexp(arg, exp);
return {FREXP_STRUCT}(fract, exp);
}}"
)?;
}

if self.module.special_types.modf_result.is_some() {
writeln!(
self.out,
"struct {MODF_STRUCT} {{
float fract;
float whole;
}};

{MODF_STRUCT} {MODF_FUNCTION}(float arg) {{
float whole;
float fract = modf(arg, whole);
return {MODF_STRUCT}(fract, whole);
}}"
)?;
}

// Write struct types.
//
// This are always ordered because the IR is structured in a way that
Expand Down Expand Up @@ -860,6 +897,8 @@ impl<'a, W: Write> Writer<'a, W> {
| TypeInner::Sampler { .. }
| TypeInner::AccelerationStructure
| TypeInner::RayQuery
| TypeInner::ModfResult
| TypeInner::FrexpResult
| TypeInner::BindingArray { .. } => {
return Err(Error::Custom(format!("Unable to write type {inner:?}")))
}
Expand All @@ -885,6 +924,14 @@ impl<'a, W: Write> Writer<'a, W> {
}
// glsl array has the size separated from the base type
TypeInner::Array { base, .. } => self.write_type(base),
TypeInner::FrexpResult => {
write!(self.out, "{FREXP_STRUCT}")?;
Ok(())
}
TypeInner::ModfResult => {
write!(self.out, "{MODF_STRUCT}")?;
Ok(())
}
ref other => self.write_value_type(other),
}
}
Expand Down Expand Up @@ -2325,6 +2372,12 @@ impl<'a, W: Write> Writer<'a, W> {
&self.names[&NameKey::StructMember(ty, index)]
)?
}
TypeInner::FrexpResult => {
write!(self.out, ".{}", if index == 0 { "fract" } else { "exp" })?
}
TypeInner::ModfResult => {
write!(self.out, ".{}", if index == 0 { "fract" } else { "whole" })?
}
ref other => return Err(Error::Custom(format!("Cannot index {other:?}"))),
}
}
Expand Down Expand Up @@ -2985,8 +3038,8 @@ impl<'a, W: Write> Writer<'a, W> {
Mf::Round => "roundEven",
Mf::Fract => "fract",
Mf::Trunc => "trunc",
Mf::Modf => "modf",
Mf::Frexp => "frexp",
Mf::Modf => MODF_FUNCTION,
Mf::Frexp => FREXP_FUNCTION,
Mf::Ldexp => "ldexp",
// exponent
Mf::Exp => "exp",
Expand Down
46 changes: 46 additions & 0 deletions src/back/hlsl/help.rs
Original file line number Diff line number Diff line change
Expand Up @@ -781,6 +781,52 @@ impl<'a, W: Write> super::Writer<'a, W> {
Ok(())
}

pub(super) fn write_special_functions(&mut self, module: &crate::Module) -> BackendResult {
if module.special_types.frexp_result.is_some() {
let function_name = super::writer::FREXP_FUNCTION;
let struct_name = super::writer::FREXP_STRUCT;
writeln!(
self.out,
"struct {struct_name} {{
float fract;
int exp;
}};

{struct_name} {function_name}(in float arg) {{
float exp;
float fract = frexp(arg, exp);
{struct_name} result;
result.exp = exp;
result.fract = fract;
return result;
}}"
)?;
writeln!(self.out)?;
}
if module.special_types.modf_result.is_some() {
let function_name = super::writer::MODF_FUNCTION;
let struct_name = super::writer::MODF_STRUCT;
writeln!(
self.out,
"struct {struct_name} {{
float fract;
float whole;
}};

{struct_name} {function_name}(in float arg) {{
float whole;
float fract = modf(arg, whole);
{struct_name} result;
result.whole = whole;
result.fract = fract;
return result;
}}"
)?;
writeln!(self.out)?;
}
Ok(())
}

/// Helper function that writes compose wrapped functions
pub(super) fn write_wrapped_compose_functions(
&mut self,
Expand Down
5 changes: 5 additions & 0 deletions src/back/hlsl/keywords.rs
Original file line number Diff line number Diff line change
Expand Up @@ -814,6 +814,11 @@ pub const RESERVED: &[&str] = &[
"TextureBuffer",
"ConstantBuffer",
"RayQuery",
// Naga utilities
super::writer::FREXP_FUNCTION,
super::writer::FREXP_STRUCT,
super::writer::MODF_FUNCTION,
super::writer::MODF_STRUCT,
];

// DXC scalar types, from https://github.com/microsoft/DirectXShaderCompiler/blob/18c9e114f9c314f93e68fbc72ce207d4ed2e65ae/tools/clang/lib/AST/ASTContextHLSL.cpp#L48-L254
Expand Down
25 changes: 23 additions & 2 deletions src/back/hlsl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@ const SPECIAL_BASE_VERTEX: &str = "base_vertex";
const SPECIAL_BASE_INSTANCE: &str = "base_instance";
const SPECIAL_OTHER: &str = "other";

pub(crate) const FREXP_FUNCTION: &str = "naga_frexp";
pub(crate) const FREXP_STRUCT: &str = "naga_frexp_result";
pub(crate) const MODF_FUNCTION: &str = "naga_modf";
pub(crate) const MODF_STRUCT: &str = "naga_modf_result";

struct EpStructMember {
name: String,
ty: Handle<crate::Type>,
Expand Down Expand Up @@ -244,6 +249,8 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
}
}

self.write_special_functions(module)?;

self.write_wrapped_compose_functions(module, &module.const_expressions)?;

// Write all named constants
Expand Down Expand Up @@ -1058,6 +1065,12 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
TypeInner::Array { base, size, .. } | TypeInner::BindingArray { base, size } => {
self.write_array_size(module, base, size)?;
}
TypeInner::FrexpResult => {
write!(self.out, "struct {FREXP_STRUCT}")?;
}
TypeInner::ModfResult => {
write!(self.out, "struct {MODF_STRUCT}")?;
}
_ => return Err(Error::Unimplemented(format!("write_value_type {inner:?}"))),
}

Expand Down Expand Up @@ -2276,6 +2289,14 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
&writer.names[&NameKey::StructMember(ty, index)]
)?
}
TypeInner::FrexpResult => {
write!(writer.out, ".{}", if index == 0 { "fract" } else { "exp" })?
}
TypeInner::ModfResult => write!(
writer.out,
".{}",
if index == 0 { "fract" } else { "whole" }
)?,
ref other => {
return Err(Error::Custom(format!("Cannot index {other:?}")))
}
Expand Down Expand Up @@ -2665,8 +2686,8 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
Mf::Round => Function::Regular("round"),
Mf::Fract => Function::Regular("frac"),
Mf::Trunc => Function::Regular("trunc"),
Mf::Modf => Function::Regular("modf"),
Mf::Frexp => Function::Regular("frexp"),
Mf::Modf => Function::Regular(MODF_FUNCTION),
Mf::Frexp => Function::Regular(FREXP_FUNCTION),
Mf::Ldexp => Function::Regular("ldexp"),
// exponent
Mf::Exp => Function::Regular("exp"),
Expand Down
4 changes: 4 additions & 0 deletions src/back/msl/keywords.rs
Original file line number Diff line number Diff line change
Expand Up @@ -214,4 +214,8 @@ pub const RESERVED: &[&str] = &[
// Naga utilities
"DefaultConstructible",
"clamped_lod_e",
super::writer::FREXP_FUNCTION,
super::writer::FREXP_STRUCT,
super::writer::MODF_FUNCTION,
super::writer::MODF_STRUCT,
];
90 changes: 87 additions & 3 deletions src/back/msl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,11 @@ const RAY_QUERY_FIELD_INTERSECTION: &str = "intersection";
const RAY_QUERY_FIELD_READY: &str = "ready";
const RAY_QUERY_FUN_MAP_INTERSECTION: &str = "_map_intersection_type";

pub(crate) const FREXP_FUNCTION: &str = "naga_frexp";
pub(crate) const FREXP_STRUCT: &str = "naga_frexp_result";
pub(crate) const MODF_FUNCTION: &str = "naga_modf";
pub(crate) const MODF_STRUCT: &str = "naga_modf_result";

/// Write the Metal name for a Naga numeric type: scalar, vector, or matrix.
///
/// The `sizes` slice determines whether this function writes a
Expand Down Expand Up @@ -140,7 +145,15 @@ impl<'a> Display for TypeContext<'a> {
// so just print the element type here.
write!(out, "{sub}")
}
crate::TypeInner::Struct { .. } => unreachable!(),
crate::TypeInner::Struct { .. } => {
unreachable!()
}
crate::TypeInner::FrexpResult { .. } => {
write!(out, "{}", FREXP_STRUCT)
}
crate::TypeInner::ModfResult { .. } => {
write!(out, "{}", MODF_STRUCT)
}
crate::TypeInner::Image {
dim,
arrayed,
Expand Down Expand Up @@ -452,6 +465,8 @@ impl crate::Type {
| Ti::Sampler { .. }
| Ti::AccelerationStructure
| Ti::RayQuery
| Ti::FrexpResult
| Ti::ModfResult
| Ti::BindingArray { .. } => false,
}
}
Expand Down Expand Up @@ -1635,6 +1650,24 @@ impl<W: Write> Writer<W> {
write!(self.out, "{NAMESPACE}::{op}")?;
self.put_call_parameters(iter::once(argument), context)?;
}
crate::Expression::Math {
fun: crate::MathFunction::Frexp,
arg,
..
} => {
write!(self.out, "{FREXP_FUNCTION}(")?;
self.put_expression(arg, context, false)?;
write!(self.out, ")")?;
}
crate::Expression::Math {
fun: crate::MathFunction::Modf,
arg,
..
} => {
write!(self.out, "{MODF_FUNCTION}(")?;
self.put_expression(arg, context, false)?;
write!(self.out, ")")?;
}
crate::Expression::Math {
fun,
arg,
Expand All @@ -1644,7 +1677,7 @@ impl<W: Write> Writer<W> {
} => {
use crate::MathFunction as Mf;

let scalar_argument = match *context.resolve_type(arg) {
let scalar_argument: bool = match *context.resolve_type(arg) {
crate::TypeInner::Scalar { .. } => true,
_ => false,
};
Expand Down Expand Up @@ -2018,7 +2051,9 @@ impl<W: Write> Writer<W> {
base_inner = &context.module.types[base].inner;
}
match *base_inner {
crate::TypeInner::Struct { .. } => (base, None),
crate::TypeInner::Struct { .. }
| crate::TypeInner::FrexpResult
| crate::TypeInner::ModfResult { .. } => (base, None),
_ => (base, Some(index::GuardedIndex::Known(index))),
}
}
Expand Down Expand Up @@ -2133,6 +2168,20 @@ impl<W: Write> Writer<W> {
write!(self.out, ".{}", back::COMPONENTS[index as usize])?;
}
}
crate::TypeInner::FrexpResult | crate::TypeInner::ModfResult => {
self.put_access_chain(base, policy, context)?;
write!(
self.out,
".{}",
if index == 0 {
"fract"
} else if *base_ty == crate::TypeInner::FrexpResult {
"exp"
} else {
"whole"
}
)?;
}
_ => {
self.put_subscripted_access_chain(
base,
Expand Down Expand Up @@ -3236,6 +3285,41 @@ impl<W: Write> Writer<W> {
}
}
}

if module.special_types.frexp_result.is_some() {
writeln!(
self.out,
"
struct {FREXP_STRUCT} {{
float fract;
int exp;
}};

struct {FREXP_STRUCT} {FREXP_FUNCTION}(float arg) {{
int exp;
float fract = {NAMESPACE}::frexp(arg, exp);
return {FREXP_STRUCT}{{ fract, exp }};
}};"
)?;
}

if module.special_types.modf_result.is_some() {
writeln!(
self.out,
"
struct {MODF_STRUCT} {{
float fract;
float whole;
}};

struct {MODF_STRUCT} {MODF_FUNCTION}(float arg) {{
float whole;
float fract = {NAMESPACE}::modf(arg, whole);
return {MODF_STRUCT}{{ fract, whole }};
}};"
)?;
}

Ok(())
}

Expand Down
Loading
Loading