Skip to content

Commit

Permalink
[msl-out] Correctly construct out-of-bounds values for all types.
Browse files Browse the repository at this point in the history
  • Loading branch information
jimblandy authored and kvark committed Dec 15, 2021
1 parent 8a2fbd3 commit 3867ef4
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 12 deletions.
2 changes: 2 additions & 0 deletions src/back/msl/keywords.rs
Original file line number Diff line number Diff line change
Expand Up @@ -210,4 +210,6 @@ pub const RESERVED: &[&str] = &[
"M_2_SQRTPI",
"M_SQRT2",
"M_SQRT1_2",
// Naga utilities
"DefaultConstructible",
];
33 changes: 31 additions & 2 deletions src/back/msl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -822,7 +822,7 @@ impl<W: Write> Writer<W> {
{
write!(self.out, " ? ")?;
self.put_access_chain(expr_handle, policy, context)?;
write!(self.out, " : 0")?;
write!(self.out, " : DefaultConstructible()")?;

if !is_scoped {
write!(self.out, ")")?;
Expand Down Expand Up @@ -1529,7 +1529,7 @@ impl<W: Write> Writer<W> {
{
write!(self.out, " ? ")?;
self.put_unchecked_load(pointer, policy, context)?;
write!(self.out, " : 0")?;
write!(self.out, " : DefaultConstructible()")?;

if !is_scoped {
write!(self.out, ")")?;
Expand Down Expand Up @@ -2154,6 +2154,13 @@ impl<W: Write> Writer<W> {
writeln!(self.out, "#include <simd/simd.h>")?;
writeln!(self.out)?;

if options
.bounds_check_policies
.contains(index::BoundsCheckPolicy::ReadZeroSkipWrite)
{
self.put_default_constructible()?;
}

{
let mut indices = vec![];
for (handle, var) in module.global_variables.iter() {
Expand Down Expand Up @@ -2181,6 +2188,28 @@ impl<W: Write> Writer<W> {
self.write_functions(module, info, options, pipeline_options)
}

/// Write the definition for the `DefaultConstructible` class.
///
/// The [`ReadZeroSkipWrite`] bounds check policy requires us to be able to
/// produce 'zero' values for any type, including structs, arrays, and so
/// on. We could do this by emitting default constructor applications, but
/// that would entail printing the name of the type, which is more trouble
/// than you'd think. Instead, we just construct this magic C++14 class that
/// can be converted to any type that can be default constructed, using
/// template parameter inference to detect which type is needed, so we don't
/// have to figure out the name.
///
/// [`ReadZeroSkipWrite`]: index::BoundsCheckPolicy::ReadZeroSkipWrite
fn put_default_constructible(&mut self) -> BackendResult {
writeln!(self.out, "struct DefaultConstructible {{")?;
writeln!(self.out, " template<typename T>")?;
writeln!(self.out, " operator T() && {{")?;
writeln!(self.out, " return T {{}};")?;
writeln!(self.out, " }}")?;
writeln!(self.out, "}};")?;
Ok(())
}

fn write_type_defs(&mut self, module: &crate::Module) -> BackendResult {
for (handle, ty) in module.types.iter() {
if !ty.needs_alias() {
Expand Down
9 changes: 6 additions & 3 deletions src/proc/index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,11 @@ impl BoundsCheckPolicies {
_ => self.index,
}
}

/// Return `true` if any of `self`'s policies are `policy`.
pub fn contains(&self, policy: BoundsCheckPolicy) -> bool {
self.index == policy || self.buffer == policy || self.image == policy
}
}

/// An index that may be statically known, or may need to be computed at runtime.
Expand Down Expand Up @@ -204,9 +209,7 @@ pub fn find_checked_indexes(
let mut guarded_indices = BitSet::new();

// Don't bother scanning if we never need `ReadZeroSkipWrite`.
if policies.index == BoundsCheckPolicy::ReadZeroSkipWrite
|| policies.buffer == BoundsCheckPolicy::ReadZeroSkipWrite
{
if policies.contains(BoundsCheckPolicy::ReadZeroSkipWrite) {
for (_handle, expr) in function.expressions.iter() {
// There's no need to handle `AccessIndex` expressions, as their
// indices never need to be cached.
Expand Down
20 changes: 13 additions & 7 deletions tests/out/msl/bounds-check-zero.msl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,12 @@
#include <metal_stdlib>
#include <simd/simd.h>

struct DefaultConstructible {
template<typename T>
operator T() && {
return T {};
}
};
struct _mslBufferSizes {
metal::uint size0;
};
Expand All @@ -23,7 +29,7 @@ float index_array(
device Globals& globals,
constant _mslBufferSizes& _buffer_sizes
) {
float _e4 = metal::uint(i) < 10 ? globals.a.inner[i] : 0;
float _e4 = metal::uint(i) < 10 ? globals.a.inner[i] : DefaultConstructible();
return _e4;
}

Expand All @@ -32,7 +38,7 @@ float index_dynamic_array(
device Globals& globals,
constant _mslBufferSizes& _buffer_sizes
) {
float _e4 = metal::uint(i_1) < 1 + (_buffer_sizes.size0 - 112 - 4) / 4 ? globals.d[i_1] : 0;
float _e4 = metal::uint(i_1) < 1 + (_buffer_sizes.size0 - 112 - 4) / 4 ? globals.d[i_1] : DefaultConstructible();
return _e4;
}

Expand All @@ -41,23 +47,23 @@ float index_vector(
device Globals& globals,
constant _mslBufferSizes& _buffer_sizes
) {
float _e4 = metal::uint(i_2) < 4 ? globals.v[i_2] : 0;
float _e4 = metal::uint(i_2) < 4 ? globals.v[i_2] : DefaultConstructible();
return _e4;
}

float index_vector_by_value(
metal::float4 v,
int i_3
) {
return metal::uint(i_3) < 4 ? v[i_3] : 0;
return metal::uint(i_3) < 4 ? v[i_3] : DefaultConstructible();
}

metal::float4 index_matrix(
int i_4,
device Globals& globals,
constant _mslBufferSizes& _buffer_sizes
) {
metal::float4 _e4 = metal::uint(i_4) < 3 ? globals.m[i_4] : 0;
metal::float4 _e4 = metal::uint(i_4) < 3 ? globals.m[i_4] : DefaultConstructible();
return _e4;
}

Expand All @@ -67,7 +73,7 @@ float index_twice(
device Globals& globals,
constant _mslBufferSizes& _buffer_sizes
) {
float _e6 = metal::uint(j) < 4 && metal::uint(i_5) < 3 ? globals.m[i_5][j] : 0;
float _e6 = metal::uint(j) < 4 && metal::uint(i_5) < 3 ? globals.m[i_5][j] : DefaultConstructible();
return _e6;
}

Expand All @@ -77,7 +83,7 @@ float index_expensive(
constant _mslBufferSizes& _buffer_sizes
) {
int _e9 = static_cast<int>(metal::sin(static_cast<float>(i_6) / 100.0) * 100.0);
float _e11 = metal::uint(_e9) < 10 ? globals.a.inner[_e9] : 0;
float _e11 = metal::uint(_e9) < 10 ? globals.a.inner[_e9] : DefaultConstructible();
return _e11;
}

Expand Down
6 changes: 6 additions & 0 deletions tests/out/msl/policy-mix.msl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,12 @@
#include <metal_stdlib>
#include <simd/simd.h>

struct DefaultConstructible {
template<typename T>
operator T() && {
return T {};
}
};
struct type_1 {
metal::float4 inner[10];
};
Expand Down

0 comments on commit 3867ef4

Please sign in to comment.