Skip to content

Commit

Permalink
Keep a single version of gxhash (the one that passes smhasher) and de…
Browse files Browse the repository at this point in the history
…clare flag to use 256-bit version on-demand
  • Loading branch information
ogxd committed Oct 28, 2023
1 parent 873ad41 commit 29d191a
Show file tree
Hide file tree
Showing 7 changed files with 71 additions and 134 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ edition = "2021"
# The 256-bit state GxHash is faster for large inputs than the default 128-bit state implementation.
# Please not however that the 256-bit GxHash and the 128-bit GxHash don't generate the same hashes for a same input.
# Requires AVX2 and VAES (X86).
256-bit = []
avx2 = []

[dependencies]
rand = "0.8"
Expand Down
16 changes: 6 additions & 10 deletions benches/throughput.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@ fn benchmark_all(c: &mut Criterion) {
let mut rng = rand::thread_rng();

// Allocate 32-bytes-aligned
let layout = Layout::from_size_align(100000, 32).unwrap();
let layout = Layout::from_size_align(50_000, 32).unwrap();
let ptr = unsafe { alloc(layout) };
let slice: &mut [u8] = unsafe { slice::from_raw_parts_mut(ptr, 100000) };
let slice: &mut [u8] = unsafe { slice::from_raw_parts_mut(ptr, layout.size()) };

// Fill with random bytes
rng.fill(slice);
Expand All @@ -43,14 +43,10 @@ fn benchmark_all(c: &mut Criterion) {
let plot_config = PlotConfiguration::default().summary_scale(AxisScale::Logarithmic);
group.plot_config(plot_config);

// GxHash0
// benchmark(&mut group, slice, "gxhash0", |data: &[u8], _: i32| -> u64 {
// gxhash0_64(data, 0)
// });

// GxHash1
benchmark(&mut group, slice, "gxhash", |data: &[u8], _: i32| -> u64 {
gxhash1_64(data, 0)
// GxHash
let algo_name = if cfg!(feature = "avx2") { "gxhash-avx2" } else { "gxhash" };
benchmark(&mut group, slice, algo_name, |data: &[u8], _: i32| -> u64 {
gxhash64(data, 0)
});

// AHash
Expand Down
108 changes: 53 additions & 55 deletions src/gxhash/mod.rs
Original file line number Diff line number Diff line change
@@ -1,69 +1,67 @@
mod platform;
use platform::*;

#[inline(always)] // To be disabled when profiling
pub fn gxhash0_32(input: &[u8], seed: i32) -> u32 {
#[cfg(not(feature = "avx2"))]
#[inline(always)]
pub fn gxhash32(input: &[u8], seed: i32) -> u32 {
unsafe {
let p = &gxhash::<0>(input, seed) as *const state as *const u32;
let p = &gxhash(input, seed) as *const state as *const u32;
*p
}
}

#[inline(always)] // To be disabled when profiling
pub fn gxhash0_64(input: &[u8], seed: i32) -> u64 {
// Since the 256-bit runs AES operations on two 128-bit lanes, we need to extract
// the hash from the center, picking the same entropy amount from the two lanes
#[cfg(feature = "avx2")]
#[inline(always)]
pub fn gxhash32(input: &[u8], seed: i32) -> u32 {
unsafe {
let p = &gxhash::<0>(input, seed) as *const state as *const u64;
*p
let p = &gxhash(input, seed) as *const state as *const u8;
let offset = std::mem::size_of::<state>() / 2 - std::mem::size_of::<u32>() / 2 - 1;
let shifted_ptr = p.offset(offset as isize) as *const u32;
*shifted_ptr
}
}

#[inline(always)] // To be disabled when profiling
pub fn gxhash1_32(input: &[u8], seed: i32) -> u32 {
#[cfg(not(feature = "avx2"))]
#[inline(always)]
pub fn gxhash64(input: &[u8], seed: i32) -> u64 {
unsafe {
let p = &gxhash::<1>(input, seed) as *const state as *const u32;
let p = &gxhash(input, seed) as *const state as *const u64;
*p
}
}

#[inline(always)] // To be disabled when profiling
pub fn gxhash1_64(input: &[u8], seed: i32) -> u64 {
// Since the 256-bit runs AES operations on two 128-bit lanes, we need to extract
// the hash from the center, picking the same entropy amount from the two lanes
#[cfg(feature = "avx2")]
#[inline(always)]
pub fn gxhash64(input: &[u8], seed: i32) -> u64 {
unsafe {
let p = &gxhash::<1>(input, seed) as *const state as *const u64;
*p

// Alternative idea is to extract the center, to avoid xoring for 256 bit state
// let p = &gxhash::<1>(input, seed) as *const state as *const u8;
// let shifted_ptr = p.offset(3) as *const u64;
// *shifted_ptr
let p = &gxhash(input, seed) as *const state as *const u8;
let offset = std::mem::size_of::<state>() / 2 - std::mem::size_of::<u64>() / 2 - 1;
let shifted_ptr = p.offset(offset as isize) as *const u64;
*shifted_ptr
}
}

const VECTOR_SIZE: isize = std::mem::size_of::<state>() as isize;

#[inline(always)]
unsafe fn compress<const N: usize>(a: state, b: state) -> state {
match N {
0 => compress_0(a, b),
1 => compress_1(a, b),
_ => compress_1(a, b)
}
}

#[inline(always)]
fn gxhash<const N: usize>(input: &[u8], seed: i32) -> state {
fn gxhash(input: &[u8], seed: i32) -> state {
unsafe {
let len: isize = input.len() as isize;
let ptr = input.as_ptr() as *const state;

// Lower sizes first, as comparison/branching overhead will become negligible as input size grows.
let hash_vector = if len <= VECTOR_SIZE {
gxhash_process_last::<N>(ptr, create_empty(), len)
gxhash_process_last(ptr, create_empty(), len)
} else if len <= VECTOR_SIZE * 2 {
gxhash_process_last::<N>(ptr.offset(1), compress::<N>(*ptr, create_empty()), len - VECTOR_SIZE)
gxhash_process_last(ptr.offset(1), compress(*ptr, create_empty()), len - VECTOR_SIZE)
} else if len < VECTOR_SIZE * 8 {
gxhash_process_1::<N>(ptr, create_empty(), len)
gxhash_process_1(ptr, create_empty(), len)
} else {
gxhash_process_8::<N>(ptr, create_empty(), len)
gxhash_process_8(ptr, create_empty(), len)
};

finalize(hash_vector, seed)
Expand All @@ -81,7 +79,7 @@ macro_rules! load_unaligned {
}

#[inline(always)]
unsafe fn gxhash_process_8<const N: usize>(mut ptr: *const state, hash_vector: state, remaining_bytes: isize) -> state {
unsafe fn gxhash_process_8(mut ptr: *const state, hash_vector: state, remaining_bytes: isize) -> state {

const UNROLL_FACTOR: isize = 8;

Expand All @@ -95,42 +93,42 @@ unsafe fn gxhash_process_8<const N: usize>(mut ptr: *const state, hash_vector: s

prefetch(ptr);

v0 = compress::<0>(v0, v1);
v0 = compress::<0>(v0, v2);
v0 = compress::<0>(v0, v3);
v0 = compress::<0>(v0, v4);
v0 = compress::<0>(v0, v5);
v0 = compress::<0>(v0, v6);
v0 = compress::<0>(v0, v7);
v0 = compress_fast(v0, v1);
v0 = compress_fast(v0, v2);
v0 = compress_fast(v0, v3);
v0 = compress_fast(v0, v4);
v0 = compress_fast(v0, v5);
v0 = compress_fast(v0, v6);
v0 = compress_fast(v0, v7);

hash_vector = compress::<N>(hash_vector, v0);
hash_vector = compress(hash_vector, v0);
}

gxhash_process_1::<N>(ptr, hash_vector, remaining_bytes - unrollable_blocks_count * VECTOR_SIZE)
gxhash_process_1(ptr, hash_vector, remaining_bytes - unrollable_blocks_count * VECTOR_SIZE)
}

#[inline(always)]
unsafe fn gxhash_process_1<const N: usize>(mut ptr: *const state, hash_vector: state, remaining_bytes: isize) -> state {
unsafe fn gxhash_process_1(mut ptr: *const state, hash_vector: state, remaining_bytes: isize) -> state {

let end_address = ptr.offset((remaining_bytes / VECTOR_SIZE) as isize) as usize;

let mut hash_vector = hash_vector;
while (ptr as usize) < end_address {
load_unaligned!(ptr, v0);
hash_vector = compress::<N>(hash_vector, v0);
hash_vector = compress(hash_vector, v0);
}

let remaining_bytes = remaining_bytes & (VECTOR_SIZE - 1);
if remaining_bytes > 0 {
hash_vector = gxhash_process_last::<N>(ptr, hash_vector, remaining_bytes);
hash_vector = gxhash_process_last(ptr, hash_vector, remaining_bytes);
}
hash_vector
}

#[inline(always)]
unsafe fn gxhash_process_last<const N: usize>(ptr: *const state, hash_vector: state, remaining_bytes: isize) -> state {
unsafe fn gxhash_process_last(ptr: *const state, hash_vector: state, remaining_bytes: isize) -> state {
let partial_vector = get_partial(ptr, remaining_bytes);
compress::<N>(hash_vector, partial_vector)
compress(hash_vector, partial_vector)
}

#[cfg(test)]
Expand All @@ -155,12 +153,12 @@ mod tests {
fn all_blocks_are_consumed(#[case] size_bits: usize) {
let mut bytes = vec![42u8; size_bits];

let ref_hash = gxhash0_32(&bytes, 0);
let ref_hash = gxhash32(&bytes, 0);

for i in 0..bytes.len() {
let swap = bytes[i];
bytes[i] = 82;
let new_hash = gxhash0_32(&bytes, 0);
let new_hash = gxhash32(&bytes, 0);
bytes[i] = swap;

assert_ne!(ref_hash, new_hash, "byte {i} not processed");
Expand All @@ -177,7 +175,7 @@ mod tests {
let mut ref_hash = 0;

for i in 32..100 {
let new_hash = gxhash0_32(&mut bytes[..i], 0);
let new_hash = gxhash32(&mut bytes[..i], 0);
assert_ne!(ref_hash, new_hash, "Same hash at size {i} ({new_hash})");
ref_hash = new_hash;
}
Expand Down Expand Up @@ -218,7 +216,7 @@ mod tests {
}

i += 1;
set.insert(gxhash1_64(&bytes, 0));
set.insert(gxhash64(&bytes, 0));

// Reset bits
for d in digits.iter() {
Expand Down Expand Up @@ -254,8 +252,8 @@ mod tests {

#[test]
fn hash_of_zero_is_not_zero() {
assert_ne!(0, gxhash0_32(&[0u8; 0], 0));
assert_ne!(0, gxhash0_32(&[0u8; 1], 0));
assert_ne!(0, gxhash0_32(&[0u8; 1200], 0));
assert_ne!(0, gxhash32(&[0u8; 0], 0));
assert_ne!(0, gxhash32(&[0u8; 1], 0));
assert_ne!(0, gxhash32(&[0u8; 1200], 0));
}
}
63 changes: 4 additions & 59 deletions src/gxhash/platform/arm_128.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,73 +63,18 @@ unsafe fn get_partial_safe(data: *const i8, len: usize) -> state {
}

#[inline(always)]
pub unsafe fn compress_1(a: int8x16_t, b: int8x16_t) -> int8x16_t {
// 37 GiB/s
pub unsafe fn compress(a: int8x16_t, b: int8x16_t) -> int8x16_t {
let keys_1 = vld1q_u32([0xFC3BC28E, 0x89C222E5, 0xB09D3E21, 0xF2784542].as_ptr());
let keys_2 = vld1q_u32([0x03FCE279, 0xCB6B2E9B, 0xB361DC58, 0x39136BD9].as_ptr());

let b = aes_encrypt(vreinterpretq_u8_s8(b), vreinterpretq_u8_u32(keys_1));
let a = aes_encrypt(vreinterpretq_u8_s8(a), vreinterpretq_u8_u32(keys_2));
vreinterpretq_s8_u8(aes_encrypt_last(a, b))

// 70 GiB/s
//vreinterpretq_s8_u8(aes_encrypt(vreinterpretq_u8_s8(a), vreinterpretq_u8_s8(b)))

//vreinterpretq_s8_u8(chmuck(vreinterpretq_u8_s8(a), vreinterpretq_u8_s8(b)))
// 26 GiB/s
// let keys_1 = vld1q_u32([0x713B01D0, 0x8F2F35DB, 0xAF163956, 0x85459F85].as_ptr());
// let keys_2 = vld1q_u32([0x1DE09647, 0x92CFA39C, 0x3DD99ACA, 0xB89C054F].as_ptr());
// let keys_3 = vld1q_u32([0xC78B122B, 0x5544B1B7, 0x689D2B7D, 0xD0012E32].as_ptr());
// let b1 = vaddq_u32(vreinterpretq_u32_s8(b), keys_2); // Cheap
// let b2 = vreinterpretq_s8_u32(vmulq_u32(b1, keys_1)); // Cheap
// let b3 = vreinterpretq_u32_s8(vextq_s8(b2, b2, 3)); // Expensive
// let b4 = vaddq_u32(b3, keys_3); // Cheap
// let b5 = vreinterpretq_s8_u32(vmulq_u32(b4, keys_2)); // Cheap
// let b6 = vreinterpretq_u32_s8(vextq_s8(b5, b5, 3)); // Expensive
// let b7 = vaddq_u32(b6, keys_1); // Cheap
// let b8 = vmulq_u32(b7, keys_3); // Cheap
// let b9 = veorq_s8(a, vreinterpretq_s8_u32(b8));
// vextq_s8(b9, b9, 7)

//let primes = vld1q_u32([0x9e3779b9, 0x9e3779b9, 0x9e3779b9, 0x9e3779b9].as_ptr());
// let keys_2 = vld1q_u32([0x1DE09647, 0x92CFA39C, 0x3DD99ACA, 0xB89C054F].as_ptr());
// let keys_3 = vld1q_u32([0xC78B122B, 0x5544B1B7, 0x689D2B7D, 0xD0012E32].as_ptr());
// let b1 = vaddq_u32(vreinterpretq_u32_s8(b), primes); // Cheap
// let b2 = vreinterpretq_s8_u32(vmulq_u32(vreinterpretq_u32_s8(b), primes)); // Cheap
// let shifted = vshlq_n_s8::<1>(b2);
// let b3 = veorq_s8(b, shifted);
//let b3: uint32x4_t = vreinterpretq_u32_s8(vextq_s8(b2, b2, 3)); // Expensive
// let b4 = vaddq_u32(b3, keys_3); // Cheap
// let b5 = vreinterpretq_s8_u32(vmulq_u32(b4, keys_2)); // Cheap
// let b6 = vreinterpretq_u32_s8(vextq_s8(b5, b5, 3)); // Expensive
// let b7 = vaddq_u32(b6, keys_1); // Cheap
// let b8 = vmulq_u32(b7, keys_3); // Cheap
// let b9 = vaddq_u32(
// vreinterpretq_u32_s8(b),
// veorq_u32(
// primes,
// vaddq_u32(
// vshrq_n_u32::<2>(vreinterpretq_u32_s8(a)),
// vshlq_n_u32::<6>(vreinterpretq_u32_s8(a)))));

// vextq_s8(vreinterpretq_s8_u32(b9), vreinterpretq_s8_u32(b9), 1)

// let mut x = vreinterpretq_u32_s8(b);
// // Round 1
// x = veorq_u32(x, vshrq_n_u32::<16>(x));
// x = vmulq_u32(x, vld1q_u32([0x7feb352d, 0x7feb352d, 0x7feb352d, 0x7feb352d].as_ptr()));
// // Round 2
// x = veorq_u32(x, vshrq_n_u32::<15>(x));
// x = vmulq_u32(x, vld1q_u32([0x846ca68b, 0x846ca68b, 0x846ca68b, 0x846ca68b].as_ptr()));
// // Round 3
// x = veorq_u32(x, vshrq_n_u32::<16>(x));
// let f = vaddq_s8(a, vreinterpretq_s8_u32(x));
// vextq_s8(f, f, 1)

//ve
vreinterpretq_s8_u8(aes_encrypt_last(a, b))
}

#[inline(always)]
pub unsafe fn compress_0(a: int8x16_t, b: int8x16_t) -> int8x16_t {
pub unsafe fn compress_fast(a: int8x16_t, b: int8x16_t) -> int8x16_t {
vreinterpretq_s8_u8(aes_encrypt(vreinterpretq_u8_s8(a), vreinterpretq_u8_s8(b)))
}

Expand Down
4 changes: 2 additions & 2 deletions src/gxhash/platform/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@
pub mod platform;

#[cfg(all(
feature = "256-bit",
feature = "avx2",
target_arch = "x86_64",
target_feature = "avx2")
)]
#[path = "x86_256.rs"]
pub mod platform;

#[cfg(all(
not(feature = "256-bit"),
not(feature = "avx2"),
target_arch = "x86_64"
))]
#[path = "x86_128.rs"]
Expand Down
4 changes: 2 additions & 2 deletions src/gxhash/platform/x86_128.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ unsafe fn get_partial_safe(data: *const u8, len: usize) -> state {

#[inline]
#[allow(overflowing_literals)]
pub unsafe fn compress_1(a: state, b: state) -> state {
pub unsafe fn compress(a: state, b: state) -> state {
let keys_1 = _mm_set_epi32(0xFC3BC28E, 0x89C222E5, 0xB09D3E21, 0xF2784542);
let keys_2 = _mm_set_epi32(0x03FCE279, 0xCB6B2E9B, 0xB361DC58, 0x39136BD9);

Expand All @@ -69,7 +69,7 @@ pub unsafe fn compress_1(a: state, b: state) -> state {

#[inline]
#[allow(overflowing_literals)]
pub unsafe fn compress_0(a: state, b: state) -> state {
pub unsafe fn compress_fast(a: state, b: state) -> state {
return _mm_aesenc_si128(a, b);
}

Expand Down
8 changes: 3 additions & 5 deletions src/gxhash/platform/x86_256.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ unsafe fn get_partial_safe(data: *const u8, len: usize) -> state {

#[inline]
#[allow(overflowing_literals)]
pub unsafe fn compress_1(a: state, b: state) -> state {
pub unsafe fn compress(a: state, b: state) -> state {
let keys_1 = _mm256_set_epi32(0xFC3BC28E, 0x89C222E5, 0xB09D3E21, 0xF2784542, 0x4155EE07, 0xC897CCE2, 0x780AF2C3, 0x8A72B781);
let keys_2 = _mm256_set_epi32(0x03FCE279, 0xCB6B2E9B, 0xB361DC58, 0x39136BD9, 0x7A83D76B, 0xB1E8F9F0, 0x028925A8, 0x3B9A4E71);

Expand All @@ -69,7 +69,7 @@ pub unsafe fn compress_1(a: state, b: state) -> state {

#[inline]
#[allow(overflowing_literals)]
pub unsafe fn compress_0(a: state, b: state) -> state {
pub unsafe fn compress_fast(a: state, b: state) -> state {
return _mm256_aesenc_epi128(a, b);
}

Expand All @@ -87,7 +87,5 @@ pub unsafe fn finalize(hash: state, seed: i32) -> state {
hash = _mm256_aesenc_epi128(hash, keys_2);
hash = _mm256_aesenclast_epi128(hash, keys_3);

// Merge the two 128 bit lanes entropy, so we can after safely truncate up to 128-bits
let permuted = _mm256_permute2x128_si256(hash, hash, 0x21);
_mm256_xor_si256(hash, permuted)
hash
}

0 comments on commit 29d191a

Please sign in to comment.