From 29d191a615b617cf77bb978317af885e8878e6fb Mon Sep 17 00:00:00 2001 From: Olivier Giniaux Date: Sun, 29 Oct 2023 01:35:12 +0200 Subject: [PATCH] Keep a single version of gxhash (the one that passes smhasher) and declare flag to use 256-bit version on-demand --- Cargo.toml | 2 +- benches/throughput.rs | 16 ++--- src/gxhash/mod.rs | 108 ++++++++++++++++----------------- src/gxhash/platform/arm_128.rs | 63 ++----------------- src/gxhash/platform/mod.rs | 4 +- src/gxhash/platform/x86_128.rs | 4 +- src/gxhash/platform/x86_256.rs | 8 +-- 7 files changed, 71 insertions(+), 134 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 42ac185..ad59ed9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,7 +8,7 @@ edition = "2021" # The 256-bit state GxHash is faster for large inputs than the default 128-bit state implementation. # Please not however that the 256-bit GxHash and the 128-bit GxHash don't generate the same hashes for a same input. # Requires AVX2 and VAES (X86). -256-bit = [] +avx2 = [] [dependencies] rand = "0.8" diff --git a/benches/throughput.rs b/benches/throughput.rs index 48f9609..c40c64f 100644 --- a/benches/throughput.rs +++ b/benches/throughput.rs @@ -32,9 +32,9 @@ fn benchmark_all(c: &mut Criterion) { let mut rng = rand::thread_rng(); // Allocate 32-bytes-aligned - let layout = Layout::from_size_align(100000, 32).unwrap(); + let layout = Layout::from_size_align(50_000, 32).unwrap(); let ptr = unsafe { alloc(layout) }; - let slice: &mut [u8] = unsafe { slice::from_raw_parts_mut(ptr, 100000) }; + let slice: &mut [u8] = unsafe { slice::from_raw_parts_mut(ptr, layout.size()) }; // Fill with random bytes rng.fill(slice); @@ -43,14 +43,10 @@ fn benchmark_all(c: &mut Criterion) { let plot_config = PlotConfiguration::default().summary_scale(AxisScale::Logarithmic); group.plot_config(plot_config); - // GxHash0 - // benchmark(&mut group, slice, "gxhash0", |data: &[u8], _: i32| -> u64 { - // gxhash0_64(data, 0) - // }); - - // GxHash1 - benchmark(&mut group, slice, "gxhash", |data: &[u8], _: i32| -> u64 { - gxhash1_64(data, 0) + // GxHash + let algo_name = if cfg!(feature = "avx2") { "gxhash-avx2" } else { "gxhash" }; + benchmark(&mut group, slice, algo_name, |data: &[u8], _: i32| -> u64 { + gxhash64(data, 0) }); // AHash diff --git a/src/gxhash/mod.rs b/src/gxhash/mod.rs index 75d57b3..99435e6 100644 --- a/src/gxhash/mod.rs +++ b/src/gxhash/mod.rs @@ -1,69 +1,67 @@ mod platform; use platform::*; -#[inline(always)] // To be disabled when profiling -pub fn gxhash0_32(input: &[u8], seed: i32) -> u32 { +#[cfg(not(feature = "avx2"))] +#[inline(always)] +pub fn gxhash32(input: &[u8], seed: i32) -> u32 { unsafe { - let p = &gxhash::<0>(input, seed) as *const state as *const u32; + let p = &gxhash(input, seed) as *const state as *const u32; *p } } -#[inline(always)] // To be disabled when profiling -pub fn gxhash0_64(input: &[u8], seed: i32) -> u64 { +// Since the 256-bit runs AES operations on two 128-bit lanes, we need to extract +// the hash from the center, picking the same entropy amount from the two lanes +#[cfg(feature = "avx2")] +#[inline(always)] +pub fn gxhash32(input: &[u8], seed: i32) -> u32 { unsafe { - let p = &gxhash::<0>(input, seed) as *const state as *const u64; - *p + let p = &gxhash(input, seed) as *const state as *const u8; + let offset = std::mem::size_of::() / 2 - std::mem::size_of::() / 2 - 1; + let shifted_ptr = p.offset(offset as isize) as *const u32; + *shifted_ptr } } -#[inline(always)] // To be disabled when profiling -pub fn gxhash1_32(input: &[u8], seed: i32) -> u32 { +#[cfg(not(feature = "avx2"))] +#[inline(always)] +pub fn gxhash64(input: &[u8], seed: i32) -> u64 { unsafe { - let p = &gxhash::<1>(input, seed) as *const state as *const u32; + let p = &gxhash(input, seed) as *const state as *const u64; *p } } -#[inline(always)] // To be disabled when profiling -pub fn gxhash1_64(input: &[u8], seed: i32) -> u64 { +// Since the 256-bit runs AES operations on two 128-bit lanes, we need to extract +// the hash from the center, picking the same entropy amount from the two lanes +#[cfg(feature = "avx2")] +#[inline(always)] +pub fn gxhash64(input: &[u8], seed: i32) -> u64 { unsafe { - let p = &gxhash::<1>(input, seed) as *const state as *const u64; - *p - - // Alternative idea is to extract the center, to avoid xoring for 256 bit state - // let p = &gxhash::<1>(input, seed) as *const state as *const u8; - // let shifted_ptr = p.offset(3) as *const u64; - // *shifted_ptr + let p = &gxhash(input, seed) as *const state as *const u8; + let offset = std::mem::size_of::() / 2 - std::mem::size_of::() / 2 - 1; + let shifted_ptr = p.offset(offset as isize) as *const u64; + *shifted_ptr } } const VECTOR_SIZE: isize = std::mem::size_of::() as isize; #[inline(always)] -unsafe fn compress(a: state, b: state) -> state { - match N { - 0 => compress_0(a, b), - 1 => compress_1(a, b), - _ => compress_1(a, b) - } -} - -#[inline(always)] -fn gxhash(input: &[u8], seed: i32) -> state { +fn gxhash(input: &[u8], seed: i32) -> state { unsafe { let len: isize = input.len() as isize; let ptr = input.as_ptr() as *const state; // Lower sizes first, as comparison/branching overhead will become negligible as input size grows. let hash_vector = if len <= VECTOR_SIZE { - gxhash_process_last::(ptr, create_empty(), len) + gxhash_process_last(ptr, create_empty(), len) } else if len <= VECTOR_SIZE * 2 { - gxhash_process_last::(ptr.offset(1), compress::(*ptr, create_empty()), len - VECTOR_SIZE) + gxhash_process_last(ptr.offset(1), compress(*ptr, create_empty()), len - VECTOR_SIZE) } else if len < VECTOR_SIZE * 8 { - gxhash_process_1::(ptr, create_empty(), len) + gxhash_process_1(ptr, create_empty(), len) } else { - gxhash_process_8::(ptr, create_empty(), len) + gxhash_process_8(ptr, create_empty(), len) }; finalize(hash_vector, seed) @@ -81,7 +79,7 @@ macro_rules! load_unaligned { } #[inline(always)] -unsafe fn gxhash_process_8(mut ptr: *const state, hash_vector: state, remaining_bytes: isize) -> state { +unsafe fn gxhash_process_8(mut ptr: *const state, hash_vector: state, remaining_bytes: isize) -> state { const UNROLL_FACTOR: isize = 8; @@ -95,42 +93,42 @@ unsafe fn gxhash_process_8(mut ptr: *const state, hash_vector: s prefetch(ptr); - v0 = compress::<0>(v0, v1); - v0 = compress::<0>(v0, v2); - v0 = compress::<0>(v0, v3); - v0 = compress::<0>(v0, v4); - v0 = compress::<0>(v0, v5); - v0 = compress::<0>(v0, v6); - v0 = compress::<0>(v0, v7); + v0 = compress_fast(v0, v1); + v0 = compress_fast(v0, v2); + v0 = compress_fast(v0, v3); + v0 = compress_fast(v0, v4); + v0 = compress_fast(v0, v5); + v0 = compress_fast(v0, v6); + v0 = compress_fast(v0, v7); - hash_vector = compress::(hash_vector, v0); + hash_vector = compress(hash_vector, v0); } - gxhash_process_1::(ptr, hash_vector, remaining_bytes - unrollable_blocks_count * VECTOR_SIZE) + gxhash_process_1(ptr, hash_vector, remaining_bytes - unrollable_blocks_count * VECTOR_SIZE) } #[inline(always)] -unsafe fn gxhash_process_1(mut ptr: *const state, hash_vector: state, remaining_bytes: isize) -> state { +unsafe fn gxhash_process_1(mut ptr: *const state, hash_vector: state, remaining_bytes: isize) -> state { let end_address = ptr.offset((remaining_bytes / VECTOR_SIZE) as isize) as usize; let mut hash_vector = hash_vector; while (ptr as usize) < end_address { load_unaligned!(ptr, v0); - hash_vector = compress::(hash_vector, v0); + hash_vector = compress(hash_vector, v0); } let remaining_bytes = remaining_bytes & (VECTOR_SIZE - 1); if remaining_bytes > 0 { - hash_vector = gxhash_process_last::(ptr, hash_vector, remaining_bytes); + hash_vector = gxhash_process_last(ptr, hash_vector, remaining_bytes); } hash_vector } #[inline(always)] -unsafe fn gxhash_process_last(ptr: *const state, hash_vector: state, remaining_bytes: isize) -> state { +unsafe fn gxhash_process_last(ptr: *const state, hash_vector: state, remaining_bytes: isize) -> state { let partial_vector = get_partial(ptr, remaining_bytes); - compress::(hash_vector, partial_vector) + compress(hash_vector, partial_vector) } #[cfg(test)] @@ -155,12 +153,12 @@ mod tests { fn all_blocks_are_consumed(#[case] size_bits: usize) { let mut bytes = vec![42u8; size_bits]; - let ref_hash = gxhash0_32(&bytes, 0); + let ref_hash = gxhash32(&bytes, 0); for i in 0..bytes.len() { let swap = bytes[i]; bytes[i] = 82; - let new_hash = gxhash0_32(&bytes, 0); + let new_hash = gxhash32(&bytes, 0); bytes[i] = swap; assert_ne!(ref_hash, new_hash, "byte {i} not processed"); @@ -177,7 +175,7 @@ mod tests { let mut ref_hash = 0; for i in 32..100 { - let new_hash = gxhash0_32(&mut bytes[..i], 0); + let new_hash = gxhash32(&mut bytes[..i], 0); assert_ne!(ref_hash, new_hash, "Same hash at size {i} ({new_hash})"); ref_hash = new_hash; } @@ -218,7 +216,7 @@ mod tests { } i += 1; - set.insert(gxhash1_64(&bytes, 0)); + set.insert(gxhash64(&bytes, 0)); // Reset bits for d in digits.iter() { @@ -254,8 +252,8 @@ mod tests { #[test] fn hash_of_zero_is_not_zero() { - assert_ne!(0, gxhash0_32(&[0u8; 0], 0)); - assert_ne!(0, gxhash0_32(&[0u8; 1], 0)); - assert_ne!(0, gxhash0_32(&[0u8; 1200], 0)); + assert_ne!(0, gxhash32(&[0u8; 0], 0)); + assert_ne!(0, gxhash32(&[0u8; 1], 0)); + assert_ne!(0, gxhash32(&[0u8; 1200], 0)); } } \ No newline at end of file diff --git a/src/gxhash/platform/arm_128.rs b/src/gxhash/platform/arm_128.rs index e888226..35dc12d 100644 --- a/src/gxhash/platform/arm_128.rs +++ b/src/gxhash/platform/arm_128.rs @@ -63,73 +63,18 @@ unsafe fn get_partial_safe(data: *const i8, len: usize) -> state { } #[inline(always)] -pub unsafe fn compress_1(a: int8x16_t, b: int8x16_t) -> int8x16_t { - // 37 GiB/s +pub unsafe fn compress(a: int8x16_t, b: int8x16_t) -> int8x16_t { let keys_1 = vld1q_u32([0xFC3BC28E, 0x89C222E5, 0xB09D3E21, 0xF2784542].as_ptr()); let keys_2 = vld1q_u32([0x03FCE279, 0xCB6B2E9B, 0xB361DC58, 0x39136BD9].as_ptr()); + let b = aes_encrypt(vreinterpretq_u8_s8(b), vreinterpretq_u8_u32(keys_1)); let a = aes_encrypt(vreinterpretq_u8_s8(a), vreinterpretq_u8_u32(keys_2)); - vreinterpretq_s8_u8(aes_encrypt_last(a, b)) - // 70 GiB/s - //vreinterpretq_s8_u8(aes_encrypt(vreinterpretq_u8_s8(a), vreinterpretq_u8_s8(b))) - - //vreinterpretq_s8_u8(chmuck(vreinterpretq_u8_s8(a), vreinterpretq_u8_s8(b))) - // 26 GiB/s - // let keys_1 = vld1q_u32([0x713B01D0, 0x8F2F35DB, 0xAF163956, 0x85459F85].as_ptr()); - // let keys_2 = vld1q_u32([0x1DE09647, 0x92CFA39C, 0x3DD99ACA, 0xB89C054F].as_ptr()); - // let keys_3 = vld1q_u32([0xC78B122B, 0x5544B1B7, 0x689D2B7D, 0xD0012E32].as_ptr()); - // let b1 = vaddq_u32(vreinterpretq_u32_s8(b), keys_2); // Cheap - // let b2 = vreinterpretq_s8_u32(vmulq_u32(b1, keys_1)); // Cheap - // let b3 = vreinterpretq_u32_s8(vextq_s8(b2, b2, 3)); // Expensive - // let b4 = vaddq_u32(b3, keys_3); // Cheap - // let b5 = vreinterpretq_s8_u32(vmulq_u32(b4, keys_2)); // Cheap - // let b6 = vreinterpretq_u32_s8(vextq_s8(b5, b5, 3)); // Expensive - // let b7 = vaddq_u32(b6, keys_1); // Cheap - // let b8 = vmulq_u32(b7, keys_3); // Cheap - // let b9 = veorq_s8(a, vreinterpretq_s8_u32(b8)); - // vextq_s8(b9, b9, 7) - - //let primes = vld1q_u32([0x9e3779b9, 0x9e3779b9, 0x9e3779b9, 0x9e3779b9].as_ptr()); - // let keys_2 = vld1q_u32([0x1DE09647, 0x92CFA39C, 0x3DD99ACA, 0xB89C054F].as_ptr()); - // let keys_3 = vld1q_u32([0xC78B122B, 0x5544B1B7, 0x689D2B7D, 0xD0012E32].as_ptr()); - // let b1 = vaddq_u32(vreinterpretq_u32_s8(b), primes); // Cheap - // let b2 = vreinterpretq_s8_u32(vmulq_u32(vreinterpretq_u32_s8(b), primes)); // Cheap - // let shifted = vshlq_n_s8::<1>(b2); - // let b3 = veorq_s8(b, shifted); - //let b3: uint32x4_t = vreinterpretq_u32_s8(vextq_s8(b2, b2, 3)); // Expensive - // let b4 = vaddq_u32(b3, keys_3); // Cheap - // let b5 = vreinterpretq_s8_u32(vmulq_u32(b4, keys_2)); // Cheap - // let b6 = vreinterpretq_u32_s8(vextq_s8(b5, b5, 3)); // Expensive - // let b7 = vaddq_u32(b6, keys_1); // Cheap - // let b8 = vmulq_u32(b7, keys_3); // Cheap - // let b9 = vaddq_u32( - // vreinterpretq_u32_s8(b), - // veorq_u32( - // primes, - // vaddq_u32( - // vshrq_n_u32::<2>(vreinterpretq_u32_s8(a)), - // vshlq_n_u32::<6>(vreinterpretq_u32_s8(a))))); - - // vextq_s8(vreinterpretq_s8_u32(b9), vreinterpretq_s8_u32(b9), 1) - - // let mut x = vreinterpretq_u32_s8(b); - // // Round 1 - // x = veorq_u32(x, vshrq_n_u32::<16>(x)); - // x = vmulq_u32(x, vld1q_u32([0x7feb352d, 0x7feb352d, 0x7feb352d, 0x7feb352d].as_ptr())); - // // Round 2 - // x = veorq_u32(x, vshrq_n_u32::<15>(x)); - // x = vmulq_u32(x, vld1q_u32([0x846ca68b, 0x846ca68b, 0x846ca68b, 0x846ca68b].as_ptr())); - // // Round 3 - // x = veorq_u32(x, vshrq_n_u32::<16>(x)); - // let f = vaddq_s8(a, vreinterpretq_s8_u32(x)); - // vextq_s8(f, f, 1) - - //ve + vreinterpretq_s8_u8(aes_encrypt_last(a, b)) } #[inline(always)] -pub unsafe fn compress_0(a: int8x16_t, b: int8x16_t) -> int8x16_t { +pub unsafe fn compress_fast(a: int8x16_t, b: int8x16_t) -> int8x16_t { vreinterpretq_s8_u8(aes_encrypt(vreinterpretq_u8_s8(a), vreinterpretq_u8_s8(b))) } diff --git a/src/gxhash/platform/mod.rs b/src/gxhash/platform/mod.rs index 230c2b2..e274fac 100644 --- a/src/gxhash/platform/mod.rs +++ b/src/gxhash/platform/mod.rs @@ -3,7 +3,7 @@ pub mod platform; #[cfg(all( - feature = "256-bit", + feature = "avx2", target_arch = "x86_64", target_feature = "avx2") )] @@ -11,7 +11,7 @@ pub mod platform; pub mod platform; #[cfg(all( - not(feature = "256-bit"), + not(feature = "avx2"), target_arch = "x86_64" ))] #[path = "x86_128.rs"] diff --git a/src/gxhash/platform/x86_128.rs b/src/gxhash/platform/x86_128.rs index d9f71dc..5d00fb8 100644 --- a/src/gxhash/platform/x86_128.rs +++ b/src/gxhash/platform/x86_128.rs @@ -57,7 +57,7 @@ unsafe fn get_partial_safe(data: *const u8, len: usize) -> state { #[inline] #[allow(overflowing_literals)] -pub unsafe fn compress_1(a: state, b: state) -> state { +pub unsafe fn compress(a: state, b: state) -> state { let keys_1 = _mm_set_epi32(0xFC3BC28E, 0x89C222E5, 0xB09D3E21, 0xF2784542); let keys_2 = _mm_set_epi32(0x03FCE279, 0xCB6B2E9B, 0xB361DC58, 0x39136BD9); @@ -69,7 +69,7 @@ pub unsafe fn compress_1(a: state, b: state) -> state { #[inline] #[allow(overflowing_literals)] -pub unsafe fn compress_0(a: state, b: state) -> state { +pub unsafe fn compress_fast(a: state, b: state) -> state { return _mm_aesenc_si128(a, b); } diff --git a/src/gxhash/platform/x86_256.rs b/src/gxhash/platform/x86_256.rs index c5ae5ac..84d19ae 100644 --- a/src/gxhash/platform/x86_256.rs +++ b/src/gxhash/platform/x86_256.rs @@ -57,7 +57,7 @@ unsafe fn get_partial_safe(data: *const u8, len: usize) -> state { #[inline] #[allow(overflowing_literals)] -pub unsafe fn compress_1(a: state, b: state) -> state { +pub unsafe fn compress(a: state, b: state) -> state { let keys_1 = _mm256_set_epi32(0xFC3BC28E, 0x89C222E5, 0xB09D3E21, 0xF2784542, 0x4155EE07, 0xC897CCE2, 0x780AF2C3, 0x8A72B781); let keys_2 = _mm256_set_epi32(0x03FCE279, 0xCB6B2E9B, 0xB361DC58, 0x39136BD9, 0x7A83D76B, 0xB1E8F9F0, 0x028925A8, 0x3B9A4E71); @@ -69,7 +69,7 @@ pub unsafe fn compress_1(a: state, b: state) -> state { #[inline] #[allow(overflowing_literals)] -pub unsafe fn compress_0(a: state, b: state) -> state { +pub unsafe fn compress_fast(a: state, b: state) -> state { return _mm256_aesenc_epi128(a, b); } @@ -87,7 +87,5 @@ pub unsafe fn finalize(hash: state, seed: i32) -> state { hash = _mm256_aesenc_epi128(hash, keys_2); hash = _mm256_aesenclast_epi128(hash, keys_3); - // Merge the two 128 bit lanes entropy, so we can after safely truncate up to 128-bits - let permuted = _mm256_permute2x128_si256(hash, hash, 0x21); - _mm256_xor_si256(hash, permuted) + hash } \ No newline at end of file