Skip to content

Commit

Permalink
Small refactoring of paeth filter logic (#544)
Browse files Browse the repository at this point in the history
  • Loading branch information
fintelia authored Dec 6, 2024
1 parent 5202c6d commit fb660c2
Showing 1 changed file with 140 additions and 153 deletions.
293 changes: 140 additions & 153 deletions src/filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ mod simd {
{
let mut out = [0; N];
for i in 0..N {
out[i] = super::filter_paeth_decode_i16(a[i].into(), b[i].into(), c[i].into());
out[i] = super::filter_paeth_stbi_i16(a[i].into(), b[i].into(), c[i].into());
}
out.into()
}
Expand All @@ -55,7 +55,7 @@ mod simd {
{
let mut out = [0; N];
for i in 0..N {
out[i] = super::filter_paeth_decode(a[i].into(), b[i].into(), c[i].into());
out[i] = super::filter_paeth_stbi(a[i].into(), b[i].into(), c[i].into());
}
out.into()
}
Expand Down Expand Up @@ -277,9 +277,30 @@ impl Default for AdaptiveFilterType {
}
}

#[cfg(target_arch = "x86_64")]
fn filter_paeth_decode(a: u8, b: u8, c: u8) -> u8 {
// Decoding optimizes better with this algorithm than with `filter_paeth()`
fn filter_paeth(a: u8, b: u8, c: u8) -> u8 {
// On ARM this algorithm performs much better than the one above adapted from stb,
// and this is the better-studied algorithm we've always used here,
// so we default to it on all non-x86 platforms.
let pa = (i16::from(b) - i16::from(c)).abs();
let pb = (i16::from(a) - i16::from(c)).abs();
let pc = ((i16::from(a) - i16::from(c)) + (i16::from(b) - i16::from(c))).abs();

let mut out = a;
let mut min = pa;

if pb < min {
min = pb;
out = b;
}
if pc < min {
out = c;
}

out
}

fn filter_paeth_stbi(a: u8, b: u8, c: u8) -> u8 {
// Decoding optimizes better with this algorithm than with `filter_paeth`
//
// This formulation looks very different from the reference in the PNG spec, but is
// actually equivalent and has favorable data dependencies and admits straightforward
Expand All @@ -295,9 +316,9 @@ fn filter_paeth_decode(a: u8, b: u8, c: u8) -> u8 {
return t1;
}

#[cfg(all(feature = "unstable", target_arch = "x86_64"))]
fn filter_paeth_decode_i16(a: i16, b: i16, c: i16) -> i16 {
// Like `filter_paeth_decode` but vectorizes better when wrapped in SIMD types.
#[cfg(any(test, all(feature = "unstable", target_arch = "x86_64")))]
fn filter_paeth_stbi_i16(a: i16, b: i16, c: i16) -> i16 {
// Like `filter_paeth_stbi` but vectorizes better when wrapped in SIMD types.
// Used for bpp=3 and bpp=6
let thresh = c * 3 - (a + b);
let lo = a.min(b);
Expand All @@ -307,30 +328,7 @@ fn filter_paeth_decode_i16(a: i16, b: i16, c: i16) -> i16 {
return t1;
}

#[cfg(not(target_arch = "x86_64"))]
fn filter_paeth_decode(a: u8, b: u8, c: u8) -> u8 {
// On ARM this algorithm performs much better than the one above adapted from stb,
// and this is the better-studied algorithm we've always used here,
// so we default to it on all non-x86 platforms.
let pa = (i16::from(b) - i16::from(c)).abs();
let pb = (i16::from(a) - i16::from(c)).abs();
let pc = ((i16::from(a) - i16::from(c)) + (i16::from(b) - i16::from(c))).abs();

let mut out = a;
let mut min = pa;

if pb < min {
min = pb;
out = b;
}
if pc < min {
out = c;
}

out
}

fn filter_paeth(a: u8, b: u8, c: u8) -> u8 {
fn filter_paeth_fpnge(a: u8, b: u8, c: u8) -> u8 {
// This is an optimized version of the paeth filter from the PNG specification, proposed by
// Luca Versari for [FPNGE](https://www.lucaversari.it/FJXL_and_FPNGE.pdf). It operates
// entirely on unsigned 8-bit quantities, making it more conducive to vectorization.
Expand Down Expand Up @@ -706,7 +704,15 @@ pub(crate) fn unfilter(
}
}
},
#[allow(unreachable_code)]
Paeth => {
// Select the fastest Paeth filter implementation based on the target architecture.
let filter_paeth_decode = if cfg!(target_arch = "x86_64") {
filter_paeth_stbi
} else {
filter_paeth
};

// Paeth filter pixels:
// C B D
// A X
Expand Down Expand Up @@ -742,141 +748,116 @@ pub(crate) fn unfilter(
BytesPerPixel::Three => {
// Do not enable this algorithm on ARM, that would be a big performance hit
#[cfg(all(feature = "unstable", target_arch = "x86_64"))]
simd::unfilter_paeth3(previous, current);
{
simd::unfilter_paeth3(previous, current);
return;
}

#[cfg(not(feature = "unstable"))]
let mut a_bpp = [0; 3];
let mut c_bpp = [0; 3];
for (chunk, b_bpp) in current.chunks_exact_mut(3).zip(previous.chunks_exact(3))
{
let mut a_bpp = [0; 3];
let mut c_bpp = [0; 3];
for (chunk, b_bpp) in
current.chunks_exact_mut(3).zip(previous.chunks_exact(3))
{
let new_chunk = [
chunk[0].wrapping_add(filter_paeth_decode(
a_bpp[0], b_bpp[0], c_bpp[0],
)),
chunk[1].wrapping_add(filter_paeth_decode(
a_bpp[1], b_bpp[1], c_bpp[1],
)),
chunk[2].wrapping_add(filter_paeth_decode(
a_bpp[2], b_bpp[2], c_bpp[2],
)),
];
*TryInto::<&mut [u8; 3]>::try_into(chunk).unwrap() = new_chunk;
a_bpp = new_chunk;
c_bpp = b_bpp.try_into().unwrap();
}
let new_chunk = [
chunk[0]
.wrapping_add(filter_paeth_decode(a_bpp[0], b_bpp[0], c_bpp[0])),
chunk[1]
.wrapping_add(filter_paeth_decode(a_bpp[1], b_bpp[1], c_bpp[1])),
chunk[2]
.wrapping_add(filter_paeth_decode(a_bpp[2], b_bpp[2], c_bpp[2])),
];
*TryInto::<&mut [u8; 3]>::try_into(chunk).unwrap() = new_chunk;
a_bpp = new_chunk;
c_bpp = b_bpp.try_into().unwrap();
}
}
BytesPerPixel::Four => {
#[cfg(all(feature = "unstable", target_arch = "x86_64"))]
simd::unfilter_paeth_u8::<4>(previous, current);
{
simd::unfilter_paeth_u8::<4>(previous, current);
return;
}

#[cfg(not(feature = "unstable"))]
let mut a_bpp = [0; 4];
let mut c_bpp = [0; 4];
for (chunk, b_bpp) in current.chunks_exact_mut(4).zip(previous.chunks_exact(4))
{
let mut a_bpp = [0; 4];
let mut c_bpp = [0; 4];
for (chunk, b_bpp) in
current.chunks_exact_mut(4).zip(previous.chunks_exact(4))
{
let new_chunk = [
chunk[0].wrapping_add(filter_paeth_decode(
a_bpp[0], b_bpp[0], c_bpp[0],
)),
chunk[1].wrapping_add(filter_paeth_decode(
a_bpp[1], b_bpp[1], c_bpp[1],
)),
chunk[2].wrapping_add(filter_paeth_decode(
a_bpp[2], b_bpp[2], c_bpp[2],
)),
chunk[3].wrapping_add(filter_paeth_decode(
a_bpp[3], b_bpp[3], c_bpp[3],
)),
];
*TryInto::<&mut [u8; 4]>::try_into(chunk).unwrap() = new_chunk;
a_bpp = new_chunk;
c_bpp = b_bpp.try_into().unwrap();
}
let new_chunk = [
chunk[0]
.wrapping_add(filter_paeth_decode(a_bpp[0], b_bpp[0], c_bpp[0])),
chunk[1]
.wrapping_add(filter_paeth_decode(a_bpp[1], b_bpp[1], c_bpp[1])),
chunk[2]
.wrapping_add(filter_paeth_decode(a_bpp[2], b_bpp[2], c_bpp[2])),
chunk[3]
.wrapping_add(filter_paeth_decode(a_bpp[3], b_bpp[3], c_bpp[3])),
];
*TryInto::<&mut [u8; 4]>::try_into(chunk).unwrap() = new_chunk;
a_bpp = new_chunk;
c_bpp = b_bpp.try_into().unwrap();
}
}
BytesPerPixel::Six => {
#[cfg(all(feature = "unstable", target_arch = "x86_64"))]
simd::unfilter_paeth6(previous, current);
{
simd::unfilter_paeth6(previous, current);
return;
}

#[cfg(not(feature = "unstable"))]
let mut a_bpp = [0; 6];
let mut c_bpp = [0; 6];
for (chunk, b_bpp) in current.chunks_exact_mut(6).zip(previous.chunks_exact(6))
{
let mut a_bpp = [0; 6];
let mut c_bpp = [0; 6];
for (chunk, b_bpp) in
current.chunks_exact_mut(6).zip(previous.chunks_exact(6))
{
let new_chunk = [
chunk[0].wrapping_add(filter_paeth_decode(
a_bpp[0], b_bpp[0], c_bpp[0],
)),
chunk[1].wrapping_add(filter_paeth_decode(
a_bpp[1], b_bpp[1], c_bpp[1],
)),
chunk[2].wrapping_add(filter_paeth_decode(
a_bpp[2], b_bpp[2], c_bpp[2],
)),
chunk[3].wrapping_add(filter_paeth_decode(
a_bpp[3], b_bpp[3], c_bpp[3],
)),
chunk[4].wrapping_add(filter_paeth_decode(
a_bpp[4], b_bpp[4], c_bpp[4],
)),
chunk[5].wrapping_add(filter_paeth_decode(
a_bpp[5], b_bpp[5], c_bpp[5],
)),
];
*TryInto::<&mut [u8; 6]>::try_into(chunk).unwrap() = new_chunk;
a_bpp = new_chunk;
c_bpp = b_bpp.try_into().unwrap();
}
let new_chunk = [
chunk[0]
.wrapping_add(filter_paeth_decode(a_bpp[0], b_bpp[0], c_bpp[0])),
chunk[1]
.wrapping_add(filter_paeth_decode(a_bpp[1], b_bpp[1], c_bpp[1])),
chunk[2]
.wrapping_add(filter_paeth_decode(a_bpp[2], b_bpp[2], c_bpp[2])),
chunk[3]
.wrapping_add(filter_paeth_decode(a_bpp[3], b_bpp[3], c_bpp[3])),
chunk[4]
.wrapping_add(filter_paeth_decode(a_bpp[4], b_bpp[4], c_bpp[4])),
chunk[5]
.wrapping_add(filter_paeth_decode(a_bpp[5], b_bpp[5], c_bpp[5])),
];
*TryInto::<&mut [u8; 6]>::try_into(chunk).unwrap() = new_chunk;
a_bpp = new_chunk;
c_bpp = b_bpp.try_into().unwrap();
}
}
BytesPerPixel::Eight => {
#[cfg(all(feature = "unstable", target_arch = "x86_64"))]
simd::unfilter_paeth_u8::<8>(previous, current);
{
simd::unfilter_paeth_u8::<8>(previous, current);
return;
}

#[cfg(not(feature = "unstable"))]
let mut a_bpp = [0; 8];
let mut c_bpp = [0; 8];
for (chunk, b_bpp) in current.chunks_exact_mut(8).zip(previous.chunks_exact(8))
{
let mut a_bpp = [0; 8];
let mut c_bpp = [0; 8];
for (chunk, b_bpp) in
current.chunks_exact_mut(8).zip(previous.chunks_exact(8))
{
let new_chunk = [
chunk[0].wrapping_add(filter_paeth_decode(
a_bpp[0], b_bpp[0], c_bpp[0],
)),
chunk[1].wrapping_add(filter_paeth_decode(
a_bpp[1], b_bpp[1], c_bpp[1],
)),
chunk[2].wrapping_add(filter_paeth_decode(
a_bpp[2], b_bpp[2], c_bpp[2],
)),
chunk[3].wrapping_add(filter_paeth_decode(
a_bpp[3], b_bpp[3], c_bpp[3],
)),
chunk[4].wrapping_add(filter_paeth_decode(
a_bpp[4], b_bpp[4], c_bpp[4],
)),
chunk[5].wrapping_add(filter_paeth_decode(
a_bpp[5], b_bpp[5], c_bpp[5],
)),
chunk[6].wrapping_add(filter_paeth_decode(
a_bpp[6], b_bpp[6], c_bpp[6],
)),
chunk[7].wrapping_add(filter_paeth_decode(
a_bpp[7], b_bpp[7], c_bpp[7],
)),
];
*TryInto::<&mut [u8; 8]>::try_into(chunk).unwrap() = new_chunk;
a_bpp = new_chunk;
c_bpp = b_bpp.try_into().unwrap();
}
let new_chunk = [
chunk[0]
.wrapping_add(filter_paeth_decode(a_bpp[0], b_bpp[0], c_bpp[0])),
chunk[1]
.wrapping_add(filter_paeth_decode(a_bpp[1], b_bpp[1], c_bpp[1])),
chunk[2]
.wrapping_add(filter_paeth_decode(a_bpp[2], b_bpp[2], c_bpp[2])),
chunk[3]
.wrapping_add(filter_paeth_decode(a_bpp[3], b_bpp[3], c_bpp[3])),
chunk[4]
.wrapping_add(filter_paeth_decode(a_bpp[4], b_bpp[4], c_bpp[4])),
chunk[5]
.wrapping_add(filter_paeth_decode(a_bpp[5], b_bpp[5], c_bpp[5])),
chunk[6]
.wrapping_add(filter_paeth_decode(a_bpp[6], b_bpp[6], c_bpp[6])),
chunk[7]
.wrapping_add(filter_paeth_decode(a_bpp[7], b_bpp[7], c_bpp[7])),
];
*TryInto::<&mut [u8; 8]>::try_into(chunk).unwrap() = new_chunk;
a_bpp = new_chunk;
c_bpp = b_bpp.try_into().unwrap();
}
}
}
Expand Down Expand Up @@ -1000,7 +981,7 @@ fn filter_internal(
.zip(&mut c_chunks)
{
for i in 0..CHUNK_SIZE {
out[i] = cur[i].wrapping_sub(filter_paeth(a[i], b[i], c[i]));
out[i] = cur[i].wrapping_sub(filter_paeth_fpnge(a[i], b[i], c[i]));
}
}

Expand All @@ -1012,11 +993,11 @@ fn filter_internal(
.zip(b_chunks.remainder())
.zip(c_chunks.remainder())
{
*out = cur.wrapping_sub(filter_paeth(a, b, c));
*out = cur.wrapping_sub(filter_paeth_fpnge(a, b, c));
}

for i in 0..bpp {
output[i] = current[i].wrapping_sub(filter_paeth(0, previous[i], 0));
output[i] = current[i].wrapping_sub(filter_paeth_fpnge(0, previous[i], 0));
}
Paeth
}
Expand Down Expand Up @@ -1085,7 +1066,7 @@ fn sum_buffer(buf: &[u8]) -> u64 {

#[cfg(test)]
mod test {
use super::{filter, unfilter, AdaptiveFilterType, BytesPerPixel, FilterType};
use super::*;
use core::iter;

#[test]
Expand Down Expand Up @@ -1135,11 +1116,17 @@ mod test {
#[test]
#[ignore] // takes ~20s without optimizations
fn paeth_impls_are_equivalent() {
use super::{filter_paeth, filter_paeth_decode};
for a in 0..=255 {
for b in 0..=255 {
for c in 0..=255 {
assert_eq!(filter_paeth(a, b, c), filter_paeth_decode(a, b, c));
let baseline = filter_paeth(a, b, c);
let fpnge = filter_paeth_fpnge(a, b, c);
let stbi = filter_paeth_stbi(a, b, c);
let stbi_i16 = filter_paeth_stbi_i16(a as i16, b as i16, c as i16);

assert_eq!(baseline, fpnge);
assert_eq!(baseline, stbi);
assert_eq!(baseline as i16, stbi_i16);
}
}
}
Expand Down

0 comments on commit fb660c2

Please sign in to comment.